Source code for airflow.providers.amazon.aws.hooks.emr
## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.fromtimeimportsleepfromtypingimportAny,Callable,Dict,List,Optional,Setfrombotocore.exceptionsimportClientErrorfromairflow.compat.functoolsimportcached_propertyfromairflow.exceptionsimportAirflowException,AirflowNotFoundExceptionfromairflow.providers.amazon.aws.hooks.base_awsimportAwsBaseHook
[docs]classEmrHook(AwsBaseHook):""" Interact with AWS EMR. emr_conn_id is only necessary for using the create_job_flow method. Additional arguments (such as ``aws_conn_id``) may be specified and are passed down to the underlying AwsBaseHook. .. seealso:: :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` """
[docs]defget_cluster_id_by_name(self,emr_cluster_name:str,cluster_states:List[str])->Optional[str]:""" Fetch id of EMR cluster with given name and (optional) states. Will return only if single id is found. :param emr_cluster_name: Name of a cluster to find :param cluster_states: State(s) of cluster to find :return: id of the EMR cluster """response=self.get_conn().list_clusters(ClusterStates=cluster_states)matching_clusters=list(filter(lambdacluster:cluster['Name']==emr_cluster_name,response['Clusters']))iflen(matching_clusters)==1:cluster_id=matching_clusters[0]['Id']self.log.info('Found cluster name = %s id = %s',emr_cluster_name,cluster_id)returncluster_ideliflen(matching_clusters)>1:raiseAirflowException(f'More than one cluster found for name {emr_cluster_name}')else:self.log.info('No cluster found for name %s',emr_cluster_name)returnNone
[docs]defcreate_job_flow(self,job_flow_overrides:Dict[str,Any])->Dict[str,Any]:""" Creates a job flow using the config from the EMR connection. Keys of the json extra hash may have the arguments of the boto3 run_job_flow method. Overrides for this config may be passed as the job_flow_overrides. """try:emr_conn=self.get_connection(self.emr_conn_id)config=emr_conn.extra_dejson.copy()exceptAirflowNotFoundException:config={}config.update(job_flow_overrides)response=self.get_conn().run_job_flow(**config)returnresponse
[docs]classEmrServerlessHook(AwsBaseHook):""" Interact with EMR Serverless API. Additional arguments (such as ``aws_conn_id``) may be specified and are passed down to the underlying AwsBaseHook. .. seealso:: :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` """def__init__(self,*args:Any,**kwargs:Any)->None:kwargs["client_type"]="emr-serverless"super().__init__(*args,**kwargs)@cached_property
[docs]defconn(self):"""Get the underlying boto3 EmrServerlessAPIService client (cached)"""returnsuper().conn
# This method should be replaced with boto waiters which would implement timeouts and backoff nicely.
[docs]defwaiter(self,get_state_callable:Callable,get_state_args:Dict,parse_response:List,desired_state:Set,failure_states:Set,object_type:str,action:str,countdown:int=25*60,check_interval_seconds:int=60,)->None:""" Will run the sensor until it turns True. :param get_state_callable: A callable to run until it returns True :param get_state_args: Arguments to pass to get_state_callable :param parse_response: Dictionary keys to extract state from response of get_state_callable :param desired_state: Wait until the getter returns this value :param failure_states: A set of states which indicate failure and should throw an exception if any are reached before the desired_state :param object_type: Used for the reporting string. What are you waiting for? (application, job, etc) :param action: Used for the reporting string. What action are you waiting for? (created, deleted, etc) :param countdown: Total amount of time the waiter should wait for the desired state before timing out (in seconds). Defaults to 25 * 60 seconds. :param check_interval_seconds: Number of seconds waiter should wait before attempting to retry get_state_callable. Defaults to 60 seconds. """response=get_state_callable(**get_state_args)state:str=self.get_state(response,parse_response)whilestatenotindesired_state:ifstateinfailure_states:raiseAirflowException(f'{object_type.title()} reached failure state {state}.')ifcountdown>=check_interval_seconds:countdown-=check_interval_secondsself.log.info('Waiting for %s to be %s.',object_type.lower(),action.lower())sleep(check_interval_seconds)state=self.get_state(get_state_callable(**get_state_args),parse_response)else:message=f'{object_type.title()} still not {action.lower()} after the allocated time limit.'self.log.error(message)raiseRuntimeError(message)
[docs]classEmrContainerHook(AwsBaseHook):""" Interact with AWS EMR Virtual Cluster to run, poll jobs and return job status Additional arguments (such as ``aws_conn_id``) may be specified and are passed down to the underlying AwsBaseHook. .. seealso:: :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` :param virtual_cluster_id: Cluster ID of the EMR on EKS virtual cluster """
[docs]defsubmit_job(self,name:str,execution_role_arn:str,release_label:str,job_driver:dict,configuration_overrides:Optional[dict]=None,client_request_token:Optional[str]=None,tags:Optional[dict]=None,)->str:""" Submit a job to the EMR Containers API and return the job ID. A job run is a unit of work, such as a Spark jar, PySpark script, or SparkSQL query, that you submit to Amazon EMR on EKS. See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr-containers.html#EMRContainers.Client.start_job_run # noqa: E501 :param name: The name of the job run. :param execution_role_arn: The IAM role ARN associated with the job run. :param release_label: The Amazon EMR release version to use for the job run. :param job_driver: Job configuration details, e.g. the Spark job parameters. :param configuration_overrides: The configuration overrides for the job run, specifically either application configuration or monitoring configuration. :param client_request_token: The client idempotency token of the job run request. Use this if you want to specify a unique ID to prevent two jobs from getting started. :param tags: The tags assigned to job runs. :return: Job ID """params={"name":name,"virtualClusterId":self.virtual_cluster_id,"executionRoleArn":execution_role_arn,"releaseLabel":release_label,"jobDriver":job_driver,"configurationOverrides":configuration_overridesor{},"tags":tagsor{},}ifclient_request_token:params["clientToken"]=client_request_tokenresponse=self.conn.start_job_run(**params)ifresponse['ResponseMetadata']['HTTPStatusCode']!=200:raiseAirflowException(f'Start Job Run failed: {response}')else:self.log.info("Start Job Run success - Job Id %s and virtual cluster id %s",response['id'],response['virtualClusterId'],)returnresponse['id']
[docs]defget_job_failure_reason(self,job_id:str)->Optional[str]:""" Fetch the reason for a job failure (e.g. error message). Returns None or reason string. :param job_id: Id of submitted job run :return: str """# We absorb any errors if we can't retrieve the job statusreason=Nonetry:response=self.conn.describe_job_run(virtualClusterId=self.virtual_cluster_id,id=job_id,)failure_reason=response['jobRun']['failureReason']state_details=response["jobRun"]["stateDetails"]reason=f"{failure_reason} - {state_details}"exceptKeyError:self.log.error('Could not get status of the EMR on EKS job')exceptClientErrorasex:self.log.error('AWS request failed, check logs for more info: %s',ex)returnreason
[docs]defcheck_query_status(self,job_id:str)->Optional[str]:""" Fetch the status of submitted job run. Returns None or one of valid query states. See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr-containers.html#EMRContainers.Client.describe_job_run # noqa: E501 :param job_id: Id of submitted job run :return: str """try:response=self.conn.describe_job_run(virtualClusterId=self.virtual_cluster_id,id=job_id,)returnresponse["jobRun"]["state"]exceptself.conn.exceptions.ResourceNotFoundException:# If the job is not found, we raise an exception as something fatal has happened.raiseAirflowException(f'Job ID {job_id} not found on Virtual Cluster {self.virtual_cluster_id}')exceptClientErrorasex:# If we receive a generic ClientError, we swallow the exception so that theself.log.error('AWS request failed, check logs for more info: %s',ex)returnNone
[docs]defpoll_query_status(self,job_id:str,max_tries:Optional[int]=None,poll_interval:int=30)->Optional[str]:""" Poll the status of submitted job run until query state reaches final state. Returns one of the final states. :param job_id: Id of submitted job run :param max_tries: Number of times to poll for query state before function exits :param poll_interval: Time (in seconds) to wait between calls to check query status on EMR :return: str """try_number=1final_query_state=None# Query state when query reaches final state or max_tries reachedwhileTrue:query_state=self.check_query_status(job_id)ifquery_stateisNone:self.log.info("Try %s: Invalid query state. Retrying again",try_number)elifquery_stateinself.TERMINAL_STATES:self.log.info("Try %s: Query execution completed. Final state is %s",try_number,query_state)final_query_state=query_statebreakelse:self.log.info("Try %s: Query is still in non-terminal state - %s",try_number,query_state)ifmax_triesandtry_number>=max_tries:# Break loop if max_tries reachedfinal_query_state=query_statebreaktry_number+=1sleep(poll_interval)returnfinal_query_state
[docs]defstop_query(self,job_id:str)->Dict:""" Cancel the submitted job_run :param job_id: Id of submitted job_run :return: dict """returnself.conn.cancel_job_run(virtualClusterId=self.virtual_cluster_id,id=job_id,