This module contains helper functions for MLEngine operators.

Module Contents


create_evaluate_ops(task_prefix, data_format, ...[, ...])

Create Operators needed for model evaluation and returns.


T[source], data_format, input_paths, prediction_path, metric_fn_and_keys, validate_fn, batch_prediction_job_id=None, region=None, project_id=None, dataflow_options=None, model_uri=None, model_name=None, version_name=None, dag=None, py_interpreter='python3')[source]

Create Operators needed for model evaluation and returns.

This function is deprecated. All the functionality of legacy MLEngine and new features are available on the Vertex AI platform.

To create and view Model Evaluation, please check the documentation:

It gets prediction over inputs via Cloud ML Engine BatchPrediction API by calling MLEngineBatchPredictionOperator, then summarize and validate the result via Cloud Dataflow using DataFlowPythonOperator.

For details and pricing about Batch prediction, please refer to the website and for Cloud Dataflow,

It returns three chained operators for prediction, summary, and validation, named as <prefix>-prediction, <prefix>-summary, and <prefix>-validation, respectively. (<prefix> should contain only alphanumeric characters or hyphen.)

The upstream and downstream can be set accordingly like:

pred, _, val = create_evaluate_ops(...)

Callers will provide two python callables, metric_fn and validate_fn, in order to customize the evaluation behavior as they wish.

  • metric_fn receives a dictionary per instance derived from json in the batch prediction result. The keys might vary depending on the model. It should return a tuple of metrics.

  • validation_fn receives a dictionary of the averaged metrics that metric_fn generated over all instances. The key/value of the dictionary matches to what’s given by metric_fn_and_keys arg. The dictionary contains an additional metric, ‘count’ to represent the total number of instances received for evaluation. The function would raise an exception to mark the task as failed, in a case the validation result is not okay to proceed (i.e. to set the trained version as default).

Typical examples are like this:

def get_metric_fn_and_keys():
    import math  # imports should be outside of the metric_fn below.

    def error_and_squared_error(inst):
        label = float(inst["input_label"])
        classes = float(inst["classes"])  # 0 or 1
        err = abs(classes - label)
        squared_err = math.pow(classes - label, 2)
        return (err, squared_err)  # returns a tuple.

    return error_and_squared_error, ["err", "mse"]  # key order must match.

def validate_err_and_count(summary):
    if summary["err"] > 0.2:
        raise ValueError("Too high err>0.2; summary=%s" % summary)
    if summary["mse"] > 0.05:
        raise ValueError("Too high mse>0.05; summary=%s" % summary)
    if summary["count"] < 1000:
        raise ValueError("Too few instances<1000; summary=%s" % summary)
    return summary

For the details on the other BatchPrediction-related arguments (project_id, job_id, region, data_format, input_paths, prediction_path, model_uri), please refer to MLEngineBatchPredictionOperator too.

  • task_prefix (str) – a prefix for the tasks. Only alphanumeric characters and hyphen are allowed (no underscores), since this will be used as dataflow job name, which doesn’t allow other characters.

  • data_format (str) – either of ‘TEXT’, ‘TF_RECORD’, ‘TF_RECORD_GZIP’

  • input_paths (list[str]) – a list of input paths to be sent to BatchPrediction.

  • prediction_path (str) – GCS path to put the prediction results in.

  • metric_fn_and_keys (tuple[T, Iterable[str]]) –

    a tuple of metric_fn and metric_keys:

    • metric_fn is a function that accepts a dictionary (for an instance), and returns a tuple of metric(s) that it calculates.

    • metric_keys is a list of strings to denote the key of each metric.

  • validate_fn (T) – a function to validate whether the averaged metric(s) is good enough to push the model.

  • batch_prediction_job_id (str | None) – the id to use for the Cloud ML Batch prediction job. Passed directly to the MLEngineBatchPredictionOperator as the job_id argument.

  • project_id (str | None) – the Google Cloud project id in which to execute Cloud ML Batch Prediction and Dataflow jobs. If None, then the dag’s default_args[‘project_id’] will be used.

  • region (str | None) – the Google Cloud region in which to execute Cloud ML Batch Prediction and Dataflow jobs. If None, then the dag’s default_args[‘region’] will be used.

  • dataflow_options (dict | None) – options to run Dataflow jobs. If None, then the dag’s default_args[‘dataflow_default_options’] will be used.

  • model_uri (str | None) – GCS path of the model exported by Tensorflow using tensorflow.estimator.export_savedmodel(). It cannot be used with model_name or version_name below. See MLEngineBatchPredictionOperator for more detail.

  • model_name (str | None) – Used to indicate a model to use for prediction. Can be used in combination with version_name, but cannot be used together with model_uri. See MLEngineBatchPredictionOperator for more detail. If None, then the dag’s default_args[‘model_name’] will be used.

  • version_name (str | None) – Used to indicate a model version to use for prediction, in combination with model_name. Cannot be used together with model_uri. See MLEngineBatchPredictionOperator for more detail. If None, then the dag’s default_args[‘version_name’] will be used.

  • dag (airflow.DAG | None) – The DAG to use for all Operators.

  • py_interpreter – Python version of the beam pipeline. If None, this defaults to the python3. To track python versions supported by beam and related issues check:


a tuple of three operators, (prediction, summary, validation) PythonOperator)

Return type

tuple[, airflow.providers.apache.beam.operators.beam.BeamRunPythonPipelineOperator, airflow.operators.python.PythonOperator]

Was this entry helpful?