Skip to content

Commit 7b6b66e

Browse files
author
Daniel McDuff
committed
DL updates.
1 parent ba6b961 commit 7b6b66e

3 files changed

Lines changed: 21 additions & 5 deletions

File tree

code/DBN_face.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def test_DBN(finetune_lr=0.1, pretraining_epochs=3,
342342
if not os.path.isdir(output_dir+'/dbn_plots'):
343343
os.makedirs(output_dir+'/dbn_plots')
344344

345-
# compute number of minibatches for training, validation and testing
345+
# compute number of minibatches for pretraining:
346346
n_train_batches = pre_train_set_x.get_value(borrow=True).shape[0] / batch_size
347347

348348
# numpy random generator
@@ -409,6 +409,9 @@ def test_DBN(finetune_lr=0.1, pretraining_epochs=3,
409409
# FINETUNING THE MODEL #
410410
########################
411411

412+
# compute number of minibatches for training, validation and testing:
413+
n_train_batches = train_set_x.get_value(borrow=True).shape[0] / batch_size
414+
412415
# get the training, validation and testing function for the model
413416
print '... getting the finetuning functions'
414417
train_fn, validate_model, test_model = dbn.build_finetune_functions(

code/load_faces.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import theano.tensor as T
1717
from theano.tensor.shared_randomstreams import RandomStreams
1818
import pandas
19+
import matplotlib.pyplot as plt
20+
import matplotlib.cm as cm
1921

2022
def logistic_transform(A, mu, sigma):
2123
A[numpy.where(A == 0)] = 0.1
@@ -63,14 +65,21 @@ def import_data(label, data_dir, image_dim):
6365
#if neutral.iloc[0]:
6466
if test_target>=-1:
6567
test_image = numpy.array(scipy.misc.imread(f))
68+
6669
if (len(test_image.flatten())!=(image_dim*image_dim)):
6770
continue
6871
#for i, row in enumerate(test_targets.iloc[0].values):
6972
# print i + str(test_targets.iloc[0][i])
7073

74+
#test_image2 = test_image.astype(float)
75+
#temp = logistic_transform(test_image2.flatten(), 140, 0.05)
76+
#plt.imshow(temp.reshape(image_dim,image_dim), cmap = cm.Greys_r)
77+
#plt.show()
78+
#plt.show(block=False)
79+
7180
if test_target > 50:
7281
test_image2 = test_image.astype(float)
73-
temp = logistic_transform(test_image2.flatten(), 120, 0.1)
82+
temp = logistic_transform(test_image2.flatten(), 140, 0.05)
7483
if numpy.isnan(temp).any():
7584
print "NaN found :("
7685
continue
@@ -79,7 +88,7 @@ def import_data(label, data_dir, image_dim):
7988
target = numpy.append(target, [1], axis=0)
8089
elif test_target == 0:
8190
test_image2 = test_image.astype(float)
82-
temp = logistic_transform(test_image2.flatten(), 120, 0.1)
91+
temp = logistic_transform(test_image2.flatten(), 140, 0.05)
8392
if numpy.isnan(temp).any():
8493
print "NaN found :("
8594
continue
@@ -88,7 +97,7 @@ def import_data(label, data_dir, image_dim):
8897
target = numpy.append(target, [0], axis=0)
8998
elif test_target == -1:
9099
test_image2 = test_image.astype(float)
91-
temp = logistic_transform(test_image2.flatten(), 120, 0.1)
100+
temp = logistic_transform(test_image2.flatten(), 140, 0.05)
92101
if numpy.isnan(temp).any():
93102
print "NaN found :("
94103
continue

code/rbm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,16 @@ def sample_h_given_v(self, v0_sample):
151151
# the visibles
152152
pre_sigmoid_h1, h1_mean = self.propup(v0_sample)
153153

154+
155+
# LARGER mu IS MORE SPARSE.
156+
mu = 0.000001 # mu = 0.01 is probably too small.
157+
# LOOKED AT THE CODE HERE: http://lrn2cre8.ofai.at/lrn2/doc/_modules/lrn2/models/srbm_goh.html#SRBM_Goh
154158
## DAN ADDED:#########################
155159
rank_0 = ((h1_mean.argsort(axis=0)).argsort(axis=0).astype(theano.config.floatX) + 1.)/T.shape(h1_mean)[0].astype(theano.config.floatX)
156160

157161
rank_1 = ((h1_mean.argsort(axis=1)).argsort(axis=1).astype(theano.config.floatX) + 1.)/T.shape(h1_mean)[1].astype(theano.config.floatX)
158162

159-
h1_mean = (1.-0.5)*(rank_0**((1./0.05)-1.))+0.5*(rank_1**((1./0.05)-1.))
163+
h1_mean = (1.-0.5)*(rank_0**((1./mu)-1.))+0.5*(rank_1**((1./mu)-1.))
160164

161165
#pre_sigmoid_h1_bin = T.log(h1_mean) - T.log(1. - h1_mean)
162166
#pre_sigmoid_h1 = pre_sigmoid_h1_bin

0 commit comments

Comments
 (0)