Skip to content

Commit 9927fa6

Browse files
committed
sync some work (next: fix pytorch conflict)
1 parent ae4cdaa commit 9927fa6

File tree

4 files changed

+610
-23
lines changed

4 files changed

+610
-23
lines changed

Dockerfile

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,24 @@ FROM base as output
5050
RUN mkdir /api
5151
WORKDIR /api
5252

53+
## XXXX playing around a lot.
54+
# pip installs pytorch 1.13 and uninstalls 1.12 (needed by xformers)
55+
# recomment conda update; didn't help. need to solve above issue.
56+
57+
RUN conda update -n base -c defaults conda
5358
# We need python 3.9 or 3.10 for xformers
5459
# Yes, we install pytorch twice... will switch base image in future
55-
# RUN conda update -n base -c defaults conda
5660
RUN conda create -n xformers python=3.10
5761
SHELL ["/opt/conda/bin/conda", "run", "--no-capture-output", "-n", "xformers", "/bin/bash", "-c"]
5862
RUN python --version
5963
RUN conda install -c pytorch -c conda-forge cudatoolkit=11.6 pytorch=1.12.1
6064
RUN conda install xformers -c xformers/label/dev
6165

6266
# Install python packages
63-
RUN pip3 install --upgrade pip
67+
# RUN pip3 install --upgrade pip
68+
RUN https_proxy="" REQUESTS_CA_BUNDLE="" conda install pip
6469
ADD requirements.txt requirements.txt
65-
RUN pip3 install -r requirements.txt
70+
RUN pip install -r requirements.txt
6671

6772
# Required to build flash attention
6873
# Turing: 7.5 (RTX 20s, Quadro), Ampere: 8.0 (A100), 8.6 (RTX 30s)
@@ -119,7 +124,13 @@ ARG USE_PATCHMATCH=0
119124
RUN if [ "$USE_PATCHMATCH" = "1" ] ; then apt-get install -yqq python3-opencv ; fi
120125
COPY --from=patchmatch /tmp/PyPatchMatch PyPatchMatch
121126

127+
ARG USE_DREAMBOOTH=1
128+
RUN if [ "$USE_DREAMBOOTH" = "1" ] ; then \
129+
pip install -r diffusers/examples/dreambooth/requirements.txt ; \
130+
fi
131+
122132
# Add your custom app code, init() and inference()
133+
ADD train_dreambooth.py .
123134
ADD send.py .
124135
ADD app.py .
125136

app.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import skimage.measure
2323
from PyPatchMatch import patch_match
2424
from getScheduler import getScheduler, SCHEDULERS
25+
from train_dreambooth import TrainDreamBooth
2526
import re
2627

2728
MODEL_ID = os.environ.get("MODEL_ID")
@@ -116,6 +117,10 @@ def truncateInputs(inputs: dict):
116117
for item in ["init_image", "mask_image", "image"]:
117118
if item in modelInputs:
118119
modelInputs[item] = modelInputs[item][0:6] + "..."
120+
if "instance_images" in modelInputs:
121+
modelInputs["instance_images"] = list(
122+
map(lambda str: str[0:6] + "...", modelInputs["instance_images"])
123+
)
119124
return clone
120125

121126

@@ -209,6 +214,14 @@ def inference(all_inputs: dict) -> dict:
209214
model_inputs.get("mask_image"), "mask_image"
210215
)
211216

217+
if "instance_images" in model_inputs:
218+
model_inputs["instance_images"] = list(
219+
map(
220+
lambda str: decodeBase64Image(str, "instance_image"),
221+
model_inputs["instance_images"],
222+
)
223+
)
224+
212225
seed = model_inputs.get("seed", None)
213226
if seed == None:
214227
generator = torch.Generator(device="cuda")
@@ -259,6 +272,14 @@ def inference(all_inputs: dict) -> dict:
259272
# with autocast("cuda"):
260273
# image = pipeline(**model_inputs).images[0]
261274

275+
if call_inputs.get("train", None) == "dreambooth":
276+
result = TrainDreamBooth(pipeline, model_inputs)
277+
send("inference", "done", {"startRequestId": startRequestId})
278+
inferenceTime = get_now() - inferenceStart
279+
timings = {"init": initTime, "inference": inferenceTime}
280+
result.update({"timings": timings})
281+
return result
282+
262283
with torch.inference_mode():
263284
# autocast im2img and inpaint which are broken in 0.4.0, 0.4.1
264285
# still broken in 0.5.1

test.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def main(tests_to_run, banana, extraCallInputs):
266266
name = "dreambooth"
267267
inputs = {
268268
"modelInputs": {
269-
"prompt": "girl with a pearl earing standing in a big room",
269+
"instance_prompt": "a photo of sks dog",
270270
"instance_images": list(
271271
map(b64encode_file, list(Path("tests/fixtures/dreambooth").iterdir()))
272272
),
@@ -276,29 +276,12 @@ def main(tests_to_run, banana, extraCallInputs):
276276
"MODEL_ID": "runwayml/stable-diffusion-v1-5",
277277
"PIPELINE": "StableDiffusionPipeline",
278278
"SCHEDULER": "LMSDiscreteScheduler",
279+
"train": "dreambooth",
279280
},
280281
}
281-
print(inputs)
282282

283-
# print(json.dumps(inputs, indent=4))
284-
exit
285-
286-
"""
287283
print("Running test: " + name)
288284
response = requests.post("http://localhost:8000/", json=inputs)
289285
result = response.json()
290286

291-
if result.get("images_base64", None) == "None":
292-
print(json.dumps(result, indent=4))
293-
print()
294-
exit
295-
296-
images_base64 = result.get("images_base64", None)
297-
if images_base64:
298-
for idx, image_byte_string in enumerate(images_base64):
299-
decode_and_save(image_byte_string, f"{name}_{idx}")
300-
else:
301-
decode_and_save(result["image_base64"], name)
302-
303-
print()
304-
"""
287+
print(result)

0 commit comments

Comments
 (0)