#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from urllib.parse import urlsplit
import sqlalchemy_jsonfield
from sqlalchemy import (
Boolean,
Column,
ForeignKey,
ForeignKeyConstraint,
Index,
Integer,
PrimaryKeyConstraint,
String,
Table,
text,
)
from sqlalchemy.orm import relationship
from airflow.datasets import Dataset
from airflow.models.base import Base, StringID
from airflow.settings import json
from airflow.utils import timezone
from airflow.utils.sqlalchemy import UtcDateTime
[docs]class DatasetModel(Base):
"""
A table to store datasets.
:param uri: a string that uniquely identifies the dataset
:param extra: JSON field for arbitrary extra info
"""
[docs] id = Column(Integer, primary_key=True, autoincrement=True)
[docs] uri = Column(
String(length=3000).with_variant(
String(
length=3000,
# latin1 allows for more indexed length in mysql
# and this field should only be ascii chars
collation="latin1_general_cs",
),
"mysql",
),
nullable=False,
)
[docs] created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
[docs] updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
[docs] is_orphaned = Column(Boolean, default=False, nullable=False, server_default="0")
[docs] consuming_dags = relationship("DagScheduleDatasetReference", back_populates="dataset")
[docs] producing_tasks = relationship("TaskOutletDatasetReference", back_populates="dataset")
[docs] __tablename__ = "dataset"
[docs] __table_args__ = (
Index("idx_uri_unique", uri, unique=True),
{"sqlite_autoincrement": True}, # ensures PK values not reused
)
@classmethod
[docs] def from_public(cls, obj: Dataset) -> DatasetModel:
return cls(uri=obj.uri, extra=obj.extra)
def __init__(self, uri: str, **kwargs):
try:
uri.encode("ascii")
except UnicodeEncodeError:
raise ValueError("URI must be ascii")
parsed = urlsplit(uri)
if parsed.scheme and parsed.scheme.lower() == "airflow":
raise ValueError("Scheme `airflow` is reserved.")
super().__init__(uri=uri, **kwargs)
[docs] def __eq__(self, other):
if isinstance(other, (self.__class__, Dataset)):
return self.uri == other.uri
else:
return NotImplemented
[docs] def __hash__(self):
return hash(self.uri)
[docs] def __repr__(self):
return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})"
[docs]class DagScheduleDatasetReference(Base):
"""References from a DAG to a dataset of which it is a consumer."""
[docs] dataset_id = Column(Integer, primary_key=True, nullable=False)
[docs] dag_id = Column(StringID(), primary_key=True, nullable=False)
[docs] created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
[docs] updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
[docs] dataset = relationship("DatasetModel", back_populates="consuming_dags")
[docs] queue_records = relationship(
"DatasetDagRunQueue",
primaryjoin="""and_(
DagScheduleDatasetReference.dataset_id == foreign(DatasetDagRunQueue.dataset_id),
DagScheduleDatasetReference.dag_id == foreign(DatasetDagRunQueue.target_dag_id),
)""",
cascade="all, delete, delete-orphan",
)
[docs] __tablename__ = "dag_schedule_dataset_reference"
[docs] __table_args__ = (
PrimaryKeyConstraint(dataset_id, dag_id, name="dsdr_pkey", mssql_clustered=True),
ForeignKeyConstraint(
(dataset_id,),
["dataset.id"],
name="dsdr_dataset_fkey",
ondelete="CASCADE",
),
ForeignKeyConstraint(
columns=(dag_id,),
refcolumns=["dag.dag_id"],
name="dsdr_dag_id_fkey",
ondelete="CASCADE",
),
)
[docs] def __eq__(self, other):
if isinstance(other, self.__class__):
return self.dataset_id == other.dataset_id and self.dag_id == other.dag_id
else:
return NotImplemented
[docs] def __hash__(self):
return hash(self.__mapper__.primary_key)
[docs] def __repr__(self):
args = []
for attr in [x.name for x in self.__mapper__.primary_key]:
args.append(f"{attr}={getattr(self, attr)!r}")
return f"{self.__class__.__name__}({', '.join(args)})"
[docs]class TaskOutletDatasetReference(Base):
"""References from a task to a dataset that it updates / produces."""
[docs] dataset_id = Column(Integer, primary_key=True, nullable=False)
[docs] dag_id = Column(StringID(), primary_key=True, nullable=False)
[docs] task_id = Column(StringID(), primary_key=True, nullable=False)
[docs] created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
[docs] updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
[docs] dataset = relationship("DatasetModel", back_populates="producing_tasks")
[docs] __tablename__ = "task_outlet_dataset_reference"
[docs] __table_args__ = (
ForeignKeyConstraint(
(dataset_id,),
["dataset.id"],
name="todr_dataset_fkey",
ondelete="CASCADE",
),
PrimaryKeyConstraint(dataset_id, dag_id, task_id, name="todr_pkey", mssql_clustered=True),
ForeignKeyConstraint(
columns=(dag_id,),
refcolumns=["dag.dag_id"],
name="todr_dag_id_fkey",
ondelete="CASCADE",
),
)
[docs] def __eq__(self, other):
if isinstance(other, self.__class__):
return (
self.dataset_id == other.dataset_id
and self.dag_id == other.dag_id
and self.task_id == other.task_id
)
else:
return NotImplemented
[docs] def __hash__(self):
return hash(self.__mapper__.primary_key)
[docs] def __repr__(self):
args = []
for attr in [x.name for x in self.__mapper__.primary_key]:
args.append(f"{attr}={getattr(self, attr)!r}")
return f"{self.__class__.__name__}({', '.join(args)})"
[docs]class DatasetDagRunQueue(Base):
"""Model for storing dataset events that need processing."""
[docs] dataset_id = Column(Integer, primary_key=True, nullable=False)
[docs] target_dag_id = Column(StringID(), primary_key=True, nullable=False)
[docs] created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
[docs] __tablename__ = "dataset_dag_run_queue"
[docs] __table_args__ = (
PrimaryKeyConstraint(dataset_id, target_dag_id, name="datasetdagrunqueue_pkey", mssql_clustered=True),
ForeignKeyConstraint(
(dataset_id,),
["dataset.id"],
name="ddrq_dataset_fkey",
ondelete="CASCADE",
),
ForeignKeyConstraint(
(target_dag_id,),
["dag.dag_id"],
name="ddrq_dag_fkey",
ondelete="CASCADE",
),
)
[docs] def __eq__(self, other):
if isinstance(other, self.__class__):
return self.dataset_id == other.dataset_id and self.target_dag_id == other.target_dag_id
else:
return NotImplemented
[docs] def __hash__(self):
return hash(self.__mapper__.primary_key)
[docs] def __repr__(self):
args = []
for attr in [x.name for x in self.__mapper__.primary_key]:
args.append(f"{attr}={getattr(self, attr)!r}")
return f"{self.__class__.__name__}({', '.join(args)})"
[docs]association_table = Table(
"dagrun_dataset_event",
Base.metadata,
Column("dag_run_id", ForeignKey("dag_run.id", ondelete="CASCADE"), primary_key=True),
Column("event_id", ForeignKey("dataset_event.id", ondelete="CASCADE"), primary_key=True),
Index("idx_dagrun_dataset_events_dag_run_id", "dag_run_id"),
Index("idx_dagrun_dataset_events_event_id", "event_id"),
)
[docs]class DatasetEvent(Base):
"""
A table to store datasets events.
:param dataset_id: reference to DatasetModel record
:param extra: JSON field for arbitrary extra info
:param source_task_id: the task_id of the TI which updated the dataset
:param source_dag_id: the dag_id of the TI which updated the dataset
:param source_run_id: the run_id of the TI which updated the dataset
:param source_map_index: the map_index of the TI which updated the dataset
:param timestamp: the time the event was logged
We use relationships instead of foreign keys so that dataset events are not deleted even
if the foreign key object is.
"""
[docs] id = Column(Integer, primary_key=True, autoincrement=True)
[docs] dataset_id = Column(Integer, nullable=False)
[docs] source_task_id = Column(StringID(), nullable=True)
[docs] source_dag_id = Column(StringID(), nullable=True)
[docs] source_run_id = Column(StringID(), nullable=True)
[docs] source_map_index = Column(Integer, nullable=True, server_default=text("-1"))
[docs] timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
[docs] __tablename__ = "dataset_event"
[docs] __table_args__ = (
Index("idx_dataset_id_timestamp", dataset_id, timestamp),
{"sqlite_autoincrement": True}, # ensures PK values not reused
)
[docs] created_dagruns = relationship(
"DagRun",
secondary=association_table,
backref="consumed_dataset_events",
)
[docs] source_task_instance = relationship(
"TaskInstance",
primaryjoin="""and_(
DatasetEvent.source_dag_id == foreign(TaskInstance.dag_id),
DatasetEvent.source_run_id == foreign(TaskInstance.run_id),
DatasetEvent.source_task_id == foreign(TaskInstance.task_id),
DatasetEvent.source_map_index == foreign(TaskInstance.map_index),
)""",
viewonly=True,
lazy="select",
uselist=False,
)
[docs] source_dag_run = relationship(
"DagRun",
primaryjoin="""and_(
DatasetEvent.source_dag_id == foreign(DagRun.dag_id),
DatasetEvent.source_run_id == foreign(DagRun.run_id),
)""",
viewonly=True,
lazy="select",
uselist=False,
)
[docs] dataset = relationship(
DatasetModel,
primaryjoin="DatasetEvent.dataset_id == foreign(DatasetModel.id)",
viewonly=True,
lazy="select",
uselist=False,
)
@property
[docs] def uri(self):
return self.dataset.uri
[docs] def __repr__(self) -> str:
args = []
for attr in [
"id",
"dataset_id",
"extra",
"source_task_id",
"source_dag_id",
"source_run_id",
"source_map_index",
]:
args.append(f"{attr}={getattr(self, attr)!r}")
return f"{self.__class__.__name__}({', '.join(args)})"