@@ -533,10 +533,7 @@ def get_oracle_num_spkrs(rec_id, spkr_info):
533533
534534
535535def spectral_embedding_sb (
536- adjacency ,
537- n_components = 8 ,
538- norm_laplacian = True ,
539- drop_first = True ,
536+ adjacency , n_components = 8 , norm_laplacian = True , drop_first = True ,
540537):
541538 """Returns spectral embeddings.
542539
@@ -605,10 +602,7 @@ def spectral_embedding_sb(
605602 laplacian *= - 1
606603
607604 vals , diffusion_map = eigsh (
608- laplacian ,
609- k = n_components ,
610- sigma = 1.0 ,
611- which = "LM" ,
605+ laplacian , k = n_components , sigma = 1.0 , which = "LM" ,
612606 )
613607
614608 embedding = diffusion_map .T [n_components ::- 1 ]
@@ -624,11 +618,7 @@ def spectral_embedding_sb(
624618
625619
626620def spectral_clustering_sb (
627- affinity ,
628- n_clusters = 8 ,
629- n_components = None ,
630- random_state = None ,
631- n_init = 10 ,
621+ affinity , n_clusters = 8 , n_components = None , random_state = None , n_init = 10 ,
632622):
633623 """Performs spectral clustering.
634624
@@ -672,9 +662,7 @@ def spectral_clustering_sb(
672662 n_components = n_clusters if n_components is None else n_components
673663
674664 maps = spectral_embedding_sb (
675- affinity ,
676- n_components = n_components ,
677- drop_first = False ,
665+ affinity , n_components = n_components , drop_first = False ,
678666 )
679667
680668 _ , labels , _ = k_means (
@@ -705,16 +693,13 @@ def perform_sc(self, X, n_neighbors=10):
705693
706694 # Computation of affinity matrix
707695 connectivity = kneighbors_graph (
708- X ,
709- n_neighbors = n_neighbors ,
710- include_self = True ,
696+ X , n_neighbors = n_neighbors , include_self = True ,
711697 )
712698 self .affinity_matrix_ = 0.5 * (connectivity + connectivity .T )
713699
714700 # Perform spectral clustering on affinity matrix
715701 self .labels_ = spectral_clustering_sb (
716- self .affinity_matrix_ ,
717- n_clusters = self .n_clusters ,
702+ self .affinity_matrix_ , n_clusters = self .n_clusters ,
718703 )
719704 return self
720705
@@ -1168,9 +1153,7 @@ def do_AHC(diary_obj, out_rttm_file, rec_id, k_oracle=4, p_val=0.3):
11681153 num_of_spk = k_oracle
11691154
11701155 clustering = AgglomerativeClustering (
1171- n_clusters = num_of_spk ,
1172- affinity = "cosine" ,
1173- linkage = "ward" ,
1156+ n_clusters = num_of_spk , affinity = "cosine" , linkage = "ward" ,
11741157 ).fit (diary_obj .stat1 )
11751158 labels = clustering .labels_
11761159
0 commit comments