Skip to content

Commit 57bbb9b

Browse files
author
Rogier van Dalen
committed
Add support for masks
1 parent 557260a commit 57bbb9b

2 files changed

Lines changed: 116 additions & 32 deletions

File tree

speechbrain/processing/features.py

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"""
3737

3838
import math
39-
from typing import Tuple, Union
39+
from typing import Optional, Tuple, Union
4040

4141
import torch
4242
from torch.distributed import ReduceOp
@@ -996,7 +996,11 @@ def forward(self, x):
996996
return cw_x
997997

998998

999-
def gaussian_statistics(x: torch.Tensor, dim: Union[int, tuple, None] = None):
999+
def gaussian_statistics(
1000+
x: torch.Tensor,
1001+
dim: Union[int, tuple, None] = None,
1002+
mask: Optional[torch.Tensor] = None,
1003+
):
10001004
"""
10011005
Compute first- and second-order moments of data, and return them as the
10021006
count, mean, and variance of a vector over one or more dimensions.
@@ -1009,6 +1013,14 @@ def gaussian_statistics(x: torch.Tensor, dim: Union[int, tuple, None] = None):
10091013
The dimension or dimensions that the statistics should be computed over.
10101014
The other dimensions are retained in the output.
10111015
If None, then scalar-valued statistics will be returned.
1016+
mask: torch.Tensor | None
1017+
A boolean tensor with True for elements that should be considered, and
1018+
False for elements that should not be considered (e.g. that are after
1019+
the end of utterances).
1020+
This tensor should have the same number of dimensions as "x".
1021+
The dimensions indicated by "dim" should have the same size as the
1022+
matching dimensions in "x".
1023+
The other dimensions should have size 1.
10121024
10131025
Returns
10141026
-------
@@ -1021,26 +1033,54 @@ def gaussian_statistics(x: torch.Tensor, dim: Union[int, tuple, None] = None):
10211033
The variance.
10221034
"""
10231035

1024-
if dim is None:
1025-
number = math.prod(x.shape)
1026-
elif isinstance(dim, int):
1027-
number = x.shape[dim]
1028-
else:
1029-
assert isinstance(dim, tuple)
1030-
if dim == ():
1036+
def normalise_dimensions(
1037+
x: torch.Tensor, dim: Union[int, tuple, None]
1038+
) -> Tuple[tuple, tuple]:
1039+
"""Normalise "dim" and return (reduce_dimensions, keep_dimensions)."""
1040+
all_dimensions = range(len(x.shape))
1041+
if dim is None:
1042+
return (tuple(d for d in all_dimensions), ())
1043+
elif isinstance(dim, int):
1044+
return ((dim,), tuple(d for d in all_dimensions if d != dim))
1045+
else:
1046+
assert isinstance(dim, tuple)
1047+
return (dim, tuple(d for d in all_dimensions if d not in dim))
1048+
1049+
(reduce_dimensions, keep_dimensions) = normalise_dimensions(x, dim)
1050+
1051+
# Compute the number of elements that the statistics are computed over
1052+
# and check that the mask is shaped correctly.
1053+
1054+
# Check that the mask is shaped correctly.
1055+
if mask is not None:
1056+
assert len(mask.shape) == len(x.shape)
1057+
for d in reduce_dimensions:
1058+
assert mask.size(d) == x.size(d)
1059+
for d in keep_dimensions:
1060+
assert mask.size(d) == 1
1061+
1062+
if reduce_dimensions == ():
1063+
if mask is None:
10311064
return 1, x, torch.zeros_like(x)
1032-
number = 1
1033-
for d in dim:
1034-
number *= x.shape[d]
1065+
else:
1066+
# mask.numel == 1
1067+
return int(torch.sum(mask)), mask * x, torch.zeros_like(x)
1068+
1069+
if mask is None:
1070+
number = math.prod(x.size(d) for d in reduce_dimensions)
1071+
else:
1072+
number = int(torch.sum(mask))
1073+
1074+
masked_data = x if mask is None else mask * x
10351075

10361076
# First keep the dimensions so that broadcasting works.
1037-
mean_with_dims = torch.mean(x, dim=dim, keepdim=True)
1077+
mean_with_dims = torch.mean(masked_data, dim=dim, keepdim=True)
10381078
mean = (
10391079
torch.squeeze(mean_with_dims)
10401080
if dim is None
10411081
else torch.squeeze(mean_with_dims, dim=dim)
10421082
)
1043-
variance = torch.mean(torch.square(x - mean_with_dims), dim=dim)
1083+
variance = torch.mean(torch.square(masked_data - mean_with_dims), dim=dim)
10441084

10451085
return (number, mean, variance)
10461086

@@ -1141,17 +1181,23 @@ def combine_gaussian_statistics_distributed(
11411181
return (global_count, global_mean, global_variance)
11421182

11431183

1144-
def mean_std_update(x, mask, dim, run_count, run_mean, run_std=None):
1184+
def mean_std_update(
1185+
x: torch.Tensor,
1186+
mask: Optional[torch.Tensor],
1187+
dim: Union[int, tuple, None],
1188+
run_count,
1189+
run_mean,
1190+
run_std=None,
1191+
):
11451192
"""
11461193
Update the running count, running mean, and running standard deviation
11471194
by integrating new data x from multiple processes.
11481195
"""
1149-
assert torch.all(mask), "Not implemented yet"
11501196

11511197
# TODO implement run_std is None
11521198
current_statistics = (run_count, run_mean, torch.square(run_std))
11531199
new_statistics = combine_gaussian_statistics_distributed(
1154-
gaussian_statistics(x, dim=dim)
1200+
gaussian_statistics(x, dim=dim, mask=mask)
11551201
)
11561202
(count, mean, variance) = combine_gaussian_statistics(
11571203
current_statistics, new_statistics

tests/unittests/test_input_norm.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""
33

44
import functools
5-
from typing import List, Tuple, Union
5+
from typing import List, Optional, Tuple, Union
66

77
import numpy as np
88
import pytest
@@ -15,31 +15,59 @@
1515
)
1616

1717

18+
def normalise_dimensions(
19+
dimensions: Union[int, tuple, None], num_dimensions: int
20+
):
21+
"""Ensure dimensions object is a tuple."""
22+
if isinstance(dimensions, int):
23+
return (dimensions,)
24+
elif dimensions is None:
25+
# All dimensions
26+
return tuple(range(num_dimensions))
27+
assert isinstance(dimensions, tuple)
28+
return dimensions
29+
30+
31+
def random_mask_numpy(
32+
generator: np.random.Generator,
33+
data_shape: tuple,
34+
dimensions: Union[int, tuple, None],
35+
):
36+
dimensions_set = set(normalise_dimensions(dimensions, len(data_shape)))
37+
38+
mask_shape = tuple(
39+
(data_shape[d] if d in dimensions_set else 1)
40+
for d in range(len(data_shape))
41+
)
42+
43+
return generator.integers(0, 2, size=mask_shape, dtype=bool)
44+
45+
1846
def reference_gaussian_statistics(
19-
x: np.ndarray, dimensions: Union[int, tuple, None]
47+
x: np.ndarray,
48+
dimensions: Union[int, tuple, None],
49+
mask: Optional[np.ndarray],
2050
) -> Tuple[int, np.ndarray, np.ndarray]:
2151
"""
2252
Compute reference count, mean, variance with Numpy, in the simplest way
2353
possible.
2454
"""
2555
# Ensure dimensions object is a tuple.
26-
if isinstance(dimensions, int):
27-
dimensions = (dimensions,)
28-
elif dimensions is None:
29-
# All dimensions
30-
dimensions = tuple(range(len(x.shape)))
31-
assert isinstance(dimensions, tuple)
56+
dimensions = normalise_dimensions(dimensions, len(x.shape))
3257

3358
# Start by pretending that dimensions=() and then roll them up one by one.
34-
count = 1
35-
mean = x
36-
variance_statistics = np.square(x)
59+
all_count = 1
60+
masked_data = x if mask is None else mask * x
61+
mean = masked_data
62+
variance_statistics = np.square(masked_data)
3763

3864
for dimension in sorted(dimensions, reverse=True):
39-
count *= x.shape[dimension]
65+
all_count *= x.shape[dimension]
4066
mean = np.mean(mean, axis=dimension)
4167
variance_statistics = np.mean(variance_statistics, axis=dimension)
4268

69+
count = all_count if mask is None else np.sum(mask)
70+
4371
variance = variance_statistics - np.square(mean)
4472

4573
return count, mean, variance
@@ -67,7 +95,8 @@ def reference_gaussian_statistics(
6795
(0, 1, 3),
6896
],
6997
)
70-
def test_gaussian_statistics(size, dimensions):
98+
@pytest.mark.parametrize("use_mask", [False, True])
99+
def test_gaussian_statistics(size, dimensions, use_mask: bool):
71100
if isinstance(dimensions, tuple):
72101
if any(dimension >= len(size) for dimension in dimensions):
73102
return
@@ -78,11 +107,20 @@ def test_gaussian_statistics(size, dimensions):
78107

79108
x = generator.uniform(low=-5, high=+5, size=size)
80109

110+
if use_mask:
111+
mask = random_mask_numpy(generator, size, dimensions)
112+
else:
113+
mask = None
114+
81115
reference_count, reference_mean, reference_variance = (
82-
reference_gaussian_statistics(x, dimensions=dimensions)
116+
reference_gaussian_statistics(x, dimensions=dimensions, mask=mask)
83117
)
84118

85-
count, mean, variance = gaussian_statistics(torch.tensor(x), dim=dimensions)
119+
count, mean, variance = gaussian_statistics(
120+
torch.tensor(x),
121+
dim=dimensions,
122+
mask=None if mask is None else torch.tensor(mask),
123+
)
86124

87125
assert count == reference_count
88126
assert mean.shape == reference_mean.shape

0 commit comments

Comments
 (0)