Skip to content

Commit 9c0b199

Browse files
sbohezalimuldal
authored andcommitted
Pass random_state to observable attributes that are Variation objects.
PiperOrigin-RevId: 317105039 Change-Id: I82c551e1ad99f726c56f96bb9ec189c87c106f8e
1 parent ee97ac1 commit 9c0b199

2 files changed

Lines changed: 34 additions & 13 deletions

File tree

dm_control/composer/observation/updater.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
from __future__ import print_function
2121

2222
import collections
23+
import functools
2324
from absl import logging
2425

26+
from dm_control.composer import variation
2527
from dm_control.composer.observation import obs_buffer
2628
from dm_env import specs
2729
import numpy as np
@@ -37,6 +39,7 @@ class _EnabledObservable(object):
3739
"""Encapsulates an enabled observable, its buffer, and its update schedule."""
3840

3941
__slots__ = ('observable', 'observation_callable',
42+
'update_interval', 'delay', 'buffer_size',
4043
'buffer', 'update_schedule')
4144

4245
def __init__(self, observable, physics, random_state,
@@ -45,6 +48,16 @@ def __init__(self, observable, physics, random_state,
4548
self.observation_callable = (
4649
observable.observation_callable(physics, random_state))
4750

51+
self._bind_attribute_from_observable('update_interval',
52+
DEFAULT_UPDATE_INTERVAL,
53+
random_state)
54+
self._bind_attribute_from_observable('delay',
55+
DEFAULT_DELAY,
56+
random_state)
57+
self._bind_attribute_from_observable('buffer_size',
58+
DEFAULT_BUFFER_SIZE,
59+
random_state)
60+
4861
obs_spec = self.observable.array_spec
4962
if obs_spec is None:
5063
# We take an observation to determine the shape and dtype of the array.
@@ -58,11 +71,22 @@ def __init__(self, observable, physics, random_state,
5871
obs_array = np.asarray(obs_array)
5972
obs_spec = specs.Array(shape=obs_array.shape, dtype=obs_array.dtype)
6073
self.buffer = obs_buffer.Buffer(
61-
buffer_size=(observable.buffer_size or DEFAULT_BUFFER_SIZE),
74+
buffer_size=self.buffer_size,
6275
shape=obs_spec.shape, dtype=obs_spec.dtype,
6376
strip_singleton_buffer_dim=strip_singleton_buffer_dim)
6477
self.update_schedule = collections.deque()
6578

79+
def _bind_attribute_from_observable(self, attr, default_value, random_state):
80+
obs_attr = getattr(self.observable, attr)
81+
if obs_attr:
82+
if isinstance(obs_attr, variation.Variation):
83+
setattr(self, attr,
84+
functools.partial(obs_attr, random_state=random_state))
85+
else:
86+
setattr(self, attr, obs_attr)
87+
else:
88+
setattr(self, attr, default_value)
89+
6690

6791
def _call_if_callable(arg):
6892
if callable(arg):
@@ -137,7 +161,7 @@ def make_buffers_dict(observables):
137161

138162
self._step_counter = 0
139163
for enabled in self._enabled_list:
140-
first_delay = _call_if_callable(enabled.observable.delay or DEFAULT_DELAY)
164+
first_delay = _call_if_callable(enabled.delay)
141165
enabled.buffer.insert(
142166
0, first_delay,
143167
enabled.observation_callable())
@@ -225,14 +249,11 @@ def prepare_for_next_control_step(self):
225249
if self._enabled_structure is None:
226250
raise RuntimeError('`reset` must be called before `before_step`.')
227251
for enabled in self._enabled_list:
228-
update_interval = (
229-
enabled.observable.update_interval or DEFAULT_UPDATE_INTERVAL)
230-
delay = enabled.observable.delay or DEFAULT_DELAY
231-
buffer_size = enabled.observable.buffer_size or DEFAULT_BUFFER_SIZE
232-
233-
if (update_interval == DEFAULT_UPDATE_INTERVAL and delay == DEFAULT_DELAY
234-
and buffer_size < self._physics_steps_per_control_step):
235-
for i in reversed(range(buffer_size)):
252+
253+
if (enabled.update_interval == DEFAULT_UPDATE_INTERVAL
254+
and enabled.delay == DEFAULT_DELAY
255+
and enabled.buffer_size < self._physics_steps_per_control_step):
256+
for i in reversed(range(enabled.buffer_size)):
236257
next_step = (
237258
self._step_counter + self._physics_steps_per_control_step - i)
238259
next_delay = DEFAULT_DELAY
@@ -244,9 +265,9 @@ def prepare_for_next_control_step(self):
244265
last_scheduled_step = self._step_counter
245266
max_step = self._step_counter + 2 * self._physics_steps_per_control_step
246267
while last_scheduled_step < max_step:
247-
next_update_interval = _call_if_callable(update_interval)
268+
next_update_interval = _call_if_callable(enabled.update_interval)
248269
next_step = last_scheduled_step + next_update_interval
249-
next_delay = _call_if_callable(delay)
270+
next_delay = _call_if_callable(enabled.delay)
250271
enabled.update_schedule.append((next_step, next_delay))
251272
last_scheduled_step = next_step
252273
# Optimize the schedule by planning ahead and dropping unseen entries.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def find_data_files(package_dir, patterns):
166166

167167
setup(
168168
name='dm_control',
169-
version='0.0.316448566',
169+
version='0.0.317105039',
170170
description='Continuous control environments and MuJoCo Python bindings.',
171171
author='DeepMind',
172172
license='Apache License, Version 2.0',

0 commit comments

Comments
 (0)