Skip to content

Commit 6c3f4b3

Browse files
authored
Replace configuration dictionaries with DictConfigs.
Part 3. of migration to typed configs, see issue #3193
2 parents db57e94 + 94cdc66 commit 6c3f4b3

10 files changed

Lines changed: 324 additions & 235 deletions

File tree

deeplabcut/core/config/config_mixin.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import Callable
1+
from typing import Callable, Mapping
22
from typing_extensions import Self
33
from pathlib import Path
44
from dataclasses import asdict, fields
5+
from enum import Enum
56

67
from omegaconf import OmegaConf, DictConfig
78
from pydantic import TypeAdapter
@@ -77,11 +78,29 @@ def from_any(
7778
)
7879

7980
@classmethod
80-
def from_yaml(cls, yaml_path: str | Path) -> Self:
81-
return cls.from_dict(read_config_as_dict(yaml_path))
81+
def from_yaml(cls, yaml_path: str | Path, ignore_empty: bool = True) -> Self:
82+
"""
83+
Load a configuration from a YAML file.
84+
85+
Args:
86+
yaml_path: Path to the YAML configuration file.
87+
ignore_empty: If True, empty/None values in the YAML are ignored and
88+
dataclass defaults are used instead. Defaults to True.
89+
90+
Returns:
91+
A new instance of the configuration class.
92+
"""
93+
# NOTE @deruyter92 2026-02-05: Default ignore_empty is now set to True to match
94+
# the prior behaviour of read_config. We should consider changing this to False
95+
# for stricter validation.
96+
yaml_dict = read_config_as_dict(yaml_path)
97+
if ignore_empty:
98+
yaml_dict = {k: v for k, v in yaml_dict.items() if v is not None}
99+
return cls.from_dict(yaml_dict)
82100

83101
def to_yaml(self, yaml_path: str | Path, overwrite: bool = True) -> None:
84-
data = CommentedMap(self.to_dict())
102+
dict_data = self.to_dict_normalized()
103+
data = CommentedMap(dict_data)
85104
for f in fields(self):
86105
if (comment := f.metadata.get("comment")):
87106
data.yaml_set_comment_before_after_key(f.name, before=comment)
@@ -90,6 +109,9 @@ def to_yaml(self, yaml_path: str | Path, overwrite: bool = True) -> None:
90109
def to_dict(self) -> dict:
91110
return asdict(self)
92111

112+
def to_dict_normalized(self) -> dict:
113+
return _normalize_for_serialization(self.to_dict())
114+
93115
def to_dictconfig(self) -> DictConfig:
94116
return OmegaConf.create(self.to_dict())
95117

@@ -99,3 +121,14 @@ def print(
99121
print_fn: Callable[[str], None] | None = None,
100122
) -> None:
101123
pretty_print(config=self.to_dict(), indent=indent, print_fn=print_fn)
124+
125+
126+
def _normalize_for_serialization(obj):
127+
"""Recursively normalize Paths to strings and Enums to values."""
128+
if isinstance(obj, Path):
129+
return str(obj)
130+
elif isinstance(obj, Enum):
131+
return obj.value
132+
elif isinstance(obj, Mapping):
133+
return type(obj)({k: _normalize_for_serialization(v) for k, v in obj.items()})
134+
return obj

deeplabcut/core/config/project_config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"""Project configuration classes for DeepLabCut pose estimation models."""
1212

1313
from typing import Any
14+
from pathlib import Path
1415

1516
from pydantic.dataclasses import dataclass
1617
from dataclasses import field
@@ -76,8 +77,8 @@ class ProjectConfig(ConfigMixin):
7677
identity: bool | None = None
7778

7879
# Project path
79-
project_path: str = field(default="", metadata={"comment": "\nProject path (change when moving around)"})
80-
pose_config_path: str = ""
80+
project_path: Path = field(default=Path(), metadata={"comment": "\nProject path (change when moving around)"})
81+
pose_config_path: Path = Path()
8182

8283
# Engine
8384
engine: str = field(
@@ -149,7 +150,7 @@ class ProjectConfig(ConfigMixin):
149150
y2: int | None = None
150151

151152
# Refinement configuration (parameters from annotation dataset configuration also relevant in this stage)
152-
corner2move2: list[list[str]] | None = field(
153+
corner2move2: list[int] | None = field(
153154
default=None,
154155
metadata={"comment": "\nRefinement configuration (parameters from annotation dataset configuration also relevant in this stage)"},
155156
)

deeplabcut/core/config/utils.py

Lines changed: 52 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,16 @@
1313

1414
import warnings
1515
from pathlib import Path
16-
from typing import Callable
16+
from typing import TYPE_CHECKING, Callable
17+
18+
if TYPE_CHECKING:
19+
from deeplabcut.core.config.project_config import ProjectConfig, ProjectConfig3D
1720

1821
import yaml
1922
import ruamel.yaml.representer
2023
from ruamel.yaml import YAML
24+
from omegaconf import DictConfig
25+
from pydantic import ValidationError
2126

2227
from deeplabcut.core.engine import Engine
2328

@@ -97,126 +102,10 @@ def create_config_template(multianimal: bool = False) -> tuple:
97102
Returns:
98103
(cfg_file, ruamelFile) for further editing and dumping.
99104
"""
100-
if multianimal:
101-
yaml_str = """\
102-
# Project definitions (do not edit)
103-
Task:
104-
scorer:
105-
date:
106-
multianimalproject:
107-
identity:
108-
\n
109-
# Project path (change when moving around)
110-
project_path:
111-
\n
112-
# Default DeepLabCut engine to use for shuffle creation (either pytorch or tensorflow)
113-
engine: pytorch
114-
\n
115-
# Annotation data set configuration (and individual video cropping parameters)
116-
video_sets:
117-
individuals:
118-
uniquebodyparts:
119-
multianimalbodyparts:
120-
bodyparts:
121-
\n
122-
# Fraction of video to start/stop when extracting frames for labeling/refinement
123-
start:
124-
stop:
125-
numframes2pick:
126-
\n
127-
# Plotting configuration
128-
skeleton:
129-
skeleton_color:
130-
pcutoff:
131-
dotsize:
132-
alphavalue:
133-
colormap:
134-
\n
135-
# Training,Evaluation and Analysis configuration
136-
TrainingFraction:
137-
iteration:
138-
default_net_type:
139-
default_augmenter:
140-
default_track_method:
141-
snapshotindex:
142-
detector_snapshotindex:
143-
batch_size:
144-
\n
145-
# Cropping Parameters (for analysis and outlier frame detection)
146-
cropping:
147-
#if cropping is true for analysis, then set the values here:
148-
x1:
149-
x2:
150-
y1:
151-
y2:
152-
\n
153-
# Refinement configuration (parameters from annotation dataset configuration also relevant in this stage)
154-
corner2move2:
155-
move2corner:
156-
\n
157-
# Conversion tables to fine-tune SuperAnimal weights
158-
SuperAnimalConversionTables:
159-
"""
160-
else:
161-
yaml_str = """\
162-
# Project definitions (do not edit)
163-
Task:
164-
scorer:
165-
date:
166-
multianimalproject:
167-
identity:
168-
\n
169-
# Project path (change when moving around)
170-
project_path:
171-
\n
172-
# Default DeepLabCut engine to use for shuffle creation (either pytorch or tensorflow)
173-
engine: pytorch
174-
\n
175-
# Annotation data set configuration (and individual video cropping parameters)
176-
video_sets:
177-
bodyparts:
178-
\n
179-
# Fraction of video to start/stop when extracting frames for labeling/refinement
180-
start:
181-
stop:
182-
numframes2pick:
183-
\n
184-
# Plotting configuration
185-
skeleton:
186-
skeleton_color:
187-
pcutoff:
188-
dotsize:
189-
alphavalue:
190-
colormap:
191-
\n
192-
# Training,Evaluation and Analysis configuration
193-
TrainingFraction:
194-
iteration:
195-
default_net_type:
196-
default_augmenter:
197-
snapshotindex:
198-
detector_snapshotindex:
199-
batch_size:
200-
detector_batch_size:
201-
\n
202-
# Cropping Parameters (for analysis and outlier frame detection)
203-
cropping:
204-
#if cropping is true for analysis, then set the values here:
205-
x1:
206-
x2:
207-
y1:
208-
y2:
209-
\n
210-
# Refinement configuration (parameters from annotation dataset configuration also relevant in this stage)
211-
corner2move2:
212-
move2corner:
213-
\n
214-
# Conversion tables to fine-tune SuperAnimal weights
215-
SuperAnimalConversionTables:
216-
"""
217-
105+
warnings.warn("This function is deprecated. Use deeplabcut.core.config.ProjectConfig instead.")
106+
from deeplabcut.core.config.project_config import ProjectConfig
218107
ruamelFile = YAML()
219-
cfg_file = ruamelFile.load(yaml_str)
108+
cfg_file = ProjectConfig(multianimalproject=multianimal).to_dict()
220109
return cfg_file, ruamelFile
221110

222111

@@ -256,52 +145,54 @@ def create_config_template_3d() -> tuple:
256145
return cfg_file_3d, ruamelFile_3d
257146

258147

259-
def read_config(configname: str | Path) -> dict:
148+
def read_config(configname: str | Path, ignore_empty: bool = True) -> DictConfig:
260149
"""
261150
Reads structured config file defining a project.
262-
Applies default values and repairs (engine, detector_snapshotindex, project_path) and writes back if needed.
263-
"""
264-
ruamelFile = YAML()
265-
path = Path(configname)
266-
if path.exists():
267-
try:
268-
with open(path, "r") as f:
269-
cfg = ruamelFile.load(f)
270-
curr_dir = str(Path(configname).parent.resolve())
271-
272-
if cfg.get("engine") is None:
273-
cfg["engine"] = Engine.TF.aliases[0]
274-
write_project_config(configname, cfg)
275-
276-
if cfg.get("detector_snapshotindex") is None:
277-
cfg["detector_snapshotindex"] = -1
278-
279-
if cfg.get("detector_batch_size") is None:
280-
cfg["detector_batch_size"] = 1
281-
282-
if cfg["project_path"] != curr_dir:
283-
cfg["project_path"] = curr_dir
284-
write_project_config(configname, cfg)
285-
except Exception as err:
286-
if len(err.args) > 2:
287-
if (
288-
err.args[2]
289-
== "could not determine a constructor for the tag '!!python/tuple'"
290-
):
291-
with open(path, "r") as ymlfile:
292-
cfg = yaml.load(ymlfile, Loader=yaml.SafeLoader)
293-
write_project_config(configname, cfg)
294-
else:
295-
raise
296-
else:
297-
raise FileNotFoundError(
298-
f"Config file at {path} not found. Please make sure that the file exists and/or that you passed the path of the config file correctly!"
299-
)
300-
return cfg
301151
152+
Applies default values and repairs (engine, detector_snapshotindex, project_path)
153+
and writes back if needed.
302154
303-
def write_project_config(configname: str | Path, cfg: dict) -> None:
155+
Args:
156+
configname: Path to the project configuration file (config.yaml).
157+
ignore_empty: If True, empty/None values in the YAML are ignored and
158+
dataclass defaults are used instead. If False, empty values represent None.
159+
Defaults to True.
160+
161+
Returns:
162+
The project configuration as a DictConfig.
163+
"""
164+
# NOTE @deruyter92 2026-02-05: Default ignore_empty is now set to True to match
165+
# the prior behaviour of read_config. We should consider changing this to False
166+
# for stricter validation.
167+
from deeplabcut.core.config.project_config import ProjectConfig
168+
path = Path(configname)
169+
project_config = ProjectConfig.from_yaml(path, ignore_empty=ignore_empty)
170+
171+
# NOTE @deruyter92 2026-02-02: copied old behaviour of writing the config back to the file.
172+
# We should consider separating the writing and reading instead of having inplace edits during reading.
173+
curr_dir = str(Path(configname).parent.resolve())
174+
if project_config.project_path != curr_dir:
175+
project_config.project_path = curr_dir
176+
project_config.to_yaml(configname)
177+
return project_config.to_dictconfig()
178+
179+
180+
def write_project_config(
181+
configname: str | Path,
182+
cfg: dict | ProjectConfig | DictConfig,
183+
) -> None:
304184
"""Write structured project config file (config.yaml) preserving template order."""
185+
from deeplabcut.core.config.project_config import ProjectConfig
186+
187+
try:
188+
project_config: ProjectConfig = ProjectConfig.from_any(cfg)
189+
project_config.to_yaml(configname)
190+
return
191+
except ValidationError as e:
192+
warnings.warn(
193+
f"Invalid configuration! Validation error in config file {cfg}. Error: {e}"
194+
"Reverting to legacy config file writing."
195+
)
305196
with open(configname, "w") as cf:
306197
cfg_file, ruamelFile = create_config_template(
307198
cfg.get("multianimalproject", False)

deeplabcut/core/types.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,38 @@
1010
GetPydanticSchema(
1111
lambda _s, h: h(InstanceOf[np.ndarray]), lambda _s, h: h(InstanceOf[np.ndarray])
1212
),
13-
]
13+
]
14+
15+
16+
class DeprecatedArgument:
17+
"""Singleton sentinel class for deprecated arguments.
18+
19+
Use this as a default value to distinguish between "argument not provided"
20+
and "argument explicitly set to None".
21+
22+
Usage:
23+
from deeplabcut.core.types import DEPRECATED_ARGUMENT, DeprecatedArgument
24+
25+
def func(old_arg=DEPRECATED_ARGUMENT):
26+
if isinstance(old_arg, DeprecatedArgument):
27+
# old_arg was not provided
28+
else:
29+
# old_arg was explicitly provided (even if None)
30+
"""
31+
32+
__slots__ = ()
33+
_instance: "DeprecatedArgument | None" = None
34+
35+
def __new__(cls) -> "DeprecatedArgument":
36+
if cls._instance is None:
37+
cls._instance = super().__new__(cls)
38+
return cls._instance
39+
40+
def __bool__(self) -> bool:
41+
return False
42+
43+
def __repr__(self) -> str:
44+
return "<deprecated argument>"
45+
46+
47+
DEPRECATED_ARGUMENT = DeprecatedArgument()

0 commit comments

Comments
 (0)