Source code for airflow.decorators.condition

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, TypeVar

from airflow.decorators.base import Task, _TaskDecorator
from airflow.exceptions import AirflowSkipException

if TYPE_CHECKING:
    from typing_extensions import TypeAlias

    from airflow.models.baseoperator import TaskPreExecuteHook
    from airflow.utils.context import Context

    BoolConditionFunc: TypeAlias = Callable[[Context], bool]
    MsgConditionFunc: TypeAlias = "Callable[[Context], tuple[bool, str | None]]"
    AnyConditionFunc: TypeAlias = "BoolConditionFunc | MsgConditionFunc"

__all__ = ["run_if", "skip_if"]

_T = TypeVar("_T", bound="Task[..., Any] | _TaskDecorator[..., Any, Any]")


[docs]def run_if(condition: AnyConditionFunc, skip_message: str | None = None) -> Callable[[_T], _T]: """ Decorate a task to run only if a condition is met. :param condition: A function that takes a context and returns a boolean. :param skip_message: The message to log if the task is skipped. If None, a default message is used. """ wrapped_condition = wrap_skip( condition, skip_message or "Task was skipped due to condition.", reverse=True ) def decorator(task: _T) -> _T: if not isinstance(task, _TaskDecorator): error_msg = "run_if can only be used with task. decorate with @task before @run_if." raise TypeError(error_msg) pre_execute: TaskPreExecuteHook | None = task.kwargs.get("pre_execute") new_pre_execute = combine_hooks(pre_execute, wrapped_condition) task.kwargs["pre_execute"] = new_pre_execute return task # type: ignore[return-value] return decorator
[docs]def skip_if(condition: AnyConditionFunc, skip_message: str | None = None) -> Callable[[_T], _T]: """ Decorate a task to skip if a condition is met. :param condition: A function that takes a context and returns a boolean. :param skip_message: The message to log if the task is skipped. If None, a default message is used. """ wrapped_condition = wrap_skip( condition, skip_message or "Task was skipped due to condition.", reverse=False ) def decorator(task: _T) -> _T: if not isinstance(task, _TaskDecorator): error_msg = "skip_if can only be used with task. decorate with @task before @skip_if." raise TypeError(error_msg) pre_execute: TaskPreExecuteHook | None = task.kwargs.get("pre_execute") new_pre_execute = combine_hooks(pre_execute, wrapped_condition) task.kwargs["pre_execute"] = new_pre_execute return task # type: ignore[return-value] return decorator
def wrap_skip(func: AnyConditionFunc, error_msg: str, *, reverse: bool) -> TaskPreExecuteHook: @wraps(func) def pre_execute(context: Context) -> None: condition = func(context) skip_msg = error_msg if isinstance(condition, tuple): condition, maybe_error_msg = condition if maybe_error_msg: skip_msg = maybe_error_msg if reverse: condition = not condition if condition: raise AirflowSkipException(skip_msg) return pre_execute def combine_hooks(*hooks: TaskPreExecuteHook | None) -> TaskPreExecuteHook: def pre_execute(context: Context) -> None: for hook in hooks: if hook is None: continue hook(context) return pre_execute

Was this entry helpful?