Source code for airflow.providers.google.cloud.hooks.mlengine

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains a Google ML Engine Hook."""
from __future__ import annotations

import logging
import random
import time
from typing import Callable, Dict, List

from googleapiclient.discovery import Resource, build
from googleapiclient.errors import HttpError

from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.version import version as airflow_version

[docs]log = logging.getLogger(__name__)
_AIRFLOW_VERSION = 'v' + airflow_version.replace('.', '-').replace('+', '-') def _poll_with_exponential_delay(request, execute_num_retries, max_n, is_done_func, is_error_func): """ Execute request with exponential delay. This method is intended to handle and retry in case of api-specific errors, such as 429 "Too Many Requests", unlike the `request.execute` which handles lower level errors like `ConnectionError`/`socket.timeout`/`ssl.SSLError`. :param request: request to be executed. :param execute_num_retries: num_retries for `request.execute` method. :param max_n: number of times to retry request in this method. :param is_done_func: callable to determine if operation is done. :param is_error_func: callable to determine if operation is failed. :return: response :rtype: httplib2.Response """ for i in range(0, max_n): try: response = request.execute(num_retries=execute_num_retries) if is_error_func(response): raise ValueError(f'The response contained an error: {response}') if is_done_func(response): log.info('Operation is done: %s', response) return response time.sleep((2**i) + (random.randint(0, 1000) / 1000)) except HttpError as e: if e.resp.status != 429: log.info('Something went wrong. Not retrying: %s', format(e)) raise else: time.sleep((2**i) + (random.randint(0, 1000) / 1000)) raise ValueError(f'Connection could not be established after {max_n} retries.')
[docs]class MLEngineHook(GoogleBaseHook): """ Hook for Google ML Engine APIs. All the methods in the hook where project_id is used must be called with keyword arguments rather than positional. """
[docs] def get_conn(self) -> Resource: """ Retrieves the connection to MLEngine. :return: Google MLEngine services object. """ authed_http = self._authorize() return build('ml', 'v1', http=authed_http, cache_discovery=False)
@GoogleBaseHook.fallback_to_default_project_id
[docs] def create_job(self, job: dict, project_id: str, use_existing_job_fn: Callable | None = None) -> dict: """ Launches a MLEngine job and wait for it to reach a terminal state. :param project_id: The Google Cloud project id within which MLEngine job will be launched. If set to None or missing, the default project_id from the Google Cloud connection is used. :param job: MLEngine Job object that should be provided to the MLEngine API, such as: :: { 'jobId': 'my_job_id', 'trainingInput': { 'scaleTier': 'STANDARD_1', ... } } :param use_existing_job_fn: In case that a MLEngine job with the same job_id already exist, this method (if provided) will decide whether we should use this existing job, continue waiting for it to finish and returning the job object. It should accepts a MLEngine job object, and returns a boolean value indicating whether it is OK to reuse the existing job. If 'use_existing_job_fn' is not provided, we by default reuse the existing MLEngine job. :return: The MLEngine job object if the job successfully reach a terminal state (which might be FAILED or CANCELLED state). :rtype: dict """ hook = self.get_conn() self._append_label(job) self.log.info("Creating job.") request = hook.projects().jobs().create(parent=f'projects/{project_id}', body=job) job_id = job['jobId'] try: request.execute(num_retries=self.num_retries) except HttpError as e: # 409 means there is an existing job with the same job ID. if e.resp.status == 409: if use_existing_job_fn is not None: existing_job = self._get_job(project_id, job_id) if not use_existing_job_fn(existing_job): self.log.error( 'Job with job_id %s already exist, but it does not match our expectation: %s', job_id, existing_job, ) raise self.log.info('Job with job_id %s already exist. Will waiting for it to finish', job_id) else: self.log.error('Failed to create MLEngine job: %s', e) raise return self._wait_for_job_done(project_id, job_id)
@GoogleBaseHook.fallback_to_default_project_id
[docs] def cancel_job( self, job_id: str, project_id: str, ) -> dict: """ Cancels a MLEngine job. :param project_id: The Google Cloud project id within which MLEngine job will be cancelled. If set to None or missing, the default project_id from the Google Cloud connection is used. :param job_id: A unique id for the want-to-be cancelled Google MLEngine training job. :return: Empty dict if cancelled successfully :rtype: dict :raises: googleapiclient.errors.HttpError """ hook = self.get_conn() request = hook.projects().jobs().cancel(name=f'projects/{project_id}/jobs/{job_id}') try: return request.execute(num_retries=self.num_retries) except HttpError as e: if e.resp.status == 404: self.log.error('Job with job_id %s does not exist. ', job_id) raise elif e.resp.status == 400: self.log.info('Job with job_id %s is already complete, cancellation aborted.', job_id) return {} else: self.log.error('Failed to cancel MLEngine job: %s', e) raise
def _get_job(self, project_id: str, job_id: str) -> dict: """ Gets a MLEngine job based on the job id. :param project_id: The project in which the Job is located. If set to None or missing, the default project_id from the Google Cloud connection is used. (templated) :param job_id: A unique id for the Google MLEngine job. (templated) :return: MLEngine job object if succeed. :rtype: dict :raises: googleapiclient.errors.HttpError """ hook = self.get_conn() job_name = f'projects/{project_id}/jobs/{job_id}' request = hook.projects().jobs().get(name=job_name) while True: try: return request.execute(num_retries=self.num_retries) except HttpError as e: if e.resp.status == 429: # polling after 30 seconds when quota failure occurs time.sleep(30) else: self.log.error('Failed to get MLEngine job: %s', e) raise def _wait_for_job_done(self, project_id: str, job_id: str, interval: int = 30): """ Waits for the Job to reach a terminal state. This method will periodically check the job state until the job reach a terminal state. :param project_id: The project in which the Job is located. If set to None or missing, the default project_id from the Google Cloud connection is used. (templated) :param job_id: A unique id for the Google MLEngine job. (templated) :param interval: Time expressed in seconds after which the job status is checked again. (templated) :raises: googleapiclient.errors.HttpError """ self.log.info("Waiting for job. job_id=%s", job_id) if interval <= 0: raise ValueError("Interval must be > 0") while True: job = self._get_job(project_id, job_id) if job['state'] in ['SUCCEEDED', 'FAILED', 'CANCELLED']: return job time.sleep(interval) @GoogleBaseHook.fallback_to_default_project_id
[docs] def create_version( self, model_name: str, version_spec: dict, project_id: str, ) -> dict: """ Creates the Version on Google Cloud ML Engine. :param version_spec: A dictionary containing the information about the version. (templated) :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. (templated) :param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None or missing, the default project_id from the Google Cloud connection is used. (templated) :return: If the version was created successfully, returns the operation. Otherwise raises an error . :rtype: dict """ hook = self.get_conn() parent_name = f'projects/{project_id}/models/{model_name}' self._append_label(version_spec) create_request = hook.projects().models().versions().create(parent=parent_name, body=version_spec) response = create_request.execute(num_retries=self.num_retries) get_request = hook.projects().operations().get(name=response['name']) return _poll_with_exponential_delay( request=get_request, execute_num_retries=self.num_retries, max_n=9, is_done_func=lambda resp: resp.get('done', False), is_error_func=lambda resp: resp.get('error', None) is not None,
) @GoogleBaseHook.fallback_to_default_project_id
[docs] def set_default_version( self, model_name: str, version_name: str, project_id: str, ) -> dict: """ Sets a version to be the default. Blocks until finished. :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. (templated) :param version_name: A name to use for the version being operated upon. (templated) :param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None or missing, the default project_id from the Google Cloud connection is used. (templated) :return: If successful, return an instance of Version. Otherwise raises an error. :rtype: dict :raises: googleapiclient.errors.HttpError """ hook = self.get_conn() full_version_name = f'projects/{project_id}/models/{model_name}/versions/{version_name}' request = hook.projects().models().versions().setDefault(name=full_version_name, body={}) try: response = request.execute(num_retries=self.num_retries) self.log.info('Successfully set version: %s to default', response) return response except HttpError as e: self.log.error('Something went wrong: %s', e) raise
@GoogleBaseHook.fallback_to_default_project_id
[docs] def list_versions( self, model_name: str, project_id: str, ) -> list[dict]: """ Lists all available versions of a model. Blocks until finished. :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. (templated) :param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None or missing, the default project_id from the Google Cloud connection is used. (templated) :return: return an list of instance of Version. :rtype: List[Dict] :raises: googleapiclient.errors.HttpError """ hook = self.get_conn() result = [] # type: List[Dict] full_parent_name = f'projects/{project_id}/models/{model_name}' request = hook.projects().models().versions().list(parent=full_parent_name, pageSize=100) while request is not None: response = request.execute(num_retries=self.num_retries) result.extend(response.get('versions', [])) request = ( hook.projects() .models() .versions() .list_next(previous_request=request, previous_response=response) ) time.sleep(5) return result
@GoogleBaseHook.fallback_to_default_project_id
[docs] def delete_version( self, model_name: str, version_name: str, project_id: str, ) -> dict: """ Deletes the given version of a model. Blocks until finished. :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. (templated) :param project_id: The Google Cloud project name to which MLEngine model belongs. :return: If the version was deleted successfully, returns the operation. Otherwise raises an error. :rtype: Dict """ hook = self.get_conn() full_name = f'projects/{project_id}/models/{model_name}/versions/{version_name}' delete_request = hook.projects().models().versions().delete(name=full_name) response = delete_request.execute(num_retries=self.num_retries) get_request = hook.projects().operations().get(name=response['name']) return _poll_with_exponential_delay( request=get_request, execute_num_retries=self.num_retries, max_n=9, is_done_func=lambda resp: resp.get('done', False), is_error_func=lambda resp: resp.get('error', None) is not None,
) @GoogleBaseHook.fallback_to_default_project_id
[docs] def create_model( self, model: dict, project_id: str, ) -> dict: """ Create a Model. Blocks until finished. :param model: A dictionary containing the information about the model. :param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None or missing, the default project_id from the Google Cloud connection is used. (templated) :return: If the version was created successfully, returns the instance of Model. Otherwise raises an error. :rtype: Dict :raises: googleapiclient.errors.HttpError """ hook = self.get_conn() if 'name' not in model or not model['name']: raise ValueError("Model name must be provided and could not be an empty string") project = f'projects/{project_id}' self._append_label(model) try: request = hook.projects().models().create(parent=project, body=model) response = request.execute(num_retries=self.num_retries) except HttpError as e: if e.resp.status != 409: raise e str(e) # Fills in the error_details field if not e.error_details or len(e.error_details) != 1: raise e error_detail = e.error_details[0] if error_detail["@type"] != 'type.googleapis.com/google.rpc.BadRequest': raise e if "fieldViolations" not in error_detail or len(error_detail['fieldViolations']) != 1: raise e field_violation = error_detail['fieldViolations'][0] if ( field_violation["field"] != "model.name" or field_violation["description"] != "A model with the same name already exists." ): raise e response = self.get_model(model_name=model['name'], project_id=project_id) return response
@GoogleBaseHook.fallback_to_default_project_id
[docs] def get_model( self, model_name: str, project_id: str, ) -> dict | None: """ Gets a Model. Blocks until finished. :param model_name: The name of the model. :param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None or missing, the default project_id from the Google Cloud connection is used. (templated) :return: If the model exists, returns the instance of Model. Otherwise return None. :rtype: Dict :raises: googleapiclient.errors.HttpError """ hook = self.get_conn() if not model_name: raise ValueError("Model name must be provided and it could not be an empty string") full_model_name = f'projects/{project_id}/models/{model_name}' request = hook.projects().models().get(name=full_model_name) try: return request.execute(num_retries=self.num_retries) except HttpError as e: if e.resp.status == 404: self.log.error('Model was not found: %s', e) return None raise
@GoogleBaseHook.fallback_to_default_project_id
[docs] def delete_model( self, model_name: str, project_id: str, delete_contents: bool = False, ) -> None: """ Delete a Model. Blocks until finished. :param model_name: The name of the model. :param delete_contents: Whether to force the deletion even if the models is not empty. Will delete all version (if any) in the dataset if set to True. The default value is False. :param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None or missing, the default project_id from the Google Cloud connection is used. (templated) :raises: googleapiclient.errors.HttpError """ hook = self.get_conn() if not model_name: raise ValueError("Model name must be provided and it could not be an empty string") model_path = f'projects/{project_id}/models/{model_name}' if delete_contents: self._delete_all_versions(model_name, project_id) request = hook.projects().models().delete(name=model_path) try: request.execute(num_retries=self.num_retries) except HttpError as e: if e.resp.status == 404: self.log.error('Model was not found: %s', e) return raise
def _delete_all_versions(self, model_name: str, project_id: str): versions = self.list_versions(project_id=project_id, model_name=model_name) # The default version can only be deleted when it is the last one in the model non_default_versions = (version for version in versions if not version.get('isDefault', False)) for version in non_default_versions: _, _, version_name = version['name'].rpartition('/') self.delete_version(project_id=project_id, model_name=model_name, version_name=version_name) default_versions = (version for version in versions if version.get('isDefault', False)) for version in default_versions: _, _, version_name = version['name'].rpartition('/') self.delete_version(project_id=project_id, model_name=model_name, version_name=version_name) def _append_label(self, model: dict) -> None: model['labels'] = model.get('labels', {}) model['labels']['airflow-version'] = _AIRFLOW_VERSION

Was this entry helpful?