Source code for airflow.providers.cncf.kubernetes.decorators.kubernetes

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import base64
import os
import pickle
import uuid
from shlex import quote
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Callable, Sequence

import dill
from kubernetes.client import models as k8s

from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from airflow.providers.cncf.kubernetes.python_kubernetes_script import (
    write_python_script,
)

if TYPE_CHECKING:
    from airflow.utils.context import Context

_PYTHON_SCRIPT_ENV = "__PYTHON_SCRIPT"
_PYTHON_INPUT_ENV = "__PYTHON_INPUT"


def _generate_decoded_command(env_var: str, file: str) -> str:
    return (
        f'python -c "import base64, os;'
        rf"x = base64.b64decode(os.environ[\"{env_var}\"]);"
        rf'f = open(\"{file}\", \"wb\"); f.write(x); f.close()"'
    )


def _read_file_contents(filename: str) -> str:
    with open(filename, "rb") as script_file:
        return base64.b64encode(script_file.read()).decode("utf-8")


class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator):
    custom_operator_name = "@task.kubernetes"

    # `cmds` and `arguments` are used internally by the operator
    template_fields: Sequence[str] = tuple(
        {"op_args", "op_kwargs", *KubernetesPodOperator.template_fields} - {"cmds", "arguments"}
    )

    # Since we won't mutate the arguments, we should just do the shallow copy
    # there are some cases we can't deepcopy the objects (e.g protobuf).
    shallow_copy_attrs: Sequence[str] = ("python_callable",)

    def __init__(self, namespace: str = "default", use_dill: bool = False, **kwargs) -> None:
        self.use_dill = use_dill
        super().__init__(
            namespace=namespace,
            name=kwargs.pop("name", f"k8s_airflow_pod_{uuid.uuid4().hex}"),
            cmds=["placeholder-command"],
            **kwargs,
        )

    def _generate_cmds(self) -> list[str]:
        script_filename = "/tmp/script.py"
        input_filename = "/tmp/script.in"
        output_filename = "/airflow/xcom/return.json"

        write_local_script_file_cmd = (
            f"{_generate_decoded_command(quote(_PYTHON_SCRIPT_ENV), quote(script_filename))}"
        )
        write_local_input_file_cmd = (
            f"{_generate_decoded_command(quote(_PYTHON_INPUT_ENV), quote(input_filename))}"
        )
        make_xcom_dir_cmd = "mkdir -p /airflow/xcom"
        exec_python_cmd = f"python {script_filename} {input_filename} {output_filename}"
        return [
            "bash",
            "-cx",
            (
                f"{write_local_script_file_cmd} && "
                f"{write_local_input_file_cmd} && "
                f"{make_xcom_dir_cmd} && "
                f"{exec_python_cmd}"
            ),
        ]

    def execute(self, context: Context):
        with TemporaryDirectory(prefix="venv") as tmp_dir:
            pickling_library = dill if self.use_dill else pickle
            script_filename = os.path.join(tmp_dir, "script.py")
            input_filename = os.path.join(tmp_dir, "script.in")

            with open(input_filename, "wb") as file:
                pickling_library.dump({"args": self.op_args, "kwargs": self.op_kwargs}, file)

            py_source = self.get_python_source()
            jinja_context = {
                "op_args": self.op_args,
                "op_kwargs": self.op_kwargs,
                "pickling_library": pickling_library.__name__,
                "python_callable": self.python_callable.__name__,
                "python_callable_source": py_source,
                "string_args_global": False,
            }
            write_python_script(jinja_context=jinja_context, filename=script_filename)

            self.env_vars = [
                *self.env_vars,
                k8s.V1EnvVar(name=_PYTHON_SCRIPT_ENV, value=_read_file_contents(script_filename)),
                k8s.V1EnvVar(name=_PYTHON_INPUT_ENV, value=_read_file_contents(input_filename)),
            ]

            self.cmds = self._generate_cmds()
            return super().execute(context)


[docs]def kubernetes_task( python_callable: Callable | None = None, multiple_outputs: bool | None = None, **kwargs, ) -> TaskDecorator: """Kubernetes operator decorator. This wraps a function to be executed in K8s using KubernetesPodOperator. Also accepts any argument that DockerOperator will via ``kwargs``. Can be reused in a single DAG. :param python_callable: Function to decorate :param multiple_outputs: if set, function return value will be unrolled to multiple XCom values. Dict will unroll to xcom values with keys as XCom keys. Defaults to False. """ return task_decorator_factory( python_callable=python_callable, multiple_outputs=multiple_outputs, decorated_operator_class=_KubernetesDecoratedOperator, **kwargs, )

Was this entry helpful?