# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportasyncioimporttypingfromtypingimportAnyfromasgiref.syncimportsync_to_asyncfromdeprecatedimportdeprecatedfromsqlalchemyimportfuncfromairflow.exceptionsimportRemovedInAirflow3Warningfromairflow.modelsimportDagRun,TaskInstancefromairflow.triggers.baseimportBaseTrigger,TriggerEventfromairflow.utils.sensor_helperimport_get_countfromairflow.utils.sessionimportNEW_SESSION,provide_sessionfromairflow.utils.stateimportTaskInstanceStatefromairflow.utils.timezoneimportutcnowiftyping.TYPE_CHECKING:fromdatetimeimportdatetimefromsqlalchemy.ormimportSessionfromairflow.utils.stateimportDagRunState
[docs]classWorkflowTrigger(BaseTrigger):""" A trigger to monitor tasks, task group and dag execution in Apache Airflow. :param external_dag_id: The ID of the external DAG. :param execution_dates: A list of execution dates for the external DAG. :param external_task_ids: A collection of external task IDs to wait for. :param external_task_group_id: The ID of the external task group to wait for. :param failed_states: States considered as failed for external tasks. :param skipped_states: States considered as skipped for external tasks. :param allowed_states: States considered as successful for external tasks. :param poke_interval: The interval (in seconds) for poking the external tasks. :param soft_fail: If True, the trigger will not fail the entire DAG on external task failure. """def__init__(self,external_dag_id:str,execution_dates:list,external_task_ids:typing.Collection[str]|None=None,external_task_group_id:str|None=None,failed_states:typing.Iterable[str]|None=None,skipped_states:typing.Iterable[str]|None=None,allowed_states:typing.Iterable[str]|None=None,poke_interval:float=2.0,soft_fail:bool=False,**kwargs,):self.external_dag_id=external_dag_idself.external_task_ids=external_task_idsself.external_task_group_id=external_task_group_idself.failed_states=failed_statesself.skipped_states=skipped_statesself.allowed_states=allowed_statesself.execution_dates=execution_datesself.poke_interval=poke_intervalself.soft_fail=soft_failsuper().__init__(**kwargs)
[docs]defserialize(self)->tuple[str,dict[str,Any]]:"""Serialize the trigger param and module path."""return("airflow.triggers.external_task.WorkflowTrigger",{"external_dag_id":self.external_dag_id,"external_task_ids":self.external_task_ids,"external_task_group_id":self.external_task_group_id,"failed_states":self.failed_states,"skipped_states":self.skipped_states,"allowed_states":self.allowed_states,"execution_dates":self.execution_dates,"poke_interval":self.poke_interval,"soft_fail":self.soft_fail,},)
[docs]asyncdefrun(self)->typing.AsyncIterator[TriggerEvent]:"""Check periodically tasks, task group or dag status."""whileTrue:ifself.failed_states:failed_count=awaitself._get_count(self.failed_states)iffailed_count>0:yieldTriggerEvent({"status":"failed"})returnelse:yieldTriggerEvent({"status":"success"})returnifself.skipped_states:skipped_count=awaitself._get_count(self.skipped_states)ifskipped_count>0:yieldTriggerEvent({"status":"skipped"})returnallowed_count=awaitself._get_count(self.allowed_states)ifallowed_count==len(self.execution_dates):yieldTriggerEvent({"status":"success"})returnself.log.info("Sleeping for %s seconds",self.poke_interval)awaitasyncio.sleep(self.poke_interval)
@sync_to_asyncdef_get_count(self,states:typing.Iterable[str]|None)->int:""" Get the count of records against dttm filter and states. Async wrapper for _get_count. :param states: task or dag states :return The count of records. """return_get_count(dttm_filter=self.execution_dates,external_task_ids=self.external_task_ids,external_task_group_id=self.external_task_group_id,external_dag_id=self.external_dag_id,states=states,)
@deprecated(reason="TaskStateTrigger has been deprecated and will be removed in future.",category=RemovedInAirflow3Warning,)
[docs]classTaskStateTrigger(BaseTrigger):""" Waits asynchronously for a task in a different DAG to complete for a specific logical date. :param dag_id: The dag_id that contains the task you want to wait for :param task_id: The task_id that contains the task you want to wait for. :param states: allowed states, default is ``['success']`` :param execution_dates: task execution time interval :param poll_interval: The time interval in seconds to check the state. The default value is 5 sec. :param trigger_start_time: time in Datetime format when the trigger was started. Is used to control the execution of trigger to prevent infinite loop in case if specified name of the dag does not exist in database. It will wait period of time equals _timeout_sec parameter from the time, when the trigger was started and if the execution lasts more time than expected, the trigger will terminate with 'timeout' status. """def__init__(self,dag_id:str,execution_dates:list[datetime],trigger_start_time:datetime,states:list[str]|None=None,task_id:str|None=None,poll_interval:float=2.0,):super().__init__()self.dag_id=dag_idself.task_id=task_idself.states=statesself.execution_dates=execution_datesself.poll_interval=poll_intervalself.trigger_start_time=trigger_start_timeself.states=statesor[TaskInstanceState.SUCCESS.value]self._timeout_sec=60
[docs]defserialize(self)->tuple[str,dict[str,typing.Any]]:"""Serialize TaskStateTrigger arguments and classpath."""return("airflow.triggers.external_task.TaskStateTrigger",{"dag_id":self.dag_id,"task_id":self.task_id,"states":self.states,"execution_dates":self.execution_dates,"poll_interval":self.poll_interval,"trigger_start_time":self.trigger_start_time,},)
[docs]asyncdefrun(self)->typing.AsyncIterator[TriggerEvent]:""" Check periodically in the database to see if the dag exists and is in the running state. If found, wait until the task specified will reach one of the expected states. If dag with specified name was not in the running state after _timeout_sec seconds after starting execution process of the trigger, terminate with status 'timeout'. """try:whileTrue:delta=utcnow()-self.trigger_start_timeifdelta.total_seconds()<self._timeout_sec:# mypy confuses typing hereifawaitself.count_running_dags()==0:# type: ignore[call-arg]self.log.info("Waiting for DAG to start execution...")awaitasyncio.sleep(self.poll_interval)else:yieldTriggerEvent({"status":"timeout"})return# mypy confuses typing hereifawaitself.count_tasks()==len(self.execution_dates):# type: ignore[call-arg]yieldTriggerEvent({"status":"success"})returnself.log.info("Task is still running, sleeping for %s seconds...",self.poll_interval)awaitasyncio.sleep(self.poll_interval)exceptException:yieldTriggerEvent({"status":"failed"})
@sync_to_async@provide_session
[docs]defcount_running_dags(self,session:Session):"""Count how many dag instances in running state in the database."""dags=(session.query(func.count("*")).filter(TaskInstance.dag_id==self.dag_id,TaskInstance.execution_date.in_(self.execution_dates),TaskInstance.state.in_(["running","success"]),).scalar())returndags
@sync_to_async@provide_session
[docs]defcount_tasks(self,*,session:Session=NEW_SESSION)->int|None:"""Count how many task instances in the database match our criteria."""count=(session.query(func.count("*"))# .count() is inefficient.filter(TaskInstance.dag_id==self.dag_id,TaskInstance.task_id==self.task_id,TaskInstance.state.in_(self.states),TaskInstance.execution_date.in_(self.execution_dates),).scalar())returntyping.cast(int,count)
[docs]classDagStateTrigger(BaseTrigger):""" Waits asynchronously for a DAG to complete for a specific logical date. :param dag_id: The dag_id that contains the task you want to wait for :param states: allowed states, default is ``['success']`` :param execution_dates: The logical date at which DAG run. :param poll_interval: The time interval in seconds to check the state. The default value is 5.0 sec. """def__init__(self,dag_id:str,states:list[DagRunState],execution_dates:list[datetime],poll_interval:float=5.0,):super().__init__()self.dag_id=dag_idself.states=statesself.execution_dates=execution_datesself.poll_interval=poll_interval
[docs]defserialize(self)->tuple[str,dict[str,typing.Any]]:"""Serialize DagStateTrigger arguments and classpath."""return("airflow.triggers.external_task.DagStateTrigger",{"dag_id":self.dag_id,"states":self.states,"execution_dates":self.execution_dates,"poll_interval":self.poll_interval,},)
[docs]asyncdefrun(self)->typing.AsyncIterator[TriggerEvent]:"""Check periodically if the dag run exists, and has hit one of the states yet, or not."""whileTrue:# mypy confuses typing herenum_dags=awaitself.count_dags()# type: ignore[call-arg]ifnum_dags==len(self.execution_dates):yieldTriggerEvent(self.serialize())returnawaitasyncio.sleep(self.poll_interval)
@sync_to_async@provide_session
[docs]defcount_dags(self,*,session:Session=NEW_SESSION)->int|None:"""Count how many dag runs in the database match our criteria."""count=(session.query(func.count("*"))# .count() is inefficient.filter(DagRun.dag_id==self.dag_id,DagRun.state.in_(self.states),DagRun.execution_date.in_(self.execution_dates),).scalar())returntyping.cast(int,count)