"""Branching operators"""
from __future__ import annotations

from typing import Iterable

from airflow.models.baseoperator import BaseOperator
from airflow.models.skipmixin import SkipMixin
from airflow.utils.context import Context

[docs]class BaseBranchOperator(BaseOperator, SkipMixin): """ This is a base class for creating operators with branching functionality, similarly to BranchPythonOperator. Users should subclass this operator and implement the function `choose_branch(self, context)`. This should run whatever business logic is needed to determine the branch, and return either the task_id for a single task (as a str) or a list of task_ids. The operator will continue with the returned task_id(s), and all other tasks directly downstream of this operator will be skipped. """
[docs] def choose_branch(self, context: Context) -> str | Iterable[str]: """ Subclasses should implement this, running whatever logic is necessary to choose a branch and returning a task_id or list of task_ids. :param context: Context dictionary as passed to execute() """ raise NotImplementedError
[docs] def execute(self, context: Context): branches_to_execute = self.choose_branch(context) self.skip_all_except(context['ti'], branches_to_execute) return branches_to_execute

