Skip to content

Commit 9c826ef

Browse files
authored
Add hooks to ensure saving of kmeans model works correctly (#3047)
1 parent 638ffd6 commit 9c826ef

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

  • speechbrain/integrations/audio_tokenizers

speechbrain/integrations/audio_tokenizers/kmeans.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
import joblib
88
import torch
99

10+
from speechbrain.utils import checkpoints
1011

12+
13+
@checkpoints.register_checkpoint_hooks
1114
class MiniBatchKMeansSklearn(torch.nn.Module):
1215
"""A wrapper for scikit-learn MiniBatchKMeans, providing integration with PyTorch tensors.
1316
@@ -64,6 +67,7 @@ def to(self, device=None, **kwargs):
6467
self.device = device
6568
return super().to(device)
6669

70+
@checkpoints.mark_as_saver
6771
def save(self, path):
6872
"""Saves the model to the specified file.
6973
@@ -74,6 +78,7 @@ def save(self, path):
7478
"""
7579
joblib.dump(self.kmeans, path)
7680

81+
@checkpoints.mark_as_loader
7782
def load(self, path, end_of_epoch):
7883
"""Loads the model from the specified file.
7984

0 commit comments

Comments
 (0)