#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import time
from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any
from azure.batch import BatchServiceClient, batch_auth, models as batch_models
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import (
AzureIdentityCredentialAdapter,
add_managed_identity_connection_widgets,
get_field,
)
from airflow.utils import timezone
if TYPE_CHECKING:
from azure.batch.models import JobAddParameter, PoolAddParameter, TaskAddParameter
[docs]class AzureBatchHook(BaseHook):
"""
Hook for Azure Batch APIs.
:param azure_batch_conn_id: :ref:`Azure Batch connection id<howto/connection:azure_batch>`
of a service principal which will be used to start the container instance.
"""
[docs] conn_name_attr = "azure_batch_conn_id"
[docs] default_conn_name = "azure_batch_default"
[docs] conn_type = "azure_batch"
[docs] hook_name = "Azure Batch Service"
@classmethod
@add_managed_identity_connection_widgets
@classmethod
[docs] def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["schema", "port", "host", "extra"],
"relabeling": {
"login": "Batch Account Name",
"password": "Batch Account Access Key",
},
}
def __init__(self, azure_batch_conn_id: str = default_conn_name) -> None:
super().__init__()
self.conn_id = azure_batch_conn_id
def _get_field(self, extras, name):
return get_field(
conn_id=self.conn_id,
conn_type=self.conn_type,
extras=extras,
field_name=name,
)
@cached_property
[docs] def connection(self) -> BatchServiceClient:
"""Get the Batch client connection (cached)."""
return self.get_conn()
[docs] def get_conn(self) -> BatchServiceClient:
"""
Get the Batch client connection.
:return: Azure Batch client
"""
conn = self.get_connection(self.conn_id)
batch_account_url = self._get_field(conn.extra_dejson, "account_url")
if not batch_account_url:
raise AirflowException("Batch Account URL parameter is missing.")
credentials: batch_auth.SharedKeyCredentials | AzureIdentityCredentialAdapter
if all([conn.login, conn.password]):
credentials = batch_auth.SharedKeyCredentials(conn.login, conn.password)
else:
managed_identity_client_id = conn.extra_dejson.get("managed_identity_client_id")
workload_identity_tenant_id = conn.extra_dejson.get("workload_identity_tenant_id")
credentials = AzureIdentityCredentialAdapter(
None,
resource_id="https://batch.core.windows.net/.default",
managed_identity_client_id=managed_identity_client_id,
workload_identity_tenant_id=workload_identity_tenant_id,
)
batch_client = BatchServiceClient(credentials, batch_url=batch_account_url)
return batch_client
[docs] def create_pool(self, pool: PoolAddParameter) -> None:
"""
Creates a pool if not already existing.
:param pool: the pool object to create
"""
try:
self.log.info("Attempting to create a pool: %s", pool.id)
self.connection.pool.add(pool)
self.log.info("Created pool: %s", pool.id)
except batch_models.BatchErrorException as err:
if not err.error or err.error.code != "PoolExists":
raise
else:
self.log.info("Pool %s already exists", pool.id)
def _get_latest_verified_image_vm_and_sku(
self,
publisher: str | None = None,
offer: str | None = None,
sku_starts_with: str | None = None,
) -> tuple:
"""
Get latest verified image vm and sku.
:param publisher: The publisher of the Azure Virtual Machines Marketplace Image.
For example, Canonical or MicrosoftWindowsServer.
:param offer: The offer type of the Azure Virtual Machines Marketplace Image.
For example, UbuntuServer or WindowsServer.
:param sku_starts_with: The start name of the sku to search
"""
options = batch_models.AccountListSupportedImagesOptions(filter="verificationType eq 'verified'")
images = self.connection.account.list_supported_images(account_list_supported_images_options=options)
# pick the latest supported sku
skus_to_use = [
(image.node_agent_sku_id, image.image_reference)
for image in images
if image.image_reference.publisher.lower() == publisher
and image.image_reference.offer.lower() == offer
and image.image_reference.sku.startswith(sku_starts_with)
]
# pick first
agent_sku_id, image_ref_to_use = skus_to_use[0]
return agent_sku_id, image_ref_to_use
[docs] def wait_for_all_node_state(self, pool_id: str, node_state: set) -> list:
"""
Wait for all nodes in a pool to reach given states.
:param pool_id: A string that identifies the pool
:param node_state: A set of batch_models.ComputeNodeState
"""
self.log.info("waiting for all nodes in pool %s to reach one of: %s", pool_id, node_state)
while True:
# refresh pool to ensure that there is no resize error
pool = self.connection.pool.get(pool_id)
if pool.resize_errors is not None:
resize_errors = "\n".join(repr(e) for e in pool.resize_errors)
raise RuntimeError(f"resize error encountered for pool {pool.id}:\n{resize_errors}")
nodes = list(self.connection.compute_node.list(pool.id))
if len(nodes) >= pool.target_dedicated_nodes and all(node.state in node_state for node in nodes):
return nodes
# Allow the timeout to be controlled by the AzureBatchOperator
# specified timeout. This way we don't interrupt a startTask inside
# the pool
time.sleep(10)
[docs] def create_job(self, job: JobAddParameter) -> None:
"""
Creates a job in the pool.
:param job: The job object to create
"""
try:
self.connection.job.add(job)
self.log.info("Job %s created", job.id)
except batch_models.BatchErrorException as err:
if not err.error or err.error.code != "JobExists":
raise
else:
self.log.info("Job %s already exists", job.id)
[docs] def add_single_task_to_job(self, job_id: str, task: TaskAddParameter) -> None:
"""
Add a single task to given job if it doesn't exist.
:param job_id: A string that identifies the given job
:param task: The task to add
"""
try:
self.connection.task.add(job_id=job_id, task=task)
except batch_models.BatchErrorException as err:
if not err.error or err.error.code != "TaskExists":
raise
else:
self.log.info("Task %s already exists", task.id)
[docs] def wait_for_job_tasks_to_complete(self, job_id: str, timeout: int) -> list[batch_models.CloudTask]:
"""
Wait for tasks in a particular job to complete.
:param job_id: A string that identifies the job
:param timeout: The amount of time to wait before timing out in minutes
"""
timeout_time = timezone.utcnow() + timedelta(minutes=timeout)
while timezone.utcnow() < timeout_time:
tasks = self.connection.task.list(job_id)
incomplete_tasks = [task for task in tasks if task.state != batch_models.TaskState.completed]
if not incomplete_tasks:
# detect if any task in job has failed
fail_tasks = [
task
for task in tasks
if task.executionInfo.result == batch_models.TaskExecutionResult.failure
]
return fail_tasks
for task in incomplete_tasks:
self.log.info("Waiting for %s to complete, currently on %s state", task.id, task.state)
time.sleep(15)
raise TimeoutError("Timed out waiting for tasks to complete")
[docs] def test_connection(self):
"""Test a configured Azure Batch connection."""
try:
# Attempt to list existing jobs under the configured Batch account and retrieve
# the first in the returned iterator. The Azure Batch API does allow for creation of a
# BatchServiceClient with incorrect values but then will fail properly once items are
# retrieved using the client. We need to _actually_ try to retrieve an object to properly
# test the connection.
next(self.get_conn().job.list(), None)
except Exception as e:
return False, str(e)
return True, "Successfully connected to Azure Batch."