Source code for airflow.providers.amazon.aws.executors.aws_lambda.lambda_executor

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import json
import time
import warnings
from collections import deque
from collections.abc import Sequence
from typing import TYPE_CHECKING, TypeAlias

from boto3.session import NoCredentialsError
from botocore.utils import ClientError

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.executors.base_executor import BaseExecutor
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.providers.amazon.aws.executors.aws_lambda.utils import (
    CONFIG_GROUP_NAME,
    INVALID_CREDENTIALS_EXCEPTIONS,
    AllLambdaConfigKeys,
    CommandType,
    LambdaQueuedTask,
)
from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import (
    calculate_next_attempt_delay,
    exponential_backoff_retry,
)
from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_3_PLUS
from airflow.providers.common.compat.sdk import AirflowException, Stats, timezone

if TYPE_CHECKING:
    from sqlalchemy.orm import Session

    from airflow.executors import workloads
    from airflow.models.taskinstance import TaskInstance

    if AIRFLOW_V_3_3_PLUS:
        from airflow.executors.workloads.types import WorkloadKey as _WorkloadKey

[docs] WorkloadKey: TypeAlias = _WorkloadKey
else: WorkloadKey: TypeAlias = TaskInstanceKey # type: ignore[no-redef, misc]
[docs] class AwsLambdaExecutor(BaseExecutor): """ An Airflow Executor that submits workloads (tasks and callbacks) to AWS Lambda asynchronously. When execute_async() is called, the executor invokes a specified AWS Lambda function (asynchronously) with a payload that includes the workload command and a unique workload key. The Lambda function writes its result directly to an SQS queue, which is then polled by this executor to update workload state in Airflow. """
[docs] supports_multi_team: bool = True
if AIRFLOW_V_3_3_PLUS:
[docs] supports_callbacks: bool = True
if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS: # In the v3 path, we store workloads, not commands as strings. # TODO: TaskSDK: move this type change into BaseExecutor.
[docs] queued_tasks: dict[WorkloadKey, workloads.All] # type: ignore[assignment]
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] self.pending_workloads: deque = deque()
[docs] self.running_workloads: dict[str, WorkloadKey] = {}
# Check if self has the ExecutorConf set on the self.conf attribute, and if not, set it to the global # configuration object. This allows the changes to be backwards compatible with older versions of # Airflow. # Can be removed when minimum supported provider version is equal to the version of core airflow # which introduces multi-team configuration. if not hasattr(self, "conf"): from airflow.providers.common.compat.sdk import conf self.conf = conf
[docs] self.lambda_function_name = self.conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.FUNCTION_NAME)
[docs] self.sqs_queue_url = self.conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.QUEUE_URL)
[docs] self.dlq_url = self.conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.DLQ_URL)
[docs] self.qualifier = self.conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.QUALIFIER, fallback=None)
# Maximum number of retries to invoke Lambda.
[docs] self.max_invoke_attempts = self.conf.get( CONFIG_GROUP_NAME, AllLambdaConfigKeys.MAX_INVOKE_ATTEMPTS, )
[docs] self.attempts_since_last_successful_connection = 0
[docs] self.IS_BOTO_CONNECTION_HEALTHY = False
self.load_connections(check_connection=False) @property
[docs] def pending_tasks(self) -> deque: """Deprecated: use pending_workloads.""" return self.pending_workloads
@property
[docs] def running_tasks(self) -> dict[str, WorkloadKey]: """Deprecated: use running_workloads.""" return self.running_workloads
[docs] def start(self): """Call this when the Executor is run for the first time by the scheduler.""" check_health = self.conf.getboolean(CONFIG_GROUP_NAME, AllLambdaConfigKeys.CHECK_HEALTH_ON_STARTUP) if not check_health: return self.log.info("Starting Lambda Executor and determining health...") try: self.check_health() except AirflowException: self.log.error("Stopping the Airflow Scheduler from starting until the issue is resolved.") raise
[docs] def check_health(self): """ Check the health of the Lambda and SQS connections. For lambda: Use get_function to test if the lambda connection works and the function can be described. For SQS: Use get_queue_attributes is used as a close analog to describe to test if the SQS connection is working. """ self.IS_BOTO_CONNECTION_HEALTHY = False def _check_queue(queue_url): sqs_get_queue_attrs_response = self.sqs_client.get_queue_attributes( QueueUrl=queue_url, AttributeNames=["ApproximateNumberOfMessages"] ) approx_num_msgs = sqs_get_queue_attrs_response.get("Attributes").get( "ApproximateNumberOfMessages" ) self.log.info( "SQS connection is healthy and queue %s is present with %s messages.", queue_url, approx_num_msgs, ) self.log.info("Checking Lambda and SQS connections") try: # Check Lambda health. lambda_get_response = self.lambda_client.get_function(FunctionName=self.lambda_function_name) if self.lambda_function_name not in lambda_get_response["Configuration"]["FunctionName"]: raise AirflowException("Lambda function %s not found.", self.lambda_function_name) self.log.info( "Lambda connection is healthy and function %s is present.", self.lambda_function_name ) # Check SQS results queue. _check_queue(self.sqs_queue_url) # Check SQS dead letter queue. _check_queue(self.dlq_url) # If we reach this point, both connections are healthy and all resources are present. self.IS_BOTO_CONNECTION_HEALTHY = True except Exception: self.log.exception("Lambda Executor health check failed") raise AirflowException( "The Lambda executor will not be able to run Airflow tasks until the issue is addressed." )
[docs] def load_connections(self, check_connection: bool = True): """ Retrieve the AWS connection via Hooks to leverage the Airflow connection system. :param check_connection: If True, check the health of the connection after loading it. """ self.log.info("Loading Connections") aws_conn_id = self.conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.AWS_CONN_ID) region_name = self.conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.REGION_NAME, fallback=None) self.sqs_client = SqsHook(aws_conn_id=aws_conn_id, region_name=region_name).conn self.lambda_client = LambdaHook(aws_conn_id=aws_conn_id, region_name=region_name).conn self.attempts_since_last_successful_connection += 1 self.last_connection_reload = timezone.utcnow() if check_connection: self.check_health() self.attempts_since_last_successful_connection = 0
[docs] def sync(self): """ Sync the executor with the current state of workloads. Check in on currently running tasks and callbacks and attempt to run any new workloads that have been queued. """ if not self.IS_BOTO_CONNECTION_HEALTHY: exponential_backoff_retry( self.last_connection_reload, self.attempts_since_last_successful_connection, self.load_connections, ) if not self.IS_BOTO_CONNECTION_HEALTHY: return try: self.sync_running_workloads() self.attempt_workload_runs() except (ClientError, NoCredentialsError) as error: error_code = error.response["Error"]["Code"] if error_code in INVALID_CREDENTIALS_EXCEPTIONS: self.IS_BOTO_CONNECTION_HEALTHY = False self.log.warning( "AWS credentials are either missing or expired: %s.\nRetrying connection", error ) except Exception: self.log.exception("An error occurred while syncing workloads.")
# TODO: Remove this once the minimum supported version is 3.2+, and defer to BaseExecutor.queue_workload.
[docs] def queue_workload(self, workload: workloads.All, session: Session | None) -> None: from airflow.executors import workloads if isinstance(workload, workloads.ExecuteTask): self.queued_tasks[workload.ti.key] = workload return if AIRFLOW_V_3_3_PLUS and isinstance(workload, workloads.ExecuteCallback): self.queued_callbacks[workload.callback.key] = workload return raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")
def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None: from airflow.executors import workloads for workload in workload_items: queue: str | None key: WorkloadKey command: CommandType if isinstance(workload, workloads.ExecuteTask): command = [workload] key = workload.ti.key queue = workload.ti.queue executor_config = workload.ti.executor_config or {} del self.queued_tasks[key] self.execute_async( key=key, command=command, queue=queue, executor_config=executor_config, ) self.running.add(key) continue if AIRFLOW_V_3_3_PLUS and isinstance(workload, workloads.ExecuteCallback): command = [workload] key = workload.callback.key queue = None if isinstance(workload.callback.data, dict) and "queue" in workload.callback.data: queue = workload.callback.data["queue"] del self.queued_callbacks[key] self.execute_async( key=key, command=command, queue=queue, ) self.running.add(key) continue raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")
[docs] def execute_async( self, key: WorkloadKey, command: CommandType, queue=None, executor_config=None, ): """ Save the workload to be executed in the next sync by inserting the commands into a queue. :param key: Unique workload key. Task workloads use TaskInstanceKey, callback workloads use a string id. :param command: The workload command or serialized shell command to execute. :param executor_config: (Unused) to keep the same signature as the base. :param queue: (Unused) to keep the same signature as the base. """ if len(command) == 1: from airflow.executors import workloads workload = command[0] if AIRFLOW_V_3_3_PLUS: if not isinstance(workload, (workloads.ExecuteTask, workloads.ExecuteCallback)): raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") else: if not isinstance(workload, workloads.ExecuteTask): raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") ser_input = workload.model_dump_json() command = [ "python", "-m", "airflow.sdk.execution_time.execute_workload", "--json-string", ser_input, ] self.pending_workloads.append( LambdaQueuedTask( key, command, queue if queue else "", executor_config or {}, 1, timezone.utcnow() ) )
[docs] def attempt_workload_runs(self): """ Attempt to run workloads that are queued in the pending_workloads. Each workload is submitted to AWS Lambda with a payload containing the workload key and command. The workload key is used to track the workload's state in Airflow. """ queue_len = len(self.pending_workloads) for _ in range(queue_len): workload_to_run = self.pending_workloads.popleft() workload_key = workload_to_run.key cmd = workload_to_run.command attempt_number = workload_to_run.attempt_number failure_reasons = [] try: ser_workload_key = json.dumps(workload_key._asdict()) except AttributeError: # Callback workloads use string id. ser_workload_key = workload_key payload = { "task_key": ser_workload_key, "command": cmd, "executor_config": workload_to_run.executor_config, } if timezone.utcnow() < workload_to_run.next_attempt_time: self.pending_workloads.append(workload_to_run) continue self.log.info( "Submitting workload %s to Lambda function %s", workload_key, self.lambda_function_name ) try: invoke_kwargs = { "FunctionName": self.lambda_function_name, "InvocationType": "Event", "Payload": json.dumps(payload), } if self.qualifier: invoke_kwargs["Qualifier"] = self.qualifier response = self.lambda_client.invoke(**invoke_kwargs) except NoCredentialsError: self.pending_workloads.append(workload_to_run) raise except ClientError as e: error_code = e.response["Error"]["Code"] if error_code in INVALID_CREDENTIALS_EXCEPTIONS: self.pending_workloads.append(workload_to_run) raise failure_reasons.append(str(e)) except Exception as e: # Failed to even get a response back from the Boto3 API or something else went # wrong. For any possible failure we want to add the exception reasons to the # failure list so that it is logged to the user and most importantly the task is # added back to the pending list to be retried later. failure_reasons.append(str(e)) if failure_reasons: # Make sure the number of attempts does not exceed max invoke attempts. if int(attempt_number) < int(self.max_invoke_attempts): workload_to_run.attempt_number += 1 workload_to_run.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay( attempt_number ) self.pending_workloads.append(workload_to_run) else: reasons_str = ", ".join(failure_reasons) self.log.error( "Lambda invoke %s has failed a maximum of %s times. Marking as failed. Reasons: %s", workload_key, attempt_number, reasons_str, ) self.log_task_event( event="lambda invoke failure", ti_key=workload_key, extra=( f"Workload could not be queued after {attempt_number} attempts. " f"Marking as failed. Reasons: {reasons_str}" ), ) self.fail(workload_key) else: status_code = response.get("StatusCode") self.log.info("Invoked Lambda for workload %s with status %s", workload_key, status_code) self.running_workloads[ser_workload_key] = workload_key # Add the serialized workload key as the info, this will be assigned on the ti as the external_executor_id. self.running_state(workload_key, ser_workload_key)
[docs] def attempt_task_runs(self): """Use attempt_workload_runs as attempt_task_runs is deprecated.""" warnings.warn( "attempt_task_runs is deprecated, use attempt_workload_runs instead.", AirflowProviderDeprecationWarning, stacklevel=2, ) return self.attempt_workload_runs()
[docs] def sync_running_workloads(self): """ Poll the SQS queue for messages indicating workload completion. Each message is expected to contain a JSON payload with 'task_key' and 'return_code'. Based on the return code, update the workload state accordingly. """ if not len(self.running_workloads): self.log.debug("No running workloads to process.") return self.process_queue(self.sqs_queue_url) if self.dlq_url and self.running_workloads: self.process_queue(self.dlq_url)
[docs] def sync_running_tasks(self): """Use sync_running_workloads as sync_running_tasks is deprecated.""" warnings.warn( "sync_running_tasks is deprecated, use sync_running_workloads instead.", AirflowProviderDeprecationWarning, stacklevel=2, ) return self.sync_running_workloads()
[docs] def process_queue(self, queue_url: str): """ Poll the SQS queue for messages indicating workload completion. Each message is expected to contain a JSON payload with 'task_key' and 'return_code'. Based on the return code, update the workload state accordingly. """ response = self.sqs_client.receive_message( QueueUrl=queue_url, MaxNumberOfMessages=10, ) # Pagination? Maybe we don't need it. But we don't always delete messages after viewing them so we # could possibly accumulate a lot of messages in the queue and get stuck if we don't read bigger # chunks and paginate. messages = response.get("Messages", []) # The keys that we validate in the messages below will be different depending on whether or not # the message is from the dead letter queue or the main results queue. message_keys = ("return_code", "task_key") if messages and queue_url == self.dlq_url: self.log.warning("%d messages received from the dead letter queue", len(messages)) message_keys = ("command", "task_key") for message in messages: delete_message = False receipt_handle = message["ReceiptHandle"] try: body = json.loads(message["Body"]) except json.JSONDecodeError: self.log.warning( "Received a message from the queue that could not be parsed as JSON: %s", message["Body"], ) delete_message = True # If the message is not already marked for deletion, check if it has the required keys. if not delete_message and not all(key in body for key in message_keys): self.log.warning( "Message is not formatted correctly, %s and/or %s are missing: %s", *message_keys, body ) delete_message = True if delete_message: self.log.warning("Deleting the message to avoid processing it again.") self.sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle) continue return_code = body.get("return_code") ser_task_key = body.get("task_key") # Fetch the real task key from the running_workloads dict, using the serialized task key. try: workload_key = self.running_workloads[ser_task_key] except KeyError: self.log.debug( "Received workload %s from the queue which is not found in running workloads, it is likely " "from another Lambda Executor sharing this queue or might be a stale message that needs " "deleting manually. Marking the message as visible again.", ser_task_key, ) # Mark workload as visible again in SQS so that another executor can pick it up. self.sqs_client.change_message_visibility( QueueUrl=queue_url, ReceiptHandle=receipt_handle, VisibilityTimeout=0, ) continue if workload_key: if return_code == 0: self.success(workload_key) self.log.info( "Successful Lambda invocation for workload %s received from SQS queue.", workload_key, ) else: self.fail(workload_key) if queue_url == self.dlq_url and return_code is None: self.log.error( "DLQ message received: Lambda invocation for workload %s was unable to " "successfully execute. This likely indicates a Lambda-level failure " "(timeout, memory limit, crash, etc.).", workload_key, ) else: self.log.debug( "Lambda invocation for workload %s returned a non-zero exit code %s", workload_key, return_code, ) # Remove the workload from the tracking mapping. self.running_workloads.pop(ser_task_key) # Delete the message from the queue. self.sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
[docs] def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]: """ Adopt task instances which have an external_executor_id (the serialized workload key). The external_executor_id represents the workload identifier. In legacy executors (Airflow < 3.3) this is the serialized TaskInstanceKey. In the workload-based executor model (Airflow ≥ 3.3) this corresponds to the WorkloadKey. Anything that is not adopted will be cleared by the scheduler and becomes eligible for re-scheduling. :param tis: The task instances to adopt. """ with Stats.timer("lambda_executor.adopt_task_instances.duration"): adopted_tis: list[TaskInstance] = [] if serialized_workload_keys := [ (ti, ti.external_executor_id) for ti in tis if ti.external_executor_id ]: for ti, ser_workload_key in serialized_workload_keys: try: data = json.loads(ser_workload_key) workload_key = TaskInstanceKey.from_dict(data) except (json.JSONDecodeError, KeyError, TypeError) as e: self.log.warning( "Failed to deserialize workload_key '%s' (%s); " "skipping deserialization and treating as callback id.", ser_workload_key, str(e), ) # Callback workloads use string keys. workload_key = ser_workload_key self.running_workloads[ser_workload_key] = workload_key adopted_tis.append(ti) if adopted_tis: tasks = [f"{task} in state {task.state}" for task in adopted_tis] task_instance_str = "\n\t".join(tasks) self.log.info( "Adopted the following %d tasks from a dead executor:\n\t%s", len(adopted_tis), task_instance_str, ) not_adopted_tis = [ti for ti in tis if ti not in adopted_tis] return not_adopted_tis
[docs] def end(self, heartbeat_interval=10): """ End execution. Poll until all outstanding workloads are marked as completed. This is a blocking call and async Lambda workloads cannot be cancelled, so this will wait until all workloads are either completed or the timeout is reached. :param heartbeat_interval: The interval in seconds to wait between checks for workload completion. """ self.log.info("Received signal to end, waiting for outstanding workloads to finish.") time_to_wait = int( self.conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.END_WAIT_TIMEOUT, fallback="0") ) start_time = timezone.utcnow() while True: if time_to_wait: current_time = timezone.utcnow() elapsed_time = (current_time - start_time).total_seconds() if elapsed_time > time_to_wait: self.log.warning( "Timed out waiting for workloads to finish. Some workloads may not be handled gracefully" " as the executor is force ending due to timeout." ) break self.sync() if not self.running_workloads: self.log.info("All workloads completed; executor ending.") break self.log.info("Waiting for %d workload(s) to complete.", len(self.running_workloads)) time.sleep(heartbeat_interval)
[docs] def terminate(self): """Get called when the daemon receives a SIGTERM.""" self.log.warning("Terminating Lambda executor. In-flight workloads cannot be stopped.")

Was this entry helpful?