99# License for the specific language governing permissions and limitations under
1010# the License.
1111"""Examples of using the Cloud ML Engine's online prediction service."""
12-
1312# [START import_libraries]
1413import googleapiclient .discovery
1514# [END import_libraries]
@@ -111,57 +110,23 @@ def census_to_example_bytes(json_instance):
111110# [END census_to_example_bytes]
112111
113112
114- # [START predict_from_files]
115- def predict_from_files (project ,
116- model ,
117- files ,
118- version = None ,
119- force_tfrecord = False ):
113+ def main (project , model , version = None , force_tfrecord = False ):
120114 import json
121- import itertools
122- instances = (json .loads (line )
123- for f in files
124- for line in f .readlines ())
125-
126- # Requests to online prediction
127- # can have at most 100 instances
128- args = [instances ] * 100
129- instance_batches = itertools .izip (* args )
130-
131- results = []
132- for batch in instance_batches :
115+ while True :
116+ user_input = json .loads (raw_input ())
133117 if force_tfrecord :
134- example_bytes_list = [
135- census_to_example_bytes (instance )
136- for instance in batch
137- ]
138- results .append (predict_tf_records (
139- project ,
140- model ,
141- example_bytes_list ,
142- version = version
143- ))
118+ example_bytes = census_to_example_bytes (user_input )
119+ result = predict_tf_records (
120+ project , model , [example_bytes ], version = version )
144121 else :
145- results .append (predict_json (
146- project ,
147- model ,
148- batch ,
149- version = version
150- ))
151- return results
152- # [END predict_from_files]
122+ result = predict_json (
123+ project , model , [user_input ], version = version )
124+ print (result )
153125
154126
155127if __name__ == '__main__' :
156128 import argparse
157- import os
158129 parser = argparse .ArgumentParser ()
159- parser .add_argument (
160- 'input_files' ,
161- help = 'File paths with examples to predict' ,
162- nargs = '+' ,
163- type = os .path .abspath
164- )
165130 parser .add_argument (
166131 '--project' ,
167132 help = 'Project in which the model is deployed' ,
@@ -186,4 +151,4 @@ def predict_from_files(project,
186151 default = False
187152 )
188153 args = parser .parse_args ()
189- predict_from_files (** args .__dict__ )
154+ main (** args .__dict__ )
0 commit comments