#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains a Google Cloud Vertex AI hook."""
from __future__ import annotations
from typing import TYPE_CHECKING, Sequence
from google.api_core.client_options import ClientOptions
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.aiplatform_v1 import ModelServiceClient
from airflow.exceptions import AirflowException
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
if TYPE_CHECKING:
    from google.api_core.operation import Operation
    from google.api_core.retry import Retry
    from google.cloud.aiplatform_v1.services.model_service.pagers import (
        ListModelsPager,
        ListModelVersionsPager,
    )
    from google.cloud.aiplatform_v1.types import Model, model_service
[docs]class ModelServiceHook(GoogleBaseHook):
    """Hook for Google Cloud Vertex AI Endpoint Service APIs."""
    def __init__(self, **kwargs):
        if kwargs.get("delegate_to") is not None:
            raise RuntimeError(
                "The `delegate_to` parameter has been deprecated before and finally removed in this version"
                " of Google Provider. You MUST convert it to `impersonate_chain`"
            )
        super().__init__(**kwargs)
[docs]    def get_model_service_client(self, region: str | None = None) -> ModelServiceClient:
        """Returns ModelServiceClient."""
        if region and region != "global":
            client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
        else:
            client_options = ClientOptions()
        return ModelServiceClient(
            credentials=self.get_credentials(), client_info=self.client_info, client_options=client_options
        ) 
    @staticmethod
[docs]    def wait_for_operation(self, operation: Operation, timeout: float | None = None):
        """Waits for long-lasting operation to complete."""
        try:
            return operation.result(timeout=timeout)
        except Exception:
            error = operation.exception(timeout=timeout)
            raise AirflowException(error) 
    @GoogleBaseHook.fallback_to_default_project_id
[docs]    def delete_model(
        self,
        project_id: str,
        region: str,
        model: str,
        retry: Retry | _MethodDefault = DEFAULT,
        timeout: float | None = None,
        metadata: Sequence[tuple[str, str]] = (),
    ) -> Operation:
        """
        Deletes a Model.
        :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
        :param region: Required. The ID of the Google Cloud region that the service belongs to.
        :param model: Required. The name of the Model resource to be deleted.
        :param retry: Designation of what errors, if any, should be retried.
        :param timeout: The timeout for this request.
        :param metadata: Strings which should be sent along with the request as metadata.
        """
        client = self.get_model_service_client(region)
        name = client.model_path(project_id, region, model)
        result = client.delete_model(
            request={
                "name": name,
            },
            retry=retry,
            timeout=timeout,
            metadata=metadata,
        )
        return result 
    @GoogleBaseHook.fallback_to_default_project_id
[docs]    def export_model(
        self,
        project_id: str,
        region: str,
        model: str,
        output_config: model_service.ExportModelRequest.OutputConfig | dict,
        retry: Retry | _MethodDefault = DEFAULT,
        timeout: float | None = None,
        metadata: Sequence[tuple[str, str]] = (),
    ) -> Operation:
        """
        Exports a trained, exportable Model to a location specified by the user.
        :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
        :param region: Required. The ID of the Google Cloud region that the service belongs to.
        :param model: Required. The resource name of the Model to export.
        :param output_config:  Required. The desired output location and configuration.
        :param retry: Designation of what errors, if any, should be retried.
        :param timeout: The timeout for this request.
        :param metadata: Strings which should be sent along with the request as metadata.
        """
        client = self.get_model_service_client(region)
        name = client.model_path(project_id, region, model)
        result = client.export_model(
            request={
                "name": name,
                "output_config": output_config,
            },
            retry=retry,
            timeout=timeout,
            metadata=metadata,
        )
        return result 
    @GoogleBaseHook.fallback_to_default_project_id
[docs]    def list_models(
        self,
        project_id: str,
        region: str,
        filter: str | None = None,
        page_size: int | None = None,
        page_token: str | None = None,
        read_mask: str | None = None,
        order_by: str | None = None,
        retry: Retry | _MethodDefault = DEFAULT,
        timeout: float | None = None,
        metadata: Sequence[tuple[str, str]] = (),
    ) -> ListModelsPager:
        r"""
        Lists Models in a Location.
        :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
        :param region: Required. The ID of the Google Cloud region that the service belongs to.
        :param filter: An expression for filtering the results of the request. For field names both
            snake_case and camelCase are supported.
            -  ``model`` supports = and !=. ``model`` represents the Model ID, i.e. the last segment of the
            Model's [resource name][google.cloud.aiplatform.v1.Model.name].
            -  ``display_name`` supports = and !=
            -  ``labels`` supports general map functions that is:
            --  ``labels.key=value`` - key:value equality
            --  \`labels.key:\* or labels:key - key existence
            --  A key including a space must be quoted. ``labels."a key"``.
        :param page_size: The standard list page size.
        :param page_token: The standard list page token. Typically obtained via
            [ListModelsResponse.next_page_token][google.cloud.aiplatform.v1.ListModelsResponse.next_page_token]
            of the previous
            [ModelService.ListModels][google.cloud.aiplatform.v1.ModelService.ListModels]
            call.
        :param read_mask: Mask specifying which fields to read.
        :param order_by: A comma-separated list of fields to order by, sorted in ascending order. Use "desc"
            after a field name for descending.
        :param retry: Designation of what errors, if any, should be retried.
        :param timeout: The timeout for this request.
        :param metadata: Strings which should be sent along with the request as metadata.
        """
        client = self.get_model_service_client(region)
        parent = client.common_location_path(project_id, region)
        result = client.list_models(
            request={
                "parent": parent,
                "filter": filter,
                "page_size": page_size,
                "page_token": page_token,
                "read_mask": read_mask,
                "order_by": order_by,
            },
            retry=retry,
            timeout=timeout,
            metadata=metadata,
        )
        return result 
    @GoogleBaseHook.fallback_to_default_project_id
[docs]    def upload_model(
        self,
        project_id: str,
        region: str,
        model: Model | dict,
        retry: Retry | _MethodDefault = DEFAULT,
        timeout: float | None = None,
        metadata: Sequence[tuple[str, str]] = (),
    ) -> Operation:
        """
        Uploads a Model artifact into Vertex AI.
        :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
        :param region: Required. The ID of the Google Cloud region that the service belongs to.
        :param model:  Required. The Model to create.
        :param retry: Designation of what errors, if any, should be retried.
        :param timeout: The timeout for this request.
        :param metadata: Strings which should be sent along with the request as metadata.
        """
        client = self.get_model_service_client(region)
        parent = client.common_location_path(project_id, region)
        result = client.upload_model(
            request={
                "parent": parent,
                "model": model,
            },
            retry=retry,
            timeout=timeout,
            metadata=metadata,
        )
        return result 
    @GoogleBaseHook.fallback_to_default_project_id
[docs]    def list_model_versions(
        self,
        region: str,
        project_id: str,
        model_id: str,
        retry: Retry | _MethodDefault = DEFAULT,
        timeout: float | None = None,
        metadata: Sequence[tuple[str, str]] = (),
    ) -> ListModelVersionsPager:
        """
        Lists all versions of the existing Model.
        :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
        :param region: Required. The ID of the Google Cloud region that the service belongs to.
        :param model_id: Required. The ID of the Model to output versions for.
        :param retry: Designation of what errors, if any, should be retried.
        :param timeout: The timeout for this request.
        :param metadata: Strings which should be sent along with the request as metadata.
        """
        client = self.get_model_service_client(region)
        name = client.model_path(project_id, region, model_id)
        result = client.list_model_versions(
            request={
                "name": name,
            },
            retry=retry,
            timeout=timeout,
            metadata=metadata,
        )
        return result 
    @GoogleBaseHook.fallback_to_default_project_id
[docs]    def delete_model_version(
        self,
        region: str,
        project_id: str,
        model_id: str,
        retry: Retry | _MethodDefault = DEFAULT,
        timeout: float | None = None,
        metadata: Sequence[tuple[str, str]] = (),
    ) -> Operation:
        """
        Deletes version of the Model. The version could not be deleted if this version is default.
        :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
        :param region: Required. The ID of the Google Cloud region that the service belongs to.
        :param model_id: Required. The ID of the Model in which to delete version.
        :param retry: Designation of what errors, if any, should be retried.
        :param timeout: The timeout for this request.
        :param metadata: Strings which should be sent along with the request as metadata.
        """
        client = self.get_model_service_client(region)
        name = client.model_path(project_id, region, model_id)
        result = client.delete_model_version(
            request={
                "name": name,
            },
            retry=retry,
            timeout=timeout,
            metadata=metadata,
        )
        return result 
    @GoogleBaseHook.fallback_to_default_project_id
[docs]    def get_model(
        self,
        region: str,
        project_id: str,
        model_id: str,
        retry: Retry | _MethodDefault = DEFAULT,
        timeout: float | None = None,
        metadata: Sequence[tuple[str, str]] = (),
    ) -> Model:
        """
        Retrieves Model of specific name and version. If version is not specified, the default is retrieved.
        :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
        :param region: Required. The ID of the Google Cloud region that the service belongs to.
        :param model_id: Required. The ID of the Model to retrieve.
        :param retry: Designation of what errors, if any, should be retried.
        :param timeout: The timeout for this request.
        :param metadata: Strings which should be sent along with the request as metadata.
        """
        client = self.get_model_service_client(region)
        name = client.model_path(project_id, region, model_id)
        result = client.get_model(
            request={
                "name": name,
            },
            retry=retry,
            timeout=timeout,
            metadata=metadata,
        )
        return result 
    @GoogleBaseHook.fallback_to_default_project_id
[docs]    def set_version_as_default(
        self,
        region: str,
        model_id: str,
        project_id: str,
        retry: Retry | _MethodDefault = DEFAULT,
        timeout: float | None = None,
        metadata: Sequence[tuple[str, str]] = (),
    ) -> Model:
        """
        Set current version of the Model as default.
        :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
        :param region: Required. The ID of the Google Cloud region that the service belongs to.
        :param model_id: Required. The ID of the Model to set as default.
        :param retry: Designation of what errors, if any, should be retried.
        :param timeout: The timeout for this request.
        :param metadata: Strings which should be sent along with the request as metadata.
        """
        client = self.get_model_service_client(region)
        name = client.model_path(project_id, region, model_id)
        result = client.merge_version_aliases(
            request={
                "name": name,
                "version_aliases": ["default"],
            },
            retry=retry,
            timeout=timeout,
            metadata=metadata,
        )
        return result 
    @GoogleBaseHook.fallback_to_default_project_id
[docs]    def add_version_aliases(
        self,
        region: str,
        model_id: str,
        project_id: str,
        version_aliases: Sequence[str],
        retry: Retry | _MethodDefault = DEFAULT,
        timeout: float | None = None,
        metadata: Sequence[tuple[str, str]] = (),
    ) -> Model:
        """
        Add list of version aliases to specific version of Model.
        :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
        :param region: Required. The ID of the Google Cloud region that the service belongs to.
        :param model_id: Required. The ID of the Model to add aliases to.
        :param version_aliases: Required. List of version aliases to be added for specific version.
        :param retry: Designation of what errors, if any, should be retried.
        :param timeout: The timeout for this request.
        :param metadata: Strings which should be sent along with the request as metadata.
        """
        client = self.get_model_service_client(region)
        name = client.model_path(project_id, region, model_id)
        for alias in version_aliases:
            if alias.startswith("-"):
                raise AirflowException("Name of the alias can't start with '-'")
        result = client.merge_version_aliases(
            request={
                "name": name,
                "version_aliases": version_aliases,
            },
            retry=retry,
            timeout=timeout,
            metadata=metadata,
        )
        return result 
    @GoogleBaseHook.fallback_to_default_project_id
[docs]    def delete_version_aliases(
        self,
        region: str,
        model_id: str,
        project_id: str,
        version_aliases: Sequence[str],
        retry: Retry | _MethodDefault = DEFAULT,
        timeout: float | None = None,
        metadata: Sequence[tuple[str, str]] = (),
    ) -> Model:
        """
        Delete list of version aliases of specific version of Model.
        :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
        :param region: Required. The ID of the Google Cloud region that the service belongs to.
        :param model_id: Required. The ID of the Model to delete aliases from.
        :param version_aliases: Required. List of version aliases to be deleted from specific version.
        :param retry: Designation of what errors, if any, should be retried.
        :param timeout: The timeout for this request.
        :param metadata: Strings which should be sent along with the request as metadata.
        """
        client = self.get_model_service_client(region)
        name = client.model_path(project_id, region, model_id)
        if "default" in version_aliases:
            raise AirflowException(
                "Default alias can't be deleted. "
                "Make sure to assign this alias to another version before deletion"
            )
        aliases_for_delete = ["-" + alias for alias in version_aliases]
        result = client.merge_version_aliases(
            request={
                "name": name,
                "version_aliases": aliases_for_delete,
            },
            retry=retry,
            timeout=timeout,
            metadata=metadata,
        )
        return result