|
1 | | -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
2 | | -# |
3 | | -# Permission is hereby granted, free of charge, to any person obtaining a |
4 | | -# copy of this software and associated documentation files (the "Software"), |
5 | | -# to deal in the Software without restriction, including without limitation |
6 | | -# the rights to use, copy, modify, merge, publish, distribute, sublicense, |
7 | | -# and/or sell copies of the Software, and to permit persons to whom the |
8 | | -# Software is furnished to do so, subject to the following conditions: |
9 | | -# |
10 | | -# The above copyright notice and this permission notice shall be included in |
11 | | -# all copies or substantial portions of the Software. |
12 | | -# |
13 | | -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
14 | | -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
15 | | -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL |
16 | | -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
17 | | -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING |
18 | | -# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
19 | | -# DEALINGS IN THE SOFTWARE. |
20 | | -# |
21 | | -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES |
22 | | -# SPDX-License-Identifier: MIT |
23 | | - |
24 | | - |
25 | | -from collections import namedtuple |
26 | | -from itertools import product |
27 | | -from typing import Dict |
28 | | - |
29 | | -import torch |
30 | | -from torch import Tensor |
31 | | - |
32 | | -from se3_transformer.runtime.utils import degree_to_dim |
33 | | - |
34 | | -FiberEl = namedtuple('FiberEl', ['degree', 'channels']) |
35 | | - |
36 | | - |
37 | | -class Fiber(dict): |
38 | | - """ |
39 | | - Describes the structure of some set of features. |
40 | | - Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1. |
41 | | - Type-0 features: invariant scalars |
42 | | - Type-1 features: equivariant 3D vectors |
43 | | - Type-2 features: equivariant symmetric traceless matrices |
44 | | - ... |
45 | | -
|
46 | | - As inputs to a SE3 layer, there can be many features of the same types, and many features of different types. |
47 | | - The 'multiplicity' or 'number of channels' is the number of features of a given type. |
48 | | - This class puts together all the degrees and their multiplicities in order to describe |
49 | | - the inputs, outputs or hidden features of SE3 layers. |
50 | | - """ |
51 | | - |
52 | | - def __init__(self, structure): |
53 | | - if isinstance(structure, dict): |
54 | | - structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])] |
55 | | - elif not isinstance(structure[0], FiberEl): |
56 | | - structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1]))) |
57 | | - self.structure = structure |
58 | | - super().__init__({d: m for d, m in self.structure}) |
59 | | - |
60 | | - @property |
61 | | - def degrees(self): |
62 | | - return sorted([t.degree for t in self.structure]) |
63 | | - |
64 | | - @property |
65 | | - def channels(self): |
66 | | - return [self[d] for d in self.degrees] |
67 | | - |
68 | | - @property |
69 | | - def num_features(self): |
70 | | - """ Size of the resulting tensor if all features were concatenated together """ |
71 | | - return sum(t.channels * degree_to_dim(t.degree) for t in self.structure) |
72 | | - |
73 | | - @staticmethod |
74 | | - def create(num_degrees: int, num_channels: int): |
75 | | - """ Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """ |
76 | | - return Fiber([(degree, num_channels) for degree in range(num_degrees)]) |
77 | | - |
78 | | - @staticmethod |
79 | | - def from_features(feats: Dict[str, Tensor]): |
80 | | - """ Infer the Fiber structure from a feature dict """ |
81 | | - structure = {} |
82 | | - for k, v in feats.items(): |
83 | | - degree = int(k) |
84 | | - assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)' |
85 | | - assert v.shape[-1] == degree_to_dim(degree) |
86 | | - structure[degree] = v.shape[-2] |
87 | | - return Fiber(structure) |
88 | | - |
89 | | - def __getitem__(self, degree: int): |
90 | | - """ fiber[degree] returns the multiplicity for this degree """ |
91 | | - return dict(self.structure).get(degree, 0) |
92 | | - |
93 | | - def __iter__(self): |
94 | | - """ Iterate over namedtuples (degree, channels) """ |
95 | | - return iter(self.structure) |
96 | | - |
97 | | - def __mul__(self, other): |
98 | | - """ |
99 | | - If other in an int, multiplies all the multiplicities by other. |
100 | | - If other is a fiber, returns the cartesian product. |
101 | | - """ |
102 | | - if isinstance(other, Fiber): |
103 | | - return product(self.structure, other.structure) |
104 | | - elif isinstance(other, int): |
105 | | - return Fiber({t.degree: t.channels * other for t in self.structure}) |
106 | | - |
107 | | - def __add__(self, other): |
108 | | - """ |
109 | | - If other in an int, add other to all the multiplicities. |
110 | | - If other is a fiber, add the multiplicities of the fibers together. |
111 | | - """ |
112 | | - if isinstance(other, Fiber): |
113 | | - return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure}) |
114 | | - elif isinstance(other, int): |
115 | | - return Fiber({t.degree: t.channels + other for t in self.structure}) |
116 | | - |
117 | | - def __repr__(self): |
118 | | - return str(self.structure) |
119 | | - |
120 | | - @staticmethod |
121 | | - def combine_max(f1, f2): |
122 | | - """ Combine two fiber by taking the maximum multiplicity for each degree in both fibers """ |
123 | | - new_dict = dict(f1.structure) |
124 | | - for k, m in f2.structure: |
125 | | - new_dict[k] = max(new_dict.get(k, 0), m) |
126 | | - |
127 | | - return Fiber(list(new_dict.items())) |
128 | | - |
129 | | - @staticmethod |
130 | | - def combine_selectively(f1, f2): |
131 | | - """ Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """ |
132 | | - # only use orders which occur in fiber f1 |
133 | | - new_dict = dict(f1.structure) |
134 | | - for k in f1.degrees: |
135 | | - if k in f2.degrees: |
136 | | - new_dict[k] += f2[k] |
137 | | - return Fiber(list(new_dict.items())) |
138 | | - |
139 | | - def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int): |
140 | | - # dict(N, num_channels, 2d+1) -> (N, num_heads, -1) |
141 | | - fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in |
142 | | - self.degrees] |
143 | | - fibers = torch.cat(fibers, -1) |
144 | | - return fibers |
| 1 | +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# |
| 3 | +# Permission is hereby granted, free of charge, to any person obtaining a |
| 4 | +# copy of this software and associated documentation files (the "Software"), |
| 5 | +# to deal in the Software without restriction, including without limitation |
| 6 | +# the rights to use, copy, modify, merge, publish, distribute, sublicense, |
| 7 | +# and/or sell copies of the Software, and to permit persons to whom the |
| 8 | +# Software is furnished to do so, subject to the following conditions: |
| 9 | +# |
| 10 | +# The above copyright notice and this permission notice shall be included in |
| 11 | +# all copies or substantial portions of the Software. |
| 12 | +# |
| 13 | +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 14 | +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 15 | +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL |
| 16 | +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 17 | +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING |
| 18 | +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
| 19 | +# DEALINGS IN THE SOFTWARE. |
| 20 | +# |
| 21 | +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES |
| 22 | +# SPDX-License-Identifier: MIT |
| 23 | + |
| 24 | + |
| 25 | +from collections import namedtuple |
| 26 | +from itertools import product |
| 27 | +from typing import Dict |
| 28 | + |
| 29 | +import torch |
| 30 | +from torch import Tensor |
| 31 | + |
| 32 | +from se3_transformer.runtime.utils import degree_to_dim |
| 33 | + |
| 34 | +FiberEl = namedtuple('FiberEl', ['degree', 'channels']) |
| 35 | + |
| 36 | + |
| 37 | +class Fiber(dict): |
| 38 | + """ |
| 39 | + Describes the structure of some set of features. |
| 40 | + Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1. |
| 41 | + Type-0 features: invariant scalars |
| 42 | + Type-1 features: equivariant 3D vectors |
| 43 | + Type-2 features: equivariant symmetric traceless matrices |
| 44 | + ... |
| 45 | +
|
| 46 | + As inputs to a SE3 layer, there can be many features of the same types, and many features of different types. |
| 47 | + The 'multiplicity' or 'number of channels' is the number of features of a given type. |
| 48 | + This class puts together all the degrees and their multiplicities in order to describe |
| 49 | + the inputs, outputs or hidden features of SE3 layers. |
| 50 | + """ |
| 51 | + |
| 52 | + def __init__(self, structure): |
| 53 | + if isinstance(structure, dict): |
| 54 | + structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])] |
| 55 | + elif not isinstance(structure[0], FiberEl): |
| 56 | + structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1]))) |
| 57 | + self.structure = structure |
| 58 | + super().__init__({d: m for d, m in self.structure}) |
| 59 | + |
| 60 | + @property |
| 61 | + def degrees(self): |
| 62 | + return sorted([t.degree for t in self.structure]) |
| 63 | + |
| 64 | + @property |
| 65 | + def channels(self): |
| 66 | + return [self[d] for d in self.degrees] |
| 67 | + |
| 68 | + @property |
| 69 | + def num_features(self): |
| 70 | + """ Size of the resulting tensor if all features were concatenated together """ |
| 71 | + return sum(t.channels * degree_to_dim(t.degree) for t in self.structure) |
| 72 | + |
| 73 | + @staticmethod |
| 74 | + def create(num_degrees: int, num_channels: int): |
| 75 | + """ Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """ |
| 76 | + return Fiber([(degree, num_channels) for degree in range(num_degrees)]) |
| 77 | + |
| 78 | + @staticmethod |
| 79 | + def from_features(feats: Dict[str, Tensor]): |
| 80 | + """ Infer the Fiber structure from a feature dict """ |
| 81 | + structure = {} |
| 82 | + for k, v in feats.items(): |
| 83 | + degree = int(k) |
| 84 | + assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)' |
| 85 | + assert v.shape[-1] == degree_to_dim(degree) |
| 86 | + structure[degree] = v.shape[-2] |
| 87 | + return Fiber(structure) |
| 88 | + |
| 89 | + def __getitem__(self, degree: int): |
| 90 | + """ fiber[degree] returns the multiplicity for this degree """ |
| 91 | + return dict(self.structure).get(degree, 0) |
| 92 | + |
| 93 | + def __iter__(self): |
| 94 | + """ Iterate over namedtuples (degree, channels) """ |
| 95 | + return iter(self.structure) |
| 96 | + |
| 97 | + def __mul__(self, other): |
| 98 | + """ |
| 99 | + If other in an int, multiplies all the multiplicities by other. |
| 100 | + If other is a fiber, returns the cartesian product. |
| 101 | + """ |
| 102 | + if isinstance(other, Fiber): |
| 103 | + return product(self.structure, other.structure) |
| 104 | + elif isinstance(other, int): |
| 105 | + return Fiber({t.degree: t.channels * other for t in self.structure}) |
| 106 | + |
| 107 | + def __add__(self, other): |
| 108 | + """ |
| 109 | + If other in an int, add other to all the multiplicities. |
| 110 | + If other is a fiber, add the multiplicities of the fibers together. |
| 111 | + """ |
| 112 | + if isinstance(other, Fiber): |
| 113 | + return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure}) |
| 114 | + elif isinstance(other, int): |
| 115 | + return Fiber({t.degree: t.channels + other for t in self.structure}) |
| 116 | + |
| 117 | + def __repr__(self): |
| 118 | + return str(self.structure) |
| 119 | + |
| 120 | + @staticmethod |
| 121 | + def combine_max(f1, f2): |
| 122 | + """ Combine two fiber by taking the maximum multiplicity for each degree in both fibers """ |
| 123 | + new_dict = dict(f1.structure) |
| 124 | + for k, m in f2.structure: |
| 125 | + new_dict[k] = max(new_dict.get(k, 0), m) |
| 126 | + |
| 127 | + return Fiber(list(new_dict.items())) |
| 128 | + |
| 129 | + @staticmethod |
| 130 | + def combine_selectively(f1, f2): |
| 131 | + """ Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """ |
| 132 | + # only use orders which occur in fiber f1 |
| 133 | + new_dict = dict(f1.structure) |
| 134 | + for k in f1.degrees: |
| 135 | + if k in f2.degrees: |
| 136 | + new_dict[k] += f2[k] |
| 137 | + return Fiber(list(new_dict.items())) |
| 138 | + |
| 139 | + def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int): |
| 140 | + # dict(N, num_channels, 2d+1) -> (N, num_heads, -1) |
| 141 | + fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in |
| 142 | + self.degrees] |
| 143 | + fibers = torch.cat(fibers, -1) |
| 144 | + return fibers |
0 commit comments