Skip to content

Commit 87a2f3b

Browse files
author
Rogier van Dalen
committed
Fix calculation with masks
1 parent 60dc766 commit 87a2f3b

2 files changed

Lines changed: 23 additions & 21 deletions

File tree

speechbrain/processing/features.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,7 +1012,10 @@ def gaussian_statistics(
10121012
dim: int | tuple | None
10131013
The dimension or dimensions that the statistics should be computed over.
10141014
The other dimensions are retained in the output.
1015-
If None, then scalar-valued statistics will be returned.
1015+
If None, then statistics will be computed over all dimensions and
1016+
scalar-valued statistics will be returned.
1017+
() has the same effect as None, which is nonsensical but it consistent
1018+
with torch.sum and friends.
10161019
mask: torch.Tensor | None
10171020
A boolean tensor with True for elements that should be considered, and
10181021
False for elements that should not be considered (e.g. that are after
@@ -1038,7 +1041,9 @@ def normalise_dimensions(
10381041
) -> Tuple[tuple, tuple]:
10391042
"""Normalise "dim" and return (reduce_dimensions, keep_dimensions)."""
10401043
all_dimensions = range(len(x.shape))
1041-
if dim is None:
1044+
if dim is None or dim == ():
1045+
# dim == () is an exceptional case and replicates the strangeness
1046+
# of torch.sum(.., dim=()) and friends.
10421047
return (tuple(d for d in all_dimensions), ())
10431048
elif isinstance(dim, int):
10441049
return ((dim,), tuple(d for d in all_dimensions if d != dim))
@@ -1059,13 +1064,6 @@ def normalise_dimensions(
10591064
for d in keep_dimensions:
10601065
assert mask.size(d) == 1
10611066

1062-
if reduce_dimensions == ():
1063-
if mask is None:
1064-
return 1, x, torch.zeros_like(x)
1065-
else:
1066-
# mask.numel == 1
1067-
return int(torch.sum(mask)), mask * x, torch.zeros_like(x)
1068-
10691067
if mask is None:
10701068
number = math.prod(x.size(d) for d in reduce_dimensions)
10711069
else:
@@ -1074,13 +1072,16 @@ def normalise_dimensions(
10741072
masked_data = x if mask is None else mask * x
10751073

10761074
# First keep the dimensions so that broadcasting works.
1077-
mean_with_dims = torch.mean(masked_data, dim=dim, keepdim=True)
1078-
mean = (
1079-
torch.squeeze(mean_with_dims)
1080-
if dim is None
1081-
else torch.squeeze(mean_with_dims, dim=dim)
1075+
sum_with_dims = torch.sum(masked_data, dim=reduce_dimensions, keepdim=True)
1076+
1077+
mean_with_dims = sum_with_dims / number
1078+
1079+
mean = torch.squeeze(mean_with_dims, dim=reduce_dimensions)
1080+
central_squared_data = torch.square(x - mean_with_dims)
1081+
masked_squared_data = (
1082+
central_squared_data if mask is None else mask * central_squared_data
10821083
)
1083-
variance = torch.mean(torch.square(masked_data - mean_with_dims), dim=dim)
1084+
variance = torch.sum(masked_squared_data, dim=reduce_dimensions) / number
10841085

10851086
return (number, mean, variance)
10861087

tests/unittests/test_input_norm.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def normalise_dimensions(
2121
"""Ensure dimensions object is a tuple."""
2222
if isinstance(dimensions, int):
2323
return (dimensions,)
24-
elif dimensions is None:
24+
elif dimensions is None or dimensions == ():
2525
# All dimensions
2626
return tuple(range(num_dimensions))
2727
assert isinstance(dimensions, tuple)
@@ -58,17 +58,18 @@ def reference_gaussian_statistics(
5858
# Start by pretending that dimensions=() and then roll them up one by one.
5959
all_count = 1
6060
masked_data = x if mask is None else mask * x
61-
mean = masked_data
62-
variance_statistics = np.square(masked_data)
61+
sum = masked_data
62+
sum_squares = np.square(masked_data)
6363

6464
for dimension in sorted(dimensions, reverse=True):
6565
all_count *= x.shape[dimension]
66-
mean = np.mean(mean, axis=dimension)
67-
variance_statistics = np.mean(variance_statistics, axis=dimension)
66+
sum = np.sum(sum, axis=dimension)
67+
sum_squares = np.sum(sum_squares, axis=dimension)
6868

6969
count = all_count if mask is None else np.sum(mask)
7070

71-
variance = variance_statistics - np.square(mean)
71+
mean = sum / count
72+
variance = (sum_squares / count) - np.square(mean)
7273

7374
return count, mean, variance
7475

0 commit comments

Comments
 (0)