Skip to content

Commit b2ac6f8

Browse files
committed
training called correctly but diffusers fp16 bug breaks training
1 parent 39652f8 commit b2ac6f8

File tree

3 files changed

+95
-13
lines changed

3 files changed

+95
-13
lines changed

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ ARG USE_DREAMBOOTH=1
127127
RUN if [ "$USE_DREAMBOOTH" = "1" ] ; then \
128128
# By specifying the same torch version as conda, it won't download again.
129129
# Without this, it will upgrade torch, break xformers, make bigger image.
130-
pip install -r diffusers/examples/dreambooth/requirements.txt torch==1.12.1 ; \
130+
pip install -r diffusers/examples/dreambooth/requirements.txt bitsandbytes torch==1.12.1 ; \
131131
fi
132132

133133
# Add your custom app code, init() and inference()

app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def inference(all_inputs: dict) -> dict:
273273
# image = pipeline(**model_inputs).images[0]
274274

275275
if call_inputs.get("train", None) == "dreambooth":
276-
result = TrainDreamBooth(pipeline, model_inputs)
276+
result = TrainDreamBooth(model_id, pipeline, model_inputs)
277277
send("inference", "done", {"startRequestId": startRequestId})
278278
inferenceTime = get_now() - inferenceStart
279279
timings = {"init": initTime, "inference": inferenceTime}

train_dreambooth.py

Lines changed: 93 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,83 @@
4141
# TODO, in download.py download the pretrained tokenizers, encoders, etc.
4242
# we might have them already.
4343

44-
45-
def TrainDreamBooth(pipeline, model_inputs):
44+
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
45+
46+
47+
def TrainDreamBooth(model_id: str, pipeline, model_inputs):
48+
# required inputs: instance_images instance_prompt
49+
50+
# TODO, not save at all... we're just getting it working
51+
# if its a hassle, in interim, at least save to unique dir
52+
if not os.path.exists("instance_data_dir"):
53+
os.mkdir("instance_data_dir")
54+
for i, image in enumerate(model_inputs["instance_images"]):
55+
image.save("instance_data_dir/image" + str(i) + ".png")
56+
del model_inputs["instance_images"]
57+
58+
# TODO allow pass through of seed
59+
del model_inputs["generator"]
60+
61+
# TODO in app.py
62+
torch.set_grad_enabled(True)
63+
64+
import subprocess
65+
66+
subprocess.run(["ls", "-l", "instance_data_dir"])
67+
68+
params = {
69+
# Defaults
70+
"pretrained_model_name_or_path": model_id, # DDA, TODO
71+
"revision": "fp16", # DDA, was: None
72+
"tokenizer_name": None,
73+
"instance_data_dir": "instance_data_dir", # DDA TODO
74+
"class_data_dir": None,
75+
# instance_prompt
76+
"class_prompt": None,
77+
"with_prior_preservation": False,
78+
"prior_loss_weight": 1.0,
79+
"num_class_images": 100,
80+
"output_dir": "text-inversion-model",
81+
"seed": None,
82+
"resolution": 512,
83+
"center_crop": None,
84+
"train_text_encoder": None,
85+
"train_batch_size": 1, # DDA, was: 4
86+
"sample_batch_size": 1, # DDA, was: 4,
87+
"num_train_epochs": 1,
88+
"max_train_steps": None,
89+
"gradient_accumulation_steps": 1,
90+
"gradient_checkpointing": None,
91+
"learning_rate": 5e-6,
92+
"scale_lr": False,
93+
"lr_scheduler": "constant",
94+
"lr_warmup_steps": 0, # DDA, was: 500,
95+
"use_8bit_adam": True, # DDA, was: None
96+
"adam_beta1": 0.9,
97+
"adam_beta2": 0.999,
98+
"adam_weight_decay": 1e-2,
99+
"adam_epsilon": 1e-08,
100+
"max_grad_norm": 1.0,
101+
"push_to_hub": None,
102+
"hub_token": HF_AUTH_TOKEN,
103+
"hub_model_id": None,
104+
"logging_dir": "logs",
105+
"mixed_precision": "fp16", # DDA, was: "no"
106+
"local_rank": -1,
107+
}
108+
109+
params.update(model_inputs)
46110
print(model_inputs)
47-
dict = {}
48-
args = argparse.Namespace(**dict)
111+
112+
# params.update(
113+
# {
114+
# "pipeline": pipeline,
115+
# }
116+
# )
117+
args = argparse.Namespace(**params)
118+
119+
print(args)
120+
main(args)
49121
return {"done": True}
50122

51123

@@ -208,12 +280,15 @@ def main(args):
208280
torch_dtype = (
209281
torch.float16 if accelerator.device.type == "cuda" else torch.float32
210282
)
211-
pipeline = StableDiffusionPipeline.from_pretrained(
212-
args.pretrained_model_name_or_path,
213-
torch_dtype=torch_dtype,
214-
safety_checker=None,
215-
revision=args.revision,
216-
)
283+
# DDA
284+
pipeline = args.pipeline
285+
pipeline.safety_checker = None
286+
# pipeline = StableDiffusionPipeline.from_pretrained(
287+
# args.pretrained_model_name_or_path,
288+
# torch_dtype=torch_dtype,
289+
# safety_checker=None,
290+
# revision=args.revision,
291+
# )
217292
pipeline.set_progress_bar_config(disable=True)
218293

219294
num_new_images = args.num_class_images - cur_class_images
@@ -270,29 +345,34 @@ def main(args):
270345
tokenizer = CLIPTokenizer.from_pretrained(
271346
args.tokenizer_name,
272347
revision=args.revision,
348+
use_auth_token=args.hub_token, # DDA
273349
)
274350
elif args.pretrained_model_name_or_path:
275351
tokenizer = CLIPTokenizer.from_pretrained(
276352
args.pretrained_model_name_or_path,
277353
subfolder="tokenizer",
278354
revision=args.revision,
355+
use_auth_token=args.hub_token, # DDA
279356
)
280357

281358
# Load models and create wrapper for stable diffusion
282359
text_encoder = CLIPTextModel.from_pretrained(
283360
args.pretrained_model_name_or_path,
284361
subfolder="text_encoder",
285362
revision=args.revision,
363+
use_auth_token=args.hub_token, # DDA
286364
)
287365
vae = AutoencoderKL.from_pretrained(
288366
args.pretrained_model_name_or_path,
289367
subfolder="vae",
290368
revision=args.revision,
369+
use_auth_token=args.hub_token, # DDA
291370
)
292371
unet = UNet2DConditionModel.from_pretrained(
293372
args.pretrained_model_name_or_path,
294373
subfolder="unet",
295374
revision=args.revision,
375+
use_auth_token=args.hub_token, # DDA
296376
)
297377

298378
vae.requires_grad_(False)
@@ -339,7 +419,9 @@ def main(args):
339419
)
340420

341421
noise_scheduler = DDPMScheduler.from_config(
342-
"CompVis/stable-diffusion-v1-4", subfolder="scheduler"
422+
"CompVis/stable-diffusion-v1-4",
423+
subfolder="scheduler",
424+
use_auth_token=args.hub_token, # DDA
343425
)
344426

345427
train_dataset = DreamBoothDataset(

0 commit comments

Comments
 (0)