|
22 | 22 | import skimage.measure |
23 | 23 | from PyPatchMatch import patch_match |
24 | 24 | from getScheduler import getScheduler, SCHEDULERS |
| 25 | +from train_dreambooth import TrainDreamBooth |
25 | 26 | import re |
26 | 27 |
|
27 | 28 | MODEL_ID = os.environ.get("MODEL_ID") |
@@ -116,6 +117,10 @@ def truncateInputs(inputs: dict): |
116 | 117 | for item in ["init_image", "mask_image", "image"]: |
117 | 118 | if item in modelInputs: |
118 | 119 | modelInputs[item] = modelInputs[item][0:6] + "..." |
| 120 | + if "instance_images" in modelInputs: |
| 121 | + modelInputs["instance_images"] = list( |
| 122 | + map(lambda str: str[0:6] + "...", modelInputs["instance_images"]) |
| 123 | + ) |
119 | 124 | return clone |
120 | 125 |
|
121 | 126 |
|
@@ -209,6 +214,14 @@ def inference(all_inputs: dict) -> dict: |
209 | 214 | model_inputs.get("mask_image"), "mask_image" |
210 | 215 | ) |
211 | 216 |
|
| 217 | + if "instance_images" in model_inputs: |
| 218 | + model_inputs["instance_images"] = list( |
| 219 | + map( |
| 220 | + lambda str: decodeBase64Image(str, "instance_image"), |
| 221 | + model_inputs["instance_images"], |
| 222 | + ) |
| 223 | + ) |
| 224 | + |
212 | 225 | seed = model_inputs.get("seed", None) |
213 | 226 | if seed == None: |
214 | 227 | generator = torch.Generator(device="cuda") |
@@ -259,6 +272,14 @@ def inference(all_inputs: dict) -> dict: |
259 | 272 | # with autocast("cuda"): |
260 | 273 | # image = pipeline(**model_inputs).images[0] |
261 | 274 |
|
| 275 | + if call_inputs.get("train", None) == "dreambooth": |
| 276 | + result = TrainDreamBooth(pipeline, model_inputs) |
| 277 | + send("inference", "done", {"startRequestId": startRequestId}) |
| 278 | + inferenceTime = get_now() - inferenceStart |
| 279 | + timings = {"init": initTime, "inference": inferenceTime} |
| 280 | + result.update({"timings": timings}) |
| 281 | + return result |
| 282 | + |
262 | 283 | with torch.inference_mode(): |
263 | 284 | # autocast im2img and inpaint which are broken in 0.4.0, 0.4.1 |
264 | 285 | # still broken in 0.5.1 |
|
0 commit comments