-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathaudio_io.py
More file actions
228 lines (192 loc) · 6.38 KB
/
audio_io.py
File metadata and controls
228 lines (192 loc) · 6.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
"""
Lightweight soundfile-based audio I/O compatibility layer.
This module provides a minimal compatibility wrapper for audio I/O operations
using soundfile (pysoundfile) library, replacing torchaudio's load, save, and
info functions.
Example
-------
>>> from speechbrain.dataio import audio_io
>>> import torch
>>> # Save audio file
>>> waveform = torch.randn(1, 16000)
>>> tmpdir = getfixture("tmpdir")
>>> audio_io.save(tmpdir / "example.wav", waveform, 16000)
>>> # Load audio file
>>> audio, sr = audio_io.load(tmpdir / "example.wav")
>>> # Get audio metadata
>>> info = audio_io.info(tmpdir / "example.wav")
>>> info.duration
1.0
Authors
* Peter Plantinga 2025
"""
import dataclasses
import numpy as np
import soundfile as sf
import torch
@dataclasses.dataclass
class AudioInfo:
"""Container for audio file metadata, compatible with torchaudio.info output.
Attributes
----------
sample_rate : int
Sample rate of the audio file.
frames : int
Total number of frames in the audio file.
channels : int
Number of audio channels.
subtype : str
Audio subtype/encoding (e.g., 'PCM_16', 'PCM_24').
format : str
Container format (e.g., 'WAV', 'FLAC').
"""
sample_rate: int
frames: int
channels: int
subtype: str
format: str
@property
def num_frames(self):
"""Alias for frames for compatibility."""
return self.frames
@property
def num_channels(self):
"""Alias for channels for compatibility."""
return self.channels
@property
def duration(self):
"""Calculate duration in seconds."""
return self.frames / self.sample_rate if self.sample_rate > 0 else 0.0
def load(
path,
*,
channels_first=True,
dtype=None,
always_2d=True,
frame_offset=0,
num_frames=-1,
):
"""Load audio file using soundfile.
Arguments
---------
path : str
Path to the audio file.
channels_first : bool
If True, returns tensor with shape (channels, frames).
If False, returns tensor with shape (frames, channels).
Ignored if `always_2d` is False and input is mono.
Default: True.
dtype : torch.dtype, optional
Data type for the output tensor. Respects default torch type.
If the dtype is not one of the available dtypes in soundfile, loads
with float32 first and then converts to the requested dtype.
always_2d : bool
If True, always return a 2D tensor even for mono audio.
If False, mono audio returns a 1D tensor (frames,).
Default: True.
frame_offset : int
Number of frames to skip at the start of the file. Default: 0.
num_frames : int
Number of frames to read. If -1, reads to the end of the file. Default: -1.
Returns
-------
tensor : torch.Tensor
Audio waveform as a tensor.
sample_rate : int
Sample rate of the audio file.
"""
try:
# Compute type for loading
dtype = dtype or torch.get_default_dtype()
_, dtype_string = str(dtype).split(".")
# If the selected dtype is not a valid soundfile type, just use float32
if dtype_string not in sf._ffi_types:
dtype_string = "float32"
# Read audio file - soundfile returns (frames, channels) or (frames,) for mono
audio_np, sample_rate = sf.read(
path,
start=frame_offset,
frames=num_frames,
dtype=dtype_string,
always_2d=always_2d,
)
# Convert to torch tensor
audio = torch.from_numpy(audio_np).to(dtype)
# Convert from (frames, channels) to (channels, frames)
if audio.ndim == 2 and channels_first:
audio = audio.transpose(0, 1)
return audio, int(sample_rate)
except Exception as e:
raise RuntimeError(f"Failed to load audio from {path}: {e}") from e
def save(path, src, sample_rate, channels_first=True, subtype=None):
"""Save audio to file using soundfile.
Arguments
---------
path : str
Path where to save the audio file.
src : torch.Tensor or numpy.ndarray
Audio waveform. Can be:
- 1D tensor/array: (frames,) - mono
- 2D tensor/array:
- (channels, frames) if channels_first=True
- (frames, channels) if channels_first=False
sample_rate : int
Sample rate for the audio file.
channels_first : bool
If True, input is assumed to be (channels, frames)
If False, input is assumed to be (frames, channels).
Ignored if input is 1D tensor/array.
Default: True.
subtype : str, optional
Audio encoding subtype (e.g., 'PCM_16', 'PCM_24', 'PCM_32', 'FLOAT').
If None, soundfile will choose an appropriate subtype based on the file format.
Default: None.
"""
try:
# Convert to numpy if needed
if isinstance(src, torch.Tensor):
audio_np = src.detach().cpu().numpy()
else:
audio_np = np.asarray(src)
# Convert to (frames, channels) if channels_first is True
if audio_np.ndim == 2 and channels_first:
audio_np = audio_np.T
if audio_np.ndim not in [1, 2]:
raise ValueError(
f"Unsupported audio shape: {audio_np.shape}. "
"Expected 1D frames or 2D channels and frames."
)
sf.write(path, audio_np, sample_rate, subtype=subtype)
except Exception as e:
raise RuntimeError(f"Failed to save audio to {path}: {e}") from e
def info(path):
"""Get audio file metadata using soundfile.
Arguments
---------
path : str
Path to the audio file.
Returns
-------
AudioInfo
Object containing audio metadata (sample_rate, frames, channels,
subtype, format, duration).
"""
try:
file_info = sf.info(path)
return AudioInfo(
sample_rate=file_info.samplerate,
frames=file_info.frames,
channels=file_info.channels,
subtype=file_info.subtype,
format=file_info.format,
)
except Exception as e:
raise RuntimeError(f"Failed to get info for {path}: {e}") from e
def list_audio_backends():
"""List available audio backends.
Returns
-------
list of str
List of available backend names. Currently only ['soundfile'].
"""
return ["soundfile"]