#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This module contains integration with Azure Blob Storage.
It communicate via the Window Azure Storage Blob protocol. Make sure that a
Airflow connection of type `wasb` exists. Authorization can be done by supplying a
login (=Storage account name) and password (=KEY), or login and SAS token in the extra
field (see connection `wasb_default` for an example).
"""
from __future__ import annotations
import logging
import os
from functools import cached_property
from typing import TYPE_CHECKING, Any, Union
from urllib.parse import urlparse
from asgiref.sync import sync_to_async
from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceNotFoundError
from azure.identity import ClientSecretCredential
from azure.identity.aio import (
ClientSecretCredential as AsyncClientSecretCredential,
DefaultAzureCredential as AsyncDefaultAzureCredential,
)
from azure.storage.blob import BlobClient, BlobServiceClient, ContainerClient, StorageStreamDownloader
from azure.storage.blob.aio import (
BlobClient as AsyncBlobClient,
BlobServiceClient as AsyncBlobServiceClient,
ContainerClient as AsyncContainerClient,
)
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import (
add_managed_identity_connection_widgets,
get_async_default_azure_credential,
get_sync_default_azure_credential,
)
if TYPE_CHECKING:
from azure.storage.blob._models import BlobProperties
[docs]AsyncCredentials = Union[AsyncClientSecretCredential, AsyncDefaultAzureCredential]
[docs]class WasbHook(BaseHook):
"""
Interact with Azure Blob Storage through the ``wasb://`` protocol.
These parameters have to be passed in Airflow Data Base: account_name and account_key.
Additional options passed in the 'extra' field of the connection will be
passed to the `BlockBlockService()` constructor. For example, authenticate
using a SAS token by adding {"sas_token": "YOUR_TOKEN"}.
If no authentication configuration is provided, DefaultAzureCredential will be used (applicable
when using Azure compute infrastructure).
:param wasb_conn_id: Reference to the :ref:`wasb connection <howto/connection:wasb>`.
:param public_read: Whether an anonymous public read access should be used. default is False
"""
[docs] conn_name_attr = "wasb_conn_id"
[docs] default_conn_name = "wasb_default"
[docs] hook_name = "Azure Blob Storage"
@classmethod
@add_managed_identity_connection_widgets
@classmethod
[docs] def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Return custom field behaviour."""
return {
"hidden_fields": ["schema", "port"],
"relabeling": {
"login": "Blob Storage Login (optional)",
"password": "Blob Storage Key (optional)",
"host": "Account URL (Active Directory Auth)",
},
"placeholders": {
"login": "account name",
"password": "secret",
"host": "account url",
"connection_string": "connection string auth",
"tenant_id": "tenant",
"shared_access_key": "shared access key",
"sas_token": "account url or token",
"extra": "additional options for use with ClientSecretCredential or DefaultAzureCredential",
},
}
def __init__(
self,
wasb_conn_id: str = default_conn_name,
public_read: bool = False,
) -> None:
super().__init__()
self.conn_id = wasb_conn_id
self.public_read = public_read
logger = logging.getLogger("azure.core.pipeline.policies.http_logging_policy")
try:
logger.setLevel(os.environ.get("AZURE_HTTP_LOGGING_LEVEL", logging.WARNING))
except ValueError:
logger.setLevel(logging.WARNING)
def _get_field(self, extra_dict, field_name):
prefix = "extra__wasb__"
if field_name.startswith("extra__"):
raise ValueError(
f"Got prefixed name {field_name}; please remove the '{prefix}' prefix "
f"when using this method."
)
if field_name in extra_dict:
return extra_dict[field_name] or None
return extra_dict.get(f"{prefix}{field_name}") or None
@cached_property
[docs] def blob_service_client(self) -> BlobServiceClient:
"""Return the BlobServiceClient object (cached)."""
return self.get_conn()
[docs] def get_conn(self) -> BlobServiceClient:
"""Return the BlobServiceClient object."""
conn = self.get_connection(self.conn_id)
extra = conn.extra_dejson or {}
client_secret_auth_config = extra.pop("client_secret_auth_config", {})
connection_string = self._get_field(extra, "connection_string")
if connection_string:
# connection_string auth takes priority
return BlobServiceClient.from_connection_string(connection_string, **extra)
account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/"
parsed_url = urlparse(account_url)
if not parsed_url.netloc:
if "." not in parsed_url.path:
# if there's no netloc and no dots in the path, then user only
# provided the Active Directory ID, not the full URL or DNS name
account_url = f"https://{conn.login}.blob.core.windows.net/"
else:
# if there's no netloc but there are dots in the path, then user
# provided the DNS name without the https:// prefix.
# Azure storage account name can only be 3 to 24 characters in length
# https://learn.microsoft.com/en-us/azure/storage/common/storage-account-overview#storage-account-name
acc_name = account_url.split(".")[0][:24]
account_url = f"https://{acc_name}." + ".".join(account_url.split(".")[1:])
tenant = self._get_field(extra, "tenant_id")
if tenant:
# use Active Directory auth
app_id = conn.login
app_secret = conn.password
token_credential = ClientSecretCredential(
tenant_id=tenant, client_id=app_id, client_secret=app_secret, **client_secret_auth_config
)
return BlobServiceClient(account_url=account_url, credential=token_credential, **extra)
if self.public_read:
# Here we use anonymous public read
# more info
# https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
return BlobServiceClient(account_url=account_url, **extra)
shared_access_key = self._get_field(extra, "shared_access_key")
if shared_access_key:
# using shared access key
return BlobServiceClient(account_url=account_url, credential=shared_access_key, **extra)
sas_token = self._get_field(extra, "sas_token")
if sas_token:
if sas_token.startswith("https"):
return BlobServiceClient(account_url=sas_token, **extra)
else:
return BlobServiceClient(account_url=f"{account_url.rstrip('/')}/{sas_token}", **extra)
# Fall back to old auth (password) or use managed identity if not provided.
credential = conn.password
if not credential:
managed_identity_client_id = self._get_field(extra, "managed_identity_client_id")
workload_identity_tenant_id = self._get_field(extra, "workload_identity_tenant_id")
credential = get_sync_default_azure_credential(
managed_identity_client_id=managed_identity_client_id,
workload_identity_tenant_id=workload_identity_tenant_id,
)
self.log.info("Using DefaultAzureCredential as credential")
return BlobServiceClient(
account_url=account_url,
credential=credential,
**extra,
)
def _get_container_client(self, container_name: str) -> ContainerClient:
"""
Instantiate a container client.
:param container_name: The name of the container
:return: ContainerClient
"""
return self.blob_service_client.get_container_client(container_name)
def _get_blob_client(self, container_name: str, blob_name: str) -> BlobClient:
"""
Instantiate a blob client.
:param container_name: The name of the blob container
:param blob_name: The name of the blob. This needs not be existing
"""
return self.blob_service_client.get_blob_client(container=container_name, blob=blob_name)
[docs] def check_for_blob(self, container_name: str, blob_name: str, **kwargs) -> bool:
"""
Check if a blob exists on Azure Blob Storage.
:param container_name: Name of the container.
:param blob_name: Name of the blob.
:param kwargs: Optional keyword arguments for ``BlobClient.get_blob_properties`` takes.
:return: True if the blob exists, False otherwise.
"""
try:
self._get_blob_client(container_name, blob_name).get_blob_properties(**kwargs)
except ResourceNotFoundError:
return False
return True
[docs] def check_for_prefix(self, container_name: str, prefix: str, **kwargs) -> bool:
"""
Check if a prefix exists on Azure Blob storage.
:param container_name: Name of the container.
:param prefix: Prefix of the blob.
:param kwargs: Optional keyword arguments that ``ContainerClient.walk_blobs`` takes
:return: True if blobs matching the prefix exist, False otherwise.
"""
blobs = self.get_blobs_list(container_name=container_name, prefix=prefix, **kwargs)
return bool(blobs)
[docs] def get_blobs_list(
self,
container_name: str,
prefix: str | None = None,
include: list[str] | None = None,
delimiter: str = "/",
**kwargs,
) -> list:
"""
List blobs in a given container.
:param container_name: The name of the container
:param prefix: Filters the results to return only blobs whose names
begin with the specified prefix.
:param include: Specifies one or more additional datasets to include in the
response. Options include: ``snapshots``, ``metadata``, ``uncommittedblobs``,
``copy`, ``deleted``.
:param delimiter: filters objects based on the delimiter (for e.g '.csv')
"""
container = self._get_container_client(container_name)
blob_list = []
blobs = container.walk_blobs(name_starts_with=prefix, include=include, delimiter=delimiter, **kwargs)
for blob in blobs:
blob_list.append(blob.name)
return blob_list
[docs] def get_blobs_list_recursive(
self,
container_name: str,
prefix: str | None = None,
include: list[str] | None = None,
endswith: str = "",
**kwargs,
) -> list:
"""
List blobs in a given container.
:param container_name: The name of the container
:param prefix: Filters the results to return only blobs whose names
begin with the specified prefix.
:param include: Specifies one or more additional datasets to include in the
response. Options include: ``snapshots``, ``metadata``, ``uncommittedblobs``,
``copy`, ``deleted``.
:param delimiter: filters objects based on the delimiter (for e.g '.csv')
"""
container = self._get_container_client(container_name)
blob_list = []
blobs = container.list_blobs(name_starts_with=prefix, include=include, **kwargs)
for blob in blobs:
if blob.name.endswith(endswith):
blob_list.append(blob.name)
return blob_list
[docs] def load_file(
self,
file_path: str,
container_name: str,
blob_name: str,
create_container: bool = False,
**kwargs,
) -> None:
"""
Upload a file to Azure Blob Storage.
:param file_path: Path to the file to load.
:param container_name: Name of the container.
:param blob_name: Name of the blob.
:param create_container: Attempt to create the target container prior to uploading the blob. This is
useful if the target container may not exist yet. Defaults to False.
:param kwargs: Optional keyword arguments that ``BlobClient.upload_blob()`` takes.
"""
with open(file_path, "rb") as data:
self.upload(
container_name=container_name,
blob_name=blob_name,
data=data,
create_container=create_container,
**kwargs,
)
[docs] def load_string(
self,
string_data: str,
container_name: str,
blob_name: str,
create_container: bool = False,
**kwargs,
) -> None:
"""
Upload a string to Azure Blob Storage.
:param string_data: String to load.
:param container_name: Name of the container.
:param blob_name: Name of the blob.
:param create_container: Attempt to create the target container prior to uploading the blob. This is
useful if the target container may not exist yet. Defaults to False.
:param kwargs: Optional keyword arguments that ``BlobClient.upload()`` takes.
"""
# Reorder the argument order from airflow.providers.amazon.aws.hooks.s3.load_string.
self.upload(
container_name=container_name,
blob_name=blob_name,
data=string_data,
create_container=create_container,
**kwargs,
)
[docs] def get_file(self, file_path: str, container_name: str, blob_name: str, **kwargs):
"""
Download a file from Azure Blob Storage.
:param file_path: Path to the file to download.
:param container_name: Name of the container.
:param blob_name: Name of the blob.
:param kwargs: Optional keyword arguments that `BlobClient.download_blob()` takes.
"""
with open(file_path, "wb") as fileblob:
stream = self.download(container_name=container_name, blob_name=blob_name, **kwargs)
fileblob.write(stream.readall())
[docs] def read_file(self, container_name: str, blob_name: str, **kwargs):
"""
Read a file from Azure Blob Storage and return as a string.
:param container_name: Name of the container.
:param blob_name: Name of the blob.
:param kwargs: Optional keyword arguments that `BlobClient.download_blob` takes.
"""
return self.download(container_name, blob_name, **kwargs).content_as_text()
[docs] def upload(
self,
container_name: str,
blob_name: str,
data: Any,
blob_type: str = "BlockBlob",
length: int | None = None,
create_container: bool = False,
**kwargs,
) -> dict[str, Any]:
"""
Create a new blob from a data source with automatic chunking.
:param container_name: The name of the container to upload data
:param blob_name: The name of the blob to upload. This need not exist in the container
:param data: The blob data to upload
:param blob_type: The type of the blob. This can be either ``BlockBlob``,
``PageBlob`` or ``AppendBlob``. The default value is ``BlockBlob``.
:param length: Number of bytes to read from the stream. This is optional,
but should be supplied for optimal performance.
:param create_container: Attempt to create the target container prior to uploading the blob. This is
useful if the target container may not exist yet. Defaults to False.
"""
if create_container:
self.create_container(container_name)
blob_client = self._get_blob_client(container_name, blob_name)
return blob_client.upload_blob(data, blob_type, length=length, **kwargs)
[docs] def download(
self, container_name, blob_name, offset: int | None = None, length: int | None = None, **kwargs
) -> StorageStreamDownloader:
"""
Download a blob to the StorageStreamDownloader.
:param container_name: The name of the container containing the blob
:param blob_name: The name of the blob to download
:param offset: Start of byte range to use for downloading a section of the blob.
Must be set if length is provided.
:param length: Number of bytes to read from the stream.
"""
blob_client = self._get_blob_client(container_name, blob_name)
return blob_client.download_blob(offset=offset, length=length, **kwargs)
[docs] def create_container(self, container_name: str) -> None:
"""
Create container object if not already existing.
:param container_name: The name of the container to create
"""
container_client = self._get_container_client(container_name)
try:
self.log.debug("Attempting to create container: %s", container_name)
container_client.create_container()
self.log.info("Created container: %s", container_name)
except ResourceExistsError:
self.log.info(
"Attempted to create container %r but it already exists. If it is expected that this "
"container will always exist, consider setting create_container to False.",
container_name,
)
except HttpResponseError as e:
self.log.info(
"Received an HTTP response error while attempting to creating container %r: %s"
"\nIf the error is related to missing permissions to create containers, please consider "
"setting create_container to False or supplying connection credentials with the "
"appropriate permission for connection ID %r.",
container_name,
e.response,
self.conn_id,
)
except Exception as e:
self.log.info("Error while attempting to create container %r: %s", container_name, e)
raise
[docs] def delete_container(self, container_name: str) -> None:
"""
Delete a container object.
:param container_name: The name of the container
"""
try:
self.log.debug("Attempting to delete container: %s", container_name)
self._get_container_client(container_name).delete_container()
self.log.info("Deleted container: %s", container_name)
except ResourceNotFoundError:
self.log.warning("Unable to delete container %s (not found)", container_name)
except Exception:
self.log.error("Error deleting container: %s", container_name)
raise
[docs] def delete_blobs(self, container_name: str, *blobs, **kwargs) -> None:
"""
Mark the specified blobs or snapshots for deletion.
:param container_name: The name of the container containing the blobs
:param blobs: The blobs to delete. This can be a single blob, or multiple values
can be supplied, where each value is either the name of the blob (str) or BlobProperties.
"""
self._get_container_client(container_name).delete_blobs(*blobs, **kwargs)
self.log.info("Deleted blobs: %s", blobs)
[docs] def delete_file(
self,
container_name: str,
blob_name: str,
is_prefix: bool = False,
ignore_if_missing: bool = False,
delimiter: str = "",
**kwargs,
) -> None:
"""
Delete a file, or all blobs matching a prefix, from Azure Blob Storage.
:param container_name: Name of the container.
:param blob_name: Name of the blob.
:param is_prefix: If blob_name is a prefix, delete all matching files
:param ignore_if_missing: if True, then return success even if the
blob does not exist.
:param kwargs: Optional keyword arguments that ``ContainerClient.delete_blobs()`` takes.
"""
if is_prefix:
blobs_to_delete = self.get_blobs_list(
container_name, prefix=blob_name, delimiter=delimiter, **kwargs
)
elif self.check_for_blob(container_name, blob_name):
blobs_to_delete = [blob_name]
else:
blobs_to_delete = []
if not ignore_if_missing and not blobs_to_delete:
raise AirflowException(f"Blob(s) not found: {blob_name}")
# The maximum number of blobs that can be deleted in a single request is 256 using the underlying
# `ContainerClient.delete_blobs()` method. Therefore the deletes need to be in batches of <= 256.
num_blobs_to_delete = len(blobs_to_delete)
for i in range(0, num_blobs_to_delete, 256):
self.delete_blobs(container_name, *blobs_to_delete[i : i + 256], **kwargs)
[docs] def test_connection(self):
"""Test Azure Blob Storage connection."""
success = (True, "Successfully connected to Azure Blob Storage.")
try:
# Attempt to retrieve storage account information
self.get_conn().get_account_information()
return success
except Exception as e:
return False, str(e)
[docs]class WasbAsyncHook(WasbHook):
"""
An async hook that connects to Azure WASB to perform operations.
:param wasb_conn_id: reference to the :ref:`wasb connection <howto/connection:wasb>`
:param public_read: whether an anonymous public read access should be used. default is False
"""
def __init__(
self,
wasb_conn_id: str = "wasb_default",
public_read: bool = False,
) -> None:
"""Initialize the hook instance."""
self.conn_id = wasb_conn_id
self.public_read = public_read
self.blob_service_client: AsyncBlobServiceClient = None # type: ignore
[docs] async def get_async_conn(self) -> AsyncBlobServiceClient:
"""Return the Async BlobServiceClient object."""
if self.blob_service_client is not None:
return self.blob_service_client
conn = await sync_to_async(self.get_connection)(self.conn_id)
extra = conn.extra_dejson or {}
client_secret_auth_config = extra.pop("client_secret_auth_config", {})
connection_string = self._get_field(extra, "connection_string")
if connection_string:
# connection_string auth takes priority
self.blob_service_client = AsyncBlobServiceClient.from_connection_string(
connection_string, **extra
)
return self.blob_service_client
account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/"
parsed_url = urlparse(account_url)
if not parsed_url.netloc:
if "." not in parsed_url.path:
# if there's no netloc and no dots in the path, then user only
# provided the Active Directory ID, not the full URL or DNS name
account_url = f"https://{conn.login}.blob.core.windows.net/"
else:
# if there's no netloc but there are dots in the path, then user
# provided the DNS name without the https:// prefix.
# Azure storage account name can only be 3 to 24 characters in length
# https://learn.microsoft.com/en-us/azure/storage/common/storage-account-overview#storage-account-name
acc_name = account_url.split(".")[0][:24]
account_url = f"https://{acc_name}." + ".".join(account_url.split(".")[1:])
tenant = self._get_field(extra, "tenant_id")
if tenant:
# use Active Directory auth
app_id = conn.login
app_secret = conn.password
token_credential = AsyncClientSecretCredential(
tenant, app_id, app_secret, **client_secret_auth_config
)
self.blob_service_client = AsyncBlobServiceClient(
account_url=account_url,
credential=token_credential,
**extra, # type:ignore[arg-type]
)
return self.blob_service_client
if self.public_read:
# Here we use anonymous public read
# more info
# https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
self.blob_service_client = AsyncBlobServiceClient(account_url=account_url, **extra)
return self.blob_service_client
shared_access_key = self._get_field(extra, "shared_access_key")
if shared_access_key:
# using shared access key
self.blob_service_client = AsyncBlobServiceClient(
account_url=account_url, credential=shared_access_key, **extra
)
return self.blob_service_client
sas_token = self._get_field(extra, "sas_token")
if sas_token:
if sas_token.startswith("https"):
self.blob_service_client = AsyncBlobServiceClient(account_url=sas_token, **extra)
else:
self.blob_service_client = AsyncBlobServiceClient(
account_url=f"{account_url.rstrip('/')}/{sas_token}", **extra
)
return self.blob_service_client
# Fall back to old auth (password) or use managed identity if not provided.
credential = conn.password
if not credential:
managed_identity_client_id = self._get_field(extra, "managed_identity_client_id")
workload_identity_tenant_id = self._get_field(extra, "workload_identity_tenant_id")
credential = get_async_default_azure_credential(
managed_identity_client_id=managed_identity_client_id,
workload_identity_tenant_id=workload_identity_tenant_id,
)
self.log.info("Using DefaultAzureCredential as credential")
self.blob_service_client = AsyncBlobServiceClient(
account_url=account_url,
credential=credential,
**extra,
)
return self.blob_service_client
def _get_blob_client(self, container_name: str, blob_name: str) -> AsyncBlobClient:
"""
Instantiate a blob client.
:param container_name: the name of the blob container
:param blob_name: the name of the blob. This needs not be existing
"""
return self.blob_service_client.get_blob_client(container=container_name, blob=blob_name)
[docs] async def check_for_blob_async(self, container_name: str, blob_name: str, **kwargs: Any) -> bool:
"""
Check if a blob exists on Azure Blob Storage.
:param container_name: name of the container
:param blob_name: name of the blob
:param kwargs: optional keyword arguments for ``BlobClient.get_blob_properties``
"""
try:
await self._get_blob_client(container_name, blob_name).get_blob_properties(**kwargs)
except ResourceNotFoundError:
return False
return True
def _get_container_client(self, container_name: str) -> AsyncContainerClient:
"""
Instantiate a container client.
:param container_name: the name of the container
"""
return self.blob_service_client.get_container_client(container_name)
[docs] async def get_blobs_list_async(
self,
container_name: str,
prefix: str | None = None,
include: list[str] | None = None,
delimiter: str = "/",
**kwargs: Any,
) -> list[BlobProperties]:
"""
List blobs in a given container.
:param container_name: the name of the container
:param prefix: filters the results to return only blobs whose names
begin with the specified prefix.
:param include: specifies one or more additional datasets to include in the
response. Options include: ``snapshots``, ``metadata``, ``uncommittedblobs``,
``copy`, ``deleted``.
:param delimiter: filters objects based on the delimiter (for e.g '.csv')
"""
container = self._get_container_client(container_name)
blob_list: list[BlobProperties] = []
blobs = container.walk_blobs(name_starts_with=prefix, include=include, delimiter=delimiter, **kwargs)
async for blob in blobs:
blob_list.append(blob)
return blob_list
[docs] async def check_for_prefix_async(self, container_name: str, prefix: str, **kwargs: Any) -> bool:
"""
Check if a prefix exists on Azure Blob storage.
:param container_name: Name of the container.
:param prefix: Prefix of the blob.
:param kwargs: Optional keyword arguments for ``ContainerClient.walk_blobs``
"""
blobs = await self.get_blobs_list_async(container_name=container_name, prefix=prefix, **kwargs)
return bool(blobs)