Skip to content

Commit 8c3da38

Browse files
committed
Fix review comments
1 parent 5ca6773 commit 8c3da38

File tree

4 files changed

+30
-58
lines changed

4 files changed

+30
-58
lines changed
Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1 @@
1-
# Online Prediction with the Cloud Machine Learning Engine
2-
3-
This sample assumes that you have already run the Cloud ML Engine end to end walkthrough using the UCI Census dataset, including training a wide & deep model and deploying it into production, ready to take prediction requests.
4-
5-
This sample then shows you how to create a client that takes user JSON inputs, and sends those inputs as online prediction requests to a given deployed model.
6-
7-
In order to use this client, first obtain the following information, and store it into the given environment variables:
8-
9-
```
10-
PROJECT=<your project name>
11-
MODEL=<your model name>
12-
VERSION=<version of model you are using>
13-
```
14-
15-
Next, launch this client as follows:
16-
17-
```
18-
python predict.py --project=$PROJECT --model=$MODEL --version=$VERSION
19-
```
20-
21-
After having done that, the client will ask you for ‘Valid JSON’ input as follows:
22-
23-
```
24-
Valid JSON >>>
25-
```
26-
27-
Now you can input a JSON example that corresponds to the schema of the given model. For instance if you are sending prediction requests to the census-based model created in the Cloud ML Engine walkthrough, you can send a JSON example like the following:
28-
29-
```
30-
{"age": 25, "workclass": " Private", "education": " 11th", "education_num": 7, "marital_status": " Never-married", "occupation": " Machine-op-inspct", "relationship": " Own-child", "race": " Black", "gender": " Male", "capital_gain": 0, "capital_loss": 0, "hours_per_week": 40, "native_country": " United-States"}
31-
```
32-
The result should be something along the following lines (depending on how you trained the model/what parameters you used, the results may vary):
33-
34-
```
35-
[{u'probabilities': [0.992774486541748, 0.007225471083074808], u'logits': [-4.922891139984131], u'classes': 0, u'logistic': [0.007225471083074808]}]
36-
```
37-
38-
Now that you have a working client, you can adapt this to your own use cases!
1+
https://cloud.google.com/ml-engine/docs/concepts/prediction-overview

ml_engine/online_prediction/predict.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
# the License.
1212

1313
"""Examples of using the Cloud ML Engine's online prediction service."""
14-
from __future__ import print_function
15-
1614
import argparse
1715
import base64
1816
import json
@@ -21,6 +19,8 @@
2119
import googleapiclient.discovery
2220
# [END import_libraries]
2321

22+
import six
23+
2424

2525
# [START predict_json]
2626
def predict_json(project, model, instances, version=None):
@@ -29,9 +29,10 @@ def predict_json(project, model, instances, version=None):
2929
Args:
3030
project (str): project where the Cloud ML Engine Model is deployed.
3131
model (str): model name.
32-
instances ([Mapping[str: any]]): dictionaries from string keys
33-
defined by the model deployment, to data with types that match
34-
expected tensors
32+
instances ([Mapping[str: Any]]): Keys should be the names of Tensors
33+
your deployed model expects as inputs. Values should be datatypes
34+
convertible to Tensors, or (potentially nested) lists of datatypes
35+
convertible to tensors.
3536
version: str, version of the model to target.
3637
Returns:
3738
Mapping[str: any]: dictionary of prediction results defined by the
@@ -106,20 +107,26 @@ def census_to_example_bytes(json_instance):
106107
for details.
107108
108109
Args:
109-
json_instance (Mapping[str: any]): representing data to be serialized.
110+
json_instance (Mapping[str: Any]): Keys should be the names of Tensors
111+
your deployed model expects to parse using it's tf.FeatureSpec.
112+
Values should be datatypes convertible to Tensors, or (potentially
113+
nested) lists of datatypes convertible to tensors.
110114
Returns:
111115
str: A string as a container for the serialized bytes of
112116
tf.train.Example protocol buffer.
113117
"""
114118
import tensorflow as tf
115119
feature_dict = {}
116120
for key, data in json_instance.iteritems():
117-
if isinstance(data, str) or isinstance(data, unicode):
121+
if isinstance(data, six.string_types):
118122
feature_dict[key] = tf.train.Feature(
119123
bytes_list=tf.train.BytesList(value=[str(data)]))
120-
elif isinstance(data, int) or isinstance(data, float):
124+
elif isinstance(data, float):
121125
feature_dict[key] = tf.train.Feature(
122126
float_list=tf.train.FloatList(value=[data]))
127+
elif isinstance(data, int):
128+
feature_dict[key] = tf.train.Feature(
129+
int64_list=tf.train.Int64List(value=[data]))
123130
return tf.train.Example(
124131
features=tf.train.Features(
125132
feature=feature_dict
@@ -181,4 +188,9 @@ def main(project, model, version=None, force_tfrecord=False):
181188
default=False
182189
)
183190
args = parser.parse_args()
184-
main(**args.__dict__)
191+
main(
192+
args.project,
193+
args.model,
194+
version=args.version,
195+
force_tfrecord=args.force_tfrecord
196+
)

ml_engine/online_prediction/predict_test.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
99
# License for the specific language governing permissions and limitations under
1010
# the License.
11+
1112
"""Tests for predict.py ."""
13+
1214
import base64
1315

1416
import pytest
1517

16-
from predict import census_to_example_bytes, predict_json
18+
import predict
1719

1820

1921
MODEL = 'census'
@@ -43,26 +45,21 @@
4345

4446

4547
def test_predict_json():
46-
result = predict_json(PROJECT, MODEL, [JSON, JSON], version=VERSION)
48+
result = predict.predict_json(
49+
PROJECT, MODEL, [JSON, JSON], version=VERSION)
4750
assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result
4851

4952

5053
def test_predict_json_error():
5154
with pytest.raises(RuntimeError):
52-
predict_json(PROJECT, MODEL, [{"foo": "bar"}], version=VERSION)
55+
predict.predict_json(PROJECT, MODEL, [{"foo": "bar"}], version=VERSION)
5356

5457

55-
# TODO(elibixby) Run on Travis when TensorFlow PyPi package supports
56-
# Ubuntu 12.04 See:
57-
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/g3doc/get_started/os_setup.md#import-error
5858
@pytest.mark.slow
5959
def test_census_example_to_bytes():
60-
b = census_to_example_bytes(JSON)
60+
b = predict.census_to_example_bytes(JSON)
6161
assert base64.b64encode(b) is not None
6262

6363

6464
def test_predict_tfrecord():
65-
# Using the same model for TFRecords and
66-
# JSON is currently broken.
67-
# TODO(elibixby) when b/35742966 is fixed add
6865
pass
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
tensorflow>=1.0.0
1+
tensorflow==1.0.0

0 commit comments

Comments
 (0)