|
9 | 9 | # License for the specific language governing permissions and limitations under |
10 | 10 | # the License. |
11 | 11 | """Examples of using the Cloud ML Engine's online prediction service.""" |
| 12 | +from __future__ import print_function |
12 | 13 | # [START import_libraries] |
13 | 14 | import googleapiclient.discovery |
14 | 15 | # [END import_libraries] |
@@ -113,15 +114,28 @@ def census_to_example_bytes(json_instance): |
113 | 114 | def main(project, model, version=None, force_tfrecord=False): |
114 | 115 | import json |
115 | 116 | while True: |
116 | | - user_input = json.loads(raw_input()) |
117 | | - if force_tfrecord: |
118 | | - example_bytes = census_to_example_bytes(user_input) |
119 | | - result = predict_tf_records( |
120 | | - project, model, [example_bytes], version=version) |
| 117 | + try: |
| 118 | + user_input = json.loads(raw_input("Valid JSON >>>")) |
| 119 | + except KeyboardInterrupt: |
| 120 | + return |
| 121 | + |
| 122 | + if not isinstance(user_input, list): |
| 123 | + user_input = [user_input] |
| 124 | + try: |
| 125 | + if force_tfrecord: |
| 126 | + example_bytes_list = [ |
| 127 | + census_to_example_bytes(e) |
| 128 | + for e in user_input |
| 129 | + ] |
| 130 | + result = predict_tf_records( |
| 131 | + project, model, example_bytes_list, version=version) |
| 132 | + else: |
| 133 | + result = predict_json( |
| 134 | + project, model, user_input, version=version) |
| 135 | + except RuntimeError as err: |
| 136 | + print(str(err)) |
121 | 137 | else: |
122 | | - result = predict_json( |
123 | | - project, model, [user_input], version=version) |
124 | | - print(result) |
| 138 | + print(result) |
125 | 139 |
|
126 | 140 |
|
127 | 141 | if __name__ == '__main__': |
|
0 commit comments