1717from enum import Enum
1818from typing import Literal
1919
20+ from deeplabcut .core .config .config_mixin import ConfigMixin
21+
2022
2123class DataLoaderType (str , Enum ):
2224 DLCLoader = "DLCLoader"
2325 COCOLoader = "COCOLoader"
2426
2527
2628@dataclass
27- class DataLoaderConfig :
29+ class DataLoaderConfig ( ConfigMixin ) :
2830 """Base class for data loader configuration.
2931
3032 Attributes:
@@ -33,6 +35,7 @@ class DataLoaderConfig:
3335
3436 type : str
3537
38+
3639@dataclass
3740class DLCLoaderConfig (DataLoaderConfig ):
3841 """Configuration for DeepLabCut Loader.
@@ -44,6 +47,7 @@ class DLCLoaderConfig(DataLoaderConfig):
4447 shuffle: Index of the shuffle for which to load data
4548 modelprefix: The modelprefix for the shuffle
4649 """
50+
4751 type : Literal [DataLoaderType .DLCLoader ]
4852 config : str | dict
4953 trainset_index : int = 0
@@ -63,6 +67,7 @@ class COCOLoaderConfig(DataLoaderConfig):
6367 test_json_path: Path of the json file containing the test annotations relative to project_root
6468 val_json_path: Path of the json file containing the validation annotations relative to project_root
6569 """
70+
6671 type : Literal [DataLoaderType .COCOLoader ]
6772 project_root : str = ""
6873 train_json_path : str = "train.json"
@@ -71,13 +76,13 @@ class COCOLoaderConfig(DataLoaderConfig):
7176
7277
7378@dataclass
74- class DataTransformationConfig :
79+ class DataTransformationConfig ( ConfigMixin ) :
7580 """Data transformation configuration.
7681
7782 Attributes:
7883 resize: Resize transformation configuration
7984 longest_max_size: Maximum size for longest edge
80- hflip: Horizontal flip configuration
85+ hflip: Horizontal flip configuration
8186 affine: Affine transformation configuration
8287 random_bbox_transform: Random bbox transformation configuration
8388 crop_sampling: Crop sampling configuration
@@ -112,8 +117,9 @@ class DataTransformationConfig:
112117 top_down_crop : dict | None = None
113118 collate : dict | None = None
114119
120+
115121@dataclass (frozen = True )
116- class GenSamplingConfig :
122+ class GenSamplingConfig ( ConfigMixin ) :
117123 """Configuration for CTD models.
118124
119125 Args:
@@ -150,8 +156,9 @@ def to_dict(self) -> dict:
150156 "miss_prob" : self .miss_prob ,
151157 }
152158
159+
153160@dataclass
154- class DataConfig :
161+ class DataConfig ( ConfigMixin ) :
155162 """Complete data configuration.
156163
157164 Attributes:
@@ -168,10 +175,13 @@ class DataConfig:
168175 gen_sampling : GenSamplingConfig | None = None
169176 inference : DataTransformationConfig | None = None
170177 train : DataTransformationConfig | None = None
171- loader : DLCLoaderConfig | COCOLoaderConfig | None = Field (default = None , discriminator = "type" )
178+ loader : DLCLoaderConfig | COCOLoaderConfig | None = Field (
179+ default = None , discriminator = "type"
180+ )
172181
173- @field_validator (' train' , ' inference' , mode = ' before' )
182+ @field_validator (" train" , " inference" , mode = " before" )
174183 @classmethod
175184 def validate_transforms (cls , v ):
176185 from deeplabcut .pose_estimation_pytorch .data import build_transforms
177- build_transforms (v )
186+
187+ build_transforms (v )
0 commit comments