## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License."""Hook for SSH connections."""importosimportsysimportwarningsfrombase64importdecodebytesfromioimportStringIOfromselectimportselectfromtypingimportAny,Dict,Optional,Sequence,Tuple,Type,Unionimportparamikofromparamiko.configimportSSH_PORTfromsshtunnelimportSSHTunnelForwarderfromtenacityimportRetrying,stop_after_attempt,wait_fixed,wait_randomifsys.version_info>=(3,8):fromfunctoolsimportcached_propertyelse:fromcached_propertyimportcached_propertyfromairflow.exceptionsimportAirflowExceptionfromairflow.hooks.baseimportBaseHooktry:fromairflow.utils.platformimportgetuserexceptImportError:fromgetpassimportgetuser# type: ignore[misc]
[docs]classSSHHook(BaseHook):""" Hook for ssh remote execution using Paramiko. ref: https://github.com/paramiko/paramiko This hook also lets you create ssh tunnel and serve as basis for SFTP file transfer :param ssh_conn_id: :ref:`ssh connection id<howto/connection:ssh>` from airflow Connections from where all the required parameters can be fetched like username, password or key_file. Thought the priority is given to the param passed during init :param remote_host: remote host to connect :param username: username to connect to the remote_host :param password: password of the username to connect to the remote_host :param key_file: path to key file to use to connect to the remote_host :param port: port of remote host to connect (Default is paramiko SSH_PORT) :param conn_timeout: timeout (in seconds) for the attempt to connect to the remote_host. The default is 10 seconds. If provided, it will replace the `conn_timeout` which was predefined in the connection of `ssh_conn_id`. :param timeout: (Deprecated). timeout for the attempt to connect to the remote_host. Use conn_timeout instead. :param keepalive_interval: send a keepalive packet to remote host every keepalive_interval seconds :param banner_timeout: timeout to wait for banner from the server in seconds """# List of classes to try loading private keys as, ordered (roughly) by most common to least common_pkey_loaders:Sequence[Type[paramiko.PKey]]=(paramiko.RSAKey,paramiko.ECDSAKey,paramiko.Ed25519Key,paramiko.DSSKey,)_host_key_mappings={'rsa':paramiko.RSAKey,'dss':paramiko.DSSKey,'ecdsa':paramiko.ECDSAKey,'ed25519':paramiko.Ed25519Key,}
[docs]defget_ui_field_behaviour()->Dict[str,Any]:"""Returns custom field behaviour"""return{"hidden_fields":['schema'],"relabeling":{'login':'Username',
},}def__init__(self,ssh_conn_id:Optional[str]=None,remote_host:str='',username:Optional[str]=None,password:Optional[str]=None,key_file:Optional[str]=None,port:Optional[int]=None,timeout:Optional[int]=None,conn_timeout:Optional[int]=None,keepalive_interval:int=30,banner_timeout:float=30.0,)->None:super().__init__()self.ssh_conn_id=ssh_conn_idself.remote_host=remote_hostself.username=usernameself.password=passwordself.key_file=key_fileself.pkey=Noneself.port=portself.timeout=timeoutself.conn_timeout=conn_timeoutself.keepalive_interval=keepalive_intervalself.banner_timeout=banner_timeoutself.host_proxy_cmd=None# Default values, overridable from Connectionself.compress=Trueself.no_host_key_check=Trueself.allow_host_key_change=Falseself.host_key=Noneself.look_for_keys=True# Placeholder for deprecated __enter__self.client:Optional[paramiko.SSHClient]=None# Use connection to override defaultsifself.ssh_conn_idisnotNone:conn=self.get_connection(self.ssh_conn_id)ifself.usernameisNone:self.username=conn.loginifself.passwordisNone:self.password=conn.passwordifnotself.remote_host:self.remote_host=conn.hostifself.portisNone:self.port=conn.portifconn.extraisnotNone:extra_options=conn.extra_dejsonif"key_file"inextra_optionsandself.key_fileisNone:self.key_file=extra_options.get("key_file")private_key=extra_options.get('private_key')private_key_passphrase=extra_options.get('private_key_passphrase')ifprivate_key:self.pkey=self._pkey_from_private_key(private_key,passphrase=private_key_passphrase)if"timeout"inextra_options:warnings.warn('Extra option `timeout` is deprecated.''Please use `conn_timeout` instead.''The old option `timeout` will be removed in a future version.',DeprecationWarning,stacklevel=2,)self.timeout=int(extra_options['timeout'])if"conn_timeout"inextra_optionsandself.conn_timeoutisNone:self.conn_timeout=int(extra_options['conn_timeout'])if"compress"inextra_optionsandstr(extra_options["compress"]).lower()=='false':self.compress=Falsehost_key=extra_options.get("host_key")no_host_key_check=extra_options.get("no_host_key_check")ifno_host_key_checkisnotNone:no_host_key_check=str(no_host_key_check).lower()=="true"ifhost_keyisnotNoneandno_host_key_check:raiseValueError("Must check host key when provided")self.no_host_key_check=no_host_key_checkif("allow_host_key_change"inextra_optionsandstr(extra_options["allow_host_key_change"]).lower()=='true'):self.allow_host_key_change=Trueif("look_for_keys"inextra_optionsandstr(extra_options["look_for_keys"]).lower()=='false'):self.look_for_keys=Falseifhost_keyisnotNone:ifhost_key.startswith("ssh-"):key_type,host_key=host_key.split(None)[:2]key_constructor=self._host_key_mappings[key_type[4:]]else:key_constructor=paramiko.RSAKeydecoded_host_key=decodebytes(host_key.encode('utf-8'))self.host_key=key_constructor(data=decoded_host_key)self.no_host_key_check=Falseifself.timeout:warnings.warn('Parameter `timeout` is deprecated.''Please use `conn_timeout` instead.''The old option `timeout` will be removed in a future version.',DeprecationWarning,stacklevel=1,)ifself.conn_timeoutisNone:self.conn_timeout=self.timeoutifself.timeoutelseTIMEOUT_DEFAULTifself.pkeyandself.key_file:raiseAirflowException("Params key_file and private_key both provided. Must provide no more than one.")ifnotself.remote_host:raiseAirflowException("Missing required param: remote_host")# Auto detecting username values from systemifnotself.username:self.log.debug("username to ssh to host: %s is not specified for connection id"" %s. Using system's default provided by getpass.getuser()",self.remote_host,self.ssh_conn_id,)self.username=getuser()user_ssh_config_filename=os.path.expanduser('~/.ssh/config')ifos.path.isfile(user_ssh_config_filename):ssh_conf=paramiko.SSHConfig()withopen(user_ssh_config_filename)asconfig_fd:ssh_conf.parse(config_fd)host_info=ssh_conf.lookup(self.remote_host)ifhost_infoandhost_info.get('proxycommand'):self.host_proxy_cmd=host_info['proxycommand']ifnot(self.passwordorself.key_file):ifhost_infoandhost_info.get('identityfile'):self.key_file=host_info['identityfile'][0]self.port=self.portorSSH_PORT@cached_property
[docs]defget_conn(self)->paramiko.SSHClient:""" Opens a ssh connection to the remote host. :rtype: paramiko.client.SSHClient """self.log.debug('Creating SSH client for conn_id: %s',self.ssh_conn_id)client=paramiko.SSHClient()ifnotself.allow_host_key_change:self.log.warning("Remote Identification Change is not verified. ""This won't protect against Man-In-The-Middle attacks")client.load_system_host_keys()ifself.no_host_key_check:self.log.warning("No Host Key Verification. This won't protect against Man-In-The-Middle attacks")# Default is RejectPolicyclient.set_missing_host_key_policy(paramiko.AutoAddPolicy())else:ifself.host_keyisnotNone:client_host_keys=client.get_host_keys()ifself.port==SSH_PORT:client_host_keys.add(self.remote_host,self.host_key.get_name(),self.host_key)else:client_host_keys.add(f"[{self.remote_host}]:{self.port}",self.host_key.get_name(),self.host_key)else:pass# will fallback to system host keys if none explicitly specified in conn extraconnect_kwargs:Dict[str,Any]=dict(hostname=self.remote_host,username=self.username,timeout=self.conn_timeout,compress=self.compress,port=self.port,sock=self.host_proxy,look_for_keys=self.look_for_keys,banner_timeout=self.banner_timeout,)ifself.password:password=self.password.strip()connect_kwargs.update(password=password)ifself.pkey:connect_kwargs.update(pkey=self.pkey)ifself.key_file:connect_kwargs.update(key_filename=self.key_file)log_before_sleep=lambdaretry_state:self.log.info("Failed to connect. Sleeping before retry attempt %d",retry_state.attempt_number)forattemptinRetrying(reraise=True,wait=wait_fixed(3)+wait_random(0,2),stop=stop_after_attempt(3),before_sleep=log_before_sleep,):withattempt:client.connect(**connect_kwargs)ifself.keepalive_interval:# MyPy check ignored because "paramiko" isn't well-typed. The `client.get_transport()` returns# type "Optional[Transport]" and item "None" has no attribute "set_keepalive".client.get_transport().set_keepalive(self.keepalive_interval)# type: ignore[union-attr]self.client=clientreturnclient
[docs]def__enter__(self)->'SSHHook':warnings.warn('The contextmanager of SSHHook is deprecated.''Please use get_conn() as a contextmanager instead.''This method will be removed in Airflow 2.0',category=DeprecationWarning,)returnself
[docs]defget_tunnel(self,remote_port:int,remote_host:str="localhost",local_port:Optional[int]=None)->SSHTunnelForwarder:""" Creates a tunnel between two hosts. Like ssh -L <LOCAL_PORT>:host:<REMOTE_PORT>. :param remote_port: The remote port to create a tunnel to :param remote_host: The remote host to create a tunnel to (default localhost) :param local_port: The local port to attach the tunnel to :return: sshtunnel.SSHTunnelForwarder object """iflocal_port:local_bind_address:Union[Tuple[str,int],Tuple[str]]=('localhost',local_port)else:local_bind_address=('localhost',)tunnel_kwargs=dict(ssh_port=self.port,ssh_username=self.username,ssh_pkey=self.key_fileorself.pkey,ssh_proxy=self.host_proxy,local_bind_address=local_bind_address,remote_bind_address=(remote_host,remote_port),logger=self.log,)ifself.password:password=self.password.strip()tunnel_kwargs.update(ssh_password=password,)else:tunnel_kwargs.update(host_pkey_directories=None,)client=SSHTunnelForwarder(self.remote_host,**tunnel_kwargs)returnclient
[docs]defcreate_tunnel(self,local_port:int,remote_port:int,remote_host:str="localhost")->SSHTunnelForwarder:""" Creates tunnel for SSH connection [Deprecated]. :param local_port: local port number :param remote_port: remote port number :param remote_host: remote host :return: """warnings.warn('SSHHook.create_tunnel is deprecated, Please''use get_tunnel() instead. But please note that the''order of the parameters have changed''This method will be removed in Airflow 2.0',category=DeprecationWarning,)returnself.get_tunnel(remote_port,remote_host,local_port)
def_pkey_from_private_key(self,private_key:str,passphrase:Optional[str]=None)->paramiko.PKey:""" Creates appropriate paramiko key for given private key :param private_key: string containing private key :return: ``paramiko.PKey`` appropriate for given key :raises AirflowException: if key cannot be read """forpkey_classinself._pkey_loaders:try:key=pkey_class.from_private_key(StringIO(private_key),password=passphrase)# Test it actually works. If Paramiko loads an openssh generated key, sometimes it will# happily load it as the wrong type, only to fail when actually used.key.sign_ssh_data(b'')returnkeyexcept(paramiko.ssh_exception.SSHException,ValueError):continueraiseAirflowException('Private key provided cannot be read by paramiko.''Ensure key provided is valid for one of the following''key formats: RSA, DSS, ECDSA, or Ed25519')
[docs]defexec_ssh_client_command(self,ssh_client:paramiko.SSHClient,command:str,get_pty:bool,environment:Optional[dict],timeout:Optional[int],)->Tuple[int,bytes,bytes]:self.log.info("Running command: %s",command)# set timeout taken as paramsstdin,stdout,stderr=ssh_client.exec_command(command=command,get_pty=get_pty,timeout=timeout,environment=environment,)# get channelschannel=stdout.channel# closing stdinstdin.close()channel.shutdown_write()agg_stdout=b''agg_stderr=b''# capture any initial output in case channel is closed alreadystdout_buffer_length=len(stdout.channel.in_buffer)ifstdout_buffer_length>0:agg_stdout+=stdout.channel.recv(stdout_buffer_length)# read from both stdout and stderrwhilenotchannel.closedorchannel.recv_ready()orchannel.recv_stderr_ready():readq,_,_=select([channel],[],[],timeout)forrecvinreadq:ifrecv.recv_ready():line=stdout.channel.recv(len(recv.in_buffer))agg_stdout+=lineself.log.info(line.decode('utf-8','replace').strip('\n'))ifrecv.recv_stderr_ready():line=stderr.channel.recv_stderr(len(recv.in_stderr_buffer))agg_stderr+=lineself.log.warning(line.decode('utf-8','replace').strip('\n'))if(stdout.channel.exit_status_ready()andnotstderr.channel.recv_stderr_ready()andnotstdout.channel.recv_ready()):stdout.channel.shutdown_read()try:stdout.channel.close()exceptException:# there is a race that when shutdown_read has been called and when# you try to close the connection, the socket is already closed# We should ignore such errors (but we should log them with warning)self.log.warning("Ignoring exception on close",exc_info=True)breakstdout.close()stderr.close()exit_status=stdout.channel.recv_exit_status()returnexit_status,agg_stdout,agg_stderr