# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
import sys
from math import pow
from random import randint
from time import sleep
import botocore.exceptions
import botocore.waiter
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.utils import apply_defaults
[docs]class AWSBatchOperator(BaseOperator):
    """
    Execute a job on AWS Batch Service
    .. warning: the queue parameter was renamed to job_queue to segregate the
                internal CeleryExecutor queue from the AWS Batch internal queue.
    :param job_name: the name for the job that will run on AWS Batch (templated)
    :type job_name: str
    :param job_definition: the job definition name on AWS Batch
    :type job_definition: str
    :param job_queue: the queue name on AWS Batch
    :type job_queue: str
    :param overrides: the same parameter that boto3 will receive on
        containerOverrides (templated)
        http://boto3.readthedocs.io/en/latest/reference/services/batch.html#Batch.Client.submit_job
    :type overrides: dict
    :param array_properties: the same parameter that boto3 will receive on
        arrayProperties
        http://boto3.readthedocs.io/en/latest/reference/services/batch.html#Batch.Client.submit_job
    :type array_properties: dict
    :param parameters: the same parameter that boto3 will receive on
        parameters (templated)
        http://boto3.readthedocs.io/en/latest/reference/services/batch.html#Batch.Client.submit_job
    :type parameters: dict
    :param max_retries: exponential backoff retries while waiter is not
        merged, 4200 = 48 hours
    :type max_retries: int
    :param status_retries: number of retries to get job description (status), 10
    :type status_retries: int
    :param aws_conn_id: connection id of AWS credentials / region name. If None,
        credential boto3 strategy will be used
        (http://boto3.readthedocs.io/en/latest/guide/configuration.html).
    :type aws_conn_id: str
    :param region_name: region name to use in AWS Hook.
        Override the region_name in connection (if provided)
    :type region_name: str
    """
[docs]    template_fields = (
        "job_name",
        "overrides",
        "parameters", 
    )
    @apply_defaults
    def __init__(
        self,
        job_name,
        job_definition,
        job_queue,
        overrides,
        array_properties=None,
        parameters=None,
        max_retries=MAX_RETRIES,
        status_retries=STATUS_RETRIES,
        aws_conn_id=None,
        region_name=None,
        **kwargs
    ):
        super(AWSBatchOperator, self).__init__(**kwargs)
        self.job_name = job_name
        self.aws_conn_id = aws_conn_id
        self.region_name = region_name
        self.job_definition = job_definition
        self.job_queue = job_queue
        self.overrides = overrides
        self.array_properties = array_properties or {}
        self.parameters = parameters
        self.max_retries = max_retries
        self.status_retries = status_retries
        self.jobId = None  # pylint: disable=invalid-name
        self.jobName = None  # pylint: disable=invalid-name
        self.hook = self.get_hook()
[docs]    def execute(self, context):
        self.log.info(
            "Running AWS Batch Job - Job definition: %s - on queue %s",
            self.job_definition,
            self.job_queue,
        )
        self.log.info("AWSBatchOperator overrides: %s", self.overrides)
        self.client = self.hook.get_client_type("batch", region_name=self.region_name)
        try:
            response = self.client.submit_job(
                jobName=self.job_name,
                jobQueue=self.job_queue,
                jobDefinition=self.job_definition,
                arrayProperties=self.array_properties,
                parameters=self.parameters,
                containerOverrides=self.overrides,
            )
            self.log.info("AWS Batch Job started: %s", response)
            self.jobId = response["jobId"]
            self.jobName = response["jobName"]
            self._wait_for_task_ended()
            self._check_success_task()
            self.log.info("AWS Batch Job has been successfully executed: %s", response)
        except Exception as e:
            self.log.info("AWS Batch Job has failed executed")
            raise AirflowException(e) 
[docs]    def _wait_for_task_ended(self):
        """
        Try to use a waiter from the below pull request
            * https://github.com/boto/botocore/pull/1307
        If the waiter is not available apply a exponential backoff
            * docs.aws.amazon.com/general/latest/gr/api-retries.html
        """
        try:
            waiter = self.client.get_waiter("job_execution_complete")
            waiter.config.max_attempts = sys.maxsize  # timeout is managed by airflow
            waiter.wait(jobs=[self.jobId])
        except ValueError:
            self._poll_for_task_ended() 
[docs]    def _poll_for_task_ended(self):
        """
        Poll for job status
            * docs.aws.amazon.com/general/latest/gr/api-retries.html
        """
        # Allow a batch job some time to spin up.  A random interval
        # decreases the chances of exceeding an AWS API throttle
        # limit when there are many concurrent tasks.
        pause = randint(5, 30)
        tries = 0
        while tries < self.max_retries:
            tries += 1
            self.log.info(
                "AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds",
                self.jobId,
                tries,
                self.max_retries,
                pause,
            )
            sleep(pause)
            response = self._get_job_description()
            jobs = response.get("jobs")
            status = jobs[-1]["status"]  # check last job status
            self.log.info("AWS Batch job (%s) status: %s", self.jobId, status)
            # status options: 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED'
            if status in ["SUCCEEDED", "FAILED"]:
                break
            pause = 1 + pow(tries * 0.3, 2) 
[docs]    def _get_job_description(self):
        """
        Get job description
            * https://docs.aws.amazon.com/batch/latest/APIReference/API_DescribeJobs.html
            * https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html
        """
        tries = 0
        while tries < self.status_retries:
            tries += 1
            try:
                response = self.client.describe_jobs(jobs=[self.jobId])  # type: ignore
                if response and response.get("jobs"):
                    return response
                else:
                    self.log.error("Job description has no jobs (%s): %s", self.jobId, response)
            except botocore.exceptions.ClientError as err:
                response = err.response
                self.log.error("Job description error (%s): %s", self.jobId, response)
                if tries < self.status_retries:
                    error = response.get("Error", {})
                    if error.get("Code") == "TooManyRequestsException":
                        pause = randint(1, 10)  # avoid excess requests with a random pause
                        self.log.info(
                            "AWS Batch job (%s) status retry (%d of %d) in the next %.2f seconds",
                            self.jobId,
                            tries,
                            self.status_retries,
                            pause,
                        )
                        sleep(pause)
                        continue
        msg = "Failed to get job description ({})".format(self.jobId)
        raise AirflowException(msg) 
[docs]    def _check_success_task(self):
        """
        Check the final status of the batch job; the job status options are:
        'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED'
        """
        response = self._get_job_description()
        jobs = response.get("jobs")
        matching_jobs = [job for job in jobs if job["jobId"] == self.jobId]
        if not matching_jobs:
            raise AirflowException(
                "Job ({}) has no job description {}".format(self.jobId, response)
            )
        job = matching_jobs[0]
        self.log.info("AWS Batch stopped, check status: %s", job)
        job_status = job["status"]
        if job_status == "FAILED":
            reason = job["statusReason"]
            raise AirflowException("Job ({}) failed with status {}".format(self.jobId, reason))
        elif job_status in ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING", "RUNNING"]:
            raise AirflowException(
                "Job ({}) is still pending {}".format(self.jobId, job_status) 
            )
[docs]    def get_hook(self):
        return AwsHook(aws_conn_id=self.aws_conn_id) 
[docs]    def on_kill(self):
        response = self.client.terminate_job(jobId=self.jobId, reason="Task killed by the user")
        self.log.info(response)