Source code for airflow.models.expandinput

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import collections.abc
import functools
import operator
from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, NamedTuple, Sequence, Sized, Union

from airflow.compat.functools import cache
from airflow.utils.context import Context

if TYPE_CHECKING:
    from sqlalchemy.orm import Session

    from airflow.models.xcom_arg import XComArg

[docs]ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"]
# Each keyword argument to expand() can be an XComArg, sequence, or dict (not # any mapping since we need the value to be ordered).
[docs]OperatorExpandArgument = Union["XComArg", Sequence, Dict[str, Any]]
# The single argument of expand_kwargs() can be an XComArg, or a list with each # element being either an XComArg or a dict.
[docs]OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]]
# For isinstance() check. @cache
[docs]def get_mappable_types() -> tuple[type, ...]: from airflow.models.xcom_arg import XComArg return (XComArg, list, tuple, dict)
[docs]class NotFullyPopulated(RuntimeError): """Raise when ``get_map_lengths`` cannot populate all mapping metadata. This is generally due to not all upstream tasks have finished when the function is called. """ def __init__(self, missing: set[str]) -> None: self.missing = missing
[docs] def __str__(self) -> str: keys = ", ".join(repr(k) for k in sorted(self.missing)) return f"Failed to populate all mapping metadata; missing: {keys}"
[docs]class DictOfListsExpandInput(NamedTuple): """Storage type of a mapped operator's mapped kwargs. This is created from ``expand(**kwargs)``. """
[docs] value: dict[str, OperatorExpandArgument]
def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: """Generate kwargs with values available on parse-time.""" from airflow.models.xcom_arg import XComArg return ((k, v) for k, v in self.value.items() if not isinstance(v, XComArg))
[docs] def get_parse_time_mapped_ti_count(self) -> int | None: if not self.value: return 0 literal_values = [len(v) for _, v in self._iter_parse_time_resolved_kwargs()] if len(literal_values) != len(self.value): return None # None-literal type encountered, so give up. return functools.reduce(operator.mul, literal_values, 1)
def _get_map_lengths(self, run_id: str, *, session: Session) -> dict[str, int]: """Return dict of argument name to map length. If any arguments are not known right now (upstream task not finished), they will not be present in the dict. """ from airflow.models.xcom_arg import XComArg # TODO: This initiates one database call for each XComArg. Would it be # more efficient to do one single db call and unpack the value here? map_lengths_iterator = ( (k, (v.get_task_map_length(run_id, session=session) if isinstance(v, XComArg) else len(v))) for k, v in self.value.items() ) map_lengths = {k: v for k, v in map_lengths_iterator if v is not None} if len(map_lengths) < len(self.value): raise NotFullyPopulated(set(self.value).difference(map_lengths)) return map_lengths
[docs] def get_total_map_length(self, run_id: str, *, session: Session) -> int: if not self.value: return 0 lengths = self._get_map_lengths(run_id, session=session) return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1)
def _expand_mapped_field(self, key: str, value: Any, context: Context, *, session: Session) -> Any: from airflow.models.xcom_arg import XComArg if isinstance(value, XComArg): value = value.resolve(context, session=session) map_index = context["ti"].map_index if map_index < 0: raise RuntimeError("can't resolve task-mapping argument without expanding") all_lengths = self._get_map_lengths(context["run_id"], session=session) def _find_index_for_this_field(index: int) -> int: # Need to use the original user input to retain argument order. for mapped_key in reversed(list(self.value)): mapped_length = all_lengths[mapped_key] if mapped_length < 1: raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}") if mapped_key == key: return index % mapped_length index //= mapped_length return -1 found_index = _find_index_for_this_field(map_index) if found_index < 0: return value if isinstance(value, collections.abc.Sequence): return value[found_index] if not isinstance(value, dict): raise TypeError(f"can't map over value of type {type(value)}") for i, (k, v) in enumerate(value.items()): if i == found_index: return k, v raise IndexError(f"index {map_index} is over mapped length")
[docs] def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]: data = {k: self._expand_mapped_field(k, v, context, session=session) for k, v in self.value.items()} literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()} resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys} return data, resolved_oids
def _describe_type(value: Any) -> str: if value is None: return "None" return type(value).__name__
[docs]class ListOfDictsExpandInput(NamedTuple): """Storage type of a mapped operator's mapped kwargs. This is created from ``expand_kwargs(xcom_arg)``. """
[docs] value: OperatorExpandKwargsArgument
[docs] def get_parse_time_mapped_ti_count(self) -> int | None: if isinstance(self.value, collections.abc.Sized): return len(self.value) return None
[docs] def get_total_map_length(self, run_id: str, *, session: Session) -> int: if isinstance(self.value, collections.abc.Sized): return len(self.value) length = self.value.get_task_map_length(run_id, session=session) if length is None: raise NotFullyPopulated({"expand_kwargs() argument"}) return length
[docs] def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]: map_index = context["ti"].map_index if map_index < 0: raise RuntimeError("can't resolve task-mapping argument without expanding") mapping: Any if isinstance(self.value, collections.abc.Sized): mapping = self.value[map_index] if not isinstance(mapping, collections.abc.Mapping): mapping = mapping.resolve(context, session) else: mappings = self.value.resolve(context, session) if not isinstance(mappings, collections.abc.Sequence): raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}") mapping = mappings[map_index] if not isinstance(mapping, collections.abc.Mapping): raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]") for key in mapping: if not isinstance(key, str): raise ValueError( f"expand_kwargs() input dict keys must all be str, " f"but {key!r} is of type {_describe_type(key)}" ) return mapping, {id(v) for v in mapping.values()}
[docs]EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value.
_EXPAND_INPUT_TYPES = { "dict-of-lists": DictOfListsExpandInput, "list-of-dicts": ListOfDictsExpandInput, }
[docs]def get_map_type_key(expand_input: ExpandInput) -> str: return next(k for k, v in _EXPAND_INPUT_TYPES.items() if v == type(expand_input))
[docs]def create_expand_input(kind: str, value: Any) -> ExpandInput: return _EXPAND_INPUT_TYPES[kind](value)

Was this entry helpful?