Source code for airflow.providers.amazon.aws.hooks.emr
## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportjsonimportwarningsfromtimeimportsleepfromtypingimportAny,Callablefrombotocore.exceptionsimportClientErrorfromairflow.compat.functoolsimportcached_propertyfromairflow.exceptionsimportAirflowException,AirflowNotFoundExceptionfromairflow.providers.amazon.aws.hooks.base_awsimportAwsBaseHookfromairflow.providers.amazon.aws.utils.waiterimportget_state,waiter
[docs]classEmrHook(AwsBaseHook):""" Interact with Amazon Elastic MapReduce Service. :param emr_conn_id: :ref:`Amazon Elastic MapReduce Connection <howto/connection:emr>`. This attribute is only necessary when using the :meth:`~airflow.providers.amazon.aws.hooks.emr.EmrHook.create_job_flow` method. Additional arguments (such as ``aws_conn_id``) may be specified and are passed down to the underlying AwsBaseHook. .. seealso:: :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` """
[docs]defget_cluster_id_by_name(self,emr_cluster_name:str,cluster_states:list[str])->str|None:""" Fetch id of EMR cluster with given name and (optional) states. Will return only if single id is found. :param emr_cluster_name: Name of a cluster to find :param cluster_states: State(s) of cluster to find :return: id of the EMR cluster """response=self.get_conn().list_clusters(ClusterStates=cluster_states)matching_clusters=list(filter(lambdacluster:cluster["Name"]==emr_cluster_name,response["Clusters"]))iflen(matching_clusters)==1:cluster_id=matching_clusters[0]["Id"]self.log.info("Found cluster name = %s id = %s",emr_cluster_name,cluster_id)returncluster_ideliflen(matching_clusters)>1:raiseAirflowException(f"More than one cluster found for name {emr_cluster_name}")else:self.log.info("No cluster found for name %s",emr_cluster_name)returnNone
[docs]defcreate_job_flow(self,job_flow_overrides:dict[str,Any])->dict[str,Any]:""" Create and start running a new cluster (job flow). This method uses ``EmrHook.emr_conn_id`` to receive the initial Amazon EMR cluster configuration. If ``EmrHook.emr_conn_id`` is empty or the connection does not exist, then an empty initial configuration is used. :param job_flow_overrides: Is used to overwrite the parameters in the initial Amazon EMR configuration cluster. The resulting configuration will be used in the boto3 emr client run_job_flow method. .. seealso:: - :ref:`Amazon Elastic MapReduce Connection <howto/connection:emr>` - `API RunJobFlow <https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html>`_ - `boto3 emr client run_job_flow method <https://boto3.amazonaws.com/v1/documentation/\ api/latest/reference/services/emr.html#EMR.Client.run_job_flow>`_. """config={}ifself.emr_conn_id:try:emr_conn=self.get_connection(self.emr_conn_id)exceptAirflowNotFoundException:warnings.warn(f"Unable to find {self.hook_name} Connection ID {self.emr_conn_id!r}, ""using an empty initial configuration. If you want to get rid of this warning ""message please provide a valid `emr_conn_id` or set it to None.",UserWarning,stacklevel=2,)else:ifemr_conn.conn_typeandemr_conn.conn_type!=self.conn_type:warnings.warn(f"{self.hook_name} Connection expected connection type {self.conn_type!r}, "f"Connection {self.emr_conn_id!r} has conn_type={emr_conn.conn_type!r}. "f"This connection might not work correctly.",UserWarning,stacklevel=2,)config=emr_conn.extra_dejson.copy()config.update(job_flow_overrides)response=self.get_conn().run_job_flow(**config)returnresponse
[docs]defadd_job_flow_steps(self,job_flow_id:str,steps:list[dict]|str|None=None,wait_for_completion:bool=False,execution_role_arn:str|None=None,)->list[str]:""" Add new steps to a running cluster. :param job_flow_id: The id of the job flow to which the steps are being added :param steps: A list of the steps to be executed by the job flow :param wait_for_completion: If True, wait for the steps to be completed. Default is False :param execution_role_arn: The ARN of the runtime role for a step on the cluster. """config={}ifexecution_role_arn:config["ExecutionRoleArn"]=execution_role_arnresponse=self.get_conn().add_job_flow_steps(JobFlowId=job_flow_id,Steps=steps,**config)ifresponse["ResponseMetadata"]["HTTPStatusCode"]!=200:raiseAirflowException(f"Adding steps failed: {response}")self.log.info("Steps %s added to JobFlow",response["StepIds"])ifwait_for_completion:waiter=self.get_conn().get_waiter("step_complete")forstep_idinresponse["StepIds"]:waiter.wait(ClusterId=job_flow_id,StepId=step_id,WaiterConfig={"Delay":5,"MaxAttempts":100,},)returnresponse["StepIds"]
[docs]deftest_connection(self):""" Return failed state for test Amazon Elastic MapReduce Connection (untestable). We need to overwrite this method because this hook is based on :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook`, otherwise it will try to test connection to AWS STS by using the default boto3 credential strategy. """msg=(f"{self.hook_name!r} Airflow Connection cannot be tested, by design it stores "f"only key/value pairs and does not make a connection to an external resource.")returnFalse,msg
@staticmethod
[docs]defget_ui_field_behaviour()->dict[str,Any]:"""Returns custom UI field behaviour for Amazon Elastic MapReduce Connection."""return{"hidden_fields":["host","schema","port","login","password"],"relabeling":{"extra":"Run Job Flow Configuration",},"placeholders":{"extra":json.dumps({"Name":"MyClusterName","ReleaseLabel":"emr-5.36.0","Applications":[{"Name":"Spark"}],"Instances":{"InstanceGroups":[{"Name":"Primary node","Market":"SPOT","InstanceRole":"MASTER","InstanceType":"m5.large","InstanceCount":1,},],"KeepJobFlowAliveWhenNoSteps":False,"TerminationProtected":False,},"StepConcurrencyLevel":2,},indent=2,
),},}
[docs]classEmrServerlessHook(AwsBaseHook):""" Interact with EMR Serverless API. Additional arguments (such as ``aws_conn_id``) may be specified and are passed down to the underlying AwsBaseHook. .. seealso:: :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` """
[docs]defconn(self):"""Get the underlying boto3 EmrServerlessAPIService client (cached)"""returnsuper().conn
# This method should be replaced with boto waiters which would implement timeouts and backoff nicely.
[docs]defwaiter(self,get_state_callable:Callable,get_state_args:dict,parse_response:list,desired_state:set,failure_states:set,object_type:str,action:str,countdown:int=25*60,check_interval_seconds:int=60,)->None:""" Will run the sensor until it turns True. :param get_state_callable: A callable to run until it returns True :param get_state_args: Arguments to pass to get_state_callable :param parse_response: Dictionary keys to extract state from response of get_state_callable :param desired_state: Wait until the getter returns this value :param failure_states: A set of states which indicate failure and should throw an exception if any are reached before the desired_state :param object_type: Used for the reporting string. What are you waiting for? (application, job, etc) :param action: Used for the reporting string. What action are you waiting for? (created, deleted, etc) :param countdown: Total amount of time the waiter should wait for the desired state before timing out (in seconds). Defaults to 25 * 60 seconds. :param check_interval_seconds: Number of seconds waiter should wait before attempting to retry get_state_callable. Defaults to 60 seconds. """warnings.warn("""This method is deprecated. Please use `airflow.providers.amazon.aws.utils.waiter.waiter`.""",DeprecationWarning,stacklevel=2,)waiter(get_state_callable=get_state_callable,get_state_args=get_state_args,parse_response=parse_response,desired_state=desired_state,failure_states=failure_states,object_type=object_type,action=action,countdown=countdown,check_interval_seconds=check_interval_seconds,
)
[docs]defget_state(self,response,keys)->str:warnings.warn("""This method is deprecated. Please use `airflow.providers.amazon.aws.utils.waiter.get_state`.""",DeprecationWarning,stacklevel=2,)returnget_state(response=response,keys=keys)
[docs]classEmrContainerHook(AwsBaseHook):""" Interact with AWS EMR Virtual Cluster to run, poll jobs and return job status Additional arguments (such as ``aws_conn_id``) may be specified and are passed down to the underlying AwsBaseHook. .. seealso:: :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` :param virtual_cluster_id: Cluster ID of the EMR on EKS virtual cluster """
[docs]defsubmit_job(self,name:str,execution_role_arn:str,release_label:str,job_driver:dict,configuration_overrides:dict|None=None,client_request_token:str|None=None,tags:dict|None=None,)->str:""" Submit a job to the EMR Containers API and return the job ID. A job run is a unit of work, such as a Spark jar, PySpark script, or SparkSQL query, that you submit to Amazon EMR on EKS. See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr-containers.html#EMRContainers.Client.start_job_run # noqa: E501 :param name: The name of the job run. :param execution_role_arn: The IAM role ARN associated with the job run. :param release_label: The Amazon EMR release version to use for the job run. :param job_driver: Job configuration details, e.g. the Spark job parameters. :param configuration_overrides: The configuration overrides for the job run, specifically either application configuration or monitoring configuration. :param client_request_token: The client idempotency token of the job run request. Use this if you want to specify a unique ID to prevent two jobs from getting started. :param tags: The tags assigned to job runs. :return: Job ID """params={"name":name,"virtualClusterId":self.virtual_cluster_id,"executionRoleArn":execution_role_arn,"releaseLabel":release_label,"jobDriver":job_driver,"configurationOverrides":configuration_overridesor{},"tags":tagsor{},}ifclient_request_token:params["clientToken"]=client_request_tokenresponse=self.conn.start_job_run(**params)ifresponse["ResponseMetadata"]["HTTPStatusCode"]!=200:raiseAirflowException(f"Start Job Run failed: {response}")else:self.log.info("Start Job Run success - Job Id %s and virtual cluster id %s",response["id"],response["virtualClusterId"],)returnresponse["id"]
[docs]defget_job_failure_reason(self,job_id:str)->str|None:""" Fetch the reason for a job failure (e.g. error message). Returns None or reason string. :param job_id: Id of submitted job run :return: str """# We absorb any errors if we can't retrieve the job statusreason=Nonetry:response=self.conn.describe_job_run(virtualClusterId=self.virtual_cluster_id,id=job_id,)failure_reason=response["jobRun"]["failureReason"]state_details=response["jobRun"]["stateDetails"]reason=f"{failure_reason} - {state_details}"exceptKeyError:self.log.error("Could not get status of the EMR on EKS job")exceptClientErrorasex:self.log.error("AWS request failed, check logs for more info: %s",ex)returnreason
[docs]defcheck_query_status(self,job_id:str)->str|None:""" Fetch the status of submitted job run. Returns None or one of valid query states. See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr-containers.html#EMRContainers.Client.describe_job_run # noqa: E501 :param job_id: Id of submitted job run :return: str """try:response=self.conn.describe_job_run(virtualClusterId=self.virtual_cluster_id,id=job_id,)returnresponse["jobRun"]["state"]exceptself.conn.exceptions.ResourceNotFoundException:# If the job is not found, we raise an exception as something fatal has happened.raiseAirflowException(f"Job ID {job_id} not found on Virtual Cluster {self.virtual_cluster_id}")exceptClientErrorasex:# If we receive a generic ClientError, we swallow the exception so that theself.log.error("AWS request failed, check logs for more info: %s",ex)returnNone
[docs]defpoll_query_status(self,job_id:str,max_tries:int|None=None,poll_interval:int=30,max_polling_attempts:int|None=None,)->str|None:""" Poll the status of submitted job run until query state reaches final state. Returns one of the final states. :param job_id: Id of submitted job run :param max_tries: Deprecated - Use max_polling_attempts instead :param poll_interval: Time (in seconds) to wait between calls to check query status on EMR :param max_polling_attempts: Number of times to poll for query state before function exits :return: str """ifmax_tries:warnings.warn(f"Method `{self.__class__.__name__}.max_tries` is deprecated and will be removed ""in a future release. Please use method `max_polling_attempts` instead.",DeprecationWarning,stacklevel=2,)ifmax_polling_attemptsandmax_polling_attempts!=max_tries:raiseException("max_polling_attempts must be the same value as max_tries")else:max_polling_attempts=max_triestry_number=1final_query_state=None# Query state when query reaches final state or max_polling_attempts reachedwhileTrue:query_state=self.check_query_status(job_id)ifquery_stateisNone:self.log.info("Try %s: Invalid query state. Retrying again",try_number)elifquery_stateinself.TERMINAL_STATES:self.log.info("Try %s: Query execution completed. Final state is %s",try_number,query_state)final_query_state=query_statebreakelse:self.log.info("Try %s: Query is still in non-terminal state - %s",try_number,query_state)if(max_polling_attemptsandtry_number>=max_polling_attempts):# Break loop if max_polling_attempts reachedfinal_query_state=query_statebreaktry_number+=1sleep(poll_interval)returnfinal_query_state
[docs]defstop_query(self,job_id:str)->dict:""" Cancel the submitted job_run :param job_id: Id of submitted job_run :return: dict """returnself.conn.cancel_job_run(virtualClusterId=self.virtual_cluster_id,id=job_id,