Skip to content

Commit 43c8ad1

Browse files
Adapting transducer greedy decoding (#2975)
Co-authored-by: Parcollet Titouan <parcollet.titouan@gmail.com> Co-authored-by: Youness Dkhissi <youness.dkhissi@orange.com>
1 parent dadc1d3 commit 43c8ad1

2 files changed

Lines changed: 50 additions & 35 deletions

File tree

speechbrain/decoders/transducer.py

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,11 @@ def forward(self, tn_output):
154154
return hyps
155155

156156
def transducer_greedy_decode(
157-
self, tn_output, hidden_state=None, return_hidden=False
157+
self,
158+
tn_output,
159+
hidden_state=None,
160+
return_hidden=False,
161+
max_symbols_per_step=5,
158162
):
159163
"""Transducer greedy decoder is a greedy decoder over batch which apply Transducer rules:
160164
1- for each time step in the Transcription Network (TN) output:
@@ -176,6 +180,9 @@ def transducer_greedy_decode(
176180
return_hidden : bool
177181
Whether the return tuple should contain an extra 5th element with
178182
the hidden state at of the last step. See `hidden_state`.
183+
max_symbols_per_step : int
184+
Maximum number of non-blank symbols to decode per time step. This is
185+
useful to avoid infinite loops.
179186
180187
Returns
181188
-------
@@ -221,42 +228,49 @@ def transducer_greedy_decode(
221228

222229
# For each time step
223230
for t_step in range(tn_output.size(1)):
224-
# do unsqueeze over since tjoint must be have a 4 dim [B,T,U,Hidden]
225-
log_probs = self._joint_forward_step(
226-
tn_output[:, t_step, :].unsqueeze(1).unsqueeze(1),
227-
out_PN.unsqueeze(1),
228-
)
229-
# Sort outputs at time
230-
logp_targets, positions = torch.max(
231-
log_probs.squeeze(1).squeeze(1), dim=1
232-
)
233-
# Batch hidden update
234-
have_update_hyp = []
235-
for i in range(positions.size(0)):
236-
# Update hiddens only if
237-
# 1- current prediction is non blank
238-
if positions[i].item() != self.blank_id:
239-
hyp["prediction"][i].append(positions[i].item())
240-
hyp["logp_scores"][i] += logp_targets[i]
241-
input_PN[i][0] = positions[i]
242-
have_update_hyp.append(i)
243-
if len(have_update_hyp) > 0:
244-
# Select sentence to update
245-
# And do a forward steps + generated hidden
246-
(
247-
selected_input_PN,
248-
selected_hidden,
249-
) = self._get_sentence_to_update(
250-
have_update_hyp, input_PN, hidden
251-
)
252-
selected_out_PN, selected_hidden = self._forward_PN(
253-
selected_input_PN, self.decode_network_lst, selected_hidden
231+
count = 0
232+
while count <= max_symbols_per_step: # avoid infinite loop
233+
# do unsqueeze over since tjoint must be have a 4 dim [B,T,U,Hidden]
234+
log_probs = self._joint_forward_step(
235+
tn_output[:, t_step, :].unsqueeze(1).unsqueeze(1),
236+
out_PN.unsqueeze(1),
254237
)
255-
# update hiddens and out_PN
256-
out_PN[have_update_hyp] = selected_out_PN
257-
hidden = self._update_hiddens(
258-
have_update_hyp, selected_hidden, hidden
238+
# Sort outputs at time
239+
logp_targets, positions = torch.max(
240+
log_probs.squeeze(1).squeeze(1), dim=1
259241
)
242+
# Batch hidden update
243+
have_update_hyp = []
244+
for i in range(positions.size(0)):
245+
# Update hiddens only if
246+
# 1- current prediction is non blank
247+
if positions[i].item() != self.blank_id:
248+
hyp["prediction"][i].append(positions[i].item())
249+
hyp["logp_scores"][i] += logp_targets[i]
250+
input_PN[i][0] = positions[i]
251+
have_update_hyp.append(i)
252+
if len(have_update_hyp) > 0:
253+
# Select sentence to update
254+
# And do a forward steps + generated hidden
255+
(
256+
selected_input_PN,
257+
selected_hidden,
258+
) = self._get_sentence_to_update(
259+
have_update_hyp, input_PN, hidden
260+
)
261+
selected_out_PN, selected_hidden = self._forward_PN(
262+
selected_input_PN,
263+
self.decode_network_lst,
264+
selected_hidden,
265+
)
266+
# update hiddens and out_PN
267+
out_PN[have_update_hyp] = selected_out_PN
268+
hidden = self._update_hiddens(
269+
have_update_hyp, selected_hidden, hidden
270+
)
271+
else:
272+
break
273+
count += 1
260274

261275
ret = (
262276
hyp["prediction"],

speechbrain/inference/ASR.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,7 @@ def _get_audio_stream(
612612
sample_rate=self.audio_normalizer.sample_rate,
613613
format="fltp", # torch.float32
614614
num_channels=1,
615+
buffer_chunk_size=-1, # avoiding the problem of dropping first chunks
615616
)
616617

617618
for (chunk,) in streamer.stream():

0 commit comments

Comments
 (0)