Skip to content

Commit fe94a92

Browse files
authored
Merge pull request #2016 from sangeet2020/HF_Whisper
Change needed in Whisper fine-tuning recipe to accommodate transformers4.30.0
2 parents a8265bf + ce53199 commit fe94a92

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

speechbrain/lobes/models/huggingface_whisper.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,18 @@ def __init__(
103103
self._n_fft = feature_extractor.n_fft
104104
self._hop_length = feature_extractor.hop_length
105105
self._n_samples = feature_extractor.n_samples
106+
# The following breaking changes were introduced in transformers>=4.29:
107+
# 1) mel_filters.shape = (..., feature_extractor.feature_size) instead of (feature_extractor.feature_size, ...)
108+
# 2) mel_filters.dtype = float64 instead of float32
109+
# The following code fixes the issue in a backward compatible way
110+
mel_filters = feature_extractor.mel_filters
111+
if mel_filters.shape[0] != feature_extractor.feature_size:
112+
mel_filters = mel_filters.T
113+
assert mel_filters.shape[0] == feature_extractor.feature_size
106114
self.register_buffer(
107-
"_mel_filters", torch.as_tensor(feature_extractor.mel_filters)
115+
"_mel_filters", torch.as_tensor(mel_filters, dtype=torch.float32)
108116
)
117+
#################################################################
109118

110119
self.model = WhisperModel.from_pretrained(source, cache_dir=save_path)
111120

0 commit comments

Comments
 (0)