2424from dm_control import composer
2525from dm_control import mjcf
2626from dm_control .composer .observation import observable
27+ from dm_control .locomotion .walkers import initializers
2728from dm_control .locomotion .walkers import legacy_base
2829import numpy as np
2930from PIL import Image
@@ -104,6 +105,7 @@ def _asset_png_with_background_rgba_bytes(asset_fname, background_rgba):
104105
105106
106107class BoxHeadObservables (legacy_base .WalkerObservables ):
108+ """BoxHead observables with low-res camera and modulo'd rotational joints."""
107109
108110 def __init__ (self , entity , camera_resolution ):
109111 self ._camera_resolution = camera_resolution
@@ -115,6 +117,43 @@ def egocentric_camera(self):
115117 return observable .MJCFCamera (self ._entity .egocentric_camera ,
116118 width = width , height = height )
117119
120+ @property
121+ def proprioception (self ):
122+ proprioception = super (BoxHeadObservables , self ).proprioception
123+ if self ._entity .observable_camera_joints :
124+ return proprioception + [self .camera_joints_pos , self .camera_joints_vel ]
125+ return proprioception
126+
127+ @composer .observable
128+ def camera_joints_pos (self ):
129+
130+ def _sin (value , random_state ):
131+ del random_state
132+ return np .sin (value )
133+
134+ def _cos (value , random_state ):
135+ del random_state
136+ return np .cos (value )
137+
138+ sin_rotation_joints = observable .MJCFFeature (
139+ 'qpos' , self ._entity .observable_camera_joints , corruptor = _sin )
140+
141+ cos_rotation_joints = observable .MJCFFeature (
142+ 'qpos' , self ._entity .observable_camera_joints , corruptor = _cos )
143+
144+ def _camera_joints (physics ):
145+ return np .concatenate ([
146+ sin_rotation_joints (physics ),
147+ cos_rotation_joints (physics )
148+ ], - 1 )
149+
150+ return observable .Generic (_camera_joints )
151+
152+ @composer .observable
153+ def camera_joints_vel (self ):
154+ return observable .MJCFFeature (
155+ 'qvel' , self ._entity .observable_camera_joints )
156+
118157
119158class BoxHead (legacy_base .Walker ):
120159 """A rollable and jumpable ball with a head."""
@@ -146,7 +185,8 @@ def _build(self,
146185 Raises:
147186 ValueError: if received invalid walker_id.
148187 """
149- super (BoxHead , self )._build (initializer = initializer )
188+ super (BoxHead , self )._build (
189+ initializer = initializer or initializers .NoOpInitializer ())
150190 xml_path = os .path .join (_ASSETS_PATH , 'boxhead.xml' )
151191 self ._mjcf_root = mjcf .from_xml_string (resources .GetResource (xml_path , 'r' ))
152192 if name :
@@ -246,7 +286,9 @@ def set_pose(self, physics, position=None, quaternion=None):
246286 1 - 2 * (quaternion [2 ] ** 2 + quaternion [3 ] ** 2 ))
247287 physics .bind (self ._mjcf_root .find ('joint' , 'steer' )).qpos = z_angle
248288
249- def initialize_episode (self , physics , unused_random_state ):
289+ def initialize_episode (self , physics , random_state ):
290+ self .reinitialize_pose (physics , random_state )
291+
250292 if self ._camera_control :
251293 _compensate_gravity (physics ,
252294 self ._mjcf_root .find ('body' , 'egocentric_camera' ))
@@ -279,6 +321,15 @@ def end_effectors(self):
279321 def observable_joints (self ):
280322 return (self ._mjcf_root .find ('joint' , 'kick' ),)
281323
324+ @composer .cached_property
325+ def observable_camera_joints (self ):
326+ if self ._camera_control :
327+ return (
328+ self ._mjcf_root .find ('joint' , 'camera_yaw' ),
329+ self ._mjcf_root .find ('joint' , 'camera_pitch' ),
330+ )
331+ return ()
332+
282333 @composer .cached_property
283334 def egocentric_camera (self ):
284335 return self ._mjcf_root .find ('camera' , 'egocentric' )
0 commit comments