diff --git a/fms/models/__init__.py b/fms/models/__init__.py index 1604ddbce..b6eb24c41 100644 --- a/fms/models/__init__.py +++ b/fms/models/__init__.py @@ -500,6 +500,7 @@ def model_wrap(model): gpt_bigcode, gpt_oss, granite, + granite41, granite_moe_hybrid, llama, llava_next, @@ -518,6 +519,7 @@ def model_wrap(model): "gpt_bigcode", "gpt_oss", "granite", + "granite41", "granite_moe_hybrid", "llama", "llava_next", diff --git a/fms/models/granite41.py b/fms/models/granite41.py new file mode 100644 index 000000000..31108a79e --- /dev/null +++ b/fms/models/granite41.py @@ -0,0 +1,600 @@ +import logging +import math +import re +from dataclasses import dataclass +from typing import Any, Mapping, Optional, Tuple +from typing_extensions import Unpack + +import torch +import torch.nn as nn + +from fms import models +from fms.distributed.strategy import DistributedStrategy, NoOpStrategy +from fms.modules.attention import ( + AttentionKwargs, + MultiHeadAttention, + get_attention_type, +) +from fms.modules.feedforward import GatedLinearUnit +from fms.modules.layernorm import LayerNormParameterized +from fms.modules.linear import get_linear_type +from fms.modules.positions import RotaryEmbedding +from fms.utils import serialization +from fms.utils.activation import str_to_activation +from fms.utils.config import ModelConfig +from fms.utils.headless import gather_outputs + + +logger = logging.getLogger(__name__) + + +@dataclass +class Granite41Config(ModelConfig): + src_vocab_size: int = 100352 # can be set by tokenizer + emb_dim: int = 4096 + norm_eps: float = 1e-5 + nheads: int = 32 + head_dim: int = 128 # getattr(config, "head_dim", emb_dim // nheads) + kvheads: int = 8 + nlayers: int = 4 + pad_id: int = -1 + hidden_grow_factor: float = 3.125 + multiple_of: int = 256 + activation_fn: str = "swish" + p_dropout: float = 0.0 + max_expected_seq_len: int = 131072 + ntk_scaling: bool = False + attn_bias: bool = False + mlp_bias: bool = False + tie_heads: bool = True + rope_theta: float = 10_000_000.0 + embedding_multiplier: float = 12.0 + logits_scaling: float = 16.0 + residual_multiplier: float = 0.22 + attention_multiplier: float = 0.0078125 + linear_config: Optional[Mapping[str, Any]] = None + window_length: int = 128 + fused_weights: bool = True + + +class Granite41Block(nn.Module): + def __init__(self, config: Granite41Config, rotary_emb: RotaryEmbedding, window_length: int = 0): + super(Granite41Block, self).__init__() + self.config = config + emb_kq = self.config.head_dim + emb_v = self.config.head_dim + + self.ln = LayerNormParameterized( + self.config.emb_dim, + elementwise_scale=True, + elementwise_shift=False, + use_mean=False, + eps=self.config.norm_eps, + use_high_precision_pow=True, + ) + self.ff_ln = LayerNormParameterized( + self.config.emb_dim, + elementwise_scale=True, + elementwise_shift=False, + use_mean=False, + eps=self.config.norm_eps, + use_high_precision_pow=True, + ) + + if self.config.kvheads == 0: + kvheads = self.config.nheads + else: + kvheads = self.config.kvheads + assert self.config.nheads % self.config.kvheads == 0 + + self.attn = MultiHeadAttention( + self.config.emb_dim, + emb_kq, + emb_v, + self.config.nheads, + kvheads, + p_dropout=self.config.p_dropout, + use_bias=self.config.attn_bias, + position_encoder=rotary_emb, + fused=self.config.fused_weights, + linear_config=self.config.linear_config, + scale_factor=self.config.attention_multiplier, + has_sinks=window_length > 0, + ) + + self.ff_sub_layer = GatedLinearUnit( + self.config.emb_dim, + hidden_grow_factor=self.config.hidden_grow_factor, + multiple_of=self.config.multiple_of, + activation_fn=str_to_activation(self.config.activation_fn), + p_dropout=self.config.p_dropout, + use_bias=self.config.mlp_bias, + fused=self.config.fused_weights, + linear_config=self.config.linear_config, + ) + + if self.config.p_dropout != 0: + self.dropout = nn.Dropout(self.config.p_dropout) + + self.window_length = window_length + + def forward( + self, + x, + *, + position_ids=None, + past_key_value_state=None, + use_cache=False, + **attn_kwargs: Unpack[AttentionKwargs], + ): + if self.window_length > 0: + old_attn_name = attn_kwargs.get("attn_name", "") + if "sdpa" in old_attn_name: + attn_kwargs["attn_name"] = "sdpa_with_sinks" + elif "paged" in old_attn_name: + attn_kwargs["attn_name"] = "spyre_paged_attn_with_sinks" + # if the cache is not empty, we need to get the kv cache for self and cross attention + self_attn_past_key_value = past_key_value_state + + # first we do MHA and Add&Norm + residual = x + x = self.ln(x) + x = self.attn( + q=x, + position_ids=position_ids, + past_key_value_state=self_attn_past_key_value, + use_cache=use_cache, + sliding_window=self.window_length, + **attn_kwargs, + ) + cache = None + if use_cache: + x, cache = x + if self.config.p_dropout != 0: + x = self.dropout(x) + # residual connection + x = x * self.config.residual_multiplier + residual + + # then we do FF and Add&Norm + residual = x + x = self.ff_ln(x) + x = self.ff_sub_layer(x) + if self.config.p_dropout != 0: + x = self.dropout(x) + # another residual + x = x * self.config.residual_multiplier + residual + + if self.window_length > 0: + attn_kwargs["attn_name"] = old_attn_name + + if use_cache: + return (x, cache) + else: + return x + + +class Granite41Headless(nn.Module): + def __init__( + self, + config: Optional[Granite41Config] = None, + distributed_strategy: DistributedStrategy = NoOpStrategy, + **kwargs, + ): + super(Granite41Headless, self).__init__() + if config is not None: + self.config = config + else: + self.config = Granite41Config() + self.config = self.config.updated(**kwargs) + self.distributed_strategy = distributed_strategy + + self.width = self.config.emb_dim + self.pad_id = self.config.pad_id + self.max_expected_seq_len = self.config.max_expected_seq_len + + self.embedding = nn.Embedding( + self.config.src_vocab_size, + self.config.emb_dim, + padding_idx=self.config.pad_id, + ) + + rope_scaling = {"rope_type": "ntk" if self.config.ntk_scaling else "regular"} + + self.rot_emb = RotaryEmbedding( + dim=self.config.head_dim, + scaling=rope_scaling, + max_seq_len=self.config.max_expected_seq_len, + ratio=self.config.rope_theta, + ) + # RoPE init + for device in set( + [param.device for param in self.parameters()] + + [buffer.device for buffer in self.buffers()] + ): + self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) + + layers = [] + for i in range(self.config.nlayers): + window_length = 0 if i % 2 == 0 else self.config.window_length + block: nn.Module = Granite41Block(self.config, self.rot_emb, window_length) + block = self.distributed_strategy.distribute_layer(block, i) + layers.append(block) + self.layers = nn.ModuleList(layers) + + dec_norm = LayerNormParameterized( + self.config.emb_dim, + elementwise_scale=True, + elementwise_shift=False, + use_mean=False, + eps=self.config.norm_eps, + use_high_precision_pow=True, + ) + self.dec_norm = self.distributed_strategy.distribute_module( + dec_norm, final_layers=True + ) + + if self.config.p_dropout: + self.dropout = nn.Dropout(self.config.p_dropout) + + def reset_parameters(self): + nn.init.trunc_normal_( + self.embedding.weight, mean=0.0, std=self.config.emb_dim**-0.5 + ) + + # RoPE init + for device in set( + [param.device for param in self.parameters()] + + [buffer.device for buffer in self.buffers()] + ): + self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) + + # Call reset_parameters for relevant sub-layers + for m in self.modules(): + if ( + isinstance(m, MultiHeadAttention) + or isinstance(m, GatedLinearUnit) + or isinstance(m, LayerNormParameterized) + ): + m.reset_parameters() + + def _clean_up_rot_emb_cache( + self, + cached_freqs: dict[Optional[torch.device], dict[int, torch.Tensor]], + max_seq_len_cached: dict[Optional[torch.device], int], + ): + # remove meta tensors from cached_freqs + for dev in list(cached_freqs.keys()): + for alp in list(cached_freqs[dev].keys()): + if cached_freqs[dev][alp].device == torch.device("meta"): + del cached_freqs[dev][alp] + if len(cached_freqs[dev]) == 0: + del cached_freqs[dev] + del max_seq_len_cached[dev] + + def post_init(self): + # This function is called in `get_model` after the model is + # fully initalized on the correct device + + self._clean_up_rot_emb_cache( + self.rot_emb.cached_freqs, + self.rot_emb.max_seq_len_cached, + ) + + # init RoPE on the right device(s) + for device in set( + [param.device for param in self.parameters()] + + [buffer.device for buffer in self.buffers()] + ): + self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) + + def forward( + self, + x_in, + position_ids=None, + past_key_value_states=None, + use_cache=False, + **attn_kwargs: Unpack[AttentionKwargs], + ): + # Embed the given vocabulary indices using the given attention mask, with pre-/post-norm and dropout as specified + # x_in: batch_size x seq_len x emb_dim if input is already embedded, otherwise batch_size x seq_len + # mask: batch_size x seq_len x seq_len + # bias: nheads x seq_len x seq_len + if past_key_value_states is None or len(past_key_value_states) == 0: + past_key_value_states = [None for _ in range(len(self.layers))] + + if x_in.dim() == 2: # input is not already embedded + x_in = self.embedding(x_in) + x_in = x_in * self.config.embedding_multiplier + + # this is the output cache for all the decoder layers + present_key_value_states = [] + + for i, layer in enumerate(self.layers): + output = layer( + x=x_in, + position_ids=position_ids, + past_key_value_state=past_key_value_states[i], + use_cache=use_cache, + **attn_kwargs, + ) + + if use_cache: + x_in, present_key_value_state = output + present_key_value_states.append(present_key_value_state) + + else: + x_in = output + + dec_out = x_in + dec_out = self.dec_norm(dec_out) + if self.config.p_dropout: + dec_out = self.dropout(dec_out) + + return dec_out, present_key_value_states + + +class Granite41(nn.Module): + def __init__( + self, + config: Optional[Granite41Config] = None, + distributed_strategy: DistributedStrategy = NoOpStrategy, + **kwargs, + ): + super(Granite41, self).__init__() + if config is not None: + self.config = config + else: + self.config = Granite41Config() + self.config = self.config.updated(**kwargs) + self.distributed_strategy = distributed_strategy + + self.base_model = Granite41Headless(self.config, self.distributed_strategy) + self.head = nn.Linear( + self.config.emb_dim, self.config.src_vocab_size, bias=False + ) + + @classmethod + def from_config(cls, config: Granite41Config) -> "Granite41": + return cls(config) + + def get_config(self) -> Granite41Config: + return self.config + + def reset_parameters(self): + self.head.weight.data.normal_( + 0, + 1 / math.sqrt(math.sqrt(self.config.emb_dim * self.config.src_vocab_size)), + ) + self.base_model.reset_parameters() + + def post_init(self): + # if this model ties weights, they are tied here + if self.config.tie_heads: + # handle assignment of non-meta weights to meta parameters + if self.head.weight.device == torch.device("meta"): + self.head.weight = self.base_model.embedding.weight + else: + self.base_model.embedding.weight = self.head.weight + + self.base_model.post_init() + + def forward( + self, + x: torch.LongTensor, + position_ids: Optional[torch.LongTensor] = None, + past_key_value_states: Optional[Tuple[torch.FloatTensor,]] = None, + use_cache: bool = False, + last_n_tokens: int = 0, + **attn_kwargs: Unpack[AttentionKwargs], + ): + get_attention_type(**attn_kwargs)["validate_attn_kwargs"]( + input_ids=x, + position_ids=position_ids, + past_key_value_states=past_key_value_states, + **attn_kwargs, + ) + + output, cache = self.base_model( + x, + position_ids, + past_key_value_states, + use_cache, + **attn_kwargs, + ) + + output = gather_outputs(output, last_n_tokens, **attn_kwargs) + preds = self.head(output) + preds = preds / self.config.logits_scaling + + if use_cache: + return preds, cache + else: + return preds + + +_8b_config = Granite41Config( +) + +_architecture_name = "granite41" + + +def _granite_factory_factory(config): + def factory(**kwargs): + return Granite41(config, **kwargs) + + return factory + + +models.register_model(_architecture_name, "8b", _granite_factory_factory(_8b_config)) + + +def _weight_fusion( + input_sd: Mapping, model_config: Optional[Granite41Config] = None, **kwargs +): + has_fused_weights = True + if model_config: + if not model_config.fused_weights: + has_fused_weights = False + + new_sd = input_sd + if has_fused_weights: + new_sd = serialization._mlp_glu_unfused_to_fused_adapter_step( + serialization._attn_unfused_to_fused_step(new_sd) + ) + return new_sd + + +serialization.register_adapter_step(_architecture_name, "weight_fusion", _weight_fusion) + + +def _hf_to_fms_names(input_sd: Mapping[str, Any], **kwargs) -> Mapping[str, Any]: + replacements = [ + (r"^lm_head.weight", "head.weight"), + (r"^model.embed_tokens.weight", "base_model.embedding.weight"), + (r"^model.norm", "base_model.dec_norm"), + (r"^model.layers", "base_model.layers"), + (r"self_attn\.k_proj", "attn.in_proj.key"), + (r"self_attn\.v_proj", "attn.in_proj.value"), + (r"self_attn\.q_proj", "attn.in_proj.query"), + (r"self_attn\.o_proj", "attn.dense"), + (r"mlp\.gate_proj", "ff_sub_layer.wg"), + (r"mlp\.up_proj", "ff_sub_layer.w1"), + (r"mlp\.down_proj", "ff_sub_layer.w2"), + (r"input_layernorm", "ln"), + (r"post_attention_layernorm", "ff_ln"), + ] + new_sd = {} + for name, param in input_sd.items(): + new_name = name + for pattern, repl in replacements: + new_name = re.sub(pattern, repl, new_name) + new_sd[new_name] = param + return new_sd + + +serialization.register_adapter_step( + _architecture_name, "hf_to_fms_names", _hf_to_fms_names +) + + +serialization.register_adapter_step( + _architecture_name, + "weight_expansion_for_mismatched_head_dim", + serialization._weight_expansion_for_mismatched_head_dim, # type: ignore[arg-type] +) + + +def _get_rope_params(linear_type: str) -> list[str]: + if "gptq" in linear_type: + return ["qweight", "scales", "qzeros", "bias"] + elif "int8" in linear_type: + # quantize_weight is fms-model-optimizer identifier of weight clip values + return ["weight", "bias", "quantize_weight"] + elif "fp8" in linear_type: + return ["weight", "weight_scale", "input_scale", "bias"] + else: # torch.nn.Linear + return ["weight", "bias"] + + +def _hf_to_fms_rope( + input_sd: Mapping[str, Any], model_config: Optional[Granite41Config] = None, **kwargs +) -> Mapping[str, Any]: + new_sd = {} + + if model_config: + head_size = model_config.emb_dim // model_config.nheads + else: + logger.warning("Missing model_config, assuming defaults for head_size") + head_size = 128 # Good default for most models + + for name, param in input_sd.items(): + # Some checkpoints have weights in different precisions, which can have + # auxiliary tensors (see _get_rope_params e.g. gptq, fp8). + # Thus, we need to get rope_params per parameter. + linear_type_str = "torch_linear" + if model_config and model_config.linear_config: + linear_type_str = get_linear_type( + model_config.linear_config, + module_name=name, + ) + rope_params = _get_rope_params(linear_type_str) + trans_required_pattern = re.compile( + f"base_model.layers.[0-9]+.attn.in_proj.(query|key).({'|'.join(rope_params)})$" + ) + + # hf -> fms requires a transpose operation for the query and key + # weight and bias parameters for Llama models + # This transpose is due to the different implementation of RoPE in + # HF and FMS. While FMS follows the original RoPE paper + # (https://arxiv.org/abs/2104.09864), HF has its own implementation + # that doesn't respect the order of outputs. This is OK as long as you + # rearrange the weights of the query and key projections, as the + # combination projection + RoPE ends up producing the same outputs. + # Therefore, to make FMS produce the correct order of outputs when + # loading from an HF checkpoint, we need to undo the transformation + # that HF does from the original Meta weights + is_gptq_2d_qparam = "gptq" in linear_type_str and param.dim() == 2 + if bool(trans_required_pattern.match(name)) and param.numel() > 1: + temp = param + if is_gptq_2d_qparam: + # GPTQ qweights are [in_feat, out_feat] (unlike usual [out_feat, in_feat]) + # and are fully transposed before & after process. + # GPTQ scales and qzeros are also transposed accordingly + temp = temp.transpose(0, 1) + # num_heads is used in the transformation required for hf->fms + # can't be precomputed because q and k might have different num_heads + num_heads = temp.size(0) // head_size + + if temp.dim() == 2: # weight + temp_view = temp.view(num_heads, 2, -1, temp.size(1)) + else: # 1-dim parameters + temp_view = temp.view(num_heads, 2, -1) + temp = temp_view.transpose(1, 2).reshape(*temp.size()) + + if is_gptq_2d_qparam: + temp = temp.transpose(0, 1) + + new_sd[name] = temp + else: + new_sd[name] = param + + return new_sd + + +def _hf_gptq_granite_check( + input_sd: Mapping[str, Any], model_config: Optional[Granite41Config] = None, **kwargs +) -> Mapping[str, Any]: + has_fused_weights = True + linear_type = "torch_linear" + if model_config: + if not model_config.fused_weights: + has_fused_weights = False + if model_config.linear_config: + linear_type = model_config.linear_config["linear_type"] + + if not callable(linear_type) and "gptq" in linear_type and has_fused_weights: + raise ValueError( + "GPTQ HF granite checkpoints cannot be loaded into a model with fused weights" + ) + + return input_sd + + +serialization.register_adapter_step( + _architecture_name, "hf_gptq_fusion_check", _hf_gptq_granite_check +) + +serialization.register_adapter_step( + _architecture_name, "hf_to_fms_rope", _hf_to_fms_rope +) + +serialization.register_adapter( + _architecture_name, + "hf", + [ + "hf_to_fms_names", + "hf_to_fms_rope", + "hf_gptq_fusion_check", + "weight_fusion", + ], +) diff --git a/fms/modules/positions.py b/fms/modules/positions.py index 95acdb8c7..26a7c6089 100644 --- a/fms/modules/positions.py +++ b/fms/modules/positions.py @@ -437,15 +437,17 @@ def adjusted_qk( freqs = self.cached_freqs[q.device.index][alpha][position_ids] freqs = freqs.float() # 1 L D/2 2 2 + q_start = freqs.size(1) - q.size(1) + k_start = freqs.size(1) - k.size(1) q_out = ( - freqs[:, -q.size(1) :, None, :, :, :] + freqs.narrow(1, q_start, q.size(1)).unsqueeze(2) .mul(q_.unsqueeze(-2)) .sum(5) .flatten(3) ).type_as(q) k_out = ( - freqs[:, -k.size(1) :, None, :, :, :] + freqs.narrow(1, k_start, k.size(1)).unsqueeze(2) .mul(k_.unsqueeze(-2)) .sum(5) .flatten(3) diff --git a/fms/utils/headless.py b/fms/utils/headless.py index 17b0cf178..03a6ec2f4 100644 --- a/fms/utils/headless.py +++ b/fms/utils/headless.py @@ -22,6 +22,8 @@ def gather_outputs( base_model_output = base_model_output[:, -1, :] # this is the base case elif last_n_tokens > 0 and base_model_output.shape[1] >= last_n_tokens: - base_model_output = base_model_output[:, -last_n_tokens:, :] + base_model_output = base_model_output.narrow( + 1, base_model_output.shape[1] - last_n_tokens, last_n_tokens + ) return base_model_output diff --git a/fms/utils/spyre/paged.py b/fms/utils/spyre/paged.py index bedb229d5..2b1f3e4d0 100644 --- a/fms/utils/spyre/paged.py +++ b/fms/utils/spyre/paged.py @@ -1,5 +1,5 @@ import math -from typing import List, Optional, Tuple +from typing import List, NotRequired, Optional, Tuple from fms.modules.attention import ( AttentionKwargs, @@ -143,6 +143,126 @@ class SpyrePagedAttentionKwargs(AttentionKwargs): mask: Optional[torch.Tensor] # prefill mask +@custom_op("spyre::paged_attn_compute_with_sinks", mutates_args={}, device_types=["cpu", "cuda"]) +def paged_attn_compute_with_sinks( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + scale: float, + current_tkv_mask: torch.Tensor, + left_padded_prompt_mask: torch.Tensor, + block_table: torch.Tensor, + sinks: torch.Tensor, + sliding_window: int, +) -> torch.Tensor: + # torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype), + output = torch.zeros_like(query) + num_query_heads = query.shape[2] + num_kv_heads = value_cache.shape[2] + head_size = value_cache.shape[3] + block_size = value_cache.shape[1] + seq_len_q = query.shape[1] + num_seqs = query.shape[0] + + block_tables_lst = block_table.tolist() + + seq_lens_lst = current_tkv_mask.tolist() + for i in range(num_seqs): + q = query[i] + block_table = block_tables_lst[i] + start_pos = int(left_padded_prompt_mask[i].item()) + seq_len = int(seq_lens_lst[i]) + seq_len_q_i = seq_len_q + + keys_lst: list[torch.Tensor] = [] + values_lst: list[torch.Tensor] = [] + for j in range(start_pos, seq_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, block_offset, :, :] + k = k.reshape(num_kv_heads, head_size) + keys_lst.append(k) + + v = value_cache[block_number, block_offset, :, :] + values_lst.append(v) + keys = torch.stack(keys_lst, dim=0) + values = torch.stack(values_lst, dim=0) + seq_len_kv = keys.shape[0] + + # cut the pads for first prefill + if q.shape[0] > seq_len_kv: + seq_len_q_i = seq_len_kv + q = q[-seq_len_kv:] + + if num_kv_heads > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_query_heads // num_kv_heads, dim=1) + values = torch.repeat_interleave( + values, num_query_heads // num_kv_heads, dim=1 + ) + + # Generate mask for prefix attention + mask = torch.ones((1, seq_len_q_i, seq_len_kv), dtype=torch.bool) + mask[:, :, -seq_len_q_i:] = torch.tril(mask[:, :, -seq_len_q_i:]) + mask = torch.where(mask.logical_not(), -torch.inf, 0.0).to( + device=query.device, dtype=query.dtype + ) + # truncate for sliding window kv_cache + mask = mask[..., -seq_len_q_i:, -seq_len_kv:] + if 0 < sliding_window < seq_len_kv: + mask += torch.tril( + mask.new_full((seq_len_q_i, seq_len_kv), -torch.inf), + diagonal=(seq_len_kv - seq_len_q_i) - sliding_window, + ) + + # Sink code + # https://github.com/openai/gpt-oss/blob/main/gpt_oss/torch/model.py#L153 + # from gpt-oss open ai implementation + S = sinks.reshape(-1, 1, 1).expand(-1, seq_len_q_i, -1) # type: ignore + if scale is None: + scale = 1.0 / math.sqrt(head_size) + scale = math.sqrt(scale) + QK = torch.einsum("qhd,khd->hqk", q*scale, keys*scale) + QK += mask + QK = torch.cat([QK, S], dim=-1) + QK = QK - QK.max(dim=-1, keepdim=True).values + W = torch.softmax(QK, dim=-1) + W = W[..., :-1] # drop the attention sinks after done + attn = torch.einsum("hqk,khd->qhd", W, values) + output[i][-seq_len_q_i:] = attn + return output + + +@paged_attn_compute_with_sinks.register_fake +def paged_attn_compute_with_sinks_meta( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + scale: float, + current_tkv_mask: torch.Tensor, + left_padded_prompt_mask: torch.Tensor, + block_table: torch.Tensor, + sinks: torch.Tensor, + sliding_window: int, +) -> torch.Tensor: + return torch.zeros_like(query) + + +class SpyrePagedAttentionSinkKwargs(SpyrePagedAttentionKwargs): + """ + The sinks attention kwargs to be passed to fms model forward. + + sinks: torch.Tensor + this is the tensor weights for the sinks + sliding_window: int + this is the sliding window size for sinks attention + """ + + sinks: NotRequired[torch.Tensor] + sliding_window: NotRequired[int | None] + + def __spyre_paged_store_op( keys: torch.Tensor, values: torch.Tensor, @@ -189,6 +309,31 @@ def __spyre_paged_compute_op( ) +def __spyre_paged_compute_sinks_op( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + nheads: int, + kvheads: int, + p_dropout: float, + scale_factor: Optional[float], + **attn_kwargs, +) -> torch.Tensor: + if scale_factor is None: + scale_factor = 1 / math.sqrt(query.shape[-1]) + return torch.ops.spyre.paged_attn_compute_with_sinks( + query, + key_cache, + value_cache, + scale_factor, + attn_kwargs["current_tkv_mask"], + attn_kwargs["left_padded_prompt_mask"], + attn_kwargs["block_table"], + attn_kwargs["sinks"], + attn_kwargs["sliding_window"], + ) + + def __spyre_paged_validate_attn_kwargs_op( input_ids: torch.Tensor, position_ids: torch.Tensor, @@ -238,3 +383,13 @@ def __spyre_paged_validate_attn_kwargs_op( compute_decode_op=__spyre_paged_compute_op, validate_attn_kwargs_op=__spyre_paged_validate_attn_kwargs_op, ) + + +register_attention_op( + "spyre_paged_attn_with_sinks", + __spyre_paged_store_op, + _sdpa_compute_op, + is_prefill_op=lambda **attn_kwargs: attn_kwargs.get("block_table", None) is None, + compute_decode_op=__spyre_paged_compute_sinks_op, + validate_attn_kwargs_op=__spyre_paged_validate_attn_kwargs_op, +)