Source code for airflow.contrib.hooks.azure_data_lake_hook

# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#

from airflow.hooks.base_hook import BaseHook
from azure.datalake.store import core, lib, multithread


[docs]class AzureDataLakeHook(BaseHook): """ Interacts with Azure Data Lake. Client ID and client secret should be in user and password parameters. Tenant and account name should be extra field as {"tenant": "<TENANT>", "account_name": "ACCOUNT_NAME"}. :param azure_data_lake_conn_id: Reference to the Azure Data Lake connection. :type azure_data_lake_conn_id: str """ def __init__(self, azure_data_lake_conn_id='azure_data_lake_default'): self.conn_id = azure_data_lake_conn_id self.connection = self.get_conn()
[docs] def get_conn(self): """Return a AzureDLFileSystem object.""" conn = self.get_connection(self.conn_id) service_options = conn.extra_dejson self.account_name = service_options.get('account_name') adlCreds = lib.auth(tenant_id=service_options.get('tenant'), client_secret=conn.password, client_id=conn.login) adlsFileSystemClient = core.AzureDLFileSystem(adlCreds, store_name=self.account_name) adlsFileSystemClient.connect() return adlsFileSystemClient
[docs] def check_for_file(self, file_path): """ Check if a file exists on Azure Data Lake. :param file_path: Path and name of the file. :type file_path: str :return: True if the file exists, False otherwise. :rtype: bool """ try: files = self.connection.glob(file_path, details=False, invalidate_cache=True) return len(files) == 1 except FileNotFoundError: return False
[docs] def upload_file(self, local_path, remote_path, nthreads=64, overwrite=True, buffersize=4194304, blocksize=4194304): """ Upload a file to Azure Data Lake. :param local_path: local path. Can be single file, directory (in which case, upload recursively) or glob pattern. Recursive glob patterns using `**` are not supported. :type local_path: str :param remote_path: Remote path to upload to; if multiple files, this is the directory root to write within. :type remote_path: str :param nthreads: Number of threads to use. If None, uses the number of cores. :type nthreads: int :param overwrite: Whether to forcibly overwrite existing files/directories. If False and remote path is a directory, will quit regardless if any files would be overwritten or not. If True, only matching filenames are actually overwritten. :type overwrite: bool :param buffersize: int [2**22] Number of bytes for internal buffer. This block cannot be bigger than a chunk and cannot be smaller than a block. :type buffersize: int :param blocksize: int [2**22] Number of bytes for a block. Within each chunk, we write a smaller block for each API call. This block cannot be bigger than a chunk. :type blocksize: int """ multithread.ADLUploader(self.connection, lpath=local_path, rpath=remote_path, nthreads=nthreads, overwrite=overwrite, buffersize=buffersize, blocksize=blocksize)
[docs] def download_file(self, local_path, remote_path, nthreads=64, overwrite=True, buffersize=4194304, blocksize=4194304): """ Download a file from Azure Blob Storage. :param local_path: local path. If downloading a single file, will write to this specific file, unless it is an existing directory, in which case a file is created within it. If downloading multiple files, this is the root directory to write within. Will create directories as required. :type local_path: str :param remote_path: remote path/globstring to use to find remote files. Recursive glob patterns using `**` are not supported. :type remote_path: str :param nthreads: Number of threads to use. If None, uses the number of cores. :type nthreads: int :param overwrite: Whether to forcibly overwrite existing files/directories. If False and remote path is a directory, will quit regardless if any files would be overwritten or not. If True, only matching filenames are actually overwritten. :type overwrite: bool :param buffersize: int [2**22] Number of bytes for internal buffer. This block cannot be bigger than a chunk and cannot be smaller than a block. :type buffersize: int :param blocksize: int [2**22] Number of bytes for a block. Within each chunk, we write a smaller block for each API call. This block cannot be bigger than a chunk. :type blocksize: int """ multithread.ADLDownloader(self.connection, lpath=local_path, rpath=remote_path, nthreads=nthreads, overwrite=overwrite, buffersize=buffersize, blocksize=blocksize)
[docs] def list(self, path): """ List files in Azure Data Lake Storage :param path: full path/globstring to use to list files in ADLS :type path: str """ if "*" in path: return self.connection.glob(path) else: return self.connection.walk(path)

Was this entry helpful?