# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from time import sleep
from typing import Any, Dict, Optional
from botocore.exceptions import ClientError
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
[docs]class EMRContainerHook(AwsBaseHook):
    """
    Interact with AWS EMR Virtual Cluster to run, poll jobs and return job status
    Additional arguments (such as ``aws_conn_id``) may be specified and
    are passed down to the underlying AwsBaseHook.
    .. seealso::
        :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
    :param virtual_cluster_id: Cluster ID of the EMR on EKS virtual cluster
    :type virtual_cluster_id: str
    """
    )
[docs]    FAILURE_STATES = (
        "FAILED",
        "CANCELLED",
        "CANCEL_PENDING", 
    )
[docs]    SUCCESS_STATES = ("COMPLETED",) 
    def __init__(self, *args: Any, virtual_cluster_id: str = None, **kwargs: Any) -> None:
        super().__init__(client_type="emr-containers", *args, **kwargs)  # type: ignore
        self.virtual_cluster_id = virtual_cluster_id
[docs]    def submit_job(
        self,
        name: str,
        execution_role_arn: str,
        release_label: str,
        job_driver: dict,
        configuration_overrides: Optional[dict] = None,
        client_request_token: Optional[str] = None,
    ) -> str:
        """
        Submit a job to the EMR Containers API and and return the job ID.
        A job run is a unit of work, such as a Spark jar, PySpark script,
        or SparkSQL query, that you submit to Amazon EMR on EKS.
        See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr-containers.html#EMRContainers.Client.start_job_run  # noqa: E501
        :param name: The name of the job run.
        :type name: str
        :param execution_role_arn: The IAM role ARN associated with the job run.
        :type execution_role_arn: str
        :param release_label: The Amazon EMR release version to use for the job run.
        :type release_label: str
        :param job_driver: Job configuration details, e.g. the Spark job parameters.
        :type job_driver: dict
        :param configuration_overrides: The configuration overrides for the job run,
            specifically either application configuration or monitoring configuration.
        :type configuration_overrides: dict
        :param client_request_token: The client idempotency token of the job run request.
            Use this if you want to specify a unique ID to prevent two jobs from getting started.
        :type client_request_token: str
        :return: Job ID
        """
        params = {
            "name": name,
            "virtualClusterId": self.virtual_cluster_id,
            "executionRoleArn": execution_role_arn,
            "releaseLabel": release_label,
            "jobDriver": job_driver,
            "configurationOverrides": configuration_overrides or {},
        }
        if client_request_token:
            params["clientToken"] = client_request_token
        response = self.conn.start_job_run(**params)
        if response['ResponseMetadata']['HTTPStatusCode'] != 200:
            raise AirflowException(f'Start Job Run failed: {response}')
        else:
            self.log.info(
                "Start Job Run success - Job Id %s and virtual cluster id %s",
                response['id'],
                response['virtualClusterId'],
            )
            return response['id'] 
[docs]    def get_job_failure_reason(self, job_id: str) -> Optional[str]:
        """
        Fetch the reason for a job failure (e.g. error message). Returns None or reason string.
        :param job_id: Id of submitted job run
        :type job_id: str
        :return: str
        """
        # We absorb any errors if we can't retrieve the job status
        reason = None
        try:
            response = self.conn.describe_job_run(
                virtualClusterId=self.virtual_cluster_id,
                id=job_id,
            )
            reason = response['jobRun']['failureReason']
        except KeyError:
            self.log.error('Could not get status of the EMR on EKS job')
        except ClientError as ex:
            self.log.error('AWS request failed, check logs for more info: %s', ex)
        return reason 
[docs]    def check_query_status(self, job_id: str) -> Optional[str]:
        """
        Fetch the status of submitted job run. Returns None or one of valid query states.
        See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr-containers.html#EMRContainers.Client.describe_job_run  # noqa: E501
        :param job_id: Id of submitted job run
        :type job_id: str
        :return: str
        """
        try:
            response = self.conn.describe_job_run(
                virtualClusterId=self.virtual_cluster_id,
                id=job_id,
            )
            return response["jobRun"]["state"]
        except self.conn.exceptions.ResourceNotFoundException:
            # If the job is not found, we raise an exception as something fatal has happened.
            raise AirflowException(f'Job ID {job_id} not found on Virtual Cluster {self.virtual_cluster_id}')
        except ClientError as ex:
            # If we receive a generic ClientError, we swallow the exception so that the
            self.log.error('AWS request failed, check logs for more info: %s', ex)
            return None 
[docs]    def poll_query_status(
        self, job_id: str, max_tries: Optional[int] = None, poll_interval: int = 30
    ) -> Optional[str]:
        """
        Poll the status of submitted job run until query state reaches final state.
        Returns one of the final states.
        :param job_id: Id of submitted job run
        :type job_id: str
        :param max_tries: Number of times to poll for query state before function exits
        :type max_tries: int
        :param poll_interval: Time (in seconds) to wait between calls to check query status on EMR
        :type poll_interval: int
        :return: str
        """
        try_number = 1
        final_query_state = None  # Query state when query reaches final state or max_tries reached
        # TODO: Make this logic a little bit more robust.
        # Currently this polls until the state is *not* one of the INTERMEDIATE_STATES
        # While that should work in most cases...it might not. :)
        while True:
            query_state = self.check_query_status(job_id)
            if query_state is None:
                self.log.info("Try %s: Invalid query state. Retrying again", try_number)
            elif query_state in self.INTERMEDIATE_STATES:
                self.log.info("Try %s: Query is still in an intermediate state - %s", try_number, query_state)
            else:
                self.log.info("Try %s: Query execution completed. Final state is %s", try_number, query_state)
                final_query_state = query_state
                break
            if max_tries and try_number >= max_tries:  # Break loop if max_tries reached
                final_query_state = query_state
                break
            try_number += 1
            sleep(poll_interval)
        return final_query_state 
[docs]    def stop_query(self, job_id: str) -> Dict:
        """
        Cancel the submitted job_run
        :param job_id: Id of submitted job_run
        :type job_id: str
        :return: dict
        """
        return self.conn.cancel_job_run(
            virtualClusterId=self.virtual_cluster_id,
            id=job_id,  
        )