# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
from airflow.hooks.base_hook import BaseHook
from azure.datalake.store import core, lib, multithread
[docs]class AzureDataLakeHook(BaseHook):
    """
    Interacts with Azure Data Lake.
    Client ID and client secret should be in user and password parameters.
    Tenant and account name should be extra field as
    {"tenant": "<TENANT>", "account_name": "ACCOUNT_NAME"}.
    :param azure_data_lake_conn_id: Reference to the Azure Data Lake connection.
    :type azure_data_lake_conn_id: str
    """
    def __init__(self, azure_data_lake_conn_id='azure_data_lake_default'):
        self.conn_id = azure_data_lake_conn_id
        self.connection = self.get_conn()
[docs]    def get_conn(self):
        """Return a AzureDLFileSystem object."""
        conn = self.get_connection(self.conn_id)
        service_options = conn.extra_dejson
        self.account_name = service_options.get('account_name')
        adlCreds = lib.auth(tenant_id=service_options.get('tenant'),
                            client_secret=conn.password,
                            client_id=conn.login)
        adlsFileSystemClient = core.AzureDLFileSystem(adlCreds,
                                                      store_name=self.account_name)
        adlsFileSystemClient.connect()
        return adlsFileSystemClient 
[docs]    def check_for_file(self, file_path):
        """
        Check if a file exists on Azure Data Lake.
        :param file_path: Path and name of the file.
        :type file_path: str
        :return: True if the file exists, False otherwise.
        :rtype: bool
        """
        try:
            files = self.connection.glob(file_path, details=False, invalidate_cache=True)
            return len(files) == 1
        except FileNotFoundError:
            return False 
[docs]    def upload_file(self, local_path, remote_path, nthreads=64, overwrite=True,
                    buffersize=4194304, blocksize=4194304):
        """
        Upload a file to Azure Data Lake.
        :param local_path: local path. Can be single file, directory (in which case,
            upload recursively) or glob pattern. Recursive glob patterns using `**`
            are not supported.
        :type local_path: str
        :param remote_path: Remote path to upload to; if multiple files, this is the
            directory root to write within.
        :type remote_path: str
        :param nthreads: Number of threads to use. If None, uses the number of cores.
        :type nthreads: int
        :param overwrite: Whether to forcibly overwrite existing files/directories.
            If False and remote path is a directory, will quit regardless if any files
            would be overwritten or not. If True, only matching filenames are actually
            overwritten.
        :type overwrite: bool
        :param buffersize: int [2**22]
            Number of bytes for internal buffer. This block cannot be bigger than
            a chunk and cannot be smaller than a block.
        :type buffersize: int
        :param blocksize: int [2**22]
            Number of bytes for a block. Within each chunk, we write a smaller
            block for each API call. This block cannot be bigger than a chunk.
        :type blocksize: int
        """
        multithread.ADLUploader(self.connection,
                                lpath=local_path,
                                rpath=remote_path,
                                nthreads=nthreads,
                                overwrite=overwrite,
                                buffersize=buffersize,
                                blocksize=blocksize) 
[docs]    def download_file(self, local_path, remote_path, nthreads=64, overwrite=True,
                      buffersize=4194304, blocksize=4194304):
        """
        Download a file from Azure Blob Storage.
        :param local_path: local path. If downloading a single file, will write to this
            specific file, unless it is an existing directory, in which case a file is
            created within it. If downloading multiple files, this is the root
            directory to write within. Will create directories as required.
        :type local_path: str
        :param remote_path: remote path/globstring to use to find remote files.
            Recursive glob patterns using `**` are not supported.
        :type remote_path: str
        :param nthreads: Number of threads to use. If None, uses the number of cores.
        :type nthreads: int
        :param overwrite: Whether to forcibly overwrite existing files/directories.
            If False and remote path is a directory, will quit regardless if any files
            would be overwritten or not. If True, only matching filenames are actually
            overwritten.
        :type overwrite: bool
        :param buffersize: int [2**22]
            Number of bytes for internal buffer. This block cannot be bigger than
            a chunk and cannot be smaller than a block.
        :type buffersize: int
        :param blocksize: int [2**22]
            Number of bytes for a block. Within each chunk, we write a smaller
            block for each API call. This block cannot be bigger than a chunk.
        :type blocksize: int
        """
        multithread.ADLDownloader(self.connection,
                                  lpath=local_path,
                                  rpath=remote_path,
                                  nthreads=nthreads,
                                  overwrite=overwrite,
                                  buffersize=buffersize,
                                  blocksize=blocksize)