2020FIXTURES = TESTS + os .sep + "fixtures"
2121OUTPUT = TESTS + os .sep + "output"
2222Path (OUTPUT ).mkdir (parents = True , exist_ok = True )
23- DEFAULT_MODEL_ID = "stabilityai/stable-diffusion-2"
2423
2524
2625def b64encode_file (filename : str ):
@@ -37,13 +36,27 @@ def output_path(filename: str):
3736 return os .path .join (OUTPUT , filename )
3837
3938
39+ # https://stackoverflow.com/a/1094933/1839099
40+ def sizeof_fmt (num , suffix = "B" ):
41+ for unit in ["" , "Ki" , "Mi" , "Gi" , "Ti" , "Pi" , "Ei" , "Zi" ]:
42+ if abs (num ) < 1024.0 :
43+ return f"{ num :3.1f} { unit } { suffix } "
44+ num /= 1024.0
45+ return f"{ num :.1f} Yi{ suffix } "
46+
47+
4048def decode_and_save (image_byte_string : str , name : str ):
4149 image_encoded = image_byte_string .encode ("utf-8" )
4250 image_bytes = BytesIO (base64 .b64decode (image_encoded ))
4351 image = Image .open (image_bytes )
4452 fp = output_path (name + ".png" )
4553 image .save (fp )
4654 print ("Saved " + fp )
55+ size_formatted = sizeof_fmt (os .path .getsize (fp ))
56+
57+ return (
58+ f"[{ image .width } x{ image .height } { image .format } image, { size_formatted } bytes]"
59+ )
4760
4861
4962all_tests = {}
@@ -56,10 +69,14 @@ def test(name, inputs):
5669
5770def runTest (name , banana , extraCallInputs , extraModelInputs ):
5871 inputs = all_tests .get (name )
72+ if not inputs .get ("callInputs" , None ):
73+ inputs .update ({"callInputs" : {}})
5974 inputs .get ("callInputs" ).update (extraCallInputs )
6075 inputs .get ("modelInputs" ).update (extraModelInputs )
6176
6277 print ("Running test: " + name )
78+ print (json .dumps (inputs , indent = 4 ))
79+ print ()
6380
6481 start = time .time ()
6582 if banana :
@@ -139,18 +156,35 @@ def runTest(name, banana, extraCallInputs, extraModelInputs):
139156 result .get ("images_base64" , None ) == None
140157 and result .get ("image_base64" , None ) == None
141158 ):
159+ error = result .get ("$error" , None )
160+ if error :
161+ code = error .get ("code" , None )
162+ name = error .get ("name" , None )
163+ message = error .get ("message" , None )
164+ stack = error .get ("stack" , None )
165+ if code and name and message and stack :
166+ print ()
167+ title = f"Exception { code } on container:"
168+ print (title )
169+ print ("-" * len (title ))
170+ # print(f'{name}("{message}")') - stack includes it.
171+ print (stack )
172+ return
173+
142174 print (json .dumps (result , indent = 4 ))
143175 print ()
144176 return
145177
146178 images_base64 = result .get ("images_base64" , None )
147179 if images_base64 :
148180 for idx , image_byte_string in enumerate (images_base64 ):
149- decode_and_save (image_byte_string , f"{ name } _{ idx } " )
181+ images_base64 [ idx ] = decode_and_save (image_byte_string , f"{ name } _{ idx } " )
150182 else :
151- decode_and_save (result ["image_base64" ], name )
183+ result [ "image_base64" ] = decode_and_save (result ["image_base64" ], name )
152184
153185 print ()
186+ print (json .dumps (result , indent = 4 ))
187+ print ()
154188
155189
156190test (
@@ -161,10 +195,10 @@ def runTest(name, banana, extraCallInputs, extraModelInputs):
161195 "num_inference_steps" : 20 ,
162196 },
163197 "callInputs" : {
164- "MODEL_ID" : DEFAULT_MODEL_ID ,
165- "PIPELINE" : "StableDiffusionPipeline" ,
166- "SCHEDULER" : "DPMSolverMultistepScheduler" ,
167- # "xformers_memory_efficient_attention": False,
198+ # "MODEL_ID": "<override_default>", # (default)
199+ # "PIPELINE": "StableDiffusionPipeline", # (default)
200+ # "SCHEDULER": "DPMSolverMultistepScheduler", # (default)
201+ # "xformers_memory_efficient_attention": False, # (default)
168202 },
169203 },
170204)
@@ -176,12 +210,7 @@ def runTest(name, banana, extraCallInputs, extraModelInputs):
176210 "modelInputs" : {
177211 "prompt" : "realistic field of grass" ,
178212 "num_images_per_prompt" : 2 ,
179- },
180- "callInputs" : {
181- "MODEL_ID" : DEFAULT_MODEL_ID ,
182- "PIPELINE" : "StableDiffusionPipeline" ,
183- "SCHEDULER" : "DPMSolverMultistepScheduler" ,
184- },
213+ }
185214 },
186215)
187216
@@ -194,9 +223,7 @@ def runTest(name, banana, extraCallInputs, extraModelInputs):
194223 "init_image" : b64encode_file ("sketch-mountains-input.jpg" ),
195224 },
196225 "callInputs" : {
197- "MODEL_ID" : DEFAULT_MODEL_ID ,
198226 "PIPELINE" : "StableDiffusionImg2ImgPipeline" ,
199- "SCHEDULER" : "DPMSolverMultistepScheduler" ,
200227 },
201228 },
202229)
@@ -274,9 +301,6 @@ def runTest(name, banana, extraCallInputs, extraModelInputs):
274301 # "push_to_hub": True,
275302 },
276303 "callInputs" : {
277- "MODEL_ID" : DEFAULT_MODEL_ID ,
278- "PIPELINE" : "StableDiffusionPipeline" ,
279- "SCHEDULER" : "DDPMScheduler" ,
280304 "train" : "dreambooth" ,
281305 # Option 2: store on S3. Note the **s3:///* (x3). See notes below.
282306 # "dest_url": "s3:///bucket/filename.tar.zst".
0 commit comments