Source code for airflow.providers.amazon.aws.triggers.s3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import asyncio
from functools import cached_property
from typing import Any, AsyncIterator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.triggers.base import BaseTrigger, TriggerEvent
[docs]class S3KeyTrigger(BaseTrigger):
"""
S3KeyTrigger is fired as deferred class with params to run the task in trigger worker.
:param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key``
is not provided as a full s3:// url.
:param bucket_key: The key being waited on. Supports full s3:// style url
or relative path from root level. When it's specified as a full s3://
url, please leave bucket_name as `None`.
:param wildcard_match: whether the bucket_key should be interpreted as a
Unix wildcard pattern
:param aws_conn_id: reference to the s3 connection
:param hook_params: params for hook its optional
"""
def __init__(
self,
bucket_name: str,
bucket_key: str | list[str],
wildcard_match: bool = False,
aws_conn_id: str = "aws_default",
poke_interval: float = 5.0,
should_check_fn: bool = False,
**hook_params: Any,
):
super().__init__()
self.bucket_name = bucket_name
self.bucket_key = bucket_key
self.wildcard_match = wildcard_match
self.aws_conn_id = aws_conn_id
self.hook_params = hook_params
self.poke_interval = poke_interval
self.should_check_fn = should_check_fn
[docs] def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize S3KeyTrigger arguments and classpath."""
return (
"airflow.providers.amazon.aws.triggers.s3.S3KeyTrigger",
{
"bucket_name": self.bucket_name,
"bucket_key": self.bucket_key,
"wildcard_match": self.wildcard_match,
"aws_conn_id": self.aws_conn_id,
"hook_params": self.hook_params,
"poke_interval": self.poke_interval,
"should_check_fn": self.should_check_fn,
},
)
@cached_property
[docs] def hook(self) -> S3Hook:
return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.hook_params.get("verify"))
[docs] async def run(self) -> AsyncIterator[TriggerEvent]:
"""Make an asynchronous connection using S3HookAsync."""
try:
async with self.hook.async_conn as client:
while True:
if await self.hook.check_key_async(
client, self.bucket_name, self.bucket_key, self.wildcard_match
):
if self.should_check_fn:
s3_objects = await self.hook.get_files_async(
client, self.bucket_name, self.bucket_key, self.wildcard_match
)
await asyncio.sleep(self.poke_interval)
yield TriggerEvent({"status": "running", "files": s3_objects})
else:
yield TriggerEvent({"status": "success"})
await asyncio.sleep(self.poke_interval)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})