# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportcontextlibimportinspectfromtypingimportTYPE_CHECKING,Any,Callable,Iterator,Mapping,Sequence,Union,overloadfromsqlalchemyimportfuncfromsqlalchemy.ormimportSessionfromairflow.exceptionsimportXComNotFoundfromairflow.models.abstractoperatorimportAbstractOperatorfromairflow.models.mappedoperatorimportMappedOperatorfromairflow.models.taskmixinimportDAGNode,DependencyMixinfromairflow.models.xcomimportXCOM_RETURN_KEYfromairflow.utils.contextimportContextfromairflow.utils.edgemodifierimportEdgeModifierfromairflow.utils.mixinsimportResolveMixinfromairflow.utils.sessionimportNEW_SESSION,provide_sessionfromairflow.utils.typesimportNOTSET,ArgNotSetifTYPE_CHECKING:fromairflow.models.dagimportDAGfromairflow.models.operatorimportOperator# Callable objects contained by MapXComArg. We only accept callables from# the user, but deserialize them into strings in a serialized XComArg for# safety (those callables are arbitrary user code).
[docs]classXComArg(ResolveMixin,DependencyMixin):"""Reference to an XCom value pushed from another operator. The implementation supports:: xcomarg >> op xcomarg << op op >> xcomarg # By BaseOperator code op << xcomarg # By BaseOperator code **Example**: The moment you get a result from any operator (decorated or regular) you can :: any_op = AnyOperator() xcomarg = XComArg(any_op) # or equivalently xcomarg = any_op.output my_op = MyOperator() my_op >> xcomarg This object can be used in legacy Operators via Jinja. **Example**: You can make this result to be part of any generated string:: any_op = AnyOperator() xcomarg = any_op.output op1 = MyOperator(my_text_message=f"the value is {xcomarg}") op2 = MyOperator(my_text_message=f"the value is {xcomarg['topic']}") :param operator: Operator instance to which the XComArg references. :param key: Key used to pull the XCom value. Defaults to *XCOM_RETURN_KEY*, i.e. the referenced operator's return value. """@overloaddef__new__(cls:type[XComArg],operator:Operator,key:str=XCOM_RETURN_KEY)->XComArg:"""Called when the user writes ``XComArg(...)`` directly."""@overloaddef__new__(cls:type[XComArg])->XComArg:"""Called by Python internals from subclasses."""def__new__(cls,*args,**kwargs)->XComArg:ifclsisXComArg:returnPlainXComArg(*args,**kwargs)returnsuper().__new__(cls)@staticmethod
[docs]defiter_xcom_references(arg:Any)->Iterator[tuple[Operator,str]]:"""Return XCom references in an arbitrary value. Recursively traverse ``arg`` and look for XComArg instances in any collection objects, and instances with ``template_fields`` set. """ifisinstance(arg,ResolveMixin):yield fromarg.iter_references()elifisinstance(arg,(tuple,set,list)):foreleminarg:yield fromXComArg.iter_xcom_references(elem)elifisinstance(arg,dict):foreleminarg.values():yield fromXComArg.iter_xcom_references(elem)elifisinstance(arg,AbstractOperator):forattrinarg.template_fields:yield fromXComArg.iter_xcom_references(getattr(arg,attr))
@staticmethod
[docs]defapply_upstream_relationship(op:Operator,arg:Any):"""Set dependency for XComArgs. This looks for XComArg objects in ``arg`` "deeply" (looking inside collections objects and classes decorated with ``template_fields``), and sets the relationship to ``op`` on any found. """foroperator,_inXComArg.iter_xcom_references(arg):op.set_upstream(operator)
@property
[docs]defroots(self)->list[DAGNode]:"""Required by TaskMixin"""return[opforop,_inself.iter_references()]
@property
[docs]defleaves(self)->list[DAGNode]:"""Required by TaskMixin"""return[opforop,_inself.iter_references()]
[docs]defset_upstream(self,task_or_task_list:DependencyMixin|Sequence[DependencyMixin],edge_modifier:EdgeModifier|None=None,):"""Proxy to underlying operator set_upstream method. Required by TaskMixin."""foroperator,_inself.iter_references():operator.set_upstream(task_or_task_list,edge_modifier)
[docs]defset_downstream(self,task_or_task_list:DependencyMixin|Sequence[DependencyMixin],edge_modifier:EdgeModifier|None=None,):"""Proxy to underlying operator set_downstream method. Required by TaskMixin."""foroperator,_inself.iter_references():operator.set_downstream(task_or_task_list,edge_modifier)
def_serialize(self)->dict[str,Any]:"""Called by DAG serialization. The implementation should be the inverse function to ``deserialize``, returning a data dict converted from this XComArg derivative. DAG serialization does not call this directly, but ``serialize_xcom_arg`` instead, which adds additional information to dispatch deserialization to the correct class. """raiseNotImplementedError()@classmethoddef_deserialize(cls,data:dict[str,Any],dag:DAG)->XComArg:"""Called when deserializing a DAG. The implementation should be the inverse function to ``serialize``, implementing given a data dict converted from this XComArg derivative, how the original XComArg should be created. DAG serialization relies on additional information added in ``serialize_xcom_arg`` to dispatch data dicts to the correct ``_deserialize`` information, so this function does not need to validate whether the incoming data contains correct keys. """raiseNotImplementedError()
[docs]defget_task_map_length(self,run_id:str,*,session:Session)->int|None:"""Inspect length of pushed value for task-mapping. This is used to determine how many task instances the scheduler should create for a downstream using this XComArg for task-mapping. *None* may be returned if the depended XCom has not been pushed. """raiseNotImplementedError()
[docs]defresolve(self,context:Context,session:Session=NEW_SESSION)->Any:"""Pull XCom value. This should only be called during ``op.execute()`` with an appropriate context (e.g. generated from ``TaskInstance.get_template_context()``). Although the ``ResolveMixin`` parent mixin also has a ``resolve`` protocol, this adds the optional ``session`` argument that some of the subclasses need. :meta private: """raiseNotImplementedError()
[docs]classPlainXComArg(XComArg):"""Reference to one single XCom without any additional semantics. This class should not be accessed directly, but only through XComArg. The class inheritance chain and ``__new__`` is implemented in this slightly convoluted way because we want to a. Allow the user to continue using XComArg directly for the simple semantics (see documentation of the base class for details). b. Make ``isinstance(thing, XComArg)`` be able to detect all kinds of XCom references. c. Not allow many properties of PlainXComArg (including ``__getitem__`` and ``__str__``) to exist on other kinds of XComArg implementations since they don't make sense. :meta private: """def__init__(self,operator:Operator,key:str=XCOM_RETURN_KEY):self.operator=operatorself.key=key
[docs]def__getitem__(self,item:str)->XComArg:"""Implements xcomresult['some_result_key']"""ifnotisinstance(item,str):raiseValueError(f"XComArg only supports str lookup, received {type(item).__name__}")returnPlainXComArg(operator=self.operator,key=item)
[docs]def__iter__(self):"""Override iterable protocol to raise error explicitly. The default ``__iter__`` implementation in Python calls ``__getitem__`` with 0, 1, 2, etc. until it hits an ``IndexError``. This does not work well with our custom ``__getitem__`` implementation, and results in poor DAG-writing experience since a misplaced ``*`` expansion would create an infinite loop consuming the entire DAG parser. This override catches the error eagerly, so an incorrectly implemented DAG fails fast and avoids wasting resources on nonsensical iterating. """raiseTypeError("'XComArg' object is not iterable")
[docs]def__str__(self)->str:""" Backward compatibility for old-style jinja used in Airflow Operators **Example**: to use XComArg at BashOperator:: BashOperator(cmd=f"... { xcomarg } ...") :return: """xcom_pull_kwargs=[f"task_ids='{self.operator.task_id}'",f"dag_id='{self.operator.dag_id}'",]ifself.keyisnotNone:xcom_pull_kwargs.append(f"key='{self.key}'")xcom_pull_str=", ".join(xcom_pull_kwargs)# {{{{ are required for escape {{ in f-stringxcom_pull=f"{{{{ task_instance.xcom_pull({xcom_pull_str}) }}}}"returnxcom_pull
[docs]defmap(self,f:Callable[[Any],Any])->MapXComArg:ifself.key!=XCOM_RETURN_KEY:raiseValueError("cannot map against non-return XCom")returnsuper().map(f)
[docs]defzip(self,*others:XComArg,fillvalue:Any=NOTSET)->ZipXComArg:ifself.key!=XCOM_RETURN_KEY:raiseValueError("cannot map against non-return XCom")returnsuper().zip(*others,fillvalue=fillvalue)
def_get_callable_name(f:Callable|str)->str:"""Try to "describe" a callable by getting its name."""ifcallable(f):returnf.__name__# Parse the source to find whatever is behind "def". For safety, we don't# want to evaluate the code in any meaningful way!withcontextlib.suppress(Exception):kw,name,_=f.lstrip().split(None,2)ifkw=="def":returnnamereturn"<function>"class_MapResult(Sequence):def__init__(self,value:Sequence|dict,callables:MapCallables)->None:self.value=valueself.callables=callablesdef__getitem__(self,index:Any)->Any:value=self.value[index]# In the worker, we can access all actual callables. Call them.callables=[fforfinself.callablesifcallable(f)]iflen(callables)==len(self.callables):forfincallables:value=f(value)returnvalue# In the scheduler, we don't have access to the actual callables, nor do# we want to run it since it's arbitrary code. This builds a string to# represent the call chain in the UI or logs instead.forvinself.callables:value=f"{_get_callable_name(v)}({value})"returnvaluedef__len__(self)->int:returnlen(self.value)
[docs]classMapXComArg(XComArg):"""An XCom reference with ``map()`` call(s) applied. This is based on an XComArg, but also applies a series of "transforms" that convert the pulled XCom value. :meta private: """def__init__(self,arg:XComArg,callables:MapCallables)->None:forcincallables:ifgetattr(c,"_airflow_is_task_decorator",False):raiseValueError("map() argument must be a plain function, not a @task operator")self.arg=argself.callables=callables
def_serialize(self)->dict[str,Any]:return{"arg":serialize_xcom_arg(self.arg),"callables":[inspect.getsource(c)ifcallable(c)elsecforcinself.callables],}@classmethoddef_deserialize(cls,data:dict[str,Any],dag:DAG)->XComArg:# We are deliberately NOT deserializing the callables. These are shown# in the UI, and displaying a function object is useless.returncls(deserialize_xcom_arg(data["arg"],dag),data["callables"])
[docs]defresolve(self,context:Context,session:Session=NEW_SESSION)->Any:value=self.arg.resolve(context,session=session)ifnotisinstance(value,(Sequence,dict)):raiseValueError(f"XCom map expects sequence or dict, not {type(value).__name__}")return_MapResult(value,self.callables)
[docs]classZipXComArg(XComArg):"""An XCom reference with ``zip()`` applied. This is constructed from multiple XComArg instances, and presents an iterable that "zips" them together like the built-in ``zip()`` (and ``itertools.zip_longest()`` if ``fillvalue`` is provided). """def__init__(self,args:Sequence[XComArg],*,fillvalue:Any=NOTSET)->None:ifnotargs:raiseValueError("At least one input is required")self.args=argsself.fillvalue=fillvalue
[docs]defget_task_map_length(self,run_id:str,*,session:Session)->int|None:all_lengths=(arg.get_task_map_length(run_id,session=session)forarginself.args)ready_lengths=[lengthforlengthinall_lengthsiflengthisnotNone]iflen(ready_lengths)!=len(self.args):returnNone# If any of the referenced XComs is not ready, we are not ready either.ifisinstance(self.fillvalue,ArgNotSet):returnmin(ready_lengths)returnmax(ready_lengths)
@provide_session
[docs]defresolve(self,context:Context,session:Session=NEW_SESSION)->Any:values=[arg.resolve(context,session=session)forarginself.args]forvalueinvalues:ifnotisinstance(value,(Sequence,dict)):raiseValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}")return_ZipResult(values,fillvalue=self.fillvalue)