Source code for airflow.providers.informatica.lineage.sql_parser
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
from dataclasses import dataclass
import sqlglot
import sqlglot.expressions as exp
[docs]
log = logging.getLogger(__name__)
_JINJA_MARKERS = ("{{", "{%")
@dataclass
[docs]
class TableRef:
"""Represents a parsed table reference with optional schema and database qualifiers."""
[docs]
schema: str | None = None
[docs]
database: str | None = None
[docs]
def parse_sql_tables(
sql: str | list[str],
dialect: str | None = None,
) -> tuple[list[TableRef], list[TableRef]]:
"""
Parse SQL and return ``(source_tables, target_tables)``.
Source tables are those read by FROM/JOIN clauses.
Target tables are those written by INSERT INTO, CREATE TABLE AS, or MERGE INTO.
Returns empty lists when SQL cannot be parsed instead of raising.
"""
statements = [sql] if isinstance(sql, str) else sql
sources: list[TableRef] = []
targets: list[TableRef] = []
for idx, stmt in enumerate(statements):
if not isinstance(stmt, str) or not stmt.strip():
continue
if any(marker in stmt for marker in _JINJA_MARKERS):
log.debug(
"SQL statement %d contains unrendered Jinja templates; skipping lineage extraction.", idx
)
continue
try:
for parsed in sqlglot.parse(stmt, dialect=dialect, error_level=sqlglot.ErrorLevel.WARN):
if parsed is None:
continue
if not isinstance(parsed, exp.Expression):
continue
stmt_sources, stmt_targets = _extract_tables(parsed)
sources.extend(stmt_sources)
targets.extend(stmt_targets)
except Exception:
log.debug("Failed to parse SQL statement %d", idx, exc_info=True)
return _dedup(sources), _dedup(targets)
def _extract_tables(parsed: exp.Expression) -> tuple[list[TableRef], list[TableRef]]:
cte_names: set[str] = set()
with_node = parsed.find(exp.With)
if with_node:
for cte in with_node.find_all(exp.CTE):
if cte.alias:
cte_names.add(cte.alias.lower())
target_node_id: int | None = None
if isinstance(parsed, (exp.Insert, exp.Create, exp.Merge)):
write_target = _get_write_target(parsed)
if write_target is not None:
target_node_id = id(write_target)
sources: list[TableRef] = []
targets: list[TableRef] = []
for table in parsed.find_all(exp.Table):
if table.name.lower() in cte_names:
continue
ref = TableRef(
table=table.name,
schema=table.db or None,
database=table.catalog or None,
)
if id(table) == target_node_id:
targets.append(ref)
else:
sources.append(ref)
return sources, targets
def _get_write_target(node: exp.Expression) -> exp.Table | None:
if isinstance(node, (exp.Insert, exp.Merge)):
candidate = node.this
elif isinstance(node, exp.Create):
candidate = node.this
else:
return None
# INSERT INTO target(col1, col2) ... is represented as Schema(Table(...), ...)
if isinstance(candidate, exp.Schema):
candidate = candidate.this
return candidate if isinstance(candidate, exp.Table) else None
def _dedup(refs: list[TableRef]) -> list[TableRef]:
seen: set[tuple[str | None, str | None, str]] = set()
result: list[TableRef] = []
for ref in refs:
key = (ref.database, ref.schema, ref.table)
if key not in seen:
seen.add(key)
result.append(ref)
return result