Source code for airflow.models.dagcode
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import os
import struct
from datetime import datetime
from sqlalchemy import BigInteger, Column, String, UnicodeText, and_, exists
from airflow.exceptions import AirflowException, DagCodeNotFound
from airflow.models import Base
from airflow.settings import STORE_DAG_CODE
from airflow.utils import timezone
from airflow.utils.file import correct_maybe_zipped, open_maybe_zipped
from airflow.utils.db import provide_session
from airflow.utils.sqlalchemy import UtcDateTime
[docs]log = logging.getLogger(__name__) 
[docs]class DagCode(Base):
    """A table for DAGs code.
    dag_code table contains code of DAG files synchronized by scheduler.
    This feature is controlled by:
    * ``[core] store_serialized_dags = True``: enable this feature
    * ``[core] store_dag_code = True``: enable this feature
    For details on dag serialization see SerializedDagModel
    """
[docs]    __tablename__ = 'dag_code' 
[docs]    fileloc_hash = Column(
        BigInteger, nullable=False, primary_key=True, autoincrement=False) 
[docs]    fileloc = Column(String(2000), nullable=False) 
    # The max length of fileloc exceeds the limit of indexing.
[docs]    last_updated = Column(UtcDateTime, nullable=False) 
[docs]    source_code = Column(UnicodeText, nullable=False) 
    def __init__(self, full_filepath, source_code=None):
        self.fileloc = full_filepath
        self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc)
        self.last_updated = timezone.utcnow()
        self.source_code = source_code or DagCode.code(self.fileloc)
    @provide_session
[docs]    def sync_to_db(self, session=None):
        """Writes code into database.
        :param session: ORM Session
        """
        self.bulk_sync_to_db([self.fileloc], session) 
    @classmethod
    @provide_session
[docs]    def bulk_sync_to_db(cls, filelocs, session=None):
        """Writes code in bulk into database.
        :param filelocs: file paths of DAGs to sync
        :param session: ORM Session
        """
        filelocs = set(filelocs)
        filelocs_to_hashes = {
            fileloc: DagCode.dag_fileloc_hash(fileloc) for fileloc in filelocs
        }
        existing_orm_dag_codes = (
            session
            .query(DagCode)
            .filter(DagCode.fileloc_hash.in_(filelocs_to_hashes.values()))
            .with_for_update(of=DagCode)
            .all()
        )
        if existing_orm_dag_codes:
            existing_orm_dag_codes_map = {
                orm_dag_code.fileloc: orm_dag_code for orm_dag_code in existing_orm_dag_codes
            }
        else:
            existing_orm_dag_codes_map = dict()
        existing_orm_dag_codes_by_fileloc_hashes = {
            orm.fileloc_hash: orm for orm in existing_orm_dag_codes
        }
        exisitng_orm_filelocs = {
            orm.fileloc for orm in existing_orm_dag_codes_by_fileloc_hashes.values()
        }
        if not exisitng_orm_filelocs.issubset(filelocs):
            conflicting_filelocs = exisitng_orm_filelocs.difference(filelocs)
            hashes_to_filelocs = {
                DagCode.dag_fileloc_hash(fileloc): fileloc for fileloc in filelocs
            }
            message = ""
            for fileloc in conflicting_filelocs:
                message += ("Filename '{}' causes a hash collision in the " +
                            "database with '{}'. Please rename the file.")\
                    
.format(
                        hashes_to_filelocs[DagCode.dag_fileloc_hash(fileloc)],
                        fileloc)
            raise AirflowException(message)
        existing_filelocs = {
            dag_code.fileloc for dag_code in existing_orm_dag_codes
        }
        missing_filelocs = filelocs.difference(existing_filelocs)
        for fileloc in missing_filelocs:
            orm_dag_code = DagCode(fileloc, cls._get_code_from_file(fileloc))
            session.add(orm_dag_code)
        for fileloc in existing_filelocs:
            current_version = existing_orm_dag_codes_by_fileloc_hashes[filelocs_to_hashes[fileloc]]
            file_mod_time = datetime.fromtimestamp(
                os.path.getmtime(correct_maybe_zipped(fileloc)), tz=timezone.utc
            )
            if file_mod_time > current_version.last_updated:
                orm_dag_code = existing_orm_dag_codes_map[fileloc]
                orm_dag_code.last_updated = file_mod_time
                orm_dag_code.source_code = cls._get_code_from_file(orm_dag_code.fileloc)
                session.merge(orm_dag_code) 
    @classmethod
    @provide_session
[docs]    def remove_deleted_code(cls, alive_dag_filelocs, session=None):
        """Deletes code not included in alive_dag_filelocs.
        :param alive_dag_filelocs: file paths of alive DAGs
        :param session: ORM Session
        """
        alive_fileloc_hashes = [
            cls.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs]
        log.debug("Deleting code from %s table ", cls.__tablename__)
        session.query(cls).filter(
            and_(cls.fileloc_hash.notin_(alive_fileloc_hashes),
                 cls.fileloc.notin_(alive_dag_filelocs))).delete(synchronize_session='fetch') 
    @classmethod
    @provide_session
[docs]    def has_dag(cls, fileloc, session=None):
        """Checks a file exist in dag_code table.
        :param fileloc: the file to check
        :param session: ORM Session
        """
        fileloc_hash = cls.dag_fileloc_hash(fileloc)
        return session.query(exists().where(cls.fileloc_hash == fileloc_hash))\
 
            .scalar()
    @classmethod
[docs]    def get_code_by_fileloc(cls, fileloc):
        """Returns source code for a given fileloc.
        :param fileloc: file path of a DAG
        :return: source code as string
        """
        return cls.code(fileloc) 
    @classmethod
[docs]    def code(cls, fileloc):
        """Returns source code for this DagCode object.
        :return: source code as string
        """
        if STORE_DAG_CODE:
            return cls._get_code_from_db(fileloc)
        else:
            return cls._get_code_from_file(fileloc) 
    @staticmethod
[docs]    def _get_code_from_file(fileloc):
        with open_maybe_zipped(fileloc, 'r') as f:
            code = f.read()
        return code 
    @classmethod
    @provide_session
[docs]    def _get_code_from_db(cls, fileloc, session=None):
        dag_code = session.query(cls) \
            
.filter(cls.fileloc_hash == cls.dag_fileloc_hash(fileloc)) \
            
.first()
        if not dag_code:
            raise DagCodeNotFound()
        else:
            code = dag_code.source_code
        return code 
    @staticmethod
[docs]    def dag_fileloc_hash(full_filepath):
        """"Hashing file location for indexing.
        :param full_filepath: full filepath of DAG file
        :return: hashed full_filepath
        """
        # Hashing is needed because the length of fileloc is 2000 as an Airflow convention,
        # which is over the limit of indexing.
        import hashlib
        # Only 7 bytes because MySQL BigInteger can hold only 8 bytes (signed).
        return struct.unpack('>Q', hashlib.sha1(
            full_filepath.encode('utf-8')).digest()[-8:])[0] >> 8