Source code for airflow.providers.databricks.operators.databricks_workflow
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportjsonimporttimefromdataclassesimportdataclassfromfunctoolsimportcached_propertyfromtypingimportTYPE_CHECKING,Anyfrommergedeepimportmergefromairflow.exceptionsimportAirflowExceptionfromairflow.modelsimportBaseOperatorfromairflow.providers.databricks.hooks.databricksimportDatabricksHook,RunLifeCycleStatefromairflow.providers.databricks.plugins.databricks_workflowimport(WorkflowJobRepairAllFailedLink,WorkflowJobRunLink,)fromairflow.utils.task_groupimportTaskGroupifTYPE_CHECKING:fromtypesimportTracebackTypefromairflow.models.taskmixinimportDAGNodefromairflow.utils.contextimportContext@dataclass
[docs]classWorkflowRunMetadata:""" Metadata for a Databricks workflow run. :param run_id: The ID of the Databricks workflow run. :param job_id: The ID of the Databricks workflow job. :param conn_id: The connection ID used to connect to Databricks. """
def_flatten_node(node:TaskGroup|BaseOperator|DAGNode,tasks:list[BaseOperator]|None=None)->list[BaseOperator]:"""Flatten a node (either a TaskGroup or Operator) to a list of nodes."""iftasksisNone:tasks=[]ifisinstance(node,BaseOperator):return[node]ifisinstance(node,TaskGroup):new_tasks=[]for_,childinnode.children.items():new_tasks+=_flatten_node(child,tasks)returntasks+new_tasksreturntasksclass_CreateDatabricksWorkflowOperator(BaseOperator):""" Creates a Databricks workflow from a DatabricksWorkflowTaskGroup specified in a DAG. :param task_id: The task_id of the operator :param databricks_conn_id: The connection ID to use when connecting to Databricks. :param existing_clusters: A list of existing clusters to use for the workflow. :param extra_job_params: A dictionary of extra properties which will override the default Databricks Workflow Job definitions. :param job_clusters: A list of job clusters to use for the workflow. :param max_concurrent_runs: The maximum number of concurrent runs for the workflow. :param notebook_params: A dictionary of notebook parameters to pass to the workflow. These parameters will be passed to all notebooks in the workflow. :param tasks_to_convert: A list of tasks to convert to a Databricks workflow. This list can also be populated after instantiation using the `add_task` method. """operator_extra_links=(WorkflowJobRunLink(),WorkflowJobRepairAllFailedLink())template_fields=("notebook_params",)caller="_CreateDatabricksWorkflowOperator"def__init__(self,task_id:str,databricks_conn_id:str,existing_clusters:list[str]|None=None,extra_job_params:dict[str,Any]|None=None,job_clusters:list[dict[str,object]]|None=None,max_concurrent_runs:int=1,notebook_params:dict|None=None,tasks_to_convert:list[BaseOperator]|None=None,**kwargs,):self.databricks_conn_id=databricks_conn_idself.existing_clusters=existing_clustersor[]self.extra_job_params=extra_job_paramsor{}self.job_clusters=job_clustersor[]self.max_concurrent_runs=max_concurrent_runsself.notebook_params=notebook_paramsor{}self.tasks_to_convert=tasks_to_convertor[]self.relevant_upstreams=[task_id]self.workflow_run_metadata:WorkflowRunMetadata|None=Nonesuper().__init__(task_id=task_id,**kwargs)def_get_hook(self,caller:str)->DatabricksHook:returnDatabricksHook(self.databricks_conn_id,caller=caller,)@cached_propertydef_hook(self)->DatabricksHook:returnself._get_hook(caller=self.caller)defadd_task(self,task:BaseOperator)->None:"""Add a task to the list of tasks to convert to a Databricks workflow."""self.tasks_to_convert.append(task)@propertydefjob_name(self)->str:ifnotself.task_group:raiseAirflowException("Task group must be set before accessing job_name")returnf"{self.dag_id}.{self.task_group.group_id}"defcreate_workflow_json(self,context:Context|None=None)->dict[str,object]:"""Create a workflow json to be used in the Databricks API."""task_json=[task._convert_to_databricks_workflow_task(# type: ignore[attr-defined]relevant_upstreams=self.relevant_upstreams,context=context)fortaskinself.tasks_to_convert]default_json={"name":self.job_name,"email_notifications":{"no_alert_for_skipped_runs":False},"timeout_seconds":0,"tasks":task_json,"format":"MULTI_TASK","job_clusters":self.job_clusters,"max_concurrent_runs":self.max_concurrent_runs,}returnmerge(default_json,self.extra_job_params)def_create_or_reset_job(self,context:Context)->int:job_spec=self.create_workflow_json(context=context)existing_jobs=self._hook.list_jobs(job_name=self.job_name)job_id=existing_jobs[0]["job_id"]ifexisting_jobselseNoneifjob_id:self.log.info("Updating existing Databricks workflow job %s with spec %s",self.job_name,json.dumps(job_spec,indent=2),)self._hook.reset_job(job_id,job_spec)else:self.log.info("Creating new Databricks workflow job %s with spec %s",self.job_name,json.dumps(job_spec,indent=2),)job_id=self._hook.create_job(job_spec)returnjob_iddef_wait_for_job_to_start(self,run_id:int)->None:run_url=self._hook.get_run_page_url(run_id)self.log.info("Check the progress of the Databricks job at %s",run_url)life_cycle_state=self._hook.get_run_state(run_id).life_cycle_stateiflife_cycle_statenotin(RunLifeCycleState.PENDING.value,RunLifeCycleState.RUNNING.value,RunLifeCycleState.BLOCKED.value,):raiseAirflowException(f"Could not start the workflow job. State: {life_cycle_state}")whilelife_cycle_statein(RunLifeCycleState.PENDING.value,RunLifeCycleState.BLOCKED.value):self.log.info("Waiting for the Databricks job to start running")time.sleep(5)life_cycle_state=self._hook.get_run_state(run_id).life_cycle_stateself.log.info("Databricks job started. State: %s",life_cycle_state)defexecute(self,context:Context)->Any:ifnotisinstance(self.task_group,DatabricksWorkflowTaskGroup):raiseAirflowException("Task group must be a DatabricksWorkflowTaskGroup")job_id=self._create_or_reset_job(context)run_id=self._hook.run_now({"job_id":job_id,"jar_params":self.task_group.jar_params,"notebook_params":self.notebook_params,"python_params":self.task_group.python_params,"spark_submit_params":self.task_group.spark_submit_params,})self._wait_for_job_to_start(run_id)self.workflow_run_metadata=WorkflowRunMetadata(self.databricks_conn_id,job_id,run_id,)return{"conn_id":self.databricks_conn_id,"job_id":job_id,"run_id":run_id,}defon_kill(self)->None:ifself.workflow_run_metadata:run_id=self.workflow_run_metadata.run_idjob_id=self.workflow_run_metadata.job_idself._hook.cancel_run(run_id)self.log.info("Run: %(run_id)s of job_id: %(job_id)s was requested to be cancelled.",{"run_id":run_id,"job_id":job_id},)else:self.log.error(""" Error: Workflow Run metadata is not populated, so the run was not canceled. This could be due to the workflow not being started or an error in the workflow creation process. """)
[docs]classDatabricksWorkflowTaskGroup(TaskGroup):""" A task group that takes a list of tasks and creates a databricks workflow. The DatabricksWorkflowTaskGroup takes a list of tasks and creates a databricks workflow based on the metadata produced by those tasks. For a task to be eligible for this TaskGroup, it must contain the ``_convert_to_databricks_workflow_task`` method. If any tasks do not contain this method then the Taskgroup will raise an error at parse time. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:DatabricksWorkflowTaskGroup` :param databricks_conn_id: The name of the databricks connection to use. :param existing_clusters: A list of existing clusters to use for this workflow. :param extra_job_params: A dictionary containing properties which will override the default Databricks Workflow Job definitions. :param jar_params: A list of jar parameters to pass to the workflow. These parameters will be passed to all jar tasks in the workflow. :param job_clusters: A list of job clusters to use for this workflow. :param max_concurrent_runs: The maximum number of concurrent runs for this workflow. :param notebook_packages: A list of dictionary of Python packages to be installed. Packages defined at the workflow task group level are installed for each of the notebook tasks under it. And packages defined at the notebook task level are installed specific for the notebook task. :param notebook_params: A dictionary of notebook parameters to pass to the workflow. These parameters will be passed to all notebook tasks in the workflow. :param python_params: A list of python parameters to pass to the workflow. These parameters will be passed to all python tasks in the workflow. :param spark_submit_params: A list of spark submit parameters to pass to the workflow. These parameters will be passed to all spark submit tasks. """
[docs]def__exit__(self,_type:type[BaseException]|None,_value:BaseException|None,_tb:TracebackType|None)->None:"""Exit the context manager and add tasks to a single ``_CreateDatabricksWorkflowOperator``."""roots=list(self.get_roots())tasks=_flatten_node(self)create_databricks_workflow_task=_CreateDatabricksWorkflowOperator(dag=self.dag,task_group=self,task_id="launch",databricks_conn_id=self.databricks_conn_id,existing_clusters=self.existing_clusters,extra_job_params=self.extra_job_params,job_clusters=self.job_clusters,max_concurrent_runs=self.max_concurrent_runs,notebook_params=self.notebook_params,)fortaskintasks:ifnot(hasattr(task,"_convert_to_databricks_workflow_task")andcallable(task._convert_to_databricks_workflow_task)):raiseAirflowException(f"Task {task.task_id} does not support conversion to databricks workflow task.")task.workflow_run_metadata=create_databricks_workflow_task.outputcreate_databricks_workflow_task.relevant_upstreams.append(task.task_id)create_databricks_workflow_task.add_task(task)forroot_taskinroots:root_task.set_upstream(create_databricks_workflow_task)super().__exit__(_type,_value,_tb)