Skip to content

Commit 8eca8fe

Browse files
committed
app.py: INVALID_INPUTS error; only enable/disable xformers as needed
1 parent dde607d commit 8eca8fe

File tree

1 file changed

+30
-18
lines changed

1 file changed

+30
-18
lines changed

app.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ def truncateInputs(inputs: dict):
151151
return clone
152152

153153

154+
last_xformers_memory_efficient_attention = None
155+
154156
# Inference is ran for every server call
155157
# Reference your preloaded global model variable here.
156158
def inference(all_inputs: dict) -> dict:
@@ -159,15 +161,22 @@ def inference(all_inputs: dict) -> dict:
159161
global last_model_id
160162
global schedulers
161163
global dummy_safety_checker
164+
global last_xformers_memory_efficient_attention
162165

163166
print(json.dumps(truncateInputs(all_inputs), indent=2))
164167
model_inputs = all_inputs.get("modelInputs", None)
165168
call_inputs = all_inputs.get("callInputs", None)
166-
startRequestId = call_inputs.get("startRequestId", None)
167169

168-
# Fallback until all clients on new code
169-
if model_inputs == None:
170-
return {"$error": "UPGRADE CLIENT - no model_inputs specified"}
170+
if model_inputs == None or call_inputs == None:
171+
return {
172+
"$error": {
173+
"code": "INVALID_INPUTS",
174+
"message": "Expecting on object like { modelInputs: {}, callInputs: {} } but got "
175+
+ json.dumps(all_inputs),
176+
}
177+
}
178+
179+
startRequestId = call_inputs.get("startRequestId", None)
171180

172181
model_id = call_inputs.get("MODEL_ID")
173182
if MODEL_ID == "ALL":
@@ -273,22 +282,25 @@ def inference(all_inputs: dict) -> dict:
273282
mask = mask.repeat(8, axis=0).repeat(8, axis=1)
274283
model_inputs["mask_image"] = PIL.Image.fromarray(mask)
275284

285+
# Turning on takes 3ms and turning off 1ms... don't worry, I've got your back :)
276286
x_m_e_a = call_inputs.get("xformers_memory_efficient_attention", None)
277-
if x_m_e_a == None:
278-
pipeline.enable_xformers_memory_efficient_attention() # default on
279-
elif x_m_e_a == True:
280-
pipeline.enable_xformers_memory_efficient_attention()
281-
elif x_m_e_a == False:
282-
pipeline.disable_xformers_memory_efficient_attention()
283-
else:
284-
return {
285-
"$error": {
286-
"code": "INVALID_XFORMERS_MEMORY_EFFICIENT_ATTENTION_VALUE",
287-
"message": f'Model "{model_id}" not available on this container which hosts "{MODEL_ID}"',
288-
"requested": x_m_e_a,
289-
"available": [True, False],
287+
if x_m_e_a != last_xformers_memory_efficient_attention:
288+
last_xformers_memory_efficient_attention = x_m_e_a
289+
if x_m_e_a == None:
290+
pipeline.enable_xformers_memory_efficient_attention() # default on
291+
elif x_m_e_a == True:
292+
pipeline.enable_xformers_memory_efficient_attention()
293+
elif x_m_e_a == False:
294+
pipeline.disable_xformers_memory_efficient_attention()
295+
else:
296+
return {
297+
"$error": {
298+
"code": "INVALID_XFORMERS_MEMORY_EFFICIENT_ATTENTION_VALUE",
299+
"message": f'Model "{model_id}" not available on this container which hosts "{MODEL_ID}"',
300+
"requested": x_m_e_a,
301+
"available": [True, False],
302+
}
290303
}
291-
}
292304

293305
# Run the model
294306
# with autocast("cuda"):

0 commit comments

Comments
 (0)