Skip to content

Commit b2c1023

Browse files
committed
removing fastcluster dependency from AHC
1 parent 23bb211 commit b2c1023

1 file changed

Lines changed: 4 additions & 25 deletions

File tree

speechbrain/processing/diarization.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,38 +1147,17 @@ def do_AHC(diary_obj, out_rttm_file, rec_id, k_oracle=4, p_val=0.3):
11471147

11481148
# p_val --> threshold_val (for AHC)
11491149

1150-
# Mean centering and normalizing the embeddings
1151-
# mu = diary_obj.get_mean_stat1()
1152-
# print ("mu = " , mu)
1153-
# print ("\n\n\n\n")
1154-
# print("stat0: ", diary_obj.stat0)
1155-
# diary_obj.center_stat1(mu)
11561150
diary_obj.norm_stat1()
11571151

1158-
import fastcluster
1159-
from scipy.cluster.hierarchy import fcluster
1160-
1161-
linkage_matrix = fastcluster.linkage(
1162-
diary_obj.stat1, method="ward", metric="cosine"
1163-
)
1164-
11651152
# processing
11661153
if k_oracle is not None:
11671154
print("ORACLE SPKRs...")
11681155
num_of_spk = k_oracle
11691156

1170-
predicted_label = (
1171-
fcluster(linkage_matrix, num_of_spk, criterion="maxclust") - 1
1172-
)
1173-
1174-
# clustering = AgglomerativeClustering(
1175-
# n_clusters=num_of_spk,
1176-
# affinity="cosine",
1177-
# linkage="ward",
1178-
# ).fit(diary_obj.stat1)
1179-
# labels = clustering.labels_
1180-
1181-
labels = predicted_label
1157+
clustering = AgglomerativeClustering(
1158+
n_clusters=num_of_spk, affinity="cosine", linkage="ward",
1159+
).fit(diary_obj.stat1)
1160+
labels = clustering.labels_
11821161

11831162
print("labels.shape (Oracle) = ", labels.shape)
11841163
else:

0 commit comments

Comments
 (0)