Skip to content

Commit c3b5fce

Browse files
Adel-MoumenCopilot
andauthored
fix issue with --arg value and --arg=value (#2999)
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
1 parent 18b41a4 commit c3b5fce

File tree

2 files changed

+53
-4
lines changed

2 files changed

+53
-4
lines changed

speechbrain/utils/run_opts.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,16 @@ def from_command_line_args(cls, arg_list=None):
266266

267267
# Go through arg list to see which were set
268268
# NOTE: Slight risk of collisions if an arg value matches an arg name
269-
overridden_args = {
270-
arg_mapping[arg] for arg in arg_list if arg in arg_mapping
271-
}
272-
269+
overridden_args = set()
270+
for arg in arg_list:
271+
# Handle both --arg=value and --arg value formats
272+
if arg.startswith("--") and "=" in arg:
273+
# Split on first = to get the argument name
274+
arg_name = arg.split("=", 1)[0]
275+
if arg_name in arg_mapping:
276+
overridden_args.add(arg_mapping[arg_name])
277+
elif arg in arg_mapping:
278+
overridden_args.add(arg_mapping[arg])
273279
# Add a record of which args were specified
274280
run_opts = cls(
275281
**{**vars(parsed_args), "overridden_args": overridden_args}

tests/unittests/test_core.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,49 @@ def test_parse_arguments():
99
assert overrides == "seed: 3\ndata_folder: TIMIT"
1010

1111

12+
def test_parse_arguments_both_formats():
13+
"""Test that both --arg=value and --arg value formats work correctly."""
14+
from speechbrain.utils.run_opts import RunOptions
15+
16+
# Test with --arg=value format
17+
filename1, run_opts1, overrides1 = RunOptions.from_command_line_args(
18+
["params.yaml", "--device=cuda:0", "--seed=42"]
19+
)
20+
assert filename1 == "params.yaml"
21+
assert run_opts1["device"] == "cuda:0"
22+
assert "device" in run_opts1.overridden_args
23+
24+
# Test with --arg value format
25+
filename2, run_opts2, overrides2 = RunOptions.from_command_line_args(
26+
[
27+
"params.yaml",
28+
"--device",
29+
"cuda:1",
30+
"--max_grad_norm",
31+
"10.0",
32+
"--seed",
33+
"99",
34+
]
35+
)
36+
assert filename2 == "params.yaml"
37+
assert run_opts2["device"] == "cuda:1"
38+
assert run_opts2["max_grad_norm"] == 10.0
39+
assert "device" in run_opts2.overridden_args
40+
assert "max_grad_norm" in run_opts2.overridden_args
41+
assert overrides2 == "seed: 99"
42+
43+
# Test with mixed formats
44+
filename3, run_opts3, overrides3 = RunOptions.from_command_line_args(
45+
["params.yaml", "--device=cuda:2", "--max_grad_norm", "5.0", "--seed=7"]
46+
)
47+
assert filename3 == "params.yaml"
48+
assert run_opts3["device"] == "cuda:2"
49+
assert run_opts3["max_grad_norm"] == 5.0
50+
assert "device" in run_opts3.overridden_args
51+
assert "max_grad_norm" in run_opts3.overridden_args
52+
assert overrides3 == "seed: 7"
53+
54+
1255
def test_brain(device):
1356
import torch
1457
from torch.optim import SGD

0 commit comments

Comments
 (0)