#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Example Airflow DAG for Google ML Engine service.
"""
import os
from typing import Dict
from airflow import models
from airflow.operators.bash import BashOperator
from airflow.providers.google.cloud.operators.mlengine import (
MLEngineCreateModelOperator,
MLEngineCreateVersionOperator,
MLEngineDeleteModelOperator,
MLEngineDeleteVersionOperator,
MLEngineGetModelOperator,
MLEngineListVersionsOperator,
MLEngineSetDefaultVersionOperator,
MLEngineStartBatchPredictionJobOperator,
MLEngineStartTrainingJobOperator,
)
from airflow.providers.google.cloud.utils import mlengine_operator_utils
from airflow.utils.dates import days_ago
PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project")
MODEL_NAME = os.environ.get("GCP_MLENGINE_MODEL_NAME", "model_name")
SAVED_MODEL_PATH = os.environ.get("GCP_MLENGINE_SAVED_MODEL_PATH", "gs://test-airflow-mlengine/saved-model/")
JOB_DIR = os.environ.get("GCP_MLENGINE_JOB_DIR", "gs://test-airflow-mlengine/keras-job-dir")
PREDICTION_INPUT = os.environ.get(
"GCP_MLENGINE_PREDICTION_INPUT", "gs://test-airflow-mlengine/prediction_input.json"
)
PREDICTION_OUTPUT = os.environ.get(
"GCP_MLENGINE_PREDICTION_OUTPUT", "gs://test-airflow-mlengine/prediction_output"
)
TRAINER_URI = os.environ.get("GCP_MLENGINE_TRAINER_URI", "gs://test-airflow-mlengine/trainer.tar.gz")
TRAINER_PY_MODULE = os.environ.get("GCP_MLENGINE_TRAINER_TRAINER_PY_MODULE", "trainer.task")
SUMMARY_TMP = os.environ.get("GCP_MLENGINE_DATAFLOW_TMP", "gs://test-airflow-mlengine/tmp/")
SUMMARY_STAGING = os.environ.get("GCP_MLENGINE_DATAFLOW_STAGING", "gs://test-airflow-mlengine/staging/")
default_args = {"params": {"model_name": MODEL_NAME}}
with models.DAG(
"example_gcp_mlengine",
schedule_interval=None, # Override to match your needs
start_date=days_ago(1),
tags=['example'],
) as dag:
# [START howto_operator_gcp_mlengine_training]
training = MLEngineStartTrainingJobOperator(
task_id="training",
project_id=PROJECT_ID,
region="us-central1",
job_id="training-job-{{ ts_nodash }}-{{ params.model_name }}",
runtime_version="1.15",
python_version="3.7",
job_dir=JOB_DIR,
package_uris=[TRAINER_URI],
training_python_module=TRAINER_PY_MODULE,
training_args=[],
labels={"job_type": "training"},
)
# [END howto_operator_gcp_mlengine_training]
# [START howto_operator_gcp_mlengine_create_model]
create_model = MLEngineCreateModelOperator(
task_id="create-model",
project_id=PROJECT_ID,
model={
"name": MODEL_NAME,
},
)
# [END howto_operator_gcp_mlengine_create_model]
# [START howto_operator_gcp_mlengine_get_model]
get_model = MLEngineGetModelOperator(
task_id="get-model",
project_id=PROJECT_ID,
model_name=MODEL_NAME,
)
# [END howto_operator_gcp_mlengine_get_model]
# [START howto_operator_gcp_mlengine_print_model]
get_model_result = BashOperator(
bash_command="echo \"{{ task_instance.xcom_pull('get-model') }}\"",
task_id="get-model-result",
)
# [END howto_operator_gcp_mlengine_print_model]
# [START howto_operator_gcp_mlengine_create_version1]
create_version = MLEngineCreateVersionOperator(
task_id="create-version",
project_id=PROJECT_ID,
model_name=MODEL_NAME,
version={
"name": "v1",
"description": "First-version",
"deployment_uri": f'{JOB_DIR}/keras_export/',
"runtime_version": "1.15",
"machineType": "mls1-c1-m2",
"framework": "TENSORFLOW",
"pythonVersion": "3.7",
},
)
# [END howto_operator_gcp_mlengine_create_version1]
# [START howto_operator_gcp_mlengine_create_version2]
create_version_2 = MLEngineCreateVersionOperator(
task_id="create-version-2",
project_id=PROJECT_ID,
model_name=MODEL_NAME,
version={
"name": "v2",
"description": "Second version",
"deployment_uri": SAVED_MODEL_PATH,
"runtime_version": "1.15",
"machineType": "mls1-c1-m2",
"framework": "TENSORFLOW",
"pythonVersion": "3.7",
},
)
# [END howto_operator_gcp_mlengine_create_version2]
# [START howto_operator_gcp_mlengine_default_version]
set_defaults_version = MLEngineSetDefaultVersionOperator(
task_id="set-default-version",
project_id=PROJECT_ID,
model_name=MODEL_NAME,
version_name="v2",
)
# [END howto_operator_gcp_mlengine_default_version]
# [START howto_operator_gcp_mlengine_list_versions]
list_version = MLEngineListVersionsOperator(
task_id="list-version",
project_id=PROJECT_ID,
model_name=MODEL_NAME,
)
# [END howto_operator_gcp_mlengine_list_versions]
# [START howto_operator_gcp_mlengine_print_versions]
list_version_result = BashOperator(
bash_command="echo \"{{ task_instance.xcom_pull('list-version') }}\"",
task_id="list-version-result",
)
# [END howto_operator_gcp_mlengine_print_versions]
# [START howto_operator_gcp_mlengine_get_prediction]
prediction = MLEngineStartBatchPredictionJobOperator(
task_id="prediction",
project_id=PROJECT_ID,
job_id="prediction-{{ ts_nodash }}-{{ params.model_name }}",
region="us-central1",
model_name=MODEL_NAME,
data_format="TEXT",
input_paths=[PREDICTION_INPUT],
output_path=PREDICTION_OUTPUT,
labels={"job_type": "prediction"},
)
# [END howto_operator_gcp_mlengine_get_prediction]
# [START howto_operator_gcp_mlengine_delete_version]
delete_version = MLEngineDeleteVersionOperator(
task_id="delete-version", project_id=PROJECT_ID, model_name=MODEL_NAME, version_name="v1"
)
# [END howto_operator_gcp_mlengine_delete_version]
# [START howto_operator_gcp_mlengine_delete_model]
delete_model = MLEngineDeleteModelOperator(
task_id="delete-model", project_id=PROJECT_ID, model_name=MODEL_NAME, delete_contents=True
)
# [END howto_operator_gcp_mlengine_delete_model]
training >> create_version
training >> create_version_2
create_model >> get_model >> [get_model_result, delete_model]
create_model >> create_version >> create_version_2 >> set_defaults_version >> list_version
create_version >> prediction
create_version_2 >> prediction
prediction >> delete_version
list_version >> list_version_result
list_version >> delete_version
delete_version >> delete_model
# [START howto_operator_gcp_mlengine_get_metric]
def get_metric_fn_and_keys():
"""
Gets metric function and keys used to generate summary
"""
def normalize_value(inst: Dict):
val = float(inst['dense_4'][0])
return tuple([val]) # returns a tuple.
return normalize_value, ['val'] # key order must match.
# [END howto_operator_gcp_mlengine_get_metric]
# [START howto_operator_gcp_mlengine_validate_error]
def validate_err_and_count(summary: Dict) -> Dict:
"""
Validate summary result
"""
if summary['val'] > 1:
raise ValueError(f'Too high val>1; summary={summary}')
if summary['val'] < 0:
raise ValueError(f'Too low val<0; summary={summary}')
if summary['count'] != 20:
raise ValueError(f'Invalid value val != 20; summary={summary}')
return summary
# [END howto_operator_gcp_mlengine_validate_error]
# [START howto_operator_gcp_mlengine_evaluate]
evaluate_prediction, evaluate_summary, evaluate_validation = mlengine_operator_utils.create_evaluate_ops(
task_prefix="evaluate-ops",
data_format="TEXT",
input_paths=[PREDICTION_INPUT],
prediction_path=PREDICTION_OUTPUT,
metric_fn_and_keys=get_metric_fn_and_keys(),
validate_fn=validate_err_and_count,
batch_prediction_job_id="evaluate-ops-{{ ts_nodash }}-{{ params.model_name }}",
project_id=PROJECT_ID,
region="us-central1",
dataflow_options={
'project': PROJECT_ID,
'tempLocation': SUMMARY_TMP,
'stagingLocation': SUMMARY_STAGING,
},
model_name=MODEL_NAME,
version_name="v1",
py_interpreter="python3",
)
# [END howto_operator_gcp_mlengine_evaluate]
create_model >> create_version >> evaluate_prediction
evaluate_validation >> delete_version