-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathmetric_stats.py
More file actions
1425 lines (1234 loc) · 48.2 KB
/
metric_stats.py
File metadata and controls
1425 lines (1234 loc) · 48.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""The ``metric_stats`` module provides an abstract class for storing
statistics produced over the course of an experiment and summarizing them.
Authors:
* Peter Plantinga 2020
* Mirco Ravanelli 2020
* Gaëlle Laperrière 2021
* Sahar Ghannay 2021
"""
from typing import Callable, Optional
import torch
from joblib import Parallel, delayed
from speechbrain.dataio.dataio import (
extract_concepts_values,
merge_char,
split_word,
)
from speechbrain.dataio.wer import print_alignments, print_wer_summary
from speechbrain.utils.data_utils import undo_padding
from speechbrain.utils.edit_distance import (
EDIT_SYMBOLS,
_str_equals,
wer_details_for_batch,
wer_summary,
)
class MetricStats:
"""A default class for storing and summarizing arbitrary metrics.
More complex metrics can be created by sub-classing this class.
Arguments
---------
metric : function
The function to use to compute the relevant metric. Should take
at least two arguments (predictions and targets) and can
optionally take the relative lengths of either or both arguments.
Not usually used in sub-classes.
n_jobs : int
The number of jobs to use for computing the metric. If this is
more than one, every sample is processed individually, otherwise
the whole batch is passed at once.
batch_eval : bool
When True it feeds the evaluation metric with the batched input.
When False and n_jobs=1, it performs metric evaluation one-by-one
in a sequential way. When False and n_jobs>1, the evaluation
runs in parallel over the different inputs using joblib.
Example
-------
>>> from speechbrain.nnet.losses import l1_loss
>>> loss_stats = MetricStats(metric=l1_loss)
>>> loss_stats.append(
... ids=["utterance1", "utterance2"],
... predictions=torch.tensor([[0.1, 0.2], [0.2, 0.3]]),
... targets=torch.tensor([[0.1, 0.2], [0.1, 0.2]]),
... reduction="batch",
... )
>>> stats = loss_stats.summarize()
>>> stats["average"]
0.050...
>>> stats["max_score"]
0.100...
>>> stats["max_id"]
'utterance2'
"""
def __init__(self, metric, n_jobs=1, batch_eval=True):
self.metric = metric
self.n_jobs = n_jobs
self.batch_eval = batch_eval
self.clear()
def clear(self):
"""Creates empty container for storage, removing existing stats."""
self.scores = []
self.ids = []
self.summary = {}
def append(self, ids, *args, **kwargs):
"""Store a particular set of metric scores.
Arguments
---------
ids : list
List of ids corresponding to utterances.
*args : tuple
Arguments to pass to the metric function.
**kwargs : dict
Arguments to pass to the metric function.
"""
self.ids.extend(ids)
# Batch evaluation
if self.batch_eval:
scores = self.metric(*args, **kwargs).detach()
else:
if "predict" not in kwargs or "target" not in kwargs:
raise ValueError(
"Must pass 'predict' and 'target' as kwargs if batch_eval=False"
)
if self.n_jobs == 1:
# Sequence evaluation (loop over inputs)
scores = sequence_evaluation(metric=self.metric, **kwargs)
else:
# Multiprocess evaluation
scores = multiprocess_evaluation(
metric=self.metric, n_jobs=self.n_jobs, **kwargs
)
self.scores.extend(scores)
def summarize(self, field=None):
"""Summarize the metric scores, returning relevant stats.
Arguments
---------
field : str
If provided, only returns selected statistic. If not,
returns all computed statistics.
Returns
-------
float or dict
Returns a float if ``field`` is provided, otherwise
returns a dictionary containing all computed stats.
"""
min_index = torch.argmin(torch.tensor(self.scores))
max_index = torch.argmax(torch.tensor(self.scores))
self.summary = {
"average": float(sum(self.scores) / len(self.scores)),
"min_score": float(self.scores[min_index]),
"min_id": self.ids[min_index],
"max_score": float(self.scores[max_index]),
"max_id": self.ids[max_index],
}
if field is not None:
return self.summary[field]
else:
return self.summary
def write_stats(self, filestream, verbose=False):
"""Write all relevant statistics to file.
Arguments
---------
filestream : file-like object
A stream for the stats to be written to.
verbose : bool
Whether to also print the stats to stdout.
"""
if not self.summary:
self.summarize()
message = f"Average score: {self.summary['average']}\n"
message += f"Min error: {self.summary['min_score']} "
message += f"id: {self.summary['min_id']}\n"
message += f"Max error: {self.summary['max_score']} "
message += f"id: {self.summary['max_id']}\n"
filestream.write(message)
if verbose:
print(message)
def multiprocess_evaluation(metric, predict, target, lengths=None, n_jobs=8):
"""Runs metric evaluation if parallel over multiple jobs."""
if lengths is not None:
lengths = (lengths * predict.size(1)).round().int().cpu()
predict = [p[:length].cpu() for p, length in zip(predict, lengths)]
target = [t[:length].cpu() for t, length in zip(target, lengths)]
while True:
try:
scores = Parallel(n_jobs=n_jobs, timeout=30)(
delayed(metric)(p, t) for p, t in zip(predict, target)
)
break
except Exception as e:
print(e)
print("Evaluation timeout...... (will try again)")
return scores
def sequence_evaluation(metric, predict, target, lengths=None):
"""Runs metric evaluation sequentially over the inputs."""
if lengths is not None:
lengths = (lengths * predict.size(1)).round().int().cpu()
predict = [p[:length].cpu() for p, length in zip(predict, lengths)]
target = [t[:length].cpu() for t, length in zip(target, lengths)]
scores = []
for p, t in zip(predict, target):
score = metric(p, t)
scores.append(score)
return scores
class ErrorRateStats(MetricStats):
"""A class for tracking error rates (e.g., WER, PER).
Arguments
---------
merge_tokens : bool
Whether to merge the successive tokens (used for e.g.,
creating words out of character tokens).
See ``speechbrain.dataio.dataio.merge_char``.
split_tokens : bool
Whether to split tokens (used for e.g. creating
characters out of word tokens).
See ``speechbrain.dataio.dataio.split_word``.
space_token : str
The character to use for boundaries. Used with ``merge_tokens``
this represents character to split on after merge.
Used with ``split_tokens`` the sequence is joined with
this token in between, and then the whole sequence is split.
keep_values : bool
Whether to keep the values of the concepts or not.
extract_concepts_values : bool
Process the predict and target to keep only concepts and values.
tag_in : str
Start of the concept ('<' for example).
tag_out : str
End of the concept ('>' for example).
equality_comparator : Callable[[str, str], bool]
The function used to check whether two words are equal.
Example
-------
>>> cer_stats = ErrorRateStats()
>>> i2l = {0: "a", 1: "b"}
>>> cer_stats.append(
... ids=["utterance1"],
... predict=torch.tensor([[0, 1, 1]]),
... target=torch.tensor([[0, 1, 0]]),
... target_len=torch.ones(1),
... ind2lab=lambda batch: [[i2l[int(x)] for x in seq] for seq in batch],
... )
>>> stats = cer_stats.summarize()
>>> stats["WER"]
33.33...
>>> stats["insertions"]
0
>>> stats["deletions"]
0
>>> stats["substitutions"]
1
"""
def __init__(
self,
merge_tokens=False,
split_tokens=False,
space_token="_",
keep_values=True,
extract_concepts_values=False,
tag_in="",
tag_out="",
equality_comparator: Callable[[str, str], bool] = _str_equals,
):
self.clear()
self.merge_tokens = merge_tokens
self.split_tokens = split_tokens
self.space_token = space_token
self.extract_concepts_values = extract_concepts_values
self.keep_values = keep_values
self.tag_in = tag_in
self.tag_out = tag_out
self.equality_comparator = equality_comparator
def append(
self,
ids,
predict,
target,
predict_len=None,
target_len=None,
ind2lab=None,
):
"""Add stats to the relevant containers.
* See MetricStats.append()
Arguments
---------
ids : list
List of ids corresponding to utterances.
predict : torch.tensor
A predicted output, for comparison with the target output
target : torch.tensor
The correct reference output, for comparison with the prediction.
predict_len : torch.tensor
The predictions relative lengths, used to undo padding if
there is padding present in the predictions.
target_len : torch.tensor
The target outputs' relative lengths, used to undo padding if
there is padding present in the target.
ind2lab : callable
Callable that maps from indices to labels, operating on batches,
for writing alignments.
"""
self.ids.extend(ids)
if predict_len is not None:
predict = undo_padding(predict, predict_len)
if target_len is not None:
target = undo_padding(target, target_len)
if ind2lab is not None:
predict = ind2lab(predict)
target = ind2lab(target)
if self.merge_tokens:
predict = merge_char(predict, space=self.space_token)
target = merge_char(target, space=self.space_token)
if self.split_tokens:
predict = split_word(predict, space=self.space_token)
target = split_word(target, space=self.space_token)
if self.extract_concepts_values:
predict = extract_concepts_values(
predict,
self.keep_values,
self.tag_in,
self.tag_out,
space=self.space_token,
)
target = extract_concepts_values(
target,
self.keep_values,
self.tag_in,
self.tag_out,
space=self.space_token,
)
scores = wer_details_for_batch(
ids,
target,
predict,
compute_alignments=True,
equality_comparator=self.equality_comparator,
)
self.scores.extend(scores)
def summarize(self, field=None):
"""Summarize the error_rate and return relevant statistics.
* See MetricStats.summarize()
"""
self.summary = wer_summary(self.scores)
# Add additional, more generic key
self.summary["error_rate"] = self.summary["WER"]
if field is not None:
return self.summary[field]
else:
return self.summary
def write_stats(self, filestream):
"""Write all relevant info (e.g., error rate alignments) to file.
* See MetricStats.write_stats()
"""
if not self.summary:
self.summarize()
print_wer_summary(self.summary, filestream)
print_alignments(self.scores, filestream)
class WeightedErrorRateStats(MetricStats):
"""Metric that reweighs the WER from :class:`~ErrorRateStats` with any
chosen method. This does not edit the sequence of found edits
(insertion/deletion/substitution) but multiplies their impact on the metric
by a value between 0 and 1 as returned by the cost function.
Arguments
---------
base_stats : ErrorRateStats
The base WER calculator to use.
cost_function : Callable[[str, Optional[str], Optional[str]], float]
Cost function of signature `fn(edit_symbol, a, b) -> float`, where the
returned value, between 0 and 1, is the weight that should be assigned
to a particular edit in the weighted WER calculation.
In the case of insertions and deletions, either of `a` or `b` may be
`None`. In the case of substitutions, `a` and `b` will never be `None`.
weight_name : str
Prefix to be prepended to each metric name (e.g. `xxx_wer`)
"""
def __init__(
self,
base_stats: ErrorRateStats,
cost_function: Callable[[str, Optional[str], Optional[str]], float],
weight_name: str = "weighted",
):
self.clear()
self.base_stats = base_stats
self.cost_function = cost_function
self.weight_name = weight_name
def append(self, *args, **kwargs):
"""Append function, which should **NOT** be used for the weighted error
rate stats. Please append to the specified `base_stats` instead.
`WeightedErrorRateStats` reuses the scores from the base
:class:`~ErrorRateStats` class.
Arguments
---------
*args : tuple
Ignored.
**kwargs : dict
Ignored.
"""
raise ValueError(
"Cannot append to a WeightedErrorRateStats. "
"You should only append to the base ErrorRateStats."
)
def summarize(self, field=None):
"""Returns a dict containing some detailed WER statistics after
weighting every edit with a weight determined by `cost_function`
(returning `0.0` for no error, `1.0` for the default error behavior, and
anything in between).
Does not require :meth:`~ErrorRateStats.summarize` to have been called.
Full set of fields, **each of which are prepended with
`<weight_name_specified_at_init>_`**:
- `wer`: Weighted WER (ratio `*100`)
- `insertions`: Weighted insertions
- `substitutions`: Weighted substitutions
- `deletions`: Weighted deletions
- `num_edits`: Sum of weighted insertions/substitutions/deletions
Additionally, a `scores` list is populated by this function for each
pair of sentences. Each entry of that list is a dict, with the fields:
- `key`: the ID of the utterance.
- `WER`, `insertions`, `substitutions`, `deletions`, `num_edits` with
the same semantics as described above, but at sentence level rather
than global.
Arguments
---------
field : str, optional
The field to return, if you are only interested in one of them.
If specified, a single `float` is returned, otherwise, a dict is.
Returns
-------
dict from str to float, if `field is None`
A dictionary of the fields documented above.
float, if `field is not None`
The single field selected by `field`.
"""
weighted_insertions = 0.0
weighted_substitutions = 0.0
weighted_deletions = 0.0
total = 0.0
for i, utterance in enumerate(self.base_stats.scores):
utt_weighted_insertions = 0.0
utt_weighted_substitutions = 0.0
utt_weighted_deletions = 0.0
utt_total = 0.0
for edit_symbol, a_idx, b_idx in utterance["alignment"]:
a = (
utterance["ref_tokens"][a_idx]
if a_idx is not None
else None
)
b = (
utterance["hyp_tokens"][b_idx]
if b_idx is not None
else None
)
if edit_symbol != EDIT_SYMBOLS["eq"]:
pair_score = self.cost_function(edit_symbol, a, b)
if edit_symbol == EDIT_SYMBOLS["ins"]:
utt_weighted_insertions += pair_score
elif edit_symbol == EDIT_SYMBOLS["del"]:
utt_weighted_deletions += pair_score
elif edit_symbol == EDIT_SYMBOLS["sub"]:
utt_weighted_substitutions += pair_score
utt_total += 1.0
utt_weighted_edits = (
utt_weighted_insertions
+ utt_weighted_substitutions
+ utt_weighted_deletions
)
utt_weighted_wer_ratio = utt_weighted_edits / utt_total
self.scores.append(
{
"key": self.base_stats.ids[i],
"WER": utt_weighted_wer_ratio * 100.0,
"insertions": utt_weighted_insertions,
"substitutions": utt_weighted_substitutions,
"deletions": utt_weighted_deletions,
"num_edits": utt_weighted_edits,
}
)
weighted_insertions += utt_weighted_insertions
weighted_substitutions += utt_weighted_substitutions
weighted_deletions += utt_weighted_deletions
total += utt_total
weighted_edits = (
weighted_insertions + weighted_substitutions + weighted_deletions
)
weighted_wer_ratio = weighted_edits / total
self.summary = {
f"{self.weight_name}_wer": weighted_wer_ratio * 100.0,
f"{self.weight_name}_insertions": weighted_insertions,
f"{self.weight_name}_substitutions": weighted_substitutions,
f"{self.weight_name}_deletions": weighted_deletions,
f"{self.weight_name}_num_edits": weighted_edits,
}
if field is not None:
return self.summary[field]
else:
return self.summary
def write_stats(self, filestream):
"""Write all relevant info to file; here, only the weighted info as
returned by `summarize`.
See :meth:`~ErrorRateStats.write_stats`.
"""
if not self.summary:
self.summarize()
print(f"Weighted WER metrics ({self.weight_name}):", file=filestream)
for k, v in self.summary.items():
print(f"{k}: {v}", file=filestream)
class EmbeddingErrorRateSimilarity:
"""Implements the similarity function from the EmbER metric as defined by
https://www.isca-archive.org/interspeech_2022/roux22_interspeech.pdf
This metric involves a dictionary to map a token to a single word embedding.
Substitutions in the WER get weighted down when the embeddings are similar
enough. The goal is to reduce the impact of substitution errors with small
semantic impact. Only substitution errors get weighted.
This is done by computing the cosine similarity between the two embeddings,
then weighing the substitution with `low_similarity_weight` if
`similarity >= threshold` or with `high_similarity_weight` otherwise (e.g.
a substitution with high similarity could be weighted down to matter 10% as
much as a substitution with low similarity).
.. note ::
The cited paper recommended `(1.0, 0.1, 0.4)` as defaults for fastTexst
French embeddings, chosen empirically. When using different embeddings,
you might want to test other values; thus we don't provide defaults.
Arguments
---------
embedding_function : Callable[[str], Optional[torch.Tensor]]
Function that returns an embedding (as a :class:`torch.Tensor`) from a
word. If no corresponding embedding could be found for the word, should
return `None`. In that case, `low_similarity_weight` will be chosen.
low_similarity_weight : float
Weight applied to the substitution if `cosine_similarity < threshold`.
high_similarity_weight : float
Weight applied to the substitution if `cosine_similarity >= threshold`.
threshold : float
Cosine similarity threshold used to select by how much a substitution
error should be weighed for this word.
"""
def __init__(
self,
embedding_function: Callable[[str], Optional[torch.Tensor]],
low_similarity_weight: float,
high_similarity_weight: float,
threshold: float,
):
self.embedding_function = embedding_function
self.low_similarity_weight = low_similarity_weight
self.high_similarity_weight = high_similarity_weight
self.threshold = threshold
def __call__(
self, edit_symbol: str, a: Optional[str], b: Optional[str]
) -> float:
"""Returns the weight that should be associated with a specific edit
in the WER calculation.
Compatible candidate for the cost function of
:class:`~WeightedErrorRateStats` so an instance of this class can be
passed as a `cost_function`.
Arguments
---------
edit_symbol: str
Edit symbol as assigned by the WER functions, see `EDIT_SYMBOLS`.
a: str, optional
First word to compare (if present)
b: str, optional
Second word to compare (if present)
Returns
-------
float
Weight to assign to the edit.
For actual edits, either `low_similarity_weight` or
`high_similarity_weight` depending on the embedding distance and
threshold.
"""
if edit_symbol in (EDIT_SYMBOLS["ins"], EDIT_SYMBOLS["del"]):
return 1.0
if edit_symbol == EDIT_SYMBOLS["sub"]:
if a is None or a == "":
return self.low_similarity_weight
if b is None or b == "":
return self.low_similarity_weight
a_emb = self.embedding_function(a)
if a_emb is None:
return self.low_similarity_weight
b_emb = self.embedding_function(b)
if b_emb is None:
return self.low_similarity_weight
similarity = torch.nn.functional.cosine_similarity(
a_emb, b_emb, dim=0
).item()
if similarity >= self.threshold:
return self.high_similarity_weight
return self.low_similarity_weight
# eq
return 0.0
class BinaryMetricStats(MetricStats):
"""Tracks binary metrics, such as precision, recall, F1, EER, etc."""
def __init__(self, positive_label=1):
self.clear()
self.positive_label = positive_label
def clear(self):
"""Clears the stored metrics."""
self.ids = []
self.scores = []
self.labels = []
self.summary = {}
def append(self, ids, scores, labels):
"""Appends scores and labels to internal lists.
Does not compute metrics until time of summary, since
automatic thresholds (e.g., EER) need full set of scores.
Arguments
---------
ids : list
The string ids for the samples.
scores : list
The scores corresponding to the ids.
labels : list
The labels corresponding to the ids.
"""
self.ids.extend(ids)
self.scores.extend(scores.detach())
self.labels.extend(labels.detach())
def summarize(
self, field=None, threshold=None, max_samples=None, beta=1, eps=1e-8
):
"""Compute statistics using a full set of scores.
Full set of fields:
- TP - True Positive
- TN - True Negative
- FP - False Positive
- FN - False Negative
- FAR - False Acceptance Rate
- FRR - False Rejection Rate
- DER - Detection Error Rate (EER if no threshold passed)
- threshold - threshold (EER threshold if no threshold passed)
- precision - Precision (positive predictive value)
- recall - Recall (sensitivity)
- F-score - Balance of precision and recall (equal if beta=1)
- MCC - Matthews Correlation Coefficient
Arguments
---------
field : str
A key for selecting a single statistic. If not provided,
a dict with all statistics is returned.
threshold : float
If no threshold is provided, equal error rate is used.
max_samples: float
How many samples to keep for positive/negative scores.
If no max_samples is provided, all scores are kept.
Only effective when threshold is None.
beta : float
How much to weight precision vs recall in F-score. Default
of 1. is equal weight, while higher values weight recall
higher, and lower values weight precision higher.
eps : float
A small value to avoid dividing by zero.
Returns
-------
summary
if field is specified, only returns the score for that field.
if field is None, returns the full set of fields.
"""
if isinstance(self.scores, list):
self.scores = torch.stack(self.scores)
self.labels = torch.stack(self.labels)
if threshold is None:
positive_scores = self.scores[
(self.labels == self.positive_label).nonzero(as_tuple=True)
]
negative_scores = self.scores[
(self.labels != self.positive_label).nonzero(as_tuple=True)
]
if max_samples is not None:
if len(positive_scores) > max_samples:
positive_scores, _ = torch.sort(positive_scores)
positive_scores = positive_scores[
[
i
for i in range(
0,
len(positive_scores),
int(len(positive_scores) / max_samples),
)
]
]
if len(negative_scores) > max_samples:
negative_scores, _ = torch.sort(negative_scores)
negative_scores = negative_scores[
[
i
for i in range(
0,
len(negative_scores),
int(len(negative_scores) / max_samples),
)
]
]
eer, threshold = EER(positive_scores, negative_scores)
pred = (self.scores > threshold).float()
true = self.labels
TP = self.summary["TP"] = float(pred.mul(true).sum())
TN = self.summary["TN"] = float((1.0 - pred).mul(1.0 - true).sum())
FP = self.summary["FP"] = float(pred.mul(1.0 - true).sum())
FN = self.summary["FN"] = float((1.0 - pred).mul(true).sum())
self.summary["FAR"] = FP / (FP + TN + eps)
self.summary["FRR"] = FN / (TP + FN + eps)
self.summary["DER"] = (FP + FN) / (TP + TN + eps)
self.summary["threshold"] = threshold
self.summary["precision"] = TP / (TP + FP + eps)
self.summary["recall"] = TP / (TP + FN + eps)
self.summary["F-score"] = (
(1.0 + beta**2.0)
* TP
/ ((1.0 + beta**2.0) * TP + beta**2.0 * FN + FP)
)
self.summary["MCC"] = (TP * TN - FP * FN) / (
(TP + FP) * (TP + FN) * (TN + FP) * (TN + FN) + eps
) ** 0.5
if field is not None:
return self.summary[field]
else:
return self.summary
def EER(positive_scores, negative_scores):
"""Computes the EER (and its threshold).
Arguments
---------
positive_scores : torch.tensor
The scores from entries of the same class.
negative_scores : torch.tensor
The scores from entries of different classes.
Returns
-------
EER : float
The EER score.
threshold : float
The corresponding threshold for the EER score.
Example
-------
>>> positive_scores = torch.tensor([0.6, 0.7, 0.8, 0.5])
>>> negative_scores = torch.tensor([0.4, 0.3, 0.2, 0.1])
>>> val_eer, threshold = EER(positive_scores, negative_scores)
>>> val_eer
0.0
"""
# Computing candidate thresholds
thresholds, _ = torch.sort(torch.cat([positive_scores, negative_scores]))
thresholds = torch.unique(thresholds)
# Adding intermediate thresholds
intermediate_thresholds = (thresholds[0:-1] + thresholds[1:]) / 2
thresholds, _ = torch.sort(torch.cat([thresholds, intermediate_thresholds]))
# Variable to store the min FRR, min FAR and their corresponding index
min_index = 0
final_FRR = 0
final_FAR = 0
for i, cur_thresh in enumerate(thresholds):
pos_scores_threshold = positive_scores <= cur_thresh
FRR = (pos_scores_threshold.sum(0)).float() / positive_scores.shape[0]
del pos_scores_threshold
neg_scores_threshold = negative_scores > cur_thresh
FAR = (neg_scores_threshold.sum(0)).float() / negative_scores.shape[0]
del neg_scores_threshold
# Finding the threshold for EER
if (FAR - FRR).abs().item() < abs(final_FAR - final_FRR) or i == 0:
min_index = i
final_FRR = FRR.item()
final_FAR = FAR.item()
# It is possible that eer != fpr != fnr. We return (FAR + FRR) / 2 as EER.
EER = (final_FAR + final_FRR) / 2
return float(EER), float(thresholds[min_index])
def minDCF(
positive_scores, negative_scores, c_miss=1.0, c_fa=1.0, p_target=0.01
):
"""Computes the minDCF metric normally used to evaluate speaker verification
systems. The min_DCF is the minimum of the following C_det function computed
within the defined threshold range:
C_det = c_miss * p_miss * p_target + c_fa * p_fa * (1 -p_target)
where p_miss is the missing probability and p_fa is the probability of having
a false alarm.
Arguments
---------
positive_scores : torch.tensor
The scores from entries of the same class.
negative_scores : torch.tensor
The scores from entries of different classes.
c_miss : float
Cost assigned to a missing error (default 1.0).
c_fa : float
Cost assigned to a false alarm (default 1.0).
p_target: float
Prior probability of having a target (default 0.01).
Returns
-------
minDCF : float
The minDCF score.
threshold : float
The corresponding threshold for the minDCF score.
Example
-------
>>> positive_scores = torch.tensor([0.6, 0.7, 0.8, 0.5])
>>> negative_scores = torch.tensor([0.4, 0.3, 0.2, 0.1])
>>> val_minDCF, threshold = minDCF(positive_scores, negative_scores)
>>> val_minDCF
0.0
"""
# Computing candidate thresholds
thresholds, _ = torch.sort(torch.cat([positive_scores, negative_scores]))
thresholds = torch.unique(thresholds)
# Adding intermediate thresholds
intermediate_thresholds = (thresholds[0:-1] + thresholds[1:]) / 2
thresholds, _ = torch.sort(torch.cat([thresholds, intermediate_thresholds]))
# Computing False Rejection Rate (miss detection)
positive_scores = torch.cat(
len(thresholds) * [positive_scores.unsqueeze(0)]
)
pos_scores_threshold = positive_scores.transpose(0, 1) <= thresholds
p_miss = (pos_scores_threshold.sum(0)).float() / positive_scores.shape[1]
del positive_scores
del pos_scores_threshold
# Computing False Acceptance Rate (false alarm)
negative_scores = torch.cat(
len(thresholds) * [negative_scores.unsqueeze(0)]
)
neg_scores_threshold = negative_scores.transpose(0, 1) > thresholds
p_fa = (neg_scores_threshold.sum(0)).float() / negative_scores.shape[1]
del negative_scores
del neg_scores_threshold
c_det = c_miss * p_miss * p_target + c_fa * p_fa * (1 - p_target)
c_min, min_index = torch.min(c_det, dim=0)
return float(c_min), float(thresholds[min_index])
class ClassificationStats(MetricStats):
"""Computes statistics pertaining to multi-label classification tasks, as
well as tasks that can be loosely interpreted as such for the purpose of evaluations.
Example
-------
>>> import sys
>>> from speechbrain.utils.metric_stats import ClassificationStats
>>> cs = ClassificationStats()
>>> cs.append(
... ids=["ITEM1", "ITEM2", "ITEM3", "ITEM4"],
... predictions=[
... "M EY K AH",
... "T EY K",
... "B AE D",
... "M EY K",
... ],
... targets=[
... "M EY K",
... "T EY K",
... "B AE D",
... "M EY K",
... ],
... categories=["make", "take", "bad", "make"],
... )
>>> cs.write_stats(sys.stdout)
Overall Accuracy: 75%
<BLANKLINE>
Class-Wise Accuracy
-------------------
bad -> B AE D : 1 / 1 (100.00%)
make -> M EY K: 1 / 2 (50.00%)
take -> T EY K: 1 / 1 (100.00%)
<BLANKLINE>
Confusion
---------
Target: bad -> B AE D
-> B AE D : 1 / 1 (100.00%)
Target: make -> M EY K
-> M EY K : 1 / 2 (50.00%)
-> M EY K AH: 1 / 2 (50.00%)
Target: take -> T EY K
-> T EY K : 1 / 1 (100.00%)
>>> summary = cs.summarize()
>>> summary["accuracy"]
0.75
>>> summary["classwise_stats"][("bad", "B AE D")]
{'total': 1.0, 'correct': 1.0, 'accuracy': 1.0}
>>> summary["classwise_stats"][("make", "M EY K")]
{'total': 2.0, 'correct': 1.0, 'accuracy': 0.5}
>>> summary["keys"]
[('bad', 'B AE D'), ('make', 'M EY K'), ('take', 'T EY K')]
>>> summary["predictions"]
['B AE D', 'M EY K', 'M EY K AH', 'T EY K']
>>> summary["classwise_total"]
{('bad', 'B AE D'): 1.0, ('make', 'M EY K'): 2.0, ('take', 'T EY K'): 1.0}
>>> summary["classwise_correct"]
{('bad', 'B AE D'): 1.0, ('make', 'M EY K'): 1.0, ('take', 'T EY K'): 1.0}
>>> summary["classwise_accuracy"]
{('bad', 'B AE D'): 1.0, ('make', 'M EY K'): 0.5, ('take', 'T EY K'): 1.0}
"""