#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import time
from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any
from azure.batch import BatchServiceClient, batch_auth, models as batch_models
from airflow.exceptions import AirflowException
from airflow.providers.common.compat.sdk import BaseHook
from airflow.providers.microsoft.azure.utils import (
    AzureIdentityCredentialAdapter,
    add_managed_identity_connection_widgets,
    get_field,
)
from airflow.utils import timezone
if TYPE_CHECKING:
    from azure.batch.models import JobAddParameter, PoolAddParameter, TaskAddParameter
[docs]
class AzureBatchHook(BaseHook):
    """
    Hook for Azure Batch APIs.
    :param azure_batch_conn_id: :ref:`Azure Batch connection id<howto/connection:azure_batch>`
        of a service principal which will be used to start the container instance.
    """
[docs]
    conn_name_attr = "azure_batch_conn_id" 
[docs]
    default_conn_name = "azure_batch_default" 
[docs]
    conn_type = "azure_batch" 
[docs]
    hook_name = "Azure Batch Service" 
    @classmethod
    @add_managed_identity_connection_widgets
    @classmethod
[docs]
    def get_ui_field_behaviour(cls) -> dict[str, Any]:
        """Return custom field behaviour."""
        return {
            "hidden_fields": ["schema", "port", "host", "extra"],
            "relabeling": {
                "login": "Batch Account Name",
                "password": "Batch Account Access Key",
            },
        } 
    def __init__(self, azure_batch_conn_id: str = default_conn_name) -> None:
        super().__init__()
[docs]
        self.conn_id = azure_batch_conn_id 
    def _get_field(self, extras, name):
        return get_field(
            conn_id=self.conn_id,
            conn_type=self.conn_type,
            extras=extras,
            field_name=name,
        )
    @cached_property
[docs]
    def connection(self) -> BatchServiceClient:
        """Get the Batch client connection (cached)."""
        return self.get_conn() 
[docs]
    def get_conn(self) -> BatchServiceClient:
        """
        Get the Batch client connection.
        :return: Azure Batch client
        """
        conn = self.get_connection(self.conn_id)
        batch_account_url = self._get_field(conn.extra_dejson, "account_url")
        if not batch_account_url:
            raise AirflowException("Batch Account URL parameter is missing.")
        credentials: batch_auth.SharedKeyCredentials | AzureIdentityCredentialAdapter
        if all([conn.login, conn.password]):
            credentials = batch_auth.SharedKeyCredentials(conn.login, conn.password)
        else:
            managed_identity_client_id = conn.extra_dejson.get("managed_identity_client_id")
            workload_identity_tenant_id = conn.extra_dejson.get("workload_identity_tenant_id")
            credentials = AzureIdentityCredentialAdapter(
                None,
                resource_id="https://batch.core.windows.net/.default",
                managed_identity_client_id=managed_identity_client_id,
                workload_identity_tenant_id=workload_identity_tenant_id,
            )
        batch_client = BatchServiceClient(credentials, batch_url=batch_account_url)
        return batch_client 
[docs]
    def create_pool(self, pool: PoolAddParameter) -> None:
        """
        Create a pool if not already existing.
        :param pool: the pool object to create
        """
        try:
            self.log.info("Attempting to create a pool: %s", pool.id)
            self.connection.pool.add(pool)
            self.log.info("Created pool: %s", pool.id)
        except batch_models.BatchErrorException as err:
            if not err.error or err.error.code != "PoolExists":
                raise
            self.log.info("Pool %s already exists", pool.id) 
    def _get_latest_verified_image_vm_and_sku(
        self,
        publisher: str | None = None,
        offer: str | None = None,
        sku_starts_with: str | None = None,
    ) -> tuple:
        """
        Get latest verified image vm and sku.
        :param publisher: The publisher of the Azure Virtual Machines Marketplace Image.
            For example, Canonical or MicrosoftWindowsServer.
        :param offer: The offer type of the Azure Virtual Machines Marketplace Image.
            For example, UbuntuServer or WindowsServer.
        :param sku_starts_with: The start name of the sku to search
        """
        options = batch_models.AccountListSupportedImagesOptions(filter="verificationType eq 'verified'")
        images = self.connection.account.list_supported_images(account_list_supported_images_options=options)
        # pick the latest supported sku
        skus_to_use = [
            (image.node_agent_sku_id, image.image_reference)
            for image in images
            if image.image_reference.publisher.lower() == publisher
            and image.image_reference.offer.lower() == offer
            and image.image_reference.sku.startswith(sku_starts_with)
        ]
        # pick first
        agent_sku_id, image_ref_to_use = skus_to_use[0]
        return agent_sku_id, image_ref_to_use
[docs]
    def wait_for_all_node_state(self, pool_id: str, node_state: set) -> list:
        """
        Wait for all nodes in a pool to reach given states.
        :param pool_id: A string that identifies the pool
        :param node_state: A set of batch_models.ComputeNodeState
        """
        self.log.info("waiting for all nodes in pool %s to reach one of: %s", pool_id, node_state)
        while True:
            # refresh pool to ensure that there is no resize error
            pool = self.connection.pool.get(pool_id)
            if pool.resize_errors is not None:
                resize_errors = "\n".join(repr(e) for e in pool.resize_errors)
                raise RuntimeError(f"resize error encountered for pool {pool.id}:\n{resize_errors}")
            nodes = list(self.connection.compute_node.list(pool.id))
            if len(nodes) >= pool.target_dedicated_nodes and all(node.state in node_state for node in nodes):
                return nodes
            # Allow the timeout to be controlled by the AzureBatchOperator
            # specified timeout. This way we don't interrupt a startTask inside
            # the pool
            time.sleep(10) 
[docs]
    def create_job(self, job: JobAddParameter) -> None:
        """
        Create a job in the pool.
        :param job: The job object to create
        """
        try:
            self.connection.job.add(job)
            self.log.info("Job %s created", job.id)
        except batch_models.BatchErrorException as err:
            if not err.error or err.error.code != "JobExists":
                raise
            self.log.info("Job %s already exists", job.id) 
[docs]
    def add_single_task_to_job(self, job_id: str, task: TaskAddParameter) -> None:
        """
        Add a single task to given job if it doesn't exist.
        :param job_id: A string that identifies the given job
        :param task: The task to add
        """
        try:
            self.connection.task.add(job_id=job_id, task=task)
        except batch_models.BatchErrorException as err:
            if not err.error or err.error.code != "TaskExists":
                raise
            self.log.info("Task %s already exists", task.id) 
[docs]
    def wait_for_job_tasks_to_complete(self, job_id: str, timeout: int) -> list[batch_models.CloudTask]:
        """
        Wait for tasks in a particular job to complete.
        :param job_id: A string that identifies the job
        :param timeout: The amount of time to wait before timing out in minutes
        """
        timeout_time = timezone.utcnow() + timedelta(minutes=timeout)
        while timezone.utcnow() < timeout_time:
            tasks = list(self.connection.task.list(job_id))
            incomplete_tasks = [task for task in tasks if task.state != batch_models.TaskState.completed]
            if not incomplete_tasks:
                # detect if any task in job has failed
                fail_tasks = [
                    task
                    for task in tasks
                    if task.execution_info.result == batch_models.TaskExecutionResult.failure
                ]
                return fail_tasks
            for task in incomplete_tasks:
                self.log.info("Waiting for %s to complete, currently on %s state", task.id, task.state)
            time.sleep(15)
        raise TimeoutError("Timed out waiting for tasks to complete") 
[docs]
    def test_connection(self):
        """Test a configured Azure Batch connection."""
        try:
            # Attempt to list existing  jobs under the configured Batch account and retrieve
            # the first in the returned iterator. The Azure Batch API does allow for creation of a
            # BatchServiceClient with incorrect values but then will fail properly once items are
            # retrieved using the client. We need to _actually_ try to retrieve an object to properly
            # test the connection.
            next(self.get_conn().job.list(), None)
        except Exception as e:
            return False, str(e)
        return True, "Successfully connected to Azure Batch."