#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import json
import warnings
from time import sleep
from typing import Any
from botocore.exceptions import ClientError
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.utils.helpers import prune_dict
[docs]class EmrHook(AwsBaseHook):
"""
Interact with Amazon Elastic MapReduce Service (EMR).
Provide thick wrapper around :external+boto3:py:class:`boto3.client("emr") <EMR.Client>`.
:param emr_conn_id: :ref:`Amazon Elastic MapReduce Connection <howto/connection:emr>`.
This attribute is only necessary when using
the :meth:`airflow.providers.amazon.aws.hooks.emr.EmrHook.create_job_flow`.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
[docs] conn_name_attr = "emr_conn_id"
[docs] default_conn_name = "emr_default"
[docs] hook_name = "Amazon Elastic MapReduce"
def __init__(self, emr_conn_id: str | None = default_conn_name, *args, **kwargs) -> None:
self.emr_conn_id = emr_conn_id
kwargs["client_type"] = "emr"
super().__init__(*args, **kwargs)
[docs] def get_cluster_id_by_name(self, emr_cluster_name: str, cluster_states: list[str]) -> str | None:
"""
Fetch id of EMR cluster with given name and (optional) states.
Will return only if single id is found.
.. seealso::
- :external+boto3:py:meth:`EMR.Client.list_clusters`
:param emr_cluster_name: Name of a cluster to find
:param cluster_states: State(s) of cluster to find
:return: id of the EMR cluster
"""
response = self.get_conn().list_clusters(ClusterStates=cluster_states)
matching_clusters = list(
filter(lambda cluster: cluster["Name"] == emr_cluster_name, response["Clusters"])
)
if len(matching_clusters) == 1:
cluster_id = matching_clusters[0]["Id"]
self.log.info("Found cluster name = %s id = %s", emr_cluster_name, cluster_id)
return cluster_id
elif len(matching_clusters) > 1:
raise AirflowException(f"More than one cluster found for name {emr_cluster_name}")
else:
self.log.info("No cluster found for name %s", emr_cluster_name)
return None
[docs] def create_job_flow(self, job_flow_overrides: dict[str, Any]) -> dict[str, Any]:
"""
Create and start running a new cluster (job flow).
.. seealso::
- :external+boto3:py:meth:`EMR.Client.run_job_flow`
This method uses ``EmrHook.emr_conn_id`` to receive the initial Amazon EMR cluster configuration.
If ``EmrHook.emr_conn_id`` is empty or the connection does not exist, then an empty initial
configuration is used.
:param job_flow_overrides: Is used to overwrite the parameters in the initial Amazon EMR configuration
cluster. The resulting configuration will be used in the
:external+boto3:py:meth:`EMR.Client.run_job_flow`.
.. seealso::
- :ref:`Amazon Elastic MapReduce Connection <howto/connection:emr>`
- :external+boto3:py:meth:`EMR.Client.run_job_flow`
- `API RunJobFlow <https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html>`_
"""
config = {}
if self.emr_conn_id:
try:
emr_conn = self.get_connection(self.emr_conn_id)
except AirflowNotFoundException:
warnings.warn(
f"Unable to find {self.hook_name} Connection ID {self.emr_conn_id!r}, "
"using an empty initial configuration. If you want to get rid of this warning "
"message please provide a valid `emr_conn_id` or set it to None.",
UserWarning,
stacklevel=2,
)
else:
if emr_conn.conn_type and emr_conn.conn_type != self.conn_type:
warnings.warn(
f"{self.hook_name} Connection expected connection type {self.conn_type!r}, "
f"Connection {self.emr_conn_id!r} has conn_type={emr_conn.conn_type!r}. "
f"This connection might not work correctly.",
UserWarning,
stacklevel=2,
)
config = emr_conn.extra_dejson.copy()
config.update(job_flow_overrides)
response = self.get_conn().run_job_flow(**config)
return response
[docs] def add_job_flow_steps(
self,
job_flow_id: str,
steps: list[dict] | str | None = None,
wait_for_completion: bool = False,
waiter_delay: int | None = None,
waiter_max_attempts: int | None = None,
execution_role_arn: str | None = None,
) -> list[str]:
"""
Add new steps to a running cluster.
.. seealso::
- :external+boto3:py:meth:`EMR.Client.add_job_flow_steps`
:param job_flow_id: The id of the job flow to which the steps are being added
:param steps: A list of the steps to be executed by the job flow
:param wait_for_completion: If True, wait for the steps to be completed. Default is False
:param waiter_delay: The amount of time in seconds to wait between attempts. Default is 5
:param waiter_max_attempts: The maximum number of attempts to be made. Default is 100
:param execution_role_arn: The ARN of the runtime role for a step on the cluster.
"""
config = {}
if execution_role_arn:
config["ExecutionRoleArn"] = execution_role_arn
response = self.get_conn().add_job_flow_steps(JobFlowId=job_flow_id, Steps=steps, **config)
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Adding steps failed: {response}")
self.log.info("Steps %s added to JobFlow", response["StepIds"])
if wait_for_completion:
waiter = self.get_conn().get_waiter("step_complete")
for step_id in response["StepIds"]:
waiter.wait(
ClusterId=job_flow_id,
StepId=step_id,
WaiterConfig=prune_dict(
{
"Delay": waiter_delay,
"MaxAttempts": waiter_max_attempts,
}
),
)
return response["StepIds"]
[docs] def test_connection(self):
"""
Return failed state for test Amazon Elastic MapReduce Connection (untestable).
We need to overwrite this method because this hook is based on
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook`,
otherwise it will try to test connection to AWS STS by using the default boto3 credential strategy.
"""
msg = (
f"{self.hook_name!r} Airflow Connection cannot be tested, by design it stores "
f"only key/value pairs and does not make a connection to an external resource."
)
return False, msg
@staticmethod
[docs] def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom UI field behaviour for Amazon Elastic MapReduce Connection."""
return {
"hidden_fields": ["host", "schema", "port", "login", "password"],
"relabeling": {
"extra": "Run Job Flow Configuration",
},
"placeholders": {
"extra": json.dumps(
{
"Name": "MyClusterName",
"ReleaseLabel": "emr-5.36.0",
"Applications": [{"Name": "Spark"}],
"Instances": {
"InstanceGroups": [
{
"Name": "Primary node",
"Market": "SPOT",
"InstanceRole": "MASTER",
"InstanceType": "m5.large",
"InstanceCount": 1,
},
],
"KeepJobFlowAliveWhenNoSteps": False,
"TerminationProtected": False,
},
"StepConcurrencyLevel": 2,
},
indent=2,
),
},
}
[docs]class EmrServerlessHook(AwsBaseHook):
"""
Interact with Amazon EMR Serverless.
Provide thin wrapper around :py:class:`boto3.client("emr-serverless") <EMRServerless.Client>`.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
- :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
[docs] JOB_FAILURE_STATES = {"FAILED", "CANCELLING", "CANCELLED"}
[docs] JOB_SUCCESS_STATES = {"SUCCESS"}
[docs] JOB_TERMINAL_STATES = JOB_SUCCESS_STATES.union(JOB_FAILURE_STATES)
[docs] APPLICATION_FAILURE_STATES = {"STOPPED", "TERMINATED"}
[docs] APPLICATION_SUCCESS_STATES = {"CREATED", "STARTED"}
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["client_type"] = "emr-serverless"
super().__init__(*args, **kwargs)
[docs] def cancel_running_jobs(self, application_id: str, waiter_config: dict = {}):
"""
List all jobs in an intermediate state and cancel them.
Then wait for those jobs to reach a terminal state.
Note: if new jobs are triggered while this operation is ongoing,
it's going to time out and return an error.
"""
paginator = self.conn.get_paginator("list_job_runs")
results_per_response = 50
iterator = paginator.paginate(
applicationId=application_id,
states=list(self.JOB_INTERMEDIATE_STATES),
PaginationConfig={
"PageSize": results_per_response,
},
)
count = 0
for r in iterator:
job_ids = [jr["id"] for jr in r["jobRuns"]]
count += len(job_ids)
if len(job_ids) > 0:
self.log.info(
"Cancelling %s pending job(s) for the application %s so that it can be stopped",
len(job_ids),
application_id,
)
for job_id in job_ids:
self.conn.cancel_job_run(applicationId=application_id, jobRunId=job_id)
if count > 0:
self.log.info("now waiting for the %s cancelled job(s) to terminate", count)
self.get_waiter("no_job_running").wait(
applicationId=application_id,
states=list(self.JOB_INTERMEDIATE_STATES.union({"CANCELLING"})),
WaiterConfig=waiter_config,
)
[docs]class EmrContainerHook(AwsBaseHook):
"""
Interact with Amazon EMR Containers (Amazon EMR on EKS).
Provide thick wrapper around :py:class:`boto3.client("emr-containers") <EMRContainers.Client>`.
:param virtual_cluster_id: Cluster ID of the EMR on EKS virtual cluster
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
- :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
)
[docs] FAILURE_STATES = (
"FAILED",
"CANCELLED",
"CANCEL_PENDING",
)
[docs] SUCCESS_STATES = ("COMPLETED",)
[docs] TERMINAL_STATES = (
"COMPLETED",
"FAILED",
"CANCELLED",
"CANCEL_PENDING",
)
def __init__(self, *args: Any, virtual_cluster_id: str | None = None, **kwargs: Any) -> None:
super().__init__(client_type="emr-containers", *args, **kwargs) # type: ignore
self.virtual_cluster_id = virtual_cluster_id
[docs] def create_emr_on_eks_cluster(
self,
virtual_cluster_name: str,
eks_cluster_name: str,
eks_namespace: str,
tags: dict | None = None,
) -> str:
response = self.conn.create_virtual_cluster(
name=virtual_cluster_name,
containerProvider={
"id": eks_cluster_name,
"type": "EKS",
"info": {"eksInfo": {"namespace": eks_namespace}},
},
tags=tags or {},
)
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Create EMR EKS Cluster failed: {response}")
else:
self.log.info(
"Create EMR EKS Cluster success - virtual cluster id %s",
response["id"],
)
return response["id"]
[docs] def submit_job(
self,
name: str,
execution_role_arn: str,
release_label: str,
job_driver: dict,
configuration_overrides: dict | None = None,
client_request_token: str | None = None,
tags: dict | None = None,
) -> str:
"""
Submit a job to the EMR Containers API and return the job ID.
A job run is a unit of work, such as a Spark jar, PySpark script,
or SparkSQL query, that you submit to Amazon EMR on EKS.
.. seealso::
- :external+boto3:py:meth:`EMRContainers.Client.start_job_run`
:param name: The name of the job run.
:param execution_role_arn: The IAM role ARN associated with the job run.
:param release_label: The Amazon EMR release version to use for the job run.
:param job_driver: Job configuration details, e.g. the Spark job parameters.
:param configuration_overrides: The configuration overrides for the job run,
specifically either application configuration or monitoring configuration.
:param client_request_token: The client idempotency token of the job run request.
Use this if you want to specify a unique ID to prevent two jobs from getting started.
:param tags: The tags assigned to job runs.
:return: The ID of the job run request.
"""
params = {
"name": name,
"virtualClusterId": self.virtual_cluster_id,
"executionRoleArn": execution_role_arn,
"releaseLabel": release_label,
"jobDriver": job_driver,
"configurationOverrides": configuration_overrides or {},
"tags": tags or {},
}
if client_request_token:
params["clientToken"] = client_request_token
response = self.conn.start_job_run(**params)
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Start Job Run failed: {response}")
else:
self.log.info(
"Start Job Run success - Job Id %s and virtual cluster id %s",
response["id"],
response["virtualClusterId"],
)
return response["id"]
[docs] def get_job_failure_reason(self, job_id: str) -> str | None:
"""
Fetch the reason for a job failure (e.g. error message). Returns None or reason string.
.. seealso::
- :external+boto3:py:meth:`EMRContainers.Client.describe_job_run`
:param job_id: The ID of the job run request.
"""
reason = None # We absorb any errors if we can't retrieve the job status
try:
response = self.conn.describe_job_run(
virtualClusterId=self.virtual_cluster_id,
id=job_id,
)
failure_reason = response["jobRun"]["failureReason"]
state_details = response["jobRun"]["stateDetails"]
reason = f"{failure_reason} - {state_details}"
except KeyError:
self.log.error("Could not get status of the EMR on EKS job")
except ClientError as ex:
self.log.error("AWS request failed, check logs for more info: %s", ex)
return reason
[docs] def check_query_status(self, job_id: str) -> str | None:
"""
Fetch the status of submitted job run. Returns None or one of valid query states.
.. seealso::
- :external+boto3:py:meth:`EMRContainers.Client.describe_job_run`
:param job_id: The ID of the job run request.
"""
try:
response = self.conn.describe_job_run(
virtualClusterId=self.virtual_cluster_id,
id=job_id,
)
return response["jobRun"]["state"]
except self.conn.exceptions.ResourceNotFoundException:
# If the job is not found, we raise an exception as something fatal has happened.
raise AirflowException(f"Job ID {job_id} not found on Virtual Cluster {self.virtual_cluster_id}")
except ClientError as ex:
# If we receive a generic ClientError, we swallow the exception so that the
self.log.error("AWS request failed, check logs for more info: %s", ex)
return None
[docs] def poll_query_status(
self,
job_id: str,
poll_interval: int = 30,
max_polling_attempts: int | None = None,
) -> str | None:
"""
Poll the status of submitted job run until query state reaches final state.
Returns one of the final states.
:param job_id: The ID of the job run request.
:param poll_interval: Time (in seconds) to wait between calls to check query status on EMR
:param max_polling_attempts: Number of times to poll for query state before function exits
"""
try_number = 1
final_query_state = None # Query state when query reaches final state or max_polling_attempts reached
while True:
query_state = self.check_query_status(job_id)
if query_state is None:
self.log.info("Try %s: Invalid query state. Retrying again", try_number)
elif query_state in self.TERMINAL_STATES:
self.log.info("Try %s: Query execution completed. Final state is %s", try_number, query_state)
final_query_state = query_state
break
else:
self.log.info("Try %s: Query is still in non-terminal state - %s", try_number, query_state)
if (
max_polling_attempts and try_number >= max_polling_attempts
): # Break loop if max_polling_attempts reached
final_query_state = query_state
break
try_number += 1
sleep(poll_interval)
return final_query_state
[docs] def stop_query(self, job_id: str) -> dict:
"""
Cancel the submitted job_run
.. seealso::
- :external+boto3:py:meth:`EMRContainers.Client.cancel_job_run`
:param job_id: The ID of the job run to cancel.
"""
return self.conn.cancel_job_run(
virtualClusterId=self.virtual_cluster_id,
id=job_id,
)