Quaternion Networks - Bug Fixes & Improvements#2464
Quaternion Networks - Bug Fixes & Improvements#2464mravanelli merged 21 commits intospeechbrain:developfrom
Conversation
TParcollet
left a comment
There was a problem hiding this comment.
Thanks! LGTM, see my questions.
| # (batch, channel, time) | ||
| x = x.transpose(1, -1) | ||
|
|
||
| if self.max_norm is not None: |
There was a problem hiding this comment.
is renorming individual quaternion components actually is strictly equivalent to renorming the quaternion?
There was a problem hiding this comment.
No I don't believe it is strictly equivalent. I couldn't find any references for how to approach it so I went with the simplest idea which was the component-wise renorm.
|
@Drew-Wagner, could you please fix the conflicts, merge the latest development, and do the last modifications? |
a95986e to
d7b5b7f
Compare
| def renorm_quaternion_weights_inplace( | ||
| r_weight, i_weight, j_weight, k_weight, max_norm | ||
| ): | ||
| """Renorms the magnitude of the quaternion-valued weights. | ||
|
|
||
| Arguments | ||
| --------- | ||
| r_weight : torch.Parameter | ||
| i_weight : torch.Parameter | ||
| j_weight : torch.Parameter | ||
| k_weight : torch.Parameter | ||
| max_norm : float | ||
| The maximum norm of the magnitude of the quaternion weights | ||
| """ | ||
| weight_magnitude = torch.sqrt( | ||
| r_weight.data**2 | ||
| + i_weight.data**2 | ||
| + j_weight.data**2 | ||
| + k_weight.data**2 | ||
| ) | ||
| renormed_weight_magnitude = torch.renorm( | ||
| weight_magnitude, p=2, dim=0, maxnorm=max_norm | ||
| ) | ||
| factor = renormed_weight_magnitude / weight_magnitude | ||
|
|
||
| r_weight.data *= factor | ||
| i_weight.data *= factor | ||
| j_weight.data *= factor | ||
| k_weight.data *= factor |
There was a problem hiding this comment.
@TParcollet Please review this implementation which renorms the weights according to the magnitude of the quaternions, rather than by individual components
A view was incorrectly being applied to broadcast tensors together
- adds max_norm option
- max_norm - swap - rename .b -> .bias
- in_channels must be divided by groups when creating kernels - Add checks to ensure divisibility
- the mean was not being subtracted from the input
- rqsrt was 8x faster
d7b5b7f to
0697c73
Compare
|
Thank you @Drew-Wagner for these fixes! |
What does this PR do?
This PR fixes several bugs which prevented the use of the quaternion network modules, and completes the collection by implementing avg and max pooling. No tests existed for quaternion networks. This PR introduces a minimum (and incomplete) set of tests.
Several bugs were present in the existing quaternion network modules and are fixed by this PR:
Several adjustments were made to improve compatibility of the QConv interface with regular convolution modules:
swapoption was added for QConv2dmax_normoption was added for QConv and QLinear modules.bto.biasA QPooling2d module is added which implements:
Breaking Changes:
.bto.bias, however given the number of bugs present, it seems unlikely that anyone was depending on this.Before submitting
PR review
Reviewer checklist