Source code for airflow.models.xcom

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import json
import logging
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, cast

from sqlalchemy import (
    JSON,
    Column,
    ForeignKeyConstraint,
    Index,
    Integer,
    PrimaryKeyConstraint,
    String,
    delete,
    func,
    select,
    text,
)
from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import Query, relationship

from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies
from airflow.utils import timezone
from airflow.utils.db import LazySelectSequence
from airflow.utils.helpers import is_container
from airflow.utils.json import XComDecoder, XComEncoder
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime

# XCom constants below are needed for providers backward compatibility,
# which should import the constants directly after apache-airflow>=2.6.0
from airflow.utils.xcom import (
    MAX_XCOM_SIZE,  # noqa: F401
    XCOM_RETURN_KEY,
)

[docs] log = logging.getLogger(__name__)
if TYPE_CHECKING: from sqlalchemy.engine import Row from sqlalchemy.orm import Session from sqlalchemy.sql.expression import Select, TextClause
[docs] class XComModel(TaskInstanceDependencies): """XCom model class. Contains table and some utilities."""
[docs] __tablename__ = "xcom"
[docs] dag_run_id = Column(Integer(), nullable=False, primary_key=True)
[docs] task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False, primary_key=True)
[docs] map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1"))
[docs] key = Column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True)
# Denormalized for easier lookup.
[docs] dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
[docs] run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
[docs] value = Column(JSON().with_variant(postgresql.JSONB, "postgresql"))
[docs] timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
[docs] __table_args__ = ( # Ideally we should create a unique index over (key, dag_id, task_id, run_id), # but it goes over MySQL's index length limit. So we instead index 'key' # separately, and enforce uniqueness with DagRun.id instead. Index("idx_xcom_key", key), Index("idx_xcom_task_instance", dag_id, task_id, run_id, map_index), PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key", name="xcom_pkey"), ForeignKeyConstraint( [dag_id, task_id, run_id, map_index], [ "task_instance.dag_id", "task_instance.task_id", "task_instance.run_id", "task_instance.map_index", ], name="xcom_task_instance_fkey", ondelete="CASCADE", ), )
[docs] dag_run = relationship( "DagRun", primaryjoin="XComModel.dag_run_id == foreign(DagRun.id)", uselist=False, lazy="joined", passive_deletes="all", )
[docs] logical_date = association_proxy("dag_run", "logical_date")
@classmethod @provide_session
[docs] def clear( cls, *, dag_id: str, task_id: str, run_id: str, map_index: int | None = None, session: Session = NEW_SESSION, ) -> None: """ Clear all XCom data from the database for the given task instance. .. note:: This **will not** purge any data from a custom XCom backend. :param dag_id: ID of DAG to clear the XCom for. :param task_id: ID of task to clear the XCom for. :param run_id: ID of DAG run to clear the XCom for. :param map_index: If given, only clear XCom from this particular mapped task. The default ``None`` clears *all* XComs from the task. :param session: Database session. If not given, a new session will be created for this function. """ # Given the historic order of this function (logical_date was first argument) to add a new optional # param we need to add default values for everything :( if dag_id is None: raise TypeError("clear() missing required argument: dag_id") if task_id is None: raise TypeError("clear() missing required argument: task_id") if not run_id: raise ValueError(f"run_id must be passed. Passed run_id={run_id}") query = session.query(cls).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id) if map_index is not None: query = query.filter_by(map_index=map_index) for xcom in query: # print(f"Clearing XCOM {xcom} with value {xcom.value}") session.delete(xcom) session.commit()
@classmethod @provide_session
[docs] def set( cls, key: str, value: Any, *, dag_id: str, task_id: str, run_id: str, map_index: int = -1, session: Session = NEW_SESSION, ) -> None: """ Store an XCom value. :param key: Key to store the XCom. :param value: XCom value to store. :param dag_id: DAG ID. :param task_id: Task ID. :param run_id: DAG run ID for the task. :param map_index: Optional map index to assign XCom for a mapped task. The default is ``-1`` (set for a non-mapped task). :param session: Database session. If not given, a new session will be created for this function. """ from airflow.models.dagrun import DagRun if not run_id: raise ValueError(f"run_id must be passed. Passed run_id={run_id}") dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar() if dag_run_id is None: raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}") # Seamlessly resolve LazySelectSequence to a list. This intends to work # as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but if # it's pushed into XCom, the user should be aware of the performance # implications, and this avoids leaking the implementation detail. if isinstance(value, LazySelectSequence): warning_message = ( "Coercing mapped lazy proxy %s from task %s (DAG %s, run %s) " "to list, which may degrade performance. Review resource " "requirements for this operation, and call list() to suppress " "this message. See Dynamic Task Mapping documentation for " "more information about lazy proxy objects." ) log.warning( warning_message, "return value" if key == XCOM_RETURN_KEY else f"value {key}", task_id, dag_id, run_id, ) value = list(value) value = cls.serialize_value( value=value, key=key, task_id=task_id, dag_id=dag_id, run_id=run_id, map_index=map_index, ) # Remove duplicate XComs and insert a new one. session.execute( delete(cls).where( cls.key == key, cls.run_id == run_id, cls.task_id == task_id, cls.dag_id == dag_id, cls.map_index == map_index, ) ) new = cast("Any", cls)( # Work around Mypy complaining model not defining '__init__'. dag_run_id=dag_run_id, key=key, value=value, run_id=run_id, task_id=task_id, dag_id=dag_id, map_index=map_index, ) session.add(new) session.flush()
@classmethod @provide_session
[docs] def get_many( cls, *, run_id: str, key: str | None = None, task_ids: str | Iterable[str] | None = None, dag_ids: str | Iterable[str] | None = None, map_indexes: int | Iterable[int] | None = None, include_prior_dates: bool = False, limit: int | None = None, session: Session = NEW_SESSION, ) -> Query: """ Composes a query to get one or more XCom entries. This function returns an SQLAlchemy query of full XCom objects. If you just want one stored value, use :meth:`get_one` instead. :param run_id: DAG run ID for the task. :param key: A key for the XComs. If provided, only XComs with matching keys will be returned. Pass *None* (default) to remove the filter. :param task_ids: Only XComs from task with matching IDs will be pulled. Pass *None* (default) to remove the filter. :param dag_ids: Only pulls XComs from specified DAGs. Pass *None* (default) to remove the filter. :param map_indexes: Only XComs from matching map indexes will be pulled. Pass *None* (default) to remove the filter. :param include_prior_dates: If *False* (default), only XComs from the specified DAG run are returned. If *True*, all matching XComs are returned regardless of the run it belongs to. :param session: Database session. If not given, a new session will be created for this function. :param limit: Limiting returning XComs """ from airflow.models.dagrun import DagRun if not run_id: raise ValueError(f"run_id must be passed. Passed run_id={run_id}") query = session.query(cls).join(XComModel.dag_run) if key: query = query.filter(XComModel.key == key) if is_container(task_ids): query = query.filter(cls.task_id.in_(task_ids)) elif task_ids is not None: query = query.filter(cls.task_id == task_ids) if is_container(dag_ids): query = query.filter(cls.dag_id.in_(dag_ids)) elif dag_ids is not None: query = query.filter(cls.dag_id == dag_ids) if isinstance(map_indexes, range) and map_indexes.step == 1: query = query.filter(cls.map_index >= map_indexes.start, cls.map_index < map_indexes.stop) elif is_container(map_indexes): query = query.filter(cls.map_index.in_(map_indexes)) elif map_indexes is not None: query = query.filter(cls.map_index == map_indexes) if include_prior_dates: dr = ( session.query( func.coalesce(DagRun.logical_date, DagRun.run_after).label("logical_date_or_run_after") ) .filter(DagRun.run_id == run_id) .subquery() ) query = query.filter( func.coalesce(DagRun.logical_date, DagRun.run_after) <= dr.c.logical_date_or_run_after ) else: query = query.filter(cls.run_id == run_id) query = query.order_by(DagRun.logical_date.desc(), cls.timestamp.desc()) if limit: return query.limit(limit) return query
@staticmethod
[docs] def serialize_value( value: Any, *, key: str | None = None, task_id: str | None = None, dag_id: str | None = None, run_id: str | None = None, map_index: int | None = None, ) -> str: """Serialize XCom value to JSON str.""" try: return json.dumps(value, cls=XComEncoder) except (ValueError, TypeError): raise ValueError("XCom value must be JSON serializable")
@staticmethod
[docs] def deserialize_value(result) -> Any: """ Deserialize XCom value from a database result. If deserialization fails, the raw value is returned, which must still be a valid Python JSON-compatible type (e.g., ``dict``, ``list``, ``str``, ``int``, ``float``, or ``bool``). XCom values are stored as JSON in the database, and SQLAlchemy automatically handles serialization (``json.dumps``) and deserialization (``json.loads``). However, we use a custom encoder for serialization (``serialize_value``) and deserialization to handle special cases, such as encoding tuples via the Airflow Serialization module. These must be decoded using ``XComDecoder`` to restore original types. Some XCom values, such as those set via the Task Execution API, bypass ``serialize_value`` and are stored directly in JSON format. Since these values are already deserialized by SQLAlchemy, they are returned as-is. **Example: Handling a tuple**: .. code-block:: python original_value = (1, 2, 3) serialized_value = XComModel.serialize_value(original_value) print(serialized_value) # '{"__classname__": "builtins.tuple", "__version__": 1, "__data__": [1, 2, 3]}' This serialized value is stored in the database. When deserialized, the value is restored to the original tuple. :param result: The XCom database row or object containing a ``value`` attribute. :return: The deserialized Python object. """ if result.value is None: return None try: return json.loads(result.value, cls=XComDecoder) except (ValueError, TypeError): # Already deserialized (e.g., set via Task Execution API) return result.value
class LazyXComSelectSequence(LazySelectSequence[Any]): """ List-like interface to lazily access XCom values. :meta private: """ @staticmethod def _rebuild_select(stmt: TextClause) -> Select: return select(XComModel.value).from_statement(stmt) @staticmethod def _process_row(row: Row) -> Any: return XComModel.deserialize_value(row)
[docs] def __getattr__(name: str): if name == "BaseXCom": from airflow.sdk.bases.xcom import BaseXCom globals()[name] = BaseXCom return BaseXCom if name == "XCom": from airflow.sdk.execution_time.xcom import XCom globals()[name] = XCom return XCom raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

Was this entry helpful?