# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Table to store information about mapped task instances (AIP-42)."""
from __future__ import annotations

import enum
from typing import TYPE_CHECKING, Any, Collection

from sqlalchemy import CheckConstraint, Column, ForeignKeyConstraint, Integer, String

from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
from airflow.utils.sqlalchemy import ExtendedJSON

    from airflow.models.taskinstance import TaskInstance

[docs]class TaskMapVariant(enum.Enum): """Task map variant. Possible values are **dict** (for a key-value mapping) and **list** (for an ordered value sequence). """
[docs] DICT = "dict"
[docs] LIST = "list"
[docs]class TaskMap(Base): """Model to track dynamic task-mapping information. This is currently only populated by an upstream TaskInstance pushing an XCom that's pulled by a downstream for mapping purposes. """
[docs] __tablename__ = "task_map"
# Link to upstream TaskInstance creating this dynamic mapping information.
[docs] dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
[docs] task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
[docs] run_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
[docs] map_index = Column(Integer, primary_key=True)
[docs] length = Column(Integer, nullable=False)
[docs] keys = Column(ExtendedJSON, nullable=True)
[docs] __table_args__ = ( CheckConstraint(length >= 0, name="task_map_length_not_negative"), ForeignKeyConstraint( [dag_id, task_id, run_id, map_index], [ "task_instance.dag_id", "task_instance.task_id", "task_instance.run_id", "task_instance.map_index", ], name="task_map_task_instance_fkey", ondelete="CASCADE",
), ) def __init__( self, dag_id: str, task_id: str, run_id: str, map_index: int, length: int, keys: list[Any] | None, ) -> None: self.dag_id = dag_id self.task_id = task_id self.run_id = run_id self.map_index = map_index self.length = length self.keys = keys @classmethod
[docs] def from_task_instance_xcom(cls, ti: TaskInstance, value: Collection) -> TaskMap: if ti.run_id is None: raise ValueError("cannot record task map for unrun task instance") return cls( dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id, map_index=ti.map_index, length=len(value), keys=(list(value) if isinstance(value, else None),
) @property
[docs] def variant(self) -> TaskMapVariant: if self.keys is None: return TaskMapVariant.LIST return TaskMapVariant.DICT

