Fix input normalization and global normalization variance calculation#2835
Fix input normalization and global normalization variance calculation#2835pplantinga merged 36 commits intospeechbrain:developfrom
Conversation
|
The failing test here is one of the ASR CTC integration tests, looking across all three examples in the folder, the train loss is always higher than before. I'm not quite sure what to make of this, but it seems to be a real effect of correcting the standard deviation calculation. I went back to the develop version and changed one line: - current_std = torch.std(x, dim=0).detach().data
+ current_std = (x - self.glob_mean.detach()).square().mean().sqrt().dataAnd got a very similar increase in the training loss. I suppose a next step could be to try training a full recipe of some kind that includes input norm and see if the result is comparable to what we had before. |
|
I ran the ASR template example using current develop and this branch, which seems to me to alleviate any concern that the train loss is higher. Current This PR: |
|
Okay, I found the issue: the mean/stdev should be computed only over the time dimension, not all dimensions. I will fix this and update. |
|
@pplantinga @mravanelli the speaker normalisation is making this class so complicated for no reasons. By no reasons I mean -- we litteraly don't have a recipe using it. What about just removing it? |
rogiervd
left a comment
There was a problem hiding this comment.
I'm new to this code, so maybe I've got possibly naive questions (@TParcollet suggested I look at this so feel free to blame him). So can I check that my assumptions are correct?
- Under
norm_type=global, the existing implementation was not correct, so this changes the behaviour only of that, and not of the other threenorm_types. - Under
norm_type=global, avg_factor=None, the expected behaviour is that the statistics represent those of all the data seen so far. - Normalisation is meant to be per-dimension, i.e. the global mean and variance have the shape of a single feature vector. (This assumption I'm getting not from the code but because e.g. HTK used to do this.)
What is the desired behaviour under norm_type=global, avg_factor=0.5 (or another value)? I don't think I understand what avg_factor does.
Shouldn't the change to the algorithm, by using Welford's algorithm, be made to the norm_type=speaker too, if norm_type=speaker is retained?
| self.glob_std = (1 - self.weight) * self.glob_std.to( | ||
| current_std | ||
| ) + self.weight * current_std | ||
| # Should never get here, this is simply for safety |
There was a problem hiding this comment.
So should that be an assert instead of a ValueError?
| delta2 = new_tensor - mean | ||
| new_var = (delta * delta2).sum() / (delta.numel() - 1) | ||
| var = (var * weight + new_var * new_weight) / (weight + new_weight) | ||
| var = ddp_all_reduce(var, torch.distributed.ReduceOp.AVG) |
There was a problem hiding this comment.
Assuming that var means "variance" and not "(non-central) second moment", this does not at first glance look correct.
There was a problem hiding this comment.
Correct me if I'm wrong (stats is not my strongest suit), but I'm pretty sure this computes the central second moment, as delta and delta2 are the difference between X and E[X], making this E[(X - E[X])^2]. The one thing that is perhaps slightly incorrect is the - 1 correction which should only be applied at the end, but we are applying it every time. However, in general I expect this to be a very minor adjustment as n should usually be large.
You can also check the latest version, after my fix to compute over the length only, but it should be the same as here.
There was a problem hiding this comment.
Or did you mean that the ddp reduction is not done correctly? This could easily be the case, let me look at the reduction part once more.
There was a problem hiding this comment.
What I mean is that if you want to get the variance of all data (which I still assume is what you want for norm_type=global, avg_factor=None), then this is not correct.
Imagine one node has data [4,6] and another node [-4,-6]. Each node has a local variance of 1, so the average you compute here is also 1, but the overall variance is 26.
I'm okay with removing it, but we would still have to handle the old checkpoints which all have this.
Well, speaker norm was not correct either, but looking at my implementation here I can see that I haven't fixed it. I'll go ahead and fix it if we end up keeping it.
The
Perhaps. This change should be straightforward if we decide to keep speaker normalization. |
|
One more annoying issue: the original code averages over time first, then batch, but here I've averaged over time and batch at the same time. I wouldn't expect this to make a huge difference on new runs but it can affect the performance of loading old checkpoints. I guess I'll edit this PR to average time first and then batch, even though this is probably slower and adds complexity (especially for standard deviation). |
|
We honestly don't want to make the code slower for the sake of retro compatibility imho ... so if it's actually slower ... well ... maybe we should also check with the profiler how impactful this new normalisation is |
|
Okay, should have rechecked this part as the normalization dimension didn't matter for global norm, although it still affects For the |
|
It would still be really helpful for me to understand whether these are true:
@pplantinga you said
That much is obvious. So let me try and be clearer. Welford's algorithm is an incremental method for computing the overall mean and variance of all the data you've put into it. Is there a similar statement you can make when you use For clarity, my feeling is that the answer may be something like: "It seems like with what we're trying to achieve in this new version, there is no meaning any more |
Yes, this is the expected behavior.
I think the answer is yes here, assuming you meant that the reduction is done over time and batch dimensions only. The first version of the PR did not have this, but I have updated it to match.
Aha, I understand your question now. I think the |
|
Okay, this should be ready to review once more @TParcollet @rogiervd I have verified the performance is the same when loading old checkpoints. The only remaining step I will complete before merge is to train a new model to see if performance is roughly the same. |
|
@pplantinga I think that what @rogiervd meant is that Welford's algorithm renders the momentum meaningless? (Test still failing :p) |
Ah, maybe now I understand what's going on. Welford's algorithm per Wikipedia works for adding a single data point. That's not what we want to do, right? Presumably what we need to implement is https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm.
Do people expect an unbiased estimate of the variance? I could be mistaken but I don't think the square root of the unbiased estimate of the variance is an unbiased estimate of the standard deviation. My main concern about this is that I'm not sure whether, in terms of the Wikipedia page, the code correctly deals with the distinction between s^2_n and sigma^2_n.
In that case, wouldn't you want to review and merge my PR? Even it is pretty long. |
…unbiased variance
…ain into fix-input-norm-stdev
|
Okay, I have merged @rogiervd 's branch into this one, to hopefully clarify the computation of the gaussian statistics. This no longer includes unbiased variance calculation nor Welford's algorithm. I added the mask and std_norm=False back in. So this should be ready to review @TParcollet. |
TParcollet
left a comment
There was a problem hiding this comment.
Only minor comment! I trust @rogiervd and @pplantinga for the correctness! This is an important change, so we may want a last review from @rogiervd to make sure we are all happy with this.
| if isinstance(dim, int): | ||
| dim = (dim,) | ||
|
|
||
| # Use output size to compute N from the mask. Assumes N will |
There was a problem hiding this comment.
Assuming N is the count? Maybe we should clarify this.
There was a problem hiding this comment.
I think I don't understand what this 5 lines do :/
| self._load_statistics_dict(stats) | ||
|
|
||
|
|
||
| def get_mask(x, lengths=None, length_dim=1): |
There was a problem hiding this comment.
We have 1 billion function doing this in SB, I wonder if we should do smth about it.
There was a problem hiding this comment.
I mean its unsatisfying but not a major issue I guess.
|
@pplantinga any chance that you could run a simple time.time comparison to check the impact on this new code compared to the previous one on a big random batch? Then it will be good to me ! |
|
Hm, yes this is probably worth checking. I was a bit concerned about this given the complexity/length of the code and checks. |
rogiervd
left a comment
There was a problem hiding this comment.
Thanks a lot! This looks generally good.
I do have a number of small remaining questions and comments.
| The tensor to compute the statistics over. | ||
| mask: torch.Tensor | ||
| Padding mask to exclude padding from the statistics computation. | ||
| All dimensions other than `dim` should be ones (e.g. [B, T, 1, ...]) |
There was a problem hiding this comment.
But the other dimensions should be the same as the matching dimensions in x, right? Why not say that?
| matching dimensions in "x". | ||
| The other dimensions should have size 1. | ||
| If None, then scalar-valued statistics will be returned. | ||
| compute_var: bool |
There was a problem hiding this comment.
This is not used, I think?
| If None, then scalar-valued statistics will be returned. | ||
| compute_var: bool | ||
| Whether to compute the variance, in order to speed computation | ||
| when it is not needed. Returns `None` for variance if `False`. |
There was a problem hiding this comment.
Returns
Nonefor variance ifFalse
Could you make this more precise? Or remove it.
| The combined mean. | ||
| variance | ||
| The combined variance, relative to the new mean. | ||
| Returns None if either statistics set has variance of None |
There was a problem hiding this comment.
Instead of "Returns", "Is"?
| """ | ||
| Update the running count, running mean, and running standard deviation | ||
| by integrating new data x from multiple processes. | ||
| def mean_std_update(x, mask, dim, run_count, run_mean, run_std=None): |
There was a problem hiding this comment.
Ach, this was a copy error, should be fixed
|
|
||
| # Convert relative lengths to absolute lengths, then compute boolean mask | ||
| max_len = x.size(length_dim) | ||
| abs_lengths = lengths.unsqueeze(1) * max_len |
There was a problem hiding this comment.
(lengths * max_len).unsqueeze(1) should be faster and use less memory, right?
There was a problem hiding this comment.
And shouldn't it be (lengths * max_len + some_small_constants).unsqueeze(1) to stop rounding errors?
Note e.g. on my computer (presumably IEEE 754)
>>> (15 / 22) * 22
14.999999999999998
There was a problem hiding this comment.
I added a test for this, and it was still not failing, though I'm not exactly sure why. I added this code for safety anyway.
| @pytest.fixture | ||
| def sample_data(self): | ||
| """Create sample data for testing.""" | ||
| # Create a batch of 3 sequences with 4 features and variable lengths |
There was a problem hiding this comment.
Is a two-dimensional tensor good enough?
| x = torch.ones(3, 4, 2) # Batch size 3, seq len 4, feature dim 2 | ||
|
|
||
| # Relative lengths: 100%, 75%, 50% | ||
| lengths = torch.tensor([1.0, 0.75, 0.5]) |
There was a problem hiding this comment.
Don't you want to test with values like 15/22?
Found by running
for numerator in range(1, 50):
for denominator in range(numerator, 100):
if (numerator/denominator) * denominator < numerator:
print(numerator, denominator)
There was a problem hiding this comment.
Added a test for this.
| dim: Union[int, tuple, None] = None, | ||
| mask: Optional[torch.Tensor] = None, | ||
| dim: Union[int, tuple, None] = None, | ||
| compute_var: bool = True, |
There was a problem hiding this comment.
Do you test this? I didn't see it.
|
@pplantinga I've looked both at a profiler trace and a simple time.time() check, and this is not making the normalisation slower. I think that once you have fixed the small changes requested by Rogier, we can simply merge! (after the release of 1.0.2, just to have a bit of safety net :p) |
rogiervd
left a comment
There was a problem hiding this comment.
Thanks a lot for humouring my questions and suggestions. I think this is good!
|
@pplantinga feel free to merge this PR! |
…speechbrain#2835) Co-authored-by: Rogier van Dalen <r.vandalen@samsung.com> Co-authored-by: Parcollet Titouan <parcollet.titouan@gmail.com>
Fixes #2815
InputNormalization(andGlobalNorm) has several bugs, including the one listed above.BatchNorm. Instead, the variance should be computed against the running mean, not the tensor mean.mean_normis false andstd_normis true, the mean is not subtracted and added back, leading to incorrect variance normalization.update_until_epochwas set to 3 by default, which in most cases is twice the number of updates needed..datawhich is not recommended because it bypasses the gradient system. Probably doesn't matter for normal cases where this is used at the input, but it could matter if its used anywhere else in the network.Maybe some of these settings were not getting used (a cursory
git grepshows mostly default settings), but we should fix the bugs anyway. This PR adds extensive unittests (thanks Claude) and hopefully catches most of the edge cases.