#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
"""This module contains Databricks sensors."""
from __future__ import annotations
from collections.abc import Sequence
from datetime import datetime
from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable
from databricks.sql.utils import ParamEscaper
from airflow.exceptions import AirflowException
from airflow.providers.common.sql.hooks.handlers import fetch_all_handler
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook
from airflow.sensors.base import BaseSensorOperator
if TYPE_CHECKING:
    try:
        from airflow.sdk.definitions.context import Context
    except ImportError:
        # TODO: Remove once provider drops support for Airflow 2
        from airflow.utils.context import Context
[docs]
class DatabricksPartitionSensor(BaseSensorOperator):
    """
    Sensor to detect the presence of table partitions in Databricks.
    :param databricks_conn_id: Reference to :ref:`Databricks
        connection id<howto/connection:databricks>` (templated), defaults to
        DatabricksSqlHook.default_conn_name.
    :param sql_warehouse_name: Optional name of Databricks SQL warehouse. If not specified, ``http_path``
        must be provided as described below, defaults to None
    :param http_path: Optional string specifying HTTP path of Databricks SQL warehouse or All Purpose cluster.
        If not specified, it should be either specified in the Databricks connection's
        extra parameters, or ``sql_warehouse_name`` must be specified.
    :param session_configuration: An optional dictionary of Spark session parameters. If not specified,
        it could be specified in the Databricks connection's extra parameters, defaults to None
    :param http_headers: An optional list of (k, v) pairs
        that will be set as HTTP headers on every request. (templated).
    :param catalog: An optional initial catalog to use.
        Requires Databricks Runtime version 9.0+ (templated), defaults to ""
    :param schema: An optional initial schema to use.
        Requires Databricks Runtime version 9.0+ (templated), defaults to "default"
    :param table_name: Name of the table to check partitions.
    :param partitions: Name of the partitions to check.
        Example: {"date": "2023-01-03", "name": ["abc", "def"]}
    :param partition_operator: Optional comparison operator for partitions, such as >=.
    :param handler: Handler for DbApiHook.run() to return results, defaults to fetch_all_handler
    :param client_parameters: Additional parameters internal to Databricks SQL connector parameters.
    """
[docs]
    template_fields: Sequence[str] = (
        "databricks_conn_id",
        "catalog",
        "schema",
        "table_name",
        "partitions",
        "http_headers",
    ) 
[docs]
    template_ext: Sequence[str] = (".sql",) 
[docs]
    template_fields_renderers = {"sql": "sql"} 
    def __init__(
        self,
        *,
        databricks_conn_id: str = DatabricksSqlHook.default_conn_name,
        http_path: str | None = None,
        sql_warehouse_name: str | None = None,
        session_configuration=None,
        http_headers: list[tuple[str, str]] | None = None,
        catalog: str = "",
        schema: str = "default",
        table_name: str,
        partitions: dict,
        partition_operator: str = "=",
        handler: Callable[[Any], Any] = fetch_all_handler,
        client_parameters: dict[str, Any] | None = None,
        **kwargs,
    ) -> None:
[docs]
        self.databricks_conn_id = databricks_conn_id 
        self._http_path = http_path
        self._sql_warehouse_name = sql_warehouse_name
[docs]
        self.session_config = session_configuration 
[docs]
        self.caller = "DatabricksPartitionSensor" 
[docs]
        self.partitions = partitions 
[docs]
        self.partition_operator = partition_operator 
[docs]
        self.table_name = table_name 
[docs]
        self.client_parameters = client_parameters or {} 
[docs]
        self.hook_params = kwargs.pop("hook_params", {}) 
[docs]
        self.escaper = ParamEscaper() 
        super().__init__(**kwargs)
    def _sql_sensor(self, sql):
        """Execute the supplied SQL statement using the hook object."""
        hook = self._get_hook
        sql_result = hook.run(
            sql,
            handler=self.handler if self.do_xcom_push else None,
        )
        self.log.debug("SQL result: %s", sql_result)
        return sql_result
    @cached_property
    def _get_hook(self) -> DatabricksSqlHook:
        """Create and return a DatabricksSqlHook object."""
        return DatabricksSqlHook(
            self.databricks_conn_id,
            self._http_path,
            self._sql_warehouse_name,
            self.session_config,
            self.http_headers,
            self.catalog,
            self.schema,
            self.caller,
            **self.client_parameters,
            **self.hook_params,
        )
    def _check_table_partitions(self) -> list:
        """Generate the fully qualified table name, generate partition, and call the _sql_sensor method."""
        if self.table_name.split(".")[0] == "delta":
            _fully_qualified_table_name = self.table_name
        else:
            _fully_qualified_table_name = f"{self.catalog}.{self.schema}.{self.table_name}"
        self.log.debug("Table name generated from arguments: %s", _fully_qualified_table_name)
        _joiner_val = " AND "
        _prefix = f"SELECT 1 FROM {_fully_qualified_table_name} WHERE"
        _suffix = " LIMIT 1"
        partition_sql = self._generate_partition_query(
            prefix=_prefix,
            suffix=_suffix,
            joiner_val=_joiner_val,
            opts=self.partitions,
            table_name=_fully_qualified_table_name,
            escape_key=False,
        )
        return self._sql_sensor(partition_sql)
    def _generate_partition_query(
        self,
        prefix: str,
        suffix: str,
        joiner_val: str,
        table_name: str,
        opts: dict[str, str] | None = None,
        escape_key: bool = False,
    ) -> str:
        """
        Query the table for available partitions.
        Generates the SQL query based on the partition data types.
            * For a list, it prepares the SQL in the format:
                column_name in (value1, value2,...)
            * For a numeric type, it prepares the format:
                column_name =(or other provided operator such as >=) value
            * For a date type, it prepares the format:
                column_name =(or other provided operator such as >=) value
        Once the filter predicates have been generated like above, the query
        is prepared to be executed using the prefix and suffix supplied, which are:
        "SELECT 1 FROM {_fully_qualified_table_name} WHERE" and "LIMIT 1".
        """
        partition_columns = self._sql_sensor(f"DESCRIBE DETAIL {table_name}")[0][7]
        self.log.debug("Partition columns: %s", partition_columns)
        if len(partition_columns) < 1:
            message = f"Table {table_name} does not have partitions"
            raise AirflowException(message)
        formatted_opts = ""
        if opts:
            output_list = []
            for partition_col, partition_value in opts.items():
                if escape_key:
                    partition_col = self.escaper.escape_item(partition_col)
                if partition_col in partition_columns:
                    if isinstance(partition_value, list):
                        output_list.append(f"""{partition_col} in {tuple(partition_value)}""")
                        self.log.debug("List formatting for partitions: %s", output_list)
                    if isinstance(partition_value, (int, float, complex)):
                        output_list.append(
                            f"""{partition_col}{self.partition_operator}{self.escaper.escape_item(partition_value)}"""
                        )
                    if isinstance(partition_value, (str, datetime)):
                        output_list.append(
                            f"""{partition_col}{self.partition_operator}{self.escaper.escape_item(partition_value)}"""
                        )
                else:
                    message = f"Column {partition_col} not part of table partitions: {partition_columns}"
                    raise AirflowException(message)
        else:
            # Raises exception if the table does not have any partitions.
            message = "No partitions specified to check with the sensor."
            raise AirflowException(message)
        formatted_opts = f"{prefix} {joiner_val.join(output_list)} {suffix}"
        self.log.debug("Formatted options: %s", formatted_opts)
        return formatted_opts.strip()
[docs]
    def poke(self, context: Context) -> bool:
        """Check the table partitions and return the results."""
        partition_result = self._check_table_partitions()
        self.log.debug("Partition sensor result: %s", partition_result)
        if partition_result:
            return True
        else:
            message = f"Specified partition(s): {self.partitions} were not found."
            raise AirflowException(message)