# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import time
import uuid
from googleapiclient.discovery import build
from zope.deprecation import deprecation
from airflow.version import version
from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
from airflow.utils.log.logging_mixin import LoggingMixin
[docs]class _DataProcJob(LoggingMixin):
def __init__(self, dataproc_api, project_id, job, region='global',
job_error_states=None, num_retries=5):
self.dataproc_api = dataproc_api
self.project_id = project_id
self.region = region
self.num_retries = num_retries
self.job_error_states = job_error_states
# Check if the job to submit is already running on the cluster.
# If so, don't resubmit the job.
try:
cluster_name = job['job']['placement']['clusterName']
except KeyError:
self.log.error('Job to submit is incorrectly configured.')
raise
jobs_on_cluster_response = dataproc_api.projects().regions().jobs().list(
projectId=self.project_id,
region=self.region,
clusterName=cluster_name).execute()
UUID_LENGTH = 9
jobs_on_cluster = jobs_on_cluster_response.get('jobs', [])
try:
task_id_to_submit = job['job']['reference']['jobId'][:-UUID_LENGTH]
except KeyError:
self.log.error('Job to submit is incorrectly configured.')
raise
# There is a small set of states that we will accept as sufficient
# for attaching the new task instance to the old Dataproc job. We
# generally err on the side of _not_ attaching, unless the prior
# job is in a known-good state. For example, we don't attach to an
# ERRORed job because we want Airflow to be able to retry the job.
# The full set of possible states is here:
# https://cloud.google.com/dataproc/docs/reference/rest/v1beta2/projects.regions.jobs#State
recoverable_states = frozenset([
'PENDING',
'SETUP_DONE',
'RUNNING',
'DONE',
])
found_match = False
for job_on_cluster in jobs_on_cluster:
job_on_cluster_id = job_on_cluster['reference']['jobId']
job_on_cluster_task_id = job_on_cluster_id[:-UUID_LENGTH]
if task_id_to_submit == job_on_cluster_task_id:
self.job = job_on_cluster
self.job_id = self.job['reference']['jobId']
found_match = True
# We can stop looking once we find a matching job in a recoverable state.
if self.job['status']['state'] in recoverable_states:
break
if found_match and self.job['status']['state'] in recoverable_states:
message = """
Reattaching to previously-started DataProc job %s (in state %s).
If this is not the desired behavior (ie if you would like to re-run this job),
please delete the previous instance of the job by running:
gcloud --project %s dataproc jobs delete %s --region %s
"""
self.log.info(
message,
self.job_id,
str(self.job['status']['state']),
self.project_id,
self.job_id,
self.region,
)
return
self.job = dataproc_api.projects().regions().jobs().submit(
projectId=self.project_id,
region=self.region,
body=job).execute(num_retries=self.num_retries)
self.job_id = self.job['reference']['jobId']
self.log.info(
'DataProc job %s is %s',
self.job_id, str(self.job['status']['state'])
)
[docs] def wait_for_done(self):
while True:
self.job = self.dataproc_api.projects().regions().jobs().get(
projectId=self.project_id,
region=self.region,
jobId=self.job_id).execute(num_retries=self.num_retries)
if 'ERROR' == self.job['status']['state']:
self.log.error('DataProc job %s has errors', self.job_id)
self.log.error(self.job['status']['details'])
self.log.debug(str(self.job))
return False
if 'CANCELLED' == self.job['status']['state']:
self.log.warning('DataProc job %s is cancelled', self.job_id)
if 'details' in self.job['status']:
self.log.warning(self.job['status']['details'])
self.log.debug(str(self.job))
return False
if 'DONE' == self.job['status']['state']:
return True
self.log.debug(
'DataProc job %s is %s',
self.job_id, str(self.job['status']['state'])
)
time.sleep(5)
[docs] def raise_error(self, message=None):
job_state = self.job['status']['state']
# We always consider ERROR to be an error state.
if (self.job_error_states and job_state in self.job_error_states) or 'ERROR' == job_state:
ex_message = message or ("Google DataProc job has state: %s" % job_state)
ex_details = (str(self.job['status']['details'])
if 'details' in self.job['status']
else "No details available")
raise Exception(ex_message + ": " + ex_details)
[docs] def get(self):
return self.job
[docs]class _DataProcJobBuilder:
def __init__(self, project_id, task_id, cluster_name, job_type, properties):
name = task_id + "_" + str(uuid.uuid4())[:8]
self.job_type = job_type
self.job = {
"job": {
"reference": {
"projectId": project_id,
"jobId": name,
},
"placement": {
"clusterName": cluster_name
},
"labels": {'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')},
job_type: {
}
}
}
if properties is not None:
self.job["job"][job_type]["properties"] = properties
[docs] def add_labels(self, labels):
"""
Set labels for Dataproc job.
:param labels: Labels for the job query.
:type labels: dict
"""
if labels:
self.job["job"]["labels"].update(labels)
[docs] def add_variables(self, variables):
if variables is not None:
self.job["job"][self.job_type]["scriptVariables"] = variables
[docs] def add_args(self, args):
if args is not None:
self.job["job"][self.job_type]["args"] = args
[docs] def add_query(self, query):
self.job["job"][self.job_type]["queryList"] = {'queries': [query]}
[docs] def add_query_uri(self, query_uri):
self.job["job"][self.job_type]["queryFileUri"] = query_uri
[docs] def add_jar_file_uris(self, jars):
if jars is not None:
self.job["job"][self.job_type]["jarFileUris"] = jars
[docs] def add_archive_uris(self, archives):
if archives is not None:
self.job["job"][self.job_type]["archiveUris"] = archives
[docs] def add_file_uris(self, files):
if files is not None:
self.job["job"][self.job_type]["fileUris"] = files
[docs] def add_python_file_uris(self, pyfiles):
if pyfiles is not None:
self.job["job"][self.job_type]["pythonFileUris"] = pyfiles
[docs] def set_main(self, main_jar, main_class):
if main_class is not None and main_jar is not None:
raise Exception("Set either main_jar or main_class")
if main_jar:
self.job["job"][self.job_type]["mainJarFileUri"] = main_jar
else:
self.job["job"][self.job_type]["mainClass"] = main_class
[docs] def set_python_main(self, main):
self.job["job"][self.job_type]["mainPythonFileUri"] = main
[docs] def set_job_name(self, name):
self.job["job"]["reference"]["jobId"] = name + "_" + str(uuid.uuid4())[:8]
[docs] def build(self):
return self.job
[docs]class _DataProcOperation(LoggingMixin):
"""Continuously polls Dataproc Operation until it completes."""
def __init__(self, dataproc_api, operation, num_retries):
self.dataproc_api = dataproc_api
self.operation = operation
self.operation_name = self.operation['name']
self.num_retries = num_retries
[docs] def wait_for_done(self):
if self._check_done():
return True
self.log.info(
'Waiting for Dataproc Operation %s to finish', self.operation_name)
while True:
time.sleep(10)
self.operation = (
self.dataproc_api.projects()
.regions()
.operations()
.get(name=self.operation_name)
.execute(num_retries=self.num_retries))
if self._check_done():
return True
[docs] def get(self):
return self.operation
[docs] def _check_done(self):
if 'done' in self.operation:
if 'error' in self.operation:
self.log.warning(
'Dataproc Operation %s failed with error: %s',
self.operation_name, self.operation['error']['message'])
self._raise_error()
else:
self.log.info(
'Dataproc Operation %s done', self.operation['name'])
return True
return False
[docs] def _raise_error(self):
raise Exception('Google Dataproc Operation %s failed: %s' %
(self.operation_name, self.operation['error']['message']))
[docs]class DataProcHook(GoogleCloudBaseHook):
"""Hook for Google Cloud Dataproc APIs."""
def __init__(self,
gcp_conn_id='google_cloud_default',
delegate_to=None,
api_version='v1beta2'):
super(DataProcHook, self).__init__(gcp_conn_id, delegate_to)
self.api_version = api_version
[docs] def get_conn(self):
"""Returns a Google Cloud Dataproc service object."""
http_authorized = self._authorize()
return build(
'dataproc', self.api_version, http=http_authorized,
cache_discovery=False)
[docs] def get_cluster(self, project_id, region, cluster_name):
return self.get_conn().projects().regions().clusters().get(
projectId=project_id,
region=region,
clusterName=cluster_name
).execute(num_retries=self.num_retries)
[docs] def submit(self, project_id, job, region='global', job_error_states=None):
submitted = _DataProcJob(self.get_conn(), project_id, job, region,
job_error_states=job_error_states,
num_retries=self.num_retries)
if not submitted.wait_for_done():
submitted.raise_error()
[docs] def create_job_template(self, task_id, cluster_name, job_type, properties):
return _DataProcJobBuilder(self.project_id, task_id, cluster_name,
job_type, properties)
[docs] def wait(self, operation):
"""Awaits for Google Cloud Dataproc Operation to complete."""
submitted = _DataProcOperation(self.get_conn(), operation,
self.num_retries)
submitted.wait_for_done()
[docs] def cancel(self, project_id, job_id, region='global'):
"""
Cancel a Google Cloud DataProc job.
:param project_id: Name of the project the job belongs to
:type project_id: str
:param job_id: Identifier of the job to cancel
:type job_id: int
:param region: Region used for the job
:type region: str
:returns A Job json dictionary representing the canceled job
"""
return self.get_conn().projects().regions().jobs().cancel(
projectId=project_id,
region=region,
jobId=job_id
)
setattr(
DataProcHook,
"await",
deprecation.deprecated(
DataProcHook.wait, "renamed to 'wait' for Python3.7 compatibility"
),
)