Skip to content

Commit 94de022

Browse files
Fix inference api & add more description on inference engine tutorial (deepspeedai#1711)
1 parent 2662fde commit 94de022

3 files changed

Lines changed: 33 additions & 6 deletions

File tree

deepspeed/inference/engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def __init__(self,
5353
replace_method: the injection method, this can be passed as auto if no injection-policy is defined, in which case the injection is automatic based on the available policies
5454
quantization_setting:
5555
one of None, Tuple(mlp_extra_grouping, quantize_groups), quantize_groups
56+
replace_with_kernel_inject: this flag need to be set to true to inject inference kernels for models such as, Bert, GPT2, GPT-Neo and GPT-J. Otherwise,
57+
the injection_dict provides the names of two linear layers as a tuple: (attention_output projection, transformer output projection)
5658
"""
5759

5860
super().__init__()

deepspeed/module_inject/replace_policy.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import ABC
22

33
import torch
4+
from torch.nn.parameter import Parameter
45

56

67
class DSPolicy(ABC):
@@ -66,8 +67,8 @@ def attention(self):
6667
vw = self.client_module.attention.self.value.weight
6768
vb = self.client_module.attention.self.value.bias
6869

69-
qkvw = torch.cat((qw, kw, vw), dim=0)
70-
qkvb = torch.cat((qb, kb, vb), dim=0)
70+
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0))
71+
qkvb = Parameter(torch.cat((qb, kb, vb), dim=0))
7172

7273
return self.linear_layer, \
7374
qkvw, \
@@ -120,7 +121,7 @@ def attention(self):
120121
kw = self.client_module.attn.attention.k_proj.weight
121122
vw = self.client_module.attn.attention.v_proj.weight
122123

123-
qkvw = torch.cat((qw, kw, vw), dim=0)
124+
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0))
124125

125126
return self.linear_layer, \
126127
qkvw, \
@@ -164,7 +165,7 @@ def attention(self):
164165
kw = self.client_module.attn.k_proj.weight
165166
vw = self.client_module.attn.v_proj.weight
166167

167-
qkvw = torch.cat((qw, kw, vw), dim=0)
168+
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0))
168169

169170
return self.linear_layer, \
170171
qkvw, \

docs/_tutorials/inference-tutorial.md

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ DeepSpeed provides a seamless inference mode for compatible transformer based mo
88

99
## Initializing for Inference
1010

11-
For inference with DeepSpeed, use `init_inference` API to load the model for inference. Here, you can specify the MP degree, and if the model has not been loaded with the appropriate checkpoint, you can also provide the checkpoint description using a `json` file. To inject the high-performance kernels, you can pass int the `replace_method` as `'auto'` for the compatible models, or define a new policy in [replace_policy class](https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py) and pass in the `injection_policy` that specifies the different parameters of a Transformer layer, such as attention and feed-forward parts. The `injection_policy` shows the mapping between the parameters of the original layer implementation with the inference-customized Transformer layer.
11+
For inference with DeepSpeed, use `init_inference` API to load the model for inference. Here, you can specify the MP degree, and if the model has not been loaded with the appropriate checkpoint, you can also provide the checkpoint description using a `json` file or the checkpoint path.
12+
13+
To inject the high-performance kernels, you need to set the `replace_with_kernel_inject` to True and pass int the `replace_method` as `'auto'` for the compatible models, or define a new policy in [replace_policy class](https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py) and pass in the `injection_policy` that specifies the different parameters of a Transformer layer, such as attention and feed-forward parts. The `injection_policy` shows the mapping between the parameters of the original layer implementation with the inference-customized Transformer layer.
1214

1315
```python
1416
# create the model
@@ -25,11 +27,33 @@ ds_engine = deepspeed.init_inference(model,
2527
mp_size=2,
2628
dtype=torch.half,
2729
checkpoint=None if args.pre_load_checkpoint else args.checkpoint_json,
28-
replace_method='auto')
30+
replace_method='auto',
31+
replace_with_kernel_inject=True)
2932
model = ds_engine.module
3033
output = model('Input String')
3134
```
3235

36+
To run inference with only model-parallelism for the models that we don't support kernels, you can pass an injection policy that shows the two specific linear layers on a Transformer Encoder/Decoder layer: 1) the attention output GeMM and 2) layer output GeMM. We need these part of the layer to add the required all-reduce communication between GPUs to merge the partial results across model-parallel ranks. Below, we bring an example that shows how you can use deepspeed-inference with a T5 model:
37+
38+
39+
```python
40+
# create the model
41+
import transformers
42+
from transformers.models.t5.modeling_t5 import T5Block
43+
44+
import deepspeed
45+
46+
pipe = pipeline("text2text-generation", model="google/t5-v1_1-small", device=local_rank)
47+
# Initialize the DeepSpeed-Inference engine
48+
pipe.model = deepspeed.init_inference(
49+
pipe.model,
50+
mp_size=world_size,
51+
dtype=torch.float,
52+
injection_policy={T5Block: ('SelfAttention.o', 'EncDecAttention.o', 'DenseReluDense.wo')}
53+
)
54+
output = pipe('Input String')
55+
```
56+
3357
## Loading Checkpoints
3458

3559
For the models trained using HuggingFace, the model checkpoint can be pre-loaded using the `from_pretrained` API as shown above. For Megatron-LM models trained with model parallelism, we require a list of all the model parallel checkpoints passed in JSON config. Below we show how to load a Megatron-LM checkpoint trained using MP=2.

0 commit comments

Comments
 (0)