Source code for airflow.models.mappedoperator

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import collections
import contextlib
import copy
import datetime
import warnings
from typing import TYPE_CHECKING, Any, ClassVar, Collection, Iterable, Iterator, Mapping, Sequence, Union

import attr
import pendulum
from sqlalchemy.orm.session import Session

from airflow import settings
from airflow.compat.functools import cache
from airflow.exceptions import AirflowException, UnmappableOperator
from airflow.models.abstractoperator import (
from airflow.models.expandinput import (
from airflow.models.param import ParamsDict
from airflow.models.pool import Pool
from airflow.serialization.enums import DagAttributeTypes
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
from airflow.typing_compat import Literal
from airflow.utils.context import Context, context_update_for_unmapped
from airflow.utils.helpers import is_container, prevent_duplicates
from airflow.utils.operator_resources import Resources
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import NOTSET

    import jinja2  # Slow import.

    from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
    from airflow.models.dag import DAG
    from airflow.models.operator import Operator
    from airflow.models.xcom_arg import XComArg
    from airflow.utils.task_group import TaskGroup

[docs]ValidationSource = Union[Literal["expand"], Literal["partial"]]
[docs]def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: # use a dict so order of args is same as code order unknown_args = value.copy() for klass in op.mro(): init = klass.__init__ # type: ignore[misc] try: param_names = init._BaseOperatorMeta__param_names except AttributeError: continue for name in param_names: value = unknown_args.pop(name, NOTSET) if func != "expand": continue if value is NOTSET: continue if is_mappable(value): continue type_name = type(value).__name__ error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}" raise ValueError(error) if not unknown_args: return # If we have no args left to check: stop looking at the MRO chain. if len(unknown_args) == 1: error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}" else: names = ", ".join(repr(n) for n in unknown_args) error = f"unexpected keyword arguments {names}" raise TypeError(f"{op.__name__}.{func}() got {error}")
[docs]def ensure_xcomarg_return_value(arg: Any) -> None: from airflow.models.xcom_arg import XCOM_RETURN_KEY, XComArg if isinstance(arg, XComArg): for operator, key in arg.iter_references(): if key != XCOM_RETURN_KEY: raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}") elif not is_container(arg): return elif isinstance(arg, for v in arg.values(): ensure_xcomarg_return_value(v) elif isinstance(arg, for v in arg: ensure_xcomarg_return_value(v)
@attr.define(kw_only=True, repr=False)
[docs]class OperatorPartial: """An "intermediate state" returned by ``BaseOperator.partial()``. This only exists at DAG-parsing time; the only intended usage is for the user to call ``.expand()`` on it at some point (usually in a method chain) to create a ``MappedOperator`` to add into the DAG. """
[docs] operator_class: type[BaseOperator]
[docs] kwargs: dict[str, Any]
[docs] params: ParamsDict | dict
_expand_called: bool = False # Set when expand() is called to ease user debugging.
[docs] def __attrs_post_init__(self): from airflow.operators.subdag import SubDagOperator if issubclass(self.operator_class, SubDagOperator): raise TypeError("Mapping over deprecated SubDagOperator is not supported") validate_mapping_kwargs(self.operator_class, "partial", self.kwargs)
[docs] def __repr__(self) -> str: args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items()) return f"{self.operator_class.__name__}.partial({args})"
[docs] def __del__(self): if not self._expand_called: try: task_id = repr(self.kwargs["task_id"]) except KeyError: task_id = f"at {hex(id(self))}" warnings.warn(f"Task {task_id} was never mapped!")
[docs] def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator: if not mapped_kwargs: raise TypeError("no arguments to expand against") validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs) prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") # Since the input is already checked at parse time, we can set strict # to False to skip the checks on execution. return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False)
[docs] def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator: from airflow.models.xcom_arg import XComArg if isinstance(kwargs, for item in kwargs: if not isinstance(item, (XComArg, raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") elif not isinstance(kwargs, XComArg): raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: from airflow.operators.empty import EmptyOperator self._expand_called = True ensure_xcomarg_return_value(expand_input.value) partial_kwargs = self.kwargs.copy() task_id = partial_kwargs.pop("task_id") dag = partial_kwargs.pop("dag") task_group = partial_kwargs.pop("task_group") start_date = partial_kwargs.pop("start_date") end_date = partial_kwargs.pop("end_date") try: operator_name = self.operator_class.custom_operator_name # type: ignore except AttributeError: operator_name = self.operator_class.__name__ op = MappedOperator( operator_class=self.operator_class, expand_input=expand_input, partial_kwargs=partial_kwargs, task_id=task_id, params=self.params, deps=MappedOperator.deps_for(self.operator_class), operator_extra_links=self.operator_class.operator_extra_links, template_ext=self.operator_class.template_ext, template_fields=self.operator_class.template_fields, template_fields_renderers=self.operator_class.template_fields_renderers, ui_color=self.operator_class.ui_color, ui_fgcolor=self.operator_class.ui_fgcolor, is_empty=issubclass(self.operator_class, EmptyOperator), task_module=self.operator_class.__module__, task_type=self.operator_class.__name__, operator_name=operator_name, dag=dag, task_group=task_group, start_date=start_date, end_date=end_date, disallow_kwargs_override=strict, # For classic operators, this points to expand_input because kwargs # to BaseOperator.expand() contribute to operator arguments. expand_input_attr="expand_input", ) return op
@attr.define( kw_only=True, # Disable custom __getstate__ and __setstate__ generation since it interacts # badly with Airflow's DAG serialization and pickling. When a mapped task is # deserialized, subclasses are coerced into MappedOperator, but when it goes # through DAG pickling, all attributes defined in the subclasses are dropped # by attrs's custom state management. Since attrs does not do anything too # special here (the logic is only important for slots=True), we use Python's # built-in implementation, which works (as proven by good old BaseOperator). getstate_setstate=False,
[docs]) class MappedOperator(AbstractOperator): """Object representing a mapped operator in a DAG.""" # This attribute serves double purpose. For a "normal" operator instance # loaded from DAG, this holds the underlying non-mapped operator class that # can be used to create an unmapped operator for execution. For an operator # recreated from a serialized DAG, however, this holds the serialized data # that can be used to unmap this into a SerializedBaseOperator.
[docs] operator_class: type[BaseOperator] | dict[str, Any]
[docs] expand_input: ExpandInput
[docs] partial_kwargs: dict[str, Any]
# Needed for serialization.
[docs] task_id: str
[docs] params: ParamsDict | dict
[docs] deps: frozenset[BaseTIDep]
[docs] template_ext: Sequence[str]
[docs] template_fields: Collection[str]
[docs] template_fields_renderers: dict[str, str]
[docs] ui_color: str
[docs] ui_fgcolor: str
_is_empty: bool _task_module: str _task_type: str _operator_name: str
[docs] dag: DAG | None
[docs] task_group: TaskGroup | None
[docs] start_date: pendulum.DateTime | None
[docs] end_date: pendulum.DateTime | None
[docs] upstream_task_ids: set[str] = attr.ib(factory=set, init=False)
[docs] downstream_task_ids: set[str] = attr.ib(factory=set, init=False)
_disallow_kwargs_override: bool """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``. If *False*, values from ``expand_input`` under duplicate keys override those under corresponding keys in ``partial_kwargs``. """ _expand_input_attr: str """Where to get kwargs to calculate expansion length against. This should be a name to call ``getattr()`` on. """
[docs] subdag: None = None # Since we don't support SubDagOperator, this is always None.
[docs] HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset( ( "parse_time_mapped_ti_count", "operator_class",
) )
[docs] def __hash__(self): return id(self)
[docs] def __repr__(self): return f"<Mapped({self._task_type}): {self.task_id}>"
[docs] def __attrs_post_init__(self): from airflow.models.xcom_arg import XComArg if self.get_closest_mapped_task_group() is not None: raise NotImplementedError("operator expansion in an expanded task group is not yet supported") if self.task_group: self.task_group.add(self) if self.dag: self.dag.add_task(self) XComArg.apply_upstream_relationship(self, self.expand_input.value) for k, v in self.partial_kwargs.items(): if k in self.template_fields: XComArg.apply_upstream_relationship(self, v) if self.partial_kwargs.get("sla") is not None: raise AirflowException( f"SLAs are unsupported with mapped tasks. Please set `sla=None` for task "
f"{self.task_id!r}." ) @classmethod @cache
[docs] def get_serialized_fields(cls): # Not using 'cls' here since we only want to serialize base fields. return frozenset(attr.fields_dict(MappedOperator)) - { "dag", "deps", "expand_input", # This is needed to be able to accept XComArg. "subdag", "task_group", "upstream_task_ids",
} @staticmethod @cache
[docs] def deps_for(operator_class: type[BaseOperator]) -> frozenset[BaseTIDep]: operator_deps = operator_class.deps if not isinstance(operator_deps, raise UnmappableOperator( f"'deps' must be a set defined as a class-level variable on {operator_class.__name__}, " f"not a {type(operator_deps).__name__}" ) return operator_deps | {MappedTaskIsExpanded()}
[docs] def task_type(self) -> str: """Implementing Operator.""" return self._task_type
[docs] def operator_name(self) -> str: return self._operator_name
[docs] def inherits_from_empty_operator(self) -> bool: """Implementing Operator.""" return self._is_empty
[docs] def roots(self) -> Sequence[AbstractOperator]: """Implementing DAGNode.""" return [self]
[docs] def leaves(self) -> Sequence[AbstractOperator]: """Implementing DAGNode.""" return [self]
[docs] def owner(self) -> str: # type: ignore[override] return self.partial_kwargs.get("owner", DEFAULT_OWNER)
[docs] def email(self) -> None | str | Iterable[str]: return self.partial_kwargs.get("email")
[docs] def trigger_rule(self) -> TriggerRule: return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE)
[docs] def depends_on_past(self) -> bool: return bool(self.partial_kwargs.get("depends_on_past"))
[docs] def ignore_first_depends_on_past(self) -> bool: value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST) return bool(value)
[docs] def wait_for_downstream(self) -> bool: return bool(self.partial_kwargs.get("wait_for_downstream"))
[docs] def retries(self) -> int | None: return self.partial_kwargs.get("retries", DEFAULT_RETRIES)
[docs] def queue(self) -> str: return self.partial_kwargs.get("queue", DEFAULT_QUEUE)
[docs] def pool(self) -> str: return self.partial_kwargs.get("pool", Pool.DEFAULT_POOL_NAME)
[docs] def pool_slots(self) -> str | None: return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS)
[docs] def execution_timeout(self) -> datetime.timedelta | None: return self.partial_kwargs.get("execution_timeout")
[docs] def max_retry_delay(self) -> datetime.timedelta | None: return self.partial_kwargs.get("max_retry_delay")
[docs] def retry_delay(self) -> datetime.timedelta: return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY)
[docs] def retry_exponential_backoff(self) -> bool: return bool(self.partial_kwargs.get("retry_exponential_backoff"))
[docs] def priority_weight(self) -> int: # type: ignore[override] return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT)
[docs] def weight_rule(self) -> int: # type: ignore[override] return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
[docs] def sla(self) -> datetime.timedelta | None: return self.partial_kwargs.get("sla")
[docs] def max_active_tis_per_dag(self) -> int | None: return self.partial_kwargs.get("max_active_tis_per_dag")
[docs] def resources(self) -> Resources | None: return self.partial_kwargs.get("resources")
[docs] def on_execute_callback(self) -> TaskStateChangeCallback | None: return self.partial_kwargs.get("on_execute_callback")
@on_execute_callback.setter def on_execute_callback(self, value: TaskStateChangeCallback | None) -> None: self.partial_kwargs["on_execute_callback"] = value @property
[docs] def on_failure_callback(self) -> TaskStateChangeCallback | None: return self.partial_kwargs.get("on_failure_callback")
@on_failure_callback.setter def on_failure_callback(self, value: TaskStateChangeCallback | None) -> None: self.partial_kwargs["on_failure_callback"] = value @property
[docs] def on_retry_callback(self) -> TaskStateChangeCallback | None: return self.partial_kwargs.get("on_retry_callback")
@on_retry_callback.setter def on_retry_callback(self, value: TaskStateChangeCallback | None) -> None: self.partial_kwargs["on_retry_callback"] = value @property
[docs] def on_success_callback(self) -> TaskStateChangeCallback | None: return self.partial_kwargs.get("on_success_callback")
@on_success_callback.setter def on_success_callback(self, value: TaskStateChangeCallback | None) -> None: self.partial_kwargs["on_success_callback"] = value @property
[docs] def run_as_user(self) -> str | None: return self.partial_kwargs.get("run_as_user")
[docs] def executor_config(self) -> dict: return self.partial_kwargs.get("executor_config", {})
@property # type: ignore[override]
[docs] def inlets(self) -> list[Any]: # type: ignore[override] return self.partial_kwargs.get("inlets", [])
@inlets.setter def inlets(self, value: list[Any]) -> None: # type: ignore[override] self.partial_kwargs["inlets"] = value @property # type: ignore[override]
[docs] def outlets(self) -> list[Any]: # type: ignore[override] return self.partial_kwargs.get("outlets", [])
@outlets.setter def outlets(self, value: list[Any]) -> None: # type: ignore[override] self.partial_kwargs["outlets"] = value @property
[docs] def doc(self) -> str | None: return self.partial_kwargs.get("doc")
[docs] def doc_md(self) -> str | None: return self.partial_kwargs.get("doc_md")
[docs] def doc_json(self) -> str | None: return self.partial_kwargs.get("doc_json")
[docs] def doc_yaml(self) -> str | None: return self.partial_kwargs.get("doc_yaml")
[docs] def doc_rst(self) -> str | None: return self.partial_kwargs.get("doc_rst")
[docs] def get_dag(self) -> DAG | None: """Implementing Operator.""" return self.dag
[docs] def output(self) -> XComArg: """Returns reference to XCom pushed by current operator""" from airflow.models.xcom_arg import XComArg return XComArg(operator=self)
[docs] def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: """Implementing DAGNode.""" return DagAttributeTypes.OP, self.task_id
def _expand_mapped_kwargs(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]: """Get the kwargs to create the unmapped operator. This exists because taskflow operators expand against op_kwargs, not the entire operator kwargs dict. """ return self._get_specified_expand_input().resolve(context, session) def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: """Get init kwargs to unmap the underlying operator class. :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``. """ if strict: prevent_duplicates( self.partial_kwargs, mapped_kwargs, fail_reason="unmappable or already specified", ) # If params appears in the mapped kwargs, we need to merge it into the # partial params, overriding existing keys. params = copy.copy(self.params) with contextlib.suppress(KeyError): params.update(mapped_kwargs["params"]) # Ordering is significant; mapped kwargs should override partial ones, # and the specially handled params should be respected. return { "task_id": self.task_id, "dag": self.dag, "task_group": self.task_group, "start_date": self.start_date, "end_date": self.end_date, **self.partial_kwargs, **mapped_kwargs, "params": params, }
[docs] def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> BaseOperator: """Get the "normal" Operator after applying the current mapping. The *resolve* argument is only used if ``operator_class`` is a real class, i.e. if this operator is not serialized. If ``operator_class`` is not a class (i.e. this DAG has been deserialized), this returns a SerializedBaseOperator that "looks like" the actual unmapping result. If *resolve* is a two-tuple (context, session), the information is used to resolve the mapped arguments into init arguments. If it is a mapping, no resolving happens, the mapping directly provides those init arguments resolved from mapped kwargs. :meta private: """ if isinstance(self.operator_class, type): if isinstance(resolve, kwargs = resolve elif resolve is not None: kwargs, _ = self._expand_mapped_kwargs(*resolve) else: raise RuntimeError("cannot unmap a non-serialized operator without context") kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override) op = self.operator_class(**kwargs, _airflow_from_mapped=True) # We need to overwrite task_id here because BaseOperator further # mangles the task_id based on the task hierarchy (namely, group_id # is prepended, and '__N' appended to deduplicate). This is hacky, # but better than duplicating the whole mangling logic. op.task_id = self.task_id return op # After a mapped operator is serialized, there's no real way to actually # unmap it since we've lost access to the underlying operator class. # This tries its best to simply "forward" all the attributes on this # mapped operator to a new SerializedBaseOperator instance. from airflow.serialization.serialized_objects import SerializedBaseOperator op = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True) SerializedBaseOperator.populate_operator(op, self.operator_class) return op
def _get_specified_expand_input(self) -> ExpandInput: """Input received from the expand call on the operator.""" return getattr(self, self._expand_input_attr)
[docs] def prepare_for_execution(self) -> MappedOperator: # Since a mapped operator cannot be used for execution, and an unmapped # BaseOperator needs to be created later (see render_template_fields), # we don't need to create a copy of the MappedOperator here. return self
[docs] def iter_mapped_dependencies(self) -> Iterator[Operator]: """Upstream dependencies that provide XComs used by this task for task mapping.""" from airflow.models.xcom_arg import XComArg for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()): yield operator
[docs] def get_parse_time_mapped_ti_count(self) -> int: current_count = self._get_specified_expand_input().get_parse_time_mapped_ti_count() try: parent_count = super().get_parse_time_mapped_ti_count() except NotMapped: return current_count return parent_count * current_count
[docs] def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: current_count = self._get_specified_expand_input().get_total_map_length(run_id, session=session) try: parent_count = super().get_mapped_ti_count(run_id, session=session) except NotMapped: return current_count return parent_count * current_count
[docs] def render_template_fields( self, context: Context, jinja_env: jinja2.Environment | None = None, ) -> None: """Template all attributes listed in *self.template_fields*. This updates *context* to reference the map-expanded task and relevant information, without modifying the mapped operator. The expanded task in *context* is then rendered in-place. :param context: Context dict with values to apply on content. :param jinja_env: Jinja environment to use for rendering. """ if not jinja_env: jinja_env = self.get_template_env() # Ideally we'd like to pass in session as an argument to this function, # but we can't easily change this function signature since operators # could override this. We can't use @provide_session since it closes and # expunges everything, which we don't want to do when we are so "deep" # in the weeds here. We don't close this session for the same reason. session = settings.Session() mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session) unmapped_task = self.unmap(mapped_kwargs) context_update_for_unmapped(context, unmapped_task) self._do_render_template_fields( parent=unmapped_task, template_fields=self.template_fields, context=context, jinja_env=jinja_env, seen_oids=seen_oids, session=session,

Was this entry helpful?