@@ -558,5 +558,62 @@ def test_ball_velocity(self):
558558 ball_velocity = env .physics .bind (ball_root_joint ).qvel
559559 np .testing .assert_array_equal (ball_velocity , 0. )
560560
561+
562+ class _ScoringInitializer (soccer .Initializer ):
563+ """Initialize the ball for home team to repeatedly score goals."""
564+
565+ def __init__ (self ):
566+ self ._num_calls = 0
567+
568+ @property
569+ def num_calls (self ):
570+ return self ._num_calls
571+
572+ def __call__ (self , task , physics , random_state ):
573+ # Initialize `ball` along the y-axis with a positive y-velocity.
574+ task .ball .set_pose (physics , [2.0 , 0.0 , 1.5 ])
575+ task .ball .set_velocity (
576+ physics , velocity = [100.0 , 0.0 , 0.0 ], angular_velocity = 0. )
577+ for i , player in enumerate (task .players ):
578+ player .walker .reinitialize_pose (physics , random_state )
579+ (_ , _ , z ), quat = player .walker .get_pose (physics )
580+ player .walker .set_pose (physics , [- i * 5 , 0.0 , z ], quat )
581+ player .walker .set_velocity (physics , velocity = 0. , angular_velocity = 0. )
582+
583+ self ._num_calls += 1
584+
585+
586+ class MultiturnTaskTest (parameterized .TestCase ):
587+
588+ def test_multiple_goals (self ):
589+ initializer = _ScoringInitializer ()
590+ time_limit = 1.0
591+ control_timestep = 0.025
592+ env = composer .Environment (
593+ task = soccer .MultiturnTask (
594+ players = _home_team (1 ) + _away_team (1 ),
595+ arena = soccer .Pitch ((20 , 15 ), field_box = True ), # disable throw-in.
596+ initializer = initializer ,
597+ control_timestep = control_timestep ),
598+ time_limit = time_limit )
599+
600+ timestep = env .reset ()
601+ num_steps = 0
602+ rewards = [np .zeros (s .shape , s .dtype ) for s in env .reward_spec ()]
603+ while not timestep .last ():
604+ timestep = env .step ([spec .generate_value () for spec in env .action_spec ()])
605+ for reward , r_t in zip (rewards , timestep .reward ):
606+ reward += r_t
607+ num_steps += 1
608+ self .assertEqual (num_steps , time_limit / control_timestep )
609+
610+ num_scores = initializer .num_calls - 1 # discard initialization.
611+ self .assertEqual (num_scores , 6 )
612+ self .assertEqual (rewards , [
613+ np .full ((), num_scores , np .float32 ),
614+ np .full ((), - num_scores , np .float32 )
615+ ])
616+
617+
561618if __name__ == "__main__" :
562619 absltest .main ()
0 commit comments