Source code for airflow.providers.databricks.hooks.databricks_sql
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsfromcontextlibimportclosingfromcopyimportcopyfromtypingimportTYPE_CHECKING,Any,Callable,Iterable,Mapping,TypeVar,overloadfromdatabricksimportsql# type: ignore[attr-defined]fromairflow.exceptionsimportAirflowExceptionfromairflow.providers.common.sql.hooks.sqlimportDbApiHook,return_single_query_resultsfromairflow.providers.databricks.hooks.databricks_baseimportBaseDatabricksHookifTYPE_CHECKING:fromdatabricks.sql.clientimportConnection
[docs]classDatabricksSqlHook(BaseDatabricksHook,DbApiHook):"""Hook to interact with Databricks SQL. :param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`. :param http_path: Optional string specifying HTTP path of Databricks SQL Endpoint or cluster. If not specified, it should be either specified in the Databricks connection's extra parameters, or ``sql_endpoint_name`` must be specified. :param sql_endpoint_name: Optional name of Databricks SQL Endpoint. If not specified, ``http_path`` must be provided as described above. :param session_configuration: An optional dictionary of Spark session parameters. Defaults to None. If not specified, it could be specified in the Databricks connection's extra parameters. :param http_headers: An optional list of (k, v) pairs that will be set as HTTP headers on every request :param catalog: An optional initial catalog to use. Requires DBR version 9.0+ :param schema: An optional initial schema to use. Requires DBR version 9.0+ :param kwargs: Additional parameters internal to Databricks SQL Connector parameters """
_test_connection_sql="select 42"def__init__(self,databricks_conn_id:str=BaseDatabricksHook.default_conn_name,http_path:str|None=None,sql_endpoint_name:str|None=None,session_configuration:dict[str,str]|None=None,http_headers:list[tuple[str,str]]|None=None,catalog:str|None=None,schema:str|None=None,caller:str="DatabricksSqlHook",**kwargs,)->None:super().__init__(databricks_conn_id,caller=caller)self._sql_conn=Noneself._token:str|None=Noneself._http_path=http_pathself._sql_endpoint_name=sql_endpoint_nameself.supports_autocommit=Trueself.session_config=session_configurationself.http_headers=http_headersself.catalog=catalogself.schema=schemaself.additional_params=kwargsdef_get_extra_config(self)->dict[str,Any|None]:extra_params=copy(self.databricks_conn.extra_dejson)forargin["http_path","session_configuration",*self.extra_parameters]:ifarginextra_params:delextra_params[arg]returnextra_paramsdef_get_sql_endpoint_by_name(self,endpoint_name)->dict[str,Any]:result=self._do_api_call(LIST_SQL_ENDPOINTS_ENDPOINT)if"endpoints"notinresult:raiseAirflowException("Can't list Databricks SQL endpoints")try:endpoint=next(endpointforendpointinresult["endpoints"]ifendpoint["name"]==endpoint_name)exceptStopIteration:raiseAirflowException(f"Can't find Databricks SQL endpoint with name '{endpoint_name}'")else:returnendpoint
[docs]defget_conn(self)->Connection:"""Return a Databricks SQL connection object."""ifnotself._http_path:ifself._sql_endpoint_name:endpoint=self._get_sql_endpoint_by_name(self._sql_endpoint_name)self._http_path=endpoint["odbc_params"]["path"]elif"http_path"inself.databricks_conn.extra_dejson:self._http_path=self.databricks_conn.extra_dejson["http_path"]else:raiseAirflowException("http_path should be provided either explicitly, ""or in extra parameter of Databricks connection, ""or sql_endpoint_name should be specified")requires_init=Trueifnotself._token:self._token=self._get_token(raise_error=True)else:new_token=self._get_token(raise_error=True)ifnew_token!=self._token:self._token=new_tokenelse:requires_init=Falseifnotself.session_config:self.session_config=self.databricks_conn.extra_dejson.get("session_configuration")ifnotself._sql_connorrequires_init:ifself._sql_conn:# close already existing connectionself._sql_conn.close()self._sql_conn=sql.connect(self.host,self._http_path,self._token,schema=self.schema,catalog=self.catalog,session_configuration=self.session_config,http_headers=self.http_headers,_user_agent_entry=self.user_agent_value,**self._get_extra_config(),**self.additional_params,)returnself._sql_conn
@overloaddefrun(self,sql:str|Iterable[str],autocommit:bool=...,parameters:Iterable|Mapping[str,Any]|None=...,handler:Callable[[Any],T]=...,split_statements:bool=...,return_last:bool=...,)->T|list[T]:...defrun(self,sql:str|Iterable[str],autocommit:bool=False,parameters:Iterable|Mapping[str,Any]|None=None,handler:Callable[[Any],T]|None=None,split_statements:bool=True,return_last:bool=True,)->T|list[T]|None:""" Run a command or a list of commands. Pass a list of SQL statements to the SQL parameter to get them to execute sequentially. :param sql: the sql statement to be executed (str) or a list of sql statements to execute :param autocommit: What to set the connection's autocommit setting to before executing the query. Note that currently there is no commit functionality in Databricks SQL so this flag has no effect. :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. :param split_statements: Whether to split a single SQL string into statements and run separately :param return_last: Whether to return result for only last statement or for all after split :return: return only result of the LAST SQL expression if handler was provided unless return_last is set to False. """self.descriptions=[]ifisinstance(sql,str):ifsplit_statements:sql_list=[self.strip_sql_string(s)forsinself.split_sql_string(sql)]else:sql_list=[self.strip_sql_string(sql)]else:sql_list=[self.strip_sql_string(s)forsinsql]ifsql_list:self.log.debug("Executing following statements against Databricks DB: %s",sql_list)else:raiseValueError("List of SQL statements is empty")conn=Noneresults=[]forsql_statementinsql_list:# when using AAD tokens, it could expire if previous query run longer than token lifetimeconn=self.get_conn()withclosing(conn.cursor())ascur:self.set_autocommit(conn,autocommit)withclosing(conn.cursor())ascur:self._run_command(cur,sql_statement,parameters)ifhandlerisnotNone:result=handler(cur)ifreturn_single_query_results(sql,return_last,split_statements):results=[result]self.descriptions=[cur.description]else:results.append(result)self.descriptions.append(cur.description)ifconn:conn.close()self._sql_conn=NoneifhandlerisNone:returnNoneifreturn_single_query_results(sql,return_last,split_statements):returnresults[-1]else:returnresults