-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathrefactoring_checks.py
More file actions
519 lines (445 loc) · 19 KB
/
refactoring_checks.py
File metadata and controls
519 lines (445 loc) · 19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
#!/usr/bin/env/python3
"""This is a test script for creating a list of expected outcomes (before refactoring);
then, manual editing might change YAMLs and/or code; another test runs to compare results
(after refactoring to before). The target is a list of known HF repos.
The goal is to identify to which extent changes break existing functionality.
Then, larger changes to code base can be rolled out more assured.
Authors
* Andreas Nautsch, 2022, 2023
"""
import importlib # noqa
import os
import subprocess
import sys
from copy import deepcopy
from glob import glob
import torch # noqa
import yaml
from hyperpyyaml import load_hyperpyyaml
from torch.utils.data import DataLoader
from tqdm import tqdm
import speechbrain # noqa
from speechbrain.dataio.dataloader import LoopedLoader, make_dataloader
from speechbrain.inference.interfaces import foreign_class # noqa
from speechbrain.utils.distributed import run_on_main # noqa
from speechbrain.utils.train_logger import FileTrainLogger
def init(
new_interfaces_git="https://github.com/speechbrain/speechbrain",
new_interfaces_branch="hf-interface-testing",
new_interfaces_local_dir="tests/tmp/hf_interfaces",
):
"""Initialises a PR branch to: https://github.com/speechbrain/speechbrain/tree/hf-interface-testing
Skip if the path as of `new_interfaces_local_dir` exists (e.g. by DIY init instead of via this script).
Arguments
---------
new_interfaces_git: str
Your git repo (or default: `https://github.com/speechbrain/speechbrain`);
can be specified in tests/utils/overrides.yaml
new_interfaces_branch: str
Default is `hf-interface-testing` (a git branch); can be specified in tests/utils/overrides.yaml
new_interfaces_local_dir: str
Default is `tests/tmp/hf_interfaces` (a local path); can be specified in tests/utils/overrides.yaml
Returns
-------
str
Local path of `updates_pretrained_models` where the update HF yaml/interface files can be found.
"""
# set up git etc
if not os.path.exists(new_interfaces_local_dir):
# note: not checking for anything, whether it exists or not - or if there is a previous one already
# clone repo with PR on updates_pretrained_models into local folder
cmd_out_clone = subprocess.run(
["git", "clone", new_interfaces_git, new_interfaces_local_dir],
capture_output=True,
)
print(f"\tgit clone log: {cmd_out_clone}")
# cd into that local folder, switch branch to the one containing updates_pretrained_models & cd back
cwd = os.getcwd()
os.chdir(new_interfaces_local_dir)
cmd_out_co = subprocess.run(
["git", "checkout", new_interfaces_branch], capture_output=True
)
print(f"\tgit checkout log: {cmd_out_co}")
os.chdir(cwd)
# return the valid local path with updates_pretrained_models
updates_dir = f"{new_interfaces_local_dir}/updates_pretrained_models"
return updates_dir
def get_model(repo, values, updates_dir=None, run_opts=None):
"""Fetches a pretrained model with the option the re-specify its hyperparameters & interface.
Arguments
---------
repo: str
Source of pretrained model (assuming its within the HF speechbrain collection).
values: dict
Interface specification.
Example: speechbrain:hf-interface-testing/updates_pretrained_models/ssl-wav2vec2-base-librispeech/test.yaml
updates_dir: str
Local folder with yaml:interface updates; None (default) = take original yaml/interface specification.
run_opts: dict
Run options, such as device
Returns
-------
A pretrained model with a speechbrain.inference.interface or a custom interface.
"""
# get the pretrained class; model & predictions
kwargs = {
"source": f"speechbrain/{repo}",
"savedir": f"pretrained_models/{repo}",
}
# adjust symlinks
hparams = f"pretrained_models/{repo}/hyperparams.yaml"
if (
"foreign" in values.keys()
): # it's a custom model which has its own Python filename
custom = f"pretrained_models/{repo}/{values['foreign']}"
# prepare model loading: is it the old -or- the new yaml/interface?
if updates_dir is not None:
# testing the refactoring; assuming all model data has been loaded already
kwargs["source"] = f"{updates_dir}/{repo}"
os.unlink(hparams)
os.symlink(f"{updates_dir}/{repo}/hyperparams.yaml", hparams)
if "foreign" in values.keys():
os.unlink(custom)
os.symlink(
f"{updates_dir}/{repo}/{values['foreign']}",
custom,
)
else:
# re:testing on develop? => simply unlink anything before and re:link from cached HF hub
if os.path.exists(hparams):
os.unlink(hparams)
if "foreign" in values.keys():
if os.path.exists(custom):
os.unlink(custom)
if run_opts is not None:
kwargs["run_opts"] = run_opts
print(f"\trepo: {repo}")
# load pretrained model either via specified pretrained class or custom interface
if "foreign" not in values.keys():
print(f"\tspeechbrain.inference.{values['cls']}")
print(f"\tobj.from_hparams({kwargs})")
obj = eval(f"speechbrain.inference.{values['cls']}")
model = obj.from_hparams(**kwargs)
else:
kwargs["pymodule_file"] = values["foreign"]
kwargs["classname"] = values["cls"]
model = foreign_class(**kwargs)
return model
def get_prediction(repo, values, updates_dir=None):
"""Gets the prediction for one predefined audio example, pattern: {repo}/{values["sample"]} (see HF model card).
Arguments
---------
repo: str
Source of pretrained model (assuming its within the HF speechbrain collection).
values: dict
Interface specification.
Examples: speechbrain:hf-interface-testing/updates_pretrained_models/ssl-wav2vec2-base-librispeech/test.yaml
speechbrain:hf-interface-testing/updates_pretrained_models/asr-wav2vec2-librispeech/test.yaml
updates_dir: str
Controls whether/not we are in the refactored results (None: expected results; before refactoring).
Returns
-------
Cleaned-up prediction results for yaml output (result logging & comparison through yaml de/serialization).
"""
def sanitize(data):
# cleanup data for yaml output (w/o this, yaml will make attempts to save torch/numpy arrays in their format)
if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()
if data.ndim:
data = list(data)
return data
# get the pretrained model (before/after yaml/interface update)
model = get_model(repo=repo, values=values, updates_dir=updates_dir) # noqa
try:
# simulate batch from single file
prediction = eval(
f'model.{values["fnx"]}(model.load_audio("{repo}/{values["sample"]}", savedir="pretrained_models/{repo}").unsqueeze(0), torch.tensor([1.0]))'
)
except Exception:
# use an example audio if no audio can be loaded
print(f"\tWARNING - no audio found on HF: {repo}/{values['sample']}")
prediction = eval(
f'model.{values["fnx"]}(model.load_audio("tests/samples/single-mic/example1.wav", savedir="pretrained_models/{repo}").unsqueeze(0), torch.tensor([1.0]))'
)
finally:
del model
return [sanitize(x[0]) for x in prediction]
def gather_expected_results(
glob_filter="*",
new_interfaces_git="https://github.com/speechbrain/speechbrain",
new_interfaces_branch="hf-interface-testing",
new_interfaces_local_dir="tests/tmp/hf_interfaces",
yaml_path="tests/tmp/refactoring_results.yaml",
):
"""Before refactoring HF YAMLs and/or code, gather prediction results.
Arguments
---------
glob_filter: str
Filter for a repo subset or a specific repo.
new_interfaces_git: str
Your git repo (or default: `https://github.com/speechbrain/speechbrain`);
can be specified in tests/utils/overrides.yaml
new_interfaces_branch: str
Default is `hf-interface-testing` (a git branch); can be specified in tests/utils/overrides.yaml
new_interfaces_local_dir: str
Default is `tests/tmp/hf_interfaces` (a local path); can be specified in tests/utils/overrides.yaml
yaml_path : str
Path where to store/load refactoring testing results for later comparison.
"""
# load results, if existing -or- new from scratch
if os.path.exists(yaml_path):
with open(yaml_path, encoding="utf-8") as yaml_in:
results = yaml.safe_load(yaml_in)
else:
results = {}
# go through each repo
updates_dir = init(
new_interfaces_git, new_interfaces_branch, new_interfaces_local_dir
)
repos = map(os.path.basename, glob(f"{updates_dir}/{glob_filter}"))
for repo in repos:
# skip if results are there
if repo not in results.keys():
# get values
with open(
f"{updates_dir}/{repo}/test.yaml", encoding="utf-8"
) as yaml_test:
values = load_hyperpyyaml(yaml_test)
print(f"Collecting results for: {repo} w/ values={values}")
prediction = get_prediction(repo, values)
# extend the results
results[repo] = {"before": prediction}
with open(yaml_path, "w", encoding="utf-8") as yaml_out:
yaml.dump(results, yaml_out, default_flow_style=None)
def gather_refactoring_results(
glob_filter="*",
new_interfaces_git="https://github.com/speechbrain/speechbrain",
new_interfaces_branch="hf-interface-testing",
new_interfaces_local_dir="tests/tmp/hf_interfaces",
yaml_path="tests/tmp/refactoring_results.yaml",
):
"""After refactoring HF YAMLs and/or code, gather prediction results.
Arguments
---------
glob_filter: str
Filter for a repo subset or a specific repo.
new_interfaces_git: str
Your git repo (or default: `https://github.com/speechbrain/speechbrain`);
can be specified in tests/utils/overrides.yaml
new_interfaces_branch: str
Default is `hf-interface-testing` (a git branch); can be specified in tests/utils/overrides.yaml
new_interfaces_local_dir: str
Default is `tests/tmp/hf_interfaces` (a local path); can be specified in tests/utils/overrides.yaml
yaml_path: str
Path where to store/load refactoring testing results for later comparison.
"""
# expected results need to exist
if os.path.exists(yaml_path):
with open(yaml_path, encoding="utf-8") as yaml_in:
results = yaml.safe_load(yaml_in)
# go through each repo
updates_dir = init(
new_interfaces_git, new_interfaces_branch, new_interfaces_local_dir
)
repos = map(os.path.basename, glob(f"{updates_dir}/{glob_filter}"))
for repo in repos:
# skip if results are there
if "after" not in results[repo].keys():
# get values
with open(
f"{updates_dir}/{repo}/test.yaml", encoding="utf-8"
) as yaml_test:
values = load_hyperpyyaml(yaml_test)
print(
f"Collecting refactoring results for: {repo} w/ values={values}"
)
# extend the results
results[repo]["after"] = get_prediction(repo, values, updates_dir)
results[repo]["same"] = (
results[repo]["before"] == results[repo]["after"]
)
# update
with open(yaml_path, "w", encoding="utf-8") as yaml_out:
yaml.dump(results, yaml_out, default_flow_style=None)
print(f"\tsame: {results[repo]['same']}")
def test_performance(
repo, values, run_opts, updates_dir=None, recipe_overrides={}
):
"""Runs the evaluation partition of a recipe dataset for a pretrained model.
Arguments
---------
repo: str
Source of pretrained model (assuming its within the HF speechbrain collection).
values: dict
Interface specification.
Examples: speechbrain:hf-interface-testing/updates_pretrained_models/ssl-wav2vec2-base-librispeech/test.yaml
speechbrain:hf-interface-testing/updates_pretrained_models/asr-wav2vec2-librispeech/test.yaml
run_opts: dict
Run options, such as device
updates_dir: str
Controls whether/not we are in the refactored results (None: expected results; before refactoring).
recipe_overrides: dict
Recipe YAMLs contain placeholders and flags which need to be overwritten (e.g. data_folder & skip_prep).
See: overrides.yaml
Returns
-------
Dict for export to yaml with performance statistics, as specified in the test.yaml files.
"""
# Dataset depending file structure
tmp_dir = f"tests/tmp/{values['dataset']}"
speechbrain.create_experiment_directory(experiment_directory=tmp_dir)
stats_meta = {
f"[{values['dataset']}] - {'BEFORE' if updates_dir is None else 'AFTER'}": repo
}
# Load pretrained
model = get_model(
repo=repo, values=values, updates_dir=updates_dir, run_opts=run_opts
) # noqa
# Dataio preparation; we need the test sets only
with open(values["recipe_yaml"], encoding="utf-8") as fin:
recipe_hparams = load_hyperpyyaml(
fin, values["overrides"] | recipe_overrides
)
# Dataset preparation is assumed to be done through recipes; before running this.
exec(values["dataio"])
test_datasets = deepcopy(eval(values["test_datasets"]))
# harmonise
if type(test_datasets) is not dict:
tmp = {}
if type(test_datasets) is list:
for i, x in enumerate(test_datasets):
tmp[i] = x
else:
tmp[0] = test_datasets
test_datasets = tmp
# prepare testing
logger = FileTrainLogger(save_file=f"{tmp_dir}/{repo}.log")
reporting = deepcopy(values["performance"])
for metric, specs in reporting.items():
reporting[metric]["tracker"] = deepcopy(
recipe_hparams[specs["handler"]]()
)
test_loader_kwargs = deepcopy(recipe_hparams[values["test_loader"]])
del recipe_hparams
stats = {}
for k in test_datasets.keys(): # keys are test_clean, test_other etc
test_set = test_datasets[k]
if not (
isinstance(test_set, DataLoader)
or isinstance(test_set, LoopedLoader)
):
test_set = make_dataloader(test_set, **test_loader_kwargs)
with torch.no_grad():
for batch in tqdm(test_set, dynamic_ncols=True, disable=False):
batch = batch.to(model.device)
wavs, wav_lens = batch.sig
wavs, wav_lens = ( # noqa
wavs.to(model.device),
wav_lens.to(model.device),
)
predictions = eval( # noqa
f"model.{values['fnx']}(wavs, wav_lens)"
)
predicted = eval(values["predicted"]) # noqa
targeted = eval(values["targeted"]) # noqa
ids = batch.id # noqa
for metric in reporting.keys():
reporting[metric]["tracker"].append(
*eval(values["to_stats"])
)
stats[k] = {}
for metric, specs in reporting.items():
stats[k][metric] = specs["tracker"].summarize(specs["field"])
logger.log_stats(
stats_meta=stats_meta | {"set": k}, test_stats=stats[k]
)
return stats
# run first w/ "--after=False" on latest develop, then checkout the refactoring branch and run w/ "--after=True"
# PYTHONPATH=`realpath .` python tests/utils/refactoring_checks.py tests/utils/overrides.yaml --LibriSpeech_data="" --CommonVoice_EN_data="" --CommonVoice_FR_data="" --IEMOCAP_data="" --after=False
if __name__ == "__main__":
hparams_file, run_opts, overrides = speechbrain.parse_arguments(
sys.argv[1:]
)
with open(hparams_file, encoding="utf-8") as fin:
dataset_overrides = load_hyperpyyaml(fin, overrides)
# go through each repo
updates_dir = init(
dataset_overrides["new_interfaces_git"],
dataset_overrides["new_interfaces_branch"],
dataset_overrides["new_interfaces_local_dir"],
)
# load results, if existing -or- new from scratch
yaml_path = f"{dataset_overrides['new_interfaces_local_dir']}.yaml"
if os.path.exists(yaml_path):
with open(yaml_path, encoding="utf-8") as yaml_in:
results = yaml.safe_load(yaml_in)
else:
results = {}
repos = map(
os.path.basename,
glob(f"{updates_dir}/{dataset_overrides['glob_filter']}"),
)
for repo in repos:
# get values
with open(
f"{updates_dir}/{repo}/test.yaml", encoding="utf-8"
) as yaml_test:
values = load_hyperpyyaml(yaml_test)
# for this testing, some fields need to exist; skip otherwise
if any(
[
entry not in values
for entry in [
"dataset",
"overrides",
"dataio",
"test_datasets",
"test_loader",
"performance",
"predicted",
]
]
):
continue
# skip if datasets is not given
if not dataset_overrides[f"{values['dataset']}_data"]:
continue
print(f"Run tests on: {repo}")
if repo not in results.keys():
results[repo] = {}
# Before refactoring
if "before" not in results[repo].keys():
results[repo]["before"] = test_performance(
repo,
values,
updates_dir=None,
run_opts=run_opts,
recipe_overrides=dataset_overrides[values["dataset"]],
)
# update
with open(yaml_path, "w", encoding="utf-8") as yaml_out:
yaml.dump(results, yaml_out, default_flow_style=None)
# After refactoring
if (
"after" not in results[repo].keys()
and dataset_overrides["after"] is True
):
results[repo]["after"] = test_performance(
repo,
values,
run_opts=run_opts,
updates_dir=updates_dir,
recipe_overrides=dataset_overrides[values["dataset"]],
)
results[repo]["same"] = (
results[repo]["before"] == results[repo]["after"]
)
print(f"\tbefore: {results[repo]['before']}")
print(f"\t after: {results[repo]['after']}")
print(f"\t same: {results[repo]['same']}")
# update
with open(yaml_path, "w", encoding="utf-8") as yaml_out:
yaml.dump(results, yaml_out, default_flow_style=None)
# update
with open(yaml_path, "w", encoding="utf-8") as yaml_out:
yaml.dump(results, yaml_out, default_flow_style=None)