# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.importwarningsfromabcimportABCMeta,abstractmethodfromtypingimportTYPE_CHECKING,Any,Iterable,Iterator,List,Optional,Sequence,Set,Tuple,Unionimportpendulumfromairflow.exceptionsimportAirflowExceptionfromairflow.serialization.enumsimportDagAttributeTypesifTYPE_CHECKING:fromloggingimportLoggerfromairflow.models.dagimportDAGfromairflow.models.mappedoperatorimportMappedOperatorfromairflow.utils.edgemodifierimportEdgeModifierfromairflow.utils.task_groupimportTaskGroup
[docs]classDependencyMixin:"""Mixing implementing common dependency setting methods methods like >> and <<."""@property
[docs]defroots(self)->Sequence["DependencyMixin"]:""" List of root nodes -- ones with no upstream dependencies. a.k.a. the "start" of this sub-graph """raiseNotImplementedError()
@property
[docs]defleaves(self)->Sequence["DependencyMixin"]:""" List of leaf nodes -- ones with only upstream dependencies. a.k.a. the "end" of this sub-graph """raiseNotImplementedError()
@abstractmethod
[docs]defset_upstream(self,other:Union["DependencyMixin",Sequence["DependencyMixin"]]):"""Set a task or a task list to be directly upstream from the current task."""raiseNotImplementedError()
@abstractmethod
[docs]defset_downstream(self,other:Union["DependencyMixin",Sequence["DependencyMixin"]]):"""Set a task or a task list to be directly downstream from the current task."""raiseNotImplementedError()
[docs]defupdate_relative(self,other:"DependencyMixin",upstream=True)->None:""" Update relationship information about another TaskMixin. Default is no-op. Override if necessary. """
[docs]def__rrshift__(self,other:Union["DependencyMixin",Sequence["DependencyMixin"]]):"""Called for Task >> [Task] because list don't have __rshift__ operators."""self.__lshift__(other)returnself
[docs]def__rlshift__(self,other:Union["DependencyMixin",Sequence["DependencyMixin"]]):"""Called for Task << [Task] because list don't have __lshift__ operators."""self.__rshift__(other)returnself
[docs]def__init_subclass__(cls)->None:warnings.warn(f"TaskMixin has been renamed to DependencyMixin, please update {cls.__name__}",category=DeprecationWarning,stacklevel=2,)returnsuper().__init_subclass__()
[docs]classDAGNode(DependencyMixin,metaclass=ABCMeta):""" A base class for a node in the graph of a workflow -- an Operator or a Task Group, either mapped or unmapped. """
def_set_relatives(self,task_or_task_list:Union[DependencyMixin,Sequence[DependencyMixin]],upstream:bool=False,edge_modifier:Optional["EdgeModifier"]=None,)->None:"""Sets relatives for the task or task list."""fromairflow.models.baseoperatorimportBaseOperatorfromairflow.models.mappedoperatorimportMappedOperatorfromairflow.models.operatorimportOperatorifnotisinstance(task_or_task_list,Sequence):task_or_task_list=[task_or_task_list]task_list:List[Operator]=[]fortask_objectintask_or_task_list:task_object.update_relative(self,notupstream)relatives=task_object.leavesifupstreamelsetask_object.rootsfortaskinrelatives:ifnotisinstance(task,(BaseOperator,MappedOperator)):raiseAirflowException(f"Relationships can only be set between Operators; received {task.__class__.__name__}")task_list.append(task)# relationships can only be set if the tasks share a single DAG. Tasks# without a DAG are assigned to that DAG.dags:Set["DAG"]={task.dagfortaskin[*self.roots,*task_list]iftask.has_dag()andtask.dag}iflen(dags)>1:raiseAirflowException(f'Tried to set relationships between tasks in more than one DAG: {dags}')eliflen(dags)==1:dag=dags.pop()else:raiseAirflowException(f"Tried to create relationships between tasks that don't have DAGs yet. "f"Set the DAG for at least one task and try again: {[self,*task_list]}")ifnotself.has_dag():# If this task does not yet have a dag, add it to the same dag as the other task and# put it in the dag's root TaskGroup.self.dag=dagself.dag.task_group.add(self)defadd_only_new(obj,item_set:Set[str],item:str)->None:"""Adds only new items to item set"""ifiteminitem_set:self.log.warning('Dependency %s, %s already registered for DAG: %s',obj,item,dag.dag_id)else:item_set.add(item)fortaskintask_list:ifdagandnottask.has_dag():# If the other task does not yet have a dag, add it to the same dag as this task and# put it in the dag's root TaskGroup.dag.add_task(task)dag.task_group.add(task)ifupstream:add_only_new(task,task.downstream_task_ids,self.node_id)add_only_new(self,self.upstream_task_ids,task.node_id)ifedge_modifier:edge_modifier.add_edge_info(self.dag,task.node_id,self.node_id)else:add_only_new(self,self.downstream_task_ids,task.node_id)add_only_new(task,task.upstream_task_ids,self.node_id)ifedge_modifier:edge_modifier.add_edge_info(self.dag,self.node_id,task.node_id)
[docs]defset_downstream(self,task_or_task_list:Union[DependencyMixin,Sequence[DependencyMixin]],edge_modifier:Optional["EdgeModifier"]=None,)->None:"""Set a node (or nodes) to be directly downstream from the current node."""self._set_relatives(task_or_task_list,upstream=False,edge_modifier=edge_modifier)
[docs]defset_upstream(self,task_or_task_list:Union[DependencyMixin,Sequence[DependencyMixin]],edge_modifier:Optional["EdgeModifier"]=None,)->None:"""Set a node (or nodes) to be directly upstream from the current node."""self._set_relatives(task_or_task_list,upstream=True,edge_modifier=edge_modifier)
@property
[docs]defdownstream_list(self)->Iterable["DAGNode"]:"""List of nodes directly downstream"""ifnotself.dag:raiseAirflowException(f'Operator {self} has not been assigned to a DAG yet')return[self.dag.get_task(tid)fortidinself.downstream_task_ids]
@property
[docs]defupstream_list(self)->Iterable["DAGNode"]:"""List of nodes directly upstream"""ifnotself.dag:raiseAirflowException(f'Operator {self} has not been assigned to a DAG yet')return[self.dag.get_task(tid)fortidinself.upstream_task_ids]
[docs]defget_direct_relative_ids(self,upstream:bool=False)->Set[str]:""" Get set of the direct relative ids to the current task, upstream or downstream. """ifupstream:returnself.upstream_task_idselse:returnself.downstream_task_ids
[docs]defget_direct_relatives(self,upstream:bool=False)->Iterable["DAGNode"]:""" Get list of the direct relatives to the current task, upstream or downstream. """ifupstream:returnself.upstream_listelse:returnself.downstream_list
[docs]defserialize_for_task_group(self)->Tuple[DagAttributeTypes,Any]:"""This is used by SerializedTaskGroup to serialize a task group's content."""raiseNotImplementedError()
def_iter_all_mapped_downstreams(self)->Iterator["MappedOperator"]:"""Return mapped nodes that are direct dependencies of the current task. For now, this walks the entire DAG to find mapped nodes that has this current task as an upstream. We cannot use ``downstream_list`` since it only contains operators, not task groups. In the future, we should provide a way to record an DAG node's all downstream nodes instead. Note that this does not guarantee the returned tasks actually use the current task for task mapping, but only checks those task are mapped operators, and are downstreams of the current task. To get a list of tasks that uses the current task for task mapping, use :meth:`iter_mapped_dependants` instead. """fromairflow.models.mappedoperatorimportMappedOperatorfromairflow.utils.task_groupimportTaskGroupdef_walk_group(group:TaskGroup)->Iterable[Tuple[str,DAGNode]]:"""Recursively walk children in a task group. This yields all direct children (including both tasks and task groups), and all children of any task groups. """forkey,childingroup.children.items():yieldkey,childifisinstance(child,TaskGroup):yield from_walk_group(child)tg=self.task_groupifnottg:raiseRuntimeError("Cannot check for mapped dependants when not attached to a DAG")forkey,childin_walk_group(tg):ifkey==self.node_id:continueifnotisinstance(child,MappedOperator):continueifself.node_idinchild.upstream_task_ids:yieldchild
[docs]defiter_mapped_dependants(self)->Iterator["MappedOperator"]:"""Return mapped nodes that depend on the current task the expansion. For now, this walks the entire DAG to find mapped nodes that has this current task as an upstream. We cannot use ``downstream_list`` since it only contains operators, not task groups. In the future, we should provide a way to record an DAG node's all downstream nodes instead. """return(downstreamfordownstreaminself._iter_all_mapped_downstreams()ifany(p.node_id==self.node_idforpindownstream.iter_mapped_dependencies())