Source code for airflow.providers.amazon.aws.triggers.base

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from abc import abstractmethod
from typing import Any, AsyncIterator

from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent


[docs]class AwsBaseWaiterTrigger(BaseTrigger): """ Base class for all AWS Triggers that follow the "standard" model of just waiting on a waiter. Subclasses need to implement the hook() method. :param serialized_fields: Fields that are specific to the subclass trigger and need to be serialized to be passed to the __init__ method on deserialization. The conn id, region, and waiter delay & attempts are always serialized. format: {<parameter_name>: <parameter_value>} :param waiter_name: The name of the (possibly custom) boto waiter to use. :param waiter_args: The arguments to pass to the waiter. :param failure_message: The message to log if a failure state is reached. :param status_message: The message logged when printing the status of the service. :param status_queries: A list containing the JMESPath queries to retrieve status information from the waiter response. See https://jmespath.org/tutorial.html :param return_key: The key to use for the return_value in the TriggerEvent this emits on success. Defaults to "value". :param return_value: A value that'll be returned in the return_key field of the TriggerEvent. Set to None if there is nothing to return. :param waiter_delay: The amount of time in seconds to wait between attempts. :param waiter_max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. To be used to build the hook. :param region_name: The AWS region where the resources to watch are. To be used to build the hook. """ def __init__( self, *, serialized_fields: dict[str, Any], waiter_name: str, waiter_args: dict[str, Any], failure_message: str, status_message: str, status_queries: list[str], return_key: str = "value", return_value: Any, waiter_delay: int, waiter_max_attempts: int, aws_conn_id: str | None, region_name: str | None = None, ): # parameters that should be hardcoded in the child's implem self.serialized_fields = serialized_fields self.waiter_name = waiter_name self.waiter_args = waiter_args self.failure_message = failure_message self.status_message = status_message self.status_queries = status_queries self.return_key = return_key self.return_value = return_value # parameters that should be passed directly from the child's parameters self.waiter_delay = waiter_delay self.attempts = waiter_max_attempts self.aws_conn_id = aws_conn_id self.region_name = region_name
[docs] def serialize(self) -> tuple[str, dict[str, Any]]: # here we put together the "common" params, # and whatever extras we got from the subclass in serialized_fields params = dict( { "waiter_delay": self.waiter_delay, "waiter_max_attempts": self.attempts, "aws_conn_id": self.aws_conn_id, }, **self.serialized_fields, ) if self.region_name: # if we serialize the None value from this, it breaks subclasses that don't have it in their ctor. params["region_name"] = self.region_name return ( # remember that self is an instance of the subclass here, not of this class. self.__class__.__module__ + "." + self.__class__.__qualname__, params, )
@abstractmethod
[docs] def hook(self) -> AwsGenericHook: """Override in subclasses to return the right hook.""" ...
[docs] async def run(self) -> AsyncIterator[TriggerEvent]: hook = self.hook() async with hook.async_conn as client: waiter = hook.get_waiter(self.waiter_name, deferrable=True, client=client) await async_wait( waiter, self.waiter_delay, self.attempts, self.waiter_args, self.failure_message, self.status_message, self.status_queries, ) yield TriggerEvent({"status": "success", self.return_key: self.return_value})

Was this entry helpful?