Skip to content

Commit a1614df

Browse files
committed
schedulers: local_files_only speedsup from_config from 800ms to 1ms
1 parent 690810d commit a1614df

4 files changed

Lines changed: 12 additions & 16 deletions

File tree

Dockerfile

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,6 @@ ADD convert-to-diffusers.py .
9494
RUN python3 convert-to-diffusers.py
9595
# RUN rm -rf checkpoints
9696

97-
# Loading a new scheduler for the first time takes an extra ~800ms,
98-
# so set this to your most common one.
99-
ARG DEFAULT_SCHEDULER="LMSDiscreteScheduler"
100-
ENV DEFAULT_SCHEDULER=${DEFAULT_SCHEDULER}
101-
10297
# Add your model weight files
10398
# (in this case we have a python script)
10499
ADD getScheduler.py .

download.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import os
55
from loadModel import loadModel, MODEL_IDS
6-
from getScheduler import getScheduler, SCHEDULERS, DEFAULT_SCHEDULER
76

87
MODEL_ID = os.environ.get("MODEL_ID")
98

@@ -17,8 +16,6 @@ def download_model():
1716
else:
1817
loadModel(MODEL_ID, False)
1918

20-
getScheduler(MODEL_ID, DEFAULT_SCHEDULER)
21-
2219

2320
if __name__ == "__main__":
2421
download_model()

getScheduler.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from diffusers import schedulers as _schedulers
55

66
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
7-
DEFAULT_SCHEDULER = os.getenv("DEFAULT_SCHEDULER")
87

98
SCHEDULERS = [
109
"LMSDiscreteScheduler",
@@ -14,6 +13,9 @@
1413
"EulerDiscreteScheduler",
1514
]
1615

16+
DEFAULT_SCHEDULER = os.getenv("DEFAULT_SCHEDULER", SCHEDULERS[0])
17+
18+
1719
"""
1820
# This was a nice idea but until we have default init vars for all schedulers
1921
# via from_config(), it's a no go. In any case, loading a scheduler takes time
@@ -31,15 +33,18 @@
3133
"""
3234

3335

34-
def initScheduler(MODEL_ID: str, scheduler_id: str):
36+
def initScheduler(MODEL_ID: str, scheduler_id: str, download=False):
3537
print(f"Initializing {scheduler_id} for {MODEL_ID}...")
3638
start = time.time()
3739
scheduler = getattr(_schedulers, scheduler_id)
3840
if scheduler == None:
3941
return None
4042

4143
inittedScheduler = scheduler.from_config(
42-
MODEL_ID, subfolder="scheduler", use_auth_token=HF_AUTH_TOKEN
44+
MODEL_ID,
45+
subfolder="scheduler",
46+
use_auth_token=HF_AUTH_TOKEN,
47+
local_files_only=not download,
4348
)
4449
diff = round((time.time() - start) * 1000)
4550
print(f"Initialized {scheduler_id} for {MODEL_ID} in {diff}ms")
@@ -50,7 +55,7 @@ def initScheduler(MODEL_ID: str, scheduler_id: str):
5055
schedulers = {}
5156

5257

53-
def getScheduler(MODEL_ID: str, scheduler_id: str):
58+
def getScheduler(MODEL_ID: str, scheduler_id: str, download=False):
5459
schedulersByModel = schedulers.get(MODEL_ID, None)
5560
if schedulersByModel == None:
5661
schedulersByModel = {}
@@ -73,7 +78,7 @@ def getScheduler(MODEL_ID: str, scheduler_id: str):
7378

7479
scheduler = schedulersByModel.get(scheduler_id, None)
7580
if scheduler == None:
76-
scheduler = initScheduler(MODEL_ID, scheduler_id)
81+
scheduler = initScheduler(MODEL_ID, scheduler_id, download)
7782
schedulersByModel.update({scheduler_id: scheduler})
7883

7984
return scheduler

loadModel.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import os
33
from diffusers import pipelines as _pipelines, StableDiffusionPipeline
4-
from getScheduler import getScheduler, SCHEDULERS, DEFAULT_SCHEDULER
4+
from getScheduler import getScheduler, DEFAULT_SCHEDULER
55

66
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
77
PIPELINE = os.getenv("PIPELINE")
@@ -22,8 +22,7 @@ def loadModel(model_id: str, load=True):
2222
StableDiffusionPipeline if PIPELINE == "ALL" else getattr(_pipelines, PIPELINE)
2323
)
2424

25-
print("DEFAULT SCHEDULER=" + DEFAULT_SCHEDULER)
26-
scheduler = getScheduler(model_id, DEFAULT_SCHEDULER)
25+
scheduler = getScheduler(model_id, DEFAULT_SCHEDULER, not load)
2726

2827
model = pipeline.from_pretrained(
2928
model_id,

0 commit comments

Comments
 (0)