Source code for airflow.models.dagcode

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import os
import struct
from datetime import datetime

from sqlalchemy import BigInteger, Column, String, UnicodeText, and_, exists

from airflow.exceptions import AirflowException, DagCodeNotFound
from airflow.models import Base
from airflow.settings import STORE_DAG_CODE
from airflow.utils import timezone
from airflow.utils.file import correct_maybe_zipped, open_maybe_zipped
from airflow.utils.db import provide_session
from airflow.utils.sqlalchemy import UtcDateTime

[docs]class DagCode(Base): """A table for DAGs code. dag_code table contains code of DAG files synchronized by scheduler. This feature is controlled by: * ``[core] store_serialized_dags = True``: enable this feature * ``[core] store_dag_code = True``: enable this feature For details on dag serialization see SerializedDagModel """
[docs] __tablename__ = 'dag_code'
[docs] fileloc_hash = Column( BigInteger, nullable=False, primary_key=True, autoincrement=False)
[docs] fileloc = Column(String(2000), nullable=False)
# The max length of fileloc exceeds the limit of indexing.
[docs] last_updated = Column(UtcDateTime, nullable=False)
[docs] source_code = Column(UnicodeText, nullable=False)
def __init__(self, full_filepath, source_code=None): self.fileloc = full_filepath self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc) self.last_updated = timezone.utcnow() self.source_code = source_code or DagCode.code(self.fileloc) @provide_session
[docs] def sync_to_db(self, session=None): """Writes code into database. :param session: ORM Session """ self.bulk_sync_to_db([self.fileloc], session)
@classmethod @provide_session
[docs] def bulk_sync_to_db(cls, filelocs, session=None): """Writes code in bulk into database. :param filelocs: file paths of DAGs to sync :param session: ORM Session """ filelocs = set(filelocs) filelocs_to_hashes = { fileloc: DagCode.dag_fileloc_hash(fileloc) for fileloc in filelocs } existing_orm_dag_codes = ( session .query(DagCode) .filter(DagCode.fileloc_hash.in_(filelocs_to_hashes.values())) .with_for_update(of=DagCode) .all() ) if existing_orm_dag_codes: existing_orm_dag_codes_map = { orm_dag_code.fileloc: orm_dag_code for orm_dag_code in existing_orm_dag_codes } else: existing_orm_dag_codes_map = dict() existing_orm_dag_codes_by_fileloc_hashes = { orm.fileloc_hash: orm for orm in existing_orm_dag_codes } exisitng_orm_filelocs = { orm.fileloc for orm in existing_orm_dag_codes_by_fileloc_hashes.values() } if not exisitng_orm_filelocs.issubset(filelocs): conflicting_filelocs = exisitng_orm_filelocs.difference(filelocs) hashes_to_filelocs = { DagCode.dag_fileloc_hash(fileloc): fileloc for fileloc in filelocs } message = "" for fileloc in conflicting_filelocs: message += ("Filename '{}' causes a hash collision in the " + "database with '{}'. Please rename the file.")\ .format( hashes_to_filelocs[DagCode.dag_fileloc_hash(fileloc)], fileloc) raise AirflowException(message) existing_filelocs = { dag_code.fileloc for dag_code in existing_orm_dag_codes } missing_filelocs = filelocs.difference(existing_filelocs) for fileloc in missing_filelocs: orm_dag_code = DagCode(fileloc, cls._get_code_from_file(fileloc)) session.add(orm_dag_code) for fileloc in existing_filelocs: current_version = existing_orm_dag_codes_by_fileloc_hashes[filelocs_to_hashes[fileloc]] file_mod_time = datetime.fromtimestamp( os.path.getmtime(correct_maybe_zipped(fileloc)), tz=timezone.utc ) if file_mod_time > current_version.last_updated: orm_dag_code = existing_orm_dag_codes_map[fileloc] orm_dag_code.last_updated = file_mod_time orm_dag_code.source_code = cls._get_code_from_file(orm_dag_code.fileloc) session.merge(orm_dag_code)
@classmethod @provide_session
[docs] def remove_deleted_code(cls, alive_dag_filelocs, session=None): """Deletes code not included in alive_dag_filelocs. :param alive_dag_filelocs: file paths of alive DAGs :param session: ORM Session """ alive_fileloc_hashes = [ cls.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs] log.debug("Deleting code from %s table ", cls.__tablename__) session.query(cls).filter( and_(cls.fileloc_hash.notin_(alive_fileloc_hashes), cls.fileloc.notin_(alive_dag_filelocs))).delete(synchronize_session='fetch')
@classmethod @provide_session
[docs] def has_dag(cls, fileloc, session=None): """Checks a file exist in dag_code table. :param fileloc: the file to check :param session: ORM Session """ fileloc_hash = cls.dag_fileloc_hash(fileloc) return session.query(exists().where(cls.fileloc_hash == fileloc_hash))\
.scalar() @classmethod
[docs] def get_code_by_fileloc(cls, fileloc): """Returns source code for a given fileloc. :param fileloc: file path of a DAG :return: source code as string """ return cls.code(fileloc)
[docs] def code(cls, fileloc): """Returns source code for this DagCode object. :return: source code as string """ if STORE_DAG_CODE: return cls._get_code_from_db(fileloc) else: return cls._get_code_from_file(fileloc)
[docs] def _get_code_from_file(fileloc): with open_maybe_zipped(fileloc, 'r') as f: code = return code
@classmethod @provide_session
[docs] def _get_code_from_db(cls, fileloc, session=None): dag_code = session.query(cls) \ .filter(cls.fileloc_hash == cls.dag_fileloc_hash(fileloc)) \ .first() if not dag_code: raise DagCodeNotFound() else: code = dag_code.source_code return code
[docs] def dag_fileloc_hash(full_filepath): """"Hashing file location for indexing. :param full_filepath: full filepath of DAG file :return: hashed full_filepath """ # Hashing is needed because the length of fileloc is 2000 as an Airflow convention, # which is over the limit of indexing. import hashlib # Only 7 bytes because MySQL BigInteger can hold only 8 bytes (signed). return struct.unpack('>Q', hashlib.sha1( full_filepath.encode('utf-8')).digest()[-8:])[0] >> 8

