Source code for airflow.providers.sftp.hooks.sftp

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains SFTP hook."""

from __future__ import annotations

import datetime
import os
import stat
import warnings
from fnmatch import fnmatch
from typing import TYPE_CHECKING, Any, Callable, Sequence

import asyncssh
from asgiref.sync import sync_to_async

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
from airflow.providers.ssh.hooks.ssh import SSHHook

if TYPE_CHECKING:
    import paramiko

    from airflow.models.connection import Connection


[docs]class SFTPHook(SSHHook): """Interact with SFTP. This hook inherits the SSH hook. Please refer to SSH hook for the input arguments. :Pitfalls:: - In contrast with FTPHook describe_directory only returns size, type and modify. It doesn't return unix.owner, unix.mode, perm, unix.group and unique. - retrieve_file and store_file only take a local full path and not a buffer. - If no mode is passed to create_directory it will be created with 777 permissions. Errors that may occur throughout but should be handled downstream. For consistency reasons with SSHHook, the preferred parameter is "ssh_conn_id". :param ssh_conn_id: The :ref:`sftp connection id<howto/connection:sftp>` :param ssh_hook: Optional SSH hook (included to support passing of an SSH hook to the SFTP operator) """
[docs] conn_name_attr = "ssh_conn_id"
[docs] default_conn_name = "sftp_default"
[docs] conn_type = "sftp"
[docs] hook_name = "SFTP"
@classmethod
[docs] def get_ui_field_behaviour(cls) -> dict[str, Any]: return { "hidden_fields": ["schema"], "relabeling": { "login": "Username", }, }
def __init__( self, ssh_conn_id: str | None = "sftp_default", ssh_hook: SSHHook | None = None, *args, **kwargs, ) -> None: self.conn: paramiko.SFTPClient | None = None # TODO: remove support for ssh_hook when it is removed from SFTPOperator self.ssh_hook = ssh_hook if self.ssh_hook is not None: warnings.warn( "Parameter `ssh_hook` is deprecated and will be removed in a future version.", AirflowProviderDeprecationWarning, stacklevel=2, ) if not isinstance(self.ssh_hook, SSHHook): raise AirflowException( f"ssh_hook must be an instance of SSHHook, but got {type(self.ssh_hook)}" ) self.log.info("ssh_hook is provided. It will be used to generate SFTP connection.") self.ssh_conn_id = self.ssh_hook.ssh_conn_id return ftp_conn_id = kwargs.pop("ftp_conn_id", None) if ftp_conn_id: warnings.warn( "Parameter `ftp_conn_id` is deprecated. Please use `ssh_conn_id` instead.", AirflowProviderDeprecationWarning, stacklevel=2, ) ssh_conn_id = ftp_conn_id kwargs["ssh_conn_id"] = ssh_conn_id self.ssh_conn_id = ssh_conn_id super().__init__(*args, **kwargs)
[docs] def get_conn(self) -> paramiko.SFTPClient: # type: ignore[override] """Open an SFTP connection to the remote host.""" if self.conn is None: # TODO: remove support for ssh_hook when it is removed from SFTPOperator if self.ssh_hook is not None: self.conn = self.ssh_hook.get_conn().open_sftp() else: self.conn = super().get_conn().open_sftp() return self.conn
[docs] def close_conn(self) -> None: """Close the SFTP connection.""" if self.conn is not None: self.conn.close() self.conn = None
[docs] def describe_directory(self, path: str) -> dict[str, dict[str, str | int | None]]: """Get file information in a directory on the remote system. The return format is ``{filename: {attributes}}``. The remote system support the MLSD command. :param path: full path to the remote directory """ conn = self.get_conn() flist = sorted(conn.listdir_attr(path), key=lambda x: x.filename) files = {} for f in flist: modify = datetime.datetime.fromtimestamp(f.st_mtime).strftime("%Y%m%d%H%M%S") # type: ignore files[f.filename] = { "size": f.st_size, "type": "dir" if stat.S_ISDIR(f.st_mode) else "file", # type: ignore "modify": modify, } return files
[docs] def list_directory(self, path: str) -> list[str]: """List files in a directory on the remote system. :param path: full path to the remote directory to list """ conn = self.get_conn() files = sorted(conn.listdir(path)) return files
[docs] def mkdir(self, path: str, mode: int = 0o777) -> None: """Create a directory on the remote system. The default mode is ``0o777``, but on some systems, the current umask value may be first masked out. :param path: full path to the remote directory to create :param mode: int permissions of octal mode for directory """ conn = self.get_conn() conn.mkdir(path, mode=mode)
[docs] def isdir(self, path: str) -> bool: """Check if the path provided is a directory. :param path: full path to the remote directory to check """ conn = self.get_conn() try: result = stat.S_ISDIR(conn.stat(path).st_mode) # type: ignore except OSError: result = False return result
[docs] def isfile(self, path: str) -> bool: """Check if the path provided is a file. :param path: full path to the remote file to check """ conn = self.get_conn() try: result = stat.S_ISREG(conn.stat(path).st_mode) # type: ignore except OSError: result = False return result
[docs] def create_directory(self, path: str, mode: int = 0o777) -> None: """Create a directory on the remote system. The default mode is ``0o777``, but on some systems, the current umask value may be first masked out. Different from :func:`.mkdir`, this function attempts to create parent directories if needed, and returns silently if the target directory already exists. :param path: full path to the remote directory to create :param mode: int permissions of octal mode for directory """ conn = self.get_conn() if self.isdir(path): self.log.info("%s already exists", path) return elif self.isfile(path): raise AirflowException(f"{path} already exists and is a file") else: dirname, basename = os.path.split(path) if dirname and not self.isdir(dirname): self.create_directory(dirname, mode) if basename: self.log.info("Creating %s", path) conn.mkdir(path, mode=mode)
[docs] def delete_directory(self, path: str) -> None: """Delete a directory on the remote system. :param path: full path to the remote directory to delete """ conn = self.get_conn() conn.rmdir(path)
[docs] def retrieve_file(self, remote_full_path: str, local_full_path: str, prefetch: bool = True) -> None: """Transfer the remote file to a local location. If local_full_path is a string path, the file will be put at that location. :param remote_full_path: full path to the remote file :param local_full_path: full path to the local file :param prefetch: controls whether prefetch is performed (default: True) """ conn = self.get_conn() conn.get(remote_full_path, local_full_path, prefetch=prefetch)
[docs] def store_file(self, remote_full_path: str, local_full_path: str, confirm: bool = True) -> None: """Transfer a local file to the remote location. If local_full_path_or_buffer is a string path, the file will be read from that location. :param remote_full_path: full path to the remote file :param local_full_path: full path to the local file """ conn = self.get_conn() conn.put(local_full_path, remote_full_path, confirm=confirm)
[docs] def delete_file(self, path: str) -> None: """Remove a file on the server. :param path: full path to the remote file """ conn = self.get_conn() conn.remove(path)
[docs] def get_mod_time(self, path: str) -> str: """Get an entry's modification time. :param path: full path to the remote file """ conn = self.get_conn() ftp_mdtm = conn.stat(path).st_mtime return datetime.datetime.fromtimestamp(ftp_mdtm).strftime("%Y%m%d%H%M%S") # type: ignore
[docs] def path_exists(self, path: str) -> bool: """Whether a remote entity exists. :param path: full path to the remote file or directory """ conn = self.get_conn() try: conn.stat(path) except OSError: return False return True
@staticmethod def _is_path_match(path: str, prefix: str | None = None, delimiter: str | None = None) -> bool: """Whether given path starts with ``prefix`` (if set) and ends with ``delimiter`` (if set). :param path: path to be checked :param prefix: if set path will be checked is starting with prefix :param delimiter: if set path will be checked is ending with suffix :return: bool """ if prefix is not None and not path.startswith(prefix): return False if delimiter is not None and not path.endswith(delimiter): return False return True
[docs] def walktree( self, path: str, fcallback: Callable[[str], Any | None], dcallback: Callable[[str], Any | None], ucallback: Callable[[str], Any | None], recurse: bool = True, ) -> None: """Recursively descend, depth first, the directory tree at ``path``. This calls discrete callback functions for each regular file, directory, and unknown file type. :param str path: root of remote directory to descend, use '.' to start at :attr:`.pwd` :param callable fcallback: callback function to invoke for a regular file. (form: ``func(str)``) :param callable dcallback: callback function to invoke for a directory. (form: ``func(str)``) :param callable ucallback: callback function to invoke for an unknown file type. (form: ``func(str)``) :param bool recurse: *Default: True* - should it recurse """ conn = self.get_conn() for entry in self.list_directory(path): pathname = os.path.join(path, entry) mode = conn.stat(pathname).st_mode if stat.S_ISDIR(mode): # type: ignore # It's a directory, call the dcallback function dcallback(pathname) if recurse: # now, recurse into it self.walktree(pathname, fcallback, dcallback, ucallback) elif stat.S_ISREG(mode): # type: ignore # It's a file, call the fcallback function fcallback(pathname) else: # Unknown file type ucallback(pathname)
[docs] def get_tree_map( self, path: str, prefix: str | None = None, delimiter: str | None = None ) -> tuple[list[str], list[str], list[str]]: """Get tuple with recursive lists of files, directories and unknown paths. It is possible to filter results by giving prefix and/or delimiter parameters. :param path: path from which tree will be built :param prefix: if set paths will be added if start with prefix :param delimiter: if set paths will be added if end with delimiter :return: tuple with list of files, dirs and unknown items """ files: list[str] = [] dirs: list[str] = [] unknowns: list[str] = [] def append_matching_path_callback(list_: list[str]) -> Callable: return lambda item: list_.append(item) if self._is_path_match(item, prefix, delimiter) else None self.walktree( path=path, fcallback=append_matching_path_callback(files), dcallback=append_matching_path_callback(dirs), ucallback=append_matching_path_callback(unknowns), recurse=True, ) return files, dirs, unknowns
[docs] def test_connection(self) -> tuple[bool, str]: """Test the SFTP connection by calling path with directory.""" try: conn = self.get_conn() conn.normalize(".") return True, "Connection successfully tested" except Exception as e: return False, str(e)
[docs] def get_file_by_pattern(self, path, fnmatch_pattern) -> str: """Get the first matching file based on the given fnmatch type pattern. :param path: path to be checked :param fnmatch_pattern: The pattern that will be matched with `fnmatch` :return: string containing the first found file, or an empty string if none matched """ for file in self.list_directory(path): if fnmatch(file, fnmatch_pattern): return file return ""
[docs] def get_files_by_pattern(self, path, fnmatch_pattern) -> list[str]: """Get all matching files based on the given fnmatch type pattern. :param path: path to be checked :param fnmatch_pattern: The pattern that will be matched with `fnmatch` :return: list of string containing the found files, or an empty list if none matched """ matched_files = [] for file in self.list_directory(path): if fnmatch(file, fnmatch_pattern): matched_files.append(file) return matched_files
[docs]class SFTPHookAsync(BaseHook): """ Interact with an SFTP server via asyncssh package. :param sftp_conn_id: SFTP connection ID to be used for connecting to SFTP server :param host: hostname of the SFTP server :param port: port of the SFTP server :param username: username used when authenticating to the SFTP server :param password: password used when authenticating to the SFTP server. Can be left blank if using a key file :param known_hosts: path to the known_hosts file on the local file system. Defaults to ``~/.ssh/known_hosts``. :param key_file: path to the client key file used for authentication to SFTP server :param passphrase: passphrase used with the key_file for authentication to SFTP server """
[docs] conn_name_attr = "ssh_conn_id"
[docs] default_conn_name = "sftp_default"
[docs] conn_type = "sftp"
[docs] hook_name = "SFTP"
[docs] default_known_hosts = "~/.ssh/known_hosts"
def __init__( # nosec: B107 self, sftp_conn_id: str = default_conn_name, host: str = "", port: int = 22, username: str = "", password: str = "", known_hosts: str = default_known_hosts, key_file: str = "", passphrase: str = "", private_key: str = "", ) -> None: self.sftp_conn_id = sftp_conn_id self.host = host self.port = port self.username = username self.password = password self.known_hosts: bytes | str = os.path.expanduser(known_hosts) self.key_file = key_file self.passphrase = passphrase self.private_key = private_key def _parse_extras(self, conn: Connection) -> None: """Parse extra fields from the connection into instance fields.""" extra_options = conn.extra_dejson if "key_file" in extra_options and self.key_file == "": self.key_file = extra_options["key_file"] if "known_hosts" in extra_options and self.known_hosts != self.default_known_hosts: self.known_hosts = extra_options["known_hosts"] if ("passphrase" or "private_key_passphrase") in extra_options: self.passphrase = extra_options["passphrase"] if "private_key" in extra_options: self.private_key = extra_options["private_key"] host_key = extra_options.get("host_key") no_host_key_check = extra_options.get("no_host_key_check") if no_host_key_check is not None: no_host_key_check = str(no_host_key_check).lower() == "true" if host_key is not None and no_host_key_check: raise ValueError("Host key check was skipped, but `host_key` value was given") if no_host_key_check: self.log.warning( "No Host Key Verification. This won't protect against Man-In-The-Middle attacks" ) self.known_hosts = "none" if host_key is not None: self.known_hosts = f"{conn.host} {host_key}".encode() async def _get_conn(self) -> asyncssh.SSHClientConnection: """ Asynchronously connect to the SFTP server as an SSH client. The following parameters are provided either in the extra json object in the SFTP connection definition - key_file - known_hosts - passphrase """ conn = await sync_to_async(self.get_connection)(self.sftp_conn_id) if conn.extra is not None: self._parse_extras(conn) conn_config = { "host": conn.host, "port": conn.port, "username": conn.login, "password": conn.password, } if self.key_file: conn_config.update(client_keys=self.key_file) if self.known_hosts: if self.known_hosts.lower() == "none": conn_config.update(known_hosts=None) else: conn_config.update(known_hosts=self.known_hosts) if self.private_key: _private_key = asyncssh.import_private_key(self.private_key, self.passphrase) conn_config.update(client_keys=[_private_key]) if self.passphrase: conn_config.update(passphrase=self.passphrase) ssh_client_conn = await asyncssh.connect(**conn_config) return ssh_client_conn
[docs] async def list_directory(self, path: str = "") -> list[str] | None: # type: ignore[return] """Return a list of files on the SFTP server at the provided path.""" async with await self._get_conn() as ssh_conn: sftp_client = await ssh_conn.start_sftp_client() try: files = await sftp_client.listdir(path) return sorted(files) except asyncssh.SFTPNoSuchFile: return None
[docs] async def read_directory(self, path: str = "") -> Sequence[asyncssh.sftp.SFTPName] | None: # type: ignore[return] """Return a list of files along with their attributes on the SFTP server at the provided path.""" async with await self._get_conn() as ssh_conn: sftp_client = await ssh_conn.start_sftp_client() try: return await sftp_client.readdir(path) except asyncssh.SFTPNoSuchFile: return None
[docs] async def get_files_and_attrs_by_pattern( self, path: str = "", fnmatch_pattern: str = "" ) -> Sequence[asyncssh.sftp.SFTPName]: """ Get the files along with their attributes matching the pattern (e.g. ``*.pdf``) at the provided path. if one exists. Otherwise, raises an AirflowException to be handled upstream for deferring """ files_list = await self.read_directory(path) if files_list is None: raise FileNotFoundError(f"No files at path {path!r} found...") matched_files = [file for file in files_list if fnmatch(str(file.filename), fnmatch_pattern)] return matched_files
[docs] async def get_mod_time(self, path: str) -> str: """ Make SFTP async connection. Looks for last modified time in the specific file path and returns last modification time for the file path. :param path: full path to the remote file """ ssh_conn = await self._get_conn() sftp_client = await ssh_conn.start_sftp_client() try: ftp_mdtm = await sftp_client.stat(path) modified_time = ftp_mdtm.mtime mod_time = datetime.datetime.fromtimestamp(modified_time).strftime("%Y%m%d%H%M%S") # type: ignore[arg-type] self.log.info("Found File %s last modified: %s", str(path), str(mod_time)) return mod_time except asyncssh.SFTPNoSuchFile: raise AirflowException("No files matching")

Was this entry helpful?