#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import json
import logging
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import (
    JSON,
    Column,
    ForeignKeyConstraint,
    Index,
    Integer,
    PrimaryKeyConstraint,
    String,
    delete,
    func,
    select,
    text,
)
from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import Query, relationship
from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies
from airflow.utils import timezone
from airflow.utils.db import LazySelectSequence
from airflow.utils.helpers import is_container
from airflow.utils.json import XComDecoder, XComEncoder
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime
# XCom constants below are needed for providers backward compatibility,
# which should import the constants directly after apache-airflow>=2.6.0
from airflow.utils.xcom import (
    MAX_XCOM_SIZE,  # noqa: F401
    XCOM_RETURN_KEY,
)
[docs]
log = logging.getLogger(__name__) 
if TYPE_CHECKING:
    from sqlalchemy.engine import Row
    from sqlalchemy.orm import Session
    from sqlalchemy.sql.expression import Select, TextClause
[docs]
class XComModel(TaskInstanceDependencies):
    """XCom model class. Contains table and some utilities."""
[docs]
    dag_run_id = Column(Integer(), nullable=False, primary_key=True) 
[docs]
    task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False, primary_key=True) 
[docs]
    map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1")) 
[docs]
    key = Column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True) 
    # Denormalized for easier lookup.
[docs]
    dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) 
[docs]
    run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) 
[docs]
    value = Column(JSON().with_variant(postgresql.JSONB, "postgresql")) 
[docs]
    timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 
[docs]
    __table_args__ = (
        # Ideally we should create a unique index over (key, dag_id, task_id, run_id),
        # but it goes over MySQL's index length limit. So we instead index 'key'
        # separately, and enforce uniqueness with DagRun.id instead.
        Index("idx_xcom_key", key),
        Index("idx_xcom_task_instance", dag_id, task_id, run_id, map_index),
        PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key", name="xcom_pkey"),
        ForeignKeyConstraint(
            [dag_id, task_id, run_id, map_index],
            [
                "task_instance.dag_id",
                "task_instance.task_id",
                "task_instance.run_id",
                "task_instance.map_index",
            ],
            name="xcom_task_instance_fkey",
            ondelete="CASCADE",
        ),
    ) 
[docs]
    dag_run = relationship(
        "DagRun",
        primaryjoin="XComModel.dag_run_id == foreign(DagRun.id)",
        uselist=False,
        lazy="joined",
        passive_deletes="all",
    ) 
[docs]
    logical_date = association_proxy("dag_run", "logical_date") 
    @classmethod
    @provide_session
[docs]
    def clear(
        cls,
        *,
        dag_id: str,
        task_id: str,
        run_id: str,
        map_index: int | None = None,
        session: Session = NEW_SESSION,
    ) -> None:
        """
        Clear all XCom data from the database for the given task instance.
        .. note:: This **will not** purge any data from a custom XCom backend.
        :param dag_id: ID of DAG to clear the XCom for.
        :param task_id: ID of task to clear the XCom for.
        :param run_id: ID of DAG run to clear the XCom for.
        :param map_index: If given, only clear XCom from this particular mapped
            task. The default ``None`` clears *all* XComs from the task.
        :param session: Database session. If not given, a new session will be
            created for this function.
        """
        # Given the historic order of this function (logical_date was first argument) to add a new optional
        # param we need to add default values for everything :(
        if dag_id is None:
            raise TypeError("clear() missing required argument: dag_id")
        if task_id is None:
            raise TypeError("clear() missing required argument: task_id")
        if not run_id:
            raise ValueError(f"run_id must be passed. Passed run_id={run_id}")
        query = session.query(cls).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id)
        if map_index is not None:
            query = query.filter_by(map_index=map_index)
        for xcom in query:
            # print(f"Clearing XCOM {xcom} with value {xcom.value}")
            session.delete(xcom)
        session.commit() 
    @classmethod
    @provide_session
[docs]
    def set(
        cls,
        key: str,
        value: Any,
        *,
        dag_id: str,
        task_id: str,
        run_id: str,
        map_index: int = -1,
        session: Session = NEW_SESSION,
    ) -> None:
        """
        Store an XCom value.
        :param key: Key to store the XCom.
        :param value: XCom value to store.
        :param dag_id: DAG ID.
        :param task_id: Task ID.
        :param run_id: DAG run ID for the task.
        :param map_index: Optional map index to assign XCom for a mapped task.
            The default is ``-1`` (set for a non-mapped task).
        :param session: Database session. If not given, a new session will be
            created for this function.
        """
        from airflow.models.dagrun import DagRun
        if not run_id:
            raise ValueError(f"run_id must be passed. Passed run_id={run_id}")
        dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar()
        if dag_run_id is None:
            raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}")
        # Seamlessly resolve LazySelectSequence to a list. This intends to work
        # as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but if
        # it's pushed into XCom, the user should be aware of the performance
        # implications, and this avoids leaking the implementation detail.
        if isinstance(value, LazySelectSequence):
            warning_message = (
                "Coercing mapped lazy proxy %s from task %s (DAG %s, run %s) "
                "to list, which may degrade performance. Review resource "
                "requirements for this operation, and call list() to suppress "
                "this message. See Dynamic Task Mapping documentation for "
                "more information about lazy proxy objects."
            )
            log.warning(
                warning_message,
                "return value" if key == XCOM_RETURN_KEY else f"value {key}",
                task_id,
                dag_id,
                run_id,
            )
            value = list(value)
        value = cls.serialize_value(
            value=value,
            key=key,
            task_id=task_id,
            dag_id=dag_id,
            run_id=run_id,
            map_index=map_index,
        )
        # Remove duplicate XComs and insert a new one.
        session.execute(
            delete(cls).where(
                cls.key == key,
                cls.run_id == run_id,
                cls.task_id == task_id,
                cls.dag_id == dag_id,
                cls.map_index == map_index,
            )
        )
        new = cast("Any", cls)(  # Work around Mypy complaining model not defining '__init__'.
            dag_run_id=dag_run_id,
            key=key,
            value=value,
            run_id=run_id,
            task_id=task_id,
            dag_id=dag_id,
            map_index=map_index,
        )
        session.add(new)
        session.flush() 
    @classmethod
    @provide_session
[docs]
    def get_many(
        cls,
        *,
        run_id: str,
        key: str | None = None,
        task_ids: str | Iterable[str] | None = None,
        dag_ids: str | Iterable[str] | None = None,
        map_indexes: int | Iterable[int] | None = None,
        include_prior_dates: bool = False,
        limit: int | None = None,
        session: Session = NEW_SESSION,
    ) -> Query:
        """
        Composes a query to get one or more XCom entries.
        This function returns an SQLAlchemy query of full XCom objects. If you
        just want one stored value, use :meth:`get_one` instead.
        :param run_id: DAG run ID for the task.
        :param key: A key for the XComs. If provided, only XComs with matching
            keys will be returned. Pass *None* (default) to remove the filter.
        :param task_ids: Only XComs from task with matching IDs will be pulled.
            Pass *None* (default) to remove the filter.
        :param dag_ids: Only pulls XComs from specified DAGs. Pass *None*
            (default) to remove the filter.
        :param map_indexes: Only XComs from matching map indexes will be pulled.
            Pass *None* (default) to remove the filter.
        :param include_prior_dates: If *False* (default), only XComs from the
            specified DAG run are returned. If *True*, all matching XComs are
            returned regardless of the run it belongs to.
        :param session: Database session. If not given, a new session will be
            created for this function.
        :param limit: Limiting returning XComs
        """
        from airflow.models.dagrun import DagRun
        if not run_id:
            raise ValueError(f"run_id must be passed. Passed run_id={run_id}")
        query = session.query(cls).join(XComModel.dag_run)
        if key:
            query = query.filter(XComModel.key == key)
        if is_container(task_ids):
            query = query.filter(cls.task_id.in_(task_ids))
        elif task_ids is not None:
            query = query.filter(cls.task_id == task_ids)
        if is_container(dag_ids):
            query = query.filter(cls.dag_id.in_(dag_ids))
        elif dag_ids is not None:
            query = query.filter(cls.dag_id == dag_ids)
        if isinstance(map_indexes, range) and map_indexes.step == 1:
            query = query.filter(cls.map_index >= map_indexes.start, cls.map_index < map_indexes.stop)
        elif is_container(map_indexes):
            query = query.filter(cls.map_index.in_(map_indexes))
        elif map_indexes is not None:
            query = query.filter(cls.map_index == map_indexes)
        if include_prior_dates:
            dr = (
                session.query(
                    func.coalesce(DagRun.logical_date, DagRun.run_after).label("logical_date_or_run_after")
                )
                .filter(DagRun.run_id == run_id)
                .subquery()
            )
            query = query.filter(
                func.coalesce(DagRun.logical_date, DagRun.run_after) <= dr.c.logical_date_or_run_after
            )
        else:
            query = query.filter(cls.run_id == run_id)
        query = query.order_by(DagRun.logical_date.desc(), cls.timestamp.desc())
        if limit:
            return query.limit(limit)
        return query 
    @staticmethod
[docs]
    def serialize_value(
        value: Any,
        *,
        key: str | None = None,
        task_id: str | None = None,
        dag_id: str | None = None,
        run_id: str | None = None,
        map_index: int | None = None,
    ) -> str:
        """Serialize XCom value to JSON str."""
        try:
            return json.dumps(value, cls=XComEncoder)
        except (ValueError, TypeError):
            raise ValueError("XCom value must be JSON serializable") 
    @staticmethod
[docs]
    def deserialize_value(result) -> Any:
        """
        Deserialize XCom value from a database result.
        If deserialization fails, the raw value is returned, which must still be a valid Python JSON-compatible
        type (e.g., ``dict``, ``list``, ``str``, ``int``, ``float``, or ``bool``).
        XCom values are stored as JSON in the database, and SQLAlchemy automatically handles
        serialization (``json.dumps``) and deserialization (``json.loads``). However, we
        use a custom encoder for serialization (``serialize_value``) and deserialization to handle special
        cases, such as encoding tuples via the Airflow Serialization module. These must be decoded
        using ``XComDecoder`` to restore original types.
        Some XCom values, such as those set via the Task Execution API, bypass ``serialize_value``
        and are stored directly in JSON format. Since these values are already deserialized
        by SQLAlchemy, they are returned as-is.
        **Example: Handling a tuple**:
        .. code-block:: python
            original_value = (1, 2, 3)
            serialized_value = XComModel.serialize_value(original_value)
            print(serialized_value)
            # '{"__classname__": "builtins.tuple", "__version__": 1, "__data__": [1, 2, 3]}'
        This serialized value is stored in the database. When deserialized, the value is restored to the original tuple.
        :param result: The XCom database row or object containing a ``value`` attribute.
        :return: The deserialized Python object.
        """
        if result.value is None:
            return None
        try:
            return json.loads(result.value, cls=XComDecoder)
        except (ValueError, TypeError):
            # Already deserialized (e.g., set via Task Execution API)
            return result.value 
 
class LazyXComSelectSequence(LazySelectSequence[Any]):
    """
    List-like interface to lazily access XCom values.
    :meta private:
    """
    @staticmethod
    def _rebuild_select(stmt: TextClause) -> Select:
        return select(XComModel.value).from_statement(stmt)
    @staticmethod
    def _process_row(row: Row) -> Any:
        return XComModel.deserialize_value(row)
[docs]
def __getattr__(name: str):
    if name == "BaseXCom":
        from airflow.sdk.bases.xcom import BaseXCom
        globals()[name] = BaseXCom
        return BaseXCom
    if name == "XCom":
        from airflow.sdk.execution_time.xcom import XCom
        globals()[name] = XCom
        return XCom
    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")