Skip to content

Commit eccbea7

Browse files
committed
add code to support inference test
1 parent 5ad8ae9 commit eccbea7

File tree

1 file changed

+151
-55
lines changed

1 file changed

+151
-55
lines changed

tests/utils/recipe_tests.py

Lines changed: 151 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ def check_performance(
182182
check field of the csv recipe file
183183
For instance: performance_check=[train_log.txt, train loss, <=15, epoch: 2]),
184184
will check the variable "train_loss" in the train_log.txt at epoch 2. It will
185-
raise an error if the train_loss is >15.
185+
raise an error if the train_loss is >15. If epoch is -1, we check the last
186+
line of the file.
186187
187188
Arguments
188189
---------
@@ -228,10 +229,18 @@ def check_performance(
228229

229230
# Fitler the lines
230231
lines_filt = []
232+
last_line = ""
231233
for line in lines:
234+
if len(line.strip()) > 0:
235+
last_line = line
232236
if epoch in line:
233237
lines_filt.append(line)
234238

239+
# If epoch: -1, we take the last line of the file
240+
epoch_id = int(epoch.split(":")[-1].strip())
241+
if epoch_id == -1:
242+
lines_filt.append(last_line)
243+
235244
# Raising an error if there are no lines after applying the filter
236245
if len(lines_filt) == 0:
237246
print(
@@ -243,16 +252,15 @@ def check_performance(
243252
for line in lines_filt:
244253

245254
# Search variable value
246-
pattern = variable + ": " + "(.*?) "
247-
var_value = re.search(pattern, line)
255+
var_value = extract_value(line, variable)
248256

249257
if var_value is None:
250258
print(
251259
"\tERROR: The file %s of recipe %s does not contain the variable %s (needed for performance checks)"
252260
% (filename, recipe_id, variable)
253261
)
254262
return False
255-
var_value = float(var_value.group(1))
263+
var_value = float(var_value)
256264
check = check_threshold(threshold, var_value)
257265

258266
if not (check):
@@ -266,6 +274,41 @@ def check_performance(
266274
return check
267275

268276

277+
def extract_value(string, key):
278+
"""Extracts from the input string the value given a key. For instance:
279+
input_string = "Epoch loaded: 49 - test loss: 4.71e-01, test PER: 14.21"
280+
print(extract_value(input_string, "test loss")) # Output: 0.471
281+
print(extract_value(input_string, "test PER")) # Output: 14.21
282+
283+
Arguments
284+
---------
285+
string: str
286+
The input string. It should be in the format mentioned above.
287+
key: str
288+
The key argument to extract.
289+
290+
Returns
291+
---------
292+
value: float or str
293+
The value corresponding to the specified key.
294+
"""
295+
escaped_key = re.escape(key)
296+
297+
# Create the regular expression pattern to match the argument and its corresponding value
298+
pattern = r"(?P<key>{})\s*:\s*(?P<value>[-+]?\d*\.\d+([eE][-+]?\d+)?)".format(
299+
escaped_key
300+
)
301+
302+
# Search for the pattern in the input string
303+
match = re.search(pattern, string)
304+
305+
if match:
306+
value = match.group("value")
307+
return value
308+
else:
309+
return None
310+
311+
269312
def check_threshold(threshold, value):
270313
"""Checks if the value satisfied the threshold constraints.
271314
@@ -411,56 +454,36 @@ def run_recipe_tests(
411454

412455
# Download all upfront
413456
if download_only:
414-
for i, recipe_id in enumerate(test_script.keys()):
415-
# If we are interested in performance checks only, skip
416-
check_str = test_check[recipe_id].strip()
417-
if run_tests_with_checks_only:
418-
if len(check_str) == 0:
419-
continue
457+
download_only_test(
458+
test_script,
459+
test_hparam,
460+
test_flag,
461+
test_check,
462+
run_opts,
463+
run_tests_with_checks_only,
464+
output_folder,
465+
)
466+
return False
420467

421-
print(
422-
"(%i/%i) Collecting pretrained models for %s..."
423-
% (i + 1, len(test_script.keys()), recipe_id)
424-
)
468+
# Run script (check how to get std out, std err and save them in files)
469+
check = True
470+
for i, recipe_id in enumerate(test_script.keys()):
425471

472+
# Check if the output folder is specified in test_field
473+
spec_outfold = False
474+
if "--output_folder" in test_flag[recipe_id]:
475+
pattern = r"--output_folder\s*=?\s*([^\s']+|'[^']*')"
476+
match = re.search(pattern, test_flag[recipe_id])
477+
output_fold = match.group(1).strip("'")
478+
spec_outfold = True
479+
else:
426480
output_fold = os.path.join(output_folder, recipe_id)
427481
os.makedirs(output_fold, exist_ok=True)
428-
stdout_file = os.path.join(output_fold, "stdout.txt")
429-
stderr_file = os.path.join(output_fold, "stderr.txt")
430-
431-
cmd = (
432-
"python -c 'import sys;from hyperpyyaml import load_hyperpyyaml;import speechbrain;"
433-
"hparams_file, run_opts, overrides = speechbrain.parse_arguments(sys.argv[1:]);"
434-
"fin=open(hparams_file);hparams = load_hyperpyyaml(fin, overrides);fin.close();"
435-
# 'speechbrain.create_experiment_directory(experiment_directory=hparams["output_folder"],'
436-
# 'hyperparams_to_save=hparams_file,overrides=overrides,);'
437-
)
438-
with open(test_hparam[recipe_id]) as hparam_file:
439-
for line in hparam_file:
440-
if "pretrainer" in line:
441-
cmd += 'hparams["pretrainer"].collect_files();hparams["pretrainer"].load_collected(device="cpu");'
442-
elif "from_pretrained" in line:
443-
field = line.split(":")[0].strip()
444-
cmd += f'hparams["{field}"]'
445-
cmd += (
446-
"' "
447-
+ test_hparam[recipe_id]
448-
+ " --output_folder="
449-
+ output_fold
450-
+ " "
451-
+ test_flag[recipe_id]
452-
+ " "
453-
+ run_opts
454-
)
455-
456-
# Prepare the test
457-
run_test_cmd(cmd, stdout_file, stderr_file)
458482

459-
return False
483+
# Create files for storing standard input and standard output
484+
stdout_file = os.path.join(output_fold, "stdout.txt")
485+
stderr_file = os.path.join(output_fold, "stderr.txt")
460486

461-
# Run script (check how to get std out, std err and save them in files)
462-
check = True
463-
for i, recipe_id in enumerate(test_script.keys()):
464487
# If we are interested in performance checks only, skip
465488
check_str = test_check[recipe_id].strip()
466489
if run_tests_with_checks_only:
@@ -472,11 +495,6 @@ def run_recipe_tests(
472495
% (i + 1, len(test_script.keys()), recipe_id)
473496
)
474497

475-
output_fold = os.path.join(output_folder, recipe_id)
476-
os.makedirs(output_fold, exist_ok=True)
477-
stdout_file = os.path.join(output_fold, "stdout.txt")
478-
stderr_file = os.path.join(output_fold, "stderr.txt")
479-
480498
# Check for setup scripts
481499
setup_script = os.path.join(
482500
"tests/recipes/setup",
@@ -494,14 +512,15 @@ def run_recipe_tests(
494512
+ test_script[recipe_id]
495513
+ " "
496514
+ test_hparam[recipe_id]
497-
+ " --output_folder="
498-
+ output_fold
499515
+ " "
500516
+ test_flag[recipe_id]
501517
+ " "
502518
+ run_opts
503519
)
504520

521+
if not spec_outfold:
522+
cmd = cmd + " --output_folder=" + output_fold
523+
505524
# add --debug if no do_checks to save testing time
506525
if not do_checks:
507526
cmd += " --debug --debug_persistently"
@@ -536,6 +555,83 @@ def run_recipe_tests(
536555
return check
537556

538557

558+
def download_only_test(
559+
test_script,
560+
test_hparam,
561+
test_flag,
562+
test_check,
563+
run_opts,
564+
run_tests_with_checks_only,
565+
output_folder,
566+
):
567+
"""Downloads only the needed data (useful for off-line tests).
568+
569+
Arguments
570+
---------
571+
test_script: dict
572+
A Dictionary containing recipe IDs as keys and test_scripts as values.
573+
test_hparam: dict
574+
A dictionary containing recipe IDs as keys and hparams as values.
575+
test_flag: dict
576+
A dictionary containing recipe IDs as keys and the test flags as values.
577+
test_check: dict
578+
A dictionary containing recipe IDs as keys and the checks as values.
579+
run_opts: str
580+
Running options to append to each test.
581+
run_tests_with_checks_only: str
582+
Running options to append to each test.
583+
run_tests_with_checks_only: bool
584+
If True skips all tests that do not have performance check criteria defined.
585+
output_folder: path
586+
The output folder where to store all the test outputs.
587+
"""
588+
589+
for i, recipe_id in enumerate(test_script.keys()):
590+
# If we are interested in performance checks only, skip
591+
check_str = test_check[recipe_id].strip()
592+
if run_tests_with_checks_only:
593+
if len(check_str) == 0:
594+
continue
595+
596+
print(
597+
"(%i/%i) Collecting pretrained models for %s..."
598+
% (i + 1, len(test_script.keys()), recipe_id)
599+
)
600+
601+
output_fold = os.path.join(output_folder, recipe_id)
602+
os.makedirs(output_fold, exist_ok=True)
603+
stdout_file = os.path.join(output_fold, "stdout.txt")
604+
stderr_file = os.path.join(output_fold, "stderr.txt")
605+
606+
cmd = (
607+
"python -c 'import sys;from hyperpyyaml import load_hyperpyyaml;import speechbrain;"
608+
"hparams_file, run_opts, overrides = speechbrain.parse_arguments(sys.argv[1:]);"
609+
"fin=open(hparams_file);hparams = load_hyperpyyaml(fin, overrides);fin.close();"
610+
# 'speechbrain.create_experiment_directory(experiment_directory=hparams["output_folder"],'
611+
# 'hyperparams_to_save=hparams_file,overrides=overrides,);'
612+
)
613+
with open(test_hparam[recipe_id]) as hparam_file:
614+
for line in hparam_file:
615+
if "pretrainer" in line:
616+
cmd += 'hparams["pretrainer"].collect_files();hparams["pretrainer"].load_collected(device="cpu");'
617+
elif "from_pretrained" in line:
618+
field = line.split(":")[0].strip()
619+
cmd += f'hparams["{field}"]'
620+
cmd += (
621+
"' "
622+
+ test_hparam[recipe_id]
623+
+ " --output_folder="
624+
+ output_fold
625+
+ " "
626+
+ test_flag[recipe_id]
627+
+ " "
628+
+ run_opts
629+
)
630+
631+
# Prepare the test
632+
run_test_cmd(cmd, stdout_file, stderr_file)
633+
634+
539635
def load_yaml_test(
540636
recipe_folder="tests/recipes",
541637
script_field="Script_file",

0 commit comments

Comments
 (0)