|
41 | 41 | # TODO, in download.py download the pretrained tokenizers, encoders, etc. |
42 | 42 | # we might have them already. |
43 | 43 |
|
44 | | - |
45 | | -def TrainDreamBooth(pipeline, model_inputs): |
| 44 | +HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN") |
| 45 | + |
| 46 | + |
| 47 | +def TrainDreamBooth(model_id: str, pipeline, model_inputs): |
| 48 | + # required inputs: instance_images instance_prompt |
| 49 | + |
| 50 | + # TODO, not save at all... we're just getting it working |
| 51 | + # if its a hassle, in interim, at least save to unique dir |
| 52 | + if not os.path.exists("instance_data_dir"): |
| 53 | + os.mkdir("instance_data_dir") |
| 54 | + for i, image in enumerate(model_inputs["instance_images"]): |
| 55 | + image.save("instance_data_dir/image" + str(i) + ".png") |
| 56 | + del model_inputs["instance_images"] |
| 57 | + |
| 58 | + # TODO allow pass through of seed |
| 59 | + del model_inputs["generator"] |
| 60 | + |
| 61 | + # TODO in app.py |
| 62 | + torch.set_grad_enabled(True) |
| 63 | + |
| 64 | + import subprocess |
| 65 | + |
| 66 | + subprocess.run(["ls", "-l", "instance_data_dir"]) |
| 67 | + |
| 68 | + params = { |
| 69 | + # Defaults |
| 70 | + "pretrained_model_name_or_path": model_id, # DDA, TODO |
| 71 | + "revision": "fp16", # DDA, was: None |
| 72 | + "tokenizer_name": None, |
| 73 | + "instance_data_dir": "instance_data_dir", # DDA TODO |
| 74 | + "class_data_dir": None, |
| 75 | + # instance_prompt |
| 76 | + "class_prompt": None, |
| 77 | + "with_prior_preservation": False, |
| 78 | + "prior_loss_weight": 1.0, |
| 79 | + "num_class_images": 100, |
| 80 | + "output_dir": "text-inversion-model", |
| 81 | + "seed": None, |
| 82 | + "resolution": 512, |
| 83 | + "center_crop": None, |
| 84 | + "train_text_encoder": None, |
| 85 | + "train_batch_size": 1, # DDA, was: 4 |
| 86 | + "sample_batch_size": 1, # DDA, was: 4, |
| 87 | + "num_train_epochs": 1, |
| 88 | + "max_train_steps": None, |
| 89 | + "gradient_accumulation_steps": 1, |
| 90 | + "gradient_checkpointing": None, |
| 91 | + "learning_rate": 5e-6, |
| 92 | + "scale_lr": False, |
| 93 | + "lr_scheduler": "constant", |
| 94 | + "lr_warmup_steps": 0, # DDA, was: 500, |
| 95 | + "use_8bit_adam": True, # DDA, was: None |
| 96 | + "adam_beta1": 0.9, |
| 97 | + "adam_beta2": 0.999, |
| 98 | + "adam_weight_decay": 1e-2, |
| 99 | + "adam_epsilon": 1e-08, |
| 100 | + "max_grad_norm": 1.0, |
| 101 | + "push_to_hub": None, |
| 102 | + "hub_token": HF_AUTH_TOKEN, |
| 103 | + "hub_model_id": None, |
| 104 | + "logging_dir": "logs", |
| 105 | + "mixed_precision": "fp16", # DDA, was: "no" |
| 106 | + "local_rank": -1, |
| 107 | + } |
| 108 | + |
| 109 | + params.update(model_inputs) |
46 | 110 | print(model_inputs) |
47 | | - dict = {} |
48 | | - args = argparse.Namespace(**dict) |
| 111 | + |
| 112 | + # params.update( |
| 113 | + # { |
| 114 | + # "pipeline": pipeline, |
| 115 | + # } |
| 116 | + # ) |
| 117 | + args = argparse.Namespace(**params) |
| 118 | + |
| 119 | + print(args) |
| 120 | + main(args) |
49 | 121 | return {"done": True} |
50 | 122 |
|
51 | 123 |
|
@@ -208,12 +280,15 @@ def main(args): |
208 | 280 | torch_dtype = ( |
209 | 281 | torch.float16 if accelerator.device.type == "cuda" else torch.float32 |
210 | 282 | ) |
211 | | - pipeline = StableDiffusionPipeline.from_pretrained( |
212 | | - args.pretrained_model_name_or_path, |
213 | | - torch_dtype=torch_dtype, |
214 | | - safety_checker=None, |
215 | | - revision=args.revision, |
216 | | - ) |
| 283 | + # DDA |
| 284 | + pipeline = args.pipeline |
| 285 | + pipeline.safety_checker = None |
| 286 | + # pipeline = StableDiffusionPipeline.from_pretrained( |
| 287 | + # args.pretrained_model_name_or_path, |
| 288 | + # torch_dtype=torch_dtype, |
| 289 | + # safety_checker=None, |
| 290 | + # revision=args.revision, |
| 291 | + # ) |
217 | 292 | pipeline.set_progress_bar_config(disable=True) |
218 | 293 |
|
219 | 294 | num_new_images = args.num_class_images - cur_class_images |
@@ -270,29 +345,34 @@ def main(args): |
270 | 345 | tokenizer = CLIPTokenizer.from_pretrained( |
271 | 346 | args.tokenizer_name, |
272 | 347 | revision=args.revision, |
| 348 | + use_auth_token=args.hub_token, # DDA |
273 | 349 | ) |
274 | 350 | elif args.pretrained_model_name_or_path: |
275 | 351 | tokenizer = CLIPTokenizer.from_pretrained( |
276 | 352 | args.pretrained_model_name_or_path, |
277 | 353 | subfolder="tokenizer", |
278 | 354 | revision=args.revision, |
| 355 | + use_auth_token=args.hub_token, # DDA |
279 | 356 | ) |
280 | 357 |
|
281 | 358 | # Load models and create wrapper for stable diffusion |
282 | 359 | text_encoder = CLIPTextModel.from_pretrained( |
283 | 360 | args.pretrained_model_name_or_path, |
284 | 361 | subfolder="text_encoder", |
285 | 362 | revision=args.revision, |
| 363 | + use_auth_token=args.hub_token, # DDA |
286 | 364 | ) |
287 | 365 | vae = AutoencoderKL.from_pretrained( |
288 | 366 | args.pretrained_model_name_or_path, |
289 | 367 | subfolder="vae", |
290 | 368 | revision=args.revision, |
| 369 | + use_auth_token=args.hub_token, # DDA |
291 | 370 | ) |
292 | 371 | unet = UNet2DConditionModel.from_pretrained( |
293 | 372 | args.pretrained_model_name_or_path, |
294 | 373 | subfolder="unet", |
295 | 374 | revision=args.revision, |
| 375 | + use_auth_token=args.hub_token, # DDA |
296 | 376 | ) |
297 | 377 |
|
298 | 378 | vae.requires_grad_(False) |
@@ -339,7 +419,9 @@ def main(args): |
339 | 419 | ) |
340 | 420 |
|
341 | 421 | noise_scheduler = DDPMScheduler.from_config( |
342 | | - "CompVis/stable-diffusion-v1-4", subfolder="scheduler" |
| 422 | + "CompVis/stable-diffusion-v1-4", |
| 423 | + subfolder="scheduler", |
| 424 | + use_auth_token=args.hub_token, # DDA |
343 | 425 | ) |
344 | 426 |
|
345 | 427 | train_dataset = DreamBoothDataset( |
|
0 commit comments