44from diffusers import schedulers as _schedulers
55
66HF_AUTH_TOKEN = os .getenv ("HF_AUTH_TOKEN" )
7- DEFAULT_SCHEDULER = os .getenv ("DEFAULT_SCHEDULER" )
87
98SCHEDULERS = [
109 "LMSDiscreteScheduler" ,
1413 "EulerDiscreteScheduler" ,
1514]
1615
16+ DEFAULT_SCHEDULER = os .getenv ("DEFAULT_SCHEDULER" , SCHEDULERS [0 ])
17+
18+
1719"""
1820# This was a nice idea but until we have default init vars for all schedulers
1921# via from_config(), it's a no go. In any case, loading a scheduler takes time
3133"""
3234
3335
34- def initScheduler (MODEL_ID : str , scheduler_id : str ):
36+ def initScheduler (MODEL_ID : str , scheduler_id : str , download = False ):
3537 print (f"Initializing { scheduler_id } for { MODEL_ID } ..." )
3638 start = time .time ()
3739 scheduler = getattr (_schedulers , scheduler_id )
3840 if scheduler == None :
3941 return None
4042
4143 inittedScheduler = scheduler .from_config (
42- MODEL_ID , subfolder = "scheduler" , use_auth_token = HF_AUTH_TOKEN
44+ MODEL_ID ,
45+ subfolder = "scheduler" ,
46+ use_auth_token = HF_AUTH_TOKEN ,
47+ local_files_only = not download ,
4348 )
4449 diff = round ((time .time () - start ) * 1000 )
4550 print (f"Initialized { scheduler_id } for { MODEL_ID } in { diff } ms" )
@@ -50,7 +55,7 @@ def initScheduler(MODEL_ID: str, scheduler_id: str):
5055schedulers = {}
5156
5257
53- def getScheduler (MODEL_ID : str , scheduler_id : str ):
58+ def getScheduler (MODEL_ID : str , scheduler_id : str , download = False ):
5459 schedulersByModel = schedulers .get (MODEL_ID , None )
5560 if schedulersByModel == None :
5661 schedulersByModel = {}
@@ -73,7 +78,7 @@ def getScheduler(MODEL_ID: str, scheduler_id: str):
7378
7479 scheduler = schedulersByModel .get (scheduler_id , None )
7580 if scheduler == None :
76- scheduler = initScheduler (MODEL_ID , scheduler_id )
81+ scheduler = initScheduler (MODEL_ID , scheduler_id , download )
7782 schedulersByModel .update ({scheduler_id : scheduler })
7883
7984 return scheduler
0 commit comments