## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License."""Implements the ``@task_group`` function decorator.When the decorated function is called, a task group will be created to representa collection of closely related tasks on the same DAG that should be groupedtogether when the DAG is displayed graphically."""from__future__importannotationsimportfunctoolsimportinspectimportwarningsfromtypingimportTYPE_CHECKING,Any,Callable,ClassVar,Generic,Mapping,Sequence,TypeVar,overloadimportattrfromairflow.decorators.baseimportExpandableFactoryfromairflow.models.expandinputimport(DictOfListsExpandInput,ListOfDictsExpandInput,MappedArgument,)fromairflow.models.taskmixinimportDAGNodefromairflow.models.xcom_argimportXComArgfromairflow.typing_compatimportParamSpecfromairflow.utils.helpersimportprevent_duplicatesfromairflow.utils.task_groupimportMappedTaskGroup,TaskGroupifTYPE_CHECKING:fromairflow.models.dagimportDAGfromairflow.models.expandinputimport(OperatorExpandArgument,OperatorExpandKwargsArgument,)
@attr.define()class_TaskGroupFactory(ExpandableFactory,Generic[FParams,FReturn]):function:Callable[FParams,FReturn]=attr.ib(validator=attr.validators.is_callable())tg_kwargs:dict[str,Any]=attr.ib(factory=dict)# Parameters forwarded to TaskGroup.partial_kwargs:dict[str,Any]=attr.ib(factory=dict)# Parameters forwarded to 'function'._task_group_created:bool=attr.ib(False,init=False)tg_class:ClassVar[type[TaskGroup]]=TaskGroup@tg_kwargs.validatordef_validate(self,_,kwargs):task_group_sig.bind_partial(**kwargs)def__attrs_post_init__(self):self.tg_kwargs.setdefault("group_id",self.function.__name__)def__del__(self):ifself.partial_kwargsandnotself._task_group_created:try:group_id=repr(self.tg_kwargs["group_id"])exceptKeyError:group_id=f"at {hex(id(self))}"warnings.warn(f"Partial task group {group_id} was never mapped!",stacklevel=1)def__call__(self,*args:FParams.args,**kwargs:FParams.kwargs)->DAGNode:""" Instantiate the task group. This uses the wrapped function to create a task group. Depending on the return type of the wrapped function, this either returns the last task in the group, or the group itself, to support task chaining. """returnself._create_task_group(TaskGroup,*args,**kwargs)def_create_task_group(self,tg_factory:Callable[...,TaskGroup],*args:Any,**kwargs:Any)->DAGNode:withtg_factory(add_suffix_on_collision=True,**self.tg_kwargs)astask_group:ifself.function.__doc__andnottask_group.tooltip:task_group.tooltip=self.function.__doc__# Invoke function to run Tasks inside the TaskGroupretval=self.function(*args,**kwargs)self._task_group_created=True# If the task-creating function returns a task, forward the return value# so dependencies bind to it. This is equivalent to# with TaskGroup(...) as tg:# t2 = task_2(task_1())# start >> t2 >> endifretvalisnotNone:returnretval# Otherwise return the task group as a whole, equivalent to# with TaskGroup(...) as tg:# task_1()# task_2()# start >> tg >> endreturntask_groupdefoverride(self,**kwargs:Any)->_TaskGroupFactory[FParams,FReturn]:# TODO: FIXME when mypy gets compatible with new attrsreturnattr.evolve(self,tg_kwargs={**self.tg_kwargs,**kwargs})# type: ignore[arg-type]defpartial(self,**kwargs:Any)->_TaskGroupFactory[FParams,FReturn]:self._validate_arg_names("partial",kwargs)prevent_duplicates(self.partial_kwargs,kwargs,fail_reason="duplicate partial")kwargs.update(self.partial_kwargs)# TODO: FIXME when mypy gets compatible with new attrsreturnattr.evolve(self,partial_kwargs=kwargs)# type: ignore[arg-type]defexpand(self,**kwargs:OperatorExpandArgument)->DAGNode:ifnotkwargs:raiseTypeError("no arguments to expand against")self._validate_arg_names("expand",kwargs)prevent_duplicates(self.partial_kwargs,kwargs,fail_reason="mapping already partial")expand_input=DictOfListsExpandInput(kwargs)returnself._create_task_group(functools.partial(MappedTaskGroup,expand_input=expand_input),**self.partial_kwargs,**{k:MappedArgument(input=expand_input,key=k)forkinkwargs},)defexpand_kwargs(self,kwargs:OperatorExpandKwargsArgument)->DAGNode:ifisinstance(kwargs,Sequence):foriteminkwargs:ifnotisinstance(item,(XComArg,Mapping)):raiseTypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")elifnotisinstance(kwargs,XComArg):raiseTypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")# It's impossible to build a dict of stubs as keyword arguments if the# function uses * or ** wildcard arguments.function_has_vararg=any(v.kind==inspect.Parameter.VAR_POSITIONALorv.kind==inspect.Parameter.VAR_KEYWORDforvinself.function_signature.parameters.values())iffunction_has_vararg:raiseTypeError("calling expand_kwargs() on task group function with * or ** is not supported")# We can't be sure how each argument is used in the function (well# technically we can with AST but let's not), so we have to create stubs# for every argument, including those with default values.map_kwargs=(kforkinself.function_signature.parametersifknotinself.partial_kwargs)expand_input=ListOfDictsExpandInput(kwargs)returnself._create_task_group(functools.partial(MappedTaskGroup,expand_input=expand_input),**self.partial_kwargs,**{k:MappedArgument(input=expand_input,key=k)forkinmap_kwargs},)# This covers the @task_group() case. Annotations are copied from the TaskGroup# class, only providing a default to 'group_id' (this is optional for the# decorator and defaults to the decorated function's name). Please keep them in# sync with TaskGroup when you can! Note that since this is an overload, these# argument defaults aren't actually used at runtime--the real implementation# does not use them, and simply rely on TaskGroup's defaults, so it's not# disastrous if they go out of sync with TaskGroup.@overload
# This covers the @task_group case (no parentheses).@overloaddeftask_group(python_callable:Callable[FParams,FReturn])->_TaskGroupFactory[FParams,FReturn]:...deftask_group(python_callable=None,**tg_kwargs):""" Python TaskGroup decorator. This wraps a function into an Airflow TaskGroup. When used as the ``@task_group()`` form, all arguments are forwarded to the underlying TaskGroup class. Can be used to parametrize TaskGroup. :param python_callable: Function to decorate. :param tg_kwargs: Keyword arguments for the TaskGroup object. """ifcallable(python_callable)andnottg_kwargs:return_TaskGroupFactory(function=python_callable,tg_kwargs=tg_kwargs)returnfunctools.partial(_TaskGroupFactory,tg_kwargs=tg_kwargs)