|
11 | 11 | FLAGS = tf.app.flags.FLAGS |
12 | 12 |
|
13 | 13 |
|
14 | | -class Inception: |
15 | | - |
16 | | - def __init__(self, graph_path, label_path): |
17 | | - self.labels = [line.rstrip() for line in tf.gfile.GFile(label_path)] |
18 | | - |
19 | | - with tf.gfile.FastGFile(graph_path, 'rb') as fp: |
20 | | - graph_def = tf.GraphDef() |
21 | | - graph_def.ParseFromString(fp.read()) |
22 | | - tf.import_graph_def(graph_def, name='') |
23 | | - |
24 | | - self.sess = tf.Session() |
25 | | - self.logits = self.sess.graph.get_tensor_by_name('final_result:0') |
26 | | - |
27 | | - def predict(self, image_path, top=5): |
28 | | - image = tf.gfile.FastGFile(image_path, 'rb').read() |
29 | | - |
30 | | - prediction = self.sess.run(self.logits, {'DecodeJpeg/contents:0': image}) |
31 | | - |
32 | | - top_indices = prediction[0].argsort()[::-1][:top] |
33 | | - |
34 | | - for i in top_indices: |
35 | | - name = self.labels[i] |
36 | | - score = prediction[0][i] |
37 | | - print('%s (%.2f%%)' % (name, score * 100)) |
38 | | - |
39 | | - |
40 | 14 | def main(_): |
41 | | - if len(sys.argv) < 2: |
42 | | - print('Usage: predict.py image_path') |
43 | | - |
44 | | - else: |
45 | | - inception = Inception(FLAGS.output_graph, FLAGS.output_labels) |
46 | | - |
47 | | - inception.predict(sys.argv[1]) |
48 | | - |
49 | | - if FLAGS.show_image: |
50 | | - img = mpimg.imread(sys.argv[1]) |
51 | | - plt.imshow(img) |
52 | | - plt.show() |
| 15 | + labels = [line.rstrip() for line in tf.gfile.GFile(FLAGS.output_labels)] |
| 16 | + |
| 17 | + with tf.gfile.FastGFile(FLAGS.output_graph, 'rb') as fp: |
| 18 | + graph_def = tf.GraphDef() |
| 19 | + graph_def.ParseFromString(fp.read()) |
| 20 | + tf.import_graph_def(graph_def, name='') |
| 21 | + |
| 22 | + with tf.Session() as sess: |
| 23 | + logits = sess.graph.get_tensor_by_name('final_result:0') |
| 24 | + image = tf.gfile.FastGFile(sys.argv[1], 'rb').read() |
| 25 | + prediction = sess.run(logits, {'DecodeJpeg/contents:0': image}) |
| 26 | + top_results = prediction[0].argsort()[::-1][:5] |
| 27 | + |
| 28 | + for i in top_results: |
| 29 | + name = labels[i] |
| 30 | + score = prediction[0][i] |
| 31 | + print('%s (%.2f%%)' % (name, score * 100)) |
| 32 | + |
| 33 | + if FLAGS.show_image: |
| 34 | + img = mpimg.imread(sys.argv[1]) |
| 35 | + plt.imshow(img) |
| 36 | + plt.show() |
53 | 37 |
|
54 | 38 |
|
55 | 39 | if __name__ == "__main__": |
|
0 commit comments