Skip to content

Commit 6bd101a

Browse files
committed
Fix some review comments
1 parent d5a0e18 commit 6bd101a

1 file changed

Lines changed: 38 additions & 29 deletions

File tree

ml_engine/online_prediction/predict.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
# Copyright 2016 Google Inc. All Rights Reserved. Licensed under the Apache
1+
#!/bin/python
2+
# Copyright 2017 Google Inc. All Rights Reserved. Licensed under the Apache
23
# License, Version 2.0 (the "License"); you may not use this file except in
34
# compliance with the License. You may obtain a copy of the License at
45
# http://www.apache.org/licenses/LICENSE-2.0
@@ -8,37 +9,35 @@
89
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
910
# License for the specific language governing permissions and limitations under
1011
# the License.
12+
1113
"""Examples of using the Cloud ML Engine's online prediction service."""
1214
from __future__ import print_function
1315
# [START import_libraries]
1416
import googleapiclient.discovery
1517
# [END import_libraries]
1618

1719

18-
# [START authenticating]
19-
def get_ml_engine_service():
20-
"""Create the ML Engine service object.
21-
To authenticate set the environment variable
22-
GOOGLE_APPLICATION_CREDENTIALS=<path_to_service_account_file>
23-
"""
24-
return googleapiclient.discovery.build('ml', 'v1beta1')
25-
# [END authenticating]
26-
27-
2820
# [START predict_json]
2921
def predict_json(project, model, instances, version=None):
30-
"""Send data instances to a deployed model for prediction
22+
"""Send json data to a deployed model for prediction.
23+
3124
Args:
32-
project: str, project where the Cloud ML Engine Model is deployed.
33-
model: str, model name.
34-
instances: [dict], dictionaries from string keys defined by the model
35-
to data.
36-
version: [optional] str, version of the model to target.
25+
project (str): project where the Cloud ML Engine Model is deployed.
26+
model (str): model name.
27+
instances ([Mapping[str: any]]): dictionaries from string keys
28+
defined by the model deployment, to data with types that match
29+
expected tensors
30+
version: str, version of the model to target.
3731
Returns:
38-
A dictionary of prediction results defined by the model.
32+
Mapping[str: any]: dictionary of prediction results defined by the
33+
model.
3934
"""
40-
service = get_ml_engine_service()
35+
# Create the ML Engine service object.
36+
# To authenticate set the environment variable
37+
# GOOGLE_APPLICATION_CREDENTIALS=<path_to_service_account_file>
38+
service = googleapiclient.discovery.build('ml', 'v1beta1')
4139
name = 'projects/{}/models/{}'.format(project, model)
40+
4241
if version is not None:
4342
name += '/versions/{}'.format(version)
4443

@@ -58,30 +57,36 @@ def predict_json(project, model, instances, version=None):
5857
def predict_tf_records(project,
5958
model,
6059
example_bytes_list,
61-
key='tfrecord',
6260
version=None):
63-
"""Send data instances to a deployed model for prediction
61+
"""Send protocol buffer data to a deployed model for prediction.
62+
6463
Args:
65-
project: str, project where the Cloud ML Engine Model is deployed
66-
model: str, model name.
67-
example_bytes_list: [str], Serialized tf.train.Example protos.
64+
project (str): project where the Cloud ML Engine Model is deployed.
65+
model (str): model name.
66+
example_bytes_list ([str]): A list of bytestrings representing
67+
serialized tf.train.Example protocol buffers. The contents of this
68+
protocol buffer will change depending on the signature of your
69+
deployed model.
6870
version: str, version of the model to target.
6971
Returns:
70-
A dictionary of prediction results defined by the model.
72+
Mapping[str: any]: dictionary of prediction results defined by the
73+
model.
7174
"""
7275
import base64
73-
service = get_ml_engine_service()
76+
service = googleapiclient.discovery.build('ml', 'v1beta1')
7477
name = 'projects/{}/models/{}'.format(project, model)
78+
7579
if version is not None:
7680
name += '/versions/{}'.format(version)
7781

7882
response = service.projects().predict(
7983
name=name,
8084
body={'instances': [
81-
{key: {'b64': base64.b64encode(example_bytes)}}
85+
{'b64': base64.b64encode(example_bytes)}
8286
for example_bytes in example_bytes_list
8387
]}
8488
).execute()
89+
8590
if 'error' in response:
8691
raise RuntimeError(response['error'])
8792

@@ -93,10 +98,14 @@ def predict_tf_records(project,
9398
def census_to_example_bytes(json_instance):
9499
"""Serialize a JSON example to the bytes of a tf.train.Example.
95100
This method is specific to the signature of the Census example.
101+
See: https://cloud.google.com/ml-engine/docs/concepts/prediction-overview
102+
for details.
103+
96104
Args:
97-
json_instance: dict, representing data to be serialized.
105+
json_instance (Mapping[str: any]): representing data to be serialized.
98106
Returns:
99-
A string (as a container for bytes).
107+
str: A string as a container for the serialized bytes of
108+
tf.train.Example protocol buffer.
100109
"""
101110
import tensorflow as tf
102111
feature_dict = {}

0 commit comments

Comments
 (0)