Source code for airflow.providers.alibaba.cloud.hooks.oss

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from functools import wraps
from inspect import signature
from typing import TYPE_CHECKING, Callable, TypeVar, cast
from urllib.parse import urlsplit

import oss2
from oss2.exceptions import ClientError

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook

if TYPE_CHECKING:
    from airflow.models.connection import Connection

[docs]T = TypeVar("T", bound=Callable)
[docs]def provide_bucket_name(func: T) -> T: """Unify bucket name and key if a key is provided but not a bucket name.""" function_signature = signature(func) @wraps(func) def wrapper(*args, **kwargs) -> T: bound_args = function_signature.bind(*args, **kwargs) self = args[0] if bound_args.arguments.get("bucket_name") is None and self.oss_conn_id: connection = self.get_connection(self.oss_conn_id) if connection.schema: bound_args.arguments["bucket_name"] = connection.schema return func(*bound_args.args, **bound_args.kwargs) return cast(T, wrapper)
[docs]def unify_bucket_name_and_key(func: T) -> T: """Unify bucket name and key if a key is provided but not a bucket name.""" function_signature = signature(func) @wraps(func) def wrapper(*args, **kwargs) -> T: bound_args = function_signature.bind(*args, **kwargs) def get_key() -> str: if "key" in bound_args.arguments: return "key" raise ValueError("Missing key parameter!") key_name = get_key() if bound_args.arguments.get("bucket_name") is None: bound_args.arguments["bucket_name"], bound_args.arguments["key"] = OSSHook.parse_oss_url( bound_args.arguments[key_name] ) return func(*bound_args.args, **bound_args.kwargs) return cast(T, wrapper)
[docs]class OSSHook(BaseHook): """Interact with Alibaba Cloud OSS, using the oss2 library."""
[docs] conn_name_attr = "alibabacloud_conn_id"
[docs] default_conn_name = "oss_default"
[docs] conn_type = "oss"
[docs] hook_name = "OSS"
def __init__(self, region: str | None = None, oss_conn_id="oss_default", *args, **kwargs) -> None: self.oss_conn_id = oss_conn_id self.oss_conn = self.get_connection(oss_conn_id) self.region = region or self.get_default_region() super().__init__(*args, **kwargs)
[docs] def get_conn(self) -> Connection: """Return connection for the hook.""" return self.oss_conn
@staticmethod
[docs] def parse_oss_url(ossurl: str) -> tuple: """ Parse the OSS Url into a bucket name and key. :param ossurl: The OSS Url to parse. :return: the parsed bucket name and key """ parsed_url = urlsplit(ossurl) if not parsed_url.netloc: raise AirflowException(f'Please provide a bucket_name instead of "{ossurl}"') bucket_name = parsed_url.netloc key = parsed_url.path.lstrip("/") return bucket_name, key
@provide_bucket_name @unify_bucket_name_and_key
[docs] def object_exists(self, key: str, bucket_name: str | None = None) -> bool: """ Check if object exists. :param key: the path of the object :param bucket_name: the name of the bucket :return: True if it exists and False if not. """ try: return self.get_bucket(bucket_name).object_exists(key) except ClientError as e: self.log.error(e.message) return False
@provide_bucket_name
[docs] def get_bucket(self, bucket_name: str | None = None) -> oss2.api.Bucket: """ Return a oss2.Bucket object. :param bucket_name: the name of the bucket :return: the bucket object to the bucket name. """ auth = self.get_credential() return oss2.Bucket(auth, f"https://oss-{self.region}.aliyuncs.com", bucket_name)
@provide_bucket_name @unify_bucket_name_and_key
[docs] def load_string(self, key: str, content: str, bucket_name: str | None = None) -> None: """ Load a string to OSS. :param key: the path of the object :param content: str to set as content for the key. :param bucket_name: the name of the bucket """ try: self.get_bucket(bucket_name).put_object(key, content) except Exception as e: raise AirflowException(f"Errors: {e}")
@provide_bucket_name @unify_bucket_name_and_key
[docs] def upload_local_file( self, key: str, file: str, bucket_name: str | None = None, ) -> None: """ Upload a local file to OSS. :param key: the OSS path of the object :param file: local file to upload. :param bucket_name: the name of the bucket """ try: self.get_bucket(bucket_name).put_object_from_file(key, file) except Exception as e: raise AirflowException(f"Errors when upload file: {e}")
@provide_bucket_name @unify_bucket_name_and_key
[docs] def download_file( self, key: str, local_file: str, bucket_name: str | None = None, ) -> str | None: """ Download file from OSS. :param key: key of the file-like object to download. :param local_file: local path + file name to save. :param bucket_name: the name of the bucket :return: the file name. """ try: self.get_bucket(bucket_name).get_object_to_file(key, local_file) except Exception as e: self.log.error(e) return None return local_file
@provide_bucket_name @unify_bucket_name_and_key
[docs] def delete_object( self, key: str, bucket_name: str | None = None, ) -> None: """ Delete object from OSS. :param key: key of the object to delete. :param bucket_name: the name of the bucket """ try: self.get_bucket(bucket_name).delete_object(key) except Exception as e: self.log.error(e) raise AirflowException(f"Errors when deleting: {key}")
@provide_bucket_name @unify_bucket_name_and_key
[docs] def delete_objects( self, key: list, bucket_name: str | None = None, ) -> None: """ Delete objects from OSS. :param key: keys list of the objects to delete. :param bucket_name: the name of the bucket """ try: self.get_bucket(bucket_name).batch_delete_objects(key) except Exception as e: self.log.error(e) raise AirflowException(f"Errors when deleting: {key}")
@provide_bucket_name
[docs] def delete_bucket( self, bucket_name: str | None = None, ) -> None: """ Delete bucket from OSS. :param bucket_name: the name of the bucket """ try: self.get_bucket(bucket_name).delete_bucket() except Exception as e: self.log.error(e) raise AirflowException(f"Errors when deleting: {bucket_name}")
@provide_bucket_name
[docs] def create_bucket( self, bucket_name: str | None = None, ) -> None: """ Create bucket. :param bucket_name: the name of the bucket """ try: self.get_bucket(bucket_name).create_bucket() except Exception as e: self.log.error(e) raise AirflowException(f"Errors when create bucket: {bucket_name}")
@provide_bucket_name @unify_bucket_name_and_key
[docs] def append_string(self, bucket_name: str | None, content: str, key: str, pos: int) -> None: """ Append string to a remote existing file. :param bucket_name: the name of the bucket :param content: content to be appended :param key: oss bucket key :param pos: position of the existing file where the content will be appended """ self.log.info("Write oss bucket. key: %s, pos: %s", key, pos) try: self.get_bucket(bucket_name).append_object(key, pos, content) except Exception as e: self.log.error(e) raise AirflowException(f"Errors when append string for object: {key}")
@provide_bucket_name @unify_bucket_name_and_key
[docs] def read_key(self, bucket_name: str | None, key: str) -> str: """ Read oss remote object content with the specified key. :param bucket_name: the name of the bucket :param key: oss bucket key """ self.log.info("Read oss key: %s", key) try: return self.get_bucket(bucket_name).get_object(key).read().decode("utf-8") except Exception as e: self.log.error(e) raise AirflowException(f"Errors when read bucket object: {key}")
@provide_bucket_name @unify_bucket_name_and_key
[docs] def head_key(self, bucket_name: str | None, key: str) -> oss2.models.HeadObjectResult: """ Get meta info of the specified remote object. :param bucket_name: the name of the bucket :param key: oss bucket key """ self.log.info("Head Object oss key: %s", key) try: return self.get_bucket(bucket_name).head_object(key) except Exception as e: self.log.error(e) raise AirflowException(f"Errors when head bucket object: {key}")
@provide_bucket_name @unify_bucket_name_and_key
[docs] def key_exist(self, bucket_name: str | None, key: str) -> bool: """ Find out whether the specified key exists in the oss remote storage. :param bucket_name: the name of the bucket :param key: oss bucket key """ # full_path = None self.log.info("Looking up oss bucket %s for bucket key %s ...", bucket_name, key) try: return self.get_bucket(bucket_name).object_exists(key) except Exception as e: self.log.error(e) raise AirflowException(f"Errors when check bucket object existence: {key}")
[docs] def get_credential(self) -> oss2.auth.Auth: extra_config = self.oss_conn.extra_dejson auth_type = extra_config.get("auth_type", None) if not auth_type: raise ValueError("No auth_type specified in extra_config. ") if auth_type != "AK": raise ValueError(f"Unsupported auth_type: {auth_type}") oss_access_key_id = extra_config.get("access_key_id", None) oss_access_key_secret = extra_config.get("access_key_secret", None) if not oss_access_key_id: raise ValueError(f"No access_key_id is specified for connection: {self.oss_conn_id}") if not oss_access_key_secret: raise ValueError(f"No access_key_secret is specified for connection: {self.oss_conn_id}") return oss2.Auth(oss_access_key_id, oss_access_key_secret)
[docs] def get_default_region(self) -> str: extra_config = self.oss_conn.extra_dejson auth_type = extra_config.get("auth_type", None) if not auth_type: raise ValueError("No auth_type specified in extra_config. ") if auth_type != "AK": raise ValueError(f"Unsupported auth_type: {auth_type}") default_region = extra_config.get("region", None) if not default_region: raise ValueError(f"No region is specified for connection: {self.oss_conn_id}") return default_region

Was this entry helpful?