# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tarfile
import tempfile
import time
import os
import collections
import warnings
from botocore.exceptions import ClientError
from airflow.exceptions import AirflowException
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.contrib.hooks.aws_logs_hook import AwsLogsHook
from airflow.hooks.S3_hook import S3Hook
from airflow.utils import timezone
# Position is a tuple that includes the last read timestamp and the number of items that were read
# at that time. This is used to figure out which event to start with on the next read.
[docs]Position = collections.namedtuple('Position', ['timestamp', 'skip'])
[docs]def argmin(arr, f):
"""Return the index, i, in arr that minimizes f(arr[i])"""
m = None
i = None
for idx, item in enumerate(arr):
if item is not None:
if m is None or f(item) < m:
m = f(item)
i = idx
return i
[docs]def secondary_training_status_changed(current_job_description, prev_job_description):
"""
Returns true if training job's secondary status message has changed.
:param current_job_description: Current job description, returned from DescribeTrainingJob call.
:type current_job_description: dict
:param prev_job_description: Previous job description, returned from DescribeTrainingJob call.
:type prev_job_description: dict
:return: Whether the secondary status message of a training job changed or not.
"""
current_secondary_status_transitions = current_job_description.get('SecondaryStatusTransitions')
if current_secondary_status_transitions is None or len(current_secondary_status_transitions) == 0:
return False
prev_job_secondary_status_transitions = prev_job_description.get('SecondaryStatusTransitions') \
if prev_job_description is not None else None
last_message = prev_job_secondary_status_transitions[-1]['StatusMessage'] \
if prev_job_secondary_status_transitions is not None \
and len(prev_job_secondary_status_transitions) > 0 else ''
message = current_job_description['SecondaryStatusTransitions'][-1]['StatusMessage']
return message != last_message
[docs]def secondary_training_status_message(job_description, prev_description):
"""
Returns a string contains start time and the secondary training job status message.
:param job_description: Returned response from DescribeTrainingJob call
:type job_description: dict
:param prev_description: Previous job description from DescribeTrainingJob call
:type prev_description: dict
:return: Job status string to be printed.
"""
if job_description is None or job_description.get('SecondaryStatusTransitions') is None\
or len(job_description.get('SecondaryStatusTransitions')) == 0:
return ''
prev_description_secondary_transitions = prev_description.get('SecondaryStatusTransitions')\
if prev_description is not None else None
prev_transitions_num = len(prev_description['SecondaryStatusTransitions'])\
if prev_description_secondary_transitions is not None else 0
current_transitions = job_description['SecondaryStatusTransitions']
transitions_to_print = current_transitions[-1:] if len(current_transitions) == prev_transitions_num else \
current_transitions[prev_transitions_num - len(current_transitions):]
status_strs = []
for transition in transitions_to_print:
message = transition['StatusMessage']
time_str = timezone.convert_to_utc(job_description['LastModifiedTime']).strftime('%Y-%m-%d %H:%M:%S')
status_strs.append('{} {} - {}'.format(time_str, transition['Status'], message))
return '\n'.join(status_strs)
[docs]class SageMakerHook(AwsHook):
"""
Interact with Amazon SageMaker.
"""
[docs] non_terminal_states = {'InProgress', 'Stopping'}
[docs] endpoint_non_terminal_states = {'Creating', 'Updating', 'SystemUpdating',
'RollingBack', 'Deleting'}
[docs] failed_states = {'Failed'}
def __init__(self,
*args, **kwargs):
super(SageMakerHook, self).__init__(*args, **kwargs)
self.s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
self.logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id)
[docs] def tar_and_s3_upload(self, path, key, bucket):
"""
Tar the local file or directory and upload to s3
:param path: local file or directory
:type path: str
:param key: s3 key
:type key: str
:param bucket: s3 bucket
:type bucket: str
:return: None
"""
with tempfile.TemporaryFile() as temp_file:
if os.path.isdir(path):
files = [os.path.join(path, name) for name in os.listdir(path)]
else:
files = [path]
with tarfile.open(mode='w:gz', fileobj=temp_file) as tar_file:
for f in files:
tar_file.add(f, arcname=os.path.basename(f))
temp_file.seek(0)
self.s3_hook.load_file_obj(temp_file, key, bucket, replace=True)
[docs] def check_s3_url(self, s3url):
"""
Check if an S3 URL exists
:param s3url: S3 url
:type s3url: str
:rtype: bool
"""
bucket, key = S3Hook.parse_s3_url(s3url)
if not self.s3_hook.check_for_bucket(bucket_name=bucket):
raise AirflowException(
"The input S3 Bucket {} does not exist ".format(bucket))
if key and not self.s3_hook.check_for_key(key=key, bucket_name=bucket)\
and not self.s3_hook.check_for_prefix(
prefix=key, bucket_name=bucket, delimiter='/'):
# check if s3 key exists in the case user provides a single file
# or if s3 prefix exists in the case user provides multiple files in
# a prefix
raise AirflowException("The input S3 Key "
"or Prefix {} does not exist in the Bucket {}"
.format(s3url, bucket))
return True
[docs] def check_training_config(self, training_config):
"""
Check if a training configuration is valid
:param training_config: training_config
:type training_config: dict
:return: None
"""
if "InputDataConfig" in training_config:
for channel in training_config['InputDataConfig']:
self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
[docs] def check_tuning_config(self, tuning_config):
"""
Check if a tuning configuration is valid
:param tuning_config: tuning_config
:type tuning_config: dict
:return: None
"""
for channel in tuning_config['TrainingJobDefinition']['InputDataConfig']:
self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
[docs] def get_conn(self):
"""
Establish an AWS connection for SageMaker
:rtype: :py:class:`SageMaker.Client`
"""
return self.get_client_type('sagemaker')
[docs] def get_log_conn(self):
"""
This method is deprecated.
Please use :py:meth:`airflow.contrib.hooks.AwsLogsHook.get_conn` instead.
"""
warnings.warn("Method `get_log_conn` has been deprecated. "
"Please use `airflow.contrib.hooks.AwsLogsHook.get_conn` instead.",
category=DeprecationWarning,
stacklevel=2)
return self.logs_hook.get_conn()
[docs] def log_stream(self, log_group, stream_name, start_time=0, skip=0):
"""
This method is deprecated.
Please use :py:meth:`airflow.contrib.hooks.AwsLogsHook.get_log_events` instead.
"""
warnings.warn("Method `log_stream` has been deprecated. "
"Please use `airflow.contrib.hooks.AwsLogsHook.get_log_events` instead.",
category=DeprecationWarning,
stacklevel=2)
return self.logs_hook.get_log_events(log_group, stream_name, start_time, skip)
[docs] def multi_stream_iter(self, log_group, streams, positions=None):
"""
Iterate over the available events coming from a set of log streams in a single log group
interleaving the events from each stream so they're yielded in timestamp order.
:param log_group: The name of the log group.
:type log_group: str
:param streams: A list of the log stream names. The position of the stream in this list is
the stream number.
:type streams: list
:param positions: A list of pairs of (timestamp, skip) which represents the last record
read from each stream.
:type positions: list
:return: A tuple of (stream number, cloudwatch log event).
"""
positions = positions or {s: Position(timestamp=0, skip=0) for s in streams}
event_iters = [self.logs_hook.get_log_events(log_group, s, positions[s].timestamp, positions[s].skip)
for s in streams]
events = []
for s in event_iters:
if not s:
events.append(None)
continue
try:
events.append(next(s))
except StopIteration:
events.append(None)
while any(events):
i = argmin(events, lambda x: x['timestamp'] if x else 9999999999)
yield (i, events[i])
try:
events[i] = next(event_iters[i])
except StopIteration:
events[i] = None
[docs] def create_training_job(self, config, wait_for_completion=True, print_log=True,
check_interval=30, max_ingestion_time=None):
"""
Create a training job
:param config: the config for training
:type config: dict
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: A response to training job creation
"""
self.check_training_config(config)
response = self.get_conn().create_training_job(**config)
if print_log:
self.check_training_status_with_log(config['TrainingJobName'],
self.non_terminal_states,
self.failed_states,
wait_for_completion,
check_interval, max_ingestion_time
)
elif wait_for_completion:
describe_response = self.check_status(config['TrainingJobName'],
'TrainingJobStatus',
self.describe_training_job,
check_interval, max_ingestion_time
)
billable_time = \
(describe_response['TrainingEndTime'] - describe_response['TrainingStartTime']) * \
describe_response['ResourceConfig']['InstanceCount']
self.log.info('Billable seconds:{}'.format(int(billable_time.total_seconds()) + 1))
return response
[docs] def create_tuning_job(self, config, wait_for_completion=True,
check_interval=30, max_ingestion_time=None):
"""
Create a tuning job
:param config: the config for tuning
:type config: dict
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: A response to tuning job creation
"""
self.check_tuning_config(config)
response = self.get_conn().create_hyper_parameter_tuning_job(**config)
if wait_for_completion:
self.check_status(config['HyperParameterTuningJobName'],
'HyperParameterTuningJobStatus',
self.describe_tuning_job,
check_interval, max_ingestion_time
)
return response
[docs] def create_model(self, config):
"""
Create a model job
:param config: the config for model
:type config: dict
:return: A response to model creation
"""
return self.get_conn().create_model(**config)
[docs] def create_endpoint_config(self, config):
"""
Create an endpoint config
:param config: the config for endpoint-config
:type config: dict
:return: A response to endpoint config creation
"""
return self.get_conn().create_endpoint_config(**config)
[docs] def create_endpoint(self, config, wait_for_completion=True,
check_interval=30, max_ingestion_time=None):
"""
Create an endpoint
:param config: the config for endpoint
:type config: dict
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: A response to endpoint creation
"""
response = self.get_conn().create_endpoint(**config)
if wait_for_completion:
self.check_status(config['EndpointName'],
'EndpointStatus',
self.describe_endpoint,
check_interval, max_ingestion_time,
non_terminal_states=self.endpoint_non_terminal_states
)
return response
[docs] def update_endpoint(self, config, wait_for_completion=True,
check_interval=30, max_ingestion_time=None):
"""
Update an endpoint
:param config: the config for endpoint
:type config: dict
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: A response to endpoint update
"""
response = self.get_conn().update_endpoint(**config)
if wait_for_completion:
self.check_status(config['EndpointName'],
'EndpointStatus',
self.describe_endpoint,
check_interval, max_ingestion_time,
non_terminal_states=self.endpoint_non_terminal_states
)
return response
[docs] def describe_training_job(self, name):
"""
Return the training job info associated with the name
:param name: the name of the training job
:type name: str
:return: A dict contains all the training job info
"""
return self.get_conn().describe_training_job(TrainingJobName=name)
[docs] def describe_training_job_with_log(self, job_name, positions, stream_names,
instance_count, state, last_description,
last_describe_job_call):
"""
Return the training job info associated with job_name and print CloudWatch logs
"""
log_group = '/aws/sagemaker/TrainingJobs'
if len(stream_names) < instance_count:
# Log streams are created whenever a container starts writing to stdout/err, so this list
# may be dynamic until we have a stream for every instance.
logs_conn = self.logs_hook.get_conn()
try:
streams = logs_conn.describe_log_streams(
logGroupName=log_group,
logStreamNamePrefix=job_name + '/',
orderBy='LogStreamName',
limit=instance_count
)
stream_names = [s['logStreamName'] for s in streams['logStreams']]
positions.update([(s, Position(timestamp=0, skip=0))
for s in stream_names if s not in positions])
except logs_conn.exceptions.ResourceNotFoundException:
# On the very first training job run on an account, there's no log group until
# the container starts logging, so ignore any errors thrown about that
pass
if len(stream_names) > 0:
for idx, event in self.multi_stream_iter(log_group, stream_names, positions):
self.log.info(event['message'])
ts, count = positions[stream_names[idx]]
if event['timestamp'] == ts:
positions[stream_names[idx]] = Position(timestamp=ts, skip=count + 1)
else:
positions[stream_names[idx]] = Position(timestamp=event['timestamp'], skip=1)
if state == LogState.COMPLETE:
return state, last_description, last_describe_job_call
if state == LogState.JOB_COMPLETE:
state = LogState.COMPLETE
elif time.time() - last_describe_job_call >= 30:
description = self.describe_training_job(job_name)
last_describe_job_call = time.time()
if secondary_training_status_changed(description, last_description):
self.log.info(secondary_training_status_message(description, last_description))
last_description = description
status = description['TrainingJobStatus']
if status not in self.non_terminal_states:
state = LogState.JOB_COMPLETE
return state, last_description, last_describe_job_call
[docs] def describe_tuning_job(self, name):
"""
Return the tuning job info associated with the name
:param name: the name of the tuning job
:type name: str
:return: A dict contains all the tuning job info
"""
return self.get_conn().describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=name)
[docs] def describe_model(self, name):
"""
Return the SageMaker model info associated with the name
:param name: the name of the SageMaker model
:type name: str
:return: A dict contains all the model info
"""
return self.get_conn().describe_model(ModelName=name)
[docs] def describe_endpoint_config(self, name):
"""
Return the endpoint config info associated with the name
:param name: the name of the endpoint config
:type name: str
:return: A dict contains all the endpoint config info
"""
return self.get_conn().describe_endpoint_config(EndpointConfigName=name)
[docs] def describe_endpoint(self, name):
"""
:param name: the name of the endpoint
:type name: str
:return: A dict contains all the endpoint info
"""
return self.get_conn().describe_endpoint(EndpointName=name)
[docs] def check_status(self, job_name, key,
describe_function, check_interval,
max_ingestion_time,
non_terminal_states=None):
"""
Check status of a SageMaker job
:param job_name: name of the job to check status
:type job_name: str
:param key: the key of the response dict
that points to the state
:type key: str
:param describe_function: the function used to retrieve the status
:type describe_function: python callable
:param args: the arguments for the function
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:param non_terminal_states: the set of nonterminal states
:type non_terminal_states: set
:return: response of describe call after job is done
"""
if not non_terminal_states:
non_terminal_states = self.non_terminal_states
sec = 0
running = True
while running:
time.sleep(check_interval)
sec = sec + check_interval
try:
response = describe_function(job_name)
status = response[key]
self.log.info('Job still running for %s seconds... '
'current status is %s' % (sec, status))
except KeyError:
raise AirflowException('Could not get status of the SageMaker job')
except ClientError:
raise AirflowException('AWS request failed, check logs for more info')
if status in non_terminal_states:
running = True
elif status in self.failed_states:
raise AirflowException('SageMaker job failed because %s' % response['FailureReason'])
else:
running = False
if max_ingestion_time and sec > max_ingestion_time:
# ensure that the job gets killed if the max ingestion time is exceeded
raise AirflowException('SageMaker job took more than %s seconds', max_ingestion_time)
self.log.info('SageMaker Job Compeleted')
response = describe_function(job_name)
return response
[docs] def check_training_status_with_log(self, job_name, non_terminal_states, failed_states,
wait_for_completion, check_interval, max_ingestion_time):
"""
Display the logs for a given training job, optionally tailing them until the
job is complete.
:param job_name: name of the training job to check status and display logs for
:type job_name: str
:param non_terminal_states: the set of non_terminal states
:type non_terminal_states: set
:param failed_states: the set of failed states
:type failed_states: set
:param wait_for_completion: Whether to keep looking for new log entries
until the job completes
:type wait_for_completion: bool
:param check_interval: The interval in seconds between polling for new log entries and job completion
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: None
"""
sec = 0
description = self.describe_training_job(job_name)
self.log.info(secondary_training_status_message(description, None))
instance_count = description['ResourceConfig']['InstanceCount']
status = description['TrainingJobStatus']
stream_names = [] # The list of log streams
positions = {} # The current position in each stream, map of stream name -> position
job_already_completed = status not in non_terminal_states
state = LogState.TAILING if wait_for_completion and not job_already_completed else LogState.COMPLETE
# The loop below implements a state machine that alternates between checking the job status and
# reading whatever is available in the logs at this point. Note, that if we were called with
# wait_for_completion == False, we never check the job status.
#
# If wait_for_completion == TRUE and job is not completed, the initial state is TAILING
# If wait_for_completion == FALSE, the initial state is COMPLETE
# (doesn't matter if the job really is complete).
#
# The state table:
#
# STATE ACTIONS CONDITION NEW STATE
# ---------------- ---------------- ----------------- ----------------
# TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
# Else TAILING
# JOB_COMPLETE Read logs, Pause Any COMPLETE
# COMPLETE Read logs, Exit N/A
#
# Notes:
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that
# got to Cloudwatch after the job was marked complete.
last_describe_job_call = time.time()
last_description = description
while True:
time.sleep(check_interval)
sec = sec + check_interval
state, last_description, last_describe_job_call = \
self.describe_training_job_with_log(job_name, positions, stream_names,
instance_count, state, last_description,
last_describe_job_call)
if state == LogState.COMPLETE:
break
if max_ingestion_time and sec > max_ingestion_time:
# ensure that the job gets killed if the max ingestion time is exceeded
raise AirflowException('SageMaker job took more than %s seconds', max_ingestion_time)
if wait_for_completion:
status = last_description['TrainingJobStatus']
if status in failed_states:
reason = last_description.get('FailureReason', '(No reason provided)')
raise AirflowException('Error training {}: {} Reason: {}'.format(job_name, status, reason))
billable_time = (last_description['TrainingEndTime'] - last_description['TrainingStartTime']) \
* instance_count
self.log.info('Billable seconds:{}'.format(int(billable_time.total_seconds()) + 1))