Source code for tests.system.providers.amazon.aws.example_sagemaker_endpoint
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.importjsonfromdatetimeimportdatetimeimportboto3fromairflowimportDAGfromairflow.decoratorsimporttaskfromairflow.models.baseoperatorimportchainfromairflow.providers.amazon.aws.operators.s3import(S3CreateBucketOperator,S3CreateObjectOperator,S3DeleteBucketOperator,)fromairflow.providers.amazon.aws.operators.sagemakerimport(SageMakerDeleteModelOperator,SageMakerEndpointConfigOperator,SageMakerEndpointOperator,SageMakerModelOperator,SageMakerTrainingOperator,)fromairflow.providers.amazon.aws.sensors.sagemakerimportSageMakerEndpointSensorfromairflow.utils.trigger_ruleimportTriggerRulefromtests.system.providers.amazon.aws.utilsimportENV_ID_KEY,SystemTestContextBuilder,purge_logs
# The URI of a Docker image for handling KNN model training.# To find the URI of a free Amazon-provided image that can be used, substitute your# desired region in the following link and find the URI under "Registry Path".# https://docs.aws.amazon.com/sagemaker/latest/dg/ecr-us-east-1.html#knn-us-east-1.title# This URI should be in the format of {12-digits}.dkr.ecr.{region}.amazonaws.com/knn
)# For an example of how to obtain the following train and test data, please see# https://github.com/apache/airflow/blob/main/airflow/providers/amazon/aws/example_dags/example_sagemaker.py
test_setup=set_up(env_id=test_context[ENV_ID_KEY],knn_image_uri=test_context[KNN_IMAGE_URI_KEY],role_arn=test_context[ROLE_ARN_KEY],)create_bucket=S3CreateBucketOperator(task_id='create_bucket',bucket_name=test_setup['bucket_name'],)upload_data=S3CreateObjectOperator(task_id='upload_data',s3_bucket=test_setup['bucket_name'],s3_key=f'{test_setup["input_data_s3_key"]}/train.csv',data=TRAIN_DATA,)train_model=SageMakerTrainingOperator(task_id='train_model',config=test_setup['training_config'],do_xcom_push=False,)create_model=SageMakerModelOperator(task_id='create_model',config=test_setup['model_config'],do_xcom_push=False,)# [START howto_operator_sagemaker_endpoint_config]configure_endpoint=SageMakerEndpointConfigOperator(task_id='configure_endpoint',config=test_setup['endpoint_config_config'],do_xcom_push=False,)# [END howto_operator_sagemaker_endpoint_config]# [START howto_operator_sagemaker_endpoint]deploy_endpoint=SageMakerEndpointOperator(task_id='deploy_endpoint',config=test_setup['deploy_endpoint_config'],# Waits by default, setting as False to demonstrate the Sensor below.wait_for_completion=False,do_xcom_push=False,)# [END howto_operator_sagemaker_endpoint]# [START howto_sensor_sagemaker_endpoint]await_endpoint=SageMakerEndpointSensor(task_id='await_endpoint',endpoint_name=test_setup['endpoint_name'],)# [END howto_sensor_sagemaker_endpoint]delete_model=SageMakerDeleteModelOperator(task_id='delete_model',trigger_rule=TriggerRule.ALL_DONE,config={'ModelName':test_setup['model_name']},)delete_bucket=S3DeleteBucketOperator(task_id='delete_bucket',trigger_rule=TriggerRule.ALL_DONE,bucket_name=test_setup['bucket_name'],force_delete=True,)chain(# TEST SETUPtest_context,test_setup,create_bucket,upload_data,# TEST BODYtrain_model,create_model,configure_endpoint,deploy_endpoint,await_endpoint,call_endpoint(test_setup['endpoint_name']),# TEST TEARDOWNdelete_endpoint_config(test_setup['endpoint_config_job_name']),delete_endpoint(test_setup['endpoint_name']),delete_model,delete_bucket,delete_logs(test_context[ENV_ID_KEY],test_setup['endpoint_name']),)fromtests.system.utils.watcherimportwatcher# This test needs watcher in order to properly mark success/failure# when "tearDown" task with trigger rule is part of the DAGlist(dag.tasks)>>watcher()fromtests.system.utilsimportget_test_run# noqa: E402# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)