@@ -154,7 +154,11 @@ def forward(self, tn_output):
154154 return hyps
155155
156156 def transducer_greedy_decode (
157- self , tn_output , hidden_state = None , return_hidden = False
157+ self ,
158+ tn_output ,
159+ hidden_state = None ,
160+ return_hidden = False ,
161+ max_symbols_per_step = 5 ,
158162 ):
159163 """Transducer greedy decoder is a greedy decoder over batch which apply Transducer rules:
160164 1- for each time step in the Transcription Network (TN) output:
@@ -176,6 +180,9 @@ def transducer_greedy_decode(
176180 return_hidden : bool
177181 Whether the return tuple should contain an extra 5th element with
178182 the hidden state at of the last step. See `hidden_state`.
183+ max_symbols_per_step : int
184+ Maximum number of non-blank symbols to decode per time step. This is
185+ useful to avoid infinite loops.
179186
180187 Returns
181188 -------
@@ -221,42 +228,49 @@ def transducer_greedy_decode(
221228
222229 # For each time step
223230 for t_step in range (tn_output .size (1 )):
224- # do unsqueeze over since tjoint must be have a 4 dim [B,T,U,Hidden]
225- log_probs = self ._joint_forward_step (
226- tn_output [:, t_step , :].unsqueeze (1 ).unsqueeze (1 ),
227- out_PN .unsqueeze (1 ),
228- )
229- # Sort outputs at time
230- logp_targets , positions = torch .max (
231- log_probs .squeeze (1 ).squeeze (1 ), dim = 1
232- )
233- # Batch hidden update
234- have_update_hyp = []
235- for i in range (positions .size (0 )):
236- # Update hiddens only if
237- # 1- current prediction is non blank
238- if positions [i ].item () != self .blank_id :
239- hyp ["prediction" ][i ].append (positions [i ].item ())
240- hyp ["logp_scores" ][i ] += logp_targets [i ]
241- input_PN [i ][0 ] = positions [i ]
242- have_update_hyp .append (i )
243- if len (have_update_hyp ) > 0 :
244- # Select sentence to update
245- # And do a forward steps + generated hidden
246- (
247- selected_input_PN ,
248- selected_hidden ,
249- ) = self ._get_sentence_to_update (
250- have_update_hyp , input_PN , hidden
251- )
252- selected_out_PN , selected_hidden = self ._forward_PN (
253- selected_input_PN , self .decode_network_lst , selected_hidden
231+ count = 0
232+ while count <= max_symbols_per_step : # avoid infinite loop
233+ # do unsqueeze over since tjoint must be have a 4 dim [B,T,U,Hidden]
234+ log_probs = self ._joint_forward_step (
235+ tn_output [:, t_step , :].unsqueeze (1 ).unsqueeze (1 ),
236+ out_PN .unsqueeze (1 ),
254237 )
255- # update hiddens and out_PN
256- out_PN [have_update_hyp ] = selected_out_PN
257- hidden = self ._update_hiddens (
258- have_update_hyp , selected_hidden , hidden
238+ # Sort outputs at time
239+ logp_targets , positions = torch .max (
240+ log_probs .squeeze (1 ).squeeze (1 ), dim = 1
259241 )
242+ # Batch hidden update
243+ have_update_hyp = []
244+ for i in range (positions .size (0 )):
245+ # Update hiddens only if
246+ # 1- current prediction is non blank
247+ if positions [i ].item () != self .blank_id :
248+ hyp ["prediction" ][i ].append (positions [i ].item ())
249+ hyp ["logp_scores" ][i ] += logp_targets [i ]
250+ input_PN [i ][0 ] = positions [i ]
251+ have_update_hyp .append (i )
252+ if len (have_update_hyp ) > 0 :
253+ # Select sentence to update
254+ # And do a forward steps + generated hidden
255+ (
256+ selected_input_PN ,
257+ selected_hidden ,
258+ ) = self ._get_sentence_to_update (
259+ have_update_hyp , input_PN , hidden
260+ )
261+ selected_out_PN , selected_hidden = self ._forward_PN (
262+ selected_input_PN ,
263+ self .decode_network_lst ,
264+ selected_hidden ,
265+ )
266+ # update hiddens and out_PN
267+ out_PN [have_update_hyp ] = selected_out_PN
268+ hidden = self ._update_hiddens (
269+ have_update_hyp , selected_hidden , hidden
270+ )
271+ else :
272+ break
273+ count += 1
260274
261275 ret = (
262276 hyp ["prediction" ],
0 commit comments