Source code for airflow.decorators.condition
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, TypeVar
from airflow.decorators.base import Task, _TaskDecorator
from airflow.exceptions import AirflowSkipException
if TYPE_CHECKING:
    from typing_extensions import TypeAlias
    from airflow.models.baseoperator import TaskPreExecuteHook
    from airflow.utils.context import Context
    BoolConditionFunc: TypeAlias = Callable[[Context], bool]
    MsgConditionFunc: TypeAlias = "Callable[[Context], tuple[bool, str | None]]"
    AnyConditionFunc: TypeAlias = "BoolConditionFunc | MsgConditionFunc"
__all__ = ["run_if", "skip_if"]
_T = TypeVar("_T", bound="Task[..., Any] | _TaskDecorator[..., Any, Any]")
[docs]def run_if(condition: AnyConditionFunc, skip_message: str | None = None) -> Callable[[_T], _T]:
    """
    Decorate a task to run only if a condition is met.
    :param condition: A function that takes a context and returns a boolean.
    :param skip_message: The message to log if the task is skipped.
        If None, a default message is used.
    """
    wrapped_condition = wrap_skip(
        condition, skip_message or "Task was skipped due to condition.", reverse=True
    )
    def decorator(task: _T) -> _T:
        if not isinstance(task, _TaskDecorator):
            error_msg = "run_if can only be used with task. decorate with @task before @run_if."
            raise TypeError(error_msg)
        pre_execute: TaskPreExecuteHook | None = task.kwargs.get("pre_execute")
        new_pre_execute = combine_hooks(pre_execute, wrapped_condition)
        task.kwargs["pre_execute"] = new_pre_execute
        return task  # type: ignore[return-value]
    return decorator 
[docs]def skip_if(condition: AnyConditionFunc, skip_message: str | None = None) -> Callable[[_T], _T]:
    """
    Decorate a task to skip if a condition is met.
    :param condition: A function that takes a context and returns a boolean.
    :param skip_message: The message to log if the task is skipped.
        If None, a default message is used.
    """
    wrapped_condition = wrap_skip(
        condition, skip_message or "Task was skipped due to condition.", reverse=False
    )
    def decorator(task: _T) -> _T:
        if not isinstance(task, _TaskDecorator):
            error_msg = "skip_if can only be used with task. decorate with @task before @skip_if."
            raise TypeError(error_msg)
        pre_execute: TaskPreExecuteHook | None = task.kwargs.get("pre_execute")
        new_pre_execute = combine_hooks(pre_execute, wrapped_condition)
        task.kwargs["pre_execute"] = new_pre_execute
        return task  # type: ignore[return-value]
    return decorator 
def wrap_skip(func: AnyConditionFunc, error_msg: str, *, reverse: bool) -> TaskPreExecuteHook:
    @wraps(func)
    def pre_execute(context: Context) -> None:
        condition = func(context)
        skip_msg = error_msg
        if isinstance(condition, tuple):
            condition, maybe_error_msg = condition
            if maybe_error_msg:
                skip_msg = maybe_error_msg
        if reverse:
            condition = not condition
        if condition:
            raise AirflowSkipException(skip_msg)
    return pre_execute
def combine_hooks(*hooks: TaskPreExecuteHook | None) -> TaskPreExecuteHook:
    def pre_execute(context: Context) -> None:
        for hook in hooks:
            if hook is None:
                continue
            hook(context)
    return pre_execute