Skip to content

Commit 0a8a86d

Browse files
committed
pep8
1 parent 47b2ac4 commit 0a8a86d

1 file changed

Lines changed: 45 additions & 32 deletions

File tree

code/mcrbm/test_mcrbm.py

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,34 @@
11
import cPickle
22

33
import numpy
4+
45
from pylearn.algorithms.mcRBM import mcRBM, mcRBMTrainer
6+
from pylearn.dataset_ops import image_patches
57
import pylearn.datasets.cifar10
8+
69
import theano
710
from theano import tensor
811

912

1013
def l2(X):
11-
return numpy.sqrt((X**2).sum())
14+
return numpy.sqrt((X ** 2).sum())
15+
1216

1317
def _default_rbm_alloc(n_I, n_K=256, n_J=100):
1418
return mcRBM.alloc(n_I, n_K, n_J)
1519

20+
1621
def _default_trainer_alloc(rbm, train_batch, batchsize, initial_lr_per_example,
1722
l1_penalty, l1_penalty_start, persistent_chains):
18-
return mcRBMTrainer.alloc(rbm, train_batch, batchsize, l1_penalty=l1_penalty,
19-
l1_penalty_start=l1_penalty_start,persistent_chains=persistent_chains)
23+
return mcRBMTrainer.alloc(rbm, train_batch, batchsize,
24+
l1_penalty=l1_penalty,
25+
l1_penalty_start=l1_penalty_start,
26+
persistent_chains=persistent_chains)
2027

21-
def test_reproduce_ranzato_hinton_2010(dataset='MAR',
28+
29+
def test_reproduce_ranzato_hinton_2010(dataset='MAR',
2230
n_train_iters=5000,
23-
rbm_alloc=_default_rbm_alloc,
31+
rbm_alloc=_default_rbm_alloc,
2432
trainer_alloc=_default_trainer_alloc,
2533
lr_per_example=.075,
2634
l1_penalty=1e-3,
@@ -30,36 +38,37 @@ def test_reproduce_ranzato_hinton_2010(dataset='MAR',
3038

3139
batchsize = 128
3240
## specific to MAR dataset ##
33-
n_vis=105
34-
n_patches=10240
35-
epoch_size=n_patches
41+
n_vis = 105
42+
n_patches = 10240
43+
epoch_size = n_patches
3644

37-
tile = pylearn.dataset_ops.image_patches.save_filters_of_ranzato_hinton_2010
45+
tile = image_patches.save_filters_of_ranzato_hinton_2010
3846

3947
batch_idx = tensor.iscalar()
40-
batch_range =batch_idx * batchsize + numpy.arange(batchsize)
48+
batch_range = batch_idx * batchsize + numpy.arange(batchsize)
4149

42-
train_batch = pylearn.dataset_ops.image_patches.ranzato_hinton_2010_op(batch_range)
50+
train_batch = image_patches.ranzato_hinton_2010_op(batch_range)
4351

4452
imgs_fn = theano.function([batch_idx], outputs=train_batch)
4553

4654
trainer = trainer_alloc(
4755
rbm_alloc(n_I=n_vis),
4856
train_batch,
49-
batchsize,
57+
batchsize,
5058
initial_lr_per_example=lr_per_example,
5159
l1_penalty=l1_penalty,
5260
l1_penalty_start=l1_penalty_start,
5361
persistent_chains=persistent_chains)
54-
rbm=trainer.rbm
62+
rbm = trainer.rbm
5563

5664
if persistent_chains:
5765
grads = trainer.contrastive_grads()
5866
learn_fn = theano.function([batch_idx],
5967
outputs=[grads[0].norm(2), grads[0].norm(2), grads[1].norm(2)],
6068
updates=trainer.cd_updates())
6169
else:
62-
learn_fn = theano.function([batch_idx], outputs=[], updates=trainer.cd_updates())
70+
learn_fn = theano.function([batch_idx], outputs=[],
71+
updates=trainer.cd_updates())
6372

6473
if persistent_chains:
6574
smplr = trainer.sampler
@@ -70,46 +79,51 @@ def test_reproduce_ranzato_hinton_2010(dataset='MAR',
7079
cPickle.dump(
7180
pylearn.dataset_ops.cifar10.random_cifar_patches_pca(
7281
n_vis, None, 'float32', n_patches, R, C,),
73-
open('test_mcRBM.pca.pkl','w'))
82+
open('test_mcRBM.pca.pkl', 'w'))
7483

7584
print "Learning..."
7685
last_epoch = -1
7786
for jj in xrange(n_train_iters):
78-
epoch = jj*batchsize / epoch_size
87+
epoch = jj * batchsize / epoch_size
7988

8089
print_jj = epoch != last_epoch
8190
last_epoch = epoch
8291

8392
if print_jj:
84-
tile(imgs_fn(jj), "imgs_%06i.png"%jj)
93+
tile(imgs_fn(jj), "imgs_%06i.png" % jj)
8594
if persistent_chains:
86-
tile(smplr.positions.get_value(borrow=True), "sample_%06i.png"%jj)
87-
tile(rbm.U.get_value(borrow=True).T, "U_%06i.png"%jj)
88-
tile(rbm.W.get_value(borrow=True).T, "W_%06i.png"%jj)
95+
tile(smplr.positions.get_value(borrow=True),
96+
"sample_%06i.png" % jj)
97+
tile(rbm.U.get_value(borrow=True).T, "U_%06i.png" % jj)
98+
tile(rbm.W.get_value(borrow=True).T, "W_%06i.png" % jj)
8999

90-
print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize)
100+
print 'saving samples', jj, 'epoch', jj / (epoch_size / batchsize)
91101

92102
print 'l2(U)', l2(rbm.U.get_value(borrow=True)),
93103
print 'l2(W)', l2(rbm.W.get_value(borrow=True)),
94-
print 'l1_penalty',
104+
print 'l1_penalty',
95105
try:
96106
print trainer.effective_l1_penalty.get_value()
97107
except:
98108
print trainer.effective_l1_penalty
99109

100-
print 'U min max', rbm.U.get_value(borrow=True).min(), rbm.U.get_value(borrow=True).max(),
101-
print 'W min max', rbm.W.get_value(borrow=True).min(), rbm.W.get_value(borrow=True).max(),
102-
print 'a min max', rbm.a.get_value(borrow=True).min(), rbm.a.get_value(borrow=True).max(),
103-
print 'b min max', rbm.b.get_value(borrow=True).min(), rbm.b.get_value(borrow=True).max(),
104-
print 'c min max', rbm.c.get_value(borrow=True).min(), rbm.c.get_value(borrow=True).max()
110+
print 'U min max', rbm.U.get_value(borrow=True).min(),
111+
print rbm.U.get_value(borrow=True).max(),
112+
print 'W min max', rbm.W.get_value(borrow=True).min(),
113+
print rbm.W.get_value(borrow=True).max(),
114+
print 'a min max', rbm.a.get_value(borrow=True).min(),
115+
print rbm.a.get_value(borrow=True).max(),
116+
print 'b min max', rbm.b.get_value(borrow=True).min(),
117+
print rbm.b.get_value(borrow=True).max(),
118+
print 'c min max', rbm.c.get_value(borrow=True).min(),
119+
print rbm.c.get_value(borrow=True).max()
105120

106121
if persistent_chains:
107122
print 'parts min', smplr.positions.get_value(borrow=True).min(),
108-
print 'max',smplr.positions.get_value(borrow=True).max(),
123+
print 'max', smplr.positions.get_value(borrow=True).max(),
109124
print 'HMC step', smplr.stepsize.get_value(),
110125
print 'arate', smplr.avg_acceptance_rate.get_value()
111126

112-
113127
l2_of_Ugrad = learn_fn(jj)
114128

115129
if persistent_chains and print_jj:
@@ -125,10 +139,9 @@ def test_reproduce_ranzato_hinton_2010(dataset='MAR',
125139
if jj % 2000 == 0:
126140
print ''
127141
print 'Saving rbm...'
128-
cPickle.dump(rbm, open('mcRBM.rbm.%06i.pkl'%jj, 'w'), -1)
142+
cPickle.dump(rbm, open('mcRBM.rbm.%06i.pkl' % jj, 'w'), -1)
129143
if persistent_chains:
130144
print 'Saving sampler...'
131-
cPickle.dump(smplr, open('mcRBM.smplr.%06i.pkl'%jj, 'w'), -1)
132-
145+
cPickle.dump(smplr, open('mcRBM.smplr.%06i.pkl' % jj, 'w'), -1)
133146

134147
return rbm, smplr

0 commit comments

Comments
 (0)