Source code for

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence

from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from import Execution

from airflow.exceptions import AirflowException, AirflowSkipException
from import WorkflowsHook
from airflow.sensors.base import BaseSensorOperator

    from google.api_core.retry import Retry

    from airflow.utils.context import Context

[docs]class WorkflowExecutionSensor(BaseSensorOperator): """ Checks state of an execution for the given ``workflow_id`` and ``execution_id``. :param workflow_id: Required. The ID of the workflow. :param execution_id: Required. The ID of the execution. :param project_id: Required. The ID of the Google Cloud project the cluster belongs to. :param location: Required. The Cloud Dataproc region in which to handle the request. :param success_states: Execution states to be considered as successful, by default it's only ``SUCCEEDED`` state :param failure_states: Execution states to be considered as failures, by default they are ``FAILED`` and ``CANCELLED`` states. :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be retried. :param request_timeout: The amount of time, in seconds, to wait for the request to complete. Note that if ``retry`` is specified, the timeout applies to each individual attempt. :param metadata: Additional metadata that is provided to the method. """
[docs] template_fields: Sequence[str] = ("location", "workflow_id", "execution_id")
def __init__( self, *, workflow_id: str, execution_id: str, location: str, project_id: str = PROVIDE_PROJECT_ID, success_states: set[Execution.State] | None = None, failure_states: set[Execution.State] | None = None, retry: Retry | _MethodDefault = DEFAULT, request_timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, **kwargs, ): super().__init__(**kwargs) self.success_states = success_states or {Execution.State(Execution.State.SUCCEEDED)} self.failure_states = failure_states or { Execution.State(Execution.State.FAILED), Execution.State(Execution.State.CANCELLED), } self.workflow_id = workflow_id self.execution_id = execution_id self.location = location self.project_id = project_id self.retry = retry self.request_timeout = request_timeout self.metadata = metadata self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain
[docs] def poke(self, context: Context): hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)"Checking state of execution %s for workflow %s", self.execution_id, self.workflow_id) execution: Execution = hook.get_execution( workflow_id=self.workflow_id, execution_id=self.execution_id, location=self.location, project_id=self.project_id, retry=self.retry, timeout=self.request_timeout, metadata=self.metadata, ) state = execution.state if state in self.failure_states: # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = ( f"Execution {self.execution_id} for workflow {self.execution_id} " f"failed and is in `{state}` state" ) if self.soft_fail: raise AirflowSkipException(message) raise AirflowException(message) if state in self.success_states: "Execution %s for workflow %s completed with state: %s", self.execution_id, self.workflow_id, state, ) return True "Execution %s for workflow %s does not completed yet, current state: %s", self.execution_id, self.workflow_id, state, ) return False

Was this entry helpful?