Source code for airflow.providers.papermill.hooks.kernel

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import typing

from jupyter_client import AsyncKernelManager
from papermill.clientwrap import PapermillNotebookClient
from papermill.engines import NBClientEngine
from papermill.utils import merge_kwargs, remove_args
from traitlets import Unicode

from airflow.hooks.base import BaseHook

[docs]JUPYTER_KERNEL_SHELL_PORT = 60316
[docs]JUPYTER_KERNEL_IOPUB_PORT = 60317
[docs]JUPYTER_KERNEL_STDIN_PORT = 60318
[docs]JUPYTER_KERNEL_CONTROL_PORT = 60319
[docs]JUPYTER_KERNEL_HB_PORT = 60320
[docs]REMOTE_KERNEL_ENGINE = "remote_kernel_engine"
[docs]class KernelConnection: """Class to represent kernel connection object."""
[docs] ip: str
[docs] shell_port: int
[docs] iopub_port: int
[docs] stdin_port: int
[docs] control_port: int
[docs] hb_port: int
[docs] session_key: str
[docs]class KernelHook(BaseHook): """ The KernelHook can be used to interact with remote jupyter kernel. Takes kernel host/ip from connection and refers to jupyter kernel ports and session_key from ``extra`` field. :param kernel_conn_id: connection that has kernel host/ip """
[docs] conn_name_attr = "kernel_conn_id"
[docs] default_conn_name = "jupyter_kernel_default"
[docs] conn_type = "jupyter_kernel"
[docs] hook_name = "Jupyter Kernel"
def __init__(self, kernel_conn_id: str = default_conn_name, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.kernel_conn = self.get_connection(kernel_conn_id) register_remote_kernel_engine()
[docs] def get_conn(self) -> KernelConnection: kernel_connection = KernelConnection() kernel_connection.ip = self.kernel_conn.host kernel_connection.shell_port = self.kernel_conn.extra_dejson.get( "shell_port", JUPYTER_KERNEL_SHELL_PORT ) kernel_connection.iopub_port = self.kernel_conn.extra_dejson.get( "iopub_port", JUPYTER_KERNEL_IOPUB_PORT ) kernel_connection.stdin_port = self.kernel_conn.extra_dejson.get( "stdin_port", JUPYTER_KERNEL_STDIN_PORT ) kernel_connection.control_port = self.kernel_conn.extra_dejson.get( "control_port", JUPYTER_KERNEL_CONTROL_PORT ) kernel_connection.hb_port = self.kernel_conn.extra_dejson.get("hb_port", JUPYTER_KERNEL_HB_PORT) kernel_connection.session_key = self.kernel_conn.extra_dejson.get("session_key", "") return kernel_connection
[docs]def register_remote_kernel_engine(): """Register ``RemoteKernelEngine`` papermill engine.""" from papermill.engines import papermill_engines papermill_engines.register(REMOTE_KERNEL_ENGINE, RemoteKernelEngine)
[docs]class RemoteKernelManager(AsyncKernelManager): """Jupyter kernel manager that connects to a remote kernel."""
[docs] session_key = Unicode("", config=True, help="Session key to connect to remote kernel")
@property
[docs] def has_kernel(self) -> bool: return True
async def _async_is_alive(self) -> bool: return True
[docs] def shutdown_kernel(self, now: bool = False, restart: bool = False): pass
[docs] def client(self, **kwargs: typing.Any): """Create a client configured to connect to our kernel.""" kernel_client = super().client(**kwargs) # load connection info to set session_key config: dict[str, int | str | bytes] = dict( ip=self.ip, shell_port=self.shell_port, iopub_port=self.iopub_port, stdin_port=self.stdin_port, control_port=self.control_port, hb_port=self.hb_port, key=self.session_key, transport="tcp", signature_scheme="hmac-sha256", ) kernel_client.load_connection_info(config) return kernel_client
[docs]class RemoteKernelEngine(NBClientEngine): """Papermill engine to use ``RemoteKernelManager`` to connect to remote kernel and execute notebook.""" @classmethod
[docs] def execute_managed_notebook( cls, nb_man, kernel_name, log_output=False, stdout_file=None, stderr_file=None, start_timeout=60, execution_timeout=None, **kwargs, ): """Perform the actual execution of the parameterized notebook locally.""" km = RemoteKernelManager() km.ip = kwargs["kernel_ip"] km.shell_port = kwargs["kernel_shell_port"] km.iopub_port = kwargs["kernel_iopub_port"] km.stdin_port = kwargs["kernel_stdin_port"] km.control_port = kwargs["kernel_control_port"] km.hb_port = kwargs["kernel_hb_port"] km.ip = kwargs["kernel_ip"] km.session_key = kwargs["kernel_session_key"] # Exclude parameters that named differently downstream safe_kwargs = remove_args(["timeout", "startup_timeout"], **kwargs) final_kwargs = merge_kwargs( safe_kwargs, timeout=execution_timeout if execution_timeout else kwargs.get("timeout"), startup_timeout=start_timeout, log_output=False, stdout_file=stdout_file, stderr_file=stderr_file, ) return PapermillNotebookClient(nb_man, km=km, **final_kwargs).execute()

Was this entry helpful?