Source code for airflow.providers.openlineage.utils.utils

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import datetime
import json
import logging
from contextlib import suppress
from functools import wraps
from typing import TYPE_CHECKING, Any, Iterable

import attrs
from openlineage.client.utils import RedactMixin  # TODO: move this maybe to Airflow's logic?

from airflow.models import DAG, BaseOperator, MappedOperator
from airflow.providers.openlineage import conf
from airflow.providers.openlineage.plugins.facets import (
    AirflowMappedTaskRunFacet,
    AirflowRunFacet,
    UnknownOperatorAttributeRunFacet,
    UnknownOperatorInstance,
)
from airflow.providers.openlineage.utils.selective_enable import (
    is_dag_lineage_enabled,
    is_task_lineage_enabled,
)
from airflow.utils.context import AirflowContextDeprecationWarning
from airflow.utils.log.secrets_masker import Redactable, Redacted, SecretsMasker, should_hide_value_for_key

if TYPE_CHECKING:
    from airflow.models import DagRun, TaskInstance


[docs]log = logging.getLogger(__name__)
_NOMINAL_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
[docs]def get_operator_class(task: BaseOperator) -> type: if task.__class__.__name__ in ("DecoratedMappedOperator", "MappedOperator"): return task.operator_class return task.__class__
[docs]def get_job_name(task: TaskInstance) -> str: return f"{task.dag_id}.{task.task_id}"
[docs]def get_custom_facets(task_instance: TaskInstance | None = None) -> dict[str, Any]: custom_facets = {} # check for -1 comes from SmartSensor compatibility with dynamic task mapping # this comes from Airflow code if hasattr(task_instance, "map_index") and getattr(task_instance, "map_index") != -1: custom_facets["airflow_mappedTask"] = AirflowMappedTaskRunFacet.from_task_instance(task_instance) return custom_facets
[docs]def get_fully_qualified_class_name(operator: BaseOperator | MappedOperator) -> str: return operator.__class__.__module__ + "." + operator.__class__.__name__
[docs]def is_operator_disabled(operator: BaseOperator | MappedOperator) -> bool: return get_fully_qualified_class_name(operator) in conf.disabled_operators()
[docs]def is_selective_lineage_enabled(obj: DAG | BaseOperator | MappedOperator) -> bool: """If selective enable is active check if DAG or Task is enabled to emit events.""" if not conf.selective_enable(): return True if isinstance(obj, DAG): return is_dag_lineage_enabled(obj) elif isinstance(obj, (BaseOperator, MappedOperator)): return is_task_lineage_enabled(obj) else: raise TypeError("is_selective_lineage_enabled can only be used on DAG or Operator objects")
[docs]class InfoJsonEncodable(dict): """ Airflow objects might not be json-encodable overall. The class provides additional attributes to control what and how is encoded: * renames: a dictionary of attribute name changes * | casts: a dictionary consisting of attribute names | and corresponding methods that should change | object value * includes: list of attributes to be included in encoding * excludes: list of attributes to be excluded from encoding Don't use both includes and excludes. """
[docs] renames: dict[str, str] = {}
[docs] casts: dict[str, Any] = {}
[docs] includes: list[str] = []
[docs] excludes: list[str] = []
def __init__(self, obj): self.obj = obj self._fields = [] self._cast_fields() self._rename_fields() self._include_fields() dict.__init__( self, **{field: InfoJsonEncodable._cast_basic_types(getattr(self, field)) for field in self._fields}, ) @staticmethod def _cast_basic_types(value): if isinstance(value, datetime.datetime): return value.isoformat() if isinstance(value, (set, list, tuple)): return str(list(value)) return value def _rename_fields(self): for field, renamed in self.renames.items(): if hasattr(self.obj, field): setattr(self, renamed, getattr(self.obj, field)) self._fields.append(renamed) def _cast_fields(self): for field, func in self.casts.items(): setattr(self, field, func(self.obj)) self._fields.append(field) def _include_fields(self): if self.includes and self.excludes: raise ValueError("Don't use both includes and excludes.") if self.includes: for field in self.includes: if field not in self._fields and hasattr(self.obj, field): setattr(self, field, getattr(self.obj, field)) self._fields.append(field) else: for field, val in self.obj.__dict__.items(): if field not in self._fields and field not in self.excludes and field not in self.renames: setattr(self, field, val) self._fields.append(field)
[docs]class DagInfo(InfoJsonEncodable): """Defines encoding DAG object to JSON."""
[docs] includes = ["dag_id", "schedule_interval", "tags", "start_date"]
[docs] casts = {"timetable": lambda dag: dag.timetable.serialize() if getattr(dag, "timetable", None) else None}
[docs] renames = {"_dag_id": "dag_id"}
[docs]class DagRunInfo(InfoJsonEncodable): """Defines encoding DagRun object to JSON."""
[docs] includes = [ "conf", "dag_id", "data_interval_start", "data_interval_end", "external_trigger", "run_id", "run_type", "start_date", ]
[docs]class TaskInstanceInfo(InfoJsonEncodable): """Defines encoding TaskInstance object to JSON."""
[docs] includes = ["duration", "try_number", "pool"]
[docs] casts = { "map_index": lambda ti: ti.map_index if hasattr(ti, "map_index") and getattr(ti, "map_index") != -1 else None }
[docs]class TaskInfo(InfoJsonEncodable): """Defines encoding BaseOperator/AbstractOperator object to JSON."""
[docs] renames = { "_BaseOperator__from_mapped": "mapped", "_downstream_task_ids": "downstream_task_ids", "_upstream_task_ids": "upstream_task_ids", "_is_setup": "is_setup", "_is_teardown": "is_teardown", }
[docs] includes = [ "depends_on_past", "downstream_task_ids", "execution_timeout", "executor_config", "ignore_first_depends_on_past", "max_active_tis_per_dag", "max_active_tis_per_dagrun", "max_retry_delay", "multiple_outputs", "owner", "priority_weight", "queue", "retries", "retry_exponential_backoff", "run_as_user", "task_id", "trigger_rule", "upstream_task_ids", "wait_for_downstream", "wait_for_past_depends_before_skipping", "weight_rule", ]
[docs] casts = { "operator_class": lambda task: task.task_type, "task_group": lambda task: TaskGroupInfo(task.task_group) if hasattr(task, "task_group") and getattr(task.task_group, "_group_id", None) else None, }
[docs]class TaskGroupInfo(InfoJsonEncodable): """Defines encoding TaskGroup object to JSON."""
[docs] renames = { "_group_id": "group_id", }
[docs] includes = [ "downstream_group_ids", "downstream_task_ids", "prefix_group_id", "tooltip", "upstream_group_ids", "upstream_task_ids", ]
[docs]def get_airflow_run_facet( dag_run: DagRun, dag: DAG, task_instance: TaskInstance, task: BaseOperator, task_uuid: str, ): return { "airflow": attrs.asdict( AirflowRunFacet( dag=DagInfo(dag), dagRun=DagRunInfo(dag_run), taskInstance=TaskInstanceInfo(task_instance), task=TaskInfo(task), taskUuid=task_uuid, ) ) }
[docs]def get_unknown_source_attribute_run_facet(task: BaseOperator, name: str | None = None): if not name: name = get_operator_class(task).__name__ return { "unknownSourceAttribute": attrs.asdict( UnknownOperatorAttributeRunFacet( unknownItems=[ UnknownOperatorInstance( name=name, properties=TaskInfo(task), ) ] ) ) }
[docs]class OpenLineageRedactor(SecretsMasker): """ This class redacts sensitive data similar to SecretsMasker in Airflow logs. The difference is that our default max recursion depth is way higher - due to the structure of OL events we need more depth. Additionally, we allow data structures to specify data that needs not to be redacted by specifying _skip_redact list by deriving RedactMixin. """ @classmethod
[docs] def from_masker(cls, other: SecretsMasker) -> OpenLineageRedactor: instance = cls() instance.patterns = other.patterns instance.replacer = other.replacer return instance
def _redact(self, item: Redactable, name: str | None, depth: int, max_depth: int) -> Redacted: if depth > max_depth: return item try: # It's impossible to check the type of variable in a dict without accessing it, and # this already causes warning - so suppress it with suppress(AirflowContextDeprecationWarning): if type(item).__name__ == "Proxy": # Those are deprecated values in _DEPRECATION_REPLACEMENTS # in airflow.utils.context.Context return "<<non-redactable: Proxy>>" if name and should_hide_value_for_key(name): return self._redact_all(item, depth, max_depth) if attrs.has(type(item)): # TODO: FIXME when mypy gets compatible with new attrs for dict_key, subval in attrs.asdict( item, # type: ignore[arg-type] recurse=False, ).items(): if _is_name_redactable(dict_key, item): setattr( item, dict_key, self._redact(subval, name=dict_key, depth=(depth + 1), max_depth=max_depth), ) return item elif is_json_serializable(item) and hasattr(item, "__dict__"): for dict_key, subval in item.__dict__.items(): if type(subval).__name__ == "Proxy": return "<<non-redactable: Proxy>>" if _is_name_redactable(dict_key, item): setattr( item, dict_key, self._redact(subval, name=dict_key, depth=(depth + 1), max_depth=max_depth), ) return item else: return super()._redact(item, name, depth, max_depth) except Exception as exc: log.warning("Unable to redact %r. Error was: %s: %s", item, type(exc).__name__, exc) return item
[docs]def is_json_serializable(item): try: json.dumps(item) return True except (TypeError, ValueError): return False
def _is_name_redactable(name, redacted): if not issubclass(redacted.__class__, RedactMixin): return not name.startswith("_") return name not in redacted.skip_redact
[docs]def get_filtered_unknown_operator_keys(operator: BaseOperator) -> dict: not_required_keys = {"dag", "task_group"} return {attr: value for attr, value in operator.__dict__.items() if attr not in not_required_keys}
[docs]def normalize_sql(sql: str | Iterable[str]): if isinstance(sql, str): sql = [stmt for stmt in sql.split(";") if stmt != ""] sql = [obj for stmt in sql for obj in stmt.split(";") if obj != ""] return ";\n".join(sql)
[docs]def should_use_external_connection(hook) -> bool: # TODO: Add checking overrides return hook.__class__.__name__ not in ["SnowflakeHook", "SnowflakeSqlApiHook"]

Was this entry helpful?