#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from urllib.parse import urlsplit
import sqlalchemy_jsonfield
from sqlalchemy import (
    Boolean,
    Column,
    ForeignKey,
    ForeignKeyConstraint,
    Index,
    Integer,
    PrimaryKeyConstraint,
    String,
    Table,
    text,
)
from sqlalchemy.orm import relationship
from airflow.datasets import Dataset
from airflow.models.base import Base, StringID
from airflow.settings import json
from airflow.utils import timezone
from airflow.utils.sqlalchemy import UtcDateTime
[docs]class DatasetModel(Base):
    """
    A table to store datasets.
    :param uri: a string that uniquely identifies the dataset
    :param extra: JSON field for arbitrary extra info
    """
[docs]    id = Column(Integer, primary_key=True, autoincrement=True) 
[docs]    uri = Column(
        String(length=3000).with_variant(
            String(
                length=3000,
                # latin1 allows for more indexed length in mysql
                # and this field should only be ascii chars
                collation="latin1_general_cs",
            ),
            "mysql",
        ),
        nullable=False, 
    )
[docs]    created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 
[docs]    updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) 
[docs]    is_orphaned = Column(Boolean, default=False, nullable=False, server_default="0") 
[docs]    consuming_dags = relationship("DagScheduleDatasetReference", back_populates="dataset") 
[docs]    producing_tasks = relationship("TaskOutletDatasetReference", back_populates="dataset") 
[docs]    __tablename__ = "dataset" 
[docs]    __table_args__ = (
        Index("idx_uri_unique", uri, unique=True),
        {"sqlite_autoincrement": True},  # ensures PK values not reused 
    )
    @classmethod
[docs]    def from_public(cls, obj: Dataset) -> DatasetModel:
        return cls(uri=obj.uri, extra=obj.extra) 
    def __init__(self, uri: str, **kwargs):
        try:
            uri.encode("ascii")
        except UnicodeEncodeError:
            raise ValueError("URI must be ascii")
        parsed = urlsplit(uri)
        if parsed.scheme and parsed.scheme.lower() == "airflow":
            raise ValueError("Scheme `airflow` is reserved.")
        super().__init__(uri=uri, **kwargs)
[docs]    def __eq__(self, other):
        if isinstance(other, (self.__class__, Dataset)):
            return self.uri == other.uri
        else:
            return NotImplemented 
[docs]    def __hash__(self):
        return hash(self.uri) 
[docs]    def __repr__(self):
        return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})"  
[docs]class DagScheduleDatasetReference(Base):
    """References from a DAG to a dataset of which it is a consumer."""
[docs]    dataset_id = Column(Integer, primary_key=True, nullable=False) 
[docs]    dag_id = Column(StringID(), primary_key=True, nullable=False) 
[docs]    created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 
[docs]    updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) 
[docs]    dataset = relationship("DatasetModel", back_populates="consuming_dags") 
[docs]    queue_records = relationship(
        "DatasetDagRunQueue",
        primaryjoin="""and_(
            DagScheduleDatasetReference.dataset_id == foreign(DatasetDagRunQueue.dataset_id),
            DagScheduleDatasetReference.dag_id == foreign(DatasetDagRunQueue.target_dag_id),
        )""",
        cascade="all, delete, delete-orphan", 
    )
[docs]    __tablename__ = "dag_schedule_dataset_reference" 
[docs]    __table_args__ = (
        PrimaryKeyConstraint(dataset_id, dag_id, name="dsdr_pkey", mssql_clustered=True),
        ForeignKeyConstraint(
            (dataset_id,),
            ["dataset.id"],
            name="dsdr_dataset_fkey",
            ondelete="CASCADE",
        ),
        ForeignKeyConstraint(
            columns=(dag_id,),
            refcolumns=["dag.dag_id"],
            name="dsdr_dag_id_fkey",
            ondelete="CASCADE", 
        ),
    )
[docs]    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return self.dataset_id == other.dataset_id and self.dag_id == other.dag_id
        else:
            return NotImplemented 
[docs]    def __hash__(self):
        return hash(self.__mapper__.primary_key) 
[docs]    def __repr__(self):
        args = []
        for attr in [x.name for x in self.__mapper__.primary_key]:
            args.append(f"{attr}={getattr(self, attr)!r}")
        return f"{self.__class__.__name__}({', '.join(args)})"  
[docs]class TaskOutletDatasetReference(Base):
    """References from a task to a dataset that it updates / produces."""
[docs]    dataset_id = Column(Integer, primary_key=True, nullable=False) 
[docs]    dag_id = Column(StringID(), primary_key=True, nullable=False) 
[docs]    task_id = Column(StringID(), primary_key=True, nullable=False) 
[docs]    created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 
[docs]    updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) 
[docs]    dataset = relationship("DatasetModel", back_populates="producing_tasks") 
[docs]    __tablename__ = "task_outlet_dataset_reference" 
[docs]    __table_args__ = (
        ForeignKeyConstraint(
            (dataset_id,),
            ["dataset.id"],
            name="todr_dataset_fkey",
            ondelete="CASCADE",
        ),
        PrimaryKeyConstraint(dataset_id, dag_id, task_id, name="todr_pkey", mssql_clustered=True),
        ForeignKeyConstraint(
            columns=(dag_id,),
            refcolumns=["dag.dag_id"],
            name="todr_dag_id_fkey",
            ondelete="CASCADE", 
        ),
    )
[docs]    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return (
                self.dataset_id == other.dataset_id
                and self.dag_id == other.dag_id
                and self.task_id == other.task_id
            )
        else:
            return NotImplemented 
[docs]    def __hash__(self):
        return hash(self.__mapper__.primary_key) 
[docs]    def __repr__(self):
        args = []
        for attr in [x.name for x in self.__mapper__.primary_key]:
            args.append(f"{attr}={getattr(self, attr)!r}")
        return f"{self.__class__.__name__}({', '.join(args)})"  
[docs]class DatasetDagRunQueue(Base):
    """Model for storing dataset events that need processing."""
[docs]    dataset_id = Column(Integer, primary_key=True, nullable=False) 
[docs]    target_dag_id = Column(StringID(), primary_key=True, nullable=False) 
[docs]    created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 
[docs]    __tablename__ = "dataset_dag_run_queue" 
[docs]    __table_args__ = (
        PrimaryKeyConstraint(dataset_id, target_dag_id, name="datasetdagrunqueue_pkey", mssql_clustered=True),
        ForeignKeyConstraint(
            (dataset_id,),
            ["dataset.id"],
            name="ddrq_dataset_fkey",
            ondelete="CASCADE",
        ),
        ForeignKeyConstraint(
            (target_dag_id,),
            ["dag.dag_id"],
            name="ddrq_dag_fkey",
            ondelete="CASCADE", 
        ),
    )
[docs]    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return self.dataset_id == other.dataset_id and self.target_dag_id == other.target_dag_id
        else:
            return NotImplemented 
[docs]    def __hash__(self):
        return hash(self.__mapper__.primary_key) 
[docs]    def __repr__(self):
        args = []
        for attr in [x.name for x in self.__mapper__.primary_key]:
            args.append(f"{attr}={getattr(self, attr)!r}")
        return f"{self.__class__.__name__}({', '.join(args)})"  
[docs]association_table = Table(
    "dagrun_dataset_event",
    Base.metadata,
    Column("dag_run_id", ForeignKey("dag_run.id", ondelete="CASCADE"), primary_key=True),
    Column("event_id", ForeignKey("dataset_event.id", ondelete="CASCADE"), primary_key=True),
    Index("idx_dagrun_dataset_events_dag_run_id", "dag_run_id"),
    Index("idx_dagrun_dataset_events_event_id", "event_id"), 
)
[docs]class DatasetEvent(Base):
    """
    A table to store datasets events.
    :param dataset_id: reference to DatasetModel record
    :param extra: JSON field for arbitrary extra info
    :param source_task_id: the task_id of the TI which updated the dataset
    :param source_dag_id: the dag_id of the TI which updated the dataset
    :param source_run_id: the run_id of the TI which updated the dataset
    :param source_map_index: the map_index of the TI which updated the dataset
    :param timestamp: the time the event was logged
    We use relationships instead of foreign keys so that dataset events are not deleted even
    if the foreign key object is.
    """
[docs]    id = Column(Integer, primary_key=True, autoincrement=True) 
[docs]    dataset_id = Column(Integer, nullable=False) 
[docs]    source_task_id = Column(StringID(), nullable=True) 
[docs]    source_dag_id = Column(StringID(), nullable=True) 
[docs]    source_run_id = Column(StringID(), nullable=True) 
[docs]    source_map_index = Column(Integer, nullable=True, server_default=text("-1")) 
[docs]    timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 
[docs]    __tablename__ = "dataset_event" 
[docs]    __table_args__ = (
        Index("idx_dataset_id_timestamp", dataset_id, timestamp),
        {"sqlite_autoincrement": True},  # ensures PK values not reused 
    )
[docs]    created_dagruns = relationship(
        "DagRun",
        secondary=association_table,
        backref="consumed_dataset_events", 
    )
[docs]    source_task_instance = relationship(
        "TaskInstance",
        primaryjoin="""and_(
            DatasetEvent.source_dag_id == foreign(TaskInstance.dag_id),
            DatasetEvent.source_run_id == foreign(TaskInstance.run_id),
            DatasetEvent.source_task_id == foreign(TaskInstance.task_id),
            DatasetEvent.source_map_index == foreign(TaskInstance.map_index),
        )""",
        viewonly=True,
        lazy="select",
        uselist=False, 
    )
[docs]    source_dag_run = relationship(
        "DagRun",
        primaryjoin="""and_(
            DatasetEvent.source_dag_id == foreign(DagRun.dag_id),
            DatasetEvent.source_run_id == foreign(DagRun.run_id),
        )""",
        viewonly=True,
        lazy="select",
        uselist=False, 
    )
[docs]    dataset = relationship(
        DatasetModel,
        primaryjoin="DatasetEvent.dataset_id == foreign(DatasetModel.id)",
        viewonly=True,
        lazy="select",
        uselist=False, 
    )
    @property
[docs]    def uri(self):
        return self.dataset.uri 
[docs]    def __repr__(self) -> str:
        args = []
        for attr in [
            "id",
            "dataset_id",
            "extra",
            "source_task_id",
            "source_dag_id",
            "source_run_id",
            "source_map_index",
        ]:
            args.append(f"{attr}={getattr(self, attr)!r}")
        return f"{self.__class__.__name__}({', '.join(args)})"