## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License."""Hook for SSH connections."""from__future__importannotationsimportosfrombase64importdecodebytesfromcollections.abcimportSequencefromfunctoolsimportcached_propertyfromioimportStringIOfromselectimportselectfromtypingimportAnyimportparamikofromparamiko.configimportSSH_PORTfromsshtunnelimportSSHTunnelForwarderfromtenacityimportRetrying,stop_after_attempt,wait_fixed,wait_randomfromairflow.exceptionsimportAirflowExceptionfromairflow.hooks.baseimportBaseHookfromairflow.utils.platformimportgetuserfromairflow.utils.typesimportNOTSET,ArgNotSet
[docs]classSSHHook(BaseHook):""" Execute remote commands with Paramiko. .. seealso:: https://github.com/paramiko/paramiko This hook also lets you create ssh tunnel and serve as basis for SFTP file transfer. :param ssh_conn_id: :ref:`ssh connection id<howto/connection:ssh>` from airflow Connections from where all the required parameters can be fetched like username, password or key_file, though priority is given to the params passed during init. :param remote_host: remote host to connect :param username: username to connect to the remote_host :param password: password of the username to connect to the remote_host :param key_file: path to key file to use to connect to the remote_host :param port: port of remote host to connect (Default is paramiko SSH_PORT) :param conn_timeout: timeout (in seconds) for the attempt to connect to the remote_host. The default is 10 seconds. If provided, it will replace the `conn_timeout` which was predefined in the connection of `ssh_conn_id`. :param cmd_timeout: timeout (in seconds) for executing the command. The default is 10 seconds. Nullable, `None` means no timeout. If provided, it will replace the `cmd_timeout` which was predefined in the connection of `ssh_conn_id`. :param keepalive_interval: send a keepalive packet to remote host every keepalive_interval seconds :param banner_timeout: timeout to wait for banner from the server in seconds :param disabled_algorithms: dictionary mapping algorithm type to an iterable of algorithm identifiers, which will be disabled for the lifetime of the transport :param ciphers: list of ciphers to use in order of preference :param auth_timeout: timeout (in seconds) for the attempt to authenticate with the remote_host """# List of classes to try loading private keys as, ordered (roughly) by most common to least common_pkey_loaders:Sequence[type[paramiko.PKey]]=(paramiko.RSAKey,paramiko.ECDSAKey,paramiko.Ed25519Key,paramiko.DSSKey,)_host_key_mappings={"rsa":paramiko.RSAKey,"dss":paramiko.DSSKey,"ecdsa":paramiko.ECDSAKey,"ed25519":paramiko.Ed25519Key,}
[docs]defget_ui_field_behaviour(cls)->dict[str,Any]:"""Return custom UI field behaviour for SSH connection."""return{"hidden_fields":["schema"],"relabeling":{"login":"Username",},}
# Use connection to override defaultsifself.ssh_conn_idisnotNone:conn=self.get_connection(self.ssh_conn_id)ifself.usernameisNone:self.username=conn.loginifself.passwordisNone:self.password=conn.passwordifnotself.remote_host:self.remote_host=conn.hostifself.portisNone:self.port=conn.portifconn.extraisnotNone:extra_options=conn.extra_dejsonif"key_file"inextra_optionsandself.key_fileisNone:self.key_file=extra_options.get("key_file")private_key=extra_options.get("private_key")private_key_passphrase=extra_options.get("private_key_passphrase")ifprivate_key:self.pkey=self._pkey_from_private_key(private_key,passphrase=private_key_passphrase)if"conn_timeout"inextra_optionsandself.conn_timeoutisNone:self.conn_timeout=int(extra_options["conn_timeout"])if"cmd_timeout"inextra_optionsandself.cmd_timeoutisNOTSET:ifextra_options["cmd_timeout"]:self.cmd_timeout=float(extra_options["cmd_timeout"])else:self.cmd_timeout=Noneif"compress"inextra_optionsandstr(extra_options["compress"]).lower()=="false":self.compress=Falsehost_key=extra_options.get("host_key")no_host_key_check=extra_options.get("no_host_key_check")ifno_host_key_checkisnotNone:no_host_key_check=str(no_host_key_check).lower()=="true"ifhost_keyisnotNoneandno_host_key_check:raiseValueError("Must check host key when provided")self.no_host_key_check=no_host_key_checkif("allow_host_key_change"inextra_optionsandstr(extra_options["allow_host_key_change"]).lower()=="true"):self.allow_host_key_change=Trueif("look_for_keys"inextra_optionsandstr(extra_options["look_for_keys"]).lower()=="false"):self.look_for_keys=Falseif"disabled_algorithms"inextra_options:self.disabled_algorithms=extra_options.get("disabled_algorithms")if"ciphers"inextra_options:self.ciphers=extra_options.get("ciphers")ifhost_keyisnotNone:ifhost_key.startswith("ssh-"):key_type,host_key=host_key.split(None)[:2]key_constructor=self._host_key_mappings[key_type[4:]]else:key_constructor=paramiko.RSAKeydecoded_host_key=decodebytes(host_key.encode("utf-8"))self.host_key=key_constructor(data=decoded_host_key)self.no_host_key_check=Falseifself.cmd_timeoutisNOTSET:self.cmd_timeout=CMD_TIMEOUTifself.pkeyandself.key_file:raiseAirflowException("Params key_file and private_key both provided. Must provide no more than one.")ifnotself.remote_host:raiseAirflowException("Missing required param: remote_host")# Auto detecting username values from systemifnotself.username:self.log.debug("username to ssh to host: %s is not specified for connection id"" %s. Using system's default provided by getpass.getuser()",self.remote_host,self.ssh_conn_id,)self.username=getuser()user_ssh_config_filename=os.path.expanduser("~/.ssh/config")ifos.path.isfile(user_ssh_config_filename):ssh_conf=paramiko.SSHConfig()withopen(user_ssh_config_filename)asconfig_fd:ssh_conf.parse(config_fd)host_info=ssh_conf.lookup(self.remote_host)ifhost_infoandhost_info.get("proxycommand")andnotself.host_proxy_cmd:self.host_proxy_cmd=host_info["proxycommand"]ifnot(self.passwordorself.key_file):ifhost_infoandhost_info.get("identityfile"):self.key_file=host_info["identityfile"][0]self.port=self.portorSSH_PORT@cached_property
[docs]defget_conn(self)->paramiko.SSHClient:"""Establish an SSH connection to the remote host."""ifself.client:transport=self.client.get_transport()iftransportandtransport.is_active():# Return the existing connectionreturnself.clientself.log.debug("Creating SSH client for conn_id: %s",self.ssh_conn_id)client=paramiko.SSHClient()ifself.allow_host_key_change:self.log.warning("Remote Identification Change is not verified. ""This won't protect against Man-In-The-Middle attacks")# to avoid BadHostKeyException, skip loading host keysclient.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy)else:client.load_system_host_keys()ifself.no_host_key_check:self.log.warning("No Host Key Verification. This won't protect against Man-In-The-Middle attacks")client.set_missing_host_key_policy(paramiko.AutoAddPolicy())# nosec B507# to avoid BadHostKeyException, skip loading and saving host keysknown_hosts=os.path.expanduser("~/.ssh/known_hosts")ifnotself.allow_host_key_changeandos.path.isfile(known_hosts):client.load_host_keys(known_hosts)elifself.host_keyisnotNone:# Get host key from connection extra if it not set or None then we fallback to system host keysclient_host_keys=client.get_host_keys()ifself.port==SSH_PORT:client_host_keys.add(self.remote_host,self.host_key.get_name(),self.host_key)else:client_host_keys.add(f"[{self.remote_host}]:{self.port}",self.host_key.get_name(),self.host_key)connect_kwargs:dict[str,Any]={"hostname":self.remote_host,"username":self.username,"timeout":self.conn_timeout,"compress":self.compress,"port":self.port,"sock":self.host_proxy,"look_for_keys":self.look_for_keys,"banner_timeout":self.banner_timeout,"auth_timeout":self.auth_timeout,}ifself.password:password=self.password.strip()connect_kwargs.update(password=password)ifself.pkey:connect_kwargs.update(pkey=self.pkey)ifself.key_file:connect_kwargs.update(key_filename=self.key_file)ifself.disabled_algorithms:connect_kwargs.update(disabled_algorithms=self.disabled_algorithms)deflog_before_sleep(retry_state):returnself.log.info("Failed to connect. Sleeping before retry attempt %d",retry_state.attempt_number)forattemptinRetrying(reraise=True,wait=wait_fixed(3)+wait_random(0,2),stop=stop_after_attempt(3),before_sleep=log_before_sleep,):withattempt:client.connect(**connect_kwargs)ifself.keepalive_interval:# MyPy check ignored because "paramiko" isn't well-typed. The `client.get_transport()` returns# type "Transport | None" and item "None" has no attribute "set_keepalive".client.get_transport().set_keepalive(self.keepalive_interval)# type: ignore[union-attr]ifself.ciphers:# MyPy check ignored because "paramiko" isn't well-typed. The `client.get_transport()` returns# type "Transport | None" and item "None" has no method `get_security_options`".client.get_transport().get_security_options().ciphers=self.ciphers# type: ignore[union-attr]self.client=clientreturnclient
[docs]defget_tunnel(self,remote_port:int,remote_host:str="localhost",local_port:int|None=None)->SSHTunnelForwarder:""" Create a tunnel between two hosts. This is conceptually similar to ``ssh -L <LOCAL_PORT>:host:<REMOTE_PORT>``. :param remote_port: The remote port to create a tunnel to :param remote_host: The remote host to create a tunnel to (default localhost) :param local_port: The local port to attach the tunnel to :return: sshtunnel.SSHTunnelForwarder object """iflocal_port:local_bind_address:tuple[str,int]|tuple[str]=("localhost",local_port)else:local_bind_address=("localhost",)tunnel_kwargs={"ssh_port":self.port,"ssh_username":self.username,"ssh_pkey":self.key_fileorself.pkey,"ssh_proxy":self.host_proxy,"local_bind_address":local_bind_address,"remote_bind_address":(remote_host,remote_port),"logger":self.log,}ifself.password:password=self.password.strip()tunnel_kwargs.update(ssh_password=password,)else:tunnel_kwargs.update(host_pkey_directories=None,)client=SSHTunnelForwarder(self.remote_host,**tunnel_kwargs)returnclient
def_pkey_from_private_key(self,private_key:str,passphrase:str|None=None)->paramiko.PKey:""" Create an appropriate Paramiko key for a given private key. :param private_key: string containing private key :return: ``paramiko.PKey`` appropriate for given key :raises AirflowException: if key cannot be read """iflen(private_key.splitlines())<2:raiseAirflowException("Key must have BEGIN and END header/footer on separate lines.")forpkey_classinself._pkey_loaders:try:key=pkey_class.from_private_key(StringIO(private_key),password=passphrase)# Test it actually works. If Paramiko loads an openssh generated key, sometimes it will# happily load it as the wrong type, only to fail when actually used.key.sign_ssh_data(b"")returnkeyexcept(paramiko.ssh_exception.SSHException,ValueError):continueraiseAirflowException("Private key provided cannot be read by paramiko.""Ensure key provided is valid for one of the following""key formats: RSA, DSS, ECDSA, or Ed25519")
[docs]defexec_ssh_client_command(self,ssh_client:paramiko.SSHClient,command:str,get_pty:bool,environment:dict|None,timeout:float|ArgNotSet|None=NOTSET,)->tuple[int,bytes,bytes]:self.log.info("Running command: %s",command)cmd_timeout:float|Noneifnotisinstance(timeout,ArgNotSet):cmd_timeout=timeoutelifnotisinstance(self.cmd_timeout,ArgNotSet):cmd_timeout=self.cmd_timeoutelse:cmd_timeout=CMD_TIMEOUTdeltimeout# Too easy to confuse with "timedout" below.# set timeout taken as paramsstdin,stdout,stderr=ssh_client.exec_command(command=command,get_pty=get_pty,timeout=cmd_timeout,environment=environment,)# get channelschannel=stdout.channel# closing stdinstdin.close()channel.shutdown_write()agg_stdout=b""agg_stderr=b""# capture any initial output in case channel is closed alreadystdout_buffer_length=len(stdout.channel.in_buffer)ifstdout_buffer_length>0:agg_stdout+=stdout.channel.recv(stdout_buffer_length)timedout=False# read from both stdout and stderrwhilenotchannel.closedorchannel.recv_ready()orchannel.recv_stderr_ready():readq,_,_=select([channel],[],[],cmd_timeout)ifcmd_timeoutisnotNone:timedout=notreadqforrecvinreadq:ifrecv.recv_ready():output=stdout.channel.recv(len(recv.in_buffer))agg_stdout+=outputforlineinoutput.decode("utf-8","replace").strip("\n").splitlines():self.log.info(line)ifrecv.recv_stderr_ready():output=stderr.channel.recv_stderr(len(recv.in_stderr_buffer))agg_stderr+=outputforlineinoutput.decode("utf-8","replace").strip("\n").splitlines():self.log.warning(line)if(stdout.channel.exit_status_ready()andnotstderr.channel.recv_stderr_ready()andnotstdout.channel.recv_ready())ortimedout:stdout.channel.shutdown_read()try:stdout.channel.close()exceptException:# there is a race that when shutdown_read has been called and when# you try to close the connection, the socket is already closed# We should ignore such errors (but we should log them with warning)self.log.warning("Ignoring exception on close",exc_info=True)breakstdout.close()stderr.close()iftimedout:raiseAirflowException("SSH command timed out")exit_status=stdout.channel.recv_exit_status()returnexit_status,agg_stdout,agg_stderr
[docs]deftest_connection(self)->tuple[bool,str]:"""Test the ssh connection by execute remote bash commands."""try:withself.get_conn()asconn:conn.exec_command("pwd")returnTrue,"Connection successfully tested"exceptExceptionase:returnFalse,str(e)