2222from tensorforce import util , TensorForceError
2323from tensorforce .models import DistributionModel
2424
25- from tensorforce .core .networks import Network , LayerBasedNetwork , Dense
25+ from tensorforce .core .networks import Network , LayerBasedNetwork , Dense , Linear , TFLayer , Nonlinearity
2626from tensorforce .core .optimizers import Optimizer , Synchronization
2727
2828
2929class DDPGCriticNetwork (LayerBasedNetwork ):
3030 def __init__ (self , scope = 'ddpg-critic-network' , summary_labels = (), size_t0 = 400 , size_t1 = 300 ):
3131 super (DDPGCriticNetwork , self ).__init__ (scope = scope , summary_labels = summary_labels )
3232
33- self .t0 = Dense (size = size_t0 , activation = 'relu' , scope = scope + '/dense0 ' )
34- self .t1 = Dense ( size = size_t1 , activation = 'relu ' , scope = scope + '/dense1' )
35- self .t2 = Dense ( size = 1 , activation = 'tanh ' , scope = scope + '/dense2 ' )
33+ self .t0l = Linear (size = size_t0 , scope = 'linear0 ' )
34+ self .t0b = TFLayer ( layer = 'batch_normalization' , scope = 'batchnorm0 ' , center = True , scale = True )
35+ self .t0n = Nonlinearity ( name = 'relu ' , scope = 'relu0 ' )
3636
37- self .add_layer (self .t0 )
38- self .add_layer (self .t1 )
39- self .add_layer (self .t2 )
37+ self .t1l = Linear (size = size_t1 , scope = 'linear1' )
38+ self .t1b = TFLayer (layer = 'batch_normalization' , scope = 'batchnorm1' , center = True , scale = True )
39+ self .t1n = Nonlinearity (name = 'relu' , scope = 'relu1' )
40+
41+ self .t2d = Dense (size = 1 , activation = 'tanh' , scope = 'dense0' ,
42+ weights = tf .random_uniform_initializer (minval = - 3e-3 , maxval = 3e-3 ))
43+
44+ self .add_layer (self .t0l )
45+ self .add_layer (self .t0b )
46+ self .add_layer (self .t0n )
47+
48+ self .add_layer (self .t1l )
49+ self .add_layer (self .t1b )
50+ self .add_layer (self .t1n )
51+
52+ self .add_layer (self .t2d )
4053
4154 def tf_apply (self , x , internals , update , return_internals = False ):
4255 assert x ['states' ], x ['actions' ]
@@ -59,13 +72,21 @@ def tf_apply(self, x, internals, update, return_internals=False):
5972
6073 x_actions = tf .reshape (tf .cast (x_actions , dtype = tf .float32 ), (- 1 , 1 ))
6174
62- out = self .t0 .tf_apply (x = x_states , update = update )
75+ out = self .t0l .apply (x = x_states , update = update )
76+ out = self .t0b .apply (x = out , update = update )
77+ out = self .t0n .apply (x = out , update = update )
6378
64- out = self .t1 .tf_apply (x = tf .concat ([out , x_actions ], axis = - 1 ), update = update )
79+ out = self .t1l .apply (x = tf .concat ([out , x_actions ], axis = - 1 ), update = update )
80+ out = self .t1b .apply (x = out , update = update )
81+ out = self .t1n .apply (x = out , update = update )
6582
66- out = self .t2 . tf_apply (x = out , update = update )
83+ out = self .t2d . apply (x = out , update = update )
6784
68- return out
85+ if return_internals :
86+ # Todo: Internals management
87+ return out , None
88+ else :
89+ return out
6990
7091
7192class DPGTargetModel (DistributionModel ):
@@ -139,6 +160,7 @@ def __init__(
139160 )
140161
141162 assert self .memory_spec ["include_next_states" ]
163+ assert self .requires_deterministic == True
142164
143165 def initialize (self , custom_getter ):
144166 super (DPGTargetModel , self ).initialize (custom_getter )
@@ -159,10 +181,6 @@ def initialize(self, custom_getter):
159181 self .target_distributions = self .create_distributions ()
160182
161183 # Critic
162- # self.critic = Network.from_spec(
163- # spec=self.critic_network_spec,
164- # kwargs=dict(scope='critic', summary_labels=self.summary_labels)
165- # )
166184 size_t0 = self .critic_network_spec ['size_t0' ]
167185 size_t1 = self .critic_network_spec ['size_t1' ]
168186
@@ -172,10 +190,6 @@ def initialize(self, custom_getter):
172190 kwargs = dict (summary_labels = self .summary_labels )
173191 )
174192
175- # self.target_critic = Network.from_spec(
176- # spec=self.critic_network_spec,
177- # kwargs=dict(scope='target-critic', summary_labels=self.summary_labels)
178- # )
179193 self .target_critic = DDPGCriticNetwork (scope = 'target-critic' , size_t0 = size_t0 , size_t1 = size_t1 )
180194
181195 # Target critic optimizer
@@ -216,7 +230,6 @@ def tf_target_actions_and_internals(self, states, internals, deterministic=True)
216230
217231 def tf_loss_per_instance (self , states , internals , actions , terminal , reward , next_states , next_internals , update , reference = None ):
218232 q = self .critic .apply (dict (states = states , actions = actions ), internals = internals , update = update )
219-
220233 return - q
221234
222235 def tf_predict_target_q (self , states , internals , terminal , actions , reward , update ):
@@ -225,13 +238,17 @@ def tf_predict_target_q(self, states, internals, terminal, actions, reward, upda
225238
226239 def tf_optimization (self , states , internals , actions , terminal , reward , next_states = None , next_internals = None ):
227240 update = tf .constant (value = True )
241+
228242 # Predict actions from target actor
229- target_actions , target_internals = self .fn_target_actions_and_internals (
243+ next_target_actions , next_target_internals = self .fn_target_actions_and_internals (
230244 states = next_states , internals = next_internals , deterministic = True
231245 )
232246
233- predicted_q = self .fn_predict_target_q (states = next_states , internals = next_internals ,
234- actions = target_actions , terminal = terminal , reward = reward , update = update )
247+ # Predicted Q value of next states
248+ predicted_q = self .fn_predict_target_q (
249+ states = next_states , internals = next_internals , actions = next_target_actions , terminal = terminal ,
250+ reward = reward , update = update
251+ )
235252 predicted_q = tf .stop_gradient (input = predicted_q )
236253
237254 real_q = self .critic .apply (dict (states = states , actions = actions ), internals = internals , update = update )
@@ -264,7 +281,7 @@ def fn_critic_loss(predicted_q, real_q):
264281 next_internals = next_internals
265282 )
266283
267- # Update target network and baseline
284+ # Update target actor ( network) and critic
268285 network_distributions_variables = [
269286 variable for name in sorted (self .distributions )
270287 for variable in self .distributions [name ].get_variables (include_nontrainable = False )
@@ -316,7 +333,7 @@ def get_variables(self, include_submodules=False, include_nontrainable=False):
316333 ]
317334 model_variables += target_distributions_variables
318335
319- target_critic_variables = self .target_critic .get_variables ()
336+ target_critic_variables = self .target_critic .get_variables (include_nontrainable = include_nontrainable )
320337 model_variables += target_critic_variables
321338
322339 if include_nontrainable :
0 commit comments