# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportinspectimportitertoolsimporttextwrapimportwarningsfromfunctoolsimportcached_propertyfromtypingimport(TYPE_CHECKING,Any,Callable,ClassVar,Collection,Generic,Iterator,Mapping,Sequence,TypeVar,cast,overload,)importattrimportre2importtyping_extensionsfromairflow.datasetsimportDatasetfromairflow.models.abstractoperatorimportDEFAULT_RETRIES,DEFAULT_RETRY_DELAYfromairflow.models.baseoperatorimport(BaseOperator,coerce_resources,coerce_timedelta,get_merged_defaults,parse_retries,)fromairflow.models.dagimportDagContextfromairflow.models.expandinputimport(EXPAND_INPUT_EMPTY,DictOfListsExpandInput,ListOfDictsExpandInput,is_mappable,)fromairflow.models.mappedoperatorimportMappedOperator,ensure_xcomarg_return_valuefromairflow.models.poolimportPoolfromairflow.models.xcom_argimportXComArgfromairflow.typing_compatimportParamSpec,Protocolfromairflow.utilsimporttimezonefromairflow.utils.contextimportKNOWN_CONTEXT_KEYSfromairflow.utils.decoratorsimportremove_task_decoratorfromairflow.utils.helpersimportprevent_duplicatesfromairflow.utils.task_groupimportTaskGroupContextfromairflow.utils.trigger_ruleimportTriggerRulefromairflow.utils.typesimportNOTSETifTYPE_CHECKING:fromsqlalchemy.ormimportSessionfromairflow.models.dagimportDAGfromairflow.models.expandinputimport(ExpandInput,OperatorExpandArgument,OperatorExpandKwargsArgument,)fromairflow.models.mappedoperatorimportValidationSourcefromairflow.utils.contextimportContextfromairflow.utils.task_groupimportTaskGroupclassExpandableFactory(Protocol):""" Protocol providing inspection against wrapped function. This is used in ``validate_expand_kwargs`` and implemented by function decorators like ``@task`` and ``@task_group``. :meta private: """function:Callable@cached_propertydeffunction_signature(self)->inspect.Signature:returninspect.signature(self.function)@cached_propertydef_mappable_function_argument_names(self)->set[str]:"""Arguments that can be mapped against."""returnset(self.function_signature.parameters)def_validate_arg_names(self,func:ValidationSource,kwargs:dict[str,Any])->None:"""Ensure that all arguments passed to operator-mapping functions are accounted for."""parameters=self.function_signature.parametersifany(v.kind==inspect.Parameter.VAR_KEYWORDforvinparameters.values()):returnkwargs_left=kwargs.copy()forarg_nameinself._mappable_function_argument_names:value=kwargs_left.pop(arg_name,NOTSET)iffunc=="expand"andvalueisnotNOTSETandnotis_mappable(value):tname=type(value).__name__raiseValueError(f"expand() got an unexpected type {tname!r} for keyword argument {arg_name!r}")iflen(kwargs_left)==1:raiseTypeError(f"{func}() got an unexpected keyword argument {next(iter(kwargs_left))!r}")elifkwargs_left:names=", ".join(repr(n)forninkwargs_left)raiseTypeError(f"{func}() got unexpected keyword arguments {names}")
[docs]defget_unique_task_id(task_id:str,dag:DAG|None=None,task_group:TaskGroup|None=None,)->str:""" Generate unique task id given a DAG (or if run in a DAG context). IDs are generated by appending a unique number to the end of the original task id. Example: task_id task_id__1 task_id__2 ... task_id__20 """dag=dagorDagContext.get_current_dag()ifnotdag:returntask_id# We need to check if we are in the context of TaskGroup as the task_id may# already be alteredtask_group=task_grouporTaskGroupContext.get_current_task_group(dag)tg_task_id=task_group.child_id(task_id)iftask_groupelsetask_idiftg_task_idnotindag.task_ids:returntask_iddef_find_id_suffixes(dag:DAG)->Iterator[int]:prefix=re2.split(r"__\d+$",tg_task_id)[0]fortask_idindag.task_ids:match=re2.match(rf"^{prefix}__(\d+)$",task_id)ifmatch:yieldint(match.group(1))yield0# Default if there's no matching task ID.core=re2.split(r"__\d+$",task_id)[0]returnf"{core}__{max(_find_id_suffixes(dag))+1}"
[docs]classDecoratedOperator(BaseOperator):""" Wraps a Python callable and captures args/kwargs when called for execution. :param python_callable: A reference to an object that is callable :param op_kwargs: a dictionary of keyword arguments that will get unpacked in your function (templated) :param op_args: a list of positional arguments that will get unpacked when calling your callable (templated) :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False. :param kwargs_to_upstream: For certain operators, we might need to upstream certain arguments that would otherwise be absorbed by the DecoratedOperator (for example python_callable for the PythonOperator). This gives a user the option to upstream kwargs as needed. """
def__init__(self,*,python_callable:Callable,task_id:str,op_args:Collection[Any]|None=None,op_kwargs:Mapping[str,Any]|None=None,kwargs_to_upstream:dict[str,Any]|None=None,**kwargs,)->None:task_id=get_unique_task_id(task_id,kwargs.get("dag"),kwargs.get("task_group"))self.python_callable=python_callablekwargs_to_upstream=kwargs_to_upstreamor{}op_args=op_argsor[]op_kwargs=op_kwargsor{}# Check the decorated function's signature. We go through the argument# list and "fill in" defaults to arguments that are known context keys,# since values for those will be provided when the task is run. Since# we're not actually running the function, None is good enough here.signature=inspect.signature(python_callable)# Don't allow context argument defaults other than None to avoid ambiguities.faulty_parameters=[param.nameforparaminsignature.parameters.values()ifparam.nameinKNOWN_CONTEXT_KEYSandparam.defaultnotin(None,inspect.Parameter.empty)]iffaulty_parameters:message=f"Context key parameter {faulty_parameters[0]} can't have a default other than None"raiseValueError(message)parameters=[param.replace(default=None)ifparam.nameinKNOWN_CONTEXT_KEYSelseparamforparaminsignature.parameters.values()]try:signature=signature.replace(parameters=parameters)exceptValueErroraserr:message=textwrap.dedent(f""" The function signature broke while assigning defaults to context key parameters. The decorator is replacing the signature > {python_callable.__name__}({', '.join(str(param)forparaminsignature.parameters.values())}) with > {python_callable.__name__}({', '.join(str(param)forparaminparameters)}) which isn't valid: {err} """)raiseValueError(message)fromerr# Check that arguments can be binded. There's a slight difference when# we do validation for task-mapping: Since there's no guarantee we can# receive enough arguments at parse time, we use bind_partial to simply# check all the arguments we know are valid. Whether these are enough# can only be known at execution time, when unmapping happens, and this# is called without the _airflow_mapped_validation_only flag.ifkwargs.get("_airflow_mapped_validation_only"):signature.bind_partial(*op_args,**op_kwargs)else:signature.bind(*op_args,**op_kwargs)self.op_args=op_argsself.op_kwargs=op_kwargssuper().__init__(task_id=task_id,**kwargs_to_upstream,**kwargs)
[docs]defexecute(self,context:Context):# todo make this more generic (move to prepare_lineage) so it deals with non taskflow operators# as wellforarginitertools.chain(self.op_args,self.op_kwargs.values()):ifisinstance(arg,Dataset):self.inlets.append(arg)return_value=super().execute(context)returnself._handle_output(return_value=return_value,context=context,xcom_push=self.xcom_push)
def_handle_output(self,return_value:Any,context:Context,xcom_push:Callable):""" Handle logic for whether a decorator needs to push a single return value or multiple return values. It sets outlets if any datasets are found in the returned value(s) :param return_value: :param context: :param xcom_push: """ifisinstance(return_value,Dataset):self.outlets.append(return_value)ifisinstance(return_value,list):foriteminreturn_value:ifisinstance(item,Dataset):self.outlets.append(item)returnreturn_valuedef_hook_apply_defaults(self,*args,**kwargs):if"python_callable"notinkwargs:returnargs,kwargspython_callable=kwargs["python_callable"]default_args=kwargs.get("default_args")or{}op_kwargs=kwargs.get("op_kwargs")or{}f_sig=inspect.signature(python_callable)forarginf_sig.parameters:ifargnotinop_kwargsandargindefault_args:op_kwargs[arg]=default_args[arg]kwargs["op_kwargs"]=op_kwargsreturnargs,kwargs
@attr.define(slots=False)class_TaskDecorator(ExpandableFactory,Generic[FParams,FReturn,OperatorSubclass]):""" Helper class for providing dynamic task mapping to decorated functions. ``task_decorator_factory`` returns an instance of this, instead of just a plain wrapped function. :meta private: """function:Callable[FParams,FReturn]=attr.ib(validator=attr.validators.is_callable())operator_class:type[OperatorSubclass]multiple_outputs:bool=attr.ib()kwargs:dict[str,Any]=attr.ib(factory=dict)decorator_name:str=attr.ib(repr=False,default="task")_airflow_is_task_decorator:ClassVar[bool]=Trueis_setup:bool=Falseis_teardown:bool=Falseon_failure_fail_dagrun:bool=False@multiple_outputs.defaultdef_infer_multiple_outputs(self):if"return"notinself.function.__annotations__:# No return type annotation, nothing to inferreturnFalsetry:# We only care about the return annotation, not anything about the parametersdeffake():...fake.__annotations__={"return":self.function.__annotations__["return"]}return_type=typing_extensions.get_type_hints(fake,self.function.__globals__).get("return",Any)exceptNameErrorase:warnings.warn(f"Cannot infer multiple_outputs for TaskFlow function {self.function.__name__!r} with forward"f" type references that are not imported. (Error was {e})",stacklevel=4,)returnFalseexceptTypeError:# Can't evaluate return type.returnFalsettype=getattr(return_type,"__origin__",return_type)returnisinstance(ttype,type)andissubclass(ttype,Mapping)def__attrs_post_init__(self):if"self"inself.function_signature.parameters:raiseTypeError(f"@{self.decorator_name} does not support methods")self.kwargs.setdefault("task_id",self.function.__name__)def__call__(self,*args:FParams.args,**kwargs:FParams.kwargs)->XComArg:ifself.is_teardown:if"trigger_rule"inself.kwargs:raiseValueError("Trigger rule not configurable for teardown tasks.")self.kwargs.update(trigger_rule=TriggerRule.ALL_DONE_SETUP_SUCCESS)on_failure_fail_dagrun=self.kwargs.pop("on_failure_fail_dagrun",self.on_failure_fail_dagrun)op=self.operator_class(python_callable=self.function,op_args=args,op_kwargs=kwargs,multiple_outputs=self.multiple_outputs,**self.kwargs,)op.is_setup=self.is_setupop.is_teardown=self.is_teardownop.on_failure_fail_dagrun=on_failure_fail_dagrunop_doc_attrs=[op.doc,op.doc_json,op.doc_md,op.doc_rst,op.doc_yaml]# Set the task's doc_md to the function's docstring if it exists and no other doc* args are set.ifself.function.__doc__andnotany(op_doc_attrs):op.doc_md=self.function.__doc__returnXComArg(op)@propertydef__wrapped__(self)->Callable[FParams,FReturn]:returnself.functiondef_validate_arg_names(self,func:ValidationSource,kwargs:dict[str,Any]):# Ensure that context variables are not shadowed.context_keys_being_mapped=KNOWN_CONTEXT_KEYS.intersection(kwargs)iflen(context_keys_being_mapped)==1:(name,)=context_keys_being_mappedraiseValueError(f"cannot call {func}() on task context variable {name!r}")elifcontext_keys_being_mapped:names=", ".join(repr(n)fornincontext_keys_being_mapped)raiseValueError(f"cannot call {func}() on task context variables {names}")super()._validate_arg_names(func,kwargs)defexpand(self,**map_kwargs:OperatorExpandArgument)->XComArg:ifnotmap_kwargs:raiseTypeError("no arguments to expand against")self._validate_arg_names("expand",map_kwargs)prevent_duplicates(self.kwargs,map_kwargs,fail_reason="mapping already partial")# Since the input is already checked at parse time, we can set strict# to False to skip the checks on execution.ifself.is_teardown:if"trigger_rule"inself.kwargs:raiseValueError("Trigger rule not configurable for teardown tasks.")self.kwargs.update(trigger_rule=TriggerRule.ALL_DONE_SETUP_SUCCESS)returnself._expand(DictOfListsExpandInput(map_kwargs),strict=False)defexpand_kwargs(self,kwargs:OperatorExpandKwargsArgument,*,strict:bool=True)->XComArg:ifisinstance(kwargs,Sequence):foriteminkwargs:ifnotisinstance(item,(XComArg,Mapping)):raiseTypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")elifnotisinstance(kwargs,XComArg):raiseTypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")returnself._expand(ListOfDictsExpandInput(kwargs),strict=strict)def_expand(self,expand_input:ExpandInput,*,strict:bool)->XComArg:ensure_xcomarg_return_value(expand_input.value)task_kwargs=self.kwargs.copy()dag=task_kwargs.pop("dag",None)orDagContext.get_current_dag()task_group=task_kwargs.pop("task_group",None)orTaskGroupContext.get_current_task_group(dag)partial_kwargs,partial_params=get_merged_defaults(dag=dag,task_group=task_group,task_params=task_kwargs.pop("params",None),task_default_args=task_kwargs.pop("default_args",None),)partial_kwargs.update(task_kwargs,is_setup=self.is_setup,is_teardown=self.is_teardown,on_failure_fail_dagrun=self.on_failure_fail_dagrun,)task_id=get_unique_task_id(partial_kwargs.pop("task_id"),dag,task_group)iftask_group:task_id=task_group.child_id(task_id)# Logic here should be kept in sync with BaseOperatorMeta.partial().if"task_concurrency"inpartial_kwargs:raiseTypeError("unexpected argument: task_concurrency")ifpartial_kwargs.get("wait_for_downstream"):partial_kwargs["depends_on_past"]=Truestart_date=timezone.convert_to_utc(partial_kwargs.pop("start_date",None))end_date=timezone.convert_to_utc(partial_kwargs.pop("end_date",None))ifpartial_kwargs.get("pool")isNone:partial_kwargs["pool"]=Pool.DEFAULT_POOL_NAMEpartial_kwargs["retries"]=parse_retries(partial_kwargs.get("retries",DEFAULT_RETRIES))partial_kwargs["retry_delay"]=coerce_timedelta(partial_kwargs.get("retry_delay",DEFAULT_RETRY_DELAY),key="retry_delay",)max_retry_delay=partial_kwargs.get("max_retry_delay")partial_kwargs["max_retry_delay"]=(max_retry_delayifmax_retry_delayisNoneelsecoerce_timedelta(max_retry_delay,key="max_retry_delay"))partial_kwargs["resources"]=coerce_resources(partial_kwargs.get("resources"))partial_kwargs.setdefault("executor_config",{})partial_kwargs.setdefault("op_args",[])partial_kwargs.setdefault("op_kwargs",{})# Mypy does not work well with a subclassed attrs class :(_MappedOperator=cast(Any,DecoratedMappedOperator)try:operator_name=self.operator_class.custom_operator_name# type: ignoreexceptAttributeError:operator_name=self.operator_class.__name__operator=_MappedOperator(operator_class=self.operator_class,expand_input=EXPAND_INPUT_EMPTY,# Don't use this; mapped values go to op_kwargs_expand_input.partial_kwargs=partial_kwargs,task_id=task_id,params=partial_params,deps=MappedOperator.deps_for(self.operator_class),operator_extra_links=self.operator_class.operator_extra_links,template_ext=self.operator_class.template_ext,template_fields=self.operator_class.template_fields,template_fields_renderers=self.operator_class.template_fields_renderers,ui_color=self.operator_class.ui_color,ui_fgcolor=self.operator_class.ui_fgcolor,is_empty=False,task_module=self.operator_class.__module__,task_type=self.operator_class.__name__,operator_name=operator_name,dag=dag,task_group=task_group,start_date=start_date,end_date=end_date,multiple_outputs=self.multiple_outputs,python_callable=self.function,op_kwargs_expand_input=expand_input,disallow_kwargs_override=strict,# Different from classic operators, kwargs passed to a taskflow# task's expand() contribute to the op_kwargs operator argument, not# the operator arguments themselves, and should expand against it.expand_input_attr="op_kwargs_expand_input",start_trigger_args=self.operator_class.start_trigger_args,start_from_trigger=self.operator_class.start_from_trigger,)returnXComArg(operator=operator)defpartial(self,**kwargs:Any)->_TaskDecorator[FParams,FReturn,OperatorSubclass]:self._validate_arg_names("partial",kwargs)old_kwargs=self.kwargs.get("op_kwargs",{})prevent_duplicates(old_kwargs,kwargs,fail_reason="duplicate partial")kwargs.update(old_kwargs)returnattr.evolve(self,kwargs={**self.kwargs,"op_kwargs":kwargs})defoverride(self,**kwargs:Any)->_TaskDecorator[FParams,FReturn,OperatorSubclass]:result=attr.evolve(self,kwargs={**self.kwargs,**kwargs})setattr(result,"is_setup",self.is_setup)setattr(result,"is_teardown",self.is_teardown)setattr(result,"on_failure_fail_dagrun",self.on_failure_fail_dagrun)returnresult@attr.define(kw_only=True,repr=False)
[docs]classDecoratedMappedOperator(MappedOperator):"""MappedOperator implementation for @task-decorated task function."""
[docs]def__attrs_post_init__(self):# The magic super() doesn't work here, so we use the explicit form.# Not using super(..., self) to work around pyupgrade bug.super(DecoratedMappedOperator,DecoratedMappedOperator).__attrs_post_init__(self)XComArg.apply_upstream_relationship(self,self.op_kwargs_expand_input.value)
def_expand_mapped_kwargs(self,context:Context,session:Session,*,include_xcom:bool)->tuple[Mapping[str,Any],set[int]]:# We only use op_kwargs_expand_input so this must always be empty.ifself.expand_inputisnotEXPAND_INPUT_EMPTY:raiseAssertionError(f"unexpected expand_input: {self.expand_input}")op_kwargs,resolved_oids=super()._expand_mapped_kwargs(context,session,include_xcom=include_xcom)return{"op_kwargs":op_kwargs},resolved_oidsdef_get_unmap_kwargs(self,mapped_kwargs:Mapping[str,Any],*,strict:bool)->dict[str,Any]:partial_op_kwargs=self.partial_kwargs["op_kwargs"]mapped_op_kwargs=mapped_kwargs["op_kwargs"]ifstrict:prevent_duplicates(partial_op_kwargs,mapped_op_kwargs,fail_reason="mapping already partial")kwargs={"multiple_outputs":self.multiple_outputs,"python_callable":self.python_callable,"op_kwargs":{**partial_op_kwargs,**mapped_op_kwargs},}returnsuper()._get_unmap_kwargs(kwargs,strict=False)
[docs]classTask(Protocol,Generic[FParams,FReturn]):""" Declaration of a @task-decorated callable for type-checking. An instance of this type inherits the call signature of the decorated function wrapped in it (not *exactly* since it actually returns an XComArg, but there's no way to express that right now), and provides two additional methods for task-mapping. This type is implemented by ``_TaskDecorator`` at runtime. """
[docs]deftask_decorator_factory(python_callable:Callable|None=None,*,multiple_outputs:bool|None=None,decorated_operator_class:type[BaseOperator],**kwargs,)->TaskDecorator:""" Generate a wrapper that wraps a function into an Airflow operator. Can be reused in a single DAG. :param python_callable: Function to decorate. :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. If set to False (default), only at most one XCom value is pushed. :param decorated_operator_class: The operator that executes the logic needed to run the python function in the correct environment. Other kwargs are directly forwarded to the underlying operator class when it's instantiated. """ifmultiple_outputsisNone:multiple_outputs=cast(bool,attr.NOTHING)ifpython_callable:decorator=_TaskDecorator(function=python_callable,multiple_outputs=multiple_outputs,operator_class=decorated_operator_class,kwargs=kwargs,)returncast(TaskDecorator,decorator)elifpython_callableisnotNone:raiseTypeError("No args allowed while using @task, use kwargs instead")defdecorator_factory(python_callable):return_TaskDecorator(function=python_callable,multiple_outputs=multiple_outputs,operator_class=decorated_operator_class,kwargs=kwargs,)returncast(TaskDecorator,decorator_factory)