Source code for airflow.providers.databricks.hooks.databricks_base
## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License."""Databricks hook.This hook enable the submitting and running of jobs to the Databricks platform. Internally theoperators talk to the ``api/2.0/jobs/runs/submit```endpoint <https://docs.databricks.com/api/latest/jobs.html#runs-submit>`_."""importcopyimportsysimporttimefromtypingimportAny,Dict,Optional,Tuplefromurllib.parseimporturlparseimportaiohttpimportrequestsfromrequestsimportPreparedRequest,exceptionsasrequests_exceptionsfromrequests.authimportAuthBase,HTTPBasicAuthfromrequests.exceptionsimportJSONDecodeErrorfromtenacityimport(AsyncRetrying,RetryError,Retrying,retry_if_exception,stop_after_attempt,wait_exponential,)fromairflowimport__version__fromairflow.exceptionsimportAirflowExceptionfromairflow.hooks.baseimportBaseHookfromairflow.modelsimportConnectionifsys.version_info>=(3,8):fromfunctoolsimportcached_propertyelse:fromcached_propertyimportcached_property
[docs]classBaseDatabricksHook(BaseHook):""" Base for interaction with Databricks. :param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`. :param timeout_seconds: The amount of time in seconds the requests library will wait before timing-out. :param retry_limit: The number of times to retry the connection in case of service outages. :param retry_delay: The number of seconds to wait between retries (it might be a floating point number). :param retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. """
]def__init__(self,databricks_conn_id:str=default_conn_name,timeout_seconds:int=180,retry_limit:int=3,retry_delay:float=1.0,retry_args:Optional[Dict[Any,Any]]=None,)->None:super().__init__()self.databricks_conn_id=databricks_conn_idself.timeout_seconds=timeout_secondsifretry_limit<1:raiseValueError('Retry limit must be greater than or equal to 1')self.retry_limit=retry_limitself.retry_delay=retry_delayself.aad_tokens:Dict[str,dict]={}self.aad_timeout_seconds=10defmy_after_func(retry_state):self._log_request_error(retry_state.attempt_number,retry_state.outcome)ifretry_args:self.retry_args=copy.copy(retry_args)self.retry_args['retry']=retry_if_exception(self._retryable_error)self.retry_args['after']=my_after_funcelse:self.retry_args=dict(stop=stop_after_attempt(self.retry_limit),wait=wait_exponential(min=self.retry_delay,max=(2**retry_limit)),retry=retry_if_exception(self._retryable_error),after=my_after_func,)@cached_property
@staticmethoddef_parse_host(host:str)->str:""" The purpose of this function is to be robust to improper connections settings provided by users, specifically in the host field. For example -- when users supply ``https://xx.cloud.databricks.com`` as the host, we must strip out the protocol to get the host.:: h = DatabricksHook() assert h._parse_host('https://xx.cloud.databricks.com') == \ 'xx.cloud.databricks.com' In the case where users supply the correct ``xx.cloud.databricks.com`` as the host, this function is a no-op.:: assert h._parse_host('xx.cloud.databricks.com') == 'xx.cloud.databricks.com' """urlparse_host=urlparse(host).hostnameifurlparse_host:# In this case, host = https://xx.cloud.databricks.comreturnurlparse_hostelse:# In this case, host = xx.cloud.databricks.comreturnhostdef_get_retry_object(self)->Retrying:""" Instantiates a retry object :return: instance of Retrying class """returnRetrying(**self.retry_args)def_a_get_retry_object(self)->AsyncRetrying:""" Instantiates an async retry object :return: instance of AsyncRetrying class """returnAsyncRetrying(**self.retry_args)def_get_aad_token(self,resource:str)->str:""" Function to get AAD token for given resource. Supports managed identity or service principal auth :param resource: resource to issue token to :return: AAD token, or raise an exception """aad_token=self.aad_tokens.get(resource)ifaad_tokenandself._is_aad_token_valid(aad_token):returnaad_token['token']self.log.info('Existing AAD token is expired, or going to expire soon. Refreshing...')try:forattemptinself._get_retry_object():withattempt:ifself.databricks_conn.extra_dejson.get('use_azure_managed_identity',False):params={"api-version":"2018-02-01","resource":resource,}resp=requests.get(AZURE_METADATA_SERVICE_TOKEN_URL,params=params,headers={**USER_AGENT_HEADER,"Metadata":"true"},timeout=self.aad_timeout_seconds,)else:tenant_id=self.databricks_conn.extra_dejson['azure_tenant_id']data={"grant_type":"client_credentials","client_id":self.databricks_conn.login,"resource":resource,"client_secret":self.databricks_conn.password,}azure_ad_endpoint=self.databricks_conn.extra_dejson.get("azure_ad_endpoint",AZURE_DEFAULT_AD_ENDPOINT)resp=requests.post(AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint,tenant_id),data=data,headers={**USER_AGENT_HEADER,'Content-Type':'application/x-www-form-urlencoded',},timeout=self.aad_timeout_seconds,)resp.raise_for_status()jsn=resp.json()if('access_token'notinjsnorjsn.get('token_type')!='Bearer'or'expires_on'notinjsn):raiseAirflowException(f"Can't get necessary data from AAD token: {jsn}")token=jsn['access_token']self.aad_tokens[resource]={'token':token,'expires_on':int(jsn["expires_on"])}breakexceptRetryError:raiseAirflowException(f'API requests to Azure failed {self.retry_limit} times. Giving up.')exceptrequests_exceptions.HTTPErrorase:raiseAirflowException(f'Response: {e.response.content}, Status Code: {e.response.status_code}')returntokenasyncdef_a_get_aad_token(self,resource:str)->str:""" Async version of `_get_aad_token()`. :param resource: resource to issue token to :return: AAD token, or raise an exception """aad_token=self.aad_tokens.get(resource)ifaad_tokenandself._is_aad_token_valid(aad_token):returnaad_token['token']self.log.info('Existing AAD token is expired, or going to expire soon. Refreshing...')try:asyncforattemptinself._a_get_retry_object():withattempt:ifself.databricks_conn.extra_dejson.get('use_azure_managed_identity',False):params={"api-version":"2018-02-01","resource":resource,}asyncwithself._session.get(url=AZURE_METADATA_SERVICE_TOKEN_URL,params=params,headers={**USER_AGENT_HEADER,"Metadata":"true"},timeout=self.aad_timeout_seconds,)asresp:resp.raise_for_status()jsn=awaitresp.json()else:tenant_id=self.databricks_conn.extra_dejson['azure_tenant_id']data={"grant_type":"client_credentials","client_id":self.databricks_conn.login,"resource":resource,"client_secret":self.databricks_conn.password,}azure_ad_endpoint=self.databricks_conn.extra_dejson.get("azure_ad_endpoint",AZURE_DEFAULT_AD_ENDPOINT)asyncwithself._session.post(url=AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint,tenant_id),data=data,headers={**USER_AGENT_HEADER,'Content-Type':'application/x-www-form-urlencoded',},timeout=self.aad_timeout_seconds,)asresp:resp.raise_for_status()jsn=awaitresp.json()if('access_token'notinjsnorjsn.get('token_type')!='Bearer'or'expires_on'notinjsn):raiseAirflowException(f"Can't get necessary data from AAD token: {jsn}")token=jsn['access_token']self.aad_tokens[resource]={'token':token,'expires_on':int(jsn["expires_on"])}breakexceptRetryError:raiseAirflowException(f'API requests to Azure failed {self.retry_limit} times. Giving up.')exceptaiohttp.ClientResponseErroraserr:raiseAirflowException(f'Response: {err.message}, Status Code: {err.status}')returntokendef_get_aad_headers(self)->dict:""" Fills AAD headers if necessary (SPN is outside of the workspace) :return: dictionary with filled AAD headers """headers={}if'azure_resource_id'inself.databricks_conn.extra_dejson:mgmt_token=self._get_aad_token(AZURE_MANAGEMENT_ENDPOINT)headers['X-Databricks-Azure-Workspace-Resource-Id']=self.databricks_conn.extra_dejson['azure_resource_id']headers['X-Databricks-Azure-SP-Management-Token']=mgmt_tokenreturnheadersasyncdef_a_get_aad_headers(self)->dict:""" Async version of `_get_aad_headers()`. :return: dictionary with filled AAD headers """headers={}if'azure_resource_id'inself.databricks_conn.extra_dejson:mgmt_token=awaitself._a_get_aad_token(AZURE_MANAGEMENT_ENDPOINT)headers['X-Databricks-Azure-Workspace-Resource-Id']=self.databricks_conn.extra_dejson['azure_resource_id']headers['X-Databricks-Azure-SP-Management-Token']=mgmt_tokenreturnheaders@staticmethoddef_is_aad_token_valid(aad_token:dict)->bool:""" Utility function to check AAD token hasn't expired yet :param aad_token: dict with properties of AAD token :return: true if token is valid, false otherwise :rtype: bool """now=int(time.time())ifaad_token['expires_on']>(now+TOKEN_REFRESH_LEAD_TIME):returnTruereturnFalse@staticmethoddef_check_azure_metadata_service()->None:""" Check for Azure Metadata Service https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service """try:jsn=requests.get(AZURE_METADATA_SERVICE_INSTANCE_URL,params={"api-version":"2021-02-01"},headers={"Metadata":"true"},timeout=2,).json()if'compute'notinjsnor'azEnvironment'notinjsn['compute']:raiseAirflowException(f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}")except(requests_exceptions.RequestException,ValueError)ase:raiseAirflowException(f"Can't reach Azure Metadata Service: {e}")asyncdef_a_check_azure_metadata_service(self):"""Async version of `_check_azure_metadata_service()`."""try:asyncwithself._session.get(url=AZURE_METADATA_SERVICE_INSTANCE_URL,params={"api-version":"2021-02-01"},headers={"Metadata":"true"},timeout=2,)asresp:jsn=awaitresp.json()if'compute'notinjsnor'azEnvironment'notinjsn['compute']:raiseAirflowException(f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}")except(requests_exceptions.RequestException,ValueError)ase:raiseAirflowException(f"Can't reach Azure Metadata Service: {e}")def_get_token(self,raise_error:bool=False)->Optional[str]:if'token'inself.databricks_conn.extra_dejson:self.log.info('Using token auth. For security reasons, please set token in Password field instead of extra')returnself.databricks_conn.extra_dejson['token']elifnotself.databricks_conn.loginandself.databricks_conn.password:self.log.info('Using token auth.')returnself.databricks_conn.passwordelif'azure_tenant_id'inself.databricks_conn.extra_dejson:ifself.databricks_conn.login==""orself.databricks_conn.password=="":raiseAirflowException("Azure SPN credentials aren't provided")self.log.info('Using AAD Token for SPN.')returnself._get_aad_token(DEFAULT_DATABRICKS_SCOPE)elifself.databricks_conn.extra_dejson.get('use_azure_managed_identity',False):self.log.info('Using AAD Token for managed identity.')self._check_azure_metadata_service()returnself._get_aad_token(DEFAULT_DATABRICKS_SCOPE)elifraise_error:raiseAirflowException('Token authentication isn\'t configured')returnNoneasyncdef_a_get_token(self,raise_error:bool=False)->Optional[str]:if'token'inself.databricks_conn.extra_dejson:self.log.info('Using token auth. For security reasons, please set token in Password field instead of extra')returnself.databricks_conn.extra_dejson["token"]elifnotself.databricks_conn.loginandself.databricks_conn.password:self.log.info('Using token auth.')returnself.databricks_conn.passwordelif'azure_tenant_id'inself.databricks_conn.extra_dejson:ifself.databricks_conn.login==""orself.databricks_conn.password=="":raiseAirflowException("Azure SPN credentials aren't provided")self.log.info('Using AAD Token for SPN.')returnawaitself._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE)elifself.databricks_conn.extra_dejson.get('use_azure_managed_identity',False):self.log.info('Using AAD Token for managed identity.')awaitself._a_check_azure_metadata_service()returnawaitself._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE)elifraise_error:raiseAirflowException('Token authentication isn\'t configured')returnNonedef_log_request_error(self,attempt_num:int,error:str)->None:self.log.error('Attempt %s API Request to Databricks failed with reason: %s',attempt_num,error)def_do_api_call(self,endpoint_info:Tuple[str,str],json:Optional[Dict[str,Any]]=None,wrap_http_errors:bool=True,):""" Utility function to perform an API call with retries :param endpoint_info: Tuple of method and endpoint :param json: Parameters for this API call. :return: If the api call returns a OK status code, this function returns the response in JSON. Otherwise, we throw an AirflowException. :rtype: dict """method,endpoint=endpoint_info# TODO: get rid of explicit 'api/' in the endpoint specificationurl=f'https://{self.host}/{endpoint}'aad_headers=self._get_aad_headers()headers={**USER_AGENT_HEADER.copy(),**aad_headers}auth:AuthBasetoken=self._get_token()iftoken:auth=_TokenAuth(token)else:self.log.info('Using basic auth.')auth=HTTPBasicAuth(self.databricks_conn.login,self.databricks_conn.password)request_func:Anyifmethod=='GET':request_func=requests.getelifmethod=='POST':request_func=requests.postelifmethod=='PATCH':request_func=requests.patchelifmethod=='DELETE':request_func=requests.deleteelse:raiseAirflowException('Unexpected HTTP Method: '+method)try:forattemptinself._get_retry_object():withattempt:response=request_func(url,json=jsonifmethodin('POST','PATCH')elseNone,params=jsonifmethod=='GET'elseNone,auth=auth,headers=headers,timeout=self.timeout_seconds,)response.raise_for_status()returnresponse.json()exceptRetryError:raiseAirflowException(f'API requests to Databricks failed {self.retry_limit} times. Giving up.')exceptrequests_exceptions.HTTPErrorase:ifwrap_http_errors:raiseAirflowException(f'Response: {e.response.content}, Status Code: {e.response.status_code}')else:raiseeasyncdef_a_do_api_call(self,endpoint_info:Tuple[str,str],json:Optional[Dict[str,Any]]=None):""" Async version of `_do_api_call()`. :param endpoint_info: Tuple of method and endpoint :param json: Parameters for this API call. :return: If the api call returns a OK status code, this function returns the response in JSON. Otherwise, throw an AirflowException. """method,endpoint=endpoint_infourl=f'https://{self.host}/{endpoint}'aad_headers=awaitself._a_get_aad_headers()headers={**USER_AGENT_HEADER.copy(),**aad_headers}auth:aiohttp.BasicAuthtoken=awaitself._a_get_token()iftoken:auth=BearerAuth(token)else:self.log.info('Using basic auth.')auth=aiohttp.BasicAuth(self.databricks_conn.login,self.databricks_conn.password)request_func:Anyifmethod=='GET':request_func=self._session.getelifmethod=='POST':request_func=self._session.postelifmethod=='PATCH':request_func=self._session.patchelse:raiseAirflowException('Unexpected HTTP Method: '+method)try:asyncforattemptinself._a_get_retry_object():withattempt:asyncwithrequest_func(url,json=json,auth=auth,headers={**headers,**USER_AGENT_HEADER},timeout=self.timeout_seconds,)asresponse:response.raise_for_status()returnawaitresponse.json()exceptRetryError:raiseAirflowException(f'API requests to Databricks failed {self.retry_limit} times. Giving up.')exceptaiohttp.ClientResponseErroraserr:raiseAirflowException(f'Response: {err.message}, Status Code: {err.status}')@staticmethoddef_get_error_code(exception:BaseException)->str:ifisinstance(exception,requests_exceptions.HTTPError):try:jsn=exception.response.json()returnjsn.get('error_code','')exceptJSONDecodeError:passreturn""@staticmethoddef_retryable_error(exception:BaseException)->bool:ifisinstance(exception,requests_exceptions.RequestException):ifisinstance(exception,(requests_exceptions.ConnectionError,requests_exceptions.Timeout))or(exception.responseisnotNoneand(exception.response.status_code>=500orexception.response.status_code==429or(exception.response.status_code==400andBaseDatabricksHook._get_error_code(exception)=='COULD_NOT_ACQUIRE_LOCK'))):returnTrueifisinstance(exception,aiohttp.ClientResponseError):ifexception.status>=500orexception.status==429:returnTruereturnFalse
class_TokenAuth(AuthBase):""" Helper class for requests Auth field. AuthBase requires you to implement the __call__ magic function. """def__init__(self,token:str)->None:self.token=tokendef__call__(self,r:PreparedRequest)->PreparedRequest:r.headers['Authorization']='Bearer '+self.tokenreturnr
[docs]classBearerAuth(aiohttp.BasicAuth):"""aiohttp only ships BasicAuth, for Bearer auth we need a subclass of BasicAuth."""def__new__(cls,token:str)->'BearerAuth':returnsuper().__new__(cls,token)# type: ignoredef__init__(self,token:str)->None:self.token=token