Skip to content

Commit ea74070

Browse files
committed
updated ddpg
1 parent aa793c2 commit ea74070

4 files changed

Lines changed: 85 additions & 27 deletions

File tree

examples/configs/ddpg.json

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,16 @@
2828
"type": "adam",
2929
"learning_rate": 1e-3
3030
},
31-
"target_sync_frequency": 64,
31+
"target_sync_frequency": 1,
3232
"target_update_weight": 0.999,
3333

34+
"actions_exploration": {
35+
"type": "ornstein_uhlenbeck",
36+
"sigma": 0.2,
37+
"mu": 0.0,
38+
"theta": 0.15
39+
},
40+
3441
"saver": {
3542
"directory": null,
3643
"seconds": 600
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
[
2+
{
3+
"type": "linear",
4+
"size": 64
5+
},
6+
{
7+
"type": "tf_layer",
8+
"layer": "batch_normalization"
9+
},
10+
{
11+
"type": "nonlinearity",
12+
"name": "relu"
13+
},
14+
15+
16+
{
17+
"type": "linear",
18+
"size": 64
19+
},
20+
{
21+
"type": "tf_layer",
22+
"layer": "batch_normalization"
23+
},
24+
{
25+
"type": "nonlinearity",
26+
"name": "relu"
27+
},
28+
29+
{
30+
"type": "dense",
31+
"size": 1,
32+
"activation": "tanh"
33+
}
34+
]

tensorforce/core/networks/layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def tf_apply(self, x, update):
185185
if self.first_scope is None:
186186
# Store scope of first call since regularization losses will be registered there.
187187
self.first_scope = tf.contrib.framework.get_name_scope()
188-
return self.layer(inputs=x)
188+
return self.layer(inputs=x, training=update)
189189

190190
def tf_regularization_loss(self):
191191
regularization_losses = tf.get_collection(

tensorforce/models/dpg_target_model.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,34 @@
2222
from tensorforce import util, TensorForceError
2323
from tensorforce.models import DistributionModel
2424

25-
from tensorforce.core.networks import Network, LayerBasedNetwork, Dense
25+
from tensorforce.core.networks import Network, LayerBasedNetwork, Dense, Linear, TFLayer, Nonlinearity
2626
from tensorforce.core.optimizers import Optimizer, Synchronization
2727

2828

2929
class DDPGCriticNetwork(LayerBasedNetwork):
3030
def __init__(self, scope='ddpg-critic-network', summary_labels=(), size_t0=400, size_t1=300):
3131
super(DDPGCriticNetwork, self).__init__(scope=scope, summary_labels=summary_labels)
3232

33-
self.t0 = Dense(size=size_t0, activation='relu', scope=scope + '/dense0')
34-
self.t1 = Dense(size=size_t1, activation='relu', scope=scope + '/dense1')
35-
self.t2 = Dense(size=1, activation='tanh', scope=scope + '/dense2')
33+
self.t0l = Linear(size=size_t0, scope='linear0')
34+
self.t0b = TFLayer(layer='batch_normalization', scope='batchnorm0', center=True, scale=True)
35+
self.t0n = Nonlinearity(name='relu', scope='relu0')
3636

37-
self.add_layer(self.t0)
38-
self.add_layer(self.t1)
39-
self.add_layer(self.t2)
37+
self.t1l = Linear(size=size_t1, scope='linear1')
38+
self.t1b = TFLayer(layer='batch_normalization', scope='batchnorm1', center=True, scale=True)
39+
self.t1n = Nonlinearity(name='relu', scope='relu1')
40+
41+
self.t2d = Dense(size=1, activation='tanh', scope='dense0',
42+
weights=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3))
43+
44+
self.add_layer(self.t0l)
45+
self.add_layer(self.t0b)
46+
self.add_layer(self.t0n)
47+
48+
self.add_layer(self.t1l)
49+
self.add_layer(self.t1b)
50+
self.add_layer(self.t1n)
51+
52+
self.add_layer(self.t2d)
4053

4154
def tf_apply(self, x, internals, update, return_internals=False):
4255
assert x['states'], x['actions']
@@ -59,13 +72,21 @@ def tf_apply(self, x, internals, update, return_internals=False):
5972

6073
x_actions = tf.reshape(tf.cast(x_actions, dtype=tf.float32), (-1, 1))
6174

62-
out = self.t0.tf_apply(x=x_states, update=update)
75+
out = self.t0l.apply(x=x_states, update=update)
76+
out = self.t0b.apply(x=out, update=update)
77+
out = self.t0n.apply(x=out, update=update)
6378

64-
out = self.t1.tf_apply(x=tf.concat([out, x_actions], axis=-1), update=update)
79+
out = self.t1l.apply(x=tf.concat([out, x_actions], axis=-1), update=update)
80+
out = self.t1b.apply(x=out, update=update)
81+
out = self.t1n.apply(x=out, update=update)
6582

66-
out = self.t2.tf_apply(x=out, update=update)
83+
out = self.t2d.apply(x=out, update=update)
6784

68-
return out
85+
if return_internals:
86+
# Todo: Internals management
87+
return out, None
88+
else:
89+
return out
6990

7091

7192
class DPGTargetModel(DistributionModel):
@@ -139,6 +160,7 @@ def __init__(
139160
)
140161

141162
assert self.memory_spec["include_next_states"]
163+
assert self.requires_deterministic == True
142164

143165
def initialize(self, custom_getter):
144166
super(DPGTargetModel, self).initialize(custom_getter)
@@ -159,10 +181,6 @@ def initialize(self, custom_getter):
159181
self.target_distributions = self.create_distributions()
160182

161183
# Critic
162-
# self.critic = Network.from_spec(
163-
# spec=self.critic_network_spec,
164-
# kwargs=dict(scope='critic', summary_labels=self.summary_labels)
165-
# )
166184
size_t0 = self.critic_network_spec['size_t0']
167185
size_t1 = self.critic_network_spec['size_t1']
168186

@@ -172,10 +190,6 @@ def initialize(self, custom_getter):
172190
kwargs=dict(summary_labels=self.summary_labels)
173191
)
174192

175-
# self.target_critic = Network.from_spec(
176-
# spec=self.critic_network_spec,
177-
# kwargs=dict(scope='target-critic', summary_labels=self.summary_labels)
178-
# )
179193
self.target_critic = DDPGCriticNetwork(scope='target-critic', size_t0=size_t0, size_t1=size_t1)
180194

181195
# Target critic optimizer
@@ -216,7 +230,6 @@ def tf_target_actions_and_internals(self, states, internals, deterministic=True)
216230

217231
def tf_loss_per_instance(self, states, internals, actions, terminal, reward, next_states, next_internals, update, reference=None):
218232
q = self.critic.apply(dict(states=states, actions=actions), internals=internals, update=update)
219-
220233
return -q
221234

222235
def tf_predict_target_q(self, states, internals, terminal, actions, reward, update):
@@ -225,13 +238,17 @@ def tf_predict_target_q(self, states, internals, terminal, actions, reward, upda
225238

226239
def tf_optimization(self, states, internals, actions, terminal, reward, next_states=None, next_internals=None):
227240
update = tf.constant(value=True)
241+
228242
# Predict actions from target actor
229-
target_actions, target_internals = self.fn_target_actions_and_internals(
243+
next_target_actions, next_target_internals = self.fn_target_actions_and_internals(
230244
states=next_states, internals=next_internals, deterministic=True
231245
)
232246

233-
predicted_q = self.fn_predict_target_q(states=next_states, internals=next_internals,
234-
actions=target_actions, terminal=terminal, reward=reward, update=update)
247+
# Predicted Q value of next states
248+
predicted_q = self.fn_predict_target_q(
249+
states=next_states, internals=next_internals, actions=next_target_actions, terminal=terminal,
250+
reward=reward, update=update
251+
)
235252
predicted_q = tf.stop_gradient(input=predicted_q)
236253

237254
real_q = self.critic.apply(dict(states=states, actions=actions), internals=internals, update=update)
@@ -264,7 +281,7 @@ def fn_critic_loss(predicted_q, real_q):
264281
next_internals=next_internals
265282
)
266283

267-
# Update target network and baseline
284+
# Update target actor (network) and critic
268285
network_distributions_variables = [
269286
variable for name in sorted(self.distributions)
270287
for variable in self.distributions[name].get_variables(include_nontrainable=False)
@@ -316,7 +333,7 @@ def get_variables(self, include_submodules=False, include_nontrainable=False):
316333
]
317334
model_variables += target_distributions_variables
318335

319-
target_critic_variables = self.target_critic.get_variables()
336+
target_critic_variables = self.target_critic.get_variables(include_nontrainable=include_nontrainable)
320337
model_variables += target_critic_variables
321338

322339
if include_nontrainable:

0 commit comments

Comments
 (0)