Skip to content

Commit c60e606

Browse files
committed
fixed hard-coded device
1 parent cc72c9e commit c60e606

2 files changed

Lines changed: 9 additions & 3 deletions

File tree

recipes/timers-and-such/direct/hparams/train.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,7 @@ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
6666
limit: !ref <number_of_epochs>
6767

6868
# Models
69-
asr_model: !apply:speechbrain.pretrained.EncoderDecoderASR.from_hparams
70-
source: speechbrain/asr-crdnn-rnnlm-librispeech
71-
run_opts: {"device":"cuda:0"}
69+
asr_model_source: speechbrain/asr-crdnn-rnnlm-librispeech
7270

7371
slu_enc: !new:speechbrain.nnet.containers.Sequential
7472
input_shape: [null, null, !ref <ASR_encoder_dim>]

recipes/timers-and-such/direct/train.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,14 @@ def text_pipeline(semantics):
344344
run_on_main(hparams["pretrainer"].collect_files)
345345
hparams["pretrainer"].load_collected(device=run_opts["device"])
346346

347+
# Download pretrained ASR model
348+
from speechbrain.pretrained import EncoderDecoderASR
349+
350+
hparams["asr_model"] = EncoderDecoderASR.from_hparams(
351+
source=hparams["asr_model_source"],
352+
run_opts={"device": run_opts["device"]},
353+
)
354+
347355
# Brain class initialization
348356
slu_brain = SLU(
349357
modules=hparams["modules"],

0 commit comments

Comments
 (0)