Skip to content

Commit 690810d

Browse files
committed
init schedulers as needed because it's slow (TODO: faster way?)
1 parent fab6e8c commit 690810d

6 files changed

Lines changed: 99 additions & 54 deletions

File tree

Dockerfile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,14 @@ ADD convert-to-diffusers.py .
9494
RUN python3 convert-to-diffusers.py
9595
# RUN rm -rf checkpoints
9696

97+
# Loading a new scheduler for the first time takes an extra ~800ms,
98+
# so set this to your most common one.
99+
ARG DEFAULT_SCHEDULER="LMSDiscreteScheduler"
100+
ENV DEFAULT_SCHEDULER=${DEFAULT_SCHEDULER}
101+
97102
# Add your model weight files
98103
# (in this case we have a python script)
104+
ADD getScheduler.py .
99105
ADD loadModel.py .
100106
ADD download.py .
101107
RUN python3 download.py

app.py

Lines changed: 4 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from torch import autocast
55
from diffusers import (
66
pipelines as _pipelines,
7-
schedulers as _schedulers,
87
LMSDiscreteScheduler,
98
DDIMScheduler,
109
PNDMScheduler,
@@ -22,6 +21,7 @@
2221
import skimage
2322
import skimage.measure
2423
from PyPatchMatch import patch_match
24+
from getScheduler import getScheduler, SCHEDULERS
2525
import re
2626

2727
MODEL_ID = os.environ.get("MODEL_ID")
@@ -35,14 +35,6 @@
3535
"StableDiffusionInpaintPipelineLegacy",
3636
]
3737

38-
SCHEDULERS = [
39-
"LMSDiscreteScheduler",
40-
"DDIMScheduler",
41-
"PNDMScheduler",
42-
"EulerAncestralDiscreteScheduler",
43-
"EulerDiscreteScheduler",
44-
]
45-
4638
torch.set_grad_enabled(False)
4739

4840

@@ -95,30 +87,6 @@ def init():
9587
True,
9688
)
9789

98-
schedulers = {}
99-
"""
100-
# This was a nice idea but until we have default init vars for all schedulers
101-
# via from_config(), it's a no go.
102-
isScheduler = re.compile(r".+Scheduler$")
103-
for key, val in _schedulers.__dict__.items():
104-
if isScheduler.match(key):
105-
schedulers.update(
106-
{
107-
key: val.from_config(
108-
MODEL_ID, subfolder="scheduler", use_auth_token=HF_AUTH_TOKEN
109-
)
110-
}
111-
)
112-
"""
113-
for scheduler_name in SCHEDULERS:
114-
schedulers.update(
115-
{
116-
scheduler_name: getattr(_schedulers, scheduler_name).from_config(
117-
MODEL_ID, subfolder="scheduler", use_auth_token=HF_AUTH_TOKEN
118-
),
119-
}
120-
)
121-
12290
dummy_safety_checker = DummySafetyChecker()
12391

12492
if MODEL_ID == "ALL":
@@ -200,30 +168,14 @@ def inference(all_inputs: dict) -> dict:
200168
else:
201169
pipeline = model
202170

203-
# Check for use of all names
204-
scheduler_name = call_inputs.get("SCHEDULER", None)
205-
deprecated_map = {
206-
"LMS": "LMSDiscreteScheduler",
207-
"DDIM": "DDIMScheduler",
208-
"PNDM": "PNDMScheduler",
209-
}
210-
scheduler_renamed = deprecated_map.get(scheduler_name, None)
211-
if scheduler_renamed != None:
212-
print(
213-
f'[Deprecation Warning]: Scheduler "{scheduler_name}" is now '
214-
f'called "{scheduler_renamed}". Please rename as this will '
215-
f"stop working in a future release."
216-
)
217-
scheduler_name = scheduler_renamed
218-
219-
pipeline.scheduler = schedulers.get(scheduler_name, None)
171+
pipeline.scheduler = getScheduler(MODEL_ID, call_inputs.get("SCHEDULER", None))
220172
if pipeline.scheduler == None:
221173
return {
222174
"$error": {
223175
"code": "INVALID_SCHEDULER",
224176
"message": "",
225177
"requeted": call_inputs.get("SCHEDULER", None),
226-
"available": ", ".join(schedulers.keys()),
178+
"available": ", ".join(SCHEDULERS),
227179
}
228180
}
229181

@@ -286,10 +238,8 @@ def inference(all_inputs: dict) -> dict:
286238
x_m_e_a = call_inputs.get("xformers_memory_efficient_attention", None)
287239
if x_m_e_a != last_xformers_memory_efficient_attention:
288240
last_xformers_memory_efficient_attention = x_m_e_a
289-
if x_m_e_a == None:
241+
if x_m_e_a == None or x_m_e_a == True:
290242
pipeline.enable_xformers_memory_efficient_attention() # default on
291-
elif x_m_e_a == True:
292-
pipeline.enable_xformers_memory_efficient_attention()
293243
elif x_m_e_a == False:
294244
pipeline.disable_xformers_memory_efficient_attention()
295245
else:

download.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import os
55
from loadModel import loadModel, MODEL_IDS
6+
from getScheduler import getScheduler, SCHEDULERS, DEFAULT_SCHEDULER
67

78
MODEL_ID = os.environ.get("MODEL_ID")
89

@@ -16,6 +17,8 @@ def download_model():
1617
else:
1718
loadModel(MODEL_ID, False)
1819

20+
getScheduler(MODEL_ID, DEFAULT_SCHEDULER)
21+
1922

2023
if __name__ == "__main__":
2124
download_model()

getScheduler.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import torch
2+
import os
3+
import time
4+
from diffusers import schedulers as _schedulers
5+
6+
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
7+
DEFAULT_SCHEDULER = os.getenv("DEFAULT_SCHEDULER")
8+
9+
SCHEDULERS = [
10+
"LMSDiscreteScheduler",
11+
"DDIMScheduler",
12+
"PNDMScheduler",
13+
"EulerAncestralDiscreteScheduler",
14+
"EulerDiscreteScheduler",
15+
]
16+
17+
"""
18+
# This was a nice idea but until we have default init vars for all schedulers
19+
# via from_config(), it's a no go. In any case, loading a scheduler takes time
20+
# so better to init as needed and cache.
21+
isScheduler = re.compile(r".+Scheduler$")
22+
for key, val in _schedulers.__dict__.items():
23+
if isScheduler.match(key):
24+
schedulers.update(
25+
{
26+
key: val.from_config(
27+
MODEL_ID, subfolder="scheduler", use_auth_token=HF_AUTH_TOKEN
28+
)
29+
}
30+
)
31+
"""
32+
33+
34+
def initScheduler(MODEL_ID: str, scheduler_id: str):
35+
print(f"Initializing {scheduler_id} for {MODEL_ID}...")
36+
start = time.time()
37+
scheduler = getattr(_schedulers, scheduler_id)
38+
if scheduler == None:
39+
return None
40+
41+
inittedScheduler = scheduler.from_config(
42+
MODEL_ID, subfolder="scheduler", use_auth_token=HF_AUTH_TOKEN
43+
)
44+
diff = round((time.time() - start) * 1000)
45+
print(f"Initialized {scheduler_id} for {MODEL_ID} in {diff}ms")
46+
47+
return inittedScheduler
48+
49+
50+
schedulers = {}
51+
52+
53+
def getScheduler(MODEL_ID: str, scheduler_id: str):
54+
schedulersByModel = schedulers.get(MODEL_ID, None)
55+
if schedulersByModel == None:
56+
schedulersByModel = {}
57+
schedulers.update({MODEL_ID: schedulersByModel})
58+
59+
# Check for use of old names
60+
deprecated_map = {
61+
"LMS": "LMSDiscreteScheduler",
62+
"DDIM": "DDIMScheduler",
63+
"PNDM": "PNDMScheduler",
64+
}
65+
scheduler_renamed = deprecated_map.get(scheduler_id, None)
66+
if scheduler_renamed != None:
67+
print(
68+
f'[Deprecation Warning]: Scheduler "{scheduler_id}" is now '
69+
f'called "{scheduler_id}". Please rename as this will '
70+
f"stop working in a future release."
71+
)
72+
scheduler_id = scheduler_renamed
73+
74+
scheduler = schedulersByModel.get(scheduler_id, None)
75+
if scheduler == None:
76+
scheduler = initScheduler(MODEL_ID, scheduler_id)
77+
schedulersByModel.update({scheduler_id: scheduler})
78+
79+
return scheduler

loadModel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import os
33
from diffusers import pipelines as _pipelines, StableDiffusionPipeline
4+
from getScheduler import getScheduler, SCHEDULERS, DEFAULT_SCHEDULER
45

56
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
67
PIPELINE = os.getenv("PIPELINE")
@@ -21,11 +22,15 @@ def loadModel(model_id: str, load=True):
2122
StableDiffusionPipeline if PIPELINE == "ALL" else getattr(_pipelines, PIPELINE)
2223
)
2324

25+
print("DEFAULT SCHEDULER=" + DEFAULT_SCHEDULER)
26+
scheduler = getScheduler(model_id, DEFAULT_SCHEDULER)
27+
2428
model = pipeline.from_pretrained(
2529
model_id,
2630
revision="fp16",
2731
torch_dtype=torch.float16,
2832
use_auth_token=HF_AUTH_TOKEN,
33+
scheduler=scheduler,
2934
)
3035

3136
return model.to("cuda") if load else None

send.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import hashlib
66
from requests_futures.sessions import FuturesSession
77

8+
print()
89
print(os.environ)
10+
print()
911

1012

1113
def get_now():

0 commit comments

Comments
 (0)