Skip to content

Commit 5378778

Browse files
liusiqi43copybara-github
authored andcommitted
Introduce a multiturn variant of the task.
PiperOrigin-RevId: 373348349 Change-Id: Ie2a58c09149fe160b9c6c902064ba3180d8cbaaf
1 parent d965ecb commit 5378778

4 files changed

Lines changed: 108 additions & 1 deletion

File tree

dm_control/locomotion/soccer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from dm_control.locomotion.soccer.pitch import RandomizedPitch
3131
from dm_control.locomotion.soccer.soccer_ball import regulation_soccer_ball
3232
from dm_control.locomotion.soccer.soccer_ball import SoccerBall
33+
from dm_control.locomotion.soccer.task import MultiturnTask
3334
from dm_control.locomotion.soccer.task import Task
3435
from dm_control.locomotion.soccer.team import Player
3536
from dm_control.locomotion.soccer.team import Team

dm_control/locomotion/soccer/task.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,52 @@ def before_step(self, physics, actions, random_state):
216216
def action_spec(self, physics):
217217
"""Return multi-agent action_spec."""
218218
return [player.walker.action_spec for player in self.players]
219+
220+
221+
class MultiturnTask(Task):
222+
"""Continuous game play through scoring events until timeout."""
223+
224+
def __init__(self,
225+
players,
226+
arena,
227+
ball=None,
228+
initializer=None,
229+
observables=None,
230+
disable_walker_contacts=False,
231+
nconmax_per_player=200,
232+
njmax_per_player=200,
233+
control_timestep=0.025,
234+
tracking_cameras=()):
235+
"""See base class."""
236+
super().__init__(
237+
players,
238+
arena,
239+
ball=ball,
240+
initializer=initializer,
241+
observables=observables,
242+
disable_walker_contacts=disable_walker_contacts,
243+
nconmax_per_player=nconmax_per_player,
244+
njmax_per_player=njmax_per_player,
245+
control_timestep=control_timestep,
246+
tracking_cameras=tracking_cameras)
247+
248+
# If `True`, reset ball entity trackers before the next step.
249+
self._should_reset = False
250+
251+
def should_terminate_episode(self, physics):
252+
return False
253+
254+
def get_discount(self, physics):
255+
return np.ones((), np.float32)
256+
257+
def before_step(self, physics, actions, random_state):
258+
super(MultiturnTask, self).before_step(physics, actions, random_state)
259+
if self._should_reset:
260+
self.ball.initialize_entity_trackers()
261+
self._should_reset = False
262+
263+
def after_step(self, physics, random_state):
264+
super(MultiturnTask, self).after_step(physics, random_state)
265+
if self.arena.detected_goal():
266+
self._initializer(self, physics, random_state)
267+
self._should_reset = True

dm_control/locomotion/soccer/task_test.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,5 +558,62 @@ def test_ball_velocity(self):
558558
ball_velocity = env.physics.bind(ball_root_joint).qvel
559559
np.testing.assert_array_equal(ball_velocity, 0.)
560560

561+
562+
class _ScoringInitializer(soccer.Initializer):
563+
"""Initialize the ball for home team to repeatedly score goals."""
564+
565+
def __init__(self):
566+
self._num_calls = 0
567+
568+
@property
569+
def num_calls(self):
570+
return self._num_calls
571+
572+
def __call__(self, task, physics, random_state):
573+
# Initialize `ball` along the y-axis with a positive y-velocity.
574+
task.ball.set_pose(physics, [2.0, 0.0, 1.5])
575+
task.ball.set_velocity(
576+
physics, velocity=[100.0, 0.0, 0.0], angular_velocity=0.)
577+
for i, player in enumerate(task.players):
578+
player.walker.reinitialize_pose(physics, random_state)
579+
(_, _, z), quat = player.walker.get_pose(physics)
580+
player.walker.set_pose(physics, [-i * 5, 0.0, z], quat)
581+
player.walker.set_velocity(physics, velocity=0., angular_velocity=0.)
582+
583+
self._num_calls += 1
584+
585+
586+
class MultiturnTaskTest(parameterized.TestCase):
587+
588+
def test_multiple_goals(self):
589+
initializer = _ScoringInitializer()
590+
time_limit = 1.0
591+
control_timestep = 0.025
592+
env = composer.Environment(
593+
task=soccer.MultiturnTask(
594+
players=_home_team(1) + _away_team(1),
595+
arena=soccer.Pitch((20, 15), field_box=True), # disable throw-in.
596+
initializer=initializer,
597+
control_timestep=control_timestep),
598+
time_limit=time_limit)
599+
600+
timestep = env.reset()
601+
num_steps = 0
602+
rewards = [np.zeros(s.shape, s.dtype) for s in env.reward_spec()]
603+
while not timestep.last():
604+
timestep = env.step([spec.generate_value() for spec in env.action_spec()])
605+
for reward, r_t in zip(rewards, timestep.reward):
606+
reward += r_t
607+
num_steps += 1
608+
self.assertEqual(num_steps, time_limit / control_timestep)
609+
610+
num_scores = initializer.num_calls - 1 # discard initialization.
611+
self.assertEqual(num_scores, 6)
612+
self.assertEqual(rewards, [
613+
np.full((), num_scores, np.float32),
614+
np.full((), -num_scores, np.float32)
615+
])
616+
617+
561618
if __name__ == "__main__":
562619
absltest.main()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def is_excluded(s):
177177

178178
setup(
179179
name='dm_control',
180-
version='0.0.372528912',
180+
version='0.0.373348349',
181181
description='Continuous control environments and MuJoCo Python bindings.',
182182
author='DeepMind',
183183
license='Apache License, Version 2.0',

0 commit comments

Comments
 (0)