Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions speechbrain/lobes/models/huggingface_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def __init__(
self._hop_length = feature_extractor.hop_length
self._n_samples = feature_extractor.n_samples
self.register_buffer(
"_mel_filters", torch.as_tensor(feature_extractor.mel_filters)
"_mel_filters",
torch.as_tensor(feature_extractor.mel_filters, dtype=torch.float32),
Comment thread
Adel-Moumen marked this conversation as resolved.
Outdated
)

self.model = WhisperModel.from_pretrained(source, cache_dir=save_path)
Expand Down Expand Up @@ -244,7 +245,7 @@ def _log_mel_spectrogram(self, audio):
magnitudes = stft[..., :-1].abs() ** 2

filters = self._mel_filters
mel_spec = filters @ magnitudes
mel_spec = filters.transpose(0, 1) @ magnitudes
Comment thread
Adel-Moumen marked this conversation as resolved.
Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why it is coming from the transformers library.... For me, it seems to be related to the torch library...

Copy link
Copy Markdown
Contributor Author

@sangeet2020 sangeet2020 Jun 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whats the torch and trasnformers version are you using? I'll try the exact same version.
Can you try installing torch 2.0 and transformers 4.26 vs transformers 4.32. I can confirm that I maintain the torch version throughout my expts and vary the transformer version, and the error vanishes as the per the table above.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in transformers >= 4.30, they change how they compute feature_extractor.mel_filters. As a result in transformers >= 4.30, the shape of feature_extractor.mel_filters is (201,80) while in the prev version, it was (80,201). It causes a problem in our _log_mel_spectrogram function ,we copy from openAI, when we calculate mel_spec = filters @ magnitudes. It expects the filter to be (80,201).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Pooneh,I see. Do you know why they changed their 'feature_extractor.mel_filters' ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is related to this PR on transformers:
huggingface/transformers#21998

Basically what they do is replacing the hand-rolled STFT in the different models including whisper with the one from audio_util.

In Transformers version <= 4.28:
In order to get self.mel_filters , they use function that is specific for whisper --> "get_mel_filters "in /transformers/models/whisper/feature_extraction_whisper.py which returns (n_mels, n_freqs, ) --> (80, 210)
then, to calculate the log-Mel spectrogram, they call their own stft function in transformers/models/whisper/feature_extraction_whisper.py and finally call _np_extract_fbank_features which is basically the same as our function _log_mel_spectrogram in /speechbrain/lobes/models/huggingface_whisper.py which is the same as open_ai function.

In version >= 4.29:
In order to get the self.mel_filters they use "mel_filter_bank" function in transformers/audio_utils.py.
It is the same as following pytorchaudio function:
https://pytorch.org/audio/main/generated/torchaudio.functional.melscale_fbanks.html
This will return (n_freqs, n_mels) --> (210, 80).
Then they call spectrogram function in audio_utils.py to generate log_mel_spectrogram and they transpose mel_filters.
spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram)
This function is quite similar to melscale in Pytorchaudio
https://pytorch.org/audio/main/generated/torchaudio.transforms.MelScale.html

Final Points:

  1. It seems they try to unify all audio functions across different models and use the same implementation as pytorchaudio and Librosa. These models are affected by this change:
  • CLAP
  • M-CTC-T
  • SpeechT5
  • TVLT
  • Whisper
  1. Based on the implementation, I think the suggested fix by @sangeet2020 makes sense.
  2. I am going to run all comon_voice recipes for the major release. Maybe it would be better to merge this change so I could test whisper recipes and make sure there won't be any bugs introduced by this change.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Adel-Moumen and @sangeet2020 I am thinking about applying transpose only if the shape is not n_mels, n_freqs, So it could also work for an older version of transformers.


log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(
Expand Down