Source code for airflow.providers.microsoft.azure.hooks.synapse

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import time
from typing import TYPE_CHECKING, Any, Union

from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.synapse.spark import SparkClient

from airflow.exceptions import AirflowTaskTimeout
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import get_field

if TYPE_CHECKING:
    from azure.synapse.spark.models import SparkBatchJobOptions

[docs]Credentials = Union[ClientSecretCredential, DefaultAzureCredential]
[docs]class AzureSynapseSparkBatchRunStatus: """Azure Synapse Spark Job operation statuses."""
[docs] NOT_STARTED = "not_started"
[docs] STARTING = "starting"
[docs] RUNNING = "running"
[docs] IDLE = "idle"
[docs] BUSY = "busy"
[docs] SHUTTING_DOWN = "shutting_down"
[docs] ERROR = "error"
[docs] DEAD = "dead"
[docs] KILLED = "killed"
[docs] SUCCESS = "success"
[docs] TERMINAL_STATUSES = {SUCCESS, DEAD, KILLED, ERROR}
[docs]class AzureSynapseHook(BaseHook): """ A hook to interact with Azure Synapse. :param azure_synapse_conn_id: The :ref:`Azure Synapse connection id<howto/connection:synapse>`. :param spark_pool: The Apache Spark pool used to submit the job """
[docs] conn_type: str = "azure_synapse"
[docs] conn_name_attr: str = "azure_synapse_conn_id"
[docs] default_conn_name: str = "azure_synapse_default"
[docs] hook_name: str = "Azure Synapse"
@staticmethod
[docs] def get_connection_form_widgets() -> dict[str, Any]: """Returns connection widgets to add to connection form.""" from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from flask_babel import lazy_gettext from wtforms import StringField return { "tenantId": StringField(lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()), "subscriptionId": StringField(lazy_gettext("Subscription ID"), widget=BS3TextFieldWidget()), }
@staticmethod
[docs] def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour.""" return { "hidden_fields": ["schema", "port", "extra"], "relabeling": {"login": "Client ID", "password": "Secret", "host": "Synapse Workspace URL"}, }
def __init__(self, azure_synapse_conn_id: str = default_conn_name, spark_pool: str = ""): self.job_id: int | None = None self._conn: SparkClient | None = None self.conn_id = azure_synapse_conn_id self.spark_pool = spark_pool super().__init__() def _get_field(self, extras, name): return get_field( conn_id=self.conn_id, conn_type=self.conn_type, extras=extras, field_name=name, )
[docs] def get_conn(self) -> SparkClient: if self._conn is not None: return self._conn conn = self.get_connection(self.conn_id) extras = conn.extra_dejson tenant = self._get_field(extras, "tenantId") spark_pool = self.spark_pool livy_api_version = "2022-02-22-preview" subscription_id = self._get_field(extras, "subscriptionId") if not subscription_id: raise ValueError("A Subscription ID is required to connect to Azure Synapse.") credential: Credentials if conn.login is not None and conn.password is not None: if not tenant: raise ValueError("A Tenant ID is required when authenticating with Client ID and Secret.") credential = ClientSecretCredential( client_id=conn.login, client_secret=conn.password, tenant_id=tenant ) else: credential = DefaultAzureCredential() self._conn = self._create_client(credential, conn.host, spark_pool, livy_api_version, subscription_id) return self._conn
@staticmethod def _create_client(credential: Credentials, host, spark_pool, livy_api_version, subscription_id: str): return SparkClient( credential=credential, endpoint=host, spark_pool_name=spark_pool, livy_api_version=livy_api_version, subscription_id=subscription_id, )
[docs] def run_spark_job( self, payload: SparkBatchJobOptions, ): """ Run a job in an Apache Spark pool. :param payload: Livy compatible payload which represents the spark job that a user wants to submit. """ job = self.get_conn().spark_batch.create_spark_batch_job(payload) self.job_id = job.id return job
[docs] def get_job_run_status(self): """Get the job run status.""" job_run_status = self.get_conn().spark_batch.get_spark_batch_job(batch_id=self.job_id).state return job_run_status
[docs] def wait_for_job_run_status( self, job_id: int | None, expected_statuses: str | set[str], check_interval: int = 60, timeout: int = 60 * 60 * 24 * 7, ) -> bool: """ Waits for a job run to match an expected status. :param job_id: The job run identifier. :param expected_statuses: The desired status(es) to check against a job run's current status. :param check_interval: Time in seconds to check on a job run's status. :param timeout: Time in seconds to wait for a job to reach a terminal status or the expected status. """ job_run_status = self.get_job_run_status() start_time = time.monotonic() while ( job_run_status not in AzureSynapseSparkBatchRunStatus.TERMINAL_STATUSES and job_run_status not in expected_statuses ): # Check if the job-run duration has exceeded the ``timeout`` configured. if start_time + timeout < time.monotonic(): raise AirflowTaskTimeout( f"Job {job_id} has not reached a terminal status after {timeout} seconds." ) # Wait to check the status of the job run based on the ``check_interval`` configured. self.log.info("Sleeping for %s seconds", check_interval) time.sleep(check_interval) job_run_status = self.get_job_run_status() self.log.info("Current spark job run status is %s", job_run_status) return job_run_status in expected_statuses
[docs] def cancel_job_run( self, job_id: int, ) -> None: """ Cancel the spark job run. :param job_id: The synapse spark job identifier. """ self.get_conn().spark_batch.cancel_spark_batch_job(job_id)

Was this entry helpful?