Source code for airflow.providers.trino.hooks.trino
## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportjsonimportosfrompathlibimportPathfromtypingimportTYPE_CHECKING,Any,Iterable,Mapping,TypeVarimporttrinofromtrino.exceptionsimportDatabaseErrorfromtrino.transactionimportIsolationLevelfromairflow.configurationimportconffromairflow.exceptionsimportAirflowExceptionfromairflow.providers.common.sql.hooks.sqlimportDbApiHookfromairflow.utils.helpersimportexactly_onefromairflow.utils.operator_helpersimportAIRFLOW_VAR_NAME_FORMAT_MAPPING,DEFAULT_FORMAT_PREFIXifTYPE_CHECKING:fromairflow.modelsimportConnection
[docs]defgenerate_trino_client_info()->str:"""Return json string with dag_id, task_id, execution_date and try_number."""context_var={format_map["default"].replace(DEFAULT_FORMAT_PREFIX,""):os.environ.get(format_map["env_var_format"],"")forformat_mapinAIRFLOW_VAR_NAME_FORMAT_MAPPING.values()}task_info={"dag_id":context_var["dag_id"],"task_id":context_var["task_id"],"execution_date":context_var["execution_date"],"try_number":context_var["try_number"],"dag_run_id":context_var["dag_run_id"],"dag_owner":context_var["dag_owner"],}returnjson.dumps(task_info,sort_keys=True)
[docs]classTrinoHook(DbApiHook):""" Interact with Trino through trino package. >>> ph = TrinoHook() >>> sql = "SELECT count(1) AS num FROM airflow.static_babynames" >>> ph.get_records(sql) [[340698]] """
[docs]defget_conn(self)->Connection:"""Return a connection object."""db=self.get_connection(self.trino_conn_id)# type: ignore[attr-defined]extra=db.extra_dejsonauth=Noneuser=db.loginifdb.passwordandextra.get("auth")in("kerberos","certs"):raiseAirflowException(f"The {extra.get('auth')!r} authorization type doesn't support password.")elifdb.password:auth=trino.auth.BasicAuthentication(db.login,db.password)# type: ignore[attr-defined]elifextra.get("auth")=="jwt":ifnotexactly_one(jwt_file:="jwt__file"inextra,jwt_token:="jwt__token"inextra):msg=("When auth set to 'jwt' then expected exactly one parameter 'jwt__file' or 'jwt__token'"" in connection extra, but ")ifjwt_fileandjwt_token:msg+="provided both."else:msg+="none of them provided."raiseValueError(msg)elifjwt_file:token=Path(extra["jwt__file"]).read_text()else:token=extra["jwt__token"]auth=trino.auth.JWTAuthentication(token=token)elifextra.get("auth")=="certs":auth=trino.auth.CertificateAuthentication(extra.get("certs__client_cert_path"),extra.get("certs__client_key_path"),)elifextra.get("auth")=="kerberos":auth=trino.auth.KerberosAuthentication(# type: ignore[attr-defined]config=extra.get("kerberos__config",os.environ.get("KRB5_CONFIG")),service_name=extra.get("kerberos__service_name"),mutual_authentication=_boolify(extra.get("kerberos__mutual_authentication",False)),force_preemptive=_boolify(extra.get("kerberos__force_preemptive",False)),hostname_override=extra.get("kerberos__hostname_override"),sanitize_mutual_error_response=_boolify(extra.get("kerberos__sanitize_mutual_error_response",True)),principal=extra.get("kerberos__principal",conf.get("kerberos","principal")),delegate=_boolify(extra.get("kerberos__delegate",False)),ca_bundle=extra.get("kerberos__ca_bundle"),)if_boolify(extra.get("impersonate_as_owner",False)):user=os.getenv("AIRFLOW_CTX_DAG_OWNER",None)ifuserisNone:user=db.loginhttp_headers={"X-Trino-Client-Info":generate_trino_client_info()}trino_conn=trino.dbapi.connect(host=db.host,port=db.port,user=user,source=extra.get("source","airflow"),http_scheme=extra.get("protocol","http"),http_headers=http_headers,catalog=extra.get("catalog","hive"),schema=db.schema,auth=auth,# type: ignore[func-returns-value]isolation_level=self.get_isolation_level(),verify=_boolify(extra.get("verify",True)),session_properties=extra.get("session_properties")orNone,client_tags=extra.get("client_tags")orNone,timezone=extra.get("timezone")orNone,)returntrino_conn
[docs]defget_isolation_level(self)->Any:"""Return an isolation level."""db=self.get_connection(self.trino_conn_id)# type: ignore[attr-defined]isolation_level=db.extra_dejson.get("isolation_level","AUTOCOMMIT").upper()returngetattr(IsolationLevel,isolation_level,IsolationLevel.AUTOCOMMIT)
[docs]defget_records(self,sql:str|list[str]="",parameters:Iterable|Mapping[str,Any]|None=None,)->Any:ifnotisinstance(sql,str):raiseValueError(f"The sql in Trino Hook must be a string and is {sql}!")try:returnsuper().get_records(self.strip_sql_string(sql),parameters)exceptDatabaseErrorase:raiseTrinoException(e)
[docs]defget_first(self,sql:str|list[str]="",parameters:Iterable|Mapping[str,Any]|None=None)->Any:ifnotisinstance(sql,str):raiseValueError(f"The sql in Trino Hook must be a string and is {sql}!")try:returnsuper().get_first(self.strip_sql_string(sql),parameters)exceptDatabaseErrorase:raiseTrinoException(e)
[docs]definsert_rows(self,table:str,rows:Iterable[tuple],target_fields:Iterable[str]|None=None,commit_every:int=0,replace:bool=False,**kwargs,)->None:""" Insert a set of tuples into a table in a generic way. :param table: Name of the target table :param rows: The rows to insert into the table :param target_fields: The names of the columns to fill in the table :param commit_every: The maximum number of rows to insert in one transaction. Set to 0 to insert all rows in one transaction. :param replace: Whether to replace instead of insert """ifself.get_isolation_level()==IsolationLevel.AUTOCOMMIT:self.log.info("Transactions are not enable in trino connection. ""Please use the isolation_level property to enable it. ""Falling back to insert all rows in one transaction.")commit_every=0super().insert_rows(table,rows,target_fields,commit_every,replace)
@staticmethoddef_serialize_cell(cell:Any,conn:Connection|None=None)->Any:""" Trino will adapt all execute() args internally, hence we return cell without any conversion. :param cell: The cell to insert into the table :param conn: The database connection :return: The cell """returncell
[docs]defget_openlineage_database_info(self,connection):"""Return Trino specific information for OpenLineage."""fromairflow.providers.openlineage.sqlparserimportDatabaseInforeturnDatabaseInfo(scheme="trino",authority=DbApiHook.get_openlineage_authority_part(connection,default_port=trino.constants.DEFAULT_PORT),information_schema_columns=["table_schema","table_name","column_name","ordinal_position","data_type","table_catalog",],database=connection.extra_dejson.get("catalog","hive"),is_information_schema_cross_db=True,)