Skip to content

Commit dd5a792

Browse files
hartikainenalimuldal
authored andcommitted
dm_control: Import of refs/pull/121/head
Closes #121 PiperOrigin-RevId: 284276886 Change-Id: I3ac8abb8314f0cc8a15d740b516f06ec0f2ba1f2
1 parent 409a057 commit dd5a792

2 files changed

Lines changed: 270 additions & 0 deletions

File tree

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2019 The dm_control Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
16+
"""Wrapper that scales actions to a specific range."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import dm_env
23+
from dm_env import specs
24+
import numpy as np
25+
26+
_ACTION_SPEC_MUST_BE_BOUNDED_ARRAY = (
27+
"`env.action_spec()` must return a single `BoundedArray`, got: {}.")
28+
_MUST_BE_FINITE = "All values in `{name}` must be finite, got: {bounds}."
29+
_MUST_BROADCAST = (
30+
"`{name}` must be broadcastable to shape {shape}, got: {bounds}.")
31+
32+
33+
class Wrapper(dm_env.Environment):
34+
"""Wraps a control environment to rescale actions to a specific range."""
35+
__slots__ = ("_action_spec", "_env", "_transform")
36+
37+
def __init__(self, env, minimum, maximum):
38+
"""Initializes a new action scale Wrapper.
39+
40+
Args:
41+
env: Instance of `dm_env.Environment` to wrap. Its `action_spec` must
42+
consist of a single `BoundedArray` with all-finite bounds.
43+
minimum: Scalar or array-like specifying element-wise lower bounds
44+
(inclusive) for the `action_spec` of the wrapped environment. Must be
45+
finite and broadcastable to the shape of the `action_spec`.
46+
maximum: Scalar or array-like specifying element-wise upper bounds
47+
(inclusive) for the `action_spec` of the wrapped environment. Must be
48+
finite and broadcastable to the shape of the `action_spec`.
49+
50+
Raises:
51+
ValueError: If `env.action_spec()` is not a single `BoundedArray`.
52+
ValueError: If `env.action_spec()` has non-finite bounds.
53+
ValueError: If `minimum` or `maximum` contain non-finite values.
54+
ValueError: If `minimum` or `maximum` are not broadcastable to
55+
`env.action_spec().shape`.
56+
"""
57+
action_spec = env.action_spec()
58+
if not isinstance(action_spec, specs.BoundedArray):
59+
raise ValueError(_ACTION_SPEC_MUST_BE_BOUNDED_ARRAY.format(action_spec))
60+
61+
minimum = np.array(minimum)
62+
maximum = np.array(maximum)
63+
shape = action_spec.shape
64+
orig_minimum = action_spec.minimum
65+
orig_maximum = action_spec.maximum
66+
orig_dtype = action_spec.dtype
67+
68+
def validate(bounds, name):
69+
if not np.all(np.isfinite(bounds)):
70+
raise ValueError(_MUST_BE_FINITE.format(name=name, bounds=bounds))
71+
try:
72+
np.broadcast_to(bounds, shape)
73+
except ValueError:
74+
raise ValueError(_MUST_BROADCAST.format(
75+
name=name, bounds=bounds, shape=shape))
76+
77+
validate(minimum, "minimum")
78+
validate(maximum, "maximum")
79+
validate(orig_minimum, "env.action_spec().minimum")
80+
validate(orig_maximum, "env.action_spec().maximum")
81+
82+
scale = (orig_maximum - orig_minimum) / (maximum - minimum)
83+
84+
def transform(action):
85+
new_action = orig_minimum + scale * (action - minimum)
86+
return new_action.astype(orig_dtype, copy=False)
87+
88+
dtype = np.result_type(minimum, maximum, orig_dtype)
89+
self._action_spec = action_spec.replace(
90+
minimum=minimum, maximum=maximum, dtype=dtype)
91+
self._env = env
92+
self._transform = transform
93+
94+
def step(self, action):
95+
return self._env.step(self._transform(action))
96+
97+
def reset(self):
98+
return self._env.reset()
99+
100+
def observation_spec(self):
101+
return self._env.observation_spec()
102+
103+
def action_spec(self):
104+
return self._action_spec
105+
106+
def __getattr__(self, name):
107+
return getattr(self._env, name)
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Copyright 2019 The dm_control Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
16+
"""Tests for the action scale wrapper."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
# Internal dependencies.
23+
from absl.testing import absltest
24+
from absl.testing import parameterized
25+
from dm_control.rl import control
26+
from dm_control.suite.wrappers import action_scale
27+
from dm_env import specs
28+
import mock
29+
import numpy as np
30+
31+
32+
def make_action_spec(lower=(-1.,), upper=(1.,)):
33+
lower, upper = np.broadcast_arrays(lower, upper)
34+
return specs.BoundedArray(
35+
shape=lower.shape, dtype=float, minimum=lower, maximum=upper)
36+
37+
38+
def make_mock_env(action_spec):
39+
action_spec = action_spec
40+
env = mock.Mock(spec=control.Environment)
41+
env.action_spec.return_value = action_spec
42+
return env
43+
44+
45+
class ActionScaleTest(parameterized.TestCase):
46+
47+
def assertStepCalledOnceWithCorrectAction(self, env, expected_action):
48+
# NB: `assert_called_once_with()` doesn't support numpy arrays.
49+
env.step.assert_called_once()
50+
actual_action = env.step.call_args_list[0][0][0]
51+
np.testing.assert_array_equal(expected_action, actual_action)
52+
53+
@parameterized.parameters(
54+
{
55+
'minimum': np.r_[-1., -1.],
56+
'maximum': np.r_[1., 1.],
57+
'scaled_minimum': np.r_[-2., -2.],
58+
'scaled_maximum': np.r_[2., 2.],
59+
},
60+
{
61+
'minimum': np.r_[-2., -2.],
62+
'maximum': np.r_[2., 2.],
63+
'scaled_minimum': np.r_[-1., -1.],
64+
'scaled_maximum': np.r_[1., 1.],
65+
},
66+
{
67+
'minimum': np.r_[-1., -1.],
68+
'maximum': np.r_[1., 1.],
69+
'scaled_minimum': np.r_[-2., -2.],
70+
'scaled_maximum': np.r_[1., 1.],
71+
},
72+
{
73+
'minimum': np.r_[-1., -1.],
74+
'maximum': np.r_[1., 1.],
75+
'scaled_minimum': np.r_[-1., -1.],
76+
'scaled_maximum': np.r_[2., 2.],
77+
},
78+
)
79+
def test_step(self, minimum, maximum, scaled_minimum, scaled_maximum):
80+
action_spec = make_action_spec(lower=minimum, upper=maximum)
81+
env = make_mock_env(action_spec=action_spec)
82+
wrapped_env = action_scale.Wrapper(
83+
env, minimum=scaled_minimum, maximum=scaled_maximum)
84+
85+
time_step = wrapped_env.step(scaled_minimum)
86+
self.assertStepCalledOnceWithCorrectAction(env, minimum)
87+
self.assertIs(time_step, env.step(minimum))
88+
89+
env.reset_mock()
90+
91+
time_step = wrapped_env.step(scaled_maximum)
92+
self.assertStepCalledOnceWithCorrectAction(env, maximum)
93+
self.assertIs(time_step, env.step(maximum))
94+
95+
@parameterized.parameters(
96+
{
97+
'minimum': np.r_[-1., -1.],
98+
'maximum': np.r_[1., 1.],
99+
},
100+
{
101+
'minimum': np.r_[0, 1],
102+
'maximum': np.r_[2, 3],
103+
},
104+
)
105+
def test_correct_action_spec(self, minimum, maximum):
106+
original_action_spec = make_action_spec(
107+
lower=np.r_[-2., -2.], upper=np.r_[2., 2.])
108+
env = make_mock_env(action_spec=original_action_spec)
109+
wrapped_env = action_scale.Wrapper(env, minimum=minimum, maximum=maximum)
110+
new_action_spec = wrapped_env.action_spec()
111+
np.testing.assert_array_equal(new_action_spec.minimum, minimum)
112+
np.testing.assert_array_equal(new_action_spec.maximum, maximum)
113+
114+
@parameterized.parameters('reset', 'observation_spec', 'control_timestep')
115+
def test_method_delegated_to_underlying_env(self, method_name):
116+
env = make_mock_env(action_spec=make_action_spec())
117+
wrapped_env = action_scale.Wrapper(env, minimum=0, maximum=1)
118+
env_method = getattr(env, method_name)
119+
wrapper_method = getattr(wrapped_env, method_name)
120+
out = wrapper_method()
121+
env_method.assert_called_once_with()
122+
self.assertIs(out, env_method())
123+
124+
def test_invalid_action_spec_type(self):
125+
action_spec = [make_action_spec()] * 2
126+
env = make_mock_env(action_spec=action_spec)
127+
with self.assertRaisesWithLiteralMatch(
128+
ValueError,
129+
action_scale._ACTION_SPEC_MUST_BE_BOUNDED_ARRAY.format(action_spec)):
130+
action_scale.Wrapper(env, minimum=0, maximum=1)
131+
132+
@parameterized.parameters(
133+
{'name': 'minimum', 'bounds': np.r_[np.nan]},
134+
{'name': 'minimum', 'bounds': np.r_[-np.inf]},
135+
{'name': 'maximum', 'bounds': np.r_[np.inf]},
136+
)
137+
def test_non_finite_bounds(self, name, bounds):
138+
kwargs = {'minimum': np.r_[-1.], 'maximum': np.r_[1.]}
139+
kwargs[name] = bounds
140+
env = make_mock_env(action_spec=make_action_spec())
141+
with self.assertRaisesWithLiteralMatch(
142+
ValueError,
143+
action_scale._MUST_BE_FINITE.format(name=name, bounds=bounds)):
144+
action_scale.Wrapper(env, **kwargs)
145+
146+
@parameterized.parameters(
147+
{'name': 'minimum', 'bounds': np.r_[1., 2., 3.]},
148+
{'name': 'minimum', 'bounds': np.r_[[1.], [2.], [3.]]},
149+
)
150+
def test_invalid_bounds_shape(self, name, bounds):
151+
shape = (2,)
152+
kwargs = {'minimum': np.zeros(shape), 'maximum': np.ones(shape)}
153+
kwargs[name] = bounds
154+
action_spec = make_action_spec(lower=[-1, -1], upper=[2, 3])
155+
env = make_mock_env(action_spec=action_spec)
156+
with self.assertRaisesWithLiteralMatch(
157+
ValueError,
158+
action_scale._MUST_BROADCAST.format(
159+
name=name, bounds=bounds, shape=shape)):
160+
action_scale.Wrapper(env, **kwargs)
161+
162+
if __name__ == '__main__':
163+
absltest.main()

0 commit comments

Comments
 (0)