Skip to content

Commit 9161dc7

Browse files
committed
Added a few more examples in processing/diarization.py
1 parent 8f8799a commit 9161dc7

2 files changed

Lines changed: 125 additions & 13 deletions

File tree

recipes/AMI/Diarization/hyperparams.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
1111
# Folders
1212
# data+annotation: http://groups.inf.ed.ac.uk/ami/download/
1313
data_folder: /network/datasets/ami/amicorpus
14-
manual_annot_folder: /miniscratch/ravanelm/ami_public_manual/
14+
manual_annot_folder: /network/tmp1/dawalatn/AMI_MANUAL/
1515
output_folder: results/ami/sd_xvector/
1616
save_folder: !ref <output_folder>/save
1717
device: 'cuda:0'
@@ -49,8 +49,8 @@ vad_type: 'oracle'
4949
max_subseg_dur: 3.0
5050
overlap: 1.5
5151

52-
# Cluster parameters
53-
affinity: 'nn'
52+
# Spectral Clustering parameters
53+
affinity: 'cos'
5454
max_num_spkrs: 10
5555
oracle_n_spkrs: False
5656

speechbrain/processing/diarization.py

Lines changed: 122 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def write_rttm(segs_list, out_rttm_file):
339339
#######################################
340340

341341

342-
def graph_connected_component(graph, node_id):
342+
def _graph_connected_component(graph, node_id):
343343
"""
344344
Find the largest graph connected components that contains one
345345
given node.
@@ -382,7 +382,7 @@ def graph_connected_component(graph, node_id):
382382
return connected_nodes
383383

384384

385-
def graph_is_connected(graph):
385+
def _graph_is_connected(graph):
386386
"""
387387
Return whether the graph is connected (True) or Not (False)
388388
@@ -403,10 +403,10 @@ def graph_is_connected(graph):
403403
return n_connected_components == 1
404404
else:
405405
# dense graph, find all connected components start from node 0
406-
return graph_connected_component(graph, 0).sum() == graph.shape[0]
406+
return _graph_connected_component(graph, 0).sum() == graph.shape[0]
407407

408408

409-
def set_diag(laplacian, value, norm_laplacian):
409+
def _set_diag(laplacian, value, norm_laplacian):
410410
"""
411411
Set the diagonal of the laplacian matrix and convert it to a sparse
412412
format well suited for eigenvalue decomposition.
@@ -451,7 +451,7 @@ def set_diag(laplacian, value, norm_laplacian):
451451
return laplacian
452452

453453

454-
def deterministic_vector_sign_flip(u):
454+
def _deterministic_vector_sign_flip(u):
455455
"""
456456
Modify the sign of vectors for reproducibility. Flips the sign of
457457
elements of all the vectors (rows of u) such that the absolute
@@ -474,7 +474,7 @@ def deterministic_vector_sign_flip(u):
474474
return u
475475

476476

477-
def check_random_state(seed):
477+
def _check_random_state(seed):
478478
"""
479479
Turn seed into a np.random.RandomState instance.
480480
@@ -554,13 +554,46 @@ def spectral_embedding_sb(
554554
If True, then compute normalized Laplacian.
555555
drop_first : bool
556556
Whether to drop the first eigenvector.
557+
558+
Returns
559+
-------
560+
embedding : array
561+
Spectral embeddings for each sample
562+
563+
Example
564+
-------
565+
>>> import numpy as np
566+
>>> from speechbrain.processing import diarization as diar
567+
>>> affinity = np.array([[1, 1, 1, 0.5, 0, 0, 0, 0, 0, 0.5],
568+
... [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
569+
... [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
570+
... [0.5, 0, 0, 1, 1, 1, 0, 0, 0, 0],
571+
... [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
572+
... [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
573+
... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
574+
... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
575+
... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
576+
... [0.5, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
577+
>>> embs = diar.spectral_embedding_sb(affinity, 3)
578+
>>> # Notice similar embeddings
579+
>>> print(np.around(embs , decimals=3))
580+
[[ 0.075 0.244 0.285]
581+
[ 0.083 0.356 -0.203]
582+
[ 0.083 0.356 -0.203]
583+
[ 0.26 -0.149 0.154]
584+
[ 0.29 -0.218 -0.11 ]
585+
[ 0.29 -0.218 -0.11 ]
586+
[-0.198 -0.084 -0.122]
587+
[-0.198 -0.084 -0.122]
588+
[-0.198 -0.084 -0.122]
589+
[-0.167 -0.044 0.316]]
557590
"""
558591

559592
# Whether to drop the first eigenvector
560593
if drop_first:
561594
n_components = n_components + 1
562595

563-
if not graph_is_connected(adjacency):
596+
if not _graph_is_connected(adjacency):
564597
warnings.warn(
565598
"Graph is not fully connected, spectral embedding"
566599
" may not work as expected."
@@ -570,7 +603,7 @@ def spectral_embedding_sb(
570603
adjacency, normed=norm_laplacian, return_diag=True
571604
)
572605

573-
laplacian = set_diag(laplacian, 1, norm_laplacian)
606+
laplacian = _set_diag(laplacian, 1, norm_laplacian)
574607

575608
laplacian *= -1
576609

@@ -583,7 +616,7 @@ def spectral_embedding_sb(
583616
if norm_laplacian:
584617
embedding = embedding / dd
585618

586-
embedding = deterministic_vector_sign_flip(embedding)
619+
embedding = _deterministic_vector_sign_flip(embedding)
587620
if drop_first:
588621
return embedding[1:n_components].T
589622
else:
@@ -608,9 +641,31 @@ def spectral_clustering_sb(
608641
A pseudo random number generator used by kmeans.
609642
n_init : int
610643
Number of time the k-means algorithm will be run with different centroid seeds.
644+
645+
Returns
646+
-------
647+
labels : array
648+
Cluster label for each sample
649+
650+
Example
651+
-------
652+
>>> import numpy as np
653+
>>> from speechbrain.processing import diarization as diar
654+
>>> affinity = np.array([[1, 1, 1, 0.5, 0, 0, 0, 0, 0, 0.5],
655+
... [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
656+
... [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
657+
... [0.5, 0, 0, 1, 1, 1, 0, 0, 0, 0],
658+
... [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
659+
... [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
660+
... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
661+
... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
662+
... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
663+
... [0.5, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
664+
>>> labs = diar.spectral_clustering_sb(affinity, 3)
665+
>>> # print (labs) # [2 2 2 1 1 1 0 0 0 0]
611666
"""
612667

613-
random_state = check_random_state(random_state)
668+
random_state = _check_random_state(random_state)
614669
n_components = n_clusters if n_components is None else n_components
615670

616671
maps = spectral_embedding_sb(
@@ -666,6 +721,63 @@ class Spec_Clust_unorm:
666721
---------
667722
Von Luxburg, U. A tutorial on spectral clustering. Stat Comput 17, 395–416 (2007).
668723
https://doi.org/10.1007/s11222-007-9033-z
724+
725+
Example
726+
-------
727+
>>> from speechbrain.processing import diarization as diar
728+
>>> clust = diar.Spec_Clust_unorm(min_num_spkrs=2, max_num_spkrs=10)
729+
>>> emb = [[ 2.1, 3.1, 4.1, 4.2, 3.1],
730+
... [ 2.2, 3.1, 4.2, 4.2, 3.2],
731+
... [ 2.0, 3.0, 4.0, 4.1, 3.0],
732+
... [ 8.0, 7.0, 7.0, 8.1, 9.0],
733+
... [ 8.1, 7.1, 7.2, 8.1, 9.2],
734+
... [ 8.3, 7.4, 7.0, 8.4, 9.0],
735+
... [ 0.3, 0.4, 0.4, 0.5, 0.8],
736+
... [ 0.4, 0.3, 0.6, 0.7, 0.8],
737+
... [ 0.2, 0.3, 0.2, 0.3, 0.7],
738+
... [ 0.3, 0.4, 0.4, 0.4, 0.7],]
739+
>>> # Estimating similarity matrix
740+
>>> sim_mat = clust.get_sim_mat(emb)
741+
>>> print (np.around(sim_mat[5:,5:], decimals=3))
742+
[[1. 0.957 0.961 0.904 0.966]
743+
[0.957 1. 0.977 0.982 0.997]
744+
[0.961 0.977 1. 0.928 0.972]
745+
[0.904 0.982 0.928 1. 0.976]
746+
[0.966 0.997 0.972 0.976 1. ]]
747+
>>> # Prunning
748+
>>> prunned_sim_mat = clust.p_pruning(sim_mat, 0.3)
749+
>>> print (np.around(prunned_sim_mat[5:,5:], decimals=3))
750+
[[1. 0. 0. 0. 0. ]
751+
[0. 1. 0. 0.982 0.997]
752+
[0. 0.977 1. 0. 0.972]
753+
[0. 0.982 0. 1. 0.976]
754+
[0. 0.997 0. 0.976 1. ]]
755+
>>> # Symmetrization
756+
>>> sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
757+
>>> print (np.around(sym_prund_sim_mat[5:,5:], decimals=3))
758+
[[1. 0. 0. 0. 0. ]
759+
[0. 1. 0.489 0.982 0.997]
760+
[0. 0.489 1. 0. 0.486]
761+
[0. 0.982 0. 1. 0.976]
762+
[0. 0.997 0.486 0.976 1. ]]
763+
>>> # Laplacian
764+
>>> laplacian = clust.get_laplacian(sym_prund_sim_mat)
765+
>>> print (np.around(laplacian[5:,5:], decimals=3))
766+
[[ 1.999 0. 0. 0. 0. ]
767+
[ 0. 2.468 -0.489 -0.982 -0.997]
768+
[ 0. -0.489 0.975 0. -0.486]
769+
[ 0. -0.982 0. 1.958 -0.976]
770+
[ 0. -0.997 -0.486 -0.976 2.458]]
771+
>>> # Spectral Embeddings
772+
>>> spec_emb, num_of_spk = clust.get_spec_embs(laplacian, 3)
773+
>>> print(num_of_spk)
774+
3
775+
>>> # Clustering
776+
>>> clust.cluster_embs(spec_emb, num_of_spk)
777+
>>> # print (clust.labels_) # [0 0 0 2 2 2 1 1 1 1]
778+
>>> # Complete spectral clustering
779+
>>> clust.do_spec_clust(emb, k_oracle=3, p_val=0.3)
780+
>>> # print(clust.labels_) # [0 0 0 2 2 2 1 1 1 1]
669781
"""
670782

671783
def __init__(self, min_num_spkrs=2, max_num_spkrs=10):

0 commit comments

Comments
 (0)