Source code for airflow.providers.microsoft.azure.hooks.fileshare
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import IO, Any
from azure.storage.fileshare import FileProperties, ShareDirectoryClient, ShareFileClient, ShareServiceClient
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import (
add_managed_identity_connection_widgets,
get_sync_default_azure_credential,
)
[docs]class AzureFileShareHook(BaseHook):
"""
Interacts with Azure FileShare Storage.
:param azure_fileshare_conn_id: Reference to the
:ref:`Azure FileShare connection id<howto/connection:azure_fileshare>`
of an Azure account of which file share should be used.
"""
@staticmethod
@add_managed_identity_connection_widgets
[docs] def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget
from flask_babel import lazy_gettext
from wtforms import PasswordField, StringField
return {
"sas_token": PasswordField(lazy_gettext("SAS Token (optional)"), widget=BS3PasswordFieldWidget()),
"connection_string": StringField(
lazy_gettext("Connection String (optional)"), widget=BS3TextFieldWidget()
),
}
@staticmethod
[docs] def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["schema", "port", "host", "extra"],
"relabeling": {
"login": "Blob Storage Login (optional)",
"password": "Blob Storage Key (optional)",
},
"placeholders": {
"login": "account name or account url",
"password": "secret",
"sas_token": "account url or token (optional)",
"connection_string": "account url or token (optional)",
},
}
def __init__(
self,
share_name: str | None = None,
file_path: str | None = None,
directory_path: str | None = None,
azure_fileshare_conn_id: str = "azure_fileshare_default",
) -> None:
super().__init__()
self._conn_id = azure_fileshare_conn_id
self.share_name = share_name
self.file_path = file_path
self.directory_path = directory_path
self._account_url: str | None = None
self._connection_string: str | None = None
self._account_access_key: str | None = None
self._sas_token: str | None = None
[docs] def get_conn(self) -> None:
conn = self.get_connection(self._conn_id)
extras = conn.extra_dejson
self._connection_string = extras.get("connection_string")
if conn.login:
self._account_url = self._parse_account_url(conn.login)
self._sas_token = extras.get("sas_token")
self._account_access_key = conn.password
@staticmethod
def _parse_account_url(account_url: str) -> str:
if not account_url.lower().startswith("https"):
return f"https://{account_url}.file.core.windows.net"
return account_url
def _get_sync_default_azure_credential(self):
conn = self.get_connection(self._conn_id)
extras = conn.extra_dejson
managed_identity_client_id = extras.get("managed_identity_client_id")
workload_identity_tenant_id = extras.get("workload_identity_tenant_id")
return get_sync_default_azure_credential(
managed_identity_client_id=managed_identity_client_id,
workload_identity_tenant_id=workload_identity_tenant_id,
)
@property
[docs] def share_service_client(self):
self.get_conn()
if self._connection_string:
return ShareServiceClient.from_connection_string(
conn_str=self._connection_string,
)
elif self._account_url and (self._sas_token or self._account_access_key):
credential = self._sas_token or self._account_access_key
return ShareServiceClient(account_url=self._account_url, credential=credential)
else:
return ShareServiceClient(
account_url=self._account_url,
credential=self._get_sync_default_azure_credential(),
token_intent="backup",
)
@property
[docs] def share_directory_client(self):
if self._connection_string:
return ShareDirectoryClient.from_connection_string(
conn_str=self._connection_string,
share_name=self.share_name,
directory_path=self.directory_path,
)
elif self._account_url and (self._sas_token or self._account_access_key):
credential = self._sas_token or self._account_access_key
return ShareDirectoryClient(
account_url=self._account_url,
share_name=self.share_name,
directory_path=self.directory_path,
credential=credential,
)
else:
return ShareDirectoryClient(
account_url=self._account_url,
share_name=self.share_name,
directory_path=self.directory_path,
credential=self._get_sync_default_azure_credential(),
token_intent="backup",
)
@property
[docs] def share_file_client(self):
if self._connection_string:
return ShareFileClient.from_connection_string(
conn_str=self._connection_string,
share_name=self.share_name,
file_path=self.file_path,
)
elif self._account_url and (self._sas_token or self._account_access_key):
credential = self._sas_token or self._account_access_key
return ShareFileClient(
account_url=self._account_url,
share_name=self.share_name,
file_path=self.file_path,
credential=credential,
)
else:
return ShareFileClient(
account_url=self._account_url,
share_name=self.share_name,
file_path=self.file_path,
credential=self._get_sync_default_azure_credential(),
token_intent="backup",
)
[docs] def check_for_directory(self) -> bool:
"""Check if a directory exists on Azure File Share."""
return self.share_directory_client.exists()
[docs] def list_directories_and_files(self) -> list:
"""Return the list of directories and files stored on a Azure File Share."""
return list(self.share_directory_client.list_directories_and_files())
[docs] def list_files(self) -> list[str]:
"""Return the list of files stored on a Azure File Share."""
return [obj.name for obj in self.list_directories_and_files() if isinstance(obj, FileProperties)]
[docs] def create_share(self, share_name: str, **kwargs) -> bool:
"""
Create new Azure File Share.
:param share_name: Name of the share.
:return: True if share is created, False if share already exists.
"""
try:
self.share_service_client.create_share(share_name, **kwargs)
except Exception as e:
self.log.warning(e)
return False
return True
[docs] def delete_share(self, share_name: str, **kwargs) -> bool:
"""
Delete existing Azure File Share.
:param share_name: Name of the share.
:return: True if share is deleted, False if share does not exist.
"""
try:
self.share_service_client.delete_share(share_name, **kwargs)
except Exception as e:
self.log.warning(e)
return False
return True
[docs] def create_directory(self, **kwargs) -> Any:
"""Create a new directory on a Azure File Share."""
return self.share_directory_client.create_directory(**kwargs)
[docs] def get_file(self, file_path: str, **kwargs) -> None:
"""
Download a file from Azure File Share.
:param file_path: Where to store the file.
"""
with open(file_path, "wb") as file_handle:
data = self.share_file_client.download_file(**kwargs)
data.readinto(file_handle)
[docs] def get_file_to_stream(self, stream: IO, **kwargs) -> None:
"""
Download a file from Azure File Share.
:param stream: A filehandle to store the file to.
"""
data = self.share_file_client.download_file(**kwargs)
data.readinto(stream)
[docs] def load_file(self, file_path: str, **kwargs) -> None:
"""
Upload a file to Azure File Share.
:param file_path: Path to the file to load.
"""
with open(file_path, "rb") as source_file:
self.share_file_client.upload_file(source_file, **kwargs)
[docs] def load_data(self, string_data: bytes | str | IO, **kwargs) -> None:
"""
Upload a string to Azure File Share.
:param string_data: String/Stream to load.
"""
self.share_file_client.upload_file(string_data, **kwargs)
[docs] def test_connection(self):
"""Test Azure FileShare connection."""
success = (True, "Successfully connected to Azure File Share.")
try:
# Attempt to retrieve file share information
next(iter(self.share_service_client.list_shares()))
return success
except StopIteration:
# If the iterator returned is empty it should still be considered a successful connection since
# it's possible to create a storage account without any file share and none could
# legitimately exist yet.
return success
except Exception as e:
return False, str(e)
Was this entry helpful?