3636"""
3737
3838import math
39- from typing import Tuple , Union
39+ from typing import Optional , Tuple , Union
4040
4141import torch
4242from torch .distributed import ReduceOp
@@ -996,7 +996,11 @@ def forward(self, x):
996996 return cw_x
997997
998998
999- def gaussian_statistics (x : torch .Tensor , dim : Union [int , tuple , None ] = None ):
999+ def gaussian_statistics (
1000+ x : torch .Tensor ,
1001+ dim : Union [int , tuple , None ] = None ,
1002+ mask : Optional [torch .Tensor ] = None ,
1003+ ):
10001004 """
10011005 Compute first- and second-order moments of data, and return them as the
10021006 count, mean, and variance of a vector over one or more dimensions.
@@ -1009,6 +1013,14 @@ def gaussian_statistics(x: torch.Tensor, dim: Union[int, tuple, None] = None):
10091013 The dimension or dimensions that the statistics should be computed over.
10101014 The other dimensions are retained in the output.
10111015 If None, then scalar-valued statistics will be returned.
1016+ mask: torch.Tensor | None
1017+ A boolean tensor with True for elements that should be considered, and
1018+ False for elements that should not be considered (e.g. that are after
1019+ the end of utterances).
1020+ This tensor should have the same number of dimensions as "x".
1021+ The dimensions indicated by "dim" should have the same size as the
1022+ matching dimensions in "x".
1023+ The other dimensions should have size 1.
10121024
10131025 Returns
10141026 -------
@@ -1021,26 +1033,54 @@ def gaussian_statistics(x: torch.Tensor, dim: Union[int, tuple, None] = None):
10211033 The variance.
10221034 """
10231035
1024- if dim is None :
1025- number = math .prod (x .shape )
1026- elif isinstance (dim , int ):
1027- number = x .shape [dim ]
1028- else :
1029- assert isinstance (dim , tuple )
1030- if dim == ():
1036+ def normalise_dimensions (
1037+ x : torch .Tensor , dim : Union [int , tuple , None ]
1038+ ) -> Tuple [tuple , tuple ]:
1039+ """Normalise "dim" and return (reduce_dimensions, keep_dimensions)."""
1040+ all_dimensions = range (len (x .shape ))
1041+ if dim is None :
1042+ return (tuple (d for d in all_dimensions ), ())
1043+ elif isinstance (dim , int ):
1044+ return ((dim ,), tuple (d for d in all_dimensions if d != dim ))
1045+ else :
1046+ assert isinstance (dim , tuple )
1047+ return (dim , tuple (d for d in all_dimensions if d not in dim ))
1048+
1049+ (reduce_dimensions , keep_dimensions ) = normalise_dimensions (x , dim )
1050+
1051+ # Compute the number of elements that the statistics are computed over
1052+ # and check that the mask is shaped correctly.
1053+
1054+ # Check that the mask is shaped correctly.
1055+ if mask is not None :
1056+ assert len (mask .shape ) == len (x .shape )
1057+ for d in reduce_dimensions :
1058+ assert mask .size (d ) == x .size (d )
1059+ for d in keep_dimensions :
1060+ assert mask .size (d ) == 1
1061+
1062+ if reduce_dimensions == ():
1063+ if mask is None :
10311064 return 1 , x , torch .zeros_like (x )
1032- number = 1
1033- for d in dim :
1034- number *= x .shape [d ]
1065+ else :
1066+ # mask.numel == 1
1067+ return int (torch .sum (mask )), mask * x , torch .zeros_like (x )
1068+
1069+ if mask is None :
1070+ number = math .prod (x .size (d ) for d in reduce_dimensions )
1071+ else :
1072+ number = int (torch .sum (mask ))
1073+
1074+ masked_data = x if mask is None else mask * x
10351075
10361076 # First keep the dimensions so that broadcasting works.
1037- mean_with_dims = torch .mean (x , dim = dim , keepdim = True )
1077+ mean_with_dims = torch .mean (masked_data , dim = dim , keepdim = True )
10381078 mean = (
10391079 torch .squeeze (mean_with_dims )
10401080 if dim is None
10411081 else torch .squeeze (mean_with_dims , dim = dim )
10421082 )
1043- variance = torch .mean (torch .square (x - mean_with_dims ), dim = dim )
1083+ variance = torch .mean (torch .square (masked_data - mean_with_dims ), dim = dim )
10441084
10451085 return (number , mean , variance )
10461086
@@ -1141,17 +1181,23 @@ def combine_gaussian_statistics_distributed(
11411181 return (global_count , global_mean , global_variance )
11421182
11431183
1144- def mean_std_update (x , mask , dim , run_count , run_mean , run_std = None ):
1184+ def mean_std_update (
1185+ x : torch .Tensor ,
1186+ mask : Optional [torch .Tensor ],
1187+ dim : Union [int , tuple , None ],
1188+ run_count ,
1189+ run_mean ,
1190+ run_std = None ,
1191+ ):
11451192 """
11461193 Update the running count, running mean, and running standard deviation
11471194 by integrating new data x from multiple processes.
11481195 """
1149- assert torch .all (mask ), "Not implemented yet"
11501196
11511197 # TODO implement run_std is None
11521198 current_statistics = (run_count , run_mean , torch .square (run_std ))
11531199 new_statistics = combine_gaussian_statistics_distributed (
1154- gaussian_statistics (x , dim = dim )
1200+ gaussian_statistics (x , dim = dim , mask = mask )
11551201 )
11561202 (count , mean , variance ) = combine_gaussian_statistics (
11571203 current_statistics , new_statistics
0 commit comments