Skip to content

Commit a601051

Browse files
authored
Add missing mean_stat_per_model method to StatObject_SB (#3029)
1 parent 1f98c52 commit a601051

1 file changed

Lines changed: 19 additions & 8 deletions

File tree

speechbrain/processing/PLDA_LDA.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,19 @@ def sum_stat_per_model(self):
223223
session_per_model[idx] += self.get_model_stat1(model).shape[0]
224224
return sts_per_model, session_per_model
225225

226+
def mean_stat_per_model(self):
227+
"""Average the zero- and first-order statistics per model and store
228+
them in a new StatObject_SB.
229+
230+
Returns
231+
-------
232+
a StatObject_SB object with the statistics averaged per model.
233+
"""
234+
sts_per_model, session_per_model = self.sum_stat_per_model()
235+
sts_per_model.stat0 = sts_per_model.stat0 / session_per_model[:, None]
236+
sts_per_model.stat1 = sts_per_model.stat1 / session_per_model[:, None]
237+
return sts_per_model
238+
226239
def center_stat1(self, mu):
227240
"""Center first order statistics.
228241
@@ -721,10 +734,13 @@ def fast_PLDA_scoring(
721734
enroll_ctr = copy.deepcopy(enroll)
722735
test_ctr = copy.deepcopy(test)
723736

724-
# If models are not unique, compute the mean per model, display a warning
737+
# If models are not unique, require the user to average them first
725738
if not numpy.unique(enroll_ctr.modelset).shape == enroll_ctr.modelset.shape:
726-
# logging.warning("Enrollment models are not unique, average i-vectors")
727-
enroll_ctr = enroll_ctr.mean_stat_per_model()
739+
raise ValueError(
740+
"Enrollment models are not unique. Call "
741+
"enroll.mean_stat_per_model() before passing to "
742+
"fast_PLDA_scoring() to average statistics per model."
743+
)
728744

729745
# Remove missing models and test segments
730746
if check_missing:
@@ -736,11 +752,6 @@ def fast_PLDA_scoring(
736752
enroll_ctr.center_stat1(mu)
737753
test_ctr.center_stat1(mu)
738754

739-
# If models are not unique, compute the mean per model, display a warning
740-
if not numpy.unique(enroll_ctr.modelset).shape == enroll_ctr.modelset.shape:
741-
# logging.warning("Enrollment models are not unique, average i-vectors")
742-
enroll_ctr = enroll_ctr.mean_stat_per_model()
743-
744755
# Compute constant component of the PLDA distribution
745756
invSigma = linalg.inv(Sigma)
746757
I_spk = numpy.eye(F.shape[1], dtype="float")

0 commit comments

Comments
 (0)