@@ -19,12 +19,12 @@ def __init__(self, config):
1919 - Learning rate: learning_rate
2020 - Learning rate: learning_rate_step
2121 - Learning rate: learning_rate_decay
22+ - Optimized used (momentum, rmsprop): optimizer
2223 - Momentum: momentum
23- - Gradient clipping (by norm): max_gradient
2424 """
2525
2626 # Check config
27- for v in ['input_dim' , 'hidden_dim' , 'output_dim' , 'periods' , 'num_steps' , 'learning_rate' , 'learning_rate_step' , 'learning_rate_decay' , 'momentum ' , 'max_gradient ' ]:
27+ for v in ['input_dim' , 'hidden_dim' , 'output_dim' , 'periods' , 'num_steps' , 'learning_rate' , 'learning_rate_step' , 'learning_rate_decay' , 'optimizer ' , 'momentum ' ]:
2828 if v not in config :
2929 print ('Missing config[\' {}\' ]' .format (v ))
3030 exit (1 )
@@ -137,7 +137,13 @@ def create_optimizer(self):
137137 self .config ['learning_rate_decay' ], # Decay rate.
138138 staircase = True )
139139
140- self .optimizer = tf .train .MomentumOptimizer (self .learning_rate , self .config ['momentum' ], use_nesterov = True )
140+ if self .config ['optimizer' ] == 'momentum' :
141+ self .optimizer = tf .train .MomentumOptimizer (self .learning_rate , self .config ['momentum' ], use_nesterov = True )
142+ elif self .config ['optimizer' ] == 'rmsprop' :
143+ self .optimizer = tf .train .RMSPropOptimizer (self .learning_rate )
144+ else :
145+ print ('Unknown optimizer {}' .format (self .config ['optimizer' ]))
146+ exit (1 )
141147
142148 self .train_step = self .optimizer .minimize (self .loss , global_step = self .global_step )
143149
0 commit comments