1- # Copyright 2016 Google Inc. All Rights Reserved. Licensed under the Apache
1+ #!/bin/python
2+ # Copyright 2017 Google Inc. All Rights Reserved. Licensed under the Apache
23# License, Version 2.0 (the "License"); you may not use this file except in
34# compliance with the License. You may obtain a copy of the License at
45# http://www.apache.org/licenses/LICENSE-2.0
89# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
910# License for the specific language governing permissions and limitations under
1011# the License.
12+
1113"""Examples of using the Cloud ML Engine's online prediction service."""
1214from __future__ import print_function
1315# [START import_libraries]
1416import googleapiclient .discovery
1517# [END import_libraries]
1618
1719
18- # [START authenticating]
19- def get_ml_engine_service ():
20- """Create the ML Engine service object.
21- To authenticate set the environment variable
22- GOOGLE_APPLICATION_CREDENTIALS=<path_to_service_account_file>
23- """
24- return googleapiclient .discovery .build ('ml' , 'v1beta1' )
25- # [END authenticating]
26-
27-
2820# [START predict_json]
2921def predict_json (project , model , instances , version = None ):
30- """Send data instances to a deployed model for prediction
22+ """Send json data to a deployed model for prediction.
23+
3124 Args:
32- project: str, project where the Cloud ML Engine Model is deployed.
33- model: str, model name.
34- instances: [dict], dictionaries from string keys defined by the model
35- to data.
36- version: [optional] str, version of the model to target.
25+ project (str): project where the Cloud ML Engine Model is deployed.
26+ model (str): model name.
27+ instances ([Mapping[str: any]]): dictionaries from string keys
28+ defined by the model deployment, to data with types that match
29+ expected tensors
30+ version: str, version of the model to target.
3731 Returns:
38- A dictionary of prediction results defined by the model.
32+ Mapping[str: any]: dictionary of prediction results defined by the
33+ model.
3934 """
40- service = get_ml_engine_service ()
35+ # Create the ML Engine service object.
36+ # To authenticate set the environment variable
37+ # GOOGLE_APPLICATION_CREDENTIALS=<path_to_service_account_file>
38+ service = googleapiclient .discovery .build ('ml' , 'v1beta1' )
4139 name = 'projects/{}/models/{}' .format (project , model )
40+
4241 if version is not None :
4342 name += '/versions/{}' .format (version )
4443
@@ -58,30 +57,36 @@ def predict_json(project, model, instances, version=None):
5857def predict_tf_records (project ,
5958 model ,
6059 example_bytes_list ,
61- key = 'tfrecord' ,
6260 version = None ):
63- """Send data instances to a deployed model for prediction
61+ """Send protocol buffer data to a deployed model for prediction.
62+
6463 Args:
65- project: str, project where the Cloud ML Engine Model is deployed
66- model: str, model name.
67- example_bytes_list: [str], Serialized tf.train.Example protos.
64+ project (str): project where the Cloud ML Engine Model is deployed.
65+ model (str): model name.
66+ example_bytes_list ([str]): A list of bytestrings representing
67+ serialized tf.train.Example protocol buffers. The contents of this
68+ protocol buffer will change depending on the signature of your
69+ deployed model.
6870 version: str, version of the model to target.
6971 Returns:
70- A dictionary of prediction results defined by the model.
72+ Mapping[str: any]: dictionary of prediction results defined by the
73+ model.
7174 """
7275 import base64
73- service = get_ml_engine_service ( )
76+ service = googleapiclient . discovery . build ( 'ml' , 'v1beta1' )
7477 name = 'projects/{}/models/{}' .format (project , model )
78+
7579 if version is not None :
7680 name += '/versions/{}' .format (version )
7781
7882 response = service .projects ().predict (
7983 name = name ,
8084 body = {'instances' : [
81- {key : { 'b64' : base64 .b64encode (example_bytes )} }
85+ {'b64' : base64 .b64encode (example_bytes )}
8286 for example_bytes in example_bytes_list
8387 ]}
8488 ).execute ()
89+
8590 if 'error' in response :
8691 raise RuntimeError (response ['error' ])
8792
@@ -93,10 +98,14 @@ def predict_tf_records(project,
9398def census_to_example_bytes (json_instance ):
9499 """Serialize a JSON example to the bytes of a tf.train.Example.
95100 This method is specific to the signature of the Census example.
101+ See: https://cloud.google.com/ml-engine/docs/concepts/prediction-overview
102+ for details.
103+
96104 Args:
97- json_instance: dict, representing data to be serialized.
105+ json_instance (Mapping[str: any]): representing data to be serialized.
98106 Returns:
99- A string (as a container for bytes).
107+ str: A string as a container for the serialized bytes of
108+ tf.train.Example protocol buffer.
100109 """
101110 import tensorflow as tf
102111 feature_dict = {}
0 commit comments