Skip to content

Commit fb3047c

Browse files
committed
2 parents c451a31 + 5aa3106 commit fb3047c

2 files changed

Lines changed: 15 additions & 10 deletions

File tree

main.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
'hidden_dim': 9,
1818
'output_dim': 1,
1919
'periods': [1, 2, 4, 8, 16, 32, 64, 128, 256],
20-
'num_steps': 320,
20+
'num_steps': 10,
2121

22-
'learning_rate': 3e-4,
23-
'learning_rate_step': 200,
24-
'learning_rate_decay': 0.95,
22+
'learning_rate': 1e-2,
23+
'learning_rate_step': 50,
24+
'learning_rate_decay': 0.9,
25+
'optimizer': 'rmsprop',
2526
'momentum': 0.95,
26-
'max_gradient': 10,
27-
'max_epochs': 5000
27+
'max_epochs': 1000
2828
}
2929

3030
### Create the model ###
@@ -75,7 +75,6 @@
7575
outputs = outputs.reshape(-1)
7676
ground_truth = targets.reshape(-1)
7777

78-
print(outputs)
7978
# Final result
8079
print('')
8180
print('')

models/clockwork_rnn.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ def __init__(self, config):
1919
- Learning rate: learning_rate
2020
- Learning rate: learning_rate_step
2121
- Learning rate: learning_rate_decay
22+
- Optimized used (momentum, rmsprop): optimizer
2223
- Momentum: momentum
23-
- Gradient clipping (by norm): max_gradient
2424
"""
2525

2626
# Check config
27-
for v in ['input_dim', 'hidden_dim', 'output_dim', 'periods', 'num_steps', 'learning_rate', 'learning_rate_step', 'learning_rate_decay', 'momentum', 'max_gradient']:
27+
for v in ['input_dim', 'hidden_dim', 'output_dim', 'periods', 'num_steps', 'learning_rate', 'learning_rate_step', 'learning_rate_decay', 'optimizer', 'momentum']:
2828
if v not in config:
2929
print('Missing config[\'{}\']'.format(v))
3030
exit(1)
@@ -137,7 +137,13 @@ def create_optimizer(self):
137137
self.config['learning_rate_decay'], # Decay rate.
138138
staircase = True)
139139

140-
self.optimizer = tf.train.MomentumOptimizer(self.learning_rate, self.config['momentum'], use_nesterov = True)
140+
if self.config['optimizer'] == 'momentum':
141+
self.optimizer = tf.train.MomentumOptimizer(self.learning_rate, self.config['momentum'], use_nesterov = True)
142+
elif self.config['optimizer'] == 'rmsprop':
143+
self.optimizer = tf.train.RMSPropOptimizer(self.learning_rate)
144+
else:
145+
print('Unknown optimizer {}'.format(self.config['optimizer']))
146+
exit(1)
141147

142148
self.train_step = self.optimizer.minimize(self.loss, global_step = self.global_step)
143149

0 commit comments

Comments
 (0)