Skip to content

Fix input normalization and global normalization variance calculation#2835

Merged
pplantinga merged 36 commits intospeechbrain:developfrom
pplantinga:fix-input-norm-stdev
Apr 26, 2025
Merged

Fix input normalization and global normalization variance calculation#2835
pplantinga merged 36 commits intospeechbrain:developfrom
pplantinga:fix-input-norm-stdev

Conversation

@pplantinga
Copy link
Copy Markdown
Collaborator

@pplantinga pplantinga commented Feb 25, 2025

Fixes #2815

InputNormalization (and GlobalNorm) has several bugs, including the one listed above.

  • Mainly, it is mathematically incorrect to average the standard deviation across inputs, which more closely mimics the behavior of BatchNorm. Instead, the variance should be computed against the running mean, not the tensor mean.
  • mean and variance normalization was applied to the padding (it should be ignored)
  • if mean_norm is false and std_norm is true, the mean is not subtracted and added back, leading to incorrect variance normalization.
  • update_until_epoch was set to 3 by default, which in most cases is twice the number of updates needed.
  • Updates are applied directly to .data which is not recommended because it bypasses the gradient system. Probably doesn't matter for normal cases where this is used at the input, but it could matter if its used anywhere else in the network.

Maybe some of these settings were not getting used (a cursory git grep shows mostly default settings), but we should fix the bugs anyway. This PR adds extensive unittests (thanks Claude) and hopefully catches most of the edge cases.

@pplantinga pplantinga added bug Something isn't working correctness Functionality not objectively broken, but may be surprising or wrong e.g. regarding literature labels Feb 25, 2025
@pplantinga pplantinga added this to the v1.0.3 milestone Feb 25, 2025
@pplantinga pplantinga self-assigned this Feb 25, 2025
@pplantinga
Copy link
Copy Markdown
Collaborator Author

The failing test here is one of the ASR CTC integration tests, looking across all three examples in the folder, the train loss is always higher than before. I'm not quite sure what to make of this, but it seems to be a real effect of correcting the standard deviation calculation. I went back to the develop version and changed one line:

- current_std = torch.std(x, dim=0).detach().data
+ current_std = (x - self.glob_mean.detach()).square().mean().sqrt().data

And got a very similar increase in the training loss.

I suppose a next step could be to try training a full recipe of some kind that includes input norm and see if the result is comparable to what we had before.

@pplantinga
Copy link
Copy Markdown
Collaborator Author

I ran the ASR template example using current develop and this branch, which seems to me to alleviate any concern that the train loss is higher.

Current develop:

epoch: 1, lr: 1.00e+00 - train loss: 1.54 - valid loss: 1.33, valid CER: 6.42, valid WER: 13.37
epoch: 2, lr: 1.00e+00 - train loss: 1.25 - valid loss: 1.36, valid CER: 12.60, valid WER: 22.30
epoch: 3, lr: 8.00e-01 - train loss: 1.07 - valid loss: 1.34, valid CER: 9.63, valid WER: 15.76
epoch: 4, lr: 8.00e-01 - train loss: 9.96e-01 - valid loss: 1.35, valid CER: 8.09, valid WER: 14.64
epoch: 5, lr: 8.00e-01 - train loss: 9.72e-01 - valid loss: 1.34, valid CER: 6.59, valid WER: 12.12
epoch: 6, lr: 8.00e-01 - train loss: 1.39 - valid loss: 1.36, valid CER: 7.63, valid WER: 14.27
epoch: 7, lr: 6.40e-01 - train loss: 1.32 - valid loss: 1.33, valid CER: 6.06, valid WER: 12.22
epoch: 8, lr: 6.40e-01 - train loss: 1.33 - valid loss: 1.33, valid CER: 4.68, valid WER: 10.35
epoch: 9, lr: 6.40e-01 - train loss: 1.28 - valid loss: 1.33, valid CER: 7.39, valid WER: 13.90
epoch: 10, lr: 5.12e-01 - train loss: 1.27 - valid loss: 1.32, valid CER: 10.23, valid WER: 19.99
epoch: 11, lr: 4.10e-01 - train loss: 1.25 - valid loss: 1.31, valid CER: 8.17, valid WER: 14.97
epoch: 12, lr: 4.10e-01 - train loss: 1.25 - valid loss: 1.31, valid CER: 7.17, valid WER: 13.73
epoch: 13, lr: 4.10e-01 - train loss: 1.24 - valid loss: 1.30, valid CER: 4.96, valid WER: 11.54
epoch: 14, lr: 4.10e-01 - train loss: 1.22 - valid loss: 1.30, valid CER: 5.46, valid WER: 10.87
epoch: 15, lr: 4.10e-01 - train loss: 1.22 - valid loss: 1.31, valid CER: 6.17, valid WER: 12.36
Epoch loaded: 8 - test loss: 1.32, test CER: 4.25, test WER: 7.36

This PR:

epoch: 1, lr: 1.00e+00 - train loss: 1.57 - valid loss: 1.34, valid CER: 7.63, valid WER: 14.22
epoch: 2, lr: 1.00e+00 - train loss: 1.25 - valid loss: 1.36, valid CER: 6.43, valid WER: 12.91
epoch: 3, lr: 1.00e+00 - train loss: 1.12 - valid loss: 1.38, valid CER: 7.67, valid WER: 13.50
epoch: 4, lr: 8.00e-01 - train loss: 9.99e-01 - valid loss: 1.35, valid CER: 5.82, valid WER: 12.32
epoch: 5, lr: 8.00e-01 - train loss: 9.74e-01 - valid loss: 1.35, valid CER: 9.89, valid WER: 18.54
epoch: 6, lr: 6.40e-01 - train loss: 1.37 - valid loss: 1.33, valid CER: 9.37, valid WER: 16.54
epoch: 7, lr: 6.40e-01 - train loss: 1.34 - valid loss: 1.33, valid CER: 9.50, valid WER: 16.81
epoch: 8, lr: 5.12e-01 - train loss: 1.28 - valid loss: 1.32, valid CER: 6.62, valid WER: 12.88
epoch: 9, lr: 5.12e-01 - train loss: 1.26 - valid loss: 1.31, valid CER: 8.46, valid WER: 16.59
epoch: 10, lr: 4.10e-01 - train loss: 1.25 - valid loss: 1.30, valid CER: 6.54, valid WER: 13.33
epoch: 11, lr: 4.10e-01 - train loss: 1.26 - valid loss: 1.30, valid CER: 5.38, valid WER: 10.58
epoch: 12, lr: 4.10e-01 - train loss: 1.24 - valid loss: 1.31, valid CER: 5.15, valid WER: 11.32
epoch: 13, lr: 3.28e-01 - train loss: 1.22 - valid loss: 1.30, valid CER: 7.03, valid WER: 14.19
epoch: 14, lr: 2.62e-01 - train loss: 1.21 - valid loss: 1.29, valid CER: 5.59, valid WER: 12.02
epoch: 15, lr: 2.62e-01 - train loss: 1.21 - valid loss: 1.30, valid CER: 5.33, valid WER: 11.12
Epoch loaded: 11 - test loss: 1.30, test CER: 3.64, test WER: 6.89

@pplantinga
Copy link
Copy Markdown
Collaborator Author

Okay, I found the issue: the mean/stdev should be computed only over the time dimension, not all dimensions. I will fix this and update.

@TParcollet
Copy link
Copy Markdown
Collaborator

@pplantinga @mravanelli the speaker normalisation is making this class so complicated for no reasons. By no reasons I mean -- we litteraly don't have a recipe using it. What about just removing it?

Copy link
Copy Markdown
Contributor

@rogiervd rogiervd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm new to this code, so maybe I've got possibly naive questions (@TParcollet suggested I look at this so feel free to blame him). So can I check that my assumptions are correct?

  • Under norm_type=global, the existing implementation was not correct, so this changes the behaviour only of that, and not of the other three norm_types.
  • Under norm_type=global, avg_factor=None, the expected behaviour is that the statistics represent those of all the data seen so far.
  • Normalisation is meant to be per-dimension, i.e. the global mean and variance have the shape of a single feature vector. (This assumption I'm getting not from the code but because e.g. HTK used to do this.)

What is the desired behaviour under norm_type=global, avg_factor=0.5 (or another value)? I don't think I understand what avg_factor does.

Shouldn't the change to the algorithm, by using Welford's algorithm, be made to the norm_type=speaker too, if norm_type=speaker is retained?

Comment thread speechbrain/processing/features.py
Comment thread speechbrain/processing/features.py Outdated
self.glob_std = (1 - self.weight) * self.glob_std.to(
current_std
) + self.weight * current_std
# Should never get here, this is simply for safety
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So should that be an assert instead of a ValueError?

Comment thread speechbrain/processing/features.py Outdated
delta2 = new_tensor - mean
new_var = (delta * delta2).sum() / (delta.numel() - 1)
var = (var * weight + new_var * new_weight) / (weight + new_weight)
var = ddp_all_reduce(var, torch.distributed.ReduceOp.AVG)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming that var means "variance" and not "(non-central) second moment", this does not at first glance look correct.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong (stats is not my strongest suit), but I'm pretty sure this computes the central second moment, as delta and delta2 are the difference between X and E[X], making this E[(X - E[X])^2]. The one thing that is perhaps slightly incorrect is the - 1 correction which should only be applied at the end, but we are applying it every time. However, in general I expect this to be a very minor adjustment as n should usually be large.

You can also check the latest version, after my fix to compute over the length only, but it should be the same as here.

Copy link
Copy Markdown
Collaborator Author

@pplantinga pplantinga Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or did you mean that the ddp reduction is not done correctly? This could easily be the case, let me look at the reduction part once more.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is that if you want to get the variance of all data (which I still assume is what you want for norm_type=global, avg_factor=None), then this is not correct.

Imagine one node has data [4,6] and another node [-4,-6]. Each node has a local variance of 1, so the average you compute here is also 1, but the overall variance is 26.

@pplantinga
Copy link
Copy Markdown
Collaborator Author

@pplantinga @mravanelli the speaker normalisation is making this class so complicated for no reasons. By no reasons I mean -- we litteraly don't have a recipe using it. What about just removing it?

I'm okay with removing it, but we would still have to handle the old checkpoints which all have this.

Under norm_type=global, the existing implementation was not correct, so this changes the behaviour only of that, and not of the other three norm_types.

Well, speaker norm was not correct either, but looking at my implementation here I can see that I haven't fixed it. I'll go ahead and fix it if we end up keeping it.

What is the desired behaviour under norm_type=global, avg_factor=0.5 (or another value)? I don't think I understand what avg_factor does.

The avg_factor updates the stats according to (1 - avg_factor) * new_stat + avg_factor * old_stat, e.g. if avg_factor is 0.1, we would include 10% of the new stat and 90% of the old stat. This might be desirable if you expect some distribution shift over training and want to weight newer samples stronger.

Shouldn't the change to the algorithm, by using Welford's algorithm, be made to the norm_type=speaker too, if norm_type=speaker is retained?

Perhaps. This change should be straightforward if we decide to keep speaker normalization.

@pplantinga
Copy link
Copy Markdown
Collaborator Author

pplantinga commented Feb 28, 2025

One more annoying issue: the original code averages over time first, then batch, but here I've averaged over time and batch at the same time. I wouldn't expect this to make a huge difference on new runs but it can affect the performance of loading old checkpoints. I guess I'll edit this PR to average time first and then batch, even though this is probably slower and adds complexity (especially for standard deviation).

@TParcollet
Copy link
Copy Markdown
Collaborator

We honestly don't want to make the code slower for the sake of retro compatibility imho ... so if it's actually slower ... well ... maybe we should also check with the profiler how impactful this new normalisation is

@pplantinga
Copy link
Copy Markdown
Collaborator Author

Okay, should have rechecked this part as the normalization dimension didn't matter for global norm, although it still affects norm_type=batch -- the real cause of the performance diff when loading old checkpoints was due to the padding not getting normalized (I explicitly excluded it). I can certainly update the PR to normalize the padding again, but do we want to have an option to exclude the padding? My assumption is that most would not want the padding normalized, but it seems our checkpoints might depend on this point.

For the norm_type=batch, is it more important to keep backwards compatibility or to update to the correct algorithm? I don't see any instances of norm_type=batch in the repo. I guess this norm type is equivalent to torch.nn.BatchNorm1d(keep_running_stats=False). We also could deprecate this one in addition to the speaker norm.

@rogiervd
Copy link
Copy Markdown
Contributor

rogiervd commented Mar 3, 2025

It would still be really helpful for me to understand whether these are true:

  • Under norm_type=global, avg_factor=None, the expected behaviour is that the statistics represent those of all the data seen so far.
  • Normalisation is meant to be per-dimension, i.e. the global mean and variance have the shape of a single feature vector. (This assumption I'm getting not from the code but because e.g. HTK used to do this.)

@pplantinga you said

The avg_factor updates the stats according to (1 - avg_factor) * new_stat + avg_factor * old_stat, e.g. if avg_factor is 0.1, we would include 10% of the new stat and 90% of the old stat. This might be desirable if you expect some distribution shift over training and want to weight newer samples stronger.

That much is obvious. So let me try and be clearer.

Welford's algorithm is an incremental method for computing the overall mean and variance of all the data you've put into it. Is there a similar statement you can make when you use avg_factor?

For clarity, my feeling is that the answer may be something like: "It seems like with what we're trying to achieve in this new version, there is no meaning any more avg_factor and we should take it out." This is all based on the idea that the standard behaviour is to compute the mean and variance of the whole data. I might be wrong though.

@pplantinga
Copy link
Copy Markdown
Collaborator Author

It would still be really helpful for me to understand whether these are true:

  • Under norm_type=global, avg_factor=None, the expected behaviour is that the statistics represent those of all the data seen so far.

Yes, this is the expected behavior.

  • Normalisation is meant to be per-dimension, i.e. the global mean and variance have the shape of a single feature vector. (This assumption I'm getting not from the code but because e.g. HTK used to do this.)

I think the answer is yes here, assuming you meant that the reduction is done over time and batch dimensions only. The first version of the PR did not have this, but I have updated it to match.

Welford's algorithm is an incremental method for computing the overall mean and variance of all the data you've put into it. Is there a similar statement you can make when you use avg_factor?

For clarity, my feeling is that the answer may be something like: "It seems like with what we're trying to achieve in this new version, there is no meaning any more avg_factor and we should take it out." This is all based on the idea that the standard behaviour is to compute the mean and variance of the whole data. I might be wrong though.

Aha, I understand your question now. I think the avg_factor is used to achieve behavior similar to BatchNorm on the input, which the PyTorch documentation calls "momentum". But if that use case is covered by BatchNorm, perhaps it is better to just deprecate it here.

@pplantinga
Copy link
Copy Markdown
Collaborator Author

Okay, this should be ready to review once more @TParcollet @rogiervd

I have verified the performance is the same when loading old checkpoints. The only remaining step I will complete before merge is to train a new model to see if performance is roughly the same.

@TParcollet
Copy link
Copy Markdown
Collaborator

TParcollet commented Mar 3, 2025

@pplantinga I think that what @rogiervd meant is that Welford's algorithm renders the momentum meaningless? (Test still failing :p)

@rogiervd
Copy link
Copy Markdown
Contributor

I have to admit I still can't wrap my head around x - run_mean with the old mean.

Perhaps we should ditch Welford's algorithm for the sake of readability then.

Ah, maybe now I understand what's going on. Welford's algorithm per Wikipedia works for adding a single data point. That's not what we want to do, right? Presumably what we need to implement is https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm.

I also don't know that it's worth getting unbiased variance estimates (dividing by n-1 instead of n), and how this is undone when combining statistics.

I'm fine either way honestly. The main reason to keep it imo is so that the examples/tests have the variance that people expect.

Do people expect an unbiased estimate of the variance? I could be mistaken but I don't think the square root of the unbiased estimate of the variance is an unbiased estimate of the standard deviation.

My main concern about this is that I'm not sure whether, in terms of the Wikipedia page, the code correctly deals with the distinction between s^2_n and sigma^2_n.

In #2857 I offer an alternative implementation for this which in my mind is simpler. At least, I can understand it.

I'll do my best to integrate your PR into this one so this one is more understandable, while maintaining your commit / authorship credit.

In that case, wouldn't you want to review and merge my PR? Even it is pretty long.

@pplantinga
Copy link
Copy Markdown
Collaborator Author

Okay, I have merged @rogiervd 's branch into this one, to hopefully clarify the computation of the gaussian statistics. This no longer includes unbiased variance calculation nor Welford's algorithm. I added the mask and std_norm=False back in. So this should be ready to review @TParcollet.

Copy link
Copy Markdown
Collaborator

@TParcollet TParcollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only minor comment! I trust @rogiervd and @pplantinga for the correctness! This is an important change, so we may want a last review from @rogiervd to make sure we are all happy with this.

Comment thread speechbrain/processing/features.py
Comment thread speechbrain/processing/features.py Outdated
if isinstance(dim, int):
dim = (dim,)

# Use output size to compute N from the mask. Assumes N will
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming N is the count? Maybe we should clarify this.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I don't understand what this 5 lines do :/

Comment thread speechbrain/processing/features.py
Comment thread speechbrain/processing/features.py Outdated
Comment thread speechbrain/processing/features.py
Comment thread speechbrain/processing/features.py Outdated
self._load_statistics_dict(stats)


def get_mask(x, lengths=None, length_dim=1):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have 1 billion function doing this in SB, I wonder if we should do smth about it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean its unsatisfying but not a major issue I guess.

Comment thread speechbrain/processing/features.py
@TParcollet
Copy link
Copy Markdown
Collaborator

TParcollet commented Apr 2, 2025

@pplantinga any chance that you could run a simple time.time comparison to check the impact on this new code compared to the previous one on a big random batch? Then it will be good to me !

@pplantinga
Copy link
Copy Markdown
Collaborator Author

Hm, yes this is probably worth checking. I was a bit concerned about this given the complexity/length of the code and checks.

Copy link
Copy Markdown
Contributor

@rogiervd rogiervd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! This looks generally good.

I do have a number of small remaining questions and comments.

The tensor to compute the statistics over.
mask: torch.Tensor
Padding mask to exclude padding from the statistics computation.
All dimensions other than `dim` should be ones (e.g. [B, T, 1, ...])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the other dimensions should be the same as the matching dimensions in x, right? Why not say that?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

Comment thread speechbrain/processing/features.py Outdated
matching dimensions in "x".
The other dimensions should have size 1.
If None, then scalar-valued statistics will be returned.
compute_var: bool
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not used, I think?

Comment thread speechbrain/processing/features.py Outdated
If None, then scalar-valued statistics will be returned.
compute_var: bool
Whether to compute the variance, in order to speed computation
when it is not needed. Returns `None` for variance if `False`.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returns None for variance if False

Could you make this more precise? Or remove it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

The combined mean.
variance
The combined variance, relative to the new mean.
Returns None if either statistics set has variance of None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of "Returns", "Is"?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

Comment thread speechbrain/processing/features.py Outdated
"""
Update the running count, running mean, and running standard deviation
by integrating new data x from multiple processes.
def mean_std_update(x, mask, dim, run_count, run_mean, run_std=None):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove types?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ach, this was a copy error, should be fixed

Comment thread speechbrain/processing/features.py Outdated

# Convert relative lengths to absolute lengths, then compute boolean mask
max_len = x.size(length_dim)
abs_lengths = lengths.unsqueeze(1) * max_len
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(lengths * max_len).unsqueeze(1) should be faster and use less memory, right?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And shouldn't it be (lengths * max_len + some_small_constants).unsqueeze(1) to stop rounding errors?

Note e.g. on my computer (presumably IEEE 754)

>>> (15 / 22) * 22
14.999999999999998

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a test for this, and it was still not failing, though I'm not exactly sure why. I added this code for safety anyway.

Comment thread tests/unittests/test_input_norm.py Outdated
@pytest.fixture
def sample_data(self):
"""Create sample data for testing."""
# Create a batch of 3 sequences with 4 features and variable lengths
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is a two-dimensional tensor good enough?

x = torch.ones(3, 4, 2) # Batch size 3, seq len 4, feature dim 2

# Relative lengths: 100%, 75%, 50%
lengths = torch.tensor([1.0, 0.75, 0.5])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you want to test with values like 15/22?

Found by running

for numerator in range(1, 50):
  for denominator in range(numerator, 100):
    if (numerator/denominator) * denominator < numerator:
      print(numerator, denominator)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test for this.

Comment thread speechbrain/processing/features.py Outdated
dim: Union[int, tuple, None] = None,
mask: Optional[torch.Tensor] = None,
dim: Union[int, tuple, None] = None,
compute_var: bool = True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you test this? I didn't see it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

Comment thread speechbrain/processing/features.py
@TParcollet
Copy link
Copy Markdown
Collaborator

@pplantinga I've looked both at a profiler trace and a simple time.time() check, and this is not making the normalisation slower. I think that once you have fixed the small changes requested by Rogier, we can simply merge! (after the release of 1.0.2, just to have a bit of safety net :p)

Copy link
Copy Markdown
Contributor

@rogiervd rogiervd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for humouring my questions and suggestions. I think this is good!

@TParcollet
Copy link
Copy Markdown
Collaborator

@pplantinga feel free to merge this PR!

@pplantinga pplantinga merged commit 9e068c4 into speechbrain:develop Apr 26, 2025
6 of 9 checks passed
@pplantinga pplantinga deleted the fix-input-norm-stdev branch April 26, 2025 16:29
pplantinga added a commit to pplantinga/speechbrain that referenced this pull request Jun 2, 2025
…speechbrain#2835)

Co-authored-by: Rogier van Dalen <r.vandalen@samsung.com>
Co-authored-by: Parcollet Titouan <parcollet.titouan@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working correctness Functionality not objectively broken, but may be surprising or wrong e.g. regarding literature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

InputNormalization with "global" is incorrect

3 participants