Skip to content

Commit dde44f0

Browse files
committed
updated dpg target model to new get_variable api
1 parent b11a351 commit dde44f0

2 files changed

Lines changed: 59 additions & 24 deletions

File tree

examples/configs/ddpg.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
"entropy_regularization": null,
2222

2323
"critic_network": {
24-
"size_t0": 400,
25-
"size_t1": 300
24+
"size_t0": 64,
25+
"size_t1": 64
2626
},
2727
"critic_optimizer": {
2828
"type": "adam",

tensorforce/models/dpg_target_model.py

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@
2727

2828

2929
class DDPGCriticNetwork(LayerBasedNetwork):
30-
def __init__(self, scope='layerbased-network', summary_labels=(), size_t0=400, size_t1=300):
30+
def __init__(self, scope='ddpg-critic-network', summary_labels=(), size_t0=400, size_t1=300):
3131
super(DDPGCriticNetwork, self).__init__(scope=scope, summary_labels=summary_labels)
3232

33-
self.t0 = Dense(size=size_t0, activation='relu')
34-
self.t1 = Dense(size=size_t1, activation='relu')
35-
self.t2 = Dense(size=1, activation='tanh')
33+
self.t0 = Dense(size=size_t0, activation='relu', scope=scope + '/dense0')
34+
self.t1 = Dense(size=size_t1, activation='relu', scope=scope + '/dense1')
35+
self.t2 = Dense(size=1, activation='tanh', scope=scope + '/dense2')
3636

3737
self.add_layer(self.t0)
3838
self.add_layer(self.t1)
@@ -176,7 +176,7 @@ def initialize(self, custom_getter):
176176
# spec=self.critic_network_spec,
177177
# kwargs=dict(scope='target-critic', summary_labels=self.summary_labels)
178178
# )
179-
self.target_critic = DDPGCriticNetwork(scope='critic', size_t0=size_t0, size_t1=size_t1)
179+
self.target_critic = DDPGCriticNetwork(scope='target-critic', size_t0=size_t0, size_t1=size_t1)
180180

181181
# Target critic optimizer
182182
self.target_critic_optimizer = Synchronization(
@@ -220,7 +220,7 @@ def tf_target_actions_and_internals(self, states, internals, deterministic=True)
220220

221221
return actions, internals
222222

223-
def tf_loss_per_instance(self, states, internals, actions, terminal, reward, next_states, next_internals, update):
223+
def tf_loss_per_instance(self, states, internals, actions, terminal, reward, next_states, next_internals, update, reference=None):
224224
# Same as PGLogProbModel
225225
embedding = self.network.apply(x=states, internals=internals, update=update)
226226
log_probs = list()
@@ -279,8 +279,15 @@ def fn_critic_loss(predicted_q, real_q):
279279
)
280280

281281
# Update target network and baseline
282-
network_distributions_variables = self.get_distributions_variables(self.distributions)
283-
target_distributions_variables = self.get_distributions_variables(self.target_distributions)
282+
network_distributions_variables = [
283+
variable for name in sorted(self.distributions)
284+
for variable in self.distributions[name].get_variables(include_nontrainable=False)
285+
]
286+
287+
target_distributions_variables = [
288+
variable for name in sorted(self.target_distributions)
289+
for variable in self.target_distributions[name].get_variables(include_nontrainable=False)
290+
]
284291

285292
target_optimization = self.target_network_optimizer.minimize(
286293
time=self.timestep,
@@ -296,24 +303,52 @@ def fn_critic_loss(predicted_q, real_q):
296303

297304
return tf.group(critic_optimization, optimization, target_optimization, target_critic_optimization)
298305

299-
def get_variables(self, include_non_trainable=False):
300-
model_variables = super(DPGTargetModel, self).get_variables(include_non_trainable=include_non_trainable)
301-
critic_variables = self.critic.get_variables() + self.critic_optimizer.get_variables()
306+
def get_variables(self, include_submodules=False, include_nontrainable=False):
307+
model_variables = super(DPGTargetModel, self).get_variables(
308+
include_submodules=include_submodules,
309+
include_nontrainable=include_nontrainable
310+
)
311+
critic_variables = self.critic.get_variables(include_nontrainable=include_nontrainable)
312+
model_variables += critic_variables
302313

303-
if include_non_trainable:
304-
# Target network and optimizer variables only included if 'include_non_trainable' set
305-
target_variables = self.target_network.get_variables(include_non_trainable=include_non_trainable) \
306-
+ self.get_distributions_variables(self.target_distributions)\
307-
+ self.target_network_optimizer.get_variables()
314+
if include_nontrainable:
315+
critic_optimizer_variables = self.critic_optimizer.get_variables()
308316

309-
target_critic_variables = self.target_critic.get_variables() + self.target_critic_optimizer.get_variables()
317+
for variable in critic_optimizer_variables:
318+
if variable in model_variables:
319+
model_variables.remove(variable)
310320

311-
return model_variables + critic_variables + target_variables + target_critic_variables
312-
else:
313-
return model_variables + critic_variables
321+
model_variables += critic_optimizer_variables
322+
323+
if include_submodules:
324+
target_variables = self.target_network.get_variables(include_nontrainable=include_nontrainable)
325+
model_variables += target_variables
326+
327+
target_distributions_variables = [
328+
variable for name in sorted(self.target_distributions)
329+
for variable in self.target_distributions[name].get_variables(include_nontrainable=include_nontrainable)
330+
]
331+
model_variables += target_distributions_variables
332+
333+
target_critic_variables = self.target_critic.get_variables()
334+
model_variables += target_critic_variables
335+
336+
if include_nontrainable:
337+
target_optimizer_variables = self.target_network_optimizer.get_variables()
338+
model_variables += target_optimizer_variables
339+
340+
target_critic_optimizer_variables = self.target_critic_optimizer.get_variables()
341+
model_variables += target_critic_optimizer_variables
342+
343+
return model_variables
314344

315345
def get_summaries(self):
346+
target_network_summaries = self.target_network.get_summaries()
347+
target_distributions_summaries = [
348+
summary for name in sorted(self.target_distributions)
349+
for summary in self.target_distributions[name].get_summaries()
350+
]
351+
316352
# Todo: Critic summaries
317-
target_distributions_summaries = self.get_distributions_summaries(self.target_distributions)
318-
return super(DPGTargetModel, self).get_summaries() + self.target_network.get_summaries() \
353+
return super(DPGTargetModel, self).get_summaries() + target_network_summaries \
319354
+ target_distributions_summaries

0 commit comments

Comments
 (0)