Skip to content

Commit b2be68d

Browse files
saran-tcopybara-github
authored andcommitted
Implement __getitem__ for Composer variations.
PiperOrigin-RevId: 307904417 Change-Id: I44f9f2ec70878d95a5f6341dac8dc4d03ffe6d0f
1 parent 4b1df5f commit b2be68d

3 files changed

Lines changed: 24 additions & 1 deletion

File tree

dm_control/composer/variation/base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import operator
2424

2525
from dm_control.composer.variation import variation_values
26+
import numpy as np
2627
import six
2728

2829

@@ -82,6 +83,9 @@ def __pow__(self, other):
8283
def __rpow__(self, other):
8384
return _BinaryOperation(operator.pow, other, self)
8485

86+
def __getitem__(self, index):
87+
return _GetItemOperation(self, index)
88+
8589

8690
class _BinaryOperation(Variation):
8791
"""Represents the result of applying a binary operator to two Variations."""
@@ -97,3 +101,15 @@ def __call__(self, initial_value=None, current_value=None, random_state=None):
97101
second_value = variation_values.evaluate(
98102
self._second, initial_value, current_value, random_state)
99103
return self._op(first_value, second_value)
104+
105+
106+
class _GetItemOperation(Variation):
107+
108+
def __init__(self, variation, index):
109+
self._variation = variation
110+
self._index = index
111+
112+
def __call__(self, initial_value=None, current_value=None, random_state=None):
113+
value = variation_values.evaluate(
114+
self._variation, initial_value, current_value, random_state)
115+
return np.asarray(value)[self._index]

dm_control/composer/variation/variation_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@
2424
from absl.testing import parameterized
2525
from dm_control.composer import variation
2626
from dm_control.composer.variation import deterministic
27+
import numpy as np
2728

2829

2930
class VariationTest(parameterized.TestCase):
3031

3132
def setUp(self):
33+
super(VariationTest, self).setUp()
3234
self.value_1 = 3
3335
self.variation_1 = deterministic.Constant(self.value_1)
3436
self.value_2 = 5
@@ -47,6 +49,11 @@ def test_operator(self, name):
4749
variation.evaluate(func(self.variation_1, self.variation_2)),
4850
func(self.value_1, self.value_2))
4951

52+
def test_getitem(self):
53+
value = deterministic.Constant(np.array([4, 5, 6, 7, 8]))
54+
np.testing.assert_array_equal(
55+
variation.evaluate(value[[3, 1]]),
56+
[7, 5])
5057

5158
if __name__ == '__main__':
5259
absltest.main()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def find_data_files(package_dir, patterns):
166166

167167
setup(
168168
name='dm_control',
169-
version='0.0.307814181',
169+
version='0.0.307904417',
170170
description='Continuous control environments and MuJoCo Python bindings.',
171171
author='DeepMind',
172172
license='Apache License, Version 2.0',

0 commit comments

Comments
 (0)