Source code for airflow.providers.amazon.aws.executors.ecs.utils
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License."""AWS ECS Executor Utilities.Data classes and utility functions used by the ECS executor."""from__future__importannotationsimportdatetimefromcollectionsimportdefaultdictfromdataclassesimportdataclassfromtypingimportTYPE_CHECKING,Any,Callablefrominflectionimportcamelizefromairflow.providers.amazon.aws.executors.utils.base_config_keysimportBaseConfigKeysfromairflow.utils.stateimportStateifTYPE_CHECKING:fromairflow.models.taskinstanceimportTaskInstanceKey
[docs]classEcsExecutorException(Exception):"""Thrown when something unexpected has occurred within the ECS ecosystem."""
[docs]classEcsExecutorTask:"""Data Transfer Object for an ECS Task."""def__init__(self,task_arn:str,last_status:str,desired_status:str,containers:list[dict[str,Any]],started_at:Any|None=None,stopped_reason:str|None=None,external_executor_id:str|None=None,):self.task_arn=task_arnself.last_status=last_statusself.desired_status=desired_statusself.containers=containersself.started_at=started_atself.stopped_reason=stopped_reasonself.external_executor_id=external_executor_id
[docs]defget_task_state(self)->str:""" Determine the state of an ECS task based on its status and other relevant attributes. It can return one of the following statuses: QUEUED - Task is being provisioned. RUNNING - Task is launched on ECS. REMOVED - Task provisioning has failed for some reason. See `stopped_reason`. FAILED - Task is completed and at least one container has failed. SUCCESS - Task is completed and all containers have succeeded. """ifself.last_status=="RUNNING":returnState.RUNNINGelifself.desired_status=="RUNNING":returnState.QUEUEDis_finished=self.desired_status=="STOPPED"has_exit_codes=all(["exit_code"inxforxinself.containers])# Sometimes ECS tasks may time out.ifnotself.started_atandis_finished:returnState.REMOVEDifnotis_finishedornothas_exit_codes:returnState.RUNNINGall_containers_succeeded=all([x["exit_code"]==0forxinself.containers])returnState.SUCCESSifall_containers_succeededelseState.FAILED
[docs]def__repr__(self):"""Return a string representation of the ECS task."""returnf"({self.task_arn}, {self.last_status}->{self.desired_status}, {self.get_task_state()})"
[docs]classEcsTaskCollection:"""A five-way dictionary between Airflow task ids, Airflow cmds, ECS ARNs, and ECS task objects."""def__init__(self):self.key_to_arn:dict[TaskInstanceKey,str]={}self.arn_to_key:dict[str,TaskInstanceKey]={}self.tasks:dict[str,EcsExecutorTask]={}self.key_to_failure_counts:dict[TaskInstanceKey,int]=defaultdict(int)self.key_to_task_info:dict[TaskInstanceKey,EcsTaskInfo]={}
[docs]defadd_task(self,task:EcsExecutorTask,airflow_task_key:TaskInstanceKey,queue:str,airflow_cmd:CommandType,exec_config:ExecutorConfigType,attempt_number:int,):"""Add a task to the collection."""arn=task.task_arnself.tasks[arn]=taskself.key_to_arn[airflow_task_key]=arnself.arn_to_key[arn]=airflow_task_keyself.key_to_task_info[airflow_task_key]=EcsTaskInfo(airflow_cmd,queue,exec_config)self.key_to_failure_counts[airflow_task_key]=attempt_number
[docs]defupdate_task(self,task:EcsExecutorTask):"""Update the state of the given task based on task ARN."""self.tasks[task.task_arn]=task
[docs]deftask_by_key(self,task_key:TaskInstanceKey)->EcsExecutorTask:"""Get a task by Airflow Instance Key."""arn=self.key_to_arn[task_key]returnself.task_by_arn(arn)
[docs]deftask_by_arn(self,arn)->EcsExecutorTask:"""Get a task by AWS ARN."""returnself.tasks[arn]
[docs]defpop_by_key(self,task_key:TaskInstanceKey)->EcsExecutorTask:"""Delete task from collection based off of Airflow Task Instance Key."""arn=self.key_to_arn[task_key]task=self.tasks[arn]delself.key_to_arn[task_key]delself.key_to_task_info[task_key]delself.arn_to_key[arn]delself.tasks[arn]iftask_keyinself.key_to_failure_counts:delself.key_to_failure_counts[task_key]returntask
[docs]defget_all_arns(self)->list[str]:"""Get all AWS ARNs in collection."""returnlist(self.key_to_arn.values())
[docs]defget_all_task_keys(self)->list[TaskInstanceKey]:"""Get all Airflow Task Keys in collection."""returnlist(self.key_to_arn.keys())
[docs]deffailure_count_by_key(self,task_key:TaskInstanceKey)->int:"""Get the number of times a task has failed given an Airflow Task Key."""returnself.key_to_failure_counts[task_key]
[docs]defincrement_failure_count(self,task_key:TaskInstanceKey):"""Increment the failure counter given an Airflow Task Key."""self.key_to_failure_counts[task_key]+=1
[docs]definfo_by_key(self,task_key:TaskInstanceKey)->EcsTaskInfo:"""Get the Airflow Command given an Airflow task key."""returnself.key_to_task_info[task_key]
[docs]def__getitem__(self,value):"""Get a task by AWS ARN."""returnself.task_by_arn(value)
[docs]def__len__(self):"""Determine the number of tasks in collection."""returnlen(self.tasks)
def_recursive_flatten_dict(nested_dict):""" Recursively unpack a nested dict and return it as a flat dict. For example, _flatten_dict({'a': 'a', 'b': 'b', 'c': {'d': 'd'}}) returns {'a': 'a', 'b': 'b', 'd': 'd'}. """items=[]forkey,valueinnested_dict.items():ifisinstance(value,dict):items.extend(_recursive_flatten_dict(value).items())else:items.append((key,value))returndict(items)
[docs]defparse_assign_public_ip(assign_public_ip,is_launch_type_ec2=False):"""Convert "assign_public_ip" from True/False to ENABLE/DISABLE."""# If the launch type is EC2, you cannot/should not provide the assignPublicIp parameter (which is# specific to Fargate)ifnotis_launch_type_ec2:return"ENABLED"ifassign_public_ip=="True"else"DISABLED"
[docs]defcamelize_dict_keys(nested_dict)->dict:"""Accept a potentially nested dictionary and recursively convert all keys into camelCase."""result={}forkey,valueinnested_dict.items():new_key=camelize(key,uppercase_first_letter=False)ifisinstance(value,dict)and(key.lower()!="tags"):# The key name on tags can be whatever the user wants, and we should not mess with them.result[new_key]=camelize_dict_keys(value)else:result[new_key]=nested_dict[key]returnresult