Source code for airflow.providers.databricks.sensors.databricks_partition

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
"""This module contains Databricks sensors."""

from __future__ import annotations

from datetime import datetime
from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable, Sequence

from databricks.sql.utils import ParamEscaper

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
    from airflow.utils.context import Context


[docs]class DatabricksPartitionSensor(BaseSensorOperator): """ Sensor to detect the presence of table partitions in Databricks. :param databricks_conn_id: Reference to :ref:`Databricks connection id<howto/connection:databricks>` (templated), defaults to DatabricksSqlHook.default_conn_name. :param sql_warehouse_name: Optional name of Databricks SQL warehouse. If not specified, ``http_path`` must be provided as described below, defaults to None :param http_path: Optional string specifying HTTP path of Databricks SQL warehouse or All Purpose cluster. If not specified, it should be either specified in the Databricks connection's extra parameters, or ``sql_warehouse_name`` must be specified. :param session_configuration: An optional dictionary of Spark session parameters. If not specified, it could be specified in the Databricks connection's extra parameters, defaults to None :param http_headers: An optional list of (k, v) pairs that will be set as HTTP headers on every request. (templated). :param catalog: An optional initial catalog to use. Requires Databricks Runtime version 9.0+ (templated), defaults to "" :param schema: An optional initial schema to use. Requires Databricks Runtime version 9.0+ (templated), defaults to "default" :param table_name: Name of the table to check partitions. :param partitions: Name of the partitions to check. Example: {"date": "2023-01-03", "name": ["abc", "def"]} :param partition_operator: Optional comparison operator for partitions, such as >=. :param handler: Handler for DbApiHook.run() to return results, defaults to fetch_all_handler :param client_parameters: Additional parameters internal to Databricks SQL connector parameters. """
[docs] template_fields: Sequence[str] = ( "databricks_conn_id", "catalog", "schema", "table_name", "partitions", "http_headers", )
[docs] template_ext: Sequence[str] = (".sql",)
[docs] template_fields_renderers = {"sql": "sql"}
def __init__( self, *, databricks_conn_id: str = DatabricksSqlHook.default_conn_name, http_path: str | None = None, sql_warehouse_name: str | None = None, session_configuration=None, http_headers: list[tuple[str, str]] | None = None, catalog: str = "", schema: str = "default", table_name: str, partitions: dict, partition_operator: str = "=", handler: Callable[[Any], Any] = fetch_all_handler, client_parameters: dict[str, Any] | None = None, **kwargs, ) -> None: self.databricks_conn_id = databricks_conn_id self._http_path = http_path self._sql_warehouse_name = sql_warehouse_name self.session_config = session_configuration self.http_headers = http_headers self.catalog = catalog self.schema = schema self.caller = "DatabricksPartitionSensor" self.partitions = partitions self.partition_operator = partition_operator self.table_name = table_name self.client_parameters = client_parameters or {} self.hook_params = kwargs.pop("hook_params", {}) self.handler = handler self.escaper = ParamEscaper() super().__init__(**kwargs) def _sql_sensor(self, sql): """Execute the supplied SQL statement using the hook object.""" hook = self._get_hook sql_result = hook.run( sql, handler=self.handler if self.do_xcom_push else None, ) self.log.debug("SQL result: %s", sql_result) return sql_result @cached_property def _get_hook(self) -> DatabricksSqlHook: """Create and return a DatabricksSqlHook object.""" return DatabricksSqlHook( self.databricks_conn_id, self._http_path, self._sql_warehouse_name, self.session_config, self.http_headers, self.catalog, self.schema, self.caller, **self.client_parameters, **self.hook_params, ) def _check_table_partitions(self) -> list: """Generate the fully qualified table name, generate partition, and call the _sql_sensor method.""" if self.table_name.split(".")[0] == "delta": _fully_qualified_table_name = self.table_name else: _fully_qualified_table_name = f"{self.catalog}.{self.schema}.{self.table_name}" self.log.debug("Table name generated from arguments: %s", _fully_qualified_table_name) _joiner_val = " AND " _prefix = f"SELECT 1 FROM {_fully_qualified_table_name} WHERE" _suffix = " LIMIT 1" partition_sql = self._generate_partition_query( prefix=_prefix, suffix=_suffix, joiner_val=_joiner_val, opts=self.partitions, table_name=_fully_qualified_table_name, escape_key=False, ) return self._sql_sensor(partition_sql) def _generate_partition_query( self, prefix: str, suffix: str, joiner_val: str, table_name: str, opts: dict[str, str] | None = None, escape_key: bool = False, ) -> str: """ Query the table for available partitions. Generates the SQL query based on the partition data types. * For a list, it prepares the SQL in the format: column_name in (value1, value2,...) * For a numeric type, it prepares the format: column_name =(or other provided operator such as >=) value * For a date type, it prepares the format: column_name =(or other provided operator such as >=) value Once the filter predicates have been generated like above, the query is prepared to be executed using the prefix and suffix supplied, which are: "SELECT 1 FROM {_fully_qualified_table_name} WHERE" and "LIMIT 1". """ partition_columns = self._sql_sensor(f"DESCRIBE DETAIL {table_name}")[0][7] self.log.debug("Partition columns: %s", partition_columns) if len(partition_columns) < 1: # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = f"Table {table_name} does not have partitions" if self.soft_fail: raise AirflowSkipException(message) raise AirflowException(message) formatted_opts = "" if opts: output_list = [] for partition_col, partition_value in opts.items(): if escape_key: partition_col = self.escaper.escape_item(partition_col) if partition_col in partition_columns: if isinstance(partition_value, list): output_list.append(f"""{partition_col} in {tuple(partition_value)}""") self.log.debug("List formatting for partitions: %s", output_list) if isinstance(partition_value, (int, float, complex)): output_list.append( f"""{partition_col}{self.partition_operator}{self.escaper.escape_item(partition_value)}""" ) if isinstance(partition_value, (str, datetime)): output_list.append( f"""{partition_col}{self.partition_operator}{self.escaper.escape_item(partition_value)}""" ) else: # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = f"Column {partition_col} not part of table partitions: {partition_columns}" if self.soft_fail: raise AirflowSkipException(message) raise AirflowException(message) else: # Raises exception if the table does not have any partitions. # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = "No partitions specified to check with the sensor." if self.soft_fail: raise AirflowSkipException(message) raise AirflowException(message) formatted_opts = f"{prefix} {joiner_val.join(output_list)} {suffix}" self.log.debug("Formatted options: %s", formatted_opts) return formatted_opts.strip()
[docs] def poke(self, context: Context) -> bool: """Check the table partitions and return the results.""" partition_result = self._check_table_partitions() self.log.debug("Partition sensor result: %s", partition_result) if partition_result: return True else: # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 message = f"Specified partition(s): {self.partitions} were not found." if self.soft_fail: raise AirflowSkipException(message) raise AirflowException(message)

Was this entry helpful?