## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportcollectionsimportcollections.abcimportcontextlibimportcopyimportdatetimeimportwarningsfromtypingimportTYPE_CHECKING,Any,ClassVar,Collection,Iterable,Iterator,Mapping,Sequence,Unionimportattrimportpendulumfromsqlalchemy.orm.sessionimportSessionfromairflowimportsettingsfromairflow.compat.functoolsimportcachefromairflow.exceptionsimportAirflowException,UnmappableOperatorfromairflow.models.abstractoperatorimport(DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,DEFAULT_OWNER,DEFAULT_POOL_SLOTS,DEFAULT_PRIORITY_WEIGHT,DEFAULT_QUEUE,DEFAULT_RETRIES,DEFAULT_RETRY_DELAY,DEFAULT_TRIGGER_RULE,DEFAULT_WEIGHT_RULE,AbstractOperator,NotMapped,TaskStateChangeCallback,)fromairflow.models.expandinputimport(DictOfListsExpandInput,ExpandInput,ListOfDictsExpandInput,OperatorExpandArgument,OperatorExpandKwargsArgument,is_mappable,)fromairflow.models.paramimportParamsDictfromairflow.models.poolimportPoolfromairflow.serialization.enumsimportDagAttributeTypesfromairflow.ti_deps.deps.base_ti_depimportBaseTIDepfromairflow.ti_deps.deps.mapped_task_expandedimportMappedTaskIsExpandedfromairflow.typing_compatimportLiteralfromairflow.utils.contextimportContext,context_update_for_unmappedfromairflow.utils.helpersimportis_container,prevent_duplicatesfromairflow.utils.operator_resourcesimportResourcesfromairflow.utils.trigger_ruleimportTriggerRulefromairflow.utils.typesimportNOTSETifTYPE_CHECKING:importjinja2# Slow import.fromairflow.models.baseoperatorimportBaseOperator,BaseOperatorLinkfromairflow.models.dagimportDAGfromairflow.models.operatorimportOperatorfromairflow.models.xcom_argimportXComArgfromairflow.utils.task_groupimportTaskGroup
[docs]defvalidate_mapping_kwargs(op:type[BaseOperator],func:ValidationSource,value:dict[str,Any])->None:# use a dict so order of args is same as code orderunknown_args=value.copy()forklassinop.mro():init=klass.__init__# type: ignore[misc]try:param_names=init._BaseOperatorMeta__param_namesexceptAttributeError:continuefornameinparam_names:value=unknown_args.pop(name,NOTSET)iffunc!="expand":continueifvalueisNOTSET:continueifis_mappable(value):continuetype_name=type(value).__name__error=f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}"raiseValueError(error)ifnotunknown_args:return# If we have no args left to check: stop looking at the MRO chain.iflen(unknown_args)==1:error=f"an unexpected keyword argument {unknown_args.popitem()[0]!r}"else:names=", ".join(repr(n)forninunknown_args)error=f"unexpected keyword arguments {names}"raiseTypeError(f"{op.__name__}.{func}() got {error}")
[docs]defensure_xcomarg_return_value(arg:Any)->None:fromairflow.models.xcom_argimportXCOM_RETURN_KEY,XComArgifisinstance(arg,XComArg):foroperator,keyinarg.iter_references():ifkey!=XCOM_RETURN_KEY:raiseValueError(f"cannot map over XCom with custom key {key!r} from {operator}")elifnotis_container(arg):returnelifisinstance(arg,collections.abc.Mapping):forvinarg.values():ensure_xcomarg_return_value(v)elifisinstance(arg,collections.abc.Iterable):forvinarg:ensure_xcomarg_return_value(v)
@attr.define(kw_only=True,repr=False)
[docs]classOperatorPartial:"""An "intermediate state" returned by ``BaseOperator.partial()``. This only exists at DAG-parsing time; the only intended usage is for the user to call ``.expand()`` on it at some point (usually in a method chain) to create a ``MappedOperator`` to add into the DAG. """
_expand_called:bool=False# Set when expand() is called to ease user debugging.
[docs]def__attrs_post_init__(self):fromairflow.operators.subdagimportSubDagOperatorifissubclass(self.operator_class,SubDagOperator):raiseTypeError("Mapping over deprecated SubDagOperator is not supported")validate_mapping_kwargs(self.operator_class,"partial",self.kwargs)
[docs]def__del__(self):ifnotself._expand_called:try:task_id=repr(self.kwargs["task_id"])exceptKeyError:task_id=f"at {hex(id(self))}"warnings.warn(f"Task {task_id} was never mapped!")
[docs]defexpand(self,**mapped_kwargs:OperatorExpandArgument)->MappedOperator:ifnotmapped_kwargs:raiseTypeError("no arguments to expand against")validate_mapping_kwargs(self.operator_class,"expand",mapped_kwargs)prevent_duplicates(self.kwargs,mapped_kwargs,fail_reason="unmappable or already specified")# Since the input is already checked at parse time, we can set strict# to False to skip the checks on execution.returnself._expand(DictOfListsExpandInput(mapped_kwargs),strict=False)
[docs]defexpand_kwargs(self,kwargs:OperatorExpandKwargsArgument,*,strict:bool=True)->MappedOperator:fromairflow.models.xcom_argimportXComArgifisinstance(kwargs,collections.abc.Sequence):foriteminkwargs:ifnotisinstance(item,(XComArg,collections.abc.Mapping)):raiseTypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")elifnotisinstance(kwargs,XComArg):raiseTypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")returnself._expand(ListOfDictsExpandInput(kwargs),strict=strict)
def_expand(self,expand_input:ExpandInput,*,strict:bool)->MappedOperator:fromairflow.operators.emptyimportEmptyOperatorself._expand_called=Trueensure_xcomarg_return_value(expand_input.value)partial_kwargs=self.kwargs.copy()task_id=partial_kwargs.pop("task_id")dag=partial_kwargs.pop("dag")task_group=partial_kwargs.pop("task_group")start_date=partial_kwargs.pop("start_date")end_date=partial_kwargs.pop("end_date")try:operator_name=self.operator_class.custom_operator_name# type: ignoreexceptAttributeError:operator_name=self.operator_class.__name__op=MappedOperator(operator_class=self.operator_class,expand_input=expand_input,partial_kwargs=partial_kwargs,task_id=task_id,params=self.params,deps=MappedOperator.deps_for(self.operator_class),operator_extra_links=self.operator_class.operator_extra_links,template_ext=self.operator_class.template_ext,template_fields=self.operator_class.template_fields,template_fields_renderers=self.operator_class.template_fields_renderers,ui_color=self.operator_class.ui_color,ui_fgcolor=self.operator_class.ui_fgcolor,is_empty=issubclass(self.operator_class,EmptyOperator),task_module=self.operator_class.__module__,task_type=self.operator_class.__name__,operator_name=operator_name,dag=dag,task_group=task_group,start_date=start_date,end_date=end_date,disallow_kwargs_override=strict,# For classic operators, this points to expand_input because kwargs# to BaseOperator.expand() contribute to operator arguments.expand_input_attr="expand_input",)returnop
@attr.define(kw_only=True,# Disable custom __getstate__ and __setstate__ generation since it interacts# badly with Airflow's DAG serialization and pickling. When a mapped task is# deserialized, subclasses are coerced into MappedOperator, but when it goes# through DAG pickling, all attributes defined in the subclasses are dropped# by attrs's custom state management. Since attrs does not do anything too# special here (the logic is only important for slots=True), we use Python's# built-in implementation, which works (as proven by good old BaseOperator).getstate_setstate=False,
[docs])classMappedOperator(AbstractOperator):"""Object representing a mapped operator in a DAG."""# This attribute serves double purpose. For a "normal" operator instance# loaded from DAG, this holds the underlying non-mapped operator class that# can be used to create an unmapped operator for execution. For an operator# recreated from a serialized DAG, however, this holds the serialized data# that can be used to unmap this into a SerializedBaseOperator.
_disallow_kwargs_override:bool"""Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``. If *False*, values from ``expand_input`` under duplicate keys override those under corresponding keys in ``partial_kwargs``. """_expand_input_attr:str"""Where to get kwargs to calculate expansion length against. This should be a name to call ``getattr()`` on. """
[docs]subdag:None=None# Since we don't support SubDagOperator, this is always None.
[docs]def__attrs_post_init__(self):fromairflow.models.xcom_argimportXComArgifself.get_closest_mapped_task_group()isnotNone:raiseNotImplementedError("operator expansion in an expanded task group is not yet supported")ifself.task_group:self.task_group.add(self)ifself.dag:self.dag.add_task(self)XComArg.apply_upstream_relationship(self,self.expand_input.value)fork,vinself.partial_kwargs.items():ifkinself.template_fields:XComArg.apply_upstream_relationship(self,v)ifself.partial_kwargs.get("sla")isnotNone:raiseAirflowException(f"SLAs are unsupported with mapped tasks. Please set `sla=None` for task "
f"{self.task_id!r}.")@classmethod@cache
[docs]defget_serialized_fields(cls):# Not using 'cls' here since we only want to serialize base fields.returnfrozenset(attr.fields_dict(MappedOperator))-{"dag","deps","expand_input",# This is needed to be able to accept XComArg."subdag","task_group","upstream_task_ids",
}@staticmethod@cache
[docs]defdeps_for(operator_class:type[BaseOperator])->frozenset[BaseTIDep]:operator_deps=operator_class.depsifnotisinstance(operator_deps,collections.abc.Set):raiseUnmappableOperator(f"'deps' must be a set defined as a class-level variable on {operator_class.__name__}, "f"not a {type(operator_deps).__name__}")returnoperator_deps|{MappedTaskIsExpanded()}
[docs]defoutput(self)->XComArg:"""Returns reference to XCom pushed by current operator"""fromairflow.models.xcom_argimportXComArgreturnXComArg(operator=self)
def_expand_mapped_kwargs(self,context:Context,session:Session)->tuple[Mapping[str,Any],set[int]]:"""Get the kwargs to create the unmapped operator. This exists because taskflow operators expand against op_kwargs, not the entire operator kwargs dict. """returnself._get_specified_expand_input().resolve(context,session)def_get_unmap_kwargs(self,mapped_kwargs:Mapping[str,Any],*,strict:bool)->dict[str,Any]:"""Get init kwargs to unmap the underlying operator class. :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``. """ifstrict:prevent_duplicates(self.partial_kwargs,mapped_kwargs,fail_reason="unmappable or already specified",)# If params appears in the mapped kwargs, we need to merge it into the# partial params, overriding existing keys.params=copy.copy(self.params)withcontextlib.suppress(KeyError):params.update(mapped_kwargs["params"])# Ordering is significant; mapped kwargs should override partial ones,# and the specially handled params should be respected.return{"task_id":self.task_id,"dag":self.dag,"task_group":self.task_group,"start_date":self.start_date,"end_date":self.end_date,**self.partial_kwargs,**mapped_kwargs,"params":params,}
[docs]defunmap(self,resolve:None|Mapping[str,Any]|tuple[Context,Session])->BaseOperator:"""Get the "normal" Operator after applying the current mapping. The *resolve* argument is only used if ``operator_class`` is a real class, i.e. if this operator is not serialized. If ``operator_class`` is not a class (i.e. this DAG has been deserialized), this returns a SerializedBaseOperator that "looks like" the actual unmapping result. If *resolve* is a two-tuple (context, session), the information is used to resolve the mapped arguments into init arguments. If it is a mapping, no resolving happens, the mapping directly provides those init arguments resolved from mapped kwargs. :meta private: """ifisinstance(self.operator_class,type):ifisinstance(resolve,collections.abc.Mapping):kwargs=resolveelifresolveisnotNone:kwargs,_=self._expand_mapped_kwargs(*resolve)else:raiseRuntimeError("cannot unmap a non-serialized operator without context")kwargs=self._get_unmap_kwargs(kwargs,strict=self._disallow_kwargs_override)op=self.operator_class(**kwargs,_airflow_from_mapped=True)# We need to overwrite task_id here because BaseOperator further# mangles the task_id based on the task hierarchy (namely, group_id# is prepended, and '__N' appended to deduplicate). This is hacky,# but better than duplicating the whole mangling logic.op.task_id=self.task_idreturnop# After a mapped operator is serialized, there's no real way to actually# unmap it since we've lost access to the underlying operator class.# This tries its best to simply "forward" all the attributes on this# mapped operator to a new SerializedBaseOperator instance.fromairflow.serialization.serialized_objectsimportSerializedBaseOperatorop=SerializedBaseOperator(task_id=self.task_id,params=self.params,_airflow_from_mapped=True)SerializedBaseOperator.populate_operator(op,self.operator_class)returnop
def_get_specified_expand_input(self)->ExpandInput:"""Input received from the expand call on the operator."""returngetattr(self,self._expand_input_attr)
[docs]defprepare_for_execution(self)->MappedOperator:# Since a mapped operator cannot be used for execution, and an unmapped# BaseOperator needs to be created later (see render_template_fields),# we don't need to create a copy of the MappedOperator here.returnself
[docs]defiter_mapped_dependencies(self)->Iterator[Operator]:"""Upstream dependencies that provide XComs used by this task for task mapping."""fromairflow.models.xcom_argimportXComArgforoperator,_inXComArg.iter_xcom_references(self._get_specified_expand_input()):yieldoperator
[docs]defrender_template_fields(self,context:Context,jinja_env:jinja2.Environment|None=None,)->None:"""Template all attributes listed in *self.template_fields*. This updates *context* to reference the map-expanded task and relevant information, without modifying the mapped operator. The expanded task in *context* is then rendered in-place. :param context: Context dict with values to apply on content. :param jinja_env: Jinja environment to use for rendering. """ifnotjinja_env:jinja_env=self.get_template_env()# Ideally we'd like to pass in session as an argument to this function,# but we can't easily change this function signature since operators# could override this. We can't use @provide_session since it closes and# expunges everything, which we don't want to do when we are so "deep"# in the weeds here. We don't close this session for the same reason.session=settings.Session()mapped_kwargs,seen_oids=self._expand_mapped_kwargs(context,session)unmapped_task=self.unmap(mapped_kwargs)context_update_for_unmapped(context,unmapped_task)# Since the operators that extend `BaseOperator` are not subclasses of# `MappedOperator`, we need to call `_do_render_template_fields` from# the unmapped task in order to call the operator method when we override# it to customize the parsing of nested fields.unmapped_task._do_render_template_fields(parent=unmapped_task,template_fields=self.template_fields,context=context,jinja_env=jinja_env,seen_oids=seen_oids,session=session,