# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import re
import select
import subprocess
import time
import uuid
from googleapiclient.discovery import build
from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
from airflow.utils.log.logging_mixin import LoggingMixin
# This is the default location
# https://cloud.google.com/dataflow/pipelines/specifying-exec-params
[docs]DEFAULT_DATAFLOW_LOCATION = 'us-central1' 
[docs]class _DataflowJob(LoggingMixin):
    def __init__(self, dataflow, project_number, name, location, poll_sleep=10,
                 job_id=None, num_retries=None):
        self._dataflow = dataflow
        self._project_number = project_number
        self._job_name = name
        self._job_location = location
        self._job_id = job_id
        self._num_retries = num_retries
        self._job = self._get_job()
        self._poll_sleep = poll_sleep
[docs]    def _get_job_id_from_name(self):
        jobs = self._dataflow.projects().locations().jobs().list(
            projectId=self._project_number,
            location=self._job_location
        ).execute(num_retries=self._num_retries)
        for job in jobs['jobs']:
            if job['name'].lower() == self._job_name.lower():
                self._job_id = job['id']
                return job
        return None 
[docs]    def _get_job(self):
        if self._job_id:
            job = self._dataflow.projects().locations().jobs().get(
                projectId=self._project_number,
                location=self._job_location,
                jobId=self._job_id).execute(num_retries=self._num_retries)
        elif self._job_name:
            job = self._get_job_id_from_name()
        else:
            raise Exception('Missing both dataflow job ID and name.')
        if job and 'currentState' in job:
            self.log.info(
                'Google Cloud DataFlow job %s is %s',
                job['name'], job['currentState']
            )
        elif job:
            self.log.info(
                'Google Cloud DataFlow with job_id %s has name %s',
                self._job_id, job['name']
            )
        else:
            self.log.info(
                'Google Cloud DataFlow job not available yet..'
            )
        return job 
[docs]    def wait_for_done(self):
        while True:
            if self._job and 'currentState' in self._job:
                if 'JOB_STATE_DONE' == self._job['currentState']:
                    return True
                elif 'JOB_STATE_RUNNING' == self._job['currentState'] and \
                     
'JOB_TYPE_STREAMING' == self._job['type']:
                    return True
                elif 'JOB_STATE_FAILED' == self._job['currentState']:
                    raise Exception("Google Cloud Dataflow job {} has failed.".format(
                        self._job['name']))
                elif 'JOB_STATE_CANCELLED' == self._job['currentState']:
                    raise Exception("Google Cloud Dataflow job {} was cancelled.".format(
                        self._job['name']))
                elif 'JOB_STATE_RUNNING' == self._job['currentState']:
                    time.sleep(self._poll_sleep)
                elif 'JOB_STATE_PENDING' == self._job['currentState']:
                    time.sleep(15)
                else:
                    self.log.debug(str(self._job))
                    raise Exception(
                        "Google Cloud Dataflow job {} was unknown state: {}".format(
                            self._job['name'], self._job['currentState']))
            else:
                time.sleep(15)
            self._job = self._get_job() 
[docs]    def get(self):
        return self._job  
[docs]class _Dataflow(LoggingMixin):
    def __init__(self, cmd):
        self.log.info("Running command: %s", ' '.join(cmd))
        self._proc = subprocess.Popen(
            cmd,
            shell=False,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            close_fds=True)
[docs]    def _line(self, fd):
        if fd == self._proc.stderr.fileno():
            line = b''.join(self._proc.stderr.readlines())
            if line:
                self.log.warning(line[:-1])
            return line
        if fd == self._proc.stdout.fileno():
            line = b''.join(self._proc.stdout.readlines())
            if line:
                self.log.info(line[:-1])
            return line 
    @staticmethod
[docs]    def wait_for_done(self):
        reads = [self._proc.stderr.fileno(), self._proc.stdout.fileno()]
        self.log.info("Start waiting for DataFlow process to complete.")
        job_id = None
        # Make sure logs are processed regardless whether the subprocess is
        # terminated.
        process_ends = False
        while True:
            ret = select.select(reads, [], [], 5)
            if ret is not None:
                for fd in ret[0]:
                    line = self._line(fd)
                    if line:
                        job_id = job_id or self._extract_job(line)
            else:
                self.log.info("Waiting for DataFlow process to complete.")
            if process_ends:
                break
            if self._proc.poll() is not None:
                # Mark process completion but allows its outputs to be consumed.
                process_ends = True
        if self._proc.returncode != 0:
            raise Exception("DataFlow failed with return code {}".format(
                self._proc.returncode))
        return job_id  
[docs]class DataFlowHook(GoogleCloudBaseHook):
    def __init__(self,
                 gcp_conn_id='google_cloud_default',
                 delegate_to=None,
                 poll_sleep=10):
        self.poll_sleep = poll_sleep
        super(DataFlowHook, self).__init__(gcp_conn_id, delegate_to)
[docs]    def get_conn(self):
        """
        Returns a Google Cloud Dataflow service object.
        """
        http_authorized = self._authorize()
        return build(
            'dataflow', 'v1b3', http=http_authorized, cache_discovery=False) 
    @GoogleCloudBaseHook._Decorators.provide_gcp_credential_file
[docs]    def _start_dataflow(self, variables, name, command_prefix, label_formatter):
        variables = self._set_variables(variables)
        cmd = command_prefix + self._build_cmd(variables, label_formatter)
        job_id = _Dataflow(cmd).wait_for_done()
        _DataflowJob(self.get_conn(), variables['project'], name,
                     variables['region'],
                     self.poll_sleep, job_id,
                     self.num_retries).wait_for_done() 
    @staticmethod
[docs]    def _set_variables(variables):
        if variables['project'] is None:
            raise Exception('Project not specified')
        if 'region' not in variables.keys():
            variables['region'] = DEFAULT_DATAFLOW_LOCATION
        return variables 
[docs]    def start_java_dataflow(self, job_name, variables, dataflow, job_class=None,
                            append_job_name=True):
        name = self._build_dataflow_job_name(job_name, append_job_name)
        variables['jobName'] = name
        def label_formatter(labels_dict):
            return ['--labels={}'.format(
                json.dumps(labels_dict).replace(' ', ''))]
        command_prefix = (["java", "-cp", dataflow, job_class] if job_class
                          else ["java", "-jar", dataflow])
        self._start_dataflow(variables, name, command_prefix, label_formatter) 
[docs]    def start_template_dataflow(self, job_name, variables, parameters, dataflow_template,
                                append_job_name=True):
        variables = self._set_variables(variables)
        name = self._build_dataflow_job_name(job_name, append_job_name)
        self._start_template_dataflow(
            name, variables, parameters, dataflow_template) 
[docs]    def start_python_dataflow(self, job_name, variables, dataflow, py_options,
                              append_job_name=True):
        name = self._build_dataflow_job_name(job_name, append_job_name)
        variables['job_name'] = name
        def label_formatter(labels_dict):
            return ['--labels={}={}'.format(key, value)
                    for key, value in labels_dict.items()]
        self._start_dataflow(variables, name, ["python2"] + py_options + [dataflow],
                             label_formatter) 
    @staticmethod
[docs]    def _build_dataflow_job_name(job_name, append_job_name=True):
        base_job_name = str(job_name).replace('_', '-')
        if not re.match(r"^[a-z]([-a-z0-9]*[a-z0-9])?$", base_job_name):
            raise ValueError(
                'Invalid job_name ({}); the name must consist of'
                'only the characters [-a-z0-9], starting with a '
                'letter and ending with a letter or number '.format(base_job_name))
        if append_job_name:
            safe_job_name = base_job_name + "-" + str(uuid.uuid4())[:8]
        else:
            safe_job_name = base_job_name
        return safe_job_name 
    @staticmethod
[docs]    def _build_cmd(variables, label_formatter):
        command = ["--runner=DataflowRunner"]
        if variables is not None:
            for attr, value in variables.items():
                if attr == 'labels':
                    command += label_formatter(value)
                elif value is None or value.__len__() < 1:
                    command.append("--" + attr)
                else:
                    command.append("--" + attr + "=" + value)
        return command 
[docs]    def _start_template_dataflow(self, name, variables, parameters,
                                 dataflow_template):
        # Builds RuntimeEnvironment from variables dictionary
        # https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
        environment = {}
        for key in ['numWorkers', 'maxWorkers', 'zone', 'serviceAccountEmail',
                    'tempLocation', 'bypassTempDirValidation', 'machineType',
                    'additionalExperiments', 'network', 'subnetwork', 'additionalUserLabels']:
            if key in variables:
                environment.update({key: variables[key]})
        body = {"jobName": name,
                "parameters": parameters,
                "environment": environment}
        service = self.get_conn()
        request = service.projects().locations().templates().launch(
            projectId=variables['project'],
            location=variables['region'],
            gcsPath=dataflow_template,
            body=body
        )
        response = request.execute(num_retries=self.num_retries)
        variables = self._set_variables(variables)
        _DataflowJob(self.get_conn(), variables['project'], name, variables['region'],
                     self.poll_sleep, num_retries=self.num_retries).wait_for_done()
        return response