# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import warnings
from datetime import datetime, timedelta
from functools import cached_property
from time import sleep
from typing import TYPE_CHECKING, cast
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.operators.python import ShortCircuitOperator
from airflow.providers.amazon.aws.hooks.appflow import AppflowHook
from airflow.providers.amazon.aws.utils import datetime_to_epoch_ms
if TYPE_CHECKING:
from mypy_boto3_appflow.type_defs import (
DescribeFlowExecutionRecordsResponseTypeDef,
ExecutionRecordTypeDef,
TaskTypeDef,
)
from airflow.utils.context import Context
[docs]SUPPORTED_SOURCES = {"salesforce", "zendesk"}
[docs]MANDATORY_FILTER_DATE_MSG = "The filter_date argument is mandatory for {entity}!"
[docs]NOT_SUPPORTED_SOURCE_MSG = "Source {source} is not supported for {entity}!"
[docs]class AppflowBaseOperator(BaseOperator):
"""
Amazon Appflow Base Operator class (not supposed to be used directly in DAGs).
:param source: The source name (Supported: salesforce, zendesk)
:param flow_name: The flow name
:param flow_update: A boolean to enable/disable a flow update before the run
:param source_field: The field name to apply filters
:param filter_date: The date value (or template) to be used in filters.
:param poll_interval: how often in seconds to check the query status
:param max_attempts: how many times to check for status before timing out
:param aws_conn_id: aws connection to use
:param region: aws region to use
:param wait_for_completion: whether to wait for the run to end to return
"""
[docs] template_fields = ("flow_name", "source", "source_field", "filter_date")
[docs] UPDATE_PROPAGATION_TIME: int = 15
def __init__(
self,
flow_name: str,
flow_update: bool,
source: str | None = None,
source_field: str | None = None,
filter_date: str | None = None,
poll_interval: int = 20,
max_attempts: int = 60,
aws_conn_id: str = "aws_default",
region: str | None = None,
wait_for_completion: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
if source is not None and source not in SUPPORTED_SOURCES:
raise ValueError(f"{source} is not a supported source (options: {SUPPORTED_SOURCES})!")
self.filter_date = filter_date
self.flow_name = flow_name
self.source = source
self.source_field = source_field
self.poll_interval = poll_interval
self.max_attempts = max_attempts
self.aws_conn_id = aws_conn_id
self.region = region
self.flow_update = flow_update
self.wait_for_completion = wait_for_completion
@cached_property
[docs] def hook(self) -> AppflowHook:
"""Create and return an AppflowHook."""
return AppflowHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
[docs] def execute(self, context: Context) -> None:
self.filter_date_parsed: datetime | None = (
datetime.fromisoformat(self.filter_date) if self.filter_date else None
)
if self.source is not None:
self.connector_type = self._get_connector_type()
if self.flow_update:
self._update_flow()
# while schedule flows will pick up the update right away, on-demand flows might use out of date
# info if triggered right after an update, so we need to wait a bit for the DB to be consistent.
sleep(AppflowBaseOperator.UPDATE_PROPAGATION_TIME)
self._run_flow(context)
def _get_connector_type(self) -> str:
response = self.hook.conn.describe_flow(flowName=self.flow_name)
connector_type = response["sourceFlowConfig"]["connectorType"]
if self.source != connector_type.lower():
raise ValueError(f"Incompatible source ({self.source} and connector type ({connector_type})!")
return connector_type
def _update_flow(self) -> None:
self.hook.update_flow_filter(flow_name=self.flow_name, filter_tasks=[], set_trigger_ondemand=True)
def _run_flow(self, context) -> str:
execution_id = self.hook.run_flow(
flow_name=self.flow_name,
poll_interval=self.poll_interval,
max_attempts=self.max_attempts,
wait_for_completion=self.wait_for_completion,
)
task_instance = context["task_instance"]
task_instance.xcom_push("execution_id", execution_id)
return execution_id
[docs]class AppflowRunOperator(AppflowBaseOperator):
"""
Execute a Appflow run as is.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:AppflowRunOperator`
:param source: Obsolete, unnecessary for this operator
:param flow_name: The flow name
:param poll_interval: how often in seconds to check the query status
:param aws_conn_id: aws connection to use
:param region: aws region to use
:param wait_for_completion: whether to wait for the run to end to return
"""
def __init__(
self,
flow_name: str,
source: str | None = None,
poll_interval: int = 20,
aws_conn_id: str = "aws_default",
region: str | None = None,
wait_for_completion: bool = True,
**kwargs,
) -> None:
if source is not None:
warnings.warn(
"The `source` parameter is unused when simply running a flow, please remove it.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
super().__init__(
flow_name=flow_name,
flow_update=False,
source_field=None,
filter_date=None,
poll_interval=poll_interval,
aws_conn_id=aws_conn_id,
region=region,
wait_for_completion=wait_for_completion,
**kwargs,
)
[docs]class AppflowRunFullOperator(AppflowBaseOperator):
"""
Execute a Appflow full run removing any filter.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:AppflowRunFullOperator`
:param source: The source name (Supported: salesforce, zendesk)
:param flow_name: The flow name
:param poll_interval: how often in seconds to check the query status
:param aws_conn_id: aws connection to use
:param region: aws region to use
:param wait_for_completion: whether to wait for the run to end to return
"""
def __init__(
self,
source: str,
flow_name: str,
poll_interval: int = 20,
aws_conn_id: str = "aws_default",
region: str | None = None,
wait_for_completion: bool = True,
**kwargs,
) -> None:
if source not in {"salesforce", "zendesk"}:
raise ValueError(NOT_SUPPORTED_SOURCE_MSG.format(source=source, entity="AppflowRunFullOperator"))
super().__init__(
source=source,
flow_name=flow_name,
flow_update=True,
source_field=None,
filter_date=None,
poll_interval=poll_interval,
aws_conn_id=aws_conn_id,
region=region,
wait_for_completion=wait_for_completion,
**kwargs,
)
[docs]class AppflowRunBeforeOperator(AppflowBaseOperator):
"""
Execute a Appflow run after updating the filters to select only previous data.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:AppflowRunBeforeOperator`
:param source: The source name (Supported: salesforce)
:param flow_name: The flow name
:param source_field: The field name to apply filters
:param filter_date: The date value (or template) to be used in filters.
:param poll_interval: how often in seconds to check the query status
:param aws_conn_id: aws connection to use
:param region: aws region to use
:param wait_for_completion: whether to wait for the run to end to return
"""
def __init__(
self,
source: str,
flow_name: str,
source_field: str,
filter_date: str,
poll_interval: int = 20,
aws_conn_id: str = "aws_default",
region: str | None = None,
wait_for_completion: bool = True,
**kwargs,
) -> None:
if not filter_date:
raise ValueError(MANDATORY_FILTER_DATE_MSG.format(entity="AppflowRunBeforeOperator"))
if source != "salesforce":
raise ValueError(
NOT_SUPPORTED_SOURCE_MSG.format(source=source, entity="AppflowRunBeforeOperator")
)
super().__init__(
source=source,
flow_name=flow_name,
flow_update=True,
source_field=source_field,
filter_date=filter_date,
poll_interval=poll_interval,
aws_conn_id=aws_conn_id,
region=region,
wait_for_completion=wait_for_completion,
**kwargs,
)
def _update_flow(self) -> None:
if not self.filter_date_parsed:
raise ValueError(f"Invalid filter_date argument parser value: {self.filter_date_parsed}")
if not self.source_field:
raise ValueError(f"Invalid source_field argument value: {self.source_field}")
filter_task: TaskTypeDef = {
"taskType": "Filter",
"connectorOperator": {self.connector_type: "LESS_THAN"}, # type: ignore
"sourceFields": [self.source_field],
"taskProperties": {
"DATA_TYPE": "datetime",
"VALUE": str(datetime_to_epoch_ms(self.filter_date_parsed)),
}, # NOT inclusive
}
self.hook.update_flow_filter(
flow_name=self.flow_name, filter_tasks=[filter_task], set_trigger_ondemand=True
)
[docs]class AppflowRunAfterOperator(AppflowBaseOperator):
"""
Execute a Appflow run after updating the filters to select only future data.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:AppflowRunAfterOperator`
:param source: The source name (Supported: salesforce, zendesk)
:param flow_name: The flow name
:param source_field: The field name to apply filters
:param filter_date: The date value (or template) to be used in filters.
:param poll_interval: how often in seconds to check the query status
:param aws_conn_id: aws connection to use
:param region: aws region to use
:param wait_for_completion: whether to wait for the run to end to return
"""
def __init__(
self,
source: str,
flow_name: str,
source_field: str,
filter_date: str,
poll_interval: int = 20,
aws_conn_id: str = "aws_default",
region: str | None = None,
wait_for_completion: bool = True,
**kwargs,
) -> None:
if not filter_date:
raise ValueError(MANDATORY_FILTER_DATE_MSG.format(entity="AppflowRunAfterOperator"))
if source not in {"salesforce", "zendesk"}:
raise ValueError(NOT_SUPPORTED_SOURCE_MSG.format(source=source, entity="AppflowRunAfterOperator"))
super().__init__(
source=source,
flow_name=flow_name,
flow_update=True,
source_field=source_field,
filter_date=filter_date,
poll_interval=poll_interval,
aws_conn_id=aws_conn_id,
region=region,
wait_for_completion=wait_for_completion,
**kwargs,
)
def _update_flow(self) -> None:
if not self.filter_date_parsed:
raise ValueError(f"Invalid filter_date argument parser value: {self.filter_date_parsed}")
if not self.source_field:
raise ValueError(f"Invalid source_field argument value: {self.source_field}")
filter_task: TaskTypeDef = {
"taskType": "Filter",
"connectorOperator": {self.connector_type: "GREATER_THAN"}, # type: ignore
"sourceFields": [self.source_field],
"taskProperties": {
"DATA_TYPE": "datetime",
"VALUE": str(datetime_to_epoch_ms(self.filter_date_parsed)),
}, # NOT inclusive
}
self.hook.update_flow_filter(
flow_name=self.flow_name, filter_tasks=[filter_task], set_trigger_ondemand=True
)
[docs]class AppflowRunDailyOperator(AppflowBaseOperator):
"""
Execute a Appflow run after updating the filters to select only a single day.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:AppflowRunDailyOperator`
:param source: The source name (Supported: salesforce)
:param flow_name: The flow name
:param source_field: The field name to apply filters
:param filter_date: The date value (or template) to be used in filters.
:param poll_interval: how often in seconds to check the query status
:param aws_conn_id: aws connection to use
:param region: aws region to use
:param wait_for_completion: whether to wait for the run to end to return
"""
def __init__(
self,
source: str,
flow_name: str,
source_field: str,
filter_date: str,
poll_interval: int = 20,
aws_conn_id: str = "aws_default",
region: str | None = None,
wait_for_completion: bool = True,
**kwargs,
) -> None:
if not filter_date:
raise ValueError(MANDATORY_FILTER_DATE_MSG.format(entity="AppflowRunDailyOperator"))
if source != "salesforce":
raise ValueError(NOT_SUPPORTED_SOURCE_MSG.format(source=source, entity="AppflowRunDailyOperator"))
super().__init__(
source=source,
flow_name=flow_name,
flow_update=True,
source_field=source_field,
filter_date=filter_date,
poll_interval=poll_interval,
aws_conn_id=aws_conn_id,
region=region,
wait_for_completion=wait_for_completion,
**kwargs,
)
def _update_flow(self) -> None:
if not self.filter_date_parsed:
raise ValueError(f"Invalid filter_date argument parser value: {self.filter_date_parsed}")
if not self.source_field:
raise ValueError(f"Invalid source_field argument value: {self.source_field}")
start_filter_date = self.filter_date_parsed - timedelta(milliseconds=1)
end_filter_date = self.filter_date_parsed + timedelta(days=1)
filter_task: TaskTypeDef = {
"taskType": "Filter",
"connectorOperator": {self.connector_type: "BETWEEN"}, # type: ignore
"sourceFields": [self.source_field],
"taskProperties": {
"DATA_TYPE": "datetime",
"LOWER_BOUND": str(datetime_to_epoch_ms(start_filter_date)), # NOT inclusive
"UPPER_BOUND": str(datetime_to_epoch_ms(end_filter_date)), # NOT inclusive
},
}
self.hook.update_flow_filter(
flow_name=self.flow_name, filter_tasks=[filter_task], set_trigger_ondemand=True
)
[docs]class AppflowRecordsShortCircuitOperator(ShortCircuitOperator):
"""
Short-circuit in case of a empty Appflow's run.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:AppflowRecordsShortCircuitOperator`
:param flow_name: The flow name
:param appflow_run_task_id: Run task ID from where this operator should extract the execution ID
:param ignore_downstream_trigger_rules: Ignore downstream trigger rules
:param aws_conn_id: aws connection to use
:param region: aws region to use
"""
[docs] ui_color = "#33ffec" # Light blue
def __init__(
self,
*,
flow_name: str,
appflow_run_task_id: str,
ignore_downstream_trigger_rules: bool = True,
aws_conn_id: str = "aws_default",
region: str | None = None,
**kwargs,
) -> None:
super().__init__(
python_callable=self._has_new_records_func,
op_kwargs={
"flow_name": flow_name,
"appflow_run_task_id": appflow_run_task_id,
},
ignore_downstream_trigger_rules=ignore_downstream_trigger_rules,
**kwargs,
)
self.aws_conn_id = aws_conn_id
self.region = region
@staticmethod
def _get_target_execution_id(
records: list[ExecutionRecordTypeDef], execution_id: str
) -> ExecutionRecordTypeDef | None:
for record in records:
if record.get("executionId") == execution_id:
return record
return None
@cached_property
[docs] def hook(self) -> AppflowHook:
"""Create and return an AppflowHook."""
return AppflowHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
def _has_new_records_func(self, **kwargs) -> bool:
appflow_task_id = kwargs["appflow_run_task_id"]
self.log.info("appflow_task_id: %s", appflow_task_id)
flow_name = kwargs["flow_name"]
self.log.info("flow_name: %s", flow_name)
af_client = self.hook.conn
task_instance = kwargs["task_instance"]
execution_id = task_instance.xcom_pull(task_ids=appflow_task_id, key="execution_id") # type: ignore
if not execution_id:
raise AirflowException(f"No execution_id found from task_id {appflow_task_id}!")
self.log.info("execution_id: %s", execution_id)
args = {"flowName": flow_name, "maxResults": 100}
response: DescribeFlowExecutionRecordsResponseTypeDef = cast(
"DescribeFlowExecutionRecordsResponseTypeDef", {}
)
record = None
while not record:
if "nextToken" in response:
response = af_client.describe_flow_execution_records(nextToken=response["nextToken"], **args)
else:
response = af_client.describe_flow_execution_records(**args)
record = AppflowRecordsShortCircuitOperator._get_target_execution_id(
response["flowExecutions"], execution_id
)
if not record and "nextToken" not in response:
raise AirflowException(f"Flow ({execution_id}) without recordsProcessed info.")
execution = record.get("executionResult", {})
if "recordsProcessed" not in execution:
raise AirflowException(f"Flow ({execution_id}) without recordsProcessed info!")
records_processed = execution["recordsProcessed"]
self.log.info("records_processed: %d", records_processed)
task_instance.xcom_push("records_processed", records_processed) # type: ignore
return records_processed > 0