2222from tensorforce import util
2323from tensorforce .models import PGLogProbModel
2424
25+ from tensorforce .core .baselines import Baseline , AggregatedBaseline
2526from tensorforce .core .networks import Network
2627from tensorforce .core .optimizers import Synchronization
2728
@@ -67,6 +68,8 @@ def __init__(
6768 self .target_network = None
6869 self .target_optimizer = None
6970 self .target_distributions = None
71+ self .target_baseline = None
72+ self .target_baseline_optimizer = None
7073
7174 super (PGLogProbModel , self ).__init__ (
7275 states = states ,
@@ -112,18 +115,73 @@ def initialize(self, custom_getter):
112115 # Target network distributions
113116 self .target_distributions = self .create_distributions ()
114117
115- def tf_pg_loss_per_instance (self , states , internals , actions , terminal , reward , next_states , next_internals , update ):
116- embedding = self .target_network .apply (x = states , internals = internals , update = update )
117- log_probs = list ()
118+ # Target baseline
119+ if self .baseline_mode :
120+ if all (name in self .states_spec for name in self .baseline_spec ):
121+ # Implies AggregatedBaseline.
122+ assert self .baseline_mode == 'states'
123+ self .target_baseline = AggregatedBaseline (baselines = self .baseline_spec )
124+ else :
125+ self .target_baseline = Baseline .from_spec (
126+ spec = self .baseline_spec ,
127+ kwargs = dict (
128+ summary_labels = self .summary_labels ,
129+ scope = 'target_baseline'
130+ )
131+ )
132+
133+ # Target baseline optimizer
134+ self .target_baseline_optimizer = Synchronization (
135+ sync_frequency = self .target_sync_frequency ,
136+ update_weight = self .target_update_weight
137+ )
138+
139+ def tf_reward_estimation (self , states , internals , terminal , reward , update ):
140+ if self .baseline_mode is None :
141+ reward = self .fn_discounted_cumulative_reward (terminal = terminal , reward = reward , discount = self .discount )
118142
119- for name , distribution in self .target_distributions .items ():
120- distr_params = distribution .parameterize (x = embedding )
121- log_prob = distribution .log_probability (distr_params = distr_params , action = actions [name ])
122- collapsed_size = util .prod (util .shape (log_prob )[1 :])
123- log_prob = tf .reshape (tensor = log_prob , shape = (- 1 , collapsed_size ))
124- log_probs .append (log_prob )
125- log_prob = tf .reduce_mean (input_tensor = tf .concat (values = log_probs , axis = 1 ), axis = 1 )
126- return - log_prob * reward
143+ else :
144+ assert self .target_baseline
145+ if self .baseline_mode == 'states' :
146+ state_value = self .target_baseline .predict (
147+ states = states ,
148+ internals = internals ,
149+ update = update
150+ )
151+
152+ elif self .baseline_mode == 'network' :
153+ embedding = self .target_network .apply (
154+ x = states ,
155+ internals = internals ,
156+ update = update
157+ )
158+ state_value = self .target_baseline .predict (
159+ states = embedding ,
160+ internals = internals ,
161+ update = update
162+ )
163+
164+ if self .gae_lambda is None :
165+ reward = self .fn_discounted_cumulative_reward (
166+ terminal = terminal ,
167+ reward = reward ,
168+ discount = self .discount
169+ )
170+ reward -= state_value
171+
172+ else :
173+ next_state_value = tf .concat (values = (state_value [1 :], (0.0 ,)), axis = 0 )
174+ zeros = tf .zeros_like (tensor = next_state_value )
175+ next_state_value = tf .where (condition = terminal , x = zeros , y = next_state_value )
176+ td_residual = reward + self .discount * next_state_value - state_value
177+ gae_discount = self .discount * self .gae_lambda
178+ reward = self .fn_discounted_cumulative_reward (
179+ terminal = terminal ,
180+ reward = td_residual ,
181+ discount = gae_discount
182+ )
183+
184+ return reward
127185
128186 def tf_optimization (self , states , internals , actions , terminal , reward , next_states = None , next_internals = None ):
129187 optimization = super (PGLogProbModel , self ).tf_optimization (
@@ -145,6 +203,14 @@ def tf_optimization(self, states, internals, actions, terminal, reward, next_sta
145203 source_variables = self .network .get_variables () + network_distributions_variables
146204 )
147205
206+ if self .target_baseline :
207+ target_baseline_optimization = self .target_baseline_optimizer .minimize (
208+ time = self .timestep ,
209+ variables = self .target_baseline .get_variables (),
210+ source_variables = self .baseline .get_variables ()
211+ )
212+ return tf .group (optimization , target_optimization , target_baseline_optimization )
213+
148214 return tf .group (optimization , target_optimization )
149215
150216 def get_variables (self , include_non_trainable = False ):
@@ -156,11 +222,16 @@ def get_variables(self, include_non_trainable=False):
156222 target_distributions_variables = self .get_distributions_variables (self .target_distributions )
157223 target_optimizer_variables = self .target_optimizer .get_variables ()
158224
225+ if self .target_baseline :
226+ target_baseline_variables = self .target_baseline .get_variables ()
227+ return model_variables + target_variables + target_optimizer_variables + \
228+ target_distributions_variables + target_baseline_variables
229+
159230 return model_variables + target_variables + target_optimizer_variables + target_distributions_variables
160231 else :
161232 return model_variables
162233
163234 def get_summaries (self ):
164235 target_distributions_summaries = self .get_distributions_summaries (self .target_distributions )
165236 return super (PGLogProbModel , self ).get_summaries () + self .target_network .get_summaries () \
166- + target_distributions_summaries
237+ + target_distributions_summaries
0 commit comments