## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportjsonimportloggingfromcollections.abcimportIterablefromtypingimportTYPE_CHECKING,Any,castfromsqlalchemyimport(JSON,Column,ForeignKeyConstraint,Index,Integer,PrimaryKeyConstraint,String,delete,func,select,text,)fromsqlalchemy.dialectsimportpostgresqlfromsqlalchemy.ext.associationproxyimportassociation_proxyfromsqlalchemy.ormimportQuery,relationshipfromairflow.models.baseimportCOLLATION_ARGS,ID_LEN,TaskInstanceDependenciesfromairflow.utilsimporttimezonefromairflow.utils.dbimportLazySelectSequencefromairflow.utils.helpersimportis_containerfromairflow.utils.jsonimportXComDecoder,XComEncoderfromairflow.utils.sessionimportNEW_SESSION,provide_sessionfromairflow.utils.sqlalchemyimportUtcDateTime# XCom constants below are needed for providers backward compatibility,# which should import the constants directly after apache-airflow>=2.6.0fromairflow.utils.xcomimport(MAX_XCOM_SIZE,# noqa: F401XCOM_RETURN_KEY,)
[docs]__table_args__=(# Ideally we should create a unique index over (key, dag_id, task_id, run_id),# but it goes over MySQL's index length limit. So we instead index 'key'# separately, and enforce uniqueness with DagRun.id instead.Index("idx_xcom_key",key),Index("idx_xcom_task_instance",dag_id,task_id,run_id,map_index),PrimaryKeyConstraint("dag_run_id","task_id","map_index","key",name="xcom_pkey"),ForeignKeyConstraint([dag_id,task_id,run_id,map_index],["task_instance.dag_id","task_instance.task_id","task_instance.run_id","task_instance.map_index",],name="xcom_task_instance_fkey",ondelete="CASCADE",),)
[docs]defclear(cls,*,dag_id:str,task_id:str,run_id:str,map_index:int|None=None,session:Session=NEW_SESSION,)->None:""" Clear all XCom data from the database for the given task instance. .. note:: This **will not** purge any data from a custom XCom backend. :param dag_id: ID of DAG to clear the XCom for. :param task_id: ID of task to clear the XCom for. :param run_id: ID of DAG run to clear the XCom for. :param map_index: If given, only clear XCom from this particular mapped task. The default ``None`` clears *all* XComs from the task. :param session: Database session. If not given, a new session will be created for this function. """# Given the historic order of this function (logical_date was first argument) to add a new optional# param we need to add default values for everything :(ifdag_idisNone:raiseTypeError("clear() missing required argument: dag_id")iftask_idisNone:raiseTypeError("clear() missing required argument: task_id")ifnotrun_id:raiseValueError(f"run_id must be passed. Passed run_id={run_id}")query=session.query(cls).filter_by(dag_id=dag_id,task_id=task_id,run_id=run_id)ifmap_indexisnotNone:query=query.filter_by(map_index=map_index)forxcominquery:# print(f"Clearing XCOM {xcom} with value {xcom.value}")session.delete(xcom)session.commit()
@classmethod@provide_session
[docs]defset(cls,key:str,value:Any,*,dag_id:str,task_id:str,run_id:str,map_index:int=-1,session:Session=NEW_SESSION,)->None:""" Store an XCom value. :param key: Key to store the XCom. :param value: XCom value to store. :param dag_id: DAG ID. :param task_id: Task ID. :param run_id: DAG run ID for the task. :param map_index: Optional map index to assign XCom for a mapped task. The default is ``-1`` (set for a non-mapped task). :param session: Database session. If not given, a new session will be created for this function. """fromairflow.models.dagrunimportDagRunifnotrun_id:raiseValueError(f"run_id must be passed. Passed run_id={run_id}")dag_run_id=session.query(DagRun.id).filter_by(dag_id=dag_id,run_id=run_id).scalar()ifdag_run_idisNone:raiseValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}")# Seamlessly resolve LazySelectSequence to a list. This intends to work# as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but if# it's pushed into XCom, the user should be aware of the performance# implications, and this avoids leaking the implementation detail.ifisinstance(value,LazySelectSequence):warning_message=("Coercing mapped lazy proxy %s from task %s (DAG %s, run %s) ""to list, which may degrade performance. Review resource ""requirements for this operation, and call list() to suppress ""this message. See Dynamic Task Mapping documentation for ""more information about lazy proxy objects.")log.warning(warning_message,"return value"ifkey==XCOM_RETURN_KEYelsef"value {key}",task_id,dag_id,run_id,)value=list(value)value=cls.serialize_value(value=value,key=key,task_id=task_id,dag_id=dag_id,run_id=run_id,map_index=map_index,)# Remove duplicate XComs and insert a new one.session.execute(delete(cls).where(cls.key==key,cls.run_id==run_id,cls.task_id==task_id,cls.dag_id==dag_id,cls.map_index==map_index,))new=cast("Any",cls)(# Work around Mypy complaining model not defining '__init__'.dag_run_id=dag_run_id,key=key,value=value,run_id=run_id,task_id=task_id,dag_id=dag_id,map_index=map_index,)session.add(new)session.flush()
@classmethod@provide_session
[docs]defget_many(cls,*,run_id:str,key:str|None=None,task_ids:str|Iterable[str]|None=None,dag_ids:str|Iterable[str]|None=None,map_indexes:int|Iterable[int]|None=None,include_prior_dates:bool=False,limit:int|None=None,session:Session=NEW_SESSION,)->Query:""" Composes a query to get one or more XCom entries. This function returns an SQLAlchemy query of full XCom objects. If you just want one stored value, use :meth:`get_one` instead. :param run_id: DAG run ID for the task. :param key: A key for the XComs. If provided, only XComs with matching keys will be returned. Pass *None* (default) to remove the filter. :param task_ids: Only XComs from task with matching IDs will be pulled. Pass *None* (default) to remove the filter. :param dag_ids: Only pulls XComs from specified DAGs. Pass *None* (default) to remove the filter. :param map_indexes: Only XComs from matching map indexes will be pulled. Pass *None* (default) to remove the filter. :param include_prior_dates: If *False* (default), only XComs from the specified DAG run are returned. If *True*, all matching XComs are returned regardless of the run it belongs to. :param session: Database session. If not given, a new session will be created for this function. :param limit: Limiting returning XComs """fromairflow.models.dagrunimportDagRunifnotrun_id:raiseValueError(f"run_id must be passed. Passed run_id={run_id}")query=session.query(cls).join(XComModel.dag_run)ifkey:query=query.filter(XComModel.key==key)ifis_container(task_ids):query=query.filter(cls.task_id.in_(task_ids))eliftask_idsisnotNone:query=query.filter(cls.task_id==task_ids)ifis_container(dag_ids):query=query.filter(cls.dag_id.in_(dag_ids))elifdag_idsisnotNone:query=query.filter(cls.dag_id==dag_ids)ifisinstance(map_indexes,range)andmap_indexes.step==1:query=query.filter(cls.map_index>=map_indexes.start,cls.map_index<map_indexes.stop)elifis_container(map_indexes):query=query.filter(cls.map_index.in_(map_indexes))elifmap_indexesisnotNone:query=query.filter(cls.map_index==map_indexes)ifinclude_prior_dates:dr=(session.query(func.coalesce(DagRun.logical_date,DagRun.run_after).label("logical_date_or_run_after")).filter(DagRun.run_id==run_id).subquery())query=query.filter(func.coalesce(DagRun.logical_date,DagRun.run_after)<=dr.c.logical_date_or_run_after)else:query=query.filter(cls.run_id==run_id)query=query.order_by(DagRun.logical_date.desc(),cls.timestamp.desc())iflimit:returnquery.limit(limit)returnquery
@staticmethod
[docs]defserialize_value(value:Any,*,key:str|None=None,task_id:str|None=None,dag_id:str|None=None,run_id:str|None=None,map_index:int|None=None,)->str:"""Serialize XCom value to JSON str."""try:returnjson.dumps(value,cls=XComEncoder)except(ValueError,TypeError):raiseValueError("XCom value must be JSON serializable")
@staticmethod
[docs]defdeserialize_value(result)->Any:""" Deserialize XCom value from a database result. If deserialization fails, the raw value is returned, which must still be a valid Python JSON-compatible type (e.g., ``dict``, ``list``, ``str``, ``int``, ``float``, or ``bool``). XCom values are stored as JSON in the database, and SQLAlchemy automatically handles serialization (``json.dumps``) and deserialization (``json.loads``). However, we use a custom encoder for serialization (``serialize_value``) and deserialization to handle special cases, such as encoding tuples via the Airflow Serialization module. These must be decoded using ``XComDecoder`` to restore original types. Some XCom values, such as those set via the Task Execution API, bypass ``serialize_value`` and are stored directly in JSON format. Since these values are already deserialized by SQLAlchemy, they are returned as-is. **Example: Handling a tuple**: .. code-block:: python original_value = (1, 2, 3) serialized_value = XComModel.serialize_value(original_value) print(serialized_value) # '{"__classname__": "builtins.tuple", "__version__": 1, "__data__": [1, 2, 3]}' This serialized value is stored in the database. When deserialized, the value is restored to the original tuple. :param result: The XCom database row or object containing a ``value`` attribute. :return: The deserialized Python object. """ifresult.valueisNone:returnNonetry:returnjson.loads(result.value,cls=XComDecoder)except(ValueError,TypeError):# Already deserialized (e.g., set via Task Execution API)returnresult.value
[docs]def__getattr__(name:str):ifname=="BaseXCom":fromairflow.sdk.bases.xcomimportBaseXComglobals()[name]=BaseXComreturnBaseXComifname=="XCom":fromairflow.sdk.execution_time.xcomimportXComglobals()[name]=XComreturnXComraiseAttributeError(f"module {__name__!r} has no attribute {name!r}")