2323
2424from dm_control import composer
2525from dm_control import mjcf
26+ from dm_control .composer .observation import observable
2627from dm_control .locomotion .walkers import legacy_base
2728import numpy as np
2829from PIL import Image
@@ -102,13 +103,27 @@ def _asset_png_with_background_rgba_bytes(asset_fname, background_rgba):
102103 return png_encoding
103104
104105
106+ class BoxHeadObservables (legacy_base .WalkerObservables ):
107+
108+ def __init__ (self , entity , camera_resolution ):
109+ self ._camera_resolution = camera_resolution
110+ super (BoxHeadObservables , self ).__init__ (entity )
111+
112+ @composer .observable
113+ def egocentric_camera (self ):
114+ width , height = self ._camera_resolution
115+ return observable .MJCFCamera (self ._entity .egocentric_camera ,
116+ width = width , height = height )
117+
118+
105119class BoxHead (legacy_base .Walker ):
106120 """A rollable and jumpable ball with a head."""
107121
108122 def _build (self ,
109123 name = 'walker' ,
110124 marker_rgba = None ,
111125 camera_control = False ,
126+ camera_resolution = (28 , 28 ),
112127 roll_gear = - 60 ,
113128 steer_gear = 55 ,
114129 walker_id = None ,
@@ -121,6 +136,7 @@ def _build(self,
121136 walkers (in multi-agent setting).
122137 camera_control: If `True`, the walker exposes two additional actuated
123138 degrees of freedom to control the egocentric camera height and tilt.
139+ camera_resolution: egocentric camera rendering resolution.
124140 roll_gear: gear determining forward acceleration.
125141 steer_gear: gear determining steering (spinning) torque.
126142 walker_id: (Optional) An integer in [0-10], this number will be shown on
@@ -176,6 +192,7 @@ def _build(self,
176192
177193 self ._root_joints = None
178194 self ._camera_control = camera_control
195+ self ._camera_resolution = camera_resolution
179196 if not camera_control :
180197 for name in ('camera_pitch' , 'camera_yaw' ):
181198 self ._mjcf_root .find ('actuator' , name ).remove ()
@@ -189,6 +206,9 @@ def _build(self,
189206 self ._prev_action = np .zeros (shape = self .action_spec .shape ,
190207 dtype = self .action_spec .dtype )
191208
209+ def _build_observables (self ):
210+ return BoxHeadObservables (self , camera_resolution = self ._camera_resolution )
211+
192212 @property
193213 def marker_geoms (self ):
194214 geoms = [
0 commit comments