99)
1010import base64
1111from io import BytesIO
12- import os
1312import PIL
1413import json
14+ from loadModel import loadModel
1515
1616from APP_VARS import MODEL_ID
1717
@@ -54,8 +54,6 @@ def init():
5454 global schedulers
5555 global dummy_safety_checker
5656
57- HF_AUTH_TOKEN = os .getenv ("HF_AUTH_TOKEN" )
58-
5957 schedulers = {
6058 "LMS" : LMSDiscreteScheduler (
6159 beta_start = 0.00085 , beta_end = 0.012 , beta_schedule = "scaled_linear"
@@ -77,15 +75,7 @@ def init():
7775 last_model_id = None
7876 return
7977
80- print ("Loading model " + MODEL_ID )
81-
82- model = _pipelines .StableDiffusionPipeline .from_pretrained (
83- MODEL_ID ,
84- revision = "fp16" ,
85- torch_dtype = torch .float16 ,
86- use_auth_token = HF_AUTH_TOKEN ,
87- ).to ("cuda" )
88-
78+ model = loadModel (MODEL_ID )
8979 pipelines = createPipelinesFromModel (MODEL_ID )
9080
9181
@@ -120,17 +110,9 @@ def inference(all_inputs: dict) -> dict:
120110 return {"$error" : "UPGRADE CLIENT - no model_inputs specified" }
121111
122112 if MODEL_ID == "ALL" :
123- HF_AUTH_TOKEN = os .getenv ("HF_AUTH_TOKEN" )
124113 model_id = call_inputs .get ("MODEL_ID" )
125114 if last_model_id != model_id :
126- print ("Loading model " + model_id )
127- model = _pipelines .StableDiffusionPipeline .from_pretrained (
128- model_id ,
129- revision = "fp16" ,
130- torch_dtype = torch .float16 ,
131- use_auth_token = HF_AUTH_TOKEN ,
132- ).to ("cuda" )
133-
115+ model = loadModel (model_id )
134116 pipelines = createPipelinesFromModel (model_id )
135117 last_model_id = model_id
136118
@@ -155,18 +137,11 @@ def inference(all_inputs: dict) -> dict:
155137 # seed = model_inputs.get("seed", None)
156138 # strength = model_inputs.get("strength", 0.75)
157139
158- if call_inputs .get ("PIPELINE" ) in [
159- "StableDiffusionImg2ImgPipeline" ,
160- "StableDiffusionInpaintPipeline" ,
161- ]:
162- model_inputs .update (
163- {"init_image" : decodeBase64Image (model_inputs .get ("init_image" ))}
164- )
140+ if "init_image" in model_inputs :
141+ model_inputs ["init_image" ] = decodeBase64Image (model_inputs .get ("init_image" ))
165142
166- if all_inputs .get ("PIPELINE" ) == "StableDiffusionInpaintPipeline" :
167- model_inputs .update (
168- {"mask_image" : decodeBase64Image (model_inputs .get ("mask_image" ))}
169- )
143+ if "mask_image" in model_inputs :
144+ model_inputs ["mask_image" ] = decodeBase64Image (model_inputs .get ("mask_image" ))
170145
171146 seed = model_inputs .get ("seed" , None )
172147 if seed == None :
0 commit comments