@@ -105,15 +105,17 @@ def sample_h_given_v(self, v0_sample):
105105 # compute the activation of the hidden units given a sample of the visibles
106106 h1_mean = T .nnet .sigmoid (T .dot (v0_sample , self .W ) + self .hbias )
107107 # get a sample of the hiddens given their activation
108- h1_sample = self .theano_rng .binomial (size = h1_mean .shape , n = 1 , prob = h1_mean )
108+ h1_sample = self .theano_rng .binomial (size = h1_mean .shape , n = 1 , prob = h1_mean ,
109+ dtype = theano .config .floatX )
109110 return [h1_mean , h1_sample ]
110111
111112 def sample_v_given_h (self , h0_sample ):
112113 ''' This function infers state of visible units given hidden units '''
113114 # compute the activation of the visible given the hidden sample
114115 v1_mean = T .nnet .sigmoid (T .dot (h0_sample , self .W .T ) + self .vbias )
115116 # get a sample of the visible given their activation
116- v1_sample = self .theano_rng .binomial (size = v1_mean .shape ,n = 1 ,prob = v1_mean )
117+ v1_sample = self .theano_rng .binomial (size = v1_mean .shape ,n = 1 ,prob = v1_mean ,
118+ dtype = theano .config .floatX )
117119 return [v1_mean , v1_sample ]
118120
119121 def gibbs_hvh (self , h0_sample ):
@@ -159,10 +161,14 @@ def cd(self, lr = 0.1, persistent=None):
159161 [nv_mean , nv_sample , nh_mean , nh_sample ] = self .gibbs_hvh (chain_start )
160162
161163 # determine gradients on RBM parameters
162- g_vbias = T .sum ( self .input - nv_mean , axis = 0 )/ self .batch_size
163- g_hbias = T .sum ( ph_mean - nh_mean , axis = 0 )/ self .batch_size
164- g_W = T .dot (ph_mean .T , self .input )/ self .batch_size - \
165- T .dot (nh_mean .T , nv_mean )/ self .batch_size
164+ # cast batch_size to floatX, because its type is int64,
165+ # and otherwise the gradients are upcasted to float64,
166+ # even when floatX == float32
167+ batch_size = T .cast (self .batch_size , dtype = theano .config .floatX )
168+ g_vbias = T .sum ( self .input - nv_mean , axis = 0 )/ batch_size
169+ g_hbias = T .sum ( ph_mean - nh_mean , axis = 0 )/ batch_size
170+ g_W = T .dot (ph_mean .T , self .input )/ batch_size - \
171+ T .dot (nh_mean .T , nv_mean )/ batch_size
166172
167173 gparams = [g_W .T , g_hbias , g_vbias ]
168174
0 commit comments