# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import getpass
import os
import warnings
import paramiko
from paramiko.config import SSH_PORT
from sshtunnel import SSHTunnelForwarder
from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
from airflow.utils.log.logging_mixin import LoggingMixin
[docs]class SSHHook(BaseHook, LoggingMixin):
"""
Hook for ssh remote execution using Paramiko.
ref: https://github.com/paramiko/paramiko
This hook also lets you create ssh tunnel and serve as basis for SFTP file transfer
:param ssh_conn_id: connection id from airflow Connections from where all the required
parameters can be fetched like username, password or key_file.
Thought the priority is given to the param passed during init
:type ssh_conn_id: str
:param remote_host: remote host to connect
:type remote_host: str
:param username: username to connect to the remote_host
:type username: str
:param password: password of the username to connect to the remote_host
:type password: str
:param key_file: key file to use to connect to the remote_host.
:type key_file: str
:param port: port of remote host to connect (Default is paramiko SSH_PORT)
:type port: int
:param timeout: timeout for the attempt to connect to the remote_host.
:type timeout: int
:param keepalive_interval: send a keepalive packet to remote host every
keepalive_interval seconds
:type keepalive_interval: int
"""
def __init__(self,
ssh_conn_id=None,
remote_host=None,
username=None,
password=None,
key_file=None,
port=None,
timeout=10,
keepalive_interval=30
):
super(SSHHook, self).__init__(ssh_conn_id)
self.ssh_conn_id = ssh_conn_id
self.remote_host = remote_host
self.username = username
self.password = password
self.key_file = key_file
self.port = port
self.timeout = timeout
self.keepalive_interval = keepalive_interval
# Default values, overridable from Connection
self.compress = True
self.no_host_key_check = True
self.host_proxy = None
# Placeholder for deprecated __enter__
self.client = None
# Use connection to override defaults
if self.ssh_conn_id is not None:
conn = self.get_connection(self.ssh_conn_id)
if self.username is None:
self.username = conn.login
if self.password is None:
self.password = conn.password
if self.remote_host is None:
self.remote_host = conn.host
if self.port is None:
self.port = conn.port
if conn.extra is not None:
extra_options = conn.extra_dejson
self.key_file = extra_options.get("key_file")
if "timeout" in extra_options:
self.timeout = int(extra_options["timeout"], 10)
if "compress" in extra_options\
and str(extra_options["compress"]).lower() == 'false':
self.compress = False
if "no_host_key_check" in extra_options\
and\
str(extra_options["no_host_key_check"]).lower() == 'false':
self.no_host_key_check = False
if not self.remote_host:
raise AirflowException("Missing required param: remote_host")
# Auto detecting username values from system
if not self.username:
self.log.debug(
"username to ssh to host: %s is not specified for connection id"
" %s. Using system's default provided by getpass.getuser()",
self.remote_host, self.ssh_conn_id
)
self.username = getpass.getuser()
user_ssh_config_filename = os.path.expanduser('~/.ssh/config')
if os.path.isfile(user_ssh_config_filename):
ssh_conf = paramiko.SSHConfig()
ssh_conf.parse(open(user_ssh_config_filename))
host_info = ssh_conf.lookup(self.remote_host)
if host_info and host_info.get('proxycommand'):
self.host_proxy = paramiko.ProxyCommand(host_info.get('proxycommand'))
if not (self.password or self.key_file):
if host_info and host_info.get('identityfile'):
self.key_file = host_info.get('identityfile')[0]
self.port = self.port or SSH_PORT
[docs] def get_conn(self):
"""
Opens a ssh connection to the remote host.
:return paramiko.SSHClient object
"""
self.log.debug('Creating SSH client for conn_id: %s', self.ssh_conn_id)
client = paramiko.SSHClient()
client.load_system_host_keys()
if self.no_host_key_check:
# Default is RejectPolicy
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
if self.password and self.password.strip():
client.connect(hostname=self.remote_host,
username=self.username,
password=self.password,
key_filename=self.key_file,
timeout=self.timeout,
compress=self.compress,
port=self.port,
sock=self.host_proxy)
else:
client.connect(hostname=self.remote_host,
username=self.username,
key_filename=self.key_file,
timeout=self.timeout,
compress=self.compress,
port=self.port,
sock=self.host_proxy)
if self.keepalive_interval:
client.get_transport().set_keepalive(self.keepalive_interval)
self.client = client
return client
def __enter__(self):
warnings.warn('The contextmanager of SSHHook is deprecated.'
'Please use get_conn() as a contextmanager instead.'
'This method will be removed in Airflow 2.0',
category=DeprecationWarning)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.client is not None:
self.client.close()
self.client = None
[docs] def get_tunnel(self, remote_port, remote_host="localhost", local_port=None):
"""
Creates a tunnel between two hosts. Like ssh -L <LOCAL_PORT>:host:<REMOTE_PORT>.
:param remote_port: The remote port to create a tunnel to
:type remote_port: int
:param remote_host: The remote host to create a tunnel to (default localhost)
:type remote_host: str
:param local_port: The local port to attach the tunnel to
:type local_port: int
:return: sshtunnel.SSHTunnelForwarder object
"""
if local_port:
local_bind_address = ('localhost', local_port)
else:
local_bind_address = ('localhost',)
if self.password and self.password.strip():
client = SSHTunnelForwarder(self.remote_host,
ssh_port=self.port,
ssh_username=self.username,
ssh_password=self.password,
ssh_pkey=self.key_file,
ssh_proxy=self.host_proxy,
local_bind_address=local_bind_address,
remote_bind_address=(remote_host, remote_port),
logger=self.log)
else:
client = SSHTunnelForwarder(self.remote_host,
ssh_port=self.port,
ssh_username=self.username,
ssh_pkey=self.key_file,
ssh_proxy=self.host_proxy,
local_bind_address=local_bind_address,
remote_bind_address=(remote_host, remote_port),
host_pkey_directories=[],
logger=self.log)
return client
def create_tunnel(self, local_port, remote_port=None, remote_host="localhost"):
warnings.warn('SSHHook.create_tunnel is deprecated, Please'
'use get_tunnel() instead. But please note that the'
'order of the parameters have changed'
'This method will be removed in Airflow 2.0',
category=DeprecationWarning)
return self.get_tunnel(remote_port, remote_host, local_port)