# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import json
from datetime import datetime
import boto3
from airflow import DAG
from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.providers.amazon.aws.operators.s3 import (
S3CreateBucketOperator,
S3CreateObjectOperator,
S3DeleteBucketOperator,
)
from airflow.providers.amazon.aws.operators.sagemaker import (
SageMakerDeleteModelOperator,
SageMakerEndpointConfigOperator,
SageMakerEndpointOperator,
SageMakerModelOperator,
SageMakerTrainingOperator,
)
from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerEndpointSensor
from airflow.utils.trigger_rule import TriggerRule
from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder, purge_logs
[docs]DAG_ID = "example_sagemaker_endpoint"
# Externally fetched variables:
[docs]ROLE_ARN_KEY = "ROLE_ARN"
[docs]sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()
# The URI of a Docker image for handling KNN model training.
# To find the URI of a free Amazon-provided image that can be used, substitute your
# desired region in the following link and find the URI under "Registry Path".
# https://docs.aws.amazon.com/sagemaker/latest/dg/ecr-us-east-1.html#knn-us-east-1.title
# This URI should be in the format of {12-digits}.dkr.ecr.{region}.amazonaws.com/knn
[docs]KNN_IMAGES_BY_REGION = {
"us-east-1": "382416733822.dkr.ecr.us-east-1.amazonaws.com/knn:1",
"us-west-2": "174872318107.dkr.ecr.us-west-2.amazonaws.com/knn:1",
}
# For an example of how to obtain the following train and test data, please see
# https://github.com/apache/airflow/blob/main/airflow/providers/amazon/aws/example_dags/example_sagemaker.py
[docs]TRAIN_DATA = "0,4.9,2.5,4.5,1.7\n1,7.0,3.2,4.7,1.4\n0,7.3,2.9,6.3,1.8\n2,5.1,3.5,1.4,0.2\n"
[docs]SAMPLE_TEST_DATA = "6.4,3.2,4.5,1.5"
@task
[docs]def call_endpoint(endpoint_name):
response = (
boto3.Session()
.client("sagemaker-runtime")
.invoke_endpoint(
EndpointName=endpoint_name,
ContentType="text/csv",
Body=SAMPLE_TEST_DATA,
)
)
return json.loads(response["Body"].read().decode())["predictions"]
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs]def delete_endpoint_config(endpoint_config_job_name):
boto3.client("sagemaker").delete_endpoint_config(EndpointConfigName=endpoint_config_job_name)
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs]def delete_endpoint(endpoint_name):
boto3.client("sagemaker").delete_endpoint(EndpointName=endpoint_name)
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs]def delete_logs(env_id, endpoint_name):
purge_logs(
[
# Format: ('log group name', 'log stream prefix')
("/aws/sagemaker/TrainingJobs", env_id),
]
)
purge_logs(test_logs=[(f"/aws/sagemaker/Endpoints/{endpoint_name}", None)], force_delete=True)
@task
[docs]def set_up(env_id, role_arn, ti=None):
bucket_name = f"{env_id}-sagemaker"
input_data_s3_key = f"{env_id}/input-data"
training_output_s3_key = f"{env_id}/results"
endpoint_config_job_name = f"{env_id}-endpoint-config"
endpoint_name = f"{env_id}-endpoint"
model_name = f"{env_id}-KNN-model"
training_job_name = f"{env_id}-train"
region = boto3.session.Session().region_name
try:
knn_image_uri = KNN_IMAGES_BY_REGION[region]
except KeyError:
raise KeyError(
f"Region name {region} does not have a known KNN "
f"Image URI. Please add the region and URI following "
f"the directions at the top of the system testfile "
)
training_config = {
"TrainingJobName": training_job_name,
"RoleArn": role_arn,
"AlgorithmSpecification": {
"TrainingImage": knn_image_uri,
"TrainingInputMode": "File",
},
"HyperParameters": {
"predictor_type": "classifier",
"feature_dim": "4",
"k": "3",
"sample_size": str(TRAIN_DATA.count("\n") - 1),
},
"InputDataConfig": [
{
"ChannelName": "train",
"CompressionType": "None",
"ContentType": "text/csv",
"DataSource": {
"S3DataSource": {
"S3DataDistributionType": "FullyReplicated",
"S3DataType": "S3Prefix",
"S3Uri": f"s3://{bucket_name}/{input_data_s3_key}/train.csv",
}
},
}
],
"OutputDataConfig": {"S3OutputPath": f"s3://{bucket_name}/{training_output_s3_key}/"},
"ResourceConfig": {
"InstanceCount": 1,
"InstanceType": "ml.m5.large",
"VolumeSizeInGB": 1,
},
"StoppingCondition": {"MaxRuntimeInSeconds": 6 * 60},
}
model_config = {
"ModelName": model_name,
"ExecutionRoleArn": role_arn,
"PrimaryContainer": {
"Mode": "SingleModel",
"Image": knn_image_uri,
"ModelDataUrl": f"s3://{bucket_name}/{training_output_s3_key}/{training_job_name}/output/model.tar.gz", # noqa: E501
},
}
endpoint_config_config = {
"EndpointConfigName": endpoint_config_job_name,
"ProductionVariants": [
{
"VariantName": f"{env_id}-demo",
"ModelName": model_name,
"InstanceType": "ml.t2.medium",
"InitialInstanceCount": 1,
},
],
}
deploy_endpoint_config = {
"EndpointName": endpoint_name,
"EndpointConfigName": endpoint_config_job_name,
}
ti.xcom_push(key="bucket_name", value=bucket_name)
ti.xcom_push(key="input_data_s3_key", value=input_data_s3_key)
ti.xcom_push(key="model_name", value=model_name)
ti.xcom_push(key="endpoint_name", value=endpoint_name)
ti.xcom_push(key="endpoint_config_job_name", value=endpoint_config_job_name)
ti.xcom_push(key="training_config", value=training_config)
ti.xcom_push(key="model_config", value=model_config)
ti.xcom_push(key="endpoint_config_config", value=endpoint_config_config)
ti.xcom_push(key="deploy_endpoint_config", value=deploy_endpoint_config)
with DAG(
dag_id=DAG_ID,
schedule="@once",
start_date=datetime(2021, 1, 1),
tags=["example"],
catchup=False,
) as dag:
[docs] test_context = sys_test_context_task()
test_setup = set_up(
env_id=test_context[ENV_ID_KEY],
role_arn=test_context[ROLE_ARN_KEY],
)
create_bucket = S3CreateBucketOperator(
task_id="create_bucket",
bucket_name=test_setup["bucket_name"],
)
upload_data = S3CreateObjectOperator(
task_id="upload_data",
s3_bucket=test_setup["bucket_name"],
s3_key=f'{test_setup["input_data_s3_key"]}/train.csv',
data=TRAIN_DATA,
)
train_model = SageMakerTrainingOperator(
task_id="train_model",
config=test_setup["training_config"],
)
create_model = SageMakerModelOperator(
task_id="create_model",
config=test_setup["model_config"],
)
# [START howto_operator_sagemaker_endpoint_config]
configure_endpoint = SageMakerEndpointConfigOperator(
task_id="configure_endpoint",
config=test_setup["endpoint_config_config"],
)
# [END howto_operator_sagemaker_endpoint_config]
# [START howto_operator_sagemaker_endpoint]
deploy_endpoint = SageMakerEndpointOperator(
task_id="deploy_endpoint",
config=test_setup["deploy_endpoint_config"],
)
# [END howto_operator_sagemaker_endpoint]
# SageMakerEndpointOperator waits by default, setting as False to test the Sensor below.
deploy_endpoint.wait_for_completion = False
# [START howto_sensor_sagemaker_endpoint]
await_endpoint = SageMakerEndpointSensor(
task_id="await_endpoint",
endpoint_name=test_setup["endpoint_name"],
)
# [END howto_sensor_sagemaker_endpoint]
delete_model = SageMakerDeleteModelOperator(
task_id="delete_model",
trigger_rule=TriggerRule.ALL_DONE,
config={"ModelName": test_setup["model_name"]},
)
delete_bucket = S3DeleteBucketOperator(
task_id="delete_bucket",
trigger_rule=TriggerRule.ALL_DONE,
bucket_name=test_setup["bucket_name"],
force_delete=True,
)
chain(
# TEST SETUP
test_context,
test_setup,
create_bucket,
upload_data,
# TEST BODY
train_model,
create_model,
configure_endpoint,
deploy_endpoint,
await_endpoint,
call_endpoint(test_setup["endpoint_name"]),
# TEST TEARDOWN
delete_endpoint_config(test_setup["endpoint_config_job_name"]),
delete_endpoint(test_setup["endpoint_name"]),
delete_model,
delete_bucket,
delete_logs(test_context[ENV_ID_KEY], test_setup["endpoint_name"]),
)
from tests.system.utils.watcher import watcher
# This test needs watcher in order to properly mark success/failure
# when "tearDown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()
from tests.system.utils import get_test_run # noqa: E402
# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
[docs]test_run = get_test_run(dag)