Source code for airflow.providers.snowflake.hooks.snowflake_sql_api
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportuuidfromdatetimeimporttimedeltafrompathlibimportPathfromtypingimportAnyimportaiohttpimportrequestsfromcryptography.hazmat.backendsimportdefault_backendfromcryptography.hazmat.primitivesimportserializationfromrequests.authimportHTTPBasicAuthfromairflow.exceptionsimportAirflowExceptionfromairflow.providers.snowflake.hooks.snowflakeimportSnowflakeHookfromairflow.providers.snowflake.utils.sql_api_generate_jwtimportJWTGenerator
[docs]classSnowflakeSqlApiHook(SnowflakeHook):""" A client to interact with Snowflake using SQL API and submit multiple SQL statements in a single request. In combination with aiohttp, make post request to submit SQL statements for execution, poll to check the status of the execution of a statement. Fetch query results asynchronously. This hook requires the snowflake_conn_id connection. This hooks mainly uses account, schema, database, warehouse, and an authentication mechanism from one of below: 1. JWT Token generated from private_key_file or private_key_content. Other inputs can be defined in the connection or hook instantiation. 2. OAuth Token generated from the refresh_token, client_id and client_secret specified in the connection :param snowflake_conn_id: Reference to :ref:`Snowflake connection id<howto/connection:snowflake>` :param account: snowflake account name :param authenticator: authenticator for Snowflake. 'snowflake' (default) to use the internal Snowflake authenticator 'externalbrowser' to authenticate using your web browser and Okta, ADFS or any other SAML 2.0-compliant identify provider (IdP) that has been defined for your account 'https://<your_okta_account_name>.okta.com' to authenticate through native Okta. :param warehouse: name of snowflake warehouse :param database: name of snowflake database :param region: name of snowflake region :param role: name of snowflake role :param schema: name of snowflake schema :param session_parameters: You can set session-level parameters at the time you connect to Snowflake :param token_life_time: lifetime of the JWT Token in timedelta :param token_renewal_delta: Renewal time of the JWT Token in timedelta :param deferrable: Run operator in the deferrable mode. """
[docs]LIFETIME=timedelta(minutes=59)# The tokens will have a 59 minute lifetime
[docs]RENEWAL_DELTA=timedelta(minutes=54)# Tokens will be renewed after 54 minutes
[docs]defget_private_key(self)->None:"""Get the private key from snowflake connection."""conn=self.get_connection(self.snowflake_conn_id)# If private_key_file is specified in the extra json, load the contents of the file as a private key.# If private_key_content is specified in the extra json, use it as a private key.# As a next step, specify this private key in the connection configuration.# The connection password then becomes the passphrase for the private key.# If your private key is not encrypted (not recommended), then leave the password empty.private_key_file=conn.extra_dejson.get("extra__snowflake__private_key_file")orconn.extra_dejson.get("private_key_file")private_key_content=conn.extra_dejson.get("extra__snowflake__private_key_content")orconn.extra_dejson.get("private_key_content")private_key_pem=Noneifprivate_key_contentandprivate_key_file:raiseAirflowException("The private_key_file and private_key_content extra fields are mutually exclusive. ""Please remove one.")elifprivate_key_file:private_key_pem=Path(private_key_file).read_bytes()elifprivate_key_content:private_key_pem=private_key_content.encode()ifprivate_key_pem:passphrase=Noneifconn.password:passphrase=conn.password.strip().encode()self.private_key=serialization.load_pem_private_key(private_key_pem,password=passphrase,backend=default_backend())
[docs]defexecute_query(self,sql:str,statement_count:int,query_tag:str="",bindings:dict[str,Any]|None=None)->list[str]:""" Run the query in Snowflake using SnowflakeSQL API by making API request. :param sql: the sql string to be executed with possibly multiple statements :param statement_count: set the MULTI_STATEMENT_COUNT field to the number of SQL statements in the request :param query_tag: (Optional) Query tag that you want to associate with the SQL statement. For details, see https://docs.snowflake.com/en/sql-reference/parameters.html#label-query-tag parameter. :param bindings: (Optional) Values of bind variables in the SQL statement. When executing the statement, Snowflake replaces placeholders (? and :name) in the statement with these specified values. """conn_config=self._get_conn_paramsreq_id=uuid.uuid4()url=f"{self.account_identifier}.snowflakecomputing.com/api/v2/statements"params:dict[str,Any]|None={"requestId":str(req_id),"async":True,"pageSize":10}headers=self.get_headers()sql_is_multi_stmt=";"insql.strip()ifnotisinstance(bindings,dict)andbindingsisnotNone:raiseAirflowException("Bindings should be a dictionary or None.")ifbindingsandsql_is_multi_stmt:self.log.warning("Bindings are not supported for multi-statement queries. Bindings will be ignored.")bindings=bindingsor{}data={"statement":sql,"resultSetMetaData":{"format":"json"},# If database, schema, warehouse, role parameters have been provided set them accordingly# If either of them has been not (Parent class initializes them to None in that case)# set them to what in the Airflow connection configuration"database":self.databaseorconn_config["database"],"schema":self.schemaorconn_config["schema"],"warehouse":self.warehouseorconn_config["warehouse"],"role":self.roleorconn_config["role"],"bindings":bindings,"parameters":{"MULTI_STATEMENT_COUNT":statement_count,"query_tag":query_tag,},}response=requests.post(url,json=data,headers=headers,params=params)try:response.raise_for_status()exceptrequests.exceptions.HTTPErrorase:# pragma: no covermsg=f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}"raiseAirflowException(msg)json_response=response.json()self.log.info("Snowflake SQL POST API response: %s",json_response)if"statementHandles"injson_response:self.query_ids=json_response["statementHandles"]elif"statementHandle"injson_response:self.query_ids.append(json_response["statementHandle"])else:raiseAirflowException("No statementHandle/statementHandles present in response")returnself.query_ids
[docs]defget_headers(self)->dict[str,Any]:"""Form auth headers based on either OAuth token or JWT token from private key."""conn_config=self._get_conn_params# Use OAuth if refresh_token and client_id and client_secret are providedifall([conn_config.get("refresh_token"),conn_config.get("client_id"),conn_config.get("client_secret")]):oauth_token=self.get_oauth_token()headers={"Content-Type":"application/json","Authorization":f"Bearer {oauth_token}","Accept":"application/json","User-Agent":"snowflakeSQLAPI/1.0","X-Snowflake-Authorization-Token-Type":"OAUTH",}returnheaders# Alternatively, get the JWT token from the connection details and the private keyifnotself.private_key:self.get_private_key()token=JWTGenerator(conn_config["account"],# type: ignore[arg-type]conn_config["user"],# type: ignore[arg-type]private_key=self.private_key,lifetime=self.token_life_time,renewal_delay=self.token_renewal_delta,).get_token()headers={"Content-Type":"application/json","Authorization":f"Bearer {token}","Accept":"application/json","User-Agent":"snowflakeSQLAPI/1.0","X-Snowflake-Authorization-Token-Type":"KEYPAIR_JWT",}returnheaders
[docs]defget_oauth_token(self)->str:"""Generate temporary OAuth access token using refresh token in connection details."""conn_config=self._get_conn_paramsurl=f"{self.account_identifier}.snowflakecomputing.com/oauth/token-request"data={"grant_type":"refresh_token","refresh_token":conn_config["refresh_token"],"redirect_uri":conn_config.get("redirect_uri","https://localhost.com"),}response=requests.post(url,data=data,headers={"Content-Type":"application/x-www-form-urlencoded",},auth=HTTPBasicAuth(conn_config["client_id"],conn_config["client_secret"]),# type: ignore[arg-type])try:response.raise_for_status()exceptrequests.exceptions.HTTPErrorase:# pragma: no covermsg=f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}"raiseAirflowException(msg)returnresponse.json()["access_token"]
[docs]defget_request_url_header_params(self,query_id:str)->tuple[dict[str,Any],dict[str,Any],str]:""" Build the request header Url with account name identifier and query id from the connection params. :param query_id: statement handles query ids for the individual statements. """req_id=uuid.uuid4()header=self.get_headers()params={"requestId":str(req_id)}url=f"{self.account_identifier}.snowflakecomputing.com/api/v2/statements/{query_id}"returnheader,params,url
[docs]defcheck_query_output(self,query_ids:list[str])->None:""" Make HTTP request to snowflake SQL API based on the provided query ids and log the response. :param query_ids: statement handles query id for the individual statements. """forquery_idinquery_ids:header,params,url=self.get_request_url_header_params(query_id)try:response=requests.get(url,headers=header,params=params)response.raise_for_status()self.log.info(response.json())exceptrequests.exceptions.HTTPErrorase:msg=f"Response: {e.response.content.decode()}, Status Code: {e.response.status_code}"raiseAirflowException(msg)
def_process_response(self,status_code,resp):self.log.info("Snowflake SQL GET statements status API response: %s",resp)ifstatus_code==202:return{"status":"running","message":"Query statements are still running"}elifstatus_code==422:return{"status":"error","message":resp["message"]}elifstatus_code==200:ifresp_statement_handles:=resp.get("statementHandles"):statement_handles=resp_statement_handleselifresp_statement_handle:=resp.get("statementHandle"):statement_handles=[resp_statement_handle]else:statement_handles=[]return{"status":"success","message":resp["message"],"statement_handles":statement_handles,}else:return{"status":"error","message":resp["message"]}
[docs]defget_sql_api_query_status(self,query_id:str)->dict[str,str|list[str]]:""" Based on the query id async HTTP request is made to snowflake SQL API and return response. :param query_id: statement handle id for the individual statements. """self.log.info("Retrieving status for query id %s",query_id)header,params,url=self.get_request_url_header_params(query_id)response=requests.get(url,params=params,headers=header)status_code=response.status_coderesp=response.json()returnself._process_response(status_code,resp)
[docs]asyncdefget_sql_api_query_status_async(self,query_id:str)->dict[str,str|list[str]]:""" Based on the query id async HTTP request is made to snowflake SQL API and return response. :param query_id: statement handle id for the individual statements. """self.log.info("Retrieving status for query id %s",query_id)header,params,url=self.get_request_url_header_params(query_id)asyncwith(aiohttp.ClientSession(headers=header)assession,session.get(url,params=params)asresponse,):status_code=response.statusresp=awaitresponse.json()returnself._process_response(status_code,resp)