#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import collections
import collections.abc
import contextlib
import copy
import datetime
import warnings
from typing import TYPE_CHECKING, Any, ClassVar, Collection, Iterable, Iterator, Mapping, Sequence, Union
import attr
import pendulum
from sqlalchemy.orm.session import Session
from airflow import settings
from airflow.compat.functools import cache
from airflow.exceptions import AirflowException, UnmappableOperator
from airflow.models.abstractoperator import (
    DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
    DEFAULT_OWNER,
    DEFAULT_POOL_SLOTS,
    DEFAULT_PRIORITY_WEIGHT,
    DEFAULT_QUEUE,
    DEFAULT_RETRIES,
    DEFAULT_RETRY_DELAY,
    DEFAULT_TRIGGER_RULE,
    DEFAULT_WEIGHT_RULE,
    AbstractOperator,
    NotMapped,
    TaskStateChangeCallback,
)
from airflow.models.expandinput import (
    DictOfListsExpandInput,
    ExpandInput,
    ListOfDictsExpandInput,
    OperatorExpandArgument,
    OperatorExpandKwargsArgument,
    is_mappable,
)
from airflow.models.param import ParamsDict
from airflow.models.pool import Pool
from airflow.serialization.enums import DagAttributeTypes
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
from airflow.typing_compat import Literal
from airflow.utils.context import Context, context_update_for_unmapped
from airflow.utils.helpers import is_container, prevent_duplicates
from airflow.utils.operator_resources import Resources
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import NOTSET
if TYPE_CHECKING:
    import jinja2  # Slow import.
    from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
    from airflow.models.dag import DAG
    from airflow.models.operator import Operator
    from airflow.models.xcom_arg import XComArg
    from airflow.utils.task_group import TaskGroup
[docs]ValidationSource = Union[Literal["expand"], Literal["partial"]] 
[docs]def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None:
    # use a dict so order of args is same as code order
    unknown_args = value.copy()
    for klass in op.mro():
        init = klass.__init__  # type: ignore[misc]
        try:
            param_names = init._BaseOperatorMeta__param_names
        except AttributeError:
            continue
        for name in param_names:
            value = unknown_args.pop(name, NOTSET)
            if func != "expand":
                continue
            if value is NOTSET:
                continue
            if is_mappable(value):
                continue
            type_name = type(value).__name__
            error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}"
            raise ValueError(error)
        if not unknown_args:
            return  # If we have no args left to check: stop looking at the MRO chain.
    if len(unknown_args) == 1:
        error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}"
    else:
        names = ", ".join(repr(n) for n in unknown_args)
        error = f"unexpected keyword arguments {names}"
    raise TypeError(f"{op.__name__}.{func}() got {error}") 
[docs]def ensure_xcomarg_return_value(arg: Any) -> None:
    from airflow.models.xcom_arg import XCOM_RETURN_KEY, XComArg
    if isinstance(arg, XComArg):
        for operator, key in arg.iter_references():
            if key != XCOM_RETURN_KEY:
                raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}")
    elif not is_container(arg):
        return
    elif isinstance(arg, collections.abc.Mapping):
        for v in arg.values():
            ensure_xcomarg_return_value(v)
    elif isinstance(arg, collections.abc.Iterable):
        for v in arg:
            ensure_xcomarg_return_value(v) 
@attr.define(kw_only=True, repr=False)
[docs]class OperatorPartial:
    """An "intermediate state" returned by ``BaseOperator.partial()``.
    This only exists at DAG-parsing time; the only intended usage is for the
    user to call ``.expand()`` on it at some point (usually in a method chain) to
    create a ``MappedOperator`` to add into the DAG.
    """
[docs]    operator_class: type[BaseOperator] 
[docs]    params: ParamsDict | dict 
    _expand_called: bool = False  # Set when expand() is called to ease user debugging.
[docs]    def __attrs_post_init__(self):
        from airflow.operators.subdag import SubDagOperator
        if issubclass(self.operator_class, SubDagOperator):
            raise TypeError("Mapping over deprecated SubDagOperator is not supported")
        validate_mapping_kwargs(self.operator_class, "partial", self.kwargs) 
[docs]    def __repr__(self) -> str:
        args = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items())
        return f"{self.operator_class.__name__}.partial({args})" 
[docs]    def __del__(self):
        if not self._expand_called:
            try:
                task_id = repr(self.kwargs["task_id"])
            except KeyError:
                task_id = f"at {hex(id(self))}"
            warnings.warn(f"Task {task_id} was never mapped!") 
[docs]    def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator:
        if not mapped_kwargs:
            raise TypeError("no arguments to expand against")
        validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs)
        prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified")
        # Since the input is already checked at parse time, we can set strict
        # to False to skip the checks on execution.
        return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False) 
[docs]    def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator:
        from airflow.models.xcom_arg import XComArg
        if isinstance(kwargs, collections.abc.Sequence):
            for item in kwargs:
                if not isinstance(item, (XComArg, collections.abc.Mapping)):
                    raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
        elif not isinstance(kwargs, XComArg):
            raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
        return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) 
    def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
        from airflow.operators.empty import EmptyOperator
        self._expand_called = True
        ensure_xcomarg_return_value(expand_input.value)
        partial_kwargs = self.kwargs.copy()
        task_id = partial_kwargs.pop("task_id")
        dag = partial_kwargs.pop("dag")
        task_group = partial_kwargs.pop("task_group")
        start_date = partial_kwargs.pop("start_date")
        end_date = partial_kwargs.pop("end_date")
        try:
            operator_name = self.operator_class.custom_operator_name  # type: ignore
        except AttributeError:
            operator_name = self.operator_class.__name__
        op = MappedOperator(
            operator_class=self.operator_class,
            expand_input=expand_input,
            partial_kwargs=partial_kwargs,
            task_id=task_id,
            params=self.params,
            deps=MappedOperator.deps_for(self.operator_class),
            operator_extra_links=self.operator_class.operator_extra_links,
            template_ext=self.operator_class.template_ext,
            template_fields=self.operator_class.template_fields,
            template_fields_renderers=self.operator_class.template_fields_renderers,
            ui_color=self.operator_class.ui_color,
            ui_fgcolor=self.operator_class.ui_fgcolor,
            is_empty=issubclass(self.operator_class, EmptyOperator),
            task_module=self.operator_class.__module__,
            task_type=self.operator_class.__name__,
            operator_name=operator_name,
            dag=dag,
            task_group=task_group,
            start_date=start_date,
            end_date=end_date,
            disallow_kwargs_override=strict,
            # For classic operators, this points to expand_input because kwargs
            # to BaseOperator.expand() contribute to operator arguments.
            expand_input_attr="expand_input",
        )
        return op 
@attr.define(
    kw_only=True,
    # Disable custom __getstate__ and __setstate__ generation since it interacts
    # badly with Airflow's DAG serialization and pickling. When a mapped task is
    # deserialized, subclasses are coerced into MappedOperator, but when it goes
    # through DAG pickling, all attributes defined in the subclasses are dropped
    # by attrs's custom state management. Since attrs does not do anything too
    # special here (the logic is only important for slots=True), we use Python's
    # built-in implementation, which works (as proven by good old BaseOperator).
    getstate_setstate=False,
[docs])
class MappedOperator(AbstractOperator):
    """Object representing a mapped operator in a DAG."""
    # This attribute serves double purpose. For a "normal" operator instance
    # loaded from DAG, this holds the underlying non-mapped operator class that
    # can be used to create an unmapped operator for execution. For an operator
    # recreated from a serialized DAG, however, this holds the serialized data
    # that can be used to unmap this into a SerializedBaseOperator.
[docs]    operator_class: type[BaseOperator] | dict[str, Any] 
[docs]    partial_kwargs: dict[str, Any] 
    # Needed for serialization.
[docs]    params: ParamsDict | dict 
[docs]    deps: frozenset[BaseTIDep] 
[docs]    template_ext: Sequence[str] 
[docs]    template_fields: Collection[str] 
[docs]    template_fields_renderers: dict[str, str] 
    _is_empty: bool
    _task_module: str
    _task_type: str
    _operator_name: str
[docs]    task_group: TaskGroup | None 
[docs]    start_date: pendulum.DateTime | None 
[docs]    end_date: pendulum.DateTime | None 
[docs]    upstream_task_ids: set[str] = attr.ib(factory=set, init=False) 
[docs]    downstream_task_ids: set[str] = attr.ib(factory=set, init=False) 
    _disallow_kwargs_override: bool
    """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``.
    If *False*, values from ``expand_input`` under duplicate keys override those
    under corresponding keys in ``partial_kwargs``.
    """
    _expand_input_attr: str
    """Where to get kwargs to calculate expansion length against.
    This should be a name to call ``getattr()`` on.
    """
[docs]    subdag: None = None  # Since we don't support SubDagOperator, this is always None. 
[docs]    HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset(
        (
            "parse_time_mapped_ti_count",
            "operator_class", 
        )
    )
[docs]    def __hash__(self):
        return id(self) 
[docs]    def __repr__(self):
        return f"<Mapped({self._task_type}): {self.task_id}>" 
[docs]    def __attrs_post_init__(self):
        from airflow.models.xcom_arg import XComArg
        if self.get_closest_mapped_task_group() is not None:
            raise NotImplementedError("operator expansion in an expanded task group is not yet supported")
        if self.task_group:
            self.task_group.add(self)
        if self.dag:
            self.dag.add_task(self)
        XComArg.apply_upstream_relationship(self, self.expand_input.value)
        for k, v in self.partial_kwargs.items():
            if k in self.template_fields:
                XComArg.apply_upstream_relationship(self, v)
        if self.partial_kwargs.get("sla") is not None:
            raise AirflowException(
                f"SLAs are unsupported with mapped tasks. Please set `sla=None` for task " 
                f"{self.task_id!r}."
            )
    @classmethod
    @cache
[docs]    def get_serialized_fields(cls):
        # Not using 'cls' here since we only want to serialize base fields.
        return frozenset(attr.fields_dict(MappedOperator)) - {
            "dag",
            "deps",
            "expand_input",  # This is needed to be able to accept XComArg.
            "subdag",
            "task_group",
            "upstream_task_ids", 
        }
    @staticmethod
    @cache
[docs]    def deps_for(operator_class: type[BaseOperator]) -> frozenset[BaseTIDep]:
        operator_deps = operator_class.deps
        if not isinstance(operator_deps, collections.abc.Set):
            raise UnmappableOperator(
                f"'deps' must be a set defined as a class-level variable on {operator_class.__name__}, "
                f"not a {type(operator_deps).__name__}"
            )
        return operator_deps | {MappedTaskIsExpanded()} 
    @property
[docs]    def task_type(self) -> str:
        """Implementing Operator."""
        return self._task_type 
    @property
[docs]    def operator_name(self) -> str:
        return self._operator_name 
    @property
[docs]    def inherits_from_empty_operator(self) -> bool:
        """Implementing Operator."""
        return self._is_empty 
    @property
[docs]    def roots(self) -> Sequence[AbstractOperator]:
        """Implementing DAGNode."""
        return [self] 
    @property
[docs]    def leaves(self) -> Sequence[AbstractOperator]:
        """Implementing DAGNode."""
        return [self] 
    @property
[docs]    def owner(self) -> str:  # type: ignore[override]
        return self.partial_kwargs.get("owner", DEFAULT_OWNER) 
    @property
[docs]    def email(self) -> None | str | Iterable[str]:
        return self.partial_kwargs.get("email") 
    @property
[docs]    def trigger_rule(self) -> TriggerRule:
        return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE) 
    @property
[docs]    def depends_on_past(self) -> bool:
        return bool(self.partial_kwargs.get("depends_on_past")) 
    @property
[docs]    def ignore_first_depends_on_past(self) -> bool:
        value = self.partial_kwargs.get("ignore_first_depends_on_past", DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST)
        return bool(value) 
    @property
[docs]    def wait_for_downstream(self) -> bool:
        return bool(self.partial_kwargs.get("wait_for_downstream")) 
    @property
[docs]    def retries(self) -> int | None:
        return self.partial_kwargs.get("retries", DEFAULT_RETRIES) 
    @property
[docs]    def queue(self) -> str:
        return self.partial_kwargs.get("queue", DEFAULT_QUEUE) 
    @property
[docs]    def pool(self) -> str:
        return self.partial_kwargs.get("pool", Pool.DEFAULT_POOL_NAME) 
    @property
[docs]    def pool_slots(self) -> str | None:
        return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS) 
    @property
[docs]    def execution_timeout(self) -> datetime.timedelta | None:
        return self.partial_kwargs.get("execution_timeout") 
    @property
[docs]    def max_retry_delay(self) -> datetime.timedelta | None:
        return self.partial_kwargs.get("max_retry_delay") 
    @property
[docs]    def retry_delay(self) -> datetime.timedelta:
        return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY) 
    @property
[docs]    def retry_exponential_backoff(self) -> bool:
        return bool(self.partial_kwargs.get("retry_exponential_backoff")) 
    @property
[docs]    def priority_weight(self) -> int:  # type: ignore[override]
        return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT) 
    @property
[docs]    def weight_rule(self) -> int:  # type: ignore[override]
        return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) 
    @property
[docs]    def sla(self) -> datetime.timedelta | None:
        return self.partial_kwargs.get("sla") 
    @property
[docs]    def max_active_tis_per_dag(self) -> int | None:
        return self.partial_kwargs.get("max_active_tis_per_dag") 
    @property
[docs]    def resources(self) -> Resources | None:
        return self.partial_kwargs.get("resources") 
    @property
[docs]    def on_execute_callback(self) -> TaskStateChangeCallback | None:
        return self.partial_kwargs.get("on_execute_callback") 
    @property
[docs]    def on_failure_callback(self) -> TaskStateChangeCallback | None:
        return self.partial_kwargs.get("on_failure_callback") 
    @property
[docs]    def on_retry_callback(self) -> TaskStateChangeCallback | None:
        return self.partial_kwargs.get("on_retry_callback") 
    @property
[docs]    def on_success_callback(self) -> TaskStateChangeCallback | None:
        return self.partial_kwargs.get("on_success_callback") 
    @property
[docs]    def run_as_user(self) -> str | None:
        return self.partial_kwargs.get("run_as_user") 
    @property
[docs]    def executor_config(self) -> dict:
        return self.partial_kwargs.get("executor_config", {}) 
    @property  # type: ignore[override]
[docs]    def inlets(self) -> list[Any]:  # type: ignore[override]
        return self.partial_kwargs.get("inlets", []) 
    @inlets.setter
    def inlets(self, value: list[Any]) -> None:  # type: ignore[override]
        self.partial_kwargs["inlets"] = value
    @property  # type: ignore[override]
[docs]    def outlets(self) -> list[Any]:  # type: ignore[override]
        return self.partial_kwargs.get("outlets", []) 
    @outlets.setter
    def outlets(self, value: list[Any]) -> None:  # type: ignore[override]
        self.partial_kwargs["outlets"] = value
    @property
[docs]    def doc(self) -> str | None:
        return self.partial_kwargs.get("doc") 
    @property
[docs]    def doc_md(self) -> str | None:
        return self.partial_kwargs.get("doc_md") 
    @property
[docs]    def doc_json(self) -> str | None:
        return self.partial_kwargs.get("doc_json") 
    @property
[docs]    def doc_yaml(self) -> str | None:
        return self.partial_kwargs.get("doc_yaml") 
    @property
[docs]    def doc_rst(self) -> str | None:
        return self.partial_kwargs.get("doc_rst") 
[docs]    def get_dag(self) -> DAG | None:
        """Implementing Operator."""
        return self.dag 
    @property
[docs]    def output(self) -> XComArg:
        """Returns reference to XCom pushed by current operator"""
        from airflow.models.xcom_arg import XComArg
        return XComArg(operator=self) 
[docs]    def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
        """Implementing DAGNode."""
        return DagAttributeTypes.OP, self.task_id 
    def _expand_mapped_kwargs(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
        """Get the kwargs to create the unmapped operator.
        This exists because taskflow operators expand against op_kwargs, not the
        entire operator kwargs dict.
        """
        return self._get_specified_expand_input().resolve(context, session)
    def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
        """Get init kwargs to unmap the underlying operator class.
        :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``.
        """
        if strict:
            prevent_duplicates(
                self.partial_kwargs,
                mapped_kwargs,
                fail_reason="unmappable or already specified",
            )
        # If params appears in the mapped kwargs, we need to merge it into the
        # partial params, overriding existing keys.
        params = copy.copy(self.params)
        with contextlib.suppress(KeyError):
            params.update(mapped_kwargs["params"])
        # Ordering is significant; mapped kwargs should override partial ones,
        # and the specially handled params should be respected.
        return {
            "task_id": self.task_id,
            "dag": self.dag,
            "task_group": self.task_group,
            "start_date": self.start_date,
            "end_date": self.end_date,
            **self.partial_kwargs,
            **mapped_kwargs,
            "params": params,
        }
[docs]    def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> BaseOperator:
        """Get the "normal" Operator after applying the current mapping.
        The *resolve* argument is only used if ``operator_class`` is a real
        class, i.e. if this operator is not serialized. If ``operator_class`` is
        not a class (i.e. this DAG has been deserialized), this returns a
        SerializedBaseOperator that "looks like" the actual unmapping result.
        If *resolve* is a two-tuple (context, session), the information is used
        to resolve the mapped arguments into init arguments. If it is a mapping,
        no resolving happens, the mapping directly provides those init arguments
        resolved from mapped kwargs.
        :meta private:
        """
        if isinstance(self.operator_class, type):
            if isinstance(resolve, collections.abc.Mapping):
                kwargs = resolve
            elif resolve is not None:
                kwargs, _ = self._expand_mapped_kwargs(*resolve)
            else:
                raise RuntimeError("cannot unmap a non-serialized operator without context")
            kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override)
            op = self.operator_class(**kwargs, _airflow_from_mapped=True)
            # We need to overwrite task_id here because BaseOperator further
            # mangles the task_id based on the task hierarchy (namely, group_id
            # is prepended, and '__N' appended to deduplicate). This is hacky,
            # but better than duplicating the whole mangling logic.
            op.task_id = self.task_id
            return op
        # After a mapped operator is serialized, there's no real way to actually
        # unmap it since we've lost access to the underlying operator class.
        # This tries its best to simply "forward" all the attributes on this
        # mapped operator to a new SerializedBaseOperator instance.
        from airflow.serialization.serialized_objects import SerializedBaseOperator
        op = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True)
        SerializedBaseOperator.populate_operator(op, self.operator_class)
        return op 
    def _get_specified_expand_input(self) -> ExpandInput:
        """Input received from the expand call on the operator."""
        return getattr(self, self._expand_input_attr)
[docs]    def prepare_for_execution(self) -> MappedOperator:
        # Since a mapped operator cannot be used for execution, and an unmapped
        # BaseOperator needs to be created later (see render_template_fields),
        # we don't need to create a copy of the MappedOperator here.
        return self 
[docs]    def iter_mapped_dependencies(self) -> Iterator[Operator]:
        """Upstream dependencies that provide XComs used by this task for task mapping."""
        from airflow.models.xcom_arg import XComArg
        for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()):
            yield operator 
    @cache
[docs]    def get_parse_time_mapped_ti_count(self) -> int:
        current_count = self._get_specified_expand_input().get_parse_time_mapped_ti_count()
        try:
            parent_count = super().get_parse_time_mapped_ti_count()
        except NotMapped:
            return current_count
        return parent_count * current_count 
[docs]    def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
        current_count = self._get_specified_expand_input().get_total_map_length(run_id, session=session)
        try:
            parent_count = super().get_mapped_ti_count(run_id, session=session)
        except NotMapped:
            return current_count
        return parent_count * current_count 
[docs]    def render_template_fields(
        self,
        context: Context,
        jinja_env: jinja2.Environment | None = None,
    ) -> None:
        """Template all attributes listed in *self.template_fields*.
        This updates *context* to reference the map-expanded task and relevant
        information, without modifying the mapped operator. The expanded task
        in *context* is then rendered in-place.
        :param context: Context dict with values to apply on content.
        :param jinja_env: Jinja environment to use for rendering.
        """
        if not jinja_env:
            jinja_env = self.get_template_env()
        # Ideally we'd like to pass in session as an argument to this function,
        # but we can't easily change this function signature since operators
        # could override this. We can't use @provide_session since it closes and
        # expunges everything, which we don't want to do when we are so "deep"
        # in the weeds here. We don't close this session for the same reason.
        session = settings.Session()
        mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session)
        unmapped_task = self.unmap(mapped_kwargs)
        context_update_for_unmapped(context, unmapped_task)
        self._do_render_template_fields(
            parent=unmapped_task,
            template_fields=self.template_fields,
            context=context,
            jinja_env=jinja_env,
            seen_oids=seen_oids,
            session=session,  
        )