Source code for airflow.providers.databricks.hooks.databricks_sql

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import re
from contextlib import closing
from copy import copy
from typing import Any, Dict, List, Optional, Union

from databricks import sql  # type: ignore[attr-defined]
from databricks.sql.client import Connection  # type: ignore[attr-defined]

from airflow.exceptions import AirflowException
from airflow.hooks.dbapi import DbApiHook
from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook

[docs]LIST_SQL_ENDPOINTS_ENDPOINT = ('GET', 'api/2.0/sql/endpoints')
[docs]class DatabricksSqlHook(BaseDatabricksHook, DbApiHook): """ Interact with Databricks SQL. :param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`. :param http_path: Optional string specifying HTTP path of Databricks SQL Endpoint or cluster. If not specified, it should be either specified in the Databricks connection's extra parameters, or ``sql_endpoint_name`` must be specified. :param sql_endpoint_name: Optional name of Databricks SQL Endpoint. If not specified, ``http_path`` must be provided as described above. :param session_configuration: An optional dictionary of Spark session parameters. Defaults to None. If not specified, it could be specified in the Databricks connection's extra parameters. """
[docs] hook_name = 'Databricks SQL'
def __init__( self, databricks_conn_id: str = BaseDatabricksHook.default_conn_name, http_path: Optional[str] = None, sql_endpoint_name: Optional[str] = None, session_configuration: Optional[Dict[str, str]] = None, ) -> None: super().__init__(databricks_conn_id) self._sql_conn = None self._token: Optional[str] = None self._http_path = http_path self._sql_endpoint_name = sql_endpoint_name self.supports_autocommit = True self.session_config = session_configuration def _get_extra_config(self) -> Dict[str, Optional[Any]]: extra_params = copy(self.databricks_conn.extra_dejson) for arg in ['http_path', 'session_configuration'] + self.extra_parameters: if arg in extra_params: del extra_params[arg] return extra_params def _get_sql_endpoint_by_name(self, endpoint_name) -> Dict[str, Any]: result = self._do_api_call(LIST_SQL_ENDPOINTS_ENDPOINT) if 'endpoints' not in result: raise AirflowException("Can't list Databricks SQL endpoints") lst = [endpoint for endpoint in result['endpoints'] if endpoint['name'] == endpoint_name] if len(lst) == 0: raise AirflowException(f"Can't f Databricks SQL endpoint with name '{endpoint_name}'") return lst[0]
[docs] def get_conn(self) -> Connection: """Returns a Databricks SQL connection object""" if not self._http_path: if self._sql_endpoint_name: endpoint = self._get_sql_endpoint_by_name(self._sql_endpoint_name) self._http_path = endpoint['odbc_params']['path'] elif 'http_path' in self.databricks_conn.extra_dejson: self._http_path = self.databricks_conn.extra_dejson['http_path'] else: raise AirflowException( "http_path should be provided either explicitly, " "or in extra parameter of Databricks connection, " "or sql_endpoint_name should be specified" ) requires_init = True if not self._token: self._token = self._get_token(raise_error=True) else: new_token = self._get_token(raise_error=True) if new_token != self._token: self._token = new_token else: requires_init = False if not self.session_config: self.session_config = self.databricks_conn.extra_dejson.get('session_configuration') if not self._sql_conn or requires_init: if self._sql_conn: # close already existing connection self._sql_conn.close() self._sql_conn = sql.connect( self.host, self._http_path, self._token, session_configuration=self.session_config, **self._get_extra_config(), ) return self._sql_conn
@staticmethod
[docs] def maybe_split_sql_string(sql: str) -> List[str]: """ Splits strings consisting of multiple SQL expressions into an TODO: do we need something more sophisticated? :param sql: SQL string potentially consisting of multiple expressions :return: list of individual expressions """ splits = [s.strip() for s in re.split(";\\s*\r?\n", sql) if s.strip() != ""] return splits
[docs] def run(self, sql: Union[str, List[str]], autocommit=True, parameters=None, handler=None): """ Runs a command or a list of commands. Pass a list of sql statements to the sql parameter to get them to execute sequentially :param sql: the sql statement to be executed (str) or a list of sql statements to execute :param autocommit: What to set the connection's autocommit setting to before executing the query. :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. :return: query results. """ if isinstance(sql, str): sql = self.maybe_split_sql_string(sql) self.log.debug("Executing %d statements", len(sql)) conn = None for sql_statement in sql: # when using AAD tokens, it could expire if previous query run longer than token lifetime conn = self.get_conn() with closing(conn.cursor()) as cur: self.log.info("Executing statement: '%s', parameters: '%s'", sql_statement, parameters) if parameters: cur.execute(sql_statement, parameters) else: cur.execute(sql_statement) schema = cur.description results = [] if handler is not None: cur = handler(cur) for row in cur: self.log.debug("Statement results: %s", row) results.append(row) self.log.info("Rows affected: %s", cur.rowcount) if conn: conn.close() self._sql_conn = None # Return only result of the last SQL expression return schema, results
[docs] def test_connection(self): """Test the Databricks SQL connection by running a simple query.""" try: self.run(sql="select 42") except Exception as e: return False, str(e) return True, "Connection successfully checked"
[docs] def bulk_dump(self, table, tmp_file): raise NotImplementedError()
[docs] def bulk_load(self, table, tmp_file): raise NotImplementedError()

Was this entry helpful?