Source code for airflow.providers.amazon.aws.triggers.emr
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportasyncioimportwarningsfromtypingimportAnyfrombotocore.exceptionsimportWaiterErrorfromairflow.exceptionsimportAirflowProviderDeprecationWarningfromairflow.providers.amazon.aws.hooks.base_awsimportAwsGenericHookfromairflow.providers.amazon.aws.hooks.emrimportEmrContainerHook,EmrHook,EmrServerlessHookfromairflow.providers.amazon.aws.triggers.baseimportAwsBaseWaiterTriggerfromairflow.triggers.baseimportBaseTrigger,TriggerEvent
[docs]classEmrAddStepsTrigger(BaseTrigger):""" Asynchronously poll the boto3 API and wait for the steps to finish executing. :param job_flow_id: The id of the job flow. :param step_ids: The id of the steps being waited upon. :param poll_interval: The amount of time in seconds to wait between attempts. :param max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. """def__init__(self,job_flow_id:str,step_ids:list[str],aws_conn_id:str,max_attempts:int|None,poll_interval:int|None,):self.job_flow_id=job_flow_idself.step_ids=step_idsself.aws_conn_id=aws_conn_idself.max_attempts=max_attemptsself.poll_interval=poll_interval
[docs]asyncdefrun(self):self.hook=EmrHook(aws_conn_id=self.aws_conn_id)asyncwithself.hook.async_connasclient:forstep_idinself.step_ids:waiter=client.get_waiter("step_complete")forattemptinrange(1,1+self.max_attempts):try:awaitwaiter.wait(ClusterId=self.job_flow_id,StepId=step_id,WaiterConfig={"Delay":int(self.poll_interval),"MaxAttempts":1,},)breakexceptWaiterErroraserror:if"terminal failure"instr(error):yieldTriggerEvent({"status":"failure","message":f"Step {step_id} failed: {error}"})breakself.log.info("Status of step is %s - %s",error.last_response["Step"]["Status"]["State"],error.last_response["Step"]["Status"]["StateChangeReason"],)awaitasyncio.sleep(int(self.poll_interval))ifattempt>=int(self.max_attempts):yieldTriggerEvent({"status":"failure","message":"Steps failed: max attempts reached"})else:yieldTriggerEvent({"status":"success","message":"Steps completed","step_ids":self.step_ids})
[docs]classEmrCreateJobFlowTrigger(AwsBaseWaiterTrigger):""" Asynchronously poll the boto3 API and wait for the JobFlow to finish executing. :param job_flow_id: The id of the job flow to wait for. :param waiter_delay: The amount of time in seconds to wait between attempts. :param waiter_max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. """def__init__(self,job_flow_id:str,poll_interval:int|None=None,# deprecatedmax_attempts:int|None=None,# deprecatedaws_conn_id:str|None=None,waiter_delay:int=30,waiter_max_attempts:int=60,):ifpoll_intervalisnotNoneormax_attemptsisnotNone:warnings.warn("please use waiter_delay instead of poll_interval ""and waiter_max_attempts instead of max_attempts",AirflowProviderDeprecationWarning,stacklevel=2,)waiter_delay=poll_intervalorwaiter_delaywaiter_max_attempts=max_attemptsorwaiter_max_attemptssuper().__init__(serialized_fields={"job_flow_id":job_flow_id},waiter_name="job_flow_waiting",waiter_args={"ClusterId":job_flow_id},failure_message="JobFlow creation failed",status_message="JobFlow creation in progress",status_queries=["Cluster.Status.State","Cluster.Status.StateChangeReason","Cluster.Status.ErrorDetails",],return_key="job_flow_id",return_value=job_flow_id,waiter_delay=waiter_delay,waiter_max_attempts=waiter_max_attempts,aws_conn_id=aws_conn_id,)
[docs]classEmrTerminateJobFlowTrigger(AwsBaseWaiterTrigger):""" Asynchronously poll the boto3 API and wait for the JobFlow to finish terminating. :param job_flow_id: ID of the EMR Job Flow to terminate :param waiter_delay: The amount of time in seconds to wait between attempts. :param waiter_max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. """def__init__(self,job_flow_id:str,poll_interval:int|None=None,# deprecatedmax_attempts:int|None=None,# deprecatedaws_conn_id:str|None=None,waiter_delay:int=30,waiter_max_attempts:int=60,):ifpoll_intervalisnotNoneormax_attemptsisnotNone:warnings.warn("please use waiter_delay instead of poll_interval ""and waiter_max_attempts instead of max_attempts",AirflowProviderDeprecationWarning,stacklevel=2,)waiter_delay=poll_intervalorwaiter_delaywaiter_max_attempts=max_attemptsorwaiter_max_attemptssuper().__init__(serialized_fields={"job_flow_id":job_flow_id},waiter_name="job_flow_terminated",waiter_args={"ClusterId":job_flow_id},failure_message="JobFlow termination failed",status_message="JobFlow termination in progress",status_queries=["Cluster.Status.State","Cluster.Status.StateChangeReason","Cluster.Status.ErrorDetails",],return_value=None,waiter_delay=waiter_delay,waiter_max_attempts=waiter_max_attempts,aws_conn_id=aws_conn_id,)
[docs]classEmrContainerTrigger(AwsBaseWaiterTrigger):""" Poll for the status of EMR container until reaches terminal state. :param virtual_cluster_id: Reference Emr cluster id :param job_id: job_id to check the state :param aws_conn_id: Reference to AWS connection id :param waiter_delay: polling period in seconds to check for the status """def__init__(self,virtual_cluster_id:str,job_id:str,aws_conn_id:str="aws_default",poll_interval:int|None=None,# deprecatedwaiter_delay:int=30,waiter_max_attempts:int=600,):ifpoll_intervalisnotNone:warnings.warn("please use waiter_delay instead of poll_interval.",AirflowProviderDeprecationWarning,stacklevel=2,)waiter_delay=poll_intervalorwaiter_delaysuper().__init__(serialized_fields={"virtual_cluster_id":virtual_cluster_id,"job_id":job_id},waiter_name="container_job_complete",waiter_args={"id":job_id,"virtualClusterId":virtual_cluster_id},failure_message="Job failed",status_message="Job in progress",status_queries=["jobRun.state","jobRun.failureReason"],return_key="job_id",return_value=job_id,waiter_delay=waiter_delay,waiter_max_attempts=waiter_max_attempts,aws_conn_id=aws_conn_id,)
[docs]classEmrStepSensorTrigger(AwsBaseWaiterTrigger):""" Poll for the status of EMR container until reaches terminal state. :param job_flow_id: job_flow_id which contains the step check the state of :param step_id: step to check the state of :param waiter_delay: polling period in seconds to check for the status :param waiter_max_attempts: The maximum number of attempts to be made :param aws_conn_id: Reference to AWS connection id """def__init__(self,job_flow_id:str,step_id:str,waiter_delay:int=30,waiter_max_attempts:int=60,aws_conn_id:str="aws_default",):super().__init__(serialized_fields={"job_flow_id":job_flow_id,"step_id":step_id},waiter_name="step_wait_for_terminal",waiter_args={"ClusterId":job_flow_id,"StepId":step_id},failure_message=f"Error while waiting for step {step_id} to complete",status_message=f"Step id: {step_id}, Step is still in non-terminal state",status_queries=["Step.Status.State","Step.Status.FailureDetails","Step.Status.StateChangeReason",],return_value=None,waiter_delay=waiter_delay,waiter_max_attempts=waiter_max_attempts,aws_conn_id=aws_conn_id,)
[docs]classEmrServerlessCreateApplicationTrigger(AwsBaseWaiterTrigger):""" Poll an Emr Serverless application and wait for it to be created. :param application_id: The ID of the application being polled. :waiter_delay: polling period in seconds to check for the status :param waiter_max_attempts: The maximum number of attempts to be made :param aws_conn_id: Reference to AWS connection id """def__init__(self,application_id:str,waiter_delay:int=30,waiter_max_attempts:int=60,aws_conn_id:str="aws_default",)->None:super().__init__(serialized_fields={"application_id":application_id},waiter_name="serverless_app_created",waiter_args={"applicationId":application_id},failure_message="Application creation failed",status_message="Application status is",status_queries=["application.state","application.stateDetails"],return_key="application_id",return_value=application_id,waiter_delay=waiter_delay,waiter_max_attempts=waiter_max_attempts,aws_conn_id=aws_conn_id,)
[docs]classEmrServerlessStartApplicationTrigger(AwsBaseWaiterTrigger):""" Poll an Emr Serverless application and wait for it to be started. :param application_id: The ID of the application being polled. :waiter_delay: polling period in seconds to check for the status :param waiter_max_attempts: The maximum number of attempts to be made :param aws_conn_id: Reference to AWS connection id """def__init__(self,application_id:str,waiter_delay:int=30,waiter_max_attempts:int=60,aws_conn_id:str="aws_default",)->None:super().__init__(serialized_fields={"application_id":application_id},waiter_name="serverless_app_started",waiter_args={"applicationId":application_id},failure_message="Application failed to start",status_message="Application status is",status_queries=["application.state","application.stateDetails"],return_key="application_id",return_value=application_id,waiter_delay=waiter_delay,waiter_max_attempts=waiter_max_attempts,aws_conn_id=aws_conn_id,)
[docs]classEmrServerlessStopApplicationTrigger(AwsBaseWaiterTrigger):""" Poll an Emr Serverless application and wait for it to be stopped. :param application_id: The ID of the application being polled. :waiter_delay: polling period in seconds to check for the status :param waiter_max_attempts: The maximum number of attempts to be made :param aws_conn_id: Reference to AWS connection id. """def__init__(self,application_id:str,waiter_delay:int=30,waiter_max_attempts:int=60,aws_conn_id:str="aws_default",)->None:super().__init__(serialized_fields={"application_id":application_id},waiter_name="serverless_app_stopped",waiter_args={"applicationId":application_id},failure_message="Application failed to start",status_message="Application status is",status_queries=["application.state","application.stateDetails"],return_key="application_id",return_value=application_id,waiter_delay=waiter_delay,waiter_max_attempts=waiter_max_attempts,aws_conn_id=aws_conn_id,)
[docs]classEmrServerlessStartJobTrigger(AwsBaseWaiterTrigger):""" Poll an Emr Serverless job run and wait for it to be completed. :param application_id: The ID of the application the job in being run on. :param job_id: The ID of the job run. :waiter_delay: polling period in seconds to check for the status :param waiter_max_attempts: The maximum number of attempts to be made :param aws_conn_id: Reference to AWS connection id """def__init__(self,application_id:str,job_id:str|None,waiter_delay:int=30,waiter_max_attempts:int=60,aws_conn_id:str="aws_default",)->None:super().__init__(serialized_fields={"application_id":application_id,"job_id":job_id},waiter_name="serverless_job_completed",waiter_args={"applicationId":application_id,"jobRunId":job_id},failure_message="Serverless Job failed",status_message="Serverless Job status is",status_queries=["jobRun.state","jobRun.stateDetails"],return_key="job_id",return_value=job_id,waiter_delay=waiter_delay,waiter_max_attempts=waiter_max_attempts,aws_conn_id=aws_conn_id,)
[docs]classEmrServerlessDeleteApplicationTrigger(AwsBaseWaiterTrigger):""" Poll an Emr Serverless application and wait for it to be deleted. :param application_id: The ID of the application being polled. :waiter_delay: polling period in seconds to check for the status :param waiter_max_attempts: The maximum number of attempts to be made :param aws_conn_id: Reference to AWS connection id """def__init__(self,application_id:str,waiter_delay:int=30,waiter_max_attempts:int=60,aws_conn_id:str="aws_default",)->None:super().__init__(serialized_fields={"application_id":application_id},waiter_name="serverless_app_terminated",waiter_args={"applicationId":application_id},failure_message="Application failed to start",status_message="Application status is",status_queries=["application.state","application.stateDetails"],return_key="application_id",return_value=application_id,waiter_delay=waiter_delay,waiter_max_attempts=waiter_max_attempts,aws_conn_id=aws_conn_id,)
[docs]classEmrServerlessCancelJobsTrigger(AwsBaseWaiterTrigger):""" Trigger for canceling a list of jobs in an EMR Serverless application. :param application_id: EMR Serverless application ID :param aws_conn_id: Reference to AWS connection id :param waiter_delay: Delay in seconds between each attempt to check the status :param waiter_max_attempts: Maximum number of attempts to check the status """def__init__(self,application_id:str,aws_conn_id:str,waiter_delay:int,waiter_max_attempts:int,)->None:self.hook_instance=EmrServerlessHook(aws_conn_id)states=list(self.hook_instance.JOB_INTERMEDIATE_STATES.union({"CANCELLING"}))super().__init__(serialized_fields={"application_id":application_id},waiter_name="no_job_running",waiter_args={"applicationId":application_id,"states":states},failure_message="Error while waiting for jobs to cancel",status_message="Currently running jobs",status_queries=["jobRuns[*].applicationId","jobRuns[*].state"],return_key="application_id",return_value=application_id,waiter_delay=waiter_delay,waiter_max_attempts=waiter_max_attempts,aws_conn_id=aws_conn_id,)