Skip to content

Commit 3d74d33

Browse files
jeylauAlexEMG
andauthored
Smart, keypoint-aware image cropping augmentation (#1334)
* Implement custom imgaug augmenter for smart cropping fully integrated in pipeline * Update maDLC_UserGuide.md Co-authored-by: Alexander Mathis <alexander@deeplabcut.org>
1 parent 596db65 commit 3d74d33

28 files changed

Lines changed: 733 additions & 908 deletions

deeplabcut/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060
from deeplabcut.generate_training_dataset import (
6161
create_training_model_comparison,
6262
create_multianimaltraining_dataset,
63-
cropimagesandlabels,
6463
)
6564
from deeplabcut.generate_training_dataset import (
6665
dropannotationfileentriesduetodeletedimages,

deeplabcut/create_project/new.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ def create_new_project(
220220
cfg_file["skeleton"] = [["bodypart1", "bodypart2"], ["objectA", "bodypart3"]]
221221
cfg_file["default_augmenter"] = "default"
222222
cfg_file["default_net_type"] = "resnet_50"
223-
cfg_file["croppedtraining"] = False
224223

225224
# common parameters:
226225
cfg_file["Task"] = project

deeplabcut/generate_training_dataset/multiple_individuals_trainingsetmanipulation.py

Lines changed: 145 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
MakeTrain_pose_yaml,
2525
MakeTest_pose_yaml,
2626
MakeInference_yaml,
27+
pad_train_test_indices,
2728
)
2829
from deeplabcut.utils import auxiliaryfunctions, auxfun_models, auxfun_multianimal
2930

@@ -99,6 +100,8 @@ def create_multianimaltraining_dataset(
99100
windows2linux=False,
100101
net_type=None,
101102
numdigits=2,
103+
crop_size=(400, 400),
104+
crop_sampling="hybrid",
102105
paf_graph=None,
103106
trainIndices=None,
104107
testIndices=None,
@@ -133,6 +136,18 @@ def create_multianimaltraining_dataset(
133136
134137
numdigits: int, optional
135138
139+
crop_size: tuple of int, optional
140+
Dimensions (width, height) of the crops for data augmentation.
141+
Default is 400x400.
142+
143+
crop_sampling: str, optional
144+
Crop centers sampling method. Must be either:
145+
"uniform" (randomly over the image),
146+
"keypoints" (randomly over the annotated keypoints),
147+
"density" (weighing preferentially dense regions of keypoints),
148+
or "hybrid" (alternating randomly between "uniform" and "density").
149+
Default is "hybrid".
150+
136151
paf_graph: list of lists, optional (default=None)
137152
If not None, overwrite the default complete graph. This is useful for advanced users who
138153
already know a good graph, or simply want to use a specific one. Note that, in that case,
@@ -155,6 +170,12 @@ def create_multianimaltraining_dataset(
155170
>>> deeplabcut.create_multianimaltraining_dataset(r'C:\\Users\\Ulf\\looming-task\\config.yaml',Shuffles=[3,17,5])
156171
--------
157172
"""
173+
if len(crop_size) != 2 or not all(isinstance(v, int) for v in crop_size):
174+
raise ValueError("Crop size must be a tuple of two integers (width, height).")
175+
176+
if crop_sampling not in ("uniform", "keypoints", "density", "hybrid"):
177+
raise ValueError(f"Invalid sampling {crop_sampling}. Must be "
178+
f"either 'uniform', 'keypoints', 'density', or 'hybrid.")
158179

159180
# Loading metadata from config file:
160181
cfg = auxiliaryfunctions.read_config(config)
@@ -170,15 +191,6 @@ def create_multianimaltraining_dataset(
170191
return
171192
Data = Data[scorer]
172193

173-
def strip_cropped_image_name(path):
174-
# utility function to split different crops from same image into either train or test!
175-
head, filename = os.path.split(path)
176-
if cfg["croppedtraining"]:
177-
filename = filename.split("c")[0]
178-
return os.path.join(head, filename)
179-
180-
img_names = Data.index.map(strip_cropped_image_name).unique()
181-
182194
if net_type is None: # loading & linking pretrained models
183195
net_type = cfg.get("default_net_type", "dlcrnet_ms5")
184196
elif not any(net in net_type for net in ("resnet", "eff", "dlc", "mob")):
@@ -236,19 +248,12 @@ def strip_cropped_image_name(path):
236248
if trainIndices is None and testIndices is None:
237249
splits = []
238250
for shuffle in Shuffles: # Creating shuffles starting from 1
239-
for trainFraction in cfg["TrainingFraction"]:
240-
train_inds_temp, test_inds_temp = SplitTrials(
241-
range(len(img_names)), trainFraction
251+
for train_frac in cfg["TrainingFraction"]:
252+
train_inds, test_inds = SplitTrials(
253+
range(len(Data)), train_frac
242254
)
243-
# Map back to the original indices.
244-
temp = [re.escape(name) for i, name in enumerate(img_names)
245-
if i in test_inds_temp]
246-
mask = Data.index.str.contains("|".join(temp))
247-
testIndices = np.flatnonzero(mask)
248-
trainIndices = np.flatnonzero(~mask)
249-
250255
splits.append(
251-
(trainFraction, shuffle, (trainIndices, testIndices))
256+
(train_frac, shuffle, (train_inds, test_inds))
252257
)
253258
else:
254259
if len(trainIndices) != len(testIndices) != len(Shuffles):
@@ -265,6 +270,12 @@ def strip_cropped_image_name(path):
265270
print(
266271
f"You passed a split with the following fraction: {int(100 * trainFraction)}%"
267272
)
273+
# Now that the training fraction is guaranteed to be correct,
274+
# the values added to pad the indices are removed.
275+
train_inds = np.asarray(train_inds)
276+
train_inds = train_inds[train_inds != -1]
277+
test_inds = np.asarray(test_inds)
278+
test_inds = test_inds[test_inds != -1]
268279
splits.append(
269280
(trainFraction, Shuffles[shuffle], (train_inds, test_inds))
270281
)
@@ -387,6 +398,8 @@ def strip_cropped_image_name(path):
387398
"num_idchannel": len(cfg["individuals"])
388399
if cfg.get("identity", False)
389400
else 0,
401+
"crop_size": list(crop_size),
402+
"crop_sampling": crop_sampling,
390403
}
391404

392405
trainingdata = MakeTrain_pose_yaml(
@@ -441,3 +454,115 @@ def strip_cropped_image_name(path):
441454
)
442455
else:
443456
pass
457+
458+
459+
def convert_cropped_to_standard_dataset(
460+
config_path,
461+
recreate_datasets=True,
462+
delete_crops=True,
463+
back_up=True,
464+
):
465+
import pandas as pd
466+
import pickle
467+
import shutil
468+
from deeplabcut.generate_training_dataset import trainingsetmanipulation
469+
from deeplabcut.utils import read_plainconfig, write_config
470+
471+
cfg = auxiliaryfunctions.read_config(config_path)
472+
videos_orig = cfg.pop("video_sets_original")
473+
is_cropped = cfg.pop("croppedtraining")
474+
if videos_orig is None or not is_cropped:
475+
print("Labeled data do not appear to be cropped. "
476+
"Project will remain unchanged...")
477+
return
478+
479+
project_path = cfg["project_path"]
480+
481+
if back_up:
482+
print("Backing up project...")
483+
shutil.copytree(project_path, project_path + "_bak", symlinks=True)
484+
485+
if delete_crops:
486+
print("Deleting crops...")
487+
data_path = os.path.join(project_path, "labeled-data")
488+
for video in cfg["video_sets"]:
489+
_, filename, _ = trainingsetmanipulation._robust_path_split(video)
490+
if "_cropped" in video: # One can never be too safe...
491+
shutil.rmtree(os.path.join(data_path, filename), ignore_errors=True)
492+
493+
cfg["video_sets"] = videos_orig
494+
write_config(config_path, cfg)
495+
496+
if not recreate_datasets:
497+
return
498+
499+
datasets_folder = os.path.join(
500+
project_path, auxiliaryfunctions.GetTrainingSetFolder(cfg),
501+
)
502+
df_old = pd.read_hdf(
503+
os.path.join(datasets_folder, "CollectedData_" + cfg["scorer"] + ".h5"),
504+
)
505+
506+
def strip_cropped_image_name(path):
507+
head, filename = os.path.split(path)
508+
head = head.replace("_cropped", "")
509+
file, ext = filename.split(".")
510+
file = file.split("c")[0]
511+
return os.path.join(head, file + "." + ext)
512+
513+
img_names_old = np.asarray(
514+
[strip_cropped_image_name(img) for img in df_old.index.to_list()]
515+
)
516+
df = merge_annotateddatasets(cfg, datasets_folder, False)
517+
img_names = df.index.to_numpy()
518+
train_idx = []
519+
test_idx = []
520+
pickle_files = []
521+
for filename in os.listdir(datasets_folder):
522+
if filename.endswith("pickle"):
523+
pickle_file = os.path.join(datasets_folder, filename)
524+
pickle_files.append(pickle_file)
525+
if filename.startswith("Docu"):
526+
with open(pickle_file, "rb") as f:
527+
_, train_inds, test_inds, train_frac = pickle.load(f)
528+
train_inds_temp = np.flatnonzero(
529+
np.isin(img_names, img_names_old[train_inds])
530+
)
531+
test_inds_temp = np.flatnonzero(
532+
np.isin(img_names, img_names_old[test_inds])
533+
)
534+
train_inds, test_inds = pad_train_test_indices(
535+
train_inds_temp, test_inds_temp, train_frac
536+
)
537+
train_idx.append(train_inds)
538+
test_idx.append(test_inds)
539+
540+
# Search a pose_config.yaml file to parse missing information
541+
pose_config_path = ""
542+
for dirpath, dirnames, filenames in os.walk(
543+
os.path.join(project_path, "dlc-models")
544+
):
545+
for file in filenames:
546+
if file.endswith("pose_cfg.yaml"):
547+
pose_config_path = os.path.join(dirpath, file)
548+
break
549+
pose_cfg = read_plainconfig(pose_config_path)
550+
net_type = pose_cfg["net_type"]
551+
if net_type == "resnet_50" and pose_cfg.get("multi_stage", False):
552+
net_type = "dlcrnet_ms5"
553+
554+
# Clean the training-datasets folder prior to recreating the data pickles
555+
shuffle_inds = set()
556+
for file in pickle_files:
557+
os.remove(file)
558+
shuffle_inds.add(int(re.findall(r"shuffle(\d+)", file)[0]))
559+
create_multianimaltraining_dataset(
560+
config_path,
561+
trainIndices=train_idx,
562+
testIndices=test_idx,
563+
Shuffles=sorted(shuffle_inds),
564+
net_type=net_type,
565+
paf_graph=pose_cfg["partaffinityfield_graph"],
566+
crop_size=pose_cfg.get("crop_size", [400, 400]),
567+
crop_sampling=pose_cfg.get("crop_sampling", "hybrid"),
568+
)

0 commit comments

Comments
 (0)