Source code for airflow.providers.snowflake.triggers.snowflake_trigger
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING, Any, AsyncIterator
from airflow.providers.snowflake.hooks.snowflake_sql_api import SnowflakeSqlApiHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
if TYPE_CHECKING:
from datetime import timedelta
[docs]class SnowflakeSqlApiTrigger(BaseTrigger):
"""
Fetch the status for the query ids passed.
:param poll_interval: polling period in seconds to check for the status
:param query_ids: List of Query ids to run and poll for the status
:param snowflake_conn_id: Reference to Snowflake connection id
:param token_life_time: lifetime of the JWT Token in timedelta
:param token_renewal_delta: Renewal time of the JWT Token in timedelta
"""
def __init__(
self,
poll_interval: float,
query_ids: list[str],
snowflake_conn_id: str,
token_life_time: timedelta,
token_renewal_delta: timedelta,
):
super().__init__()
self.poll_interval = poll_interval
self.query_ids = query_ids
self.snowflake_conn_id = snowflake_conn_id
self.token_life_time = token_life_time
self.token_renewal_delta = token_renewal_delta
[docs] def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize SnowflakeSqlApiTrigger arguments and classpath."""
return (
"airflow.providers.snowflake.triggers.snowflake_trigger.SnowflakeSqlApiTrigger",
{
"poll_interval": self.poll_interval,
"query_ids": self.query_ids,
"snowflake_conn_id": self.snowflake_conn_id,
"token_life_time": self.token_life_time,
"token_renewal_delta": self.token_renewal_delta,
},
)
[docs] async def run(self) -> AsyncIterator[TriggerEvent]:
"""Wait for the query the snowflake query to complete."""
SnowflakeSqlApiHook(
self.snowflake_conn_id,
self.token_life_time,
self.token_renewal_delta,
)
try:
statement_query_ids: list[str] = []
for query_id in self.query_ids:
while True:
statement_status = await self.get_query_status(query_id)
if statement_status["status"] not in ["running"]:
break
await asyncio.sleep(self.poll_interval)
if statement_status["status"] == "error":
yield TriggerEvent(statement_status)
return
if statement_status["status"] == "success":
statement_query_ids.extend(statement_status["statement_handles"])
yield TriggerEvent(
{
"status": "success",
"statement_query_ids": statement_query_ids,
}
)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})
[docs] async def get_query_status(self, query_id: str) -> dict[str, Any]:
"""Return True if the SQL query is still running otherwise return False."""
hook = SnowflakeSqlApiHook(
self.snowflake_conn_id,
self.token_life_time,
self.token_renewal_delta,
)
return await hook.get_sql_api_query_status_async(query_id)
def _set_context(self, context):
pass