#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains Google Dataprep hook."""
from __future__ import annotations
import json
from enum import Enum
from typing import Any
from urllib.parse import urljoin
import requests
from requests import HTTPError
from tenacity import retry, stop_after_attempt, wait_exponential
from airflow.hooks.base import BaseHook
def _get_field(extras: dict, field_name: str):
"""Get field from extra, first checking short name, then for backcompat we check for prefixed name."""
backcompat_prefix = "extra__dataprep__"
if field_name.startswith("extra__"):
raise ValueError(
f"Got prefixed name {field_name}; please remove the '{backcompat_prefix}' prefix "
"when using this method."
)
if field_name in extras:
return extras[field_name] or None
prefixed_name = f"{backcompat_prefix}{field_name}"
return extras.get(prefixed_name) or None
[docs]class JobGroupStatuses(str, Enum):
"""Types of job group run statuses."""
[docs] UNDEFINED = "undefined"
[docs] IN_PROGRESS = "InProgress"
[docs]class GoogleDataprepHook(BaseHook):
"""
Hook for connection with Dataprep API.
To get connection Dataprep with Airflow you need Dataprep token.
https://clouddataprep.com/documentation/api#section/Authentication
It should be added to the Connection in Airflow in JSON format.
"""
[docs] conn_name_attr = "dataprep_conn_id"
[docs] default_conn_name = "google_cloud_dataprep_default"
[docs] hook_name = "Google Dataprep"
def __init__(self, dataprep_conn_id: str = default_conn_name, api_version: str = "v4") -> None:
super().__init__()
self.dataprep_conn_id = dataprep_conn_id
self.api_version = api_version
conn = self.get_connection(self.dataprep_conn_id)
extras = conn.extra_dejson
self._token = _get_field(extras, "token")
self._base_url = _get_field(extras, "base_url") or "https://api.clouddataprep.com"
@property
def _headers(self) -> dict[str, str]:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self._token}",
}
return headers
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
[docs] def get_jobs_for_job_group(self, job_id: int) -> dict[str, Any]:
"""
Get information about the batch jobs within a Cloud Dataprep job.
:param job_id: The ID of the job that will be fetched
"""
endpoint_path = f"{self.api_version}/jobGroups/{job_id}/jobs"
url: str = urljoin(self._base_url, endpoint_path)
response = requests.get(url, headers=self._headers)
self._raise_for_status(response)
return response.json()
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
[docs] def get_job_group(self, job_group_id: int, embed: str, include_deleted: bool) -> dict[str, Any]:
"""
Get the specified job group.
A job group is a job that is executed from a specific node in a flow.
:param job_group_id: The ID of the job that will be fetched
:param embed: Comma-separated list of objects to pull in as part of the response
:param include_deleted: if set to "true", will include deleted objects
"""
params: dict[str, Any] = {"embed": embed, "includeDeleted": include_deleted}
endpoint_path = f"{self.api_version}/jobGroups/{job_group_id}"
url: str = urljoin(self._base_url, endpoint_path)
response = requests.get(url, headers=self._headers, params=params)
self._raise_for_status(response)
return response.json()
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
[docs] def run_job_group(self, body_request: dict) -> dict[str, Any]:
"""
Creates a ``jobGroup``, which launches the specified job as the authenticated user.
This performs the same action as clicking on the Run Job button in the application.
To get recipe_id please follow the Dataprep API documentation
https://clouddataprep.com/documentation/api#operation/runJobGroup.
:param body_request: The identifier for the recipe you would like to run.
"""
endpoint_path = f"{self.api_version}/jobGroups"
url: str = urljoin(self._base_url, endpoint_path)
response = requests.post(url, headers=self._headers, data=json.dumps(body_request))
self._raise_for_status(response)
return response.json()
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
[docs] def create_flow(self, *, body_request: dict) -> dict:
"""
Creates flow.
:param body_request: Body of the POST request to be sent.
For more details check https://clouddataprep.com/documentation/api#operation/createFlow
"""
endpoint = f"/{self.api_version}/flows"
url: str = urljoin(self._base_url, endpoint)
response = requests.post(url, headers=self._headers, data=json.dumps(body_request))
self._raise_for_status(response)
return response.json()
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
[docs] def copy_flow(
self, *, flow_id: int, name: str = "", description: str = "", copy_datasources: bool = False
) -> dict:
"""
Create a copy of the provided flow id, as well as all contained recipes.
:param flow_id: ID of the flow to be copied
:param name: Name for the copy of the flow
:param description: Description of the copy of the flow
:param copy_datasources: Bool value to define should copies of data inputs be made or not.
"""
endpoint_path = f"{self.api_version}/flows/{flow_id}/copy"
url: str = urljoin(self._base_url, endpoint_path)
body_request = {
"name": name,
"description": description,
"copyDatasources": copy_datasources,
}
response = requests.post(url, headers=self._headers, data=json.dumps(body_request))
self._raise_for_status(response)
return response.json()
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
[docs] def delete_flow(self, *, flow_id: int) -> None:
"""
Delete the flow with the provided id.
:param flow_id: ID of the flow to be copied
"""
endpoint_path = f"{self.api_version}/flows/{flow_id}"
url: str = urljoin(self._base_url, endpoint_path)
response = requests.delete(url, headers=self._headers)
self._raise_for_status(response)
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
[docs] def run_flow(self, *, flow_id: int, body_request: dict) -> dict:
"""
Runs the flow with the provided id copy of the provided flow id.
:param flow_id: ID of the flow to be copied
:param body_request: Body of the POST request to be sent.
"""
endpoint = f"{self.api_version}/flows/{flow_id}/run"
url: str = urljoin(self._base_url, endpoint)
response = requests.post(url, headers=self._headers, data=json.dumps(body_request))
self._raise_for_status(response)
return response.json()
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
[docs] def get_job_group_status(self, *, job_group_id: int) -> JobGroupStatuses:
"""
Check the status of the Dataprep task to be finished.
:param job_group_id: ID of the job group to check
"""
endpoint = f"/{self.api_version}/jobGroups/{job_group_id}/status"
url: str = urljoin(self._base_url, endpoint)
response = requests.get(url, headers=self._headers)
self._raise_for_status(response)
return response.json()
def _raise_for_status(self, response: requests.models.Response) -> None:
try:
response.raise_for_status()
except HTTPError:
self.log.error(response.json().get("exception"))
raise
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
[docs] def create_imported_dataset(self, *, body_request: dict) -> dict:
"""
Creates imported dataset.
:param body_request: Body of the POST request to be sent.
For more details check https://clouddataprep.com/documentation/api#operation/createImportedDataset
"""
endpoint = f"/{self.api_version}/importedDatasets"
url: str = urljoin(self._base_url, endpoint)
response = requests.post(url, headers=self._headers, data=json.dumps(body_request))
self._raise_for_status(response)
return response.json()
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
[docs] def create_wrangled_dataset(self, *, body_request: dict) -> dict:
"""
Creates wrangled dataset.
:param body_request: Body of the POST request to be sent.
For more details check
https://clouddataprep.com/documentation/api#operation/createWrangledDataset
"""
endpoint = f"/{self.api_version}/wrangledDatasets"
url: str = urljoin(self._base_url, endpoint)
response = requests.post(url, headers=self._headers, data=json.dumps(body_request))
self._raise_for_status(response)
return response.json()
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
[docs] def create_output_object(self, *, body_request: dict) -> dict:
"""
Creates output.
:param body_request: Body of the POST request to be sent.
For more details check
https://clouddataprep.com/documentation/api#operation/createOutputObject
"""
endpoint = f"/{self.api_version}/outputObjects"
url: str = urljoin(self._base_url, endpoint)
response = requests.post(url, headers=self._headers, data=json.dumps(body_request))
self._raise_for_status(response)
return response.json()
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
[docs] def create_write_settings(self, *, body_request: dict) -> dict:
"""
Creates write settings.
:param body_request: Body of the POST request to be sent.
For more details check
https://clouddataprep.com/documentation/api#tag/createWriteSetting
"""
endpoint = f"/{self.api_version}/writeSettings"
url: str = urljoin(self._base_url, endpoint)
response = requests.post(url, headers=self._headers, data=json.dumps(body_request))
self._raise_for_status(response)
return response.json()
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
[docs] def delete_imported_dataset(self, *, dataset_id: int) -> None:
"""
Deletes imported dataset.
:param dataset_id: ID of the imported dataset for removal.
"""
endpoint = f"/{self.api_version}/importedDatasets/{dataset_id}"
url: str = urljoin(self._base_url, endpoint)
response = requests.delete(url, headers=self._headers)
self._raise_for_status(response)