|
4 | 4 | from torch import autocast |
5 | 5 | from diffusers import ( |
6 | 6 | pipelines as _pipelines, |
7 | | - schedulers as _schedulers, |
8 | 7 | LMSDiscreteScheduler, |
9 | 8 | DDIMScheduler, |
10 | 9 | PNDMScheduler, |
|
22 | 21 | import skimage |
23 | 22 | import skimage.measure |
24 | 23 | from PyPatchMatch import patch_match |
| 24 | +from getScheduler import getScheduler, SCHEDULERS |
25 | 25 | import re |
26 | 26 |
|
27 | 27 | MODEL_ID = os.environ.get("MODEL_ID") |
|
35 | 35 | "StableDiffusionInpaintPipelineLegacy", |
36 | 36 | ] |
37 | 37 |
|
38 | | -SCHEDULERS = [ |
39 | | - "LMSDiscreteScheduler", |
40 | | - "DDIMScheduler", |
41 | | - "PNDMScheduler", |
42 | | - "EulerAncestralDiscreteScheduler", |
43 | | - "EulerDiscreteScheduler", |
44 | | -] |
45 | | - |
46 | 38 | torch.set_grad_enabled(False) |
47 | 39 |
|
48 | 40 |
|
@@ -95,30 +87,6 @@ def init(): |
95 | 87 | True, |
96 | 88 | ) |
97 | 89 |
|
98 | | - schedulers = {} |
99 | | - """ |
100 | | - # This was a nice idea but until we have default init vars for all schedulers |
101 | | - # via from_config(), it's a no go. |
102 | | - isScheduler = re.compile(r".+Scheduler$") |
103 | | - for key, val in _schedulers.__dict__.items(): |
104 | | - if isScheduler.match(key): |
105 | | - schedulers.update( |
106 | | - { |
107 | | - key: val.from_config( |
108 | | - MODEL_ID, subfolder="scheduler", use_auth_token=HF_AUTH_TOKEN |
109 | | - ) |
110 | | - } |
111 | | - ) |
112 | | - """ |
113 | | - for scheduler_name in SCHEDULERS: |
114 | | - schedulers.update( |
115 | | - { |
116 | | - scheduler_name: getattr(_schedulers, scheduler_name).from_config( |
117 | | - MODEL_ID, subfolder="scheduler", use_auth_token=HF_AUTH_TOKEN |
118 | | - ), |
119 | | - } |
120 | | - ) |
121 | | - |
122 | 90 | dummy_safety_checker = DummySafetyChecker() |
123 | 91 |
|
124 | 92 | if MODEL_ID == "ALL": |
@@ -200,30 +168,14 @@ def inference(all_inputs: dict) -> dict: |
200 | 168 | else: |
201 | 169 | pipeline = model |
202 | 170 |
|
203 | | - # Check for use of all names |
204 | | - scheduler_name = call_inputs.get("SCHEDULER", None) |
205 | | - deprecated_map = { |
206 | | - "LMS": "LMSDiscreteScheduler", |
207 | | - "DDIM": "DDIMScheduler", |
208 | | - "PNDM": "PNDMScheduler", |
209 | | - } |
210 | | - scheduler_renamed = deprecated_map.get(scheduler_name, None) |
211 | | - if scheduler_renamed != None: |
212 | | - print( |
213 | | - f'[Deprecation Warning]: Scheduler "{scheduler_name}" is now ' |
214 | | - f'called "{scheduler_renamed}". Please rename as this will ' |
215 | | - f"stop working in a future release." |
216 | | - ) |
217 | | - scheduler_name = scheduler_renamed |
218 | | - |
219 | | - pipeline.scheduler = schedulers.get(scheduler_name, None) |
| 171 | + pipeline.scheduler = getScheduler(MODEL_ID, call_inputs.get("SCHEDULER", None)) |
220 | 172 | if pipeline.scheduler == None: |
221 | 173 | return { |
222 | 174 | "$error": { |
223 | 175 | "code": "INVALID_SCHEDULER", |
224 | 176 | "message": "", |
225 | 177 | "requeted": call_inputs.get("SCHEDULER", None), |
226 | | - "available": ", ".join(schedulers.keys()), |
| 178 | + "available": ", ".join(SCHEDULERS), |
227 | 179 | } |
228 | 180 | } |
229 | 181 |
|
@@ -286,10 +238,8 @@ def inference(all_inputs: dict) -> dict: |
286 | 238 | x_m_e_a = call_inputs.get("xformers_memory_efficient_attention", None) |
287 | 239 | if x_m_e_a != last_xformers_memory_efficient_attention: |
288 | 240 | last_xformers_memory_efficient_attention = x_m_e_a |
289 | | - if x_m_e_a == None: |
| 241 | + if x_m_e_a == None or x_m_e_a == True: |
290 | 242 | pipeline.enable_xformers_memory_efficient_attention() # default on |
291 | | - elif x_m_e_a == True: |
292 | | - pipeline.enable_xformers_memory_efficient_attention() |
293 | 243 | elif x_m_e_a == False: |
294 | 244 | pipeline.disable_xformers_memory_efficient_attention() |
295 | 245 | else: |
|
0 commit comments