Source code for airflow.providers.celery.executors.celery_executor_utils
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License."""Utilities and classes used by the Celery Executor.Much of this code is expensive to import/load, be careful where this module is imported."""from__future__importannotationsimportcontextlibimportloggingimportmathimportosimportsubprocessimporttracebackimportwarningsfromcollections.abcimportMapping,MutableMappingfromconcurrent.futuresimportProcessPoolExecutorfromtypingimportTYPE_CHECKING,Any,Optional,UnionfromceleryimportCelery,Task,statesascelery_statesfromcelery.backends.baseimportBaseKeyValueStoreBackendfromcelery.backends.databaseimportDatabaseBackend,TaskasTaskDb,retry,session_cleanupfromcelery.signalsimportimport_modulesascelery_import_modulesfromsetproctitleimportsetproctitlefromsqlalchemyimportselectimportairflow.settingsassettingsfromairflow.configurationimportconffromairflow.exceptionsimportAirflowException,AirflowProviderDeprecationWarning,AirflowTaskTimeoutfromairflow.executors.base_executorimportBaseExecutorfromairflow.providers.celery.version_compatimportAIRFLOW_V_3_0_PLUSfromairflow.statsimportStatsfromairflow.utils.log.logging_mixinimportLoggingMixinfromairflow.utils.netimportget_hostnamefromairflow.utils.providers_configuration_loaderimportproviders_configuration_loadedfromairflow.utils.timeoutimporttimeouttry:fromairflow.sdk.definitions._internal.dag_parsing_contextimport_airflow_parsing_context_managerexceptImportError:fromairflow.utils.dag_parsing_contextimport_airflow_parsing_context_manager
ifTYPE_CHECKING:fromcelery.resultimportAsyncResultfromairflow.executorsimportworkloadsfromairflow.executors.base_executorimportCommandType,EventBufferValueTypefromairflow.models.taskinstanceimportTaskInstanceKeyfromairflow.typing_compatimportTypeAlias# We can't use `if AIRFLOW_V_3_0_PLUS` conditions in type checks, so unfortunately we just have to define# the type as the union of both kinds
@providers_configuration_loadeddef_get_celery_app()->Celery:"""Init providers before importing the configuration, so the _SECRET and _CMD options work."""globalcelery_configurationifconf.has_option("celery","celery_config_options"):celery_configuration=conf.getimport("celery","celery_config_options")else:fromairflow.providers.celery.executors.default_celeryimportDEFAULT_CELERY_CONFIGcelery_configuration=DEFAULT_CELERY_CONFIGcelery_app_name=conf.get("celery","CELERY_APP_NAME")ifcelery_app_name=="airflow.executors.celery_executor":warnings.warn("The celery.CELERY_APP_NAME configuration uses deprecated package name: ""'airflow.executors.celery_executor'. ""Change it to `airflow.providers.celery.executors.celery_executor`, and ""update the `-app` flag in your Celery Health Checks ""to use `airflow.providers.celery.executors.celery_executor.app`.",AirflowProviderDeprecationWarning,stacklevel=2,)returnCelery(celery_app_name,config_source=celery_configuration)
[docs]defon_celery_import_modules(*args,**kwargs):""" Preload some "expensive" airflow modules once, so other task processes won't have to import it again. Loading these for each task adds 0.3-0.5s *per task* before the task can run. For long running tasks this doesn't matter, but for short tasks this starts to be a noticeable impact. """importjinja2.ext# noqa: F401ifnotAIRFLOW_V_3_0_PLUS:importairflow.jobs.local_task_job_runnerimportairflow.macrostry:importairflow.providers.standard.operators.bashimportairflow.providers.standard.operators.pythonexceptImportError:importairflow.operators.bashimportairflow.operators.python# noqa: F401withcontextlib.suppress(ImportError):importnumpy# noqa: F401withcontextlib.suppress(ImportError):importkubernetes.client# noqa: F401
# Once Celery 5.5 is out of beta, we can pass `pydantic=True` to the decorator and it will handle the validation# and deserialization for us@app.task(name="execute_workload")
[docs]defexecute_workload(input:str)->None:frompydanticimportTypeAdapterfromairflow.configurationimportconffromairflow.executorsimportworkloadsfromairflow.sdk.execution_time.supervisorimportsupervisedecoder=TypeAdapter[workloads.All](workloads.All)workload=decoder.validate_json(input)celery_task_id=app.current_task.request.idifnotisinstance(workload,workloads.ExecuteTask):raiseValueError(f"CeleryExecutor does not know how to handle {type(workload)}")log.info("[%s] Executing workload in Celery: %s",celery_task_id,workload)supervise(# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.ti=workload.ti,# type: ignore[arg-type]dag_rel_path=workload.dag_rel_path,bundle_info=workload.bundle_info,token=workload.token,server=conf.get("core","execution_api_server_url"),log_path=workload.log_path,)
ifnotAIRFLOW_V_3_0_PLUS:@app.task
[docs]defexecute_command(command_to_exec:CommandType)->None:"""Execute command."""dag_id,task_id=BaseExecutor.validate_airflow_tasks_run_command(command_to_exec)celery_task_id=app.current_task.request.idlog.info("[%s] Executing command in Celery: %s",celery_task_id,command_to_exec)with_airflow_parsing_context_manager(dag_id=dag_id,task_id=task_id):try:ifsettings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER:_execute_in_subprocess(command_to_exec,celery_task_id)else:_execute_in_fork(command_to_exec,celery_task_id)exceptException:Stats.incr("celery.execute_command.failure")raise
def_execute_in_fork(command_to_exec:CommandType,celery_task_id:str|None=None)->None:pid=os.fork()ifpid:# In parent, wait for the childpid,ret=os.waitpid(pid,0)ifret==0:returnmsg=f"Celery command failed on host: {get_hostname()} with celery_task_id {celery_task_id} (PID: {pid}, Return Code: {ret})"raiseAirflowException(msg)fromairflow.sentryimportSentryret=1try:fromairflow.cli.cli_parserimportget_parserparser=get_parser()# [1:] - remove "airflow" from the start of the commandargs=parser.parse_args(command_to_exec[1:])args.shut_down_logging=Falseifcelery_task_id:args.external_executor_id=celery_task_idsetproctitle(f"airflow task supervisor: {command_to_exec}")log.debug("calling func '%s' with args %s",args.func.__name__,args)args.func(args)ret=0exceptException:log.exception("[%s] Failed to execute task.",celery_task_id)ret=1finally:try:Sentry.flush()logging.shutdown()exceptException:log.exception("[%s] Failed to clean up.",celery_task_id)ret=1os._exit(ret)def_execute_in_subprocess(command_to_exec:CommandType,celery_task_id:str|None=None)->None:env=os.environ.copy()ifcelery_task_id:env["external_executor_id"]=celery_task_idtry:subprocess.check_output(command_to_exec,stderr=subprocess.STDOUT,close_fds=True,env=env)exceptsubprocess.CalledProcessErrorase:log.exception("[%s] execute_command encountered a CalledProcessError",celery_task_id)log.error(e.output)msg=f"Celery command failed on host: {get_hostname()} with celery_task_id {celery_task_id}"raiseAirflowException(msg)
[docs]classExceptionWithTraceback:""" Wrapper class used to propagate exceptions to parent processes from subprocesses. :param exception: The exception to wrap :param exception_traceback: The stacktrace to wrap """def__init__(self,exception:BaseException,exception_traceback:str):
[docs]defsend_task_to_executor(task_tuple:TaskInstanceInCelery,)->tuple[TaskInstanceKey,CommandType,AsyncResult|ExceptionWithTraceback]:"""Send task to executor."""key,args,queue,task_to_run=task_tupleifAIRFLOW_V_3_0_PLUS:ifTYPE_CHECKING:assertisinstance(args,workloads.BaseWorkload)args=(args.model_dump_json(exclude={"ti":{"executor_config"}}),)try:withtimeout(seconds=OPERATION_TIMEOUT):result=task_to_run.apply_async(args=args,queue=queue)except(Exception,AirflowTaskTimeout)ase:exception_traceback=f"Celery Task ID: {key}\n{traceback.format_exc()}"result=ExceptionWithTraceback(e,exception_traceback)# The type is right for the version, but the type cannot be defined correctly for Airflow 2 and 3# concurrently;returnkey,args,result# type: ignore[return-value]
[docs]deffetch_celery_task_state(async_result:AsyncResult)->tuple[str,str|ExceptionWithTraceback,Any]:""" Fetch and return the state of the given celery task. The scope of this function is global so that it can be called by subprocesses in the pool. :param async_result: a tuple of the Celery task key and the async Celery object used to fetch the task's state :return: a tuple of the Celery task key and the Celery state and the celery info of the task """try:withtimeout(seconds=OPERATION_TIMEOUT):# Accessing state property of celery task will make actual network request# to get the current state of the taskinfo=async_result.infoifhasattr(async_result,"info")elseNonereturnasync_result.task_id,async_result.state,infoexceptExceptionase:exception_traceback=f"Celery Task ID: {async_result}\n{traceback.format_exc()}"returnasync_result.task_id,ExceptionWithTraceback(e,exception_traceback),None
[docs]classBulkStateFetcher(LoggingMixin):""" Gets status for many Celery tasks using the best method available. If BaseKeyValueStoreBackend is used as result backend, the mget method is used. If DatabaseBackend is used as result backend, the SELECT ...WHERE task_id IN (...) query is used Otherwise, multiprocessing.Pool will be used. Each task status will be downloaded individually. """def__init__(self,sync_parallelism=None):super().__init__()self._sync_parallelism=sync_parallelismdef_tasks_list_to_task_ids(self,async_tasks)->set[str]:return{a.task_idforainasync_tasks}
[docs]defget_many(self,async_results)->Mapping[str,EventBufferValueType]:"""Get status for many Celery tasks using the best method available."""ifisinstance(app.backend,BaseKeyValueStoreBackend):result=self._get_many_from_kv_backend(async_results)elifisinstance(app.backend,DatabaseBackend):result=self._get_many_from_db_backend(async_results)else:result=self._get_many_using_multiprocessing(async_results)self.log.debug("Fetched %d state(s) for %d task(s)",len(result),len(async_results))returnresult