airflow.providers.google.cloud.hooks.vertex_ai.generative_model¶
This module contains a Google Cloud Vertex AI Generative Model hook.
Classes¶
Hook for Google Cloud Vertex AI Generative Model APIs. |
Module Contents¶
- class airflow.providers.google.cloud.hooks.vertex_ai.generative_model.GenerativeModelHook(gcp_conn_id='google_cloud_default', impersonation_chain=None, **kwargs)[source]¶
Bases:
airflow.providers.google.common.hooks.base_google.GoogleBaseHookHook for Google Cloud Vertex AI Generative Model APIs.
- get_text_generation_model(pretrained_model)[source]¶
Return a Model Garden Model object based on Text Generation.
- get_text_embedding_model(pretrained_model)[source]¶
Return a Model Garden Model object based on Text Embedding.
- get_generative_model(pretrained_model, system_instruction=None, generation_config=None, safety_settings=None, tools=None)[source]¶
Return a Generative Model object.
- get_cached_context_model(cached_content_name)[source]¶
Return a Generative Model with Cached Context.
- text_generation_model_predict(prompt, pretrained_model, temperature, max_output_tokens, top_p, top_k, location, project_id=PROVIDE_PROJECT_ID)[source]¶
Use the Vertex AI PaLM API to generate natural language text.
- Parameters:
project_id (str) – Required. The ID of the Google Cloud project that the service belongs to.
location (str) – Required. The ID of the Google Cloud location that the service belongs to.
prompt (str) – Required. Inputs or queries that a user or a program gives to the Vertex AI PaLM API, in order to elicit a specific response.
pretrained_model (str) – A pre-trained model optimized for performing natural language tasks such as classification, summarization, extraction, content creation, and ideation.
temperature (float) – Temperature controls the degree of randomness in token selection.
max_output_tokens (int) – Token limit determines the maximum amount of text output.
top_p (float) – Tokens are selected from most probable to least until the sum of their probabilities equals the top_p value. Defaults to 0.8.
top_k (int) – A top_k of 1 means the selected token is the most probable among all tokens.
- text_embedding_model_get_embeddings(prompt, pretrained_model, location, project_id=PROVIDE_PROJECT_ID)[source]¶
Use the Vertex AI PaLM API to generate text embeddings.
- Parameters:
project_id (str) – Required. The ID of the Google Cloud project that the service belongs to.
location (str) – Required. The ID of the Google Cloud location that the service belongs to.
prompt (str) – Required. Inputs or queries that a user or a program gives to the Vertex AI PaLM API, in order to elicit a specific response.
pretrained_model (str) – A pre-trained model optimized for generating text embeddings.
- generative_model_generate_content(contents, location, tools=None, generation_config=None, safety_settings=None, system_instruction=None, pretrained_model='gemini-pro', project_id=PROVIDE_PROJECT_ID)[source]¶
Use the Vertex AI Gemini Pro foundation model to generate natural language text.
- Parameters:
location (str) – Required. The ID of the Google Cloud location that the service belongs to.
project_id (str) – Required. The ID of the Google Cloud project that the service belongs to.
contents (list) – Required. The multi-part content of a message that a user or a program gives to the generative model, in order to elicit a specific response.
generation_config (dict | None) – Optional. Generation configuration settings.
safety_settings (dict | None) – Optional. Per request settings for blocking unsafe content.
tools (list | None) – Optional. A list of tools available to the model during evaluation, such as a data store.
system_instruction (str | None) – Optional. An instruction given to the model to guide its behavior.
pretrained_model (str) – By default uses the pre-trained model gemini-pro, supporting prompts with text-only input, including natural language tasks, multi-turn text and code chat, and code generation. It can output text and code.
- supervised_fine_tuning_train(source_model, train_dataset, location, tuned_model_display_name=None, validation_dataset=None, epochs=None, adapter_size=None, learning_rate_multiplier=None, project_id=PROVIDE_PROJECT_ID)[source]¶
Use the Supervised Fine Tuning API to create a tuning job.
- Parameters:
project_id (str) – Required. The ID of the Google Cloud project that the service belongs to.
location (str) – Required. The ID of the Google Cloud location that the service belongs to.
source_model (str) – Required. A pre-trained model optimized for performing natural language tasks such as classification, summarization, extraction, content creation, and ideation.
train_dataset (str) – Required. Cloud Storage URI of your training dataset. The dataset must be formatted as a JSONL file. For best results, provide at least 100 to 500 examples.
tuned_model_display_name (str | None) – Optional. Display name of the TunedModel. The name can be up to 128 characters long and can consist of any UTF-8 characters.
validation_dataset (str | None) – Optional. Cloud Storage URI of your training dataset. The dataset must be formatted as a JSONL file. For best results, provide at least 100 to 500 examples.
epochs (int | None) – Optional. To optimize performance on a specific dataset, try using a higher epoch value. Increasing the number of epochs might improve results. However, be cautious about over-fitting, especially when dealing with small datasets. If over-fitting occurs, consider lowering the epoch number.
adapter_size (int | None) – Optional. Adapter size for tuning.
learning_rate_multiplier (float | None) – Optional. Multiplier for adjusting the default learning rate.
- count_tokens(contents, location, pretrained_model='gemini-pro', project_id=PROVIDE_PROJECT_ID)[source]¶
Use the Vertex AI Count Tokens API to calculate the number of input tokens before sending a request to the Gemini API.
- Parameters:
project_id (str) – Required. The ID of the Google Cloud project that the service belongs to.
location (str) – Required. The ID of the Google Cloud location that the service belongs to.
contents (list) – Required. The multi-part content of a message that a user or a program gives to the generative model, in order to elicit a specific response.
pretrained_model (str) – By default uses the pre-trained model gemini-pro, supporting prompts with text-only input, including natural language tasks, multi-turn text and code chat, and code generation. It can output text and code.
- run_evaluation(pretrained_model, eval_dataset, metrics, experiment_name, experiment_run_name, prompt_template, location, generation_config=None, safety_settings=None, system_instruction=None, tools=None, project_id=PROVIDE_PROJECT_ID)[source]¶
Use the Rapid Evaluation API to evaluate a model.
- Parameters:
project_id (str) – Required. The ID of the Google Cloud project that the service belongs to.
location (str) – Required. The ID of the Google Cloud location that the service belongs to.
pretrained_model (str) – Required. A pre-trained model optimized for performing natural language tasks such as classification, summarization, extraction, content creation, and ideation.
eval_dataset (dict) – Required. A fixed dataset for evaluating a model against. Adheres to Rapid Evaluation API.
metrics (list) – Required. A list of evaluation metrics to be used in the experiment. Adheres to Rapid Evaluation API.
experiment_name (str) – Required. The name of the evaluation experiment.
experiment_run_name (str) – Required. The specific run name or ID for this experiment.
prompt_template (str) – Required. The template used to format the model’s prompts during evaluation. Adheres to Rapid Evaluation API.
generation_config (dict | None) – Optional. A dictionary containing generation parameters for the model.
safety_settings (dict | None) – Optional. A dictionary specifying harm category thresholds for blocking model outputs.
system_instruction (str | None) – Optional. An instruction given to the model to guide its behavior.
tools (list | None) – Optional. A list of tools available to the model during evaluation, such as a data store.
- create_cached_content(model_name, location, ttl_hours=1, system_instruction=None, contents=None, display_name=None, project_id=PROVIDE_PROJECT_ID)[source]¶
Create CachedContent to reduce the cost of requests that contain repeat content with high input token counts.
- Parameters:
project_id (str) – Required. The ID of the Google Cloud project that the service belongs to.
location (str) – Required. The ID of the Google Cloud location that the service belongs to.
model_name (str) – Required. The name of the publisher model to use for cached content.
system_instruction (str | None) – Developer set system instruction.
contents (list | None) – The content to cache.
ttl_hours (float) – The TTL for this resource in hours. The expiration time is computed: now + TTL. Defaults to one hour.
display_name (str | None) – The user-generated meaningful display name of the cached content
- generate_from_cached_content(location, cached_content_name, contents, generation_config=None, safety_settings=None, project_id=PROVIDE_PROJECT_ID)[source]¶
Generate a response from CachedContent.
- Parameters:
project_id (str) – Required. The ID of the Google Cloud project that the service belongs to.
location (str) – Required. The ID of the Google Cloud location that the service belongs to.
cached_content_name (str) – Required. The name of the cached content resource.
contents (list) – Required. The multi-part content of a message that a user or a program gives to the generative model, in order to elicit a specific response.
generation_config (dict | None) – Optional. Generation configuration settings.
safety_settings (dict | None) – Optional. Per request settings for blocking unsafe content.