22Helper classes and functions for examples
33'''
44
5+ import os
56import io
67from pathlib import Path
78import json
@@ -19,15 +20,19 @@ def __init__(self,
1920 ):
2021 self .model_name = model_name
2122 self .dtype = dtype
23+
24+ # the Deepspeed team made these so it's super fast to load (~1 minute), rather than wait 10-20min loading time.
25+ self .tp_presharded_models = ["microsoft/bloom-deepspeed-inference-int8" , "microsoft/bloom-deepspeed-inference-fp16" ]
26+
2227 self .tokenizer = AutoTokenizer .from_pretrained (model_name )
2328
2429 if (is_meta ):
2530 '''When meta tensors enabled, use checkpoints'''
2631 self .config = AutoConfig .from_pretrained (self .model_name )
27- self .repo_root , self .checkpoints_json = self .generate_json ()
32+ self .repo_root , self .checkpoints_json = self ._generate_json ()
2833
29- with deepspeed .OnDevice (dtype = self . dtype , device = "meta" ):
30- self .model = AutoModelForCausalLM .from_config (self .config , torch_dtype = self . dtype )
34+ with deepspeed .OnDevice (dtype = torch . float16 , device = "meta" ):
35+ self .model = AutoModelForCausalLM .from_config (self .config )
3136 else :
3237 self .model = AutoModelForCausalLM .from_pretrained (self .model_name )
3338
@@ -47,15 +52,19 @@ def __call__(self,
4752 return outputs
4853
4954
50- def generate_json (self ):
55+ def _generate_json (self ):
5156 repo_root = snapshot_download (self .model_name , allow_patterns = ["*" ], local_files_only = False , revision = None )
5257
53- checkpoints_json = "checkpoints.json"
58+ if (self .model_name in self .tp_presharded_models ):
59+ # tp presharded repos come with their own checkpoints config file
60+ checkpoints_json = os .path .join (repo_root , "ds_inference_config.json" )
61+ else :
62+ checkpoints_json = "checkpoints.json"
5463
55- with io .open (checkpoints_json , "w" , encoding = "utf-8" ) as f :
56- file_list = [str (entry ) for entry in Path (repo_root ).rglob ("*.[bp][it][n]" ) if entry .is_file ()]
57- data = {"type" : self .config .model_type , "checkpoints" : file_list , "version" : 1.0 }
58- json .dump (data , f )
64+ with io .open (checkpoints_json , "w" , encoding = "utf-8" ) as f :
65+ file_list = [str (entry ) for entry in Path (repo_root ).rglob ("*.[bp][it][n]" ) if entry .is_file ()]
66+ data = {"type" : self .config .model_type , "checkpoints" : file_list , "version" : 1.0 }
67+ json .dump (data , f )
5968
6069 return repo_root , checkpoints_json
6170
0 commit comments