# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import random
import time
from googleapiclient.errors import HttpError
from googleapiclient.discovery import build
from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
[docs]log = logging.getLogger(__name__)
[docs]def _poll_with_exponential_delay(request, max_n, is_done_func, is_error_func):
for i in range(0, max_n):
try:
response = request.execute()
if is_error_func(response):
raise ValueError(
'The response contained an error: {}'.format(response)
)
elif is_done_func(response):
log.info('Operation is done: %s', response)
return response
else:
time.sleep((2**i) + (random.randint(0, 1000) / 1000))
except HttpError as e:
if e.resp.status != 429:
log.info('Something went wrong. Not retrying: %s', format(e))
raise
else:
time.sleep((2**i) + (random.randint(0, 1000) / 1000))
[docs]class MLEngineHook(GoogleCloudBaseHook):
def __init__(self, gcp_conn_id='google_cloud_default', delegate_to=None):
super(MLEngineHook, self).__init__(gcp_conn_id, delegate_to)
self._mlengine = self.get_conn()
[docs] def get_conn(self):
"""
Returns a Google MLEngine service object.
"""
authed_http = self._authorize()
return build('ml', 'v1', http=authed_http, cache_discovery=False)
[docs] def create_job(self, project_id, job, use_existing_job_fn=None):
"""
Launches a MLEngine job and wait for it to reach a terminal state.
:param project_id: The Google Cloud project id within which MLEngine
job will be launched.
:type project_id: str
:param job: MLEngine Job object that should be provided to the MLEngine
API, such as: ::
{
'jobId': 'my_job_id',
'trainingInput': {
'scaleTier': 'STANDARD_1',
...
}
}
:type job: dict
:param use_existing_job_fn: In case that a MLEngine job with the same
job_id already exist, this method (if provided) will decide whether
we should use this existing job, continue waiting for it to finish
and returning the job object. It should accepts a MLEngine job
object, and returns a boolean value indicating whether it is OK to
reuse the existing job. If 'use_existing_job_fn' is not provided,
we by default reuse the existing MLEngine job.
:type use_existing_job_fn: function
:return: The MLEngine job object if the job successfully reach a
terminal state (which might be FAILED or CANCELLED state).
:rtype: dict
"""
request = self._mlengine.projects().jobs().create(
parent='projects/{}'.format(project_id),
body=job)
job_id = job['jobId']
try:
request.execute()
except HttpError as e:
# 409 means there is an existing job with the same job ID.
if e.resp.status == 409:
if use_existing_job_fn is not None:
existing_job = self._get_job(project_id, job_id)
if not use_existing_job_fn(existing_job):
self.log.error(
'Job with job_id %s already exist, but it does '
'not match our expectation: %s',
job_id, existing_job
)
raise
self.log.info(
'Job with job_id %s already exist. Will waiting for it to finish',
job_id
)
else:
self.log.error('Failed to create MLEngine job: {}'.format(e))
raise
return self._wait_for_job_done(project_id, job_id)
[docs] def _get_job(self, project_id, job_id):
"""
Gets a MLEngine job based on the job name.
:return: MLEngine job object if succeed.
:rtype: dict
Raises:
googleapiclient.errors.HttpError: if HTTP error is returned from server
"""
job_name = 'projects/{}/jobs/{}'.format(project_id, job_id)
request = self._mlengine.projects().jobs().get(name=job_name)
while True:
try:
return request.execute()
except HttpError as e:
if e.resp.status == 429:
# polling after 30 seconds when quota failure occurs
time.sleep(30)
else:
self.log.error('Failed to get MLEngine job: {}'.format(e))
raise
[docs] def _wait_for_job_done(self, project_id, job_id, interval=30):
"""
Waits for the Job to reach a terminal state.
This method will periodically check the job state until the job reach
a terminal state.
Raises:
googleapiclient.errors.HttpError: if HTTP error is returned when getting
the job
"""
if interval <= 0:
raise ValueError("Interval must be > 0")
while True:
job = self._get_job(project_id, job_id)
if job['state'] in ['SUCCEEDED', 'FAILED', 'CANCELLED']:
return job
time.sleep(interval)
[docs] def create_version(self, project_id, model_name, version_spec):
"""
Creates the Version on Google Cloud ML Engine.
Returns the operation if the version was created successfully and
raises an error otherwise.
"""
parent_name = 'projects/{}/models/{}'.format(project_id, model_name)
create_request = self._mlengine.projects().models().versions().create(
parent=parent_name, body=version_spec)
response = create_request.execute()
get_request = self._mlengine.projects().operations().get(
name=response['name'])
return _poll_with_exponential_delay(
request=get_request,
max_n=9,
is_done_func=lambda resp: resp.get('done', False),
is_error_func=lambda resp: resp.get('error', None) is not None)
[docs] def set_default_version(self, project_id, model_name, version_name):
"""
Sets a version to be the default. Blocks until finished.
"""
full_version_name = 'projects/{}/models/{}/versions/{}'.format(
project_id, model_name, version_name)
request = self._mlengine.projects().models().versions().setDefault(
name=full_version_name, body={})
try:
response = request.execute()
self.log.info('Successfully set version: %s to default', response)
return response
except HttpError as e:
self.log.error('Something went wrong: %s', e)
raise
[docs] def list_versions(self, project_id, model_name):
"""
Lists all available versions of a model. Blocks until finished.
"""
result = []
full_parent_name = 'projects/{}/models/{}'.format(
project_id, model_name)
request = self._mlengine.projects().models().versions().list(
parent=full_parent_name, pageSize=100)
response = request.execute()
next_page_token = response.get('nextPageToken', None)
result.extend(response.get('versions', []))
while next_page_token is not None:
next_request = self._mlengine.projects().models().versions().list(
parent=full_parent_name,
pageToken=next_page_token,
pageSize=100)
response = next_request.execute()
next_page_token = response.get('nextPageToken', None)
result.extend(response.get('versions', []))
time.sleep(5)
return result
[docs] def delete_version(self, project_id, model_name, version_name):
"""
Deletes the given version of a model. Blocks until finished.
"""
full_name = 'projects/{}/models/{}/versions/{}'.format(
project_id, model_name, version_name)
delete_request = self._mlengine.projects().models().versions().delete(
name=full_name)
response = delete_request.execute()
get_request = self._mlengine.projects().operations().get(
name=response['name'])
return _poll_with_exponential_delay(
request=get_request,
max_n=9,
is_done_func=lambda resp: resp.get('done', False),
is_error_func=lambda resp: resp.get('error', None) is not None)
[docs] def create_model(self, project_id, model):
"""
Create a Model. Blocks until finished.
"""
if not model['name']:
raise ValueError("Model name must be provided and "
"could not be an empty string")
project = 'projects/{}'.format(project_id)
request = self._mlengine.projects().models().create(
parent=project, body=model)
return request.execute()
[docs] def get_model(self, project_id, model_name):
"""
Gets a Model. Blocks until finished.
"""
if not model_name:
raise ValueError("Model name must be provided and "
"it could not be an empty string")
full_model_name = 'projects/{}/models/{}'.format(
project_id, model_name)
request = self._mlengine.projects().models().get(name=full_model_name)
try:
return request.execute()
except HttpError as e:
if e.resp.status == 404:
self.log.error('Model was not found: %s', e)
return None
raise