Skip to content

Commit db36290

Browse files
DeepMindalimuldal
authored andcommitted
Locomotion reference pose task: Add more general quaternion difference reward computation.
PiperOrigin-RevId: 369839626 Change-Id: I19a67a299e06143ba8fca86b6a00b1aff51f53fb
1 parent faff446 commit db36290

2 files changed

Lines changed: 96 additions & 8 deletions

File tree

dm_control/locomotion/tasks/reference_pose/rewards.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def bounded_quat_dist(source, target):
3636
Returns:
3737
quaternion distance.
3838
"""
39+
source /= np.linalg.norm(source)
40+
target /= np.linalg.norm(target)
3941
default_dist = tr.quat_dist(source, target)
4042
anti_dist = tr.quat_dist(source, -np.asarray(target))
4143
min_dist = np.minimum(default_dist, anti_dist)
@@ -55,14 +57,16 @@ def compute_squared_differences(walker_features, reference_features,
5557
if 'quaternion' not in k:
5658
squared_differences[k] = np.sum(
5759
(walker_features[k] - reference_features[k])**2)
58-
quat_dists = np.array([
59-
bounded_quat_dist(w, r)
60-
for w, r in zip(walker_features['body_quaternions'],
61-
reference_features['body_quaternions'])
62-
])
63-
squared_differences['body_quaternions'] = np.sum(quat_dists**2)
64-
squared_differences['quaternion'] = bounded_quat_dist(
65-
walker_features['quaternion'], reference_features['quaternion'])**2
60+
elif 'quaternions' in k:
61+
quat_dists = np.array([
62+
bounded_quat_dist(w, r)
63+
for w, r in zip(walker_features[k], reference_features[k])
64+
])
65+
squared_differences[k] = np.sum(quat_dists**2)
66+
else:
67+
squared_differences[k] = bounded_quat_dist(walker_features[k],
68+
reference_features[k])**2
69+
6670
return squared_differences
6771

6872

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2021 The dm_control Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
"""Tests for dm_control.locomotion.tasks.reference_pose.rewards."""
16+
17+
from absl.testing import absltest
18+
from dm_control.locomotion.tasks.reference_pose import rewards
19+
import numpy as np
20+
21+
WALKER_FEATURES = {
22+
'scalar': 0.,
23+
'vector': np.ones(3),
24+
'match': 0.1,
25+
}
26+
27+
REFERENCE_FEATURES = {
28+
'scalar': 1.5,
29+
'vector': np.full(3, 2),
30+
'match': 0.1,
31+
}
32+
33+
QUATERNION_FEATURES = {
34+
'unmatched_quaternion': (1., 0., 0., 0.),
35+
'matched_quaternions': [(1., 0., 1., 0.), (0.707, 0.707, 0., 0.)],
36+
}
37+
38+
REFERENCE_QUATERNION_FEATURES = {
39+
'unmatched_quaternion': (0., 0., 0., 1.),
40+
'matched_quaternions': [(1., 0., 1., 0.), (0.707, 0.707, 0., 0.)],
41+
}
42+
43+
44+
EXPECTED_DIFFERENCES = {
45+
'scalar': 2.25,
46+
'vector': 3.,
47+
'match': 0.,
48+
'unmatched_quaternion': np.sum(rewards.bounded_quat_dist(
49+
QUATERNION_FEATURES['unmatched_quaternion'],
50+
REFERENCE_QUATERNION_FEATURES['unmatched_quaternion']))**2,
51+
'matched_quaternions': 0.,
52+
}
53+
54+
EXCLUDE_KEYS = ('scalar', 'match')
55+
56+
57+
class RewardsTest(absltest.TestCase):
58+
59+
def test_compute_squared_differences(self):
60+
"""Basic usage."""
61+
differences = rewards.compute_squared_differences(
62+
WALKER_FEATURES, REFERENCE_FEATURES)
63+
for key, difference in differences.items():
64+
self.assertEqual(difference, EXPECTED_DIFFERENCES[key])
65+
66+
def test_compute_squared_differences_exclude_keys(self):
67+
"""Test excluding some keys from squared difference computation."""
68+
differences = rewards.compute_squared_differences(
69+
WALKER_FEATURES, REFERENCE_FEATURES, exclude_keys=EXCLUDE_KEYS)
70+
for key in EXCLUDE_KEYS:
71+
self.assertNotIn(key, differences)
72+
73+
def test_compute_squared_differences_quaternion(self):
74+
"""Test that quaternions use a different distance computation."""
75+
76+
differences = rewards.compute_squared_differences(
77+
QUATERNION_FEATURES, REFERENCE_QUATERNION_FEATURES)
78+
79+
for key, difference in differences.items():
80+
self.assertAlmostEqual(difference, EXPECTED_DIFFERENCES[key])
81+
82+
83+
if __name__ == '__main__':
84+
absltest.main()

0 commit comments

Comments
 (0)