Skip to content

Commit 6b15629

Browse files
authored
Add presharded model support to BLOOM example (deepspeedai#213)
Add presharded model support to BLOOM example Co-authored-by: Lev Kurilenko <lekurile@microsoft.com>
1 parent 7d0260b commit 6b15629

1 file changed

Lines changed: 18 additions & 9 deletions

File tree

  • inference/huggingface/text-generation

inference/huggingface/text-generation/utils.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Helper classes and functions for examples
33
'''
44

5+
import os
56
import io
67
from pathlib import Path
78
import json
@@ -19,15 +20,19 @@ def __init__(self,
1920
):
2021
self.model_name = model_name
2122
self.dtype = dtype
23+
24+
# the Deepspeed team made these so it's super fast to load (~1 minute), rather than wait 10-20min loading time.
25+
self.tp_presharded_models = ["microsoft/bloom-deepspeed-inference-int8", "microsoft/bloom-deepspeed-inference-fp16"]
26+
2227
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
2328

2429
if (is_meta):
2530
'''When meta tensors enabled, use checkpoints'''
2631
self.config = AutoConfig.from_pretrained(self.model_name)
27-
self.repo_root, self.checkpoints_json = self.generate_json()
32+
self.repo_root, self.checkpoints_json = self._generate_json()
2833

29-
with deepspeed.OnDevice(dtype=self.dtype, device="meta"):
30-
self.model = AutoModelForCausalLM.from_config(self.config, torch_dtype=self.dtype)
34+
with deepspeed.OnDevice(dtype=torch.float16, device="meta"):
35+
self.model = AutoModelForCausalLM.from_config(self.config)
3136
else:
3237
self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
3338

@@ -47,15 +52,19 @@ def __call__(self,
4752
return outputs
4853

4954

50-
def generate_json(self):
55+
def _generate_json(self):
5156
repo_root = snapshot_download(self.model_name, allow_patterns=["*"], local_files_only=False, revision=None)
5257

53-
checkpoints_json = "checkpoints.json"
58+
if (self.model_name in self.tp_presharded_models):
59+
# tp presharded repos come with their own checkpoints config file
60+
checkpoints_json = os.path.join(repo_root, "ds_inference_config.json")
61+
else:
62+
checkpoints_json = "checkpoints.json"
5463

55-
with io.open(checkpoints_json, "w", encoding="utf-8") as f:
56-
file_list = [str(entry) for entry in Path(repo_root).rglob("*.[bp][it][n]") if entry.is_file()]
57-
data = {"type": self.config.model_type, "checkpoints": file_list, "version": 1.0}
58-
json.dump(data, f)
64+
with io.open(checkpoints_json, "w", encoding="utf-8") as f:
65+
file_list = [str(entry) for entry in Path(repo_root).rglob("*.[bp][it][n]") if entry.is_file()]
66+
data = {"type": self.config.model_type, "checkpoints": file_list, "version": 1.0}
67+
json.dump(data, f)
5968

6069
return repo_root, checkpoints_json
6170

0 commit comments

Comments
 (0)