@@ -80,3 +80,133 @@ def test_pin_memory():
8080 )
8181 batch .pin_memory ()
8282 assert batch .foo .data .is_pinned ()
83+
84+
85+ def test_paddedbatch_per_key_padding (device ):
86+ """Test per-key padding configuration functionality."""
87+ from speechbrain .dataio .batch import PaddedBatch
88+
89+ examples = [
90+ {
91+ "wav" : torch .tensor ([1 , 2 , 3 ]).to (device ),
92+ "labels" : torch .tensor ([1 , 2 ]).to (device ),
93+ },
94+ {
95+ "wav" : torch .tensor ([4 , 5 ]).to (device ),
96+ "labels" : torch .tensor ([3 ]).to (device ),
97+ },
98+ ]
99+
100+ # Configure different padding values for different keys
101+ per_key_padding_kwargs = {
102+ "wav" : {"value" : 0 }, # Pad wav with 0
103+ "labels" : {"value" : - 100 }, # Pad labels with -100
104+ }
105+
106+ batch = PaddedBatch (examples , per_key_padding_kwargs = per_key_padding_kwargs )
107+
108+ # Check that wav is padded with 0
109+ assert torch .all (batch .wav .data [1 , 2 :] == 0 )
110+ assert torch .all (
111+ batch .wav .data [0 , :3 ] == torch .tensor ([1 , 2 , 3 ]).to (device )
112+ )
113+
114+ # Check that labels is padded with -100
115+ assert torch .all (batch .labels .data [1 , 1 :] == - 100 )
116+ assert torch .all (
117+ batch .labels .data [0 , :2 ] == torch .tensor ([1 , 2 ]).to (device )
118+ )
119+
120+
121+ def test_paddedbatch_mixed_padding_config (device ):
122+ """Test mixed configuration where some keys use global config and others use per-key config."""
123+ from speechbrain .dataio .batch import PaddedBatch
124+
125+ examples = [
126+ {
127+ "wav" : torch .tensor ([1 , 2 , 3 ]).to (device ),
128+ "labels" : torch .tensor ([1 , 2 ]).to (device ),
129+ "features" : torch .tensor ([0.1 , 0.2 ]).to (device ),
130+ },
131+ {
132+ "wav" : torch .tensor ([4 , 5 ]).to (device ),
133+ "labels" : torch .tensor ([3 ]).to (device ),
134+ "features" : torch .tensor ([0.3 ]).to (device ),
135+ },
136+ ]
137+
138+ # Global padding config (default)
139+ padding_kwargs = {"value" : 0 }
140+
141+ # Per-key config (overrides global for specific keys)
142+ per_key_padding_kwargs = {
143+ "labels" : {"value" : - 100 } # Only labels get special padding
144+ }
145+
146+ batch = PaddedBatch (
147+ examples ,
148+ padding_kwargs = padding_kwargs ,
149+ per_key_padding_kwargs = per_key_padding_kwargs ,
150+ )
151+
152+ # Check that wav uses global padding (0)
153+ assert torch .all (batch .wav .data [1 , 2 :] == 0 )
154+
155+ # Check that labels uses per-key padding (-100)
156+ assert torch .all (batch .labels .data [1 , 1 :] == - 100 )
157+
158+ # Check that features uses global padding (0)
159+ assert torch .all (batch .features .data [1 , 1 :] == 0 )
160+
161+
162+ def test_paddedbatch_numpy_arrays ():
163+ """Test with numpy arrays to ensure conversion works with per-key padding."""
164+ from speechbrain .dataio .batch import PaddedBatch
165+
166+ examples = [
167+ {"wav" : np .array ([1 , 2 , 3 ]), "labels" : np .array ([1 , 2 ])},
168+ {"wav" : np .array ([4 , 5 ]), "labels" : np .array ([3 ])},
169+ ]
170+
171+ per_key_padding_kwargs = {"wav" : {"value" : 0 }, "labels" : {"value" : - 100 }}
172+
173+ batch = PaddedBatch (examples , per_key_padding_kwargs = per_key_padding_kwargs )
174+
175+ # Check that numpy arrays are converted to torch tensors and padded correctly
176+ assert isinstance (batch .wav .data , torch .Tensor )
177+ assert isinstance (batch .labels .data , torch .Tensor )
178+
179+ # Check padding values
180+ assert torch .all (batch .wav .data [1 , 2 :] == 0 )
181+ assert torch .all (batch .labels .data [1 , 1 :] == - 100 )
182+
183+
184+ def test_paddedbatch_backward_compatibility (device ):
185+ """Test that the new functionality maintains backward compatibility."""
186+ from speechbrain .dataio .batch import PaddedBatch
187+
188+ examples = [
189+ {
190+ "wav" : torch .tensor ([1 , 2 , 3 ]).to (device ),
191+ "labels" : torch .tensor ([1 , 2 ]).to (device ),
192+ },
193+ {
194+ "wav" : torch .tensor ([4 , 5 ]).to (device ),
195+ "labels" : torch .tensor ([3 ]).to (device ),
196+ },
197+ ]
198+
199+ # Test with only padding_kwargs (old behavior)
200+ batch_old = PaddedBatch (examples , padding_kwargs = {"value" : 0 })
201+
202+ # Test with only per_key_padding_kwargs (new behavior)
203+ batch_new = PaddedBatch (
204+ examples ,
205+ per_key_padding_kwargs = {"wav" : {"value" : 0 }, "labels" : {"value" : 0 }},
206+ )
207+
208+ # Both should produce the same result
209+ assert torch .allclose (batch_old .wav .data , batch_new .wav .data )
210+ assert torch .allclose (batch_old .labels .data , batch_new .labels .data )
211+ assert torch .allclose (batch_old .wav .lengths , batch_new .wav .lengths )
212+ assert torch .allclose (batch_old .labels .lengths , batch_new .labels .lengths )
0 commit comments