Source code for airflow.hooks.S3_hook

# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from airflow.exceptions import AirflowException
from airflow.contrib.hooks.aws_hook import AwsHook

from six import BytesIO
from urllib.parse import urlparse
import re
import fnmatch

[docs]class S3Hook(AwsHook): """ Interact with AWS S3, using the boto3 library. """ def get_conn(self): return self.get_client_type('s3') @staticmethod def parse_s3_url(s3url): parsed_url = urlparse(s3url) if not parsed_url.netloc: raise AirflowException('Please provide a bucket_name instead of "%s"' % s3url) else: bucket_name = parsed_url.netloc key = parsed_url.path.strip('/') return (bucket_name, key)
[docs] def check_for_bucket(self, bucket_name): """ Check if bucket_name exists. :param bucket_name: the name of the bucket :type bucket_name: str """ try: self.get_conn().head_bucket(Bucket=bucket_name) return True except: return False
[docs] def get_bucket(self, bucket_name): """ Returns a boto3.S3.Bucket object :param bucket_name: the name of the bucket :type bucket_name: str """ s3 = self.get_resource_type('s3') return s3.Bucket(bucket_name)
[docs] def create_bucket(self, bucket_name, region_name=None): """ Creates an Amazon S3 bucket. :param bucket_name: The name of the bucket :type bucket_name: str :param region_name: The name of the aws region in which to create the bucket. :type region_name: str """ s3_conn = self.get_conn() if not region_name: region_name = s3_conn.meta.region_name if region_name == 'us-east-1': self.get_conn().create_bucket(Bucket=bucket_name) else: self.get_conn().create_bucket(Bucket=bucket_name, CreateBucketConfiguration={ 'LocationConstraint': region_name })
[docs] def check_for_prefix(self, bucket_name, prefix, delimiter): """ Checks that a prefix exists in a bucket """ prefix = prefix + delimiter if prefix[-1] != delimiter else prefix prefix_split = re.split(r'(\w+[{d}])$'.format(d=delimiter), prefix, 1) previous_level = prefix_split[0] plist = self.list_prefixes(bucket_name, previous_level, delimiter) return False if plist is None else prefix in plist
[docs] def list_prefixes(self, bucket_name, prefix='', delimiter='', page_size=None, max_items=None): """ Lists prefixes in a bucket under prefix :param bucket_name: the name of the bucket :type bucket_name: str :param prefix: a key prefix :type prefix: str :param delimiter: the delimiter marks key hierarchy. :type delimiter: str :param page_size: pagination size :type page_size: int :param max_items: maximum items to return :type max_items: int """ config = { 'PageSize': page_size, 'MaxItems': max_items, } paginator = self.get_conn().get_paginator('list_objects_v2') response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config) has_results = False prefixes = [] for page in response: if 'CommonPrefixes' in page: has_results = True for p in page['CommonPrefixes']: prefixes.append(p['Prefix']) if has_results: return prefixes
[docs] def list_keys(self, bucket_name, prefix='', delimiter='', page_size=None, max_items=None): """ Lists keys in a bucket under prefix and not containing delimiter :param bucket_name: the name of the bucket :type bucket_name: str :param prefix: a key prefix :type prefix: str :param delimiter: the delimiter marks key hierarchy. :type delimiter: str :param page_size: pagination size :type page_size: int :param max_items: maximum items to return :type max_items: int """ config = { 'PageSize': page_size, 'MaxItems': max_items, } paginator = self.get_conn().get_paginator('list_objects_v2') response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config) has_results = False keys = [] for page in response: if 'Contents' in page: has_results = True for k in page['Contents']: keys.append(k['Key']) if has_results: return keys
[docs] def check_for_key(self, key, bucket_name=None): """ Checks if a key exists in a bucket :param key: S3 key that will point to the file :type key: str :param bucket_name: Name of the bucket in which the file is stored :type bucket_name: str """ if not bucket_name: (bucket_name, key) = self.parse_s3_url(key) try: self.get_conn().head_object(Bucket=bucket_name, Key=key) return True except: return False
[docs] def get_key(self, key, bucket_name=None): """ Returns a boto3.s3.Object :param key: the path to the key :type key: str :param bucket_name: the name of the bucket :type bucket_name: str """ if not bucket_name: (bucket_name, key) = self.parse_s3_url(key) obj = self.get_resource_type('s3').Object(bucket_name, key) obj.load() return obj
[docs] def read_key(self, key, bucket_name=None): """ Reads a key from S3 :param key: S3 key that will point to the file :type key: str :param bucket_name: Name of the bucket in which the file is stored :type bucket_name: str """ obj = self.get_key(key, bucket_name) return obj.get()['Body'].read().decode('utf-8')
[docs] def select_key(self, key, bucket_name=None, expression='SELECT * FROM S3Object', expression_type='SQL', input_serialization={'CSV': {}}, output_serialization={'CSV': {}}): """ Reads a key with S3 Select. :param key: S3 key that will point to the file :type key: str :param bucket_name: Name of the bucket in which the file is stored :type bucket_name: str :param expression: S3 Select expression :type expression: str :param expression_type: S3 Select expression type :type expression_type: str :param input_serialization: S3 Select input data serialization format :type input_serialization: dict :param output_serialization: S3 Select output data serialization format :type output_serialization: dict :return: retrieved subset of original data by S3 Select :rtype: str .. seealso:: For more details about S3 Select parameters: """ if not bucket_name: (bucket_name, key) = self.parse_s3_url(key) response = self.get_conn().select_object_content( Bucket=bucket_name, Key=key, Expression=expression, ExpressionType=expression_type, InputSerialization=input_serialization, OutputSerialization=output_serialization) return ''.join(event['Records']['Payload'] for event in response['Payload'] if 'Records' in event)
[docs] def check_for_wildcard_key(self, wildcard_key, bucket_name=None, delimiter=''): """ Checks that a key matching a wildcard expression exists in a bucket """ return self.get_wildcard_key(wildcard_key=wildcard_key, bucket_name=bucket_name, delimiter=delimiter) is not None
[docs] def get_wildcard_key(self, wildcard_key, bucket_name=None, delimiter=''): """ Returns a boto3.s3.Object object matching the wildcard expression :param wildcard_key: the path to the key :type wildcard_key: str :param bucket_name: the name of the bucket :type bucket_name: str """ if not bucket_name: (bucket_name, wildcard_key) = self.parse_s3_url(wildcard_key) prefix = re.split(r'[*]', wildcard_key, 1)[0] klist = self.list_keys(bucket_name, prefix=prefix, delimiter=delimiter) if klist: key_matches = [k for k in klist if fnmatch.fnmatch(k, wildcard_key)] if key_matches: return self.get_key(key_matches[0], bucket_name)
[docs] def load_file(self, filename, key, bucket_name=None, replace=False, encrypt=False): """ Loads a local file to S3 :param filename: name of the file to load. :type filename: str :param key: S3 key that will point to the file :type key: str :param bucket_name: Name of the bucket in which to store the file :type bucket_name: str :param replace: A flag to decide whether or not to overwrite the key if it already exists. If replace is False and the key exists, an error will be raised. :type replace: bool :param encrypt: If True, the file will be encrypted on the server-side by S3 and will be stored in an encrypted form while at rest in S3. :type encrypt: bool """ if not bucket_name: (bucket_name, key) = self.parse_s3_url(key) if not replace and self.check_for_key(key, bucket_name): raise ValueError("The key {key} already exists.".format(key=key)) extra_args = {} if encrypt: extra_args['ServerSideEncryption'] = "AES256" client = self.get_conn() client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args)
[docs] def load_string(self, string_data, key, bucket_name=None, replace=False, encrypt=False, encoding='utf-8'): """ Loads a string to S3 This is provided as a convenience to drop a string in S3. It uses the boto infrastructure to ship a file to s3. :param string_data: string to set as content for the key. :type string_data: str :param key: S3 key that will point to the file :type key: str :param bucket_name: Name of the bucket in which to store the file :type bucket_name: str :param replace: A flag to decide whether or not to overwrite the key if it already exists :type replace: bool :param encrypt: If True, the file will be encrypted on the server-side by S3 and will be stored in an encrypted form while at rest in S3. :type encrypt: bool """ self.load_bytes(string_data.encode(encoding), key=key, bucket_name=bucket_name, replace=replace, encrypt=encrypt)
[docs] def load_bytes(self, bytes_data, key, bucket_name=None, replace=False, encrypt=False): """ Loads bytes to S3 This is provided as a convenience to drop a string in S3. It uses the boto infrastructure to ship a file to s3. :param bytes_data: bytes to set as content for the key. :type bytes_data: bytes :param key: S3 key that will point to the file :type key: str :param bucket_name: Name of the bucket in which to store the file :type bucket_name: str :param replace: A flag to decide whether or not to overwrite the key if it already exists :type replace: bool :param encrypt: If True, the file will be encrypted on the server-side by S3 and will be stored in an encrypted form while at rest in S3. :type encrypt: bool """ if not bucket_name: (bucket_name, key) = self.parse_s3_url(key) if not replace and self.check_for_key(key, bucket_name): raise ValueError("The key {key} already exists.".format(key=key)) extra_args = {} if encrypt: extra_args['ServerSideEncryption'] = "AES256" filelike_buffer = BytesIO(bytes_data) client = self.get_conn() client.upload_fileobj(filelike_buffer, bucket_name, key, ExtraArgs=extra_args)
[docs] def load_file_obj(self, file_obj, key, bucket_name=None, replace=False, encrypt=False): """ Loads a file object to S3 :param file_obj: The file-like object to set as the content for the S3 key. :type file_obj: file-like object :param key: S3 key that will point to the file :type key: str :param bucket_name: Name of the bucket in which to store the file :type bucket_name: str :param replace: A flag that indicates whether to overwrite the key if it already exists. :type replace: bool :param encrypt: If True, S3 encrypts the file on the server, and the file is stored in encrypted form at rest in S3. :type encrypt: bool """ if not bucket_name: (bucket_name, key) = self.parse_s3_url(key) if not replace and self.check_for_key(key, bucket_name): raise ValueError("The key {key} already exists.".format(key=key)) extra_args = {} if encrypt: extra_args['ServerSideEncryption'] = "AES256" client = self.get_conn() client.upload_fileobj(file_obj, bucket_name, key, ExtraArgs=extra_args)
[docs] def copy_object(self, source_bucket_key, dest_bucket_key, source_bucket_name=None, dest_bucket_name=None, source_version_id=None): """ Creates a copy of an object that is already stored in S3. Note: the S3 connection used here needs to have access to both source and destination bucket/key. :param source_bucket_key: The key of the source object. It can be either full s3:// style url or relative path from root level. When it's specified as a full s3:// url, please omit source_bucket_name. :type source_bucket_key: str :param dest_bucket_key: The key of the object to copy to. The convention to specify `dest_bucket_key` is the same as `source_bucket_key`. :type dest_bucket_key: str :param source_bucket_name: Name of the S3 bucket where the source object is in. It should be omitted when `source_bucket_key` is provided as a full s3:// url. :type source_bucket_name: str :param dest_bucket_name: Name of the S3 bucket to where the object is copied. It should be omitted when `dest_bucket_key` is provided as a full s3:// url. :type dest_bucket_name: str :param source_version_id: Version ID of the source object (OPTIONAL) :type source_version_id: str """ if dest_bucket_name is None: dest_bucket_name, dest_bucket_key = self.parse_s3_url(dest_bucket_key) else: parsed_url = urlparse(dest_bucket_key) if parsed_url.scheme != '' or parsed_url.netloc != '': raise AirflowException('If dest_bucket_name is provided, ' + 'dest_bucket_key should be relative path ' + 'from root level, rather than a full s3:// url') if source_bucket_name is None: source_bucket_name, source_bucket_key = self.parse_s3_url(source_bucket_key) else: parsed_url = urlparse(source_bucket_key) if parsed_url.scheme != '' or parsed_url.netloc != '': raise AirflowException('If source_bucket_name is provided, ' + 'source_bucket_key should be relative path ' + 'from root level, rather than a full s3:// url') CopySource = {'Bucket': source_bucket_name, 'Key': source_bucket_key, 'VersionId': source_version_id} response = self.get_conn().copy_object(Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=CopySource) return response
[docs] def delete_objects(self, bucket, keys): """ :param bucket: Name of the bucket in which you are going to delete object(s) :type bucket: str :param keys: The key(s) to delete from S3 bucket. When ``keys`` is a string, it's supposed to be the key name of the single object to delete. When ``keys`` is a list, it's supposed to be the list of the keys to delete. :type keys: str or list """ if isinstance(keys, list): keys = keys else: keys = [keys] delete_dict = {"Objects": [{"Key": k} for k in keys]} response = self.get_conn().delete_objects(Bucket=bucket, Delete=delete_dict) return response