2525import re
2626import requests
2727from download import download_model
28+ from precision import revision
2829
2930RUNTIME_DOWNLOADS = os .getenv ("RUNTIME_DOWNLOADS" ) == "1"
3031USE_DREAMBOOTH = os .getenv ("USE_DREAMBOOTH" ) == "1"
4243 "StableDiffusionInpaintPipelineLegacy" ,
4344]
4445
46+ COMMUNITY_PIPELINES = [
47+ "lpw_stable_diffusion" ,
48+ ]
49+
4550torch .set_grad_enabled (False )
4651
4752
48- def createPipelinesFromModel (model ):
53+ def createPipelinesFromModel (model , model_id ):
4954 pipelines = dict ()
55+
5056 for pipeline in PIPELINES :
5157 if hasattr (_pipelines , pipeline ):
5258 if hasattr (model , "components" ):
@@ -63,6 +69,16 @@ def createPipelinesFromModel(model):
6369 )
6470 else :
6571 print (f'Skipping non-existent pipeline "{ PIPELINE } "' )
72+
73+ for pipeline in COMMUNITY_PIPELINES :
74+ pipelines [pipeline ] = DiffusionPipeline .from_pretrained (
75+ model_id ,
76+ revision = revision ,
77+ custom_pipeline = "./diffusers/examples/community/" + pipeline + ".py" ,
78+ local_files_only = True ,
79+ ** model .components ,
80+ )
81+
6682 return pipelines
6783
6884
@@ -104,7 +120,7 @@ def init():
104120 model = loadModel (MODEL_ID )
105121
106122 if PIPELINE == "ALL" :
107- pipelines = createPipelinesFromModel (model )
123+ pipelines = createPipelinesFromModel (model , MODEL_ID )
108124
109125 send ("init" , "done" )
110126 initTime = get_now () - initStart
@@ -183,13 +199,13 @@ def inference(all_inputs: dict) -> dict:
183199 downloaded_models .update ({model_id : True })
184200 model = loadModel (model_id )
185201 if PIPELINE == "ALL" :
186- pipelines = createPipelinesFromModel (model )
202+ pipelines = createPipelinesFromModel (model , model_id )
187203 last_model_id = model_id
188204
189205 if MODEL_ID == "ALL" :
190206 if last_model_id != model_id :
191207 model = loadModel (model_id )
192- pipelines = createPipelinesFromModel (model )
208+ pipelines = createPipelinesFromModel (model , model_id )
193209 last_model_id = model_id
194210 else :
195211 if model_id != MODEL_ID and not RUNTIME_DOWNLOADS :
@@ -328,9 +344,12 @@ def inference(all_inputs: dict) -> dict:
328344 model_inputs .update ({"generator" : generator })
329345
330346 with torch .inference_mode ():
347+ custom_pipeline_method = call_inputs .get ("custom_pipeline_method" , None )
348+ if custom_pipeline_method :
349+ images = getattr (pipeline , custom_pipeline_method )(** model_inputs ).images
331350 # autocast im2img and inpaint which are broken in 0.4.0, 0.4.1
332351 # still broken in 0.5.1
333- if call_inputs .get ("PIPELINE" ) != "StableDiffusionPipeline" :
352+ elif call_inputs .get ("PIPELINE" ) != "StableDiffusionPipeline" :
334353 with autocast ("cuda" ):
335354 images = pipeline (** model_inputs ).images
336355 else :
0 commit comments