# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import time
from typing import TYPE_CHECKING, Any
from azure.core.exceptions import ServiceRequestError
from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.synapse.artifacts import ArtifactsClient
from azure.synapse.spark import SparkClient
from airflow.exceptions import AirflowException, AirflowTaskTimeout
from airflow.providers.common.compat.sdk import BaseHook
from airflow.providers.microsoft.azure.utils import (
    add_managed_identity_connection_widgets,
    get_field,
    get_sync_default_azure_credential,
)
if TYPE_CHECKING:
    from azure.synapse.artifacts.models import CreateRunResponse, PipelineRun
    from azure.synapse.spark.models import SparkBatchJobOptions
[docs]
Credentials = ClientSecretCredential | DefaultAzureCredential 
[docs]
class AzureSynapseSparkBatchRunStatus:
    """Azure Synapse Spark Job operation statuses."""
[docs]
    NOT_STARTED = "not_started" 
[docs]
    SHUTTING_DOWN = "shutting_down" 
[docs]
    TERMINAL_STATUSES = {SUCCESS, DEAD, KILLED, ERROR} 
 
[docs]
class AzureSynapseHook(BaseHook):
    """
    A hook to interact with Azure Synapse.
    :param azure_synapse_conn_id: The :ref:`Azure Synapse connection id<howto/connection:synapse>`.
    :param spark_pool: The Apache Spark pool used to submit the job
    """
[docs]
    conn_type: str = "azure_synapse" 
[docs]
    conn_name_attr: str = "azure_synapse_conn_id" 
[docs]
    default_conn_name: str = "azure_synapse_default" 
[docs]
    hook_name: str = "Azure Synapse" 
    @classmethod
    @add_managed_identity_connection_widgets
    @classmethod
[docs]
    def get_ui_field_behaviour(cls) -> dict[str, Any]:
        """Return custom field behaviour."""
        return {
            "hidden_fields": ["schema", "port", "extra"],
            "relabeling": {
                "login": "Client ID",
                "password": "Secret",
                "host": "Synapse Workspace URL",
            },
        } 
    def __init__(self, azure_synapse_conn_id: str = default_conn_name, spark_pool: str = ""):
[docs]
        self.job_id: int | None = None 
        self._conn: SparkClient | None = None
[docs]
        self.conn_id = azure_synapse_conn_id 
[docs]
        self.spark_pool = spark_pool 
        super().__init__()
    def _get_field(self, extras, name):
        return get_field(
            conn_id=self.conn_id,
            conn_type=self.conn_type,
            extras=extras,
            field_name=name,
        )
[docs]
    def get_conn(self) -> SparkClient:
        if self._conn is not None:
            return self._conn
        conn = self.get_connection(self.conn_id)
        extras = conn.extra_dejson
        tenant = self._get_field(extras, "tenantId")
        spark_pool = self.spark_pool
        livy_api_version = "2022-02-22-preview"
        subscription_id = self._get_field(extras, "subscriptionId")
        if not subscription_id:
            raise ValueError("A Subscription ID is required to connect to Azure Synapse.")
        credential: Credentials
        if conn.login is not None and conn.password is not None:
            if not tenant:
                raise ValueError("A Tenant ID is required when authenticating with Client ID and Secret.")
            credential = ClientSecretCredential(
                client_id=conn.login, client_secret=conn.password, tenant_id=tenant
            )
        else:
            managed_identity_client_id = self._get_field(extras, "managed_identity_client_id")
            workload_identity_tenant_id = self._get_field(extras, "workload_identity_tenant_id")
            credential = get_sync_default_azure_credential(
                managed_identity_client_id=managed_identity_client_id,
                workload_identity_tenant_id=workload_identity_tenant_id,
            )
        self._conn = self._create_client(credential, conn.host, spark_pool, livy_api_version, subscription_id)
        return self._conn 
    @staticmethod
    def _create_client(credential: Credentials, host, spark_pool, livy_api_version, subscription_id: str):
        return SparkClient(
            credential=credential,
            endpoint=host,
            spark_pool_name=spark_pool,
            livy_api_version=livy_api_version,
            subscription_id=subscription_id,
        )
[docs]
    def run_spark_job(
        self,
        payload: SparkBatchJobOptions,
    ):
        """
        Run a job in an Apache Spark pool.
        :param payload: Livy compatible payload which represents the spark job that a user wants to submit.
        """
        job = self.get_conn().spark_batch.create_spark_batch_job(payload)
        self.job_id = job.id
        return job 
[docs]
    def get_job_run_status(self):
        """Get the job run status."""
        job_run_status = self.get_conn().spark_batch.get_spark_batch_job(batch_id=self.job_id).state
        return job_run_status 
[docs]
    def wait_for_job_run_status(
        self,
        job_id: int | None,
        expected_statuses: str | set[str],
        check_interval: int = 60,
        timeout: int = 60 * 60 * 24 * 7,
    ) -> bool:
        """
        Wait for a job run to match an expected status.
        :param job_id: The job run identifier.
        :param expected_statuses: The desired status(es) to check against a job run's current status.
        :param check_interval: Time in seconds to check on a job run's status.
        :param timeout: Time in seconds to wait for a job to reach a terminal status or the expected
            status.
        """
        job_run_status = self.get_job_run_status()
        start_time = time.monotonic()
        while (
            job_run_status not in AzureSynapseSparkBatchRunStatus.TERMINAL_STATUSES
            and job_run_status not in expected_statuses
        ):
            # Check if the job-run duration has exceeded the ``timeout`` configured.
            if start_time + timeout < time.monotonic():
                raise AirflowTaskTimeout(
                    f"Job {job_id} has not reached a terminal status after {timeout} seconds."
                )
            # Wait to check the status of the job run based on the ``check_interval`` configured.
            self.log.info("Sleeping for %s seconds", check_interval)
            time.sleep(check_interval)
            job_run_status = self.get_job_run_status()
            self.log.info("Current spark job run status is %s", job_run_status)
        return job_run_status in expected_statuses 
[docs]
    def cancel_job_run(
        self,
        job_id: int,
    ) -> None:
        """
        Cancel the spark job run.
        :param job_id: The synapse spark job identifier.
        """
        self.get_conn().spark_batch.cancel_spark_batch_job(job_id) 
 
[docs]
class AzureSynapsePipelineRunStatus:
    """Azure Synapse pipeline operation statuses."""
[docs]
    IN_PROGRESS = "InProgress" 
[docs]
    SUCCEEDED = "Succeeded" 
[docs]
    CANCELING = "Canceling" 
[docs]
    CANCELLED = "Cancelled" 
[docs]
    TERMINAL_STATUSES = {CANCELLED, FAILED, SUCCEEDED} 
[docs]
    FAILURE_STATES = {FAILED, CANCELLED} 
 
[docs]
class AzureSynapsePipelineRunException(AirflowException):
    """An exception that indicates a pipeline run failed to complete.""" 
[docs]
class BaseAzureSynapseHook(BaseHook):
    """
    A base hook class to create session and connection to Azure Synapse using connection id.
    :param azure_synapse_conn_id: The :ref:`Azure Synapse connection id<howto/connection:synapse>`.
    """
[docs]
    conn_type: str = "azure_synapse" 
[docs]
    conn_name_attr: str = "azure_synapse_conn_id" 
[docs]
    default_conn_name: str = "azure_synapse_default" 
[docs]
    hook_name: str = "Azure Synapse" 
    @classmethod
    @add_managed_identity_connection_widgets
    @classmethod
[docs]
    def get_ui_field_behaviour(cls) -> dict[str, Any]:
        """Return custom field behaviour."""
        return {
            "hidden_fields": ["schema", "port", "extra"],
            "relabeling": {
                "login": "Client ID",
                "password": "Secret",
                "host": "Synapse Workspace URL",
            },
        } 
    def __init__(self, azure_synapse_conn_id: str = default_conn_name, **kwargs) -> None:
        super().__init__(**kwargs)
[docs]
        self.conn_id = azure_synapse_conn_id 
    def _get_field(self, extras: dict, field_name: str) -> str:
        return get_field(
            conn_id=self.conn_id,
            conn_type=self.conn_type,
            extras=extras,
            field_name=field_name,
        ) 
[docs]
class AzureSynapsePipelineHook(BaseAzureSynapseHook):
    """
    A hook to interact with Azure Synapse Pipeline.
    :param azure_synapse_conn_id: The :ref:`Azure Synapse connection id<howto/connection:synapse>`.
    :param azure_synapse_workspace_dev_endpoint: The Azure Synapse Workspace development endpoint.
    """
[docs]
    default_conn_name: str = "azure_synapse_connection" 
    def __init__(
        self,
        azure_synapse_workspace_dev_endpoint: str,
        azure_synapse_conn_id: str = default_conn_name,
        **kwargs,
    ):
        self._conn: ArtifactsClient | None = None
[docs]
        self.azure_synapse_workspace_dev_endpoint = azure_synapse_workspace_dev_endpoint 
        super().__init__(azure_synapse_conn_id=azure_synapse_conn_id, **kwargs)
    def _get_field(self, extras, name):
        return get_field(
            conn_id=self.conn_id,
            conn_type=self.conn_type,
            extras=extras,
            field_name=name,
        )
[docs]
    def get_conn(self) -> ArtifactsClient:
        if self._conn is not None:
            return self._conn
        conn = self.get_connection(self.conn_id)
        extras = conn.extra_dejson
        tenant = self._get_field(extras, "tenantId")
        credential: Credentials
        if not conn.login or not conn.password:
            managed_identity_client_id = self._get_field(extras, "managed_identity_client_id")
            workload_identity_tenant_id = self._get_field(extras, "workload_identity_tenant_id")
            credential = get_sync_default_azure_credential(
                managed_identity_client_id=managed_identity_client_id,
                workload_identity_tenant_id=workload_identity_tenant_id,
            )
        else:
            if not tenant:
                raise ValueError("A Tenant ID is required when authenticating with Client ID and Secret.")
            credential = ClientSecretCredential(
                client_id=conn.login, client_secret=conn.password, tenant_id=tenant
            )
        self._conn = self._create_client(credential, self.azure_synapse_workspace_dev_endpoint)
        if self._conn is not None:
            return self._conn
        raise ValueError("Failed to create ArtifactsClient") 
    @staticmethod
    def _create_client(credential: Credentials, endpoint: str) -> ArtifactsClient:
        return ArtifactsClient(credential=credential, endpoint=endpoint)
[docs]
    def run_pipeline(self, pipeline_name: str, **config: Any) -> CreateRunResponse:
        """
        Run a Synapse pipeline.
        :param pipeline_name: The pipeline name.
        :param config: Extra parameters for the Synapse Artifact Client.
        :return: The pipeline run Id.
        """
        return self.get_conn().pipeline.create_pipeline_run(pipeline_name, **config) 
[docs]
    def get_pipeline_run(self, run_id: str) -> PipelineRun:
        """
        Get the pipeline run.
        :param run_id: The pipeline run identifier.
        :return: The pipeline run.
        """
        return self.get_conn().pipeline_run.get_pipeline_run(run_id=run_id) 
[docs]
    def get_pipeline_run_status(self, run_id: str) -> str:
        """
        Get a pipeline run's current status.
        :param run_id: The pipeline run identifier.
        :return: The status of the pipeline run.
        """
        pipeline_run_status = self.get_pipeline_run(
            run_id=run_id,
        ).status
        return str(pipeline_run_status) 
[docs]
    def refresh_conn(self) -> ArtifactsClient:
        self._conn = None
        return self.get_conn() 
[docs]
    def wait_for_pipeline_run_status(
        self,
        run_id: str,
        expected_statuses: str | set[str],
        check_interval: int = 60,
        timeout: int = 60 * 60 * 24 * 7,
    ) -> bool:
        """
        Wait for a pipeline run to match an expected status.
        :param run_id: The pipeline run identifier.
        :param expected_statuses: The desired status(es) to check against a pipeline run's current status.
        :param check_interval: Time in seconds to check on a pipeline run's status.
        :param timeout: Time in seconds to wait for a pipeline to reach a terminal status or the expected
            status.
        :return: Boolean indicating if the pipeline run has reached the ``expected_status``.
        """
        pipeline_run_status = self.get_pipeline_run_status(run_id=run_id)
        executed_after_token_refresh = True
        start_time = time.monotonic()
        while (
            pipeline_run_status not in AzureSynapsePipelineRunStatus.TERMINAL_STATUSES
            and pipeline_run_status not in expected_statuses
        ):
            if start_time + timeout < time.monotonic():
                raise AzureSynapsePipelineRunException(
                    f"Pipeline run {run_id} has not reached a terminal status after {timeout} seconds."
                )
            # Wait to check the status of the pipeline run based on the ``check_interval`` configured.
            time.sleep(check_interval)
            try:
                pipeline_run_status = self.get_pipeline_run_status(run_id=run_id)
                executed_after_token_refresh = True
            except ServiceRequestError:
                if executed_after_token_refresh:
                    self.refresh_conn()
                else:
                    raise
        return pipeline_run_status in expected_statuses 
[docs]
    def cancel_run_pipeline(self, run_id: str) -> None:
        """
        Cancel the pipeline run.
        :param run_id: The pipeline run identifier.
        """
        self.get_conn().pipeline_run.cancel_pipeline_run(run_id)