2020from __future__ import print_function
2121
2222import collections
23+ import functools
2324from absl import logging
2425
26+ from dm_control .composer import variation
2527from dm_control .composer .observation import obs_buffer
2628from dm_env import specs
2729import numpy as np
@@ -37,6 +39,7 @@ class _EnabledObservable(object):
3739 """Encapsulates an enabled observable, its buffer, and its update schedule."""
3840
3941 __slots__ = ('observable' , 'observation_callable' ,
42+ 'update_interval' , 'delay' , 'buffer_size' ,
4043 'buffer' , 'update_schedule' )
4144
4245 def __init__ (self , observable , physics , random_state ,
@@ -45,6 +48,16 @@ def __init__(self, observable, physics, random_state,
4548 self .observation_callable = (
4649 observable .observation_callable (physics , random_state ))
4750
51+ self ._bind_attribute_from_observable ('update_interval' ,
52+ DEFAULT_UPDATE_INTERVAL ,
53+ random_state )
54+ self ._bind_attribute_from_observable ('delay' ,
55+ DEFAULT_DELAY ,
56+ random_state )
57+ self ._bind_attribute_from_observable ('buffer_size' ,
58+ DEFAULT_BUFFER_SIZE ,
59+ random_state )
60+
4861 obs_spec = self .observable .array_spec
4962 if obs_spec is None :
5063 # We take an observation to determine the shape and dtype of the array.
@@ -58,11 +71,22 @@ def __init__(self, observable, physics, random_state,
5871 obs_array = np .asarray (obs_array )
5972 obs_spec = specs .Array (shape = obs_array .shape , dtype = obs_array .dtype )
6073 self .buffer = obs_buffer .Buffer (
61- buffer_size = ( observable .buffer_size or DEFAULT_BUFFER_SIZE ) ,
74+ buffer_size = self .buffer_size ,
6275 shape = obs_spec .shape , dtype = obs_spec .dtype ,
6376 strip_singleton_buffer_dim = strip_singleton_buffer_dim )
6477 self .update_schedule = collections .deque ()
6578
79+ def _bind_attribute_from_observable (self , attr , default_value , random_state ):
80+ obs_attr = getattr (self .observable , attr )
81+ if obs_attr :
82+ if isinstance (obs_attr , variation .Variation ):
83+ setattr (self , attr ,
84+ functools .partial (obs_attr , random_state = random_state ))
85+ else :
86+ setattr (self , attr , obs_attr )
87+ else :
88+ setattr (self , attr , default_value )
89+
6690
6791def _call_if_callable (arg ):
6892 if callable (arg ):
@@ -137,7 +161,7 @@ def make_buffers_dict(observables):
137161
138162 self ._step_counter = 0
139163 for enabled in self ._enabled_list :
140- first_delay = _call_if_callable (enabled .observable . delay or DEFAULT_DELAY )
164+ first_delay = _call_if_callable (enabled .delay )
141165 enabled .buffer .insert (
142166 0 , first_delay ,
143167 enabled .observation_callable ())
@@ -225,14 +249,11 @@ def prepare_for_next_control_step(self):
225249 if self ._enabled_structure is None :
226250 raise RuntimeError ('`reset` must be called before `before_step`.' )
227251 for enabled in self ._enabled_list :
228- update_interval = (
229- enabled .observable .update_interval or DEFAULT_UPDATE_INTERVAL )
230- delay = enabled .observable .delay or DEFAULT_DELAY
231- buffer_size = enabled .observable .buffer_size or DEFAULT_BUFFER_SIZE
232-
233- if (update_interval == DEFAULT_UPDATE_INTERVAL and delay == DEFAULT_DELAY
234- and buffer_size < self ._physics_steps_per_control_step ):
235- for i in reversed (range (buffer_size )):
252+
253+ if (enabled .update_interval == DEFAULT_UPDATE_INTERVAL
254+ and enabled .delay == DEFAULT_DELAY
255+ and enabled .buffer_size < self ._physics_steps_per_control_step ):
256+ for i in reversed (range (enabled .buffer_size )):
236257 next_step = (
237258 self ._step_counter + self ._physics_steps_per_control_step - i )
238259 next_delay = DEFAULT_DELAY
@@ -244,9 +265,9 @@ def prepare_for_next_control_step(self):
244265 last_scheduled_step = self ._step_counter
245266 max_step = self ._step_counter + 2 * self ._physics_steps_per_control_step
246267 while last_scheduled_step < max_step :
247- next_update_interval = _call_if_callable (update_interval )
268+ next_update_interval = _call_if_callable (enabled . update_interval )
248269 next_step = last_scheduled_step + next_update_interval
249- next_delay = _call_if_callable (delay )
270+ next_delay = _call_if_callable (enabled . delay )
250271 enabled .update_schedule .append ((next_step , next_delay ))
251272 last_scheduled_step = next_step
252273 # Optimize the schedule by planning ahead and dropping unseen entries.
0 commit comments