## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License."""CeleryExecutor.. seealso:: For more information on how the CeleryExecutor works, take a look at the guide: :ref:`executor:CeleryExecutor`"""importdatetimeimportloggingimportmathimportoperatorimportosimportsubprocessimporttimeimporttracebackfromcollectionsimportOrderedDictfromconcurrent.futuresimportProcessPoolExecutorfrommultiprocessingimportcpu_countfromtypingimportAny,Dict,List,Mapping,MutableMapping,Optional,Set,Tuple,UnionfromceleryimportCelery,Task,statesascelery_statesfromcelery.backends.baseimportBaseKeyValueStoreBackendfromcelery.backends.databaseimportDatabaseBackend,TaskasTaskDb,session_cleanupfromcelery.resultimportAsyncResultfromcelery.signalsimportimport_modulesascelery_import_modulesfromsetproctitleimportsetproctitleimportairflow.settingsassettingsfromairflow.config_templates.default_celeryimportDEFAULT_CELERY_CONFIGfromairflow.configurationimportconffromairflow.exceptionsimportAirflowException,AirflowTaskTimeoutfromairflow.executors.base_executorimportBaseExecutor,CommandType,EventBufferValueTypefromairflow.models.taskinstanceimportTaskInstance,TaskInstanceKeyfromairflow.statsimportStatsfromairflow.utils.log.logging_mixinimportLoggingMixinfromairflow.utils.netimportget_hostnamefromairflow.utils.stateimportStatefromairflow.utils.timeoutimporttimeoutfromairflow.utils.timezoneimportutcnow
[docs]defexecute_command(command_to_exec:CommandType)->None:"""Executes command."""BaseExecutor.validate_command(command_to_exec)log.info("Executing command in Celery: %s",command_to_exec)celery_task_id=app.current_task.request.idlog.info(f"Celery task ID: {celery_task_id}")ifsettings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER:_execute_in_subprocess(command_to_exec,celery_task_id)else:_execute_in_fork(command_to_exec,celery_task_id)
def_execute_in_fork(command_to_exec:CommandType,celery_task_id:Optional[str]=None)->None:pid=os.fork()ifpid:# In parent, wait for the childpid,ret=os.waitpid(pid,0)ifret==0:returnraiseAirflowException('Celery command failed on host: '+get_hostname())fromairflow.sentryimportSentryret=1try:fromairflow.cli.cli_parserimportget_parsersettings.engine.pool.dispose()settings.engine.dispose()parser=get_parser()# [1:] - remove "airflow" from the start of the commandargs=parser.parse_args(command_to_exec[1:])args.shut_down_logging=Falseifcelery_task_id:args.external_executor_id=celery_task_idsetproctitle(f"airflow task supervisor: {command_to_exec}")args.func(args)ret=0exceptExceptionase:log.exception("Failed to execute task %s.",str(e))ret=1finally:Sentry.flush()logging.shutdown()os._exit(ret)def_execute_in_subprocess(command_to_exec:CommandType,celery_task_id:Optional[str]=None)->None:env=os.environ.copy()ifcelery_task_id:env["external_executor_id"]=celery_task_idtry:subprocess.check_output(command_to_exec,stderr=subprocess.STDOUT,close_fds=True,env=env)exceptsubprocess.CalledProcessErrorase:log.exception('execute_command encountered a CalledProcessError')log.error(e.output)msg='Celery command failed on host: '+get_hostname()raiseAirflowException(msg)
[docs]classExceptionWithTraceback:""" Wrapper class used to propagate exceptions to parent processes from subprocesses. :param exception: The exception to wrap :type exception: Exception :param exception_traceback: The stacktrace to wrap :type exception_traceback: str """def__init__(self,exception:Exception,exception_traceback:str):self.exception=exceptionself.traceback=exception_traceback
# Task instance that is sent over Celery queues# TaskInstanceKey, Command, queue_name, CallableTask
[docs]defsend_task_to_executor(task_tuple:TaskInstanceInCelery,)->Tuple[TaskInstanceKey,CommandType,Union[AsyncResult,ExceptionWithTraceback]]:"""Sends task to executor."""key,command,queue,task_to_run=task_tupletry:withtimeout(seconds=OPERATION_TIMEOUT):result=task_to_run.apply_async(args=[command],queue=queue)exceptExceptionase:exception_traceback=f"Celery Task ID: {key}\n{traceback.format_exc()}"result=ExceptionWithTraceback(e,exception_traceback)returnkey,command,result
@celery_import_modules.connect
[docs]defon_celery_import_modules(*args,**kwargs):""" Preload some "expensive" airflow modules so that every task process doesn't have to import it again and again. Loading these for each task adds 0.3-0.5s *per task* before the task can run. For long running tasks this doesn't matter, but for short tasks this starts to be a noticeable impact. """importjinja2.ext# noqa: F401importairflow.jobs.local_task_jobimportairflow.macrosimportairflow.operators.bashimportairflow.operators.pythonimportairflow.operators.subdag# noqa: F401try:importnumpy# noqa: F401exceptImportError:passtry:importkubernetes.client# noqa: F401exceptImportError:pass
[docs]classCeleryExecutor(BaseExecutor):""" CeleryExecutor is recommended for production use of Airflow. It allows distributing the execution of task instances to multiple worker nodes. Celery is a simple, flexible and reliable distributed system to process vast amounts of messages, while providing operations with the tools required to maintain such a system. """def__init__(self):super().__init__()# Celery doesn't support bulk sending the tasks (which can become a bottleneck on bigger clusters)# so we use a multiprocessing pool to speed this up.# How many worker processes are created for checking celery task state.self._sync_parallelism=conf.getint('celery','SYNC_PARALLELISM')ifself._sync_parallelism==0:self._sync_parallelism=max(1,cpu_count()-1)self.bulk_state_fetcher=BulkStateFetcher(self._sync_parallelism)self.tasks={}# Mapping of tasks we've adopted, ordered by the earliest date they timeoutself.adopted_task_timeouts:Dict[TaskInstanceKey,datetime.datetime]=OrderedDict()self.task_adoption_timeout=datetime.timedelta(seconds=conf.getint('celery','task_adoption_timeout',fallback=600))self.task_publish_retries:Dict[TaskInstanceKey,int]=OrderedDict()self.task_publish_max_retries=conf.getint('celery','task_publish_max_retries',fallback=3)
[docs]defstart(self)->None:self.log.debug('Starting Celery Executor using %s processes for syncing',self._sync_parallelism)
def_num_tasks_per_send_process(self,to_send_count:int)->int:""" How many Celery tasks should each worker process send. :return: Number of tasks that should be sent per process :rtype: int """returnmax(1,int(math.ceil(1.0*to_send_count/self._sync_parallelism)))
[docs]deftrigger_tasks(self,open_slots:int)->None:""" Overwrite trigger_tasks function from BaseExecutor :param open_slots: Number of open slots :return: """sorted_queue=self.order_queued_tasks_by_priority()task_tuples_to_send:List[TaskInstanceInCelery]=[]for_inrange(min(open_slots,len(self.queued_tasks))):key,(command,_,queue,_)=sorted_queue.pop(0)task_tuple=(key,command,queue,execute_command)task_tuples_to_send.append(task_tuple)ifkeynotinself.task_publish_retries:self.task_publish_retries[key]=1iftask_tuples_to_send:self._process_tasks(task_tuples_to_send)
def_process_tasks(self,task_tuples_to_send:List[TaskInstanceInCelery])->None:first_task=next(t[3]fortintask_tuples_to_send)# Celery state queries will stuck if we do not use one same backend# for all tasks.cached_celery_backend=first_task.backendkey_and_async_results=self._send_tasks_to_celery(task_tuples_to_send)self.log.debug('Sent all tasks.')forkey,_,resultinkey_and_async_results:ifisinstance(result,ExceptionWithTraceback)andisinstance(result.exception,AirflowTaskTimeout):ifkeyinself.task_publish_retriesand(self.task_publish_retries.get(key)<=self.task_publish_max_retries):Stats.incr("celery.task_timeout_error")self.log.info("[Try %s of %s] Task Timeout Error for Task: (%s).",self.task_publish_retries[key],self.task_publish_max_retries,key,)self.task_publish_retries[key]+=1continueself.queued_tasks.pop(key)self.task_publish_retries.pop(key)ifisinstance(result,ExceptionWithTraceback):self.log.error(CELERY_SEND_ERR_MSG_HEADER+": %s\n%s\n",result.exception,result.traceback)self.event_buffer[key]=(State.FAILED,None)elifresultisnotNone:result.backend=cached_celery_backendself.running.add(key)self.tasks[key]=result# Store the Celery task_id in the event buffer. This will get "overwritten" if the task# has another event, but that is fine, because the only other events are success/failed at# which point we don't need the ID anymore anywayself.event_buffer[key]=(State.QUEUED,result.task_id)# If the task runs _really quickly_ we may already have a result!self.update_task_state(key,result.state,getattr(result,'info',None))def_send_tasks_to_celery(self,task_tuples_to_send:List[TaskInstanceInCelery]):iflen(task_tuples_to_send)==1orself._sync_parallelism==1:# One tuple, or max one process -> send it in the main thread.returnlist(map(send_task_to_executor,task_tuples_to_send))# Use chunks instead of a work queue to reduce context switching# since tasks are roughly uniform in sizechunksize=self._num_tasks_per_send_process(len(task_tuples_to_send))num_processes=min(len(task_tuples_to_send),self._sync_parallelism)withProcessPoolExecutor(max_workers=num_processes)assend_pool:key_and_async_results=list(send_pool.map(send_task_to_executor,task_tuples_to_send,chunksize=chunksize))returnkey_and_async_results
[docs]defsync(self)->None:ifnotself.tasks:self.log.debug("No task to query celery, skipping sync")returnself.update_all_task_states()ifself.adopted_task_timeouts:self._check_for_stalled_adopted_tasks()
def_check_for_stalled_adopted_tasks(self):""" See if any of the tasks we adopted from another Executor run have not progressed after the configured timeout. If they haven't, they likely never made it to Celery, and we should just resend them. We do that by clearing the state and letting the normal scheduler loop deal with that """now=utcnow()sorted_adopted_task_timeouts=sorted(self.adopted_task_timeouts.items(),key=lambdak:k[1])timedout_keys=[]forkey,stalled_afterinsorted_adopted_task_timeouts:ifstalled_after>now:# Since items are stored sorted, if we get to a stalled_after# in the future then we can stopbreak# If the task gets updated to STARTED (which Celery does) or has# already finished, then it will be removed from this list -- so# the only time it's still in this list is when it a) never made it# to celery in the first place (i.e. race condition somewhere in# the dying executor) or b) a really long celery queue and it just# hasn't started yet -- better cancel it and let the scheduler# re-queue rather than have this task risk stalling for evertimedout_keys.append(key)iftimedout_keys:self.log.error("Adopted tasks were still pending after %s, assuming they never made it to celery and ""clearing:\n\t%s",self.task_adoption_timeout,"\n\t".join(repr(x)forxintimedout_keys),)forkeyintimedout_keys:self.change_state(key,State.FAILED)
[docs]defdebug_dump(self)->None:"""Called in response to SIGUSR2 by the scheduler"""super().debug_dump()self.log.info("executor.tasks (%d)\n\t%s",len(self.tasks),"\n\t".join(map(repr,self.tasks.items())))self.log.info("executor.adopted_task_timeouts (%d)\n\t%s",len(self.adopted_task_timeouts),"\n\t".join(map(repr,self.adopted_task_timeouts.items())),
)
[docs]defupdate_all_task_states(self)->None:"""Updates states of the tasks."""self.log.debug("Inquiring about %s celery task(s)",len(self.tasks))state_and_info_by_celery_task_id=self.bulk_state_fetcher.get_many(self.tasks.values())self.log.debug("Inquiries completed.")forkey,async_resultinlist(self.tasks.items()):state,info=state_and_info_by_celery_task_id.get(async_result.task_id)ifstate:self.update_task_state(key,state,info)
[docs]defupdate_task_state(self,key:TaskInstanceKey,state:str,info:Any)->None:"""Updates state of a single task."""try:ifstate==celery_states.SUCCESS:self.success(key,info)elifstatein(celery_states.FAILURE,celery_states.REVOKED):self.fail(key,info)elifstate==celery_states.STARTED:# It's now actually running, so know it made it to celery okay!self.adopted_task_timeouts.pop(key,None)elifstate==celery_states.PENDING:passelse:self.log.info("Unexpected state for %s: %s",key,state)exceptException:self.log.exception("Error syncing the Celery executor, ignoring it.")
[docs]defexecute_async(self,key:TaskInstanceKey,command:CommandType,queue:Optional[str]=None,executor_config:Optional[Any]=None,):"""Do not allow async execution for Celery executor."""raiseAirflowException("No Async execution for Celery executor.")
[docs]deftry_adopt_task_instances(self,tis:List[TaskInstance])->List[TaskInstance]:# See which of the TIs are still alive (or have finished even!)## Since Celery doesn't store "SENT" state for queued commands (if we create an AsyncResult with a made# up id it just returns PENDING state for it), we have to store Celery's task_id against the TI row to# look at in future.## This process is not perfect -- we could have sent the task to celery, and crashed before we were# able to record the AsyncResult.task_id in the TaskInstance table, in which case we won't adopt the# task (it'll either run and update the TI state, or the scheduler will clear and re-queue it. Either# way it won't get executed more than once)## (If we swapped it around, and generated a task_id for Celery, stored that in TI and enqueued that# there is also still a race condition where we could generate and store the task_id, but die before# we managed to enqueue the command. Since neither way is perfect we always have to deal with this# process not being perfect.)celery_tasks={}not_adopted_tis=[]fortiintis:ifti.external_executor_idisnotNone:celery_tasks[ti.external_executor_id]=(AsyncResult(ti.external_executor_id),ti)else:not_adopted_tis.append(ti)ifnotcelery_tasks:# Nothing to adoptreturntisstates_by_celery_task_id=self.bulk_state_fetcher.get_many(list(map(operator.itemgetter(0),celery_tasks.values())))adopted=[]cached_celery_backend=next(iter(celery_tasks.values()))[0].backendforcelery_task_id,(state,info)instates_by_celery_task_id.items():result,ti=celery_tasks[celery_task_id]result.backend=cached_celery_backend# Set the correct elements of the state dicts, then update this# like we just queried it.self.adopted_task_timeouts[ti.key]=ti.queued_dttm+self.task_adoption_timeoutself.tasks[ti.key]=resultself.running.add(ti.key)self.update_task_state(ti.key,state,info)adopted.append(f"{ti} in state {state}")ifadopted:task_instance_str='\n\t'.join(adopted)self.log.info("Adopted the following %d tasks from a dead executor\n\t%s",len(adopted),task_instance_str)returnnot_adopted_tis
[docs]deffetch_celery_task_state(async_result:AsyncResult)->Tuple[str,Union[str,ExceptionWithTraceback],Any]:""" Fetch and return the state of the given celery task. The scope of this function is global so that it can be called by subprocesses in the pool. :param async_result: a tuple of the Celery task key and the async Celery object used to fetch the task's state :type async_result: tuple(str, celery.result.AsyncResult) :return: a tuple of the Celery task key and the Celery state and the celery info of the task :rtype: tuple[str, str, str] """try:withtimeout(seconds=OPERATION_TIMEOUT):# Accessing state property of celery task will make actual network request# to get the current state of the taskinfo=async_result.infoifhasattr(async_result,'info')elseNonereturnasync_result.task_id,async_result.state,infoexceptExceptionase:exception_traceback=f"Celery Task ID: {async_result}\n{traceback.format_exc()}"returnasync_result.task_id,ExceptionWithTraceback(e,exception_traceback),None
[docs]classBulkStateFetcher(LoggingMixin):""" Gets status for many Celery tasks using the best method available If BaseKeyValueStoreBackend is used as result backend, the mget method is used. If DatabaseBackend is used as result backend, the SELECT ...WHERE task_id IN (...) query is used Otherwise, multiprocessing.Pool will be used. Each task status will be downloaded individually. """def__init__(self,sync_parallelism=None):super().__init__()self._sync_parallelism=sync_parallelismdef_tasks_list_to_task_ids(self,async_tasks)->Set[str]:return{a.task_idforainasync_tasks}
[docs]defget_many(self,async_results)->Mapping[str,EventBufferValueType]:"""Gets status for many Celery tasks using the best method available."""ifisinstance(app.backend,BaseKeyValueStoreBackend):result=self._get_many_from_kv_backend(async_results)elifisinstance(app.backend,DatabaseBackend):result=self._get_many_from_db_backend(async_results)else:result=self._get_many_using_multiprocessing(async_results)self.log.debug("Fetched %d state(s) for %d task(s)",len(result),len(async_results))returnresult