Skip to content

Commit d2d5ad1

Browse files
committed
refactor loadModel. use model_inputs to check {input,mask}_image
1 parent 9c07a6b commit d2d5ad1

4 files changed

Lines changed: 31 additions & 53 deletions

File tree

Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ ADD DOWNLOAD_VARS.py .
2525

2626
# Add your model weight files
2727
# (in this case we have a python script)
28+
ADD loadModel.py .
2829
ADD download.py .
2930
RUN python3 download.py
3031

app.py

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
)
1010
import base64
1111
from io import BytesIO
12-
import os
1312
import PIL
1413
import json
14+
from loadModel import loadModel
1515

1616
from APP_VARS import MODEL_ID
1717

@@ -54,8 +54,6 @@ def init():
5454
global schedulers
5555
global dummy_safety_checker
5656

57-
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
58-
5957
schedulers = {
6058
"LMS": LMSDiscreteScheduler(
6159
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
@@ -77,15 +75,7 @@ def init():
7775
last_model_id = None
7876
return
7977

80-
print("Loading model " + MODEL_ID)
81-
82-
model = _pipelines.StableDiffusionPipeline.from_pretrained(
83-
MODEL_ID,
84-
revision="fp16",
85-
torch_dtype=torch.float16,
86-
use_auth_token=HF_AUTH_TOKEN,
87-
).to("cuda")
88-
78+
model = loadModel(MODEL_ID)
8979
pipelines = createPipelinesFromModel(MODEL_ID)
9080

9181

@@ -120,17 +110,9 @@ def inference(all_inputs: dict) -> dict:
120110
return {"$error": "UPGRADE CLIENT - no model_inputs specified"}
121111

122112
if MODEL_ID == "ALL":
123-
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
124113
model_id = call_inputs.get("MODEL_ID")
125114
if last_model_id != model_id:
126-
print("Loading model " + model_id)
127-
model = _pipelines.StableDiffusionPipeline.from_pretrained(
128-
model_id,
129-
revision="fp16",
130-
torch_dtype=torch.float16,
131-
use_auth_token=HF_AUTH_TOKEN,
132-
).to("cuda")
133-
115+
model = loadModel(model_id)
134116
pipelines = createPipelinesFromModel(model_id)
135117
last_model_id = model_id
136118

@@ -155,18 +137,11 @@ def inference(all_inputs: dict) -> dict:
155137
# seed = model_inputs.get("seed", None)
156138
# strength = model_inputs.get("strength", 0.75)
157139

158-
if call_inputs.get("PIPELINE") in [
159-
"StableDiffusionImg2ImgPipeline",
160-
"StableDiffusionInpaintPipeline",
161-
]:
162-
model_inputs.update(
163-
{"init_image": decodeBase64Image(model_inputs.get("init_image"))}
164-
)
140+
if "init_image" in model_inputs:
141+
model_inputs["init_image"] = decodeBase64Image(model_inputs.get("init_image"))
165142

166-
if all_inputs.get("PIPELINE") == "StableDiffusionInpaintPipeline":
167-
model_inputs.update(
168-
{"mask_image": decodeBase64Image(model_inputs.get("mask_image"))}
169-
)
143+
if "mask_image" in model_inputs:
144+
model_inputs["mask_image"] = decodeBase64Image(model_inputs.get("mask_image"))
170145

171146
seed = model_inputs.get("seed", None)
172147
if seed == None:

download.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,18 @@
11
# In this file, we define download_model
22
# It runs during container build time to get model weights built into the container
33

4-
from diffusers import StableDiffusionPipeline
5-
import torch
6-
import os
74
from DOWNLOAD_VARS import MODEL_ID
8-
9-
MODEL_IDS = ["CompVis/stable-diffusion-v1-4", "hakurei/waifu-diffusion"]
5+
from loadModel import loadModel, MODEL_IDS
106

117

128
def download_model():
139
# do a dry run of loading the huggingface model, which will download weights at build time
14-
# Set auth token which is required to download stable diffusion model weights
15-
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
16-
17-
# Bad for production serverless, great for local dev & preview deploys
10+
# For local dev & preview deploys, download all the models (terrible for serverless deploys)
1811
if MODEL_ID == "ALL":
1912
for MODEL_I in MODEL_IDS:
20-
StableDiffusionPipeline.from_pretrained(
21-
MODEL_I,
22-
revision="fp16",
23-
torch_dtype=torch.float16,
24-
use_auth_token=HF_AUTH_TOKEN,
25-
)
13+
loadModel(MODEL_I, False)
2614
else:
27-
model = StableDiffusionPipeline.from_pretrained(
28-
MODEL_ID,
29-
revision="fp16",
30-
torch_dtype=torch.float16,
31-
use_auth_token=HF_AUTH_TOKEN,
32-
)
15+
loadModel(MODEL_ID, False)
3316

3417

3518
if __name__ == "__main__":

loadModel.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
import os
3+
from diffusers import StableDiffusionPipeline
4+
5+
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
6+
MODEL_IDS = ["CompVis/stable-diffusion-v1-4", "hakurei/waifu-diffusion"]
7+
8+
9+
def loadModel(model_id: str, load=True):
10+
print(("Loading" if load else "Downloading") + " model: " + model_id)
11+
12+
model = StableDiffusionPipeline.from_pretrained(
13+
model_id,
14+
revision="fp16",
15+
torch_dtype=torch.float16,
16+
use_auth_token=HF_AUTH_TOKEN,
17+
)
18+
19+
return model.to("cuda") if load else None

0 commit comments

Comments
 (0)