Source code for airflow.providers.amazon.aws.executors.ecs.ecs_executor_config

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""
AWS ECS Executor configuration.

This is the configuration for calling the ECS ``run_task`` function. The AWS ECS Executor calls
Boto3's ``run_task(**kwargs)`` function with the kwargs templated by this dictionary. See the URL
below for documentation on the parameters accepted by the Boto3 run_task function.

.. seealso::
    https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.run_task

"""

from __future__ import annotations

import json
from json import JSONDecodeError

from airflow.configuration import conf
from airflow.providers.amazon.aws.executors.ecs.utils import (
    CONFIG_GROUP_NAME,
    AllEcsConfigKeys,
    RunTaskKwargsConfigKeys,
    camelize_dict_keys,
    parse_assign_public_ip,
)
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.utils.helpers import prune_dict


def _fetch_templated_kwargs() -> dict[str, str]:
    run_task_kwargs_value = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.RUN_TASK_KWARGS, fallback=dict())
    return json.loads(str(run_task_kwargs_value))


def _fetch_config_values() -> dict[str, str]:
    return prune_dict(
        {key: conf.get(CONFIG_GROUP_NAME, key, fallback=None) for key in RunTaskKwargsConfigKeys()}
    )


[docs]def build_task_kwargs() -> dict: # This will put some kwargs at the root of the dictionary that do NOT belong there. However, # the code below expects them to be there and will rearrange them as necessary. task_kwargs = _fetch_config_values() task_kwargs.update(_fetch_templated_kwargs()) has_launch_type: bool = "launch_type" in task_kwargs has_capacity_provider: bool = "capacity_provider_strategy" in task_kwargs if has_capacity_provider and has_launch_type: raise ValueError( "capacity_provider_strategy and launch_type are mutually exclusive, you can not provide both." ) elif "cluster" in task_kwargs and not (has_capacity_provider or has_launch_type): # Default API behavior if neither is provided is to fall back on the default capacity # provider if it exists. Since it is not a required value, check if there is one # before using it, and if there is not then use the FARGATE launch_type as # the final fallback. cluster = EcsHook().conn.describe_clusters(clusters=[task_kwargs["cluster"]])["clusters"][0] if not cluster.get("defaultCapacityProviderStrategy"): task_kwargs["launch_type"] = "FARGATE" # There can only be 1 count of these containers task_kwargs["count"] = 1 # type: ignore # There could be a generic approach to the below, but likely more convoluted then just manually ensuring # the one nested config we need to update is present. If we need to override more options in the future we # should revisit this. if "overrides" not in task_kwargs: task_kwargs["overrides"] = {} # type: ignore if "containerOverrides" not in task_kwargs["overrides"]: task_kwargs["overrides"]["containerOverrides"] = [{}] # type: ignore task_kwargs["overrides"]["containerOverrides"][0]["name"] = task_kwargs.pop( # type: ignore AllEcsConfigKeys.CONTAINER_NAME ) # The executor will overwrite the 'command' property during execution. Must always be the first container! task_kwargs["overrides"]["containerOverrides"][0]["command"] = [] # type: ignore if any( [ subnets := task_kwargs.pop(AllEcsConfigKeys.SUBNETS, None), security_groups := task_kwargs.pop(AllEcsConfigKeys.SECURITY_GROUPS, None), # Surrounding parens are for the walrus operator to function correctly along with the None check (assign_public_ip := task_kwargs.pop(AllEcsConfigKeys.ASSIGN_PUBLIC_IP, None)) is not None, ] ): network_config = prune_dict( { "awsvpcConfiguration": { "subnets": str(subnets).split(",") if subnets else None, "securityGroups": str(security_groups).split(",") if security_groups else None, "assignPublicIp": parse_assign_public_ip(assign_public_ip), } } ) if "subnets" not in network_config["awsvpcConfiguration"]: raise ValueError("At least one subnet is required to run a task.") task_kwargs["networkConfiguration"] = network_config task_kwargs = camelize_dict_keys(task_kwargs) try: json.loads(json.dumps(task_kwargs)) except JSONDecodeError: raise ValueError("AWS ECS Executor config values must be JSON serializable.") return task_kwargs

Was this entry helpful?