Skip to content

Commit 7af45cf

Browse files
committed
feat(pipelines): initial community pipeline support
1 parent 9527ece commit 7af45cf

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

app.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import re
2626
import requests
2727
from download import download_model
28+
from precision import revision
2829

2930
RUNTIME_DOWNLOADS = os.getenv("RUNTIME_DOWNLOADS") == "1"
3031
USE_DREAMBOOTH = os.getenv("USE_DREAMBOOTH") == "1"
@@ -42,11 +43,16 @@
4243
"StableDiffusionInpaintPipelineLegacy",
4344
]
4445

46+
COMMUNITY_PIPELINES = [
47+
"lpw_stable_diffusion",
48+
]
49+
4550
torch.set_grad_enabled(False)
4651

4752

48-
def createPipelinesFromModel(model):
53+
def createPipelinesFromModel(model, model_id):
4954
pipelines = dict()
55+
5056
for pipeline in PIPELINES:
5157
if hasattr(_pipelines, pipeline):
5258
if hasattr(model, "components"):
@@ -63,6 +69,16 @@ def createPipelinesFromModel(model):
6369
)
6470
else:
6571
print(f'Skipping non-existent pipeline "{PIPELINE}"')
72+
73+
for pipeline in COMMUNITY_PIPELINES:
74+
pipelines[pipeline] = DiffusionPipeline.from_pretrained(
75+
model_id,
76+
revision=revision,
77+
custom_pipeline="./diffusers/examples/community/" + pipeline + ".py",
78+
local_files_only=True,
79+
**model.components,
80+
)
81+
6682
return pipelines
6783

6884

@@ -104,7 +120,7 @@ def init():
104120
model = loadModel(MODEL_ID)
105121

106122
if PIPELINE == "ALL":
107-
pipelines = createPipelinesFromModel(model)
123+
pipelines = createPipelinesFromModel(model, MODEL_ID)
108124

109125
send("init", "done")
110126
initTime = get_now() - initStart
@@ -183,13 +199,13 @@ def inference(all_inputs: dict) -> dict:
183199
downloaded_models.update({model_id: True})
184200
model = loadModel(model_id)
185201
if PIPELINE == "ALL":
186-
pipelines = createPipelinesFromModel(model)
202+
pipelines = createPipelinesFromModel(model, model_id)
187203
last_model_id = model_id
188204

189205
if MODEL_ID == "ALL":
190206
if last_model_id != model_id:
191207
model = loadModel(model_id)
192-
pipelines = createPipelinesFromModel(model)
208+
pipelines = createPipelinesFromModel(model, model_id)
193209
last_model_id = model_id
194210
else:
195211
if model_id != MODEL_ID and not RUNTIME_DOWNLOADS:
@@ -328,9 +344,12 @@ def inference(all_inputs: dict) -> dict:
328344
model_inputs.update({"generator": generator})
329345

330346
with torch.inference_mode():
347+
custom_pipeline_method = call_inputs.get("custom_pipeline_method", None)
348+
if custom_pipeline_method:
349+
images = getattr(pipeline, custom_pipeline_method)(**model_inputs).images
331350
# autocast im2img and inpaint which are broken in 0.4.0, 0.4.1
332351
# still broken in 0.5.1
333-
if call_inputs.get("PIPELINE") != "StableDiffusionPipeline":
352+
elif call_inputs.get("PIPELINE") != "StableDiffusionPipeline":
334353
with autocast("cuda"):
335354
images = pipeline(**model_inputs).images
336355
else:

0 commit comments

Comments
 (0)