3535from torchvision import transforms
3636from tqdm .auto import tqdm
3737from transformers import CLIPTextModel , CLIPTokenizer
38+
3839from precision import revision , torch_dtype
40+ from send import send , get_now
3941
4042
4143# Our original code in docker-diffusers-api:
@@ -119,8 +121,8 @@ def TrainDreamBooth(model_id: str, pipeline, model_inputs):
119121 args = argparse .Namespace (** params )
120122
121123 print (args )
122- main (args , pipeline )
123- return { "done" : True }
124+ result = main (args , pipeline )
125+ return result
124126
125127
126128# What follows is mostly the original train_dreambooth.py
@@ -571,6 +573,10 @@ def collate_fn(examples):
571573 progress_bar .set_description ("Steps" )
572574 global_step = 0
573575
576+ # DDA
577+ send ("training" , "start" , {}, True )
578+ training_start = get_now ()
579+
574580 for epoch in range (args .num_train_epochs ):
575581 unet .train ()
576582 if args .train_text_encoder :
@@ -657,6 +663,12 @@ def collate_fn(examples):
657663
658664 accelerator .wait_for_everyone ()
659665
666+ # DDA
667+ send ("training" , "done" )
668+ training_total = get_now () - training_start
669+ upload_start = 0
670+ upload_total = 0
671+
660672 # Create the pipeline using using the trained modules and save it.
661673 if accelerator .is_main_process :
662674 pipeline = StableDiffusionPipeline .from_pretrained (
@@ -669,6 +681,10 @@ def collate_fn(examples):
669681 pipeline .save_pretrained (args .output_dir )
670682
671683 if args .push_to_hub :
684+ # DDA
685+ send ("uploading" , "start" , {}, True )
686+ upload_start = get_now ()
687+
672688 repo .push_to_hub (
673689 commit_message = "End of training" ,
674690 # DDA need to think about this, quite nice to not block, then could
@@ -678,4 +694,14 @@ def collate_fn(examples):
678694 auto_lfs_prune = True ,
679695 )
680696
697+ # DDA
698+ send ("uploading" , "done" )
699+ upload_total = get_now () - upload_start
700+
681701 accelerator .end_training ()
702+
703+ # DDA
704+ return {
705+ "done" : True ,
706+ "$timings" : {"training" : training_total , "upload" : upload_total },
707+ }
0 commit comments