Skip to content

Commit cc72c9e

Browse files
committed
fixed hard-coded device
1 parent 091b3ce commit cc72c9e

2 files changed

Lines changed: 9 additions & 3 deletions

File tree

recipes/fluent-speech-commands/direct/hparams/train.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
5656
limit: !ref <number_of_epochs>
5757

5858
# Models
59-
asr_model: !apply:speechbrain.pretrained.EncoderDecoderASR.from_hparams
60-
source: speechbrain/asr-crdnn-rnnlm-librispeech
61-
run_opts: {"device":"cuda:0"}
59+
asr_model_source: speechbrain/asr-crdnn-rnnlm-librispeech
6260

6361
slu_enc: !new:speechbrain.nnet.containers.Sequential
6462
input_shape: [null, null, !ref <ASR_encoder_dim>]

recipes/fluent-speech-commands/direct/train.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,14 @@ def text_pipeline(semantics):
322322
run_on_main(hparams["pretrainer"].collect_files)
323323
hparams["pretrainer"].load_collected(device=run_opts["device"])
324324

325+
# Download pretrained ASR model
326+
from speechbrain.pretrained import EncoderDecoderASR
327+
328+
hparams["asr_model"] = EncoderDecoderASR.from_hparams(
329+
source=hparams["asr_model_source"],
330+
run_opts={"device": run_opts["device"]},
331+
)
332+
325333
# Brain class initialization
326334
slu_brain = SLU(
327335
modules=hparams["modules"],

0 commit comments

Comments
 (0)