Skip to content

Commit d850b5e

Browse files
authored
Update DLC3 PyTorch docs (#2804)
* update user_guide * update_pytorch_config
1 parent e98edcf commit d850b5e

2 files changed

Lines changed: 82 additions & 39 deletions

File tree

docs/pytorch/pytorch_config.md

Lines changed: 61 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ runner:
288288
...
289289
scheduler: # optional: a learning rate scheduler
290290
...
291+
load_scheduler_state_dict: true/false # whether to load scheduler state when resuming training from a snapshot,
291292
snapshots: # parameters for the TorchSnapshotManager
292293
max_snapshots: 5 # the maximum number of snapshots to save (the "best" model does not count as one of them)
293294
save_epochs: 25 # the interval between each snapshot save
@@ -327,7 +328,7 @@ https://pytorch.org/docs/stable/optim.html). Examples:
327328
lr: 1e-4
328329
```
329330

330-
**Scheduler**: YYou can use [any scheduler](
331+
**Scheduler**: You can use [any scheduler](
331332
https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) defined in
332333
`torch.optim.lr_scheduler`, where the arguments given are arguments of the scheduler.
333334
The default scheduler is an LRListScheduler, which changes the learning rates at each
@@ -379,23 +380,43 @@ default.
379380
Additionally, you can log results to [Weights and Biases](https://wandb.ai/site), by adding a
380381
`WandbLogger`. Just make sure you're logged in to your `wandb` account before starting
381382
your training run (with `wandb login` from your shell). For more information, see their
382-
[tutorials](https://docs.wandb.ai/tutorials) and their documentation for
383-
[`wandb.init`](https://docs.wandb.ai/ref/python/init). You can also log images as they are seen by the model to `wandb`
384-
with the `image_log_interval`. This logs a random train and test image, as well as the
385-
targets and heatmaps for that image.
383+
[tutorials](https://docs.wandb.ai/tutorials) and their documentation for [`wandb.init`](https://docs.wandb.ai/ref/python/init).
386384

387385
Logging to `wandb` is a good way to keep track of what you've run, including performance
388386
and metrics.
389387

390388
```yaml
391389
logger:
392390
type: WandbLogger
393-
image_log_interval: 5 # how often images are logged to wandb (in epochs)
394391
project_name: my-dlc3-project # the name of the project where the run should be logged
395392
run_name: dekr-w32-shuffle0 # the name of the run to log
396393
... # any other argument you can pass to `wandb.init`, such as `tags: ["dekr", "split=0"]`
397394
```
398395

396+
You can also log images as they are seen by the model to `wandb`
397+
with the `image_log_interval`. This logs a random train and test image, as well as the
398+
targets and heatmaps for that image.
399+
400+
### Restarting Training at a Specific Checkpoint
401+
402+
If you wish to restart the training at a specific checkpoint, you can specify the
403+
full path of the checkpoint to the `resume_training_from` variable, as shown below. In this
404+
example, `snapshot-010.pt` will be loaded before training starts, and the model will
405+
continue to train from the 10th epoch on.
406+
407+
```yaml
408+
# model configuration
409+
...
410+
# weights from which to resume training
411+
resume_training_from: /Users/john/dlc-project-2021-06-22/dlc-models-pytorch/iteration-0/dlcJun22-trainset95shuffle0/train/snapshot-010.pt
412+
```
413+
414+
When continuing to train a model, you may want to modify the learning rate scheduling
415+
that was being used (by editing the configuration under the `scheduler` key). When doing
416+
so, you *must set `load_scheduler_state_dict: false`* in your `runner` config!
417+
Otherwise, the parameters for the scheduler your started training with will be loaded
418+
from the state dictionary, and your edits might not be kept!
419+
399420
## Training Top-Down Models
400421

401422
Top-down models are split into two main elements: a detector (localizing individuals in
@@ -437,15 +458,37 @@ detector:
437458
...
438459
```
439460

440-
Currently, the only detector available is a `FasterRCNN`. However, multiple variants are
441-
available (you can view the different variants on [torchvision's object detection page](
442-
https://pytorch.org/vision/stable/models.html#object-detection)). It's recommended to
443-
use the fastest detector that brings enough performance. The recommended variants
444-
are the following (from fastest to most powerful, taken from torchvision's
445-
documentation):
446-
447-
| name | Box MAP (larger = more powerful) | Params (larger = more powerful) | GFLOPS (larger = slower) |
448-
|-----------------------------------|----------------------------------:|--------------------------------:|----------------------------:|
449-
| fasterrcnn_mobilenet_v3_large_fpn | 32.8 | 19.4M | 4.49 |
450-
| fasterrcnn_resnet50_fpn | 37 | 41.8M | 134.38 |
451-
| fasterrcnn_resnet50_fpn_v2 | 46.7 | 43.7M | 280.37 |
461+
Currently, the only detectors available are `FasterRCNN` and `SSDLite`. However, multiple variants of
462+
`FasterRCNN` are available (you can view the different variants on
463+
[torchvision's object detection page](https://pytorch.org/vision/stable/models.html#object-detection)). It's recommended to use the fastest
464+
detector that brings enough performance. The recommended variants are the following
465+
(from fastest to most powerful, taken from torchvision's documentation):
466+
467+
| name | Box MAP (larger = more powerful) | Params (larger = more powerful) | GFLOPS (larger = slower) |
468+
|-----------------------------------|---------------------------------:|--------------------------------:|-------------------------:|
469+
| SSDLite | 21.3 | 3.4M | 0.58 |
470+
| fasterrcnn_mobilenet_v3_large_fpn | 32.8 | 19.4M | 4.49 |
471+
| fasterrcnn_resnet50_fpn | 37 | 41.8M | 134.38 |
472+
| fasterrcnn_resnet50_fpn_v2 | 46.7 | 43.7M | 280.37 |
473+
474+
475+
### Restarting Training of an Object Detector at a Specific Checkpoint
476+
477+
If you wish to restart the training of a detector at a specific checkpoint, you can
478+
specify the full path of the checkpoint to the detector's `resume_training_from` variable, as
479+
shown below. In this example, `snapshot-detector-020.pt` will be loaded before training
480+
starts, and the model will continue to train from the 20th epoch on.
481+
482+
```yaml
483+
detector:
484+
# detector configuration
485+
...
486+
# weights from which to resume training
487+
resume_training_from: /Users/john/dlc-project-2021-06-22/dlc-models-pytorch/iteration-0/dlcJun22-trainset95shuffle0/train/snapshot-detector-020.pt
488+
```
489+
490+
When continuing to train a detector, you may want to modify the learning rate scheduling
491+
that was being used (by editing the configuration under the `scheduler` key). When doing
492+
so, you *must set `load_scheduler_state_dict: false`* in your `detector`: `runner`
493+
config! Otherwise, the parameters for the scheduler your started training with will be
494+
loaded from the state dictionary, and your edits might not be kept!

docs/pytorch/user_guide.md

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -69,24 +69,24 @@ print(available_models())
6969

7070
### Development State and Road Map 🚧
7171

72-
The table below describes the DeepLabCut API methods that have been implemented,
73-
as well as indications which options are not yet implemented, and which parameters
74-
are not valid for the DLC 3.0 API.
75-
76-
77-
| API Method | Implemented | Parameters not yet implemented | Parameters invalid for pytorch |
78-
|--------------------------------|:-----------:|-------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------|
79-
| `train_network` | 🟢 | `keepdeconvweights` | `maxiters`, `saveiters`, `allow_growth`, `autotune` |
80-
| `return_train_network_path` | 🟢 | | |
81-
| `evaluate_network` | 🟢 | `comparisonbodyparts`, `rescale`, `per_keypoint_evaluation` | |
82-
| `return_evaluate_network_data` | 🔴 | | `TFGPUinference`, `allow_growth` |
83-
| `analyze_videos` | 🟢 | `use_shelve`, `save_as_csv`, `in_random_order`, `batchsize`, `cropping`, `dynamic`, `robust_nframes`, `n_tracks`, `calibrate` | |
84-
| `create_tracking_dataset` | 🔴 | | |
85-
| `analyze_time_lapse_frames` | 🟠 | the name has changed to `analyze_images` to better reflect what it actually does (no video needed) | |
86-
| `convert_detections2tracklets` | 🟢 | `greedy`, `calibrate`, `window_size` | |
87-
| `extract_maps` | 🔴 | | |
88-
| `visualize_scoremaps` | 🔴 | | |
89-
| `visualize_locrefs` | 🔴 | | |
90-
| `visualize_paf` | 🔴 | | |
91-
| `extract_save_all_maps` | 🔴 | | |
92-
| `export_model` | 🔴 | | |
72+
The table below describes the DeepLabCut API methods that have been implemented for the
73+
PyTorch engine, as well as indications which options are not yet implemented, and which
74+
parameters are not valid for the DLC 3.0 PyTorch API.
75+
76+
77+
| API Method | Implemented | Parameters not yet implemented | Parameters invalid for pytorch |
78+
|--------------------------------|:-----------:|-----------------------------------------------------------------------------------------------------|-----------------------------------------------------|
79+
| `train_network` | 🟢 | `keepdeconvweights` | `maxiters`, `saveiters`, `allow_growth`, `autotune` |
80+
| `return_train_network_path` | 🟢 | | |
81+
| `evaluate_network` | 🟢 | `comparisonbodyparts`, `rescale`, `per_keypoint_evaluation` | |
82+
| `return_evaluate_network_data` | 🔴 | | `TFGPUinference`, `allow_growth` |
83+
| `analyze_videos` | 🟢 | `in_random_order`, `dynamic`, `n_tracks`, `calibrate` | |
84+
| `create_tracking_dataset` | 🔴 | | |
85+
| `analyze_time_lapse_frames` | 🟠 | the name has changed to `analyze_images` to better reflect what it actually does (no video needed) | |
86+
| `convert_detections2tracklets` | 🟢 | `greedy`, `calibrate`, `window_size` | |
87+
| `extract_maps` | 🟢 | | |
88+
| `visualize_scoremaps` | 🟢 | | |
89+
| `visualize_locrefs` | 🟢 | | |
90+
| `visualize_paf` | 🟢 | | |
91+
| `extract_save_all_maps` | 🟢 | | |
92+
| `export_model` | 🟢 | | |

0 commit comments

Comments
 (0)