Source code for airflow.providers.informatica.lineage.resolver

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from typing import Any

from airflow.providers.informatica.lineage.sql_parser import TableRef, parse_sql_tables

[docs] log = logging.getLogger(__name__)
try: from airflow.providers.common.sql.operators.sql import BaseSQLOperator as _BaseSQLOperator _HAS_BASE_SQL_OPERATOR = True except ImportError: _BaseSQLOperator = None # type: ignore[assignment, misc] _HAS_BASE_SQL_OPERATOR = False # Operator attribute names scanned in order to locate a connection ID. # conn_id_field (BaseSQLOperator) is tried first; this list is the fallback. _CONN_ID_ATTRS: tuple[str, ...] = ( "conn_id", "source_conn_id", "mysql_conn_id", "postgres_conn_id", "mssql_conn_id", "oracle_conn_id", "sqlite_conn_id", "snowflake_conn_id", "databricks_conn_id", "exasol_conn_id", "hiveserver2_conn_id", ) # Keyword fragments found in a conn_id string mapped to sqlglot dialect names. _CONN_TYPE_TO_DIALECT: dict[str, str] = { "postgres": "postgres", "redshift": "redshift", "mysql": "mysql", "mssql": "tsql", "snowflake": "snowflake", "bigquery": "bigquery", "databricks": "databricks", "sqlite": "sqlite", "oracle": "oracle", "trino": "trino", "presto": "presto", "hive": "hive", "spark": "spark", } # Operator attribute names checked as explicit write-target table when SQL # parsing yields no targets (e.g. GenericTransfer, HiveToMySqlOperator). _TARGET_TABLE_ATTRS: tuple[str, ...] = ( "destination_table", "mysql_table", "hive_table", "target_table", )
[docs] class BaseLineageResolver(ABC): """Base class for operator lineage resolvers.""" @abstractmethod
[docs] def resolve(self, task: Any) -> tuple[list[TableRef], list[TableRef]] | None: """Return ``(source_refs, target_refs)`` or ``None`` if the resolver does not apply."""
[docs] class SQLLineageResolver(BaseLineageResolver): """ Resolves lineage for any operator that exposes a ``sql`` attribute. Detection is tiered: - Tier 1: operators inheriting from ``BaseSQLOperator`` — ``conn_id_field`` points to the right connection attribute. - Tier 2: operators with a ``sql`` attribute but no ``BaseSQLOperator`` base (e.g. ``GenericTransfer``, ``BaseSQLToGCSOperator``) — dialect is inferred from the first recognizable connection ID string found. Returns ``None`` when there is no SQL, when Jinja templates are detected, or when parsing produces no table references. """
[docs] def resolve(self, task: Any) -> tuple[list[TableRef], list[TableRef]] | None: sql = getattr(task, "sql", None) if not sql: return None dialect = _infer_dialect(task) default_database: str | None = getattr(task, "database", None) sources, targets = parse_sql_tables(sql, dialect=dialect) if not targets: for attr in _TARGET_TABLE_ATTRS: table_name = getattr(task, attr, None) if table_name and isinstance(table_name, str): targets.append(TableRef(table=table_name)) break if not sources and not targets: return None # Fill in default_database for refs that have none set if default_database: sources = [TableRef(t.table, t.schema, t.database or default_database) for t in sources] targets = [TableRef(t.table, t.schema, t.database or default_database) for t in targets] return sources, targets
def _infer_dialect(task: Any) -> str | None: conn_id_field = getattr(task, "conn_id_field", None) if conn_id_field: conn_id = getattr(task, conn_id_field, None) if conn_id and isinstance(conn_id, str): result = _dialect_from_conn_id_str(conn_id) if result: return result for attr in _CONN_ID_ATTRS: conn_id = getattr(task, attr, None) if conn_id and isinstance(conn_id, str): result = _dialect_from_conn_id_str(conn_id) if result: return result return None def _dialect_from_conn_id_str(conn_id: str) -> str | None: conn_id_lower = conn_id.lower() for keyword, dialect in _CONN_TYPE_TO_DIALECT.items(): if keyword in conn_id_lower: return dialect return None _SQL_RESOLVER = SQLLineageResolver()
[docs] def get_resolver(task: Any) -> BaseLineageResolver | None: """Return a resolver for *task*, or ``None`` when no resolver applies.""" if getattr(task, "sql", None): return _SQL_RESOLVER return None

Was this entry helpful?