Skip to content

Commit 5c950a5

Browse files
committed
add pytorch configs as pydantic dataclasses
1 parent e32c6c3 commit 5c950a5

11 files changed

Lines changed: 191 additions & 56 deletions

File tree

deeplabcut/core/config/project_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from pydantic.dataclasses import dataclass
1616
from dataclasses import field
1717

18+
from deeplabcut.core.config.config_mixin import ConfigMixin
19+
1820

1921
@dataclass
2022
class ProjectConfig(ConfigMixin):

deeplabcut/core/weight_init.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
import numpy as np
2020

2121
from deeplabcut.core.types import PydanticNDArray
22-
22+
from deeplabcut.core.config import ConfigMixin
2323

2424
@dataclass
25-
class WeightInitialization:
25+
class WeightInitialization(ConfigMixin):
2626
"""Configures weights initialization when transfer learning or fine-tuning models
2727
2828
Args:
@@ -128,6 +128,7 @@ def from_dict(data: dict) -> "WeightInitialization":
128128
@staticmethod
129129
def from_dict_legacy(data: dict) -> "WeightInitialization":
130130
"""Deals with weight initialization that were created before 3.0.0rc5"""
131+
131132
import deeplabcut.pose_estimation_pytorch.modelzoo.utils as utils
132133

133134
conversion_array = data.get("conversion_array")
@@ -189,14 +190,16 @@ def build(
189190
Returns:
190191
The built WeightInitialization.
191192
"""
192-
from deeplabcut.modelzoo import build_weight_init
193+
193194
deprecation_warning = (
194195
"The `WeightInitialization.build` is deprecated and will be removed in a "
195196
"future version of DeepLabCut. Please use `build_weight_init` from "
196197
"`deeplabcut.modelzoo` instead."
197198
)
198199
warnings.warn(deprecation_warning, DeprecationWarning)
199200

201+
from deeplabcut.modelzoo import build_weight_init
202+
200203
return build_weight_init(
201204
cfg,
202205
super_animal,

deeplabcut/pose_estimation_pytorch/config/data.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@
1717
from enum import Enum
1818
from typing import Literal
1919

20+
from deeplabcut.core.config.config_mixin import ConfigMixin
21+
2022

2123
class DataLoaderType(str, Enum):
2224
DLCLoader = "DLCLoader"
2325
COCOLoader = "COCOLoader"
2426

2527

2628
@dataclass
27-
class DataLoaderConfig:
29+
class DataLoaderConfig(ConfigMixin):
2830
"""Base class for data loader configuration.
2931
3032
Attributes:
@@ -33,6 +35,7 @@ class DataLoaderConfig:
3335

3436
type: str
3537

38+
3639
@dataclass
3740
class DLCLoaderConfig(DataLoaderConfig):
3841
"""Configuration for DeepLabCut Loader.
@@ -44,6 +47,7 @@ class DLCLoaderConfig(DataLoaderConfig):
4447
shuffle: Index of the shuffle for which to load data
4548
modelprefix: The modelprefix for the shuffle
4649
"""
50+
4751
type: Literal[DataLoaderType.DLCLoader]
4852
config: str | dict
4953
trainset_index: int = 0
@@ -63,6 +67,7 @@ class COCOLoaderConfig(DataLoaderConfig):
6367
test_json_path: Path of the json file containing the test annotations relative to project_root
6468
val_json_path: Path of the json file containing the validation annotations relative to project_root
6569
"""
70+
6671
type: Literal[DataLoaderType.COCOLoader]
6772
project_root: str = ""
6873
train_json_path: str = "train.json"
@@ -71,13 +76,13 @@ class COCOLoaderConfig(DataLoaderConfig):
7176

7277

7378
@dataclass
74-
class DataTransformationConfig:
79+
class DataTransformationConfig(ConfigMixin):
7580
"""Data transformation configuration.
7681
7782
Attributes:
7883
resize: Resize transformation configuration
7984
longest_max_size: Maximum size for longest edge
80-
hflip: Horizontal flip configuration
85+
hflip: Horizontal flip configuration
8186
affine: Affine transformation configuration
8287
random_bbox_transform: Random bbox transformation configuration
8388
crop_sampling: Crop sampling configuration
@@ -112,8 +117,9 @@ class DataTransformationConfig:
112117
top_down_crop: dict | None = None
113118
collate: dict | None = None
114119

120+
115121
@dataclass(frozen=True)
116-
class GenSamplingConfig:
122+
class GenSamplingConfig(ConfigMixin):
117123
"""Configuration for CTD models.
118124
119125
Args:
@@ -150,8 +156,9 @@ def to_dict(self) -> dict:
150156
"miss_prob": self.miss_prob,
151157
}
152158

159+
153160
@dataclass
154-
class DataConfig:
161+
class DataConfig(ConfigMixin):
155162
"""Complete data configuration.
156163
157164
Attributes:
@@ -168,10 +175,13 @@ class DataConfig:
168175
gen_sampling: GenSamplingConfig | None = None
169176
inference: DataTransformationConfig | None = None
170177
train: DataTransformationConfig | None = None
171-
loader: DLCLoaderConfig | COCOLoaderConfig | None = Field(default=None, discriminator="type")
178+
loader: DLCLoaderConfig | COCOLoaderConfig | None = Field(
179+
default=None, discriminator="type"
180+
)
172181

173-
@field_validator('train', 'inference', mode='before')
182+
@field_validator("train", "inference", mode="before")
174183
@classmethod
175184
def validate_transforms(cls, v):
176185
from deeplabcut.pose_estimation_pytorch.data import build_transforms
177-
build_transforms(v)
186+
187+
build_transforms(v)

deeplabcut/pose_estimation_pytorch/config/inference.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
from dataclasses import field
1616
from typing import Any
1717

18+
from deeplabcut.core.config.config_mixin import ConfigMixin
19+
1820

1921
@dataclass
20-
class MultithreadingConfig:
22+
class MultithreadingConfig(ConfigMixin):
2123
"""Multithreading configuration for inference.
2224
2325
Attributes:
@@ -32,7 +34,7 @@ class MultithreadingConfig:
3234

3335

3436
@dataclass
35-
class CompileConfig:
37+
class CompileConfig(ConfigMixin):
3638
"""Model compilation configuration for inference optimization.
3739
3840
Attributes:
@@ -45,7 +47,7 @@ class CompileConfig:
4547

4648

4749
@dataclass
48-
class AutocastConfig:
50+
class AutocastConfig(ConfigMixin):
4951
"""Automatic mixed precision configuration.
5052
5153
Attributes:
@@ -55,30 +57,33 @@ class AutocastConfig:
5557

5658
enabled: bool = False
5759

60+
5861
@dataclass
59-
class EvaluationConfig:
62+
class EvaluationConfig(ConfigMixin):
6063
"""Configuration for evaluation metrics computation.
61-
64+
6265
Attributes:
6366
pcutoff: Confidence threshold for RMSE computation. Can be:
6467
- float: Single threshold for all bodyparts
6568
- list[float]: One value per bodypart (and unique bodypart if any)
6669
- dict[str, float]: Mapping bodypart names to thresholds
6770
comparison_bodyparts: Subset of bodyparts to compute metrics for.
6871
Can be "all", None (all bodyparts), or a list of bodypart names.
69-
per_keypoint_evaluation: Whether to compute train and test RMSE
72+
per_keypoint_evaluation: Whether to compute train and test RMSE
7073
for each keypoint individually.
71-
force_multi_animal: If True, use multi-animal evaluation even if
74+
force_multi_animal: If True, use multi-animal evaluation even if
7275
loader contains only a single animal.
7376
"""
77+
7478
mode: Literal["train", "test", "all"] = "all"
7579
pcutoff: float | list[float] | dict[str, float] = 0.6
7680
comparison_bodyparts: Literal["all"] | list[str] | None = "all"
7781
per_keypoint_evaluation: bool = False
7882
force_multi_animal: bool = False
7983

84+
8085
@dataclass
81-
class InferenceConfig:
86+
class InferenceConfig(ConfigMixin):
8287
"""Complete inference configuration.
8388
8489
Attributes:

deeplabcut/pose_estimation_pytorch/config/logger.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,32 @@
1414
from pydantic.dataclasses import dataclass
1515
from enum import Enum
1616

17+
from deeplabcut.core.config.config_mixin import ConfigMixin
18+
19+
1720
class LoggerType(str, Enum):
1821
WandbLogger = "WandbLogger"
1922
CSVLogger = "CSVLogger"
2023

24+
2125
@dataclass
22-
class LoggerConfig:
26+
class LoggerConfig(ConfigMixin):
2327
"""Base configuration for all loggers.
24-
28+
2529
Attributes:
2630
type: The type of logger to use (WandbLogger or CSVLogger)
2731
"""
32+
2833
type: str
2934

3035

3136
@dataclass
32-
class WandbLoggerConfig(LoggerConfig):
37+
class WandbLoggerConfig(LoggerConfig): #
3338
"""Configuration for Weights & Biases (wandb) logger.
34-
39+
3540
This logger tracks experiments and logs data to Weights & Biases.
3641
Refer to: https://docs.wandb.ai/guides for more information.
37-
42+
3843
Attributes:
3944
type: Logger type (should be 'WandbLogger')
4045
project_name: The name of the wandb project
@@ -45,6 +50,7 @@ class WandbLoggerConfig(LoggerConfig):
4550
train_folder: The path of the folder containing training files.
4651
wandb_kwargs: Additional keyword arguments to pass to wandb.init
4752
"""
53+
4854
type: Literal[LoggerType.WandbLogger]
4955
project_name: str = "deeplabcut"
5056
run_name: str = "tmp"
@@ -55,16 +61,17 @@ class WandbLoggerConfig(LoggerConfig):
5561

5662

5763
@dataclass
58-
class CSVLoggerConfig(LoggerConfig):
64+
class CSVLoggerConfig(LoggerConfig): #
5965
"""Configuration for CSV logger.
60-
66+
6167
This logger saves training stats and metrics to a CSV file.
62-
68+
6369
Attributes:
6470
type: Logger type (should be 'CSVLogger')
6571
train_folder: The path of the folder containing training files.
6672
log_filename: The name of the file in which to store training stats
6773
"""
74+
6875
type: Literal[LoggerType.CSVLogger]
6976
train_folder: str = ""
7077
log_filename: str = "learning_stats.csv"

deeplabcut/pose_estimation_pytorch/config/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
from pydantic.dataclasses import dataclass
1414
from dataclasses import field
1515

16+
from deeplabcut.core.config.config_mixin import ConfigMixin
17+
1618

1719
@dataclass
18-
class ModelConfig:
20+
class ModelConfig(ConfigMixin):
1921
"""Complete model configuration.
2022
2123
Attributes:
@@ -34,7 +36,7 @@ class ModelConfig:
3436

3537

3638
@dataclass
37-
class DetectorModelConfig:
39+
class DetectorModelConfig(ConfigMixin):
3840
"""Configuration for detector models
3941
4042
Attributes:

deeplabcut/pose_estimation_pytorch/config/pose.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,20 @@
1515
from pydantic import Field
1616
from enum import Enum
1717

18+
from deeplabcut.core.config.config_mixin import ConfigMixin
1819
from deeplabcut.pose_estimation_pytorch.config.project import ProjectConfig
1920
from deeplabcut.pose_estimation_pytorch.config.data import DataConfig
2021
from deeplabcut.pose_estimation_pytorch.config.training import TrainSettingsConfig
2122
from deeplabcut.pose_estimation_pytorch.config.runner import RunnerConfig
2223
from deeplabcut.pose_estimation_pytorch.config.inference import InferenceConfig
23-
from deeplabcut.pose_estimation_pytorch.config.model import ModelConfig, DetectorModelConfig
24-
from deeplabcut.pose_estimation_pytorch.config.logger import CSVLoggerConfig, WandbLoggerConfig
24+
from deeplabcut.pose_estimation_pytorch.config.model import (
25+
ModelConfig,
26+
DetectorModelConfig,
27+
)
28+
from deeplabcut.pose_estimation_pytorch.config.logger import (
29+
CSVLoggerConfig,
30+
WandbLoggerConfig,
31+
)
2532

2633

2734
class MethodType(str, Enum):
@@ -83,7 +90,7 @@ class NetType(str, Enum):
8390

8491

8592
@dataclass
86-
class DetectorConfig:
93+
class DetectorConfig(ConfigMixin):
8794
model: DetectorModelConfig
8895
device: str = "auto"
8996
data: DataConfig | None = None
@@ -93,7 +100,7 @@ class DetectorConfig:
93100

94101

95102
@dataclass
96-
class PoseConfig:
103+
class PoseConfig(ConfigMixin):
97104
"""Main configuration class for DeepLabCut pose estimation models.
98105
99106
This is the top-level configuration that brings together all the different
@@ -121,7 +128,9 @@ class PoseConfig:
121128
metadata: ProjectConfig | None = None
122129
data: DataConfig | None = None
123130
inference: InferenceConfig = field(default_factory=InferenceConfig)
124-
logger: CSVLoggerConfig | WandbLoggerConfig | None = Field(default=None, discriminator="type")
131+
logger: CSVLoggerConfig | WandbLoggerConfig | None = Field(
132+
default=None, discriminator="type"
133+
)
125134
with_center_keypoints: bool = False
126135
runner: RunnerConfig | None = None
127136
train_settings: TrainSettingsConfig | None = None

0 commit comments

Comments
 (0)