## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportitertoolsimportosimportwarningsfromcollectionsimportdefaultdictfromtypingimportTYPE_CHECKING,Any,Callable,Iterable,Iterator,NamedTuple,Sequence,TypeVar,overloadimportre2fromsqlalchemyimport(Boolean,Column,ForeignKey,ForeignKeyConstraint,Index,Integer,PickleType,PrimaryKeyConstraint,String,Text,UniqueConstraint,and_,func,or_,text,update,)fromsqlalchemy.excimportIntegrityErrorfromsqlalchemy.ext.associationproxyimportassociation_proxyfromsqlalchemy.ormimportdeclared_attr,joinedload,relationship,synonym,validatesfromsqlalchemy.sql.expressionimportcase,false,select,truefromairflowimportsettingsfromairflow.api_internal.internal_api_callimportinternal_api_callfromairflow.callbacks.callback_requestsimportDagCallbackRequestfromairflow.configurationimportconfasairflow_conffromairflow.exceptionsimportAirflowException,RemovedInAirflow3Warning,TaskNotFoundfromairflow.listeners.listenerimportget_listener_managerfromairflow.modelsimportLogfromairflow.models.abstractoperatorimportNotMappedfromairflow.models.baseimportBase,StringIDfromairflow.models.expandinputimportNotFullyPopulatedfromairflow.models.taskinstanceimportTaskInstanceasTIfromairflow.models.tasklogimportLogTemplatefromairflow.statsimportStatsfromairflow.ti_deps.dep_contextimportDepContextfromairflow.ti_deps.dependencies_statesimportSCHEDULEABLE_STATESfromairflow.traces.tracerimportTracefromairflow.utilsimporttimezonefromairflow.utils.datesimportdatetime_to_nanofromairflow.utils.helpersimportchunks,is_container,prune_dictfromairflow.utils.log.logging_mixinimportLoggingMixinfromairflow.utils.sessionimportNEW_SESSION,provide_sessionfromairflow.utils.sqlalchemyimportUtcDateTime,nulls_first,tuple_in_condition,with_row_locksfromairflow.utils.stateimportDagRunState,State,TaskInstanceStatefromairflow.utils.typesimportNOTSET,DagRunTypeifTYPE_CHECKING:fromdatetimeimportdatetimefromsqlalchemy.ormimportQuery,Sessionfromairflow.models.dagimportDAGfromairflow.models.operatorimportOperatorfromairflow.serialization.pydantic.dag_runimportDagRunPydanticfromairflow.serialization.pydantic.taskinstanceimportTaskInstancePydanticfromairflow.serialization.pydantic.tasklogimportLogTemplatePydanticfromairflow.typing_compatimportLiteralfromairflow.utils.typesimportArgNotSet
def_creator_note(val):"""Creator the ``note`` association proxy."""ifisinstance(val,str):returnDagRunNote(content=val)elifisinstance(val,dict):returnDagRunNote(**val)else:returnDagRunNote(*val)
[docs]classDagRun(Base,LoggingMixin):""" Invocation instance of a DAG. A DAG run can be created by the scheduler (i.e. scheduled runs), or by an external trigger (i.e. manual runs). """
# Foreign key to LogTemplate. DagRun rows created prior to this column's# existence have this set to NULL. Later rows automatically populate this on# insert to point to the latest LogTemplate entry.
# Keeps track of the number of times the dagrun had been cleared.# This number is incremented only when the DagRun is re-Queued,# when the DagRun is cleared.
[docs]__table_args__=(Index("dag_id_state",dag_id,_state),UniqueConstraint("dag_id","execution_date",name="dag_run_dag_id_execution_date_key"),UniqueConstraint("dag_id","run_id",name="dag_run_dag_id_run_id_key"),Index("idx_dag_run_dag_id",dag_id),Index("idx_dag_run_running_dags","state","dag_id",postgresql_where=text("state='running'"),sqlite_where=text("state='running'"),),# since mysql lacks filtered/partial indices, this creates a# duplicate index on mysql. Not the end of the worldIndex("idx_dag_run_queued_dags","state","dag_id",postgresql_where=text("state='queued'"),sqlite_where=text("state='queued'"),),)
def__init__(self,dag_id:str|None=None,run_id:str|None=None,queued_at:datetime|None|ArgNotSet=NOTSET,execution_date:datetime|None=None,start_date:datetime|None=None,external_trigger:bool|None=None,conf:Any|None=None,state:DagRunState|None=None,run_type:str|None=None,dag_hash:str|None=None,creating_job_id:int|None=None,data_interval:tuple[datetime,datetime]|None=None,):ifdata_intervalisNone:# Legacy: Only happen for runs created prior to Airflow 2.2.self.data_interval_start=self.data_interval_end=Noneelse:self.data_interval_start,self.data_interval_end=data_intervalself.dag_id=dag_idself.run_id=run_idself.execution_date=execution_dateself.start_date=start_dateself.external_trigger=external_triggerself.conf=confor{}ifstateisnotNone:self.state=stateifqueued_atisNOTSET:self.queued_at=timezone.utcnow()ifstate==DagRunState.QUEUEDelseNoneelse:self.queued_at=queued_atself.run_type=run_typeself.dag_hash=dag_hashself.creating_job_id=creating_job_idself.clear_number=0super().__init__()
[docs]defvalidate_run_id(self,key:str,run_id:str)->str|None:ifnotrun_id:returnNoneregex=airflow_conf.get("scheduler","allowed_run_id_pattern")ifnotre2.match(regex,run_id)andnotre2.match(RUN_ID_REGEX,run_id):raiseValueError(f"The run_id provided '{run_id}' does not match the pattern '{regex}' or '{RUN_ID_REGEX}'")returnrun_id
[docs]defrefresh_from_db(self,session:Session=NEW_SESSION)->None:""" Reload the current dagrun from the database. :param session: database session """dr=session.scalars(select(DagRun).where(DagRun.dag_id==self.dag_id,DagRun.run_id==self.run_id)).one()self.id=dr.idself.state=dr.state
@classmethod@provide_session
[docs]defactive_runs_of_dags(cls,dag_ids:Iterable[str]|None=None,only_running:bool=False,session:Session=NEW_SESSION,)->dict[str,int]:"""Get the number of active dag runs for each dag."""query=select(cls.dag_id,func.count("*"))ifdag_idsisnotNone:# 'set' called to avoid duplicate dag_ids, but converted back to 'list'# because SQLAlchemy doesn't accept a set here.query=query.where(cls.dag_id.in_(set(dag_ids)))ifonly_running:query=query.where(cls.state==DagRunState.RUNNING)else:query=query.where(cls.state.in_((DagRunState.RUNNING,DagRunState.QUEUED)))query=query.group_by(cls.dag_id)returndict(iter(session.execute(query)))
@classmethod
[docs]defnext_dagruns_to_examine(cls,state:DagRunState,session:Session,max_number:int|None=None,)->Query:""" Return the next DagRuns that the scheduler should attempt to schedule. This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE" query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as the transaction is committed it will be unlocked. """fromairflow.models.dagimportDagModelifmax_numberisNone:max_number=cls.DEFAULT_DAGRUNS_TO_EXAMINE# TODO: Bake this query, it is run _A lot_query=(select(cls).with_hint(cls,"USE INDEX (idx_dag_run_running_dags)",dialect_name="mysql").where(cls.state==state,cls.run_type!=DagRunType.BACKFILL_JOB).join(DagModel,DagModel.dag_id==cls.dag_id).where(DagModel.is_paused==false(),DagModel.is_active==true()))ifstate==DagRunState.QUEUED:# For dag runs in the queued state, we check if they have reached the max_active_runs limit# and if so we drop themrunning_drs=(select(DagRun.dag_id,func.count(DagRun.state).label("num_running")).where(DagRun.state==DagRunState.RUNNING).group_by(DagRun.dag_id).subquery())query=query.outerjoin(running_drs,running_drs.c.dag_id==DagRun.dag_id).where(func.coalesce(running_drs.c.num_running,0)<DagModel.max_active_runs)query=query.order_by(nulls_first(cls.last_scheduling_decision,session=session),cls.execution_date,)ifnotsettings.ALLOW_FUTURE_EXEC_DATES:query=query.where(DagRun.execution_date<=func.now())returnsession.scalars(with_row_locks(query.limit(max_number),of=cls,session=session,skip_locked=True))
@classmethod@provide_session
[docs]deffind(cls,dag_id:str|list[str]|None=None,run_id:Iterable[str]|None=None,execution_date:datetime|Iterable[datetime]|None=None,state:DagRunState|None=None,external_trigger:bool|None=None,no_backfills:bool=False,run_type:DagRunType|None=None,session:Session=NEW_SESSION,execution_start_date:datetime|None=None,execution_end_date:datetime|None=None,)->list[DagRun]:""" Return a set of dag runs for the given search criteria. :param dag_id: the dag_id or list of dag_id to find dag runs for :param run_id: defines the run id for this dag run :param run_type: type of DagRun :param execution_date: the execution date :param state: the state of the dag run :param external_trigger: whether this dag run is externally triggered :param no_backfills: return no backfills (True), return all (False). Defaults to False :param session: database session :param execution_start_date: dag run that was executed from this date :param execution_end_date: dag run that was executed until this date """qry=select(cls)dag_ids=[dag_id]ifisinstance(dag_id,str)elsedag_idifdag_ids:qry=qry.where(cls.dag_id.in_(dag_ids))ifis_container(run_id):qry=qry.where(cls.run_id.in_(run_id))elifrun_idisnotNone:qry=qry.where(cls.run_id==run_id)ifis_container(execution_date):qry=qry.where(cls.execution_date.in_(execution_date))elifexecution_dateisnotNone:qry=qry.where(cls.execution_date==execution_date)ifexecution_start_dateandexecution_end_date:qry=qry.where(cls.execution_date.between(execution_start_date,execution_end_date))elifexecution_start_date:qry=qry.where(cls.execution_date>=execution_start_date)elifexecution_end_date:qry=qry.where(cls.execution_date<=execution_end_date)ifstate:qry=qry.where(cls.state==state)ifexternal_triggerisnotNone:qry=qry.where(cls.external_trigger==external_trigger)ifrun_type:qry=qry.where(cls.run_type==run_type)ifno_backfills:qry=qry.where(cls.run_type!=DagRunType.BACKFILL_JOB)returnsession.scalars(qry.order_by(cls.execution_date)).all()
@classmethod@provide_session
[docs]deffind_duplicate(cls,dag_id:str,run_id:str,execution_date:datetime,session:Session=NEW_SESSION,)->DagRun|None:""" Return an existing run for the DAG with a specific run_id or execution_date. *None* is returned if no such DAG run is found. :param dag_id: the dag_id to find duplicates for :param run_id: defines the run id for this dag run :param execution_date: the execution date :param session: database session """returnsession.scalars(select(cls).where(cls.dag_id==dag_id,or_(cls.run_id==run_id,cls.execution_date==execution_date),)).one_or_none()
@staticmethod
[docs]defgenerate_run_id(run_type:DagRunType,execution_date:datetime)->str:"""Generate Run ID based on Run Type and Execution Date."""# _Ensure_ run_type is a DagRunType, not just a string from user codereturnDagRunType(run_type).generate_run_id(execution_date)
@staticmethod@internal_api_call@provide_session
[docs]deffetch_task_instances(dag_id:str|None=None,run_id:str|None=None,task_ids:list[str]|None=None,state:Iterable[TaskInstanceState|None]|None=None,session:Session=NEW_SESSION,)->list[TI]:"""Return the task instances for this dag run."""tis=(select(TI).options(joinedload(TI.dag_run)).where(TI.dag_id==dag_id,TI.run_id==run_id,))ifstate:ifisinstance(state,str):tis=tis.where(TI.state==state)else:# this is required to deal with NULL valuesifNoneinstate:ifall(xisNoneforxinstate):tis=tis.where(TI.state.is_(None))else:not_none_state=(sforsinstateifs)tis=tis.where(or_(TI.state.in_(not_none_state),TI.state.is_(None)))else:tis=tis.where(TI.state.in_(state))iftask_idsisnotNone:tis=tis.where(TI.task_id.in_(task_ids))returnsession.scalars(tis).all()
@internal_api_calldef_check_last_n_dagruns_failed(self,dag_id,max_consecutive_failed_dag_runs,session):"""Check if last N dags failed."""dag_runs=(session.query(DagRun).filter(DagRun.dag_id==dag_id).order_by(DagRun.execution_date.desc()).limit(max_consecutive_failed_dag_runs).all())""" Marking dag as paused, if needed"""to_be_paused=len(dag_runs)>=max_consecutive_failed_dag_runsandall(dag_run.state==DagRunState.FAILEDfordag_runindag_runs)ifto_be_paused:fromairflow.models.dagimportDagModelself.log.info("Pausing DAG %s because last %s DAG runs failed.",self.dag_id,max_consecutive_failed_dag_runs,)filter_query=[DagModel.dag_id==self.dag_id,DagModel.root_dag_id==self.dag_id,# for sub-dags]session.execute(update(DagModel).where(or_(*filter_query)).values(is_paused=True).execution_options(synchronize_session="fetch"))session.add(Log(event="paused",dag_id=self.dag_id,owner="scheduler",owner_display_name="Scheduler",extra=f"[('dag_id', '{self.dag_id}'), ('is_paused', True)]",))else:self.log.debug("Limit of consecutive DAG failed dag runs is not reached, DAG %s will not be paused.",self.dag_id,)@provide_session
[docs]defget_task_instances(self,state:Iterable[TaskInstanceState|None]|None=None,session:Session=NEW_SESSION,)->list[TI]:""" Return the task instances for this dag run. Redirect to DagRun.fetch_task_instances method. Keep this method because it is widely used across the code. """task_ids=DagRun._get_partial_task_ids(self.dag)returnDagRun.fetch_task_instances(dag_id=self.dag_id,run_id=self.run_id,task_ids=task_ids,state=state,session=session)
@provide_session
[docs]defget_task_instance(self,task_id:str,session:Session=NEW_SESSION,*,map_index:int=-1,)->TI|TaskInstancePydantic|None:""" Return the task instance specified by task_id for this dag run. :param task_id: the task id :param session: Sqlalchemy ORM Session """returnDagRun.fetch_task_instance(dag_id=self.dag_id,dag_run_id=self.run_id,task_id=task_id,session=session,map_index=map_index,)
@staticmethod@internal_api_call@provide_session
[docs]deffetch_task_instance(dag_id:str,dag_run_id:str,task_id:str,session:Session=NEW_SESSION,map_index:int=-1,)->TI|TaskInstancePydantic|None:""" Return the task instance specified by task_id for this dag run. :param dag_id: the DAG id :param dag_run_id: the DAG run id :param task_id: the task id :param session: Sqlalchemy ORM Session """returnsession.scalars(select(TI).filter_by(dag_id=dag_id,run_id=dag_run_id,task_id=task_id,map_index=map_index)).one_or_none()
[docs]defget_dag(self)->DAG:""" Return the Dag associated with this DagRun. :return: DAG """ifnotself.dag:raiseAirflowException(f"The DAG (.dag) for {self} needs to be set")returnself.dag
@staticmethod@internal_api_call@provide_session
[docs]defget_previous_dagrun(dag_run:DagRun|DagRunPydantic,state:DagRunState|None=None,session:Session=NEW_SESSION)->DagRun|None:""" Return the previous DagRun, if there is one. :param dag_run: the dag run :param session: SQLAlchemy ORM Session :param state: the dag run state """filters=[DagRun.dag_id==dag_run.dag_id,DagRun.execution_date<dag_run.execution_date,]ifstateisnotNone:filters.append(DagRun.state==state)returnsession.scalar(select(DagRun).where(*filters).order_by(DagRun.execution_date.desc()).limit(1))
@staticmethod@internal_api_call@provide_session
[docs]defget_previous_scheduled_dagrun(dag_run_id:int,session:Session=NEW_SESSION,)->DagRun|None:""" Return the previous SCHEDULED DagRun, if there is one. :param dag_run_id: the DAG run ID :param session: SQLAlchemy ORM Session """dag_run=session.get(DagRun,dag_run_id)returnsession.scalar(select(DagRun).where(DagRun.dag_id==dag_run.dag_id,DagRun.execution_date<dag_run.execution_date,DagRun.run_type!=DagRunType.MANUAL,).order_by(DagRun.execution_date.desc()).limit(1))
def_tis_for_dagrun_state(self,*,dag,tis):""" Return the collection of tasks that should be considered for evaluation of terminal dag run state. Teardown tasks by default are not considered for the purpose of dag run state. But users may enable such consideration with on_failure_fail_dagrun. """defis_effective_leaf(task):fordown_task_idintask.downstream_task_ids:down_task=dag.get_task(down_task_id)ifnotdown_task.is_teardownordown_task.on_failure_fail_dagrun:# we found a down task that is not ignorable; not a leafreturnFalse# we found no ignorable downstreams# evaluate whether task is itself ignorablereturnnottask.is_teardownortask.on_failure_fail_dagrunleaf_task_ids={x.task_idforxindag.tasksifis_effective_leaf(x)}ifnotleaf_task_ids:# can happen if dag is exclusively teardown tasksleaf_task_ids={x.task_idforxindag.tasksifnotx.downstream_list}leaf_tis={tifortiintisifti.task_idinleaf_task_idsifti.state!=TaskInstanceState.REMOVED}returnleaf_tis@provide_session
[docs]defupdate_state(self,session:Session=NEW_SESSION,execute_callbacks:bool=True)->tuple[list[TI],DagCallbackRequest|None]:""" Determine the overall state of the DagRun based on the state of its TaskInstances. :param session: Sqlalchemy ORM Session :param execute_callbacks: Should dag callbacks (success/failure, SLA etc.) be invoked directly (default: true) or recorded as a pending request in the ``returned_callback`` property :return: Tuple containing tis that can be scheduled in the current loop & `returned_callback` that needs to be executed """# Callback to execute in case of Task Failurescallback:DagCallbackRequest|None=Noneclass_UnfinishedStates(NamedTuple):tis:Sequence[TI]@classmethoddefcalculate(cls,unfinished_tis:Sequence[TI])->_UnfinishedStates:returncls(tis=unfinished_tis)@propertydefshould_schedule(self)->bool:return(bool(self.tis)andall(nott.task.depends_on_pastfortinself.tis)# type: ignore[union-attr]andall(t.task.max_active_tis_per_dagisNonefortinself.tis)# type: ignore[union-attr]andall(t.task.max_active_tis_per_dagrunisNonefortinself.tis)# type: ignore[union-attr]andall(t.state!=TaskInstanceState.DEFERREDfortinself.tis))defrecalculate(self)->_UnfinishedStates:returnself._replace(tis=[tfortinself.tisift.stateinState.unfinished])start_dttm=timezone.utcnow()self.last_scheduling_decision=start_dttmwithStats.timer(f"dagrun.dependency-check.{self.dag_id}"),Stats.timer("dagrun.dependency-check",tags=self.stats_tags):dag=self.get_dag()info=self.task_instance_scheduling_decisions(session)tis=info.tisschedulable_tis=info.schedulable_tischanged_tis=info.changed_tisfinished_tis=info.finished_tisunfinished=_UnfinishedStates.calculate(info.unfinished_tis)ifunfinished.should_schedule:are_runnable_tasks=schedulable_tisorchanged_tis# small speed upifnotare_runnable_tasks:are_runnable_tasks,changed_by_upstream=self._are_premature_tis(unfinished.tis,finished_tis,session)ifchanged_by_upstream:# Something changed, we need to recalculate!unfinished=unfinished.recalculate()tis_for_dagrun_state=self._tis_for_dagrun_state(dag=dag,tis=tis)# if all tasks finished and at least one failed, the run failedifnotunfinished.tisandany(x.stateinState.failed_statesforxintis_for_dagrun_state):self.log.error("Marking run %s failed",self)self.set_state(DagRunState.FAILED)self.notify_dagrun_state_changed(msg="task_failure")ifexecute_callbacks:dag.handle_callback(self,success=False,reason="task_failure",session=session)elifdag.has_on_failure_callback:fromairflow.models.dagimportDagModeldag_model=DagModel.get_dagmodel(dag.dag_id,session)callback=DagCallbackRequest(full_filepath=dag.fileloc,dag_id=self.dag_id,run_id=self.run_id,is_failure_callback=True,processor_subdir=Noneifdag_modelisNoneelsedag_model.processor_subdir,msg="task_failure",)# Check if the max_consecutive_failed_dag_runs has been provided and not 0# and last consecutive failures are moreifdag.max_consecutive_failed_dag_runs>0:self.log.debug("Checking consecutive failed DAG runs for DAG %s, limit is %s",self.dag_id,dag.max_consecutive_failed_dag_runs,)self._check_last_n_dagruns_failed(dag.dag_id,dag.max_consecutive_failed_dag_runs,session)# if all leaves succeeded and no unfinished tasks, the run succeededelifnotunfinished.tisandall(x.stateinState.success_statesforxintis_for_dagrun_state):self.log.info("Marking run %s successful",self)self.set_state(DagRunState.SUCCESS)self.notify_dagrun_state_changed(msg="success")ifexecute_callbacks:dag.handle_callback(self,success=True,reason="success",session=session)elifdag.has_on_success_callback:fromairflow.models.dagimportDagModeldag_model=DagModel.get_dagmodel(dag.dag_id,session)callback=DagCallbackRequest(full_filepath=dag.fileloc,dag_id=self.dag_id,run_id=self.run_id,is_failure_callback=False,processor_subdir=Noneifdag_modelisNoneelsedag_model.processor_subdir,msg="success",)# if *all tasks* are deadlocked, the run failedelifunfinished.should_scheduleandnotare_runnable_tasks:self.log.error("Task deadlock (no runnable tasks); marking run %s failed",self)self.set_state(DagRunState.FAILED)self.notify_dagrun_state_changed(msg="all_tasks_deadlocked")ifexecute_callbacks:dag.handle_callback(self,success=False,reason="all_tasks_deadlocked",session=session)elifdag.has_on_failure_callback:fromairflow.models.dagimportDagModeldag_model=DagModel.get_dagmodel(dag.dag_id,session)callback=DagCallbackRequest(full_filepath=dag.fileloc,dag_id=self.dag_id,run_id=self.run_id,is_failure_callback=True,processor_subdir=Noneifdag_modelisNoneelsedag_model.processor_subdir,msg="all_tasks_deadlocked",)# finally, if the leaves aren't done, the dag is still runningelse:self.set_state(DagRunState.RUNNING)ifself._state==DagRunState.FAILEDorself._state==DagRunState.SUCCESS:msg=("DagRun Finished: dag_id=%s, execution_date=%s, run_id=%s, ""run_start_date=%s, run_end_date=%s, run_duration=%s, ""state=%s, external_trigger=%s, run_type=%s, ""data_interval_start=%s, data_interval_end=%s, dag_hash=%s")self.log.info(msg,self.dag_id,self.execution_date,self.run_id,self.start_date,self.end_date,((self.end_date-self.start_date).total_seconds()ifself.start_dateandself.end_dateelseNone),self._state,self.external_trigger,self.run_type,self.data_interval_start,self.data_interval_end,self.dag_hash,)withTrace.start_span_from_dagrun(dagrun=self)asspan:ifself._stateisDagRunState.FAILED:span.set_attribute("error",True)attributes={"category":"DAG runs","dag_id":str(self.dag_id),"execution_date":str(self.execution_date),"run_id":str(self.run_id),"queued_at":str(self.queued_at),"run_start_date":str(self.start_date),"run_end_date":str(self.end_date),"run_duration":str((self.end_date-self.start_date).total_seconds()ifself.start_dateandself.end_dateelse0),"state":str(self._state),"external_trigger":str(self.external_trigger),"run_type":str(self.run_type),"data_interval_start":str(self.data_interval_start),"data_interval_end":str(self.data_interval_end),"dag_hash":str(self.dag_hash),"conf":str(self.conf),}ifspan.is_recording():span.add_event(name="queued",timestamp=datetime_to_nano(self.queued_at))span.add_event(name="started",timestamp=datetime_to_nano(self.start_date))span.add_event(name="ended",timestamp=datetime_to_nano(self.end_date))span.set_attributes(attributes)session.flush()self._emit_true_scheduling_delay_stats_for_finished_state(finished_tis)self._emit_duration_stats_for_finished_state()session.merge(self)# We do not flush here for performance reasons(It increases queries count by +20)returnschedulable_tis,callback
@provide_session
[docs]deftask_instance_scheduling_decisions(self,session:Session=NEW_SESSION)->TISchedulingDecision:tis=self.get_task_instances(session=session,state=State.task_states)self.log.debug("number of tis tasks for %s: %s task(s)",self,len(tis))def_filter_tis_and_exclude_removed(dag:DAG,tis:list[TI])->Iterable[TI]:"""Populate ``ti.task`` while excluding those missing one, marking them as REMOVED."""fortiintis:try:ti.task=dag.get_task(ti.task_id)exceptTaskNotFound:ifti.state!=TaskInstanceState.REMOVED:self.log.error("Failed to get task for ti %s. Marking it as removed.",ti)ti.state=TaskInstanceState.REMOVEDsession.flush()else:yieldtitis=list(_filter_tis_and_exclude_removed(self.get_dag(),tis))unfinished_tis=[tfortintisift.stateinState.unfinished]finished_tis=[tfortintisift.stateinState.finished]ifunfinished_tis:schedulable_tis=[utforutinunfinished_tisifut.stateinSCHEDULEABLE_STATES]self.log.debug("number of scheduleable tasks for %s: %s task(s)",self,len(schedulable_tis))schedulable_tis,changed_tis,expansion_happened=self._get_ready_tis(schedulable_tis,finished_tis,session=session,)# During expansion, we may change some tis into non-schedulable# states, so we need to re-compute.ifexpansion_happened:changed_tis=Truenew_unfinished_tis=[tfortinunfinished_tisift.stateinState.unfinished]finished_tis.extend(tfortinunfinished_tisift.stateinState.finished)unfinished_tis=new_unfinished_tiselse:schedulable_tis=[]changed_tis=FalsereturnTISchedulingDecision(tis=tis,schedulable_tis=schedulable_tis,changed_tis=changed_tis,unfinished_tis=unfinished_tis,finished_tis=finished_tis,)
# deliberately not notifying on QUEUED# we can't get all the state changes on SchedulerJob, BackfillJob# or LocalTaskJob, so we don't want to "falsely advertise" we notify about thatdef_get_ready_tis(self,schedulable_tis:list[TI],finished_tis:list[TI],session:Session,)->tuple[list[TI],bool,bool]:old_states={}ready_tis:list[TI]=[]changed_tis=Falseifnotschedulable_tis:returnready_tis,changed_tis,False# If we expand TIs, we need a new list so that we iterate over them too. (We can't alter# `schedulable_tis` in place and have the `for` loop pick them upadditional_tis:list[TI]=[]dep_context=DepContext(flag_upstream_failed=True,ignore_unmapped_tasks=True,# Ignore this Dep, as we will expand it if we can.finished_tis=finished_tis,)def_expand_mapped_task_if_needed(ti:TI)->Iterable[TI]|None:""" Try to expand the ti, if needed. If the ti needs expansion, newly created task instances are returned as well as the original ti. The original ti is also modified in-place and assigned the ``map_index`` of 0. If the ti does not need expansion, either because the task is not mapped, or has already been expanded, *None* is returned. """ifTYPE_CHECKING:assertti.taskifti.map_index>=0:# Already expanded, we're good.returnNonefromairflow.models.mappedoperatorimportMappedOperatorifisinstance(ti.task,MappedOperator):# If we get here, it could be that we are moving from non-mapped to mapped# after task instance clearing or this ti is not yet expanded. Safe to clear# the db references.ti.clear_db_references(session=session)try:expanded_tis,_=ti.task.expand_mapped_task(self.run_id,session=session)exceptNotMapped:# Not a mapped task, nothing needed.returnNoneifexpanded_tis:returnexpanded_tisreturn()# Check dependencies.expansion_happened=False# Set of task ids for which was already done _revise_map_indexes_if_mappedrevised_map_index_task_ids=set()forschedulableinitertools.chain(schedulable_tis,additional_tis):ifTYPE_CHECKING:assertschedulable.taskold_state=schedulable.stateifnotschedulable.are_dependencies_met(session=session,dep_context=dep_context):old_states[schedulable.key]=old_statecontinue# If schedulable is not yet expanded, try doing it now. This is# called in two places: First and ideally in the mini scheduler at# the end of LocalTaskJob, and then as an "expansion of last resort"# in the scheduler to ensure that the mapped task is correctly# expanded before executed. Also see _revise_map_indexes_if_mapped# docstring for additional information.new_tis=Noneifschedulable.map_index<0:new_tis=_expand_mapped_task_if_needed(schedulable)ifnew_tisisnotNone:additional_tis.extend(new_tis)expansion_happened=Trueifnew_tisisNoneandschedulable.stateinSCHEDULEABLE_STATES:# It's enough to revise map index once per task id,# checking the map index for each mapped task significantly slows down schedulingifschedulable.task.task_idnotinrevised_map_index_task_ids:ready_tis.extend(self._revise_map_indexes_if_mapped(schedulable.task,session=session))revised_map_index_task_ids.add(schedulable.task.task_id)ready_tis.append(schedulable)# Check if any ti changed statetis_filter=TI.filter_for_tis(old_states)iftis_filterisnotNone:fresh_tis=session.scalars(select(TI).where(tis_filter)).all()changed_tis=any(ti.state!=old_states[ti.key]fortiinfresh_tis)returnready_tis,changed_tis,expansion_happeneddef_are_premature_tis(self,unfinished_tis:Sequence[TI],finished_tis:list[TI],session:Session,)->tuple[bool,bool]:dep_context=DepContext(flag_upstream_failed=True,ignore_in_retry_period=True,ignore_in_reschedule_period=True,finished_tis=finished_tis,)# there might be runnable tasks that are up for retry and for some reason(retry delay, etc.) are# not ready yet, so we set the flags to count them inreturn(any(ut.are_dependencies_met(dep_context=dep_context,session=session)forutinunfinished_tis),dep_context.have_changed_ti_states,)def_emit_true_scheduling_delay_stats_for_finished_state(self,finished_tis:list[TI])->None:""" Emit the true scheduling delay stats. The true scheduling delay stats is defined as the time when the first task in DAG starts minus the expected DAG run datetime. This helper method is used in ``update_state`` when the state of the DAG run is updated to a completed status (either success or failure). It finds the first started task within the DAG, calculates the run's expected start time based on the logical date and timetable, and gets the delay from the difference of these two values. The emitted data may contain outliers (e.g. when the first task was cleared, so the second task's start date will be used), but we can get rid of the outliers on the stats side through dashboards tooling. Note that the stat will only be emitted for scheduler-triggered DAG runs (i.e. when ``external_trigger`` is *False* and ``clear_number`` is equal to 0). """ifself.state==TaskInstanceState.RUNNING:returnifself.external_trigger:returnifself.clear_number>0:returnifnotfinished_tis:returntry:dag=self.get_dag()ifnotdag.timetable.periodic:# We can't emit this metric if there is no following schedule to calculate from!returntry:first_start_date=min(ti.start_datefortiinfinished_tisifti.start_date)exceptValueError:# No start dates at all.passelse:# TODO: Logically, this should be DagRunInfo.run_after, but the# information is not stored on a DagRun, only before the actual# execution on DagModel.next_dagrun_create_after. We should add# a field on DagRun for this instead of relying on the run# always happening immediately after the data interval.data_interval_end=dag.get_run_data_interval(self).endtrue_delay=first_start_date-data_interval_endiftrue_delay.total_seconds()>0:Stats.timing(f"dagrun.{dag.dag_id}.first_task_scheduling_delay",true_delay,tags=self.stats_tags)Stats.timing("dagrun.first_task_scheduling_delay",true_delay,tags=self.stats_tags)exceptException:self.log.warning("Failed to record first_task_scheduling_delay metric:",exc_info=True)def_emit_duration_stats_for_finished_state(self):ifself.state==DagRunState.RUNNING:returnifself.start_dateisNone:self.log.warning("Failed to record duration of %s: start_date is not set.",self)returnifself.end_dateisNone:self.log.warning("Failed to record duration of %s: end_date is not set.",self)returnduration=self.end_date-self.start_datetimer_params={"dt":duration,"tags":self.stats_tags}Stats.timing(f"dagrun.duration.{self.state}.{self.dag_id}",**timer_params)Stats.timing(f"dagrun.duration.{self.state}",**timer_params)@provide_session
[docs]defverify_integrity(self,*,session:Session=NEW_SESSION)->None:""" Verify the DagRun by checking for removed tasks or tasks that are not in the database yet. It will set state to removed or add the task if required. :missing_indexes: A dictionary of task vs indexes that are missing. :param session: Sqlalchemy ORM Session """fromairflow.settingsimporttask_instance_mutation_hook# Set for the empty default in airflow.settings -- if it's not set this means it has been changed# Note: Literal[True, False] instead of bool because otherwise it doesn't correctly find the overload.hook_is_noop:Literal[True,False]=getattr(task_instance_mutation_hook,"is_noop",False)dag=self.get_dag()task_ids=self._check_for_removed_or_restored_tasks(dag,task_instance_mutation_hook,session=session)deftask_filter(task:Operator)->bool:returntask.task_idnotintask_idsand(self.is_backfillor(task.start_dateisNoneortask.start_date<=self.execution_date)and(task.end_dateisNoneorself.execution_date<=task.end_date))created_counts:dict[str,int]=defaultdict(int)task_creator=self._get_task_creator(created_counts,task_instance_mutation_hook,hook_is_noop)# Create the missing tasks, including mapped taskstasks_to_create=(taskfortaskindag.task_dict.values()iftask_filter(task))tis_to_create=self._create_tasks(tasks_to_create,task_creator,session=session)self._create_task_instances(self.dag_id,tis_to_create,created_counts,hook_is_noop,session=session)
def_check_for_removed_or_restored_tasks(self,dag:DAG,ti_mutation_hook,*,session:Session)->set[str]:""" Check for removed tasks/restored/missing tasks. :param dag: DAG object corresponding to the dagrun :param ti_mutation_hook: task_instance_mutation_hook function :param session: Sqlalchemy ORM Session :return: Task IDs in the DAG run """tis=self.get_task_instances(session=session)# check for removed or restored taskstask_ids=set()fortiintis:ti_mutation_hook(ti)task_ids.add(ti.task_id)try:task=dag.get_task(ti.task_id)should_restore_task=(taskisnotNone)andti.state==TaskInstanceState.REMOVEDifshould_restore_task:self.log.info("Restoring task '%s' which was previously removed from DAG '%s'",ti,dag)Stats.incr(f"task_restored_to_dag.{dag.dag_id}",tags=self.stats_tags)# Same metric with taggingStats.incr("task_restored_to_dag",tags={**self.stats_tags,"dag_id":dag.dag_id})ti.state=NoneexceptAirflowException:ifti.state==TaskInstanceState.REMOVED:pass# ti has already been removed, just ignore itelifself.state!=DagRunState.RUNNINGandnotdag.partial:self.log.warning("Failed to get task '%s' for dag '%s'. Marking it as removed.",ti,dag)Stats.incr(f"task_removed_from_dag.{dag.dag_id}",tags=self.stats_tags)# Same metric with taggingStats.incr("task_removed_from_dag",tags={**self.stats_tags,"dag_id":dag.dag_id})ti.state=TaskInstanceState.REMOVEDcontinuetry:num_mapped_tis=task.get_parse_time_mapped_ti_count()exceptNotMapped:continueexceptNotFullyPopulated:# What if it is _now_ dynamically mapped, but wasn't before?try:total_length=task.get_mapped_ti_count(self.run_id,session=session)exceptNotFullyPopulated:# Not all upstreams finished, so we can't tell what should be here. Remove everything.ifti.map_index>=0:self.log.debug("Removing the unmapped TI '%s' as the mapping can't be resolved yet",ti)ti.state=TaskInstanceState.REMOVEDcontinue# Upstreams finished, check there aren't any extrasifti.map_index>=total_length:self.log.debug("Removing task '%s' as the map_index is longer than the resolved mapping list (%d)",ti,total_length,)ti.state=TaskInstanceState.REMOVEDelse:# Check if the number of mapped literals has changed, and we need to mark this TI as removed.ifti.map_index>=num_mapped_tis:self.log.debug("Removing task '%s' as the map_index is longer than the literal mapping list (%s)",ti,num_mapped_tis,)ti.state=TaskInstanceState.REMOVEDelifti.map_index<0:self.log.debug("Removing the unmapped TI '%s' as the mapping can now be performed",ti)ti.state=TaskInstanceState.REMOVEDreturntask_ids@overloaddef_get_task_creator(self,created_counts:dict[str,int],ti_mutation_hook:Callable,hook_is_noop:Literal[True],)->Callable[[Operator,Iterable[int]],Iterator[dict[str,Any]]]:...@overloaddef_get_task_creator(self,created_counts:dict[str,int],ti_mutation_hook:Callable,hook_is_noop:Literal[False],)->Callable[[Operator,Iterable[int]],Iterator[TI]]:...def_get_task_creator(self,created_counts:dict[str,int],ti_mutation_hook:Callable,hook_is_noop:Literal[True,False],)->Callable[[Operator,Iterable[int]],Iterator[dict[str,Any]]|Iterator[TI]]:""" Get the task creator function. This function also updates the created_counts dictionary with the number of tasks created. :param created_counts: Dictionary of task_type -> count of created TIs :param ti_mutation_hook: task_instance_mutation_hook function :param hook_is_noop: Whether the task_instance_mutation_hook is a noop """ifhook_is_noop:defcreate_ti_mapping(task:Operator,indexes:Iterable[int])->Iterator[dict[str,Any]]:created_counts[task.task_type]+=1formap_indexinindexes:yieldTI.insert_mapping(self.run_id,task,map_index=map_index)creator=create_ti_mappingelse:defcreate_ti(task:Operator,indexes:Iterable[int])->Iterator[TI]:formap_indexinindexes:ti=TI(task,run_id=self.run_id,map_index=map_index)ti_mutation_hook(ti)created_counts[ti.operator]+=1yieldticreator=create_tireturncreatordef_create_tasks(self,tasks:Iterable[Operator],task_creator:Callable[[Operator,Iterable[int]],CreatedTasks],*,session:Session,)->CreatedTasks:""" Create missing tasks -- and expand any MappedOperator that _only_ have literals as input. :param tasks: Tasks to create jobs for in the DAG run :param task_creator: Function to create task instances """map_indexes:Iterable[int]fortaskintasks:try:count=task.get_mapped_ti_count(self.run_id,session=session)except(NotMapped,NotFullyPopulated):map_indexes=(-1,)else:ifcount:map_indexes=range(count)else:# Make sure to always create at least one ti; this will be# marked as REMOVED later at runtime.map_indexes=(-1,)yield fromtask_creator(task,map_indexes)def_create_task_instances(self,dag_id:str,tasks:Iterator[dict[str,Any]]|Iterator[TI],created_counts:dict[str,int],hook_is_noop:bool,*,session:Session,)->None:""" Create the necessary task instances from the given tasks. :param dag_id: DAG ID associated with the dagrun :param tasks: the tasks to create the task instances from :param created_counts: a dictionary of number of tasks -> total ti created by the task creator :param hook_is_noop: whether the task_instance_mutation_hook is noop :param session: the session to use """# Fetch the information we need before handling the exception to avoid# PendingRollbackError due to the session being invalidated on exception# see https://github.com/apache/superset/pull/530run_id=self.run_idtry:ifhook_is_noop:session.bulk_insert_mappings(TI,tasks)else:session.bulk_save_objects(tasks)fortask_type,countincreated_counts.items():Stats.incr(f"task_instance_created_{task_type}",count,tags=self.stats_tags)# Same metric with taggingStats.incr("task_instance_created",count,tags={**self.stats_tags,"task_type":task_type})session.flush()exceptIntegrityError:self.log.info("Hit IntegrityError while creating the TIs for %s- %s",dag_id,run_id,exc_info=True,)self.log.info("Doing session rollback.")# TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.session.rollback()def_revise_map_indexes_if_mapped(self,task:Operator,*,session:Session)->Iterator[TI]:""" Check if task increased or reduced in length and handle appropriately. Task instances that do not already exist are created and returned if possible. Expansion only happens if all upstreams are ready; otherwise we delay expansion to the "last resort". See comments at the call site for more details. """fromairflow.settingsimporttask_instance_mutation_hooktry:total_length=task.get_mapped_ti_count(self.run_id,session=session)exceptNotMapped:return# Not a mapped task, don't need to do anything.exceptNotFullyPopulated:return# Upstreams not ready, don't need to revise this yet.query=session.scalars(select(TI.map_index).where(TI.dag_id==self.dag_id,TI.task_id==task.task_id,TI.run_id==self.run_id,))existing_indexes=set(query)removed_indexes=existing_indexes.difference(range(total_length))ifremoved_indexes:session.execute(update(TI).where(TI.dag_id==self.dag_id,TI.task_id==task.task_id,TI.run_id==self.run_id,TI.map_index.in_(removed_indexes),).values(state=TaskInstanceState.REMOVED))session.flush()forindexinrange(total_length):ifindexinexisting_indexes:continueti=TI(task,run_id=self.run_id,map_index=index,state=None)self.log.debug("Expanding TIs upserted %s",ti)task_instance_mutation_hook(ti)ti=session.merge(ti)ti.refresh_from_task(task)session.flush()yieldti@staticmethoddefget_run(session:Session,dag_id:str,execution_date:datetime)->DagRun|None:""" Get a single DAG Run. :meta private: :param session: Sqlalchemy ORM Session :param dag_id: DAG ID :param execution_date: execution date :return: DagRun corresponding to the given dag_id and execution date if one exists. None otherwise. """warnings.warn("This method is deprecated. Please use SQLAlchemy directly",RemovedInAirflow3Warning,stacklevel=2,)returnsession.scalar(select(DagRun).where(DagRun.dag_id==dag_id,DagRun.external_trigger==False,# noqa: E712DagRun.execution_date==execution_date,))@property
[docs]defget_latest_runs(cls,session:Session=NEW_SESSION)->list[DagRun]:"""Return the latest DagRun for each DAG."""subquery=(select(cls.dag_id,func.max(cls.execution_date).label("execution_date")).group_by(cls.dag_id).subquery())returnsession.scalars(select(cls).join(subquery,and_(cls.dag_id==subquery.c.dag_id,cls.execution_date==subquery.c.execution_date),)).all()
@provide_session
[docs]defschedule_tis(self,schedulable_tis:Iterable[TI],session:Session=NEW_SESSION,max_tis_per_query:int|None=None,)->int:""" Set the given task instances in to the scheduled state. Each element of ``schedulable_tis`` should have its ``task`` attribute already set. Any EmptyOperator without callbacks or outlets is instead set straight to the success state. All the TIs should belong to this DagRun, but this code is in the hot-path, this is not checked -- it is the caller's responsibility to call this function only with TIs from a single dag run. """# Get list of TI IDs that do not need to executed, these are# tasks using EmptyOperator and without on_execute_callback / on_success_callbackdummy_ti_ids=[]schedulable_ti_ids=[]fortiinschedulable_tis:ifTYPE_CHECKING:assertti.taskif(ti.task.inherits_from_empty_operatorandnotti.task.on_execute_callbackandnotti.task.on_success_callbackandnotti.task.outlets):dummy_ti_ids.append((ti.task_id,ti.map_index))# check "start_trigger_args" to see whether the operator supports start execution from triggerer# if so, we'll then check "start_from_trigger" to see whether this feature is turned on and defer# this task.# if not, we'll add this "ti" into "schedulable_ti_ids" and later execute it to run in the workerelifti.task.start_trigger_argsisnotNone:context=ti.get_template_context()start_from_trigger=ti.task.expand_start_from_trigger(context=context,session=session)ifstart_from_trigger:ti.start_date=timezone.utcnow()ifti.state!=TaskInstanceState.UP_FOR_RESCHEDULE:ti.try_number+=1ti.defer_task(exception=None,session=session)else:schedulable_ti_ids.append((ti.task_id,ti.map_index))else:schedulable_ti_ids.append((ti.task_id,ti.map_index))count=0ifschedulable_ti_ids:schedulable_ti_ids_chunks=chunks(schedulable_ti_ids,max_tis_per_queryorlen(schedulable_ti_ids))forschedulable_ti_ids_chunkinschedulable_ti_ids_chunks:count+=session.execute(update(TI).where(TI.dag_id==self.dag_id,TI.run_id==self.run_id,tuple_in_condition((TI.task_id,TI.map_index),schedulable_ti_ids_chunk),).values(state=TaskInstanceState.SCHEDULED,try_number=case((or_(TI.state.is_(None),TI.state!=TaskInstanceState.UP_FOR_RESCHEDULE),TI.try_number+1,),else_=TI.try_number,),).execution_options(synchronize_session=False)).rowcount# Tasks using EmptyOperator should not be executed, mark them as successifdummy_ti_ids:dummy_ti_ids_chunks=chunks(dummy_ti_ids,max_tis_per_queryorlen(dummy_ti_ids))fordummy_ti_ids_chunkindummy_ti_ids_chunks:count+=session.execute(update(TI).where(TI.dag_id==self.dag_id,TI.run_id==self.run_id,tuple_in_condition((TI.task_id,TI.map_index),dummy_ti_ids_chunk),).values(state=TaskInstanceState.SUCCESS,start_date=timezone.utcnow(),end_date=timezone.utcnow(),duration=0,try_number=TI.try_number+1,).execution_options(synchronize_session=False,)).rowcountreturncount
@staticmethod@internal_api_call@provide_sessiondef_get_log_template(log_template_id:int|None,session:Session=NEW_SESSION)->LogTemplate|LogTemplatePydantic:template:LogTemplate|Noneiflog_template_idisNone:# DagRun created before LogTemplate introduction.template=session.scalar(select(LogTemplate).order_by(LogTemplate.id).limit(1))else:template=session.get(LogTemplate,log_template_id)iftemplateisNone:raiseAirflowException(f"No log_template entry found for ID {log_template_id!r}. "f"Please make sure you set up the metadatabase correctly.")returntemplate@provide_session
[docs]defget_log_filename_template(self,*,session:Session=NEW_SESSION)->str:warnings.warn("This method is deprecated. Please use get_log_template instead.",RemovedInAirflow3Warning,stacklevel=2,)returnself.get_log_template(session=session).filename