Source code for airflow.providers.amazon.aws.example_dags.example_dms

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Note:  DMS requires you to configure specific IAM roles/permissions.  For more information, see
https://docs.aws.amazon.com/dms/latest/userguide/CHAP_Security.html#CHAP_Security.APIRole
"""

import json
import os
from datetime import datetime

import boto3
from sqlalchemy import Column, MetaData, String, Table, create_engine

from airflow import DAG
from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.operators.python import get_current_context
from airflow.providers.amazon.aws.operators.dms import (
    DmsCreateTaskOperator,
    DmsDeleteTaskOperator,
    DmsDescribeTasksOperator,
    DmsStartTaskOperator,
    DmsStopTaskOperator,
)
from airflow.providers.amazon.aws.sensors.dms import DmsTaskBaseSensor, DmsTaskCompletedSensor

[docs]S3_BUCKET = os.getenv('S3_BUCKET', 's3_bucket_name')
[docs]ROLE_ARN = os.getenv('ROLE_ARN', 'arn:aws:iam::1234567890:role/s3_target_endpoint_role')
# The project name will be used as a prefix for various entity names. # Use either PascalCase or camelCase. While some names require kebab-case # and others require snake_case, they all accept mixedCase strings.
[docs]PROJECT_NAME = 'DmsDemo'
# Config values for setting up the "Source" database.
[docs]RDS_ENGINE = 'postgres'
[docs]RDS_PROTOCOL = 'postgresql'
[docs]RDS_USERNAME = 'username'
# NEVER store your production password in plaintext in a DAG like this. # Use Airflow Secrets or a secret manager for this in production.
[docs]RDS_PASSWORD = 'rds_password'
# Config values for RDS.
[docs]RDS_INSTANCE_NAME = f'{PROJECT_NAME}-instance'
[docs]RDS_DB_NAME = f'{PROJECT_NAME}_source_database'
# Config values for DMS.
[docs]DMS_REPLICATION_INSTANCE_NAME = f'{PROJECT_NAME}-replication-instance'
[docs]DMS_REPLICATION_TASK_ID = f'{PROJECT_NAME}-replication-task'
[docs]SOURCE_ENDPOINT_IDENTIFIER = f'{PROJECT_NAME}-source-endpoint'
[docs]TARGET_ENDPOINT_IDENTIFIER = f'{PROJECT_NAME}-target-endpoint'
# Sample data.
[docs]TABLE_NAME = f'{PROJECT_NAME}-table'
[docs]TABLE_HEADERS = ['apache_project', 'release_year']
[docs]SAMPLE_DATA = [ ('Airflow', '2015'), ('OpenOffice', '2012'), ('Subversion', '2000'), ('NiFi', '2006'),
]
[docs]TABLE_DEFINITION = { 'TableCount': '1', 'Tables': [ { 'TableName': TABLE_NAME, 'TableColumns': [ { 'ColumnName': TABLE_HEADERS[0], 'ColumnType': 'STRING', 'ColumnNullable': 'false', 'ColumnIsPk': 'true', }, {"ColumnName": TABLE_HEADERS[1], "ColumnType": 'STRING', "ColumnLength": "4"}, ], 'TableColumnsTotal': '2',
} ], }
[docs]TABLE_MAPPINGS = { 'rules': [ { 'rule-type': 'selection', 'rule-id': '1', 'rule-name': '1', 'object-locator': { 'schema-name': 'public', 'table-name': TABLE_NAME, }, 'rule-action': 'include',
} ] } def _create_rds_instance(): print('Creating RDS Instance.') rds_client = boto3.client('rds') rds_client.create_db_instance( DBName=RDS_DB_NAME, DBInstanceIdentifier=RDS_INSTANCE_NAME, AllocatedStorage=20, DBInstanceClass='db.t3.micro', Engine=RDS_ENGINE, MasterUsername=RDS_USERNAME, MasterUserPassword=RDS_PASSWORD, ) rds_client.get_waiter('db_instance_available').wait(DBInstanceIdentifier=RDS_INSTANCE_NAME) response = rds_client.describe_db_instances(DBInstanceIdentifier=RDS_INSTANCE_NAME) return response['DBInstances'][0]['Endpoint'] def _create_rds_table(rds_endpoint): print('Creating table.') hostname = rds_endpoint['Address'] port = rds_endpoint['Port'] rds_url = f'{RDS_PROTOCOL}://{RDS_USERNAME}:{RDS_PASSWORD}@{hostname}:{port}/{RDS_DB_NAME}' engine = create_engine(rds_url) table = Table( TABLE_NAME, MetaData(engine), Column(TABLE_HEADERS[0], String, primary_key=True), Column(TABLE_HEADERS[1], String), ) with engine.connect() as connection: # Create the Table. table.create() load_data = table.insert().values(SAMPLE_DATA) connection.execute(load_data) # Read the data back to verify everything is working. connection.execute(table.select()) def _create_dms_replication_instance(ti, dms_client): print('Creating replication instance.') instance_arn = dms_client.create_replication_instance( ReplicationInstanceIdentifier=DMS_REPLICATION_INSTANCE_NAME, ReplicationInstanceClass='dms.t3.micro', )['ReplicationInstance']['ReplicationInstanceArn'] ti.xcom_push(key='replication_instance_arn', value=instance_arn) return instance_arn def _create_dms_endpoints(ti, dms_client, rds_instance_endpoint): print('Creating DMS source endpoint.') source_endpoint_arn = dms_client.create_endpoint( EndpointIdentifier=SOURCE_ENDPOINT_IDENTIFIER, EndpointType='source', EngineName=RDS_ENGINE, Username=RDS_USERNAME, Password=RDS_PASSWORD, ServerName=rds_instance_endpoint['Address'], Port=rds_instance_endpoint['Port'], DatabaseName=RDS_DB_NAME, )['Endpoint']['EndpointArn'] print('Creating DMS target endpoint.') target_endpoint_arn = dms_client.create_endpoint( EndpointIdentifier=TARGET_ENDPOINT_IDENTIFIER, EndpointType='target', EngineName='s3', S3Settings={ 'BucketName': S3_BUCKET, 'BucketFolder': PROJECT_NAME, 'ServiceAccessRoleArn': ROLE_ARN, 'ExternalTableDefinition': json.dumps(TABLE_DEFINITION), }, )['Endpoint']['EndpointArn'] ti.xcom_push(key='source_endpoint_arn', value=source_endpoint_arn) ti.xcom_push(key='target_endpoint_arn', value=target_endpoint_arn) def _await_setup_assets(dms_client, instance_arn): print("Awaiting asset provisioning.") dms_client.get_waiter('replication_instance_available').wait( Filters=[{'Name': 'replication-instance-arn', 'Values': [instance_arn]}] ) def _delete_rds_instance(): print('Deleting RDS Instance.') rds_client = boto3.client('rds') rds_client.delete_db_instance( DBInstanceIdentifier=RDS_INSTANCE_NAME, SkipFinalSnapshot=True, ) rds_client.get_waiter('db_instance_deleted').wait(DBInstanceIdentifier=RDS_INSTANCE_NAME) def _delete_dms_assets(dms_client): ti = get_current_context()['ti'] replication_instance_arn = ti.xcom_pull(key='replication_instance_arn') source_arn = ti.xcom_pull(key='source_endpoint_arn') target_arn = ti.xcom_pull(key='target_endpoint_arn') print('Deleting DMS assets.') dms_client.delete_replication_instance(ReplicationInstanceArn=replication_instance_arn) dms_client.delete_endpoint(EndpointArn=source_arn) dms_client.delete_endpoint(EndpointArn=target_arn) def _await_all_teardowns(dms_client): print('Awaiting tear-down.') dms_client.get_waiter('replication_instance_deleted').wait( Filters=[{'Name': 'replication-instance-id', 'Values': [DMS_REPLICATION_INSTANCE_NAME]}] ) dms_client.get_waiter('endpoint_deleted').wait( Filters=[ { 'Name': 'endpoint-id', 'Values': [SOURCE_ENDPOINT_IDENTIFIER, TARGET_ENDPOINT_IDENTIFIER], } ] ) @task
[docs]def set_up(): ti = get_current_context()['ti'] dms_client = boto3.client('dms') rds_instance_endpoint = _create_rds_instance() _create_rds_table(rds_instance_endpoint) instance_arn = _create_dms_replication_instance(ti, dms_client) _create_dms_endpoints(ti, dms_client, rds_instance_endpoint) _await_setup_assets(dms_client, instance_arn)
@task(trigger_rule='all_done')
[docs]def clean_up(): dms_client = boto3.client('dms') _delete_rds_instance() _delete_dms_assets(dms_client) _await_all_teardowns(dms_client)
with DAG( dag_id='example_dms', schedule_interval=None, start_date=datetime(2021, 1, 1), tags=['example'], catchup=False, ) as dag: # [START howto_operator_dms_create_task]
[docs] create_task = DmsCreateTaskOperator( task_id='create_task', replication_task_id=DMS_REPLICATION_TASK_ID, source_endpoint_arn='{{ ti.xcom_pull(key="source_endpoint_arn") }}', target_endpoint_arn='{{ ti.xcom_pull(key="target_endpoint_arn") }}', replication_instance_arn='{{ ti.xcom_pull(key="replication_instance_arn") }}', table_mappings=TABLE_MAPPINGS,
) # [END howto_operator_dms_create_task] # [START howto_operator_dms_start_task] start_task = DmsStartTaskOperator( task_id='start_task', replication_task_arn=create_task.output, ) # [END howto_operator_dms_start_task] # [START howto_operator_dms_describe_tasks] describe_tasks = DmsDescribeTasksOperator( task_id='describe_tasks', describe_tasks_kwargs={ 'Filters': [ { 'Name': 'replication-instance-arn', 'Values': ['{{ ti.xcom_pull(key="replication_instance_arn") }}'], } ] }, do_xcom_push=False, ) # [END howto_operator_dms_describe_tasks] await_task_start = DmsTaskBaseSensor( task_id='await_task_start', replication_task_arn=create_task.output, target_statuses=['running'], termination_statuses=['stopped', 'deleting', 'failed'], ) # [START howto_operator_dms_stop_task] stop_task = DmsStopTaskOperator( task_id='stop_task', replication_task_arn=create_task.output, ) # [END howto_operator_dms_stop_task] # TaskCompletedSensor actually waits until task reaches the "Stopped" state, so it will work here. # [START howto_sensor_dms_task_completed] await_task_stop = DmsTaskCompletedSensor( task_id='await_task_stop', replication_task_arn=create_task.output, ) # [END howto_sensor_dms_task_completed] # [START howto_operator_dms_delete_task] delete_task = DmsDeleteTaskOperator( task_id='delete_task', replication_task_arn=create_task.output, trigger_rule='all_done', ) # [END howto_operator_dms_delete_task] chain( set_up() >> create_task >> start_task >> describe_tasks >> await_task_start >> stop_task >> await_task_stop >> delete_task >> clean_up() )

Was this entry helpful?