airflow.providers.google.cloud.utils.mlengine_prediction_summary

A template called by DataFlowPythonOperator to summarize BatchPrediction.

It accepts a user function to calculate the metric(s) per instance in the prediction results, then aggregates to output as a summary.

It accepts the following arguments:

  • --prediction_path: The GCS folder that contains BatchPrediction results, containing prediction.results-NNNNN-of-NNNNN files in the json format. Output will be also stored in this folder, as ‘prediction.summary.json’.

  • --metric_fn_encoded: An encoded function that calculates and returns a tuple of metric(s) for a given instance (as a dictionary). It should be encoded via base64.b64encode(dill.dumps(fn, recurse=True)).

  • --metric_keys: A comma-separated key(s) of the aggregated metric(s) in the summary output. The order and the size of the keys must match to the output of metric_fn. The summary will have an additional key, ‘count’, to represent the total number of instances, so the keys shouldn’t include ‘count’.

Usage example:

When the input file is like the following:

{"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]}
{"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]}
{"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]}
{"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]}

The output file will be:

{"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25}

To test outside of the dag:

subprocess.check_call(["python",
                       "-m",
                       "airflow.providers.google.cloud.utils.mlengine_prediction_summary",
                       "--prediction_path=gs://...",
                       "--metric_fn_encoded=" + metric_fn_encoded,
                       "--metric_keys=log_loss,mse",
                       "--runner=DataflowRunner",
                       "--staging_location=gs://...",
                       "--temp_location=gs://...",
                       ])

Module Contents

class airflow.providers.google.cloud.utils.mlengine_prediction_summary.JsonCoder[source]

JSON encoder/decoder.

static encode(x)[source]

JSON encoder.

static decode(x)[source]

JSON decoder.

airflow.providers.google.cloud.utils.mlengine_prediction_summary.MakeSummary(pcoll, metric_fn, metric_keys)[source]
Summary PTransform used in Dataflow.
airflow.providers.google.cloud.utils.mlengine_prediction_summary.run(argv=None)[source]
Helper for obtaining prediction summary.

Was this entry helpful?