Source code for airflow.providers.amazon.aws.hooks.athena

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This module contains AWS Athena hook.

.. spelling:word-list::

    PageIterator
"""
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Collection

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait

if TYPE_CHECKING:
    from botocore.paginate import PageIterator

[docs]MULTI_LINE_QUERY_LOG_PREFIX = "\n\t\t"
[docs]def query_params_to_string(params: dict[str, str | Collection[str]]) -> str: result = "" for key, value in params.items(): if key == "QueryString": value = ( MULTI_LINE_QUERY_LOG_PREFIX + str(value).replace("\n", MULTI_LINE_QUERY_LOG_PREFIX).rstrip() ) result += f"\t{key}: {value}\n" return result.rstrip()
[docs]class AthenaHook(AwsBaseHook): """Interact with Amazon Athena. Provide thick wrapper around :external+boto3:py:class:`boto3.client("athena") <Athena.Client>`. :param sleep_time: obsolete, please use the parameter of `poll_query_status` method instead :param log_query: Whether to log athena query and other execution params when it's executed. Defaults to *True*. Additional arguments (such as ``aws_conn_id``) may be specified and are passed down to the underlying AwsBaseHook. .. seealso:: - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` """
[docs] INTERMEDIATE_STATES = ( "QUEUED", "RUNNING", )
[docs] FAILURE_STATES = ( "FAILED", "CANCELLED", )
[docs] SUCCESS_STATES = ("SUCCEEDED",)
[docs] TERMINAL_STATES = ( "SUCCEEDED", "FAILED", "CANCELLED", )
def __init__( self, *args: Any, sleep_time: int | None = None, log_query: bool = True, **kwargs: Any ) -> None: super().__init__(client_type="athena", *args, **kwargs) # type: ignore if sleep_time is not None: self.sleep_time = sleep_time warnings.warn( "The `sleep_time` parameter of the Athena hook is deprecated, " "please pass this parameter to the poll_query_status method instead.", AirflowProviderDeprecationWarning, stacklevel=2, ) else: self.sleep_time = 30 # previous default value self.log_query = log_query self.__query_results: dict[str, Any] = {}
[docs] def run_query( self, query: str, query_context: dict[str, str], result_configuration: dict[str, Any], client_request_token: str | None = None, workgroup: str = "primary", ) -> str: """Run a Trino/Presto query on Athena with provided config. .. seealso:: - :external+boto3:py:meth:`Athena.Client.start_query_execution` :param query: Trino/Presto query to run. :param query_context: Context in which query need to be run. :param result_configuration: Dict with path to store results in and config related to encryption. :param client_request_token: Unique token created by user to avoid multiple executions of same query. :param workgroup: Athena workgroup name, when not specified, will be ``'primary'``. :return: Submitted query execution ID. """ params = { "QueryString": query, "QueryExecutionContext": query_context, "ResultConfiguration": result_configuration, "WorkGroup": workgroup, } if client_request_token: params["ClientRequestToken"] = client_request_token if self.log_query: self.log.info("Running Query with params:\n%s", query_params_to_string(params)) response = self.get_conn().start_query_execution(**params) query_execution_id = response["QueryExecutionId"] self.log.info("Query execution id: %s", query_execution_id) return query_execution_id
[docs] def get_query_info(self, query_execution_id: str, use_cache: bool = False) -> dict: """Get information about a single execution of a query. .. seealso:: - :external+boto3:py:meth:`Athena.Client.get_query_execution` :param query_execution_id: Id of submitted athena query :param use_cache: If True, use execution information cache """ if use_cache and query_execution_id in self.__query_results: return self.__query_results[query_execution_id] response = self.get_conn().get_query_execution(QueryExecutionId=query_execution_id) if use_cache: self.__query_results[query_execution_id] = response return response
[docs] def check_query_status(self, query_execution_id: str, use_cache: bool = False) -> str | None: """Fetch the state of a submitted query. .. seealso:: - :external+boto3:py:meth:`Athena.Client.get_query_execution` :param query_execution_id: Id of submitted athena query :return: One of valid query states, or *None* if the response is malformed. """ response = self.get_query_info(query_execution_id=query_execution_id, use_cache=use_cache) state = None try: state = response["QueryExecution"]["Status"]["State"] except Exception: self.log.exception( "Exception while getting query state. Query execution id: %s", query_execution_id ) finally: # The error is being absorbed here and is being handled by the caller. # The error is being absorbed to implement retries. return state
[docs] def get_state_change_reason(self, query_execution_id: str, use_cache: bool = False) -> str | None: """ Fetch the reason for a state change (e.g. error message). Returns None or reason string. .. seealso:: - :external+boto3:py:meth:`Athena.Client.get_query_execution` :param query_execution_id: Id of submitted athena query """ response = self.get_query_info(query_execution_id=query_execution_id, use_cache=use_cache) reason = None try: reason = response["QueryExecution"]["Status"]["StateChangeReason"] except Exception: self.log.exception( "Exception while getting query state change reason. Query execution id: %s", query_execution_id, ) finally: # The error is being absorbed here and is being handled by the caller. # The error is being absorbed to implement retries. return reason
[docs] def get_query_results( self, query_execution_id: str, next_token_id: str | None = None, max_results: int = 1000 ) -> dict | None: """Fetch submitted query results. .. seealso:: - :external+boto3:py:meth:`Athena.Client.get_query_results` :param query_execution_id: Id of submitted athena query :param next_token_id: The token that specifies where to start pagination. :param max_results: The maximum number of results (rows) to return in this request. :return: *None* if the query is in intermediate, failed, or cancelled state. Otherwise a dict of query outputs. """ query_state = self.check_query_status(query_execution_id) if query_state is None: self.log.error("Invalid Query state. Query execution id: %s", query_execution_id) return None elif query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES: self.log.error( 'Query is in "%s" state. Cannot fetch results. Query execution id: %s', query_state, query_execution_id, ) return None result_params = {"QueryExecutionId": query_execution_id, "MaxResults": max_results} if next_token_id: result_params["NextToken"] = next_token_id return self.get_conn().get_query_results(**result_params)
[docs] def get_query_results_paginator( self, query_execution_id: str, max_items: int | None = None, page_size: int | None = None, starting_token: str | None = None, ) -> PageIterator | None: """Fetch submitted Athena query results. .. seealso:: - :external+boto3:py:class:`Athena.Paginator.GetQueryResults` :param query_execution_id: Id of submitted athena query :param max_items: The total number of items to return. :param page_size: The size of each page. :param starting_token: A token to specify where to start paginating. :return: *None* if the query is in intermediate, failed, or cancelled state. Otherwise a paginator to iterate through pages of results. Call :meth`.build_full_result()` on the returned paginator to get all results at once. """ query_state = self.check_query_status(query_execution_id) if query_state is None: self.log.error("Invalid Query state (null). Query execution id: %s", query_execution_id) return None if query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES: self.log.error( 'Query is in "%s" state. Cannot fetch results, Query execution id: %s', query_state, query_execution_id, ) return None result_params = { "QueryExecutionId": query_execution_id, "PaginationConfig": { "MaxItems": max_items, "PageSize": page_size, "StartingToken": starting_token, }, } paginator = self.get_conn().get_paginator("get_query_results") return paginator.paginate(**result_params)
[docs] def poll_query_status( self, query_execution_id: str, max_polling_attempts: int | None = None, sleep_time: int | None = None ) -> str | None: """Poll the state of a submitted query until it reaches final state. :param query_execution_id: ID of submitted athena query :param max_polling_attempts: Number of times to poll for query state before function exits :param sleep_time: Time (in seconds) to wait between two consecutive query status checks. :return: One of the final states """ try: wait( waiter=self.get_waiter("query_complete"), waiter_delay=self.sleep_time if sleep_time is None else sleep_time, waiter_max_attempts=max_polling_attempts or 120, args={"QueryExecutionId": query_execution_id}, failure_message=f"Error while waiting for query {query_execution_id} to complete", status_message=f"Query execution id: {query_execution_id}, " f"Query is still in non-terminal state", status_args=["QueryExecution.Status.State"], ) except AirflowException as error: # this function does not raise errors to keep previous behavior. self.log.warning(error) finally: return self.check_query_status(query_execution_id)
[docs] def get_output_location(self, query_execution_id: str) -> str: """Get the output location of the query results in S3 URI format. .. seealso:: - :external+boto3:py:meth:`Athena.Client.get_query_execution` :param query_execution_id: Id of submitted athena query """ if not query_execution_id: raise ValueError(f"Invalid Query execution id. Query execution id: {query_execution_id}") if not (response := self.get_query_info(query_execution_id=query_execution_id, use_cache=True)): raise ValueError(f"Unable to get query information for execution id: {query_execution_id}") try: return response["QueryExecution"]["ResultConfiguration"]["OutputLocation"] except KeyError: self.log.error("Error retrieving OutputLocation. Query execution id: %s", query_execution_id) raise
[docs] def stop_query(self, query_execution_id: str) -> dict: """Cancel the submitted query. .. seealso:: - :external+boto3:py:meth:`Athena.Client.stop_query_execution` :param query_execution_id: Id of submitted athena query """ self.log.info("Stopping Query with executionId - %s", query_execution_id) return self.get_conn().stop_query_execution(QueryExecutionId=query_execution_id)

Was this entry helpful?