# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import json
import time
from collections import deque
from collections.abc import Sequence
from typing import TYPE_CHECKING
from boto3.session import NoCredentialsError
from botocore.utils import ClientError
from airflow.executors.base_executor import BaseExecutor
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.providers.amazon.aws.executors.aws_lambda.utils import (
CONFIG_GROUP_NAME,
INVALID_CREDENTIALS_EXCEPTIONS,
AllLambdaConfigKeys,
CommandType,
LambdaQueuedTask,
)
from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import (
calculate_next_attempt_delay,
exponential_backoff_retry,
)
from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.providers.common.compat.sdk import AirflowException, Stats, timezone
if TYPE_CHECKING:
from sqlalchemy.orm import Session
from airflow.executors import workloads
from airflow.models.taskinstance import TaskInstance
[docs]
class AwsLambdaExecutor(BaseExecutor):
"""
An Airflow Executor that submits tasks to AWS Lambda asynchronously.
When execute_async() is called, the executor invokes a specified AWS Lambda function (asynchronously)
with a payload that includes the task command and a unique task key.
The Lambda function writes its result directly to an SQS queue, which is then polled by this executor
to update task state in Airflow.
"""
[docs]
supports_multi_team: bool = True
if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS:
# In the v3 path, we store workloads, not commands as strings.
# TODO: TaskSDK: move this type change into BaseExecutor
[docs]
queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
self.pending_tasks: deque = deque()
[docs]
self.running_tasks: dict[str, TaskInstanceKey] = {}
# Check if self has the ExecutorConf set on the self.conf attribute, and if not, set it to the global
# configuration object. This allows the changes to be backwards compatible with older versions of
# Airflow.
# Can be removed when minimum supported provider version is equal to the version of core airflow
# which introduces multi-team configuration.
if not hasattr(self, "conf"):
from airflow.providers.common.compat.sdk import conf
self.conf = conf
[docs]
self.lambda_function_name = self.conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.FUNCTION_NAME)
[docs]
self.sqs_queue_url = self.conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.QUEUE_URL)
[docs]
self.dlq_url = self.conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.DLQ_URL)
[docs]
self.qualifier = self.conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.QUALIFIER, fallback=None)
# Maximum number of retries to invoke Lambda.
[docs]
self.max_invoke_attempts = self.conf.get(
CONFIG_GROUP_NAME,
AllLambdaConfigKeys.MAX_INVOKE_ATTEMPTS,
)
[docs]
self.attempts_since_last_successful_connection = 0
[docs]
self.IS_BOTO_CONNECTION_HEALTHY = False
self.load_connections(check_connection=False)
[docs]
def start(self):
"""Call this when the Executor is run for the first time by the scheduler."""
check_health = self.conf.getboolean(CONFIG_GROUP_NAME, AllLambdaConfigKeys.CHECK_HEALTH_ON_STARTUP)
if not check_health:
return
self.log.info("Starting Lambda Executor and determining health...")
try:
self.check_health()
except AirflowException:
self.log.error("Stopping the Airflow Scheduler from starting until the issue is resolved.")
raise
[docs]
def check_health(self):
"""
Check the health of the Lambda and SQS connections.
For lambda: Use get_function to test if the lambda connection works and the function can be
described.
For SQS: Use get_queue_attributes is used as a close analog to describe to test if the SQS
connection is working.
"""
self.IS_BOTO_CONNECTION_HEALTHY = False
def _check_queue(queue_url):
sqs_get_queue_attrs_response = self.sqs_client.get_queue_attributes(
QueueUrl=queue_url, AttributeNames=["ApproximateNumberOfMessages"]
)
approx_num_msgs = sqs_get_queue_attrs_response.get("Attributes").get(
"ApproximateNumberOfMessages"
)
self.log.info(
"SQS connection is healthy and queue %s is present with %s messages.",
queue_url,
approx_num_msgs,
)
self.log.info("Checking Lambda and SQS connections")
try:
# Check Lambda health
lambda_get_response = self.lambda_client.get_function(FunctionName=self.lambda_function_name)
if self.lambda_function_name not in lambda_get_response["Configuration"]["FunctionName"]:
raise AirflowException("Lambda function %s not found.", self.lambda_function_name)
self.log.info(
"Lambda connection is healthy and function %s is present.", self.lambda_function_name
)
# Check SQS results queue
_check_queue(self.sqs_queue_url)
# Check SQS dead letter queue
_check_queue(self.dlq_url)
# If we reach this point, both connections are healthy and all resources are present
self.IS_BOTO_CONNECTION_HEALTHY = True
except Exception:
self.log.exception("Lambda Executor health check failed")
raise AirflowException(
"The Lambda executor will not be able to run Airflow tasks until the issue is addressed."
)
[docs]
def load_connections(self, check_connection: bool = True):
"""
Retrieve the AWS connection via Hooks to leverage the Airflow connection system.
:param check_connection: If True, check the health of the connection after loading it.
"""
self.log.info("Loading Connections")
aws_conn_id = self.conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.AWS_CONN_ID)
region_name = self.conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.REGION_NAME, fallback=None)
self.sqs_client = SqsHook(aws_conn_id=aws_conn_id, region_name=region_name).conn
self.lambda_client = LambdaHook(aws_conn_id=aws_conn_id, region_name=region_name).conn
self.attempts_since_last_successful_connection += 1
self.last_connection_reload = timezone.utcnow()
if check_connection:
self.check_health()
self.attempts_since_last_successful_connection = 0
[docs]
def sync(self):
"""
Sync the executor with the current state of tasks.
Check in on currently running tasks and attempt to run any new tasks that have been queued.
"""
if not self.IS_BOTO_CONNECTION_HEALTHY:
exponential_backoff_retry(
self.last_connection_reload,
self.attempts_since_last_successful_connection,
self.load_connections,
)
if not self.IS_BOTO_CONNECTION_HEALTHY:
return
try:
self.sync_running_tasks()
self.attempt_task_runs()
except (ClientError, NoCredentialsError) as error:
error_code = error.response["Error"]["Code"]
if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
self.IS_BOTO_CONNECTION_HEALTHY = False
self.log.warning(
"AWS credentials are either missing or expired: %s.\nRetrying connection", error
)
except Exception:
self.log.exception("An error occurred while syncing tasks")
[docs]
def queue_workload(self, workload: workloads.All, session: Session | None) -> None:
from airflow.executors import workloads
if not isinstance(workload, workloads.ExecuteTask):
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")
ti = workload.ti
self.queued_tasks[ti.key] = workload
def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
from airflow.executors.workloads import ExecuteTask
for w in workloads:
if not isinstance(w, ExecuteTask):
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}")
command = [w]
key = w.ti.key
queue = w.ti.queue
executor_config = w.ti.executor_config or {}
del self.queued_tasks[key]
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type]
self.running.add(key)
[docs]
def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, executor_config=None):
"""
Save the task to be executed in the next sync by inserting the commands into a queue.
:param key: A unique task key (typically a tuple identifying the task instance).
:param command: The shell command string to execute.
:param executor_config: (Unused) to keep the same signature as the base.
:param queue: (Unused) to keep the same signature as the base.
"""
if len(command) == 1:
from airflow.executors.workloads import ExecuteTask
if isinstance(command[0], ExecuteTask):
workload = command[0]
ser_input = workload.model_dump_json()
command = [
"python",
"-m",
"airflow.sdk.execution_time.execute_workload",
"--json-string",
ser_input,
]
else:
raise RuntimeError(
f"LambdaExecutor doesn't know how to handle workload of type: {type(command[0])}"
)
self.pending_tasks.append(
LambdaQueuedTask(
key, command, queue if queue else "", executor_config or {}, 1, timezone.utcnow()
)
)
[docs]
def attempt_task_runs(self):
"""
Attempt to run tasks that are queued in the pending_tasks.
Each task is submitted to AWS Lambda with a payload containing the task key and command.
The task key is used to track the task's state in Airflow.
"""
queue_len = len(self.pending_tasks)
for _ in range(queue_len):
task_to_run = self.pending_tasks.popleft()
task_key = task_to_run.key
cmd = task_to_run.command
attempt_number = task_to_run.attempt_number
failure_reasons = []
ser_task_key = json.dumps(task_key._asdict())
payload = {
"task_key": ser_task_key,
"command": cmd,
"executor_config": task_to_run.executor_config,
}
if timezone.utcnow() < task_to_run.next_attempt_time:
self.pending_tasks.append(task_to_run)
continue
self.log.info("Submitting task %s to Lambda function %s", task_key, self.lambda_function_name)
try:
invoke_kwargs = {
"FunctionName": self.lambda_function_name,
"InvocationType": "Event",
"Payload": json.dumps(payload),
}
if self.qualifier:
invoke_kwargs["Qualifier"] = self.qualifier
response = self.lambda_client.invoke(**invoke_kwargs)
except NoCredentialsError:
self.pending_tasks.append(task_to_run)
raise
except ClientError as e:
error_code = e.response["Error"]["Code"]
if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
self.pending_tasks.append(task_to_run)
raise
failure_reasons.append(str(e))
except Exception as e:
# Failed to even get a response back from the Boto3 API or something else went
# wrong. For any possible failure we want to add the exception reasons to the
# failure list so that it is logged to the user and most importantly the task is
# added back to the pending list to be retried later.
failure_reasons.append(str(e))
if failure_reasons:
# Make sure the number of attempts does not exceed max invoke attempts
if int(attempt_number) < int(self.max_invoke_attempts):
task_to_run.attempt_number += 1
task_to_run.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay(
attempt_number
)
self.pending_tasks.append(task_to_run)
else:
reasons_str = ", ".join(failure_reasons)
self.log.error(
"Lambda invoke %s has failed a maximum of %s times. Marking as failed. Reasons: %s",
task_key,
attempt_number,
reasons_str,
)
self.log_task_event(
event="lambda invoke failure",
ti_key=task_key,
extra=(
f"Task could not be queued after {attempt_number} attempts. "
f"Marking as failed. Reasons: {reasons_str}"
),
)
self.fail(task_key)
else:
status_code = response.get("StatusCode")
self.log.info("Invoked Lambda for task %s with status %s", task_key, status_code)
self.running_tasks[ser_task_key] = task_key
# Add the serialized task key as the info, this will be assigned on the ti as the external_executor_id
self.running_state(task_key, ser_task_key)
[docs]
def sync_running_tasks(self):
"""
Poll the SQS queue for messages indicating task completion.
Each message is expected to contain a JSON payload with 'task_key' and 'return_code'.
Based on the return code, update the task state accordingly.
"""
if not len(self.running_tasks):
self.log.debug("No running tasks to process.")
return
self.process_queue(self.sqs_queue_url)
if self.dlq_url and self.running_tasks:
self.process_queue(self.dlq_url)
[docs]
def process_queue(self, queue_url: str):
"""
Poll the SQS queue for messages indicating task completion.
Each message is expected to contain a JSON payload with 'task_key' and 'return_code'.
Based on the return code, update the task state accordingly.
"""
response = self.sqs_client.receive_message(
QueueUrl=queue_url,
MaxNumberOfMessages=10,
)
# Pagination? Maybe we don't need it. But we don't always delete messages after viewing them so we
# could possibly accumulate a lot of messages in the queue and get stuck if we don't read bigger
# chunks and paginate.
messages = response.get("Messages", [])
# The keys that we validate in the messages below will be different depending on whether or not
# the message is from the dead letter queue or the main results queue.
message_keys = ("return_code", "task_key")
if messages and queue_url == self.dlq_url:
self.log.warning("%d messages received from the dead letter queue", len(messages))
message_keys = ("command", "task_key")
for message in messages:
delete_message = False
receipt_handle = message["ReceiptHandle"]
try:
body = json.loads(message["Body"])
except json.JSONDecodeError:
self.log.warning(
"Received a message from the queue that could not be parsed as JSON: %s",
message["Body"],
)
delete_message = True
# If the message is not already marked for deletion, check if it has the required keys.
if not delete_message and not all(key in body for key in message_keys):
self.log.warning(
"Message is not formatted correctly, %s and/or %s are missing: %s", *message_keys, body
)
delete_message = True
if delete_message:
self.log.warning("Deleting the message to avoid processing it again.")
self.sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
continue
return_code = body.get("return_code")
ser_task_key = body.get("task_key")
# Fetch the real task key from the running_tasks dict, using the serialized task key.
try:
task_key = self.running_tasks[ser_task_key]
except KeyError:
self.log.debug(
"Received task %s from the queue which is not found in running tasks, it is likely "
"from another Lambda Executor sharing this queue or might be a stale message that needs "
"deleting manually. Marking the message as visible again.",
ser_task_key,
)
# Mark task as visible again in SQS so that another executor can pick it up.
self.sqs_client.change_message_visibility(
QueueUrl=queue_url,
ReceiptHandle=receipt_handle,
VisibilityTimeout=0,
)
continue
if task_key:
if return_code == 0:
self.success(task_key)
self.log.info(
"Successful Lambda invocation for task %s received from SQS queue.", task_key
)
else:
self.fail(task_key)
if queue_url == self.dlq_url and return_code is None:
# DLQ failure: AWS Lambda service could not complete the invocation after retries.
# This indicates a Lambda-level failure (timeout, memory limit, crash, etc.)
# where the function was unable to successfully execute to return a result.
self.log.error(
"DLQ message received: Lambda invocation for task: %s was unable to successfully execute. This likely indicates a Lambda-level failure (timeout, memory limit, crash, etc.).",
task_key,
)
else:
# In this case the Lambda likely started but failed at run time since we got a non-zero
# return code. We could consider retrying these tasks within the executor, because this _likely_
# means the Airflow task did not run to completion, however we can't be sure (maybe the
# lambda runtime code has a bug and is returning a non-zero when it actually passed?). So
# perhaps not retrying is the safest option.
self.log.debug(
"Lambda invocation for task: %s completed but the underlying Airflow task has returned a non-zero exit code %s",
task_key,
return_code,
)
# Remove the task from the tracking mapping.
self.running_tasks.pop(ser_task_key)
# Delete the message from the queue.
self.sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
[docs]
def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
"""
Adopt task instances which have an external_executor_id (the serialized task key).
Anything that is not adopted will be cleared by the scheduler and becomes eligible for re-scheduling.
:param tis: The task instances to adopt.
"""
with Stats.timer("lambda_executor.adopt_task_instances.duration"):
adopted_tis: list[TaskInstance] = []
if serialized_task_keys := [
(ti, ti.external_executor_id) for ti in tis if ti.external_executor_id
]:
for ti, ser_task_key in serialized_task_keys:
try:
task_key = TaskInstanceKey.from_dict(json.loads(ser_task_key))
except Exception:
# If that task fails to deserialize, we should just skip it.
self.log.exception(
"Task failed to be adopted because the key could not be deserialized"
)
continue
self.running_tasks[ser_task_key] = task_key
adopted_tis.append(ti)
if adopted_tis:
tasks = [f"{task} in state {task.state}" for task in adopted_tis]
task_instance_str = "\n\t".join(tasks)
self.log.info(
"Adopted the following %d tasks from a dead executor:\n\t%s",
len(adopted_tis),
task_instance_str,
)
not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
return not_adopted_tis
[docs]
def end(self, heartbeat_interval=10):
"""
End execution. Poll until all outstanding tasks are marked as completed.
This is a blocking call and async Lambda tasks can not be cancelled, so this will wait until
all tasks are either completed or the timeout is reached.
:param heartbeat_interval: The interval in seconds to wait between checks for task completion.
"""
self.log.info("Received signal to end, waiting for outstanding tasks to finish.")
time_to_wait = int(
self.conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.END_WAIT_TIMEOUT, fallback="0")
)
start_time = timezone.utcnow()
while True:
if time_to_wait:
current_time = timezone.utcnow()
elapsed_time = (current_time - start_time).total_seconds()
if elapsed_time > time_to_wait:
self.log.warning(
"Timed out waiting for tasks to finish. Some tasks may not be handled gracefully"
" as the executor is force ending due to timeout."
)
break
self.sync()
if not self.running_tasks:
self.log.info("All tasks completed; executor ending.")
break
self.log.info("Waiting for %d task(s) to complete.", len(self.running_tasks))
time.sleep(heartbeat_interval)
[docs]
def terminate(self):
"""Get called when the daemon receives a SIGTERM."""
self.log.warning("Terminating Lambda executor. In-flight tasks cannot be stopped.")