Source code for airflow.providers.openlineage.extractors.manager

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Iterator

from airflow.providers.openlineage import conf
from airflow.providers.openlineage.extractors import BaseExtractor, OperatorLineage
from airflow.providers.openlineage.extractors.base import DefaultExtractor
from airflow.providers.openlineage.extractors.bash import BashExtractor
from airflow.providers.openlineage.extractors.python import PythonExtractor
from airflow.providers.openlineage.utils.utils import (
    get_unknown_source_attribute_run_facet,
    translate_airflow_dataset,
    try_import_from_string,
)
from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
    from openlineage.client.event_v2 import Dataset

    from airflow.lineage.entities import Table
    from airflow.models import Operator


def _iter_extractor_types() -> Iterator[type[BaseExtractor]]:
    if PythonExtractor is not None:
        yield PythonExtractor
    if BashExtractor is not None:
        yield BashExtractor


[docs]class ExtractorManager(LoggingMixin): """Class abstracting management of custom extractors.""" def __init__(self): super().__init__() self.extractors: dict[str, type[BaseExtractor]] = {} self.default_extractor = DefaultExtractor # Built-in Extractors like Bash and Python for extractor in _iter_extractor_types(): for operator_class in extractor.get_operator_classnames(): self.extractors[operator_class] = extractor for extractor_path in conf.custom_extractors(): extractor: type[BaseExtractor] | None = try_import_from_string(extractor_path) if not extractor: self.log.warning( "OpenLineage is unable to import custom extractor `%s`; will ignore it.", extractor_path ) continue for operator_class in extractor.get_operator_classnames(): if operator_class in self.extractors: self.log.warning( "Duplicate OpenLineage custom extractor found for `%s`. " "`%s` will be used instead of `%s`", operator_class, extractor_path, self.extractors[operator_class], ) self.extractors[operator_class] = extractor self.log.debug( "Registered custom OpenLineage extractor `%s` for class `%s`", extractor_path, operator_class, )
[docs] def add_extractor(self, operator_class: str, extractor: type[BaseExtractor]): self.extractors[operator_class] = extractor
[docs] def extract_metadata(self, dagrun, task, complete: bool = False, task_instance=None) -> OperatorLineage: extractor = self._get_extractor(task) task_info = ( f"task_type={task.task_type} " f"airflow_dag_id={task.dag_id} " f"task_id={task.task_id} " f"airflow_run_id={dagrun.run_id} " ) if extractor: # Extracting advanced metadata is only possible when extractor for particular operator # is defined. Without it, we can't extract any input or output data. try: self.log.debug("Using extractor %s %s", extractor.__class__.__name__, str(task_info)) if complete: task_metadata = extractor.extract_on_complete(task_instance) else: task_metadata = extractor.extract() self.log.debug("Found task metadata for operation %s: %s", task.task_id, str(task_metadata)) task_metadata = self.validate_task_metadata(task_metadata) if task_metadata: if (not task_metadata.inputs) and (not task_metadata.outputs): if (hook_lineage := self.get_hook_lineage()) is not None: inputs, outputs = hook_lineage task_metadata.inputs = inputs task_metadata.outputs = outputs else: self.extract_inlets_and_outlets(task_metadata, task.inlets, task.outlets) return task_metadata except Exception as e: self.log.warning( "Failed to extract metadata using found extractor %s - %s %s", extractor, e, task_info ) elif (hook_lineage := self.get_hook_lineage()) is not None: inputs, outputs = hook_lineage task_metadata = OperatorLineage(inputs=inputs, outputs=outputs) return task_metadata else: self.log.debug("Unable to find an extractor %s", task_info) # Only include the unkonwnSourceAttribute facet if there is no extractor task_metadata = OperatorLineage( run_facets=get_unknown_source_attribute_run_facet(task=task), ) inlets = task.get_inlet_defs() outlets = task.get_outlet_defs() self.extract_inlets_and_outlets(task_metadata, inlets, outlets) return task_metadata return OperatorLineage()
[docs] def get_extractor_class(self, task: Operator) -> type[BaseExtractor] | None: if task.task_type in self.extractors: return self.extractors[task.task_type] def method_exists(method_name): method = getattr(task, method_name, None) if method: return callable(method) if method_exists("get_openlineage_facets_on_start") or method_exists( "get_openlineage_facets_on_complete" ): return self.default_extractor return None
def _get_extractor(self, task: Operator) -> BaseExtractor | None: # TODO: Re-enable in Extractor PR # self.instantiate_abstract_extractors(task) extractor = self.get_extractor_class(task) self.log.debug("extractor for %s is %s", task.task_type, extractor) if extractor: return extractor(task) return None
[docs] def extract_inlets_and_outlets( self, task_metadata: OperatorLineage, inlets: list, outlets: list, ): if inlets or outlets: self.log.debug("Manually extracting lineage metadata from inlets and outlets") for i in inlets: d = self.convert_to_ol_dataset(i) if d: task_metadata.inputs.append(d) for o in outlets: d = self.convert_to_ol_dataset(o) if d: task_metadata.outputs.append(d)
[docs] def get_hook_lineage(self) -> tuple[list[Dataset], list[Dataset]] | None: try: from airflow.lineage.hook import get_hook_lineage_collector except ImportError: return None if not get_hook_lineage_collector().has_collected: return None return ( [ dataset for dataset_info in get_hook_lineage_collector().collected_datasets.inputs if (dataset := translate_airflow_dataset(dataset_info.dataset, dataset_info.context)) is not None ], [ dataset for dataset_info in get_hook_lineage_collector().collected_datasets.outputs if (dataset := translate_airflow_dataset(dataset_info.dataset, dataset_info.context)) is not None ], )
@staticmethod
[docs] def convert_to_ol_dataset_from_object_storage_uri(uri: str) -> Dataset | None: from urllib.parse import urlparse from openlineage.client.event_v2 import Dataset if "/" not in uri: return None try: scheme, netloc, path, params, _, _ = urlparse(uri) except Exception: return None common_schemas = { "s3": "s3", "gs": "gs", "gcs": "gs", "hdfs": "hdfs", "file": "file", } for found, final in common_schemas.items(): if scheme.startswith(found): return Dataset(namespace=f"{final}://{netloc}", name=path.lstrip("/")) return Dataset(namespace=scheme, name=f"{netloc}{path}")
@staticmethod
[docs] def convert_to_ol_dataset_from_table(table: Table) -> Dataset: from openlineage.client.event_v2 import Dataset from openlineage.client.facet_v2 import ( DatasetFacet, documentation_dataset, ownership_dataset, schema_dataset, ) facets: dict[str, DatasetFacet] = {} if table.columns: facets["schema"] = schema_dataset.SchemaDatasetFacet( fields=[ schema_dataset.SchemaDatasetFacetFields( name=column.name, type=column.data_type, description=column.description, ) for column in table.columns ] ) if table.owners: facets["ownership"] = ownership_dataset.OwnershipDatasetFacet( owners=[ ownership_dataset.Owner( # f.e. "user:John Doe <jdoe@company.com>" or just "user:<jdoe@company.com>" name=f"user:" f"{user.first_name + ' ' if user.first_name else ''}" f"{user.last_name + ' ' if user.last_name else ''}" f"<{user.email}>", type="", ) for user in table.owners ] ) if table.description: facets["documentation"] = documentation_dataset.DocumentationDatasetFacet( description=table.description ) return Dataset( namespace=f"{table.cluster}", name=f"{table.database}.{table.name}", facets=facets, )
@staticmethod
[docs] def convert_to_ol_dataset(obj) -> Dataset | None: from openlineage.client.event_v2 import Dataset from airflow.lineage.entities import File, Table if isinstance(obj, Dataset): return obj elif isinstance(obj, Table): return ExtractorManager.convert_to_ol_dataset_from_table(obj) elif isinstance(obj, File): return ExtractorManager.convert_to_ol_dataset_from_object_storage_uri(obj.url) else: return None
[docs] def validate_task_metadata(self, task_metadata) -> OperatorLineage | None: try: return OperatorLineage( inputs=task_metadata.inputs, outputs=task_metadata.outputs, run_facets=task_metadata.run_facets, job_facets=task_metadata.job_facets, ) except AttributeError: self.log.warning("Extractor returns non-valid metadata: %s", task_metadata) return None

Was this entry helpful?