Skip to content

Commit c283065

Browse files
committed
updated formula to new weight initialization formula
1 parent af7aa84 commit c283065

1 file changed

Lines changed: 16 additions & 8 deletions

File tree

code/convolutional_mlp.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,10 @@ def __init__(self, rng, input, filter_shape, image_shape, poolsize=(2,2)):
4949
"""
5050
assert image_shape[1]==filter_shape[1]
5151
self.input = input
52-
53-
# initialize weight values: the fan-in of each hidden neuron is
54-
# restricted by the size of the receptive fields.
55-
fan_in = numpy.prod(filter_shape[1:])
56-
W_values = numpy.asarray( rng.uniform( \
57-
low = -numpy.sqrt(3./fan_in), \
58-
high = numpy.sqrt(3./fan_in), \
59-
size = filter_shape), dtype = theano.config.floatX)
52+
53+
# initialize weights to temporary values until we know the shape of the output feature
54+
# maps
55+
W_values = numpy.zeros(filter_shape, dtype=theano.config.floatX)
6056
self.W = theano.shared(value = W_values)
6157

6258
# the bias is a 1D tensor -- one bias per output feature map
@@ -67,6 +63,18 @@ def __init__(self, rng, input, filter_shape, image_shape, poolsize=(2,2)):
6763
conv_out = conv.conv2d(input, self.W,
6864
filter_shape=filter_shape, image_shape=image_shape)
6965

66+
# there are "num input feature maps * filter height * filter width" inputs
67+
# to each hidden unit
68+
fan_in = numpy.prod(filter_shape[1:])
69+
# each unit in the lower layer receives a gradient from:
70+
# "num output feature maps * filter height * filter width" / pooling size
71+
fan_out = filter_shape[0] * numpy.prod(filter_shape[2:]) / numpy.prod(poolsize)
72+
# replace weight values with random weights
73+
W_bound = numpy.sqrt(6./(fan_in + fan_out))
74+
self.W.value = numpy.asarray(
75+
rng.uniform(low=-W_bound, high=W_bound, size=filter_shape),
76+
dtype = theano.config.floatX)
77+
7078
# downsample each feature map individually, using maxpooling
7179
pooled_out = downsample.max_pool2D(conv_out, poolsize, ignore_border=True)
7280

0 commit comments

Comments
 (0)