Skip to content

Commit 38dd34e

Browse files
committed
added target baseline to ddpg
1 parent 9382c6c commit 38dd34e

4 files changed

Lines changed: 87 additions & 16 deletions

File tree

examples/configs/ddpg.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"update_mode": {
55
"unit": "timesteps",
66
"batch_size": 200,
7-
"frequency": 1
7+
"frequency": 200
88
},
99
"memory": {
1010
"type": "replay",

examples/configs/vpg.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"update_mode": {
55
"unit": "episodes",
66
"batch_size": 20,
7-
"frequency": 1
7+
"frequency": 20
88
},
99
"memory": {
1010
"type": "latest",

tensorforce/agents/ddpg_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323

2424
class DDPGAgent(LearningAgent):
2525
"""
26-
Deep Deterministic Policy Gradient agent as described by [Silver et al. (2014)]
27-
(http://proceedings.mlr.press/v32/silver14.pdf).
26+
Deep Deterministic Policy Gradient agent as described by [Lillicrap et al. (2016)]
27+
(https://arxiv.org/pdf/1509.02971.pdf).
2828
2929
"""
3030

tensorforce/models/pg_log_prob_target_model.py

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tensorforce import util
2323
from tensorforce.models import PGLogProbModel
2424

25+
from tensorforce.core.baselines import Baseline, AggregatedBaseline
2526
from tensorforce.core.networks import Network
2627
from tensorforce.core.optimizers import Synchronization
2728

@@ -67,6 +68,8 @@ def __init__(
6768
self.target_network = None
6869
self.target_optimizer = None
6970
self.target_distributions = None
71+
self.target_baseline = None
72+
self.target_baseline_optimizer = None
7073

7174
super(PGLogProbModel, self).__init__(
7275
states=states,
@@ -112,18 +115,73 @@ def initialize(self, custom_getter):
112115
# Target network distributions
113116
self.target_distributions = self.create_distributions()
114117

115-
def tf_pg_loss_per_instance(self, states, internals, actions, terminal, reward, next_states, next_internals, update):
116-
embedding = self.target_network.apply(x=states, internals=internals, update=update)
117-
log_probs = list()
118+
# Target baseline
119+
if self.baseline_mode:
120+
if all(name in self.states_spec for name in self.baseline_spec):
121+
# Implies AggregatedBaseline.
122+
assert self.baseline_mode == 'states'
123+
self.target_baseline = AggregatedBaseline(baselines=self.baseline_spec)
124+
else:
125+
self.target_baseline = Baseline.from_spec(
126+
spec=self.baseline_spec,
127+
kwargs=dict(
128+
summary_labels=self.summary_labels,
129+
scope='target_baseline'
130+
)
131+
)
132+
133+
# Target baseline optimizer
134+
self.target_baseline_optimizer = Synchronization(
135+
sync_frequency=self.target_sync_frequency,
136+
update_weight=self.target_update_weight
137+
)
138+
139+
def tf_reward_estimation(self, states, internals, terminal, reward, update):
140+
if self.baseline_mode is None:
141+
reward = self.fn_discounted_cumulative_reward(terminal=terminal, reward=reward, discount=self.discount)
118142

119-
for name, distribution in self.target_distributions.items():
120-
distr_params = distribution.parameterize(x=embedding)
121-
log_prob = distribution.log_probability(distr_params=distr_params, action=actions[name])
122-
collapsed_size = util.prod(util.shape(log_prob)[1:])
123-
log_prob = tf.reshape(tensor=log_prob, shape=(-1, collapsed_size))
124-
log_probs.append(log_prob)
125-
log_prob = tf.reduce_mean(input_tensor=tf.concat(values=log_probs, axis=1), axis=1)
126-
return -log_prob * reward
143+
else:
144+
assert self.target_baseline
145+
if self.baseline_mode == 'states':
146+
state_value = self.target_baseline.predict(
147+
states=states,
148+
internals=internals,
149+
update=update
150+
)
151+
152+
elif self.baseline_mode == 'network':
153+
embedding = self.target_network.apply(
154+
x=states,
155+
internals=internals,
156+
update=update
157+
)
158+
state_value = self.target_baseline.predict(
159+
states=embedding,
160+
internals=internals,
161+
update=update
162+
)
163+
164+
if self.gae_lambda is None:
165+
reward = self.fn_discounted_cumulative_reward(
166+
terminal=terminal,
167+
reward=reward,
168+
discount=self.discount
169+
)
170+
reward -= state_value
171+
172+
else:
173+
next_state_value = tf.concat(values=(state_value[1:], (0.0,)), axis=0)
174+
zeros = tf.zeros_like(tensor=next_state_value)
175+
next_state_value = tf.where(condition=terminal, x=zeros, y=next_state_value)
176+
td_residual = reward + self.discount * next_state_value - state_value
177+
gae_discount = self.discount * self.gae_lambda
178+
reward = self.fn_discounted_cumulative_reward(
179+
terminal=terminal,
180+
reward=td_residual,
181+
discount=gae_discount
182+
)
183+
184+
return reward
127185

128186
def tf_optimization(self, states, internals, actions, terminal, reward, next_states=None, next_internals=None):
129187
optimization = super(PGLogProbModel, self).tf_optimization(
@@ -145,6 +203,14 @@ def tf_optimization(self, states, internals, actions, terminal, reward, next_sta
145203
source_variables=self.network.get_variables() + network_distributions_variables
146204
)
147205

206+
if self.target_baseline:
207+
target_baseline_optimization = self.target_baseline_optimizer.minimize(
208+
time=self.timestep,
209+
variables=self.target_baseline.get_variables(),
210+
source_variables=self.baseline.get_variables()
211+
)
212+
return tf.group(optimization, target_optimization, target_baseline_optimization)
213+
148214
return tf.group(optimization, target_optimization)
149215

150216
def get_variables(self, include_non_trainable=False):
@@ -156,11 +222,16 @@ def get_variables(self, include_non_trainable=False):
156222
target_distributions_variables = self.get_distributions_variables(self.target_distributions)
157223
target_optimizer_variables = self.target_optimizer.get_variables()
158224

225+
if self.target_baseline:
226+
target_baseline_variables = self.target_baseline.get_variables()
227+
return model_variables + target_variables + target_optimizer_variables + \
228+
target_distributions_variables + target_baseline_variables
229+
159230
return model_variables + target_variables + target_optimizer_variables + target_distributions_variables
160231
else:
161232
return model_variables
162233

163234
def get_summaries(self):
164235
target_distributions_summaries = self.get_distributions_summaries(self.target_distributions)
165236
return super(PGLogProbModel, self).get_summaries() + self.target_network.get_summaries() \
166-
+ target_distributions_summaries
237+
+ target_distributions_summaries

0 commit comments

Comments
 (0)