Skip to content

Commit aad9437

Browse files
committed
visualizing output in autoencoder
1 parent 4193605 commit aad9437

1 file changed

Lines changed: 25 additions & 2 deletions

File tree

06_autoencoder.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,22 @@
11
import tensorflow as tf
22
import numpy as np
33
import input_data
4+
import matplotlib.pyplot as plt
5+
import matplotlib.gridspec as gridspec
6+
import pdb
7+
8+
## Visualizing reconstructions
9+
def vis(images, save_name):
10+
dim = images.shape[0]
11+
n_image_rows = int(np.ceil(np.sqrt(dim)))
12+
n_image_cols = int(np.ceil(dim * 1.0/n_image_rows))
13+
gs = gridspec.GridSpec(n_image_rows,n_image_cols,top=1., bottom=0., right=1., left=0., hspace=0., wspace=0.)
14+
for g,count in zip(gs,range(int(dim))):
15+
ax = plt.subplot(g)
16+
ax.imshow(images[count,:].reshape((28,28)))
17+
ax.set_xticks([])
18+
ax.set_yticks([])
19+
plt.savefig(save_name + '_vis.png')
420

521
mnist_width = 28
622
n_visible = mnist_width * mnist_width
@@ -39,7 +55,7 @@ def model(X, mask, W, b, W_prime, b_prime):
3955
# create cost function
4056
cost = tf.reduce_sum(tf.pow(X - Z, 2)) # minimize squared error
4157
train_op = tf.train.GradientDescentOptimizer(0.02).minimize(cost) # construct an optimizer
42-
58+
predict_op = Z
4359
# load MNIST data
4460
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
4561
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
@@ -57,4 +73,11 @@ def model(X, mask, W, b, W_prime, b_prime):
5773

5874
mask_np = np.random.binomial(1, 1 - corruption_level, teX.shape)
5975
print(i, sess.run(cost, feed_dict={X: teX, mask: mask_np}))
60-
76+
# save the predictions for 100 images
77+
mask_np = np.random.binomial(1, 1 - corruption_level, teX[:100].shape)
78+
predicted_imgs = sess.run(predict_op, feed_dict={X: teX[:100], mask: mask_np})
79+
input_imgs = teX[:100]
80+
# plot the reconstructed images
81+
vis(predicted_imgs,'pred')
82+
vis(input_imgs,'in')
83+
print('Done')

0 commit comments

Comments
 (0)