Skip to content

Commit e0db814

Browse files
authored
Add device comprehension to BLOOM Pipeline utility class (deepspeedai#217)
This PR adds device comprehension to the BLOOM Pipeline utility class to expand support for devices and also support the case where the DeepSpeed init_inference API isn't used.
1 parent 6b15629 commit e0db814

2 files changed

Lines changed: 17 additions & 3 deletions

File tree

inference/huggingface/text-generation/test-bloom.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
world_size = int(os.getenv('WORLD_SIZE', '1'))
1616

1717
pipe = Pipeline(model_name=model_name,
18-
dtype=dtype
18+
dtype=dtype,
19+
is_meta=True,
20+
device=local_rank
1921
)
2022

2123
pipe.model = deepspeed.init_inference(

inference/huggingface/text-generation/utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,21 @@ class Pipeline():
1616
def __init__(self,
1717
model_name='bigscience/bloom-3b',
1818
dtype=torch.float16,
19-
is_meta=True
19+
is_meta=True,
20+
device=-1
2021
):
2122
self.model_name = model_name
2223
self.dtype = dtype
2324

25+
if isinstance(device, torch.device):
26+
self.device = device
27+
elif isinstance(device, str):
28+
self.device = torch.device(device)
29+
elif device < 0:
30+
self.device = torch.device("cpu")
31+
else:
32+
self.device = torch.device(f"cuda:{device}")
33+
2434
# the Deepspeed team made these so it's super fast to load (~1 minute), rather than wait 10-20min loading time.
2535
self.tp_presharded_models = ["microsoft/bloom-deepspeed-inference-int8", "microsoft/bloom-deepspeed-inference-fp16"]
2636

@@ -78,7 +88,9 @@ def generate_outputs(self,
7888
input_tokens = self.tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True)
7989
for t in input_tokens:
8090
if torch.is_tensor(input_tokens[t]):
81-
input_tokens[t] = input_tokens[t].to(torch.cuda.current_device())
91+
input_tokens[t] = input_tokens[t].to(self.device)
92+
93+
self.model.cuda().to(self.device)
8294

8395
outputs = self.model.generate(**input_tokens, **generate_kwargs)
8496
outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)

0 commit comments

Comments
 (0)