|
31 | 31 |
|
32 | 32 |
|
33 | 33 | class HiddenLayer(object): |
34 | | - def __init__(self, rng, input, n_in, n_out, activation = T.tanh): |
| 34 | + def __init__(self, rng, input, n_in, n_out, W = None, b = None, activation = T.tanh): |
35 | 35 | """ |
36 | 36 | Typical hidden layer of a MLP: units are fully-connected and have |
37 | 37 | sigmoidal activation function. Weight matrix W is of shape (n_in,n_out) |
@@ -70,19 +70,25 @@ def __init__(self, rng, input, n_in, n_out, activation = T.tanh): |
70 | 70 | # should use 4 times larger initial weights for sigmoid |
71 | 71 | # compared to tanh |
72 | 72 | # We have no info for other function, so we use the same as tanh. |
73 | | - W_values = numpy.asarray( rng.uniform( |
74 | | - low = - numpy.sqrt(6./(n_in+n_out)), |
75 | | - high = numpy.sqrt(6./(n_in+n_out)), |
76 | | - size = (n_in, n_out)), dtype = theano.config.floatX) |
77 | | - if activation == theano.tensor.nnet.sigmoid: |
78 | | - W_values *= 4 |
| 73 | + if W is None: |
| 74 | + W_values = numpy.asarray( rng.uniform( |
| 75 | + low = - numpy.sqrt(6./(n_in+n_out)), |
| 76 | + high = numpy.sqrt(6./(n_in+n_out)), |
| 77 | + size = (n_in, n_out)), dtype = theano.config.floatX) |
| 78 | + if activation == theano.tensor.nnet.sigmoid: |
| 79 | + W_values *= 4 |
79 | 80 |
|
80 | | - self.W = theano.shared(value = W_values, name ='W') |
| 81 | + W = theano.shared(value = W_values, name ='W') |
81 | 82 |
|
82 | | - b_values = numpy.zeros((n_out,), dtype= theano.config.floatX) |
83 | | - self.b = theano.shared(value= b_values, name ='b') |
| 83 | + if b is None: |
| 84 | + b_values = numpy.zeros((n_out,), dtype= theano.config.floatX) |
| 85 | + b = theano.shared(value= b_values, name ='b') |
84 | 86 |
|
85 | | - self.output = activation(T.dot(input, self.W) + self.b) |
| 87 | + self.W = W |
| 88 | + self.b = b |
| 89 | + |
| 90 | + lin_output = T.dot(input, self.W) + self.b |
| 91 | + self.output = lin_output if activation is None else activation(lin_output) |
86 | 92 | # parameters of the model |
87 | 93 | self.params = [self.W, self.b] |
88 | 94 |
|
|
0 commit comments