Source code for airflow.providers.amazon.aws.triggers.ecs
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportasynciofromcollections.abcimportAsyncIteratorfromtypingimportTYPE_CHECKING,Anyfrombotocore.exceptionsimportClientError,WaiterErrorfromairflow.exceptionsimportAirflowExceptionfromairflow.providers.amazon.aws.hooks.ecsimportEcsHookfromairflow.providers.amazon.aws.hooks.logsimportAwsLogsHookfromairflow.providers.amazon.aws.triggers.baseimportAwsBaseWaiterTriggerfromairflow.providers.amazon.aws.utils.task_log_fetcherimportAwsTaskLogFetcherfromairflow.triggers.baseimportBaseTrigger,TriggerEventifTYPE_CHECKING:fromairflow.providers.amazon.aws.hooks.base_awsimportAwsGenericHook
[docs]classClusterActiveTrigger(AwsBaseWaiterTrigger):""" Polls the status of a cluster until it's active. :param cluster_arn: ARN of the cluster to watch. :param waiter_delay: The amount of time in seconds to wait between attempts. :param waiter_max_attempts: The number of times to ping for status. Will fail after that many unsuccessful attempts. :param aws_conn_id: The Airflow connection used for AWS credentials. :param region_name: The AWS region where the cluster is located. """def__init__(self,cluster_arn:str,waiter_delay:int,waiter_max_attempts:int,aws_conn_id:str|None,region_name:str|None=None,**kwargs,):super().__init__(serialized_fields={"cluster_arn":cluster_arn},waiter_name="cluster_active",waiter_args={"clusters":[cluster_arn]},failure_message="Failure while waiting for cluster to be available",status_message="Cluster is not ready yet",status_queries=["clusters[].status","failures"],return_key="arn",return_value=cluster_arn,waiter_delay=waiter_delay,waiter_max_attempts=waiter_max_attempts,aws_conn_id=aws_conn_id,region_name=region_name,**kwargs,)
[docs]classClusterInactiveTrigger(AwsBaseWaiterTrigger):""" Polls the status of a cluster until it's inactive. :param cluster_arn: ARN of the cluster to watch. :param waiter_delay: The amount of time in seconds to wait between attempts. :param waiter_max_attempts: The number of times to ping for status. Will fail after that many unsuccessful attempts. :param aws_conn_id: The Airflow connection used for AWS credentials. :param region_name: The AWS region where the cluster is located. """def__init__(self,cluster_arn:str,waiter_delay:int,waiter_max_attempts:int,aws_conn_id:str|None,region_name:str|None=None,**kwargs,):super().__init__(serialized_fields={"cluster_arn":cluster_arn},waiter_name="cluster_inactive",waiter_args={"clusters":[cluster_arn]},failure_message="Failure while waiting for cluster to be deactivated",status_message="Cluster deactivation is not done yet",status_queries=["clusters[].status","failures"],return_value=cluster_arn,waiter_delay=waiter_delay,waiter_max_attempts=waiter_max_attempts,aws_conn_id=aws_conn_id,region_name=region_name,**kwargs,)
[docs]classTaskDoneTrigger(BaseTrigger):""" Waits for an ECS task to be done, while eventually polling logs. :param cluster: short name or full ARN of the cluster where the task is running. :param task_arn: ARN of the task to watch. :param waiter_delay: The amount of time in seconds to wait between attempts. :param waiter_max_attempts: The number of times to ping for status. Will fail after that many unsuccessful attempts. :param aws_conn_id: The Airflow connection used for AWS credentials. :param region: The AWS region where the cluster is located. """def__init__(self,cluster:str,task_arn:str,waiter_delay:int,waiter_max_attempts:int,aws_conn_id:str|None,region:str|None,log_group:str|None=None,log_stream:str|None=None,):self.cluster=clusterself.task_arn=task_arnself.waiter_delay=waiter_delayself.waiter_max_attempts=waiter_max_attemptsself.aws_conn_id=aws_conn_idself.region=regionself.log_group=log_groupself.log_stream=log_stream
[docs]asyncdefrun(self)->AsyncIterator[TriggerEvent]:asyncwith(EcsHook(aws_conn_id=self.aws_conn_id,region_name=self.region).async_connasecs_client,AwsLogsHook(aws_conn_id=self.aws_conn_id,region_name=self.region).async_connaslogs_client,):waiter=ecs_client.get_waiter("tasks_stopped")logs_token=Nonewhileself.waiter_max_attempts:self.waiter_max_attempts-=1try:awaitwaiter.wait(cluster=self.cluster,tasks=[self.task_arn],WaiterConfig={"MaxAttempts":1})# we reach this point only if the waiter met a success criteriayieldTriggerEvent({"status":"success","task_arn":self.task_arn,"cluster":self.cluster})returnexceptWaiterErroraserror:if"terminal failure"instr(error):raiseself.log.info("Status of the task is %s",error.last_response["tasks"][0]["lastStatus"])awaitasyncio.sleep(int(self.waiter_delay))finally:ifself.log_groupandself.log_stream:logs_token=awaitself._forward_logs(logs_client,logs_token)raiseAirflowException("Waiter error: max attempts reached")
asyncdef_forward_logs(self,logs_client,next_token:str|None=None)->str|None:""" Read logs from the cloudwatch stream and print them to the task logs. :return: the token to pass to the next iteration to resume where we started. """whileTrue:ifnext_tokenisnotNone:token_arg:dict[str,str]={"nextToken":next_token}else:token_arg={}try:response=awaitlogs_client.get_log_events(logGroupName=self.log_group,logStreamName=self.log_stream,startFromHead=True,**token_arg,)exceptClientErrorasce:ifce.response["Error"]["Code"]=="ResourceNotFoundException":self.log.info("Tried to get logs from stream %s in group %s but it didn't exist (yet). ""Will try again.",self.log_stream,self.log_group,)returnNoneraiseevents=response["events"]forlog_eventinevents:self.log.info(AwsTaskLogFetcher.event_to_str(log_event))iflen(events)==0ornext_token==response["nextForwardToken"]:returnresponse["nextForwardToken"]next_token=response["nextForwardToken"]