Skip to content

Commit 37ce613

Browse files
committed
Experiments on ecapa with different audio streams and aug. models
1 parent 742b49c commit 37ce613

2 files changed

Lines changed: 36 additions & 4 deletions

File tree

recipes/AMI/Diarization/experiment_beam.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ def diarize_dataset(full_csv, split_type, n_lambdas, pval, n_neighbors=10):
276276
params["affinity"],
277277
n_neighbors,
278278
)
279+
# Maybe used for AHC
279280
if params["backend"] == "AHC":
280281
# call AHC
281282
threshold = pval # TODO: have to update the calling function
@@ -625,6 +626,11 @@ def check_dirs():
625626
"Tuning for p-value for SC (Multiple iterations over AMI Dev set)"
626627
)
627628
best_pval = dev_p_tuner(full_csv, "dev")
629+
630+
elif params["backend"] == "AHC":
631+
logger.info("Tuning for threshold-value for AHC")
632+
best_threshold = dev_threshold_tuner(full_csv, "dev")
633+
best_pval = best_threshold
628634
else:
629635
# This part (NN for unknown num of speakers) is WIP
630636
if params["oracle_n_spkrs"] is False:

speechbrain/processing/diarization.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,13 +1147,39 @@ def do_AHC(diary_obj, out_rttm_file, rec_id, k_oracle=4, p_val=0.3):
11471147

11481148
# p_val --> threshold_val (for AHC)
11491149

1150+
# Mean centering and normalizing the embeddings
1151+
# mu = diary_obj.get_mean_stat1()
1152+
# print ("mu = " , mu)
1153+
# print ("\n\n\n\n")
1154+
# print("stat0: ", diary_obj.stat0)
1155+
# diary_obj.center_stat1(mu)
1156+
diary_obj.norm_stat1()
1157+
1158+
import fastcluster
1159+
from scipy.cluster.hierarchy import fcluster
1160+
1161+
linkage_matrix = fastcluster.linkage(
1162+
diary_obj.stat1, method="ward", metric="cosine"
1163+
)
1164+
1165+
# processing
11501166
if k_oracle is not None:
11511167
print("ORACLE SPKRs...")
11521168
num_of_spk = k_oracle
1153-
clustering = AgglomerativeClustering(
1154-
n_clusters=num_of_spk, linkage="ward"
1155-
).fit(diary_obj.stat1)
1156-
labels = clustering.labels_
1169+
1170+
predicted_label = (
1171+
fcluster(linkage_matrix, num_of_spk, criterion="maxclust") - 1
1172+
)
1173+
1174+
# clustering = AgglomerativeClustering(
1175+
# n_clusters=num_of_spk,
1176+
# affinity="cosine",
1177+
# linkage="ward",
1178+
# ).fit(diary_obj.stat1)
1179+
# labels = clustering.labels_
1180+
1181+
labels = predicted_label
1182+
11571183
print("labels.shape (Oracle) = ", labels.shape)
11581184
else:
11591185
print("Using AHC threshold pval = ", p_val)

0 commit comments

Comments
 (0)