#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Example Airflow DAG for Google ML Engine service.
"""
import os
from datetime import datetime
from typing import Any , Dict
from airflow import models
from airflow.operators.bash import BashOperator
from airflow.providers.google.cloud.operators.mlengine import (
MLEngineCreateModelOperator ,
MLEngineCreateVersionOperator ,
MLEngineDeleteModelOperator ,
MLEngineDeleteVersionOperator ,
MLEngineGetModelOperator ,
MLEngineListVersionsOperator ,
MLEngineSetDefaultVersionOperator ,
MLEngineStartBatchPredictionJobOperator ,
MLEngineStartTrainingJobOperator ,
)
from airflow.providers.google.cloud.utils import mlengine_operator_utils
[docs] PROJECT_ID = os . environ . get ( "GCP_PROJECT_ID" , "example-project" )
[docs] MODEL_NAME = os . environ . get ( "GCP_MLENGINE_MODEL_NAME" , "model_name" )
[docs] SAVED_MODEL_PATH = os . environ . get ( "GCP_MLENGINE_SAVED_MODEL_PATH" , "gs://INVALID BUCKET NAME/saved-model/" )
[docs] JOB_DIR = os . environ . get ( "GCP_MLENGINE_JOB_DIR" , "gs://INVALID BUCKET NAME/keras-job-dir" )
)
[docs] PREDICTION_OUTPUT = os . environ . get (
"GCP_MLENGINE_PREDICTION_OUTPUT" , "gs://INVALID BUCKET NAME/prediction_output"
)
[docs] TRAINER_URI = os . environ . get ( "GCP_MLENGINE_TRAINER_URI" , "gs://INVALID BUCKET NAME/trainer.tar.gz" )
[docs] TRAINER_PY_MODULE = os . environ . get ( "GCP_MLENGINE_TRAINER_TRAINER_PY_MODULE" , "trainer.task" )
[docs] SUMMARY_TMP = os . environ . get ( "GCP_MLENGINE_DATAFLOW_TMP" , "gs://INVALID BUCKET NAME/tmp/" )
[docs] SUMMARY_STAGING = os . environ . get ( "GCP_MLENGINE_DATAFLOW_STAGING" , "gs://INVALID BUCKET NAME/staging/" )
with models . DAG (
"example_gcp_mlengine" ,
schedule_interval = '@once' , # Override to match your needs
start_date = datetime ( 2021 , 1 , 1 ),
catchup = False ,
tags = [ 'example' ],
params = { "model_name" : MODEL_NAME },
) as dag :
[docs] hyperparams : Dict [ str , Any ] = {
'goal' : 'MAXIMIZE' ,
'hyperparameterMetricTag' : 'metric1' ,
'maxTrials' : 30 ,
'maxParallelTrials' : 1 ,
'enableTrialEarlyStopping' : True ,
'params' : [],
}
hyperparams [ 'params' ] . append (
{
'parameterName' : 'hidden1' ,
'type' : 'INTEGER' ,
'minValue' : 40 ,
'maxValue' : 400 ,
'scaleType' : 'UNIT_LINEAR_SCALE' ,
}
)
hyperparams [ 'params' ] . append (
{ 'parameterName' : 'numRnnCells' , 'type' : 'DISCRETE' , 'discreteValues' : [ 1 , 2 , 3 , 4 ]}
)
hyperparams [ 'params' ] . append (
{
'parameterName' : 'rnnCellType' ,
'type' : 'CATEGORICAL' ,
'categoricalValues' : [
'BasicLSTMCell' ,
'BasicRNNCell' ,
'GRUCell' ,
'LSTMCell' ,
'LayerNormBasicLSTMCell' ,
],
}
)
# [START howto_operator_gcp_mlengine_training]
training = MLEngineStartTrainingJobOperator (
task_id = "training" ,
project_id = PROJECT_ID ,
region = "us-central1" ,
job_id = "training-job-{{ ts_nodash }}-{{ params.model_name }}" ,
runtime_version = "1.15" ,
python_version = "3.7" ,
job_dir = JOB_DIR ,
package_uris = [ TRAINER_URI ],
training_python_module = TRAINER_PY_MODULE ,
training_args = [],
labels = { "job_type" : "training" },
hyperparameters = hyperparams ,
)
# [END howto_operator_gcp_mlengine_training]
# [START howto_operator_gcp_mlengine_create_model]
create_model = MLEngineCreateModelOperator (
task_id = "create-model" ,
project_id = PROJECT_ID ,
model = {
"name" : MODEL_NAME ,
},
)
# [END howto_operator_gcp_mlengine_create_model]
# [START howto_operator_gcp_mlengine_get_model]
get_model = MLEngineGetModelOperator (
task_id = "get-model" ,
project_id = PROJECT_ID ,
model_name = MODEL_NAME ,
)
# [END howto_operator_gcp_mlengine_get_model]
# [START howto_operator_gcp_mlengine_print_model]
get_model_result = BashOperator (
bash_command = f "echo { get_model . output } " ,
task_id = "get-model-result" ,
)
# [END howto_operator_gcp_mlengine_print_model]
# [START howto_operator_gcp_mlengine_create_version1]
create_version = MLEngineCreateVersionOperator (
task_id = "create-version" ,
project_id = PROJECT_ID ,
model_name = MODEL_NAME ,
version = {
"name" : "v1" ,
"description" : "First-version" ,
"deployment_uri" : f ' { JOB_DIR } /keras_export/' ,
"runtime_version" : "1.15" ,
"machineType" : "mls1-c1-m2" ,
"framework" : "TENSORFLOW" ,
"pythonVersion" : "3.7" ,
},
)
# [END howto_operator_gcp_mlengine_create_version1]
# [START howto_operator_gcp_mlengine_create_version2]
create_version_2 = MLEngineCreateVersionOperator (
task_id = "create-version-2" ,
project_id = PROJECT_ID ,
model_name = MODEL_NAME ,
version = {
"name" : "v2" ,
"description" : "Second version" ,
"deployment_uri" : SAVED_MODEL_PATH ,
"runtime_version" : "1.15" ,
"machineType" : "mls1-c1-m2" ,
"framework" : "TENSORFLOW" ,
"pythonVersion" : "3.7" ,
},
)
# [END howto_operator_gcp_mlengine_create_version2]
# [START howto_operator_gcp_mlengine_default_version]
set_defaults_version = MLEngineSetDefaultVersionOperator (
task_id = "set-default-version" ,
project_id = PROJECT_ID ,
model_name = MODEL_NAME ,
version_name = "v2" ,
)
# [END howto_operator_gcp_mlengine_default_version]
# [START howto_operator_gcp_mlengine_list_versions]
list_version = MLEngineListVersionsOperator (
task_id = "list-version" ,
project_id = PROJECT_ID ,
model_name = MODEL_NAME ,
)
# [END howto_operator_gcp_mlengine_list_versions]
# [START howto_operator_gcp_mlengine_print_versions]
list_version_result = BashOperator (
bash_command = f "echo { list_version . output } " ,
task_id = "list-version-result" ,
)
# [END howto_operator_gcp_mlengine_print_versions]
# [START howto_operator_gcp_mlengine_get_prediction]
prediction = MLEngineStartBatchPredictionJobOperator (
task_id = "prediction" ,
project_id = PROJECT_ID ,
job_id = "prediction-{{ ts_nodash }}-{{ params.model_name }}" ,
region = "us-central1" ,
model_name = MODEL_NAME ,
data_format = "TEXT" ,
input_paths = [ PREDICTION_INPUT ],
output_path = PREDICTION_OUTPUT ,
labels = { "job_type" : "prediction" },
)
# [END howto_operator_gcp_mlengine_get_prediction]
# [START howto_operator_gcp_mlengine_delete_version]
delete_version = MLEngineDeleteVersionOperator (
task_id = "delete-version" , project_id = PROJECT_ID , model_name = MODEL_NAME , version_name = "v1"
)
# [END howto_operator_gcp_mlengine_delete_version]
# [START howto_operator_gcp_mlengine_delete_model]
delete_model = MLEngineDeleteModelOperator (
task_id = "delete-model" , project_id = PROJECT_ID , model_name = MODEL_NAME , delete_contents = True
)
# [END howto_operator_gcp_mlengine_delete_model]
training >> create_version
training >> create_version_2
create_model >> get_model >> [ get_model_result , delete_model ]
create_model >> get_model >> delete_model
create_model >> create_version >> create_version_2 >> set_defaults_version >> list_version
create_version >> prediction
create_version_2 >> prediction
prediction >> delete_version
list_version >> list_version_result
list_version >> delete_version
delete_version >> delete_model
# [START howto_operator_gcp_mlengine_get_metric]
def get_metric_fn_and_keys ():
"""
Gets metric function and keys used to generate summary
"""
def normalize_value ( inst : Dict ):
val = float ( inst [ 'dense_4' ][ 0 ])
return tuple ([ val ]) # returns a tuple.
return normalize_value , [ 'val' ] # key order must match.
# [END howto_operator_gcp_mlengine_get_metric]
# [START howto_operator_gcp_mlengine_validate_error]
def validate_err_and_count ( summary : Dict ) -> Dict :
"""
Validate summary result
"""
if summary [ 'val' ] > 1 :
raise ValueError ( f 'Too high val>1; summary= { summary } ' )
if summary [ 'val' ] < 0 :
raise ValueError ( f 'Too low val<0; summary= { summary } ' )
if summary [ 'count' ] != 20 :
raise ValueError ( f 'Invalid value val != 20; summary= { summary } ' )
return summary
# [END howto_operator_gcp_mlengine_validate_error]
# [START howto_operator_gcp_mlengine_evaluate]
evaluate_prediction , evaluate_summary , evaluate_validation = mlengine_operator_utils . create_evaluate_ops (
task_prefix = "evaluate-ops" ,
data_format = "TEXT" ,
input_paths = [ PREDICTION_INPUT ],
prediction_path = PREDICTION_OUTPUT ,
metric_fn_and_keys = get_metric_fn_and_keys (),
validate_fn = validate_err_and_count ,
batch_prediction_job_id = "evaluate-ops-{{ ts_nodash }}-{{ params.model_name }}" ,
project_id = PROJECT_ID ,
region = "us-central1" ,
dataflow_options = {
'project' : PROJECT_ID ,
'tempLocation' : SUMMARY_TMP ,
'stagingLocation' : SUMMARY_STAGING ,
},
model_name = MODEL_NAME ,
version_name = "v1" ,
py_interpreter = "python3" ,
)
# [END howto_operator_gcp_mlengine_evaluate]
create_model >> create_version >> evaluate_prediction
evaluate_validation >> delete_version
Copy to clipboard