Skip to content

Commit 7abc4ac

Browse files
committed
feat(app): runtime downloads with MODEL_URL
1 parent c20f013 commit 7abc4ac

1 file changed

Lines changed: 29 additions & 6 deletions

File tree

app.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
from getScheduler import getScheduler, SCHEDULERS
2525
import re
2626
import requests
27+
from download import download_model
2728

29+
RUNTIME_DOWNLOADS = os.getenv("RUNTIME_DOWNLOADS") == "1"
2830
USE_DREAMBOOTH = os.getenv("USE_DREAMBOOTH") == "1"
2931
if USE_DREAMBOOTH:
3032
from train_dreambooth import TrainDreamBooth
@@ -99,10 +101,11 @@ def init():
99101
last_model_id = None
100102
return
101103

102-
model = loadModel(MODEL_ID)
104+
if not RUNTIME_DOWNLOADS:
105+
model = loadModel(MODEL_ID)
103106

104-
if PIPELINE == "ALL":
105-
pipelines = createPipelinesFromModel(model)
107+
if PIPELINE == "ALL":
108+
pipelines = createPipelinesFromModel(model)
106109

107110
send("init", "done")
108111
initTime = get_now() - initStart
@@ -113,6 +116,7 @@ def decodeBase64Image(imageStr: str, name: str) -> PIL.Image:
113116
print(f'Decoded image "{name}": {image.format} {image.width}x{image.height}')
114117
return image
115118

119+
116120
def getFromUrl(url: str, name: str) -> PIL.Image:
117121
response = requests.get(url)
118122
image = PIL.Image.open(BytesIO(response.content))
@@ -135,6 +139,7 @@ def truncateInputs(inputs: dict):
135139

136140

137141
last_xformers_memory_efficient_attention = {}
142+
downloaded_models = {}
138143

139144
# Inference is ran for every server call
140145
# Reference your preloaded global model variable here.
@@ -162,13 +167,31 @@ def inference(all_inputs: dict) -> dict:
162167
startRequestId = call_inputs.get("startRequestId", None)
163168

164169
model_id = call_inputs.get("MODEL_ID")
170+
171+
if RUNTIME_DOWNLOADS:
172+
global downloaded_models
173+
if not downloaded_models.get(model_id, None):
174+
model_url = call_inputs.get("MODEL_URL", None)
175+
if not model_url:
176+
return {
177+
"$error": {
178+
"code": "NO_MODEL_URL",
179+
"message": "Currently RUNTIME_DOWNOADS requires a MODEL_URL callInput",
180+
}
181+
}
182+
download_model(model_id=model_id, model_url=model_url)
183+
downloaded_models.update({model_id: True})
184+
model = loadModel(model_id)
185+
if PIPELINE == "ALL":
186+
pipelines = createPipelinesFromModel(model)
187+
165188
if MODEL_ID == "ALL":
166189
if last_model_id != model_id:
167190
model = loadModel(model_id)
168191
pipelines = createPipelinesFromModel(model)
169192
last_model_id = model_id
170193
else:
171-
if model_id != MODEL_ID:
194+
if model_id != MODEL_ID and not RUNTIME_DOWNLOADS:
172195
return {
173196
"$error": {
174197
"code": "MODEL_MISMATCH",
@@ -183,7 +206,7 @@ def inference(all_inputs: dict) -> dict:
183206
else:
184207
pipeline = model
185208

186-
pipeline.scheduler = getScheduler(MODEL_ID, call_inputs.get("SCHEDULER", None))
209+
pipeline.scheduler = getScheduler(model_id, call_inputs.get("SCHEDULER", None))
187210
if pipeline.scheduler == None:
188211
return {
189212
"$error": {
@@ -263,7 +286,7 @@ def inference(all_inputs: dict) -> dict:
263286
return {
264287
"$error": {
265288
"code": "INVALID_XFORMERS_MEMORY_EFFICIENT_ATTENTION_VALUE",
266-
"message": f'Model "{model_id}" not available on this container which hosts "{MODEL_ID}"',
289+
"message": f"x_m_e_a expects True or False, not: {x_m_e_a}",
267290
"requested": x_m_e_a,
268291
"available": [True, False],
269292
}

0 commit comments

Comments
 (0)