2424 MakeTrain_pose_yaml ,
2525 MakeTest_pose_yaml ,
2626 MakeInference_yaml ,
27+ pad_train_test_indices ,
2728)
2829from deeplabcut .utils import auxiliaryfunctions , auxfun_models , auxfun_multianimal
2930
@@ -99,6 +100,8 @@ def create_multianimaltraining_dataset(
99100 windows2linux = False ,
100101 net_type = None ,
101102 numdigits = 2 ,
103+ crop_size = (400 , 400 ),
104+ crop_sampling = "hybrid" ,
102105 paf_graph = None ,
103106 trainIndices = None ,
104107 testIndices = None ,
@@ -133,6 +136,18 @@ def create_multianimaltraining_dataset(
133136
134137 numdigits: int, optional
135138
139+ crop_size: tuple of int, optional
140+ Dimensions (width, height) of the crops for data augmentation.
141+ Default is 400x400.
142+
143+ crop_sampling: str, optional
144+ Crop centers sampling method. Must be either:
145+ "uniform" (randomly over the image),
146+ "keypoints" (randomly over the annotated keypoints),
147+ "density" (weighing preferentially dense regions of keypoints),
148+ or "hybrid" (alternating randomly between "uniform" and "density").
149+ Default is "hybrid".
150+
136151 paf_graph: list of lists, optional (default=None)
137152 If not None, overwrite the default complete graph. This is useful for advanced users who
138153 already know a good graph, or simply want to use a specific one. Note that, in that case,
@@ -155,6 +170,12 @@ def create_multianimaltraining_dataset(
155170 >>> deeplabcut.create_multianimaltraining_dataset(r'C:\\ Users\\ Ulf\\ looming-task\\ config.yaml',Shuffles=[3,17,5])
156171 --------
157172 """
173+ if len (crop_size ) != 2 or not all (isinstance (v , int ) for v in crop_size ):
174+ raise ValueError ("Crop size must be a tuple of two integers (width, height)." )
175+
176+ if crop_sampling not in ("uniform" , "keypoints" , "density" , "hybrid" ):
177+ raise ValueError (f"Invalid sampling { crop_sampling } . Must be "
178+ f"either 'uniform', 'keypoints', 'density', or 'hybrid." )
158179
159180 # Loading metadata from config file:
160181 cfg = auxiliaryfunctions .read_config (config )
@@ -170,15 +191,6 @@ def create_multianimaltraining_dataset(
170191 return
171192 Data = Data [scorer ]
172193
173- def strip_cropped_image_name (path ):
174- # utility function to split different crops from same image into either train or test!
175- head , filename = os .path .split (path )
176- if cfg ["croppedtraining" ]:
177- filename = filename .split ("c" )[0 ]
178- return os .path .join (head , filename )
179-
180- img_names = Data .index .map (strip_cropped_image_name ).unique ()
181-
182194 if net_type is None : # loading & linking pretrained models
183195 net_type = cfg .get ("default_net_type" , "dlcrnet_ms5" )
184196 elif not any (net in net_type for net in ("resnet" , "eff" , "dlc" , "mob" )):
@@ -236,19 +248,12 @@ def strip_cropped_image_name(path):
236248 if trainIndices is None and testIndices is None :
237249 splits = []
238250 for shuffle in Shuffles : # Creating shuffles starting from 1
239- for trainFraction in cfg ["TrainingFraction" ]:
240- train_inds_temp , test_inds_temp = SplitTrials (
241- range (len (img_names )), trainFraction
251+ for train_frac in cfg ["TrainingFraction" ]:
252+ train_inds , test_inds = SplitTrials (
253+ range (len (Data )), train_frac
242254 )
243- # Map back to the original indices.
244- temp = [re .escape (name ) for i , name in enumerate (img_names )
245- if i in test_inds_temp ]
246- mask = Data .index .str .contains ("|" .join (temp ))
247- testIndices = np .flatnonzero (mask )
248- trainIndices = np .flatnonzero (~ mask )
249-
250255 splits .append (
251- (trainFraction , shuffle , (trainIndices , testIndices ))
256+ (train_frac , shuffle , (train_inds , test_inds ))
252257 )
253258 else :
254259 if len (trainIndices ) != len (testIndices ) != len (Shuffles ):
@@ -265,6 +270,12 @@ def strip_cropped_image_name(path):
265270 print (
266271 f"You passed a split with the following fraction: { int (100 * trainFraction )} %"
267272 )
273+ # Now that the training fraction is guaranteed to be correct,
274+ # the values added to pad the indices are removed.
275+ train_inds = np .asarray (train_inds )
276+ train_inds = train_inds [train_inds != - 1 ]
277+ test_inds = np .asarray (test_inds )
278+ test_inds = test_inds [test_inds != - 1 ]
268279 splits .append (
269280 (trainFraction , Shuffles [shuffle ], (train_inds , test_inds ))
270281 )
@@ -387,6 +398,8 @@ def strip_cropped_image_name(path):
387398 "num_idchannel" : len (cfg ["individuals" ])
388399 if cfg .get ("identity" , False )
389400 else 0 ,
401+ "crop_size" : list (crop_size ),
402+ "crop_sampling" : crop_sampling ,
390403 }
391404
392405 trainingdata = MakeTrain_pose_yaml (
@@ -441,3 +454,115 @@ def strip_cropped_image_name(path):
441454 )
442455 else :
443456 pass
457+
458+
459+ def convert_cropped_to_standard_dataset (
460+ config_path ,
461+ recreate_datasets = True ,
462+ delete_crops = True ,
463+ back_up = True ,
464+ ):
465+ import pandas as pd
466+ import pickle
467+ import shutil
468+ from deeplabcut .generate_training_dataset import trainingsetmanipulation
469+ from deeplabcut .utils import read_plainconfig , write_config
470+
471+ cfg = auxiliaryfunctions .read_config (config_path )
472+ videos_orig = cfg .pop ("video_sets_original" )
473+ is_cropped = cfg .pop ("croppedtraining" )
474+ if videos_orig is None or not is_cropped :
475+ print ("Labeled data do not appear to be cropped. "
476+ "Project will remain unchanged..." )
477+ return
478+
479+ project_path = cfg ["project_path" ]
480+
481+ if back_up :
482+ print ("Backing up project..." )
483+ shutil .copytree (project_path , project_path + "_bak" , symlinks = True )
484+
485+ if delete_crops :
486+ print ("Deleting crops..." )
487+ data_path = os .path .join (project_path , "labeled-data" )
488+ for video in cfg ["video_sets" ]:
489+ _ , filename , _ = trainingsetmanipulation ._robust_path_split (video )
490+ if "_cropped" in video : # One can never be too safe...
491+ shutil .rmtree (os .path .join (data_path , filename ), ignore_errors = True )
492+
493+ cfg ["video_sets" ] = videos_orig
494+ write_config (config_path , cfg )
495+
496+ if not recreate_datasets :
497+ return
498+
499+ datasets_folder = os .path .join (
500+ project_path , auxiliaryfunctions .GetTrainingSetFolder (cfg ),
501+ )
502+ df_old = pd .read_hdf (
503+ os .path .join (datasets_folder , "CollectedData_" + cfg ["scorer" ] + ".h5" ),
504+ )
505+
506+ def strip_cropped_image_name (path ):
507+ head , filename = os .path .split (path )
508+ head = head .replace ("_cropped" , "" )
509+ file , ext = filename .split ("." )
510+ file = file .split ("c" )[0 ]
511+ return os .path .join (head , file + "." + ext )
512+
513+ img_names_old = np .asarray (
514+ [strip_cropped_image_name (img ) for img in df_old .index .to_list ()]
515+ )
516+ df = merge_annotateddatasets (cfg , datasets_folder , False )
517+ img_names = df .index .to_numpy ()
518+ train_idx = []
519+ test_idx = []
520+ pickle_files = []
521+ for filename in os .listdir (datasets_folder ):
522+ if filename .endswith ("pickle" ):
523+ pickle_file = os .path .join (datasets_folder , filename )
524+ pickle_files .append (pickle_file )
525+ if filename .startswith ("Docu" ):
526+ with open (pickle_file , "rb" ) as f :
527+ _ , train_inds , test_inds , train_frac = pickle .load (f )
528+ train_inds_temp = np .flatnonzero (
529+ np .isin (img_names , img_names_old [train_inds ])
530+ )
531+ test_inds_temp = np .flatnonzero (
532+ np .isin (img_names , img_names_old [test_inds ])
533+ )
534+ train_inds , test_inds = pad_train_test_indices (
535+ train_inds_temp , test_inds_temp , train_frac
536+ )
537+ train_idx .append (train_inds )
538+ test_idx .append (test_inds )
539+
540+ # Search a pose_config.yaml file to parse missing information
541+ pose_config_path = ""
542+ for dirpath , dirnames , filenames in os .walk (
543+ os .path .join (project_path , "dlc-models" )
544+ ):
545+ for file in filenames :
546+ if file .endswith ("pose_cfg.yaml" ):
547+ pose_config_path = os .path .join (dirpath , file )
548+ break
549+ pose_cfg = read_plainconfig (pose_config_path )
550+ net_type = pose_cfg ["net_type" ]
551+ if net_type == "resnet_50" and pose_cfg .get ("multi_stage" , False ):
552+ net_type = "dlcrnet_ms5"
553+
554+ # Clean the training-datasets folder prior to recreating the data pickles
555+ shuffle_inds = set ()
556+ for file in pickle_files :
557+ os .remove (file )
558+ shuffle_inds .add (int (re .findall (r"shuffle(\d+)" , file )[0 ]))
559+ create_multianimaltraining_dataset (
560+ config_path ,
561+ trainIndices = train_idx ,
562+ testIndices = test_idx ,
563+ Shuffles = sorted (shuffle_inds ),
564+ net_type = net_type ,
565+ paf_graph = pose_cfg ["partaffinityfield_graph" ],
566+ crop_size = pose_cfg .get ("crop_size" , [400 , 400 ]),
567+ crop_sampling = pose_cfg .get ("crop_sampling" , "hybrid" ),
568+ )
0 commit comments