# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Caching model wrapper for durable execution."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import structlog
from pydantic_ai.models.wrapper import WrapperModel
from airflow.providers.common.ai.durable.fingerprint import fingerprint_model_request
[docs]
log = structlog.get_logger(logger_name="task")
if TYPE_CHECKING:
from pydantic_ai.messages import ModelMessage, ModelResponse
from pydantic_ai.models import ModelRequestParameters
from pydantic_ai.settings import ModelSettings
from airflow.providers.common.ai.durable.step_counter import DurableStepCounter
from airflow.providers.common.ai.durable.storage import DurableStorage
@dataclass(init=False)
[docs]
class CachingModel(WrapperModel):
"""
Wraps a model to cache responses in ObjectStorage for durable execution.
On each ``request()`` call, checks if a cached response exists for the
current step index and was produced by an equivalent request (same model,
message history, settings, and tools -- compared via fingerprint). If so,
returns the cached response without calling the underlying model.
Otherwise, calls the model and caches the response. A fingerprint
mismatch means the agent changed between attempts; the stale entry is
discarded and the step re-runs live.
"""
[docs]
storage: DurableStorage = field(repr=False)
[docs]
counter: DurableStepCounter = field(repr=False)
def __init__(
self,
wrapped: Any,
*,
storage: DurableStorage,
counter: DurableStepCounter,
) -> None:
super().__init__(wrapped)
self.storage = storage
self.counter = counter
[docs]
async def request(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
step = self.counter.next_step()
key = f"model_step_{step}"
# Fingerprint the *prepared* request, not the raw arguments. Concrete
# models call ``prepare_request()`` at the start of ``request()`` to merge
# their model-level ``settings`` and apply profile-specific transforms
# (thinking resolution, native-tool handling, output-mode defaults) before
# the provider sees the request. Fingerprinting the raw arguments would
# miss a change that lives only at the model level -- e.g. a different
# temperature or thinking setting on the connection -- and replay a stale
# response. The raw arguments are still passed to ``wrapped.request()``,
# which re-runs ``prepare_request()`` itself (it is pure and idempotent).
prepared_settings, prepared_parameters = self.wrapped.prepare_request(
model_settings, model_request_parameters
)
fingerprint = fingerprint_model_request(
f"{self.wrapped.system}:{self.wrapped.model_name}",
messages,
prepared_settings,
prepared_parameters,
)
cached, cached_fingerprint = self.storage.load_model_response(key)
if cached is not None:
if cached_fingerprint == fingerprint:
self.counter.replayed_model += 1
log.debug("Durable: replayed cached model response", step=step)
return cached
log.warning(
"Durable: cached model response does not match the current request; "
"re-running this step instead of replaying",
step=step,
reason=(
"entry predates fingerprinting or the request could not be fingerprinted"
if fingerprint is None or cached_fingerprint is None
else "model, prompt, message history, settings, or tools changed since "
"the previous attempt"
),
)
response = await self.wrapped.request(messages, model_settings, model_request_parameters)
self.storage.save_model_response(key, response, fingerprint=fingerprint)
self.counter.cached_model += 1
log.debug("Durable: cached model response", step=step)
return response