1- import argparse
2- import json
1+ # Copyright 2016 Google Inc. All Rights Reserved. Licensed under the Apache
2+ # License, Version 2.0 (the "License"); you may not use this file except in
3+ # compliance with the License. You may obtain a copy of the License at
4+ # http://www.apache.org/licenses/LICENSE-2.0
5+
6+ # Unless required by applicable law or agreed to in writing, software
7+ # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
8+ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
9+ # License for the specific language governing permissions and limitations under
10+ # the License.
11+ """Examples of using the Cloud ML Engine's online prediction service."""
12+
313# [START import_libraries]
414import googleapiclient .discovery
515# [END import_libraries]
616
17+
718# [START authenticating]
819def get_ml_engine_service ():
9- return googleapiclient .discovery .build_from_document (
10- json .load (open ('staging_ml.json' )))
20+ return googleapiclient .discovery .build ('ml' , 'v1beta1' )
1121# [END authenticating]
1222
23+
1324# [START predict_json]
1425def predict_json (project , model , instances , version = None ):
1526 """Send data instances to a deployed model for prediction
1627 Args:
17- project: str, project where the Cloud ML Engine Model is deployed
18- model: str, model name
28+ project: str, project where the Cloud ML Engine Model is deployed.
29+ model: str, model name.
1930 instances: [dict], dictionaries from string keys defined by the model
2031 to data.
2132 version: [optional] str, version of the model to target.
@@ -26,10 +37,10 @@ def predict_json(project, model, instances, version=None):
2637 name = 'projects/{}/models/{}' .format (project , model )
2738 if version is not None :
2839 name += '/versions/{}' .format (version )
29-
40+
3041 response = service .projects ().predict (
3142 name = name ,
32- body = {" instances" : instances }
43+ body = {' instances' : instances }
3344 ).execute ()
3445
3546 if 'error' in response :
@@ -38,15 +49,19 @@ def predict_json(project, model, instances, version=None):
3849 return response ['predictions' ]
3950# [END predict_json]
4051
52+
4153# [START predict_tf_records]
42- def predict_tf_records (project , model , example_bytes_list , key = 'tfrecord' , version = None ):
54+ def predict_tf_records (project ,
55+ model ,
56+ example_bytes_list ,
57+ key = 'tfrecord' ,
58+ version = None ):
4359 """Send data instances to a deployed model for prediction
4460 Args:
4561 project: str, project where the Cloud ML Engine Model is deployed
46- model: str, model name
47- feature_dict_list: [dict], dictionaries from string keys to
48- tf.train.Feature protos.
49- version: [optional] str, version of the model to target.
62+ model: str, model name.
63+ example_bytes_list: [str], Serialized tf.train.Example protos.
64+ version: str, version of the model to target.
5065 Returns:
5166 A dictionary of prediction results defined by the model.
5267 """
@@ -58,7 +73,7 @@ def predict_tf_records(project, model, example_bytes_list, key='tfrecord', versi
5873
5974 response = service .projects ().predict (
6075 name = name ,
61- body = {" instances" : [
76+ body = {' instances' : [
6277 {key : {'b64' : base64 .b64encode (example_bytes )}}
6378 for example_bytes in example_bytes_list
6479 ]}
@@ -67,8 +82,18 @@ def predict_tf_records(project, model, example_bytes_list, key='tfrecord', versi
6782 raise RuntimeError (response ['error' ])
6883
6984 return response ['predictions' ]
85+ # [END predict_tf_records]
86+
7087
88+ # [START census_to_example_bytes]
7189def census_to_example_bytes (json_instance ):
90+ """Serialize a JSON example to the bytes of a tf.train.Example.
91+ This method is specific to the signature of the Census example.
92+ Args:
93+ json_instance: dict, representing data to be serialized.
94+ Returns:
95+ A string (as a container for bytes).
96+ """
7297 import tensorflow as tf
7398 feature_dict = {}
7499 for key , data in json_instance .iteritems ():
@@ -83,18 +108,82 @@ def census_to_example_bytes(json_instance):
83108 feature = feature_dict
84109 )
85110 ).SerializeToString ()
86- # [END predict_tf_records ]
111+ # [END census_to_example_bytes ]
87112
88- if __name__ == '__main__' :
89- import sys
90- import base64
113+
114+ # [START predict_from_files]
115+ def predict_from_files (project ,
116+ model ,
117+ files ,
118+ version = None ,
119+ force_tfrecord = False ):
91120 import json
92- with open (sys .argv [1 ]) as f :
93- instances = [json .loads (line ) for line in f .readlines ()]
94-
95- with open (sys .argv [2 ], 'w' ) as f :
96- for instance in instances :
97- f .write (json .dumps (
98- {'tfrecord' : {'b64' : base64 .b64encode (
99- census_to_example_string (instance )
100- )}}))
121+ import itertools
122+ instances = (json .loads (line )
123+ for f in files
124+ for line in f .readlines ())
125+
126+ # Requests to online prediction
127+ # can have at most 100 instances
128+ args = [instances ] * 100
129+ instance_batches = itertools .izip (* args )
130+
131+ results = []
132+ for batch in instance_batches :
133+ if force_tfrecord :
134+ example_bytes_list = [
135+ census_to_example_bytes (instance )
136+ for instance in batch
137+ ]
138+ results .append (predict_tf_records (
139+ project ,
140+ model ,
141+ example_bytes_list ,
142+ version = version
143+ ))
144+ else :
145+ results .append (predict_json (
146+ project ,
147+ model ,
148+ batch ,
149+ version = version
150+ ))
151+ return results
152+ # [END predict_from_files]
153+
154+
155+ if __name__ == '__main__' :
156+ import argparse
157+ import os
158+ parser = argparse .ArgumentParser ()
159+ parser .add_argument (
160+ 'input_files' ,
161+ help = 'File paths with examples to predict' ,
162+ nargs = '+' ,
163+ type = os .path .abspath
164+ )
165+ parser .add_argument (
166+ '--project' ,
167+ help = 'Project in which the model is deployed' ,
168+ type = str ,
169+ required = True
170+ )
171+ parser .add_argument (
172+ '--model' ,
173+ help = 'Model name' ,
174+ type = str ,
175+ required = True
176+ )
177+ parser .add_argument (
178+ '--version' ,
179+ help = 'Name of the version.' ,
180+ type = str
181+ )
182+ parser .add_argument (
183+ '--force-tfrecord' ,
184+ help = 'Send predictions as TFRecords rather than raw JSON' ,
185+ action = 'store_true' ,
186+ default = False
187+ )
188+ args = parser .parse_args ()
189+ predict_from_files (** args .__dict__ )
0 commit comments