## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportwarningsfromtypingimportTYPE_CHECKING,Iterable,Sequencefromairflow.exceptionsimportAirflowException,RemovedInAirflow3Warningfromairflow.models.taskinstanceimportTaskInstancefromairflow.utilsimporttimezonefromairflow.utils.log.logging_mixinimportLoggingMixinfromairflow.utils.sessionimportNEW_SESSION,create_session,provide_sessionfromairflow.utils.stateimportStateifTYPE_CHECKING:frompendulumimportDateTimefromsqlalchemyimportSessionfromairflow.models.dagrunimportDagRunfromairflow.models.operatorimportOperatorfromairflow.models.taskmixinimportDAGNode# The key used by SkipMixin to store XCom data.
[docs]classSkipMixin(LoggingMixin):"""A Mixin to skip Tasks Instances"""def_set_state_to_skipped(self,dag_run:DagRun,tasks:Iterable[Operator],session:Session,)->None:"""Used internally to set state of task instances to skipped from the same dag run."""now=timezone.utcnow()session.query(TaskInstance).filter(TaskInstance.dag_id==dag_run.dag_id,TaskInstance.run_id==dag_run.run_id,TaskInstance.task_id.in_(d.task_idfordintasks),).update({TaskInstance.state:State.SKIPPED,TaskInstance.start_date:now,TaskInstance.end_date:now,},synchronize_session=False,)@provide_session
[docs]defskip(self,dag_run:DagRun,execution_date:DateTime,tasks:Iterable[DAGNode],session:Session=NEW_SESSION,):""" Sets tasks instances to skipped from the same dag run. If this instance has a `task_id` attribute, store the list of skipped task IDs to XCom so that NotPreviouslySkippedDep knows these tasks should be skipped when they are cleared. :param dag_run: the DagRun for which to set the tasks to skipped :param execution_date: execution_date :param tasks: tasks to skip (not task_ids) :param session: db session to use """task_list=_ensure_tasks(tasks)ifnottask_list:returnifexecution_dateandnotdag_run:fromairflow.models.dagrunimportDagRunwarnings.warn("Passing an execution_date to `skip()` is deprecated in favour of passing a dag_run",RemovedInAirflow3Warning,stacklevel=2,)dag_run=(session.query(DagRun).filter(DagRun.dag_id==task_list[0].dag_id,DagRun.execution_date==execution_date,).one())elifexecution_dateanddag_runandexecution_date!=dag_run.execution_date:raiseValueError("execution_date has a different value to dag_run.execution_date -- please only pass dag_run")ifdag_runisNone:raiseValueError("dag_run is required")self._set_state_to_skipped(dag_run,task_list,session)session.commit()# SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available.task_id:str|None=getattr(self,"task_id",None)iftask_idisnotNone:fromairflow.models.xcomimportXComXCom.set(key=XCOM_SKIPMIXIN_KEY,value={XCOM_SKIPMIXIN_SKIPPED:[d.task_idfordintask_list]},task_id=task_id,dag_id=dag_run.dag_id,run_id=dag_run.run_id,session=session,
)
[docs]defskip_all_except(self,ti:TaskInstance,branch_task_ids:None|str|Iterable[str]):""" This method implements the logic for a branching operator; given a single task ID or list of task IDs to follow, this skips all other tasks immediately downstream of this operator. branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or newly added tasks should be skipped when they are cleared. """self.log.info("Following branch %s",branch_task_ids)ifisinstance(branch_task_ids,str):branch_task_id_set={branch_task_ids}elifisinstance(branch_task_ids,Iterable):branch_task_id_set=set(branch_task_ids)invalid_task_ids_type={(bti,type(bti).__name__)forbtiinbranch_task_idsifnotisinstance(bti,str)}ifinvalid_task_ids_type:raiseAirflowException(f"'branch_task_ids' expected all task IDs are strings. "f"Invalid tasks found: {invalid_task_ids_type}.")elifbranch_task_idsisNone:branch_task_id_set=set()else:raiseAirflowException("'branch_task_ids' must be either None, a task ID, or an Iterable of IDs, "f"but got {type(branch_task_ids).__name__!r}.")dag_run=ti.get_dagrun()task=ti.taskdag=task.dagifTYPE_CHECKING:assertdagvalid_task_ids=set(dag.task_ids)invalid_task_ids=branch_task_id_set-valid_task_idsifinvalid_task_ids:raiseAirflowException("'branch_task_ids' must contain only valid task_ids. "f"Invalid tasks found: {invalid_task_ids}.")downstream_tasks=_ensure_tasks(task.downstream_list)ifdownstream_tasks:# For a branching workflow that looks like this, when "branch" does skip_all_except("task1"),# we intuitively expect both "task1" and "join" to execute even though strictly speaking,# "join" is also immediately downstream of "branch" and should have been skipped. Therefore,# we need a special case here for such empty branches: Check downstream tasks of branch_task_ids.# In case the task to skip is also downstream of branch_task_ids, we add it to branch_task_ids and# exclude it from skipping.## branch -----> join# \ ^# v /# task1#forbranch_task_idinlist(branch_task_id_set):branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False))skip_tasks=[tfortindownstream_tasksift.task_idnotinbranch_task_id_set]follow_task_ids=[t.task_idfortindownstream_tasksift.task_idinbranch_task_id_set]self.log.info("Skipping tasks %s",[t.task_idfortinskip_tasks])withcreate_session()assession:self._set_state_to_skipped(dag_run,skip_tasks,session=session)# For some reason, session.commit() needs to happen before xcom_push.# Otherwise the session is not committed.session.commit()ti.xcom_push(key=XCOM_SKIPMIXIN_KEY,value={XCOM_SKIPMIXIN_FOLLOWED:follow_task_ids})