Complete the airflow survey & get a free airflow 3 certification!

Source code for airflow.providers.apache.beam.triggers.beam

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import asyncio
import contextlib
from collections.abc import AsyncIterator
from typing import IO, Any

from airflow.providers.apache.beam.hooks.beam import BeamAsyncHook
from airflow.triggers.base import BaseTrigger, TriggerEvent


[docs] class BeamPipelineBaseTrigger(BaseTrigger): """Base class for Beam Pipeline Triggers.""" @staticmethod def _get_async_hook(*args, **kwargs) -> BeamAsyncHook: return BeamAsyncHook(*args, **kwargs) @staticmethod
[docs] def file_has_gcs_path(file_path: str): return file_path.lower().startswith("gs://")
@staticmethod
[docs] async def provide_gcs_tempfile(gcs_file, gcp_conn_id): try: from airflow.providers.google.cloud.hooks.gcs import GCSHook except ImportError: from airflow.exceptions import AirflowOptionalProviderFeatureException raise AirflowOptionalProviderFeatureException( "Failed to import GCSHook. To use the GCSHook functionality, please install the " "apache-airflow-google-provider." ) gcs_hook = GCSHook(gcp_conn_id=gcp_conn_id) loop = asyncio.get_running_loop() # Running synchronous `enter_context()` method in a separate # thread using the default executor `None`. The `run_in_executor()` function returns the # file object, which is created using gcs function `provide_file()`, asynchronously. # This means we can perform asynchronous operations with this file. create_tmp_file_call = gcs_hook.provide_file(object_url=gcs_file) tmp_gcs_file: IO[str] = await loop.run_in_executor( None, contextlib.ExitStack().enter_context, create_tmp_file_call, ) return tmp_gcs_file
[docs] class BeamPythonPipelineTrigger(BeamPipelineBaseTrigger): """ Trigger to perform checking the Python pipeline status until it reaches terminate state. :param variables: Variables passed to the pipeline. :param py_file: Path to the python file to execute. :param py_options: Additional options. :param py_interpreter: Python version of the Apache Beam pipeline. If `None`, this defaults to the python3. To track python versions supported by beam and related issues check: https://issues.apache.org/jira/browse/BEAM-1251 :param py_requirements: Additional python package(s) to install. If a value is passed to this parameter, a new virtual environment has been created with additional packages installed. You could also install the apache-beam package if it is not installed on your system, or you want to use a different version. :param py_system_site_packages: Whether to include system_site_packages in your virtualenv. See virtualenv documentation for more information. This option is only relevant if the ``py_requirements`` parameter is not None. :param runner: Runner on which pipeline will be run. By default, "DirectRunner" is being used. Other possible options: DataflowRunner, SparkRunner, FlinkRunner, PortableRunner. See: :class:`~providers.apache.beam.hooks.beam.BeamRunnerType` See: https://beam.apache.org/documentation/runners/capability-matrix/ :param gcp_conn_id: Optional. The connection ID to use connecting to Google Cloud. """ def __init__( self, variables: dict, py_file: str, py_options: list[str] | None = None, py_interpreter: str = "python3", py_requirements: list[str] | None = None, py_system_site_packages: bool = False, runner: str = "DirectRunner", gcp_conn_id: str = "google_cloud_default", ): super().__init__()
[docs] self.variables = variables
[docs] self.py_file = py_file
[docs] self.py_options = py_options
[docs] self.py_interpreter = py_interpreter
[docs] self.py_requirements = py_requirements
[docs] self.py_system_site_packages = py_system_site_packages
[docs] self.runner = runner
[docs] self.gcp_conn_id = gcp_conn_id
[docs] def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize BeamPythonPipelineTrigger arguments and classpath.""" return ( "airflow.providers.apache.beam.triggers.beam.BeamPythonPipelineTrigger", { "variables": self.variables, "py_file": self.py_file, "py_options": self.py_options, "py_interpreter": self.py_interpreter, "py_requirements": self.py_requirements, "py_system_site_packages": self.py_system_site_packages, "runner": self.runner, "gcp_conn_id": self.gcp_conn_id, }, )
[docs] async def run(self) -> AsyncIterator[TriggerEvent]: """Get current pipeline status and yields a TriggerEvent.""" hook = self._get_async_hook(runner=self.runner) try: if self.file_has_gcs_path(self.py_file): tmp_gcs_file = await self.provide_gcs_tempfile(self.py_file, self.gcp_conn_id) self.py_file = tmp_gcs_file.name return_code = await hook.start_python_pipeline_async( variables=self.variables, py_file=self.py_file, py_options=self.py_options, py_interpreter=self.py_interpreter, py_requirements=self.py_requirements, py_system_site_packages=self.py_system_site_packages, ) except Exception as e: self.log.exception("Exception occurred while checking for pipeline state") yield TriggerEvent({"status": "error", "message": str(e)}) else: if return_code == 0: yield TriggerEvent( { "status": "success", "message": "Pipeline has finished SUCCESSFULLY", } ) else: yield TriggerEvent({"status": "error", "message": "Operation failed"}) return
[docs] class BeamJavaPipelineTrigger(BeamPipelineBaseTrigger): """ Trigger to perform checking the Java pipeline status until it reaches terminate state. :param variables: Variables passed to the job. :param jar: Name of the jar for the pipeline. :param job_class: Optional. Name of the java class for the pipeline. :param runner: Runner on which pipeline will be run. By default, "DirectRunner" is being used. Other possible options: DataflowRunner, SparkRunner, FlinkRunner, PortableRunner. See: :class:`~providers.apache.beam.hooks.beam.BeamRunnerType` See: https://beam.apache.org/documentation/runners/capability-matrix/ :param gcp_conn_id: Optional. The connection ID to use connecting to Google Cloud. """ def __init__( self, variables: dict, jar: str, job_class: str | None = None, runner: str = "DirectRunner", gcp_conn_id: str = "google_cloud_default", ): super().__init__()
[docs] self.variables = variables
[docs] self.jar = jar
[docs] self.job_class = job_class
[docs] self.runner = runner
[docs] self.gcp_conn_id = gcp_conn_id
[docs] def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize BeamJavaPipelineTrigger arguments and classpath.""" return ( "airflow.providers.apache.beam.triggers.beam.BeamJavaPipelineTrigger", { "variables": self.variables, "jar": self.jar, "job_class": self.job_class, "runner": self.runner, "gcp_conn_id": self.gcp_conn_id, }, )
[docs] async def run(self) -> AsyncIterator[TriggerEvent]: """Get current Java pipeline status and yields a TriggerEvent.""" hook = self._get_async_hook(runner=self.runner) return_code = 0 try: if self.file_has_gcs_path(self.jar): tmp_gcs_file = await self.provide_gcs_tempfile(self.jar, self.gcp_conn_id) self.jar = tmp_gcs_file.name return_code = await hook.start_java_pipeline_async( variables=self.variables, jar=self.jar, job_class=self.job_class, ) except Exception as e: self.log.exception("Exception occurred while starting the Java pipeline") yield TriggerEvent({"status": "error", "message": str(e)}) if return_code == 0: yield TriggerEvent( { "status": "success", "message": "Pipeline has finished SUCCESSFULLY", } ) else: yield TriggerEvent({"status": "error", "message": "Operation failed"}) return

Was this entry helpful?