11import tensorflow as tf
22import numpy as np
33import input_data
4+ import matplotlib .pyplot as plt
5+ import matplotlib .gridspec as gridspec
6+ import pdb
7+
8+ ## Visualizing reconstructions
9+ def vis (images , save_name ):
10+ dim = images .shape [0 ]
11+ n_image_rows = int (np .ceil (np .sqrt (dim )))
12+ n_image_cols = int (np .ceil (dim * 1.0 / n_image_rows ))
13+ gs = gridspec .GridSpec (n_image_rows ,n_image_cols ,top = 1. , bottom = 0. , right = 1. , left = 0. , hspace = 0. , wspace = 0. )
14+ for g ,count in zip (gs ,range (int (dim ))):
15+ ax = plt .subplot (g )
16+ ax .imshow (images [count ,:].reshape ((28 ,28 )))
17+ ax .set_xticks ([])
18+ ax .set_yticks ([])
19+ plt .savefig (save_name + '_vis.png' )
420
521mnist_width = 28
622n_visible = mnist_width * mnist_width
@@ -39,7 +55,7 @@ def model(X, mask, W, b, W_prime, b_prime):
3955# create cost function
4056cost = tf .reduce_sum (tf .pow (X - Z , 2 )) # minimize squared error
4157train_op = tf .train .GradientDescentOptimizer (0.02 ).minimize (cost ) # construct an optimizer
42-
58+ predict_op = Z
4359# load MNIST data
4460mnist = input_data .read_data_sets ("MNIST_data/" , one_hot = True )
4561trX , trY , teX , teY = mnist .train .images , mnist .train .labels , mnist .test .images , mnist .test .labels
@@ -57,4 +73,11 @@ def model(X, mask, W, b, W_prime, b_prime):
5773
5874 mask_np = np .random .binomial (1 , 1 - corruption_level , teX .shape )
5975 print (i , sess .run (cost , feed_dict = {X : teX , mask : mask_np }))
60-
76+ # save the predictions for 100 images
77+ mask_np = np .random .binomial (1 , 1 - corruption_level , teX [:100 ].shape )
78+ predicted_imgs = sess .run (predict_op , feed_dict = {X : teX [:100 ], mask : mask_np })
79+ input_imgs = teX [:100 ]
80+ # plot the reconstructed images
81+ vis (predicted_imgs ,'pred' )
82+ vis (input_imgs ,'in' )
83+ print ('Done' )
0 commit comments