@@ -223,6 +223,19 @@ def sum_stat_per_model(self):
223223 session_per_model [idx ] += self .get_model_stat1 (model ).shape [0 ]
224224 return sts_per_model , session_per_model
225225
226+ def mean_stat_per_model (self ):
227+ """Average the zero- and first-order statistics per model and store
228+ them in a new StatObject_SB.
229+
230+ Returns
231+ -------
232+ a StatObject_SB object with the statistics averaged per model.
233+ """
234+ sts_per_model , session_per_model = self .sum_stat_per_model ()
235+ sts_per_model .stat0 = sts_per_model .stat0 / session_per_model [:, None ]
236+ sts_per_model .stat1 = sts_per_model .stat1 / session_per_model [:, None ]
237+ return sts_per_model
238+
226239 def center_stat1 (self , mu ):
227240 """Center first order statistics.
228241
@@ -721,10 +734,13 @@ def fast_PLDA_scoring(
721734 enroll_ctr = copy .deepcopy (enroll )
722735 test_ctr = copy .deepcopy (test )
723736
724- # If models are not unique, compute the mean per model, display a warning
737+ # If models are not unique, require the user to average them first
725738 if not numpy .unique (enroll_ctr .modelset ).shape == enroll_ctr .modelset .shape :
726- # logging.warning("Enrollment models are not unique, average i-vectors")
727- enroll_ctr = enroll_ctr .mean_stat_per_model ()
739+ raise ValueError (
740+ "Enrollment models are not unique. Call "
741+ "enroll.mean_stat_per_model() before passing to "
742+ "fast_PLDA_scoring() to average statistics per model."
743+ )
728744
729745 # Remove missing models and test segments
730746 if check_missing :
@@ -736,11 +752,6 @@ def fast_PLDA_scoring(
736752 enroll_ctr .center_stat1 (mu )
737753 test_ctr .center_stat1 (mu )
738754
739- # If models are not unique, compute the mean per model, display a warning
740- if not numpy .unique (enroll_ctr .modelset ).shape == enroll_ctr .modelset .shape :
741- # logging.warning("Enrollment models are not unique, average i-vectors")
742- enroll_ctr = enroll_ctr .mean_stat_per_model ()
743-
744755 # Compute constant component of the PLDA distribution
745756 invSigma = linalg .inv (Sigma )
746757 I_spk = numpy .eye (F .shape [1 ], dtype = "float" )
0 commit comments