@@ -9,6 +9,49 @@ def test_parse_arguments():
99 assert overrides == "seed: 3\n data_folder: TIMIT"
1010
1111
12+ def test_parse_arguments_both_formats ():
13+ """Test that both --arg=value and --arg value formats work correctly."""
14+ from speechbrain .utils .run_opts import RunOptions
15+
16+ # Test with --arg=value format
17+ filename1 , run_opts1 , overrides1 = RunOptions .from_command_line_args (
18+ ["params.yaml" , "--device=cuda:0" , "--seed=42" ]
19+ )
20+ assert filename1 == "params.yaml"
21+ assert run_opts1 ["device" ] == "cuda:0"
22+ assert "device" in run_opts1 .overridden_args
23+
24+ # Test with --arg value format
25+ filename2 , run_opts2 , overrides2 = RunOptions .from_command_line_args (
26+ [
27+ "params.yaml" ,
28+ "--device" ,
29+ "cuda:1" ,
30+ "--max_grad_norm" ,
31+ "10.0" ,
32+ "--seed" ,
33+ "99" ,
34+ ]
35+ )
36+ assert filename2 == "params.yaml"
37+ assert run_opts2 ["device" ] == "cuda:1"
38+ assert run_opts2 ["max_grad_norm" ] == 10.0
39+ assert "device" in run_opts2 .overridden_args
40+ assert "max_grad_norm" in run_opts2 .overridden_args
41+ assert overrides2 == "seed: 99"
42+
43+ # Test with mixed formats
44+ filename3 , run_opts3 , overrides3 = RunOptions .from_command_line_args (
45+ ["params.yaml" , "--device=cuda:2" , "--max_grad_norm" , "5.0" , "--seed=7" ]
46+ )
47+ assert filename3 == "params.yaml"
48+ assert run_opts3 ["device" ] == "cuda:2"
49+ assert run_opts3 ["max_grad_norm" ] == 5.0
50+ assert "device" in run_opts3 .overridden_args
51+ assert "max_grad_norm" in run_opts3 .overridden_args
52+ assert overrides3 == "seed: 7"
53+
54+
1255def test_brain (device ):
1356 import torch
1457 from torch .optim import SGD
0 commit comments