#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Note: DMS requires you to configure specific IAM roles/permissions. For more information, see
https://docs.aws.amazon.com/dms/latest/userguide/CHAP_Security.html#CHAP_Security.APIRole
"""
from __future__ import annotations
import json
from datetime import datetime
from typing import cast
import boto3
from sqlalchemy import Column, MetaData, String, Table, create_engine
from airflow import DAG
from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.providers.amazon.aws.operators.dms import (
DmsCreateTaskOperator,
DmsDeleteTaskOperator,
DmsDescribeTasksOperator,
DmsStartTaskOperator,
DmsStopTaskOperator,
)
from airflow.providers.amazon.aws.operators.rds import (
RdsCreateDbInstanceOperator,
RdsDeleteDbInstanceOperator,
)
from airflow.providers.amazon.aws.operators.s3 import S3CreateBucketOperator, S3DeleteBucketOperator
from airflow.providers.amazon.aws.sensors.dms import DmsTaskBaseSensor, DmsTaskCompletedSensor
from airflow.utils.trigger_rule import TriggerRule
from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder
[docs]ROLE_ARN_KEY = 'ROLE_ARN'
[docs]sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).add_variable(VPC_ID_KEY).build()
# Config values for setting up the "Source" database.
[docs]RDS_PROTOCOL = 'postgresql'
[docs]RDS_USERNAME = 'username'
# NEVER store your production password in plaintext in a DAG like this.
# Use Airflow Secrets or a secret manager for this in production.
[docs]RDS_PASSWORD = 'rds_password'
[docs]SAMPLE_DATA = [
('Airflow', '2015'),
('OpenOffice', '2012'),
('Subversion', '2000'),
('NiFi', '2006'),
]
[docs]SG_IP_PERMISSION = {
'FromPort': 5432,
'IpProtocol': 'All',
'IpRanges': [{'CidrIp': '0.0.0.0/0'}],
}
def _get_rds_instance_endpoint(instance_name: str):
print('Retrieving RDS instance endpoint.')
rds_client = boto3.client('rds')
response = rds_client.describe_db_instances(DBInstanceIdentifier=instance_name)
rds_instance_endpoint = response['DBInstances'][0]['Endpoint']
return rds_instance_endpoint
@task
[docs]def create_security_group(security_group_name: str, vpc_id: str):
client = boto3.client('ec2')
security_group = client.create_security_group(
GroupName=security_group_name,
Description='Created for DMS system test',
VpcId=vpc_id,
)
client.get_waiter('security_group_exists').wait(
GroupIds=[security_group['GroupId']],
GroupNames=[security_group_name],
WaiterConfig={'Delay': 15, 'MaxAttempts': 4},
)
client.authorize_security_group_ingress(
GroupId=security_group['GroupId'],
GroupName=security_group_name,
IpPermissions=[SG_IP_PERMISSION],
)
return security_group['GroupId']
@task
[docs]def create_sample_table(instance_name: str, db_name: str, table_name: str):
print('Creating sample table.')
rds_endpoint = _get_rds_instance_endpoint(instance_name)
hostname = rds_endpoint['Address']
port = rds_endpoint['Port']
rds_url = f'{RDS_PROTOCOL}://{RDS_USERNAME}:{RDS_PASSWORD}@{hostname}:{port}/{db_name}'
engine = create_engine(rds_url)
table = Table(
table_name,
MetaData(engine),
Column(TABLE_HEADERS[0], String, primary_key=True),
Column(TABLE_HEADERS[1], String),
)
with engine.connect() as connection:
# Create the Table.
table.create()
load_data = table.insert().values(SAMPLE_DATA)
connection.execute(load_data)
# Read the data back to verify everything is working.
connection.execute(table.select())
@task(multiple_outputs=True)
[docs]def create_dms_assets(
db_name: str,
instance_name: str,
replication_instance_name: str,
bucket_name: str,
role_arn,
source_endpoint_identifier: str,
target_endpoint_identifier: str,
table_definition: dict,
):
print('Creating DMS assets.')
dms_client = boto3.client('dms')
rds_instance_endpoint = _get_rds_instance_endpoint(instance_name)
print('Creating replication instance.')
instance_arn = dms_client.create_replication_instance(
ReplicationInstanceIdentifier=replication_instance_name,
ReplicationInstanceClass='dms.t3.micro',
)['ReplicationInstance']['ReplicationInstanceArn']
print('Creating DMS source endpoint.')
source_endpoint_arn = dms_client.create_endpoint(
EndpointIdentifier=source_endpoint_identifier,
EndpointType='source',
EngineName=RDS_ENGINE,
Username=RDS_USERNAME,
Password=RDS_PASSWORD,
ServerName=rds_instance_endpoint['Address'],
Port=rds_instance_endpoint['Port'],
DatabaseName=db_name,
)['Endpoint']['EndpointArn']
print('Creating DMS target endpoint.')
target_endpoint_arn = dms_client.create_endpoint(
EndpointIdentifier=target_endpoint_identifier,
EndpointType='target',
EngineName='s3',
S3Settings={
'BucketName': bucket_name,
'BucketFolder': 'folder',
'ServiceAccessRoleArn': role_arn,
'ExternalTableDefinition': json.dumps(table_definition),
},
)['Endpoint']['EndpointArn']
print("Awaiting replication instance provisioning.")
dms_client.get_waiter('replication_instance_available').wait(
Filters=[{'Name': 'replication-instance-arn', 'Values': [instance_arn]}]
)
return {
'replication_instance_arn': instance_arn,
'source_endpoint_arn': source_endpoint_arn,
'target_endpoint_arn': target_endpoint_arn,
}
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs]def delete_dms_assets(
replication_instance_arn: str,
source_endpoint_arn: str,
target_endpoint_arn: str,
source_endpoint_identifier: str,
target_endpoint_identifier: str,
replication_instance_name: str,
):
dms_client = boto3.client('dms')
print('Deleting DMS assets.')
dms_client.delete_replication_instance(ReplicationInstanceArn=replication_instance_arn)
dms_client.delete_endpoint(EndpointArn=source_endpoint_arn)
dms_client.delete_endpoint(EndpointArn=target_endpoint_arn)
print('Awaiting DMS assets tear-down.')
dms_client.get_waiter('replication_instance_deleted').wait(
Filters=[{'Name': 'replication-instance-id', 'Values': [replication_instance_name]}]
)
dms_client.get_waiter('endpoint_deleted').wait(
Filters=[
{
'Name': 'endpoint-id',
'Values': [source_endpoint_identifier, target_endpoint_identifier],
}
]
)
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs]def delete_security_group(security_group_id: str, security_group_name: str):
boto3.client('ec2').delete_security_group(GroupId=security_group_id, GroupName=security_group_name)
with DAG(
DAG_ID,
schedule='@once',
start_date=datetime(2021, 1, 1),
tags=['example'],
catchup=False,
) as dag:
[docs] test_context = sys_test_context_task()
env_id = test_context[ENV_ID_KEY]
role_arn = test_context[ROLE_ARN_KEY]
vpc_id = test_context[VPC_ID_KEY]
bucket_name = f'{env_id}-dms-bucket'
rds_instance_name = f'{env_id}-instance'
rds_db_name = f'{env_id}_source_database' # dashes are not allowed in db name
rds_table_name = f'{env_id}-table'
dms_replication_instance_name = f'{env_id}-replication-instance'
dms_replication_task_id = f'{env_id}-replication-task'
source_endpoint_identifier = f'{env_id}-source-endpoint'
target_endpoint_identifier = f'{env_id}-target-endpoint'
security_group_name = f'{env_id}-dms-security-group'
# Sample data.
table_definition = {
'TableCount': '1',
'Tables': [
{
'TableName': rds_table_name,
'TableColumns': [
{
'ColumnName': TABLE_HEADERS[0],
'ColumnType': 'STRING',
'ColumnNullable': 'false',
'ColumnIsPk': 'true',
},
{"ColumnName": TABLE_HEADERS[1], "ColumnType": 'STRING', "ColumnLength": "4"},
],
'TableColumnsTotal': '2',
}
],
}
table_mappings = {
'rules': [
{
'rule-type': 'selection',
'rule-id': '1',
'rule-name': '1',
'object-locator': {
'schema-name': 'public',
'table-name': rds_table_name,
},
'rule-action': 'include',
}
]
}
create_s3_bucket = S3CreateBucketOperator(task_id='create_s3_bucket', bucket_name=bucket_name)
create_sg = create_security_group(security_group_name, vpc_id)
create_db_instance = RdsCreateDbInstanceOperator(
task_id="create_db_instance",
db_instance_identifier=rds_instance_name,
db_instance_class='db.t3.micro',
engine=RDS_ENGINE,
rds_kwargs={
"DBName": rds_db_name,
"AllocatedStorage": 20,
"MasterUsername": RDS_USERNAME,
"MasterUserPassword": RDS_PASSWORD,
"PubliclyAccessible": True,
"VpcSecurityGroupIds": [
create_sg,
],
},
)
create_assets = create_dms_assets(
db_name=rds_db_name,
instance_name=rds_instance_name,
replication_instance_name=dms_replication_instance_name,
bucket_name=bucket_name,
role_arn=role_arn,
source_endpoint_identifier=source_endpoint_identifier,
target_endpoint_identifier=target_endpoint_identifier,
table_definition=table_definition,
)
# [START howto_operator_dms_create_task]
create_task = DmsCreateTaskOperator(
task_id='create_task',
replication_task_id=dms_replication_task_id,
source_endpoint_arn=create_assets['source_endpoint_arn'],
target_endpoint_arn=create_assets['target_endpoint_arn'],
replication_instance_arn=create_assets['replication_instance_arn'],
table_mappings=table_mappings,
)
# [END howto_operator_dms_create_task]
task_arn = cast(str, create_task.output)
# [START howto_operator_dms_start_task]
start_task = DmsStartTaskOperator(
task_id='start_task',
replication_task_arn=task_arn,
)
# [END howto_operator_dms_start_task]
# [START howto_operator_dms_describe_tasks]
describe_tasks = DmsDescribeTasksOperator(
task_id='describe_tasks',
describe_tasks_kwargs={
'Filters': [
{
'Name': 'replication-instance-arn',
'Values': [create_assets['replication_instance_arn']],
}
]
},
do_xcom_push=False,
)
# [END howto_operator_dms_describe_tasks]
await_task_start = DmsTaskBaseSensor(
task_id='await_task_start',
replication_task_arn=task_arn,
target_statuses=['running'],
termination_statuses=['stopped', 'deleting', 'failed'],
)
# [START howto_operator_dms_stop_task]
stop_task = DmsStopTaskOperator(
task_id='stop_task',
replication_task_arn=task_arn,
)
# [END howto_operator_dms_stop_task]
# TaskCompletedSensor actually waits until task reaches the "Stopped" state, so it will work here.
# [START howto_sensor_dms_task_completed]
await_task_stop = DmsTaskCompletedSensor(
task_id='await_task_stop',
replication_task_arn=task_arn,
)
# [END howto_sensor_dms_task_completed]
# [START howto_operator_dms_delete_task]
delete_task = DmsDeleteTaskOperator(
task_id='delete_task',
replication_task_arn=task_arn,
)
# [END howto_operator_dms_delete_task]
delete_task.trigger_rule = TriggerRule.ALL_DONE
delete_assets = delete_dms_assets(
replication_instance_arn=create_assets['replication_instance_arn'],
source_endpoint_arn=create_assets['source_endpoint_arn'],
target_endpoint_arn=create_assets['target_endpoint_arn'],
source_endpoint_identifier=source_endpoint_identifier,
target_endpoint_identifier=target_endpoint_identifier,
replication_instance_name=dms_replication_instance_name,
)
delete_db_instance = RdsDeleteDbInstanceOperator(
task_id='delete_db_instance',
db_instance_identifier=rds_instance_name,
rds_kwargs={
"SkipFinalSnapshot": True,
},
trigger_rule=TriggerRule.ALL_DONE,
)
delete_s3_bucket = S3DeleteBucketOperator(
task_id='delete_s3_bucket',
bucket_name=bucket_name,
force_delete=True,
trigger_rule=TriggerRule.ALL_DONE,
)
chain(
# TEST SETUP
test_context,
create_s3_bucket,
create_sg,
create_db_instance,
create_sample_table(rds_instance_name, rds_db_name, rds_table_name),
create_assets,
# TEST BODY
create_task,
start_task,
describe_tasks,
await_task_start,
stop_task,
await_task_stop,
# TEST TEARDOWN
delete_task,
delete_assets,
delete_db_instance,
delete_security_group(create_sg, security_group_name),
delete_s3_bucket,
)
from tests.system.utils.watcher import watcher
# This test needs watcher in order to properly mark success/failure
# when "tearDown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()
from tests.system.utils import get_test_run # noqa: E402
# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
[docs]test_run = get_test_run(dag)