2727
2828
2929class DDPGCriticNetwork (LayerBasedNetwork ):
30- def __init__ (self , scope = 'layerbased -network' , summary_labels = (), size_t0 = 400 , size_t1 = 300 ):
30+ def __init__ (self , scope = 'ddpg-critic -network' , summary_labels = (), size_t0 = 400 , size_t1 = 300 ):
3131 super (DDPGCriticNetwork , self ).__init__ (scope = scope , summary_labels = summary_labels )
3232
33- self .t0 = Dense (size = size_t0 , activation = 'relu' )
34- self .t1 = Dense (size = size_t1 , activation = 'relu' )
35- self .t2 = Dense (size = 1 , activation = 'tanh' )
33+ self .t0 = Dense (size = size_t0 , activation = 'relu' , scope = scope + '/dense0' )
34+ self .t1 = Dense (size = size_t1 , activation = 'relu' , scope = scope + '/dense1' )
35+ self .t2 = Dense (size = 1 , activation = 'tanh' , scope = scope + '/dense2' )
3636
3737 self .add_layer (self .t0 )
3838 self .add_layer (self .t1 )
@@ -176,7 +176,7 @@ def initialize(self, custom_getter):
176176 # spec=self.critic_network_spec,
177177 # kwargs=dict(scope='target-critic', summary_labels=self.summary_labels)
178178 # )
179- self .target_critic = DDPGCriticNetwork (scope = 'critic' , size_t0 = size_t0 , size_t1 = size_t1 )
179+ self .target_critic = DDPGCriticNetwork (scope = 'target- critic' , size_t0 = size_t0 , size_t1 = size_t1 )
180180
181181 # Target critic optimizer
182182 self .target_critic_optimizer = Synchronization (
@@ -220,7 +220,7 @@ def tf_target_actions_and_internals(self, states, internals, deterministic=True)
220220
221221 return actions , internals
222222
223- def tf_loss_per_instance (self , states , internals , actions , terminal , reward , next_states , next_internals , update ):
223+ def tf_loss_per_instance (self , states , internals , actions , terminal , reward , next_states , next_internals , update , reference = None ):
224224 # Same as PGLogProbModel
225225 embedding = self .network .apply (x = states , internals = internals , update = update )
226226 log_probs = list ()
@@ -279,8 +279,15 @@ def fn_critic_loss(predicted_q, real_q):
279279 )
280280
281281 # Update target network and baseline
282- network_distributions_variables = self .get_distributions_variables (self .distributions )
283- target_distributions_variables = self .get_distributions_variables (self .target_distributions )
282+ network_distributions_variables = [
283+ variable for name in sorted (self .distributions )
284+ for variable in self .distributions [name ].get_variables (include_nontrainable = False )
285+ ]
286+
287+ target_distributions_variables = [
288+ variable for name in sorted (self .target_distributions )
289+ for variable in self .target_distributions [name ].get_variables (include_nontrainable = False )
290+ ]
284291
285292 target_optimization = self .target_network_optimizer .minimize (
286293 time = self .timestep ,
@@ -296,24 +303,52 @@ def fn_critic_loss(predicted_q, real_q):
296303
297304 return tf .group (critic_optimization , optimization , target_optimization , target_critic_optimization )
298305
299- def get_variables (self , include_non_trainable = False ):
300- model_variables = super (DPGTargetModel , self ).get_variables (include_non_trainable = include_non_trainable )
301- critic_variables = self .critic .get_variables () + self .critic_optimizer .get_variables ()
306+ def get_variables (self , include_submodules = False , include_nontrainable = False ):
307+ model_variables = super (DPGTargetModel , self ).get_variables (
308+ include_submodules = include_submodules ,
309+ include_nontrainable = include_nontrainable
310+ )
311+ critic_variables = self .critic .get_variables (include_nontrainable = include_nontrainable )
312+ model_variables += critic_variables
302313
303- if include_non_trainable :
304- # Target network and optimizer variables only included if 'include_non_trainable' set
305- target_variables = self .target_network .get_variables (include_non_trainable = include_non_trainable ) \
306- + self .get_distributions_variables (self .target_distributions )\
307- + self .target_network_optimizer .get_variables ()
314+ if include_nontrainable :
315+ critic_optimizer_variables = self .critic_optimizer .get_variables ()
308316
309- target_critic_variables = self .target_critic .get_variables () + self .target_critic_optimizer .get_variables ()
317+ for variable in critic_optimizer_variables :
318+ if variable in model_variables :
319+ model_variables .remove (variable )
310320
311- return model_variables + critic_variables + target_variables + target_critic_variables
312- else :
313- return model_variables + critic_variables
321+ model_variables += critic_optimizer_variables
322+
323+ if include_submodules :
324+ target_variables = self .target_network .get_variables (include_nontrainable = include_nontrainable )
325+ model_variables += target_variables
326+
327+ target_distributions_variables = [
328+ variable for name in sorted (self .target_distributions )
329+ for variable in self .target_distributions [name ].get_variables (include_nontrainable = include_nontrainable )
330+ ]
331+ model_variables += target_distributions_variables
332+
333+ target_critic_variables = self .target_critic .get_variables ()
334+ model_variables += target_critic_variables
335+
336+ if include_nontrainable :
337+ target_optimizer_variables = self .target_network_optimizer .get_variables ()
338+ model_variables += target_optimizer_variables
339+
340+ target_critic_optimizer_variables = self .target_critic_optimizer .get_variables ()
341+ model_variables += target_critic_optimizer_variables
342+
343+ return model_variables
314344
315345 def get_summaries (self ):
346+ target_network_summaries = self .target_network .get_summaries ()
347+ target_distributions_summaries = [
348+ summary for name in sorted (self .target_distributions )
349+ for summary in self .target_distributions [name ].get_summaries ()
350+ ]
351+
316352 # Todo: Critic summaries
317- target_distributions_summaries = self .get_distributions_summaries (self .target_distributions )
318- return super (DPGTargetModel , self ).get_summaries () + self .target_network .get_summaries () \
353+ return super (DPGTargetModel , self ).get_summaries () + target_network_summaries \
319354 + target_distributions_summaries
0 commit comments