Source code for

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
This module contains a Google Storage Transfer Service Hook.

.. spelling::



from __future__ import annotations

import json
import logging
import time
import warnings
from copy import deepcopy
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Sequence

from import (
from googleapiclient.discovery import Resource, build
from googleapiclient.errors import HttpError

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from import CLIENT_INFO
from import (

    from import (
    from proto import Message

[docs]log = logging.getLogger(__name__)
# Time to sleep between active checks of the operation results
[docs]class GcpTransferJobsStatus: """Google Cloud Transfer job status."""
[docs]class GcpTransferOperationStatus: """Google Cloud Transfer operation status."""
[docs] PAUSED = "PAUSED"
[docs] FAILED = "FAILED"
# A list of keywords used to build a request or response
[docs]ACCESS_KEY_ID = "accessKeyId"
[docs]ALREADY_EXISTING_IN_SINK = "overwriteObjectsAlreadyExistingInSink"
[docs]AWS_ACCESS_KEY = "awsAccessKey"
[docs]AWS_SECRET_ACCESS_KEY = "secretAccessKey"
[docs]AWS_S3_DATA_SOURCE = "awsS3DataSource"
[docs]AWS_ROLE_ARN = "roleArn"
[docs]BODY = "body"
[docs]BUCKET_NAME = "bucketName"
[docs]COUNTERS = "counters"
[docs]DAY = "day"
[docs]DESCRIPTION = "description"
[docs]FILTER = "filter"
[docs]FILTER_JOB_NAMES = "job_names"
[docs]FILTER_PROJECT_ID = "project_id"
[docs]GCS_DATA_SINK = "gcsDataSink"
[docs]GCS_DATA_SOURCE = "gcsDataSource"
[docs]HOURS = "hours"
[docs]HTTP_DATA_SOURCE = "httpDataSource"
[docs]INCLUDE_PREFIXES = "includePrefixes"
[docs]JOB_NAME = "name"
[docs]LIST_URL = "list_url"
[docs]METADATA = "metadata"
[docs]MINUTES = "minutes"
[docs]MONTH = "month"
[docs]NAME = "name"
[docs]OBJECT_CONDITIONS = "object_conditions"
[docs]OPERATIONS = "operations"
[docs]OVERWRITE_OBJECTS_ALREADY_EXISTING_IN_SINK = "overwriteObjectsAlreadyExistingInSink"
[docs]PATH = "path"
[docs]PROJECT_ID = "projectId"
[docs]SCHEDULE = "schedule"
[docs]SCHEDULE_END_DATE = "scheduleEndDate"
[docs]SCHEDULE_START_DATE = "scheduleStartDate"
[docs]SECONDS = "seconds"
[docs]SECRET_ACCESS_KEY = "secretAccessKey"
[docs]START_TIME_OF_DAY = "startTimeOfDay"
[docs]STATUS = "status"
[docs]STATUS1 = "status"
[docs]TRANSFER_JOB = "transfer_job"
[docs]TRANSFER_JOBS = "transferJobs"
[docs]TRANSFER_JOB_FIELD_MASK = "update_transfer_job_field_mask"
[docs]TRANSFER_OPERATIONS = "transferOperations"
[docs]TRANSFER_OPTIONS = "transfer_options"
[docs]TRANSFER_SPEC = "transferSpec"
[docs]YEAR = "year"
[docs]NEGATIVE_STATUSES = {GcpTransferOperationStatus.FAILED, GcpTransferOperationStatus.ABORTED}
[docs]def gen_job_name(job_name: str) -> str: """Add a unique suffix to the job name. :param job_name: :return: job_name with suffix """ uniq = int(time.time()) return f"{job_name}_{uniq}"
[docs]class CloudDataTransferServiceHook(GoogleBaseHook): """Google Storage Transfer Service functionalities. All methods in the hook with *project_id* in the signature must be called with keyword arguments rather than positional. """ def __init__( self, api_version: str = "v1", gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: if "delegate_to" in kwargs: raise RuntimeError( "The `delegate_to` parameter has been deprecated before and " "finally removed in this version of Google Provider. You MUST " "convert it to `impersonate_chain`." ) super().__init__( gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, ) self.api_version = api_version self._conn = None
[docs] def get_conn(self) -> Resource: """Retrieve connection to Google Storage Transfer service. :return: Google Storage Transfer service object """ if not self._conn: http_authorized = self._authorize() self._conn = build( "storagetransfer", self.api_version, http=http_authorized, cache_discovery=False ) return self._conn
[docs] def create_transfer_job(self, body: dict) -> dict: """Create a transfer job that runs periodically. :param body: (Required) The request body, as described in :return: The transfer job. See: """ body = self._inject_project_id(body, BODY, PROJECT_ID) try: transfer_job = ( self.get_conn().transferJobs().create(body=body).execute(num_retries=self.num_retries) ) except HttpError as e: # If status code "Conflict" # # we should try to find this job job_name = body.get(JOB_NAME, "") if int(e.resp.status) == ALREADY_EXIST_CODE and job_name: transfer_job = self.get_transfer_job(job_name=job_name, project_id=body.get(PROJECT_ID)) # Generate new job_name, if jobs status is deleted # and try to create this job again if transfer_job.get(STATUS) == GcpTransferJobsStatus.DELETED: body[JOB_NAME] = gen_job_name(job_name) "Job `%s` has been soft deleted. Creating job with new name `%s`", job_name, {body[JOB_NAME]}, ) return ( self.get_conn().transferJobs().create(body=body).execute(num_retries=self.num_retries) ) elif transfer_job.get(STATUS) == GcpTransferJobsStatus.DISABLED: return self.enable_transfer_job(job_name=job_name, project_id=body.get(PROJECT_ID)) else: raise e"Created job %s", transfer_job[NAME]) return transfer_job
[docs] def get_transfer_job(self, job_name: str, project_id: str) -> dict: """Get latest state of a long-running Google Storage Transfer Service job. :param job_name: (Required) Name of the job to be fetched :param project_id: (Optional) the ID of the project that owns the Transfer Job. If set to None or missing, the default project_id from the Google Cloud connection is used. :return: Transfer Job """ return ( self.get_conn() .transferJobs() .get(jobName=job_name, projectId=project_id) .execute(num_retries=self.num_retries) )
[docs] def list_transfer_job(self, request_filter: dict | None = None, **kwargs) -> list[dict]: """List long-running operations in Google Storage Transfer Service. A filter can be specified to match only certain entries. :param request_filter: (Required) A request filter, as described in :return: List of Transfer Jobs """ # To preserve backward compatibility # TODO: remove one day if request_filter is None: if "filter" in kwargs: request_filter = kwargs["filter"] if not isinstance(request_filter, dict): raise ValueError(f"The request_filter should be dict and is {type(request_filter)}") warnings.warn( "Use 'request_filter' instead of 'filter'", AirflowProviderDeprecationWarning, stacklevel=2, ) else: raise TypeError("list_transfer_job missing 1 required positional argument: 'request_filter'") conn = self.get_conn() request_filter = self._inject_project_id(request_filter, FILTER, FILTER_PROJECT_ID) request = conn.transferJobs().list(filter=json.dumps(request_filter)) jobs: list[dict] = [] while request is not None: response = request.execute(num_retries=self.num_retries) jobs.extend(response[TRANSFER_JOBS]) request = conn.transferJobs().list_next(previous_request=request, previous_response=response) return jobs
[docs] def enable_transfer_job(self, job_name: str, project_id: str) -> dict: """Make new transfers be performed based on the schedule. :param job_name: (Required) Name of the job to be updated :param project_id: (Optional) the ID of the project that owns the Transfer Job. If set to None or missing, the default project_id from the Google Cloud connection is used. :return: If successful, TransferJob. """ return ( self.get_conn() .transferJobs() .patch( jobName=job_name, body={ PROJECT_ID: project_id, TRANSFER_JOB: {STATUS1: GcpTransferJobsStatus.ENABLED}, TRANSFER_JOB_FIELD_MASK: STATUS1, }, ) .execute(num_retries=self.num_retries) )
[docs] def update_transfer_job(self, job_name: str, body: dict) -> dict: """Update a transfer job that runs periodically. :param job_name: (Required) Name of the job to be updated :param body: A request body, as described in :return: If successful, TransferJob. """ body = self._inject_project_id(body, BODY, PROJECT_ID) return ( self.get_conn() .transferJobs() .patch(jobName=job_name, body=body) .execute(num_retries=self.num_retries) )
[docs] def delete_transfer_job(self, job_name: str, project_id: str) -> None: """Delete a transfer job. This is a soft delete. After a transfer job is deleted, the job and all the transfer executions are subject to garbage collection. Transfer jobs become eligible for garbage collection 30 days after soft delete. :param job_name: (Required) Name of the job to be deleted :param project_id: (Optional) the ID of the project that owns the Transfer Job. If set to None or missing, the default project_id from the Google Cloud connection is used. """ ( self.get_conn() .transferJobs() .patch( jobName=job_name, body={ PROJECT_ID: project_id, TRANSFER_JOB: {STATUS1: GcpTransferJobsStatus.DELETED}, TRANSFER_JOB_FIELD_MASK: STATUS1, }, ) .execute(num_retries=self.num_retries) )
[docs] def cancel_transfer_operation(self, operation_name: str) -> None: """Cancel a transfer operation in Google Storage Transfer Service. :param operation_name: Name of the transfer operation. """ self.get_conn().transferOperations().cancel(name=operation_name).execute(num_retries=self.num_retries)
[docs] def get_transfer_operation(self, operation_name: str) -> dict: """Get a transfer operation in Google Storage Transfer Service. :param operation_name: (Required) Name of the transfer operation. :return: transfer operation .. seealso:: """ return ( self.get_conn() .transferOperations() .get(name=operation_name) .execute(num_retries=self.num_retries) )
[docs] def list_transfer_operations(self, request_filter: dict | None = None, **kwargs) -> list[dict]: """Get a transfer operation in Google Storage Transfer Service. :param request_filter: (Required) A request filter, as described in With one additional improvement: :return: transfer operation The ``project_id`` parameter is optional if you have a project ID defined in the connection. See: :doc:`/connections/gcp` """ # To preserve backward compatibility # TODO: remove one day if request_filter is None: if "filter" in kwargs: request_filter = kwargs["filter"] if not isinstance(request_filter, dict): raise ValueError(f"The request_filter should be dict and is {type(request_filter)}") warnings.warn( "Use 'request_filter' instead of 'filter'", AirflowProviderDeprecationWarning, stacklevel=2, ) else: raise TypeError( "list_transfer_operations missing 1 required positional argument: 'request_filter'" ) conn = self.get_conn() request_filter = self._inject_project_id(request_filter, FILTER, FILTER_PROJECT_ID) operations: list[dict] = [] request = conn.transferOperations().list(name=TRANSFER_OPERATIONS, filter=json.dumps(request_filter)) while request is not None: response = request.execute(num_retries=self.num_retries) if OPERATIONS in response: operations.extend(response[OPERATIONS]) request = conn.transferOperations().list_next( previous_request=request, previous_response=response ) return operations
[docs] def pause_transfer_operation(self, operation_name: str) -> None: """Pause a transfer operation in Google Storage Transfer Service. :param operation_name: (Required) Name of the transfer operation. """ self.get_conn().transferOperations().pause(name=operation_name).execute(num_retries=self.num_retries)
[docs] def resume_transfer_operation(self, operation_name: str) -> None: """Resume a transfer operation in Google Storage Transfer Service. :param operation_name: (Required) Name of the transfer operation. """ self.get_conn().transferOperations().resume(name=operation_name).execute(num_retries=self.num_retries)
[docs] def wait_for_transfer_job( self, job: dict, expected_statuses: set[str] | None = None, timeout: float | timedelta | None = None, ) -> None: """Wait until the job reaches the expected state. :param job: The transfer job to wait for. See: :param expected_statuses: The expected state. See: :param timeout: Time in which the operation must end in seconds. If not specified, defaults to 60 seconds. """ expected_statuses = ( {GcpTransferOperationStatus.SUCCESS} if not expected_statuses else expected_statuses ) if timeout is None: timeout = 60 elif isinstance(timeout, timedelta): timeout = timeout.total_seconds() start_time = time.monotonic() while time.monotonic() - start_time < timeout: request_filter = {FILTER_PROJECT_ID: job[PROJECT_ID], FILTER_JOB_NAMES: [job[NAME]]} operations = self.list_transfer_operations(request_filter=request_filter) for operation in operations:"Progress for operation %s: %s", operation[NAME], operation[METADATA][COUNTERS]) if self.operations_contain_expected_statuses(operations, expected_statuses): return time.sleep(TIME_TO_SLEEP_IN_SECONDS) raise AirflowException("Timeout. The operation could not be completed within the allotted time.")
def _inject_project_id(self, body: dict, param_name: str, target_key: str) -> dict: body = deepcopy(body) body[target_key] = body.get(target_key, self.project_id) if not body.get(target_key): raise AirflowException( f"The project id must be passed either as `{target_key}` key in `{param_name}` " f"parameter or as project_id extra in Google Cloud connection definition. Both are not set!" ) return body @staticmethod
[docs] def operations_contain_expected_statuses( operations: list[dict], expected_statuses: set[str] | str ) -> bool: """Check whether an operation exists with the expected status. :param operations: (Required) List of transfer operations to check. :param expected_statuses: (Required) The expected status. See: :return: If there is an operation with the expected state in the operation list, returns true, :raises AirflowException: If it encounters operations with state FAILED or ABORTED in the list. """ expected_statuses_set = ( {expected_statuses} if isinstance(expected_statuses, str) else set(expected_statuses) ) if not operations: return False current_statuses = {operation[METADATA][STATUS] for operation in operations} if len(current_statuses - expected_statuses_set) != len(current_statuses): return True if len(NEGATIVE_STATUSES - current_statuses) != len(NEGATIVE_STATUSES): raise AirflowException( f"An unexpected operation status was encountered. " f"Expected: {', '.join(expected_statuses_set)}" ) return False
[docs]class CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook): """Asynchronous hook for Google Storage Transfer Service.""" def __init__(self, project_id: str = PROVIDE_PROJECT_ID, **kwargs: Any) -> None: super().__init__(**kwargs) self.project_id = project_id self._client: StorageTransferServiceAsyncClient | None = None
[docs] async def get_conn(self) -> StorageTransferServiceAsyncClient: """ Return async connection to the Storage Transfer Service. :return: Google Storage Transfer asynchronous client. """ if not self._client: credentials = (await self.get_sync_hook()).get_credentials() self._client = StorageTransferServiceAsyncClient( credentials=credentials, client_info=CLIENT_INFO, ) return self._client
[docs] async def get_jobs(self, job_names: list[str]) -> ListTransferJobsAsyncPager: """ Get the latest state of a long-running operations in Google Storage Transfer Service. :param job_names: (Required) List of names of the jobs to be fetched. :return: Object that yields Transfer jobs. """ client = await self.get_conn() jobs_list_request = ListTransferJobsRequest( filter=json.dumps({"project_id": self.project_id, "job_names": job_names}) ) return await client.list_transfer_jobs(request=jobs_list_request)
[docs] async def get_latest_operation(self, job: TransferJob) -> Message | None: """ Get the latest operation of the given TransferJob instance. :param job: Transfer job instance. :return: The latest job operation. """ latest_operation_name = job.latest_operation_name if latest_operation_name: client = await self.get_conn() response_operation = await client.transport.operations_client.get_operation(latest_operation_name) operation = TransferOperation.deserialize(response_operation.metadata.value) return operation return None

