Skip to content

Commit 3e87169

Browse files
committed
$timings improvements
* train_dreambooth will record train,upload time * test.py will show all $timings (not just init,inference)
1 parent 07ba610 commit 3e87169

File tree

3 files changed

+67
-11
lines changed

3 files changed

+67
-11
lines changed

app.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,9 @@ def inference(all_inputs: dict) -> dict:
276276
result = TrainDreamBooth(model_id, pipeline, model_inputs)
277277
send("inference", "done", {"startRequestId": startRequestId})
278278
inferenceTime = get_now() - inferenceStart
279-
timings = {"init": initTime, "inference": inferenceTime}
280-
result.update({"timings": timings})
279+
timings = result.get("$timings", {})
280+
timings = {"init": initTime, "inference": inferenceTime, **timings}
281+
result.update({"$timings": timings})
281282
return result
282283

283284
with torch.inference_mode():

test.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import sys
99
import time
10+
import datetime
1011
import argparse
1112
import distutils
1213
from uuid import uuid4
@@ -76,8 +77,27 @@ def runTest(name, banana, extraCallInputs, extraModelInputs):
7677
"startOnly": False,
7778
}
7879
response = requests.post("https://api.banana.dev/start/v4/", json=payload)
79-
8080
result = response.json()
81+
callID = result.get("callID")
82+
83+
if result.get("finished", None) == False:
84+
while result.get("message", None) != "success":
85+
secondsSinceStart = round((time.time() - start) / 1000)
86+
print(str(datetime.datetime.now()) + f": t+{secondsSinceStart}s")
87+
print(json.dumps(result, indent=4))
88+
print
89+
payload = {
90+
"id": str(uuid4()),
91+
"created": int(time.time()),
92+
"longPoll": True,
93+
"apiKey": BANANA_API_KEY,
94+
"callID": callID,
95+
}
96+
response = requests.post(
97+
"https://api.banana.dev/check/v4/", json=payload
98+
)
99+
result = response.json()
100+
81101
modelOutputs = result.get("modelOutputs", None)
82102
if modelOutputs == None:
83103
finish = time.time() - start
@@ -91,13 +111,22 @@ def runTest(name, banana, extraCallInputs, extraModelInputs):
91111

92112
finish = time.time() - start
93113
timings = result.get("$timings")
114+
94115
if timings:
95-
init = timings.get("init") / 1000
96-
inference = timings.get("inference") / 1000
97-
print(
98-
f"Request took {finish:.1f}s ("
99-
+ f"init: {init:.1f}s, inference: {inference:.1f}s)"
100-
)
116+
timings_str = json.dumps(
117+
dict(
118+
map(
119+
lambda item: (
120+
item[0],
121+
f"{item[1]/1000:.1f}s"
122+
if item[1] > 1000
123+
else str(item[1]) + "ms",
124+
),
125+
timings.items(),
126+
)
127+
)
128+
).replace('"', "")[1:-1]
129+
print(f"Request took {finish:.1f}s ({timings_str})")
101130
else:
102131
print(f"Request took {finish:.1f}s")
103132

train_dreambooth.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
from torchvision import transforms
3636
from tqdm.auto import tqdm
3737
from transformers import CLIPTextModel, CLIPTokenizer
38+
3839
from precision import revision, torch_dtype
40+
from send import send, get_now
3941

4042

4143
# Our original code in docker-diffusers-api:
@@ -119,8 +121,8 @@ def TrainDreamBooth(model_id: str, pipeline, model_inputs):
119121
args = argparse.Namespace(**params)
120122

121123
print(args)
122-
main(args, pipeline)
123-
return {"done": True}
124+
result = main(args, pipeline)
125+
return result
124126

125127

126128
# What follows is mostly the original train_dreambooth.py
@@ -571,6 +573,10 @@ def collate_fn(examples):
571573
progress_bar.set_description("Steps")
572574
global_step = 0
573575

576+
# DDA
577+
send("training", "start", {}, True)
578+
training_start = get_now()
579+
574580
for epoch in range(args.num_train_epochs):
575581
unet.train()
576582
if args.train_text_encoder:
@@ -657,6 +663,12 @@ def collate_fn(examples):
657663

658664
accelerator.wait_for_everyone()
659665

666+
# DDA
667+
send("training", "done")
668+
training_total = get_now() - training_start
669+
upload_start = 0
670+
upload_total = 0
671+
660672
# Create the pipeline using using the trained modules and save it.
661673
if accelerator.is_main_process:
662674
pipeline = StableDiffusionPipeline.from_pretrained(
@@ -669,6 +681,10 @@ def collate_fn(examples):
669681
pipeline.save_pretrained(args.output_dir)
670682

671683
if args.push_to_hub:
684+
# DDA
685+
send("uploading", "start", {}, True)
686+
upload_start = get_now()
687+
672688
repo.push_to_hub(
673689
commit_message="End of training",
674690
# DDA need to think about this, quite nice to not block, then could
@@ -678,4 +694,14 @@ def collate_fn(examples):
678694
auto_lfs_prune=True,
679695
)
680696

697+
# DDA
698+
send("uploading", "done")
699+
upload_total = get_now() - upload_start
700+
681701
accelerator.end_training()
702+
703+
# DDA
704+
return {
705+
"done": True,
706+
"$timings": {"training": training_total, "upload": upload_total},
707+
}

0 commit comments

Comments
 (0)