Skip to content

Commit 1272ab0

Browse files
committed
Inception 예제 더 간략하게 변경
1 parent 0e40383 commit 1272ab0

1 file changed

Lines changed: 22 additions & 38 deletions

File tree

09 - Inception/predict.py

Lines changed: 22 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,45 +11,29 @@
1111
FLAGS = tf.app.flags.FLAGS
1212

1313

14-
class Inception:
15-
16-
def __init__(self, graph_path, label_path):
17-
self.labels = [line.rstrip() for line in tf.gfile.GFile(label_path)]
18-
19-
with tf.gfile.FastGFile(graph_path, 'rb') as fp:
20-
graph_def = tf.GraphDef()
21-
graph_def.ParseFromString(fp.read())
22-
tf.import_graph_def(graph_def, name='')
23-
24-
self.sess = tf.Session()
25-
self.logits = self.sess.graph.get_tensor_by_name('final_result:0')
26-
27-
def predict(self, image_path, top=5):
28-
image = tf.gfile.FastGFile(image_path, 'rb').read()
29-
30-
prediction = self.sess.run(self.logits, {'DecodeJpeg/contents:0': image})
31-
32-
top_indices = prediction[0].argsort()[::-1][:top]
33-
34-
for i in top_indices:
35-
name = self.labels[i]
36-
score = prediction[0][i]
37-
print('%s (%.2f%%)' % (name, score * 100))
38-
39-
4014
def main(_):
41-
if len(sys.argv) < 2:
42-
print('Usage: predict.py image_path')
43-
44-
else:
45-
inception = Inception(FLAGS.output_graph, FLAGS.output_labels)
46-
47-
inception.predict(sys.argv[1])
48-
49-
if FLAGS.show_image:
50-
img = mpimg.imread(sys.argv[1])
51-
plt.imshow(img)
52-
plt.show()
15+
labels = [line.rstrip() for line in tf.gfile.GFile(FLAGS.output_labels)]
16+
17+
with tf.gfile.FastGFile(FLAGS.output_graph, 'rb') as fp:
18+
graph_def = tf.GraphDef()
19+
graph_def.ParseFromString(fp.read())
20+
tf.import_graph_def(graph_def, name='')
21+
22+
with tf.Session() as sess:
23+
logits = sess.graph.get_tensor_by_name('final_result:0')
24+
image = tf.gfile.FastGFile(sys.argv[1], 'rb').read()
25+
prediction = sess.run(logits, {'DecodeJpeg/contents:0': image})
26+
top_results = prediction[0].argsort()[::-1][:5]
27+
28+
for i in top_results:
29+
name = labels[i]
30+
score = prediction[0][i]
31+
print('%s (%.2f%%)' % (name, score * 100))
32+
33+
if FLAGS.show_image:
34+
img = mpimg.imread(sys.argv[1])
35+
plt.imshow(img)
36+
plt.show()
5337

5438

5539
if __name__ == "__main__":

0 commit comments

Comments
 (0)