Skip to content

Commit a2ad4e3

Browse files
Leonard Hasenclevercopybara-github
authored andcommitted
Variation broadcaster that allows a variation to be broadcasted to multiple callers.
PiperOrigin-RevId: 683243432 Change-Id: I6ed10e3cb14a39fa5d82886d6dab9faa0ce634d9
1 parent 2456cfa commit a2ad4e3

2 files changed

Lines changed: 169 additions & 0 deletions

File tree

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2024 The dm_control Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
16+
"""A broadcaster that allows sharing of variation values across many callers."""
17+
18+
import collections
19+
import weakref
20+
21+
from dm_control.composer import variation
22+
23+
24+
class VariationBroadcaster:
25+
"""Allows a variation to be broadcasted to multiple callers.
26+
27+
This class wraps a `Variation` object and generates multiple proxies that
28+
can be used in place of the wrapped `Variation`. The broadcaster updates its
29+
value in rounds. At the beginning of each round, the broadcaster re-evaluates
30+
the wrapped `Variation` and caches the new value internally. When a proxy
31+
is called, the broadcaster will return this cached value, thus ensuring that
32+
all proxied values are the same. The round ends when all of the proxies have
33+
been called exactly once. It is an error to call any particular proxy more
34+
than once per round.
35+
"""
36+
37+
def __init__(self, wrapped_variation: variation.Variation):
38+
self._wrapped_variation = wrapped_variation
39+
self._cached_values = weakref.WeakKeyDictionary()
40+
41+
def get_proxy(self) -> variation.Variation:
42+
"""Returns a `Variation` to be used in place of the wrapped `Variation`."""
43+
new_proxy = _BroadcastedValueProxy(self)
44+
self._cached_values[new_proxy] = collections.deque()
45+
return new_proxy
46+
47+
def _get_value(self, proxy, random_state):
48+
"""Returns the variation value for a proxy owned by this broadcaster."""
49+
cached_values = self._cached_values[proxy]
50+
if not cached_values:
51+
new_value = variation.evaluate(
52+
self._wrapped_variation, None, None, random_state)
53+
for values in self._cached_values.values():
54+
values.append(new_value)
55+
return cached_values.popleft()
56+
57+
58+
class _BroadcastedValueProxy(variation.Variation):
59+
60+
def __init__(self, broadcaster):
61+
self._broadcaster = broadcaster
62+
63+
def __call__(self, initial_value=None, current_value=None, random_state=None):
64+
value = self._broadcaster._get_value(self, random_state) # pylint: disable=protected-access
65+
return value
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright 2024 The dm_control Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
16+
from absl.testing import absltest
17+
from dm_control.composer import variation
18+
from dm_control.composer.variation import distributions
19+
from dm_control.composer.variation import variation_broadcaster
20+
import numpy as np
21+
22+
23+
class VariationBroadcasterTest(absltest.TestCase):
24+
25+
def test_can_generate_values(self):
26+
random_state = np.random.RandomState(2348)
27+
expected_values = [random_state.uniform(0, 1) for _ in range(5)]
28+
29+
random_state = np.random.RandomState(2348)
30+
broadcaster = variation_broadcaster.VariationBroadcaster(
31+
distributions.Uniform(0, 1)
32+
)
33+
proxy_1 = broadcaster.get_proxy()
34+
proxy_2 = broadcaster.get_proxy()
35+
proxy_3 = broadcaster.get_proxy()
36+
37+
self.assertEqual(
38+
variation.evaluate(proxy_1, random_state=random_state),
39+
expected_values[0],
40+
)
41+
self.assertEqual(
42+
variation.evaluate(proxy_2, random_state=random_state),
43+
expected_values[0],
44+
)
45+
self.assertEqual(
46+
variation.evaluate(proxy_3, random_state=random_state),
47+
expected_values[0],
48+
)
49+
50+
self.assertEqual(
51+
variation.evaluate(proxy_1, random_state=random_state),
52+
expected_values[1],
53+
)
54+
self.assertEqual(
55+
variation.evaluate(proxy_1, random_state=random_state),
56+
expected_values[2],
57+
)
58+
59+
self.assertEqual(
60+
variation.evaluate(proxy_2, random_state=random_state),
61+
expected_values[1],
62+
)
63+
self.assertEqual(
64+
variation.evaluate(proxy_3, random_state=random_state),
65+
expected_values[1],
66+
)
67+
self.assertEqual(
68+
variation.evaluate(proxy_3, random_state=random_state),
69+
expected_values[2],
70+
)
71+
72+
self.assertEqual(
73+
variation.evaluate(proxy_3, random_state=random_state),
74+
expected_values[3],
75+
)
76+
self.assertEqual(
77+
variation.evaluate(proxy_1, random_state=random_state),
78+
expected_values[3],
79+
)
80+
self.assertEqual(
81+
variation.evaluate(proxy_2, random_state=random_state),
82+
expected_values[2],
83+
)
84+
85+
self.assertEqual(
86+
variation.evaluate(proxy_1, random_state=random_state),
87+
expected_values[4],
88+
)
89+
self.assertEqual(
90+
variation.evaluate(proxy_2, random_state=random_state),
91+
expected_values[3],
92+
)
93+
self.assertEqual(
94+
variation.evaluate(proxy_2, random_state=random_state),
95+
expected_values[4],
96+
)
97+
self.assertEqual(
98+
variation.evaluate(proxy_3, random_state=random_state),
99+
expected_values[4],
100+
)
101+
102+
103+
if __name__ == '__main__':
104+
absltest.main()

0 commit comments

Comments
 (0)