Source code for airflow.providers.databricks.hooks.databricks

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Databricks hook.

This hook enable the submitting and running of jobs to the Databricks platform. Internally the
operators talk to the
``api/2.1/jobs/run-now``
`endpoint <https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunNow>_`
or the ``api/2.1/jobs/runs/submit``
`endpoint <https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit>`_.
"""

from __future__ import annotations

import json
from typing import Any

from requests import exceptions as requests_exceptions

from airflow.exceptions import AirflowException
from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook

[docs]GET_CLUSTER_ENDPOINT = ("GET", "api/2.0/clusters/get")
[docs]RESTART_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/restart")
[docs]START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start")
[docs]TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete")
[docs]CREATE_ENDPOINT = ("POST", "api/2.1/jobs/create")
[docs]RESET_ENDPOINT = ("POST", "api/2.1/jobs/reset")
[docs]RUN_NOW_ENDPOINT = ("POST", "api/2.1/jobs/run-now")
[docs]SUBMIT_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/submit")
[docs]GET_RUN_ENDPOINT = ("GET", "api/2.1/jobs/runs/get")
[docs]CANCEL_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/cancel")
[docs]DELETE_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/delete")
[docs]REPAIR_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/repair")
[docs]OUTPUT_RUNS_JOB_ENDPOINT = ("GET", "api/2.1/jobs/runs/get-output")
[docs]CANCEL_ALL_RUNS_ENDPOINT = ("POST", "api/2.1/jobs/runs/cancel-all")
[docs]INSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/install")
[docs]UNINSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/uninstall")
[docs]LIST_JOBS_ENDPOINT = ("GET", "api/2.1/jobs/list")
[docs]LIST_PIPELINES_ENDPOINT = ("GET", "api/2.0/pipelines")
[docs]WORKSPACE_GET_STATUS_ENDPOINT = ("GET", "api/2.0/workspace/get-status")
[docs]SPARK_VERSIONS_ENDPOINT = ("GET", "api/2.0/clusters/spark-versions")
[docs]class RunState: """Utility class for the run state concept of Databricks runs."""
[docs] RUN_LIFE_CYCLE_STATES = [ "PENDING", "RUNNING", "TERMINATING", "TERMINATED", "SKIPPED", "INTERNAL_ERROR", "QUEUED", ]
def __init__( self, life_cycle_state: str, result_state: str = "", state_message: str = "", *args, **kwargs ) -> None: if life_cycle_state not in self.RUN_LIFE_CYCLE_STATES: raise AirflowException( f"Unexpected life cycle state: {life_cycle_state}: If the state has " "been introduced recently, please check the Databricks user " "guide for troubleshooting information" ) self.life_cycle_state = life_cycle_state self.result_state = result_state self.state_message = state_message @property
[docs] def is_terminal(self) -> bool: """True if the current state is a terminal state.""" return self.life_cycle_state in ("TERMINATED", "SKIPPED", "INTERNAL_ERROR")
@property
[docs] def is_successful(self) -> bool: """True if the result state is SUCCESS.""" return self.result_state == "SUCCESS"
[docs] def __eq__(self, other: object) -> bool: if not isinstance(other, RunState): return NotImplemented return ( self.life_cycle_state == other.life_cycle_state and self.result_state == other.result_state and self.state_message == other.state_message )
[docs] def __repr__(self) -> str: return str(self.__dict__)
[docs] def to_json(self) -> str: return json.dumps(self.__dict__)
@classmethod
[docs] def from_json(cls, data: str) -> RunState: return RunState(**json.loads(data))
[docs]class ClusterState: """Utility class for the cluster state concept of Databricks cluster."""
[docs] CLUSTER_LIFE_CYCLE_STATES = [ "PENDING", "RUNNING", "RESTARTING", "RESIZING", "TERMINATING", "TERMINATED", "ERROR", "UNKNOWN", ]
def __init__(self, state: str = "", state_message: str = "", *args, **kwargs) -> None: if state not in self.CLUSTER_LIFE_CYCLE_STATES: raise AirflowException( f"Unexpected cluster life cycle state: {state}: If the state has " "been introduced recently, please check the Databricks user " "guide for troubleshooting information" ) self.state = state self.state_message = state_message @property
[docs] def is_terminal(self) -> bool: """True if the current state is a terminal state.""" return self.state in ("TERMINATING", "TERMINATED", "ERROR", "UNKNOWN")
@property
[docs] def is_running(self) -> bool: """True if the current state is running.""" return self.state in ("RUNNING", "RESIZING")
[docs] def __eq__(self, other) -> bool: return self.state == other.state and self.state_message == other.state_message
[docs] def __repr__(self) -> str: return str(self.__dict__)
[docs] def to_json(self) -> str: return json.dumps(self.__dict__)
@classmethod
[docs] def from_json(cls, data: str) -> ClusterState: return ClusterState(**json.loads(data))
[docs]class DatabricksHook(BaseDatabricksHook): """ Interact with Databricks. :param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`. :param timeout_seconds: The amount of time in seconds the requests library will wait before timing-out. :param retry_limit: The number of times to retry the connection in case of service outages. :param retry_delay: The number of seconds to wait between retries (it might be a floating point number). :param retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. """
[docs] hook_name = "Databricks"
def __init__( self, databricks_conn_id: str = BaseDatabricksHook.default_conn_name, timeout_seconds: int = 180, retry_limit: int = 3, retry_delay: float = 1.0, retry_args: dict[Any, Any] | None = None, caller: str = "DatabricksHook", ) -> None: super().__init__(databricks_conn_id, timeout_seconds, retry_limit, retry_delay, retry_args, caller)
[docs] def create_job(self, json: dict) -> int: """Call the ``api/2.1/jobs/create`` endpoint. :param json: The data used in the body of the request to the ``create`` endpoint. :return: the job_id as an int """ response = self._do_api_call(CREATE_ENDPOINT, json) return response["job_id"]
[docs] def reset_job(self, job_id: str, json: dict) -> None: """Call the ``api/2.1/jobs/reset`` endpoint. :param json: The data used in the new_settings of the request to the ``reset`` endpoint. """ self._do_api_call(RESET_ENDPOINT, {"job_id": job_id, "new_settings": json})
[docs] def run_now(self, json: dict) -> int: """ Call the ``api/2.1/jobs/run-now`` endpoint. :param json: The data used in the body of the request to the ``run-now`` endpoint. :return: the run_id as an int """ response = self._do_api_call(RUN_NOW_ENDPOINT, json) return response["run_id"]
[docs] def submit_run(self, json: dict) -> int: """ Call the ``api/2.1/jobs/runs/submit`` endpoint. :param json: The data used in the body of the request to the ``submit`` endpoint. :return: the run_id as an int """ response = self._do_api_call(SUBMIT_RUN_ENDPOINT, json) return response["run_id"]
[docs] def list_jobs( self, limit: int = 25, expand_tasks: bool = False, job_name: str | None = None, page_token: str | None = None, ) -> list[dict[str, Any]]: """ List the jobs in the Databricks Job Service. :param limit: The limit/batch size used to retrieve jobs. :param expand_tasks: Whether to include task and cluster details in the response. :param job_name: Optional name of a job to search. :param page_token: The optional page token pointing at the first first job to return. :return: A list of jobs. """ has_more = True all_jobs = [] if page_token is None: page_token = "" while has_more: payload: dict[str, Any] = { "limit": limit, "expand_tasks": expand_tasks, } payload["page_token"] = page_token if job_name: payload["name"] = job_name response = self._do_api_call(LIST_JOBS_ENDPOINT, payload) jobs = response.get("jobs", []) if job_name: all_jobs += [j for j in jobs if j["settings"]["name"] == job_name] else: all_jobs += jobs has_more = response.get("has_more", False) if has_more: page_token = response.get("next_page_token", "") return all_jobs
[docs] def find_job_id_by_name(self, job_name: str) -> int | None: """ Find job id by its name; if there are multiple jobs with the same name, raise AirflowException. :param job_name: The name of the job to look up. :return: The job_id as an int or None if no job was found. """ matching_jobs = self.list_jobs(job_name=job_name) if len(matching_jobs) > 1: raise AirflowException( f"There are more than one job with name {job_name}. Please delete duplicated jobs first" ) if not matching_jobs: return None else: return matching_jobs[0]["job_id"]
[docs] def list_pipelines( self, batch_size: int = 25, pipeline_name: str | None = None, notebook_path: str | None = None ) -> list[dict[str, Any]]: """ List the pipelines in Databricks Delta Live Tables. :param batch_size: The limit/batch size used to retrieve pipelines. :param pipeline_name: Optional name of a pipeline to search. Cannot be combined with path. :param notebook_path: Optional notebook of a pipeline to search. Cannot be combined with name. :return: A list of pipelines. """ has_more = True next_token = None all_pipelines = [] filter = None if pipeline_name and notebook_path: raise AirflowException("Cannot combine pipeline_name and notebook_path in one request") if notebook_path: filter = f"notebook='{notebook_path}'" elif pipeline_name: filter = f"name LIKE '{pipeline_name}'" payload: dict[str, Any] = { "max_results": batch_size, } if filter: payload["filter"] = filter while has_more: if next_token is not None: payload = {**payload, "page_token": next_token} response = self._do_api_call(LIST_PIPELINES_ENDPOINT, payload) pipelines = response.get("statuses", []) all_pipelines += pipelines if "next_page_token" in response: next_token = response["next_page_token"] else: has_more = False return all_pipelines
[docs] def find_pipeline_id_by_name(self, pipeline_name: str) -> str | None: """ Find pipeline id by its name; if multiple pipelines with the same name, raise AirflowException. :param pipeline_name: The name of the pipeline to look up. :return: The pipeline_id as a GUID string or None if no pipeline was found. """ matching_pipelines = self.list_pipelines(pipeline_name=pipeline_name) if len(matching_pipelines) > 1: raise AirflowException( f"There are more than one pipelines with name {pipeline_name}. " "Please delete duplicated pipelines first" ) if not pipeline_name or len(matching_pipelines) == 0: return None else: return matching_pipelines[0]["pipeline_id"]
[docs] def get_run_page_url(self, run_id: int) -> str: """ Retrieve run_page_url. :param run_id: id of the run :return: URL of the run page """ json = {"run_id": run_id} response = self._do_api_call(GET_RUN_ENDPOINT, json) return response["run_page_url"]
[docs] async def a_get_run_page_url(self, run_id: int) -> str: """ Async version of `get_run_page_url()`. :param run_id: id of the run :return: URL of the run page """ json = {"run_id": run_id} response = await self._a_do_api_call(GET_RUN_ENDPOINT, json) return response["run_page_url"]
[docs] def get_job_id(self, run_id: int) -> int: """ Retrieve job_id from run_id. :param run_id: id of the run :return: Job id for given Databricks run """ json = {"run_id": run_id} response = self._do_api_call(GET_RUN_ENDPOINT, json) return response["job_id"]
[docs] def get_run_state(self, run_id: int) -> RunState: """ Retrieve run state of the run. Please note that any Airflow tasks that call the ``get_run_state`` method will result in failure unless you have enabled xcom pickling. This can be done using the following environment variable: ``AIRFLOW__CORE__ENABLE_XCOM_PICKLING`` If you do not want to enable xcom pickling, use the ``get_run_state_str`` method to get a string describing state, or ``get_run_state_lifecycle``, ``get_run_state_result``, or ``get_run_state_message`` to get individual components of the run state. :param run_id: id of the run :return: state of the run """ json = {"run_id": run_id} response = self._do_api_call(GET_RUN_ENDPOINT, json) state = response["state"] return RunState(**state)
[docs] async def a_get_run_state(self, run_id: int) -> RunState: """ Async version of `get_run_state()`. :param run_id: id of the run :return: state of the run """ json = {"run_id": run_id} response = await self._a_do_api_call(GET_RUN_ENDPOINT, json) state = response["state"] return RunState(**state)
[docs] def get_run(self, run_id: int) -> dict[str, Any]: """ Retrieve run information. :param run_id: id of the run :return: state of the run """ json = {"run_id": run_id} response = self._do_api_call(GET_RUN_ENDPOINT, json) return response
[docs] async def a_get_run(self, run_id: int) -> dict[str, Any]: """ Async version of `get_run`. :param run_id: id of the run :return: state of the run """ json = {"run_id": run_id} response = await self._a_do_api_call(GET_RUN_ENDPOINT, json) return response
[docs] def get_run_state_str(self, run_id: int) -> str: """ Return the string representation of RunState. :param run_id: id of the run :return: string describing run state """ state = self.get_run_state(run_id) run_state_str = ( f"State: {state.life_cycle_state}. Result: {state.result_state}. {state.state_message}" ) return run_state_str
[docs] def get_run_state_lifecycle(self, run_id: int) -> str: """ Return the lifecycle state of the run. :param run_id: id of the run :return: string with lifecycle state """ return self.get_run_state(run_id).life_cycle_state
[docs] def get_run_state_result(self, run_id: int) -> str: """ Return the resulting state of the run. :param run_id: id of the run :return: string with resulting state """ return self.get_run_state(run_id).result_state
[docs] def get_run_state_message(self, run_id: int) -> str: """ Return the state message for the run. :param run_id: id of the run :return: string with state message """ return self.get_run_state(run_id).state_message
[docs] def get_run_output(self, run_id: int) -> dict: """ Retrieve run output of the run. :param run_id: id of the run :return: output of the run """ json = {"run_id": run_id} run_output = self._do_api_call(OUTPUT_RUNS_JOB_ENDPOINT, json) return run_output
[docs] def cancel_run(self, run_id: int) -> None: """ Cancel the run. :param run_id: id of the run """ json = {"run_id": run_id} self._do_api_call(CANCEL_RUN_ENDPOINT, json)
[docs] def cancel_all_runs(self, job_id: int) -> None: """ Cancel all active runs of a job asynchronously. :param job_id: The canonical identifier of the job to cancel all runs of """ json = {"job_id": job_id} self._do_api_call(CANCEL_ALL_RUNS_ENDPOINT, json)
[docs] def delete_run(self, run_id: int) -> None: """ Delete a non-active run. :param run_id: id of the run """ json = {"run_id": run_id} self._do_api_call(DELETE_RUN_ENDPOINT, json)
[docs] def repair_run(self, json: dict) -> int: """ Re-run one or more tasks. :param json: repair a job run. """ response = self._do_api_call(REPAIR_RUN_ENDPOINT, json) return response["repair_id"]
[docs] def get_latest_repair_id(self, run_id: int) -> int | None: """Get latest repair id if any exist for run_id else None.""" json = {"run_id": run_id, "include_history": "true"} response = self._do_api_call(GET_RUN_ENDPOINT, json) repair_history = response["repair_history"] if len(repair_history) == 1: return None else: return repair_history[-1]["id"]
[docs] def get_cluster_state(self, cluster_id: str) -> ClusterState: """ Retrieve run state of the cluster. :param cluster_id: id of the cluster :return: state of the cluster """ json = {"cluster_id": cluster_id} response = self._do_api_call(GET_CLUSTER_ENDPOINT, json) state = response["state"] state_message = response["state_message"] return ClusterState(state, state_message)
[docs] async def a_get_cluster_state(self, cluster_id: str) -> ClusterState: """ Async version of `get_cluster_state`. :param cluster_id: id of the cluster :return: state of the cluster """ json = {"cluster_id": cluster_id} response = await self._a_do_api_call(GET_CLUSTER_ENDPOINT, json) state = response["state"] state_message = response["state_message"] return ClusterState(state, state_message)
[docs] def restart_cluster(self, json: dict) -> None: """ Restarts the cluster. :param json: json dictionary containing cluster specification. """ self._do_api_call(RESTART_CLUSTER_ENDPOINT, json)
[docs] def start_cluster(self, json: dict) -> None: """ Start the cluster. :param json: json dictionary containing cluster specification. """ self._do_api_call(START_CLUSTER_ENDPOINT, json)
[docs] def terminate_cluster(self, json: dict) -> None: """ Terminate the cluster. :param json: json dictionary containing cluster specification. """ self._do_api_call(TERMINATE_CLUSTER_ENDPOINT, json)
[docs] def install(self, json: dict) -> None: """ Install libraries on the cluster. Utility function to call the ``2.0/libraries/install`` endpoint. :param json: json dictionary containing cluster_id and an array of library """ self._do_api_call(INSTALL_LIBS_ENDPOINT, json)
[docs] def uninstall(self, json: dict) -> None: """ Uninstall libraries on the cluster. Utility function to call the ``2.0/libraries/uninstall`` endpoint. :param json: json dictionary containing cluster_id and an array of library """ self._do_api_call(UNINSTALL_LIBS_ENDPOINT, json)
[docs] def update_repo(self, repo_id: str, json: dict[str, Any]) -> dict: """ Update given Databricks Repos. :param repo_id: ID of Databricks Repos :param json: payload :return: metadata from update """ repos_endpoint = ("PATCH", f"api/2.0/repos/{repo_id}") return self._do_api_call(repos_endpoint, json)
[docs] def delete_repo(self, repo_id: str): """ Delete given Databricks Repos. :param repo_id: ID of Databricks Repos :return: """ repos_endpoint = ("DELETE", f"api/2.0/repos/{repo_id}") self._do_api_call(repos_endpoint)
[docs] def create_repo(self, json: dict[str, Any]) -> dict: """ Create a Databricks Repos. :param json: payload :return: """ repos_endpoint = ("POST", "api/2.0/repos") return self._do_api_call(repos_endpoint, json)
[docs] def get_repo_by_path(self, path: str) -> str | None: """ Obtain Repos ID by path. :param path: path to a repository :return: Repos ID if it exists, None if doesn't. """ try: result = self._do_api_call(WORKSPACE_GET_STATUS_ENDPOINT, {"path": path}, wrap_http_errors=False) if result.get("object_type", "") == "REPO": return str(result["object_id"]) except requests_exceptions.HTTPError as e: if e.response.status_code != 404: raise e return None
[docs] def update_job_permission(self, job_id: int, json: dict[str, Any]) -> dict: """ Update databricks job permission. :param job_id: job id :param json: payload :return: json containing permission specification """ return self._do_api_call(("PATCH", f"api/2.0/permissions/jobs/{job_id}"), json)
[docs] def test_connection(self) -> tuple[bool, str]: """Test the Databricks connectivity from UI.""" hook = DatabricksHook(databricks_conn_id=self.databricks_conn_id) try: hook._do_api_call(endpoint_info=SPARK_VERSIONS_ENDPOINT).get("versions") status = True message = "Connection successfully tested" except Exception as e: status = False message = str(e) return status, message

Was this entry helpful?