@@ -182,7 +182,8 @@ def check_performance(
182182 check field of the csv recipe file
183183 For instance: performance_check=[train_log.txt, train loss, <=15, epoch: 2]),
184184 will check the variable "train_loss" in the train_log.txt at epoch 2. It will
185- raise an error if the train_loss is >15.
185+ raise an error if the train_loss is >15. If epoch is -1, we check the last
186+ line of the file.
186187
187188 Arguments
188189 ---------
@@ -228,10 +229,18 @@ def check_performance(
228229
229230 # Fitler the lines
230231 lines_filt = []
232+ last_line = ""
231233 for line in lines :
234+ if len (line .strip ()) > 0 :
235+ last_line = line
232236 if epoch in line :
233237 lines_filt .append (line )
234238
239+ # If epoch: -1, we take the last line of the file
240+ epoch_id = int (epoch .split (":" )[- 1 ].strip ())
241+ if epoch_id == - 1 :
242+ lines_filt .append (last_line )
243+
235244 # Raising an error if there are no lines after applying the filter
236245 if len (lines_filt ) == 0 :
237246 print (
@@ -243,16 +252,15 @@ def check_performance(
243252 for line in lines_filt :
244253
245254 # Search variable value
246- pattern = variable + ": " + "(.*?) "
247- var_value = re .search (pattern , line )
255+ var_value = extract_value (line , variable )
248256
249257 if var_value is None :
250258 print (
251259 "\t ERROR: The file %s of recipe %s does not contain the variable %s (needed for performance checks)"
252260 % (filename , recipe_id , variable )
253261 )
254262 return False
255- var_value = float (var_value . group ( 1 ) )
263+ var_value = float (var_value )
256264 check = check_threshold (threshold , var_value )
257265
258266 if not (check ):
@@ -266,6 +274,41 @@ def check_performance(
266274 return check
267275
268276
277+ def extract_value (string , key ):
278+ """Extracts from the input string the value given a key. For instance:
279+ input_string = "Epoch loaded: 49 - test loss: 4.71e-01, test PER: 14.21"
280+ print(extract_value(input_string, "test loss")) # Output: 0.471
281+ print(extract_value(input_string, "test PER")) # Output: 14.21
282+
283+ Arguments
284+ ---------
285+ string: str
286+ The input string. It should be in the format mentioned above.
287+ key: str
288+ The key argument to extract.
289+
290+ Returns
291+ ---------
292+ value: float or str
293+ The value corresponding to the specified key.
294+ """
295+ escaped_key = re .escape (key )
296+
297+ # Create the regular expression pattern to match the argument and its corresponding value
298+ pattern = r"(?P<key>{})\s*:\s*(?P<value>[-+]?\d*\.\d+([eE][-+]?\d+)?)" .format (
299+ escaped_key
300+ )
301+
302+ # Search for the pattern in the input string
303+ match = re .search (pattern , string )
304+
305+ if match :
306+ value = match .group ("value" )
307+ return value
308+ else :
309+ return None
310+
311+
269312def check_threshold (threshold , value ):
270313 """Checks if the value satisfied the threshold constraints.
271314
@@ -411,56 +454,36 @@ def run_recipe_tests(
411454
412455 # Download all upfront
413456 if download_only :
414- for i , recipe_id in enumerate (test_script .keys ()):
415- # If we are interested in performance checks only, skip
416- check_str = test_check [recipe_id ].strip ()
417- if run_tests_with_checks_only :
418- if len (check_str ) == 0 :
419- continue
457+ download_only_test (
458+ test_script ,
459+ test_hparam ,
460+ test_flag ,
461+ test_check ,
462+ run_opts ,
463+ run_tests_with_checks_only ,
464+ output_folder ,
465+ )
466+ return False
420467
421- print (
422- "(%i/%i) Collecting pretrained models for %s..."
423- % (i + 1 , len (test_script .keys ()), recipe_id )
424- )
468+ # Run script (check how to get std out, std err and save them in files)
469+ check = True
470+ for i , recipe_id in enumerate (test_script .keys ()):
425471
472+ # Check if the output folder is specified in test_field
473+ spec_outfold = False
474+ if "--output_folder" in test_flag [recipe_id ]:
475+ pattern = r"--output_folder\s*=?\s*([^\s']+|'[^']*')"
476+ match = re .search (pattern , test_flag [recipe_id ])
477+ output_fold = match .group (1 ).strip ("'" )
478+ spec_outfold = True
479+ else :
426480 output_fold = os .path .join (output_folder , recipe_id )
427481 os .makedirs (output_fold , exist_ok = True )
428- stdout_file = os .path .join (output_fold , "stdout.txt" )
429- stderr_file = os .path .join (output_fold , "stderr.txt" )
430-
431- cmd = (
432- "python -c 'import sys;from hyperpyyaml import load_hyperpyyaml;import speechbrain;"
433- "hparams_file, run_opts, overrides = speechbrain.parse_arguments(sys.argv[1:]);"
434- "fin=open(hparams_file);hparams = load_hyperpyyaml(fin, overrides);fin.close();"
435- # 'speechbrain.create_experiment_directory(experiment_directory=hparams["output_folder"],'
436- # 'hyperparams_to_save=hparams_file,overrides=overrides,);'
437- )
438- with open (test_hparam [recipe_id ]) as hparam_file :
439- for line in hparam_file :
440- if "pretrainer" in line :
441- cmd += 'hparams["pretrainer"].collect_files();hparams["pretrainer"].load_collected(device="cpu");'
442- elif "from_pretrained" in line :
443- field = line .split (":" )[0 ].strip ()
444- cmd += f'hparams["{ field } "]'
445- cmd += (
446- "' "
447- + test_hparam [recipe_id ]
448- + " --output_folder="
449- + output_fold
450- + " "
451- + test_flag [recipe_id ]
452- + " "
453- + run_opts
454- )
455-
456- # Prepare the test
457- run_test_cmd (cmd , stdout_file , stderr_file )
458482
459- return False
483+ # Create files for storing standard input and standard output
484+ stdout_file = os .path .join (output_fold , "stdout.txt" )
485+ stderr_file = os .path .join (output_fold , "stderr.txt" )
460486
461- # Run script (check how to get std out, std err and save them in files)
462- check = True
463- for i , recipe_id in enumerate (test_script .keys ()):
464487 # If we are interested in performance checks only, skip
465488 check_str = test_check [recipe_id ].strip ()
466489 if run_tests_with_checks_only :
@@ -472,11 +495,6 @@ def run_recipe_tests(
472495 % (i + 1 , len (test_script .keys ()), recipe_id )
473496 )
474497
475- output_fold = os .path .join (output_folder , recipe_id )
476- os .makedirs (output_fold , exist_ok = True )
477- stdout_file = os .path .join (output_fold , "stdout.txt" )
478- stderr_file = os .path .join (output_fold , "stderr.txt" )
479-
480498 # Check for setup scripts
481499 setup_script = os .path .join (
482500 "tests/recipes/setup" ,
@@ -494,14 +512,15 @@ def run_recipe_tests(
494512 + test_script [recipe_id ]
495513 + " "
496514 + test_hparam [recipe_id ]
497- + " --output_folder="
498- + output_fold
499515 + " "
500516 + test_flag [recipe_id ]
501517 + " "
502518 + run_opts
503519 )
504520
521+ if not spec_outfold :
522+ cmd = cmd + " --output_folder=" + output_fold
523+
505524 # add --debug if no do_checks to save testing time
506525 if not do_checks :
507526 cmd += " --debug --debug_persistently"
@@ -536,6 +555,83 @@ def run_recipe_tests(
536555 return check
537556
538557
558+ def download_only_test (
559+ test_script ,
560+ test_hparam ,
561+ test_flag ,
562+ test_check ,
563+ run_opts ,
564+ run_tests_with_checks_only ,
565+ output_folder ,
566+ ):
567+ """Downloads only the needed data (useful for off-line tests).
568+
569+ Arguments
570+ ---------
571+ test_script: dict
572+ A Dictionary containing recipe IDs as keys and test_scripts as values.
573+ test_hparam: dict
574+ A dictionary containing recipe IDs as keys and hparams as values.
575+ test_flag: dict
576+ A dictionary containing recipe IDs as keys and the test flags as values.
577+ test_check: dict
578+ A dictionary containing recipe IDs as keys and the checks as values.
579+ run_opts: str
580+ Running options to append to each test.
581+ run_tests_with_checks_only: str
582+ Running options to append to each test.
583+ run_tests_with_checks_only: bool
584+ If True skips all tests that do not have performance check criteria defined.
585+ output_folder: path
586+ The output folder where to store all the test outputs.
587+ """
588+
589+ for i , recipe_id in enumerate (test_script .keys ()):
590+ # If we are interested in performance checks only, skip
591+ check_str = test_check [recipe_id ].strip ()
592+ if run_tests_with_checks_only :
593+ if len (check_str ) == 0 :
594+ continue
595+
596+ print (
597+ "(%i/%i) Collecting pretrained models for %s..."
598+ % (i + 1 , len (test_script .keys ()), recipe_id )
599+ )
600+
601+ output_fold = os .path .join (output_folder , recipe_id )
602+ os .makedirs (output_fold , exist_ok = True )
603+ stdout_file = os .path .join (output_fold , "stdout.txt" )
604+ stderr_file = os .path .join (output_fold , "stderr.txt" )
605+
606+ cmd = (
607+ "python -c 'import sys;from hyperpyyaml import load_hyperpyyaml;import speechbrain;"
608+ "hparams_file, run_opts, overrides = speechbrain.parse_arguments(sys.argv[1:]);"
609+ "fin=open(hparams_file);hparams = load_hyperpyyaml(fin, overrides);fin.close();"
610+ # 'speechbrain.create_experiment_directory(experiment_directory=hparams["output_folder"],'
611+ # 'hyperparams_to_save=hparams_file,overrides=overrides,);'
612+ )
613+ with open (test_hparam [recipe_id ]) as hparam_file :
614+ for line in hparam_file :
615+ if "pretrainer" in line :
616+ cmd += 'hparams["pretrainer"].collect_files();hparams["pretrainer"].load_collected(device="cpu");'
617+ elif "from_pretrained" in line :
618+ field = line .split (":" )[0 ].strip ()
619+ cmd += f'hparams["{ field } "]'
620+ cmd += (
621+ "' "
622+ + test_hparam [recipe_id ]
623+ + " --output_folder="
624+ + output_fold
625+ + " "
626+ + test_flag [recipe_id ]
627+ + " "
628+ + run_opts
629+ )
630+
631+ # Prepare the test
632+ run_test_cmd (cmd , stdout_file , stderr_file )
633+
634+
539635def load_yaml_test (
540636 recipe_folder = "tests/recipes" ,
541637 script_field = "Script_file" ,
0 commit comments