Source code for airflow.providers.amazon.aws.example_dags.example_sagemaker_endpoint

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import os
from datetime import datetime

import boto3

from airflow import DAG
from airflow.decorators import task
from airflow.providers.amazon.aws.operators.s3 import S3CreateObjectOperator
from airflow.providers.amazon.aws.operators.sagemaker import (
    SageMakerDeleteModelOperator,
    SageMakerEndpointConfigOperator,
    SageMakerEndpointOperator,
    SageMakerModelOperator,
    SageMakerTrainingOperator,
)
from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerEndpointSensor

# Project name will be used in naming the S3 buckets and various tasks.
# The dataset used in this example is identifying varieties of the Iris flower.
[docs]PROJECT_NAME = 'iris'
[docs]TIMESTAMP = '{{ ts_nodash }}'
[docs]S3_BUCKET = os.getenv('S3_BUCKET', 'S3_bucket')
[docs]ROLE_ARN = os.getenv( 'SAGEMAKER_ROLE_ARN', 'arn:aws:iam::1234567890:role/service-role/AmazonSageMaker-ExecutionRole',
)
[docs]INPUT_DATA_S3_KEY = f'{PROJECT_NAME}/processed-input-data'
[docs]TRAINING_OUTPUT_S3_KEY = f'{PROJECT_NAME}/training-results'
[docs]MODEL_NAME = f'{PROJECT_NAME}-KNN-model'
[docs]ENDPOINT_NAME = f'{PROJECT_NAME}-endpoint'
# Job names can't be reused, so appending a timestamp ensures it is unique.
[docs]ENDPOINT_CONFIG_JOB_NAME = f'{PROJECT_NAME}-endpoint-config-{TIMESTAMP}'
[docs]TRAINING_JOB_NAME = f'{PROJECT_NAME}-train-{TIMESTAMP}'
# For an example of how to obtain the following train and test data, please see # https://github.com/apache/airflow/blob/main/airflow/providers/amazon/aws/example_dags/example_sagemaker.py
[docs]TRAIN_DATA = '0,4.9,2.5,4.5,1.7\n1,7.0,3.2,4.7,1.4\n0,7.3,2.9,6.3,1.8\n2,5.1,3.5,1.4,0.2\n'
[docs]SAMPLE_TEST_DATA = '6.4,3.2,4.5,1.5'
# The URI of an Amazon-provided docker image for handling KNN model training. This is a public ECR # repo cited in public SageMaker documentation, so the account number does not need to be redacted. # For more info see: https://docs.aws.amazon.com/sagemaker/latest/dg/ecr-us-west-2.html#knn-us-west-2.title
[docs]KNN_IMAGE_URI = '174872318107.dkr.ecr.us-west-2.amazonaws.com/knn'
# Define configs for processing, training, model creation, and batch transform jobs
[docs]TRAINING_CONFIG = { 'TrainingJobName': TRAINING_JOB_NAME, 'RoleArn': ROLE_ARN, 'AlgorithmSpecification': { "TrainingImage": KNN_IMAGE_URI, "TrainingInputMode": "File", }, 'HyperParameters': { 'predictor_type': 'classifier', 'feature_dim': '4', 'k': '3', 'sample_size': '6', }, 'InputDataConfig': [ { 'ChannelName': 'train', 'CompressionType': 'None', 'ContentType': 'text/csv', 'DataSource': { 'S3DataSource': { 'S3DataDistributionType': 'FullyReplicated', 'S3DataType': 'S3Prefix', 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}/train.csv', } }, } ], 'OutputDataConfig': {'S3OutputPath': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/'}, 'ResourceConfig': { 'InstanceCount': 1, 'InstanceType': 'ml.m5.large', 'VolumeSizeInGB': 1, }, 'StoppingCondition': {'MaxRuntimeInSeconds': 6 * 60},
}
[docs]MODEL_CONFIG = { 'ModelName': MODEL_NAME, 'ExecutionRoleArn': ROLE_ARN, 'PrimaryContainer': { 'Mode': 'SingleModel', 'Image': KNN_IMAGE_URI, 'ModelDataUrl': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/{TRAINING_JOB_NAME}/output/model.tar.gz',
}, }
[docs]ENDPOINT_CONFIG_CONFIG = { 'EndpointConfigName': ENDPOINT_CONFIG_JOB_NAME, 'ProductionVariants': [ { 'VariantName': f'{PROJECT_NAME}-demo', 'ModelName': MODEL_NAME, 'InstanceType': 'ml.t2.medium', 'InitialInstanceCount': 1,
}, ], }
[docs]DEPLOY_ENDPOINT_CONFIG = { 'EndpointName': ENDPOINT_NAME, 'EndpointConfigName': ENDPOINT_CONFIG_JOB_NAME,
} @task
[docs]def call_endpoint(): runtime = boto3.Session().client('sagemaker-runtime') response = runtime.invoke_endpoint( EndpointName=ENDPOINT_NAME, ContentType='text/csv', Body=SAMPLE_TEST_DATA, ) return json.loads(response["Body"].read().decode())['predictions']
@task(trigger_rule='all_done')
[docs]def cleanup(): # Delete Endpoint and Endpoint Config client = boto3.client('sagemaker') endpoint_config_name = client.list_endpoint_configs()['EndpointConfigs'][0]['EndpointConfigName'] client.delete_endpoint_config(EndpointConfigName=endpoint_config_name) client.delete_endpoint(EndpointName=ENDPOINT_NAME) # Delete S3 Artifacts client = boto3.client('s3') object_keys = [ key['Key'] for key in client.list_objects_v2(Bucket=S3_BUCKET, Prefix=PROJECT_NAME)['Contents'] ] for key in object_keys: client.delete_objects(Bucket=S3_BUCKET, Delete={'Objects': [{'Key': key}]})
with DAG( dag_id='example_sagemaker_endpoint', schedule_interval=None, start_date=datetime(2021, 1, 1), tags=['example'], catchup=False, ) as dag:
[docs] upload_data = S3CreateObjectOperator( task_id='upload_data', s3_bucket=S3_BUCKET, s3_key=f'{INPUT_DATA_S3_KEY}/train.csv', data=TRAIN_DATA, replace=True,
) train_model = SageMakerTrainingOperator( task_id='train_model', config=TRAINING_CONFIG, do_xcom_push=False, ) create_model = SageMakerModelOperator( task_id='create_model', config=MODEL_CONFIG, do_xcom_push=False, ) # [START howto_operator_sagemaker_endpoint_config] configure_endpoint = SageMakerEndpointConfigOperator( task_id='configure_endpoint', config=ENDPOINT_CONFIG_CONFIG, do_xcom_push=False, ) # [END howto_operator_sagemaker_endpoint_config] # [START howto_operator_sagemaker_endpoint] deploy_endpoint = SageMakerEndpointOperator( task_id='deploy_endpoint', config=DEPLOY_ENDPOINT_CONFIG, # Waits by default, <setting as False to demonstrate the Sensor below. wait_for_completion=False, do_xcom_push=False, ) # [END howto_operator_sagemaker_endpoint] # [START howto_sensor_sagemaker_endpoint] await_endpoint = SageMakerEndpointSensor( task_id='await_endpoint', endpoint_name=ENDPOINT_NAME, do_xcom_push=False, ) # [END howto_sensor_sagemaker_endpoint] # Trigger rule set to "all_done" so clean up will run regardless of success on other tasks. delete_model = SageMakerDeleteModelOperator( task_id='delete_model', config={'ModelName': MODEL_NAME}, trigger_rule='all_done', ) ( upload_data >> train_model >> create_model >> configure_endpoint >> deploy_endpoint >> await_endpoint >> call_endpoint() >> cleanup() >> delete_model )

Was this entry helpful?