Skip to content

Commit 2c491a4

Browse files
committed
fix transducer loss inputs devices
1 parent e961fb4 commit 2c491a4

1 file changed

Lines changed: 12 additions & 3 deletions

File tree

speechbrain/nnet/loss/transducer_loss.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ class TransducerLoss(Module):
301301
The TranducerLoss(nn.Module) use Transducer(autograd.Function)
302302
to compute the forward-backward loss and gradients.
303303
304+
Input tensors must be on a cuda device.
305+
304306
Example
305307
-------
306308
>>> import torch
@@ -332,11 +334,18 @@ def __init__(self, blank=0, reduction="mean"):
332334
err_msg += "export NUMBAPRO_NVVM='/usr/local/cuda/nvvm/lib64/libnvvm.so' \n"
333335
err_msg += "================================ \n"
334336
err_msg += "If you use conda:\n"
335-
err_msg += "conda install numba cudatoolkit=9.0"
337+
err_msg += "conda install numba cudatoolkit=XX (XX is your cuda toolkit version)"
336338
raise ImportError(err_msg)
337339

338340
def forward(self, logits, labels, T, U):
339341
"""Computes the transducer loss."""
340342
# Transducer.apply function take log_probs tensor.
341-
log_probs = logits.log_softmax(-1)
342-
return self.loss(log_probs, labels, T, U, self.blank, self.reduction)
343+
if logits.device == labels.device == T.device == U.device == "cuda":
344+
log_probs = logits.log_softmax(-1)
345+
return self.loss(
346+
log_probs, labels, T, U, self.blank, self.reduction
347+
)
348+
else:
349+
raise ValueError(
350+
f"Found inputs tensors to be on {[logits.device, labels.device, T.device, U.device]} while needed to be on a 'cuda' device to use the transducer loss."
351+
)

0 commit comments

Comments
 (0)