Skip to content

Commit 139dd99

Browse files
committed
fix testscript invalid config: only add detector updates for td models
1 parent 69033a2 commit 139dd99

1 file changed

Lines changed: 14 additions & 6 deletions

File tree

examples/testscript_pytorch_multi_animal.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,25 @@ def main(
6767
"runner.device": device,
6868
"runner.snapshots.save_epochs": save_epochs,
6969
"runner.snapshots.max_snapshots": max_snapshots_to_keep,
70-
"detector.train_settings.display_iters": 1,
71-
"detector.train_settings.epochs": detector_epochs,
72-
"detector.train_settings.batch_size": detector_batch_size,
73-
"detector.train_settings.dataloader_workers": 0,
74-
"detector.runner.snapshots.save_epochs": save_epochs,
75-
"detector.runner.snapshots.max_snapshots": max_snapshots_to_keep,
7670
"logger": logger,
7771
}
72+
73+
# Only add detector config updates for top-down models
74+
if is_model_top_down(net_type):
75+
pytorch_cfg_updates.update({
76+
"detector.train_settings.display_iters": 1,
77+
"detector.train_settings.epochs": detector_epochs,
78+
"detector.train_settings.batch_size": detector_batch_size,
79+
"detector.train_settings.dataloader_workers": 0,
80+
"detector.runner.snapshots.save_epochs": save_epochs,
81+
"detector.runner.snapshots.max_snapshots": max_snapshots_to_keep,
82+
})
83+
84+
# Only add conditional top-down config updates for conditional top-down models
7885
if is_model_cond_top_down(net_type):
7986
pytorch_cfg_updates["inference.conditions.shuffle"] = conditions_shuffle
8087
pytorch_cfg_updates["inference.conditions.snapshot_index"] = -1
88+
8189
run(
8290
config_path=config_path,
8391
train_fraction=train_frac,

0 commit comments

Comments
 (0)