-
Notifications
You must be signed in to change notification settings - Fork 745
Expand file tree
/
Copy pathrotations.py
More file actions
146 lines (117 loc) · 5.04 KB
/
rotations.py
File metadata and controls
146 lines (117 loc) · 5.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Copyright 2018 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Variations in 3D rotations."""
from dm_control.composer.variation import base
from dm_control.composer.variation import variation_values
from dm_control.utils import transformations
import numpy as np
IDENTITY_QUATERNION = np.array([1., 0., 0., 0.])
class UniformQuaternion(base.Variation):
"""Uniformly distributed unit quaternions."""
def __call__(self, initial_value=None, current_value=None, random_state=None):
random_state = random_state or np.random
u1, u2, u3 = random_state.uniform([0.] * 3, [1., 2. * np.pi, 2. * np.pi])
return np.array([np.sqrt(1. - u1) * np.sin(u2),
np.sqrt(1. - u1) * np.cos(u2),
np.sqrt(u1) * np.sin(u3),
np.sqrt(u1) * np.cos(u3)])
def __eq__(self, other):
return isinstance(other, UniformQuaternion)
def __repr__(self):
return "UniformQuaternion()"
class QuaternionFromAxisAngle(base.Variation):
"""Quaternion variation specified in terms of variations in axis and angle."""
def __init__(self, axis, angle):
self._axis = axis
self._angle = angle
def __call__(self, initial_value=None, current_value=None, random_state=None):
random_state = random_state or np.random
axis = variation_values.evaluate(
self._axis, initial_value, current_value, random_state)
angle = variation_values.evaluate(
self._angle, initial_value, current_value, random_state)
return transformations.axisangle_to_quat(np.asarray(axis) * angle)
def __eq__(self, other):
if not isinstance(other, QuaternionFromAxisAngle):
return False
return (
self._axis == other._axis
and self._angle == other._angle
)
def __repr__(self):
return (
f"QuaternionFromAxisAngle(axis={self._axis}, angle={self._angle})"
)
class QuaternionPreMultiply(base.Variation):
"""A variation that pre-multiplies an existing quaternion value.
This variation takes a quaternion value generated by another variation and
pre-multiplies it to an existing value. In cumulative mode, the new quaternion
is pre-multiplied to the current value being varied. In non-cumulative mode,
the new quaternion is pre-multiplied to a fixed initial value.
"""
def __init__(self, quat, cumulative=False):
self._quat = quat
self._cumulative = cumulative
def __call__(self, initial_value=None, current_value=None, random_state=None):
random_state = random_state or np.random
q1 = variation_values.evaluate(self._quat, initial_value, current_value,
random_state)
q2 = current_value if self._cumulative else initial_value
return transformations.quat_mul(np.asarray(q1), np.asarray(q2))
def __eq__(self, other):
if not isinstance(other, QuaternionPreMultiply):
return False
return self._quat == other._quat and self._cumulative == other._cumulative
def __repr__(self):
return (
f"QuaternionPreMultiply(quat={self._quat},"
f" cumulative={self._cumulative})"
)
class QuaternionRotate(base.Variation):
"""Variation that rotates a given vector by the given quaternion.
The vector can either be an existing value passed at evaluation, or specified
as a separate variation at construction. In the former case, cumulative mode
determines whether to use the current or initial value of the vector. The#
quaternion is always specified by a variation at construction.
"""
def __init__(self, quat, vec=None, cumulative=False):
self._quat = quat
self._vec = vec
self._cumulative = cumulative
def __call__(self, initial_value=None, current_value=None, random_state=None):
random_state = random_state or np.random
quat = variation_values.evaluate(
self._quat, initial_value, current_value, random_state
)
if self._vec is None:
vec = current_value if self._cumulative else initial_value
else:
vec = variation_values.evaluate(
self._vec, initial_value, current_value, random_state
)
return transformations.quat_rotate(np.asarray(quat), np.asarray(vec))
def __eq__(self, other):
if not isinstance(other, QuaternionRotate):
return False
return (
self._quat == other._quat
and self._vec == other._vec
and self._cumulative == other._cumulative
)
def __repr__(self):
return (
f"QuaternionRotate(quat={self._quat}, vec={self._vec},"
f" cumulative={self._cumulative})"
)