11import cPickle
22
33import numpy
4+
45from pylearn .algorithms .mcRBM import mcRBM , mcRBMTrainer
6+ from pylearn .dataset_ops import image_patches
57import pylearn .datasets .cifar10
8+
69import theano
710from theano import tensor
811
912
1013def l2 (X ):
11- return numpy .sqrt ((X ** 2 ).sum ())
14+ return numpy .sqrt ((X ** 2 ).sum ())
15+
1216
1317def _default_rbm_alloc (n_I , n_K = 256 , n_J = 100 ):
1418 return mcRBM .alloc (n_I , n_K , n_J )
1519
20+
1621def _default_trainer_alloc (rbm , train_batch , batchsize , initial_lr_per_example ,
1722 l1_penalty , l1_penalty_start , persistent_chains ):
18- return mcRBMTrainer .alloc (rbm , train_batch , batchsize , l1_penalty = l1_penalty ,
19- l1_penalty_start = l1_penalty_start ,persistent_chains = persistent_chains )
23+ return mcRBMTrainer .alloc (rbm , train_batch , batchsize ,
24+ l1_penalty = l1_penalty ,
25+ l1_penalty_start = l1_penalty_start ,
26+ persistent_chains = persistent_chains )
2027
21- def test_reproduce_ranzato_hinton_2010 (dataset = 'MAR' ,
28+
29+ def test_reproduce_ranzato_hinton_2010 (dataset = 'MAR' ,
2230 n_train_iters = 5000 ,
23- rbm_alloc = _default_rbm_alloc ,
31+ rbm_alloc = _default_rbm_alloc ,
2432 trainer_alloc = _default_trainer_alloc ,
2533 lr_per_example = .075 ,
2634 l1_penalty = 1e-3 ,
@@ -30,36 +38,37 @@ def test_reproduce_ranzato_hinton_2010(dataset='MAR',
3038
3139 batchsize = 128
3240 ## specific to MAR dataset ##
33- n_vis = 105
34- n_patches = 10240
35- epoch_size = n_patches
41+ n_vis = 105
42+ n_patches = 10240
43+ epoch_size = n_patches
3644
37- tile = pylearn . dataset_ops . image_patches .save_filters_of_ranzato_hinton_2010
45+ tile = image_patches .save_filters_of_ranzato_hinton_2010
3846
3947 batch_idx = tensor .iscalar ()
40- batch_range = batch_idx * batchsize + numpy .arange (batchsize )
48+ batch_range = batch_idx * batchsize + numpy .arange (batchsize )
4149
42- train_batch = pylearn . dataset_ops . image_patches .ranzato_hinton_2010_op (batch_range )
50+ train_batch = image_patches .ranzato_hinton_2010_op (batch_range )
4351
4452 imgs_fn = theano .function ([batch_idx ], outputs = train_batch )
4553
4654 trainer = trainer_alloc (
4755 rbm_alloc (n_I = n_vis ),
4856 train_batch ,
49- batchsize ,
57+ batchsize ,
5058 initial_lr_per_example = lr_per_example ,
5159 l1_penalty = l1_penalty ,
5260 l1_penalty_start = l1_penalty_start ,
5361 persistent_chains = persistent_chains )
54- rbm = trainer .rbm
62+ rbm = trainer .rbm
5563
5664 if persistent_chains :
5765 grads = trainer .contrastive_grads ()
5866 learn_fn = theano .function ([batch_idx ],
5967 outputs = [grads [0 ].norm (2 ), grads [0 ].norm (2 ), grads [1 ].norm (2 )],
6068 updates = trainer .cd_updates ())
6169 else :
62- learn_fn = theano .function ([batch_idx ], outputs = [], updates = trainer .cd_updates ())
70+ learn_fn = theano .function ([batch_idx ], outputs = [],
71+ updates = trainer .cd_updates ())
6372
6473 if persistent_chains :
6574 smplr = trainer .sampler
@@ -70,46 +79,51 @@ def test_reproduce_ranzato_hinton_2010(dataset='MAR',
7079 cPickle .dump (
7180 pylearn .dataset_ops .cifar10 .random_cifar_patches_pca (
7281 n_vis , None , 'float32' , n_patches , R , C ,),
73- open ('test_mcRBM.pca.pkl' ,'w' ))
82+ open ('test_mcRBM.pca.pkl' , 'w' ))
7483
7584 print "Learning..."
7685 last_epoch = - 1
7786 for jj in xrange (n_train_iters ):
78- epoch = jj * batchsize / epoch_size
87+ epoch = jj * batchsize / epoch_size
7988
8089 print_jj = epoch != last_epoch
8190 last_epoch = epoch
8291
8392 if print_jj :
84- tile (imgs_fn (jj ), "imgs_%06i.png" % jj )
93+ tile (imgs_fn (jj ), "imgs_%06i.png" % jj )
8594 if persistent_chains :
86- tile (smplr .positions .get_value (borrow = True ), "sample_%06i.png" % jj )
87- tile (rbm .U .get_value (borrow = True ).T , "U_%06i.png" % jj )
88- tile (rbm .W .get_value (borrow = True ).T , "W_%06i.png" % jj )
95+ tile (smplr .positions .get_value (borrow = True ),
96+ "sample_%06i.png" % jj )
97+ tile (rbm .U .get_value (borrow = True ).T , "U_%06i.png" % jj )
98+ tile (rbm .W .get_value (borrow = True ).T , "W_%06i.png" % jj )
8999
90- print 'saving samples' , jj , 'epoch' , jj / (epoch_size / batchsize )
100+ print 'saving samples' , jj , 'epoch' , jj / (epoch_size / batchsize )
91101
92102 print 'l2(U)' , l2 (rbm .U .get_value (borrow = True )),
93103 print 'l2(W)' , l2 (rbm .W .get_value (borrow = True )),
94- print 'l1_penalty' ,
104+ print 'l1_penalty' ,
95105 try :
96106 print trainer .effective_l1_penalty .get_value ()
97107 except :
98108 print trainer .effective_l1_penalty
99109
100- print 'U min max' , rbm .U .get_value (borrow = True ).min (), rbm .U .get_value (borrow = True ).max (),
101- print 'W min max' , rbm .W .get_value (borrow = True ).min (), rbm .W .get_value (borrow = True ).max (),
102- print 'a min max' , rbm .a .get_value (borrow = True ).min (), rbm .a .get_value (borrow = True ).max (),
103- print 'b min max' , rbm .b .get_value (borrow = True ).min (), rbm .b .get_value (borrow = True ).max (),
104- print 'c min max' , rbm .c .get_value (borrow = True ).min (), rbm .c .get_value (borrow = True ).max ()
110+ print 'U min max' , rbm .U .get_value (borrow = True ).min (),
111+ print rbm .U .get_value (borrow = True ).max (),
112+ print 'W min max' , rbm .W .get_value (borrow = True ).min (),
113+ print rbm .W .get_value (borrow = True ).max (),
114+ print 'a min max' , rbm .a .get_value (borrow = True ).min (),
115+ print rbm .a .get_value (borrow = True ).max (),
116+ print 'b min max' , rbm .b .get_value (borrow = True ).min (),
117+ print rbm .b .get_value (borrow = True ).max (),
118+ print 'c min max' , rbm .c .get_value (borrow = True ).min (),
119+ print rbm .c .get_value (borrow = True ).max ()
105120
106121 if persistent_chains :
107122 print 'parts min' , smplr .positions .get_value (borrow = True ).min (),
108- print 'max' ,smplr .positions .get_value (borrow = True ).max (),
123+ print 'max' , smplr .positions .get_value (borrow = True ).max (),
109124 print 'HMC step' , smplr .stepsize .get_value (),
110125 print 'arate' , smplr .avg_acceptance_rate .get_value ()
111126
112-
113127 l2_of_Ugrad = learn_fn (jj )
114128
115129 if persistent_chains and print_jj :
@@ -125,10 +139,9 @@ def test_reproduce_ranzato_hinton_2010(dataset='MAR',
125139 if jj % 2000 == 0 :
126140 print ''
127141 print 'Saving rbm...'
128- cPickle .dump (rbm , open ('mcRBM.rbm.%06i.pkl' % jj , 'w' ), - 1 )
142+ cPickle .dump (rbm , open ('mcRBM.rbm.%06i.pkl' % jj , 'w' ), - 1 )
129143 if persistent_chains :
130144 print 'Saving sampler...'
131- cPickle .dump (smplr , open ('mcRBM.smplr.%06i.pkl' % jj , 'w' ), - 1 )
132-
145+ cPickle .dump (smplr , open ('mcRBM.smplr.%06i.pkl' % jj , 'w' ), - 1 )
133146
134147 return rbm , smplr
0 commit comments