Source code for airflow.providers.ssh.hooks.ssh

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Hook for SSH connections."""

from __future__ import annotations

import os
from base64 import decodebytes
from collections.abc import Sequence
from functools import cached_property
from io import StringIO
from select import select
from typing import Any

import paramiko
from paramiko.config import SSH_PORT
from tenacity import Retrying, stop_after_attempt, wait_fixed, wait_random

from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException, BaseHook
from airflow.providers.ssh.tunnel import AsyncSSHTunnel, SSHTunnel
from airflow.utils.platform import getuser

try:
    from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
except ImportError:
    from airflow.utils.types import NOTSET, ArgNotSet  # type: ignore[attr-defined,no-redef]
try:
    from airflow.sdk.definitions._internal.types import is_arg_set
except ImportError:

[docs] def is_arg_set(value): # type: ignore[misc,no-redef] return value is not NOTSET
[docs] CMD_TIMEOUT = 10
[docs] class SSHHook(BaseHook): """ Execute remote commands with Paramiko. .. seealso:: https://github.com/paramiko/paramiko This hook also lets you create ssh tunnel and serve as basis for SFTP file transfer. :param ssh_conn_id: :ref:`ssh connection id<howto/connection:ssh>` from airflow Connections from where all the required parameters can be fetched like username, password or key_file, though priority is given to the params passed during init. :param remote_host: remote host to connect :param username: username to connect to the remote_host :param password: password of the username to connect to the remote_host :param key_file: path to key file to use to connect to the remote_host :param port: port of remote host to connect (Default is paramiko SSH_PORT) :param conn_timeout: timeout (in seconds) for the attempt to connect to the remote_host. The default is 10 seconds. If provided, it will replace the `conn_timeout` which was predefined in the connection of `ssh_conn_id`. :param cmd_timeout: timeout (in seconds) for executing the command. The default is 10 seconds. Nullable, `None` means no timeout. If provided, it will replace the `cmd_timeout` which was predefined in the connection of `ssh_conn_id`. :param keepalive_interval: send a keepalive packet to remote host every keepalive_interval seconds :param banner_timeout: timeout to wait for banner from the server in seconds :param disabled_algorithms: dictionary mapping algorithm type to an iterable of algorithm identifiers, which will be disabled for the lifetime of the transport :param ciphers: list of ciphers to use in order of preference :param auth_timeout: timeout (in seconds) for the attempt to authenticate with the remote_host :param conn_retry_attempts: number of times to attempt the initial SSH connection before giving up (default 3). Raising this helps when many tasks target the same SSH server at once and some connections are transiently refused (e.g. ``sshd`` ``MaxStartups`` throttling). """ # List of classes to try loading private keys as, ordered (roughly) by most common to least common _pkey_loaders: Sequence[type[paramiko.PKey]] = ( paramiko.RSAKey, paramiko.ECDSAKey, paramiko.Ed25519Key, paramiko.DSSKey, ) _host_key_mappings = { "rsa": paramiko.RSAKey, "dss": paramiko.DSSKey, "ecdsa": paramiko.ECDSAKey, "ed25519": paramiko.Ed25519Key, }
[docs] conn_name_attr = "ssh_conn_id"
[docs] default_conn_name = "ssh_default"
[docs] conn_type = "ssh"
[docs] hook_name = "SSH"
@classmethod
[docs] def get_ui_field_behaviour(cls) -> dict[str, Any]: """Return custom UI field behaviour for SSH connection.""" return { "hidden_fields": ["schema"], "relabeling": { "login": "Username", }, }
def __init__( self, ssh_conn_id: str | None = None, remote_host: str = "", username: str | None = None, password: str | None = None, key_file: str | None = None, port: int | None = None, conn_timeout: int | None = None, cmd_timeout: float | ArgNotSet | None = NOTSET, keepalive_interval: int = 30, banner_timeout: float = 30.0, disabled_algorithms: dict | None = None, ciphers: list[str] | None = None, auth_timeout: int | None = None, host_proxy_cmd: str | None = None, conn_retry_attempts: int = 3, ) -> None: super().__init__()
[docs] self.ssh_conn_id = ssh_conn_id
[docs] self.conn_retry_attempts = max(1, conn_retry_attempts)
[docs] self.remote_host = remote_host
[docs] self.username = username
[docs] self.password = password
[docs] self.key_file = key_file
[docs] self.pkey = None
[docs] self.port = port
[docs] self.conn_timeout = conn_timeout
[docs] self.cmd_timeout = cmd_timeout
[docs] self.keepalive_interval = keepalive_interval
[docs] self.banner_timeout = banner_timeout
[docs] self.disabled_algorithms = disabled_algorithms
[docs] self.ciphers = ciphers
[docs] self.host_proxy_cmd = host_proxy_cmd
[docs] self.auth_timeout = auth_timeout
# Default values, overridable from Connection
[docs] self.compress = True
[docs] self.no_host_key_check = True
[docs] self.allow_host_key_change = False
[docs] self.host_key = None
[docs] self.look_for_keys = True
# Placeholder for future cached connection
[docs] self.client: paramiko.SSHClient | None = None
# Use connection to override defaults if self.ssh_conn_id is not None: conn = self.get_connection(self.ssh_conn_id) if self.username is None: self.username = conn.login if self.password is None: self.password = conn.password if not self.remote_host: if conn.host: self.remote_host = conn.host if self.port is None: self.port = conn.port if conn.extra is not None: extra_options = conn.extra_dejson if "key_file" in extra_options and self.key_file is None: self.key_file = extra_options.get("key_file") private_key = extra_options.get("private_key") private_key_passphrase = extra_options.get("private_key_passphrase") if private_key: self.pkey = self._pkey_from_private_key(private_key, passphrase=private_key_passphrase) if "conn_timeout" in extra_options and self.conn_timeout is None: self.conn_timeout = int(extra_options["conn_timeout"]) if "cmd_timeout" in extra_options and self.cmd_timeout is NOTSET: if extra_options["cmd_timeout"]: self.cmd_timeout = float(extra_options["cmd_timeout"]) else: self.cmd_timeout = None if "compress" in extra_options and str(extra_options["compress"]).lower() == "false": self.compress = False host_key = extra_options.get("host_key") no_host_key_check = extra_options.get("no_host_key_check") if no_host_key_check is not None: no_host_key_check = str(no_host_key_check).lower() == "true" if host_key is not None and no_host_key_check: raise ValueError("Must check host key when provided") self.no_host_key_check = no_host_key_check if ( "allow_host_key_change" in extra_options and str(extra_options["allow_host_key_change"]).lower() == "true" ): self.allow_host_key_change = True if ( "look_for_keys" in extra_options and str(extra_options["look_for_keys"]).lower() == "false" ): self.look_for_keys = False if "disabled_algorithms" in extra_options: self.disabled_algorithms = extra_options.get("disabled_algorithms") if "ciphers" in extra_options: self.ciphers = extra_options.get("ciphers") if host_key is not None: if host_key.startswith("ssh-"): key_type, host_key = host_key.split(None)[:2] key_constructor = self._host_key_mappings[key_type[4:]] else: key_constructor = paramiko.RSAKey decoded_host_key = decodebytes(host_key.encode("utf-8")) self.host_key = key_constructor(data=decoded_host_key) self.no_host_key_check = False if self.cmd_timeout is NOTSET: self.cmd_timeout = CMD_TIMEOUT if self.pkey and self.key_file: raise AirflowException( "Params key_file and private_key both provided. Must provide no more than one." ) if not self.remote_host: raise AirflowException("Missing required param: remote_host") # Auto detecting username values from system if not self.username: self.log.debug( "username to ssh to host: %s is not specified for connection id" " %s. Using system's default provided by getpass.getuser()", self.remote_host, self.ssh_conn_id, ) self.username = getuser() user_ssh_config_filename = os.path.expanduser("~/.ssh/config") if os.path.isfile(user_ssh_config_filename): ssh_conf = paramiko.SSHConfig() with open(user_ssh_config_filename) as config_fd: ssh_conf.parse(config_fd) host_info = ssh_conf.lookup(self.remote_host) if host_info and host_info.get("proxycommand") and not self.host_proxy_cmd: self.host_proxy_cmd = host_info["proxycommand"] if not (self.password or self.key_file): if host_info and host_info.get("identityfile"): self.key_file = host_info["identityfile"][0] self.port = self.port or SSH_PORT @cached_property
[docs] def host_proxy(self) -> paramiko.ProxyCommand | None: cmd = self.host_proxy_cmd return paramiko.ProxyCommand(cmd) if cmd else None
[docs] def get_conn(self) -> paramiko.SSHClient: """Establish an SSH connection to the remote host.""" if self.client: transport = self.client.get_transport() if transport and transport.is_active(): # Return the existing connection return self.client self.log.debug("Creating SSH client for conn_id: %s", self.ssh_conn_id) client = paramiko.SSHClient() if self.allow_host_key_change: self.log.warning( "Remote Identification Change is not verified. " "This won't protect against Man-In-The-Middle attacks" ) # to avoid BadHostKeyException, skip loading host keys client.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy) else: client.load_system_host_keys() if self.no_host_key_check: self.log.warning("No Host Key Verification. This won't protect against Man-In-The-Middle attacks") client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) # nosec B507 # to avoid BadHostKeyException, skip loading and saving host keys known_hosts = os.path.expanduser("~/.ssh/known_hosts") if not self.allow_host_key_change and os.path.isfile(known_hosts): client.load_host_keys(known_hosts) elif self.host_key is not None: # Get host key from connection extra if it not set or None then we fallback to system host keys client_host_keys = client.get_host_keys() if self.port == SSH_PORT: client_host_keys.add(self.remote_host, self.host_key.get_name(), self.host_key) else: client_host_keys.add( f"[{self.remote_host}]:{self.port}", self.host_key.get_name(), self.host_key ) connect_kwargs: dict[str, Any] = { "hostname": self.remote_host, "username": self.username, "timeout": self.conn_timeout, "compress": self.compress, "port": self.port, "sock": self.host_proxy, "look_for_keys": self.look_for_keys, "banner_timeout": self.banner_timeout, "auth_timeout": self.auth_timeout, } if self.password: password = self.password.strip() connect_kwargs.update(password=password) if self.pkey: connect_kwargs.update(pkey=self.pkey) if self.key_file: connect_kwargs.update(key_filename=self.key_file) if self.disabled_algorithms: connect_kwargs.update(disabled_algorithms=self.disabled_algorithms) def log_before_sleep(retry_state): return self.log.info( "Failed to connect. Sleeping before retry attempt %d", retry_state.attempt_number ) for attempt in Retrying( reraise=True, wait=wait_fixed(3) + wait_random(0, 2), stop=stop_after_attempt(self.conn_retry_attempts), before_sleep=log_before_sleep, ): with attempt: client.connect(**connect_kwargs) if self.keepalive_interval: # MyPy check ignored because "paramiko" isn't well-typed. The `client.get_transport()` returns # type "Transport | None" and item "None" has no attribute "set_keepalive". client.get_transport().set_keepalive(self.keepalive_interval) # type: ignore[union-attr] if self.ciphers: # MyPy check ignored because "paramiko" isn't well-typed. The `client.get_transport()` returns # type "Transport | None" and item "None" has no method `get_security_options`". client.get_transport().get_security_options().ciphers = self.ciphers # type: ignore[union-attr] self.client = client return client
[docs] def get_tunnel( self, remote_port: int, remote_host: str = "localhost", local_port: int | None = None ) -> SSHTunnel: """ Create a local port-forwarding tunnel through the SSH connection. This is conceptually similar to ``ssh -L <LOCAL_PORT>:<remote_host>:<REMOTE_PORT>``. The returned ``SSHTunnel`` should be used as a context manager:: with hook.get_tunnel(remote_port=5432) as tunnel: connect_to("localhost", tunnel.local_bind_port) The ``.start()`` / ``.stop()`` methods still work but are deprecated. .. versionchanged:: 4.4.0 Returns ``SSHTunnel`` instead of ``sshtunnel.SSHTunnelForwarder``. The tunnel now reuses the hook's SSH connection (``get_conn()``) instead of establishing a separate one. :param remote_port: The remote port to create a tunnel to :param remote_host: The remote host to create a tunnel to (default localhost) :param local_port: The local port to attach the tunnel to (None for ephemeral) :return: SSHTunnel instance """ ssh_client = self.get_conn() return SSHTunnel( ssh_client=ssh_client, remote_host=remote_host, remote_port=remote_port, local_port=local_port, logger=self.log, )
def _pkey_from_private_key(self, private_key: str, passphrase: str | None = None) -> paramiko.PKey: """ Create an appropriate Paramiko key for a given private key. :param private_key: string containing private key :return: ``paramiko.PKey`` appropriate for given key :raises AirflowException: if key cannot be read """ if len(private_key.splitlines()) < 2: raise AirflowException("Key must have BEGIN and END header/footer on separate lines.") for pkey_class in self._pkey_loaders: try: key = pkey_class.from_private_key(StringIO(private_key), password=passphrase) # Test it actually works. If Paramiko loads an openssh generated key, sometimes it will # happily load it as the wrong type, only to fail when actually used. key.sign_ssh_data(b"") return key except (paramiko.ssh_exception.SSHException, ValueError): continue raise AirflowException( "Private key provided cannot be read by paramiko." "Ensure key provided is valid for one of the following" "key formats: RSA, DSS, ECDSA, or Ed25519" )
[docs] def exec_ssh_client_command( self, ssh_client: paramiko.SSHClient, command: str, get_pty: bool, environment: dict | None, timeout: float | ArgNotSet | None = NOTSET, ) -> tuple[int, bytes, bytes]: self.log.info("Running command: %s", command) cmd_timeout: float | None if is_arg_set(timeout): cmd_timeout = timeout elif is_arg_set(self.cmd_timeout): cmd_timeout = self.cmd_timeout else: cmd_timeout = CMD_TIMEOUT del timeout # Too easy to confuse with "timedout" below. # set timeout taken as params stdin, stdout, stderr = ssh_client.exec_command( command=command, get_pty=get_pty, timeout=cmd_timeout, environment=environment, ) # get channels channel = stdout.channel # closing stdin stdin.close() channel.shutdown_write() agg_stdout = b"" agg_stderr = b"" # capture any initial output in case channel is closed already stdout_buffer_length = len(stdout.channel.in_buffer) if stdout_buffer_length > 0: agg_stdout += stdout.channel.recv(stdout_buffer_length) timedout = False # read from both stdout and stderr while not channel.closed or channel.recv_ready() or channel.recv_stderr_ready(): readq, _, _ = select([channel], [], [], cmd_timeout) if cmd_timeout is not None: timedout = not readq for recv in readq: if recv.recv_ready(): output = stdout.channel.recv(len(recv.in_buffer)) agg_stdout += output for line in output.decode("utf-8", "replace").strip("\n").splitlines(): self.log.info(line) if recv.recv_stderr_ready(): output = stderr.channel.recv_stderr(len(recv.in_stderr_buffer)) agg_stderr += output for line in output.decode("utf-8", "replace").strip("\n").splitlines(): self.log.warning(line) if ( stdout.channel.exit_status_ready() and not stderr.channel.recv_stderr_ready() and not stdout.channel.recv_ready() ) or timedout: stdout.channel.shutdown_read() try: stdout.channel.close() except Exception: # there is a race that when shutdown_read has been called and when # you try to close the connection, the socket is already closed # We should ignore such errors (but we should log them with warning) self.log.warning("Ignoring exception on close", exc_info=True) break stdout.close() stderr.close() if timedout: raise AirflowException("SSH command timed out") exit_status = stdout.channel.recv_exit_status() return exit_status, agg_stdout, agg_stderr
[docs] def test_connection(self) -> tuple[bool, str]: """Test the ssh connection by execute remote bash commands.""" try: with self.get_conn() as conn: conn.exec_command("pwd") return True, "Connection successfully tested" except Exception as e: return False, str(e)
[docs] class SSHHookAsync(BaseHook): """ Asynchronous SSH hook using asyncssh for use in triggers. This hook provides async SSH connectivity for deferrable operators and their triggers. :param ssh_conn_id: SSH connection ID from Airflow Connections :param host: hostname of the SSH server (overrides connection) :param port: port of the SSH server (overrides connection) :param username: username for authentication (overrides connection) :param password: password for authentication (overrides connection) :param known_hosts: path to known_hosts file. Defaults to ``~/.ssh/known_hosts``. :param key_file: path to private key file for authentication :param passphrase: passphrase for the private key :param private_key: private key content as string """
[docs] conn_name_attr = "ssh_conn_id"
[docs] default_conn_name = "ssh_default"
[docs] conn_type = "ssh"
[docs] hook_name = "SSH"
[docs] default_known_hosts = "~/.ssh/known_hosts"
def __init__( self, ssh_conn_id: str = default_conn_name, host: str | None = None, port: int | None = None, username: str | None = None, password: str | None = None, known_hosts: str = default_known_hosts, key_file: str = "", passphrase: str = "", private_key: str = "", keepalive_interval: int = 30, ) -> None: super().__init__()
[docs] self.ssh_conn_id = ssh_conn_id
[docs] self.host = host
[docs] self.port = port
[docs] self.username = username
[docs] self.password = password
[docs] self.known_hosts: bytes | str = os.path.expanduser(known_hosts)
[docs] self.key_file = key_file
[docs] self.passphrase = passphrase
[docs] self.private_key = private_key
[docs] self.keepalive_interval = keepalive_interval
def _parse_extras(self, conn: Any) -> None: """Parse extra fields from the connection into instance fields.""" extra_options = conn.extra_dejson if "key_file" in extra_options and self.key_file == "": self.key_file = extra_options["key_file"] if "known_hosts" in extra_options: expanded_default = os.path.expanduser(self.default_known_hosts) if self.known_hosts == expanded_default: self.known_hosts = extra_options["known_hosts"] if "passphrase" in extra_options or "private_key_passphrase" in extra_options: self.passphrase = extra_options.get("passphrase") or extra_options.get( "private_key_passphrase", "" ) if "private_key" in extra_options: self.private_key = extra_options["private_key"] host_key = extra_options.get("host_key") nhkc_raw = extra_options.get("no_host_key_check") no_host_key_check = str(nhkc_raw).lower() == "true" if nhkc_raw is not None else True if host_key is not None and no_host_key_check: raise ValueError("Host key check was skipped, but `host_key` value was given") if no_host_key_check: self.log.warning("No Host Key Verification. This won't protect against Man-In-The-Middle attacks") self.known_hosts = "none" elif host_key is not None: self.known_hosts = f"{conn.host} {host_key}".encode() async def _get_conn(self): """ Asynchronously connect to the SSH server. Returns an asyncssh SSHClientConnection that can be used to run commands. """ import asyncssh conn = await get_async_connection(self.ssh_conn_id) if conn.extra is not None: self._parse_extras(conn) def _get_value(self_val, conn_val, default=None): if self_val is not None: return self_val if conn_val is not None: return conn_val return default conn_config: dict = { "host": _get_value(self.host, conn.host), "port": _get_value(self.port, conn.port, SSH_PORT), "username": _get_value(self.username, conn.login), "password": _get_value(self.password, conn.password), } if self.key_file: conn_config["client_keys"] = self.key_file if self.known_hosts: if isinstance(self.known_hosts, str) and self.known_hosts.lower() == "none": conn_config["known_hosts"] = None else: conn_config["known_hosts"] = self.known_hosts if self.private_key: _private_key = asyncssh.import_private_key(self.private_key, self.passphrase) conn_config["client_keys"] = [_private_key] if self.passphrase: conn_config["passphrase"] = self.passphrase if self.keepalive_interval: # The trigger holds one connection for the whole job; a keepalive stops idle # NAT/firewall timeouts from silently dropping it between long poll intervals. conn_config["keepalive_interval"] = self.keepalive_interval ssh_client_conn = await asyncssh.connect(**conn_config) return ssh_client_conn
[docs] async def get_conn(self): """ Open an asyncssh connection that can be reused for multiple commands. Unlike :meth:`run_command`, the returned connection is **not** closed automatically; the caller owns its lifecycle (e.g. ``async with await hook.get_conn() as conn: ...`` or an explicit ``conn.close()``). Reusing one connection avoids a new TCP/SSH handshake per command, which matters when many tasks poll the same SSH server. """ return await self._get_conn()
[docs] async def run_command(self, command: str, timeout: float | None = None) -> tuple[int, str, str]: """ Execute a command on the remote host asynchronously. :param command: The command to execute :param timeout: Optional timeout in seconds :return: Tuple of (exit_code, stdout, stderr) """ async with await self._get_conn() as ssh_conn: result = await ssh_conn.run(command, timeout=timeout, check=False) return result.exit_status or 0, result.stdout or "", result.stderr or ""
[docs] async def get_tunnel( self, remote_port: int, remote_host: str = "localhost", local_port: int | None = None ) -> AsyncSSHTunnel: """ Create an async local port-forwarding tunnel through the SSH connection. Usage:: async with await hook.get_tunnel(remote_port=5432) as tunnel: connect_to("localhost", tunnel.local_bind_port) :param remote_port: The remote port to create a tunnel to :param remote_host: The remote host to create a tunnel to (default localhost) :param local_port: The local port to attach the tunnel to (None for ephemeral) :return: AsyncSSHTunnel instance """ ssh_conn = await self._get_conn() return AsyncSSHTunnel( ssh_conn=ssh_conn, remote_host=remote_host, remote_port=remote_port, local_port=local_port, )
[docs] async def run_command_output(self, command: str, timeout: float | None = None) -> str: """ Execute a command and return stdout. :param command: The command to execute :param timeout: Optional timeout in seconds :return: stdout as string """ _, stdout, _ = await self.run_command(command, timeout=timeout) return stdout

Was this entry helpful?