Source code for airflow.providers.amazon.aws.triggers.emr
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportasynciofromfunctoolsimportcached_propertyfromtypingimportAny,AsyncIteratorfrombotocore.exceptionsimportWaiterErrorfromairflow.exceptionsimportAirflowExceptionfromairflow.providers.amazon.aws.hooks.emrimportEmrContainerHook,EmrHookfromairflow.triggers.baseimportBaseTrigger,TriggerEventfromairflow.utils.helpersimportprune_dict
[docs]classEmrAddStepsTrigger(BaseTrigger):""" AWS Emr Add Steps Trigger The trigger will asynchronously poll the boto3 API and wait for the steps to finish executing. :param job_flow_id: The id of the job flow. :param step_ids: The id of the steps being waited upon. :param poll_interval: The amount of time in seconds to wait between attempts. :param max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. """def__init__(self,job_flow_id:str,step_ids:list[str],aws_conn_id:str,max_attempts:int|None,poll_interval:int|None,):self.job_flow_id=job_flow_idself.step_ids=step_idsself.aws_conn_id=aws_conn_idself.max_attempts=max_attemptsself.poll_interval=poll_interval
[docs]asyncdefrun(self):self.hook=EmrHook(aws_conn_id=self.aws_conn_id)asyncwithself.hook.async_connasclient:forstep_idinself.step_ids:attempt=0waiter=client.get_waiter("step_complete")whileattempt<int(self.max_attempts):attempt+=1try:awaitwaiter.wait(ClusterId=self.job_flow_id,StepId=step_id,WaiterConfig={"Delay":int(self.poll_interval),"MaxAttempts":1,},)breakexceptWaiterErroraserror:if"terminal failure"instr(error):yieldTriggerEvent({"status":"failure","message":f"Step {step_id} failed: {error}"})breakself.log.info("Status of step is %s - %s",error.last_response["Step"]["Status"]["State"],error.last_response["Step"]["Status"]["StateChangeReason"],)awaitasyncio.sleep(int(self.poll_interval))ifattempt>=int(self.max_attempts):yieldTriggerEvent({"status":"failure","message":"Steps failed: max attempts reached"})else:yieldTriggerEvent({"status":"success","message":"Steps completed","step_ids":self.step_ids})
[docs]classEmrCreateJobFlowTrigger(BaseTrigger):""" Trigger for EmrCreateJobFlowOperator. The trigger will asynchronously poll the boto3 API and wait for the JobFlow to finish executing. :param job_flow_id: The id of the job flow to wait for. :param poll_interval: The amount of time in seconds to wait between attempts. :param max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. """def__init__(self,job_flow_id:str,poll_interval:int,max_attempts:int,aws_conn_id:str,):self.job_flow_id=job_flow_idself.poll_interval=poll_intervalself.max_attempts=max_attemptsself.aws_conn_id=aws_conn_id
[docs]asyncdefrun(self):self.hook=EmrHook(aws_conn_id=self.aws_conn_id)asyncwithself.hook.async_connasclient:attempt=0waiter=self.hook.get_waiter("job_flow_waiting",deferrable=True,client=client)whileattempt<int(self.max_attempts):attempt=attempt+1try:awaitwaiter.wait(ClusterId=self.job_flow_id,WaiterConfig=prune_dict({"Delay":self.poll_interval,"MaxAttempts":1,}),)breakexceptWaiterErroraserror:if"terminal failure"instr(error):raiseAirflowException(f"JobFlow creation failed: {error}")self.log.info("Status of jobflow is %s - %s",error.last_response["Cluster"]["Status"]["State"],error.last_response["Cluster"]["Status"]["StateChangeReason"],)awaitasyncio.sleep(int(self.poll_interval))ifattempt>=int(self.max_attempts):raiseAirflowException(f"JobFlow creation failed - max attempts reached: {self.max_attempts}")else:yieldTriggerEvent({"status":"success","message":"JobFlow completed successfully","job_flow_id":self.job_flow_id,})
[docs]classEmrTerminateJobFlowTrigger(BaseTrigger):""" Trigger that terminates a running EMR Job Flow. The trigger will asynchronously poll the boto3 API and wait for the JobFlow to finish terminating. :param job_flow_id: ID of the EMR Job Flow to terminate :param poll_interval: The amount of time in seconds to wait between attempts. :param max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. """def__init__(self,job_flow_id:str,poll_interval:int,max_attempts:int,aws_conn_id:str,):self.job_flow_id=job_flow_idself.poll_interval=poll_intervalself.max_attempts=max_attemptsself.aws_conn_id=aws_conn_id
[docs]asyncdefrun(self):self.hook=EmrHook(aws_conn_id=self.aws_conn_id)asyncwithself.hook.async_connasclient:attempt=0waiter=self.hook.get_waiter("job_flow_terminated",deferrable=True,client=client)whileattempt<int(self.max_attempts):attempt=attempt+1try:awaitwaiter.wait(ClusterId=self.job_flow_id,WaiterConfig=prune_dict({"Delay":self.poll_interval,"MaxAttempts":1,}),)breakexceptWaiterErroraserror:if"terminal failure"instr(error):raiseAirflowException(f"JobFlow termination failed: {error}")self.log.info("Status of jobflow is %s - %s",error.last_response["Cluster"]["Status"]["State"],error.last_response["Cluster"]["Status"]["StateChangeReason"],)awaitasyncio.sleep(int(self.poll_interval))ifattempt>=int(self.max_attempts):raiseAirflowException(f"JobFlow termination failed - max attempts reached: {self.max_attempts}")else:yieldTriggerEvent({"status":"success","message":"JobFlow terminated successfully",})
[docs]classEmrContainerSensorTrigger(BaseTrigger):""" Poll for the status of EMR container until reaches terminal state. :param virtual_cluster_id: Reference Emr cluster id :param job_id: job_id to check the state :param aws_conn_id: Reference to AWS connection id :param poll_interval: polling period in seconds to check for the status """def__init__(self,virtual_cluster_id:str,job_id:str,aws_conn_id:str="aws_default",poll_interval:int=30,**kwargs:Any,):self.virtual_cluster_id=virtual_cluster_idself.job_id=job_idself.aws_conn_id=aws_conn_idself.poll_interval=poll_intervalsuper().__init__(**kwargs)@cached_property
[docs]defserialize(self)->tuple[str,dict[str,Any]]:"""Serializes EmrContainerSensorTrigger arguments and classpath."""return("airflow.providers.amazon.aws.triggers.emr.EmrContainerSensorTrigger",{"virtual_cluster_id":self.virtual_cluster_id,"job_id":self.job_id,"aws_conn_id":self.aws_conn_id,"poll_interval":self.poll_interval,},)
[docs]asyncdefrun(self)->AsyncIterator[TriggerEvent]:asyncwithself.hook.async_connasclient:waiter=self.hook.get_waiter("container_job_complete",deferrable=True,client=client)attempt=0whileTrue:attempt=attempt+1try:awaitwaiter.wait(id=self.job_id,virtualClusterId=self.virtual_cluster_id,WaiterConfig={"Delay":self.poll_interval,"MaxAttempts":1,},)breakexceptWaiterErroraserror:if"terminal failure"instr(error):yieldTriggerEvent({"status":"failure","message":f"Job Failed: {error}"})breakself.log.info("Job status is %s. Retrying attempt %s",error.last_response["jobRun"]["state"],attempt,)awaitasyncio.sleep(int(self.poll_interval))yieldTriggerEvent({"status":"success","job_id":self.job_id})