## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License."""Branching operators."""from__future__importannotationsfromtypingimportTYPE_CHECKING,Iterablefromairflow.models.baseoperatorimportBaseOperatorfromairflow.models.skipmixinimportSkipMixinifTYPE_CHECKING:fromairflow.modelsimportTaskInstancefromairflow.serialization.pydantic.taskinstanceimportTaskInstancePydanticfromairflow.utils.contextimportContext
[docs]classBranchMixIn(SkipMixin):"""Utility helper which handles the branching as one-liner."""
[docs]defdo_branch(self,context:Context,branches_to_execute:str|Iterable[str])->str|Iterable[str]:"""Implement the handling of branching including logging."""self.log.info("Branch into %s",branches_to_execute)branch_task_ids=self._expand_task_group_roots(context["ti"],branches_to_execute)self.skip_all_except(context["ti"],branch_task_ids)returnbranches_to_execute
def_expand_task_group_roots(self,ti:TaskInstance|TaskInstancePydantic,branches_to_execute:str|Iterable[str])->Iterable[str]:"""Expand any task group into its root task ids."""ifTYPE_CHECKING:assertti.tasktask=ti.taskdag=task.dagifTYPE_CHECKING:assertdagifbranches_to_executeisNone:returnelifisinstance(branches_to_execute,str)ornotisinstance(branches_to_execute,Iterable):branches_to_execute=[branches_to_execute]forbranchinbranches_to_execute:ifbranchindag.task_group_dict:tg=dag.task_group_dict[branch]root_ids=[root.task_idforrootintg.roots]self.log.info("Expanding task group %s into %s",tg.group_id,root_ids)yield fromroot_idselse:yieldbranch
[docs]classBaseBranchOperator(BaseOperator,BranchMixIn):""" A base class for creating operators with branching functionality, like to BranchPythonOperator. Users should create a subclass from this operator and implement the function `choose_branch(self, context)`. This should run whatever business logic is needed to determine the branch, and return one of the following: - A single task_id (as a str) - A single task_group_id (as a str) - A list containing a combination of task_ids and task_group_ids The operator will continue with the returned task_id(s) and/or task_group_id(s), and all other tasks directly downstream of this operator will be skipped. """
[docs]defchoose_branch(self,context:Context)->str|Iterable[str]:""" Abstract method to choose which branch to run. Subclasses should implement this, running whatever logic is necessary to choose a branch and returning a task_id or list of task_ids. :param context: Context dictionary as passed to execute() """raiseNotImplementedError