Skip to content

Commit 7d0260b

Browse files
awan-10lekurile
andauthored
Add BLOOM huggingface inference example (deepspeedai#211)
This PR adds a bloom inference example (bigscience/bloom-3b) and a corresponding helper Pipeline class meant to mimic the functionality and API of the huggingface pipelines. This class was added in order to comprehend bloom meta tensors and checkpoint loading in a more organized way, that closely matched the existing examples. This PR also cleans up extra whitespace across the inference examples. Co-authored-by: Lev Kurilenko <lekurile@microsoft.com>
1 parent 8c1339c commit 7d0260b

6 files changed

Lines changed: 119 additions & 11 deletions

File tree

inference/huggingface/automatic-speech-recognition/test-wav2vec2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
dtype=torch.float,
2828
injection_policy={Wav2Vec2EncoderLayer: ('attention.out_proj','feed_forward.output_dense')},
2929
replace_with_kernel_inject=False)
30-
model.to(f'cuda:{local_rank}')
30+
model.to(f'cuda:{local_rank}')
3131
def map_to_array(batch):
3232
speech, _ = sf.read(batch["file"])
3333
batch["speech"] = speech

inference/huggingface/text-generation/run-generation-script/test-run-generation.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def main():
193193
required=False,
194194
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
195195
)
196-
196+
197197
parser.add_argument("--prompt", type=str, default="")
198198
parser.add_argument("--length", type=int, default=20)
199199
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
@@ -214,7 +214,7 @@ def main():
214214
parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
215215
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
216216

217-
parser.add_argument("--local_rank", type=int, default=0, help="local rank")
217+
parser.add_argument("--local_rank", type=int, default=0, help="local rank")
218218
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
219219
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
220220
parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
@@ -235,7 +235,7 @@ def main():
235235
args.n_gpu,
236236
args.fp16,
237237
)
238-
238+
239239
set_seed(args)
240240

241241
# Initialize the model and tokenizer
@@ -256,9 +256,9 @@ def main():
256256
if args.ds_inference:
257257
import deepspeed.module_inject as module_inject
258258
import deepspeed
259-
injection_policy={gpt2_transformer:
259+
injection_policy={gpt2_transformer:
260260
module_inject.replace_policy.HFGPT2LayerPolicy}
261-
model = deepspeed.init_inference(model,
261+
model = deepspeed.init_inference(model,
262262
mp_size=1,
263263
dtype=(torch.half if args.fp16 else torch.float),
264264
injection_policy=injection_policy,
@@ -293,7 +293,7 @@ def main():
293293
prefix = args.prefix if args.prefix else args.padding_text
294294
for ppt in prompt_text:
295295
eprompt.append(tokenizer.encode(prefix + ppt, add_special_tokens=False, return_tensors="pt"))
296-
296+
297297
latencies = []
298298
for encoded_prompt, ppt in zip(eprompt, prompt_text):
299299
encoded_prompt = encoded_prompt.to(args.device)
@@ -302,10 +302,10 @@ def main():
302302
input_ids = None
303303
else:
304304
input_ids = encoded_prompt
305-
305+
306306
torch.cuda.synchronize()
307307
t0 = time.time()
308-
308+
309309
output_sequences = model.generate(
310310
input_ids=input_ids,
311311
max_length=args.length + len(encoded_prompt[0]),
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import os
2+
import torch
3+
import deepspeed
4+
import transformers
5+
6+
# Pipeline class to mimic HF pipeline
7+
from utils import Pipeline
8+
9+
model_name = 'bigscience/bloom-3b'
10+
dtype = torch.float16
11+
num_tokens = 100
12+
13+
# Get local gpu rank from torch.distributed/deepspeed launcher
14+
local_rank = int(os.getenv('LOCAL_RANK', '0'))
15+
world_size = int(os.getenv('WORLD_SIZE', '1'))
16+
17+
pipe = Pipeline(model_name=model_name,
18+
dtype=dtype
19+
)
20+
21+
pipe.model = deepspeed.init_inference(
22+
pipe.model,
23+
mp_size=world_size,
24+
dtype=dtype,
25+
replace_with_kernel_inject=True,
26+
base_dir=pipe.repo_root,
27+
checkpoint=pipe.checkpoints_json
28+
)
29+
30+
output = pipe('DeepSpeed is', num_tokens=num_tokens, do_sample=False)
31+
print(output)

inference/huggingface/text-generation/test-gpt2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@
2727
mp_size=world_size,
2828
dtype=torch.half,
2929
replace_with_kernel_inject=True)
30-
30+
3131
string = generator("DeepSpeed is", min_length=50, max_length=50, do_sample=True, use_cache=True)
3232
print(string)

inference/huggingface/text-generation/test-gptj.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@
2727
mp_size=world_size,
2828
dtype=torch.half,
2929
replace_with_kernel_inject=True)
30-
30+
3131
string = generator("DeepSpeed is", min_length=50, max_length=50, do_sample=True, use_cache=True)
3232
print(string)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
'''
2+
Helper classes and functions for examples
3+
'''
4+
5+
import io
6+
from pathlib import Path
7+
import json
8+
import deepspeed
9+
import torch
10+
from huggingface_hub import snapshot_download
11+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
12+
13+
class Pipeline():
14+
'''Example helper class, meant to mimic HF pipelines'''
15+
def __init__(self,
16+
model_name='bigscience/bloom-3b',
17+
dtype=torch.float16,
18+
is_meta=True
19+
):
20+
self.model_name = model_name
21+
self.dtype = dtype
22+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
23+
24+
if (is_meta):
25+
'''When meta tensors enabled, use checkpoints'''
26+
self.config = AutoConfig.from_pretrained(self.model_name)
27+
self.repo_root, self.checkpoints_json = self.generate_json()
28+
29+
with deepspeed.OnDevice(dtype=self.dtype, device="meta"):
30+
self.model = AutoModelForCausalLM.from_config(self.config, torch_dtype=self.dtype)
31+
else:
32+
self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
33+
34+
self.model.eval()
35+
36+
37+
def __call__(self,
38+
inputs=["test"],
39+
num_tokens=100,
40+
do_sample=False):
41+
if isinstance(inputs, str):
42+
input_list = [inputs]
43+
else:
44+
input_list = inputs
45+
46+
outputs = self.generate_outputs(input_list, num_tokens=num_tokens, do_sample=do_sample)
47+
return outputs
48+
49+
50+
def generate_json(self):
51+
repo_root = snapshot_download(self.model_name, allow_patterns=["*"], local_files_only=False, revision=None)
52+
53+
checkpoints_json = "checkpoints.json"
54+
55+
with io.open(checkpoints_json, "w", encoding="utf-8") as f:
56+
file_list = [str(entry) for entry in Path(repo_root).rglob("*.[bp][it][n]") if entry.is_file()]
57+
data = {"type": self.config.model_type, "checkpoints": file_list, "version": 1.0}
58+
json.dump(data, f)
59+
60+
return repo_root, checkpoints_json
61+
62+
63+
def generate_outputs(self,
64+
inputs=["test"],
65+
num_tokens=100,
66+
do_sample=False):
67+
generate_kwargs = dict(max_new_tokens=num_tokens, do_sample=do_sample)
68+
69+
input_tokens = self.tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True)
70+
for t in input_tokens:
71+
if torch.is_tensor(input_tokens[t]):
72+
input_tokens[t] = input_tokens[t].to(torch.cuda.current_device())
73+
74+
outputs = self.model.generate(**input_tokens, **generate_kwargs)
75+
outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
76+
77+
return outputs

0 commit comments

Comments
 (0)