## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.importcollectionsimportcollections.abcimportdatetimeimportfunctoolsimportoperatorimportwarningsfromtypingimport(TYPE_CHECKING,Any,ClassVar,Collection,Dict,FrozenSet,Iterable,Iterator,List,Optional,Sequence,Set,Tuple,Type,Union,)importattrimportpendulumfromsqlalchemyimportfunc,or_fromsqlalchemy.orm.sessionimportSessionfromairflowimportsettingsfromairflow.compat.functoolsimportcache,cached_propertyfromairflow.exceptionsimportAirflowException,UnmappableOperatorfromairflow.models.abstractoperatorimport(DEFAULT_OWNER,DEFAULT_POOL_SLOTS,DEFAULT_PRIORITY_WEIGHT,DEFAULT_QUEUE,DEFAULT_RETRIES,DEFAULT_RETRY_DELAY,DEFAULT_TRIGGER_RULE,DEFAULT_WEIGHT_RULE,AbstractOperator,TaskStateChangeCallback,)fromairflow.models.poolimportPoolfromairflow.serialization.enumsimportDagAttributeTypesfromairflow.ti_deps.deps.base_ti_depimportBaseTIDepfromairflow.ti_deps.deps.mapped_task_expandedimportMappedTaskIsExpandedfromairflow.typing_compatimportLiteralfromairflow.utils.contextimportContextfromairflow.utils.helpersimportis_containerfromairflow.utils.operator_resourcesimportResourcesfromairflow.utils.stateimportState,TaskInstanceStatefromairflow.utils.trigger_ruleimportTriggerRulefromairflow.utils.typesimportNOTSETifTYPE_CHECKING:importjinja2# Slow import.fromairflow.models.baseoperatorimportBaseOperator,BaseOperatorLinkfromairflow.models.dagimportDAGfromairflow.models.operatorimportOperatorfromairflow.models.taskinstanceimportTaskInstancefromairflow.models.xcom_argimportXComArgfromairflow.utils.task_groupimportTaskGroup# BaseOperator.expand() can be called on an XComArg, sequence, or dict (not# any mapping since we need the value to be ordered).
[docs]defvalidate_mapping_kwargs(op:Type["BaseOperator"],func:ValidationSource,value:Dict[str,Any])->None:# use a dict so order of args is same as code orderunknown_args=value.copy()forklassinop.mro():init=klass.__init__# type: ignore[misc]try:param_names=init._BaseOperatorMeta__param_namesexceptAttributeError:continuefornameinparam_names:value=unknown_args.pop(name,NOTSET)iffunc!="expand":continueifvalueisNOTSET:continueifisinstance(value,get_mappable_types()):continuetype_name=type(value).__name__error=f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}"raiseValueError(error)ifnotunknown_args:return# If we have no args left to check: stop looking at the MRO chain.iflen(unknown_args)==1:error=f"an unexpected keyword argument {unknown_args.popitem()[0]!r}"else:names=", ".join(repr(n)forninunknown_args)error=f"unexpected keyword arguments {names}"raiseTypeError(f"{op.__name__}.{func}() got {error}")
[docs]defensure_xcomarg_return_value(arg:Any)->None:fromairflow.models.xcom_argimportXCOM_RETURN_KEY,XComArgifisinstance(arg,XComArg):ifarg.key!=XCOM_RETURN_KEY:raiseValueError(f"cannot map over XCom with custom key {arg.key!r} from {arg.operator}")elifnotis_container(arg):returnelifisinstance(arg,collections.abc.Mapping):forvinarg.values():ensure_xcomarg_return_value(v)elifisinstance(arg,collections.abc.Iterable):forvinarg:ensure_xcomarg_return_value(v)
@attr.define(kw_only=True,repr=False)
[docs]classOperatorPartial:"""An "intermediate state" returned by ``BaseOperator.partial()``. This only exists at DAG-parsing time; the only intended usage is for the user to call ``.expand()`` on it at some point (usually in a method chain) to create a ``MappedOperator`` to add into the DAG. """
_expand_called:bool=False# Set when expand() is called to ease user debugging.
[docs]def__attrs_post_init__(self):fromairflow.operators.subdagimportSubDagOperatorifissubclass(self.operator_class,SubDagOperator):raiseTypeError("Mapping over deprecated SubDagOperator is not supported")validate_mapping_kwargs(self.operator_class,"partial",self.kwargs)
[docs]def__del__(self):ifnotself._expand_called:try:task_id=repr(self.kwargs["task_id"])exceptKeyError:task_id=f"at {hex(id(self))}"warnings.warn(f"Task {task_id} was never mapped!")
[docs]defexpand(self,**mapped_kwargs:"Mappable")->"MappedOperator":ifnotmapped_kwargs:raiseTypeError("no arguments to expand against")returnself._expand(**mapped_kwargs)
def_expand(self,**mapped_kwargs:"Mappable")->"MappedOperator":self._expand_called=Truefromairflow.operators.emptyimportEmptyOperatorvalidate_mapping_kwargs(self.operator_class,"expand",mapped_kwargs)prevent_duplicates(self.kwargs,mapped_kwargs,fail_reason="mapping already partial")ensure_xcomarg_return_value(mapped_kwargs)partial_kwargs=self.kwargs.copy()task_id=partial_kwargs.pop("task_id")params=partial_kwargs.pop("params")dag=partial_kwargs.pop("dag")task_group=partial_kwargs.pop("task_group")start_date=partial_kwargs.pop("start_date")end_date=partial_kwargs.pop("end_date")op=MappedOperator(operator_class=self.operator_class,mapped_kwargs=mapped_kwargs,partial_kwargs=partial_kwargs,task_id=task_id,params=params,deps=MappedOperator.deps_for(self.operator_class),operator_extra_links=self.operator_class.operator_extra_links,template_ext=self.operator_class.template_ext,template_fields=self.operator_class.template_fields,template_fields_renderers=self.operator_class.template_fields_renderers,ui_color=self.operator_class.ui_color,ui_fgcolor=self.operator_class.ui_fgcolor,is_empty=issubclass(self.operator_class,EmptyOperator),task_module=self.operator_class.__module__,task_type=self.operator_class.__name__,dag=dag,task_group=task_group,start_date=start_date,end_date=end_date,# For classic operators, this points to mapped_kwargs because kwargs# to BaseOperator.expand() contribute to operator arguments.expansion_kwargs_attr="mapped_kwargs",)returnop
@attr.define(kw_only=True,# Disable custom __getstate__ and __setstate__ generation since it interacts# badly with Airflow's DAG serialization and pickling. When a mapped task is# deserialized, subclasses are coerced into MappedOperator, but when it goes# through DAG pickling, all attributes defined in the subclasses are dropped# by attrs's custom state management. Since attrs does not do anything too# special here (the logic is only important for slots=True), we use Python's# built-in implementation, which works (as proven by good old BaseOperator).getstate_setstate=False,
[docs])classMappedOperator(AbstractOperator):"""Object representing a mapped operator in a DAG."""# This attribute serves double purpose. For a "normal" operator instance# loaded from DAG, this holds the underlying non-mapped operator class that# can be used to create an unmapped operator for execution. For an operator# recreated from a serialized DAG, however, this holds the serialized data# that can be used to unmap this into a SerializedBaseOperator.
[docs]def__attrs_post_init__(self):fromairflow.models.xcom_argimportXComArgself._validate_argument_count()ifself.task_group:self.task_group.add(self)ifself.dag:self.dag.add_task(self)fork,vinself.mapped_kwargs.items():XComArg.apply_upstream_relationship(self,v)fork,vinself.partial_kwargs.items():ifkinself.template_fields:XComArg.apply_upstream_relationship(self,v)ifself.partial_kwargs.get('sla')isnotNone:raiseAirflowException(f"SLAs are unsupported with mapped tasks. Please set `sla=None` for task "
f"{self.task_id!r}.")@classmethod@cache
[docs]defget_serialized_fields(cls):# Not using 'cls' here since we only want to serialize base fields.returnfrozenset(attr.fields_dict(MappedOperator))-{"dag","deps","is_mapped","mapped_kwargs",# This is needed to be able to accept XComArg."subdag","task_group","upstream_task_ids",
}@staticmethod@cache
[docs]defdeps_for(operator_class:Type["BaseOperator"])->FrozenSet[BaseTIDep]:operator_deps=operator_class.depsifnotisinstance(operator_deps,collections.abc.Set):raiseUnmappableOperator(f"'deps' must be a set defined as a class-level variable on {operator_class.__name__}, "f"not a {type(operator_deps).__name__}")returnoperator_deps|{MappedTaskIsExpanded()}
def_validate_argument_count(self)->None:"""Validate mapping arguments by unmapping with mocked values. This ensures the user passed enough arguments in the DAG definition for the operator to work in the task runner. This does not guarantee the arguments are *valid* (that depends on the actual mapping values), but makes sure there are *enough* of them. """ifnotisinstance(self.operator_class,type):return# No need to validate deserialized operator.self.operator_class.validate_mapped_arguments(**self._get_unmap_kwargs())@property
[docs]defunmap(self,unmap_kwargs:Optional[Dict[str,Any]]=None)->"BaseOperator":""" Get the "normal" Operator after applying the current mapping. If ``operator_class`` is not a class (i.e. this DAG has been deserialized) then this will return a SerializedBaseOperator that aims to "look like" the real operator. :param unmap_kwargs: Override the args to pass to the Operator constructor. Only used when ``operator_class`` is still an actual class. :meta private: """ifisinstance(self.operator_class,type):# We can't simply specify task_id here because BaseOperator further# mangles the task_id based on the task hierarchy (namely, group_id# is prepended, and '__N' appended to deduplicate). Instead of# recreating the whole logic here, we just overwrite task_id later.ifunmap_kwargsisNone:unmap_kwargs=self._get_unmap_kwargs()op=self.operator_class(**unmap_kwargs,_airflow_from_mapped=True)op.task_id=self.task_idreturnop# After a mapped operator is serialized, there's no real way to actually# unmap it since we've lost access to the underlying operator class.# This tries its best to simply "forward" all the attributes on this# mapped operator to a new SerializedBaseOperator instance.fromairflow.serialization.serialized_objectsimportSerializedBaseOperatorop=SerializedBaseOperator(task_id=self.task_id,_airflow_from_mapped=True)SerializedBaseOperator.populate_operator(op,self.operator_class)returnop
def_get_expansion_kwargs(self)->Dict[str,"Mappable"]:"""The kwargs to calculate expansion length against."""returngetattr(self,self._expansion_kwargs_attr)def_get_map_lengths(self,run_id:str,*,session:Session)->Dict[str,int]:"""Return dict of argument name to map length. If any arguments are not known right now (upstream task not finished) they will not be present in the dict. """# TODO: Find a way to cache this.fromairflow.models.taskmapimportTaskMapfromairflow.models.xcomimportXComfromairflow.models.xcom_argimportXComArgexpansion_kwargs=self._get_expansion_kwargs()# Populate literal mapped arguments first.map_lengths:Dict[str,int]=collections.defaultdict(int)map_lengths.update((k,len(v))fork,vinexpansion_kwargs.items()ifnotisinstance(v,XComArg))# Build a reverse mapping of what arguments each task contributes to.mapped_dep_keys:Dict[str,Set[str]]=collections.defaultdict(set)non_mapped_dep_keys:Dict[str,Set[str]]=collections.defaultdict(set)fork,vinexpansion_kwargs.items():ifnotisinstance(v,XComArg):continueifv.operator.is_mapped:mapped_dep_keys[v.operator.task_id].add(k)else:non_mapped_dep_keys[v.operator.task_id].add(k)# TODO: It's not possible now, but in the future (AIP-42 Phase 2)# we will add support for depending on one single mapped task# instance. When that happens, we need to further analyze the mapped# case to contain only tasks we depend on "as a whole", and put# those we only depend on individually to the non-mapped lookup.# Collect lengths from unmapped upstreams.taskmap_query=session.query(TaskMap.task_id,TaskMap.length).filter(TaskMap.dag_id==self.dag_id,TaskMap.run_id==run_id,TaskMap.task_id.in_(non_mapped_dep_keys),TaskMap.map_index<0,)fortask_id,lengthintaskmap_query:formapped_arg_nameinnon_mapped_dep_keys[task_id]:map_lengths[mapped_arg_name]+=length# Collect lengths from mapped upstreams.xcom_query=(session.query(XCom.task_id,func.count(XCom.map_index)).group_by(XCom.task_id).filter(XCom.dag_id==self.dag_id,XCom.run_id==run_id,XCom.task_id.in_(mapped_dep_keys),XCom.map_index>=0,))fortask_id,lengthinxcom_query:formapped_arg_nameinmapped_dep_keys[task_id]:map_lengths[mapped_arg_name]+=lengthreturnmap_lengths@cachedef_resolve_map_lengths(self,run_id:str,*,session:Session)->Dict[str,int]:"""Return dict of argument name to map length, or throw if some are not resolvable"""expansion_kwargs=self._get_expansion_kwargs()map_lengths=self._get_map_lengths(run_id,session=session)iflen(map_lengths)<len(expansion_kwargs):keys=", ".join(repr(k)forkinsorted(set(expansion_kwargs).difference(map_lengths)))raiseRuntimeError(f"Failed to populate all mapping metadata; missing: {keys}")returnmap_lengths
[docs]defexpand_mapped_task(self,run_id:str,*,session:Session)->Tuple[Sequence["TaskInstance"],int]:"""Create the mapped task instances for mapped task. :return: The newly created mapped TaskInstances (if any) in ascending order by map index, and the maximum map_index. """fromairflow.models.taskinstanceimportTaskInstancefromairflow.settingsimporttask_instance_mutation_hooktotal_length=functools.reduce(operator.mul,self._resolve_map_lengths(run_id,session=session).values())state:Optional[TaskInstanceState]=Noneunmapped_ti:Optional[TaskInstance]=(session.query(TaskInstance).filter(TaskInstance.dag_id==self.dag_id,TaskInstance.task_id==self.task_id,TaskInstance.run_id==run_id,TaskInstance.map_index==-1,or_(TaskInstance.state.in_(State.unfinished),TaskInstance.state.is_(None)),).one_or_none())all_expanded_tis:List[TaskInstance]=[]ifunmapped_ti:# The unmapped task instance still exists and is unfinished, i.e. we# haven't tried to run it before.iftotal_length<1:# If the upstream maps this to a zero-length value, simply marked the# unmapped task instance as SKIPPED (if needed).self.log.info("Marking %s as SKIPPED since the map has %d values to expand",unmapped_ti,total_length,)unmapped_ti.state=TaskInstanceState.SKIPPEDelse:# Otherwise convert this into the first mapped index, and create# TaskInstance for other indexes.unmapped_ti.map_index=0self.log.debug("Updated in place to become %s",unmapped_ti)all_expanded_tis.append(unmapped_ti)state=unmapped_ti.stateindexes_to_map=range(1,total_length)else:# Only create "missing" ones.current_max_mapping=(session.query(func.max(TaskInstance.map_index)).filter(TaskInstance.dag_id==self.dag_id,TaskInstance.task_id==self.task_id,TaskInstance.run_id==run_id,).scalar())indexes_to_map=range(current_max_mapping+1,total_length)forindexinindexes_to_map:# TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.ti=TaskInstance(self,run_id=run_id,map_index=index,state=state)self.log.debug("Expanding TIs upserted %s",ti)task_instance_mutation_hook(ti)ti=session.merge(ti)ti.refresh_from_task(self)# session.merge() loses task information.all_expanded_tis.append(ti)# Set to "REMOVED" any (old) TaskInstances with map indices greater# than the current map valuesession.query(TaskInstance).filter(TaskInstance.dag_id==self.dag_id,TaskInstance.task_id==self.task_id,TaskInstance.run_id==run_id,TaskInstance.map_index>=total_length,).update({TaskInstance.state:TaskInstanceState.REMOVED})session.flush()returnall_expanded_tis,total_length
[docs]defprepare_for_execution(self)->"MappedOperator":# Since a mapped operator cannot be used for execution, and an unmapped# BaseOperator needs to be created later (see render_template_fields),# we don't need to create a copy of the MappedOperator here.returnself
[docs]defrender_template_fields(self,context:Context,jinja_env:Optional["jinja2.Environment"]=None,)->Optional["BaseOperator"]:"""Template all attributes listed in template_fields. Different from the BaseOperator implementation, this renders the template fields on the *unmapped* BaseOperator. :param context: Dict with values to apply on content :param jinja_env: Jinja environment :return: The unmapped, populated BaseOperator """ifnotjinja_env:jinja_env=self.get_template_env()# Before we unmap we have to resolve the mapped arguments, otherwise the real operator constructor# could be called with an XComArg, rather than the value it resolves to.## We also need to resolve _all_ mapped arguments, even if they aren't marked as templatedkwargs=self._get_unmap_kwargs()template_fields=set(self.template_fields)# Ideally we'd like to pass in session as an argument to this function, but since operators _could_# override this we can't easily change this function signature.# We can't use @provide_session, as that closes and expunges everything, which we don't want to do# when we are so "deep" in the weeds here.## Nor do we want to close the session -- that would expunge all the things from the internal cache# which we don't want to do eithersession=settings.Session()self._resolve_expansion_kwargs(kwargs,template_fields,context,session)unmapped_task=self.unmap(unmap_kwargs=kwargs)self._do_render_template_fields(parent=unmapped_task,template_fields=template_fields,context=context,jinja_env=jinja_env,seen_oids=set(),session=session,)returnunmapped_task
def_resolve_expansion_kwargs(self,kwargs:Dict[str,Any],template_fields:Set[str],context:Context,session:Session)->None:"""Update mapped fields in place in kwargs dict"""fromairflow.models.xcom_argimportXComArgexpansion_kwargs=self._get_expansion_kwargs()fork,vinexpansion_kwargs.items():ifisinstance(v,XComArg):v=v.resolve(context,session=session)v=self._expand_mapped_field(k,v,context,session=session)template_fields.discard(k)kwargs[k]=vdef_expand_mapped_field(self,key:str,value:Any,context:Context,*,session:Session)->Any:map_index=context["ti"].map_indexifmap_index<0:returnvalueexpansion_kwargs=self._get_expansion_kwargs()all_lengths=self._resolve_map_lengths(context["run_id"],session=session)def_find_index_for_this_field(index:int)->int:# Need to use self.mapped_kwargs for the original argument order.formapped_keyinreversed(list(expansion_kwargs)):mapped_length=all_lengths[mapped_key]ifmapped_length<1:raiseRuntimeError(f"cannot expand field mapped to length {mapped_length!r}")ifmapped_key==key:returnindex%mapped_lengthindex//=mapped_lengthreturn-1found_index=_find_index_for_this_field(map_index)iffound_index<0:returnvalueifisinstance(value,collections.abc.Sequence):returnvalue[found_index]ifnotisinstance(value,dict):raiseTypeError(f"can't map over value of type {type(value)}")fori,(k,v)inenumerate(value.items()):ifi==found_index:returnk,vraiseIndexError(f"index {map_index} is over mapped length")
[docs]defiter_mapped_dependencies(self)->Iterator["Operator"]:"""Upstream dependencies that provide XComs used by this task for task mapping."""fromairflow.models.xcom_argimportXComArgforrefinXComArg.iter_xcom_args(self._get_expansion_kwargs()):yieldref.operator
@cached_property
[docs]defparse_time_mapped_ti_count(self)->Optional[int]:""" Number of mapped TaskInstances that can be created at DagRun create time. :return: None if non-literal mapped arg encountered, or else total number of mapped TIs this task should have """total=0forvalueinself._get_expansion_kwargs().values():ifnotisinstance(value,MAPPABLE_LITERAL_TYPES):# None literal type encountered, so give upreturnNoneiftotal==0:total=len(value)else:total*=len(value)returntotal
@cache
[docs]defrun_time_mapped_ti_count(self,run_id:str,*,session:Session)->Optional[int]:""" Number of mapped TaskInstances that can be created at run time, or None if upstream tasks are not complete yet. :return: None if upstream tasks are not complete yet, or else total number of mapped TIs this task should have """lengths=self._get_map_lengths(run_id,session=session)expansion_kwargs=self._get_expansion_kwargs()ifnotlengthsornotexpansion_kwargs:returnNonetotal=1fornameinexpansion_kwargs:val=lengths.get(name)ifvalisNone:returnNonetotal*=valreturntotal