Source code for airflow.contrib.hooks.aws_athena_hook

# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from time import sleep
from airflow.contrib.hooks.aws_hook import AwsHook


[docs]class AWSAthenaHook(AwsHook): """ Interact with AWS Athena to run, poll queries and return query results :param aws_conn_id: aws connection to use. :type aws_conn_id: str :param sleep_time: Time to wait between two consecutive call to check query status on athena :type sleep_time: int """
[docs] INTERMEDIATE_STATES = ('QUEUED', 'RUNNING',)
[docs] FAILURE_STATES = ('FAILED', 'CANCELLED',)
[docs] SUCCESS_STATES = ('SUCCEEDED',)
def __init__(self, aws_conn_id='aws_default', sleep_time=30, *args, **kwargs): super(AWSAthenaHook, self).__init__(aws_conn_id, *args, **kwargs) self.sleep_time = sleep_time self.conn = None
[docs] def get_conn(self): """ check if aws conn exists already or create one and return it :return: boto3 session """ if not self.conn: self.conn = self.get_client_type('athena') return self.conn
[docs] def run_query(self, query, query_context, result_configuration, client_request_token=None): """ Run Presto query on athena with provided config and return submitted query_execution_id :param query: Presto query to run :type query: str :param query_context: Context in which query need to be run :type query_context: dict :param result_configuration: Dict with path to store results in and config related to encryption :type result_configuration: dict :param client_request_token: Unique token created by user to avoid multiple executions of same query :type client_request_token: str :return: str """ response = self.get_conn().start_query_execution(QueryString=query, ClientRequestToken=client_request_token, QueryExecutionContext=query_context, ResultConfiguration=result_configuration) query_execution_id = response['QueryExecutionId'] return query_execution_id
[docs] def check_query_status(self, query_execution_id): """ Fetch the status of submitted athena query. Returns None or one of valid query states. :param query_execution_id: Id of submitted athena query :type query_execution_id: str :return: str """ response = self.get_conn().get_query_execution(QueryExecutionId=query_execution_id) state = None try: state = response['QueryExecution']['Status']['State'] except Exception as ex: self.log.error('Exception while getting query state', ex) finally: return state
[docs] def get_state_change_reason(self, query_execution_id): """ Fetch the reason for a state change (e.g. error message). Returns None or reason string. :param query_execution_id: Id of submitted athena query :type query_execution_id: str :return: str """ response = self.get_conn().get_query_execution(QueryExecutionId=query_execution_id) reason = None try: reason = response['QueryExecution']['Status']['StateChangeReason'] except Exception as ex: self.log.error('Exception while getting query state change reason', ex) finally: return reason
[docs] def get_query_results(self, query_execution_id): """ Fetch submitted athena query results. returns none if query is in intermediate state or failed/cancelled state else dict of query output :param query_execution_id: Id of submitted athena query :type query_execution_id: str :return: dict """ query_state = self.check_query_status(query_execution_id) if query_state is None: self.log.error('Invalid Query state') return None elif query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES: self.log.error('Query is in {state} state. Cannot fetch results'.format(state=query_state)) return None return self.get_conn().get_query_results(QueryExecutionId=query_execution_id)
[docs] def poll_query_status(self, query_execution_id, max_tries=None): """ Poll the status of submitted athena query until query state reaches final state. Returns one of the final states :param query_execution_id: Id of submitted athena query :type query_execution_id: str :param max_tries: Number of times to poll for query state before function exits :type max_tries: int :return: str """ try_number = 1 final_query_state = None # Query state when query reaches final state or max_tries reached while True: query_state = self.check_query_status(query_execution_id) if query_state is None: self.log.info('Trial {try_number}: Invalid query state. Retrying again'.format( try_number=try_number)) elif query_state in self.INTERMEDIATE_STATES: self.log.info('Trial {try_number}: Query is still in an intermediate state - {state}' .format(try_number=try_number, state=query_state)) else: self.log.info('Trial {try_number}: Query execution completed. Final state is {state}' .format(try_number=try_number, state=query_state)) final_query_state = query_state break if max_tries and try_number >= max_tries: # Break loop if max_tries reached final_query_state = query_state break try_number += 1 sleep(self.sleep_time) return final_query_state
[docs] def stop_query(self, query_execution_id): """ Cancel the submitted athena query :param query_execution_id: Id of submitted athena query :type query_execution_id: str :return: dict """ return self.get_conn().stop_query_execution(QueryExecutionId=query_execution_id)