Skip to content

Commit 6ba0024

Browse files
sywangyitjruwase
andauthored
Enable autoTP for bloom (deepspeedai#3035)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
1 parent 514b020 commit 6ba0024

2 files changed

Lines changed: 66 additions & 3 deletions

File tree

deepspeed/inference/engine.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,58 @@
3434
INFERENCE_MODEL_TIMER = "model-forward-inference"
3535

3636

37+
def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
38+
"""
39+
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
40+
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
41+
`softmax(l+a) = softmax(l)`. Based on
42+
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
43+
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
44+
45+
Args:
46+
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
47+
attention_mask (`torch.Tensor`):
48+
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
49+
num_heads (`int`, *required*):
50+
number of heads
51+
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
52+
dtype of the output tensor
53+
"""
54+
import math
55+
batch_size, seq_length = attention_mask.shape
56+
closest_power_of_2 = 2**math.floor(math.log2(num_heads))
57+
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
58+
device=attention_mask.device,
59+
dtype=torch.float32)
60+
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
61+
slopes = torch.pow(base, powers)
62+
63+
if closest_power_of_2 != num_heads:
64+
extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
65+
device=attention_mask.device,
66+
dtype=torch.float32)
67+
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
68+
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
69+
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
70+
71+
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
72+
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
73+
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
74+
# => the query_length dimension will then be broadcasted correctly
75+
# This is more or less identical to T5's relative position bias:
76+
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
77+
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
78+
alibi = slopes[..., None] * arange_tensor
79+
if dist.is_initialized():
80+
num_heads_per_rank = int(num_heads / dist.get_world_size())
81+
offset = dist.get_rank() * num_heads_per_rank
82+
alibi = alibi.view(batch_size, num_heads, 1, seq_length)
83+
alibi = alibi[:, offset:num_heads_per_rank + offset, :, :]
84+
return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype)
85+
else:
86+
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
87+
88+
3789
class InferenceEngine(Module):
3890
inference_mp_group = None
3991
inference_ep_group = None
@@ -86,8 +138,14 @@ def __init__(self, model, config):
86138
self.model_profile_enabled = False
87139
self._model_times = []
88140

89-
# This is a hack to remove the prepare_mask function on HF side for BLOOM architecture
90-
self.remove_mask_prepare_for_bloom()
141+
if not self.injection_dict and config.replace_with_kernel_inject:
142+
# This is a hack to remove the prepare_mask function on HF side for BLOOM architecture
143+
self.remove_mask_prepare_for_bloom()
144+
145+
if self.injection_dict or not config.replace_with_kernel_inject:
146+
# This is a hack to redefine the alibi func due to TP
147+
if config.tensor_parallel.tp_size > 1:
148+
self.build_alibi_tensor()
91149

92150
if get_accelerator().device_name() == 'cuda' and config.enable_cuda_graph:
93151
assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \
@@ -178,6 +236,11 @@ def remove_mask_prepare_for_bloom(self):
178236
if hasattr(self.module.transformer, '_prepare_attn_mask'):
179237
self.module.transformer._prepare_attn_mask = lambda attention_mask, *args, **kwargs: attention_mask
180238

239+
def build_alibi_tensor(self):
240+
if hasattr(self.module, 'transformer'):
241+
if hasattr(self.module.transformer, 'build_alibi_tensor'):
242+
self.module.transformer.build_alibi_tensor = build_bloom_alibi_tensor
243+
181244
def _pre_forward_hook(self, module, *inputs, **kwargs):
182245
if self.use_cuda_events:
183246
self.timers(INFERENCE_MODEL_TIMER).start()

deepspeed/module_inject/auto_tp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def get_module_list(model):
3232
return mlist
3333

3434
def supported(model):
35-
unsupported = ['bloom', 'codegen', 'deberta', 'flaubert', 'fsmt', 'gpt2', 'led', 'longformer', 'xlm', 'xlnet']
35+
unsupported = ['codegen', 'deberta', 'flaubert', 'fsmt', 'gpt2', 'led', 'longformer', 'xlm', 'xlnet']
3636
model = str(model)
3737
key = re.search(r": (.*?)Model", model)
3838
if key is None:

0 commit comments

Comments
 (0)