Source code for airflow.models.xcom_arg

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict, List, Optional, Sequence, Union

from airflow.exceptions import AirflowException
from airflow.models.baseoperator import BaseOperator  # pylint: disable=R0401
from airflow.models.taskmixin import TaskMixin
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.utils.edgemodifier import EdgeModifier

[docs]class XComArg(TaskMixin): """ Class that represents a XCom push from a previous operator. Defaults to "return_value" as only key. Current implementation supports xcomarg >> op xcomarg << op op >> xcomarg (by BaseOperator code) op << xcomarg (by BaseOperator code) **Example**: The moment you get a result from any operator (decorated or regular) you can :: any_op = AnyOperator() xcomarg = XComArg(any_op) # or equivalently xcomarg = any_op.output my_op = MyOperator() my_op >> xcomarg This object can be used in legacy Operators via Jinja. **Example**: You can make this result to be part of any generated string :: any_op = AnyOperator() xcomarg = any_op.output op1 = MyOperator(my_text_message=f"the value is {xcomarg}") op2 = MyOperator(my_text_message=f"the value is {xcomarg['topic']}") :param operator: operator to which the XComArg belongs to :type operator: airflow.models.baseoperator.BaseOperator :param key: key value which is used for xcom_pull (key in the XCom table) :type key: str """ def __init__(self, operator: BaseOperator, key: str = XCOM_RETURN_KEY): self._operator = operator self._key = key
[docs] def __eq__(self, other): return self.operator == other.operator and self.key == other.key
[docs] def __getitem__(self, item): """Implements xcomresult['some_result_key']""" return XComArg(operator=self.operator, key=item)
[docs] def __str__(self): """ Backward compatibility for old-style jinja used in Airflow Operators **Example**: to use XComArg at BashOperator:: BashOperator(cmd=f"... { xcomarg } ...") :return: """ xcom_pull_kwargs = [ f"task_ids='{self.operator.task_id}'", f"dag_id='{self.operator.dag.dag_id}'", ] if self.key is not None: xcom_pull_kwargs.append(f"key='{self.key}'") xcom_pull_kwargs = ", ".join(xcom_pull_kwargs) # {{{{ are required for escape {{ in f-string xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_kwargs}) }}}}" return xcom_pull
[docs] def operator(self) -> BaseOperator: """Returns operator of this XComArg.""" return self._operator
[docs] def roots(self) -> List[BaseOperator]: """Required by TaskMixin""" return [self._operator]
[docs] def leaves(self) -> List[BaseOperator]: """Required by TaskMixin""" return [self._operator]
[docs] def key(self) -> str: """Returns keys of this XComArg""" return self._key
[docs] def set_upstream( self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]], edge_modifier: Optional[EdgeModifier] = None, ): """Proxy to underlying operator set_upstream method. Required by TaskMixin.""" self.operator.set_upstream(task_or_task_list, edge_modifier)
[docs] def set_downstream( self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]], edge_modifier: Optional[EdgeModifier] = None, ): """Proxy to underlying operator set_downstream method. Required by TaskMixin.""" self.operator.set_downstream(task_or_task_list, edge_modifier)
[docs] def resolve(self, context: Dict) -> Any: """ Pull XCom value for the existing arg. This method is run during ``op.execute()`` in respectable context. """ resolved_value = self.operator.xcom_pull( context=context, task_ids=[self.operator.task_id], key=str(self.key), # xcom_pull supports only key as str dag_id=self.operator.dag.dag_id, ) if not resolved_value: raise AirflowException( f'XComArg result from {self.operator.task_id} at {self.operator.dag.dag_id} ' f'with key="{self.key}"" is not found!' ) resolved_value = resolved_value[0] return resolved_value

