@@ -67,17 +67,25 @@ def main(
6767 "runner.device" : device ,
6868 "runner.snapshots.save_epochs" : save_epochs ,
6969 "runner.snapshots.max_snapshots" : max_snapshots_to_keep ,
70- "detector.train_settings.display_iters" : 1 ,
71- "detector.train_settings.epochs" : detector_epochs ,
72- "detector.train_settings.batch_size" : detector_batch_size ,
73- "detector.train_settings.dataloader_workers" : 0 ,
74- "detector.runner.snapshots.save_epochs" : save_epochs ,
75- "detector.runner.snapshots.max_snapshots" : max_snapshots_to_keep ,
7670 "logger" : logger ,
7771 }
72+
73+ # Only add detector config updates for top-down models
74+ if is_model_top_down (net_type ):
75+ pytorch_cfg_updates .update ({
76+ "detector.train_settings.display_iters" : 1 ,
77+ "detector.train_settings.epochs" : detector_epochs ,
78+ "detector.train_settings.batch_size" : detector_batch_size ,
79+ "detector.train_settings.dataloader_workers" : 0 ,
80+ "detector.runner.snapshots.save_epochs" : save_epochs ,
81+ "detector.runner.snapshots.max_snapshots" : max_snapshots_to_keep ,
82+ })
83+
84+ # Only add conditional top-down config updates for conditional top-down models
7885 if is_model_cond_top_down (net_type ):
7986 pytorch_cfg_updates ["inference.conditions.shuffle" ] = conditions_shuffle
8087 pytorch_cfg_updates ["inference.conditions.snapshot_index" ] = - 1
88+
8189 run (
8290 config_path = config_path ,
8391 train_fraction = train_frac ,
0 commit comments