Skip to content

Commit ce8607a

Browse files
saran-tcopybara-github
authored andcommitted
Split out non-essential methods in base.Walker into legacy_base.Walker to keep the base API clean.
All current walkers have been modified to inherit from `legacy_base.Walker`. PiperOrigin-RevId: 285169387 Change-Id: Iaaf3d3ffa0ad9451d5b0a7723cbad11e360b3358
1 parent fbdc9fd commit ce8607a

5 files changed

Lines changed: 338 additions & 285 deletions

File tree

dm_control/locomotion/soccer/boxhead.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from dm_control import composer
2525
from dm_control import mjcf
26-
from dm_control.locomotion.walkers import base
26+
from dm_control.locomotion.walkers import legacy_base
2727
import numpy as np
2828
from PIL import Image
2929
import six
@@ -102,7 +102,7 @@ def _asset_png_with_background_rgba_bytes(asset_fname, background_rgba):
102102
return png_encoding
103103

104104

105-
class BoxHead(base.Walker):
105+
class BoxHead(legacy_base.Walker):
106106
"""A rollable and jumpable ball with a head."""
107107

108108
def _build(self,

dm_control/locomotion/walkers/base.py

Lines changed: 4 additions & 267 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,11 @@
2424

2525
from dm_control import composer
2626
from dm_control.composer.observation import observable
27-
from dm_control.locomotion.walkers import initializers
28-
from dm_control.mujoco.wrapper.mjbindings import mjlib
2927

3028
from dm_env import specs
3129
import numpy as np
3230
import six
3331

34-
_RANGEFINDER_SCALE = 10.0
35-
_TOUCH_THRESHOLD = 1e-3
36-
3732

3833
def _make_readonly_float64_copy(value):
3934
if np.isscalar(value):
@@ -74,22 +69,12 @@ def __new__(cls, qpos=None, xpos=(0, 0, 0), xquat=(1, 0, 0, 0)):
7469
class Walker(composer.Robot):
7570
"""Abstract base class for Walker robots."""
7671

77-
def _build(self, initializer=None):
78-
self._initializer = initializer or initializers.UprightInitializer()
79-
8072
def create_root_joints(self, attachment_frame):
8173
attachment_frame.add('freejoint')
8274

83-
@property
84-
def upright_pose(self):
85-
return WalkerPose()
86-
8775
def _build_observables(self):
8876
return WalkerObservables(self)
8977

90-
def reinitialize_pose(self, physics, random_state):
91-
self._initializer.initialize_pose(physics, self, random_state)
92-
9378
def transform_vec_to_egocentric_frame(self, physics, vec_in_world_frame):
9479
"""Linearly transforms a world-frame vector into walker's egocentric frame.
9580
@@ -165,138 +150,10 @@ def transform_xmat_to_egocentric_frame(self, physics, xmat):
165150
def root_body(self):
166151
raise NotImplementedError
167152

168-
def aliveness(self, physics):
169-
"""A measure of the aliveness of the walker.
170-
171-
Aliveness measure could be used for deciding on termination (ant flipped
172-
over and it's impossible for it to recover), or used as a shaping reward
173-
to maintain an alive pose that we desired (humanoids remaining upright).
174-
175-
Args:
176-
physics: an instance of `Physics`.
177-
178-
Returns:
179-
a `float` in the range of [-1., 0.] where -1 means not alive and 0. means
180-
alive. In walkers for which the concept of aliveness does not make sense,
181-
the default implementation is to always return 0.0.
182-
"""
183-
return 0.
184-
185-
@abc.abstractproperty
186-
def ground_contact_geoms(self):
187-
"""Geoms in this walker that are expected to be in contact with the ground.
188-
189-
This property is used by some tasks to determine contact-based failure
190-
termination. It should only contain geoms that are expected to be in
191-
contact with the ground during "normal" locomotion. For example, for a
192-
humanoid model, this property would be expected to contain only the geoms
193-
that make up the two feet.
194-
195-
Note that certain specialized tasks may also allow geoms that are not listed
196-
here to be in contact with the ground. For example, a humanoid cartwheel
197-
task would also allow the hands to touch the ground in addition to the feet.
198-
"""
199-
raise NotImplementedError
200-
201-
def after_compile(self, physics, unused_random_state):
202-
super(Walker, self).after_compile(physics, unused_random_state)
203-
self._end_effector_geom_ids = set()
204-
for eff_body in self.end_effectors:
205-
eff_geom = eff_body.find_all('geom')
206-
self._end_effector_geom_ids |= set(physics.bind(eff_geom).element_id)
207-
self._body_geom_ids = set(
208-
physics.bind(geom).element_id
209-
for geom in self.mjcf_model.find_all('geom'))
210-
self._body_geom_ids.difference_update(self._end_effector_geom_ids)
211-
212-
@property
213-
def end_effector_geom_ids(self):
214-
return self._end_effector_geom_ids
215-
216-
@property
217-
def body_geom_ids(self):
218-
return self._body_geom_ids
219-
220-
def end_effector_contacts(self, physics):
221-
"""Collect the contacts with the end effectors.
222-
223-
This function returns any contacts being made with any of the end effectors,
224-
both the other geom with which contact is being made as well as the
225-
magnitude.
226-
227-
Args:
228-
physics: an instance of `Physics`.
229-
230-
Returns:
231-
a dict with as key a tuple of geom ids, of which one is an end effector,
232-
and as value the total magnitude of all contacts between these geoms
233-
"""
234-
return self.collect_contacts(physics, self._end_effector_geom_ids)
235-
236-
def body_contacts(self, physics):
237-
"""Collect the contacts with the body.
238-
239-
This function returns any contacts being made with any of body geoms, except
240-
the end effectors, both the other geom with which contact is being made as
241-
well as the magnitude.
242-
243-
Args:
244-
physics: an instance of `Physics`.
245-
246-
Returns:
247-
a dict with as key a tuple of geom ids, of which one is a body geom,
248-
and as value the total magnitude of all contacts between these geoms
249-
"""
250-
return self.collect_contacts(physics, self._body_geom_ids)
251-
252-
def collect_contacts(self, physics, geom_ids):
253-
contacts = {}
254-
forcetorque = np.zeros(6)
255-
for i, contact in enumerate(physics.data.contact):
256-
if ((contact.geom1 in geom_ids) or
257-
(contact.geom2 in geom_ids)) and contact.dist < contact.includemargin:
258-
mjlib.mj_contactForce(physics.model.ptr, physics.data.ptr, i,
259-
forcetorque)
260-
contacts[(contact.geom1, contact.geom2)] = (forcetorque[0]
261-
+ contacts.get(
262-
(contact.geom1,
263-
contact.geom2), 0.))
264-
return contacts
265-
266-
@abc.abstractproperty
267-
def end_effectors(self):
268-
raise NotImplementedError
269-
270153
@abc.abstractproperty
271154
def observable_joints(self):
272155
raise NotImplementedError
273156

274-
@abc.abstractproperty
275-
def egocentric_camera(self):
276-
raise NotImplementedError
277-
278-
@composer.cached_property
279-
def touch_sensors(self):
280-
return self._mjcf_root.sensor.get_children('touch')
281-
282-
@property
283-
def prev_action(self):
284-
"""Returns the actuation actions applied in the previous step.
285-
286-
Concrete walker implementations should provide caching mechanism themselves
287-
in order to access this observable (for example, through `apply_action`).
288-
"""
289-
raise NotImplementedError
290-
291-
def after_substep(self, physics, random_state):
292-
del random_state # Unused.
293-
# As of MuJoCo v2.0, updates to `mjData->subtree_linvel` will be skipped
294-
# unless these quantities are needed by the simulation. We need these in
295-
# order to calculate `torso_{x,y}vel`, so we therefore call `mj_subtreeVel`
296-
# explicitly.
297-
# TODO(b/123065920): Consider using a `subtreelinvel` sensor instead.
298-
mjlib.mj_subtreeVel(physics.model.ptr, physics.data.ptr)
299-
300157
@property
301158
def action_spec(self):
302159
minimum, maximum = zip(*[
@@ -327,35 +184,11 @@ def joints_pos(self):
327184
def joints_vel(self):
328185
return observable.MJCFFeature('qvel', self._entity.observable_joints)
329186

330-
@composer.observable
331-
def body_height(self):
332-
return observable.MJCFFeature('xpos', self._entity.root_body)[2]
333-
334-
@composer.observable
335-
def end_effectors_pos(self):
336-
"""Position of end effectors relative to torso, in the egocentric frame."""
337-
def relative_pos_in_egocentric_frame(physics):
338-
end_effector = physics.bind(self._entity.end_effectors).xpos
339-
torso = physics.bind(self._entity.root_body).xpos
340-
xmat = np.reshape(physics.bind(self._entity.root_body).xmat, (3, 3))
341-
return np.reshape(np.dot(end_effector - torso, xmat), -1)
342-
return observable.Generic(relative_pos_in_egocentric_frame)
343-
344-
@composer.observable
345-
def world_zaxis(self):
346-
"""The world's z-vector in this Walker's torso frame."""
347-
return observable.MJCFFeature('xmat', self._entity.root_body)[6:]
348-
349187
@composer.observable
350188
def sensors_gyro(self):
351189
return observable.MJCFFeature('sensordata',
352190
self._entity.mjcf_model.sensor.gyro)
353191

354-
@composer.observable
355-
def sensors_velocimeter(self):
356-
return observable.MJCFFeature('sensordata',
357-
self._entity.mjcf_model.sensor.velocimeter)
358-
359192
@composer.observable
360193
def sensors_accelerometer(self):
361194
return observable.MJCFFeature('sensordata',
@@ -373,59 +206,8 @@ def sensors_torque(self):
373206

374207
@composer.observable
375208
def sensors_touch(self):
376-
return observable.MJCFFeature(
377-
'sensordata',
378-
self._entity.mjcf_model.sensor.touch,
379-
corruptor=
380-
lambda v, random_state: np.array(v > _TOUCH_THRESHOLD, dtype=np.float))
381-
382-
@composer.observable
383-
def sensors_rangefinder(self):
384-
def tanh_rangefinder(physics):
385-
raw = physics.bind(self._entity.mjcf_model.sensor.rangefinder).sensordata
386-
raw = np.array(raw)
387-
raw[raw == -1.0] = np.inf
388-
return _RANGEFINDER_SCALE * np.tanh(raw / _RANGEFINDER_SCALE)
389-
return observable.Generic(tanh_rangefinder)
390-
391-
@composer.observable
392-
def egocentric_camera(self):
393-
return observable.MJCFCamera(self._entity.egocentric_camera,
394-
width=64, height=64)
395-
396-
@composer.observable
397-
def position(self):
398-
return observable.MJCFFeature('xpos', self._entity.root_body)
399-
400-
@composer.observable
401-
def orientation(self):
402-
return observable.MJCFFeature('xmat', self._entity.root_body)
403-
404-
def add_egocentric_vector(self,
405-
name,
406-
world_frame_observable,
407-
enabled=True,
408-
origin_callable=None,
409-
**kwargs):
410-
411-
def _egocentric(physics, origin_callable=origin_callable):
412-
vec = world_frame_observable.observation_callable(physics)()
413-
origin_callable = origin_callable or (lambda physics: np.zeros(vec.size))
414-
delta = vec - origin_callable(physics)
415-
return self._entity.transform_vec_to_egocentric_frame(physics, delta)
416-
417-
self._observables[name] = observable.Generic(_egocentric, **kwargs)
418-
self._observables[name].enabled = enabled
419-
420-
def add_egocentric_xmat(self, name, xmat_observable, enabled=True, **kwargs):
421-
422-
def _egocentric(physics):
423-
return self._entity.transform_xmat_to_egocentric_frame(
424-
physics,
425-
xmat_observable.observation_callable(physics)())
426-
427-
self._observables[name] = observable.Generic(_egocentric, **kwargs)
428-
self._observables[name].enabled = enabled
209+
return observable.MJCFFeature('sensordata',
210+
self._entity.mjcf_model.sensor.touch)
429211

430212
# Semantic groupings of Walker observables.
431213
def _collect_from_attachments(self, attribute_name):
@@ -436,60 +218,15 @@ def _collect_from_attachments(self, attribute_name):
436218

437219
@property
438220
def proprioception(self):
439-
return ([self.joints_pos, self.joints_vel,
440-
self.body_height, self.end_effectors_pos, self.world_zaxis] +
221+
return ([self.joints_pos, self.joints_vel] +
441222
self._collect_from_attachments('proprioception'))
442223

443224
@property
444225
def kinematic_sensors(self):
445-
return ([self.sensors_gyro, self.sensors_velocimeter,
446-
self.sensors_accelerometer] +
226+
return ([self.sensors_gyro, self.sensors_accelerometer] +
447227
self._collect_from_attachments('kinematic_sensors'))
448228

449229
@property
450230
def dynamic_sensors(self):
451231
return ([self.sensors_force, self.sensors_torque, self.sensors_touch] +
452232
self._collect_from_attachments('dynamic_sensors'))
453-
454-
# Convenience observables for defining rewards and terminations.
455-
@composer.observable
456-
def veloc_strafe(self):
457-
return observable.MJCFFeature(
458-
'sensordata', self._entity.mjcf_model.sensor.velocimeter)[1]
459-
460-
@composer.observable
461-
def veloc_up(self):
462-
return observable.MJCFFeature(
463-
'sensordata', self._entity.mjcf_model.sensor.velocimeter)[2]
464-
465-
@composer.observable
466-
def veloc_forward(self):
467-
return observable.MJCFFeature(
468-
'sensordata', self._entity.mjcf_model.sensor.velocimeter)[0]
469-
470-
@composer.observable
471-
def gyro_backward_roll(self):
472-
return observable.MJCFFeature(
473-
'sensordata', self._entity.mjcf_model.sensor.gyro)[0]
474-
475-
@composer.observable
476-
def gyro_rightward_roll(self):
477-
return observable.MJCFFeature(
478-
'sensordata', self._entity.mjcf_model.sensor.gyro)[1]
479-
480-
@composer.observable
481-
def gyro_anticlockwise_spin(self):
482-
return observable.MJCFFeature(
483-
'sensordata', self._entity.mjcf_model.sensor.gyro)[2]
484-
485-
@composer.observable
486-
def torso_xvel(self):
487-
return observable.MJCFFeature('subtree_linvel', self._entity.root_body)[0]
488-
489-
@composer.observable
490-
def torso_yvel(self):
491-
return observable.MJCFFeature('subtree_linvel', self._entity.root_body)[1]
492-
493-
@composer.observable
494-
def prev_action(self):
495-
return observable.Generic(lambda _: self._entity.prev_action)

dm_control/locomotion/walkers/base_test.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ class FakeWalker(base.Walker):
3030

3131
def _build(self):
3232
self._mjcf_root = mjcf.RootElement(model='walker')
33-
self._egocentric_camera = self._mjcf_root.worldbody.add(
34-
'camera', name='egocentric', xyaxes=[0, -1, 0, 0, 0, 1])
3533
self._torso_body = self._mjcf_root.worldbody.add(
3634
'body', name='torso', xyaxes=[0, 1, 0, -1, 0, 0])
3735

@@ -47,22 +45,10 @@ def actuators(self):
4745
def root_body(self):
4846
return self._torso_body
4947

50-
@property
51-
def end_effectors(self):
52-
return []
53-
5448
@property
5549
def observable_joints(self):
5650
return []
5751

58-
@property
59-
def ground_contact_geoms(self):
60-
return []
61-
62-
@property
63-
def egocentric_camera(self):
64-
return self._egocentric_camera
65-
6652

6753
class BaseWalkerTest(absltest.TestCase):
6854

0 commit comments

Comments
 (0)