Source code for airflow.providers.ssh.hooks.ssh

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Hook for SSH connections."""

from __future__ import annotations

import os
import warnings
from base64 import decodebytes
from functools import cached_property
from io import StringIO
from select import select
from typing import Any, Sequence

import paramiko
from deprecated import deprecated
from paramiko.config import SSH_PORT
from sshtunnel import SSHTunnelForwarder
from tenacity import Retrying, stop_after_attempt, wait_fixed, wait_random

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
from airflow.utils.platform import getuser
from airflow.utils.types import NOTSET, ArgNotSet

[docs]TIMEOUT_DEFAULT = 10
[docs]CMD_TIMEOUT = 10
[docs]class SSHHook(BaseHook): """ Execute remote commands with Paramiko. .. seealso:: https://github.com/paramiko/paramiko This hook also lets you create ssh tunnel and serve as basis for SFTP file transfer. :param ssh_conn_id: :ref:`ssh connection id<howto/connection:ssh>` from airflow Connections from where all the required parameters can be fetched like username, password or key_file, though priority is given to the params passed during init. :param remote_host: remote host to connect :param username: username to connect to the remote_host :param password: password of the username to connect to the remote_host :param key_file: path to key file to use to connect to the remote_host :param port: port of remote host to connect (Default is paramiko SSH_PORT) :param conn_timeout: timeout (in seconds) for the attempt to connect to the remote_host. The default is 10 seconds. If provided, it will replace the `conn_timeout` which was predefined in the connection of `ssh_conn_id`. :param timeout: (Deprecated). timeout for the attempt to connect to the remote_host. Use conn_timeout instead. :param cmd_timeout: timeout (in seconds) for executing the command. The default is 10 seconds. Nullable, `None` means no timeout. If provided, it will replace the `cmd_timeout` which was predefined in the connection of `ssh_conn_id`. :param keepalive_interval: send a keepalive packet to remote host every keepalive_interval seconds :param banner_timeout: timeout to wait for banner from the server in seconds :param disabled_algorithms: dictionary mapping algorithm type to an iterable of algorithm identifiers, which will be disabled for the lifetime of the transport :param ciphers: list of ciphers to use in order of preference :param auth_timeout: timeout (in seconds) for the attempt to authenticate with the remote_host """ # List of classes to try loading private keys as, ordered (roughly) by most common to least common _pkey_loaders: Sequence[type[paramiko.PKey]] = ( paramiko.RSAKey, paramiko.ECDSAKey, paramiko.Ed25519Key, paramiko.DSSKey, ) _host_key_mappings = { "rsa": paramiko.RSAKey, "dss": paramiko.DSSKey, "ecdsa": paramiko.ECDSAKey, "ed25519": paramiko.Ed25519Key, }
[docs] conn_name_attr = "ssh_conn_id"
[docs] default_conn_name = "ssh_default"
[docs] conn_type = "ssh"
[docs] hook_name = "SSH"
@classmethod
[docs] def get_ui_field_behaviour(cls) -> dict[str, Any]: """Return custom UI field behaviour for SSH connection.""" return { "hidden_fields": ["schema"], "relabeling": { "login": "Username", }, }
def __init__( self, ssh_conn_id: str | None = None, remote_host: str = "", username: str | None = None, password: str | None = None, key_file: str | None = None, port: int | None = None, timeout: int | None = None, conn_timeout: int | None = None, cmd_timeout: int | ArgNotSet | None = NOTSET, keepalive_interval: int = 30, banner_timeout: float = 30.0, disabled_algorithms: dict | None = None, ciphers: list[str] | None = None, auth_timeout: int | None = None, ) -> None: super().__init__() self.ssh_conn_id = ssh_conn_id self.remote_host = remote_host self.username = username self.password = password self.key_file = key_file self.pkey = None self.port = port self.timeout = timeout self.conn_timeout = conn_timeout self.cmd_timeout = cmd_timeout self.keepalive_interval = keepalive_interval self.banner_timeout = banner_timeout self.disabled_algorithms = disabled_algorithms self.ciphers = ciphers self.host_proxy_cmd = None self.auth_timeout = auth_timeout # Default values, overridable from Connection self.compress = True self.no_host_key_check = True self.allow_host_key_change = False self.host_key = None self.look_for_keys = True # Placeholder for deprecated __enter__ self.client: paramiko.SSHClient | None = None # Use connection to override defaults if self.ssh_conn_id is not None: conn = self.get_connection(self.ssh_conn_id) if self.username is None: self.username = conn.login if self.password is None: self.password = conn.password if not self.remote_host: self.remote_host = conn.host if self.port is None: self.port = conn.port if conn.extra is not None: extra_options = conn.extra_dejson if "key_file" in extra_options and self.key_file is None: self.key_file = extra_options.get("key_file") private_key = extra_options.get("private_key") private_key_passphrase = extra_options.get("private_key_passphrase") if private_key: self.pkey = self._pkey_from_private_key(private_key, passphrase=private_key_passphrase) if "timeout" in extra_options: warnings.warn( "Extra option `timeout` is deprecated." "Please use `conn_timeout` instead." "The old option `timeout` will be removed in a future version.", category=AirflowProviderDeprecationWarning, stacklevel=2, ) self.timeout = int(extra_options["timeout"]) if "conn_timeout" in extra_options and self.conn_timeout is None: self.conn_timeout = int(extra_options["conn_timeout"]) if "cmd_timeout" in extra_options and self.cmd_timeout is NOTSET: if extra_options["cmd_timeout"]: self.cmd_timeout = int(extra_options["cmd_timeout"]) else: self.cmd_timeout = None if "compress" in extra_options and str(extra_options["compress"]).lower() == "false": self.compress = False host_key = extra_options.get("host_key") no_host_key_check = extra_options.get("no_host_key_check") if no_host_key_check is not None: no_host_key_check = str(no_host_key_check).lower() == "true" if host_key is not None and no_host_key_check: raise ValueError("Must check host key when provided") self.no_host_key_check = no_host_key_check if ( "allow_host_key_change" in extra_options and str(extra_options["allow_host_key_change"]).lower() == "true" ): self.allow_host_key_change = True if ( "look_for_keys" in extra_options and str(extra_options["look_for_keys"]).lower() == "false" ): self.look_for_keys = False if "disabled_algorithms" in extra_options: self.disabled_algorithms = extra_options.get("disabled_algorithms") if "ciphers" in extra_options: self.ciphers = extra_options.get("ciphers") if host_key is not None: if host_key.startswith("ssh-"): key_type, host_key = host_key.split(None)[:2] key_constructor = self._host_key_mappings[key_type[4:]] else: key_constructor = paramiko.RSAKey decoded_host_key = decodebytes(host_key.encode("utf-8")) self.host_key = key_constructor(data=decoded_host_key) self.no_host_key_check = False if self.timeout: warnings.warn( "Parameter `timeout` is deprecated." "Please use `conn_timeout` instead." "The old option `timeout` will be removed in a future version.", category=AirflowProviderDeprecationWarning, stacklevel=2, ) if self.conn_timeout is None: self.conn_timeout = self.timeout if self.timeout else TIMEOUT_DEFAULT if self.cmd_timeout is NOTSET: self.cmd_timeout = CMD_TIMEOUT if self.pkey and self.key_file: raise AirflowException( "Params key_file and private_key both provided. Must provide no more than one." ) if not self.remote_host: raise AirflowException("Missing required param: remote_host") # Auto detecting username values from system if not self.username: self.log.debug( "username to ssh to host: %s is not specified for connection id" " %s. Using system's default provided by getpass.getuser()", self.remote_host, self.ssh_conn_id, ) self.username = getuser() user_ssh_config_filename = os.path.expanduser("~/.ssh/config") if os.path.isfile(user_ssh_config_filename): ssh_conf = paramiko.SSHConfig() with open(user_ssh_config_filename) as config_fd: ssh_conf.parse(config_fd) host_info = ssh_conf.lookup(self.remote_host) if host_info and host_info.get("proxycommand"): self.host_proxy_cmd = host_info["proxycommand"] if not (self.password or self.key_file): if host_info and host_info.get("identityfile"): self.key_file = host_info["identityfile"][0] self.port = self.port or SSH_PORT @cached_property
[docs] def host_proxy(self) -> paramiko.ProxyCommand | None: cmd = self.host_proxy_cmd return paramiko.ProxyCommand(cmd) if cmd else None
[docs] def get_conn(self) -> paramiko.SSHClient: """Establish an SSH connection to the remote host.""" if self.client: transport = self.client.get_transport() if transport and transport.is_active(): # Return the existing connection return self.client self.log.debug("Creating SSH client for conn_id: %s", self.ssh_conn_id) client = paramiko.SSHClient() if self.allow_host_key_change: self.log.warning( "Remote Identification Change is not verified. " "This won't protect against Man-In-The-Middle attacks" ) # to avoid BadHostKeyException, skip loading host keys client.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy) else: client.load_system_host_keys() if self.no_host_key_check: self.log.warning("No Host Key Verification. This won't protect against Man-In-The-Middle attacks") client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) # nosec B507 # to avoid BadHostKeyException, skip loading and saving host keys known_hosts = os.path.expanduser("~/.ssh/known_hosts") if not self.allow_host_key_change and os.path.isfile(known_hosts): client.load_host_keys(known_hosts) elif self.host_key is not None: # Get host key from connection extra if it not set or None then we fallback to system host keys client_host_keys = client.get_host_keys() if self.port == SSH_PORT: client_host_keys.add(self.remote_host, self.host_key.get_name(), self.host_key) else: client_host_keys.add( f"[{self.remote_host}]:{self.port}", self.host_key.get_name(), self.host_key ) connect_kwargs: dict[str, Any] = { "hostname": self.remote_host, "username": self.username, "timeout": self.conn_timeout, "compress": self.compress, "port": self.port, "sock": self.host_proxy, "look_for_keys": self.look_for_keys, "banner_timeout": self.banner_timeout, "auth_timeout": self.auth_timeout, } if self.password: password = self.password.strip() connect_kwargs.update(password=password) if self.pkey: connect_kwargs.update(pkey=self.pkey) if self.key_file: connect_kwargs.update(key_filename=self.key_file) if self.disabled_algorithms: connect_kwargs.update(disabled_algorithms=self.disabled_algorithms) def log_before_sleep(retry_state): return self.log.info( "Failed to connect. Sleeping before retry attempt %d", retry_state.attempt_number ) for attempt in Retrying( reraise=True, wait=wait_fixed(3) + wait_random(0, 2), stop=stop_after_attempt(3), before_sleep=log_before_sleep, ): with attempt: client.connect(**connect_kwargs) if self.keepalive_interval: # MyPy check ignored because "paramiko" isn't well-typed. The `client.get_transport()` returns # type "Transport | None" and item "None" has no attribute "set_keepalive". client.get_transport().set_keepalive(self.keepalive_interval) # type: ignore[union-attr] if self.ciphers: # MyPy check ignored because "paramiko" isn't well-typed. The `client.get_transport()` returns # type "Transport | None" and item "None" has no method `get_security_options`". client.get_transport().get_security_options().ciphers = self.ciphers # type: ignore[union-attr] self.client = client return client
@deprecated( reason=( "The contextmanager of SSHHook is deprecated." "Please use get_conn() as a contextmanager instead." "This method will be removed in Airflow 2.0" ), category=AirflowProviderDeprecationWarning, )
[docs] def __enter__(self) -> SSHHook: """Return an instance of SSHHook when the `with` statement is used.""" return self
[docs] def __exit__(self, exc_type, exc_val, exc_tb) -> None: """Clear ssh client after exiting the `with` statement block.""" if self.client is not None: self.client.close() self.client = None
[docs] def get_tunnel( self, remote_port: int, remote_host: str = "localhost", local_port: int | None = None ) -> SSHTunnelForwarder: """ Create a tunnel between two hosts. This is conceptually similar to ``ssh -L <LOCAL_PORT>:host:<REMOTE_PORT>``. :param remote_port: The remote port to create a tunnel to :param remote_host: The remote host to create a tunnel to (default localhost) :param local_port: The local port to attach the tunnel to :return: sshtunnel.SSHTunnelForwarder object """ if local_port: local_bind_address: tuple[str, int] | tuple[str] = ("localhost", local_port) else: local_bind_address = ("localhost",) tunnel_kwargs = { "ssh_port": self.port, "ssh_username": self.username, "ssh_pkey": self.key_file or self.pkey, "ssh_proxy": self.host_proxy, "local_bind_address": local_bind_address, "remote_bind_address": (remote_host, remote_port), "logger": self.log, } if self.password: password = self.password.strip() tunnel_kwargs.update( ssh_password=password, ) else: tunnel_kwargs.update( host_pkey_directories=None, ) client = SSHTunnelForwarder(self.remote_host, **tunnel_kwargs) return client
@deprecated( reason=( "SSHHook.create_tunnel is deprecated, Please " "use get_tunnel() instead. But please note that the " "order of the parameters have changed. " "This method will be removed in Airflow 2.0" ), category=AirflowProviderDeprecationWarning, )
[docs] def create_tunnel( self, local_port: int, remote_port: int, remote_host: str = "localhost" ) -> SSHTunnelForwarder: """ Create a tunnel for SSH connection [Deprecated]. :param local_port: local port number :param remote_port: remote port number :param remote_host: remote host """ return self.get_tunnel(remote_port, remote_host, local_port)
def _pkey_from_private_key(self, private_key: str, passphrase: str | None = None) -> paramiko.PKey: """ Create an appropriate Paramiko key for a given private key. :param private_key: string containing private key :return: ``paramiko.PKey`` appropriate for given key :raises AirflowException: if key cannot be read """ if len(private_key.splitlines()) < 2: raise AirflowException("Key must have BEGIN and END header/footer on separate lines.") for pkey_class in self._pkey_loaders: try: key = pkey_class.from_private_key(StringIO(private_key), password=passphrase) # Test it actually works. If Paramiko loads an openssh generated key, sometimes it will # happily load it as the wrong type, only to fail when actually used. key.sign_ssh_data(b"") return key except (paramiko.ssh_exception.SSHException, ValueError): continue raise AirflowException( "Private key provided cannot be read by paramiko." "Ensure key provided is valid for one of the following" "key formats: RSA, DSS, ECDSA, or Ed25519" )
[docs] def exec_ssh_client_command( self, ssh_client: paramiko.SSHClient, command: str, get_pty: bool, environment: dict | None, timeout: int | ArgNotSet | None = NOTSET, ) -> tuple[int, bytes, bytes]: self.log.info("Running command: %s", command) cmd_timeout: int | None if not isinstance(timeout, ArgNotSet): cmd_timeout = timeout elif not isinstance(self.cmd_timeout, ArgNotSet): cmd_timeout = self.cmd_timeout else: cmd_timeout = CMD_TIMEOUT del timeout # Too easy to confuse with "timedout" below. # set timeout taken as params stdin, stdout, stderr = ssh_client.exec_command( command=command, get_pty=get_pty, timeout=cmd_timeout, environment=environment, ) # get channels channel = stdout.channel # closing stdin stdin.close() channel.shutdown_write() agg_stdout = b"" agg_stderr = b"" # capture any initial output in case channel is closed already stdout_buffer_length = len(stdout.channel.in_buffer) if stdout_buffer_length > 0: agg_stdout += stdout.channel.recv(stdout_buffer_length) timedout = False # read from both stdout and stderr while not channel.closed or channel.recv_ready() or channel.recv_stderr_ready(): readq, _, _ = select([channel], [], [], cmd_timeout) if cmd_timeout is not None: timedout = not readq for recv in readq: if recv.recv_ready(): output = stdout.channel.recv(len(recv.in_buffer)) agg_stdout += output for line in output.decode("utf-8", "replace").strip("\n").splitlines(): self.log.info(line) if recv.recv_stderr_ready(): output = stderr.channel.recv_stderr(len(recv.in_stderr_buffer)) agg_stderr += output for line in output.decode("utf-8", "replace").strip("\n").splitlines(): self.log.warning(line) if ( stdout.channel.exit_status_ready() and not stderr.channel.recv_stderr_ready() and not stdout.channel.recv_ready() ) or timedout: stdout.channel.shutdown_read() try: stdout.channel.close() except Exception: # there is a race that when shutdown_read has been called and when # you try to close the connection, the socket is already closed # We should ignore such errors (but we should log them with warning) self.log.warning("Ignoring exception on close", exc_info=True) break stdout.close() stderr.close() if timedout: raise AirflowException("SSH command timed out") exit_status = stdout.channel.recv_exit_status() return exit_status, agg_stdout, agg_stderr
[docs] def test_connection(self) -> tuple[bool, str]: """Test the ssh connection by execute remote bash commands.""" try: with self.get_conn() as conn: conn.exec_command("pwd") return True, "Connection successfully tested" except Exception as e: return False, str(e)

Was this entry helpful?