@@ -151,6 +151,8 @@ def truncateInputs(inputs: dict):
151151 return clone
152152
153153
154+ last_xformers_memory_efficient_attention = None
155+
154156# Inference is ran for every server call
155157# Reference your preloaded global model variable here.
156158def inference (all_inputs : dict ) -> dict :
@@ -159,15 +161,22 @@ def inference(all_inputs: dict) -> dict:
159161 global last_model_id
160162 global schedulers
161163 global dummy_safety_checker
164+ global last_xformers_memory_efficient_attention
162165
163166 print (json .dumps (truncateInputs (all_inputs ), indent = 2 ))
164167 model_inputs = all_inputs .get ("modelInputs" , None )
165168 call_inputs = all_inputs .get ("callInputs" , None )
166- startRequestId = call_inputs .get ("startRequestId" , None )
167169
168- # Fallback until all clients on new code
169- if model_inputs == None :
170- return {"$error" : "UPGRADE CLIENT - no model_inputs specified" }
170+ if model_inputs == None or call_inputs == None :
171+ return {
172+ "$error" : {
173+ "code" : "INVALID_INPUTS" ,
174+ "message" : "Expecting on object like { modelInputs: {}, callInputs: {} } but got "
175+ + json .dumps (all_inputs ),
176+ }
177+ }
178+
179+ startRequestId = call_inputs .get ("startRequestId" , None )
171180
172181 model_id = call_inputs .get ("MODEL_ID" )
173182 if MODEL_ID == "ALL" :
@@ -273,22 +282,25 @@ def inference(all_inputs: dict) -> dict:
273282 mask = mask .repeat (8 , axis = 0 ).repeat (8 , axis = 1 )
274283 model_inputs ["mask_image" ] = PIL .Image .fromarray (mask )
275284
285+ # Turning on takes 3ms and turning off 1ms... don't worry, I've got your back :)
276286 x_m_e_a = call_inputs .get ("xformers_memory_efficient_attention" , None )
277- if x_m_e_a == None :
278- pipeline .enable_xformers_memory_efficient_attention () # default on
279- elif x_m_e_a == True :
280- pipeline .enable_xformers_memory_efficient_attention ()
281- elif x_m_e_a == False :
282- pipeline .disable_xformers_memory_efficient_attention ()
283- else :
284- return {
285- "$error" : {
286- "code" : "INVALID_XFORMERS_MEMORY_EFFICIENT_ATTENTION_VALUE" ,
287- "message" : f'Model "{ model_id } " not available on this container which hosts "{ MODEL_ID } "' ,
288- "requested" : x_m_e_a ,
289- "available" : [True , False ],
287+ if x_m_e_a != last_xformers_memory_efficient_attention :
288+ last_xformers_memory_efficient_attention = x_m_e_a
289+ if x_m_e_a == None :
290+ pipeline .enable_xformers_memory_efficient_attention () # default on
291+ elif x_m_e_a == True :
292+ pipeline .enable_xformers_memory_efficient_attention ()
293+ elif x_m_e_a == False :
294+ pipeline .disable_xformers_memory_efficient_attention ()
295+ else :
296+ return {
297+ "$error" : {
298+ "code" : "INVALID_XFORMERS_MEMORY_EFFICIENT_ATTENTION_VALUE" ,
299+ "message" : f'Model "{ model_id } " not available on this container which hosts "{ MODEL_ID } "' ,
300+ "requested" : x_m_e_a ,
301+ "available" : [True , False ],
302+ }
290303 }
291- }
292304
293305 # Run the model
294306 # with autocast("cuda"):
0 commit comments