airflow.providers.google.cloud.operators.vertex_ai.generative_model

This module contains Google Vertex AI Generative AI operators.

Classes

TextGenerationModelPredictOperator

Uses the Vertex AI PaLM API to generate natural language text.

TextEmbeddingModelGetEmbeddingsOperator

Uses the Vertex AI Embeddings API to generate embeddings based on prompt.

GenerativeModelGenerateContentOperator

Use the Vertex AI Gemini Pro foundation model to generate content.

SupervisedFineTuningTrainOperator

Use the Supervised Fine Tuning API to create a tuning job.

CountTokensOperator

Use the Vertex AI Count Tokens API to calculate the number of input tokens before sending a request to the Gemini API.

RunEvaluationOperator

Use the Rapid Evaluation API to evaluate a model.

CreateCachedContentOperator

Create CachedContent to reduce the cost of requests that contain repeat content with high input token counts.

GenerateFromCachedContentOperator

Generate a response from CachedContent.

Module Contents

class airflow.providers.google.cloud.operators.vertex_ai.generative_model.TextGenerationModelPredictOperator(*, project_id, location, prompt, pretrained_model='text-bison', temperature=0.0, max_output_tokens=256, top_p=0.8, top_k=40, gcp_conn_id='google_cloud_default', impersonation_chain=None, **kwargs)[source]

Bases: airflow.providers.google.cloud.operators.cloud_base.GoogleCloudBaseOperator

Uses the Vertex AI PaLM API to generate natural language text.

Parameters:
  • project_id (str) – Required. The ID of the Google Cloud project that the service belongs to (templated).

  • location (str) – Required. The ID of the Google Cloud location that the service belongs to (templated).

  • prompt (str) – Required. Inputs or queries that a user or a program gives to the Vertex AI PaLM API, in order to elicit a specific response (templated).

  • pretrained_model (str) – By default uses the pre-trained model text-bison, optimized for performing natural language tasks such as classification, summarization, extraction, content creation, and ideation.

  • temperature (float) – Temperature controls the degree of randomness in token selection. Defaults to 0.0.

  • max_output_tokens (int) – Token limit determines the maximum amount of text output. Defaults to 256.

  • top_p (float) – Tokens are selected from most probable to least until the sum of their probabilities equals the top_p value. Defaults to 0.8.

  • top_k (int) – A top_k of 1 means the selected token is the most probable among all tokens. Defaults to 0.4.

  • gcp_conn_id (str) – The connection ID to use connecting to Google Cloud.

  • impersonation_chain (str | collections.abc.Sequence[str] | None) – Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. If set as a string, the account must grant the originating account the Service Account Token Creator IAM role. If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated).

template_fields = ('location', 'project_id', 'impersonation_chain', 'prompt')[source]
project_id[source]
location[source]
prompt[source]
pretrained_model = 'text-bison'[source]
temperature = 0.0[source]
max_output_tokens = 256[source]
top_p = 0.8[source]
top_k = 40[source]
gcp_conn_id = 'google_cloud_default'[source]
impersonation_chain = None[source]
execute(context)[source]

Derive when creating an operator.

Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

class airflow.providers.google.cloud.operators.vertex_ai.generative_model.TextEmbeddingModelGetEmbeddingsOperator(*, project_id, location, prompt, pretrained_model='textembedding-gecko', gcp_conn_id='google_cloud_default', impersonation_chain=None, **kwargs)[source]

Bases: airflow.providers.google.cloud.operators.cloud_base.GoogleCloudBaseOperator

Uses the Vertex AI Embeddings API to generate embeddings based on prompt.

Parameters:
  • project_id (str) – Required. The ID of the Google Cloud project that the service belongs to (templated).

  • location (str) – Required. The ID of the Google Cloud location that the service belongs to (templated).

  • prompt (str) – Required. Inputs or queries that a user or a program gives to the Vertex AI PaLM API, in order to elicit a specific response (templated).

  • pretrained_model (str) – By default uses the pre-trained model textembedding-gecko, optimized for performing text embeddings.

  • gcp_conn_id (str) – The connection ID to use connecting to Google Cloud.

  • impersonation_chain (str | collections.abc.Sequence[str] | None) – Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. If set as a string, the account must grant the originating account the Service Account Token Creator IAM role. If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated).

template_fields = ('location', 'project_id', 'impersonation_chain', 'prompt')[source]
project_id[source]
location[source]
prompt[source]
pretrained_model = 'textembedding-gecko'[source]
gcp_conn_id = 'google_cloud_default'[source]
impersonation_chain = None[source]
execute(context)[source]

Derive when creating an operator.

Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

class airflow.providers.google.cloud.operators.vertex_ai.generative_model.GenerativeModelGenerateContentOperator(*, project_id, location, contents, tools=None, generation_config=None, safety_settings=None, system_instruction=None, pretrained_model='gemini-pro', gcp_conn_id='google_cloud_default', impersonation_chain=None, **kwargs)[source]

Bases: airflow.providers.google.cloud.operators.cloud_base.GoogleCloudBaseOperator

Use the Vertex AI Gemini Pro foundation model to generate content.

Parameters:
  • project_id (str) – Required. The ID of the Google Cloud project that the service belongs to (templated).

  • location (str) – Required. The ID of the Google Cloud location that the service belongs to (templated).

  • contents (list) – Required. The multi-part content of a message that a user or a program gives to the generative model, in order to elicit a specific response.

  • generation_config (dict | None) – Optional. Generation configuration settings.

  • safety_settings (dict | None) – Optional. Per request settings for blocking unsafe content.

  • tools (list | None) – Optional. A list of tools available to the model during evaluation, such as a data store.

  • system_instruction (str | None) – Optional. An instruction given to the model to guide its behavior.

  • pretrained_model (str) – By default uses the pre-trained model gemini-pro, supporting prompts with text-only input, including natural language tasks, multi-turn text and code chat, and code generation. It can output text and code.

  • gcp_conn_id (str) – The connection ID to use connecting to Google Cloud.

  • impersonation_chain (str | collections.abc.Sequence[str] | None) – Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. If set as a string, the account must grant the originating account the Service Account Token Creator IAM role. If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated).

template_fields = ('location', 'project_id', 'impersonation_chain', 'contents', 'pretrained_model')[source]
project_id[source]
location[source]
contents[source]
tools = None[source]
generation_config = None[source]
safety_settings = None[source]
system_instruction = None[source]
pretrained_model = 'gemini-pro'[source]
gcp_conn_id = 'google_cloud_default'[source]
impersonation_chain = None[source]
execute(context)[source]

Derive when creating an operator.

Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

class airflow.providers.google.cloud.operators.vertex_ai.generative_model.SupervisedFineTuningTrainOperator(*, project_id, location, source_model, train_dataset, tuned_model_display_name=None, validation_dataset=None, epochs=None, adapter_size=None, learning_rate_multiplier=None, gcp_conn_id='google_cloud_default', impersonation_chain=None, **kwargs)[source]

Bases: airflow.providers.google.cloud.operators.cloud_base.GoogleCloudBaseOperator

Use the Supervised Fine Tuning API to create a tuning job.

Parameters:
  • project_id (str) – Required. The ID of the Google Cloud project that the service belongs to.

  • location (str) – Required. The ID of the Google Cloud location that the service belongs to.

  • source_model (str) – Required. A pre-trained model optimized for performing natural language tasks such as classification, summarization, extraction, content creation, and ideation.

  • train_dataset (str) – Required. Cloud Storage URI of your training dataset. The dataset must be formatted as a JSONL file. For best results, provide at least 100 to 500 examples.

  • tuned_model_display_name (str | None) – Optional. Display name of the TunedModel. The name can be up to 128 characters long and can consist of any UTF-8 characters.

  • validation_dataset (str | None) – Optional. Cloud Storage URI of your training dataset. The dataset must be formatted as a JSONL file. For best results, provide at least 100 to 500 examples.

  • epochs (int | None) – Optional. To optimize performance on a specific dataset, try using a higher epoch value. Increasing the number of epochs might improve results. However, be cautious about over-fitting, especially when dealing with small datasets. If over-fitting occurs, consider lowering the epoch number.

  • adapter_size (int | None) – Optional. Adapter size for tuning.

  • learning_multiplier_rate – Optional. Multiplier for adjusting the default learning rate.

  • gcp_conn_id (str) – The connection ID to use connecting to Google Cloud.

  • impersonation_chain (str | collections.abc.Sequence[str] | None) – Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. If set as a string, the account must grant the originating account the Service Account Token Creator IAM role. If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated).

template_fields = ('location', 'project_id', 'impersonation_chain', 'train_dataset', 'validation_dataset')[source]
project_id[source]
location[source]
source_model[source]
train_dataset[source]
tuned_model_display_name = None[source]
validation_dataset = None[source]
epochs = None[source]
adapter_size = None[source]
learning_rate_multiplier = None[source]
gcp_conn_id = 'google_cloud_default'[source]
impersonation_chain = None[source]
execute(context)[source]

Derive when creating an operator.

Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

class airflow.providers.google.cloud.operators.vertex_ai.generative_model.CountTokensOperator(*, project_id, location, contents, pretrained_model='gemini-pro', gcp_conn_id='google_cloud_default', impersonation_chain=None, **kwargs)[source]

Bases: airflow.providers.google.cloud.operators.cloud_base.GoogleCloudBaseOperator

Use the Vertex AI Count Tokens API to calculate the number of input tokens before sending a request to the Gemini API.

Parameters:
  • project_id (str) – Required. The ID of the Google Cloud project that the service belongs to (templated).

  • location (str) – Required. The ID of the Google Cloud location that the service belongs to (templated).

  • contents (list) – Required. The multi-part content of a message that a user or a program gives to the generative model, in order to elicit a specific response.

  • pretrained_model (str) – By default uses the pre-trained model gemini-pro, supporting prompts with text-only input, including natural language tasks, multi-turn text and code chat, and code generation. It can output text and code.

  • gcp_conn_id (str) – The connection ID to use connecting to Google Cloud.

  • impersonation_chain (str | collections.abc.Sequence[str] | None) – Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. If set as a string, the account must grant the originating account the Service Account Token Creator IAM role. If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated).

template_fields = ('location', 'project_id', 'impersonation_chain', 'contents', 'pretrained_model')[source]
project_id[source]
location[source]
contents[source]
pretrained_model = 'gemini-pro'[source]
gcp_conn_id = 'google_cloud_default'[source]
impersonation_chain = None[source]
execute(context)[source]

Derive when creating an operator.

Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

class airflow.providers.google.cloud.operators.vertex_ai.generative_model.RunEvaluationOperator(*, project_id, location, pretrained_model, eval_dataset, metrics, experiment_name, experiment_run_name, prompt_template, generation_config=None, safety_settings=None, system_instruction=None, tools=None, gcp_conn_id='google_cloud_default', impersonation_chain=None, **kwargs)[source]

Bases: airflow.providers.google.cloud.operators.cloud_base.GoogleCloudBaseOperator

Use the Rapid Evaluation API to evaluate a model.

Parameters:
  • project_id (str) – Required. The ID of the Google Cloud project that the service belongs to.

  • location (str) – Required. The ID of the Google Cloud location that the service belongs to.

  • pretrained_model (str) – Required. A pre-trained model optimized for performing natural language tasks such as classification, summarization, extraction, content creation, and ideation.

  • eval_dataset (dict) – Required. A fixed dataset for evaluating a model against. Adheres to Rapid Evaluation API.

  • metrics (list) – Required. A list of evaluation metrics to be used in the experiment. Adheres to Rapid Evaluation API.

  • experiment_name (str) – Required. The name of the evaluation experiment.

  • experiment_run_name (str) – Required. The specific run name or ID for this experiment.

  • prompt_template (str) – Required. The template used to format the model’s prompts during evaluation. Adheres to Rapid Evaluation API.

  • generation_config (dict | None) – Optional. A dictionary containing generation parameters for the model.

  • safety_settings (dict | None) – Optional. A dictionary specifying harm category thresholds for blocking model outputs.

  • system_instruction (str | None) – Optional. An instruction given to the model to guide its behavior.

  • tools (list | None) – Optional. A list of tools available to the model during evaluation, such as a data store.

  • gcp_conn_id (str) – The connection ID to use connecting to Google Cloud.

  • impersonation_chain (str | collections.abc.Sequence[str] | None) – Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. If set as a string, the account must grant the originating account the Service Account Token Creator IAM role. If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated).

template_fields = ('location', 'project_id', 'impersonation_chain', 'pretrained_model', 'eval_dataset',...[source]
project_id[source]
location[source]
pretrained_model[source]
eval_dataset[source]
metrics[source]
experiment_name[source]
experiment_run_name[source]
prompt_template[source]
generation_config = None[source]
safety_settings = None[source]
system_instruction = None[source]
tools = None[source]
gcp_conn_id = 'google_cloud_default'[source]
impersonation_chain = None[source]
execute(context)[source]

Derive when creating an operator.

Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

class airflow.providers.google.cloud.operators.vertex_ai.generative_model.CreateCachedContentOperator(*, project_id, location, model_name, system_instruction=None, contents=None, ttl_hours=1, display_name=None, gcp_conn_id='google_cloud_default', impersonation_chain=None, **kwargs)[source]

Bases: airflow.providers.google.cloud.operators.cloud_base.GoogleCloudBaseOperator

Create CachedContent to reduce the cost of requests that contain repeat content with high input token counts.

Parameters:
  • project_id (str) – Required. The ID of the Google Cloud project that the service belongs to.

  • location (str) – Required. The ID of the Google Cloud location that the service belongs to.

  • model_name (str) – Required. The name of the publisher model to use for cached content.

  • system_instruction (str | None) – Developer set system instruction.

  • contents (list | None) – The content to cache.

  • ttl_hours (float) – The TTL for this resource in hours. The expiration time is computed: now + TTL. Defaults to one hour.

  • display_name (str | None) – The user-generated meaningful display name of the cached content

  • gcp_conn_id (str) – The connection ID to use connecting to Google Cloud.

  • impersonation_chain (str | collections.abc.Sequence[str] | None) – Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. If set as a string, the account must grant the originating account the Service Account Token Creator IAM role. If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated).

template_fields = ('location', 'project_id', 'impersonation_chain', 'model_name', 'contents', 'system_instruction')[source]
project_id[source]
location[source]
model_name[source]
system_instruction = None[source]
contents = None[source]
ttl_hours = 1[source]
display_name = None[source]
gcp_conn_id = 'google_cloud_default'[source]
impersonation_chain = None[source]
execute(context)[source]

Derive when creating an operator.

Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

class airflow.providers.google.cloud.operators.vertex_ai.generative_model.GenerateFromCachedContentOperator(*, project_id, location, cached_content_name, contents, generation_config=None, safety_settings=None, gcp_conn_id='google_cloud_default', impersonation_chain=None, **kwargs)[source]

Bases: airflow.providers.google.cloud.operators.cloud_base.GoogleCloudBaseOperator

Generate a response from CachedContent.

Parameters:
  • project_id (str) – Required. The ID of the Google Cloud project that the service belongs to.

  • location (str) – Required. The ID of the Google Cloud location that the service belongs to.

  • cached_content_name (str) – Required. The name of the cached content resource.

  • contents (list) – Required. The multi-part content of a message that a user or a program gives to the generative model, in order to elicit a specific response.

  • generation_config (dict | None) – Optional. Generation configuration settings.

  • safety_settings (dict | None) – Optional. Per request settings for blocking unsafe content.

  • gcp_conn_id (str) – The connection ID to use connecting to Google Cloud.

  • impersonation_chain (str | collections.abc.Sequence[str] | None) – Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. If set as a string, the account must grant the originating account the Service Account Token Creator IAM role. If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated).

template_fields = ('location', 'project_id', 'impersonation_chain', 'cached_content_name', 'contents')[source]
project_id[source]
location[source]
cached_content_name[source]
contents[source]
generation_config = None[source]
safety_settings = None[source]
gcp_conn_id = 'google_cloud_default'[source]
impersonation_chain = None[source]
execute(context)[source]

Derive when creating an operator.

Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

Was this entry helpful?