44
55from hmc import HMC_sampler
66
7+
78def sampler_on_nd_gaussian (sampler_cls , burnin , n_samples , dim = 10 ):
8- batchsize = 3
9+ batchsize = 3
910
1011 rng = numpy .random .RandomState (123 )
1112
1213 # Define a covariance and mu for a gaussian
13- mu = numpy .array (rng .rand (dim ) * 10 , dtype = theano .config .floatX )
14+ mu = numpy .array (rng .rand (dim ) * 10 , dtype = theano .config .floatX )
1415 cov = numpy .array (rng .rand (dim , dim ), dtype = theano .config .floatX )
1516 cov = (cov + cov .T ) / 2.
1617 cov [numpy .arange (dim ), numpy .arange (dim )] = 1.0
1718 cov_inv = linalg .inv (cov )
1819
1920 # Define energy function for a multi-variate Gaussian
2021 def gaussian_energy (x ):
21- return 0.5 * (theano .tensor .dot ((x - mu ),cov_inv )* (x - mu )).sum (axis = 1 )
22+ return 0.5 * (theano .tensor .dot ((x - mu ), cov_inv ) *
23+ (x - mu )).sum (axis = 1 )
2224
2325 # Declared shared random variable for positions
2426 position = rng .randn (batchsize , dim ).astype (theano .config .floatX )
@@ -29,11 +31,12 @@ def gaussian_energy(x):
2931 initial_stepsize = 1e-3 , stepsize_max = 0.5 )
3032
3133 # Start with a burn-in process
32- garbage = [sampler .draw () for r in xrange (burnin )] #burn-in
33- # Draw `n_samples`: result is a 3D tensor of dim [n_samples, batchsize, dim]
34+ garbage = [sampler .draw () for r in xrange (burnin )] # burn-in Draw
35+ # `n_samples`: result is a 3D tensor of dim [n_samples, batchsize,
36+ # dim]
3437 _samples = numpy .asarray ([sampler .draw () for r in xrange (n_samples )])
3538 # Flatten to [n_samples * batchsize, dim]
36- samples = _samples .T .reshape (dim ,- 1 ).T
39+ samples = _samples .T .reshape (dim , - 1 ).T
3740
3841 print '****** TARGET VALUES ******'
3942 print 'target mean:' , mu
@@ -49,10 +52,11 @@ def gaussian_energy(x):
4952
5053 return sampler
5154
55+
5256def test_hmc ():
5357 sampler = sampler_on_nd_gaussian (HMC_sampler .new_from_shared_positions ,
5458 burnin = 1000 , n_samples = 1000 , dim = 5 )
55- assert abs (sampler .avg_acceptance_rate - sampler .target_acceptance_rate ) < .1
59+ assert abs (sampler .avg_acceptance_rate -
60+ sampler .target_acceptance_rate ) < .1
5661 assert sampler .stepsize .get_value () >= sampler .stepsize_min
5762 assert sampler .stepsize .get_value () <= sampler .stepsize_max
58-
0 commit comments