Source code for airflow.providers.amazon.aws.operators.emr_containers
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Optional
from uuid import uuid4
from airflow.exceptions import AirflowException
try:
from functools import cached_property
except ImportError:
from cached_property import cached_property
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.emr_containers import EMRContainerHook
[docs]class EMRContainerOperator(BaseOperator):
"""
An operator that submits jobs to EMR on EKS virtual clusters.
:param name: The name of the job run.
:type name: str
:param virtual_cluster_id: The EMR on EKS virtual cluster ID
:type virtual_cluster_id: str
:param execution_role_arn: The IAM role ARN associated with the job run.
:type execution_role_arn: str
:param release_label: The Amazon EMR release version to use for the job run.
:type release_label: str
:param job_driver: Job configuration details, e.g. the Spark job parameters.
:type job_driver: dict
:param configuration_overrides: The configuration overrides for the job run,
specifically either application configuration or monitoring configuration.
:type configuration_overrides: dict
:param client_request_token: The client idempotency token of the job run request.
Use this if you want to specify a unique ID to prevent two jobs from getting started.
If no token is provided, a UUIDv4 token will be generated for you.
:type client_request_token: str
:param aws_conn_id: The Airflow connection used for AWS credentials.
:type aws_conn_id: str
:param poll_interval: Time (in seconds) to wait between two consecutive calls to check query status on EMR
:type poll_interval: int
:param max_tries: Maximum number of times to wait for the job run to finish.
Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state.
:type max_tries: int
"""
[docs] template_fields = ["name", "virtual_cluster_id", "execution_role_arn", "release_label", "job_driver"]
def __init__(
self,
*,
name: str,
virtual_cluster_id: str,
execution_role_arn: str,
release_label: str,
job_driver: dict,
configuration_overrides: Optional[dict] = None,
client_request_token: Optional[str] = None,
aws_conn_id: str = "aws_default",
poll_interval: int = 30,
max_tries: Optional[int] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.name = name
self.virtual_cluster_id = virtual_cluster_id
self.execution_role_arn = execution_role_arn
self.release_label = release_label
self.job_driver = job_driver
self.configuration_overrides = configuration_overrides or {}
self.aws_conn_id = aws_conn_id
self.client_request_token = client_request_token or str(uuid4())
self.poll_interval = poll_interval
self.max_tries = max_tries
self.job_id = None
@cached_property
[docs] def hook(self) -> EMRContainerHook:
"""Create and return an EMRContainerHook."""
return EMRContainerHook(
self.aws_conn_id,
virtual_cluster_id=self.virtual_cluster_id,
)
[docs] def execute(self, context: dict) -> Optional[str]:
"""Run job on EMR Containers"""
self.job_id = self.hook.submit_job(
self.name,
self.execution_role_arn,
self.release_label,
self.job_driver,
self.configuration_overrides,
self.client_request_token,
)
query_status = self.hook.poll_query_status(self.job_id, self.max_tries, self.poll_interval)
if query_status in EMRContainerHook.FAILURE_STATES:
error_message = self.hook.get_job_failure_reason(self.job_id)
raise AirflowException(
f"EMR Containers job failed. Final state is {query_status}. "
f"query_execution_id is {self.job_id}. Error: {error_message}"
)
elif not query_status or query_status in EMRContainerHook.INTERMEDIATE_STATES:
raise AirflowException(
f"Final state of EMR Containers job is {query_status}. "
f"Max tries of poll status exceeded, query_execution_id is {self.job_id}."
)
return self.job_id
[docs] def on_kill(self) -> None:
"""Cancel the submitted job run"""
if self.job_id:
self.log.info("Stopping job run with jobId - %s", self.job_id)
response = self.hook.stop_query(self.job_id)
http_status_code = None
try:
http_status_code = response["ResponseMetadata"]["HTTPStatusCode"]
except Exception as ex:
self.log.error("Exception while cancelling query: %s", ex)
finally:
if http_status_code is None or http_status_code != 200:
self.log.error("Unable to request query cancel on EMR. Exiting")
else:
self.log.info(
"Polling EMR for query with id %s to reach final state",
self.job_id,
)
self.hook.poll_query_status(self.job_id)