Source code for airflow.providers.informatica.hooks.edc
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import base64
import re
from collections.abc import Mapping, MutableMapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from urllib.parse import urlencode
from requests.exceptions import RequestException
from airflow.providers.common.compat.sdk import conf
from airflow.providers.http.hooks.http import HttpHook
if TYPE_CHECKING:
from requests import Response
from airflow.providers.common.compat.sdk import Connection
[docs]
class InformaticaEDCError(RuntimeError):
"""Raised when the Informatica Enterprise Data Catalog API returns an error."""
@dataclass(frozen=True)
[docs]
class InformaticaConnectionConfig:
"""Container for Informatica EDC connection settings."""
@property
[docs]
def auth_header(self) -> str | None:
"""Return the authorization header for the configured credentials."""
if not self.username:
return None
domain_prefix = f"{self.security_domain}\\" if self.security_domain else ""
credential = f"{domain_prefix}{self.username}:{self.password or ''}"
token = base64.b64encode(bytes(credential, "utf-8")).decode("utf-8")
return f"Basic {token}"
[docs]
class InformaticaEDCHook(HttpHook):
"""Hook providing a minimal client for the Informatica EDC REST API."""
[docs]
default_conn_name = conf.get("informatica", "default_conn_id", fallback="informatica_edc_default")
_lineage_association = "core.DataSetDataFlow"
def __init__(
self,
informatica_edc_conn_id: str = default_conn_name,
*,
request_timeout: int | None = None,
**kwargs,
) -> None:
super().__init__(http_conn_id=informatica_edc_conn_id, method="GET", **kwargs)
self._config: InformaticaConnectionConfig | None = None
self._request_timeout = request_timeout or conf.getint("informatica", "request_timeout", fallback=30)
@property
[docs]
def config(self) -> InformaticaConnectionConfig:
"""Return cached connection configuration."""
if self._config is None:
connection = self.get_connection(self.http_conn_id)
self._config = self._build_connection_config(connection)
return self._config
def _build_connection_config(self, connection: Connection) -> InformaticaConnectionConfig:
"""Build a configuration object from an Airflow connection."""
host = connection.host or ""
schema = connection.schema or "https"
if host.startswith("http://") or host.startswith("https://"):
base_url = host
else:
base_url = f"{schema}://{host}" if host else f"{schema}://"
if connection.port:
base_url = f"{base_url}:{connection.port}"
extras: MutableMapping[str, Any] = connection.extra_dejson or {}
verify_ssl_raw = extras.get("verify_ssl", extras.get("verify", True))
verify_ssl = str(verify_ssl_raw).lower() not in {"0", "false", "no"}
provider_id = str(extras.get("provider_id", "enrichment"))
modified_by = str(extras.get("modified_by", connection.login or "airflow"))
security_domain = extras.get("security_domain") or extras.get("domain")
return InformaticaConnectionConfig(
base_url=base_url.rstrip("/"),
username=connection.login,
password=connection.password,
security_domain=str(security_domain) if security_domain else None,
verify_ssl=verify_ssl,
request_timeout=self._request_timeout,
provider_id=provider_id,
modified_by=modified_by,
)
[docs]
def get_conn(
self,
headers: dict[str, Any] | None = None,
extra_options: dict[str, Any] | None = None,
) -> Any:
"""Return a configured session augmented with Informatica specific headers."""
session = super().get_conn(headers=headers, extra_options=extra_options)
session.verify = self.config.verify_ssl
session.headers.update({"Accept": "application/json", "Content-Type": "application/json"})
if self.config.auth_header:
session.headers["Authorization"] = self.config.auth_header
return session
def _build_url(self, endpoint: str) -> str:
endpoint = endpoint if endpoint.startswith("/") else f"/{endpoint}"
return f"{self.config.base_url}{endpoint}"
def _request(
self,
method: str,
endpoint: str,
*,
params: Mapping[str, Any] | None = None,
json: Mapping[str, Any] | None = None,
) -> Response:
"""Execute an HTTP request and raise :class:`InformaticaEDCError` on failure."""
url = self._build_url(endpoint)
session = self.get_conn()
try:
response = session.request(
method=method.upper(),
url=url,
params=params,
json=json,
timeout=self.config.request_timeout,
)
except RequestException as exc:
raise InformaticaEDCError(f"Failed to call Informatica EDC endpoint {endpoint}: {exc}") from exc
if response.ok:
return response
message = response.text or response.reason
raise InformaticaEDCError(
f"Informatica EDC request to {endpoint} returned {response.status_code}: {message}"
)
def _encode_id(self, object_id: str, tilde: bool = False) -> str:
"""Encode an object ID for safe use in EDC URLs using tilde-prefix or percent encoding."""
if ":___" in object_id:
object_id = object_id.replace(":___", "://")
regex = re.compile(r"([^a-zA-Z0-9_-])")
id_lst = list(object_id)
idx = 0
match = regex.search(object_id, idx)
while match is not None:
idx = match.span()[1]
if tilde:
id_lst[idx - 1] = "~" + str(bytes(id_lst[idx - 1], "utf-8").hex()) + "~"
else:
id_lst[idx - 1] = "%" + str(bytes(id_lst[idx - 1], "utf-8").hex())
match = regex.search(object_id, idx)
return "".join(id_lst)
[docs]
def get_object(self, object_id: str, include_ref_objects: bool = False) -> dict[str, Any]:
"""Retrieve a catalog object by its identifier."""
encoded_object_id = self._encode_id(object_id, tilde=True)
include_refs = "true" if include_ref_objects else "false"
url = f"/access/2/catalog/data/objects/{encoded_object_id}?includeRefObjects={include_refs}"
response = self._request("GET", url)
return response.json()
def _search(self, **fq_parts: str) -> dict:
"""Execute a catalog data search with the given ``fq`` filter parts."""
params: list[tuple[str, str]] = [
("defaultFacets", "true"),
("disableSemanticSearch", "false"),
("enableLegacySearch", "false"),
("facet", "false"),
("fl", "core.name"),
("highlight", "false"),
("includeRefObjects", "false"),
]
for key, value in fq_parts.items():
params.append(("fq", f"{key}:{value}"))
query_string = urlencode(params)
response = self._request("GET", f"/access/2/catalog/data/search?{query_string}")
return response.json()
[docs]
def search_database(self, database_name: str) -> dict:
"""Search for a relational Database or DatabaseServer object by name."""
result = self._search(
**{"core.classType": "com.infa.ldm.relational.Database", "core.name": database_name}
)
if result.get("hits"):
return result
return self._search(
**{"core.classType": "com.infa.ldm.relational.DatabaseServer", "core.name": database_name}
)
[docs]
def search_schema(self, schema_name: str) -> dict:
"""Search for a relational Schema or DatabaseSchema object by name."""
result = self._search(
**{"core.classType": "com.infa.ldm.relational.Schema", "core.name": schema_name}
)
if result.get("hits"):
return result
return self._search(
**{"core.classType": "com.infa.ldm.relational.DatabaseSchema", "core.name": schema_name}
)
[docs]
def search_table(self, table_name: str) -> dict:
"""Search for a relational Table or View object by name."""
return self._search(
**{
"core.classType": "com.infa.ldm.relational.Table OR core.classType:com.infa.ldm.relational.View",
"core.name": table_name,
}
)
[docs]
def create_lineage_link(self, source_object_id: str, target_object_id: str) -> dict[str, Any]:
"""Create a lineage relationship between source and target objects."""
if source_object_id == target_object_id:
raise InformaticaEDCError(
"Source and target object identifiers must differ when creating lineage."
)
payload = {
"providerId": self.config.provider_id,
"modifiedBy": self.config.modified_by,
"updates": [
{
"id": target_object_id,
"newSourceLinks": [
{
"objectId": source_object_id,
"associationId": self._lineage_association,
"properties": [
{
"attrUuid": "core.targetAttribute",
"value": self._lineage_association,
}
],
}
],
"deleteSourceLinks": [],
"newFacts": [],
"deleteFacts": [],
}
],
}
response = self._request("PATCH", "/access/1/catalog/data/objects", json=payload)
return response.json() if response.content else {}