Source code for airflow.providers.amazon.aws.hooks.athena_sql
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportjsonfromfunctoolsimportcached_propertyfromtypingimportTYPE_CHECKING,Anyimportpyathenafromsqlalchemy.engine.urlimportURLfromairflow.exceptionsimportAirflowException,AirflowNotFoundExceptionfromairflow.providers.amazon.aws.hooks.base_awsimportAwsBaseHookfromairflow.providers.amazon.aws.utils.connection_wrapperimportAwsConnectionWrapperfromairflow.providers.common.sql.hooks.sqlimportDbApiHookifTYPE_CHECKING:frompyathena.connectionimportConnectionasAthenaConnection
[docs]classAthenaSQLHook(AwsBaseHook,DbApiHook):"""Interact with Amazon Athena. Provide wrapper around PyAthena library. :param athena_conn_id: :ref:`Amazon Athena Connection <howto/connection:athena>`. Additional arguments (such as ``aws_conn_id``) may be specified and are passed down to the underlying AwsBaseHook. You can specify ``driver`` in ``extra`` of your connection in order to use a different driver than the default ``rest``. Also, aws_domain could be specified in ``extra`` of your connection. PyAthena and AWS Authentication parameters could be passed in extra field of ``athena_conn_id`` connection. Passing authentication parameters in ``athena_conn_id`` will override those in ``aws_conn_id``. .. seealso:: :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` .. note:: get_uri() depends on SQLAlchemy and PyAthena. """
[docs]defget_ui_field_behaviour(cls)->dict[str,Any]:"""Return custom UI field behaviour for AWS Athena Connection."""return{"hidden_fields":["host","port"],"relabeling":{"login":"AWS Access Key ID","password":"AWS Secret Access Key",},"placeholders":{"login":"AKIAIOSFODNN7EXAMPLE","password":"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY","extra":json.dumps({"aws_domain":"amazonaws.com","driver":"rest","s3_staging_dir":"s3://bucket_name/staging/","work_group":"primary","region_name":"us-east-1","session_kwargs":{"profile_name":"default"},"config_kwargs":{"retries":{"mode":"standard","max_attempts":10}},"role_arn":"arn:aws:iam::123456789098:role/role-name","assume_role_method":"assume_role","assume_role_kwargs":{"RoleSessionName":"airflow"},"aws_session_token":"AQoDYXdzEJr...EXAMPLETOKEN","endpoint_url":"http://localhost:4566",},indent=2,),},}
@cached_property
[docs]defconn_config(self)->AwsConnectionWrapper:"""Get the Airflow Connection object and wrap it in helper (cached)."""athena_conn=self.get_connection(self.athena_conn_id)ifself.aws_conn_id:try:connection=self.get_connection(self.aws_conn_id)connection.login=athena_conn.loginconnection.password=athena_conn.passwordconnection.schema=athena_conn.schemaconnection.set_extra(json.dumps({**athena_conn.extra_dejson,**connection.extra_dejson}))exceptAirflowNotFoundException:connection=athena_connconnection.conn_type="aws"self.log.warning("Unable to find AWS Connection ID '%s', switching to empty.",self.aws_conn_id)returnAwsConnectionWrapper(conn=connection,region_name=self._region_name,botocore_config=self._config,verify=self._verify)
def_get_conn_params(self)->dict[str,str|None]:"""Retrieve connection parameters."""ifnotself.conn.region_name:raiseAirflowException("region_name must be specified in the connection's extra")returndict(driver=self.conn.extra_dejson.get("driver","rest"),schema_name=self.conn.schema,region_name=self.conn.region_name,aws_domain=self.conn.extra_dejson.get("aws_domain","amazonaws.com"),)
[docs]defget_uri(self)->str:"""Overridden to use the Athena dialect as driver name."""conn_params=self._get_conn_params()creds=self.get_credentials(region_name=conn_params["region_name"])returnURL.create(f'awsathena+{conn_params["driver"]}',username=creds.access_key,password=creds.secret_key,host=f'athena.{conn_params["region_name"]}.{conn_params["aws_domain"]}',port=443,database=conn_params["schema_name"],query={"aws_session_token":creds.token,**self.conn.extra_dejson},)
[docs]defget_conn(self)->AthenaConnection:"""Get a ``pyathena.Connection`` object."""conn_params=self._get_conn_params()conn_kwargs:dict={"schema_name":conn_params["schema_name"],"region_name":conn_params["region_name"],"session":self.get_session(region_name=conn_params["region_name"]),**self.conn.extra_dejson,}returnpyathena.connect(**conn_kwargs)