@@ -1012,7 +1012,10 @@ def gaussian_statistics(
10121012 dim: int | tuple | None
10131013 The dimension or dimensions that the statistics should be computed over.
10141014 The other dimensions are retained in the output.
1015- If None, then scalar-valued statistics will be returned.
1015+ If None, then statistics will be computed over all dimensions and
1016+ scalar-valued statistics will be returned.
1017+ () has the same effect as None, which is nonsensical but it consistent
1018+ with torch.sum and friends.
10161019 mask: torch.Tensor | None
10171020 A boolean tensor with True for elements that should be considered, and
10181021 False for elements that should not be considered (e.g. that are after
@@ -1038,7 +1041,9 @@ def normalise_dimensions(
10381041 ) -> Tuple [tuple , tuple ]:
10391042 """Normalise "dim" and return (reduce_dimensions, keep_dimensions)."""
10401043 all_dimensions = range (len (x .shape ))
1041- if dim is None :
1044+ if dim is None or dim == ():
1045+ # dim == () is an exceptional case and replicates the strangeness
1046+ # of torch.sum(.., dim=()) and friends.
10421047 return (tuple (d for d in all_dimensions ), ())
10431048 elif isinstance (dim , int ):
10441049 return ((dim ,), tuple (d for d in all_dimensions if d != dim ))
@@ -1059,13 +1064,6 @@ def normalise_dimensions(
10591064 for d in keep_dimensions :
10601065 assert mask .size (d ) == 1
10611066
1062- if reduce_dimensions == ():
1063- if mask is None :
1064- return 1 , x , torch .zeros_like (x )
1065- else :
1066- # mask.numel == 1
1067- return int (torch .sum (mask )), mask * x , torch .zeros_like (x )
1068-
10691067 if mask is None :
10701068 number = math .prod (x .size (d ) for d in reduce_dimensions )
10711069 else :
@@ -1074,13 +1072,16 @@ def normalise_dimensions(
10741072 masked_data = x if mask is None else mask * x
10751073
10761074 # First keep the dimensions so that broadcasting works.
1077- mean_with_dims = torch .mean (masked_data , dim = dim , keepdim = True )
1078- mean = (
1079- torch .squeeze (mean_with_dims )
1080- if dim is None
1081- else torch .squeeze (mean_with_dims , dim = dim )
1075+ sum_with_dims = torch .sum (masked_data , dim = reduce_dimensions , keepdim = True )
1076+
1077+ mean_with_dims = sum_with_dims / number
1078+
1079+ mean = torch .squeeze (mean_with_dims , dim = reduce_dimensions )
1080+ central_squared_data = torch .square (x - mean_with_dims )
1081+ masked_squared_data = (
1082+ central_squared_data if mask is None else mask * central_squared_data
10821083 )
1083- variance = torch .mean ( torch . square ( masked_data - mean_with_dims ) , dim = dim )
1084+ variance = torch .sum ( masked_squared_data , dim = reduce_dimensions ) / number
10841085
10851086 return (number , mean , variance )
10861087
0 commit comments