|
34 | 34 | INFERENCE_MODEL_TIMER = "model-forward-inference" |
35 | 35 |
|
36 | 36 |
|
| 37 | +def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: |
| 38 | + """ |
| 39 | + Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it |
| 40 | + relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value |
| 41 | + `softmax(l+a) = softmax(l)`. Based on |
| 42 | + https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 |
| 43 | + TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. |
| 44 | +
|
| 45 | + Args: |
| 46 | + Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) |
| 47 | + attention_mask (`torch.Tensor`): |
| 48 | + Token-wise attention mask, this should be of shape (batch_size, max_seq_len). |
| 49 | + num_heads (`int`, *required*): |
| 50 | + number of heads |
| 51 | + dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): |
| 52 | + dtype of the output tensor |
| 53 | + """ |
| 54 | + import math |
| 55 | + batch_size, seq_length = attention_mask.shape |
| 56 | + closest_power_of_2 = 2**math.floor(math.log2(num_heads)) |
| 57 | + base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), |
| 58 | + device=attention_mask.device, |
| 59 | + dtype=torch.float32) |
| 60 | + powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) |
| 61 | + slopes = torch.pow(base, powers) |
| 62 | + |
| 63 | + if closest_power_of_2 != num_heads: |
| 64 | + extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), |
| 65 | + device=attention_mask.device, |
| 66 | + dtype=torch.float32) |
| 67 | + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) |
| 68 | + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32) |
| 69 | + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) |
| 70 | + |
| 71 | + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention |
| 72 | + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) |
| 73 | + # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) |
| 74 | + # => the query_length dimension will then be broadcasted correctly |
| 75 | + # This is more or less identical to T5's relative position bias: |
| 76 | + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 |
| 77 | + arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] |
| 78 | + alibi = slopes[..., None] * arange_tensor |
| 79 | + if dist.is_initialized(): |
| 80 | + num_heads_per_rank = int(num_heads / dist.get_world_size()) |
| 81 | + offset = dist.get_rank() * num_heads_per_rank |
| 82 | + alibi = alibi.view(batch_size, num_heads, 1, seq_length) |
| 83 | + alibi = alibi[:, offset:num_heads_per_rank + offset, :, :] |
| 84 | + return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype) |
| 85 | + else: |
| 86 | + return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) |
| 87 | + |
| 88 | + |
37 | 89 | class InferenceEngine(Module): |
38 | 90 | inference_mp_group = None |
39 | 91 | inference_ep_group = None |
@@ -86,8 +138,14 @@ def __init__(self, model, config): |
86 | 138 | self.model_profile_enabled = False |
87 | 139 | self._model_times = [] |
88 | 140 |
|
89 | | - # This is a hack to remove the prepare_mask function on HF side for BLOOM architecture |
90 | | - self.remove_mask_prepare_for_bloom() |
| 141 | + if not self.injection_dict and config.replace_with_kernel_inject: |
| 142 | + # This is a hack to remove the prepare_mask function on HF side for BLOOM architecture |
| 143 | + self.remove_mask_prepare_for_bloom() |
| 144 | + |
| 145 | + if self.injection_dict or not config.replace_with_kernel_inject: |
| 146 | + # This is a hack to redefine the alibi func due to TP |
| 147 | + if config.tensor_parallel.tp_size > 1: |
| 148 | + self.build_alibi_tensor() |
91 | 149 |
|
92 | 150 | if get_accelerator().device_name() == 'cuda' and config.enable_cuda_graph: |
93 | 151 | assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \ |
@@ -178,6 +236,11 @@ def remove_mask_prepare_for_bloom(self): |
178 | 236 | if hasattr(self.module.transformer, '_prepare_attn_mask'): |
179 | 237 | self.module.transformer._prepare_attn_mask = lambda attention_mask, *args, **kwargs: attention_mask |
180 | 238 |
|
| 239 | + def build_alibi_tensor(self): |
| 240 | + if hasattr(self.module, 'transformer'): |
| 241 | + if hasattr(self.module.transformer, 'build_alibi_tensor'): |
| 242 | + self.module.transformer.build_alibi_tensor = build_bloom_alibi_tensor |
| 243 | + |
181 | 244 | def _pre_forward_hook(self, module, *inputs, **kwargs): |
182 | 245 | if self.use_cuda_events: |
183 | 246 | self.timers(INFERENCE_MODEL_TIMER).start() |
|
0 commit comments