#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Example Airflow DAG for Google ML Engine service.
"""
import os
from typing import Dict
from airflow import models
from airflow.operators.bash import BashOperator
from airflow.providers.google.cloud.operators.mlengine import (
MLEngineCreateModelOperator,
MLEngineCreateVersionOperator,
MLEngineDeleteModelOperator,
MLEngineDeleteVersionOperator,
MLEngineGetModelOperator,
MLEngineListVersionsOperator,
MLEngineSetDefaultVersionOperator,
MLEngineStartBatchPredictionJobOperator,
MLEngineStartTrainingJobOperator,
)
from airflow.providers.google.cloud.utils import mlengine_operator_utils
from airflow.utils.dates import days_ago
PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project")
MODEL_NAME = os.environ.get("GCP_MLENGINE_MODEL_NAME", "model_name")
SAVED_MODEL_PATH = os.environ.get("GCP_MLENGINE_SAVED_MODEL_PATH", "gs://INVALID BUCKET NAME/saved-model/")
JOB_DIR = os.environ.get("GCP_MLENGINE_JOB_DIR", "gs://INVALID BUCKET NAME/keras-job-dir")
PREDICTION_INPUT = os.environ.get(
"GCP_MLENGINE_PREDICTION_INPUT", "gs://INVALID BUCKET NAME/prediction_input.json"
)
PREDICTION_OUTPUT = os.environ.get(
"GCP_MLENGINE_PREDICTION_OUTPUT", "gs://INVALID BUCKET NAME/prediction_output"
)
TRAINER_URI = os.environ.get("GCP_MLENGINE_TRAINER_URI", "gs://INVALID BUCKET NAME/trainer.tar.gz")
TRAINER_PY_MODULE = os.environ.get("GCP_MLENGINE_TRAINER_TRAINER_PY_MODULE", "trainer.task")
SUMMARY_TMP = os.environ.get("GCP_MLENGINE_DATAFLOW_TMP", "gs://INVALID BUCKET NAME/tmp/")
SUMMARY_STAGING = os.environ.get("GCP_MLENGINE_DATAFLOW_STAGING", "gs://INVALID BUCKET NAME/staging/")
with models.DAG(
"example_gcp_mlengine",
schedule_interval=None, # Override to match your needs
start_date=days_ago(1),
tags=['example'],
params={"model_name": MODEL_NAME},
) as dag:
hyperparams = {
'goal': 'MAXIMIZE',
'hyperparameterMetricTag': 'metric1',
'maxTrials': 30,
'maxParallelTrials': 1,
'enableTrialEarlyStopping': True,
'params': [],
}
hyperparams['params'].append(
{
'parameterName': 'hidden1',
'type': 'INTEGER',
'minValue': 40,
'maxValue': 400,
'scaleType': 'UNIT_LINEAR_SCALE',
}
)
hyperparams['params'].append(
{'parameterName': 'numRnnCells', 'type': 'DISCRETE', 'discreteValues': [1, 2, 3, 4]}
)
hyperparams['params'].append(
{
'parameterName': 'rnnCellType',
'type': 'CATEGORICAL',
'categoricalValues': [
'BasicLSTMCell',
'BasicRNNCell',
'GRUCell',
'LSTMCell',
'LayerNormBasicLSTMCell',
],
}
)
# [START howto_operator_gcp_mlengine_training]
training = MLEngineStartTrainingJobOperator(
task_id="training",
project_id=PROJECT_ID,
region="us-central1",
job_id="training-job-{{ ts_nodash }}-{{ params.model_name }}",
runtime_version="1.15",
python_version="3.7",
job_dir=JOB_DIR,
package_uris=[TRAINER_URI],
training_python_module=TRAINER_PY_MODULE,
training_args=[],
labels={"job_type": "training"},
hyperparameters=hyperparams,
)
# [END howto_operator_gcp_mlengine_training]
# [START howto_operator_gcp_mlengine_create_model]
create_model = MLEngineCreateModelOperator(
task_id="create-model",
project_id=PROJECT_ID,
model={
"name": MODEL_NAME,
},
)
# [END howto_operator_gcp_mlengine_create_model]
# [START howto_operator_gcp_mlengine_get_model]
get_model = MLEngineGetModelOperator(
task_id="get-model",
project_id=PROJECT_ID,
model_name=MODEL_NAME,
)
# [END howto_operator_gcp_mlengine_get_model]
# [START howto_operator_gcp_mlengine_print_model]
get_model_result = BashOperator(
bash_command=f"echo {get_model.output}",
task_id="get-model-result",
)
# [END howto_operator_gcp_mlengine_print_model]
# [START howto_operator_gcp_mlengine_create_version1]
create_version = MLEngineCreateVersionOperator(
task_id="create-version",
project_id=PROJECT_ID,
model_name=MODEL_NAME,
version={
"name": "v1",
"description": "First-version",
"deployment_uri": f'{JOB_DIR}/keras_export/',
"runtime_version": "1.15",
"machineType": "mls1-c1-m2",
"framework": "TENSORFLOW",
"pythonVersion": "3.7",
},
)
# [END howto_operator_gcp_mlengine_create_version1]
# [START howto_operator_gcp_mlengine_create_version2]
create_version_2 = MLEngineCreateVersionOperator(
task_id="create-version-2",
project_id=PROJECT_ID,
model_name=MODEL_NAME,
version={
"name": "v2",
"description": "Second version",
"deployment_uri": SAVED_MODEL_PATH,
"runtime_version": "1.15",
"machineType": "mls1-c1-m2",
"framework": "TENSORFLOW",
"pythonVersion": "3.7",
},
)
# [END howto_operator_gcp_mlengine_create_version2]
# [START howto_operator_gcp_mlengine_default_version]
set_defaults_version = MLEngineSetDefaultVersionOperator(
task_id="set-default-version",
project_id=PROJECT_ID,
model_name=MODEL_NAME,
version_name="v2",
)
# [END howto_operator_gcp_mlengine_default_version]
# [START howto_operator_gcp_mlengine_list_versions]
list_version = MLEngineListVersionsOperator(
task_id="list-version",
project_id=PROJECT_ID,
model_name=MODEL_NAME,
)
# [END howto_operator_gcp_mlengine_list_versions]
# [START howto_operator_gcp_mlengine_print_versions]
list_version_result = BashOperator(
bash_command=f"echo {list_version.output}",
task_id="list-version-result",
)
# [END howto_operator_gcp_mlengine_print_versions]
# [START howto_operator_gcp_mlengine_get_prediction]
prediction = MLEngineStartBatchPredictionJobOperator(
task_id="prediction",
project_id=PROJECT_ID,
job_id="prediction-{{ ts_nodash }}-{{ params.model_name }}",
region="us-central1",
model_name=MODEL_NAME,
data_format="TEXT",
input_paths=[PREDICTION_INPUT],
output_path=PREDICTION_OUTPUT,
labels={"job_type": "prediction"},
)
# [END howto_operator_gcp_mlengine_get_prediction]
# [START howto_operator_gcp_mlengine_delete_version]
delete_version = MLEngineDeleteVersionOperator(
task_id="delete-version", project_id=PROJECT_ID, model_name=MODEL_NAME, version_name="v1"
)
# [END howto_operator_gcp_mlengine_delete_version]
# [START howto_operator_gcp_mlengine_delete_model]
delete_model = MLEngineDeleteModelOperator(
task_id="delete-model", project_id=PROJECT_ID, model_name=MODEL_NAME, delete_contents=True
)
# [END howto_operator_gcp_mlengine_delete_model]
training >> create_version
training >> create_version_2
create_model >> get_model >> [get_model_result, delete_model]
create_model >> get_model >> delete_model
create_model >> create_version >> create_version_2 >> set_defaults_version >> list_version
create_version >> prediction
create_version_2 >> prediction
prediction >> delete_version
list_version >> list_version_result
list_version >> delete_version
delete_version >> delete_model
# [START howto_operator_gcp_mlengine_get_metric]
def get_metric_fn_and_keys():
"""
Gets metric function and keys used to generate summary
"""
def normalize_value(inst: Dict):
val = float(inst['dense_4'][0])
return tuple([val]) # returns a tuple.
return normalize_value, ['val'] # key order must match.
# [END howto_operator_gcp_mlengine_get_metric]
# [START howto_operator_gcp_mlengine_validate_error]
def validate_err_and_count(summary: Dict) -> Dict:
"""
Validate summary result
"""
if summary['val'] > 1:
raise ValueError(f'Too high val>1; summary={summary}')
if summary['val'] < 0:
raise ValueError(f'Too low val<0; summary={summary}')
if summary['count'] != 20:
raise ValueError(f'Invalid value val != 20; summary={summary}')
return summary
# [END howto_operator_gcp_mlengine_validate_error]
# [START howto_operator_gcp_mlengine_evaluate]
evaluate_prediction, evaluate_summary, evaluate_validation = mlengine_operator_utils.create_evaluate_ops(
task_prefix="evaluate-ops",
data_format="TEXT",
input_paths=[PREDICTION_INPUT],
prediction_path=PREDICTION_OUTPUT,
metric_fn_and_keys=get_metric_fn_and_keys(),
validate_fn=validate_err_and_count,
batch_prediction_job_id="evaluate-ops-{{ ts_nodash }}-{{ params.model_name }}",
project_id=PROJECT_ID,
region="us-central1",
dataflow_options={
'project': PROJECT_ID,
'tempLocation': SUMMARY_TMP,
'stagingLocation': SUMMARY_STAGING,
},
model_name=MODEL_NAME,
version_name="v1",
py_interpreter="python3",
)
# [END howto_operator_gcp_mlengine_evaluate]
create_model >> create_version >> evaluate_prediction
evaluate_validation >> delete_version