# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

AWS ECS Executor Utilities.

Data classes and utility functions used by the ECS executor.

from __future__ import annotations

import datetime
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List

from inflection import camelize

from airflow.utils.state import State

    from airflow.models.taskinstance import TaskInstanceKey

[docs]CommandType = List[str]
[docs]ExecutorConfigFunctionType = Callable[[CommandType], dict]
[docs]ExecutorConfigType = Dict[str, Any]
[docs]CONFIG_GROUP_NAME = "aws_ecs_executor"
[docs]CONFIG_DEFAULTS = { "conn_id": "aws_default", "max_run_task_attempts": "3", "assign_public_ip": "False", "platform_version": "LATEST", "check_health_on_startup": "True", }
[docs]class EcsQueuedTask: """Represents an ECS task that is queued. The task will be run in the next heartbeat."""
[docs] key: TaskInstanceKey
[docs] command: CommandType
[docs] queue: str
[docs] executor_config: ExecutorConfigType
[docs] attempt_number: int
[docs] next_attempt_time: datetime.datetime
[docs]class EcsTaskInfo: """Contains information about a currently running ECS task."""
[docs] cmd: CommandType
[docs] queue: str
[docs] config: ExecutorConfigType
[docs]class BaseConfigKeys: """Base Implementation of the Config Keys class. Implements iteration for child classes to inherit."""
[docs] def __iter__(self): """Return an iterator of values of non dunder attributes of Config Keys.""" return iter({value for (key, value) in self.__class__.__dict__.items() if not key.startswith("__")})
[docs]class RunTaskKwargsConfigKeys(BaseConfigKeys): """Keys loaded into the config which are valid ECS run_task kwargs."""
[docs] ASSIGN_PUBLIC_IP = "assign_public_ip"
[docs] CAPACITY_PROVIDER_STRATEGY = "capacity_provider_strategy"
[docs] CLUSTER = "cluster"
[docs] LAUNCH_TYPE = "launch_type"
[docs] PLATFORM_VERSION = "platform_version"
[docs] SECURITY_GROUPS = "security_groups"
[docs] SUBNETS = "subnets"
[docs] TASK_DEFINITION = "task_definition"
[docs] CONTAINER_NAME = "container_name"
[docs]class AllEcsConfigKeys(RunTaskKwargsConfigKeys): """All keys loaded into the config which are related to the ECS Executor."""
[docs] MAX_RUN_TASK_ATTEMPTS = "max_run_task_attempts"
[docs] AWS_CONN_ID = "conn_id"
[docs] RUN_TASK_KWARGS = "run_task_kwargs"
[docs] REGION_NAME = "region_name"
[docs] CHECK_HEALTH_ON_STARTUP = "check_health_on_startup"
[docs]class EcsExecutorException(Exception): """Thrown when something unexpected has occurred within the ECS ecosystem."""
[docs]class EcsExecutorTask: """Data Transfer Object for an ECS Fargate Task.""" def __init__( self, task_arn: str, last_status: str, desired_status: str, containers: list[dict[str, Any]], started_at: Any | None = None, stopped_reason: str | None = None, ): self.task_arn = task_arn self.last_status = last_status self.desired_status = desired_status self.containers = containers self.started_at = started_at self.stopped_reason = stopped_reason
[docs] def get_task_state(self) -> str: """ Determine the state of an ECS task based on its status and other relevant attributes. It can return one of the following statuses: QUEUED - Task is being provisioned. RUNNING - Task is launched on ECS. REMOVED - Task provisioning has failed for some reason. See `stopped_reason`. FAILED - Task is completed and at least one container has failed. SUCCESS - Task is completed and all containers have succeeded. """ if self.last_status == "RUNNING": return State.RUNNING elif self.desired_status == "RUNNING": return State.QUEUED is_finished = self.desired_status == "STOPPED" has_exit_codes = all(["exit_code" in x for x in self.containers]) # Sometimes ECS tasks may time out. if not self.started_at and is_finished: return State.REMOVED if not is_finished or not has_exit_codes: return State.RUNNING all_containers_succeeded = all([x["exit_code"] == 0 for x in self.containers]) return State.SUCCESS if all_containers_succeeded else State.FAILED
[docs] def __repr__(self): """Return a string representation of the ECS task.""" return f"({self.task_arn}, {self.last_status}->{self.desired_status}, {self.get_task_state()})"
[docs]class EcsTaskCollection: """A five-way dictionary between Airflow task ids, Airflow cmds, ECS ARNs, and ECS task objects.""" def __init__(self): self.key_to_arn: dict[TaskInstanceKey, str] = {} self.arn_to_key: dict[str, TaskInstanceKey] = {} self.tasks: dict[str, EcsExecutorTask] = {} self.key_to_failure_counts: dict[TaskInstanceKey, int] = defaultdict(int) self.key_to_task_info: dict[TaskInstanceKey, EcsTaskInfo] = {}
[docs] def add_task( self, task: EcsExecutorTask, airflow_task_key: TaskInstanceKey, queue: str, airflow_cmd: CommandType, exec_config: ExecutorConfigType, attempt_number: int, ): """Add a task to the collection.""" arn = task.task_arn self.tasks[arn] = task self.key_to_arn[airflow_task_key] = arn self.arn_to_key[arn] = airflow_task_key self.key_to_task_info[airflow_task_key] = EcsTaskInfo(airflow_cmd, queue, exec_config) self.key_to_failure_counts[airflow_task_key] = attempt_number
[docs] def update_task(self, task: EcsExecutorTask): """Update the state of the given task based on task ARN.""" self.tasks[task.task_arn] = task
[docs] def task_by_key(self, task_key: TaskInstanceKey) -> EcsExecutorTask: """Get a task by Airflow Instance Key.""" arn = self.key_to_arn[task_key] return self.task_by_arn(arn)
[docs] def task_by_arn(self, arn) -> EcsExecutorTask: """Get a task by AWS ARN.""" return self.tasks[arn]
[docs] def pop_by_key(self, task_key: TaskInstanceKey) -> EcsExecutorTask: """Delete task from collection based off of Airflow Task Instance Key.""" arn = self.key_to_arn[task_key] task = self.tasks[arn] del self.key_to_arn[task_key] del self.key_to_task_info[task_key] del self.arn_to_key[arn] del self.tasks[arn] if task_key in self.key_to_failure_counts: del self.key_to_failure_counts[task_key] return task
[docs] def get_all_arns(self) -> list[str]: """Get all AWS ARNs in collection.""" return list(self.key_to_arn.values())
[docs] def get_all_task_keys(self) -> list[TaskInstanceKey]: """Get all Airflow Task Keys in collection.""" return list(self.key_to_arn.keys())
[docs] def failure_count_by_key(self, task_key: TaskInstanceKey) -> int: """Get the number of times a task has failed given an Airflow Task Key.""" return self.key_to_failure_counts[task_key]
[docs] def increment_failure_count(self, task_key: TaskInstanceKey): """Increment the failure counter given an Airflow Task Key.""" self.key_to_failure_counts[task_key] += 1
[docs] def info_by_key(self, task_key: TaskInstanceKey) -> EcsTaskInfo: """Get the Airflow Command given an Airflow task key.""" return self.key_to_task_info[task_key]
[docs] def __getitem__(self, value): """Get a task by AWS ARN.""" return self.task_by_arn(value)
[docs] def __len__(self): """Determine the number of tasks in collection.""" return len(self.tasks)
def _recursive_flatten_dict(nested_dict): """ Recursively unpack a nested dict and return it as a flat dict. For example, _flatten_dict({'a': 'a', 'b': 'b', 'c': {'d': 'd'}}) returns {'a': 'a', 'b': 'b', 'd': 'd'}. """ items = [] for key, value in nested_dict.items(): if isinstance(value, dict): items.extend(_recursive_flatten_dict(value).items()) else: items.append((key, value)) return dict(items)
[docs]def parse_assign_public_ip(assign_public_ip): """Convert "assign_public_ip" from True/False to ENABLE/DISABLE.""" return "ENABLED" if assign_public_ip == "True" else "DISABLED"
[docs]def camelize_dict_keys(nested_dict) -> dict: """Accept a potentially nested dictionary and recursively convert all keys into camelCase.""" result = {} for key, value in nested_dict.items(): new_key = camelize(key, uppercase_first_letter=False) if isinstance(value, dict) and (key.lower() != "tags"): # The key name on tags can be whatever the user wants, and we should not mess with them. result[new_key] = camelize_dict_keys(value) else: result[new_key] = nested_dict[key] return result

