## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportcollectionsimportcollections.abcimportdatetimeimportwarningsfromtypingimportTYPE_CHECKING,Any,ClassVar,Collection,Iterable,Iterator,Mapping,Sequence,Unionimportattrimportpendulumfromsqlalchemyimportfunc,or_fromsqlalchemy.orm.sessionimportSessionfromairflowimportsettingsfromairflow.compat.functoolsimportcache,cached_propertyfromairflow.exceptionsimportAirflowException,UnmappableOperatorfromairflow.models.abstractoperatorimport(DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,DEFAULT_OWNER,DEFAULT_POOL_SLOTS,DEFAULT_PRIORITY_WEIGHT,DEFAULT_QUEUE,DEFAULT_RETRIES,DEFAULT_RETRY_DELAY,DEFAULT_TRIGGER_RULE,DEFAULT_WEIGHT_RULE,AbstractOperator,TaskStateChangeCallback,)fromairflow.models.expandinputimport(DictOfListsExpandInput,ExpandInput,ListOfDictsExpandInput,NotFullyPopulated,OperatorExpandArgument,OperatorExpandKwargsArgument,get_mappable_types,)fromairflow.models.poolimportPoolfromairflow.serialization.enumsimportDagAttributeTypesfromairflow.ti_deps.deps.base_ti_depimportBaseTIDepfromairflow.ti_deps.deps.mapped_task_expandedimportMappedTaskIsExpandedfromairflow.typing_compatimportLiteralfromairflow.utils.contextimportContextfromairflow.utils.helpersimportis_containerfromairflow.utils.operator_resourcesimportResourcesfromairflow.utils.stateimportState,TaskInstanceStatefromairflow.utils.trigger_ruleimportTriggerRulefromairflow.utils.typesimportNOTSETifTYPE_CHECKING:importjinja2# Slow import.fromairflow.models.baseoperatorimportBaseOperator,BaseOperatorLinkfromairflow.models.dagimportDAGfromairflow.models.operatorimportOperatorfromairflow.models.taskinstanceimportTaskInstancefromairflow.models.xcom_argimportXComArgfromairflow.utils.task_groupimportTaskGroup
[docs]defvalidate_mapping_kwargs(op:type[BaseOperator],func:ValidationSource,value:dict[str,Any])->None:# use a dict so order of args is same as code orderunknown_args=value.copy()forklassinop.mro():init=klass.__init__# type: ignore[misc]try:param_names=init._BaseOperatorMeta__param_namesexceptAttributeError:continuefornameinparam_names:value=unknown_args.pop(name,NOTSET)iffunc!="expand":continueifvalueisNOTSET:continueifisinstance(value,get_mappable_types()):continuetype_name=type(value).__name__error=f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}"raiseValueError(error)ifnotunknown_args:return# If we have no args left to check: stop looking at the MRO chain.iflen(unknown_args)==1:error=f"an unexpected keyword argument {unknown_args.popitem()[0]!r}"else:names=", ".join(repr(n)forninunknown_args)error=f"unexpected keyword arguments {names}"raiseTypeError(f"{op.__name__}.{func}() got {error}")
[docs]defensure_xcomarg_return_value(arg:Any)->None:fromairflow.models.xcom_argimportXCOM_RETURN_KEY,XComArgifisinstance(arg,XComArg):foroperator,keyinarg.iter_references():ifkey!=XCOM_RETURN_KEY:raiseValueError(f"cannot map over XCom with custom key {key!r} from {operator}")elifnotis_container(arg):returnelifisinstance(arg,collections.abc.Mapping):forvinarg.values():ensure_xcomarg_return_value(v)elifisinstance(arg,collections.abc.Iterable):forvinarg:ensure_xcomarg_return_value(v)
@attr.define(kw_only=True,repr=False)
[docs]classOperatorPartial:"""An "intermediate state" returned by ``BaseOperator.partial()``. This only exists at DAG-parsing time; the only intended usage is for the user to call ``.expand()`` on it at some point (usually in a method chain) to create a ``MappedOperator`` to add into the DAG. """
_expand_called:bool=False# Set when expand() is called to ease user debugging.
[docs]def__attrs_post_init__(self):fromairflow.operators.subdagimportSubDagOperatorifissubclass(self.operator_class,SubDagOperator):raiseTypeError("Mapping over deprecated SubDagOperator is not supported")validate_mapping_kwargs(self.operator_class,"partial",self.kwargs)
[docs]def__del__(self):ifnotself._expand_called:try:task_id=repr(self.kwargs["task_id"])exceptKeyError:task_id=f"at {hex(id(self))}"warnings.warn(f"Task {task_id} was never mapped!")
[docs]defexpand(self,**mapped_kwargs:OperatorExpandArgument)->MappedOperator:ifnotmapped_kwargs:raiseTypeError("no arguments to expand against")validate_mapping_kwargs(self.operator_class,"expand",mapped_kwargs)prevent_duplicates(self.kwargs,mapped_kwargs,fail_reason="unmappable or already specified")# Since the input is already checked at parse time, we can set strict# to False to skip the checks on execution.returnself._expand(DictOfListsExpandInput(mapped_kwargs),strict=False)
[docs]defexpand_kwargs(self,kwargs:OperatorExpandKwargsArgument,*,strict:bool=True)->MappedOperator:fromairflow.models.xcom_argimportXComArgifisinstance(kwargs,collections.abc.Sequence):foriteminkwargs:ifnotisinstance(item,(XComArg,collections.abc.Mapping)):raiseTypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")elifnotisinstance(kwargs,XComArg):raiseTypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")returnself._expand(ListOfDictsExpandInput(kwargs),strict=strict)
def_expand(self,expand_input:ExpandInput,*,strict:bool)->MappedOperator:fromairflow.operators.emptyimportEmptyOperatorself._expand_called=Trueensure_xcomarg_return_value(expand_input.value)partial_kwargs=self.kwargs.copy()task_id=partial_kwargs.pop("task_id")params=partial_kwargs.pop("params")dag=partial_kwargs.pop("dag")task_group=partial_kwargs.pop("task_group")start_date=partial_kwargs.pop("start_date")end_date=partial_kwargs.pop("end_date")try:operator_name=self.operator_class.custom_operator_name# type: ignoreexceptAttributeError:operator_name=self.operator_class.__name__op=MappedOperator(operator_class=self.operator_class,expand_input=expand_input,partial_kwargs=partial_kwargs,task_id=task_id,params=params,deps=MappedOperator.deps_for(self.operator_class),operator_extra_links=self.operator_class.operator_extra_links,template_ext=self.operator_class.template_ext,template_fields=self.operator_class.template_fields,template_fields_renderers=self.operator_class.template_fields_renderers,ui_color=self.operator_class.ui_color,ui_fgcolor=self.operator_class.ui_fgcolor,is_empty=issubclass(self.operator_class,EmptyOperator),task_module=self.operator_class.__module__,task_type=self.operator_class.__name__,operator_name=operator_name,dag=dag,task_group=task_group,start_date=start_date,end_date=end_date,disallow_kwargs_override=strict,# For classic operators, this points to expand_input because kwargs# to BaseOperator.expand() contribute to operator arguments.expand_input_attr="expand_input",)returnop
@attr.define(kw_only=True,# Disable custom __getstate__ and __setstate__ generation since it interacts# badly with Airflow's DAG serialization and pickling. When a mapped task is# deserialized, subclasses are coerced into MappedOperator, but when it goes# through DAG pickling, all attributes defined in the subclasses are dropped# by attrs's custom state management. Since attrs does not do anything too# special here (the logic is only important for slots=True), we use Python's# built-in implementation, which works (as proven by good old BaseOperator).getstate_setstate=False,
[docs])classMappedOperator(AbstractOperator):"""Object representing a mapped operator in a DAG."""# This attribute serves double purpose. For a "normal" operator instance# loaded from DAG, this holds the underlying non-mapped operator class that# can be used to create an unmapped operator for execution. For an operator# recreated from a serialized DAG, however, this holds the serialized data# that can be used to unmap this into a SerializedBaseOperator.
_disallow_kwargs_override:bool"""Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``. If *False*, values from ``expand_input`` under duplicate keys override those under corresponding keys in ``partial_kwargs``. """_expand_input_attr:str"""Where to get kwargs to calculate expansion length against. This should be a name to call ``getattr()`` on. """
[docs]def__attrs_post_init__(self):fromairflow.models.xcom_argimportXComArgifself.task_group:self.task_group.add(self)ifself.dag:self.dag.add_task(self)XComArg.apply_upstream_relationship(self,self.expand_input.value)fork,vinself.partial_kwargs.items():ifkinself.template_fields:XComArg.apply_upstream_relationship(self,v)ifself.partial_kwargs.get('sla')isnotNone:raiseAirflowException(f"SLAs are unsupported with mapped tasks. Please set `sla=None` for task "
f"{self.task_id!r}.")@classmethod@cache
[docs]defget_serialized_fields(cls):# Not using 'cls' here since we only want to serialize base fields.returnfrozenset(attr.fields_dict(MappedOperator))-{"dag","deps","is_mapped","expand_input",# This is needed to be able to accept XComArg."subdag","task_group","upstream_task_ids",
}@staticmethod@cache
[docs]defdeps_for(operator_class:type[BaseOperator])->frozenset[BaseTIDep]:operator_deps=operator_class.depsifnotisinstance(operator_deps,collections.abc.Set):raiseUnmappableOperator(f"'deps' must be a set defined as a class-level variable on {operator_class.__name__}, "f"not a {type(operator_deps).__name__}")returnoperator_deps|{MappedTaskIsExpanded()}
[docs]defoutput(self)->XComArg:"""Returns reference to XCom pushed by current operator"""fromairflow.models.xcom_argimportXComArgreturnXComArg(operator=self)
def_expand_mapped_kwargs(self,context:Context,session:Session)->tuple[Mapping[str,Any],set[int]]:"""Get the kwargs to create the unmapped operator. This exists because taskflow operators expand against op_kwargs, not the entire operator kwargs dict. """returnself._get_specified_expand_input().resolve(context,session)def_get_unmap_kwargs(self,mapped_kwargs:Mapping[str,Any],*,strict:bool)->dict[str,Any]:"""Get init kwargs to unmap the underlying operator class. :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``. """ifstrict:prevent_duplicates(self.partial_kwargs,mapped_kwargs,fail_reason="unmappable or already specified",)# Ordering is significant; mapped kwargs should override partial ones.return{"task_id":self.task_id,"dag":self.dag,"task_group":self.task_group,"params":self.params,"start_date":self.start_date,"end_date":self.end_date,**self.partial_kwargs,**mapped_kwargs,}
[docs]defunmap(self,resolve:None|Mapping[str,Any]|tuple[Context,Session])->BaseOperator:"""Get the "normal" Operator after applying the current mapping. The *resolve* argument is only used if ``operator_class`` is a real class, i.e. if this operator is not serialized. If ``operator_class`` is not a class (i.e. this DAG has been deserialized), this returns a SerializedBaseOperator that "looks like" the actual unmapping result. If *resolve* is a two-tuple (context, session), the information is used to resolve the mapped arguments into init arguments. If it is a mapping, no resolving happens, the mapping directly provides those init arguments resolved from mapped kwargs. :meta private: """ifisinstance(self.operator_class,type):ifisinstance(resolve,collections.abc.Mapping):kwargs=resolveelifresolveisnotNone:kwargs,_=self._expand_mapped_kwargs(*resolve)else:raiseRuntimeError("cannot unmap a non-serialized operator without context")kwargs=self._get_unmap_kwargs(kwargs,strict=self._disallow_kwargs_override)op=self.operator_class(**kwargs,_airflow_from_mapped=True)# We need to overwrite task_id here because BaseOperator further# mangles the task_id based on the task hierarchy (namely, group_id# is prepended, and '__N' appended to deduplicate). This is hacky,# but better than duplicating the whole mangling logic.op.task_id=self.task_idreturnop# After a mapped operator is serialized, there's no real way to actually# unmap it since we've lost access to the underlying operator class.# This tries its best to simply "forward" all the attributes on this# mapped operator to a new SerializedBaseOperator instance.fromairflow.serialization.serialized_objectsimportSerializedBaseOperatorop=SerializedBaseOperator(task_id=self.task_id,_airflow_from_mapped=True)SerializedBaseOperator.populate_operator(op,self.operator_class)returnop
def_get_specified_expand_input(self)->ExpandInput:"""Input received from the expand call on the operator."""returngetattr(self,self._expand_input_attr)
[docs]defexpand_mapped_task(self,run_id:str,*,session:Session)->tuple[Sequence[TaskInstance],int]:"""Create the mapped task instances for mapped task. :return: The newly created mapped TaskInstances (if any) in ascending order by map index, and the maximum map_index. """fromairflow.models.taskinstanceimportTaskInstancefromairflow.settingsimporttask_instance_mutation_hooktotal_length:int|Nonetry:total_length=self._get_specified_expand_input().get_total_map_length(run_id,session=session)exceptNotFullyPopulatedase:self.log.info("Cannot expand %r for run %s; missing upstream values: %s",self,run_id,sorted(e.missing),)total_length=Nonestate:TaskInstanceState|None=Noneunmapped_ti:TaskInstance|None=(session.query(TaskInstance).filter(TaskInstance.dag_id==self.dag_id,TaskInstance.task_id==self.task_id,TaskInstance.run_id==run_id,TaskInstance.map_index==-1,or_(TaskInstance.state.in_(State.unfinished),TaskInstance.state.is_(None)),).one_or_none())all_expanded_tis:list[TaskInstance]=[]ifunmapped_ti:# The unmapped task instance still exists and is unfinished, i.e. we# haven't tried to run it before.iftotal_lengthisNone:# If the map length cannot be calculated (due to unavailable# upstream sources), fail the unmapped task.unmapped_ti.state=TaskInstanceState.UPSTREAM_FAILEDindexes_to_map:Iterable[int]=()eliftotal_length<1:# If the upstream maps this to a zero-length value, simply mark# the unmapped task instance as SKIPPED (if needed).self.log.info("Marking %s as SKIPPED since the map has %d values to expand",unmapped_ti,total_length,)unmapped_ti.state=TaskInstanceState.SKIPPEDindexes_to_map=()else:# Otherwise convert this into the first mapped index, and create# TaskInstance for other indexes.unmapped_ti.map_index=0self.log.debug("Updated in place to become %s",unmapped_ti)all_expanded_tis.append(unmapped_ti)indexes_to_map=range(1,total_length)state=unmapped_ti.stateelifnottotal_length:# Nothing to fixup.indexes_to_map=()else:# Only create "missing" ones.current_max_mapping=(session.query(func.max(TaskInstance.map_index)).filter(TaskInstance.dag_id==self.dag_id,TaskInstance.task_id==self.task_id,TaskInstance.run_id==run_id,).scalar())indexes_to_map=range(current_max_mapping+1,total_length)forindexinindexes_to_map:# TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.ti=TaskInstance(self,run_id=run_id,map_index=index,state=state)self.log.debug("Expanding TIs upserted %s",ti)task_instance_mutation_hook(ti)ti=session.merge(ti)ti.refresh_from_task(self)# session.merge() loses task information.all_expanded_tis.append(ti)# Coerce the None case to 0 -- these two are almost treated identically,# except the unmapped ti (if exists) is marked to different states.total_expanded_ti_count=total_lengthor0# Set to "REMOVED" any (old) TaskInstances with map indices greater# than the current map valuesession.query(TaskInstance).filter(TaskInstance.dag_id==self.dag_id,TaskInstance.task_id==self.task_id,TaskInstance.run_id==run_id,TaskInstance.map_index>=total_expanded_ti_count,).update({TaskInstance.state:TaskInstanceState.REMOVED})session.flush()returnall_expanded_tis,total_expanded_ti_count-1
[docs]defprepare_for_execution(self)->MappedOperator:# Since a mapped operator cannot be used for execution, and an unmapped# BaseOperator needs to be created later (see render_template_fields),# we don't need to create a copy of the MappedOperator here.returnself
[docs]defiter_mapped_dependencies(self)->Iterator[Operator]:"""Upstream dependencies that provide XComs used by this task for task mapping."""fromairflow.models.xcom_argimportXComArgforrefinXComArg.iter_xcom_args(self._get_specified_expand_input()):foroperator,_inref.iter_references():yieldoperator
@cached_property
[docs]defparse_time_mapped_ti_count(self)->int|None:"""Number of mapped TaskInstances that can be created at DagRun create time. :return: None if non-literal mapped arg encountered, or the total number of mapped TIs this task should have. """returnself._get_specified_expand_input().get_parse_time_mapped_ti_count()
@cache
[docs]defrun_time_mapped_ti_count(self,run_id:str,*,session:Session)->int|None:"""Number of mapped TaskInstances that can be created at run time. :return: None if upstream tasks are not complete yet, or the total number of mapped TIs this task should have. """try:returnself._get_specified_expand_input().get_total_map_length(run_id,session=session)exceptNotFullyPopulated:returnNone
[docs]defrender_template_fields(self,context:Context,jinja_env:jinja2.Environment|None=None,)->BaseOperator|None:ifnotjinja_env:jinja_env=self.get_template_env()# Ideally we'd like to pass in session as an argument to this function,# but we can't easily change this function signature since operators# could override this. We can't use @provide_session since it closes and# expunges everything, which we don't want to do when we are so "deep"# in the weeds here. We don't close this session for the same reason.session=settings.Session()mapped_kwargs,seen_oids=self._expand_mapped_kwargs(context,session)unmapped_task=self.unmap(mapped_kwargs)self._do_render_template_fields(parent=unmapped_task,template_fields=self.template_fields,context=context,jinja_env=jinja_env,seen_oids=seen_oids,session=session,)returnunmapped_task