Source code for airflow.providers.google.cloud.hooks.vertex_ai.generative_model
## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License."""This module contains a Google Cloud Vertex AI Generative Model hook."""from__future__importannotationsimporttimefromdatetimeimporttimedeltafromtypingimportTYPE_CHECKINGimportvertexaifromvertexai.generative_modelsimportGenerativeModel,Partfromvertexai.language_modelsimportTextEmbeddingModel,TextGenerationModelfromvertexai.preview.cachingimportCachedContentfromvertexai.preview.evaluationimportEvalResult,EvalTaskfromvertexai.preview.generative_modelsimportGenerativeModelaspreview_generative_modelfromvertexai.preview.tuningimportsftfromairflow.exceptionsimportAirflowProviderDeprecationWarningfromairflow.providers.google.common.deprecatedimportdeprecatedfromairflow.providers.google.common.hooks.base_googleimportPROVIDE_PROJECT_ID,GoogleBaseHookifTYPE_CHECKING:fromgoogle.cloud.aiplatform_v1importtypesastypes_v1fromgoogle.cloud.aiplatform_v1beta1importtypesastypes_v1beta1
[docs]classGenerativeModelHook(GoogleBaseHook):"""Hook for Google Cloud Vertex AI Generative Model APIs."""
[docs]defget_text_generation_model(self,pretrained_model:str):"""Return a Model Garden Model object based on Text Generation."""model=TextGenerationModel.from_pretrained(pretrained_model)returnmodel
[docs]defget_text_embedding_model(self,pretrained_model:str):"""Return a Model Garden Model object based on Text Embedding."""model=TextEmbeddingModel.from_pretrained(pretrained_model)returnmodel
[docs]defget_generative_model(self,pretrained_model:str,system_instruction:str|None=None,generation_config:dict|None=None,safety_settings:dict|None=None,tools:list|None=None,)->GenerativeModel:"""Return a Generative Model object."""model=GenerativeModel(model_name=pretrained_model,system_instruction=system_instruction,generation_config=generation_config,safety_settings=safety_settings,tools=tools,)returnmodel
[docs]defget_eval_task(self,dataset:dict,metrics:list,experiment:str,)->EvalTask:"""Return an EvalTask object."""eval_task=EvalTask(dataset=dataset,metrics=metrics,experiment=experiment,)returneval_task
[docs]defget_cached_context_model(self,cached_content_name:str,)->preview_generative_model:"""Return a Generative Model with Cached Context."""cached_content=CachedContent(cached_content_name=cached_content_name)cached_context_model=preview_generative_model.from_cached_content(cached_content)returncached_context_model
@deprecated(planned_removal_date="January 01, 2025",use_instead="Part objects included in contents parameter of ""airflow.providers.google.cloud.hooks.generative_model.""GenerativeModelHook.generative_model_generate_content",category=AirflowProviderDeprecationWarning,)
[docs]defget_generative_model_part(self,content_gcs_path:str,content_mime_type:str|None=None)->Part:"""Return a Generative Model Part object."""part=Part.from_uri(content_gcs_path,mime_type=content_mime_type)returnpart
[docs]defprompt_language_model(self,prompt:str,pretrained_model:str,temperature:float,max_output_tokens:int,top_p:float,top_k:int,location:str,project_id:str=PROVIDE_PROJECT_ID,)->str:""" Use the Vertex AI PaLM API to generate natural language text. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param prompt: Required. Inputs or queries that a user or a program gives to the Vertex AI PaLM API, in order to elicit a specific response. :param pretrained_model: A pre-trained model optimized for performing natural language tasks such as classification, summarization, extraction, content creation, and ideation. :param temperature: Temperature controls the degree of randomness in token selection. :param max_output_tokens: Token limit determines the maximum amount of text output. :param top_p: Tokens are selected from most probable to least until the sum of their probabilities equals the top_p value. Defaults to 0.8. :param top_k: A top_k of 1 means the selected token is the most probable among all tokens. """vertexai.init(project=project_id,location=location,credentials=self.get_credentials())parameters={"temperature":temperature,"max_output_tokens":max_output_tokens,"top_p":top_p,"top_k":top_k,}model=self.get_text_generation_model(pretrained_model)response=model.predict(prompt=prompt,**parameters,)returnresponse.text
[docs]defgenerate_text_embeddings(self,prompt:str,pretrained_model:str,location:str,project_id:str=PROVIDE_PROJECT_ID,)->list:""" Use the Vertex AI PaLM API to generate text embeddings. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param prompt: Required. Inputs or queries that a user or a program gives to the Vertex AI PaLM API, in order to elicit a specific response. :param pretrained_model: A pre-trained model optimized for generating text embeddings. """vertexai.init(project=project_id,location=location,credentials=self.get_credentials())model=self.get_text_embedding_model(pretrained_model)response=model.get_embeddings([prompt])[0]# single promptreturnresponse.values
[docs]defprompt_multimodal_model(self,prompt:str,location:str,generation_config:dict|None=None,safety_settings:dict|None=None,pretrained_model:str="gemini-pro",project_id:str=PROVIDE_PROJECT_ID,)->str:""" Use the Vertex AI Gemini Pro foundation model to generate natural language text. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param prompt: Required. Inputs or queries that a user or a program gives to the Multi-modal model, in order to elicit a specific response. :param generation_config: Optional. Generation configuration settings. :param safety_settings: Optional. Per request settings for blocking unsafe content. :param pretrained_model: By default uses the pre-trained model `gemini-pro`, supporting prompts with text-only input, including natural language tasks, multi-turn text and code chat, and code generation. It can output text and code. """vertexai.init(project=project_id,location=location,credentials=self.get_credentials())model=self.get_generative_model(pretrained_model)response=model.generate_content(contents=[prompt],generation_config=generation_config,safety_settings=safety_settings)returnresponse.text
[docs]defprompt_multimodal_model_with_media(self,prompt:str,location:str,media_gcs_path:str,mime_type:str,generation_config:dict|None=None,safety_settings:dict|None=None,pretrained_model:str="gemini-pro-vision",project_id:str=PROVIDE_PROJECT_ID,)->str:""" Use the Vertex AI Gemini Pro foundation model to generate natural language text. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param prompt: Required. Inputs or queries that a user or a program gives to the Multi-modal model, in order to elicit a specific response. :param generation_config: Optional. Generation configuration settings. :param safety_settings: Optional. Per request settings for blocking unsafe content. :param pretrained_model: By default uses the pre-trained model `gemini-pro-vision`, supporting prompts with text-only input, including natural language tasks, multi-turn text and code chat, and code generation. It can output text and code. :param media_gcs_path: A GCS path to a content file such as an image or a video. Can be passed to the multi-modal model as part of the prompt. Used with vision models. :param mime_type: Validates the media type presented by the file in the media_gcs_path. """vertexai.init(project=project_id,location=location,credentials=self.get_credentials())model=self.get_generative_model(pretrained_model)part=self.get_generative_model_part(media_gcs_path,mime_type)response=model.generate_content(contents=[prompt,part],generation_config=generation_config,safety_settings=safety_settings)returnresponse.text
@GoogleBaseHook.fallback_to_default_project_id
[docs]deftext_generation_model_predict(self,prompt:str,pretrained_model:str,temperature:float,max_output_tokens:int,top_p:float,top_k:int,location:str,project_id:str=PROVIDE_PROJECT_ID,)->str:""" Use the Vertex AI PaLM API to generate natural language text. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param prompt: Required. Inputs or queries that a user or a program gives to the Vertex AI PaLM API, in order to elicit a specific response. :param pretrained_model: A pre-trained model optimized for performing natural language tasks such as classification, summarization, extraction, content creation, and ideation. :param temperature: Temperature controls the degree of randomness in token selection. :param max_output_tokens: Token limit determines the maximum amount of text output. :param top_p: Tokens are selected from most probable to least until the sum of their probabilities equals the top_p value. Defaults to 0.8. :param top_k: A top_k of 1 means the selected token is the most probable among all tokens. """vertexai.init(project=project_id,location=location,credentials=self.get_credentials())parameters={"temperature":temperature,"max_output_tokens":max_output_tokens,"top_p":top_p,"top_k":top_k,}model=self.get_text_generation_model(pretrained_model)response=model.predict(prompt=prompt,**parameters,)returnresponse.text
@GoogleBaseHook.fallback_to_default_project_id
[docs]deftext_embedding_model_get_embeddings(self,prompt:str,pretrained_model:str,location:str,project_id:str=PROVIDE_PROJECT_ID,)->list:""" Use the Vertex AI PaLM API to generate text embeddings. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param prompt: Required. Inputs or queries that a user or a program gives to the Vertex AI PaLM API, in order to elicit a specific response. :param pretrained_model: A pre-trained model optimized for generating text embeddings. """vertexai.init(project=project_id,location=location,credentials=self.get_credentials())model=self.get_text_embedding_model(pretrained_model)response=model.get_embeddings([prompt])[0]# single promptreturnresponse.values
@GoogleBaseHook.fallback_to_default_project_id
[docs]defgenerative_model_generate_content(self,contents:list,location:str,tools:list|None=None,generation_config:dict|None=None,safety_settings:dict|None=None,system_instruction:str|None=None,pretrained_model:str="gemini-pro",project_id:str=PROVIDE_PROJECT_ID,)->str:""" Use the Vertex AI Gemini Pro foundation model to generate natural language text. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param contents: Required. The multi-part content of a message that a user or a program gives to the generative model, in order to elicit a specific response. :param generation_config: Optional. Generation configuration settings. :param safety_settings: Optional. Per request settings for blocking unsafe content. :param tools: Optional. A list of tools available to the model during evaluation, such as a data store. :param system_instruction: Optional. An instruction given to the model to guide its behavior. :param pretrained_model: By default uses the pre-trained model `gemini-pro`, supporting prompts with text-only input, including natural language tasks, multi-turn text and code chat, and code generation. It can output text and code. """vertexai.init(project=project_id,location=location,credentials=self.get_credentials())model=self.get_generative_model(pretrained_model=pretrained_model,system_instruction=system_instruction)response=model.generate_content(contents=contents,tools=tools,generation_config=generation_config,safety_settings=safety_settings,)returnresponse.text
@GoogleBaseHook.fallback_to_default_project_id
[docs]defsupervised_fine_tuning_train(self,source_model:str,train_dataset:str,location:str,tuned_model_display_name:str|None=None,validation_dataset:str|None=None,epochs:int|None=None,adapter_size:int|None=None,learning_rate_multiplier:float|None=None,project_id:str=PROVIDE_PROJECT_ID,)->types_v1.TuningJob:""" Use the Supervised Fine Tuning API to create a tuning job. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param source_model: Required. A pre-trained model optimized for performing natural language tasks such as classification, summarization, extraction, content creation, and ideation. :param train_dataset: Required. Cloud Storage URI of your training dataset. The dataset must be formatted as a JSONL file. For best results, provide at least 100 to 500 examples. :param tuned_model_display_name: Optional. Display name of the TunedModel. The name can be up to 128 characters long and can consist of any UTF-8 characters. :param validation_dataset: Optional. Cloud Storage URI of your training dataset. The dataset must be formatted as a JSONL file. For best results, provide at least 100 to 500 examples. :param epochs: Optional. To optimize performance on a specific dataset, try using a higher epoch value. Increasing the number of epochs might improve results. However, be cautious about over-fitting, especially when dealing with small datasets. If over-fitting occurs, consider lowering the epoch number. :param adapter_size: Optional. Adapter size for tuning. :param learning_rate_multiplier: Optional. Multiplier for adjusting the default learning rate. """vertexai.init(project=project_id,location=location,credentials=self.get_credentials())sft_tuning_job=sft.train(source_model=source_model,train_dataset=train_dataset,validation_dataset=validation_dataset,epochs=epochs,adapter_size=adapter_size,learning_rate_multiplier=learning_rate_multiplier,tuned_model_display_name=tuned_model_display_name,)# Polling for job completionwhilenotsft_tuning_job.has_ended:time.sleep(60)sft_tuning_job.refresh()returnsft_tuning_job
@GoogleBaseHook.fallback_to_default_project_id
[docs]defcount_tokens(self,contents:list,location:str,pretrained_model:str="gemini-pro",project_id:str=PROVIDE_PROJECT_ID,)->types_v1beta1.CountTokensResponse:""" Use the Vertex AI Count Tokens API to calculate the number of input tokens before sending a request to the Gemini API. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param contents: Required. The multi-part content of a message that a user or a program gives to the generative model, in order to elicit a specific response. :param pretrained_model: By default uses the pre-trained model `gemini-pro`, supporting prompts with text-only input, including natural language tasks, multi-turn text and code chat, and code generation. It can output text and code. """vertexai.init(project=project_id,location=location,credentials=self.get_credentials())model=self.get_generative_model(pretrained_model=pretrained_model)response=model.count_tokens(contents=contents,)returnresponse
@GoogleBaseHook.fallback_to_default_project_id
[docs]defrun_evaluation(self,pretrained_model:str,eval_dataset:dict,metrics:list,experiment_name:str,experiment_run_name:str,prompt_template:str,location:str,generation_config:dict|None=None,safety_settings:dict|None=None,system_instruction:str|None=None,tools:list|None=None,project_id:str=PROVIDE_PROJECT_ID,)->EvalResult:""" Use the Rapid Evaluation API to evaluate a model. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param pretrained_model: Required. A pre-trained model optimized for performing natural language tasks such as classification, summarization, extraction, content creation, and ideation. :param eval_dataset: Required. A fixed dataset for evaluating a model against. Adheres to Rapid Evaluation API. :param metrics: Required. A list of evaluation metrics to be used in the experiment. Adheres to Rapid Evaluation API. :param experiment_name: Required. The name of the evaluation experiment. :param experiment_run_name: Required. The specific run name or ID for this experiment. :param prompt_template: Required. The template used to format the model's prompts during evaluation. Adheres to Rapid Evaluation API. :param generation_config: Optional. A dictionary containing generation parameters for the model. :param safety_settings: Optional. A dictionary specifying harm category thresholds for blocking model outputs. :param system_instruction: Optional. An instruction given to the model to guide its behavior. :param tools: Optional. A list of tools available to the model during evaluation, such as a data store. """vertexai.init(project=project_id,location=location,credentials=self.get_credentials())model=self.get_generative_model(pretrained_model=pretrained_model,system_instruction=system_instruction,generation_config=generation_config,safety_settings=safety_settings,tools=tools,)eval_task=self.get_eval_task(dataset=eval_dataset,metrics=metrics,experiment=experiment_name,)eval_result=eval_task.evaluate(model=model,prompt_template=prompt_template,experiment_run_name=experiment_run_name,)returneval_result
[docs]defcreate_cached_content(self,model_name:str,location:str,ttl_hours:float=1,system_instruction:str|None=None,contents:list|None=None,display_name:str|None=None,project_id:str=PROVIDE_PROJECT_ID,)->str:""" Create CachedContent to reduce the cost of requests that contain repeat content with high input token counts. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param model_name: Required. The name of the publisher model to use for cached content. :param system_instruction: Developer set system instruction. :param contents: The content to cache. :param ttl_hours: The TTL for this resource in hours. The expiration time is computed: now + TTL. Defaults to one hour. :param display_name: The user-generated meaningful display name of the cached content """vertexai.init(project=project_id,location=location,credentials=self.get_credentials())response=CachedContent.create(model_name=model_name,system_instruction=system_instruction,contents=contents,ttl=timedelta(hours=ttl_hours),display_name=display_name,)returnresponse.name
[docs]defgenerate_from_cached_content(self,location:str,cached_content_name:str,contents:list,generation_config:dict|None=None,safety_settings:dict|None=None,project_id:str=PROVIDE_PROJECT_ID,)->str:""" Generate a response from CachedContent. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param cached_content_name: Required. The name of the cached content resource. :param contents: Required. The multi-part content of a message that a user or a program gives to the generative model, in order to elicit a specific response. :param generation_config: Optional. Generation configuration settings. :param safety_settings: Optional. Per request settings for blocking unsafe content. """vertexai.init(project=project_id,location=location,credentials=self.get_credentials())cached_context_model=self.get_cached_context_model(cached_content_name=cached_content_name)response=cached_context_model.generate_content(contents=contents,generation_config=generation_config,safety_settings=safety_settings,)returnresponse.text