Source code for airflow.models.dagcode

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging
import os
import struct
from datetime import datetime
from typing import Iterable

from sqlalchemy import BigInteger, Column, String, Text
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.sql.expression import literal

from airflow.exceptions import AirflowException, DagCodeNotFound
from airflow.models.base import Base
from airflow.utils import timezone
from airflow.utils.file import correct_maybe_zipped, open_maybe_zipped
from airflow.utils.session import provide_session
from airflow.utils.sqlalchemy import UtcDateTime

[docs]log = logging.getLogger(__name__)
[docs]class DagCode(Base): """A table for DAGs code. dag_code table contains code of DAG files synchronized by scheduler. For details on dag serialization see SerializedDagModel """
[docs] __tablename__ = 'dag_code'
[docs] fileloc_hash = Column(BigInteger, nullable=False, primary_key=True, autoincrement=False)
[docs] fileloc = Column(String(2000), nullable=False)
# The max length of fileloc exceeds the limit of indexing.
[docs] last_updated = Column(UtcDateTime, nullable=False)
[docs] source_code = Column(Text().with_variant(MEDIUMTEXT(), 'mysql'), nullable=False)
def __init__(self, full_filepath: str, source_code: str | None = None): self.fileloc = full_filepath self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc) self.last_updated = timezone.utcnow() self.source_code = source_code or DagCode.code(self.fileloc) @provide_session
[docs] def sync_to_db(self, session=None): """Writes code into database. :param session: ORM Session """ self.bulk_sync_to_db([self.fileloc], session)
@classmethod @provide_session
[docs] def bulk_sync_to_db(cls, filelocs: Iterable[str], session=None): """Writes code in bulk into database. :param filelocs: file paths of DAGs to sync :param session: ORM Session """ filelocs = set(filelocs) filelocs_to_hashes = {fileloc: DagCode.dag_fileloc_hash(fileloc) for fileloc in filelocs} existing_orm_dag_codes = ( session.query(DagCode) .filter(DagCode.fileloc_hash.in_(filelocs_to_hashes.values())) .with_for_update(of=DagCode) .all() ) if existing_orm_dag_codes: existing_orm_dag_codes_map = { orm_dag_code.fileloc: orm_dag_code for orm_dag_code in existing_orm_dag_codes } else: existing_orm_dag_codes_map = {} existing_orm_dag_codes_by_fileloc_hashes = {orm.fileloc_hash: orm for orm in existing_orm_dag_codes} existing_orm_filelocs = {orm.fileloc for orm in existing_orm_dag_codes_by_fileloc_hashes.values()} if not existing_orm_filelocs.issubset(filelocs): conflicting_filelocs = existing_orm_filelocs.difference(filelocs) hashes_to_filelocs = {DagCode.dag_fileloc_hash(fileloc): fileloc for fileloc in filelocs} message = "" for fileloc in conflicting_filelocs: filename = hashes_to_filelocs[DagCode.dag_fileloc_hash(fileloc)] message += ( f"Filename '{filename}' causes a hash collision in the " f"database with '{fileloc}'. Please rename the file." ) raise AirflowException(message) existing_filelocs = {dag_code.fileloc for dag_code in existing_orm_dag_codes} missing_filelocs = filelocs.difference(existing_filelocs) for fileloc in missing_filelocs: orm_dag_code = DagCode(fileloc, cls._get_code_from_file(fileloc)) session.add(orm_dag_code) for fileloc in existing_filelocs: current_version = existing_orm_dag_codes_by_fileloc_hashes[filelocs_to_hashes[fileloc]] file_mod_time = datetime.fromtimestamp( os.path.getmtime(correct_maybe_zipped(fileloc)), tz=timezone.utc ) if file_mod_time > current_version.last_updated: orm_dag_code = existing_orm_dag_codes_map[fileloc] orm_dag_code.last_updated = file_mod_time orm_dag_code.source_code = cls._get_code_from_file(orm_dag_code.fileloc) session.merge(orm_dag_code)
@classmethod @provide_session
[docs] def remove_deleted_code(cls, alive_dag_filelocs: list[str], session=None): """Deletes code not included in alive_dag_filelocs. :param alive_dag_filelocs: file paths of alive DAGs :param session: ORM Session """ alive_fileloc_hashes = [cls.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs] log.debug("Deleting code from %s table ", cls.__tablename__) session.query(cls).filter( cls.fileloc_hash.notin_(alive_fileloc_hashes), cls.fileloc.notin_(alive_dag_filelocs) ).delete(synchronize_session='fetch')
@classmethod @provide_session
[docs] def has_dag(cls, fileloc: str, session=None) -> bool: """Checks a file exist in dag_code table. :param fileloc: the file to check :param session: ORM Session """ fileloc_hash = cls.dag_fileloc_hash(fileloc) return session.query(literal(True)).filter(cls.fileloc_hash == fileloc_hash).one_or_none() is not None
[docs] def get_code_by_fileloc(cls, fileloc: str) -> str: """Returns source code for a given fileloc. :param fileloc: file path of a DAG :return: source code as string """ return cls.code(fileloc)
[docs] def code(cls, fileloc) -> str: """Returns source code for this DagCode object. :return: source code as string """ return cls._get_code_from_db(fileloc)
@staticmethod def _get_code_from_file(fileloc): with open_maybe_zipped(fileloc, 'r') as f: code = return code @classmethod @provide_session def _get_code_from_db(cls, fileloc, session=None): dag_code = session.query(cls).filter(cls.fileloc_hash == cls.dag_fileloc_hash(fileloc)).first() if not dag_code: raise DagCodeNotFound() else: code = dag_code.source_code return code @staticmethod
[docs] def dag_fileloc_hash(full_filepath: str) -> int: """Hashing file location for indexing. :param full_filepath: full filepath of DAG file :return: hashed full_filepath """ # Hashing is needed because the length of fileloc is 2000 as an Airflow convention, # which is over the limit of indexing. import hashlib # Only 7 bytes because MySQL BigInteger can hold only 8 bytes (signed). return struct.unpack('>Q', hashlib.sha1(full_filepath.encode('utf-8')).digest()[-8:])[0] >> 8

Was this entry helpful?