Airflow Summit 2025 is coming October 07-09. Register now for early bird ticket!

Source code for airflow.providers.google.cloud.operators.vertex_ai.auto_ml

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""This module contains Google Vertex AI operators."""

from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING

from google.api_core.exceptions import NotFound
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform.models import Model
from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline

from airflow.providers.google.cloud.hooks.vertex_ai.auto_ml import AutoMLHook
from airflow.providers.google.cloud.links.vertex_ai import (
    VertexAIModelLink,
    VertexAITrainingLink,
    VertexAITrainingPipelinesLink,
)
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator

if TYPE_CHECKING:
    from google.api_core.retry import Retry

    from airflow.utils.context import Context


[docs] class AutoMLTrainingJobBaseOperator(GoogleCloudBaseOperator): """The base class for operators that launch AutoML jobs on VertexAI.""" def __init__( self, *, project_id: str, region: str, display_name: str, labels: dict[str, str] | None = None, parent_model: str | None = None, is_default_version: bool | None = None, model_version_aliases: list[str] | None = None, model_version_description: str | None = None, training_encryption_spec_key_name: str | None = None, model_encryption_spec_key_name: str | None = None, # RUN training_fraction_split: float | None = None, test_fraction_split: float | None = None, model_display_name: str | None = None, model_labels: dict[str, str] | None = None, sync: bool = True, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs)
[docs] self.project_id = project_id
[docs] self.region = region
[docs] self.display_name = display_name
[docs] self.labels = labels
[docs] self.parent_model = parent_model
[docs] self.is_default_version = is_default_version
[docs] self.model_version_aliases = model_version_aliases
[docs] self.model_version_description = model_version_description
[docs] self.training_encryption_spec_key_name = training_encryption_spec_key_name
[docs] self.model_encryption_spec_key_name = model_encryption_spec_key_name
# START Run param
[docs] self.training_fraction_split = training_fraction_split
[docs] self.test_fraction_split = test_fraction_split
[docs] self.model_display_name = model_display_name
[docs] self.model_labels = model_labels
[docs] self.sync = sync
# END Run param
[docs] self.gcp_conn_id = gcp_conn_id
[docs] self.impersonation_chain = impersonation_chain
[docs] self.hook: AutoMLHook | None = None
[docs] def on_kill(self) -> None: """Act as a callback called when the operator is killed; cancel any running job.""" if self.hook: self.hook.cancel_auto_ml_job()
[docs] class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator): """Create AutoML Forecasting Training job."""
[docs] template_fields = ( "parent_model", "dataset_id", "region", "impersonation_chain", "display_name", "model_display_name", )
def __init__( self, *, dataset_id: str, target_column: str, time_column: str, time_series_identifier_column: str, unavailable_at_forecast_columns: list[str], available_at_forecast_columns: list[str], forecast_horizon: int, data_granularity_unit: str, data_granularity_count: int, display_name: str, model_display_name: str | None = None, optimization_objective: str | None = None, column_specs: dict[str, str] | None = None, column_transformations: list[dict[str, dict[str, str]]] | None = None, validation_fraction_split: float | None = None, predefined_split_column_name: str | None = None, weight_column: str | None = None, time_series_attribute_columns: list[str] | None = None, context_window: int | None = None, export_evaluated_data_items: bool = False, export_evaluated_data_items_bigquery_destination_uri: str | None = None, export_evaluated_data_items_override_destination: bool = False, quantiles: list[float] | None = None, validation_options: str | None = None, budget_milli_node_hours: int = 1000, region: str, impersonation_chain: str | Sequence[str] | None = None, parent_model: str | None = None, window_stride_length: int | None = None, window_max_count: int | None = None, holiday_regions: list[str] | None = None, **kwargs, ) -> None: super().__init__( display_name=display_name, model_display_name=model_display_name, region=region, impersonation_chain=impersonation_chain, parent_model=parent_model, **kwargs, )
[docs] self.dataset_id = dataset_id
[docs] self.target_column = target_column
[docs] self.time_column = time_column
[docs] self.time_series_identifier_column = time_series_identifier_column
[docs] self.unavailable_at_forecast_columns = unavailable_at_forecast_columns
[docs] self.available_at_forecast_columns = available_at_forecast_columns
[docs] self.forecast_horizon = forecast_horizon
[docs] self.data_granularity_unit = data_granularity_unit
[docs] self.data_granularity_count = data_granularity_count
[docs] self.optimization_objective = optimization_objective
[docs] self.column_specs = column_specs
[docs] self.column_transformations = column_transformations
[docs] self.validation_fraction_split = validation_fraction_split
[docs] self.predefined_split_column_name = predefined_split_column_name
[docs] self.weight_column = weight_column
[docs] self.time_series_attribute_columns = time_series_attribute_columns
[docs] self.context_window = context_window
[docs] self.export_evaluated_data_items = export_evaluated_data_items
[docs] self.export_evaluated_data_items_bigquery_destination_uri = ( export_evaluated_data_items_bigquery_destination_uri )
[docs] self.export_evaluated_data_items_override_destination = ( export_evaluated_data_items_override_destination )
[docs] self.quantiles = quantiles
[docs] self.validation_options = validation_options
[docs] self.budget_milli_node_hours = budget_milli_node_hours
[docs] self.window_stride_length = window_stride_length
[docs] self.window_max_count = window_max_count
[docs] self.holiday_regions = holiday_regions
[docs] def execute(self, context: Context): self.hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) self.parent_model = self.parent_model.split("@")[0] if self.parent_model else None model, training_id = self.hook.create_auto_ml_forecasting_training_job( project_id=self.project_id, region=self.region, display_name=self.display_name, dataset=datasets.TimeSeriesDataset(dataset_name=self.dataset_id), parent_model=self.parent_model, is_default_version=self.is_default_version, model_version_aliases=self.model_version_aliases, model_version_description=self.model_version_description, target_column=self.target_column, time_column=self.time_column, time_series_identifier_column=self.time_series_identifier_column, unavailable_at_forecast_columns=self.unavailable_at_forecast_columns, available_at_forecast_columns=self.available_at_forecast_columns, forecast_horizon=self.forecast_horizon, data_granularity_unit=self.data_granularity_unit, data_granularity_count=self.data_granularity_count, optimization_objective=self.optimization_objective, column_specs=self.column_specs, column_transformations=self.column_transformations, labels=self.labels, training_encryption_spec_key_name=self.training_encryption_spec_key_name, model_encryption_spec_key_name=self.model_encryption_spec_key_name, training_fraction_split=self.training_fraction_split, validation_fraction_split=self.validation_fraction_split, test_fraction_split=self.test_fraction_split, predefined_split_column_name=self.predefined_split_column_name, weight_column=self.weight_column, time_series_attribute_columns=self.time_series_attribute_columns, context_window=self.context_window, export_evaluated_data_items=self.export_evaluated_data_items, export_evaluated_data_items_bigquery_destination_uri=( self.export_evaluated_data_items_bigquery_destination_uri ), export_evaluated_data_items_override_destination=( self.export_evaluated_data_items_override_destination ), quantiles=self.quantiles, validation_options=self.validation_options, budget_milli_node_hours=self.budget_milli_node_hours, model_display_name=self.model_display_name, model_labels=self.model_labels, sync=self.sync, window_stride_length=self.window_stride_length, window_max_count=self.window_max_count, holiday_regions=self.holiday_regions, ) if model: result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) self.xcom_push(context, key="model_id", value=model_id) VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) else: result = model # type: ignore self.xcom_push(context, key="training_id", value=training_id) VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) return result
[docs] class CreateAutoMLImageTrainingJobOperator(AutoMLTrainingJobBaseOperator): """Create Auto ML Image Training job."""
[docs] template_fields = ( "parent_model", "dataset_id", "region", "impersonation_chain", )
def __init__( self, *, dataset_id: str, prediction_type: str = "classification", multi_label: bool = False, model_type: str = "CLOUD", base_model: Model | None = None, validation_fraction_split: float | None = None, training_filter_split: str | None = None, validation_filter_split: str | None = None, test_filter_split: str | None = None, budget_milli_node_hours: int | None = None, disable_early_stopping: bool = False, region: str, impersonation_chain: str | Sequence[str] | None = None, parent_model: str | None = None, **kwargs, ) -> None: super().__init__( region=region, impersonation_chain=impersonation_chain, parent_model=parent_model, **kwargs )
[docs] self.dataset_id = dataset_id
[docs] self.prediction_type = prediction_type
[docs] self.multi_label = multi_label
[docs] self.model_type = model_type
[docs] self.base_model = base_model
[docs] self.validation_fraction_split = validation_fraction_split
[docs] self.training_filter_split = training_filter_split
[docs] self.validation_filter_split = validation_filter_split
[docs] self.test_filter_split = test_filter_split
[docs] self.budget_milli_node_hours = budget_milli_node_hours
[docs] self.disable_early_stopping = disable_early_stopping
[docs] def execute(self, context: Context): self.hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) self.parent_model = self.parent_model.split("@")[0] if self.parent_model else None model, training_id = self.hook.create_auto_ml_image_training_job( project_id=self.project_id, region=self.region, display_name=self.display_name, dataset=datasets.ImageDataset(dataset_name=self.dataset_id), parent_model=self.parent_model, is_default_version=self.is_default_version, model_version_aliases=self.model_version_aliases, model_version_description=self.model_version_description, prediction_type=self.prediction_type, multi_label=self.multi_label, model_type=self.model_type, base_model=self.base_model, labels=self.labels, training_encryption_spec_key_name=self.training_encryption_spec_key_name, model_encryption_spec_key_name=self.model_encryption_spec_key_name, training_fraction_split=self.training_fraction_split, validation_fraction_split=self.validation_fraction_split, test_fraction_split=self.test_fraction_split, training_filter_split=self.training_filter_split, validation_filter_split=self.validation_filter_split, test_filter_split=self.test_filter_split, budget_milli_node_hours=self.budget_milli_node_hours, model_display_name=self.model_display_name, model_labels=self.model_labels, disable_early_stopping=self.disable_early_stopping, sync=self.sync, ) if model: result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) self.xcom_push(context, key="model_id", value=model_id) VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) else: result = model # type: ignore self.xcom_push(context, key="training_id", value=training_id) VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) return result
[docs] class CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator): """Create Auto ML Tabular Training job."""
[docs] template_fields = ( "parent_model", "dataset_id", "region", "impersonation_chain", )
def __init__( self, *, dataset_id: str, target_column: str, optimization_prediction_type: str, optimization_objective: str | None = None, column_specs: dict[str, str] | None = None, column_transformations: list[dict[str, dict[str, str]]] | None = None, optimization_objective_recall_value: float | None = None, optimization_objective_precision_value: float | None = None, validation_fraction_split: float | None = None, predefined_split_column_name: str | None = None, timestamp_split_column_name: str | None = None, weight_column: str | None = None, budget_milli_node_hours: int = 1000, disable_early_stopping: bool = False, export_evaluated_data_items: bool = False, export_evaluated_data_items_bigquery_destination_uri: str | None = None, export_evaluated_data_items_override_destination: bool = False, region: str, impersonation_chain: str | Sequence[str] | None = None, parent_model: str | None = None, **kwargs, ) -> None: super().__init__( region=region, impersonation_chain=impersonation_chain, parent_model=parent_model, **kwargs )
[docs] self.dataset_id = dataset_id
[docs] self.target_column = target_column
[docs] self.optimization_prediction_type = optimization_prediction_type
[docs] self.optimization_objective = optimization_objective
[docs] self.column_specs = column_specs
[docs] self.column_transformations = column_transformations
[docs] self.optimization_objective_recall_value = optimization_objective_recall_value
[docs] self.optimization_objective_precision_value = optimization_objective_precision_value
[docs] self.validation_fraction_split = validation_fraction_split
[docs] self.predefined_split_column_name = predefined_split_column_name
[docs] self.timestamp_split_column_name = timestamp_split_column_name
[docs] self.weight_column = weight_column
[docs] self.budget_milli_node_hours = budget_milli_node_hours
[docs] self.disable_early_stopping = disable_early_stopping
[docs] self.export_evaluated_data_items = export_evaluated_data_items
[docs] self.export_evaluated_data_items_bigquery_destination_uri = ( export_evaluated_data_items_bigquery_destination_uri )
[docs] self.export_evaluated_data_items_override_destination = ( export_evaluated_data_items_override_destination )
[docs] def execute(self, context: Context): self.hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) credentials, _ = self.hook.get_credentials_and_project_id() self.parent_model = self.parent_model.split("@")[0] if self.parent_model else None model, training_id = self.hook.create_auto_ml_tabular_training_job( project_id=self.project_id, region=self.region, display_name=self.display_name, dataset=datasets.TabularDataset( dataset_name=self.dataset_id, project=self.project_id, credentials=credentials, ), parent_model=self.parent_model, is_default_version=self.is_default_version, model_version_aliases=self.model_version_aliases, model_version_description=self.model_version_description, target_column=self.target_column, optimization_prediction_type=self.optimization_prediction_type, optimization_objective=self.optimization_objective, column_specs=self.column_specs, column_transformations=self.column_transformations, optimization_objective_recall_value=self.optimization_objective_recall_value, optimization_objective_precision_value=self.optimization_objective_precision_value, labels=self.labels, training_encryption_spec_key_name=self.training_encryption_spec_key_name, model_encryption_spec_key_name=self.model_encryption_spec_key_name, training_fraction_split=self.training_fraction_split, validation_fraction_split=self.validation_fraction_split, test_fraction_split=self.test_fraction_split, predefined_split_column_name=self.predefined_split_column_name, timestamp_split_column_name=self.timestamp_split_column_name, weight_column=self.weight_column, budget_milli_node_hours=self.budget_milli_node_hours, model_display_name=self.model_display_name, model_labels=self.model_labels, disable_early_stopping=self.disable_early_stopping, export_evaluated_data_items=self.export_evaluated_data_items, export_evaluated_data_items_bigquery_destination_uri=( self.export_evaluated_data_items_bigquery_destination_uri ), export_evaluated_data_items_override_destination=( self.export_evaluated_data_items_override_destination ), sync=self.sync, ) if model: result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) self.xcom_push(context, key="model_id", value=model_id) VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) else: result = model # type: ignore self.xcom_push(context, key="training_id", value=training_id) VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) return result
[docs] class CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator): """Create Auto ML Video Training job."""
[docs] template_fields = ( "parent_model", "dataset_id", "region", "impersonation_chain", )
def __init__( self, *, dataset_id: str, prediction_type: str = "classification", model_type: str = "CLOUD", training_filter_split: str | None = None, test_filter_split: str | None = None, region: str, impersonation_chain: str | Sequence[str] | None = None, parent_model: str | None = None, **kwargs, ) -> None: super().__init__( region=region, impersonation_chain=impersonation_chain, parent_model=parent_model, **kwargs )
[docs] self.dataset_id = dataset_id
[docs] self.prediction_type = prediction_type
[docs] self.model_type = model_type
[docs] self.training_filter_split = training_filter_split
[docs] self.test_filter_split = test_filter_split
[docs] def execute(self, context: Context): self.hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) self.parent_model = self.parent_model.split("@")[0] if self.parent_model else None model, training_id = self.hook.create_auto_ml_video_training_job( project_id=self.project_id, region=self.region, display_name=self.display_name, dataset=datasets.VideoDataset(dataset_name=self.dataset_id), prediction_type=self.prediction_type, model_type=self.model_type, labels=self.labels, training_encryption_spec_key_name=self.training_encryption_spec_key_name, model_encryption_spec_key_name=self.model_encryption_spec_key_name, training_fraction_split=self.training_fraction_split, test_fraction_split=self.test_fraction_split, training_filter_split=self.training_filter_split, test_filter_split=self.test_filter_split, model_display_name=self.model_display_name, model_labels=self.model_labels, sync=self.sync, parent_model=self.parent_model, is_default_version=self.is_default_version, model_version_aliases=self.model_version_aliases, model_version_description=self.model_version_description, ) if model: result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) self.xcom_push(context, key="model_id", value=model_id) VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) else: result = model # type: ignore self.xcom_push(context, key="training_id", value=training_id) VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) return result
[docs] class DeleteAutoMLTrainingJobOperator(GoogleCloudBaseOperator): """ Delete an AutoML training job. Can be used with AutoMLForecastingTrainingJob, AutoMLImageTrainingJob, AutoMLTabularTrainingJob, AutoMLTextTrainingJob, or AutoMLVideoTrainingJob. """
[docs] template_fields = ("training_pipeline_id", "region", "project_id", "impersonation_chain")
def __init__( self, *, training_pipeline_id: str, region: str, project_id: str, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs)
[docs] self.training_pipeline_id = training_pipeline_id
[docs] self.region = region
[docs] self.project_id = project_id
[docs] self.retry = retry
[docs] self.timeout = timeout
[docs] self.metadata = metadata
[docs] self.gcp_conn_id = gcp_conn_id
[docs] self.impersonation_chain = impersonation_chain
[docs] def execute(self, context: Context): hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) try: self.log.info("Deleting Auto ML training pipeline: %s", self.training_pipeline_id) training_pipeline_operation = hook.delete_training_pipeline( training_pipeline=self.training_pipeline_id, region=self.region, project_id=self.project_id, retry=self.retry, timeout=self.timeout, metadata=self.metadata, ) hook.wait_for_operation(timeout=self.timeout, operation=training_pipeline_operation) self.log.info("Training pipeline was deleted.") except NotFound: self.log.info("The Training Pipeline ID %s does not exist.", self.training_pipeline_id)
[docs] class ListAutoMLTrainingJobOperator(GoogleCloudBaseOperator): """ List an AutoML training job. Can be used with AutoMLForecastingTrainingJob, AutoMLImageTrainingJob, AutoMLTabularTrainingJob, AutoMLTextTrainingJob, or AutoMLVideoTrainingJob in a Location. """
[docs] template_fields = ( "region", "project_id", "impersonation_chain", )
def __init__( self, *, region: str, project_id: str, page_size: int | None = None, page_token: str | None = None, filter: str | None = None, read_mask: str | None = None, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs)
[docs] self.region = region
[docs] self.project_id = project_id
[docs] self.page_size = page_size
[docs] self.page_token = page_token
[docs] self.filter = filter
[docs] self.read_mask = read_mask
[docs] self.retry = retry
[docs] self.timeout = timeout
[docs] self.metadata = metadata
[docs] self.gcp_conn_id = gcp_conn_id
[docs] self.impersonation_chain = impersonation_chain
[docs] def execute(self, context: Context): hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) results = hook.list_training_pipelines( region=self.region, project_id=self.project_id, page_size=self.page_size, page_token=self.page_token, filter=self.filter, read_mask=self.read_mask, retry=self.retry, timeout=self.timeout, metadata=self.metadata, ) VertexAITrainingPipelinesLink.persist(context=context, task_instance=self) return [TrainingPipeline.to_dict(result) for result in results]

Was this entry helpful?