Source code for airflow.providers.google.cloud.hooks.vertex_ai.model_service
## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License."""This module contains a Google Cloud Vertex AI hook."""from__future__importannotationsfromtypingimportTYPE_CHECKING,Sequencefromgoogle.api_core.client_optionsimportClientOptionsfromgoogle.api_core.gapic_v1.methodimportDEFAULT,_MethodDefaultfromgoogle.cloud.aiplatform_v1importModelServiceClientfromairflow.exceptionsimportAirflowExceptionfromairflow.providers.google.common.hooks.base_googleimportGoogleBaseHookifTYPE_CHECKING:fromgoogle.api_core.operationimportOperationfromgoogle.api_core.retryimportRetryfromgoogle.cloud.aiplatform_v1.services.model_service.pagersimport(ListModelsPager,ListModelVersionsPager,)fromgoogle.cloud.aiplatform_v1.typesimportModel,model_service
[docs]classModelServiceHook(GoogleBaseHook):"""Hook for Google Cloud Vertex AI Endpoint Service APIs."""def__init__(self,**kwargs):ifkwargs.get("delegate_to")isnotNone:raiseRuntimeError("The `delegate_to` parameter has been deprecated before and finally removed in this version"" of Google Provider. You MUST convert it to `impersonate_chain`")super().__init__(**kwargs)
[docs]defextract_model_id(obj:dict)->str:"""Returns unique id of the model."""returnobj["model"].rpartition("/")[-1]
[docs]defwait_for_operation(self,operation:Operation,timeout:float|None=None):"""Waits for long-lasting operation to complete."""try:returnoperation.result(timeout=timeout)exceptException:error=operation.exception(timeout=timeout)raiseAirflowException(error)
@GoogleBaseHook.fallback_to_default_project_id
[docs]defdelete_model(self,project_id:str,region:str,model:str,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),)->Operation:""" Deletes a Model. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. :param model: Required. The name of the Model resource to be deleted. :param retry: Designation of what errors, if any, should be retried. :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """client=self.get_model_service_client(region)name=client.model_path(project_id,region,model)result=client.delete_model(request={"name":name,},retry=retry,timeout=timeout,metadata=metadata,)returnresult
@GoogleBaseHook.fallback_to_default_project_id
[docs]defexport_model(self,project_id:str,region:str,model:str,output_config:model_service.ExportModelRequest.OutputConfig|dict,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),)->Operation:""" Exports a trained, exportable Model to a location specified by the user. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. :param model: Required. The resource name of the Model to export. :param output_config: Required. The desired output location and configuration. :param retry: Designation of what errors, if any, should be retried. :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """client=self.get_model_service_client(region)name=client.model_path(project_id,region,model)result=client.export_model(request={"name":name,"output_config":output_config,},retry=retry,timeout=timeout,metadata=metadata,)returnresult
@GoogleBaseHook.fallback_to_default_project_id
[docs]deflist_models(self,project_id:str,region:str,filter:str|None=None,page_size:int|None=None,page_token:str|None=None,read_mask:str|None=None,order_by:str|None=None,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),)->ListModelsPager:r""" Lists Models in a Location. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. :param filter: An expression for filtering the results of the request. For field names both snake_case and camelCase are supported. - ``model`` supports = and !=. ``model`` represents the Model ID, i.e. the last segment of the Model's [resource name][google.cloud.aiplatform.v1.Model.name]. - ``display_name`` supports = and != - ``labels`` supports general map functions that is: -- ``labels.key=value`` - key:value equality -- \`labels.key:\* or labels:key - key existence -- A key including a space must be quoted. ``labels."a key"``. :param page_size: The standard list page size. :param page_token: The standard list page token. Typically obtained via [ListModelsResponse.next_page_token][google.cloud.aiplatform.v1.ListModelsResponse.next_page_token] of the previous [ModelService.ListModels][google.cloud.aiplatform.v1.ModelService.ListModels] call. :param read_mask: Mask specifying which fields to read. :param order_by: A comma-separated list of fields to order by, sorted in ascending order. Use "desc" after a field name for descending. :param retry: Designation of what errors, if any, should be retried. :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """client=self.get_model_service_client(region)parent=client.common_location_path(project_id,region)result=client.list_models(request={"parent":parent,"filter":filter,"page_size":page_size,"page_token":page_token,"read_mask":read_mask,"order_by":order_by,},retry=retry,timeout=timeout,metadata=metadata,)returnresult
@GoogleBaseHook.fallback_to_default_project_id
[docs]defupload_model(self,project_id:str,region:str,model:Model|dict,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),)->Operation:""" Uploads a Model artifact into Vertex AI. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. :param model: Required. The Model to create. :param retry: Designation of what errors, if any, should be retried. :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """client=self.get_model_service_client(region)parent=client.common_location_path(project_id,region)result=client.upload_model(request={"parent":parent,"model":model,},retry=retry,timeout=timeout,metadata=metadata,)returnresult
@GoogleBaseHook.fallback_to_default_project_id
[docs]deflist_model_versions(self,region:str,project_id:str,model_id:str,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),)->ListModelVersionsPager:""" Lists all versions of the existing Model. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. :param model_id: Required. The ID of the Model to output versions for. :param retry: Designation of what errors, if any, should be retried. :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """client=self.get_model_service_client(region)name=client.model_path(project_id,region,model_id)result=client.list_model_versions(request={"name":name,},retry=retry,timeout=timeout,metadata=metadata,)returnresult
@GoogleBaseHook.fallback_to_default_project_id
[docs]defdelete_model_version(self,region:str,project_id:str,model_id:str,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),)->Operation:""" Deletes version of the Model. The version could not be deleted if this version is default. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. :param model_id: Required. The ID of the Model in which to delete version. :param retry: Designation of what errors, if any, should be retried. :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """client=self.get_model_service_client(region)name=client.model_path(project_id,region,model_id)result=client.delete_model_version(request={"name":name,},retry=retry,timeout=timeout,metadata=metadata,)returnresult
@GoogleBaseHook.fallback_to_default_project_id
[docs]defget_model(self,region:str,project_id:str,model_id:str,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),)->Model:""" Retrieves Model of specific name and version. If version is not specified, the default is retrieved. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. :param model_id: Required. The ID of the Model to retrieve. :param retry: Designation of what errors, if any, should be retried. :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """client=self.get_model_service_client(region)name=client.model_path(project_id,region,model_id)result=client.get_model(request={"name":name,},retry=retry,timeout=timeout,metadata=metadata,)returnresult
@GoogleBaseHook.fallback_to_default_project_id
[docs]defset_version_as_default(self,region:str,model_id:str,project_id:str,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),)->Model:""" Set current version of the Model as default. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. :param model_id: Required. The ID of the Model to set as default. :param retry: Designation of what errors, if any, should be retried. :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """client=self.get_model_service_client(region)name=client.model_path(project_id,region,model_id)result=client.merge_version_aliases(request={"name":name,"version_aliases":["default"],},retry=retry,timeout=timeout,metadata=metadata,)returnresult
@GoogleBaseHook.fallback_to_default_project_id
[docs]defadd_version_aliases(self,region:str,model_id:str,project_id:str,version_aliases:Sequence[str],retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),)->Model:""" Add list of version aliases to specific version of Model. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. :param model_id: Required. The ID of the Model to add aliases to. :param version_aliases: Required. List of version aliases to be added for specific version. :param retry: Designation of what errors, if any, should be retried. :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """client=self.get_model_service_client(region)name=client.model_path(project_id,region,model_id)foraliasinversion_aliases:ifalias.startswith("-"):raiseAirflowException("Name of the alias can't start with '-'")result=client.merge_version_aliases(request={"name":name,"version_aliases":version_aliases,},retry=retry,timeout=timeout,metadata=metadata,)returnresult
@GoogleBaseHook.fallback_to_default_project_id
[docs]defdelete_version_aliases(self,region:str,model_id:str,project_id:str,version_aliases:Sequence[str],retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),)->Model:""" Delete list of version aliases of specific version of Model. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. :param model_id: Required. The ID of the Model to delete aliases from. :param version_aliases: Required. List of version aliases to be deleted from specific version. :param retry: Designation of what errors, if any, should be retried. :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """client=self.get_model_service_client(region)name=client.model_path(project_id,region,model_id)if"default"inversion_aliases:raiseAirflowException("Default alias can't be deleted. ""Make sure to assign this alias to another version before deletion")aliases_for_delete=["-"+aliasforaliasinversion_aliases]result=client.merge_version_aliases(request={"name":name,"version_aliases":aliases_for_delete,},retry=retry,timeout=timeout,metadata=metadata,)returnresult