# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from botocore.exceptions import ClientError
from airflow.exceptions import AirflowException
from airflow.contrib.hooks.aws_hook import AwsHook
from six import BytesIO
from urllib.parse import urlparse
import re
import fnmatch
[docs]class S3Hook(AwsHook):
"""
Interact with AWS S3, using the boto3 library.
"""
def get_conn(self):
return self.get_client_type('s3')
@staticmethod
def parse_s3_url(s3url):
parsed_url = urlparse(s3url)
if not parsed_url.netloc:
raise AirflowException('Please provide a bucket_name instead of "%s"' % s3url)
else:
bucket_name = parsed_url.netloc
key = parsed_url.path.strip('/')
return bucket_name, key
[docs] def check_for_bucket(self, bucket_name):
"""
Check if bucket_name exists.
:param bucket_name: the name of the bucket
:type bucket_name: str
"""
try:
self.get_conn().head_bucket(Bucket=bucket_name)
return True
except ClientError as e:
self.log.info(e.response["Error"]["Message"])
return False
[docs] def get_bucket(self, bucket_name):
"""
Returns a boto3.S3.Bucket object
:param bucket_name: the name of the bucket
:type bucket_name: str
"""
s3 = self.get_resource_type('s3')
return s3.Bucket(bucket_name)
[docs] def create_bucket(self, bucket_name, region_name=None):
"""
Creates an Amazon S3 bucket.
:param bucket_name: The name of the bucket
:type bucket_name: str
:param region_name: The name of the aws region in which to create the bucket.
:type region_name: str
"""
s3_conn = self.get_conn()
if not region_name:
region_name = s3_conn.meta.region_name
if region_name == 'us-east-1':
self.get_conn().create_bucket(Bucket=bucket_name)
else:
self.get_conn().create_bucket(Bucket=bucket_name,
CreateBucketConfiguration={
'LocationConstraint': region_name
})
[docs] def check_for_prefix(self, bucket_name, prefix, delimiter):
"""
Checks that a prefix exists in a bucket
:param bucket_name: the name of the bucket
:type bucket_name: str
:param prefix: a key prefix
:type prefix: str
:param delimiter: the delimiter marks key hierarchy.
:type delimiter: str
"""
prefix = prefix + delimiter if prefix[-1] != delimiter else prefix
prefix_split = re.split(r'(\w+[{d}])$'.format(d=delimiter), prefix, 1)
previous_level = prefix_split[0]
plist = self.list_prefixes(bucket_name, previous_level, delimiter)
return False if plist is None else prefix in plist
[docs] def list_prefixes(self, bucket_name, prefix='', delimiter='',
page_size=None, max_items=None):
"""
Lists prefixes in a bucket under prefix
:param bucket_name: the name of the bucket
:type bucket_name: str
:param prefix: a key prefix
:type prefix: str
:param delimiter: the delimiter marks key hierarchy.
:type delimiter: str
:param page_size: pagination size
:type page_size: int
:param max_items: maximum items to return
:type max_items: int
"""
config = {
'PageSize': page_size,
'MaxItems': max_items,
}
paginator = self.get_conn().get_paginator('list_objects_v2')
response = paginator.paginate(Bucket=bucket_name,
Prefix=prefix,
Delimiter=delimiter,
PaginationConfig=config)
has_results = False
prefixes = []
for page in response:
if 'CommonPrefixes' in page:
has_results = True
for p in page['CommonPrefixes']:
prefixes.append(p['Prefix'])
if has_results:
return prefixes
[docs] def list_keys(self, bucket_name, prefix='', delimiter='',
page_size=None, max_items=None):
"""
Lists keys in a bucket under prefix and not containing delimiter
:param bucket_name: the name of the bucket
:type bucket_name: str
:param prefix: a key prefix
:type prefix: str
:param delimiter: the delimiter marks key hierarchy.
:type delimiter: str
:param page_size: pagination size
:type page_size: int
:param max_items: maximum items to return
:type max_items: int
"""
config = {
'PageSize': page_size,
'MaxItems': max_items,
}
paginator = self.get_conn().get_paginator('list_objects_v2')
response = paginator.paginate(Bucket=bucket_name,
Prefix=prefix,
Delimiter=delimiter,
PaginationConfig=config)
has_results = False
keys = []
for page in response:
if 'Contents' in page:
has_results = True
for k in page['Contents']:
keys.append(k['Key'])
if has_results:
return keys
[docs] def check_for_key(self, key, bucket_name=None):
"""
Checks if a key exists in a bucket
:param key: S3 key that will point to the file
:type key: str
:param bucket_name: Name of the bucket in which the file is stored
:type bucket_name: str
"""
if not bucket_name:
(bucket_name, key) = self.parse_s3_url(key)
try:
self.get_conn().head_object(Bucket=bucket_name, Key=key)
return True
except ClientError as e:
self.log.info(e.response["Error"]["Message"])
return False
[docs] def get_key(self, key, bucket_name=None):
"""
Returns a boto3.s3.Object
:param key: the path to the key
:type key: str
:param bucket_name: the name of the bucket
:type bucket_name: str
"""
if not bucket_name:
(bucket_name, key) = self.parse_s3_url(key)
obj = self.get_resource_type('s3').Object(bucket_name, key)
obj.load()
return obj
[docs] def read_key(self, key, bucket_name=None):
"""
Reads a key from S3
:param key: S3 key that will point to the file
:type key: str
:param bucket_name: Name of the bucket in which the file is stored
:type bucket_name: str
"""
obj = self.get_key(key, bucket_name)
return obj.get()['Body'].read().decode('utf-8')
[docs] def select_key(self, key, bucket_name=None,
expression='SELECT * FROM S3Object',
expression_type='SQL',
input_serialization=None,
output_serialization=None):
"""
Reads a key with S3 Select.
:param key: S3 key that will point to the file
:type key: str
:param bucket_name: Name of the bucket in which the file is stored
:type bucket_name: str
:param expression: S3 Select expression
:type expression: str
:param expression_type: S3 Select expression type
:type expression_type: str
:param input_serialization: S3 Select input data serialization format
:type input_serialization: dict
:param output_serialization: S3 Select output data serialization format
:type output_serialization: dict
:return: retrieved subset of original data by S3 Select
:rtype: str
.. seealso::
For more details about S3 Select parameters:
http://boto3.readthedocs.io/en/latest/reference/services/s3.html#S3.Client.select_object_content
"""
if input_serialization is None:
input_serialization = {'CSV': {}}
if output_serialization is None:
output_serialization = {'CSV': {}}
if not bucket_name:
(bucket_name, key) = self.parse_s3_url(key)
response = self.get_conn().select_object_content(
Bucket=bucket_name,
Key=key,
Expression=expression,
ExpressionType=expression_type,
InputSerialization=input_serialization,
OutputSerialization=output_serialization)
return ''.join(event['Records']['Payload']
for event in response['Payload']
if 'Records' in event)
[docs] def check_for_wildcard_key(self,
wildcard_key, bucket_name=None, delimiter=''):
"""
Checks that a key matching a wildcard expression exists in a bucket
:param wildcard_key: the path to the key
:type wildcard_key: str
:param bucket_name: the name of the bucket
:type bucket_name: str
:param delimiter: the delimiter marks key hierarchy
:type delimiter: str
"""
return self.get_wildcard_key(wildcard_key=wildcard_key,
bucket_name=bucket_name,
delimiter=delimiter) is not None
[docs] def get_wildcard_key(self, wildcard_key, bucket_name=None, delimiter=''):
"""
Returns a boto3.s3.Object object matching the wildcard expression
:param wildcard_key: the path to the key
:type wildcard_key: str
:param bucket_name: the name of the bucket
:type bucket_name: str
:param delimiter: the delimiter marks key hierarchy
:type delimiter: str
"""
if not bucket_name:
(bucket_name, wildcard_key) = self.parse_s3_url(wildcard_key)
prefix = re.split(r'[*]', wildcard_key, 1)[0]
klist = self.list_keys(bucket_name, prefix=prefix, delimiter=delimiter)
if klist:
key_matches = [k for k in klist if fnmatch.fnmatch(k, wildcard_key)]
if key_matches:
return self.get_key(key_matches[0], bucket_name)
[docs] def load_file(self,
filename,
key,
bucket_name=None,
replace=False,
encrypt=False):
"""
Loads a local file to S3
:param filename: name of the file to load.
:type filename: str
:param key: S3 key that will point to the file
:type key: str
:param bucket_name: Name of the bucket in which to store the file
:type bucket_name: str
:param replace: A flag to decide whether or not to overwrite the key
if it already exists. If replace is False and the key exists, an
error will be raised.
:type replace: bool
:param encrypt: If True, the file will be encrypted on the server-side
by S3 and will be stored in an encrypted form while at rest in S3.
:type encrypt: bool
"""
if not bucket_name:
(bucket_name, key) = self.parse_s3_url(key)
if not replace and self.check_for_key(key, bucket_name):
raise ValueError("The key {key} already exists.".format(key=key))
extra_args = {}
if encrypt:
extra_args['ServerSideEncryption'] = "AES256"
client = self.get_conn()
client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args)
[docs] def load_string(self,
string_data,
key,
bucket_name=None,
replace=False,
encrypt=False,
encoding='utf-8'):
"""
Loads a string to S3
This is provided as a convenience to drop a string in S3. It uses the
boto infrastructure to ship a file to s3.
:param string_data: string to set as content for the key.
:type string_data: str
:param key: S3 key that will point to the file
:type key: str
:param bucket_name: Name of the bucket in which to store the file
:type bucket_name: str
:param replace: A flag to decide whether or not to overwrite the key
if it already exists
:type replace: bool
:param encrypt: If True, the file will be encrypted on the server-side
by S3 and will be stored in an encrypted form while at rest in S3.
:type encrypt: bool
"""
self.load_bytes(string_data.encode(encoding),
key=key,
bucket_name=bucket_name,
replace=replace,
encrypt=encrypt)
[docs] def load_bytes(self,
bytes_data,
key,
bucket_name=None,
replace=False,
encrypt=False):
"""
Loads bytes to S3
This is provided as a convenience to drop a string in S3. It uses the
boto infrastructure to ship a file to s3.
:param bytes_data: bytes to set as content for the key.
:type bytes_data: bytes
:param key: S3 key that will point to the file
:type key: str
:param bucket_name: Name of the bucket in which to store the file
:type bucket_name: str
:param replace: A flag to decide whether or not to overwrite the key
if it already exists
:type replace: bool
:param encrypt: If True, the file will be encrypted on the server-side
by S3 and will be stored in an encrypted form while at rest in S3.
:type encrypt: bool
"""
if not bucket_name:
(bucket_name, key) = self.parse_s3_url(key)
if not replace and self.check_for_key(key, bucket_name):
raise ValueError("The key {key} already exists.".format(key=key))
extra_args = {}
if encrypt:
extra_args['ServerSideEncryption'] = "AES256"
filelike_buffer = BytesIO(bytes_data)
client = self.get_conn()
client.upload_fileobj(filelike_buffer, bucket_name, key, ExtraArgs=extra_args)
[docs] def load_file_obj(self,
file_obj,
key,
bucket_name=None,
replace=False,
encrypt=False):
"""
Loads a file object to S3
:param file_obj: The file-like object to set as the content for the S3 key.
:type file_obj: file-like object
:param key: S3 key that will point to the file
:type key: str
:param bucket_name: Name of the bucket in which to store the file
:type bucket_name: str
:param replace: A flag that indicates whether to overwrite the key
if it already exists.
:type replace: bool
:param encrypt: If True, S3 encrypts the file on the server,
and the file is stored in encrypted form at rest in S3.
:type encrypt: bool
"""
if not bucket_name:
(bucket_name, key) = self.parse_s3_url(key)
if not replace and self.check_for_key(key, bucket_name):
raise ValueError("The key {key} already exists.".format(key=key))
extra_args = {}
if encrypt:
extra_args['ServerSideEncryption'] = "AES256"
client = self.get_conn()
client.upload_fileobj(file_obj, bucket_name, key, ExtraArgs=extra_args)
[docs] def copy_object(self,
source_bucket_key,
dest_bucket_key,
source_bucket_name=None,
dest_bucket_name=None,
source_version_id=None):
"""
Creates a copy of an object that is already stored in S3.
Note: the S3 connection used here needs to have access to both
source and destination bucket/key.
:param source_bucket_key: The key of the source object.
It can be either full s3:// style url or relative path from root level.
When it's specified as a full s3:// url, please omit source_bucket_name.
:type source_bucket_key: str
:param dest_bucket_key: The key of the object to copy to.
The convention to specify `dest_bucket_key` is the same
as `source_bucket_key`.
:type dest_bucket_key: str
:param source_bucket_name: Name of the S3 bucket where the source object is in.
It should be omitted when `source_bucket_key` is provided as a full s3:// url.
:type source_bucket_name: str
:param dest_bucket_name: Name of the S3 bucket to where the object is copied.
It should be omitted when `dest_bucket_key` is provided as a full s3:// url.
:type dest_bucket_name: str
:param source_version_id: Version ID of the source object (OPTIONAL)
:type source_version_id: str
"""
if dest_bucket_name is None:
dest_bucket_name, dest_bucket_key = self.parse_s3_url(dest_bucket_key)
else:
parsed_url = urlparse(dest_bucket_key)
if parsed_url.scheme != '' or parsed_url.netloc != '':
raise AirflowException('If dest_bucket_name is provided, ' +
'dest_bucket_key should be relative path ' +
'from root level, rather than a full s3:// url')
if source_bucket_name is None:
source_bucket_name, source_bucket_key = self.parse_s3_url(source_bucket_key)
else:
parsed_url = urlparse(source_bucket_key)
if parsed_url.scheme != '' or parsed_url.netloc != '':
raise AirflowException('If source_bucket_name is provided, ' +
'source_bucket_key should be relative path ' +
'from root level, rather than a full s3:// url')
CopySource = {'Bucket': source_bucket_name,
'Key': source_bucket_key,
'VersionId': source_version_id}
response = self.get_conn().copy_object(Bucket=dest_bucket_name,
Key=dest_bucket_key,
CopySource=CopySource)
return response
[docs] def delete_objects(self,
bucket,
keys):
"""
:param bucket: Name of the bucket in which you are going to delete object(s)
:type bucket: str
:param keys: The key(s) to delete from S3 bucket.
When ``keys`` is a string, it's supposed to be the key name of
the single object to delete.
When ``keys`` is a list, it's supposed to be the list of the
keys to delete.
:type keys: str or list
"""
if isinstance(keys, list):
keys = keys
else:
keys = [keys]
delete_dict = {"Objects": [{"Key": k} for k in keys]}
response = self.get_conn().delete_objects(Bucket=bucket,
Delete=delete_dict)
return response