Source code for airflow.models.dataset

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from urllib.parse import urlparse

import sqlalchemy_jsonfield
from sqlalchemy import (
    Column,
    ForeignKey,
    ForeignKeyConstraint,
    Index,
    Integer,
    PrimaryKeyConstraint,
    String,
    Table,
    text,
)
from sqlalchemy.orm import relationship

from airflow.datasets import Dataset
from airflow.models.base import Base, StringID
from airflow.settings import json
from airflow.utils import timezone
from airflow.utils.sqlalchemy import UtcDateTime


[docs]class DatasetModel(Base): """ A table to store datasets. :param uri: a string that uniquely identifies the dataset :param extra: JSON field for arbitrary extra info """
[docs] id = Column(Integer, primary_key=True, autoincrement=True)
[docs] uri = Column( String(length=3000).with_variant( String( length=3000, # latin1 allows for more indexed length in mysql # and this field should only be ascii chars collation='latin1_general_cs', ), 'mysql', ), nullable=False,
)
[docs] extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={})
[docs] created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
[docs] updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
[docs] consuming_dags = relationship("DagScheduleDatasetReference", back_populates="dataset")
[docs] producing_tasks = relationship("TaskOutletDatasetReference", back_populates="dataset")
[docs] __tablename__ = "dataset"
[docs] __table_args__ = ( Index('idx_uri_unique', uri, unique=True), {'sqlite_autoincrement': True}, # ensures PK values not reused
) @classmethod
[docs] def from_public(cls, obj: Dataset) -> DatasetModel: return cls(uri=obj.uri, extra=obj.extra)
def __init__(self, uri: str, **kwargs): try: uri.encode('ascii') except UnicodeEncodeError: raise ValueError('URI must be ascii') parsed = urlparse(uri) if parsed.scheme and parsed.scheme.lower() == 'airflow': raise ValueError("Scheme `airflow` is reserved.") super().__init__(uri=uri, **kwargs)
[docs] def __eq__(self, other): if isinstance(other, (self.__class__, Dataset)): return self.uri == other.uri else: return NotImplemented
[docs] def __hash__(self): return hash(self.uri)
[docs] def __repr__(self): return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})"
[docs]class DagScheduleDatasetReference(Base): """References from a DAG to a dataset of which it is a consumer."""
[docs] dataset_id = Column(Integer, primary_key=True, nullable=False)
[docs] dag_id = Column(StringID(), primary_key=True, nullable=False)
[docs] created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
[docs] updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
[docs] dataset = relationship('DatasetModel', back_populates="consuming_dags")
[docs] queue_records = relationship( "DatasetDagRunQueue", primaryjoin="""and_( DagScheduleDatasetReference.dataset_id == foreign(DatasetDagRunQueue.dataset_id), DagScheduleDatasetReference.dag_id == foreign(DatasetDagRunQueue.target_dag_id), )""",
)
[docs] __tablename__ = "dag_schedule_dataset_reference"
[docs] __table_args__ = ( PrimaryKeyConstraint(dataset_id, dag_id, name="dsdr_pkey", mssql_clustered=True), ForeignKeyConstraint( (dataset_id,), ["dataset.id"], name='dsdr_dataset_fkey', ondelete="CASCADE", ), ForeignKeyConstraint( columns=(dag_id,), refcolumns=['dag.dag_id'], name='dsdr_dag_id_fkey', ondelete='CASCADE',
), )
[docs] def __eq__(self, other): if isinstance(other, self.__class__): return self.dataset_id == other.dataset_id and self.dag_id == other.dag_id else: return NotImplemented
[docs] def __hash__(self): return hash(self.__mapper__.primary_key)
[docs] def __repr__(self): args = [] for attr in [x.name for x in self.__mapper__.primary_key]: args.append(f"{attr}={getattr(self, attr)!r}") return f"{self.__class__.__name__}({', '.join(args)})"
[docs]class TaskOutletDatasetReference(Base): """References from a task to a dataset that it updates / produces."""
[docs] dataset_id = Column(Integer, primary_key=True, nullable=False)
[docs] dag_id = Column(StringID(), primary_key=True, nullable=False)
[docs] task_id = Column(StringID(), primary_key=True, nullable=False)
[docs] created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
[docs] updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
[docs] dataset = relationship("DatasetModel", back_populates="producing_tasks")
[docs] __tablename__ = "task_outlet_dataset_reference"
[docs] __table_args__ = ( ForeignKeyConstraint( (dataset_id,), ["dataset.id"], name='todr_dataset_fkey', ondelete="CASCADE", ), PrimaryKeyConstraint(dataset_id, dag_id, task_id, name="todr_pkey", mssql_clustered=True), ForeignKeyConstraint( columns=(dag_id,), refcolumns=['dag.dag_id'], name='todr_dag_id_fkey', ondelete='CASCADE',
), )
[docs] def __eq__(self, other): if isinstance(other, self.__class__): return ( self.dataset_id == other.dataset_id and self.dag_id == other.dag_id and self.task_id == other.task_id ) else: return NotImplemented
[docs] def __hash__(self): return hash(self.__mapper__.primary_key)
[docs] def __repr__(self): args = [] for attr in [x.name for x in self.__mapper__.primary_key]: args.append(f"{attr}={getattr(self, attr)!r}") return f"{self.__class__.__name__}({', '.join(args)})"
[docs]class DatasetDagRunQueue(Base): """Model for storing dataset events that need processing."""
[docs] dataset_id = Column(Integer, primary_key=True, nullable=False)
[docs] target_dag_id = Column(StringID(), primary_key=True, nullable=False)
[docs] created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
[docs] __tablename__ = "dataset_dag_run_queue"
[docs] __table_args__ = ( PrimaryKeyConstraint(dataset_id, target_dag_id, name="datasetdagrunqueue_pkey", mssql_clustered=True), ForeignKeyConstraint( (dataset_id,), ["dataset.id"], name='ddrq_dataset_fkey', ondelete="CASCADE", ), ForeignKeyConstraint( (target_dag_id,), ["dag.dag_id"], name='ddrq_dag_fkey', ondelete="CASCADE",
), )
[docs] def __eq__(self, other): if isinstance(other, self.__class__): return self.dataset_id == other.dataset_id and self.target_dag_id == other.target_dag_id else: return NotImplemented
[docs] def __hash__(self): return hash(self.__mapper__.primary_key)
[docs] def __repr__(self): args = [] for attr in [x.name for x in self.__mapper__.primary_key]: args.append(f"{attr}={getattr(self, attr)!r}") return f"{self.__class__.__name__}({', '.join(args)})"
[docs]association_table = Table( "dagrun_dataset_event", Base.metadata, Column("dag_run_id", ForeignKey("dag_run.id", ondelete="CASCADE"), primary_key=True), Column("event_id", ForeignKey("dataset_event.id", ondelete="CASCADE"), primary_key=True), Index("idx_dagrun_dataset_events_dag_run_id", "dag_run_id"), Index("idx_dagrun_dataset_events_event_id", "event_id"),
)
[docs]class DatasetEvent(Base): """ A table to store datasets events. :param dataset_id: reference to DatasetModel record :param extra: JSON field for arbitrary extra info :param source_task_id: the task_id of the TI which updated the dataset :param source_dag_id: the dag_id of the TI which updated the dataset :param source_run_id: the run_id of the TI which updated the dataset :param source_map_index: the map_index of the TI which updated the dataset :param timestamp: the time the event was logged We use relationships instead of foreign keys so that dataset events are not deleted even if the foreign key object is. """
[docs] id = Column(Integer, primary_key=True, autoincrement=True)
[docs] dataset_id = Column(Integer, nullable=False)
[docs] extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={})
[docs] source_task_id = Column(StringID(), nullable=True)
[docs] source_dag_id = Column(StringID(), nullable=True)
[docs] source_run_id = Column(StringID(), nullable=True)
[docs] source_map_index = Column(Integer, nullable=True, server_default=text("-1"))
[docs] timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
[docs] __tablename__ = "dataset_event"
[docs] __table_args__ = ( Index('idx_dataset_id_timestamp', dataset_id, timestamp), {'sqlite_autoincrement': True}, # ensures PK values not reused
)
[docs] created_dagruns = relationship( "DagRun", secondary=association_table, backref="consumed_dataset_events",
)
[docs] source_task_instance = relationship( "TaskInstance", primaryjoin="""and_( DatasetEvent.source_dag_id == foreign(TaskInstance.dag_id), DatasetEvent.source_run_id == foreign(TaskInstance.run_id), DatasetEvent.source_task_id == foreign(TaskInstance.task_id), DatasetEvent.source_map_index == foreign(TaskInstance.map_index), )""", viewonly=True, lazy="select", uselist=False,
)
[docs] source_dag_run = relationship( "DagRun", primaryjoin="""and_( DatasetEvent.source_dag_id == foreign(DagRun.dag_id), DatasetEvent.source_run_id == foreign(DagRun.run_id), )""", viewonly=True, lazy="select", uselist=False,
)
[docs] dataset = relationship( DatasetModel, primaryjoin="DatasetEvent.dataset_id == foreign(DatasetModel.id)", viewonly=True, lazy="select", uselist=False,
) @property
[docs] def uri(self): return self.dataset.uri
[docs] def __repr__(self) -> str: args = [] for attr in [ 'id', 'dataset_id', 'extra', 'source_task_id', 'source_dag_id', 'source_run_id', 'source_map_index', ]: args.append(f"{attr}={getattr(self, attr)!r}") return f"{self.__class__.__name__}({', '.join(args)})"

Was this entry helpful?