## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportcontextlibimporthashlibimportitertoolsimportloggingimportmathimportoperatorimportosimportsignalimporttracebackfromcollectionsimportdefaultdictfromcollections.abcimportCollection,Generator,Iterable,Mapping,SequencefromdatetimeimporttimedeltafromenumimportEnumfromfunctoolsimportcachefrompathlibimportPathfromtypingimportTYPE_CHECKING,Any,Callablefromurllib.parseimportquoteimportattrsimportdillimportjinja2importlazy_object_proxyimportuuid6fromjinja2importTemplateAssertionError,UndefinedErrorfromsqlalchemyimport(Column,Float,ForeignKey,ForeignKeyConstraint,Index,Integer,PrimaryKeyConstraint,String,Text,UniqueConstraint,and_,case,delete,extract,false,func,inspect,or_,select,text,tuple_,update,)fromsqlalchemy.dialectsimportpostgresqlfromsqlalchemy.ext.associationproxyimportassociation_proxyfromsqlalchemy.ext.hybridimporthybrid_propertyfromsqlalchemy.ext.mutableimportMutableDictfromsqlalchemy.ormimportlazyload,reconstructor,relationshipfromsqlalchemy.orm.attributesimportNO_VALUE,set_committed_valuefromsqlalchemy_utilsimportUUIDTypefromairflowimportsettingsfromairflow.assets.managerimportasset_managerfromairflow.configurationimportconffromairflow.exceptionsimport(AirflowException,AirflowFailException,AirflowInactiveAssetInInletOrOutletException,AirflowRescheduleException,AirflowSensorTimeout,AirflowSkipException,AirflowTaskTerminated,AirflowTaskTimeout,TaskDeferralError,TaskDeferred,UnmappableXComLengthPushed,UnmappableXComTypePushed,XComForMappingNotPushed,)fromairflow.listeners.listenerimportget_listener_managerfromairflow.models.assetimportAssetActive,AssetEvent,AssetModelfromairflow.models.baseimportBase,StringID,TaskInstanceDependenciesfromairflow.models.dagbagimportDagBagfromairflow.models.logimportLogfromairflow.models.renderedtifieldsimportget_serialized_template_fieldsfromairflow.models.taskinstancekeyimportTaskInstanceKeyfromairflow.models.taskmapimportTaskMapfromairflow.models.taskrescheduleimportTaskReschedulefromairflow.models.xcomimportLazyXComSelectSequence,XComModelfromairflow.plugins_managerimportintegrate_macros_pluginsfromairflow.sdk.execution_time.contextimportcontext_to_airflow_varsfromairflow.sentryimportSentryfromairflow.settingsimporttask_instance_mutation_hookfromairflow.statsimportStatsfromairflow.ti_deps.dep_contextimportDepContextfromairflow.ti_deps.dependencies_depsimportREQUEUEABLE_DEPS,RUNNING_DEPSfromairflow.utilsimporttimezonefromairflow.utils.emailimportsend_emailfromairflow.utils.helpersimportprune_dict,render_template_to_stringfromairflow.utils.log.logging_mixinimportLoggingMixinfromairflow.utils.netimportget_hostnamefromairflow.utils.platformimportgetuserfromairflow.utils.retriesimportrun_with_db_retriesfromairflow.utils.sessionimportNEW_SESSION,create_session,provide_sessionfromairflow.utils.span_statusimportSpanStatusfromairflow.utils.sqlalchemyimportExecutorConfigType,ExtendedJSON,UtcDateTimefromairflow.utils.stateimportDagRunState,State,TaskInstanceStatefromairflow.utils.task_instance_sessionimportset_current_task_instance_sessionfromairflow.utils.timeoutimporttimeoutfromairflow.utils.xcomimportXCOM_RETURN_KEY
classTaskReturnCode(Enum):""" Enum to signal manner of exit for task run command. :meta private: """DEFERRED=100"""When task exits with deferral to trigger."""@provide_sessiondef_add_log(event,task_instance=None,owner=None,owner_display_name=None,extra=None,session:Session=NEW_SESSION,**kwargs,):session.add(Log(event,task_instance,owner,owner_display_name,extra,**kwargs,))@contextlib.contextmanager
[docs]defset_current_context(context:Context)->Generator[Context,None,None]:""" Set the current execution context to the provided context object. This method should be called once per Task execution, before calling operator.execute. """fromairflow.sdk.definitions._internal.contextmanagerimport_CURRENT_CONTEXT_CURRENT_CONTEXT.append(context)try:yieldcontextfinally:expected_state=_CURRENT_CONTEXT.pop()ifexpected_state!=context:log.warning("Current context is not equal to the state at context stack. Expected=%s, got=%s",context,expected_state,)
def_stop_remaining_tasks(*,task_instance:TaskInstance,task_teardown_map=None,session:Session):""" Stop non-teardown tasks in dag. :meta private: """ifnottask_instance.dag_run:raiseValueError("``task_instance`` must have ``dag_run`` set")tis=task_instance.dag_run.get_task_instances(session=session)ifTYPE_CHECKING:asserttask_instance.taskassertisinstance(task_instance.task.dag,DAG)fortiintis:ifti.task_id==task_instance.task_idorti.statein(TaskInstanceState.SUCCESS,TaskInstanceState.FAILED,):continueiftask_teardown_map:teardown=task_teardown_map[ti.task_id]else:task=task_instance.task.dag.task_dict[ti.task_id]teardown=task.is_teardownifnotteardown:ifti.state==TaskInstanceState.RUNNING:log.info("Forcing task %s to fail due to dag's `fail_fast` setting",ti.task_id)msg="Forcing task to fail due to dag's `fail_fast` setting."session.add(Log(event="fail task",extra=msg,task_instance=ti.key))ti.error(session)else:log.info("Setting task %s to SKIPPED due to dag's `fail_fast` setting.",ti.task_id)msg="Skipping task due to dag's `fail_fast` setting."session.add(Log(event="skip task",extra=msg,task_instance=ti.key))ti.set_state(state=TaskInstanceState.SKIPPED,session=session)else:log.info("Not skipping teardown task '%s'",ti.task_id)
[docs]defclear_task_instances(tis:list[TaskInstance],session:Session,dag:DAG|None=None,dag_run_state:DagRunState|Literal[False]=DagRunState.QUEUED,)->None:""" Clear a set of task instances, but make sure the running ones get killed. Also sets Dagrun's `state` to QUEUED and `start_date` to the time of execution. But only for finished DRs (SUCCESS and FAILED). Doesn't clear DR's `state` and `start_date`for running DRs (QUEUED and RUNNING) because clearing the state for already running DR is redundant and clearing `start_date` affects DR's duration. :param tis: a list of task instances :param session: current session :param dag_run_state: state to set finished DagRuns to. If set to False, DagRuns state will not be changed. :param dag: DAG object """# taskinstance uuids:task_instance_ids:list[str]=[]dag_bag=DagBag(read_dags_from_db=True)fortiintis:task_instance_ids.append(ti.id)ti.prepare_db_for_next_try(session)ifti.state==TaskInstanceState.RUNNING:# If a task is cleared when running, set its state to RESTARTING so that# the task is terminated and becomes eligible for retry.ti.state=TaskInstanceState.RESTARTINGelse:ti_dag=dagifdaganddag.dag_id==ti.dag_idelsedag_bag.get_dag(ti.dag_id,session=session)task_id=ti.task_idifti_dagandti_dag.has_task(task_id):task=ti_dag.get_task(task_id)ti.refresh_from_task(task)ifTYPE_CHECKING:assertti.taskti.max_tries=ti.try_number+task.retrieselse:# Ignore errors when updating max_tries if the DAG or# task are not found since database records could be# outdated. We make max_tries the maximum value of its# original max_tries or the last attempted try number.ti.max_tries=max(ti.max_tries,ti.try_number)ti.state=Noneti.external_executor_id=Noneti.clear_next_method_args()session.merge(ti)ifdag_run_stateisnotFalseandtis:fromairflow.models.dagrunimportDagRun# Avoid circular importrun_ids_by_dag_id=defaultdict(set)forinstanceintis:run_ids_by_dag_id[instance.dag_id].add(instance.run_id)drs=(session.query(DagRun).filter(or_(and_(DagRun.dag_id==dag_id,DagRun.run_id.in_(run_ids))fordag_id,run_idsinrun_ids_by_dag_id.items())).all())dag_run_state=DagRunState(dag_run_state)# Validate the state value.fordrindrs:ifdr.stateinState.finished_dr_states:dr.state=dag_run_statedr.start_date=timezone.utcnow()ifdag_run_state==DagRunState.QUEUED:dr.last_scheduling_decision=Nonedr.start_date=Nonedr.clear_number+=1session.flush()
def_creator_note(val):"""Creator the ``note`` association proxy."""ifisinstance(val,str):returnTaskInstanceNote(content=val)ifisinstance(val,dict):returnTaskInstanceNote(**val)returnTaskInstanceNote(*val)@provide_sessiondef_record_task_map_for_downstreams(*,task_instance:TaskInstance,task:Operator,value:Any,session:Session,)->None:""" Record the task map for downstream tasks. :param task_instance: the task instance :param task: The task object :param dag: the dag associated with the task :param value: The value :param session: SQLAlchemy ORM Session :meta private: """fromairflow.sdk.definitions.mappedoperatorimportMappedOperator,is_mappable_valueifnext(task.iter_mapped_dependants(),None)isNone:# No mapped dependants, no need to validate.return# TODO: We don't push TaskMap for mapped task instances because it's not# currently possible for a downstream to depend on one individual mapped# task instance. This will change when we implement task mapping inside# a mapped task group, and we'll need to further analyze the case.ifisinstance(task,MappedOperator):returnifvalueisNone:raiseXComForMappingNotPushed()ifnotis_mappable_value(value):raiseUnmappableXComTypePushed(value)task_map=TaskMap.from_task_instance_xcom(task_instance,value)max_map_length=conf.getint("core","max_map_length",fallback=1024)iftask_map.length>max_map_length:raiseUnmappableXComLengthPushed(value,max_map_length)session.merge(task_map)def_get_email_subject_content(*,task_instance:TaskInstance|RuntimeTaskInstanceProtocol,exception:BaseException,task:BaseOperator|None=None,)->tuple[str,str,str]:""" Get the email subject content for exceptions. :param task_instance: the task instance :param exception: the exception sent in the email :param task: :meta private: """# For a ti from DB (without ti.task), return the default valueiftaskisNone:task=getattr(task_instance,"task")use_default=taskisNoneexception_html=str(exception).replace("\n","<br>")default_subject="Airflow alert: {{ti}}"# For reporting purposes, we report based on 1-indexed,# not 0-indexed lists (i.e. Try 1 instead of# Try 0 for the first attempt).default_html_content=("Try {{try_number}} out of {{max_tries + 1}}<br>""Exception:<br>{{exception_html}}<br>"'Log: <a href="{{ti.log_url}}">Link</a><br>'"Host: {{ti.hostname}}<br>"'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>')default_html_content_err=("Try {{try_number}} out of {{max_tries + 1}}<br>""Exception:<br>Failed attempt to attach error logs<br>"'Log: <a href="{{ti.log_url}}">Link</a><br>'"Host: {{ti.hostname}}<br>"'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>')additional_context:dict[str,Any]={"exception":exception,"exception_html":exception_html,"try_number":task_instance.try_number,"max_tries":task_instance.max_tries,}ifuse_default:default_context={"ti":task_instance,**additional_context}jinja_env=jinja2.Environment(loader=jinja2.FileSystemLoader(os.path.dirname(__file__)),autoescape=True)subject=jinja_env.from_string(default_subject).render(**default_context)html_content=jinja_env.from_string(default_html_content).render(**default_context)html_content_err=jinja_env.from_string(default_html_content_err).render(**default_context)else:fromairflow.sdk.definitions._internal.templaterimportSandboxedEnvironmentfromairflow.utils.contextimportcontext_mergeifTYPE_CHECKING:asserttask_instance.task# Use the DAG's get_template_env() to set force_sandboxed. Don't add# the flag to the function on task object -- that function can be# overridden, and adding a flag breaks backward compatibility.dag=task_instance.task.get_dag()ifdag:jinja_env=dag.get_template_env(force_sandboxed=True)else:jinja_env=SandboxedEnvironment(cache_size=0)jinja_context=task_instance.get_template_context()context_merge(jinja_context,additional_context)defrender(key:str,content:str)->str:ifconf.has_option("email",key):path=conf.get_mandatory_value("email",key)try:withopen(path)asf:content=f.read()exceptFileNotFoundError:log.warning("Could not find email template file '%s'. Using defaults...",path)exceptOSError:log.exception("Error while using email template %s. Using defaults...",path)returnrender_template_to_string(jinja_env.from_string(content),jinja_context)subject=render("subject_template",default_subject)html_content=render("html_content_template",default_html_content)html_content_err=render("html_content_template",default_html_content_err)returnsubject,html_content,html_content_errdef_run_finished_callback(*,callbacks:None|TaskStateChangeCallback|Sequence[TaskStateChangeCallback],context:Context,)->None:""" Run callback after task finishes. :param callbacks: callbacks to run :param context: callbacks context :meta private: """ifcallbacks:callbacks=callbacksifisinstance(callbacks,Sequence)else[callbacks]defget_callback_representation(callback:TaskStateChangeCallback)->Any:withcontextlib.suppress(AttributeError):returncallback.__name__withcontextlib.suppress(AttributeError):returncallback.__class__.__name__returncallbackforidx,callbackinenumerate(callbacks):callback_repr=get_callback_representation(callback)log.info("Executing callback at index %d: %s",idx,callback_repr)try:callback(context)exceptException:log.exception("Error in callback at index %d: %s",idx,callback_repr)def_log_state(*,task_instance:TaskInstance,lead_msg:str="")->None:""" Log task state. :param task_instance: the task instance :param lead_msg: lead message :meta private: """params=[lead_msg,str(task_instance.state).upper(),task_instance.dag_id,task_instance.task_id,task_instance.run_id,]message="%sMarking task as %s. dag_id=%s, task_id=%s, run_id=%s, "iftask_instance.map_index>=0:params.append(task_instance.map_index)message+="map_index=%d, "message+="logical_date=%s, start_date=%s, end_date=%s"log.info(message,*params,_date_or_empty(task_instance=task_instance,attr="logical_date"),_date_or_empty(task_instance=task_instance,attr="start_date"),_date_or_empty(task_instance=task_instance,attr="end_date"),stacklevel=2,)def_date_or_empty(*,task_instance:TaskInstance,attr:str)->str:""" Fetch a date attribute or None of it does not exist. :param task_instance: the task instance :param attr: the attribute name :meta private: """result:datetime|None=getattr(task_instance,attr,None)returnresult.strftime("%Y%m%dT%H%M%S")ifresultelse""
[docs]defuuid7()->str:"""Generate a new UUID7 string."""returnstr(uuid6.uuid7())
[docs]classTaskInstance(Base,LoggingMixin):""" Task instances store the state of a task instance. This table is the authority and single source of truth around what tasks have run and the state they are in. The SqlAlchemy model doesn't have a SqlAlchemy foreign key to the task or dag model deliberately to have more control over transactions. Database transactions on this table should insure double triggers and any confusion around what task instances are or aren't ready to run even while multiple schedulers may be firing task instances. A value of -1 in map_index represents any of: a TI without mapped tasks; a TI with mapped tasks that has yet to be expanded (state=pending); a TI with mapped tasks that expanded to an empty list (state=skipped). """
raw:bool|None=None"""Indicate to FileTaskHandler that logging context should be set up for trigger logging. :meta private: """_logger_name="airflow.task"def__init__(self,task:Operator,run_id:str|None=None,state:str|None=None,map_index:int=-1,dag_version_id:UUIDType|None=None,):super().__init__()self.dag_id=task.dag_idself.task_id=task.task_idself.map_index=map_indexself.dag_version_id=dag_version_idself.refresh_from_task(task)ifTYPE_CHECKING:assertself.task# init_on_load will config the logself.init_on_load()self.run_id=run_idself.try_number=0self.max_tries=self.task.retriesifnotself.id:self.id=uuid7()self.unixname=getuser()ifstate:self.state=stateself.hostname=""# Is this TaskInstance being currently running within `airflow tasks run --raw`.# Not persisted to the database so only valid for the current processself.raw=False# can be changed when calling 'run'self.test_mode=Falseself.context_carrier={}
[docs]definit_on_load(self)->None:"""Initialize the attributes that aren't stored in the DB."""self.test_mode=False# can be changed when calling 'run'
@property
[docs]defoperator_name(self)->str|None:"""@property: use a more friendly display name for the operator, if set."""returnself.custom_operator_nameorself.operator
@staticmethoddef_command_as_list(ti:TaskInstance,mark_success:bool=False,ignore_all_deps:bool=False,ignore_task_deps:bool=False,ignore_depends_on_past:bool=False,wait_for_past_depends_before_skipping:bool=False,ignore_ti_state:bool=False,local:bool=False,raw:bool=False,pool:str|None=None,cfg_path:str|None=None,)->list[str]:dag:DAG|DagModel|None# Use the dag if we have it, else fallback to the ORM dag_model, which might not be loadedifhasattr(ti,"task")andgetattr(ti.task,"dag",None)isnotNone:ifTYPE_CHECKING:assertti.taskassertisinstance(ti.task.dag,SchedulerDAG)dag=ti.task.dagelse:dag=ti.dag_modelifdagisNone:raiseValueError("DagModel is empty")path=Noneifdag.relative_fileloc:path=Path(dag.relative_fileloc)ifpath:ifnotpath.is_absolute():path="DAGS_FOLDER"/pathreturnTaskInstance.generate_command(ti.dag_id,ti.task_id,run_id=ti.run_id,mark_success=mark_success,ignore_all_deps=ignore_all_deps,ignore_task_deps=ignore_task_deps,ignore_depends_on_past=ignore_depends_on_past,wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,ignore_ti_state=ignore_ti_state,local=local,file_path=path,raw=raw,pool=pool,cfg_path=cfg_path,map_index=ti.map_index,)
[docs]defcommand_as_list(self,mark_success:bool=False,ignore_all_deps:bool=False,ignore_task_deps:bool=False,ignore_depends_on_past:bool=False,wait_for_past_depends_before_skipping:bool=False,ignore_ti_state:bool=False,local:bool=False,raw:bool=False,pool:str|None=None,cfg_path:str|None=None,)->list[str]:""" Return a command that can be executed anywhere where airflow is installed. This command is part of the message sent to executors by the orchestrator. """returnTaskInstance._command_as_list(ti=self,mark_success=mark_success,ignore_all_deps=ignore_all_deps,ignore_task_deps=ignore_task_deps,ignore_depends_on_past=ignore_depends_on_past,wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,ignore_ti_state=ignore_ti_state,local=local,raw=raw,pool=pool,cfg_path=cfg_path,)
@staticmethod
[docs]defgenerate_command(dag_id:str,task_id:str,run_id:str,mark_success:bool=False,ignore_all_deps:bool=False,ignore_depends_on_past:bool=False,wait_for_past_depends_before_skipping:bool=False,ignore_task_deps:bool=False,ignore_ti_state:bool=False,local:bool=False,file_path:PurePath|str|None=None,raw:bool=False,pool:str|None=None,cfg_path:str|None=None,map_index:int=-1,)->list[str]:""" Generate the shell command required to execute this task instance. :param dag_id: DAG ID :param task_id: Task ID :param run_id: The run_id of this task's DagRun :param mark_success: Whether to mark the task as successful :param ignore_all_deps: Ignore all ignorable dependencies. Overrides the other ignore_* parameters. :param ignore_depends_on_past: Ignore depends_on_past parameter of DAGs (e.g. for Backfills) :param wait_for_past_depends_before_skipping: Wait for past depends before marking the ti as skipped :param ignore_task_deps: Ignore task-specific dependencies such as depends_on_past and trigger rule :param ignore_ti_state: Ignore the task instance's previous failure/success :param local: Whether to run the task locally :param file_path: path to the file containing the DAG definition :param raw: raw mode (needs more details) :param pool: the Airflow pool that the task should run in :param cfg_path: the Path to the configuration file :return: shell command that can be used to run the task instance """cmd=["airflow","tasks","run",dag_id,task_id,run_id]ifmark_success:cmd.extend(["--mark-success"])ifignore_all_deps:cmd.extend(["--ignore-all-dependencies"])ifignore_task_deps:cmd.extend(["--ignore-dependencies"])ifignore_depends_on_past:cmd.extend(["--depends-on-past","ignore"])elifwait_for_past_depends_before_skipping:cmd.extend(["--depends-on-past","wait"])ifignore_ti_state:cmd.extend(["--force"])iflocal:cmd.extend(["--local"])ifpool:cmd.extend(["--pool",pool])ifraw:cmd.extend(["--raw"])iffile_path:cmd.extend(["--subdir",os.fspath(file_path)])ifcfg_path:cmd.extend(["--cfg-path",cfg_path])ifmap_index!=-1:cmd.extend(["--map-index",str(map_index)])returncmd
@property
[docs]deflog_url(self)->str:"""Log URL for TaskInstance."""run_id=quote(self.run_id)base_url=conf.get("api","base_url",fallback="http://localhost:8080/")map_index=f"/mapped/{self.map_index}"ifself.map_index>=0else""try_number=f"?try_number={self.try_number}"ifself.try_number>0else""_log_uri=f"{base_url}dags/{self.dag_id}/runs/{run_id}/tasks/{self.task_id}{map_index}{try_number}"return_log_uri
@property
[docs]defmark_success_url(self)->str:"""URL to mark TI success."""returnself.log_url
@provide_session
[docs]deferror(self,session:Session=NEW_SESSION)->None:""" Force the task instance's state to FAILED in the database. :param session: SQLAlchemy ORM Session """self.log.error("Recording the task instance as FAILED")self.state=TaskInstanceState.FAILEDsession.merge(self)session.commit()
@classmethod@provide_session
[docs]defget_task_instance(cls,dag_id:str,run_id:str,task_id:str,map_index:int,lock_for_update:bool=False,session:Session=NEW_SESSION,)->TaskInstance|None:query=(session.query(TaskInstance).options(lazyload(TaskInstance.dag_run))# lazy load dag run to avoid locking it.filter_by(dag_id=dag_id,run_id=run_id,task_id=task_id,map_index=map_index,))iflock_for_update:forattemptinrun_with_db_retries(logger=cls.logger()):withattempt:returnquery.with_for_update().one_or_none()else:returnquery.one_or_none()returnNone
@provide_session
[docs]defrefresh_from_db(self,session:Session=NEW_SESSION,lock_for_update:bool=False,keep_local_changes:bool=False)->None:""" Refresh the task instance from the database based on the primary key. :param session: SQLAlchemy ORM Session :param lock_for_update: if True, indicates that the database should lock the TaskInstance (issuing a FOR UPDATE clause) until the session is committed. :param keep_local_changes: Force all attributes to the values from the database if False (the default), or if True don't overwrite locally set attributes """query=select(# Select the columns, not the ORM object, to bypass any session/ORM caching layercforcinTaskInstance.__table__.columns).filter_by(dag_id=self.dag_id,run_id=self.run_id,task_id=self.task_id,map_index=self.map_index,)iflock_for_update:query=query.with_for_update()source=session.execute(query).mappings().one_or_none()ifsource:target_state=inspect(self)iftarget_stateisNone:raiseRuntimeError(f"Unable to inspect SQLAlchemy state of {type(self)}: {self}")# To deal with `@hybrid_property` we need to get the names from `mapper.columns`forattr_name,colintarget_state.mapper.columns.items():ifkeep_local_changesandtarget_state.attrs[attr_name].history.has_changes():continueset_committed_value(self,attr_name,source[col.name])# ID may have changed, update SQLAs state and object trackingnewkey=session.identity_key(type(self),(self.id,))# Delete anything under the new keyifnewkey!=target_state.key:old=session.identity_map.get(newkey)ifoldisnotselfandoldisnotNone:session.expunge(old)target_state.key=newkeyiftarget_state.attrs.dag_run.loaded_valueisnotNO_VALUE:dr_key=session.identity_key(type(self.dag_run),(self.dag_run.id,))if(dr:=session.identity_map.get(dr_key))isnotNone:set_committed_value(self,"dag_run",dr)else:self.state=None
[docs]defrefresh_from_task(self,task:Operator,pool_override:str|None=None)->None:""" Copy common attributes from the given task. :param task: The task object to copy from :param pool_override: Use the pool_override instead of task's pool """self.task=taskself.queue=task.queueself.pool=pool_overrideortask.poolself.pool_slots=task.pool_slotswithcontextlib.suppress(Exception):# This method is called from the different places, and sometimes the TI is not fully initializedself.priority_weight=self.task.weight_rule.get_weight(self)# type: ignore[arg-type]self.run_as_user=task.run_as_user# Do not set max_tries to task.retries here because max_tries is a cumulative# value that needs to be stored in the db.self.executor=task.executorself.executor_config=task.executor_configself.operator=task.task_typeself.custom_operator_name=getattr(task,"custom_operator_name",None)# Re-apply cluster policy here so that task default do not overload previous datatask_instance_mutation_hook(self)
@staticmethod@provide_sessiondef_clear_xcom_data(ti:TaskInstance,session:Session=NEW_SESSION)->None:""" Clear all XCom data from the database for the task instance. If the task is unmapped, all XComs matching this task ID in the same DAG run are removed. If the task is mapped, only the one with matching map index is removed. :param ti: The TI for which we need to clear xcoms. :param session: SQLAlchemy ORM Session """ti.log.debug("Clearing XCom data")ifti.map_index<0:map_index:int|None=Noneelse:map_index=ti.map_indexXComModel.clear(dag_id=ti.dag_id,task_id=ti.task_id,run_id=ti.run_id,map_index=map_index,session=session,)@provide_session
[docs]defkey(self)->TaskInstanceKey:"""Returns a tuple that identifies the task instance uniquely."""returnTaskInstanceKey(self.dag_id,self.task_id,self.run_id,self.try_number,self.map_index)
@provide_session
[docs]defset_state(self,state:str|None,session:Session=NEW_SESSION)->bool:""" Set TaskInstance state. :param state: State to set for the TI :param session: SQLAlchemy ORM Session :return: Was the state changed """ifself.state==state:returnFalsecurrent_time=timezone.utcnow()self.log.debug("Setting task state for %s to %s",self,state)ifselfnotinsession:self.refresh_from_db(session)self.state=stateself.start_date=self.start_dateorcurrent_timeifself.stateinState.finishedorself.state==TaskInstanceState.UP_FOR_RETRY:self.end_date=self.end_dateorcurrent_timeself.duration=(self.end_date-self.start_date).total_seconds()session.merge(self)session.flush()returnTrue
@property
[docs]defis_premature(self)->bool:"""Returns whether a task is in UP_FOR_RETRY state and its retry interval has elapsed."""# is the task still in the retry waiting period?returnself.state==TaskInstanceState.UP_FOR_RETRYandnotself.ready_for_retry()
[docs]defprepare_db_for_next_try(self,session:Session):"""Update the metadata with all the records needed to put this TI in queued for the next try."""fromairflow.models.taskinstancehistoryimportTaskInstanceHistoryTaskInstanceHistory.record_ti(self,session=session)session.execute(delete(TaskReschedule).filter_by(ti_id=self.id))self.id=uuid7()
@provide_session
[docs]defare_dependents_done(self,session:Session=NEW_SESSION)->bool:""" Check whether the immediate dependents of this task instance have succeeded or have been skipped. This is meant to be used by wait_for_downstream. This is useful when you do not want to start processing the next schedule of a task until the dependents are done. For instance, if the task DROPs and recreates a table. :param session: SQLAlchemy ORM Session """task=self.taskifTYPE_CHECKING:asserttaskifnottask.downstream_task_ids:returnTrueti=session.query(func.count(TaskInstance.task_id)).filter(TaskInstance.dag_id==self.dag_id,TaskInstance.task_id.in_(task.downstream_task_ids),TaskInstance.run_id==self.run_id,TaskInstance.state.in_((TaskInstanceState.SKIPPED,TaskInstanceState.SUCCESS)),)count=ti[0][0]returncount==len(task.downstream_task_ids)
@provide_session
[docs]defget_previous_dagrun(self,state:DagRunState|None=None,session:Session|None=None,)->DagRun|None:""" Return the DagRun that ran before this task instance's DagRun. :param state: If passed, it only take into account instances of a specific state. :param session: SQLAlchemy ORM Session. """ifTYPE_CHECKING:assertself.taskdag=self.task.dagifdagisNone:returnNoneifTYPE_CHECKING:assertisinstance(dag,SchedulerDAG)dr=self.get_dagrun(session=session)dr.dag=dagfromairflow.models.dagrunimportDagRun# Avoid circular import# We always ignore schedule in dagrun lookup when `state` is given# or the DAG is never scheduled. For legacy reasons, when# `catchup=True`, we use `get_previous_scheduled_dagrun` unless# `ignore_schedule` is `True`.ignore_schedule=stateisnotNoneornotdag.timetable.can_be_scheduledifdag.catchupisTrueandnotignore_schedule:last_dagrun=DagRun.get_previous_scheduled_dagrun(dr.id,session=session)else:last_dagrun=DagRun.get_previous_dagrun(dag_run=dr,session=session,state=state)iflast_dagrun:returnlast_dagrunreturnNone
@provide_session
[docs]defget_previous_ti(self,state:DagRunState|None=None,session:Session=NEW_SESSION,)->TaskInstance|None:""" Return the task instance for the task that ran before this task instance. :param session: SQLAlchemy ORM Session :param state: If passed, it only take into account instances of a specific state. """dagrun=self.get_previous_dagrun(state,session=session)ifdagrunisNone:returnNonereturndagrun.get_task_instance(self.task_id,session=session)
@provide_session
[docs]defare_dependencies_met(self,dep_context:DepContext|None=None,session:Session=NEW_SESSION,verbose:bool=False)->bool:""" Are all conditions met for this task instance to be run given the context for the dependencies. (e.g. a task instance being force run from the UI will ignore some dependencies). :param dep_context: The execution context that determines the dependencies that should be evaluated. :param session: database session :param verbose: whether log details on failed dependencies on info or debug log level """dep_context=dep_contextorDepContext()failed=Falseverbose_aware_logger=self.log.infoifverboseelseself.log.debugfordep_statusinself.get_failed_dep_statuses(dep_context=dep_context,session=session):failed=Trueverbose_aware_logger("Dependencies not met for %s, dependency '%s' FAILED: %s",self,dep_status.dep_name,dep_status.reason,)iffailed:returnFalseverbose_aware_logger("Dependencies all met for dep_context=%s ti=%s",dep_context.description,self)returnTrue
@provide_session
[docs]defget_failed_dep_statuses(self,dep_context:DepContext|None=None,session:Session=NEW_SESSION):"""Get failed Dependencies."""ifTYPE_CHECKING:assertisinstance(self.task,BaseOperator)ifnothasattr(self.task,"deps"):# These deps are not on BaseOperator since they are only needed and evaluated# in the scheduler and not needed at the Runtime.fromairflow.serialization.serialized_objectsimportSerializedBaseOperatorserialized_op=SerializedBaseOperator.deserialize_operator(SerializedBaseOperator.serialize_operator(self.task))setattr(self.task,"deps",serialized_op.deps)# type: ignore[union-attr]dep_context=dep_contextorDepContext()fordepindep_context.deps|self.task.deps:fordep_statusindep.get_dep_statuses(self,session,dep_context):self.log.debug("%s dependency '%s' PASSED: %s, %s",self,dep_status.dep_name,dep_status.passed,dep_status.reason,)ifnotdep_status.passed:yielddep_status
[docs]defnext_retry_datetime(self):""" Get datetime of the next retry if the task instance fails. For exponential backoff, retry_delay is used as base and will be converted to seconds. """fromairflow.sdk.definitions._internal.abstractoperatorimportMAX_RETRY_DELAYdelay=self.task.retry_delayifself.task.retry_exponential_backoff:# If the min_backoff calculation is below 1, it will be converted to 0 via int. Thus,# we must round up prior to converting to an int, otherwise a divide by zero error# will occur in the modded_hash calculation.# this probably gives unexpected results if a task instance has previously been cleared,# because try_number can increase without boundmin_backoff=math.ceil(delay.total_seconds()*(2**(self.try_number-1)))# In the case when delay.total_seconds() is 0, min_backoff will not be rounded up to 1.# To address this, we impose a lower bound of 1 on min_backoff. This effectively makes# the ceiling function unnecessary, but the ceiling function was retained to avoid# introducing a breaking change.ifmin_backoff<1:min_backoff=1# deterministic per task instanceti_hash=int(hashlib.sha1(f"{self.dag_id}#{self.task_id}#{self.logical_date}#{self.try_number}".encode(),usedforsecurity=False,).hexdigest(),16,)# between 1 and 1.0 * delay * (2^retry_number)modded_hash=min_backoff+ti_hash%min_backoff# timedelta has a maximum representable value. The exponentiation# here means this value can be exceeded after a certain number# of tries (around 50 if the initial delay is 1s, even fewer if# the delay is larger). Cap the value here before creating a# timedelta object so the operation doesn't fail with "OverflowError".delay_backoff_in_seconds=min(modded_hash,MAX_RETRY_DELAY)delay=timedelta(seconds=delay_backoff_in_seconds)ifself.task.max_retry_delay:delay=min(self.task.max_retry_delay,delay)returnself.end_date+delay
[docs]defready_for_retry(self)->bool:"""Check on whether the task instance is in the right state and timeframe to be retried."""returnself.state==TaskInstanceState.UP_FOR_RETRYandself.next_retry_datetime()<timezone.utcnow()
[docs]defget_dagrun(self,session:Session=NEW_SESSION)->DagRun:""" Return the DagRun for this TaskInstance. :param session: SQLAlchemy ORM Session :return: DagRun """info=inspect(self)ifinfo.attrs.dag_run.loaded_valueisnotNO_VALUE:ifgetattr(self,"task",None)isnotNone:ifTYPE_CHECKING:assertself.taskself.dag_run.dag=self.task.dagreturnself.dag_rundr=self._get_dagrun(self.dag_id,self.run_id,session)ifgetattr(self,"task",None)isnotNone:ifTYPE_CHECKING:assertself.taskassertisinstance(self.task.dag,SchedulerDAG)dr.dag=self.task.dag# Record it in the instance for next time. This means that `self.logical_date` will work correctlyset_committed_value(self,"dag_run",dr)returndr
@classmethod@provide_sessiondef_check_and_change_state_before_execution(cls,task_instance:TaskInstance,verbose:bool=True,ignore_all_deps:bool=False,ignore_depends_on_past:bool=False,wait_for_past_depends_before_skipping:bool=False,ignore_task_deps:bool=False,ignore_ti_state:bool=False,mark_success:bool=False,test_mode:bool=False,hostname:str="",pool:str|None=None,external_executor_id:str|None=None,session:Session=NEW_SESSION,)->bool:""" Check dependencies and then sets state to RUNNING if they are met. Returns True if and only if state is set to RUNNING, which implies that task should be executed, in preparation for _run_raw_task. :param verbose: whether to turn on more verbose logging :param ignore_all_deps: Ignore all of the non-critical dependencies, just runs :param ignore_depends_on_past: Ignore depends_on_past DAG attribute :param wait_for_past_depends_before_skipping: Wait for past depends before mark the ti as skipped :param ignore_task_deps: Don't check the dependencies of this TaskInstance's task :param ignore_ti_state: Disregards previous task instance state :param mark_success: Don't run the task, mark its state as success :param test_mode: Doesn't record success or failure in the DB :param hostname: The hostname of the worker running the task instance. :param pool: specifies the pool to use to run the task instance :param external_executor_id: The identifier of the celery executor :param session: SQLAlchemy ORM Session :return: whether the state was changed to running or not """ifTYPE_CHECKING:asserttask_instance.taskti:TaskInstance=task_instancetask=task_instance.taskifTYPE_CHECKING:asserttaskti.refresh_from_task(task,pool_override=pool)ti.test_mode=test_modeti.refresh_from_db(session=session,lock_for_update=True)ti.hostname=hostnameti.pid=Noneifnotignore_all_depsandnotignore_ti_stateandti.state==TaskInstanceState.SUCCESS:Stats.incr("previously_succeeded",tags=ti.stats_tags)ifnotmark_success:# Firstly find non-runnable and non-requeueable tis.# Since mark_success is not set, we do nothing.non_requeueable_dep_context=DepContext(deps=RUNNING_DEPS-REQUEUEABLE_DEPS,ignore_all_deps=ignore_all_deps,ignore_ti_state=ignore_ti_state,ignore_depends_on_past=ignore_depends_on_past,wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,ignore_task_deps=ignore_task_deps,description="non-requeueable deps",)ifnotti.are_dependencies_met(dep_context=non_requeueable_dep_context,session=session,verbose=True):session.commit()returnFalse# For reporting purposes, we report based on 1-indexed,# not 0-indexed lists (i.e. Attempt 1 instead of# Attempt 0 for the first attempt).# Set the task start date. In case it was re-scheduled use the initial# start date that is recorded in task_reschedule table# If the task continues after being deferred (next_method is set), use the original start_dateti.start_date=ti.start_dateifti.next_methodelsetimezone.utcnow()ifti.state==TaskInstanceState.UP_FOR_RESCHEDULE:tr_start_date=session.scalar(TR.stmt_for_task_instance(ti,descending=False).with_only_columns(TR.start_date).limit(1))iftr_start_date:ti.start_date=tr_start_date# Secondly we find non-runnable but requeueable tis. We reset its state.# This is because we might have hit concurrency limits,# e.g. because of backfilling.dep_context=DepContext(deps=REQUEUEABLE_DEPS,ignore_all_deps=ignore_all_deps,ignore_depends_on_past=ignore_depends_on_past,wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,ignore_task_deps=ignore_task_deps,ignore_ti_state=ignore_ti_state,description="requeueable deps",)ifnotti.are_dependencies_met(dep_context=dep_context,session=session,verbose=True):ti.state=Nonecls.logger().warning("Rescheduling due to concurrency limits reached ""at task runtime. Attempt %s of ""%s. State set to NONE.",ti.try_number,ti.max_tries+1,)ti.queued_dttm=timezone.utcnow()session.merge(ti)session.commit()returnFalseifti.next_kwargsisnotNone:cls.logger().info("Resuming after deferral")else:cls.logger().info("Starting attempt %s of %s",ti.try_number,ti.max_tries+1)ifnottest_mode:session.add(Log(TaskInstanceState.RUNNING.value,ti))ti.state=TaskInstanceState.RUNNINGti.emit_state_change_metric(TaskInstanceState.RUNNING)ifexternal_executor_id:ti.external_executor_id=external_executor_idti.end_date=Noneifnottest_mode:session.merge(ti).task=tasksession.commit()# Closing all pooled connections to prevent# "max number of connections reached"settings.engine.dispose()# type: ignoreifverbose:ifmark_success:cls.logger().info("Marking success for %s on %s",ti.task,ti.logical_date)else:cls.logger().info("Executing %s on %s",ti.task,ti.logical_date)returnTrue@provide_session
[docs]defemit_state_change_metric(self,new_state:TaskInstanceState)->None:""" Send a time metric representing how much time a given state transition took. The previous state and metric name is deduced from the state the task was put in. :param new_state: The state that has just been set for this task. We do not use `self.state`, because sometimes the state is updated directly in the DB and not in the local TaskInstance object. Supported states: QUEUED and RUNNING """ifself.end_date:# if the task has an end date, it means that this is not its first round.# we send the state transition time metric only on the first try, otherwise it gets more complex.return# switch on state and deduce which metric to sendifnew_state==TaskInstanceState.RUNNING:metric_name="queued_duration"ifself.queued_dttmisNone:# this should not really happen except in tests or rare cases,# but we don't want to create errors just for a metric, so we just skip itself.log.warning("cannot record %s for task %s because previous state change time has not been saved",metric_name,self.task_id,)returntiming=timezone.utcnow()-self.queued_dttmelifnew_state==TaskInstanceState.QUEUED:metric_name="scheduled_duration"ifself.scheduled_dttmisNone:self.log.warning("cannot record %s for task %s because previous state change time has not been saved",metric_name,self.task_id,)returntiming=timezone.utcnow()-self.scheduled_dttmelse:raiseNotImplementedError("no metric emission setup for state %s",new_state)# send metric twice, once (legacy) with tags in the name and once with tags as tagsStats.timing(f"dag.{self.dag_id}.{self.task_id}.{metric_name}",timing)Stats.timing(f"task.{metric_name}",timing,tags={"task_id":self.task_id,"dag_id":self.dag_id,"queue":self.queue},)
[docs]defclear_next_method_args(self)->None:"""Ensure we unset next_method and next_kwargs to ensure that any retries don't reuse them."""log.debug("Clearing next_method and next_kwargs.")self.next_method=Noneself.next_kwargs=None
@provide_session@Sentry.enrich_errorsdef_run_raw_task(self,mark_success:bool=False,test_mode:bool=False,pool:str|None=None,raise_on_defer:bool=False,session:Session=NEW_SESSION,)->TaskReturnCode|None:""" Run a task, update the state upon completion, and run any appropriate callbacks. Immediately runs the task (without checking or changing db state before execution) and then sets the appropriate final state after completion and runs any post-execute callbacks. Meant to be called only after another function changes the state to running. :param mark_success: Don't run the task, mark its state as success :param test_mode: Doesn't record success or failure in the DB :param pool: specifies the pool to use to run the task instance :param session: SQLAlchemy ORM Session """ifTYPE_CHECKING:assertself.taskifTYPE_CHECKING:assertisinstance(self.task,BaseOperator)self.test_mode=test_modeself.refresh_from_task(self.task,pool_override=pool)self.refresh_from_db(session=session)self.hostname=get_hostname()self.pid=os.getpid()ifnottest_mode:TaskInstance.save_to_db(ti=self,session=session)actual_start_date=timezone.utcnow()Stats.incr(f"ti.start.{self.task.dag_id}.{self.task.task_id}",tags=self.stats_tags)# Same metric with taggingStats.incr("ti.start",tags=self.stats_tags)# Initialize final state counters at zeroforstateinState.task_states:Stats.incr(f"ti.finish.{self.task.dag_id}.{self.task.task_id}.{state}",count=0,tags=self.stats_tags,)# Same metric with taggingStats.incr("ti.finish",count=0,tags={**self.stats_tags,"state":str(state)},)withset_current_task_instance_session(session=session):self.task=self.task.prepare_for_execution()context=self.get_template_context(ignore_param_exceptions=False,session=session)try:ifself.task:fromairflow.sdk.definitions.assetimportAssetinlets=[asset.asprofile()forassetinself.task.inletsifisinstance(asset,Asset)]outlets=[asset.asprofile()forassetinself.task.outletsifisinstance(asset,Asset)]TaskInstance.validate_inlet_outlet_assets_activeness(inlets,outlets,session=session)ifnotmark_success:TaskInstance._execute_task_with_callbacks(self=self,# type: ignore[arg-type]context=context,test_mode=test_mode,session=session,)ifnottest_mode:self.refresh_from_db(lock_for_update=True,session=session,keep_local_changes=True)self.state=TaskInstanceState.SUCCESSexceptTaskDeferredasdefer:# The task has signalled it wants to defer execution based on# a trigger.ifraise_on_defer:raiseself.defer_task(exception=defer,session=session)self.log.info("Pausing task as DEFERRED. dag_id=%s, task_id=%s, run_id=%s, logical_date=%s, start_date=%s",self.dag_id,self.task_id,self.run_id,_date_or_empty(task_instance=self,attr="logical_date"),_date_or_empty(task_instance=self,attr="start_date"),)returnTaskReturnCode.DEFERREDexceptAirflowSkipExceptionase:# Recording SKIP# log only if exception has any arguments to prevent log floodingife.args:self.log.info(e)ifnottest_mode:self.refresh_from_db(lock_for_update=True,session=session,keep_local_changes=True)self.state=TaskInstanceState.SKIPPED_run_finished_callback(callbacks=self.task.on_skipped_callback,context=context)TaskInstance.save_to_db(ti=self,session=session)exceptAirflowRescheduleExceptionasreschedule_exception:self._handle_reschedule(actual_start_date,reschedule_exception,test_mode,session=session)self.log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE")returnNoneexcept(AirflowFailException,AirflowSensorTimeout)ase:# If AirflowFailException is raised, task should not retry.# If a sensor in reschedule mode reaches timeout, task should not retry.self.handle_failure(e,test_mode,context,force_fail=True,session=session)# already saves to dbraiseexcept(AirflowTaskTimeout,AirflowException,AirflowTaskTerminated)ase:ifnottest_mode:self.refresh_from_db(lock_for_update=True,session=session)# for case when task is marked as success/failed externally# or dagrun timed out and task is marked as skipped# current behavior doesn't hit the callbacksifself.stateinState.finished:self.clear_next_method_args()TaskInstance.save_to_db(ti=self,session=session)returnNoneself.handle_failure(e,test_mode,context,session=session)raiseexceptSystemExitase:# We have already handled SystemExit with success codes (0 and None) in the `_execute_task`.# Therefore, here we must handle only error codes.msg=f"Task failed due to SystemExit({e.code})"self.handle_failure(msg,test_mode,context,session=session)raiseAirflowException(msg)exceptBaseExceptionase:self.handle_failure(e,test_mode,context,session=session)raisefinally:# Print a marker post execution for internals of post task processinglog.info("::group::Post task execution logs")Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}",tags=self.stats_tags,)# Same metric with taggingStats.incr("ti.finish",tags={**self.stats_tags,"state":str(self.state)})# Recording SKIPPED or SUCCESSself.clear_next_method_args()self.end_date=timezone.utcnow()_log_state(task_instance=self)self.set_duration()# run on_success_callback before db committing# otherwise, the LocalTaskJob sees the state is changed to `success`,# but the task_runner is still running, LocalTaskJob then treats the state is set externally!ifself.state==TaskInstanceState.SUCCESS:_run_finished_callback(callbacks=self.task.on_success_callback,context=context)ifnottest_mode:_add_log(event=self.state,task_instance=self,session=session)ifself.state==TaskInstanceState.SUCCESS:fromairflow.sdk.execution_time.task_runnerimport(_build_asset_profiles,_serialize_outlet_events,)TaskInstance.register_asset_changes_in_db(self,list(_build_asset_profiles(self.task.outlets)),list(_serialize_outlet_events(context["outlet_events"])),session=session,)TaskInstance.save_to_db(ti=self,session=session)ifself.state==TaskInstanceState.SUCCESS:try:get_listener_manager().hook.on_task_instance_success(previous_state=TaskInstanceState.RUNNING,task_instance=self)exceptException:log.exception("error calling listener")returnNone@staticmethod@provide_session
[docs]defregister_asset_changes_in_db(ti:TaskInstance,task_outlets:list[AssetProfile],outlet_events:list[dict[str,Any]],session:Session=NEW_SESSION,)->None:fromairflow.sdk.definitions.assetimportAsset,AssetAlias,AssetNameRef,AssetUniqueKey,AssetUriRefasset_keys={AssetUniqueKey(o.name,o.uri)forointask_outletsifo.type==Asset.__name__ando.nameando.uri}asset_name_refs={Asset.ref(name=o.name)forointask_outletsifo.type==AssetNameRef.__name__ando.name}asset_uri_refs={Asset.ref(uri=o.uri)forointask_outletsifo.type==AssetUriRef.__name__ando.uri}asset_models:dict[AssetUniqueKey,AssetModel]={AssetUniqueKey.from_asset(am):amforaminsession.scalars(select(AssetModel).where(AssetModel.active.has(),or_(tuple_(AssetModel.name,AssetModel.uri).in_(attrs.astuple(k)forkinasset_keys),AssetModel.name.in_(r.nameforrinasset_name_refs),AssetModel.uri.in_(r.uriforrinasset_uri_refs),),))}asset_event_extras:dict[AssetUniqueKey,dict]={AssetUniqueKey(**event["dest_asset_key"]):event["extra"]foreventinoutlet_eventsif"source_alias_name"notinevent}bad_asset_keys:set[AssetUniqueKey|AssetNameRef|AssetUriRef]=set()forkeyinasset_keys:try:am=asset_models[key]exceptKeyError:bad_asset_keys.add(key)continueti.log.debug("register event for asset %s",am)asset_manager.register_asset_change(task_instance=ti,asset=am,extra=asset_event_extras.get(key),session=session,)ifasset_name_refs:asset_models_by_name={key.name:amforkey,aminasset_models.items()}asset_event_extras_by_name={key.name:extraforkey,extrainasset_event_extras.items()}fornrefinasset_name_refs:try:am=asset_models_by_name[nref.name]exceptKeyError:bad_asset_keys.add(nref)continueti.log.debug("register event for asset name ref %s",am)asset_manager.register_asset_change(task_instance=ti,asset=am,extra=asset_event_extras_by_name.get(nref.name),session=session,)ifasset_uri_refs:asset_models_by_uri={key.uri:amforkey,aminasset_models.items()}asset_event_extras_by_uri={key.uri:extraforkey,extrainasset_event_extras.items()}forurefinasset_uri_refs:try:am=asset_models_by_uri[uref.uri]exceptKeyError:bad_asset_keys.add(uref)continueti.log.debug("register event for asset uri ref %s",am)asset_manager.register_asset_change(task_instance=ti,asset=am,extra=asset_event_extras_by_uri.get(uref.uri),session=session,)def_asset_event_extras_from_aliases()->dict[tuple[AssetUniqueKey,frozenset],set[str]]:d=defaultdict(set)foreventinoutlet_events:try:alias_name=event["source_alias_name"]exceptKeyError:continueifalias_namenotinoutlet_alias_names:continueasset_key=AssetUniqueKey(**event["dest_asset_key"])extra_key=frozenset(event["extra"].items())d[asset_key,extra_key].add(alias_name)returndoutlet_alias_names={o.nameforointask_outletsifo.type==AssetAlias.__name__ando.name}ifoutlet_alias_namesand(event_extras_from_aliases:=_asset_event_extras_from_aliases()):for(asset_key,extra_key),event_aliase_namesinevent_extras_from_aliases.items():ti.log.debug("register event for asset %s with aliases %s",asset_key,event_aliase_names)event=asset_manager.register_asset_change(task_instance=ti,asset=asset_key,source_alias_names=event_aliase_names,extra=dict(extra_key),session=session,)ifeventisNone:ti.log.info("Dynamically creating AssetModel %s",asset_key)session.add(AssetModel(name=asset_key.name,uri=asset_key.uri))session.flush()# So event can set up its asset fk.asset_manager.register_asset_change(task_instance=ti,asset=asset_key,source_alias_names=event_aliase_names,extra=dict(extra_key),session=session,)ifbad_asset_keys:raiseAirflowInactiveAssetInInletOrOutletException(bad_asset_keys)
def_execute_task_with_callbacks(self,context:Context,test_mode:bool=False,*,session:Session):"""Prepare Task for Execution."""fromairflow.sdk.execution_time.callback_runnerimportcreate_executable_runnerfromairflow.sdk.execution_time.contextimportcontext_get_outlet_eventsifTYPE_CHECKING:assertself.taskparent_pid=os.getpid()defsignal_handler(signum,frame):pid=os.getpid()# If a task forks during execution (from DAG code) for whatever# reason, we want to make sure that we react to the signal only in# the process that we've spawned ourselves (referred to here as the# parent process).ifpid!=parent_pid:os._exit(1)returnself.log.error("Received SIGTERM. Terminating subprocesses.")self.log.error("Stacktrace: \n%s","".join(traceback.format_stack()))self.task.on_kill()raiseAirflowTaskTerminated(f"Task received SIGTERM signal {self.task_id=}{self.dag_id=}{self.run_id=}{self.map_index=}")signal.signal(signal.SIGTERM,signal_handler)# Don't clear Xcom until the task is certain to execute, and check if we are resuming from deferral.ifnotself.next_method:self.clear_xcom_data()with(Stats.timer(f"dag.{self.task.dag_id}.{self.task.task_id}.duration"),Stats.timer("task.duration",tags=self.stats_tags),):# Set the validated/merged params on the task object.self.task.params=context["params"]withset_current_context(context):dag=self.task.get_dag()ifdagisnotNone:jinja_env=dag.get_template_env()else:jinja_env=Nonetask_orig=self.render_templates(context=context,jinja_env=jinja_env)# The task is never MappedOperator at this point.ifTYPE_CHECKING:assertisinstance(self.task,BaseOperator)ifnottest_mode:rendered_fields=get_serialized_template_fields(task=self.task)self.update_rtif(rendered_fields=rendered_fields)# Export context to make it available for operators to use.airflow_context_vars=context_to_airflow_vars(context,in_env_var_format=True)os.environ.update(airflow_context_vars)# Log context only for the default execution method, the assumption# being that otherwise we're resuming a deferred task (in which# case there's no need to log these again).ifnotself.next_method:self.log.info("Exporting env vars: %s"," ".join(f"{k}={v!r}"fork,vinairflow_context_vars.items()),)# Run pre_execute callbackifself.task._pre_execute_hook:create_executable_runner(self.task._pre_execute_hook,context_get_outlet_events(context),logger=self.log,).run(context)create_executable_runner(self.task.pre_execute,context_get_outlet_events(context),logger=self.log,).run(context)# Run on_execute callbackself._run_execute_callback(context,self.task)# Run on_task_instance_running eventtry:get_listener_manager().hook.on_task_instance_running(previous_state=TaskInstanceState.QUEUED,task_instance=self)exceptException:log.exception("error calling listener")def_render_map_index(context:Context,*,jinja_env:jinja2.Environment|None)->str|None:"""Render named map index if the DAG author defined map_index_template at the task level."""ifjinja_envisNoneor(template:=context.get("map_index_template"))isNone:returnNonerendered_map_index=jinja_env.from_string(template).render(context)log.debug("Map index rendered as %s",rendered_map_index)returnrendered_map_index# Execute the task.withset_current_context(context):try:result=self._execute_task(context,task_orig)exceptException:# If the task failed, swallow rendering error so it doesn't mask the main error.withcontextlib.suppress(jinja2.TemplateSyntaxError,jinja2.UndefinedError):self._rendered_map_index=_render_map_index(context,jinja_env=jinja_env)raiseelse:# If the task succeeded, render normally to let rendering error bubble up.self._rendered_map_index=_render_map_index(context,jinja_env=jinja_env)# Run post_execute callbackifself.task._post_execute_hook:create_executable_runner(self.task._post_execute_hook,context_get_outlet_events(context),logger=self.log,).run(context,result)create_executable_runner(self.task.post_execute,context_get_outlet_events(context),logger=self.log,).run(context,result)Stats.incr(f"operator_successes_{self.task.task_type}",tags=self.stats_tags)# Same metric with taggingStats.incr("operator_successes",tags={**self.stats_tags,"task_type":self.task.task_type})Stats.incr("ti_successes",tags=self.stats_tags)def_execute_task(self,context:Context,task_orig:Operator):""" Execute Task (optionally with a Timeout) and push Xcom results. :param context: Jinja2 context :param task_orig: origin task """fromairflow.sdk.bases.operatorimportExecutorSafeguardfromairflow.sdk.definitions.mappedoperatorimportMappedOperatortask_to_execute=self.taskifTYPE_CHECKING:# TODO: TaskSDK this function will need 100% re-writing# This only works with a "rich" BaseOperator, not the SDK versionassertisinstance(task_to_execute,BaseOperator)ifisinstance(task_to_execute,MappedOperator):raiseAirflowException("MappedOperator cannot be executed.")# If the task has been deferred and is being executed due to a trigger,# then we need to pick the right method to come back to, otherwise# we go for the default executeexecute_callable_kwargs:dict[str,Any]={}execute_callable:Callableifself.next_method:execute_callable=task_to_execute.resume_executionexecute_callable_kwargs["next_method"]=self.next_method# We don't want modifictions we make here to be tracked by SQLAexecute_callable_kwargs["next_kwargs"]={**(self.next_kwargsor{})}ifself.next_method=="execute":execute_callable_kwargs["next_kwargs"][f"{task_to_execute.__class__.__name__}__sentinel"]=(ExecutorSafeguard.sentinel_value)else:execute_callable=task_to_execute.executeifexecute_callable.__name__=="execute":execute_callable_kwargs[f"{task_to_execute.__class__.__name__}__sentinel"]=(ExecutorSafeguard.sentinel_value)def_execute_callable(context:Context,**execute_callable_kwargs):fromairflow.sdk.execution_time.callback_runnerimportcreate_executable_runnerfromairflow.sdk.execution_time.contextimportcontext_get_outlet_eventstry:# Print a marker for log grouping of details before task executionlog.info("::endgroup::")returncreate_executable_runner(execute_callable,context_get_outlet_events(context),logger=log,).run(context=context,**execute_callable_kwargs)exceptSystemExitase:# Handle only successful cases here. Failure cases will be handled upper# in the exception chain.ife.codeisnotNoneande.code!=0:raisereturnNone# If a timeout is specified for the task, make it fail# if it goes beyondiftask_to_execute.execution_timeout:# If we are coming in with a next_method (i.e. from a deferral),# calculate the timeout from our start_date.ifself.next_methodandself.start_date:timeout_seconds=(task_to_execute.execution_timeout-(timezone.utcnow()-self.start_date)).total_seconds()else:timeout_seconds=task_to_execute.execution_timeout.total_seconds()try:# It's possible we're already timed out, so fast-fail if trueiftimeout_seconds<=0:raiseAirflowTaskTimeout()# Run task in timeout wrapperwithtimeout(timeout_seconds):result=_execute_callable(context=context,**execute_callable_kwargs)exceptAirflowTaskTimeout:task_to_execute.on_kill()raiseelse:result=_execute_callable(context=context,**execute_callable_kwargs)cm=create_session()withcmassession_or_null:iftask_to_execute.do_xcom_push:xcom_value=resultelse:xcom_value=Noneifxcom_valueisnotNone:# If the task returns a result, push an XCom containing it.iftask_to_execute.multiple_outputs:ifnotisinstance(xcom_value,Mapping):raiseAirflowException(f"Returned output was type {type(xcom_value)} ""expected dictionary for multiple_outputs")forkeyinxcom_value.keys():ifnotisinstance(key,str):raiseAirflowException("Returned dictionary keys must be strings when using "f"multiple_outputs, found {key} ({type(key)}) instead")forkey,valueinxcom_value.items():self.xcom_push(key=key,value=value,session=session_or_null)self.xcom_push(key=XCOM_RETURN_KEY,value=xcom_value,session=session_or_null)ifTYPE_CHECKING:asserttask_orig.dag_record_task_map_for_downstreams(task_instance=self,task=task_orig,value=xcom_value,session=session_or_null,)returnresult
[docs]defdefer_task(self,exception:TaskDeferred|None,session:Session=NEW_SESSION)->None:""" Mark the task as deferred and sets up the trigger that is needed to resume it when TaskDeferred is raised. :meta: private """fromairflow.models.triggerimportTrigger# TODO: TaskSDK add start_trigger_args to SDK definitionsifTYPE_CHECKING:assertself.taskisNoneorisinstance(self.task,BaseOperator)timeout:timedelta|NoneifexceptionisnotNone:trigger_row=Trigger.from_object(exception.trigger)next_method=exception.method_namenext_kwargs=exception.kwargstimeout=exception.timeoutelifself.taskisnotNoneandself.task.start_trigger_argsisnotNone:context=self.get_template_context()start_trigger_args=self.task.expand_start_trigger_args(context=context,session=session)ifstart_trigger_argsisNone:raiseTaskDeferralError("A none 'None' start_trigger_args has been change to 'None' during expandion")trigger_kwargs=start_trigger_args.trigger_kwargsor{}next_kwargs=start_trigger_args.next_kwargsnext_method=start_trigger_args.next_methodtimeout=start_trigger_args.timeouttrigger_row=Trigger(classpath=self.task.start_trigger_args.trigger_cls,kwargs=trigger_kwargs,)else:raiseTaskDeferralError("exception and ti.task.start_trigger_args cannot both be None")# First, make the trigger entrysession.add(trigger_row)session.flush()ifTYPE_CHECKING:assertself.task# Then, update ourselves so it matches the deferral request# Keep an eye on the logic in `check_and_change_state_before_execution()`# depending on self.next_method semanticsself.state=TaskInstanceState.DEFERREDself.trigger_id=trigger_row.idself.next_method=next_methodself.next_kwargs=next_kwargsor{}# Calculate timeout too if it was passediftimeoutisnotNone:self.trigger_timeout=timezone.utcnow()+timeoutelse:self.trigger_timeout=None# If an execution_timeout is set, set the timeout to the minimum of# it and the trigger timeoutexecution_timeout=self.task.execution_timeoutifexecution_timeout:ifTYPE_CHECKING:assertself.start_dateifself.trigger_timeout:self.trigger_timeout=min(self.start_date+execution_timeout,self.trigger_timeout)else:self.trigger_timeout=self.start_date+execution_timeoutifself.test_mode:_add_log(event=self.state,task_instance=self,session=session)ifexceptionisnotNone:session.merge(self)session.commit()
def_run_execute_callback(self,context:Context,task:BaseOperator)->None:"""Functions that need to be run before a Task is executed."""ifnot(callbacks:=task.on_execute_callback):returnforcallbackincallbacksifisinstance(callbacks,list)else[callbacks]:try:callback(context)exceptException:self.log.exception("Failed when executing execute callback")@provide_session
[docs]defdry_run(self)->None:"""Only Renders Templates for the TI."""ifTYPE_CHECKING:assertself.taskself.task=self.task.prepare_for_execution()self.render_templates()ifTYPE_CHECKING:assertisinstance(self.task,BaseOperator)self.task.dry_run()
@provide_sessiondef_handle_reschedule(self,actual_start_date:datetime,reschedule_exception:AirflowRescheduleException,test_mode:bool=False,session:Session=NEW_SESSION,):# Don't record reschedule request in test modeiftest_mode:returnself.refresh_from_db(session)ifTYPE_CHECKING:assertself.taskself.end_date=timezone.utcnow()self.set_duration()# set stateself.state=TaskInstanceState.UP_FOR_RESCHEDULEself.clear_next_method_args()session.merge(self)session.commit()# we add this in separate commit to reduce likelihood of deadlock# see https://github.com/apache/airflow/pull/21362 for more infosession.add(TaskReschedule(self.id,actual_start_date,self.end_date,reschedule_exception.reschedule_date,))session.commit()returnself@staticmethoddefget_truncated_error_traceback(error:BaseException,truncate_to:Callable)->TracebackType|None:""" Truncate the traceback of an exception to the first frame called from within a given function. :param error: exception to get traceback from :param truncate_to: Function to truncate TB to. Must have a ``__code__`` attribute :meta private: """tb=error.__traceback__code=truncate_to.__func__.__code__# type: ignore[attr-defined]whiletbisnotNone:iftb.tb_frame.f_codeiscode:returntb.tb_nexttb=tb.tb_nextreturntborerror.__traceback__@classmethod
[docs]deffetch_handle_failure_context(cls,ti:TaskInstance,error:None|str|BaseException,test_mode:bool|None=None,context:Context|None=None,force_fail:bool=False,*,session:Session,fail_fast:bool=False,):""" Fetch the context needed to handle a failure. :param ti: TaskInstance :param error: if specified, log the specific exception if thrown :param test_mode: doesn't record success or failure in the DB if True :param context: Jinja2 context :param force_fail: if True, task does not retry :param session: SQLAlchemy ORM Session :param fail_fast: if True, fail all downstream tasks """iferror:ifisinstance(error,BaseException):tb=TaskInstance.get_truncated_error_traceback(error,truncate_to=ti._execute_task)cls.logger().error("Task failed with exception",exc_info=(type(error),error,tb))else:cls.logger().error("%s",error)ifnottest_mode:ti.refresh_from_db(session)ti.end_date=timezone.utcnow()ti.set_duration()Stats.incr(f"operator_failures_{ti.operator}",tags=ti.stats_tags)# Same metric with taggingStats.incr("operator_failures",tags={**ti.stats_tags,"operator":ti.operator})Stats.incr("ti_failures",tags=ti.stats_tags)ifnottest_mode:session.add(Log(TaskInstanceState.FAILED.value,ti))ti.clear_next_method_args()# In extreme cases (task instance heartbeat timeout in case of dag with parse error) we might _not_ have a Task.ifcontextisNoneandgetattr(ti,"task",None):context=ti.get_template_context(session)ifcontextisnotNone:context["exception"]=error# Set state correctly and figure out how to log it and decide whether# to email# Note, callback invocation needs to be handled by caller of# _run_raw_task to avoid race conditions which could lead to duplicate# invocations or miss invocation.# Since this function is called only when the TaskInstance state is running,# try_number contains the current try_number (not the next). We# only mark task instance as FAILED if the next task instance# try_number exceeds the max_tries ... or if force_fail is truthytask:BaseOperator|None=Nonetry:ifgetattr(ti,"task",None)andcontext:ifTYPE_CHECKING:assertisinstance(ti.task,BaseOperator)task=ti.task.unmap((context,session))exceptException:cls.logger().error("Unable to unmap task to determine if we need to send an alert email")ifforce_failornotti.is_eligible_to_retry():ti.state=TaskInstanceState.FAILEDemail_for_state=operator.attrgetter("email_on_failure")callbacks=task.on_failure_callbackiftaskelseNoneiftaskandfail_fast:_stop_remaining_tasks(task_instance=ti,session=session)else:ifti.state==TaskInstanceState.RUNNING:# If the task instance is in the running state, it means it raised an exception and# about to retry so we record the task instance history. For other states, the task# instance was cleared and already recorded in the task instance history.ti.prepare_db_for_next_try(session)ti.state=State.UP_FOR_RETRYemail_for_state=operator.attrgetter("email_on_retry")callbacks=task.on_retry_callbackiftaskelseNonetry:get_listener_manager().hook.on_task_instance_failed(previous_state=TaskInstanceState.RUNNING,task_instance=ti,error=error)exceptException:log.exception("error calling listener")return{"ti":ti,"email_for_state":email_for_state,"task":task,"callbacks":callbacks,"context":context,}
[docs]defhandle_failure(self,error:None|str|BaseException,test_mode:bool|None=None,context:Context|None=None,force_fail:bool=False,session:Session=NEW_SESSION,)->None:""" Handle Failure for a task instance. :param error: if specified, log the specific exception if thrown :param session: SQLAlchemy ORM Session :param test_mode: doesn't record success or failure in the DB if True :param context: Jinja2 context :param force_fail: if True, task does not retry """ifTYPE_CHECKING:assertself.taskassertself.task.dagtry:fail_fast=self.task.dag.fail_fastexceptException:fail_fast=Falseiftest_modeisNone:test_mode=self.test_modefailure_context=TaskInstance.fetch_handle_failure_context(ti=self,# type: ignore[arg-type]error=error,test_mode=test_mode,context=context,force_fail=force_fail,session=session,fail_fast=fail_fast,)_log_state(task_instance=self,lead_msg="Immediate failure requested. "ifforce_failelse"")if(failure_context["task"]andfailure_context["email_for_state"](failure_context["task"])andfailure_context["task"].email):try:self.email_alert(error,failure_context["task"])exceptException:log.exception("Failed to send email to: %s",failure_context["task"].email)iffailure_context["callbacks"]andfailure_context["context"]:_run_finished_callback(callbacks=failure_context["callbacks"],context=failure_context["context"],)ifnottest_mode:TaskInstance.save_to_db(failure_context["ti"],session)
[docs]defis_eligible_to_retry(self)->bool:"""Is task instance is eligible for retry."""ifself.state==TaskInstanceState.RESTARTING:# If a task is cleared when running, it goes into RESTARTING state and is always# eligible for retryreturnTrueifnotgetattr(self,"task",None):# Couldn't load the task, don't know number of retries, guess:returnself.try_number<=self.max_triesifTYPE_CHECKING:assertself.taskassertself.task.retriesreturnbool(self.task.retriesandself.try_number<=self.max_tries)
[docs]defget_template_context(self,session:Session|None=None,ignore_param_exceptions:bool=True,)->Context:""" Return TI Context. :param session: SQLAlchemy ORM Session :param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict """ifTYPE_CHECKING:assertself.taskassertisinstance(self.task.dag,SchedulerDAG)# Do not use provide_session here -- it expunges everything on exit!ifnotsession:session=settings.Session()fromairflowimportmacrosfromairflow.models.abstractoperatorimportNotMappedfromairflow.models.baseoperatorimportBaseOperatorfromairflow.sdk.api.datamodels._generatedimport(DagRunasDagRunSDK,PrevSuccessfulDagRunResponse,TIRunContext,)fromairflow.sdk.definitions.paramimportprocess_paramsfromairflow.sdk.execution_time.contextimportInletEventsAccessorsfromairflow.utils.contextimport(ConnectionAccessor,OutletEventAccessors,VariableAccessor,)integrate_macros_plugins()task=self.taskifTYPE_CHECKING:assertself.taskasserttaskasserttask.dagassertsessiondef_get_dagrun(session:Session)->DagRun:dag_run=self.get_dagrun(session)ifdag_runinsession:returndag_run# The dag_run may not be attached to the session anymore since the# code base is over-zealous with use of session.expunge_all().# Re-attach it if the relation is not loaded so we can load it when needed.info=inspect(dag_run)ifinfo.attrs.consumed_asset_events.loaded_valueisnotNO_VALUE:returndag_run# If dag_run is not flushed to db at all (e.g. CLI commands using# in-memory objects for ad-hoc operations), just set the value manually.ifnotinfo.has_identity:dag_run.consumed_asset_events=[]returndag_runreturnsession.merge(dag_run,load=False)dag_run=_get_dagrun(session)validated_params=process_params(self.task.dag,task,dag_run.conf,suppress_exception=ignore_param_exceptions)ti_context_from_server=TIRunContext(dag_run=DagRunSDK.model_validate(dag_run,from_attributes=True),max_tries=self.max_tries,should_retry=self.is_eligible_to_retry(),)runtime_ti=self.to_runtime_ti(context_from_server=ti_context_from_server)context:Context=runtime_ti.get_template_context()@cache# Prevent multiple database access.def_get_previous_dagrun_success()->PrevSuccessfulDagRunResponse:dr_from_db=self.get_previous_dagrun(state=DagRunState.SUCCESS,session=session)ifdr_from_db:returnPrevSuccessfulDagRunResponse.model_validate(dr_from_db,from_attributes=True)returnPrevSuccessfulDagRunResponse()defget_prev_data_interval_start_success()->pendulum.DateTime|None:returntimezone.coerce_datetime(_get_previous_dagrun_success().data_interval_start)defget_prev_data_interval_end_success()->pendulum.DateTime|None:returntimezone.coerce_datetime(_get_previous_dagrun_success().data_interval_end)defget_prev_start_date_success()->pendulum.DateTime|None:returntimezone.coerce_datetime(_get_previous_dagrun_success().start_date)defget_prev_end_date_success()->pendulum.DateTime|None:returntimezone.coerce_datetime(_get_previous_dagrun_success().end_date)defget_triggering_events()->dict[str,list[AssetEvent]]:asset_events=dag_run.consumed_asset_eventstriggering_events:dict[str,list[AssetEvent]]=defaultdict(list)foreventinasset_events:ifevent.asset:triggering_events[event.asset.uri].append(event)returntriggering_events# NOTE: If you add to this dict, make sure to also update the following:# * Context in task-sdk/src/airflow/sdk/definitions/context.py# * KNOWN_CONTEXT_KEYS in airflow/utils/context.py# * Table in docs/apache-airflow/templates-ref.rstcontext.update({"outlet_events":OutletEventAccessors(),"inlet_events":InletEventsAccessors(task.inlets),"macros":macros,"params":validated_params,"prev_data_interval_start_success":get_prev_data_interval_start_success(),"prev_data_interval_end_success":get_prev_data_interval_end_success(),"prev_start_date_success":get_prev_start_date_success(),"prev_end_date_success":get_prev_end_date_success(),"test_mode":self.test_mode,# ti/task_instance are added here for ti.xcom_{push,pull}"task_instance":self,"ti":self,"triggering_asset_events":lazy_object_proxy.Proxy(get_triggering_events),"var":{"json":VariableAccessor(deserialize_json=True),"value":VariableAccessor(deserialize_json=False),},"conn":ConnectionAccessor(),})try:expanded_ti_count:int|None=BaseOperator.get_mapped_ti_count(task,self.run_id,session=session)context["expanded_ti_count"]=expanded_ti_countifexpanded_ti_count:setattr(self,"_upstream_map_indexes",{upstream.task_id:self.get_relevant_upstream_map_indexes(upstream,expanded_ti_count,session=session,)forupstreamintask.upstream_list},)exceptNotMapped:passreturncontext
@provide_session
[docs]defget_rendered_template_fields(self,session:Session=NEW_SESSION)->None:""" Update task with rendered template fields for presentation in UI. If task has already run, will fetch from DB; otherwise will render. """fromairflow.models.renderedtifieldsimportRenderedTaskInstanceFieldsifTYPE_CHECKING:assertisinstance(self.task,BaseOperator)rendered_task_instance_fields=RenderedTaskInstanceFields.get_templated_fields(self,session=session)ifrendered_task_instance_fields:self.task=self.task.unmap(None)forfield_name,rendered_valueinrendered_task_instance_fields.items():setattr(self.task,field_name,rendered_value)returntry:# If we get here, either the task hasn't run or the RTIF record was purged.fromairflow.sdk.execution_time.secrets_maskerimportredactself.render_templates()forfield_nameinself.task.template_fields:rendered_value=getattr(self.task,field_name)setattr(self.task,field_name,redact(rendered_value,field_name))except(TemplateAssertionError,UndefinedError)ase:raiseAirflowException("Webserver does not have access to User-defined Macros or Filters ""when Dag Serialization is enabled. Hence for the task that have not yet ""started running, please use 'airflow tasks render' for debugging the ""rendering of template_fields.")frome
[docs]defoverwrite_params_with_dag_run_conf(self,params:dict,dag_run:DagRun):"""Overwrite Task Params with DagRun.conf."""ifdag_runanddag_run.conf:self.log.debug("Updating task params (%s) with DagRun.conf (%s)",params,dag_run.conf)params.update(dag_run.conf)
[docs]defrender_templates(self,context:Context|None=None,jinja_env:jinja2.Environment|None=None)->Operator:""" Render templates in the operator fields. If the task was originally mapped, this may replace ``self.task`` with the unmapped, fully rendered BaseOperator. The original ``self.task`` before replacement is returned. """fromairflow.sdk.definitions.mappedoperatorimportMappedOperatorifnotcontext:context=self.get_template_context()original_task=self.taskti=context["ti"]ifTYPE_CHECKING:assertoriginal_taskassertself.taskassertti.task# If self.task is mapped, this call replaces self.task to point to the# unmapped BaseOperator created by this function! This is because the# MappedOperator is useless for template rendering, and we need to be# able to access the unmapped task instead.original_task.render_template_fields(context,jinja_env)ifisinstance(self.task,MappedOperator):self.task=context["ti"].task# type: ignore[assignment]returnoriginal_task
[docs]defget_email_subject_content(self,exception:BaseException,task:BaseOperator|None=None)->tuple[str,str,str]:""" Get the email subject content for exceptions. :param exception: the exception sent in the email :param task: """return_get_email_subject_content(task_instance=self,exception=exception,task=task)
[docs]defemail_alert(self,exception,task:BaseOperator)->None:""" Send alert email with exception information. :param exception: the exception :param task: task related to the exception """subject,html_content,html_content_err=self.get_email_subject_content(exception,task=task)ifTYPE_CHECKING:asserttask.emailtry:send_email(task.email,subject,html_content)exceptException:send_email(task.email,subject,html_content_err)
[docs]defset_duration(self)->None:"""Set task instance duration."""ifself.end_dateandself.start_date:self.duration=(self.end_date-self.start_date).total_seconds()else:self.duration=Nonelog.debug("Task Duration set to %s",self.duration)
@provide_session
[docs]defxcom_push(self,key:str,value:Any,session:Session=NEW_SESSION,)->None:""" Make an XCom available for tasks to pull. :param key: Key to store the value under. :param value: Value to store. Only be JSON-serializable may be used otherwise. """XComModel.set(key=key,value=value,task_id=self.task_id,dag_id=self.dag_id,run_id=self.run_id,map_index=self.map_index,session=session,)
@provide_sessiondefxcom_pull(self,task_ids:str|Iterable[str]|None=None,dag_id:str|None=None,key:str=XCOM_RETURN_KEY,include_prior_dates:bool=False,session:Session=NEW_SESSION,*,map_indexes:int|Iterable[int]|None=None,default:Any=None,run_id:str|None=None,)->Any:""":meta private:"""# noqa: D400# This is only kept for compatibility in tests for now while AIP-72 is in progress.ifdag_idisNone:dag_id=self.dag_idifrun_idisNone:run_id=self.run_idquery=XComModel.get_many(key=key,run_id=run_id,dag_ids=dag_id,task_ids=task_ids,map_indexes=map_indexes,include_prior_dates=include_prior_dates,session=session,)# NOTE: Since we're only fetching the value field and not the whole# class, the @recreate annotation does not kick in. Therefore we need to# call XCom.deserialize_value() manually.# We are only pulling one single task.if(task_idsisNoneorisinstance(task_ids,str))andnotisinstance(map_indexes,Iterable):first=query.with_entities(XComModel.run_id,XComModel.task_id,XComModel.dag_id,XComModel.map_index,XComModel.value).first()iffirstisNone:# No matching XCom at all.returndefaultifmap_indexesisnotNoneorfirst.map_index<0:returnXComModel.deserialize_value(first)# raise RuntimeError("Nothing should hit this anymore")# TODO: TaskSDK: We should remove this, but many tests still currently call `ti.run()`. See #45549# At this point either task_ids or map_indexes is explicitly multi-value.# Order return values to match task_ids and map_indexes ordering.ordering=[]iftask_idsisNoneorisinstance(task_ids,str):ordering.append(XComModel.task_id)eliftask_id_whens:={tid:ifori,tidinenumerate(task_ids)}:ordering.append(case(task_id_whens,value=XComModel.task_id))else:ordering.append(XComModel.task_id)ifmap_indexesisNoneorisinstance(map_indexes,int):ordering.append(XComModel.map_index)elifisinstance(map_indexes,range):order=XComModel.map_indexifmap_indexes.step<0:order=order.desc()ordering.append(order)elifmap_index_whens:={map_index:ifori,map_indexinenumerate(map_indexes)}:ordering.append(case(map_index_whens,value=XComModel.map_index))else:ordering.append(XComModel.map_index)returnLazyXComSelectSequence.from_select(query.with_entities(XComModel.value).order_by(None).statement,order_by=ordering,session=session,)@provide_session
[docs]defget_num_running_task_instances(self,session:Session,same_dagrun:bool=False)->int:"""Return Number of running TIs from the DB."""# .count() is inefficientnum_running_task_instances_query=session.query(func.count()).filter(TaskInstance.dag_id==self.dag_id,TaskInstance.task_id==self.task_id,TaskInstance.state==TaskInstanceState.RUNNING,)ifsame_dagrun:num_running_task_instances_query=num_running_task_instances_query.filter(TaskInstance.run_id==self.run_id)returnnum_running_task_instances_query.scalar()
@staticmethod
[docs]deffilter_for_tis(tis:Iterable[TaskInstance|TaskInstanceKey])->BooleanClauseList|None:"""Return SQLAlchemy filter to query selected task instances."""# DictKeys type, (what we often pass here from the scheduler) is not directly indexable :(# Or it might be a generator, but we need to be able to iterate over it more than oncetis=list(tis)ifnottis:returnNonefirst=tis[0]dag_id=first.dag_idrun_id=first.run_idmap_index=first.map_indexfirst_task_id=first.task_id# pre-compute the set of dag_id, run_id, map_indices and task_idsdag_ids,run_ids,map_indices,task_ids=set(),set(),set(),set()fortintis:dag_ids.add(t.dag_id)run_ids.add(t.run_id)map_indices.add(t.map_index)task_ids.add(t.task_id)# Common path optimisations: when all TIs are for the same dag_id and run_id, or same dag_id# and task_id -- this can be over 150x faster for huge numbers of TIs (20k+)ifdag_ids=={dag_id}andrun_ids=={run_id}andmap_indices=={map_index}:returnand_(TaskInstance.dag_id==dag_id,TaskInstance.run_id==run_id,TaskInstance.map_index==map_index,TaskInstance.task_id.in_(task_ids),)ifdag_ids=={dag_id}andtask_ids=={first_task_id}andmap_indices=={map_index}:returnand_(TaskInstance.dag_id==dag_id,TaskInstance.run_id.in_(run_ids),TaskInstance.map_index==map_index,TaskInstance.task_id==first_task_id,)ifdag_ids=={dag_id}andrun_ids=={run_id}andtask_ids=={first_task_id}:returnand_(TaskInstance.dag_id==dag_id,TaskInstance.run_id==run_id,TaskInstance.map_index.in_(map_indices),TaskInstance.task_id==first_task_id,)filter_condition=[]# create 2 nested groups, both primarily grouped by dag_id and run_id,# and in the nested group 1 grouped by task_id the other by map_index.task_id_groups:dict[tuple,dict[Any,list[Any]]]=defaultdict(lambda:defaultdict(list))map_index_groups:dict[tuple,dict[Any,list[Any]]]=defaultdict(lambda:defaultdict(list))fortintis:task_id_groups[(t.dag_id,t.run_id)][t.task_id].append(t.map_index)map_index_groups[(t.dag_id,t.run_id)][t.map_index].append(t.task_id)# this assumes that most dags have dag_id as the largest grouping, followed by run_id. even# if its not, this is still a significant optimization over querying for every single tuple keyforcur_dag_id,cur_run_idinitertools.product(dag_ids,run_ids):# we compare the group size between task_id and map_index and use the smaller groupdag_task_id_groups=task_id_groups[(cur_dag_id,cur_run_id)]dag_map_index_groups=map_index_groups[(cur_dag_id,cur_run_id)]iflen(dag_task_id_groups)<=len(dag_map_index_groups):forcur_task_id,cur_map_indicesindag_task_id_groups.items():filter_condition.append(and_(TaskInstance.dag_id==cur_dag_id,TaskInstance.run_id==cur_run_id,TaskInstance.task_id==cur_task_id,TaskInstance.map_index.in_(cur_map_indices),))else:forcur_map_index,cur_task_idsindag_map_index_groups.items():filter_condition.append(and_(TaskInstance.dag_id==cur_dag_id,TaskInstance.run_id==cur_run_id,TaskInstance.task_id.in_(cur_task_ids),TaskInstance.map_index==cur_map_index,))returnor_(*filter_condition)
@classmethoddefti_selector_condition(cls,vals:Collection[str|tuple[str,int]])->ColumnOperators:""" Build an SQLAlchemy filter for a list of task_ids or tuples of (task_id,map_index). :meta private: """# Compute a filter for TI.task_id and TI.map_index based on input values# For each item, it will either be a task_id, or (task_id, map_index)task_id_only=[vforvinvalsifisinstance(v,str)]with_map_index=[vforvinvalsifnotisinstance(v,str)]filters:list[ColumnOperators]=[]iftask_id_only:filters.append(cls.task_id.in_(task_id_only))ifwith_map_index:filters.append(tuple_(cls.task_id,cls.map_index).in_(with_map_index))ifnotfilters:returnfalse()iflen(filters)==1:returnfilters[0]returnor_(*filters)
[docs]defget_relevant_upstream_map_indexes(self,upstream:Operator,ti_count:int|None,*,session:Session,)->int|range|None:""" Infer the map indexes of an upstream "relevant" to this ti. The bulk of the logic mainly exists to solve the problem described by the following example, where 'val' must resolve to different values, depending on where the reference is being used:: @task def this_task(v): # This is self.task. return v * 2 @task_group def tg1(inp): val = upstream(inp) # This is the upstream task. this_task(val) # When inp is 1, val here should resolve to 2. return val # This val is the same object returned by tg1. val = tg1.expand(inp=[1, 2, 3]) @task_group def tg2(inp): another_task(inp, val) # val here should resolve to [2, 4, 6]. tg2.expand(inp=["a", "b"]) The surrounding mapped task groups of ``upstream`` and ``self.task`` are inspected to find a common "ancestor". If such an ancestor is found, we need to return specific map indexes to pull a partial value from upstream XCom. :param upstream: The referenced upstream task. :param ti_count: The total count of task instance this task was expanded by the scheduler, i.e. ``expanded_ti_count`` in the template context. :return: Specific map index or map indexes to pull, or ``None`` if we want to "whole" return value (i.e. no mapped task groups involved). """fromairflow.models.baseoperatorimportBaseOperatorifTYPE_CHECKING:assertself.task# This value should never be None since we already know the current task# is in a mapped task group, and should have been expanded, despite that,# we need to check that it is not None to satisfy Mypy.# But this value can be 0 when we expand an empty list, for that it is# necessary to check that ti_count is not 0 to avoid dividing by 0.ifnotti_count:returnNone# Find the innermost common mapped task group between the current task# If the current task and the referenced task does not have a common# mapped task group, the two are in different task mapping contexts# (like another_task above), and we should use the "whole" value.common_ancestor=_find_common_ancestor_mapped_group(self.task,upstream)ifcommon_ancestorisNone:returnNone# At this point we know the two tasks share a mapped task group, and we# should use a "partial" value. Let's break down the mapped ti count# between the ancestor and further expansion happened inside it.ancestor_ti_count=BaseOperator.get_mapped_ti_count(common_ancestor,self.run_id,session=session)ancestor_map_index=self.map_index*ancestor_ti_count//ti_count# If the task is NOT further expanded inside the common ancestor, we# only want to reference one single ti. We must walk the actual DAG,# and "ti_count == ancestor_ti_count" does not work, since the further# expansion may be of length 1.ifnot_is_further_mapped_inside(upstream,common_ancestor):returnancestor_map_index# Otherwise we need a partial aggregation for values from selected task# instances in the ancestor's expansion context.further_count=ti_count//ancestor_ti_countmap_index_start=ancestor_map_index*further_countreturnrange(map_index_start,map_index_start+further_count)
defclear_db_references(self,session:Session):""" Clear db tables that have a reference to this instance. :param session: ORM Session :meta private: """fromairflow.models.renderedtifieldsimportRenderedTaskInstanceFieldstables:list[type[TaskInstanceDependencies]]=[XComModel,RenderedTaskInstanceFields,TaskMap,]tables_by_id:list[type[Base]]=[TaskInstanceNote,TaskReschedule]fortableintables:session.execute(delete(table).where(table.dag_id==self.dag_id,table.task_id==self.task_id,table.run_id==self.run_id,table.map_index==self.map_index,))fortableintables_by_id:session.execute(delete(table).where(table.ti_id==self.id))@classmethod
[docs]defduration_expression_update(cls,end_date:datetime,query:Update,bind:Engine|SAConnection)->Update:"""Return a SQL expression for calculating the duration of this TI, based on the start and end date columns."""# TODO: Compare it with self._set_duration methodifbind.dialect.name=="sqlite":returnquery.values({"end_date":end_date,"duration":((func.strftime("%s",end_date)-func.strftime("%s",cls.start_date))+func.round((func.strftime("%f",end_date)-func.strftime("%f",cls.start_date)),3)),})ifbind.dialect.name=="postgresql":returnquery.values({"end_date":end_date,"duration":extract("EPOCH",end_date-cls.start_date),})returnquery.values({"end_date":end_date,"duration":(func.timestampdiff(text("MICROSECOND"),cls.start_date,end_date)# Turn microseconds into floating point seconds./1_000_000),})
[docs]defget_first_reschedule_date(self,context:Context)->datetime|None:"""Get the first reschedule date for the task instance."""ifTYPE_CHECKING:assertisinstance(self.task,BaseOperator)withcreate_session()assession:start_date=session.scalar(select(TaskReschedule).where(TaskReschedule.ti_id==str(self.id),).order_by(TaskReschedule.id.asc()).with_only_columns(TaskReschedule.start_date).limit(1))returnstart_date
def_find_common_ancestor_mapped_group(node1:Operator,node2:Operator)->MappedTaskGroup|None:"""Given two operators, find their innermost common mapped task group."""ifnode1.dagisNoneornode2.dagisNoneornode1.dag_id!=node2.dag_id:returnNoneparent_group_ids={g.group_idforginnode1.iter_mapped_task_groups()}common_groups=(gforginnode2.iter_mapped_task_groups()ifg.group_idinparent_group_ids)returnnext(common_groups,None)def_is_further_mapped_inside(operator:Operator,container:TaskGroup)->bool:"""Whether given operator is *further* mapped inside a task group."""fromairflow.sdk.definitions.mappedoperatorimportMappedOperatorfromairflow.sdk.definitions.taskgroupimportMappedTaskGroupifisinstance(operator,MappedOperator):returnTruetask_group=operator.task_groupwhiletask_groupisnotNoneandtask_group.group_id!=container.group_id:ifisinstance(task_group,MappedTaskGroup):returnTruetask_group=task_group.parent_groupreturnFalse# State of the task instance.# Stores string version of the task state.
[docs]classSimpleTaskInstance:""" Simplified Task Instance. Used to send data between processes via Queues. """def__init__(self,dag_id:str,task_id:str,run_id:str,queued_dttm:datetime|None,start_date:datetime|None,end_date:datetime|None,try_number:int,map_index:int,state:str,executor:str|None,executor_config:Any,pool:str,queue:str,key:TaskInstanceKey,run_as_user:str|None=None,priority_weight:int|None=None,parent_context_carrier:dict|None=None,context_carrier:dict|None=None,span_status:str|None=None,):
[docs]deffrom_ti(cls,ti:TaskInstance)->SimpleTaskInstance:returncls(dag_id=ti.dag_id,task_id=ti.task_id,run_id=ti.run_id,map_index=ti.map_index,queued_dttm=ti.queued_dttm,start_date=ti.start_date,end_date=ti.end_date,try_number=ti.try_number,state=ti.state,executor=ti.executor,executor_config=ti.executor_config,pool=ti.pool,queue=ti.queue,key=ti.key,run_as_user=ti.run_as_userifhasattr(ti,"run_as_user")elseNone,priority_weight=ti.priority_weightifhasattr(ti,"priority_weight")elseNone,# Inspect the ti, to check if the 'dag_run' relationship is loaded.parent_context_carrier=ti.dag_run.context_carrierif"dag_run"notininspect(ti).unloadedelseNone,context_carrier=ti.context_carrierifhasattr(ti,"context_carrier")elseNone,span_status=ti.span_status,)
[docs]classTaskInstanceNote(Base):"""For storage of arbitrary notes concerning the task instance."""
[docs]def__repr__(self):prefix=f"<{self.__class__.__name__}: {self.task_instance.dag_id}.{self.task_instance.task_id}{self.task_instance.run_id}"ifself.task_instance.map_index!=-1:prefix+=f" map_index={self.task_instance.map_index}"returnprefix+f" TI ID: {self.ti_id}>"
STATICA_HACK=Trueglobals()["kcah_acitats"[::-1].upper()]=FalseifSTATICA_HACK:# pragma: no coverfromairflow.jobs.jobimportJobTaskInstance.queued_by_job=relationship(Job)