Source code for airflow.providers.google.cloud.hooks.mlengine
## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License."""This module contains a Google ML Engine Hook."""from__future__importannotationsimportcontextlibimportloggingimportrandomimporttimefromtypingimportTYPE_CHECKING,CallablefromaiohttpimportClientSessionfromgcloud.aio.authimportAioSession,Tokenfromgoogleapiclient.discoveryimportResource,buildfromgoogleapiclient.errorsimportHttpErrorfromairflow.exceptionsimportAirflowExceptionfromairflow.providers.google.common.hooks.base_googleimport(PROVIDE_PROJECT_ID,GoogleBaseAsyncHook,GoogleBaseHook,)fromairflow.versionimportversionasairflow_versionifTYPE_CHECKING:fromhttplib2importResponsefromrequestsimportSession
_AIRFLOW_VERSION="v"+airflow_version.replace(".","-").replace("+","-")def_poll_with_exponential_delay(request,execute_num_retries,max_n,is_done_func,is_error_func)->Response:""" Execute request with exponential delay. This method is intended to handle and retry in case of api-specific errors, such as 429 "Too Many Requests", unlike the `request.execute` which handles lower level errors like `ConnectionError`/`socket.timeout`/`ssl.SSLError`. :param request: request to be executed. :param execute_num_retries: num_retries for `request.execute` method. :param max_n: number of times to retry request in this method. :param is_done_func: callable to determine if operation is done. :param is_error_func: callable to determine if operation is failed. :return: response """foriinrange(0,max_n):try:response=request.execute(num_retries=execute_num_retries)ifis_error_func(response):raiseValueError(f"The response contained an error: {response}")ifis_done_func(response):log.info("Operation is done: %s",response)returnresponsetime.sleep((2**i)+random.random())exceptHttpErrorase:ife.resp.status!=429:log.info("Something went wrong. Not retrying: %s",format(e))raiseelse:time.sleep((2**i)+random.random())raiseValueError(f"Connection could not be established after {max_n} retries.")
[docs]classMLEngineHook(GoogleBaseHook):""" Hook for Google ML Engine APIs. All the methods in the hook where project_id is used must be called with keyword arguments rather than positional. """
[docs]defget_conn(self)->Resource:""" Retrieve the connection to MLEngine. :return: Google MLEngine services object. """authed_http=self._authorize()returnbuild("ml","v1",http=authed_http,cache_discovery=False)
@GoogleBaseHook.fallback_to_default_project_id
[docs]defcreate_job(self,job:dict,project_id:str,use_existing_job_fn:Callable|None=None)->dict:""" Launch a MLEngine job and wait for it to reach a terminal state. :param project_id: The Google Cloud project id within which MLEngine job will be launched. If set to None or missing, the default project_id from the Google Cloud connection is used. :param job: MLEngine Job object that should be provided to the MLEngine API, such as: :: { 'jobId': 'my_job_id', 'trainingInput': { 'scaleTier': 'STANDARD_1', ... } } :param use_existing_job_fn: In case that a MLEngine job with the same job_id already exist, this method (if provided) will decide whether we should use this existing job, continue waiting for it to finish and returning the job object. It should accepts a MLEngine job object, and returns a boolean value indicating whether it is OK to reuse the existing job. If 'use_existing_job_fn' is not provided, we by default reuse the existing MLEngine job. :return: The MLEngine job object if the job successfully reach a terminal state (which might be FAILED or CANCELLED state). """hook=self.get_conn()self._append_label(job)self.log.info("Creating job.")request=hook.projects().jobs().create(parent=f"projects/{project_id}",body=job)job_id=job["jobId"]try:request.execute(num_retries=self.num_retries)exceptHttpErrorase:# 409 means there is an existing job with the same job ID.ife.resp.status==409:ifuse_existing_job_fnisnotNone:existing_job=self.get_job(project_id,job_id)ifnotuse_existing_job_fn(existing_job):self.log.error("Job with job_id %s already exist, but it does not match our expectation: %s",job_id,existing_job,)raiseself.log.info("Job with job_id %s already exist. Will waiting for it to finish",job_id)else:self.log.error("Failed to create MLEngine job: %s",e)raisereturnself._wait_for_job_done(project_id,job_id)
@GoogleBaseHook.fallback_to_default_project_id
[docs]defcreate_job_without_waiting_result(self,body:dict,project_id:str,):""" Launch a MLEngine job and wait for it to reach a terminal state. :param project_id: The Google Cloud project id within which MLEngine job will be launched. If set to None or missing, the default project_id from the Google Cloud connection is used. :param body: MLEngine Job object that should be provided to the MLEngine API, such as: :: { 'jobId': 'my_job_id', 'trainingInput': { 'scaleTier': 'STANDARD_1', ... } } :return: The MLEngine job_id of the object if the job successfully reach a terminal state (which might be FAILED or CANCELLED state). """hook=self.get_conn()self._append_label(body)request=hook.projects().jobs().create(parent=f"projects/{project_id}",body=body)job_id=body["jobId"]request.execute(num_retries=self.num_retries)returnjob_id
@GoogleBaseHook.fallback_to_default_project_id
[docs]defcancel_job(self,job_id:str,project_id:str,)->dict:""" Cancel a MLEngine job. :param project_id: The Google Cloud project id within which MLEngine job will be cancelled. If set to None or missing, the default project_id from the Google Cloud connection is used. :param job_id: A unique id for the want-to-be cancelled Google MLEngine training job. :return: Empty dict if cancelled successfully :raises: googleapiclient.errors.HttpError """hook=self.get_conn()request=hook.projects().jobs().cancel(name=f"projects/{project_id}/jobs/{job_id}")try:returnrequest.execute(num_retries=self.num_retries)exceptHttpErrorase:ife.resp.status==404:self.log.error("Job with job_id %s does not exist. ",job_id)raiseelife.resp.status==400:self.log.info("Job with job_id %s is already complete, cancellation aborted.",job_id)return{}else:self.log.error("Failed to cancel MLEngine job: %s",e)raise
[docs]defget_job(self,project_id:str,job_id:str)->dict:""" Get a MLEngine job based on the job id. :param project_id: The project in which the Job is located. If set to None or missing, the default project_id from the Google Cloud connection is used. (templated) :param job_id: A unique id for the Google MLEngine job. (templated) :return: MLEngine job object if succeed. :raises: googleapiclient.errors.HttpError """hook=self.get_conn()job_name=f"projects/{project_id}/jobs/{job_id}"request=hook.projects().jobs().get(name=job_name)whileTrue:try:returnrequest.execute(num_retries=self.num_retries)exceptHttpErrorase:ife.resp.status==429:# polling after 30 seconds when quota failure occurstime.sleep(30)else:self.log.error("Failed to get MLEngine job: %s",e)raise
def_wait_for_job_done(self,project_id:str,job_id:str,interval:int=30):""" Wait for the Job to reach a terminal state. This method will periodically check the job state until the job reach a terminal state. :param project_id: The project in which the Job is located. If set to None or missing, the default project_id from the Google Cloud connection is used. (templated) :param job_id: A unique id for the Google MLEngine job. (templated) :param interval: Time expressed in seconds after which the job status is checked again. (templated) :raises: googleapiclient.errors.HttpError """self.log.info("Waiting for job. job_id=%s",job_id)ifinterval<=0:raiseValueError("Interval must be > 0")whileTrue:job=self.get_job(project_id,job_id)ifjob["state"]in["SUCCEEDED","FAILED","CANCELLED"]:returnjobtime.sleep(interval)@GoogleBaseHook.fallback_to_default_project_id
[docs]defcreate_version(self,model_name:str,version_spec:dict,project_id:str,)->dict:""" Create the Version on Google Cloud ML Engine. :param version_spec: A dictionary containing the information about the version. (templated) :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. (templated) :param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None or missing, the default project_id from the Google Cloud connection is used. (templated) :return: If the version was created successfully, returns the operation. Otherwise raises an error . """hook=self.get_conn()parent_name=f"projects/{project_id}/models/{model_name}"self._append_label(version_spec)create_request=hook.projects().models().versions().create(parent=parent_name,body=version_spec)response=create_request.execute(num_retries=self.num_retries)get_request=hook.projects().operations().get(name=response["name"])return_poll_with_exponential_delay(request=get_request,execute_num_retries=self.num_retries,max_n=9,is_done_func=lambdaresp:resp.get("done",False),is_error_func=lambdaresp:resp.get("error",None)isnotNone,)
@GoogleBaseHook.fallback_to_default_project_id
[docs]defset_default_version(self,model_name:str,version_name:str,project_id:str,)->dict:""" Set a version to be the default. Blocks until finished. :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. (templated) :param version_name: A name to use for the version being operated upon. (templated) :param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None or missing, the default project_id from the Google Cloud connection is used. (templated) :return: If successful, return an instance of Version. Otherwise raises an error. :raises: googleapiclient.errors.HttpError """hook=self.get_conn()full_version_name=f"projects/{project_id}/models/{model_name}/versions/{version_name}"request=hook.projects().models().versions().setDefault(name=full_version_name,body={})try:response=request.execute(num_retries=self.num_retries)self.log.info("Successfully set version: %s to default",response)returnresponseexceptHttpErrorase:self.log.error("Something went wrong: %s",e)raise
@GoogleBaseHook.fallback_to_default_project_id
[docs]deflist_versions(self,model_name:str,project_id:str,)->list[dict]:""" List all available versions of a model. Blocks until finished. :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. (templated) :param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None or missing, the default project_id from the Google Cloud connection is used. (templated) :return: return an list of instance of Version. :raises: googleapiclient.errors.HttpError """hook=self.get_conn()result:list[dict]=[]full_parent_name=f"projects/{project_id}/models/{model_name}"request=hook.projects().models().versions().list(parent=full_parent_name,pageSize=100)whilerequestisnotNone:response=request.execute(num_retries=self.num_retries)result.extend(response.get("versions",[]))request=(hook.projects().models().versions().list_next(previous_request=request,previous_response=response))time.sleep(5)returnresult
@GoogleBaseHook.fallback_to_default_project_id
[docs]defdelete_version(self,model_name:str,version_name:str,project_id:str,)->dict:""" Delete the given version of a model. Blocks until finished. :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. (templated) :param project_id: The Google Cloud project name to which MLEngine model belongs. :param version_name: A name to use for the version being operated upon. (templated) :return: If the version was deleted successfully, returns the operation. Otherwise raises an error. """hook=self.get_conn()full_name=f"projects/{project_id}/models/{model_name}/versions/{version_name}"delete_request=hook.projects().models().versions().delete(name=full_name)response=delete_request.execute(num_retries=self.num_retries)get_request=hook.projects().operations().get(name=response["name"])return_poll_with_exponential_delay(request=get_request,execute_num_retries=self.num_retries,max_n=9,is_done_func=lambdaresp:resp.get("done",False),is_error_func=lambdaresp:resp.get("error",None)isnotNone,)
@GoogleBaseHook.fallback_to_default_project_id
[docs]defcreate_model(self,model:dict,project_id:str,)->dict:""" Create a Model. Blocks until finished. :param model: A dictionary containing the information about the model. :param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None or missing, the default project_id from the Google Cloud connection is used. (templated) :return: If the version was created successfully, returns the instance of Model. Otherwise raises an error. :raises: googleapiclient.errors.HttpError """hook=self.get_conn()if"name"notinmodelornotmodel["name"]:raiseValueError("Model name must be provided and could not be an empty string")project=f"projects/{project_id}"self._append_label(model)try:request=hook.projects().models().create(parent=project,body=model)response=request.execute(num_retries=self.num_retries)exceptHttpErrorase:ife.resp.status!=409:raiseestr(e)# Fills in the error_details fieldifnote.error_detailsorlen(e.error_details)!=1:raiseeerror_detail=e.error_details[0]iferror_detail["@type"]!="type.googleapis.com/google.rpc.BadRequest":raiseeif"fieldViolations"notinerror_detailorlen(error_detail["fieldViolations"])!=1:raiseefield_violation=error_detail["fieldViolations"][0]if(field_violation["field"]!="model.name"orfield_violation["description"]!="A model with the same name already exists."):raiseeresponse=self.get_model(model_name=model["name"],project_id=project_id)returnresponse
@GoogleBaseHook.fallback_to_default_project_id
[docs]defget_model(self,model_name:str,project_id:str,)->dict|None:""" Get a Model. Blocks until finished. :param model_name: The name of the model. :param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None or missing, the default project_id from the Google Cloud connection is used. (templated) :return: If the model exists, returns the instance of Model. Otherwise return None. :raises: googleapiclient.errors.HttpError """hook=self.get_conn()ifnotmodel_name:raiseValueError("Model name must be provided and it could not be an empty string")full_model_name=f"projects/{project_id}/models/{model_name}"request=hook.projects().models().get(name=full_model_name)try:returnrequest.execute(num_retries=self.num_retries)exceptHttpErrorase:ife.resp.status==404:self.log.error("Model was not found: %s",e)returnNoneraise
@GoogleBaseHook.fallback_to_default_project_id
[docs]defdelete_model(self,model_name:str,project_id:str,delete_contents:bool=False,)->None:""" Delete a Model. Blocks until finished. :param model_name: The name of the model. :param delete_contents: Whether to force the deletion even if the models is not empty. Will delete all version (if any) in the dataset if set to True. The default value is False. :param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None or missing, the default project_id from the Google Cloud connection is used. (templated) :raises: googleapiclient.errors.HttpError """hook=self.get_conn()ifnotmodel_name:raiseValueError("Model name must be provided and it could not be an empty string")model_path=f"projects/{project_id}/models/{model_name}"ifdelete_contents:self._delete_all_versions(model_name,project_id)request=hook.projects().models().delete(name=model_path)try:request.execute(num_retries=self.num_retries)exceptHttpErrorase:ife.resp.status==404:self.log.error("Model was not found: %s",e)returnraise
def_delete_all_versions(self,model_name:str,project_id:str):versions=self.list_versions(project_id=project_id,model_name=model_name)# The default version can only be deleted when it is the last one in the modelnon_default_versions=(versionforversioninversionsifnotversion.get("isDefault",False))forversioninnon_default_versions:_,_,version_name=version["name"].rpartition("/")self.delete_version(project_id=project_id,model_name=model_name,version_name=version_name)default_versions=(versionforversioninversionsifversion.get("isDefault",False))forversionindefault_versions:_,_,version_name=version["name"].rpartition("/")self.delete_version(project_id=project_id,model_name=model_name,version_name=version_name)def_append_label(self,model:dict)->None:model["labels"]=model.get("labels",{})model["labels"]["airflow-version"]=_AIRFLOW_VERSION
[docs]classMLEngineAsyncHook(GoogleBaseAsyncHook):"""Class to get asynchronous hook for MLEngine."""
def_check_fileds(self,job_id:str,project_id:str=PROVIDE_PROJECT_ID,):ifnotproject_id:raiseAirflowException("Google Cloud project id is required.")ifnotjob_id:raiseAirflowException("An unique job id is required for Google MLEngine training job.")asyncdef_get_link(self,url:str,session:Session):asyncwithToken(scopes=self.scopes)astoken:session_aio=AioSession(session)headers={"Authorization":f"Bearer {awaittoken.get()}",}withcontextlib.suppress(AirflowException):# suppress AirflowException because we don't want to raise exceptionjob=awaitsession_aio.get(url=url,headers=headers)returnjob
[docs]asyncdefget_job(self,job_id:str,session:Session,project_id:str=PROVIDE_PROJECT_ID):"""Get the specified job resource by job ID and project ID."""self._check_fileds(project_id=project_id,job_id=job_id)url=f"https://ml.googleapis.com/v1/projects/{project_id}/jobs/{job_id}"returnawaitself._get_link(url=url,session=session)
[docs]asyncdefget_job_status(self,job_id:str,project_id:str=PROVIDE_PROJECT_ID,)->str|None:""" Poll for job status asynchronously using gcloud-aio. Note that an OSError is raised when Job results are still pending. Exception means that Job finished with errors """self._check_fileds(project_id=project_id,job_id=job_id)asyncwithClientSession()assession:try:job=awaitself.get_job(project_id=project_id,job_id=job_id,session=session,# type: ignore)job=awaitjob.json(content_type=None)self.log.info("Retrieving json_response: %s",job)ifjob["state"]in["SUCCEEDED","FAILED","CANCELLED"]:job_status="success"elifjob["state"]in["PREPARING","RUNNING"]:job_status="pending"exceptOSError:job_status="pending"exceptExceptionase:self.log.info("Query execution finished with errors...")job_status=str(e)returnjob_status