Skip to content

Commit 8b4e1f9

Browse files
committed
pep8
1 parent b3ba144 commit 8b4e1f9

1 file changed

Lines changed: 12 additions & 8 deletions

File tree

code/mcrbm/test_hmc.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,23 @@
44

55
from hmc import HMC_sampler
66

7+
78
def sampler_on_nd_gaussian(sampler_cls, burnin, n_samples, dim=10):
8-
batchsize=3
9+
batchsize = 3
910

1011
rng = numpy.random.RandomState(123)
1112

1213
# Define a covariance and mu for a gaussian
13-
mu = numpy.array(rng.rand(dim) * 10, dtype=theano.config.floatX)
14+
mu = numpy.array(rng.rand(dim) * 10, dtype=theano.config.floatX)
1415
cov = numpy.array(rng.rand(dim, dim), dtype=theano.config.floatX)
1516
cov = (cov + cov.T) / 2.
1617
cov[numpy.arange(dim), numpy.arange(dim)] = 1.0
1718
cov_inv = linalg.inv(cov)
1819

1920
# Define energy function for a multi-variate Gaussian
2021
def gaussian_energy(x):
21-
return 0.5 * (theano.tensor.dot((x-mu),cov_inv)*(x-mu)).sum(axis=1)
22+
return 0.5 * (theano.tensor.dot((x - mu), cov_inv) *
23+
(x - mu)).sum(axis=1)
2224

2325
# Declared shared random variable for positions
2426
position = rng.randn(batchsize, dim).astype(theano.config.floatX)
@@ -29,11 +31,12 @@ def gaussian_energy(x):
2931
initial_stepsize=1e-3, stepsize_max=0.5)
3032

3133
# Start with a burn-in process
32-
garbage = [sampler.draw() for r in xrange(burnin)] #burn-in
33-
# Draw `n_samples`: result is a 3D tensor of dim [n_samples, batchsize, dim]
34+
garbage = [sampler.draw() for r in xrange(burnin)] # burn-in Draw
35+
# `n_samples`: result is a 3D tensor of dim [n_samples, batchsize,
36+
# dim]
3437
_samples = numpy.asarray([sampler.draw() for r in xrange(n_samples)])
3538
# Flatten to [n_samples * batchsize, dim]
36-
samples = _samples.T.reshape(dim,-1).T
39+
samples = _samples.T.reshape(dim, -1).T
3740

3841
print '****** TARGET VALUES ******'
3942
print 'target mean:', mu
@@ -49,10 +52,11 @@ def gaussian_energy(x):
4952

5053
return sampler
5154

55+
5256
def test_hmc():
5357
sampler = sampler_on_nd_gaussian(HMC_sampler.new_from_shared_positions,
5458
burnin=1000, n_samples=1000, dim=5)
55-
assert abs(sampler.avg_acceptance_rate - sampler.target_acceptance_rate) < .1
59+
assert abs(sampler.avg_acceptance_rate -
60+
sampler.target_acceptance_rate) < .1
5661
assert sampler.stepsize.get_value() >= sampler.stepsize_min
5762
assert sampler.stepsize.get_value() <= sampler.stepsize_max
58-

0 commit comments

Comments
 (0)