22import torch
33
44from torch import autocast
5- from diffusers import (
6- pipelines as _pipelines ,
7- LMSDiscreteScheduler ,
8- DDIMScheduler ,
9- PNDMScheduler ,
10- DiffusionPipeline ,
11- __version__ ,
12- )
5+ from diffusers import __version__
136import base64
147from io import BytesIO
158import PIL
2215import skimage .measure
2316from PyPatchMatch import patch_match
2417from getScheduler import getScheduler , SCHEDULERS
18+ from getPipeline import getPipelineForModel , listAvailablePipelines , clearPipelines
2519import re
2620import requests
2721from download import download_model
28- from precision import revision , torch_dtype
2922
3023RUNTIME_DOWNLOADS = os .getenv ("RUNTIME_DOWNLOADS" ) == "1"
3124USE_DREAMBOOTH = os .getenv ("USE_DREAMBOOTH" ) == "1"
3629PIPELINE = os .environ .get ("PIPELINE" )
3730HF_AUTH_TOKEN = os .getenv ("HF_AUTH_TOKEN" )
3831
39- PIPELINES = [
40- "StableDiffusionPipeline" ,
41- "StableDiffusionImg2ImgPipeline" ,
42- "StableDiffusionInpaintPipeline" ,
43- "StableDiffusionInpaintPipelineLegacy" ,
44- ]
45-
46- COMMUNITY_PIPELINES = [
47- "lpw_stable_diffusion" ,
48- ]
49-
5032torch .set_grad_enabled (False )
5133
5234
53- def createPipelinesFromModel (model , model_id ):
54- pipelines = dict ()
55-
56- for pipeline in PIPELINES :
57- if hasattr (_pipelines , pipeline ):
58- if hasattr (model , "components" ):
59- pipelines [pipeline ] = getattr (_pipelines , pipeline )(** model .components )
60- else :
61- pipelines [pipeline ] = getattr (_pipelines , pipeline )(
62- vae = model .vae ,
63- text_encoder = model .text_encoder ,
64- tokenizer = model .tokenizer ,
65- unet = model .unet ,
66- scheduler = model .scheduler ,
67- safety_checker = model .safety_checker ,
68- feature_extractor = model .feature_extractor ,
69- )
70- else :
71- print (f'Skipping non-existent pipeline "{ PIPELINE } "' )
72-
73- for pipeline in COMMUNITY_PIPELINES :
74- pipelines [pipeline ] = DiffusionPipeline .from_pretrained (
75- model_id ,
76- revision = revision ,
77- torch_dtype = torch_dtype ,
78- custom_pipeline = "./diffusers/examples/community/" + pipeline + ".py" ,
79- local_files_only = True ,
80- ** model .components ,
81- )
82-
83- return pipelines
84-
85-
8635class DummySafetyChecker :
8736 @staticmethod
8837 def __call__ (images , clip_input ):
@@ -93,8 +42,6 @@ def __call__(images, clip_input):
9342# Load your model to GPU as a global variable here using the variable name "model"
9443def init ():
9544 global model # needed for bananna optimizations
96- global pipelines
97- global schedulers
9845 global dummy_safety_checker
9946 global initTime
10047
@@ -120,9 +67,6 @@ def init():
12067 if not RUNTIME_DOWNLOADS :
12168 model = loadModel (MODEL_ID )
12269
123- if PIPELINE == "ALL" :
124- pipelines = createPipelinesFromModel (model , MODEL_ID )
125-
12670 send ("init" , "done" )
12771 initTime = get_now () - initStart
12872
@@ -200,13 +144,13 @@ def inference(all_inputs: dict) -> dict:
200144 downloaded_models .update ({model_id : True })
201145 model = loadModel (model_id )
202146 if PIPELINE == "ALL" :
203- pipelines = createPipelinesFromModel ( model , model_id )
147+ clearPipelines ( )
204148 last_model_id = model_id
205149
206150 if MODEL_ID == "ALL" :
207151 if last_model_id != model_id :
208152 model = loadModel (model_id )
209- pipelines = createPipelinesFromModel ( model , model_id )
153+ clearPipelines ( )
210154 last_model_id = model_id
211155 else :
212156 if model_id != MODEL_ID and not RUNTIME_DOWNLOADS :
@@ -220,7 +164,17 @@ def inference(all_inputs: dict) -> dict:
220164 }
221165
222166 if PIPELINE == "ALL" :
223- pipeline = pipelines .get (call_inputs .get ("PIPELINE" ))
167+ pipeline_name = call_inputs .get ("PIPELINE" )
168+ pipeline = getPipelineForModel (pipeline_name , model , model_id )
169+ if not pipeline :
170+ return {
171+ "$error" : {
172+ "code" : "NO_SUCH_PIPELINE" ,
173+ "message" : f'"{ pipeline_name } " is not an official nor community Diffusers pipelines' ,
174+ "requested" : pipeline_name ,
175+ "available" : listAvailablePipelines (),
176+ }
177+ }
224178 else :
225179 pipeline = model
226180
0 commit comments