2424from getScheduler import getScheduler , SCHEDULERS
2525import re
2626import requests
27+ from download import download_model
2728
29+ RUNTIME_DOWNLOADS = os .getenv ("RUNTIME_DOWNLOADS" ) == "1"
2830USE_DREAMBOOTH = os .getenv ("USE_DREAMBOOTH" ) == "1"
2931if USE_DREAMBOOTH :
3032 from train_dreambooth import TrainDreamBooth
@@ -99,10 +101,11 @@ def init():
99101 last_model_id = None
100102 return
101103
102- model = loadModel (MODEL_ID )
104+ if not RUNTIME_DOWNLOADS :
105+ model = loadModel (MODEL_ID )
103106
104- if PIPELINE == "ALL" :
105- pipelines = createPipelinesFromModel (model )
107+ if PIPELINE == "ALL" :
108+ pipelines = createPipelinesFromModel (model )
106109
107110 send ("init" , "done" )
108111 initTime = get_now () - initStart
@@ -113,6 +116,7 @@ def decodeBase64Image(imageStr: str, name: str) -> PIL.Image:
113116 print (f'Decoded image "{ name } ": { image .format } { image .width } x{ image .height } ' )
114117 return image
115118
119+
116120def getFromUrl (url : str , name : str ) -> PIL .Image :
117121 response = requests .get (url )
118122 image = PIL .Image .open (BytesIO (response .content ))
@@ -135,6 +139,7 @@ def truncateInputs(inputs: dict):
135139
136140
137141last_xformers_memory_efficient_attention = {}
142+ downloaded_models = {}
138143
139144# Inference is ran for every server call
140145# Reference your preloaded global model variable here.
@@ -162,13 +167,31 @@ def inference(all_inputs: dict) -> dict:
162167 startRequestId = call_inputs .get ("startRequestId" , None )
163168
164169 model_id = call_inputs .get ("MODEL_ID" )
170+
171+ if RUNTIME_DOWNLOADS :
172+ global downloaded_models
173+ if not downloaded_models .get (model_id , None ):
174+ model_url = call_inputs .get ("MODEL_URL" , None )
175+ if not model_url :
176+ return {
177+ "$error" : {
178+ "code" : "NO_MODEL_URL" ,
179+ "message" : "Currently RUNTIME_DOWNOADS requires a MODEL_URL callInput" ,
180+ }
181+ }
182+ download_model (model_id = model_id , model_url = model_url )
183+ downloaded_models .update ({model_id : True })
184+ model = loadModel (model_id )
185+ if PIPELINE == "ALL" :
186+ pipelines = createPipelinesFromModel (model )
187+
165188 if MODEL_ID == "ALL" :
166189 if last_model_id != model_id :
167190 model = loadModel (model_id )
168191 pipelines = createPipelinesFromModel (model )
169192 last_model_id = model_id
170193 else :
171- if model_id != MODEL_ID :
194+ if model_id != MODEL_ID and not RUNTIME_DOWNLOADS :
172195 return {
173196 "$error" : {
174197 "code" : "MODEL_MISMATCH" ,
@@ -183,7 +206,7 @@ def inference(all_inputs: dict) -> dict:
183206 else :
184207 pipeline = model
185208
186- pipeline .scheduler = getScheduler (MODEL_ID , call_inputs .get ("SCHEDULER" , None ))
209+ pipeline .scheduler = getScheduler (model_id , call_inputs .get ("SCHEDULER" , None ))
187210 if pipeline .scheduler == None :
188211 return {
189212 "$error" : {
@@ -263,7 +286,7 @@ def inference(all_inputs: dict) -> dict:
263286 return {
264287 "$error" : {
265288 "code" : "INVALID_XFORMERS_MEMORY_EFFICIENT_ATTENTION_VALUE" ,
266- "message" : f'Model " { model_id } " not available on this container which hosts " { MODEL_ID } "' ,
289+ "message" : f"x_m_e_a expects True or False, not: { x_m_e_a } " ,
267290 "requested" : x_m_e_a ,
268291 "available" : [True , False ],
269292 }
0 commit comments