Note: DMS requires you to configure specific IAM roles/permissions. For more information, see
import json
import os
from datetime import datetime
import boto3
from sqlalchemy import Column , MetaData , String , Table , create_engine
from airflow import DAG
from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.operators.python import get_current_context
from airflow.providers.amazon.aws.operators.dms import (
DmsCreateTaskOperator ,
DmsDeleteTaskOperator ,
DmsDescribeTasksOperator ,
DmsStartTaskOperator ,
DmsStopTaskOperator ,
from airflow.providers.amazon.aws.sensors.dms import DmsTaskBaseSensor , DmsTaskCompletedSensor
[docs] S3_BUCKET = os . getenv ( 'S3_BUCKET' , 's3_bucket_name' )
[docs] ROLE_ARN = os . getenv ( 'ROLE_ARN' , 'arn:aws:iam::1234567890:role/s3_target_endpoint_role' )
# The project name will be used as a prefix for various entity names.
# Use either PascalCase or camelCase. While some names require kebab-case
# and others require snake_case, they all accept mixedCase strings.
[docs] PROJECT_NAME = 'DmsDemo'
# Config values for setting up the "Source" database.
[docs] RDS_PROTOCOL = 'postgresql'
[docs] RDS_USERNAME = 'username'
# NEVER store your production password in plaintext in a DAG like this.
# Use Airflow Secrets or a secret manager for this in production.
[docs] RDS_PASSWORD = 'rds_password'
# Config values for RDS.
[docs] RDS_INSTANCE_NAME = f ' { PROJECT_NAME } -instance'
[docs] RDS_DB_NAME = f ' { PROJECT_NAME } _source_database'
# Config values for DMS.
[docs] DMS_REPLICATION_INSTANCE_NAME = f ' { PROJECT_NAME } -replication-instance'
[docs] DMS_REPLICATION_TASK_ID = f ' { PROJECT_NAME } -replication-task'
[docs] SOURCE_ENDPOINT_IDENTIFIER = f ' { PROJECT_NAME } -source-endpoint'
[docs] TARGET_ENDPOINT_IDENTIFIER = f ' { PROJECT_NAME } -target-endpoint'
# Sample data.
[docs] TABLE_NAME = f ' { PROJECT_NAME } -table'
[docs] SAMPLE_DATA = [
( 'Airflow' , '2015' ),
( 'OpenOffice' , '2012' ),
( 'Subversion' , '2000' ),
( 'NiFi' , '2006' ),
'TableCount' : '1' ,
'Tables' : [
'TableName' : TABLE_NAME ,
'TableColumns' : [
'ColumnName' : TABLE_HEADERS [ 0 ],
'ColumnType' : 'STRING' ,
'ColumnNullable' : 'false' ,
'ColumnIsPk' : 'true' ,
{ "ColumnName" : TABLE_HEADERS [ 1 ], "ColumnType" : 'STRING' , "ColumnLength" : "4" },
'TableColumnsTotal' : '2' ,
'rules' : [
'rule-type' : 'selection' ,
'rule-id' : '1' ,
'rule-name' : '1' ,
'object-locator' : {
'schema-name' : 'public' ,
'table-name' : TABLE_NAME ,
'rule-action' : 'include' ,
def _create_rds_instance ():
print ( 'Creating RDS Instance.' )
rds_client = boto3 . client ( 'rds' )
rds_client . create_db_instance (
DBInstanceIdentifier = RDS_INSTANCE_NAME ,
AllocatedStorage = 20 ,
DBInstanceClass = 'db.t3.micro' ,
Engine = RDS_ENGINE ,
MasterUsername = RDS_USERNAME ,
MasterUserPassword = RDS_PASSWORD ,
rds_client . get_waiter ( 'db_instance_available' ) . wait ( DBInstanceIdentifier = RDS_INSTANCE_NAME )
response = rds_client . describe_db_instances ( DBInstanceIdentifier = RDS_INSTANCE_NAME )
return response [ 'DBInstances' ][ 0 ][ 'Endpoint' ]
def _create_rds_table ( rds_endpoint ):
print ( 'Creating table.' )
hostname = rds_endpoint [ 'Address' ]
port = rds_endpoint [ 'Port' ]
rds_url = f ' { RDS_PROTOCOL } :// { RDS_USERNAME } : { RDS_PASSWORD } @ { hostname } : { port } / { RDS_DB_NAME } '
engine = create_engine ( rds_url )
table = Table (
MetaData ( engine ),
Column ( TABLE_HEADERS [ 0 ], String , primary_key = True ),
Column ( TABLE_HEADERS [ 1 ], String ),
with engine . connect () as connection :
# Create the Table.
table . create ()
load_data = table . insert () . values ( SAMPLE_DATA )
connection . execute ( load_data )
# Read the data back to verify everything is working.
connection . execute ( table . select ())
def _create_dms_replication_instance ( ti , dms_client ):
print ( 'Creating replication instance.' )
instance_arn = dms_client . create_replication_instance (
ReplicationInstanceIdentifier = DMS_REPLICATION_INSTANCE_NAME ,
ReplicationInstanceClass = 'dms.t3.micro' ,
)[ 'ReplicationInstance' ][ 'ReplicationInstanceArn' ]
ti . xcom_push ( key = 'replication_instance_arn' , value = instance_arn )
return instance_arn
def _create_dms_endpoints ( ti , dms_client , rds_instance_endpoint ):
print ( 'Creating DMS source endpoint.' )
source_endpoint_arn = dms_client . create_endpoint (
EndpointType = 'source' ,
EngineName = RDS_ENGINE ,
Username = RDS_USERNAME ,
Password = RDS_PASSWORD ,
ServerName = rds_instance_endpoint [ 'Address' ],
Port = rds_instance_endpoint [ 'Port' ],
DatabaseName = RDS_DB_NAME ,
)[ 'Endpoint' ][ 'EndpointArn' ]
print ( 'Creating DMS target endpoint.' )
target_endpoint_arn = dms_client . create_endpoint (
EndpointType = 'target' ,
EngineName = 's3' ,
S3Settings = {
'BucketName' : S3_BUCKET ,
'BucketFolder' : PROJECT_NAME ,
'ServiceAccessRoleArn' : ROLE_ARN ,
'ExternalTableDefinition' : json . dumps ( TABLE_DEFINITION ),
)[ 'Endpoint' ][ 'EndpointArn' ]
ti . xcom_push ( key = 'source_endpoint_arn' , value = source_endpoint_arn )
ti . xcom_push ( key = 'target_endpoint_arn' , value = target_endpoint_arn )
def _await_setup_assets ( dms_client , instance_arn ):
print ( "Awaiting asset provisioning." )
dms_client . get_waiter ( 'replication_instance_available' ) . wait (
Filters = [{ 'Name' : 'replication-instance-arn' , 'Values' : [ instance_arn ]}]
def _delete_rds_instance ():
print ( 'Deleting RDS Instance.' )
rds_client = boto3 . client ( 'rds' )
rds_client . delete_db_instance (
DBInstanceIdentifier = RDS_INSTANCE_NAME ,
SkipFinalSnapshot = True ,
rds_client . get_waiter ( 'db_instance_deleted' ) . wait ( DBInstanceIdentifier = RDS_INSTANCE_NAME )
def _delete_dms_assets ( dms_client ):
ti = get_current_context ()[ 'ti' ]
replication_instance_arn = ti . xcom_pull ( key = 'replication_instance_arn' )
source_arn = ti . xcom_pull ( key = 'source_endpoint_arn' )
target_arn = ti . xcom_pull ( key = 'target_endpoint_arn' )
print ( 'Deleting DMS assets.' )
dms_client . delete_replication_instance ( ReplicationInstanceArn = replication_instance_arn )
dms_client . delete_endpoint ( EndpointArn = source_arn )
dms_client . delete_endpoint ( EndpointArn = target_arn )
def _await_all_teardowns ( dms_client ):
print ( 'Awaiting tear-down.' )
dms_client . get_waiter ( 'replication_instance_deleted' ) . wait (
Filters = [{ 'Name' : 'replication-instance-id' , 'Values' : [ DMS_REPLICATION_INSTANCE_NAME ]}]
dms_client . get_waiter ( 'endpoint_deleted' ) . wait (
Filters = [
'Name' : 'endpoint-id' ,
[docs] def set_up ():
ti = get_current_context ()[ 'ti' ]
dms_client = boto3 . client ( 'dms' )
rds_instance_endpoint = _create_rds_instance ()
_create_rds_table ( rds_instance_endpoint )
instance_arn = _create_dms_replication_instance ( ti , dms_client )
_create_dms_endpoints ( ti , dms_client , rds_instance_endpoint )
_await_setup_assets ( dms_client , instance_arn )
@task ( trigger_rule = 'all_done' )
[docs] def clean_up ():
dms_client = boto3 . client ( 'dms' )
_delete_rds_instance ()
_delete_dms_assets ( dms_client )
_await_all_teardowns ( dms_client )
with DAG (
dag_id = 'example_dms' ,
schedule_interval = None ,
start_date = datetime ( 2021 , 1 , 1 ),
tags = [ 'example' ],
catchup = False ,
) as dag :
# [START howto_operator_dms_create_task]
[docs] create_task = DmsCreateTaskOperator (
task_id = 'create_task' ,
replication_task_id = DMS_REPLICATION_TASK_ID ,
source_endpoint_arn = '{{ ti.xcom_pull(key="source_endpoint_arn") }}' ,
target_endpoint_arn = '{{ ti.xcom_pull(key="target_endpoint_arn") }}' ,
replication_instance_arn = '{{ ti.xcom_pull(key="replication_instance_arn") }}' ,
table_mappings = TABLE_MAPPINGS ,
# [END howto_operator_dms_create_task]
# [START howto_operator_dms_start_task]
start_task = DmsStartTaskOperator (
task_id = 'start_task' ,
replication_task_arn = create_task . output ,
# [END howto_operator_dms_start_task]
# [START howto_operator_dms_describe_tasks]
describe_tasks = DmsDescribeTasksOperator (
task_id = 'describe_tasks' ,
describe_tasks_kwargs = {
'Filters' : [
'Name' : 'replication-instance-arn' ,
'Values' : [ '{{ ti.xcom_pull(key="replication_instance_arn") }}' ],
do_xcom_push = False ,
# [END howto_operator_dms_describe_tasks]
await_task_start = DmsTaskBaseSensor (
task_id = 'await_task_start' ,
replication_task_arn = create_task . output ,
target_statuses = [ 'running' ],
termination_statuses = [ 'stopped' , 'deleting' , 'failed' ],
# [START howto_operator_dms_stop_task]
stop_task = DmsStopTaskOperator (
task_id = 'stop_task' ,
replication_task_arn = create_task . output ,
# [END howto_operator_dms_stop_task]
# TaskCompletedSensor actually waits until task reaches the "Stopped" state, so it will work here.
# [START howto_sensor_dms_task_completed]
await_task_stop = DmsTaskCompletedSensor (
task_id = 'await_task_stop' ,
replication_task_arn = create_task . output ,
# [END howto_sensor_dms_task_completed]
# [START howto_operator_dms_delete_task]
delete_task = DmsDeleteTaskOperator (
task_id = 'delete_task' ,
replication_task_arn = create_task . output ,
trigger_rule = 'all_done' ,
# [END howto_operator_dms_delete_task]
chain (
set_up ()
>> create_task
>> start_task
>> describe_tasks
>> await_task_start
>> stop_task
>> await_task_stop
>> delete_task
>> clean_up ()
