Skip to content

Commit c94d5e2

Browse files
Adel-MoumenCopilot
andauthored
Implement per-key padding configuration in PaddedBatch (#3008)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent fac0367 commit c94d5e2

2 files changed

Lines changed: 170 additions & 3 deletions

File tree

speechbrain/dataio/batch.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,14 @@ class PaddedBatch:
4343
padding_func : callable, optional
4444
Called with a list of tensors to be padded together. Needs to return
4545
two tensors: the padded data, and another tensor for the data lengths.
46-
padding_kwargs : dict
46+
padding_kwargs : dict, None
4747
(Optional) Extra kwargs to pass to padding_func. E.G. mode, value
48+
This is used as the default padding configuration for all keys.
49+
per_key_padding_kwargs : dict, None
50+
(Optional) Per-key padding configuration. Keys in this dict should match
51+
the keys in the examples. Each value should be a dict with padding parameters
52+
(e.g., {'value': -100, 'mode': 'constant'}). If a key is not in this dict,
53+
the global padding_kwargs will be used.
4854
apply_default_convert : bool
4955
Whether to apply PyTorch default_convert (numpy to torch recursively,
5056
etc.) on all data. Default:True, usually does the right thing.
@@ -111,6 +117,26 @@ class PaddedBatch:
111117
... )
112118
>>> batch.text
113119
[['Hello'], ['How', 'are', 'you?']]
120+
>>> # Per-key padding configuration:
121+
>>> batch = PaddedBatch(
122+
... [
123+
... {
124+
... "wav": torch.tensor([1, 2, 3]),
125+
... "labels": torch.tensor([1, 2]),
126+
... },
127+
... {"wav": torch.tensor([4, 5]), "labels": torch.tensor([3])},
128+
... ],
129+
... per_key_padding_kwargs={
130+
... "wav": {"value": 0},
131+
... "labels": {"value": -100},
132+
... },
133+
... )
134+
>>> batch.wav.data
135+
tensor([[1, 2, 3],
136+
[4, 5, 0]])
137+
>>> batch.labels.data
138+
tensor([[ 1, 2],
139+
[ 3, -100]])
114140
115141
"""
116142

@@ -120,10 +146,15 @@ def __init__(
120146
padded_keys=None,
121147
device_prep_keys=None,
122148
padding_func=batch_pad_right,
123-
padding_kwargs={},
149+
padding_kwargs=None,
150+
per_key_padding_kwargs=None,
124151
apply_default_convert=True,
125152
nonpadded_stack=True,
126153
):
154+
padding_kwargs = padding_kwargs if padding_kwargs is not None else {}
155+
per_key_padding_kwargs = (
156+
per_key_padding_kwargs if per_key_padding_kwargs is not None else {}
157+
)
127158
self.__length = len(examples)
128159
self.__keys = list(examples[0].keys())
129160
self.__padded_keys = []
@@ -138,7 +169,13 @@ def __init__(
138169
):
139170
# Padding and PaddedData
140171
self.__padded_keys.append(key)
141-
padded = PaddedData(*padding_func(values, **padding_kwargs))
172+
173+
# Use per-key padding config if available, otherwise fall back to global padding_kwargs
174+
if key in per_key_padding_kwargs:
175+
key_padding_kwargs = per_key_padding_kwargs[key]
176+
else:
177+
key_padding_kwargs = padding_kwargs
178+
padded = PaddedData(*padding_func(values, **key_padding_kwargs))
142179
setattr(self, key, padded)
143180
else:
144181
# Default PyTorch collate usually does the right thing

tests/unittests/test_batching.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,133 @@ def test_pin_memory():
8080
)
8181
batch.pin_memory()
8282
assert batch.foo.data.is_pinned()
83+
84+
85+
def test_paddedbatch_per_key_padding(device):
86+
"""Test per-key padding configuration functionality."""
87+
from speechbrain.dataio.batch import PaddedBatch
88+
89+
examples = [
90+
{
91+
"wav": torch.tensor([1, 2, 3]).to(device),
92+
"labels": torch.tensor([1, 2]).to(device),
93+
},
94+
{
95+
"wav": torch.tensor([4, 5]).to(device),
96+
"labels": torch.tensor([3]).to(device),
97+
},
98+
]
99+
100+
# Configure different padding values for different keys
101+
per_key_padding_kwargs = {
102+
"wav": {"value": 0}, # Pad wav with 0
103+
"labels": {"value": -100}, # Pad labels with -100
104+
}
105+
106+
batch = PaddedBatch(examples, per_key_padding_kwargs=per_key_padding_kwargs)
107+
108+
# Check that wav is padded with 0
109+
assert torch.all(batch.wav.data[1, 2:] == 0)
110+
assert torch.all(
111+
batch.wav.data[0, :3] == torch.tensor([1, 2, 3]).to(device)
112+
)
113+
114+
# Check that labels is padded with -100
115+
assert torch.all(batch.labels.data[1, 1:] == -100)
116+
assert torch.all(
117+
batch.labels.data[0, :2] == torch.tensor([1, 2]).to(device)
118+
)
119+
120+
121+
def test_paddedbatch_mixed_padding_config(device):
122+
"""Test mixed configuration where some keys use global config and others use per-key config."""
123+
from speechbrain.dataio.batch import PaddedBatch
124+
125+
examples = [
126+
{
127+
"wav": torch.tensor([1, 2, 3]).to(device),
128+
"labels": torch.tensor([1, 2]).to(device),
129+
"features": torch.tensor([0.1, 0.2]).to(device),
130+
},
131+
{
132+
"wav": torch.tensor([4, 5]).to(device),
133+
"labels": torch.tensor([3]).to(device),
134+
"features": torch.tensor([0.3]).to(device),
135+
},
136+
]
137+
138+
# Global padding config (default)
139+
padding_kwargs = {"value": 0}
140+
141+
# Per-key config (overrides global for specific keys)
142+
per_key_padding_kwargs = {
143+
"labels": {"value": -100} # Only labels get special padding
144+
}
145+
146+
batch = PaddedBatch(
147+
examples,
148+
padding_kwargs=padding_kwargs,
149+
per_key_padding_kwargs=per_key_padding_kwargs,
150+
)
151+
152+
# Check that wav uses global padding (0)
153+
assert torch.all(batch.wav.data[1, 2:] == 0)
154+
155+
# Check that labels uses per-key padding (-100)
156+
assert torch.all(batch.labels.data[1, 1:] == -100)
157+
158+
# Check that features uses global padding (0)
159+
assert torch.all(batch.features.data[1, 1:] == 0)
160+
161+
162+
def test_paddedbatch_numpy_arrays():
163+
"""Test with numpy arrays to ensure conversion works with per-key padding."""
164+
from speechbrain.dataio.batch import PaddedBatch
165+
166+
examples = [
167+
{"wav": np.array([1, 2, 3]), "labels": np.array([1, 2])},
168+
{"wav": np.array([4, 5]), "labels": np.array([3])},
169+
]
170+
171+
per_key_padding_kwargs = {"wav": {"value": 0}, "labels": {"value": -100}}
172+
173+
batch = PaddedBatch(examples, per_key_padding_kwargs=per_key_padding_kwargs)
174+
175+
# Check that numpy arrays are converted to torch tensors and padded correctly
176+
assert isinstance(batch.wav.data, torch.Tensor)
177+
assert isinstance(batch.labels.data, torch.Tensor)
178+
179+
# Check padding values
180+
assert torch.all(batch.wav.data[1, 2:] == 0)
181+
assert torch.all(batch.labels.data[1, 1:] == -100)
182+
183+
184+
def test_paddedbatch_backward_compatibility(device):
185+
"""Test that the new functionality maintains backward compatibility."""
186+
from speechbrain.dataio.batch import PaddedBatch
187+
188+
examples = [
189+
{
190+
"wav": torch.tensor([1, 2, 3]).to(device),
191+
"labels": torch.tensor([1, 2]).to(device),
192+
},
193+
{
194+
"wav": torch.tensor([4, 5]).to(device),
195+
"labels": torch.tensor([3]).to(device),
196+
},
197+
]
198+
199+
# Test with only padding_kwargs (old behavior)
200+
batch_old = PaddedBatch(examples, padding_kwargs={"value": 0})
201+
202+
# Test with only per_key_padding_kwargs (new behavior)
203+
batch_new = PaddedBatch(
204+
examples,
205+
per_key_padding_kwargs={"wav": {"value": 0}, "labels": {"value": 0}},
206+
)
207+
208+
# Both should produce the same result
209+
assert torch.allclose(batch_old.wav.data, batch_new.wav.data)
210+
assert torch.allclose(batch_old.labels.data, batch_new.labels.data)
211+
assert torch.allclose(batch_old.wav.lengths, batch_new.wav.lengths)
212+
assert torch.allclose(batch_old.labels.lengths, batch_new.labels.lengths)

0 commit comments

Comments
 (0)