## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportitertoolsimportosimportwarningsfromcollectionsimportdefaultdictfromdatetimeimportdatetimefromtypingimport(TYPE_CHECKING,Any,Callable,Iterable,Iterator,NamedTuple,Sequence,TypeVar,cast,overload,)fromsqlalchemyimport(Boolean,Column,ForeignKey,Index,Integer,PickleType,String,UniqueConstraint,and_,func,or_,text,)fromsqlalchemy.excimportIntegrityErrorfromsqlalchemy.ext.declarativeimportdeclared_attrfromsqlalchemy.ormimportjoinedload,relationship,synonymfromsqlalchemy.orm.sessionimportSessionfromsqlalchemy.sql.expressionimportfalse,select,truefromairflowimportsettingsfromairflow.callbacks.callback_requestsimportDagCallbackRequestfromairflow.configurationimportconfasairflow_conffromairflow.exceptionsimportAirflowException,RemovedInAirflow3Warning,TaskNotFoundfromairflow.models.baseimportBase,StringIDfromairflow.models.mappedoperatorimportMappedOperatorfromairflow.models.taskinstanceimportTaskInstanceasTIfromairflow.models.tasklogimportLogTemplatefromairflow.statsimportStatsfromairflow.ti_deps.dep_contextimportDepContextfromairflow.ti_deps.dependencies_statesimportSCHEDULEABLE_STATESfromairflow.typing_compatimportLiteralfromairflow.utilsimporttimezonefromairflow.utils.helpersimportis_containerfromairflow.utils.log.logging_mixinimportLoggingMixinfromairflow.utils.sessionimportNEW_SESSION,provide_sessionfromairflow.utils.sqlalchemyimportUtcDateTime,nulls_first,skip_locked,tuple_in_condition,with_row_locksfromairflow.utils.stateimportDagRunState,State,TaskInstanceStatefromairflow.utils.typesimportNOTSET,ArgNotSet,DagRunTypeifTYPE_CHECKING:fromairflow.models.dagimportDAGfromairflow.models.operatorimportOperator
[docs]classDagRun(Base,LoggingMixin):""" DagRun describes an instance of a Dag. It can be created by the scheduler (for regular runs) or by an external trigger """
# Foreign key to LogTemplate. DagRun rows created prior to this column's# existence have this set to NULL. Later rows automatically populate this on# insert to point to the latest LogTemplate entry.
)# Remove this `if` after upgrading Sphinx-AutoAPIifnotTYPE_CHECKINGand"BUILDING_AIRFLOW_DOCS"inos.environ:dag:DAG|Noneelse:dag:DAG|None=None
[docs]__table_args__=(Index('dag_id_state',dag_id,_state),UniqueConstraint('dag_id','execution_date',name='dag_run_dag_id_execution_date_key'),UniqueConstraint('dag_id','run_id',name='dag_run_dag_id_run_id_key'),Index('idx_last_scheduling_decision',last_scheduling_decision),Index('idx_dag_run_dag_id',dag_id),Index('idx_dag_run_running_dags','state','dag_id',postgresql_where=text("state='running'"),mssql_where=text("state='running'"),sqlite_where=text("state='running'"),),# since mysql lacks filtered/partial indices, this creates a# duplicate index on mysql. Not the end of the worldIndex('idx_dag_run_queued_dags','state','dag_id',postgresql_where=text("state='queued'"),mssql_where=text("state='queued'"),sqlite_where=text("state='queued'"),
)def__init__(self,dag_id:str|None=None,run_id:str|None=None,queued_at:datetime|None|ArgNotSet=NOTSET,execution_date:datetime|None=None,start_date:datetime|None=None,external_trigger:bool|None=None,conf:Any|None=None,state:DagRunState|None=None,run_type:str|None=None,dag_hash:str|None=None,creating_job_id:int|None=None,data_interval:tuple[datetime,datetime]|None=None,):ifdata_intervalisNone:# Legacy: Only happen for runs created prior to Airflow 2.2.self.data_interval_start=self.data_interval_end=Noneelse:self.data_interval_start,self.data_interval_end=data_intervalself.dag_id=dag_idself.run_id=run_idself.execution_date=execution_dateself.start_date=start_dateself.external_trigger=external_triggerself.conf=confor{}ifstateisnotNone:self.state=stateifqueued_atisNOTSET:self.queued_at=timezone.utcnow()ifstate==State.QUEUEDelseNoneelse:self.queued_at=queued_atself.run_type=run_typeself.dag_hash=dag_hashself.creating_job_id=creating_job_idsuper().__init__()
[docs]defrefresh_from_db(self,session:Session=NEW_SESSION)->None:""" Reloads the current dagrun from the database :param session: database session """dr=session.query(DagRun).filter(DagRun.dag_id==self.dag_id,DagRun.run_id==self.run_id).one()self.id=dr.idself.state=dr.state
@classmethod@provide_session
[docs]defactive_runs_of_dags(cls,dag_ids=None,only_running=False,session=None)->dict[str,int]:"""Get the number of active dag runs for each dag."""query=session.query(cls.dag_id,func.count('*'))ifdag_idsisnotNone:# 'set' called to avoid duplicate dag_ids, but converted back to 'list'# because SQLAlchemy doesn't accept a set here.query=query.filter(cls.dag_id.in_(list(set(dag_ids))))ifonly_running:query=query.filter(cls.state==State.RUNNING)else:query=query.filter(cls.state.in_([State.RUNNING,State.QUEUED]))query=query.group_by(cls.dag_id)return{dag_id:countfordag_id,countinquery.all()}
@classmethod
[docs]defnext_dagruns_to_examine(cls,state:DagRunState,session:Session,max_number:int|None=None,):""" Return the next DagRuns that the scheduler should attempt to schedule. This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE" query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as the transaction is committed it will be unlocked. :rtype: list[airflow.models.DagRun] """fromairflow.models.dagimportDagModelifmax_numberisNone:max_number=cls.DEFAULT_DAGRUNS_TO_EXAMINE# TODO: Bake this query, it is run _A lot_query=(session.query(cls).filter(cls.state==state,cls.run_type!=DagRunType.BACKFILL_JOB).join(DagModel,DagModel.dag_id==cls.dag_id).filter(DagModel.is_paused==false(),DagModel.is_active==true()))ifstate==State.QUEUED:# For dag runs in the queued state, we check if they have reached the max_active_runs limit# and if so we drop themrunning_drs=(session.query(DagRun.dag_id,func.count(DagRun.state).label('num_running')).filter(DagRun.state==DagRunState.RUNNING).group_by(DagRun.dag_id).subquery())query=query.outerjoin(running_drs,running_drs.c.dag_id==DagRun.dag_id).filter(func.coalesce(running_drs.c.num_running,0)<DagModel.max_active_runs)query=query.order_by(nulls_first(cls.last_scheduling_decision,session=session),cls.execution_date,)ifnotsettings.ALLOW_FUTURE_EXEC_DATES:query=query.filter(DagRun.execution_date<=func.now())returnwith_row_locks(query.limit(max_number),of=cls,session=session,**skip_locked(session=session)
)@classmethod@provide_session
[docs]deffind(cls,dag_id:str|list[str]|None=None,run_id:Iterable[str]|None=None,execution_date:datetime|Iterable[datetime]|None=None,state:DagRunState|None=None,external_trigger:bool|None=None,no_backfills:bool=False,run_type:DagRunType|None=None,session:Session=NEW_SESSION,execution_start_date:datetime|None=None,execution_end_date:datetime|None=None,)->list[DagRun]:""" Returns a set of dag runs for the given search criteria. :param dag_id: the dag_id or list of dag_id to find dag runs for :param run_id: defines the run id for this dag run :param run_type: type of DagRun :param execution_date: the execution date :param state: the state of the dag run :param external_trigger: whether this dag run is externally triggered :param no_backfills: return no backfills (True), return all (False). Defaults to False :param session: database session :param execution_start_date: dag run that was executed from this date :param execution_end_date: dag run that was executed until this date """qry=session.query(cls)dag_ids=[dag_id]ifisinstance(dag_id,str)elsedag_idifdag_ids:qry=qry.filter(cls.dag_id.in_(dag_ids))ifis_container(run_id):qry=qry.filter(cls.run_id.in_(run_id))elifrun_idisnotNone:qry=qry.filter(cls.run_id==run_id)ifis_container(execution_date):qry=qry.filter(cls.execution_date.in_(execution_date))elifexecution_dateisnotNone:qry=qry.filter(cls.execution_date==execution_date)ifexecution_start_dateandexecution_end_date:qry=qry.filter(cls.execution_date.between(execution_start_date,execution_end_date))elifexecution_start_date:qry=qry.filter(cls.execution_date>=execution_start_date)elifexecution_end_date:qry=qry.filter(cls.execution_date<=execution_end_date)ifstate:qry=qry.filter(cls.state==state)ifexternal_triggerisnotNone:qry=qry.filter(cls.external_trigger==external_trigger)ifrun_type:qry=qry.filter(cls.run_type==run_type)ifno_backfills:qry=qry.filter(cls.run_type!=DagRunType.BACKFILL_JOB)returnqry.order_by(cls.execution_date).all()
@classmethod@provide_session
[docs]deffind_duplicate(cls,dag_id:str,run_id:str,execution_date:datetime,session:Session=NEW_SESSION,)->DagRun|None:""" Return an existing run for the DAG with a specific run_id or execution_date. *None* is returned if no such DAG run is found. :param dag_id: the dag_id to find duplicates for :param run_id: defines the run id for this dag run :param execution_date: the execution date :param session: database session """return(session.query(cls).filter(cls.dag_id==dag_id,or_(cls.run_id==run_id,cls.execution_date==execution_date),
).one_or_none())@staticmethod
[docs]defgenerate_run_id(run_type:DagRunType,execution_date:datetime)->str:"""Generate Run ID based on Run Type and Execution Date"""# _Ensure_ run_type is a DagRunType, not just a string from user codereturnDagRunType(run_type).generate_run_id(execution_date)
@provide_session
[docs]defget_task_instances(self,state:Iterable[TaskInstanceState|None]|None=None,session:Session=NEW_SESSION,)->list[TI]:"""Returns the task instances for this dag run"""tis=(session.query(TI).options(joinedload(TI.dag_run)).filter(TI.dag_id==self.dag_id,TI.run_id==self.run_id,))ifstate:ifisinstance(state,str):tis=tis.filter(TI.state==state)else:# this is required to deal with NULL valuesifState.NONEinstate:ifall(xisNoneforxinstate):tis=tis.filter(TI.state.is_(None))else:not_none_state=[sforsinstateifs]tis=tis.filter(or_(TI.state.in_(not_none_state),TI.state.is_(None)))else:tis=tis.filter(TI.state.in_(state))ifself.dagandself.dag.partial:tis=tis.filter(TI.task_id.in_(self.dag.task_ids))returntis.all()
@provide_session
[docs]defget_task_instance(self,task_id:str,session:Session=NEW_SESSION,*,map_index:int=-1,)->TI|None:""" Returns the task instance specified by task_id for this dag run :param task_id: the task id :param session: Sqlalchemy ORM Session """return(session.query(TI).filter_by(dag_id=self.dag_id,run_id=self.run_id,task_id=task_id,map_index=map_index)
.one_or_none())
[docs]defget_dag(self)->DAG:""" Returns the Dag associated with this DagRun. :return: DAG """ifnotself.dag:raiseAirflowException(f"The DAG (.dag) for {self} needs to be set")returnself.dag
@provide_session
[docs]defget_previous_dagrun(self,state:DagRunState|None=None,session:Session=NEW_SESSION)->DagRun|None:"""The previous DagRun, if there is one"""filters=[DagRun.dag_id==self.dag_id,DagRun.execution_date<self.execution_date,]ifstateisnotNone:filters.append(DagRun.state==state)returnsession.query(DagRun).filter(*filters).order_by(DagRun.execution_date.desc()).first()
@provide_session
[docs]defget_previous_scheduled_dagrun(self,session:Session=NEW_SESSION)->DagRun|None:"""The previous, SCHEDULED DagRun, if there is one"""return(session.query(DagRun).filter(DagRun.dag_id==self.dag_id,DagRun.execution_date<self.execution_date,DagRun.run_type!=DagRunType.MANUAL,).order_by(DagRun.execution_date.desc())
.first())@provide_session
[docs]defupdate_state(self,session:Session=NEW_SESSION,execute_callbacks:bool=True)->tuple[list[TI],DagCallbackRequest|None]:""" Determines the overall state of the DagRun based on the state of its TaskInstances. :param session: Sqlalchemy ORM Session :param execute_callbacks: Should dag callbacks (success/failure, SLA etc) be invoked directly (default: true) or recorded as a pending request in the ``callback`` property :return: Tuple containing tis that can be scheduled in the current loop & `callback` that needs to be executed """# Callback to execute in case of Task Failurescallback:DagCallbackRequest|None=Noneclass_UnfinishedStates(NamedTuple):tis:Sequence[TI]@classmethoddefcalculate(cls,unfinished_tis:Sequence[TI])->_UnfinishedStates:returncls(tis=unfinished_tis)@propertydefshould_schedule(self)->bool:return(bool(self.tis)andall(nott.task.depends_on_pastfortinself.tis)andall(t.task.max_active_tis_per_dagisNonefortinself.tis)andall(t.state!=TaskInstanceState.DEFERREDfortinself.tis))defrecalculate(self)->_UnfinishedStates:returnself._replace(tis=[tfortinself.tisift.stateinState.unfinished])start_dttm=timezone.utcnow()self.last_scheduling_decision=start_dttmwithStats.timer(f"dagrun.dependency-check.{self.dag_id}"):dag=self.get_dag()info=self.task_instance_scheduling_decisions(session)tis=info.tisschedulable_tis=info.schedulable_tischanged_tis=info.changed_tisfinished_tis=info.finished_tisunfinished=_UnfinishedStates.calculate(info.unfinished_tis)ifunfinished.should_schedule:are_runnable_tasks=schedulable_tisorchanged_tis# small speed upifnotare_runnable_tasks:are_runnable_tasks,changed_by_upstream=self._are_premature_tis(unfinished.tis,finished_tis,session)ifchanged_by_upstream:# Something changed, we need to recalculate!unfinished=unfinished.recalculate()leaf_task_ids={t.task_idfortindag.leaves}leaf_tis=[tifortiintisifti.task_idinleaf_task_idsifti.state!=TaskInstanceState.REMOVED]# if all roots finished and at least one failed, the run failedifnotunfinished.tisandany(leaf_ti.stateinState.failed_statesforleaf_tiinleaf_tis):self.log.error('Marking run %s failed',self)self.set_state(DagRunState.FAILED)ifexecute_callbacks:dag.handle_callback(self,success=False,reason='task_failure',session=session)elifdag.has_on_failure_callback:fromairflow.models.dagimportDagModeldag_model=DagModel.get_dagmodel(dag.dag_id,session)callback=DagCallbackRequest(full_filepath=dag.fileloc,dag_id=self.dag_id,run_id=self.run_id,is_failure_callback=True,processor_subdir=dag_model.processor_subdir,msg='task_failure',)# if all leaves succeeded and no unfinished tasks, the run succeededelifnotunfinished.tisandall(leaf_ti.stateinState.success_statesforleaf_tiinleaf_tis):self.log.info('Marking run %s successful',self)self.set_state(DagRunState.SUCCESS)ifexecute_callbacks:dag.handle_callback(self,success=True,reason='success',session=session)elifdag.has_on_success_callback:fromairflow.models.dagimportDagModeldag_model=DagModel.get_dagmodel(dag.dag_id,session)callback=DagCallbackRequest(full_filepath=dag.fileloc,dag_id=self.dag_id,run_id=self.run_id,is_failure_callback=False,processor_subdir=dag_model.processor_subdir,msg='success',)# if *all tasks* are deadlocked, the run failedelifunfinished.should_scheduleandnotare_runnable_tasks:self.log.error('Deadlock; marking run %s failed',self)self.set_state(DagRunState.FAILED)ifexecute_callbacks:dag.handle_callback(self,success=False,reason='all_tasks_deadlocked',session=session)elifdag.has_on_failure_callback:fromairflow.models.dagimportDagModeldag_model=DagModel.get_dagmodel(dag.dag_id,session)callback=DagCallbackRequest(full_filepath=dag.fileloc,dag_id=self.dag_id,run_id=self.run_id,is_failure_callback=True,processor_subdir=dag_model.processor_subdir,msg='all_tasks_deadlocked',)# finally, if the roots aren't done, the dag is still runningelse:self.set_state(DagRunState.RUNNING)ifself._state==DagRunState.FAILEDorself._state==DagRunState.SUCCESS:msg=("DagRun Finished: dag_id=%s, execution_date=%s, run_id=%s, ""run_start_date=%s, run_end_date=%s, run_duration=%s, ""state=%s, external_trigger=%s, run_type=%s, ""data_interval_start=%s, data_interval_end=%s, dag_hash=%s")self.log.info(msg,self.dag_id,self.execution_date,self.run_id,self.start_date,self.end_date,(self.end_date-self.start_date).total_seconds()ifself.start_dateandself.end_dateelseNone,self._state,self.external_trigger,self.run_type,self.data_interval_start,self.data_interval_end,self.dag_hash,)session.flush()self._emit_true_scheduling_delay_stats_for_finished_state(finished_tis)self._emit_duration_stats_for_finished_state()session.merge(self)# We do not flush here for performance reasons(It increases queries count by +20)returnschedulable_tis,callback
@provide_session
[docs]deftask_instance_scheduling_decisions(self,session:Session=NEW_SESSION)->TISchedulingDecision:tis=self.get_task_instances(session=session,state=State.task_states)self.log.debug("number of tis tasks for %s: %s task(s)",self,len(tis))def_filter_tis_and_exclude_removed(dag:DAG,tis:list[TI])->Iterable[TI]:"""Populate ``ti.task`` while excluding those missing one, marking them as REMOVED."""fortiintis:try:ti.task=dag.get_task(ti.task_id)exceptTaskNotFound:ifti.state!=State.REMOVED:self.log.error("Failed to get task for ti %s. Marking it as removed.",ti)ti.state=State.REMOVEDsession.flush()else:yieldtitis=list(_filter_tis_and_exclude_removed(self.get_dag(),tis))unfinished_tis=[tfortintisift.stateinState.unfinished]finished_tis=[tfortintisift.stateinState.finished]ifunfinished_tis:schedulable_tis=[utforutinunfinished_tisifut.stateinSCHEDULEABLE_STATES]self.log.debug("number of scheduleable tasks for %s: %s task(s)",self,len(schedulable_tis))schedulable_tis,changed_tis,expansion_happened=self._get_ready_tis(schedulable_tis,finished_tis,session=session,)# During expansion we may change some tis into non-schedulable# states, so we need to re-compute.ifexpansion_happened:new_unfinished_tis=[tfortinunfinished_tisift.stateinState.unfinished]finished_tis.extend(tfortinunfinished_tisift.stateinState.finished)unfinished_tis=new_unfinished_tiselse:schedulable_tis=[]changed_tis=FalsereturnTISchedulingDecision(tis=tis,schedulable_tis=schedulable_tis,changed_tis=changed_tis,unfinished_tis=unfinished_tis,finished_tis=finished_tis,
)def_get_ready_tis(self,schedulable_tis:list[TI],finished_tis:list[TI],session:Session,)->tuple[list[TI],bool,bool]:old_states={}ready_tis:list[TI]=[]changed_tis=Falseifnotschedulable_tis:returnready_tis,changed_tis,False# If we expand TIs, we need a new list so that we iterate over them too. (We can't alter# `schedulable_tis` in place and have the `for` loop pick them upadditional_tis:list[TI]=[]dep_context=DepContext(flag_upstream_failed=True,ignore_unmapped_tasks=True,# Ignore this Dep, as we will expand it if we can.finished_tis=finished_tis,)# Check dependencies.expansion_happened=Falseforschedulableinitertools.chain(schedulable_tis,additional_tis):old_state=schedulable.stateifnotschedulable.are_dependencies_met(session=session,dep_context=dep_context):old_states[schedulable.key]=old_statecontinue# If schedulable is from a mapped task, but not yet expanded, do it# now. This is called in two places: First and ideally in the mini# scheduler at the end of LocalTaskJob, and then as an "expansion of# last resort" in the scheduler to ensure that the mapped task is# correctly expanded before executed.ifschedulable.map_index<0andisinstance(schedulable.task,MappedOperator):expanded_tis,_=schedulable.task.expand_mapped_task(self.run_id,session=session)ifexpanded_tis:assertexpanded_tis[0]isschedulableadditional_tis.extend(expanded_tis[1:])expansion_happened=Trueifschedulable.stateinSCHEDULEABLE_STATES:task=schedulable.taskifisinstance(schedulable.task,MappedOperator):# Ensure the task indexes are completecreated=self._revise_mapped_task_indexes(task,session=session)ready_tis.extend(created)ready_tis.append(schedulable)# Check if any ti changed statetis_filter=TI.filter_for_tis(old_states)iftis_filterisnotNone:fresh_tis=session.query(TI).filter(tis_filter).all()changed_tis=any(ti.state!=old_states[ti.key]fortiinfresh_tis)returnready_tis,changed_tis,expansion_happeneddef_are_premature_tis(self,unfinished_tis:Sequence[TI],finished_tis:list[TI],session:Session,)->tuple[bool,bool]:dep_context=DepContext(flag_upstream_failed=True,ignore_in_retry_period=True,ignore_in_reschedule_period=True,finished_tis=finished_tis,)# there might be runnable tasks that are up for retry and for some reason(retry delay, etc) are# not ready yet so we set the flags to count them inreturn(any(ut.are_dependencies_met(dep_context=dep_context,session=session)forutinunfinished_tis),dep_context.have_changed_ti_states,)def_emit_true_scheduling_delay_stats_for_finished_state(self,finished_tis:list[TI])->None:""" This is a helper method to emit the true scheduling delay stats, which is defined as the time when the first task in DAG starts minus the expected DAG run datetime. This method will be used in the update_state method when the state of the DagRun is updated to a completed status (either success or failure). The method will find the first started task within the DAG and calculate the expected DagRun start time (based on dag.execution_date & dag.timetable), and minus these two values to get the delay. The emitted data may contains outlier (e.g. when the first task was cleared, so the second task's start_date will be used), but we can get rid of the outliers on the stats side through the dashboards tooling built. Note, the stat will only be emitted if the DagRun is a scheduler triggered one (i.e. external_trigger is False). """ifself.state==State.RUNNING:returnifself.external_trigger:returnifnotfinished_tis:returntry:dag=self.get_dag()ifnotdag.timetable.periodic:# We can't emit this metric if there is no following schedule to calculate from!returnordered_tis_by_start_date=[tifortiinfinished_tisifti.start_date]ordered_tis_by_start_date.sort(key=lambdati:ti.start_date,reverse=False)first_start_date=ordered_tis_by_start_date[0].start_dateiffirst_start_date:# TODO: Logically, this should be DagRunInfo.run_after, but the# information is not stored on a DagRun, only before the actual# execution on DagModel.next_dagrun_create_after. We should add# a field on DagRun for this instead of relying on the run# always happening immediately after the data interval.data_interval_end=dag.get_run_data_interval(self).endtrue_delay=first_start_date-data_interval_endiftrue_delay.total_seconds()>0:Stats.timing(f'dagrun.{dag.dag_id}.first_task_scheduling_delay',true_delay)exceptException:self.log.warning('Failed to record first_task_scheduling_delay metric:',exc_info=True)def_emit_duration_stats_for_finished_state(self):ifself.state==State.RUNNING:returnifself.start_dateisNone:self.log.warning('Failed to record duration of %s: start_date is not set.',self)returnifself.end_dateisNone:self.log.warning('Failed to record duration of %s: end_date is not set.',self)returnduration=self.end_date-self.start_dateifself.state==State.SUCCESS:Stats.timing(f'dagrun.duration.success.{self.dag_id}',duration)elifself.state==State.FAILED:Stats.timing(f'dagrun.duration.failed.{self.dag_id}',duration)@provide_session
[docs]defverify_integrity(self,*,session:Session=NEW_SESSION,):""" Verifies the DagRun by checking for removed tasks or tasks that are not in the database yet. It will set state to removed or add the task if required. :missing_indexes: A dictionary of task vs indexes that are missing. :param session: Sqlalchemy ORM Session """fromairflow.settingsimporttask_instance_mutation_hook# Set for the empty default in airflow.settings -- if it's not set this means it has been changed# Note: Literal[True, False] instead of bool because otherwise it doesn't correctly find the overload.hook_is_noop:Literal[True,False]=getattr(task_instance_mutation_hook,'is_noop',False)dag=self.get_dag()task_ids:set[str]=set()task_ids=self._check_for_removed_or_restored_tasks(dag,task_instance_mutation_hook,session=session)deftask_filter(task:Operator)->bool:returntask.task_idnotintask_idsand(self.is_backfillortask.start_date<=self.execution_dateand(task.end_dateisNoneorself.execution_date<=task.end_date))created_counts:dict[str,int]=defaultdict(int)# Get task creator functiontask_creator=self._get_task_creator(created_counts,task_instance_mutation_hook,hook_is_noop)# Create the missing tasks, including mapped taskstasks=self._create_tasks(dag,task_creator,task_filter,session=session)self._create_task_instances(dag.dag_id,tasks,created_counts,hook_is_noop,session=session)
def_check_for_removed_or_restored_tasks(self,dag:DAG,ti_mutation_hook,*,session:Session)->set[str]:""" Check for removed tasks/restored/missing tasks. :param dag: DAG object corresponding to the dagrun :param ti_mutation_hook: task_instance_mutation_hook function :param session: Sqlalchemy ORM Session :return: Task IDs in the DAG run """tis=self.get_task_instances(session=session)# check for removed or restored taskstask_ids=set()fortiintis:ti_mutation_hook(ti)task_ids.add(ti.task_id)task=Nonetry:task=dag.get_task(ti.task_id)should_restore_task=(taskisnotNone)andti.state==State.REMOVEDifshould_restore_task:self.log.info("Restoring task '%s' which was previously removed from DAG '%s'",ti,dag)Stats.incr(f"task_restored_to_dag.{dag.dag_id}",1,1)ti.state=State.NONEexceptAirflowException:ifti.state==State.REMOVED:pass# ti has already been removed, just ignore itelifself.state!=State.RUNNINGandnotdag.partial:self.log.warning("Failed to get task '%s' for dag '%s'. Marking it as removed.",ti,dag)Stats.incr(f"task_removed_from_dag.{dag.dag_id}",1,1)ti.state=State.REMOVEDcontinueifnottask.is_mapped:continuetask=cast("MappedOperator",task)num_mapped_tis=task.parse_time_mapped_ti_count# Check if the number of mapped literals has changed and we need to mark this TI as removedifnum_mapped_tisisnotNone:ifti.map_index>=num_mapped_tis:self.log.debug("Removing task '%s' as the map_index is longer than the literal mapping list (%s)",ti,num_mapped_tis,)ti.state=State.REMOVEDelifti.map_index<0:self.log.debug("Removing the unmapped TI '%s' as the mapping can now be performed",ti)ti.state=State.REMOVEDelse:# What if it is _now_ dynamically mapped, but wasn't before?task.run_time_mapped_ti_count.cache_clear()# type: ignore[attr-defined]total_length=task.run_time_mapped_ti_count(self.run_id,session=session)iftotal_lengthisNone:# Not all upstreams finished, so we can't tell what should be here. Remove everything.ifti.map_index>=0:self.log.debug("Removing the unmapped TI '%s' as the mapping can't be resolved yet",ti)ti.state=State.REMOVEDcontinue# Upstreams finished, check there aren't any extrasifti.map_index>=total_length:self.log.debug("Removing task '%s' as the map_index is longer than the resolved mapping list (%d)",ti,total_length,)ti.state=State.REMOVEDreturntask_ids@overloaddef_get_task_creator(self,created_counts:dict[str,int],ti_mutation_hook:Callable,hook_is_noop:Literal[True],)->Callable[[Operator,tuple[int,...]],Iterator[dict[str,Any]]]:...@overloaddef_get_task_creator(self,created_counts:dict[str,int],ti_mutation_hook:Callable,hook_is_noop:Literal[False],)->Callable[[Operator,tuple[int,...]],Iterator[TI]]:...def_get_task_creator(self,created_counts:dict[str,int],ti_mutation_hook:Callable,hook_is_noop:Literal[True,False],)->Callable[[Operator,tuple[int,...]],Iterator[dict[str,Any]]|Iterator[TI]]:""" Get the task creator function. This function also updates the created_counts dictionary with the number of tasks created. :param created_counts: Dictionary of task_type -> count of created TIs :param ti_mutation_hook: task_instance_mutation_hook function :param hook_is_noop: Whether the task_instance_mutation_hook is a noop """ifhook_is_noop:defcreate_ti_mapping(task:Operator,indexes:tuple[int,...])->Iterator[dict[str,Any]]:created_counts[task.task_type]+=1formap_indexinindexes:yieldTI.insert_mapping(self.run_id,task,map_index=map_index)creator=create_ti_mappingelse:defcreate_ti(task:Operator,indexes:tuple[int,...])->Iterator[TI]:formap_indexinindexes:ti=TI(task,run_id=self.run_id,map_index=map_index)ti_mutation_hook(ti)created_counts[ti.operator]+=1yieldticreator=create_tireturncreatordef_create_tasks(self,dag:DAG,task_creator:Callable[[Operator,tuple[int,...]],CreatedTasksType],task_filter:Callable[[Operator],bool],*,session:Session,)->CreatedTasksType:""" Create missing tasks -- and expand any MappedOperator that _only_ have literals as input :param dag: DAG object corresponding to the dagrun :param task_creator: a function that creates tasks :param task_filter: a function that filters tasks to create :param session: the session to use """defexpand_mapped_literals(task:Operator,sequence:Sequence[int]|None=None)->tuple[Operator,Sequence[int]]:ifnottask.is_mapped:return(task,(-1,))task=cast("MappedOperator",task)count=task.parse_time_mapped_ti_countortask.run_time_mapped_ti_count(self.run_id,session=session)ifnotcount:return(task,(-1,))ifsequence:return(task,sequence)return(task,range(count))tasks_and_map_idxs=map(expand_mapped_literals,filter(task_filter,dag.task_dict.values()))tasks:CreatedTasksType=itertools.chain.from_iterable(# type: ignoreitertools.starmap(task_creator,tasks_and_map_idxs)# type: ignore)returntasksdef_create_task_instances(self,dag_id:str,tasks:Iterator[dict[str,Any]]|Iterator[TI],created_counts:dict[str,int],hook_is_noop:bool,*,session:Session,)->None:""" Create the necessary task instances from the given tasks. :param dag_id: DAG ID associated with the dagrun :param tasks: the tasks to create the task instances from :param created_counts: a dictionary of number of tasks -> total ti created by the task creator :param hook_is_noop: whether the task_instance_mutation_hook is noop :param session: the session to use """# Fetch the information we need before handling the exception to avoid# PendingRollbackError due to the session being invalidated on exception# see https://github.com/apache/superset/pull/530run_id=self.run_idtry:ifhook_is_noop:session.bulk_insert_mappings(TI,tasks)else:session.bulk_save_objects(tasks)fortask_type,countincreated_counts.items():Stats.incr(f"task_instance_created-{task_type}",count)session.flush()exceptIntegrityError:self.log.info('Hit IntegrityError while creating the TIs for %s- %s',dag_id,run_id,exc_info=True,)self.log.info('Doing session rollback.')# TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.session.rollback()def_revise_mapped_task_indexes(self,task,session:Session):"""Check if task increased or reduced in length and handle appropriately"""fromairflow.models.taskinstanceimportTaskInstancefromairflow.settingsimporttask_instance_mutation_hooktask.run_time_mapped_ti_count.cache_clear()total_length=(task.parse_time_mapped_ti_countortask.run_time_mapped_ti_count(self.run_id,session=session)or0)query=session.query(TaskInstance.map_index).filter(TaskInstance.dag_id==self.dag_id,TaskInstance.task_id==task.task_id,TaskInstance.run_id==self.run_id,)existing_indexes={ifor(i,)inquery}missing_indexes=set(range(total_length)).difference(existing_indexes)removed_indexes=existing_indexes.difference(range(total_length))created_tis=[]ifmissing_indexes:forindexinmissing_indexes:ti=TaskInstance(task,run_id=self.run_id,map_index=index,state=None)self.log.debug("Expanding TIs upserted %s",ti)task_instance_mutation_hook(ti)ti=session.merge(ti)ti.refresh_from_task(task)session.flush()created_tis.append(ti)elifremoved_indexes:session.query(TaskInstance).filter(TaskInstance.dag_id==self.dag_id,TaskInstance.task_id==task.task_id,TaskInstance.run_id==self.run_id,TaskInstance.map_index.in_(removed_indexes),).update({TaskInstance.state:TaskInstanceState.REMOVED})session.flush()returncreated_tis@staticmethod
[docs]defget_run(session:Session,dag_id:str,execution_date:datetime)->DagRun|None:""" Get a single DAG Run :meta private: :param session: Sqlalchemy ORM Session :param dag_id: DAG ID :param execution_date: execution date :return: DagRun corresponding to the given dag_id and execution date if one exists. None otherwise. :rtype: airflow.models.DagRun """warnings.warn("This method is deprecated. Please use SQLAlchemy directly",RemovedInAirflow3Warning,stacklevel=2,)return(session.query(DagRun).filter(DagRun.dag_id==dag_id,DagRun.external_trigger==False,# noqaDagRun.execution_date==execution_date,
[docs]defget_latest_runs(cls,session=None)->list[DagRun]:"""Returns the latest DagRun for each DAG"""subquery=(session.query(cls.dag_id,func.max(cls.execution_date).label('execution_date')).group_by(cls.dag_id).subquery())return(session.query(cls).join(subquery,and_(cls.dag_id==subquery.c.dag_id,cls.execution_date==subquery.c.execution_date),
).all())@provide_session
[docs]defschedule_tis(self,schedulable_tis:Iterable[TI],session:Session=NEW_SESSION)->int:""" Set the given task instances in to the scheduled state. Each element of ``schedulable_tis`` should have it's ``task`` attribute already set. Any EmptyOperator without callbacks or outlets is instead set straight to the success state. All the TIs should belong to this DagRun, but this code is in the hot-path, this is not checked -- it is the caller's responsibility to call this function only with TIs from a single dag run. """# Get list of TI IDs that do not need to executed, these are# tasks using EmptyOperator and without on_execute_callback / on_success_callbackdummy_ti_ids=[]schedulable_ti_ids=[]fortiinschedulable_tis:if(ti.task.inherits_from_empty_operatorandnotti.task.on_execute_callbackandnotti.task.on_success_callbackandnotti.task.outlets):dummy_ti_ids.append(ti.task_id)else:schedulable_ti_ids.append((ti.task_id,ti.map_index))count=0ifschedulable_ti_ids:count+=(session.query(TI).filter(TI.dag_id==self.dag_id,TI.run_id==self.run_id,tuple_in_condition((TI.task_id,TI.map_index),schedulable_ti_ids),).update({TI.state:State.SCHEDULED},synchronize_session=False))# Tasks using EmptyOperator should not be executed, mark them as successifdummy_ti_ids:count+=(session.query(TI).filter(TI.dag_id==self.dag_id,TI.run_id==self.run_id,TI.task_id.in_(dummy_ti_ids),).update({TI.state:State.SUCCESS,TI.start_date:timezone.utcnow(),TI.end_date:timezone.utcnow(),TI.duration:0,},synchronize_session=False,))returncount
@provide_session
[docs]defget_log_template(self,*,session:Session=NEW_SESSION)->LogTemplate:ifself.log_template_idisNone:# DagRun created before LogTemplate introduction.template=session.query(LogTemplate).order_by(LogTemplate.id).first()else:template=session.query(LogTemplate).get(self.log_template_id)iftemplateisNone:raiseAirflowException(f"No log_template entry found for ID {self.log_template_id!r}. "f"Please make sure you set up the metadatabase correctly.")returntemplate
@provide_session
[docs]defget_log_filename_template(self,*,session:Session=NEW_SESSION)->str:warnings.warn("This method is deprecated. Please use get_log_template instead.",RemovedInAirflow3Warning,stacklevel=2,)returnself.get_log_template(session=session).filename