Skip to content

Commit 1ccbaad

Browse files
committed
feat(pipelines): allow calling of ALL PIPELINES (official+community)
1 parent 35b0b06 commit 1ccbaad

4 files changed

Lines changed: 115 additions & 62 deletions

File tree

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66
all pipelines. Using `init_image` now shoes a deprecation warning and
77
will be removed in future.
88

9+
* **ALL THE PIPELINES**. We no longer load a list of hard-coded pipelines
10+
in `init()`. Instead, we init and cache each on first use (for faster
11+
first calls on cold boots), and, *all* pipelines, both official diffusers
12+
and community pipelines, are available.
13+
[Full details](https://banana-forums.dev/t/all-your-pipelines-are-belong-to-us/83)
14+
915
* **Changed `sd-base` to `diffusers-api` as the default tag / name used
1016
in the README examples and optional [./build][build script].
1117

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@ Please give credit and link back to this repo if you use it in a public project.
77

88
## Features
99

10-
* Pipelines: txt2img, img2img and inpainting in a single container
1110
* Models: stable-diffusion, waifu-diffusion, and easy to add others (e.g. jp-sd)
11+
* Pipelines: txt2img, img2img and inpainting in a single container
12+
13+
(
14+
[all diffusers official and community pipelines](https://banana-forums.dev/t/all-your-pipelines-are-belong-to-us/83) are wrapped, but untested)
1215
* All model inputs supported, including setting nsfw filter per request
1316
* *Permute* base config to multiple forks based on yaml config with vars
1417
* Optionally send signed event logs / performance data to a REST endpoint

app.py

Lines changed: 15 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,7 @@
22
import torch
33

44
from torch import autocast
5-
from diffusers import (
6-
pipelines as _pipelines,
7-
LMSDiscreteScheduler,
8-
DDIMScheduler,
9-
PNDMScheduler,
10-
DiffusionPipeline,
11-
__version__,
12-
)
5+
from diffusers import __version__
136
import base64
147
from io import BytesIO
158
import PIL
@@ -22,10 +15,10 @@
2215
import skimage.measure
2316
from PyPatchMatch import patch_match
2417
from getScheduler import getScheduler, SCHEDULERS
18+
from getPipeline import getPipelineForModel, listAvailablePipelines, clearPipelines
2519
import re
2620
import requests
2721
from download import download_model
28-
from precision import revision, torch_dtype
2922

3023
RUNTIME_DOWNLOADS = os.getenv("RUNTIME_DOWNLOADS") == "1"
3124
USE_DREAMBOOTH = os.getenv("USE_DREAMBOOTH") == "1"
@@ -36,53 +29,9 @@
3629
PIPELINE = os.environ.get("PIPELINE")
3730
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
3831

39-
PIPELINES = [
40-
"StableDiffusionPipeline",
41-
"StableDiffusionImg2ImgPipeline",
42-
"StableDiffusionInpaintPipeline",
43-
"StableDiffusionInpaintPipelineLegacy",
44-
]
45-
46-
COMMUNITY_PIPELINES = [
47-
"lpw_stable_diffusion",
48-
]
49-
5032
torch.set_grad_enabled(False)
5133

5234

53-
def createPipelinesFromModel(model, model_id):
54-
pipelines = dict()
55-
56-
for pipeline in PIPELINES:
57-
if hasattr(_pipelines, pipeline):
58-
if hasattr(model, "components"):
59-
pipelines[pipeline] = getattr(_pipelines, pipeline)(**model.components)
60-
else:
61-
pipelines[pipeline] = getattr(_pipelines, pipeline)(
62-
vae=model.vae,
63-
text_encoder=model.text_encoder,
64-
tokenizer=model.tokenizer,
65-
unet=model.unet,
66-
scheduler=model.scheduler,
67-
safety_checker=model.safety_checker,
68-
feature_extractor=model.feature_extractor,
69-
)
70-
else:
71-
print(f'Skipping non-existent pipeline "{PIPELINE}"')
72-
73-
for pipeline in COMMUNITY_PIPELINES:
74-
pipelines[pipeline] = DiffusionPipeline.from_pretrained(
75-
model_id,
76-
revision=revision,
77-
torch_dtype=torch_dtype,
78-
custom_pipeline="./diffusers/examples/community/" + pipeline + ".py",
79-
local_files_only=True,
80-
**model.components,
81-
)
82-
83-
return pipelines
84-
85-
8635
class DummySafetyChecker:
8736
@staticmethod
8837
def __call__(images, clip_input):
@@ -93,8 +42,6 @@ def __call__(images, clip_input):
9342
# Load your model to GPU as a global variable here using the variable name "model"
9443
def init():
9544
global model # needed for bananna optimizations
96-
global pipelines
97-
global schedulers
9845
global dummy_safety_checker
9946
global initTime
10047

@@ -120,9 +67,6 @@ def init():
12067
if not RUNTIME_DOWNLOADS:
12168
model = loadModel(MODEL_ID)
12269

123-
if PIPELINE == "ALL":
124-
pipelines = createPipelinesFromModel(model, MODEL_ID)
125-
12670
send("init", "done")
12771
initTime = get_now() - initStart
12872

@@ -200,13 +144,13 @@ def inference(all_inputs: dict) -> dict:
200144
downloaded_models.update({model_id: True})
201145
model = loadModel(model_id)
202146
if PIPELINE == "ALL":
203-
pipelines = createPipelinesFromModel(model, model_id)
147+
clearPipelines()
204148
last_model_id = model_id
205149

206150
if MODEL_ID == "ALL":
207151
if last_model_id != model_id:
208152
model = loadModel(model_id)
209-
pipelines = createPipelinesFromModel(model, model_id)
153+
clearPipelines()
210154
last_model_id = model_id
211155
else:
212156
if model_id != MODEL_ID and not RUNTIME_DOWNLOADS:
@@ -220,7 +164,17 @@ def inference(all_inputs: dict) -> dict:
220164
}
221165

222166
if PIPELINE == "ALL":
223-
pipeline = pipelines.get(call_inputs.get("PIPELINE"))
167+
pipeline_name = call_inputs.get("PIPELINE")
168+
pipeline = getPipelineForModel(pipeline_name, model, model_id)
169+
if not pipeline:
170+
return {
171+
"$error": {
172+
"code": "NO_SUCH_PIPELINE",
173+
"message": f'"{pipeline_name}" is not an official nor community Diffusers pipelines',
174+
"requested": pipeline_name,
175+
"available": listAvailablePipelines(),
176+
}
177+
}
224178
else:
225179
pipeline = model
226180

getPipeline.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import time
2+
import os, fnmatch
3+
from diffusers import (
4+
DiffusionPipeline,
5+
pipelines as diffusers_pipelines,
6+
)
7+
from precision import revision, torch_dtype
8+
9+
_pipelines = {}
10+
_availableCommunityPipelines = None
11+
12+
13+
def listAvailablePipelines():
14+
return (
15+
list(
16+
filter(
17+
lambda key: key.endswith("Pipeline"),
18+
list(diffusers_pipelines.__dict__.keys()),
19+
)
20+
)
21+
+ availableCommunityPipelines()
22+
)
23+
24+
25+
def availableCommunityPipelines():
26+
global _availableCommunityPipelines
27+
if not _availableCommunityPipelines:
28+
_availableCommunityPipelines = list(
29+
map(
30+
lambda s: s[0:-3],
31+
fnmatch.filter(os.listdir("diffusers/examples/community"), "*.py"),
32+
)
33+
)
34+
35+
return _availableCommunityPipelines
36+
37+
38+
def clearPipelines():
39+
"""
40+
Clears the pipeline cache. Important to call this when changing the
41+
loaded model, as pipelines include references to the model and would
42+
therefore prevent memory being reclaimed after unloading the previous
43+
model.
44+
"""
45+
pipelines = {}
46+
47+
48+
def getPipelineForModel(pipeline_name: str, model, model_id):
49+
"""
50+
Inits a new pipeline, re-using components from a previously loaded
51+
model. The pipeline is cached and future calls with the same
52+
arguments will return the previously initted instance. Be sure
53+
to call `clearPipelines()` if loading a new model, to allow the
54+
previous model to be garbage collected.
55+
"""
56+
pipeline = _pipelines.get(pipeline_name)
57+
if pipeline:
58+
return pipeline
59+
60+
start = time.time()
61+
62+
if hasattr(diffusers_pipelines, pipeline_name):
63+
if hasattr(model, "components"):
64+
pipeline = getattr(diffusers_pipelines, pipeline_name)(**model.components)
65+
else:
66+
pipeline = getattr(diffusers_pipelines, pipeline_name)(
67+
vae=model.vae,
68+
text_encoder=model.text_encoder,
69+
tokenizer=model.tokenizer,
70+
unet=model.unet,
71+
scheduler=model.scheduler,
72+
safety_checker=model.safety_checker,
73+
feature_extractor=model.feature_extractor,
74+
)
75+
76+
elif pipeline_name in availableCommunityPipelines():
77+
pipeline = DiffusionPipeline.from_pretrained(
78+
model_id,
79+
revision=revision,
80+
torch_dtype=torch_dtype,
81+
custom_pipeline="./diffusers/examples/community/" + pipeline_name + ".py",
82+
local_files_only=True,
83+
**model.components,
84+
)
85+
86+
if pipeline:
87+
_pipelines.update({pipeline_name: pipeline})
88+
diff = round((time.time() - start) * 1000)
89+
print(f"Initialized {pipeline_name} for {model_id} in {diff}ms")
90+
return pipeline

0 commit comments

Comments
 (0)