# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from time import sleep
from airflow.contrib.hooks.aws_hook import AwsHook
[docs]class AWSAthenaHook(AwsHook):
"""
Interact with AWS Athena to run, poll queries and return query results
:param aws_conn_id: aws connection to use.
:type aws_conn_id: str
:param sleep_time: Time to wait between two consecutive call to check query status on athena
:type sleep_time: int
"""
[docs] FAILURE_STATES = ('FAILED', 'CANCELLED',)
[docs] SUCCESS_STATES = ('SUCCEEDED',)
def __init__(self, aws_conn_id='aws_default', sleep_time=30, *args, **kwargs):
super(AWSAthenaHook, self).__init__(aws_conn_id, *args, **kwargs)
self.sleep_time = sleep_time
self.conn = None
[docs] def get_conn(self):
"""
check if aws conn exists already or create one and return it
:return: boto3 session
"""
if not self.conn:
self.conn = self.get_client_type('athena')
return self.conn
[docs] def run_query(self, query, query_context, result_configuration, client_request_token=None):
"""
Run Presto query on athena with provided config and return submitted query_execution_id
:param query: Presto query to run
:type query: str
:param query_context: Context in which query need to be run
:type query_context: dict
:param result_configuration: Dict with path to store results in and config related to encryption
:type result_configuration: dict
:param client_request_token: Unique token created by user to avoid multiple executions of same query
:type client_request_token: str
:return: str
"""
response = self.get_conn().start_query_execution(QueryString=query,
ClientRequestToken=client_request_token,
QueryExecutionContext=query_context,
ResultConfiguration=result_configuration)
query_execution_id = response['QueryExecutionId']
return query_execution_id
[docs] def check_query_status(self, query_execution_id):
"""
Fetch the status of submitted athena query. Returns None or one of valid query states.
:param query_execution_id: Id of submitted athena query
:type query_execution_id: str
:return: str
"""
response = self.get_conn().get_query_execution(QueryExecutionId=query_execution_id)
state = None
try:
state = response['QueryExecution']['Status']['State']
except Exception as ex:
self.log.error('Exception while getting query state', ex)
finally:
return state
[docs] def get_state_change_reason(self, query_execution_id):
"""
Fetch the reason for a state change (e.g. error message). Returns None or reason string.
:param query_execution_id: Id of submitted athena query
:type query_execution_id: str
:return: str
"""
response = self.get_conn().get_query_execution(QueryExecutionId=query_execution_id)
reason = None
try:
reason = response['QueryExecution']['Status']['StateChangeReason']
except Exception as ex:
self.log.error('Exception while getting query state change reason', ex)
finally:
return reason
[docs] def get_query_results(self, query_execution_id):
"""
Fetch submitted athena query results. returns none if query is in intermediate state or
failed/cancelled state else dict of query output
:param query_execution_id: Id of submitted athena query
:type query_execution_id: str
:return: dict
"""
query_state = self.check_query_status(query_execution_id)
if query_state is None:
self.log.error('Invalid Query state')
return None
elif query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES:
self.log.error('Query is in {state} state. Cannot fetch results'.format(state=query_state))
return None
return self.get_conn().get_query_results(QueryExecutionId=query_execution_id)
[docs] def poll_query_status(self, query_execution_id, max_tries=None):
"""
Poll the status of submitted athena query until query state reaches final state.
Returns one of the final states
:param query_execution_id: Id of submitted athena query
:type query_execution_id: str
:param max_tries: Number of times to poll for query state before function exits
:type max_tries: int
:return: str
"""
try_number = 1
final_query_state = None # Query state when query reaches final state or max_tries reached
while True:
query_state = self.check_query_status(query_execution_id)
if query_state is None:
self.log.info('Trial {try_number}: Invalid query state. Retrying again'.format(
try_number=try_number))
elif query_state in self.INTERMEDIATE_STATES:
self.log.info('Trial {try_number}: Query is still in an intermediate state - {state}'
.format(try_number=try_number, state=query_state))
else:
self.log.info('Trial {try_number}: Query execution completed. Final state is {state}'
.format(try_number=try_number, state=query_state))
final_query_state = query_state
break
if max_tries and try_number >= max_tries: # Break loop if max_tries reached
final_query_state = query_state
break
try_number += 1
sleep(self.sleep_time)
return final_query_state
[docs] def stop_query(self, query_execution_id):
"""
Cancel the submitted athena query
:param query_execution_id: Id of submitted athena query
:type query_execution_id: str
:return: dict
"""
return self.get_conn().stop_query_execution(QueryExecutionId=query_execution_id)