From e00c4284893bb2cfe8bc6c3382c63a6deb7e078b Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Fri, 20 Feb 2026 10:36:25 -0300 Subject: [PATCH 01/98] Initial working version Signed-off-by: Flavia Beo --- fms/models/__init__.py | 2 + fms/models/hf/config_utils/__init__.py | 1 + fms/models/hf/config_utils/param_builders.py | 22 + fms/models/qwen3.py | 640 +++++++++++++++++++ fms/modules/attention.py | 61 +- tests/models/hf_equivalence/test_qwen3.py | 292 +++++++++ 6 files changed, 1014 insertions(+), 4 deletions(-) create mode 100644 fms/models/qwen3.py create mode 100644 tests/models/hf_equivalence/test_qwen3.py diff --git a/fms/models/__init__.py b/fms/models/__init__.py index ff6932d31..aa43810ba 100644 --- a/fms/models/__init__.py +++ b/fms/models/__init__.py @@ -502,6 +502,7 @@ def model_wrap(model): mistral, mistral3, mixtral, + qwen3, roberta, siglip_vision, mpnet, @@ -517,6 +518,7 @@ def model_wrap(model): "mistral", "mistral3", "mixtral", + "qwen3", "roberta", "siglip_vision", "mpnet", diff --git a/fms/models/hf/config_utils/__init__.py b/fms/models/hf/config_utils/__init__.py index 3bb7d588c..58b5f1a56 100644 --- a/fms/models/hf/config_utils/__init__.py +++ b/fms/models/hf/config_utils/__init__.py @@ -44,6 +44,7 @@ "LlavaNextForConditionalGeneration": ("llava_next", pb.build_llava_next_params), "MPNetForMaskedLM": ("mpnet", pb.build_mpnet_params), "BertForMaskedLM": ("bert", pb.build_bert_params), + "Qwen3ForCausalLM": ("qwen3", pb.build_qwen3_embeddings_params), "Mistral3ForConditionalGeneration": ("mistral3", pb.build_mistral3_params), # Classify arches have some extra keys for labels "RobertaForSequenceClassification": ("roberta_classification", partial(pb.build_roberta_params, is_classify=True)), diff --git a/fms/models/hf/config_utils/param_builders.py b/fms/models/hf/config_utils/param_builders.py index 9fdf77c9f..fd57179b6 100644 --- a/fms/models/hf/config_utils/param_builders.py +++ b/fms/models/hf/config_utils/param_builders.py @@ -346,6 +346,28 @@ def build_mistral3_params(config: PretrainedConfig) -> dict: return config_params +def build_qwen3_embeddings_params(config: PretrainedConfig) -> dict: + """Param builder for mapping Qwen3ForCausalLM to FMS.""" + config_params = { + "norm_eps": config.rms_norm_eps, + "bos_token_id": config.bos_token_id, + "eos_token_id": config.eos_token_id, + "initializer_range": config.initializer_range, + "activation_fn": config.hidden_act, + "emb_dim": config.hidden_size, + "max_expected_seq_len": config.max_position_embeddings, + "kvheads": config.num_key_value_heads, + "p_dropout": config.attention_dropout, + "rope_base": config.rope_theta, + "head_dim": getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ), + } + return model_params_with_common_opts( + config, config_params, inner_dim=config.intermediate_size + ) + + def model_params_with_common_opts( config: PretrainedConfig, config_params: dict, inner_dim: int ) -> dict: diff --git a/fms/models/qwen3.py b/fms/models/qwen3.py new file mode 100644 index 000000000..514737c27 --- /dev/null +++ b/fms/models/qwen3.py @@ -0,0 +1,640 @@ +import logging +import math +import re +from dataclasses import dataclass, field +from typing import Any, Mapping, Optional, Tuple +from typing_extensions import Unpack + +import torch +import torch.nn as nn + +from fms import models +from fms.distributed.strategy import ( + DistributedStrategy, + NoOpStrategy, + TensorParallelStrategy, +) +from fms.modules.attention import ( + AttentionKwargs, + MultiHeadAttention, + get_attention_type, +) +from fms.modules.feedforward import GatedLinearUnit +from fms.modules.head import LinearClassificationHead +from fms.modules.layernorm import LayerNormParameterized +from fms.modules.linear import get_linear_type +from fms.modules.positions import RotaryEmbedding +from fms.utils import serialization +from fms.utils.activation import str_to_activation +from fms.utils.config import ModelConfig +from fms.utils.headless import gather_outputs + + +logger = logging.getLogger(__name__) + + +""" +Qwen3 Model Implementation + +Based on the Qwen3 architecture from HuggingFace: +https://huggingface.co/Qwen/Qwen3-Embedding-0.6B + +Architecture mapping from HuggingFace config.json: +- attention_bias: False -> attn_bias +- attention_dropout: 0.0 -> p_dropout +- head_dim: 128 -> head_dim +- hidden_act: "silu" -> activation_fn +- hidden_size: 1024 -> emb_dim +- intermediate_size: 3072 -> used to calculate hidden_grow_factor +- max_position_embeddings: 32768 -> max_expected_seq_len +- num_attention_heads: 16 -> nheads +- num_hidden_layers: 28 -> nlayers +- num_key_value_heads: 8 -> kvheads +- rms_norm_eps: 1e-06 -> norm_eps +- rope_theta: 1000000 -> rope_theta +- vocab_size: 151669 -> src_vocab_size +- tie_word_embeddings: true -> tie_heads +""" + + +@dataclass +class Qwen3Config(ModelConfig): + src_vocab_size: int = 151_669 + emb_dim: int = 1024 + norm_eps: float = 1e-6 + nheads: int = 16 + kvheads: int = 8 + nlayers: int = 28 + pad_id: int = -1 + hidden_grow_factor: float = 3072 / 1024 # intermediate_size / hidden_size + multiple_of: int = 256 + activation_fn: str = "swish" # silu is same as swish + p_dropout: float = 0.0 + max_expected_seq_len: int = 32768 + attn_bias: bool = False + mlp_bias: bool = False + tie_heads: bool = True + rope_theta: float = 1000000.0 + rope_scaling: dict = field(default_factory=lambda: {}) + head_dim: int = 128 + linear_config: Optional[Mapping[str, Any]] = None + fused_weights: bool = False + + +class Qwen3Block(nn.Module): + def __init__(self, config: Qwen3Config, rotary_emb: RotaryEmbedding): + super(Qwen3Block, self).__init__() + self.config = config + emb_kq = self.config.head_dim + emb_v = self.config.head_dim + + self.ln = LayerNormParameterized( + self.config.emb_dim, + elementwise_scale=True, + elementwise_shift=False, + use_mean=False, + eps=self.config.norm_eps, + use_high_precision_pow=True, + ) + self.ff_ln = LayerNormParameterized( + self.config.emb_dim, + elementwise_scale=True, + elementwise_shift=False, + use_mean=False, + eps=self.config.norm_eps, + use_high_precision_pow=True, + ) + + if self.config.kvheads == 0: + kvheads = self.config.nheads + else: + kvheads = self.config.kvheads + assert self.config.nheads % self.config.kvheads == 0 + + self.attn = MultiHeadAttention( + self.config.emb_dim, + emb_kq, + emb_v, + self.config.nheads, + kvheads, + p_dropout=self.config.p_dropout, + use_bias=self.config.attn_bias, + position_encoder=rotary_emb, + fused=self.config.fused_weights, + linear_config=self.config.linear_config, + norm_eps=self.config.norm_eps, + head_dim=self.config.head_dim, + ) + self.ff_sub_layer = GatedLinearUnit( + self.config.emb_dim, + hidden_grow_factor=self.config.hidden_grow_factor, + multiple_of=self.config.multiple_of, + activation_fn=str_to_activation(self.config.activation_fn), + p_dropout=self.config.p_dropout, + use_bias=self.config.mlp_bias, + fused=self.config.fused_weights, + linear_config=self.config.linear_config, + ) + + if self.config.p_dropout != 0: + self.dropout = nn.Dropout(self.config.p_dropout) + + def forward( + self, + x, + *, + position_ids=None, + past_key_value_state=None, + use_cache=False, + **attn_kwargs: Unpack[AttentionKwargs], + ): + # if the cache is not empty, we need to get the kv cache for self attention + self_attn_past_key_value = past_key_value_state + + # first we do MHA and Add&Norm + residual = x + x = self.ln(x) + x = self.attn( + q=x, + position_ids=position_ids, + past_key_value_state=self_attn_past_key_value, + use_cache=use_cache, + **attn_kwargs, + ) + cache = None + if use_cache: + x, cache = x + if self.config.p_dropout != 0: + x = self.dropout(x) + # residual connection + x = x + residual + + # then we do FF and Add&Norm + residual = x + x = self.ff_ln(x) + x = self.ff_sub_layer(x) + if self.config.p_dropout != 0: + x = self.dropout(x) + # another residual + x = x + residual + + if use_cache: + return (x, cache) + else: + return x + + +class Qwen3Headless(nn.Module): + def __init__( + self, + config: Optional[Qwen3Config] = None, + distributed_strategy: DistributedStrategy = NoOpStrategy, + **kwargs, + ): + super(Qwen3Headless, self).__init__() + if config is not None: + self.config = config + else: + self.config = Qwen3Config() + self.config = self.config.updated(**kwargs) + self.distributed_strategy = distributed_strategy + + embedding = nn.Embedding( + self.config.src_vocab_size, self.config.emb_dim, self.config.pad_id + ) + # TP does not work with tied weights + if ( + not isinstance(self.distributed_strategy, TensorParallelStrategy) + or not self.config.tie_heads + ): + self.embedding = self.distributed_strategy.distribute_module(embedding) + else: + logger.warning( + "You're using TP on a model with tied weights between head and embedding. " + "The tied weights won't be sharded, which can result in unexpected OOMs." + ) + self.embedding = embedding + + self.rot_emb = RotaryEmbedding( + dim=self.config.head_dim, + scaling=self.config.rope_scaling, + max_seq_len=self.config.max_expected_seq_len, + ratio=self.config.rope_theta, + ) + # RoPE init + for device in set( + [param.device for param in self.parameters()] + + [buffer.device for buffer in self.buffers()] + ): + self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) + + layers = [] + for i in range(self.config.nlayers): + block: nn.Module = Qwen3Block(self.config, self.rot_emb) + block = self.distributed_strategy.distribute_layer(block, i) + layers.append(block) + self.layers = nn.ModuleList(layers) + + dec_norm = LayerNormParameterized( + self.config.emb_dim, + elementwise_scale=True, + elementwise_shift=False, + use_mean=False, + eps=self.config.norm_eps, + use_high_precision_pow=True, + ) + self.dec_norm = self.distributed_strategy.distribute_module( + dec_norm, final_layers=True + ) + + if self.config.p_dropout: + self.dropout = nn.Dropout(self.config.p_dropout) + + def get_config(self) -> Qwen3Config: + return self.config + + @classmethod + def from_config(cls, config: Qwen3Config) -> "Qwen3Headless": + return cls(config) + + def reset_parameters(self): + assert isinstance(self.embedding, torch.nn.Embedding) + nn.init.trunc_normal_( + self.embedding.weight, mean=0.0, std=self.config.emb_dim**-0.5 + ) + + # RoPE init + for device in set( + [param.device for param in self.parameters()] + + [buffer.device for buffer in self.buffers()] + ): + self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) + + # Call reset_parameters for relevant sub-layers + for m in self.modules(): + if ( + isinstance(m, MultiHeadAttention) + or isinstance(m, GatedLinearUnit) + or isinstance(m, LayerNormParameterized) + ): + m.reset_parameters() + + def validate_reset_parameters(self): + # Verifies that the above self.reset_parameters() executed correctly. + # This may not always be the case for distributed settings with sharded tensors, + # such as FSDP or TP. Note that performing this check may require unsharding / + # re-materializing the full model on a single rank to access the underlying tensors. + tolerance = 1e-3 + + def check_close(x): + assert x.mean().abs() < tolerance + assert x.std().sub(0.02).abs() < tolerance + + with torch.no_grad(): + for p in self.parameters(): + assert p.isnan().int().sum() == 0 + assert p.isinf().int().sum() == 0 + for m in self.modules(): + if isinstance(m, LayerNormParameterized): + if m.elementwise_scale: + assert m.weight.sum() == m.weight.numel() + if m.elementwise_shift: + assert m.bias.add(1).sum() == m.bias.numel() + elif isinstance(m, nn.Embedding): + check_close(m.weight) + elif isinstance(m, GatedLinearUnit): + check_close(m.w1.weight) + check_close(m.w2.weight) + check_close(m.wg.weight) + elif isinstance(m, MultiHeadAttention): + if m.fused: + check_close(m.in_proj.qkv_fused.weight) + else: + check_close(m.in_proj.query.weight) + check_close(m.in_proj.key.weight) + check_close(m.in_proj.value.weight) + check_close(m.dense.weight) + + def _clean_up_rot_emb_cache( + self, + cached_freqs: dict[Optional[torch.device], dict[int, torch.Tensor]], + max_seq_len_cached: dict[Optional[torch.device], int], + ): + # remove meta tensors from cached_freqs + for dev in list(cached_freqs.keys()): + for alp in list(cached_freqs[dev].keys()): + if cached_freqs[dev][alp].device == torch.device("meta"): + del cached_freqs[dev][alp] + if len(cached_freqs[dev]) == 0: + del cached_freqs[dev] + del max_seq_len_cached[dev] + + def post_init(self): + # This function is called in `get_model` after the model is + # fully initalized on the correct device + self._clean_up_rot_emb_cache( + self.rot_emb.cached_freqs, + self.rot_emb.max_seq_len_cached, + ) + + # init RoPE on the right device(s) + for device in set( + [param.device for param in self.parameters()] + + [buffer.device for buffer in self.buffers()] + ): + self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) + + def forward( + self, + x_in, + position_ids=None, + past_key_value_states=None, + use_cache=False, + **attn_kwargs: Unpack[AttentionKwargs], + ): + # Embed the given vocabulary indices using the given attention mask, with pre-/post-norm and dropout as specified + # x_in: batch_size x seq_len + # mask: batch_size x seq_len x seq_len + # bias: nheads x seq_len x seq_len + if past_key_value_states is None or len(past_key_value_states) == 0: + past_key_value_states = [None for _ in range(len(self.layers))] + x_in = self.embedding(x_in) + + # this is the output cache for all the decoder layers + present_key_value_states = [] + + for i, layer in enumerate(self.layers): + output = layer( + x=x_in, + position_ids=position_ids, + past_key_value_state=past_key_value_states[i], + use_cache=use_cache, + **attn_kwargs, + ) + + if use_cache: + x_in, present_key_value_state = output + present_key_value_states.append(present_key_value_state) + else: + x_in = output + + dec_out = x_in + dec_out = self.dec_norm(dec_out) + if self.config.p_dropout: + dec_out = self.dropout(dec_out) + + return dec_out, present_key_value_states + + +class Qwen3(nn.Module): + def __init__( + self, + config: Optional[Qwen3Config] = None, + distributed_strategy: DistributedStrategy = NoOpStrategy, + **kwargs, + ): + super(Qwen3, self).__init__() + if config is not None: + self.config = config + else: + self.config = Qwen3Config() + self.config = self.config.updated(**kwargs) + self.distributed_strategy = distributed_strategy + + self.base_model = Qwen3Headless(self.config, self.distributed_strategy) + head = LinearClassificationHead( + self.config.emb_dim, self.config.src_vocab_size, bias=False + ) + # TP does not work with tied weights + if ( + not isinstance(self.distributed_strategy, TensorParallelStrategy) + or not self.config.tie_heads + ): + self.head = self.distributed_strategy.distribute_module(head) + else: + self.head = head + + def get_config(self) -> Qwen3Config: + return self.config + + @classmethod + def from_config(cls, config: Qwen3Config) -> "Qwen3": + return cls(config) + + def reset_parameters(self): + # Call reset_parameters for relevant sub-layers + assert isinstance(self.head, torch.nn.Linear) + self.head.weight.data.normal_( + 0, + 1 / math.sqrt(math.sqrt(self.config.emb_dim * self.config.src_vocab_size)), + ) + self.base_model.reset_parameters() + + def validate_reset_parameters(self): + # Verifies that the above self.reset_parameters() executed correctly. + # This may not always be the case for distributed settings with sharded tensors, + # such as FSDP or TP. Note that performing this check may require unsharding / + # re-materializing the full model on a single rank to access the underlying tensors. + tolerance = 1e-3 + + def check_close(x): + assert x.mean().abs() < tolerance + assert x.std().sub(0.02).abs() < tolerance + + with torch.no_grad(): + for p in self.parameters(): + assert p.isnan().int().sum() == 0 + assert p.isinf().int().sum() == 0 + self.base_model.validate_reset_parameters() + check_close(self.head.weight) + + def post_init(self): + self.base_model.post_init() + + # if this model ties weights, they are tied here + if self.config.tie_heads: + # handle assignment of non-meta weights to meta parameters + if self.head.weight.device == torch.device("meta"): + self.head.weight = self.base_model.embedding.weight + else: + self.base_model.embedding.weight = self.head.weight + + def forward( + self, + x: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + past_key_value_states: Optional[Tuple[torch.FloatTensor,]] = None, + use_cache: bool = False, + last_n_tokens: int = 0, + **attn_kwargs: Unpack[AttentionKwargs], + ): + get_attention_type(**attn_kwargs)["validate_attn_kwargs"]( + input_ids=x, + position_ids=position_ids, + past_key_value_states=past_key_value_states, + **attn_kwargs, + ) + output, cache = self.base_model( + x, position_ids, past_key_value_states, use_cache, **attn_kwargs + ) + + output = gather_outputs(output, last_n_tokens, **attn_kwargs) + preds = self.head(output) + + if use_cache: + return preds, cache + else: + return preds + + +# Register Qwen3 variants with the model registration API + +# Qwen3-Embedding-0.6B configuration +_0_6b_config = Qwen3Config( + src_vocab_size=151_669, + emb_dim=1024, + norm_eps=1e-6, + nheads=16, + kvheads=8, + nlayers=28, + hidden_grow_factor=3072 / 1024, + max_expected_seq_len=32768, + rope_theta=1_000_000.0, + head_dim=128, + tie_heads=True, +) + +_architecture_name = "qwen3" + + +def _qwen3_factory_factory(config): + def factory(**kwargs): + return Qwen3(config, **kwargs) + + return factory + + +# HuggingFace checkpoint adapter +def _hf_to_fms_names(hf_sd: Mapping[str, Any], model_config: Optional[Qwen3Config] = None) -> Mapping[str, Any]: + """ + Convert HuggingFace Qwen3 state dict to FMS format + """ + replacements = [ + (r"^norm.weight", "base_model.dec_norm.weight"), + (r"^embed_tokens.weight", "base_model.embedding.weight"), + (r"layers", "base_model.layers"), + (r"self_attn.k_proj.weight", "attn.in_proj.key.weight"), + (r"self_attn.k_norm.weight", "attn.in_proj.k_norm.weight"), + (r"self_attn.v_proj.weight", "attn.in_proj.value.weight"), + (r"self_attn.q_proj.weight", "attn.in_proj.query.weight"), + (r"self_attn.q_norm.weight", "attn.in_proj.q_norm.weight"), + (r"self_attn.o_proj.weight", "attn.dense.weight"), + (r"mlp.gate_proj.weight", "ff_sub_layer.wg.weight"), + (r"mlp.up_proj.weight", "ff_sub_layer.w1.weight"), + (r"mlp.down_proj.weight", "ff_sub_layer.w2.weight"), + (r"input_layernorm.weight", "ln.weight"), + (r"post_attention_layernorm.weight", "ff_ln.weight"), + ] + + new_sd = {} + for name, param in hf_sd.items(): + new_name = name + for pattern, repl in replacements: + new_name = re.sub(pattern, repl, new_name) + new_sd[new_name] = param + + return new_sd + +def _hf_to_fms_rope( + input_sd: Mapping[str, Any], model_config: Optional[Qwen3Config] = None, **kwargs +) -> Mapping[str, Any]: + new_sd = {} + + if model_config: + head_size = model_config.emb_dim // model_config.nheads + else: + logger.warning("Missing model_config, assuming defaults for head_size") + head_size = 128 # Good default for most models + + for name, param in input_sd.items(): + # Some checkpoints have weights in different precisions, which can have + # auxiliary tensors (see _get_rope_params e.g. gptq, fp8). + # Thus, we need to get rope_params per parameter. + linear_type_str = "torch_linear" + if model_config and model_config.linear_config: + linear_type_str = get_linear_type( + model_config.linear_config, + module_name=name, + ) + rope_params = _get_rope_params(linear_type_str) + trans_required_pattern = re.compile( + f"base_model.layers.[0-9]+.attn.in_proj.(query|key).({'|'.join(rope_params)})$" + ) + + # hf -> fms requires a transpose operation for the query and key + # weight and bias parameters for Llama models + # This transpose is due to the different implementation of RoPE in + # HF and FMS. While FMS follows the original RoPE paper + # (https://arxiv.org/abs/2104.09864), HF has its own implementation + # that doesn't respect the order of outputs. This is OK as long as you + # rearrange the weights of the query and key projections, as the + # combination projection + RoPE ends up producing the same outputs. + # Therefore, to make FMS produce the correct order of outputs when + # loading from an HF checkpoint, we need to undo the transformation + # that HF does from the original Meta weights + is_gptq_2d_qparam = "gptq" in linear_type_str and param.dim() == 2 + if bool(trans_required_pattern.match(name)) and param.numel() > 1: + temp = param + if is_gptq_2d_qparam: + # GPTQ qweights are [in_feat, out_feat] (unlike usual [out_feat, in_feat]) + # and are fully transposed before & after process. + # GPTQ scales and qzeros are also transposed accordingly + temp = temp.transpose(0, 1) + # num_heads is used in the transformation required for hf->fms + # can't be precomputed because q and k might have different num_heads + num_heads = temp.size(0) // head_size + + if temp.dim() == 2: # weight + temp_view = temp.view(num_heads, 2, -1, temp.size(1)) + else: # 1-dim parameters + temp_view = temp.view(num_heads, 2, -1) + temp = temp_view.transpose(1, 2).reshape(*temp.size()) + + if is_gptq_2d_qparam: + temp = temp.transpose(0, 1) + + new_sd[name] = temp + else: + new_sd[name] = param + + return new_sd + + +models.register_model( + _architecture_name, "0.6b", _qwen3_factory_factory(_0_6b_config) +) + +def _get_rope_params(linear_type: str) -> list[str]: + if "gptq" in linear_type: + return ["qweight", "scales", "qzeros", "bias"] + elif "int8" in linear_type: + # quantize_weight is fms-model-optimizer identifier of weight clip values + return ["weight", "bias", "quantize_weight"] + elif "fp8" in linear_type: + return ["weight", "weight_scale", "input_scale", "bias"] + else: # torch.nn.Linear + return ["weight", "bias"] + +serialization.register_adapter_step( + _architecture_name, "hf_to_fms_rope", _hf_to_fms_rope +) + +serialization.register_adapter_step( + _architecture_name, "hf_to_fms_names", _hf_to_fms_names +) + +serialization.register_adapter( + _architecture_name, + "hf", + ["hf_to_fms_names"], +) diff --git a/fms/modules/attention.py b/fms/modules/attention.py index 95f323c0f..3ee94d668 100644 --- a/fms/modules/attention.py +++ b/fms/modules/attention.py @@ -1,5 +1,6 @@ import abc import functools +from token import OP from typing import ( Any, Callable, @@ -12,6 +13,7 @@ ) from typing_extensions import NotRequired, Unpack +from fms.modules.layernorm import LayerNormParameterized import torch import torch.distributed from torch import Tensor, nn @@ -384,6 +386,8 @@ def __init__( emb_v_per_head: int, use_bias: bool, linear_config: Optional[Mapping[str, Any]] = None, + norm_eps: Optional[float] = None, + head_dim: Optional[int] = None, *args, **kwargs, ): @@ -399,6 +403,9 @@ def __init__( **kwargs, ) + self.head_dim = head_dim or emb_kq_per_head + self.norm_eps = norm_eps or 1e-5 + self.query = get_linear( self.emb_dim, self.nheads * self.emb_kq_per_head, @@ -418,6 +425,27 @@ def __init__( linear_config=linear_config, ) + if norm_eps: + self.norm = True + self.q_norm = LayerNormParameterized( + head_dim, + elementwise_scale=True, + elementwise_shift=False, + use_mean=False, + eps=norm_eps, + use_high_precision_pow=True, + ) + self.k_norm = LayerNormParameterized( + head_dim, + elementwise_scale=True, + elementwise_shift=False, + use_mean=False, + eps=norm_eps, + use_high_precision_pow=True, + ) + else: + self.norm = False + def reset_parameters(self): for m in self.modules(): if isinstance(m, nn.Linear): @@ -439,10 +467,29 @@ def forward( "both k and v must either be given as tensors or both None" ) - # b x h x qlen x ds - queries = self.query(q) - keys = self.key(k) - values = self.value(v) + # Project queries, keys, values + queries = self.query(q) # b x qlen x (nheads * head_dim) + keys = self.key(k) # b x klen x (kvheads * head_dim) + values = self.value(v) # b x vlen x (kvheads * head_dim) + + # Apply normalization if enabled + # Normalization should be applied per-head, so we need to reshape first + if self.norm: + batch_size, q_len, _ = queries.shape + k_len = keys.shape[1] + + # Reshape to separate heads: b x len x heads x head_dim + queries = queries.view(batch_size, q_len, self.nheads, self.head_dim) + keys = keys.view(batch_size, k_len, self.kvheads, self.head_dim) + + # Apply normalization per head + queries = self.q_norm(queries) + keys = self.k_norm(keys) + + # Reshape back: b x len x (heads * head_dim) + queries = queries.view(batch_size, q_len, -1) + keys = keys.view(batch_size, k_len, -1) + return queries, keys, values @@ -576,6 +623,8 @@ def __init__( fused: bool = True, linear_config: Optional[Mapping[str, Any]] = None, scale_factor: Optional[float] = None, + norm_eps: Optional[float] = None, + head_dim: Optional[int] = None, ): super(MultiHeadAttention, self).__init__() self.nheads = nheads @@ -588,6 +637,8 @@ def __init__( self.fused = fused self.linear_config = linear_config self.scale_factor = scale_factor + self.norm_eps = norm_eps + self.head_dim = head_dim self.in_proj: QKV = (FusedQKV if self.fused else UnfusedQKV)( self.emb_dim, @@ -597,6 +648,8 @@ def __init__( self.emb_v_per_head, self.use_bias, linear_config=linear_config, + norm_eps=self.norm_eps, + head_dim=self.head_dim, ) self.dense = get_linear( diff --git a/tests/models/hf_equivalence/test_qwen3.py b/tests/models/hf_equivalence/test_qwen3.py new file mode 100644 index 000000000..966425527 --- /dev/null +++ b/tests/models/hf_equivalence/test_qwen3.py @@ -0,0 +1,292 @@ +import pytest +import torch +import torch.nn.functional as F +import random + +from fms.models import get_model +from fms.utils.generation import pad_input_ids + +device = "cpu" +SEED = 42 +random.seed(SEED) +torch.manual_seed(SEED) +torch.use_deterministic_algorithms(True) + + +def _get_inputs(tokenizer, prompt="Hello, how are you?"): + """Tokenize input prompt""" + encoded = tokenizer(prompt, return_tensors="pt") + input_ids = encoded["input_ids"] + return input_ids + + +def _get_hf_model_output(model_path, inputs): + """Get output from HuggingFace model""" + from transformers import AutoModel, AutoTokenizer + + model = AutoModel.from_pretrained( + model_path, torch_dtype=torch.float32 + ).to(device) + tokenizer = AutoTokenizer.from_pretrained(model_path) + + model.eval() + with torch.no_grad(): + outputs = model(**inputs) + # The model uses the last token's representation as the embedding + embeddings = outputs.last_hidden_state[:, -1, :] + # Normalize embeddings for cosine similarity + embeddings = F.normalize(embeddings, p=2, dim=1) + + query_embedding = embeddings[0] + doc_embeddings = embeddings[1:] + + return query_embedding, doc_embeddings + + +def _get_fms_model_output(model_path, inputs): + """Get output from FMS model""" + model = get_model( + "hf_pretrained", + model_path, + data_type=torch.float32, + device_type=device, + ) + + model.eval() + torch.set_grad_enabled(False) + + # Get input_ids from the inputs dict + input_ids = inputs["input_ids"].to(device) + + # Prepare inputs for FMS - this will create appropriate mask and position_ids + input_ids_padded, padding_kwargs = pad_input_ids(input_ids, min_pad_length=0) + input_ids_padded = input_ids_padded.to(device) + + with torch.no_grad(): + # Get embeddings from base model (before LM head) + embeddings, _ = model.base_model( + input_ids_padded, + mask=padding_kwargs["mask"].to(device), + position_ids=padding_kwargs["position_ids"].to(device), + ) + # The model uses the last token's representation as the embedding + embeddings = embeddings[:, -1, :] + # Normalize embeddings for cosine similarity + embeddings = F.normalize(embeddings, p=2, dim=1) + + query_embedding = embeddings[0] + doc_embeddings = embeddings[1:] + + return query_embedding, doc_embeddings + + +@pytest.mark.slow +def test_qwen3_embedding_0_6b_equivalence(): + """ + Test equivalence between FMS and HuggingFace implementations of Qwen3-Embedding-0.6B. + + This test: + 1. Loads both HF and FMS versions of the model + 2. Compares logits for the same input + 3. Compares generated sequences + + Note: This test requires downloading the model from HuggingFace Hub. + """ + model_path = "Qwen/Qwen3-Embedding-0.6B" + + # Skip if model is not available locally + try: + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path) + except Exception as e: + pytest.skip(f"Model not available: {e}") + + + # Prepare input + query = "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: What is the capital of China?" + documents = [ + "The capital of China is Beijing.", + "That is a very fast car." + ] + input_texts = [query] + documents + + # 3. Tokenize + inputs = tokenizer(input_texts, padding=True, truncation=True, return_tensors='pt', max_length=8192) + + + # Get outputs from both models + hf_query_embedding, hf_doc_embeddings = _get_hf_model_output(model_path, inputs) + fms_query_embedding, fms_doc_embeddings = _get_fms_model_output(model_path, inputs) + + hf_scores = (hf_query_embedding @ hf_doc_embeddings.T) + fms_scores = (fms_query_embedding @ fms_doc_embeddings.T) + + + # First sentence contains the awnser to the query. + # It's score should be always the highest. + assert hf_scores[0] > hf_scores[1] + assert fms_scores[0] > fms_scores[1] + assert fms_scores[0] > 0.7 + assert hf_scores[0] > 0.7 + assert hf_scores[0] > fms_scores[1] + assert fms_scores[0] > hf_scores[1] + + +def test_qwen3_forward_pass(): + """ + Test basic forward pass of Qwen3 model. + + This is a simpler test that just verifies the model can be loaded + and produces reasonable outputs without comparing to HF. + """ + model_path = "Qwen/Qwen3-Embedding-0.6B" + + try: + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path) + except Exception as e: + pytest.skip(f"Model not available: {e}") + + # Load FMS model + model = get_model( + "hf_pretrained", + model_path, + data_type=torch.float32, + device_type=device, + ) + + model.eval() + + # Prepare input + prompt = "Hello, world!" + input_ids = _get_inputs(tokenizer, prompt) + input_ids_padded, padding_kwargs = pad_input_ids(input_ids, min_pad_length=0) + input_ids_padded = input_ids_padded.to(device) + + # Forward pass + with torch.no_grad(): + logits = model( + input_ids_padded, + mask=padding_kwargs["mask"].to(device), + position_ids=padding_kwargs["position_ids"].to(device), + ) + + # Basic sanity checks + assert logits.shape[0] == 1, "Batch size should be 1" + assert logits.shape[1] == input_ids.shape[1], "Sequence length should match input" + assert logits.shape[2] == 151669, "Vocab size should be 151669" + + # Check that logits are reasonable (not NaN or Inf) + assert not torch.isnan(logits).any(), "Logits contain NaN" + assert not torch.isinf(logits).any(), "Logits contain Inf" + + # Check that logits have reasonable range + assert logits.abs().max() < 100, "Logits have unreasonable magnitude" + + +def test_qwen3_parameter_count(): + """ + Test that FMS and HF models have the same number of parameters. + """ + model_path = "Qwen/Qwen3-Embedding-0.6B" + + try: + from transformers import AutoModelForCausalLM + except Exception as e: + pytest.skip(f"Transformers not available: {e}") + + try: + # Load both models + hf_model = AutoModelForCausalLM.from_pretrained(model_path) + fms_model = get_model("hf_pretrained", model_path) + + # Count parameters + def count_parameters(model): + return sum(p.numel() for p in model.parameters()) + + hf_params = count_parameters(hf_model) + fms_params = count_parameters(fms_model) + + assert hf_params == fms_params, \ + f"Parameter count mismatch: HF {hf_params} vs FMS {fms_params}" + + # Verify it's approximately 0.6B parameters + assert 500_000_000 < fms_params < 700_000_000, \ + f"Expected ~600M parameters, got {fms_params}" + + except Exception as e: + pytest.skip(f"Could not load models: {e}") + + +def test_qwen3_with_cache(): + """ + Test that KV caching works correctly in Qwen3. + """ + model_path = "Qwen/Qwen3-Embedding-0.6B" + + try: + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path) + except Exception as e: + pytest.skip(f"Model not available: {e}") + + # Load FMS model + model = get_model( + "hf_pretrained", + model_path, + data_type=torch.float32, + device_type=device, + ) + + model.eval() + + # Prepare input + prompt = "The quick brown fox" + input_ids = _get_inputs(tokenizer, prompt) + input_ids_padded, padding_kwargs = pad_input_ids(input_ids, min_pad_length=0) + input_ids_padded = input_ids_padded.to(device) + + # Forward pass without cache + with torch.no_grad(): + logits_no_cache = model( + input_ids_padded, + mask=padding_kwargs["mask"].to(device), + position_ids=padding_kwargs["position_ids"].to(device), + use_cache=False, + ) + + # Forward pass with cache + with torch.no_grad(): + output_with_cache = model( + input_ids_padded, + mask=padding_kwargs["mask"].to(device), + position_ids=padding_kwargs["position_ids"].to(device), + use_cache=True, + ) + + if isinstance(output_with_cache, tuple): + logits_with_cache, cache = output_with_cache + else: + logits_with_cache = output_with_cache + cache = None + + # Logits should be the same regardless of caching + torch.testing.assert_close( + logits_no_cache, + logits_with_cache, + rtol=1e-5, + atol=1e-5, + msg="Logits differ when using cache" + ) + + # Cache should be returned when use_cache=True + if cache is not None: + assert len(cache) == 28, f"Expected 28 layers of cache, got {len(cache)}" + + +if __name__ == "__main__": + test_qwen3_forward_pass() + test_qwen3_with_cache() + test_qwen3_parameter_count() + test_qwen3_embedding_0_6b_equivalence() + From c23034a7b45589a6db1d26db01d51d98f0d8d598 Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Fri, 20 Feb 2026 10:39:51 -0300 Subject: [PATCH 02/98] Fix mypy/ruff issues Signed-off-by: Flavia Beo --- fms/models/qwen3.py | 11 ++- fms/modules/attention.py | 33 +++---- tests/models/hf_equivalence/test_qwen3.py | 103 +++++++++++----------- 3 files changed, 74 insertions(+), 73 deletions(-) diff --git a/fms/models/qwen3.py b/fms/models/qwen3.py index 514737c27..25d098780 100644 --- a/fms/models/qwen3.py +++ b/fms/models/qwen3.py @@ -515,7 +515,9 @@ def factory(**kwargs): # HuggingFace checkpoint adapter -def _hf_to_fms_names(hf_sd: Mapping[str, Any], model_config: Optional[Qwen3Config] = None) -> Mapping[str, Any]: +def _hf_to_fms_names( + hf_sd: Mapping[str, Any], model_config: Optional[Qwen3Config] = None +) -> Mapping[str, Any]: """ Convert HuggingFace Qwen3 state dict to FMS format """ @@ -545,6 +547,7 @@ def _hf_to_fms_names(hf_sd: Mapping[str, Any], model_config: Optional[Qwen3Confi return new_sd + def _hf_to_fms_rope( input_sd: Mapping[str, Any], model_config: Optional[Qwen3Config] = None, **kwargs ) -> Mapping[str, Any]: @@ -610,9 +613,8 @@ def _hf_to_fms_rope( return new_sd -models.register_model( - _architecture_name, "0.6b", _qwen3_factory_factory(_0_6b_config) -) +models.register_model(_architecture_name, "0.6b", _qwen3_factory_factory(_0_6b_config)) + def _get_rope_params(linear_type: str) -> list[str]: if "gptq" in linear_type: @@ -625,6 +627,7 @@ def _get_rope_params(linear_type: str) -> list[str]: else: # torch.nn.Linear return ["weight", "bias"] + serialization.register_adapter_step( _architecture_name, "hf_to_fms_rope", _hf_to_fms_rope ) diff --git a/fms/modules/attention.py b/fms/modules/attention.py index 3ee94d668..244a58ed3 100644 --- a/fms/modules/attention.py +++ b/fms/modules/attention.py @@ -1,6 +1,5 @@ import abc import functools -from token import OP from typing import ( Any, Callable, @@ -405,7 +404,7 @@ def __init__( self.head_dim = head_dim or emb_kq_per_head self.norm_eps = norm_eps or 1e-5 - + self.query = get_linear( self.emb_dim, self.nheads * self.emb_kq_per_head, @@ -436,13 +435,13 @@ def __init__( use_high_precision_pow=True, ) self.k_norm = LayerNormParameterized( - head_dim, - elementwise_scale=True, - elementwise_shift=False, - use_mean=False, - eps=norm_eps, - use_high_precision_pow=True, - ) + head_dim, + elementwise_scale=True, + elementwise_shift=False, + use_mean=False, + eps=norm_eps, + use_high_precision_pow=True, + ) else: self.norm = False @@ -469,27 +468,27 @@ def forward( # Project queries, keys, values queries = self.query(q) # b x qlen x (nheads * head_dim) - keys = self.key(k) # b x klen x (kvheads * head_dim) - values = self.value(v) # b x vlen x (kvheads * head_dim) - + keys = self.key(k) # b x klen x (kvheads * head_dim) + values = self.value(v) # b x vlen x (kvheads * head_dim) + # Apply normalization if enabled # Normalization should be applied per-head, so we need to reshape first if self.norm: batch_size, q_len, _ = queries.shape k_len = keys.shape[1] - + # Reshape to separate heads: b x len x heads x head_dim queries = queries.view(batch_size, q_len, self.nheads, self.head_dim) keys = keys.view(batch_size, k_len, self.kvheads, self.head_dim) - + # Apply normalization per head queries = self.q_norm(queries) keys = self.k_norm(keys) - + # Reshape back: b x len x (heads * head_dim) queries = queries.view(batch_size, q_len, -1) keys = keys.view(batch_size, k_len, -1) - + return queries, keys, values @@ -507,6 +506,8 @@ def __init__( emb_v_per_head: int, use_bias: bool, linear_config: Optional[Mapping[str, Any]] = None, + norm_eps: Optional[float] = None, + head_dim: Optional[int] = None, *args, **kwargs, ): diff --git a/tests/models/hf_equivalence/test_qwen3.py b/tests/models/hf_equivalence/test_qwen3.py index 966425527..521be2e59 100644 --- a/tests/models/hf_equivalence/test_qwen3.py +++ b/tests/models/hf_equivalence/test_qwen3.py @@ -22,13 +22,10 @@ def _get_inputs(tokenizer, prompt="Hello, how are you?"): def _get_hf_model_output(model_path, inputs): """Get output from HuggingFace model""" - from transformers import AutoModel, AutoTokenizer + from transformers import AutoModel + + model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float32).to(device) - model = AutoModel.from_pretrained( - model_path, torch_dtype=torch.float32 - ).to(device) - tokenizer = AutoTokenizer.from_pretrained(model_path) - model.eval() with torch.no_grad(): outputs = model(**inputs) @@ -39,7 +36,7 @@ def _get_hf_model_output(model_path, inputs): query_embedding = embeddings[0] doc_embeddings = embeddings[1:] - + return query_embedding, doc_embeddings @@ -51,13 +48,13 @@ def _get_fms_model_output(model_path, inputs): data_type=torch.float32, device_type=device, ) - + model.eval() torch.set_grad_enabled(False) # Get input_ids from the inputs dict input_ids = inputs["input_ids"].to(device) - + # Prepare inputs for FMS - this will create appropriate mask and position_ids input_ids_padded, padding_kwargs = pad_input_ids(input_ids, min_pad_length=0) input_ids_padded = input_ids_padded.to(device) @@ -76,7 +73,7 @@ def _get_fms_model_output(model_path, inputs): query_embedding = embeddings[0] doc_embeddings = embeddings[1:] - + return query_embedding, doc_embeddings @@ -84,43 +81,40 @@ def _get_fms_model_output(model_path, inputs): def test_qwen3_embedding_0_6b_equivalence(): """ Test equivalence between FMS and HuggingFace implementations of Qwen3-Embedding-0.6B. - + This test: 1. Loads both HF and FMS versions of the model 2. Compares logits for the same input 3. Compares generated sequences - + Note: This test requires downloading the model from HuggingFace Hub. """ model_path = "Qwen/Qwen3-Embedding-0.6B" - + # Skip if model is not available locally try: from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path) except Exception as e: pytest.skip(f"Model not available: {e}") - # Prepare input query = "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: What is the capital of China?" - documents = [ - "The capital of China is Beijing.", - "That is a very fast car." - ] + documents = ["The capital of China is Beijing.", "That is a very fast car."] input_texts = [query] + documents # 3. Tokenize - inputs = tokenizer(input_texts, padding=True, truncation=True, return_tensors='pt', max_length=8192) - + inputs = tokenizer( + input_texts, padding=True, truncation=True, return_tensors="pt", max_length=8192 + ) - # Get outputs from both models + # Get outputs from both models hf_query_embedding, hf_doc_embeddings = _get_hf_model_output(model_path, inputs) fms_query_embedding, fms_doc_embeddings = _get_fms_model_output(model_path, inputs) - hf_scores = (hf_query_embedding @ hf_doc_embeddings.T) - fms_scores = (fms_query_embedding @ fms_doc_embeddings.T) - + hf_scores = hf_query_embedding @ hf_doc_embeddings.T + fms_scores = fms_query_embedding @ fms_doc_embeddings.T # First sentence contains the awnser to the query. # It's score should be always the highest. @@ -135,18 +129,19 @@ def test_qwen3_embedding_0_6b_equivalence(): def test_qwen3_forward_pass(): """ Test basic forward pass of Qwen3 model. - + This is a simpler test that just verifies the model can be loaded and produces reasonable outputs without comparing to HF. """ model_path = "Qwen/Qwen3-Embedding-0.6B" - + try: from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path) except Exception as e: pytest.skip(f"Model not available: {e}") - + # Load FMS model model = get_model( "hf_pretrained", @@ -154,15 +149,15 @@ def test_qwen3_forward_pass(): data_type=torch.float32, device_type=device, ) - + model.eval() - + # Prepare input prompt = "Hello, world!" input_ids = _get_inputs(tokenizer, prompt) input_ids_padded, padding_kwargs = pad_input_ids(input_ids, min_pad_length=0) input_ids_padded = input_ids_padded.to(device) - + # Forward pass with torch.no_grad(): logits = model( @@ -170,16 +165,16 @@ def test_qwen3_forward_pass(): mask=padding_kwargs["mask"].to(device), position_ids=padding_kwargs["position_ids"].to(device), ) - + # Basic sanity checks assert logits.shape[0] == 1, "Batch size should be 1" assert logits.shape[1] == input_ids.shape[1], "Sequence length should match input" assert logits.shape[2] == 151669, "Vocab size should be 151669" - + # Check that logits are reasonable (not NaN or Inf) assert not torch.isnan(logits).any(), "Logits contain NaN" assert not torch.isinf(logits).any(), "Logits contain Inf" - + # Check that logits have reasonable range assert logits.abs().max() < 100, "Logits have unreasonable magnitude" @@ -189,31 +184,33 @@ def test_qwen3_parameter_count(): Test that FMS and HF models have the same number of parameters. """ model_path = "Qwen/Qwen3-Embedding-0.6B" - + try: from transformers import AutoModelForCausalLM except Exception as e: pytest.skip(f"Transformers not available: {e}") - + try: # Load both models hf_model = AutoModelForCausalLM.from_pretrained(model_path) fms_model = get_model("hf_pretrained", model_path) - + # Count parameters def count_parameters(model): return sum(p.numel() for p in model.parameters()) - + hf_params = count_parameters(hf_model) fms_params = count_parameters(fms_model) - - assert hf_params == fms_params, \ + + assert hf_params == fms_params, ( f"Parameter count mismatch: HF {hf_params} vs FMS {fms_params}" - + ) + # Verify it's approximately 0.6B parameters - assert 500_000_000 < fms_params < 700_000_000, \ + assert 500_000_000 < fms_params < 700_000_000, ( f"Expected ~600M parameters, got {fms_params}" - + ) + except Exception as e: pytest.skip(f"Could not load models: {e}") @@ -223,13 +220,14 @@ def test_qwen3_with_cache(): Test that KV caching works correctly in Qwen3. """ model_path = "Qwen/Qwen3-Embedding-0.6B" - + try: from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path) except Exception as e: pytest.skip(f"Model not available: {e}") - + # Load FMS model model = get_model( "hf_pretrained", @@ -237,15 +235,15 @@ def test_qwen3_with_cache(): data_type=torch.float32, device_type=device, ) - + model.eval() - + # Prepare input prompt = "The quick brown fox" input_ids = _get_inputs(tokenizer, prompt) input_ids_padded, padding_kwargs = pad_input_ids(input_ids, min_pad_length=0) input_ids_padded = input_ids_padded.to(device) - + # Forward pass without cache with torch.no_grad(): logits_no_cache = model( @@ -254,7 +252,7 @@ def test_qwen3_with_cache(): position_ids=padding_kwargs["position_ids"].to(device), use_cache=False, ) - + # Forward pass with cache with torch.no_grad(): output_with_cache = model( @@ -263,22 +261,22 @@ def test_qwen3_with_cache(): position_ids=padding_kwargs["position_ids"].to(device), use_cache=True, ) - + if isinstance(output_with_cache, tuple): logits_with_cache, cache = output_with_cache else: logits_with_cache = output_with_cache cache = None - + # Logits should be the same regardless of caching torch.testing.assert_close( logits_no_cache, logits_with_cache, rtol=1e-5, atol=1e-5, - msg="Logits differ when using cache" + msg="Logits differ when using cache", ) - + # Cache should be returned when use_cache=True if cache is not None: assert len(cache) == 28, f"Expected 28 layers of cache, got {len(cache)}" @@ -289,4 +287,3 @@ def test_qwen3_with_cache(): test_qwen3_with_cache() test_qwen3_parameter_count() test_qwen3_embedding_0_6b_equivalence() - From 8997ee6f956cab62dd1fc03e9c418e6be998754d Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Mon, 23 Feb 2026 14:27:52 -0300 Subject: [PATCH 03/98] Adds config and register 4B model Signed-off-by: Flavia Beo --- fms/models/qwen3.py | 25 ++++- tests/models/hf_equivalence/test_qwen3.py | 53 +++++++++- tests/models/test_qwen3_embeddings.py | 123 ++++++++++++++++++++++ 3 files changed, 194 insertions(+), 7 deletions(-) create mode 100644 tests/models/test_qwen3_embeddings.py diff --git a/fms/models/qwen3.py b/fms/models/qwen3.py index 25d098780..4920e0640 100644 --- a/fms/models/qwen3.py +++ b/fms/models/qwen3.py @@ -491,7 +491,7 @@ def forward( # Qwen3-Embedding-0.6B configuration _0_6b_config = Qwen3Config( - src_vocab_size=151_669, + src_vocab_size=151669, emb_dim=1024, norm_eps=1e-6, nheads=16, @@ -499,7 +499,21 @@ def forward( nlayers=28, hidden_grow_factor=3072 / 1024, max_expected_seq_len=32768, - rope_theta=1_000_000.0, + rope_theta=1000000.0, + head_dim=128, + tie_heads=True, +) + +_4b_config = Qwen3Config( + src_vocab_size=151665, + emb_dim=2560, + norm_eps=1e-6, + nheads=32, + kvheads=8, + nlayers=36, + hidden_grow_factor=9728 / 2560, + max_expected_seq_len=40960, + rope_theta=1000000.0, head_dim=128, tie_heads=True, ) @@ -514,6 +528,10 @@ def factory(**kwargs): return factory +models.register_model(_architecture_name, "0.6b", _qwen3_factory_factory(_0_6b_config)) +models.register_model(_architecture_name, "4b", _qwen3_factory_factory(_4b_config)) + + # HuggingFace checkpoint adapter def _hf_to_fms_names( hf_sd: Mapping[str, Any], model_config: Optional[Qwen3Config] = None @@ -613,9 +631,6 @@ def _hf_to_fms_rope( return new_sd -models.register_model(_architecture_name, "0.6b", _qwen3_factory_factory(_0_6b_config)) - - def _get_rope_params(linear_type: str) -> list[str]: if "gptq" in linear_type: return ["qweight", "scales", "qzeros", "bias"] diff --git a/tests/models/hf_equivalence/test_qwen3.py b/tests/models/hf_equivalence/test_qwen3.py index 521be2e59..c9f67ecf4 100644 --- a/tests/models/hf_equivalence/test_qwen3.py +++ b/tests/models/hf_equivalence/test_qwen3.py @@ -84,8 +84,8 @@ def test_qwen3_embedding_0_6b_equivalence(): This test: 1. Loads both HF and FMS versions of the model - 2. Compares logits for the same input - 3. Compares generated sequences + 2. Compares scores for the same input + 3. Compares scores against the original Qwen3-Embedding-0.6B scores Note: This test requires downloading the model from HuggingFace Hub. """ @@ -126,6 +126,55 @@ def test_qwen3_embedding_0_6b_equivalence(): assert fms_scores[0] > hf_scores[1] +@pytest.mark.slow +def test_qwen3_embedding_4b_equivalence(): + """ + Test equivalence between FMS and HuggingFace implementations of Qwen3-Embedding-4B. + + This test: + 1. Loads both HF and FMS versions of the model + 2. Compares scores for the same input + 3. Compares scores against the original Qwen3-Embedding-4B scores + + Note: This test requires downloading the model from HuggingFace Hub. + """ + model_path = "Qwen/Qwen3-Embedding-4B" + + # Skip if model is not available locally + try: + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_path) + except Exception as e: + pytest.skip(f"Model not available: {e}") + + # Prepare input + query = "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: What is the capital of China?" + documents = ["The capital of China is Beijing.", "That is a very fast car."] + input_texts = [query] + documents + + # 3. Tokenize + inputs = tokenizer( + input_texts, padding=True, truncation=True, return_tensors="pt", max_length=8192 + ) + + # Get outputs from both models + hf_query_embedding, hf_doc_embeddings = _get_hf_model_output(model_path, inputs) + fms_query_embedding, fms_doc_embeddings = _get_fms_model_output(model_path, inputs) + + hf_scores = hf_query_embedding @ hf_doc_embeddings.T + fms_scores = fms_query_embedding @ fms_doc_embeddings.T + + # First sentence contains the awnser to the query. + # It's score should be always the highest. + assert hf_scores[0] > hf_scores[1] + assert fms_scores[0] > fms_scores[1] + assert fms_scores[0] > 0.7 + assert hf_scores[0] > 0.7 + assert hf_scores[0] > fms_scores[1] + assert fms_scores[0] > hf_scores[1] + + def test_qwen3_forward_pass(): """ Test basic forward pass of Qwen3 model. diff --git a/tests/models/test_qwen3_embeddings.py b/tests/models/test_qwen3_embeddings.py new file mode 100644 index 000000000..5842c77cf --- /dev/null +++ b/tests/models/test_qwen3_embeddings.py @@ -0,0 +1,123 @@ +import pytest +import torch + +from fms.models.qwen3 import Qwen3, Qwen3Config, Qwen3Headless +from fms.testing._internal.model_test_suite import ( + ConfigFixtureMixin, + ModelCompileTestSuite, + ModelConfigTestSuite, + ModelConsistencyTestSuite, + ModelFixtureMixin, +) +from fms.utils.config import ModelConfig + + +class Qwen3Fixtures(ConfigFixtureMixin, ModelFixtureMixin): + """ + Base Qwen3 Fixtures that can be re-used for other purposes + + This will include the config and model signatures + """ + + @pytest.fixture(scope="class", autouse=True) + def uninitialized_model(self, config: Qwen3Config): + return Qwen3(config) + + @pytest.fixture(scope="class", autouse=True) + def config(self) -> ModelConfig: + return Qwen3Config( + src_vocab_size=384, + emb_dim=64, + norm_eps=1e-05, + nheads=32, + head_dim=64 // 32, + kvheads=16, + nlayers=2, + hidden_grow_factor=2.0, + max_expected_seq_len=1024, + ) + + @pytest.fixture(scope="class", autouse=True) + def model(self, uninitialized_model: torch.nn.Module): + """include this fixture to get a model that is fully initialized""" + + torch.random.manual_seed(5) + sd = uninitialized_model.state_dict() + params = sorted(sd.keys()) + for key in params: + parameter = sd[key] + opt_parameter_initialized = self._maybe_get_initialized_parameter( + key, parameter + ) + if opt_parameter_initialized is not None: + parameter.copy_(opt_parameter_initialized) + else: + values = torch.randn_like(parameter) + values -= 0.5 + values /= 20.0 + parameter.copy_(values) + + device = "cuda" if torch.cuda.is_available() else "cpu" + model = uninitialized_model.to(device) + + # Pre-compute RoPE frequencies for the target device to avoid graph breaks + model.base_model.rot_emb.compute_freqs_cis( + torch.device(device), model.base_model.config.max_expected_seq_len + ) + + return model + + +class TestQwen3( + ModelConfigTestSuite, + ModelConsistencyTestSuite, + ModelCompileTestSuite, + Qwen3Fixtures, +): + # x is the main parameter for this model which is the input tensor + _get_signature_params = ["x"] + + def test_model_compile_no_graph_breaks(self, model): + """Test that Qwen3 model is compilable without graph breaks""" + import platform + from torch._dynamo.exc import TorchDynamoException + from torch._dynamo.testing import CompileCounterWithBackend + from fms.testing.comparison import get_signature + + if platform.system() != "Linux": + pytest.skip( + f"pytorch compile is more stable on Linux, skipping as current platform is {platform.platform()}" + ) + + try: + torch._dynamo.reset() + + # Move model to the appropriate device before compilation + device = "cuda" if torch.cuda.is_available() else "cpu" + model = model.to(device) + + cnt = CompileCounterWithBackend("inductor") + compiled_model = torch.compile(model=model, backend=cnt, fullgraph=True) + assert cnt.frame_count == 0 + + optional_params = ( + self._get_signature_optional_params.copy() + if self._get_signature_optional_params + else {} + ) + # default attn_algorithm won't compile on CPU for older pytorch versions + optional_params["attn_algorithm"] = "math" + + get_signature( + compiled_model, + params=self._get_signature_params, + optional_params=optional_params, + logits_getter_fn=self._get_signature_logits_getter_fn, + device=device, + ) + assert cnt.frame_count == 1 + except TorchDynamoException as e: + pytest.fail(f"Failed to get signature of full-graph compiled model:\n{e}") + + def test_model_unfused(self, model, signature): + pytest.skip("weight unfuse is not implemented") From 4b85f17c2916ca039103ae0954e5a3fe99b4acac Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Mon, 23 Feb 2026 14:36:39 -0300 Subject: [PATCH 04/98] Adds model expectations Signed-off-by: Flavia Beo --- .../models.test_qwen3_embeddings.TestQwen3.test_model_output | 1 + ...models.test_qwen3_embeddings.TestQwen3.test_model_weight_keys | 1 + 2 files changed, 2 insertions(+) create mode 100644 tests/resources/expectations/models.test_qwen3_embeddings.TestQwen3.test_model_output create mode 100644 tests/resources/expectations/models.test_qwen3_embeddings.TestQwen3.test_model_weight_keys diff --git a/tests/resources/expectations/models.test_qwen3_embeddings.TestQwen3.test_model_output b/tests/resources/expectations/models.test_qwen3_embeddings.TestQwen3.test_model_output new file mode 100644 index 000000000..b2313e4ca --- /dev/null +++ b/tests/resources/expectations/models.test_qwen3_embeddings.TestQwen3.test_model_output @@ -0,0 +1 @@ +0.02183489501476288, 0.0669936090707779, 0.029024794697761536, 0.004201903939247131, 0.0, 0.0398319810628891, 0.05619408190250397, 0.013838574290275574, 0.02368016541004181, 0.024958208203315735, 0.057019785046577454, 0.018252834677696228, 0.04230411350727081, 0.05843259394168854, 0.04746481776237488, 0.03072647750377655 \ No newline at end of file diff --git a/tests/resources/expectations/models.test_qwen3_embeddings.TestQwen3.test_model_weight_keys b/tests/resources/expectations/models.test_qwen3_embeddings.TestQwen3.test_model_weight_keys new file mode 100644 index 000000000..bb517e2a9 --- /dev/null +++ b/tests/resources/expectations/models.test_qwen3_embeddings.TestQwen3.test_model_weight_keys @@ -0,0 +1 @@ +base_model.dec_norm.weight,base_model.embedding.weight,base_model.layers.0.attn.dense.weight,base_model.layers.0.attn.in_proj.k_norm.weight,base_model.layers.0.attn.in_proj.key.weight,base_model.layers.0.attn.in_proj.q_norm.weight,base_model.layers.0.attn.in_proj.query.weight,base_model.layers.0.attn.in_proj.value.weight,base_model.layers.0.ff_ln.weight,base_model.layers.0.ff_sub_layer.w1.weight,base_model.layers.0.ff_sub_layer.w2.weight,base_model.layers.0.ff_sub_layer.wg.weight,base_model.layers.0.ln.weight,base_model.layers.1.attn.dense.weight,base_model.layers.1.attn.in_proj.k_norm.weight,base_model.layers.1.attn.in_proj.key.weight,base_model.layers.1.attn.in_proj.q_norm.weight,base_model.layers.1.attn.in_proj.query.weight,base_model.layers.1.attn.in_proj.value.weight,base_model.layers.1.ff_ln.weight,base_model.layers.1.ff_sub_layer.w1.weight,base_model.layers.1.ff_sub_layer.w2.weight,base_model.layers.1.ff_sub_layer.wg.weight,base_model.layers.1.ln.weight,head.weight \ No newline at end of file From a031640e34b5e0540800dc44bee69229c3739b70 Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Mon, 23 Feb 2026 14:37:52 -0300 Subject: [PATCH 05/98] Remove unused import Signed-off-by: Flavia Beo --- tests/models/test_qwen3_embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_qwen3_embeddings.py b/tests/models/test_qwen3_embeddings.py index 5842c77cf..38268a29d 100644 --- a/tests/models/test_qwen3_embeddings.py +++ b/tests/models/test_qwen3_embeddings.py @@ -1,7 +1,7 @@ import pytest import torch -from fms.models.qwen3 import Qwen3, Qwen3Config, Qwen3Headless +from fms.models.qwen3 import Qwen3, Qwen3Config from fms.testing._internal.model_test_suite import ( ConfigFixtureMixin, ModelCompileTestSuite, From 7a76e8f6054e038aa952aa3463baa8f4f25d3f01 Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Wed, 25 Feb 2026 15:28:16 -0300 Subject: [PATCH 06/98] Changes added to support model.compile Signed-off-by: Flavia Beo --- fms/models/qwen3.py | 26 +-- scripts/qwen3_embedding_compiled_example.py | 203 ++++++++++++++++++++ tests/models/hf_equivalence/test_qwen3.py | 2 +- 3 files changed, 213 insertions(+), 18 deletions(-) create mode 100644 scripts/qwen3_embedding_compiled_example.py diff --git a/fms/models/qwen3.py b/fms/models/qwen3.py index 4920e0640..843532619 100644 --- a/fms/models/qwen3.py +++ b/fms/models/qwen3.py @@ -17,7 +17,6 @@ from fms.modules.attention import ( AttentionKwargs, MultiHeadAttention, - get_attention_type, ) from fms.modules.feedforward import GatedLinearUnit from fms.modules.head import LinearClassificationHead @@ -27,7 +26,6 @@ from fms.utils import serialization from fms.utils.activation import str_to_activation from fms.utils.config import ModelConfig -from fms.utils.headless import gather_outputs logger = logging.getLogger(__name__) @@ -148,16 +146,13 @@ def forward( use_cache=False, **attn_kwargs: Unpack[AttentionKwargs], ): - # if the cache is not empty, we need to get the kv cache for self attention - self_attn_past_key_value = past_key_value_state - # first we do MHA and Add&Norm residual = x x = self.ln(x) x = self.attn( q=x, position_ids=position_ids, - past_key_value_state=self_attn_past_key_value, + past_key_value_state=past_key_value_state, use_cache=use_cache, **attn_kwargs, ) @@ -457,7 +452,10 @@ def post_init(self): if self.head.weight.device == torch.device("meta"): self.head.weight = self.base_model.embedding.weight else: - self.base_model.embedding.weight = self.head.weight + # For torch.compile compatibility, copy weights instead of tying them + # This avoids graph tracing issues with shared tensors in certain backends + with torch.no_grad(): + self.base_model.embedding.weight.copy_(self.head.weight) def forward( self, @@ -468,23 +466,16 @@ def forward( last_n_tokens: int = 0, **attn_kwargs: Unpack[AttentionKwargs], ): - get_attention_type(**attn_kwargs)["validate_attn_kwargs"]( - input_ids=x, - position_ids=position_ids, - past_key_value_states=past_key_value_states, - **attn_kwargs, - ) output, cache = self.base_model( x, position_ids, past_key_value_states, use_cache, **attn_kwargs ) - output = gather_outputs(output, last_n_tokens, **attn_kwargs) - preds = self.head(output) + output = self.head(output) if use_cache: - return preds, cache + return output, cache else: - return preds + return output # Register Qwen3 variants with the model registration API @@ -540,6 +531,7 @@ def _hf_to_fms_names( Convert HuggingFace Qwen3 state dict to FMS format """ replacements = [ + (r"^lm_head.weight", "head.weight"), (r"^norm.weight", "base_model.dec_norm.weight"), (r"^embed_tokens.weight", "base_model.embedding.weight"), (r"layers", "base_model.layers"), diff --git a/scripts/qwen3_embedding_compiled_example.py b/scripts/qwen3_embedding_compiled_example.py new file mode 100644 index 000000000..8d491883f --- /dev/null +++ b/scripts/qwen3_embedding_compiled_example.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 +""" +Example script to run Qwen3-Embedding-0.6B with torch.compile for optimized inference. + +This script demonstrates: +1. Loading the Qwen3-Embedding-0.6B model from HuggingFace +2. Compiling the model with torch.compile for better performance +3. Computing embeddings for queries and documents +4. Calculating similarity scores + +Usage: + python scripts/qwen3_embedding_compiled_example.py +""" + +import torch +import torch.nn.functional as F +from fms.models import get_model +from fms.utils.generation import pad_input_ids + +# Configuration +MODEL_PATH = "Qwen/Qwen3-Embedding-0.6B" +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +COMPILE_MODE = "default" # Options: "default", "reduce-overhead", "max-autotune" + + +def load_and_compile_model(model_path: str, device: str, compile_mode: str = "default"): + """ + Load the Qwen3 model and compile it for optimized inference. + + Args: + model_path: Path to the model (HuggingFace model ID or local path) + device: Device to run on ("cuda" or "cpu") + compile_mode: Compilation mode for torch.compile + + Returns: + Compiled model ready for inference + """ + print(f"Loading model from {model_path}...") + model = get_model( + "hf_pretrained", + model_path, + data_type=torch.float32, + device_type=device, + ) + + model.eval() + torch.set_grad_enabled(False) + + print(f"Compiling model with mode='{compile_mode}'...") + # Compile the base model for better performance + # We compile base_model instead of the full model since we only need embeddings + model = torch.compile( # type: ignore[assignment,arg-type] + model, # type: ignore[arg-type] + mode=compile_mode, + ) + + print("Model loaded and compiled successfully!") + return model + + +def get_embeddings(model, input_ids, device): + """ + Get normalized embeddings from the model. + + Args: + model: The Qwen3 model + input_ids: Input token IDs [batch_size, seq_len] + device: Device to run on + + Returns: + Normalized embeddings [batch_size, emb_dim] + """ + # Prepare inputs for FMS - this will create appropriate mask and position_ids + input_ids_padded, padding_kwargs = pad_input_ids(input_ids, min_pad_length=0) + input_ids_padded = input_ids_padded.to(device) + + with torch.no_grad(): + # Get embeddings from base model (before LM head) + embeddings = model( + input_ids_padded, + mask=padding_kwargs["mask"].to(device), + position_ids=padding_kwargs["position_ids"].to(device), + ) + # The model uses the last token's representation as the embedding + embeddings = embeddings[:, -1, :] + # Normalize embeddings for cosine similarity + embeddings = F.normalize(embeddings, p=2, dim=1) + + return embeddings + + +def compute_similarity_scores( + query_embedding: torch.Tensor, doc_embeddings: torch.Tensor +): + """ + Compute cosine similarity scores between query and documents. + + Args: + query_embedding: Query embedding [emb_dim] + doc_embeddings: Document embeddings [num_docs, emb_dim] + + Returns: + Similarity scores [num_docs] + """ + scores = query_embedding @ doc_embeddings.T + return scores + + +def main(): + """Main execution function.""" + print("=" * 80) + print("Qwen3-Embedding-0.6B Compiled Model Example") + print("=" * 80) + print() + + # Check if transformers is available + try: + from transformers import AutoTokenizer + except ImportError: + print( + "Error: transformers library is required. Install with: pip install transformers" + ) + return + + # Load tokenizer + print(f"Loading tokenizer from {MODEL_PATH}...") + try: + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) + except Exception as e: + print(f"Error loading tokenizer: {e}") + print("Make sure you have internet connection or the model is cached locally.") + return + + # Load and compile model + try: + model = load_and_compile_model(MODEL_PATH, DEVICE, COMPILE_MODE) + except Exception as e: + print(f"Error loading model: {e}") + return + + print() + print("-" * 80) + print("Running inference example...") + print("-" * 80) + print() + + # Prepare example data + query = "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: What is the capital of China?" + documents = [ + "The capital of China is Beijing.", + "That is a very fast car.", + "Beijing is a major city in China and serves as the nation's capital.", + ] + + print(f"Query: {query}") + print() + print("Documents:") + for i, doc in enumerate(documents, 1): + print(f" {i}. {doc}") + print() + + # Tokenize inputs + input_texts = [query] + documents + inputs = tokenizer( + input_texts, padding=True, truncation=True, return_tensors="pt", max_length=8192 + ) + input_ids = inputs["input_ids"].to(DEVICE) + + # Get embeddings + print("Computing embeddings...") + embeddings = get_embeddings(model, input_ids, DEVICE) + + # Split into query and document embeddings + query_embedding = embeddings[0] + doc_embeddings = embeddings[1:] + + # Compute similarity scores + scores = compute_similarity_scores(query_embedding, doc_embeddings) + + # Display results + print() + print("-" * 80) + print("Results:") + print("-" * 80) + print() + print("Similarity Scores:") + for i, (doc, score) in enumerate(zip(documents, scores), 1): + print(f" Document {i}: {score.item():.4f}") + print(f' "{doc}"') + print() + + # Find most relevant document + best_idx = int(scores.argmax().item()) + print( + f"Most relevant document: Document {best_idx + 1} (score: {scores[best_idx].item():.4f})" + ) + print(f' "{documents[best_idx]}"') + + +if __name__ == "__main__": + main() + +# Made with Bob diff --git a/tests/models/hf_equivalence/test_qwen3.py b/tests/models/hf_equivalence/test_qwen3.py index c9f67ecf4..24b2b339c 100644 --- a/tests/models/hf_equivalence/test_qwen3.py +++ b/tests/models/hf_equivalence/test_qwen3.py @@ -61,7 +61,7 @@ def _get_fms_model_output(model_path, inputs): with torch.no_grad(): # Get embeddings from base model (before LM head) - embeddings, _ = model.base_model( + embeddings = model( input_ids_padded, mask=padding_kwargs["mask"].to(device), position_ids=padding_kwargs["position_ids"].to(device), From c593a95f9f740b2c02bbfad263c35eb8b67cff6d Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Tue, 3 Mar 2026 16:28:18 -0600 Subject: [PATCH 07/98] :bug: Fix pad token id resolution in padding function Signed-off-by: Gaurav-Kumbhat --- fms/utils/generation.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/fms/utils/generation.py b/fms/utils/generation.py index 51d5f3064..ae749aefc 100644 --- a/fms/utils/generation.py +++ b/fms/utils/generation.py @@ -19,6 +19,7 @@ def pad_input_ids( is_causal_mask=True, padding_side="left", position_ids_offset=0, + pad_token_id=0, ) -> Tuple[torch.Tensor, MutableMapping[str, Any]]: """ Convert a list of Tensors to a rectangular tensor. Return extra padding kwargs for the position_ids and mask, since @@ -34,7 +35,8 @@ def pad_input_ids( position_ids_offset: int some models are trained with position_ids that do not start at 0 but at pad_id + 1. The default parameter here will work for most models, but for example MPNet requires passing a real pad_id. - + pad_token_id: int + the token ID to use for padding. Default is 0. Returns ------- Tuple[torch.Tensor, MutableMapping[str, Any]] @@ -49,25 +51,34 @@ def pad_input_ids( position_ids_list = [] for input_ids_i in input_ids_list: seq_len = input_ids_i.size(0) - pads = torch.zeros( - max_len - seq_len, dtype=torch.long, device=input_ids_i.device + pads = torch.full( + (max_len - seq_len,), + pad_token_id, + dtype=torch.long, + device=input_ids_i.device, ) non_pads = torch.ones(seq_len, dtype=torch.bool, device=input_ids_i.device) # Setting this to 0, however if 0 is the eos, we will end up truncating the output if using truncate_after_eos # once this workflow works for nested tensor, this can probably be removed - pos_ids_pads = pads + pos_ids_pads = torch.zeros( + max_len - seq_len, dtype=torch.long, device=input_ids_i.device + ) pos_ids_seq = torch.arange( 0, seq_len, dtype=torch.long, device=input_ids_i.device ) if padding_side == "left": padded_input_ids_list.append(torch.cat((pads, input_ids_i))) - mask_list.append(torch.cat((pads.bool(), non_pads))) + mask_list.append( + torch.cat((pads.bool(), non_pads)) + ) # This will be False for pad tokens position_ids_list.append(torch.cat((pos_ids_pads, pos_ids_seq))) elif padding_side == "right": padded_input_ids_list.append(torch.cat((input_ids_i, pads))) - mask_list.append(torch.cat((non_pads, pads.bool()))) + mask_list.append( + torch.cat((non_pads, pads.bool())) + ) # This will be False for pad tokens position_ids_list.append(torch.cat((pos_ids_seq, pos_ids_pads))) else: raise NotImplementedError("padding_side must be 'right' or left'") From c48ce21adf16d3d6e9803782fe8ad38e6643c9c6 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Wed, 4 Mar 2026 15:16:19 -0600 Subject: [PATCH 08/98] :sparkles: Add pad token id to generate function Signed-off-by: Gaurav-Kumbhat --- fms/utils/generation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fms/utils/generation.py b/fms/utils/generation.py index ae749aefc..ab7bd60a5 100644 --- a/fms/utils/generation.py +++ b/fms/utils/generation.py @@ -184,6 +184,7 @@ def generate( use_cache: bool = False, contiguous_cache: bool = False, eos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, timing: str = "", prepare_model_inputs_hook: Optional[ Callable[ @@ -220,6 +221,7 @@ def generate( past_key_value_states args in forward method. contiguous_cache: ensures the cache is contiguous in device memory eos_token_id: the optional token id representing the end of sequence + pad_token_id: the optional token id representing the pad token timing: whether to measure timings: "per-token" for each token generation time, "e2e" for full generation loop. Both options make `generate` return a tuple with the following information: From 73d1aafd628900bdf035d639bcecd7ad5ac54616 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Fri, 6 Mar 2026 18:35:47 +0000 Subject: [PATCH 09/98] :construction: Initiate ministral3 Signed-off-by: Gaurav-Kumbhat --- fms/models/__init__.py | 2 + fms/models/hf/config_utils/__init__.py | 2 + fms/models/hf/config_utils/param_builders.py | 70 +++++- fms/models/hf/modeling_hf_adapter.py | 10 +- fms/models/hf/utils.py | 15 ++ fms/models/ministral3.py | 231 +++++++++++++++++++ 6 files changed, 328 insertions(+), 2 deletions(-) create mode 100644 fms/models/ministral3.py diff --git a/fms/models/__init__.py b/fms/models/__init__.py index ff6932d31..92487a50c 100644 --- a/fms/models/__init__.py +++ b/fms/models/__init__.py @@ -500,6 +500,7 @@ def model_wrap(model): llama, llava_next, mistral, + ministral3, mistral3, mixtral, roberta, @@ -515,6 +516,7 @@ def model_wrap(model): "llama", "llava_next", "mistral", + "ministral3", "mistral3", "mixtral", "roberta", diff --git a/fms/models/hf/config_utils/__init__.py b/fms/models/hf/config_utils/__init__.py index 3bb7d588c..5db646a57 100644 --- a/fms/models/hf/config_utils/__init__.py +++ b/fms/models/hf/config_utils/__init__.py @@ -45,6 +45,8 @@ "MPNetForMaskedLM": ("mpnet", pb.build_mpnet_params), "BertForMaskedLM": ("bert", pb.build_bert_params), "Mistral3ForConditionalGeneration": ("mistral3", pb.build_mistral3_params), + # NOTE: Special case for Ministral3 + "FMSMinistral3ForConditionalGeneration": ("ministral3", pb.build_ministral3_params), # Classify arches have some extra keys for labels "RobertaForSequenceClassification": ("roberta_classification", partial(pb.build_roberta_params, is_classify=True)), "BertForSequenceClassification": ("bert_classification", partial(pb.build_bert_params, is_classify=True)), diff --git a/fms/models/hf/config_utils/param_builders.py b/fms/models/hf/config_utils/param_builders.py index 9fdf77c9f..84bca517a 100644 --- a/fms/models/hf/config_utils/param_builders.py +++ b/fms/models/hf/config_utils/param_builders.py @@ -11,6 +11,7 @@ # Used for mistral3 from fms.models.pixtral_vision import PixtralVisionConfig from fms.models.mistral import MistralConfig +from fms.models.ministral3 import Ministral3TextConfig from transformers import PretrainedConfig @@ -299,6 +300,13 @@ def build_pixtral_params(config: PretrainedConfig) -> dict: # we use the same default in Pixtral's encoder, which is 1e-5, # but should be aware in case this is changed and added to the # config in future releases. + + # To handle cases such as ministral3 + if hasattr(config, "rope_parameters"): + rope_theta = config.rope_parameters["rope_theta"] + else: + rope_theta = config.rope_theta + config_params = { "hidden_size": config.hidden_size, "intermediate_size": config.intermediate_size, @@ -308,7 +316,7 @@ def build_pixtral_params(config: PretrainedConfig) -> dict: "image_size": config.image_size, "patch_size": config.patch_size, "hidden_act": config.hidden_act, - "rope_theta": config.rope_theta, + "rope_theta": rope_theta, "attention_dropout": config.attention_dropout, "initializer_range": config.initializer_range, } @@ -320,6 +328,10 @@ def build_pixtral_params(config: PretrainedConfig) -> dict: def build_mistral3_params(config: PretrainedConfig) -> dict: """Param builder for mapping Mistral3ForConditionalGeneration to FMS.""" + ## NOTE: Since ministral3 and mistral3 uses same architecture class + # we are combining their build params function into one. They also + # use same vision model + # Sanity checks – we currently support only Mistral text + Pixtral vision if getattr(config.text_config, "model_type", None) != "mistral": raise ValueError( @@ -346,6 +358,62 @@ def build_mistral3_params(config: PretrainedConfig) -> dict: return config_params +def build_ministral3_params(config: PretrainedConfig) -> dict: + """Param builder for ministral3 mapping Mistral3ForConditionalGeneration to FMS.""" + + ## NOTE: Since ministral3 and mistral3 uses same architecture class + # we are combining their build params function into one. They also + # use same vision model + + # Sanity checks – we currently support only Mistral text + Pixtral vision + if getattr(config.text_config, "model_type", None) != "ministral3": + raise ValueError( + "FMS implementation of Mistral3 currently supports only 'mistral' language model" + ) + + if getattr(config.vision_config, "model_type", None) != "pixtral": + raise ValueError( + "FMS implementation of Mistral3 currently supports only 'pixtral' vision tower" + ) + config_params = { + "projector_hidden_act": config.projector_hidden_act, + "multimodal_projector_bias": config.multimodal_projector_bias, + "spatial_merge_size": config.spatial_merge_size, + "image_token_index": config.image_token_index, + "vision_feature_layer": config.vision_feature_layer, + } + # Handle text / vision subconfigs, respectively + text_config_params = build_ministral3_text_params(config.text_config) + config_params["text_config"] = Ministral3TextConfig(**text_config_params) + + vision_config_params = build_pixtral_params(config.vision_config) + config_params["vision_config"] = PixtralVisionConfig(**vision_config_params) + return config_params + +def build_ministral3_text_params(config: PretrainedConfig) -> dict:# + """Param builder for mapping MistralForCausalLM to FMS.""" + config_params = { + "activation_fn": config.hidden_act, + "src_vocab_size": config.vocab_size, + "nheads": config.num_attention_heads, + "nlayers": config.num_hidden_layers, + "emb_dim": config.hidden_size, + "max_expected_seq_len": config.max_position_embeddings, + "kvheads": config.num_key_value_heads, + "p_dropout": config.attention_dropout, + "head_dim": ( + getattr(config, "head_dim", None) + or config.hidden_size // config.num_attention_heads + ), + "norm_eps": config.rms_norm_eps, + "rope_parameters": config.rope_parameters, + "sliding_window": config.sliding_window, + } + return model_params_with_common_opts( + config, config_params, inner_dim=config.intermediate_size + ) + + def model_params_with_common_opts( config: PretrainedConfig, config_params: dict, inner_dim: int ) -> dict: diff --git a/fms/models/hf/modeling_hf_adapter.py b/fms/models/hf/modeling_hf_adapter.py index ca1ed8adb..4d7870393 100644 --- a/fms/models/hf/modeling_hf_adapter.py +++ b/fms/models/hf/modeling_hf_adapter.py @@ -1,13 +1,14 @@ import abc import copy import os +from packaging.version import Version from typing import Callable, Dict, Optional, Tuple, Union import torch from torch import nn from torch.nn.modules.loss import _Loss from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin -from transformers.modeling_utils import no_init_weights +from transformers import __version__ as tf_version from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -15,6 +16,13 @@ ) from transformers.utils import ModelOutput, is_torch_fx_proxy +## Address transformers API changes +if Version(tf_version) > Version("5.0.0"): + from transformers.initialization import no_init_weights +else: + from transformers.modeling_utils import no_init_weights + + from fms.models.hf.utils import mask_2d_to_3d, mask_2d_to_3d_bidirectional diff --git a/fms/models/hf/utils.py b/fms/models/hf/utils.py index 03a73b121..680877a9e 100644 --- a/fms/models/hf/utils.py +++ b/fms/models/hf/utils.py @@ -1,4 +1,5 @@ import os.path +import logging from typing import Any, Dict, Optional, Union import torch @@ -14,6 +15,7 @@ from fms.models import get_model from fms.models.hf.config_utils import _FMS_MODEL_CONFIG_REGISTRY +logger = logging.getLogger(__name__) def register_fms_models(): """Register all FMS models with huggingface AutoModels""" @@ -159,6 +161,19 @@ def infer_model_configuration( config = AutoConfig.from_pretrained(model_path) + ## HACK to map Mistral3ForConditionalGeneration to Ministral3 class successfully + + if config.architectures[0] == "Mistral3ForConditionalGeneration" and \ + config.text_config.model_type == "ministral3": + config.architectures = ["FMSMinistral3ForConditionalGeneration"] + logger.warning( + "%s architecture detected with ministral3 text_config.model_type." + "This will get remapped to FMSMinistral3ForConditionalGeneration for" + "building params and configuring class accordingly!" + % config.architectures[0] + ) + + config_params = _FMS_MODEL_CONFIG_REGISTRY.hf_config_to_fms_config_params( config, model_path=model_path if download_weights else None, diff --git a/fms/models/ministral3.py b/fms/models/ministral3.py new file mode 100644 index 000000000..4ce038ec3 --- /dev/null +++ b/fms/models/ministral3.py @@ -0,0 +1,231 @@ +import logging +import re +from dataclasses import dataclass, field +from typing import Any, Dict, Mapping, Optional, Tuple +from typing_extensions import Unpack + +import torch +import torch.nn as nn + +from fms import models +from fms.distributed.strategy import ( + DistributedStrategy, + NoOpStrategy, +) + +from fms.modules.attention import AttentionKwargs + +from fms.utils import serialization +from fms.utils.activation import str_to_activation +from fms.utils.config import ModelConfig + +from fms.modules.layernorm import LayerNormParameterized +from fms.models.mistral import Mistral +from fms.models.mistral3 import Mistral3 +from fms.models.pixtral_vision import PixtralVisionConfig, PixtralVisionModel + +logger = logging.getLogger(__name__) + + +_architecture_name = "ministral3" + +@dataclass +class Ministral3TextConfig(ModelConfig): + src_vocab_size: int = 131072 + nheads: int = 32 + nlayers: int = 40 + hidden_grow_factor: float = 16384 / 5120 # intermediate_size / hidden_size:emb_dim + multiple_of: int = 256 # borrowed from llama + tie_heads: bool = False + p_dropout: float = 0.0 + activation_fn: str = "silu" + emb_dim: int = 5120 + head_dim: int = 128 # getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + max_expected_seq_len: int = 262144 + kvheads: int = 8 + norm_eps: float = 1e-5 + sliding_window: int = 4000 # null for ministral3 in the model itself + rope_parameters: Dict = field(default_factory=dict) + fused_weights: bool = True # FMS Specific -- For CPU/GPU = T, AIU = F + pad_id: int = -1 # borrowed from granite, we do need it + linear_config: Optional[Mapping[str, Any]] = None # To support quantization + + + +@dataclass +class Ministral3Config(ModelConfig): + """ + Composite configuration for the FMS Ministral3 multimodal model. + + This wraps a Ministral3TextConfig (text) config for Ministral3 & Pixtral vision encoder. + The current defaults correspond to Ministral3 14B, 8B, i.e., + https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512 + """ + + text_config: Ministral3TextConfig = field(default_factory=Ministral3TextConfig) + vision_config: PixtralVisionConfig = field(default_factory=PixtralVisionConfig) + projector_hidden_act: str = "gelu" + multimodal_projector_bias: bool = False + spatial_merge_size: int = 2 + image_token_index: int = 10 + vision_feature_layer: int | list[int] = -1 + ### FMS Specific + fused_weights: bool = True # True For CPU/GPU = T, False for AIU + +_14b_config = Ministral3Config() + + +# =============== Modeling ====================== + +class Ministral3(Mistral3): + pass + + + +# =============== Registration ================== + +def _ministral3_factory_factory(config): + def factory(**kwargs): + return Mistral3(config, **kwargs) + + return factory + + +models.register_model(_architecture_name, "14b", _ministral3_factory_factory(_14b_config)) + + +# =============== Serialization ================== + + +def _weight_fusion( + input_sd: Mapping[str, Any], model_config: Optional[Ministral3Config] = None, **kwargs +) -> Mapping[str, Any]: + has_fused_weights = True + if model_config: + if not model_config.fused_weights: + has_fused_weights = False + + new_sd = input_sd + if has_fused_weights: + new_sd = serialization._mlp_glu_unfused_to_fused_adapter_step( + serialization._attn_unfused_to_fused_step(new_sd) + ) + return new_sd + + +serialization.register_adapter_step(_architecture_name, "weight_fusion", _weight_fusion) + + +def _hf_to_fms_names(input_sd: Mapping[str, Any], **kwargs) -> Mapping[str, Any]: + replacements = replacements = [ + # Language Model + (r"^language_model.lm_head.weight", "language_model.head.weight"), + ( + r"^language_model.model.embed_tokens.weight", + "language_model.base_model.embedding.weight", + ), + (r"^language_model.model.norm", "language_model.base_model.dec_norm"), + (r"^language_model.model.layers", "language_model.base_model.layers"), + (r"self_attn\.k_proj", "attn.in_proj.key"), + (r"self_attn\.v_proj", "attn.in_proj.value"), + (r"self_attn\.q_proj", "attn.in_proj.query"), + (r"self_attn\.o_proj", "attn.dense"), + (r"mlp\.gate_proj", "ff_sub_layer.wg"), + (r"mlp\.up_proj", "ff_sub_layer.w1"), + (r"mlp\.down_proj", "ff_sub_layer.w2"), + (r"input_layernorm", "ln"), + (r"post_attention_layernorm", "ff_ln"), + # Vision Model + (r"feed_forward\.gate_proj", "ff_sub_layer.wg"), + (r"feed_forward\.up_proj", "ff_sub_layer.w1"), + (r"feed_forward\.down_proj", "ff_sub_layer.w2"), + (r"attention\.k_proj", "attn.in_proj.key"), + (r"attention\.v_proj", "attn.in_proj.value"), + (r"attention\.q_proj", "attn.in_proj.query"), + (r"attention\.o_proj", "attn.dense"), + ] + new_sd = {} + for name, param in input_sd.items(): + new_name = name + for pattern, repl in replacements: + new_name = re.sub(pattern, repl, new_name) + new_sd[new_name] = param + return new_sd + + +serialization.register_adapter_step( + _architecture_name, "hf_to_fms_names", _hf_to_fms_names +) + + +def _hf_to_fms_rope( + input_sd: Mapping[str, Any], model_config: Optional[Ministral3Config] = None, **kwargs +) -> Mapping[str, Any]: + new_sd = {} + + if model_config is None: + # It Fall back to values for Ministral3; ModelConfig should really not be + # optional here though, as setting the wrong head dimensions can cause a + # lot of confusion. + lm_head_dim = 128 + vision_head_dim = 64 + logger.warning("Missing model_config, assuming default text/vision head sizes") + else: + text_config = model_config.text_config + vision_config = model_config.vision_config + lm_head_dim = text_config.head_dim + vision_head_dim = vision_config.hidden_size // vision_config.nheads + + # TODO: Update this if we ever need gptq for this model arch, + # this assusmes torchj linear layers. + rope_params = ["weight", "bias"] + # Match on either the language model or vision tower attn qk + trans_required_pattern = re.compile( + "|".join( + [ + f"language_model.base_model.layers.[0-9]+.attn.in_proj.(query|key).({'|'.join(rope_params)})", + f"vision_tower.transformer.layers.[0-9]+.attn.in_proj.(query|key).({'|'.join(rope_params)})", + ] + ) + ) + for name, param in input_sd.items(): + # hf -> fms requires a transpose operation for the query and key + # weight and bias parameters for Llama models + # This transpose is due to the different implementation of RoPE in + # HF and FMS. While FMS follows the original RoPE paper + # (https://arxiv.org/abs/2104.09864), HF has its own implementation + # that doesn't respect the order of outputs. This is OK as long as you + # rearrange the weights of the query and key projections, as the + # combination projection + RoPE ends up producing the same outputs. + # Therefore, to make FMS produce the correct order of outputs when + # loading from an HF checkpoint, we need to undo the transformation + # that HF does from the original Meta weights: + if bool(trans_required_pattern.search(name)): + head_dim = lm_head_dim if "language" in name else vision_head_dim + temp = param + # num_heads is used in the transformation required for hf->fms + # can't be precomputed because q and k might have different num_heads + num_heads = temp.size(0) // head_dim + + if temp.dim() == 2: # weight + temp_view = temp.view(num_heads, 2, -1, temp.size(1)) + else: # bias + temp_view = temp.view(num_heads, 2, -1) + temp = temp_view.transpose(1, 2).reshape(*temp.size()) + + new_sd[name] = temp + else: + new_sd[name] = param + + return new_sd + + +serialization.register_adapter_step( + _architecture_name, "hf_to_fms_rope", _hf_to_fms_rope +) + +serialization.register_adapter( + _architecture_name, + "hf", + ["hf_to_fms_names", "hf_to_fms_rope", "weight_fusion"], +) From 9db6d71510c3a88f7be628b542c77477a5ee90a4 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Fri, 6 Mar 2026 15:12:59 -0600 Subject: [PATCH 10/98] :package: Make param builder work with transformers 5.x.x refactors Signed-off-by: Gaurav-Kumbhat --- fms/models/hf/config_utils/param_builders.py | 41 +++++++++++++------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/fms/models/hf/config_utils/param_builders.py b/fms/models/hf/config_utils/param_builders.py index 84bca517a..e70c2ad55 100644 --- a/fms/models/hf/config_utils/param_builders.py +++ b/fms/models/hf/config_utils/param_builders.py @@ -15,6 +15,20 @@ from transformers import PretrainedConfig +def reverse_rope_param_lookup(config: PretrainedConfig): + """This function allows fetching the rope_theta from the config + allowing compatibility with transformers 5.0 changes + """ + if hasattr(config, "rope_parameters"): + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = getattr(config.rope_parameters, "rope_scaling", None) + else: + rope_theta = config.rope_theta + rope_scaling = getattr(config, "rope_scaling", None) + + return rope_theta, rope_scaling + + def build_llama_params(config: PretrainedConfig) -> dict: """Param builder for mapping LlamaForCausalLM to FMS.""" config_params = { @@ -27,11 +41,9 @@ def build_llama_params(config: PretrainedConfig) -> dict: "max_expected_seq_len": config.max_position_embeddings, } # New in Llama 3 - rope_theta = getattr(config, "rope_theta", None) + rope_theta, rope_scaling = reverse_rope_param_lookup(config) if rope_theta is not None: config_params["rope_theta"] = rope_theta - # New in Llama 3.1 - rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None: config_params["rope_scaling"] = rope_scaling @@ -57,6 +69,7 @@ def build_gpt_bigcode_params(config: PretrainedConfig) -> dict: def build_mixtral_params(config: PretrainedConfig) -> dict: """Param builder for mapping MixtralForCausalLM to FMS.""" inner_dim = config.intermediate_size + rope_theta, _ = reverse_rope_param_lookup(config) config_params = { "dim": config.hidden_size, "hidden_dim": inner_dim, @@ -64,7 +77,7 @@ def build_mixtral_params(config: PretrainedConfig) -> dict: "kv_heads": config.num_key_value_heads, "num_experts": config.num_local_experts, "top_k_experts": config.num_experts_per_tok, - "rope_base": config.rope_theta, + "rope_base": rope_theta, "max_expected_seq_len": config.max_position_embeddings, } return model_params_with_common_opts(config, config_params, inner_dim=inner_dim) @@ -98,6 +111,7 @@ def build_roberta_params(config: PretrainedConfig, is_classify: bool = False) -> def build_granite_params(config: PretrainedConfig) -> dict: """Param builder for mapping GraniteForCausalLM to FMS.""" + rope_theta, _ = reverse_rope_param_lookup(config) config_params = { "attn_bias": getattr(config, "attention_bias", False), "mlp_bias": getattr(config, "mlp_bias", False), @@ -110,7 +124,7 @@ def build_granite_params(config: PretrainedConfig) -> dict: "attention_multiplier": config.attention_multiplier, "logits_scaling": config.logits_scaling, "embedding_multiplier": config.embedding_multiplier, - "rope_theta": config.rope_theta, + "rope_theta": rope_theta, "activation_fn": config.hidden_act, "head_dim": getattr( config, "head_dim", config.hidden_size // config.num_attention_heads @@ -127,6 +141,8 @@ def build_granite_moe_hybrid_params(config: PretrainedConfig) -> dict: # granite-v4 dense version. In future, based on the configuration # we may route to different architectures or classes. + rope_theta, _ = reverse_rope_param_lookup(config) + config_params = { "attn_bias": getattr(config, "attention_bias", False), "kvheads": config.num_key_value_heads, @@ -138,7 +154,7 @@ def build_granite_moe_hybrid_params(config: PretrainedConfig) -> dict: "attention_multiplier": config.attention_multiplier, "logits_scaling": config.logits_scaling, "embedding_multiplier": config.embedding_multiplier, - "rope_theta": config.rope_theta, + "rope_theta": rope_theta, "activation_fn": config.hidden_act, "head_dim": getattr( config, "head_dim", config.hidden_size // config.num_attention_heads @@ -151,6 +167,7 @@ def build_granite_moe_hybrid_params(config: PretrainedConfig) -> dict: def build_mistral_params(config: PretrainedConfig) -> dict: """Param builder for mapping MistralForCausalLM to FMS.""" + rope_theta, _ = reverse_rope_param_lookup(config) config_params = { "activation_fn": config.hidden_act, "emb_dim": config.hidden_size, @@ -162,7 +179,7 @@ def build_mistral_params(config: PretrainedConfig) -> dict: or config.hidden_size // config.num_attention_heads ), "norm_eps": config.rms_norm_eps, - "rope_base": config.rope_theta, + "rope_base": rope_theta, "sliding_window": config.sliding_window, } return model_params_with_common_opts( @@ -302,10 +319,7 @@ def build_pixtral_params(config: PretrainedConfig) -> dict: # config in future releases. # To handle cases such as ministral3 - if hasattr(config, "rope_parameters"): - rope_theta = config.rope_parameters["rope_theta"] - else: - rope_theta = config.rope_theta + rope_theta, _ = reverse_rope_param_lookup(config) config_params = { "hidden_size": config.hidden_size, @@ -390,8 +404,9 @@ def build_ministral3_params(config: PretrainedConfig) -> dict: config_params["vision_config"] = PixtralVisionConfig(**vision_config_params) return config_params -def build_ministral3_text_params(config: PretrainedConfig) -> dict:# - """Param builder for mapping MistralForCausalLM to FMS.""" + +def build_ministral3_text_params(config: PretrainedConfig) -> dict: # + """Param builder for mapping ministral3 with MistralForCausalLM to FMS.""" config_params = { "activation_fn": config.hidden_act, "src_vocab_size": config.vocab_size, From c0f72ac6a6faf93bff70b1fd27cdf894fcb0accc Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Fri, 6 Mar 2026 15:14:28 -0600 Subject: [PATCH 11/98] :construction: Make ministral3 tensors load into FMS format correctly Signed-off-by: Gaurav-Kumbhat --- fms/models/hf/utils.py | 23 ++- fms/models/ministral3.py | 302 +++++++++++++++++++++++++++++++++++-- fms/utils/serialization.py | 8 + 3 files changed, 311 insertions(+), 22 deletions(-) diff --git a/fms/models/hf/utils.py b/fms/models/hf/utils.py index 680877a9e..afe9fddbc 100644 --- a/fms/models/hf/utils.py +++ b/fms/models/hf/utils.py @@ -17,6 +17,7 @@ logger = logging.getLogger(__name__) + def register_fms_models(): """Register all FMS models with huggingface AutoModels""" from fms.models.hf import ( @@ -140,6 +141,11 @@ def infer_model_configuration( ): ignore_patterns = ["*.safetensors"] allow_patterns.append("*.pt") + elif isinstance(model_id_or_path, str) and model_id_or_path.startswith( + "mistralai/Ministral" + ): + ignore_patterns = ["consolidated.safetensors"] + allow_patterns.append("*.safetensors*") elif isinstance(model_id_or_path, str) and model_id_or_path.startswith( "mistralai/Mistral" ): @@ -163,16 +169,17 @@ def infer_model_configuration( ## HACK to map Mistral3ForConditionalGeneration to Ministral3 class successfully - if config.architectures[0] == "Mistral3ForConditionalGeneration" and \ - config.text_config.model_type == "ministral3": + if ( + config.architectures[0] == "Mistral3ForConditionalGeneration" + and config.text_config.model_type == "ministral3" + ): config.architectures = ["FMSMinistral3ForConditionalGeneration"] logger.warning( - "%s architecture detected with ministral3 text_config.model_type." - "This will get remapped to FMSMinistral3ForConditionalGeneration for" - "building params and configuring class accordingly!" - % config.architectures[0] - ) - + "%s architecture detected with ministral3 text_config.model_type. " + "This will get remapped to FMSMinistral3ForConditionalGeneration for" + "building params and configuring class accordingly!" + % config.architectures[0] + ) config_params = _FMS_MODEL_CONFIG_REGISTRY.hf_config_to_fms_config_params( config, diff --git a/fms/models/ministral3.py b/fms/models/ministral3.py index 4ce038ec3..64737a84a 100644 --- a/fms/models/ministral3.py +++ b/fms/models/ministral3.py @@ -1,3 +1,4 @@ +import math import logging import re from dataclasses import dataclass, field @@ -13,15 +14,18 @@ NoOpStrategy, ) -from fms.modules.attention import AttentionKwargs - -from fms.utils import serialization -from fms.utils.activation import str_to_activation from fms.utils.config import ModelConfig - +from fms.utils import serialization +from fms.utils.headless import gather_outputs +from fms.modules.attention import ( + AttentionKwargs, + MultiHeadAttention, + get_attention_type, +) +from fms.modules.feedforward import GatedLinearUnit from fms.modules.layernorm import LayerNormParameterized -from fms.models.mistral import Mistral -from fms.models.mistral3 import Mistral3 +from fms.models.mistral import MistralBlock +from fms.models.mistral3 import Mistral3, Mistral3MultiModalProjector from fms.models.pixtral_vision import PixtralVisionConfig, PixtralVisionModel logger = logging.getLogger(__name__) @@ -29,6 +33,7 @@ _architecture_name = "ministral3" + @dataclass class Ministral3TextConfig(ModelConfig): src_vocab_size: int = 131072 @@ -44,14 +49,13 @@ class Ministral3TextConfig(ModelConfig): max_expected_seq_len: int = 262144 kvheads: int = 8 norm_eps: float = 1e-5 - sliding_window: int = 4000 # null for ministral3 in the model itself + sliding_window: int = 4000 # null for ministral3 in the model itself rope_parameters: Dict = field(default_factory=dict) fused_weights: bool = True # FMS Specific -- For CPU/GPU = T, AIU = F pad_id: int = -1 # borrowed from granite, we do need it linear_config: Optional[Mapping[str, Any]] = None # To support quantization - @dataclass class Ministral3Config(ModelConfig): """ @@ -72,33 +76,301 @@ class Ministral3Config(ModelConfig): ### FMS Specific fused_weights: bool = True # True For CPU/GPU = T, False for AIU + _14b_config = Ministral3Config() # =============== Modeling ====================== + +class Ministral3Headless(nn.Module): + def __init__( + self, + config: Ministral3TextConfig, + distributed_strategy: DistributedStrategy = NoOpStrategy, + ): + super(Ministral3Headless, self).__init__() + self.config = config + self.distributed_strategy = distributed_strategy + + self.embedding = nn.Embedding( + self.config.src_vocab_size, + self.config.emb_dim, + padding_idx=self.config.pad_id, + ) + + # TODO: + self.rot_emb = None + # self.rot_emb = RotaryEmbedding( + # dim=self.config.head_dim, + # scaling=self.config.rope_scaling, + # max_seq_len=self.config.max_expected_seq_len, + # ratio=self.config.rope_base, + # ) + + # RoPE init + # for device in set( + # [param.device for param in self.parameters()] + # + [buffer.device for buffer in self.buffers()] + # ): + # self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) + + layers = [] + for i in range(self.config.nlayers): + block: nn.Module = MistralBlock(self.config, self.rot_emb) + block = self.distributed_strategy.distribute_layer(block, i) + layers.append(block) + self.layers = nn.ModuleList(layers) + + dec_norm = LayerNormParameterized( + self.config.emb_dim, + elementwise_scale=True, + elementwise_shift=False, + use_mean=False, + eps=self.config.norm_eps, + use_high_precision_pow=True, + ) + self.dec_norm = self.distributed_strategy.distribute_module( + dec_norm, final_layers=True + ) + + if self.config.p_dropout: + self.dropout = nn.Dropout(self.config.p_dropout) + + def reset_parameters(self): + nn.init.trunc_normal_( + self.embedding.weight, mean=0.0, std=self.config.emb_dim**-0.5 + ) + + # RoPE init + # TODO: + # for device in set( + # [param.device for param in self.parameters()] + # + [buffer.device for buffer in self.buffers()] + # ): + # self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) + + # Call reset_parameters for relevant sub-layers + for m in self.modules(): + if ( + isinstance(m, MultiHeadAttention) + or isinstance(m, GatedLinearUnit) + or isinstance(m, LayerNormParameterized) + ): + m.reset_parameters() + + def _clean_up_rot_emb_cache( + self, + cached_freqs: dict[Optional[torch.device], dict[int, torch.Tensor]], + max_seq_len_cached: dict[Optional[torch.device], int], + ): + # remove meta tensors from cached_freqs + for dev in list(cached_freqs.keys()): + for alp in list(cached_freqs[dev].keys()): + if cached_freqs[dev][alp].device == torch.device("meta"): + del cached_freqs[dev][alp] + if len(cached_freqs[dev]) == 0: + del cached_freqs[dev] + del max_seq_len_cached[dev] + + def post_init(self): + pass + + # def post_init(self): + # This function is called in `get_model` after the model is + # fully initalized on the correct device + + # TODO: + # self._clean_up_rot_emb_cache( + # self.rot_emb.cached_freqs, + # self.rot_emb.max_seq_len_cached, + # ) + + # init RoPE on the right device(s) + # TODO: + # for device in set( + # [param.device for param in self.parameters()] + # + [buffer.device for buffer in self.buffers()] + # ): + # self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) + + def forward( + self, + x_in, + position_ids=None, + past_key_value_states=None, + use_cache=False, + **attn_kwargs: Unpack[AttentionKwargs], + ): + # Embed the given vocabulary indices using the given attention mask, with pre-/post-norm and dropout as specified + # x_in: batch_size x seq_len + # mask: batch_size x seq_len x seq_len + # bias: nheads x seq_len x seq_len + if past_key_value_states is None or len(past_key_value_states) == 0: + past_key_value_states = [None for _ in range(len(self.layers))] + + if x_in.dim() == 2: # input is not already embedded + x_in = self.embedding(x_in) + + # this is the output cache for all the decoder layers + present_key_value_states = [] + + for i, layer in enumerate(self.layers): + output = layer( + x=x_in, + position_ids=position_ids, + past_key_value_state=past_key_value_states[i], + use_cache=use_cache, + **attn_kwargs, + ) + + if use_cache: + x_in, present_key_value_state = output + present_key_value_states.append(present_key_value_state) + + else: + x_in = output + + dec_out = x_in + dec_out = self.dec_norm(dec_out) + if self.config.p_dropout: + dec_out = self.dropout(dec_out) + + return dec_out, present_key_value_states + + +class Ministral3Text(nn.Module): + def __init__( + self, + config: Optional[Ministral3TextConfig] = None, + distributed_strategy: DistributedStrategy = NoOpStrategy, + **kwargs, + ): + super(Ministral3Text, self).__init__() + if config is not None: + self.config = config + else: + self.config = Ministral3TextConfig() + self.config = self.config.updated(**kwargs) + self.distributed_strategy = distributed_strategy + + self.base_model = Ministral3Headless(self.config, self.distributed_strategy) + self.head = nn.Linear( + self.config.emb_dim, self.config.src_vocab_size, bias=False + ) + + @classmethod + def from_config(cls, config: Ministral3TextConfig) -> "Ministral3": + return cls(config) + + def get_config(self) -> Ministral3TextConfig: + return self.config + + def reset_parameters(self): + self.head.weight.data.normal_( + 0, + 1 / math.sqrt(math.sqrt(self.config.emb_dim * self.config.src_vocab_size)), + ) + self.base_model.reset_parameters() + + def post_init(self): + # if this model ties weights, they are tied here + if self.config.tie_heads: + # handle assignment of non-meta weights to meta parameters + if self.head.weight.device == torch.device("meta"): + self.head.weight = self.base_model.embedding.weight + else: + self.base_model.embedding.weight = self.head.weight + + self.base_model.post_init() + + def forward( + self, + x: torch.LongTensor, + position_ids: Optional[torch.LongTensor] = None, + past_key_value_states: Optional[Tuple[torch.FloatTensor,]] = None, + use_cache: bool = False, + last_n_tokens: int = 0, + **attn_kwargs: Unpack[AttentionKwargs], + ): + get_attention_type(**attn_kwargs)["validate_attn_kwargs"]( + input_ids=x, + position_ids=position_ids, + past_key_value_states=past_key_value_states, + **attn_kwargs, + ) + output, cache = self.base_model( + x, position_ids, past_key_value_states, use_cache, **attn_kwargs + ) + + output = gather_outputs(output, last_n_tokens, **attn_kwargs) + preds = self.head(output) + + if use_cache: + return preds, cache + else: + return preds + + class Ministral3(Mistral3): - pass + def __init__( + self, + config: Optional[Ministral3Config] = None, + distributed_strategy: DistributedStrategy = NoOpStrategy, + **kwargs, + ): + super().__init__() + + if config is not None: + self.config = config + else: + self.config = Ministral3Config() + + self.config = self.config.updated(**kwargs) + + # Ensure weight fusion correctly propogates; + # NOTE: since pixtral is only run as a standalone model + if not self.config.fused_weights: + self.config.text_config.fused_weights = False + self.config.vision_config.fused_weights = False + + self.distributed_strategy = distributed_strategy + # Currently, we always use mistral for the LLM + self.language_model = Ministral3Text( + self.config.text_config, self.distributed_strategy + ) + # Vision encoder and projector for multimodal features + self.vision_tower = PixtralVisionModel( + self.config.vision_config, self.distributed_strategy + ) + self.multi_modal_projector = Mistral3MultiModalProjector( + self.config, + ) # =============== Registration ================== + def _ministral3_factory_factory(config): def factory(**kwargs): - return Mistral3(config, **kwargs) + return Ministral3(config, **kwargs) return factory -models.register_model(_architecture_name, "14b", _ministral3_factory_factory(_14b_config)) +models.register_model( + _architecture_name, "14b", _ministral3_factory_factory(_14b_config) +) # =============== Serialization ================== def _weight_fusion( - input_sd: Mapping[str, Any], model_config: Optional[Ministral3Config] = None, **kwargs + input_sd: Mapping[str, Any], + model_config: Optional[Ministral3Config] = None, + **kwargs, ) -> Mapping[str, Any]: has_fused_weights = True if model_config: @@ -159,7 +431,9 @@ def _hf_to_fms_names(input_sd: Mapping[str, Any], **kwargs) -> Mapping[str, Any] def _hf_to_fms_rope( - input_sd: Mapping[str, Any], model_config: Optional[Ministral3Config] = None, **kwargs + input_sd: Mapping[str, Any], + model_config: Optional[Ministral3Config] = None, + **kwargs, ) -> Mapping[str, Any]: new_sd = {} diff --git a/fms/utils/serialization.py b/fms/utils/serialization.py index 17e40adc0..9887113dd 100644 --- a/fms/utils/serialization.py +++ b/fms/utils/serialization.py @@ -400,6 +400,14 @@ def load_state_dict( file_list = list(model_path.glob(glob_pattern_possibility)) if len(file_list) > 0: checkpoints = sorted(file_list) + # Filter out consolidated.safetensors for HF models when loading from local path + # as it contains duplicate keys in original Mistral format + if source == "hf": + checkpoints = [ + ckpt + for ckpt in checkpoints + if ckpt.name != "consolidated.safetensors" + ] break if model_path.is_file(): From b471b573b4954bba2855fe10b70278cfdc52a0ba Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Sun, 8 Mar 2026 16:31:14 -0500 Subject: [PATCH 12/98] :construction: Make cached yarn vs HF impl test pass Signed-off-by: Gaurav-Kumbhat --- fms/modules/positions.py | 253 ++++++++++++++++++++++++++++++++ tests/modules/test_positions.py | 148 ++++++++++++++++++- 2 files changed, 400 insertions(+), 1 deletion(-) diff --git a/fms/modules/positions.py b/fms/modules/positions.py index 5cd51eed5..6db1cc8be 100644 --- a/fms/modules/positions.py +++ b/fms/modules/positions.py @@ -1,6 +1,7 @@ from collections import defaultdict import copy import math +from statistics import quantiles from typing import MutableMapping, Optional, Tuple import torch @@ -594,3 +595,255 @@ def adjusted_qk( q_out = mulq.sum(5).flatten(3).type_as(q) k_out = mulk.sum(5).flatten(3).type_as(k) return q_out, k_out + + + +class CachedYarnRotaryEmbedding(PositionEncoder): + def __init__( + self, + dim: int, # Rotary dimension + max_position_embeddings: int, + base: float, # Rope theta + scaling_factor: float, # factor + *, + extrapolation_factor: float = 1.0, + attn_factor: Optional[float] = None, + beta_fast: float = 32.0, + beta_slow: float = 1.0, + mscale: float = 1.0, + mscale_all_dim: float = 1.0, + ): + """ + This implements Yarn scaling rotary embedding. + + Credits to Peng et al. github.com/jquesnelle/yarn + Ref: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py + Ref: https://github.com/mistralai/vllm-release/blob/3e21dacb79471ebf946e72e67a5ca14ebcc598c1/vllm/model_executor/layers/rotary_embedding.py#L268 + """ + + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings # original_max_position_embeddings + self.base = base + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.beta_fast = beta_fast # low + self.beta_slow = beta_slow # high + + self.cached_freqs: dict[int, Tuple[torch.Tensor, torch.Tensor]] = {} + + # magnitude scaling factor + self.mscale = float( + self._yarn_get_mscale(mscale)) + + self.mscale_all_dim = float( + self._yarn_get_mscale(mscale_all_dim)) + + # NOTE: We are not computing attn_factor based on mscale here, since its not requried for ministral3 + if attn_factor is None: + attn_factor = float(self.mscale / self.mscale_all_dim) + + self.attn_factor = attn_factor + + # TODO: Currently llama_4_scaling is not applied + + + def _yarn_get_mscale(self, mscale: float = 1) -> float: + if self.scaling_factor <= 1: + return 1.0 + return 0.1 * mscale * math.log(self.scaling_factor) + 1.0 + + + def _compute_cos_sin_cache(self, inv_freq: torch.Tensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute the cos and sin cache for the rotary embedding to avoid computing + while doing the forward pass. + Args: + inv_freq: The precomputed inverse frequency tensor + Returns: + Tuple of (cos_cache, sin_cache) each with shape [max_pos, dim/2] + """ + t = torch.arange( + int(self.max_position_embeddings * self.scaling_factor), + device=device, + dtype=torch.float32 + ) + freqs = torch.outer(t, inv_freq).float() + + # Apply mscale and compute cos/sin + cos = (freqs.cos() * self.attn_factor) + sin = (freqs.sin() * self.attn_factor) + + return cos, sin + + # def _rotate(self, x: torch.Tensor) -> torch.Tensor: + # """ + # Rotate the input tensor + # Args: + # x: The input tensor + # Returns: + # The rotated tensor + # """ + # x1 = x[..., ::2] + # x2 = x[..., 1::2] + # x = torch.stack((-x2, x1), dim=-1) + # return x.flatten(-2) + + def compute_freqs_cis(self, device: torch.device) -> torch.Tensor: + """ + Compute the frequencies for the rotary embedding. + Args: + device: device to compute frequencies on + """ + + if device == torch.device("meta"): + # Protect from initializing on spyre device + raise AssertionError("Attempted to init yarn freqs on meta device") + + if device.index in self.cached_freqs: + return self.cached_freqs[device.index] + + freqs =( + self.base + **(torch.arange(0, self.dim, 2, device=device).float() / self.dim)) + + inv_freq_extrapolation = 1.0 / freqs + inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs) + + # NOTE: math.floor and math.ceil being used here are referred to as "truncate" option + low = math.floor( + self.dim * math.log(self.max_position_embeddings / + (self.beta_fast * 2 * math.pi))) / (2 * + math.log(self.base) + ) + high = math.ceil( + self.dim * math.log(self.max_position_embeddings / + (self.beta_slow * 2 * math.pi))) / (2 * + math.log(self.base) + ) + + # Make sure values are not going outside range + low = max(low, 0) + high = min(high, self.dim - 1) + + if low == high: + high += 0.001 # Prevent singularity + + # Get n-dimensional rotational scaling corrected for extrapolation + linear_func = ( + torch.arange(self.dim // 2, dtype=torch.float32, device=device + ) - low) / (high - low) + + # Compute ramp function (clamped linear interpolation) + ramp_func = torch.clamp(linear_func, 0, 1) + + # inv_freq_extrapolation_factor is the weight for extrapolation + # (1 - ramp_func) means: use extrapolation for low frequencies (< low) + # ramp_func means: use interpolation for high frequencies (> high) + inv_freq_extrapolation_factor = 1 - ramp_func + + # Blend between interpolation and extrapolation + # Note: extrapolation_factor is applied to the extrapolation frequencies + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor * self.extrapolation_factor + ) + + + # Cache the computed frequencies for this device + cos_cache, sin_cache = self._compute_cos_sin_cache(inv_freq, device) + self.cached_freqs[device.index] = (cos_cache, sin_cache) + + return self.cached_freqs[device.index] + + def adjusted_qk( + self, + q: torch.Tensor, + k: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + past_kv_state: Optional[Tuple[torch.Tensor | None, torch.Tensor | None]] = None, + use_cache=False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This function applies 1D rotary embeddings to the queries and keys using interleaved rotation. + Args + ---- + q : torch.Tensor + Embedded query tensor, expected size is B x S x H x D + where B=batch, S=sequence length, H=num heads, D=head dimension + k : torch.Tensor + Embedded key tensor, expected size is B x S x H x D + where B=batch, S=sequence length, H=num heads, D=head dimension + position_ids : Optional[torch.LongTensor] + The position of each of the tokens encoded in q and k. This is important in + kv-caching and left-padding situations, for which the rotation to be applied might + not always be the pre-cached position 0...S. For kv-caching without dynamic batching + or variable per-row left padding position_ids is shared for all the batch. + past_kv_state : Optional[Tuple[torch.Tensor | None, torch.Tensor | None]] + Past key-value states for caching + use_cache : bool + Whether to use KV caching + """ + + assert len(q.size()) == 4 + assert len(k.size()) == 4 + + seq_len = max(k.size(1), q.size(1)) + if position_ids is None: + # Compute position_ids based on cache config + position_ids = torch.arange( + 0, seq_len, dtype=torch.long, device=q.device + ).repeat(k.size(0), 1) + if ( + use_cache + and past_kv_state is not None + and past_kv_state[0] is not None + and past_kv_state[0].numel() > 0 + ): + position_ids += past_kv_state[0].size(2) + + # Fetch the cos and sin values for the given position_ids from cache + cos_cache, sin_cache = self.compute_freqs_cis(q.device) + + # Index by position_ids: [B, L] -> [B, L, rotary_dim/2] + cos = cos_cache[position_ids] + sin = sin_cache[position_ids] + + # Unsqueeze to add head dimension: [B, L, rotary_dim/2] -> [B, L, 1, rotary_dim/2] + cos = cos.unsqueeze(2) + sin = sin.unsqueeze(2) + + # Only apply rotation to the first self.dim dimensions + # Extract the rotary portion + q_rope = q[..., : self.dim] # [B, L, H, rotary_dim] + k_rope = k[..., : self.dim] # [B, L, H, rotary_dim] + + # Reshape for interleaved rotation + # From [B, L, H, rotary_dim] to [B, L, H, rotary_dim/2, 2] for interleaved pairs + q_ = q_rope.float().view(*q_rope.size()[:-1], -1, 2) # B L H rotary_dim/2 2 + k_ = k_rope.float().view(*k_rope.size()[:-1], -1, 2) # B L H rotary_dim/2 2 + + # Apply interleaved rotation: [x, y] -> [x*cos - y*sin, y*cos + x*sin] + # cos and sin have shape [B, L, 1, rotary_dim/2], broadcast to [B, L, H, rotary_dim/2] + q_out = torch.stack( + [ + q_[..., 0] * cos - q_[..., 1] * sin, # x*cos - y*sin + q_[..., 1] * cos + q_[..., 0] * sin, # y*cos + x*sin + ], + dim=-1, + ).flatten(-2).type_as(q) + + k_out = torch.stack( + [ + k_[..., 0] * cos - k_[..., 1] * sin, + k_[..., 1] * cos + k_[..., 0] * sin, + ], + dim=-1, + ).flatten(-2).type_as(k) + + # Concatenate with the non-rotated portion if rotary_dim < head_dim + if self.dim < q.size(-1): + q_out = torch.cat([q_out, q[..., self.dim :]], dim=-1) + k_out = torch.cat([k_out, k[..., self.dim :]], dim=-1) + + return q_out, k_out \ No newline at end of file diff --git a/tests/modules/test_positions.py b/tests/modules/test_positions.py index 0c5a55a0b..735d16498 100644 --- a/tests/modules/test_positions.py +++ b/tests/modules/test_positions.py @@ -4,7 +4,7 @@ import pytest import torch -from fms.modules.positions import RotaryEmbedding, PixtralRotaryEmbedding +from fms.modules.positions import RotaryEmbedding, PixtralRotaryEmbedding, CachedYarnRotaryEmbedding class RotaryEmbeddingTests(unittest.TestCase): @@ -429,3 +429,149 @@ def permute_fms_to_hf(tensor): # Compare results torch.testing.assert_close(adjusted_query_fms, query_hf, rtol=1e-4, atol=1e-5) torch.testing.assert_close(adjusted_key_fms, key_hf, rtol=1e-4, atol=1e-5) + + + +class CachedYarnRotaryEmbeddingTests(unittest.TestCase): + def test_args(self): + """Test that CachedYarnRotaryEmbedding validates input shapes correctly""" + q = torch.ones(2, 4, 1, 16, dtype=torch.float) # b s h e + k = 2 * torch.ones(2, 4, 1, 16, dtype=torch.float) # b s h e + yarn_rope = CachedYarnRotaryEmbedding( + dim=16, + max_position_embeddings=32, + base=10000, + scaling_factor=1.0, + ) + + with self.assertRaises(AssertionError): + qr, kr = yarn_rope.adjusted_qk(q.squeeze(), k) + + with self.assertRaises(AssertionError): + qr, kr = yarn_rope.adjusted_qk(q, k.squeeze()) + + # This should not throw, as position_ids is optional + qr, kr = yarn_rope.adjusted_qk(q, k) + + # This should not throw + qr, kr = yarn_rope.adjusted_qk( + q, + k, + torch.arange(0, q.size(1), device=q.device, dtype=torch.long).unsqueeze(0), + None, + ) + + + def test_meta_device_error(self): + """Test that attempting to compute on meta device raises an error""" + yarn_rope = CachedYarnRotaryEmbedding( + dim=16, + max_position_embeddings=32, + base=10000, + scaling_factor=1.0, + ) + + with self.assertRaises(AssertionError): + yarn_rope.compute_freqs_cis(torch.device("meta")) + + + def test_hf_fms_equivalence(self): + """Test that FMS CachedYarn RoPE matches HF Transformers implementation""" + try: + from transformers.models.ministral3.modeling_ministral3 import ( + Ministral3Config, + Ministral3RotaryEmbedding, + apply_rotary_pos_emb, + ) + except ImportError: + self.skipTest("Unable to import Transformer's Ministral3 Model / Config") + + + # Configuration + dim = 32 + num_heads = 4 + rope_theta = 1000000000.0 + beta_fast = 32.0 + beta_slow = 1.0 + scaling_factor = 16.0 + original_max_position_embeddings = 16384 # Dummy Value + llama_4_scaling_beta = 0.1 + mscale = 1.0 + mscale_all_dim = 1.0 + + + # Create sample inputs + seq_len = 8 + q = torch.ones(2, seq_len, num_heads, dim, dtype=torch.float) # B x S x H x D + k = 2 * torch.ones(2, seq_len, num_heads, dim, dtype=torch.float) # B x S x H x D + + position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).repeat(2, 1) + + + ############ Get HF results + hf_config = Ministral3Config( + **{ + "hidden_size": 1024, + "head_dim": dim, + "rope_parameters": { + "rope_type": "yarn", + "rope_theta": rope_theta, + "beta_fast": beta_fast, + "beta_slow": beta_slow, + "factor": scaling_factor, + "original_max_position_embeddings": original_max_position_embeddings, + "llama_4_scaling_beta": llama_4_scaling_beta, + "mscale": mscale, + "mscale_all_dim": mscale_all_dim + } + } + ) + + transformers_emb = Ministral3RotaryEmbedding(hf_config) + cos, sin = transformers_emb(q, position_ids) + query_hf, key_hf = apply_rotary_pos_emb( + q, k, cos, sin, unsqueeze_dim=2 + ) + + ############ Get FMS results + fms_emb = CachedYarnRotaryEmbedding( + dim, + max_position_embeddings=original_max_position_embeddings, + base=rope_theta, + scaling_factor=scaling_factor, + beta_fast=beta_fast, + beta_slow=beta_slow, + mscale=mscale, + ) + + # FMS adjusted_qk expects [B, S, H, D] format + # Both HF and FMS receive the same input tensors + query_fms, key_fms = fms_emb.adjusted_qk( + q, k, position_ids + ) + + # Convert FMS output back to HF format for comparison + def permute_fms_to_hf(tensor): + """ + Permute tensor from FMS RoPE format to HF RoPE format. + FMS: [x0, y0, x1, y1, x2, y2, ..., x15, y15] (interleaved pairs) + HF: [x0, x1, x2, ..., x15, y0, y1, y2, ..., y15] (split halves) + """ + *batch_dims, head_dim = tensor.shape + half_dim = head_dim // 2 + # Reshape to separate interleaved pairs + paired = tensor.reshape(*batch_dims, half_dim, 2) + # Split into first and second elements of each pair + first_half = paired[..., 0] # x0, x1, x2, ..., x15 + second_half = paired[..., 1] # y0, y1, y2, ..., y15 + # Concatenate: first all x's, then all y's + return torch.cat([first_half, second_half], dim=-1) + + adjusted_query_fms = permute_fms_to_hf(query_fms) + adjusted_key_fms = permute_fms_to_hf(key_fms) + + # Compare results + torch.testing.assert_close(adjusted_query_fms, query_hf, rtol=1e-3, atol=1e-4) + torch.testing.assert_close(adjusted_key_fms, key_hf, rtol=1e-3, atol=1e-4) + + From 67b034c4e80111acdb2c33e6849d15ea4990383c Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Sun, 8 Mar 2026 16:43:08 -0500 Subject: [PATCH 13/98] :coffin: Cleanup Signed-off-by: Gaurav-Kumbhat --- fms/modules/positions.py | 81 +++++++++++++++++----------------------- 1 file changed, 34 insertions(+), 47 deletions(-) diff --git a/fms/modules/positions.py b/fms/modules/positions.py index 6db1cc8be..c830cb6c5 100644 --- a/fms/modules/positions.py +++ b/fms/modules/positions.py @@ -630,7 +630,7 @@ def __init__( self.beta_fast = beta_fast # low self.beta_slow = beta_slow # high - self.cached_freqs: dict[int, Tuple[torch.Tensor, torch.Tensor]] = {} + self.cached_freqs: dict[int, torch.Tensor] = {} # magnitude scaling factor self.mscale = float( @@ -654,14 +654,14 @@ def _yarn_get_mscale(self, mscale: float = 1) -> float: return 0.1 * mscale * math.log(self.scaling_factor) + 1.0 - def _compute_cos_sin_cache(self, inv_freq: torch.Tensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: + def _compute_cos_sin_cache(self, inv_freq: torch.Tensor, device: torch.device) -> torch.Tensor: """ - Compute the cos and sin cache for the rotary embedding to avoid computing + Compute the rotation matrix cache for the rotary embedding to avoid computing while doing the forward pass. Args: inv_freq: The precomputed inverse frequency tensor Returns: - Tuple of (cos_cache, sin_cache) each with shape [max_pos, dim/2] + Rotation matrices with shape [max_pos, dim/2, 2, 2] """ t = torch.arange( int(self.max_position_embeddings * self.scaling_factor), @@ -674,20 +674,13 @@ def _compute_cos_sin_cache(self, inv_freq: torch.Tensor, device: torch.device) - cos = (freqs.cos() * self.attn_factor) sin = (freqs.sin() * self.attn_factor) - return cos, sin - - # def _rotate(self, x: torch.Tensor) -> torch.Tensor: - # """ - # Rotate the input tensor - # Args: - # x: The input tensor - # Returns: - # The rotated tensor - # """ - # x1 = x[..., ::2] - # x2 = x[..., 1::2] - # x = torch.stack((-x2, x1), dim=-1) - # return x.flatten(-2) + # Construct rotation matrices: [max_pos, dim/2, 2, 2] + # Matrix form: [[cos, -sin], [sin, cos]] + freqs_cis = torch.stack( + [cos, -sin, sin, cos], dim=-1 + ).view(*cos.shape, 2, 2) + + return freqs_cis def compute_freqs_cis(self, device: torch.device) -> torch.Tensor: """ @@ -750,9 +743,9 @@ def compute_freqs_cis(self, device: torch.device) -> torch.Tensor: ) - # Cache the computed frequencies for this device - cos_cache, sin_cache = self._compute_cos_sin_cache(inv_freq, device) - self.cached_freqs[device.index] = (cos_cache, sin_cache) + # Cache the computed rotation matrices for this device + freqs_cis = self._compute_cos_sin_cache(inv_freq, device) + self.cached_freqs[device.index] = freqs_cis return self.cached_freqs[device.index] @@ -802,16 +795,11 @@ def adjusted_qk( ): position_ids += past_kv_state[0].size(2) - # Fetch the cos and sin values for the given position_ids from cache - cos_cache, sin_cache = self.compute_freqs_cis(q.device) + # Fetch the rotation matrices from cache + freqs_cis = self.compute_freqs_cis(q.device) - # Index by position_ids: [B, L] -> [B, L, rotary_dim/2] - cos = cos_cache[position_ids] - sin = sin_cache[position_ids] - - # Unsqueeze to add head dimension: [B, L, rotary_dim/2] -> [B, L, 1, rotary_dim/2] - cos = cos.unsqueeze(2) - sin = sin.unsqueeze(2) + # Index by position_ids: [B, L] -> [B, L, rotary_dim/2, 2, 2] + freqs = freqs_cis[position_ids] # Only apply rotation to the first self.dim dimensions # Extract the rotary portion @@ -823,23 +811,22 @@ def adjusted_qk( q_ = q_rope.float().view(*q_rope.size()[:-1], -1, 2) # B L H rotary_dim/2 2 k_ = k_rope.float().view(*k_rope.size()[:-1], -1, 2) # B L H rotary_dim/2 2 - # Apply interleaved rotation: [x, y] -> [x*cos - y*sin, y*cos + x*sin] - # cos and sin have shape [B, L, 1, rotary_dim/2], broadcast to [B, L, H, rotary_dim/2] - q_out = torch.stack( - [ - q_[..., 0] * cos - q_[..., 1] * sin, # x*cos - y*sin - q_[..., 1] * cos + q_[..., 0] * sin, # y*cos + x*sin - ], - dim=-1, - ).flatten(-2).type_as(q) - - k_out = torch.stack( - [ - k_[..., 0] * cos - k_[..., 1] * sin, - k_[..., 1] * cos + k_[..., 0] * sin, - ], - dim=-1, - ).flatten(-2).type_as(k) + # Apply rotation using matrix multiplication + # freqs: [B, L, rotary_dim/2, 2, 2] + # Add head dimension: [B, L, 1, rotary_dim/2, 2, 2] + # q_, k_: [B, L, H, rotary_dim/2, 2] + q_out = ( + freqs[:, :, None, :, :, :] + .mul(q_.unsqueeze(-2)) + .sum(-1) + .flatten(-2) + ).type_as(q) + k_out = ( + freqs[:, :, None, :, :, :] + .mul(k_.unsqueeze(-2)) + .sum(-1) + .flatten(-2) + ).type_as(k) # Concatenate with the non-rotated portion if rotary_dim < head_dim if self.dim < q.size(-1): From 1beb8593b2b638aa4daf1d27cdd0a1d00d43120f Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Sun, 8 Mar 2026 16:43:58 -0500 Subject: [PATCH 14/98] :art: Fix formatting Signed-off-by: Gaurav-Kumbhat --- fms/modules/positions.py | 80 +++++++++++++++------------------ tests/modules/test_positions.py | 32 ++++++------- 2 files changed, 48 insertions(+), 64 deletions(-) diff --git a/fms/modules/positions.py b/fms/modules/positions.py index c830cb6c5..a0009faca 100644 --- a/fms/modules/positions.py +++ b/fms/modules/positions.py @@ -597,14 +597,13 @@ def adjusted_qk( return q_out, k_out - class CachedYarnRotaryEmbedding(PositionEncoder): def __init__( self, - dim: int, # Rotary dimension + dim: int, # Rotary dimension max_position_embeddings: int, - base: float, # Rope theta - scaling_factor: float, # factor + base: float, # Rope theta + scaling_factor: float, # factor *, extrapolation_factor: float = 1.0, attn_factor: Optional[float] = None, @@ -623,21 +622,21 @@ def __init__( super().__init__() self.dim = dim - self.max_position_embeddings = max_position_embeddings # original_max_position_embeddings + self.max_position_embeddings = ( + max_position_embeddings # original_max_position_embeddings + ) self.base = base self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor - self.beta_fast = beta_fast # low - self.beta_slow = beta_slow # high + self.beta_fast = beta_fast # low + self.beta_slow = beta_slow # high self.cached_freqs: dict[int, torch.Tensor] = {} # magnitude scaling factor - self.mscale = float( - self._yarn_get_mscale(mscale)) + self.mscale = float(self._yarn_get_mscale(mscale)) - self.mscale_all_dim = float( - self._yarn_get_mscale(mscale_all_dim)) + self.mscale_all_dim = float(self._yarn_get_mscale(mscale_all_dim)) # NOTE: We are not computing attn_factor based on mscale here, since its not requried for ministral3 if attn_factor is None: @@ -647,14 +646,14 @@ def __init__( # TODO: Currently llama_4_scaling is not applied - def _yarn_get_mscale(self, mscale: float = 1) -> float: if self.scaling_factor <= 1: return 1.0 return 0.1 * mscale * math.log(self.scaling_factor) + 1.0 - - def _compute_cos_sin_cache(self, inv_freq: torch.Tensor, device: torch.device) -> torch.Tensor: + def _compute_cos_sin_cache( + self, inv_freq: torch.Tensor, device: torch.device + ) -> torch.Tensor: """ Compute the rotation matrix cache for the rotary embedding to avoid computing while doing the forward pass. @@ -666,19 +665,17 @@ def _compute_cos_sin_cache(self, inv_freq: torch.Tensor, device: torch.device) - t = torch.arange( int(self.max_position_embeddings * self.scaling_factor), device=device, - dtype=torch.float32 + dtype=torch.float32, ) freqs = torch.outer(t, inv_freq).float() # Apply mscale and compute cos/sin - cos = (freqs.cos() * self.attn_factor) - sin = (freqs.sin() * self.attn_factor) + cos = freqs.cos() * self.attn_factor + sin = freqs.sin() * self.attn_factor # Construct rotation matrices: [max_pos, dim/2, 2, 2] # Matrix form: [[cos, -sin], [sin, cos]] - freqs_cis = torch.stack( - [cos, -sin, sin, cos], dim=-1 - ).view(*cos.shape, 2, 2) + freqs_cis = torch.stack([cos, -sin, sin, cos], dim=-1).view(*cos.shape, 2, 2) return freqs_cis @@ -696,24 +693,22 @@ def compute_freqs_cis(self, device: torch.device) -> torch.Tensor: if device.index in self.cached_freqs: return self.cached_freqs[device.index] - freqs =( - self.base - **(torch.arange(0, self.dim, 2, device=device).float() / self.dim)) + freqs = self.base ** ( + torch.arange(0, self.dim, 2, device=device).float() / self.dim + ) inv_freq_extrapolation = 1.0 / freqs inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs) # NOTE: math.floor and math.ceil being used here are referred to as "truncate" option low = math.floor( - self.dim * math.log(self.max_position_embeddings / - (self.beta_fast * 2 * math.pi))) / (2 * - math.log(self.base) - ) + self.dim + * math.log(self.max_position_embeddings / (self.beta_fast * 2 * math.pi)) + ) / (2 * math.log(self.base)) high = math.ceil( - self.dim * math.log(self.max_position_embeddings / - (self.beta_slow * 2 * math.pi))) / (2 * - math.log(self.base) - ) + self.dim + * math.log(self.max_position_embeddings / (self.beta_slow * 2 * math.pi)) + ) / (2 * math.log(self.base)) # Make sure values are not going outside range low = max(low, 0) @@ -724,8 +719,8 @@ def compute_freqs_cis(self, device: torch.device) -> torch.Tensor: # Get n-dimensional rotational scaling corrected for extrapolation linear_func = ( - torch.arange(self.dim // 2, dtype=torch.float32, device=device - ) - low) / (high - low) + torch.arange(self.dim // 2, dtype=torch.float32, device=device) - low + ) / (high - low) # Compute ramp function (clamped linear interpolation) ramp_func = torch.clamp(linear_func, 0, 1) @@ -738,11 +733,12 @@ def compute_freqs_cis(self, device: torch.device) -> torch.Tensor: # Blend between interpolation and extrapolation # Note: extrapolation_factor is applied to the extrapolation frequencies inv_freq = ( - inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + - inv_freq_extrapolation * inv_freq_extrapolation_factor * self.extrapolation_factor + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation + * inv_freq_extrapolation_factor + * self.extrapolation_factor ) - # Cache the computed rotation matrices for this device freqs_cis = self._compute_cos_sin_cache(inv_freq, device) self.cached_freqs[device.index] = freqs_cis @@ -816,16 +812,10 @@ def adjusted_qk( # Add head dimension: [B, L, 1, rotary_dim/2, 2, 2] # q_, k_: [B, L, H, rotary_dim/2, 2] q_out = ( - freqs[:, :, None, :, :, :] - .mul(q_.unsqueeze(-2)) - .sum(-1) - .flatten(-2) + freqs[:, :, None, :, :, :].mul(q_.unsqueeze(-2)).sum(-1).flatten(-2) ).type_as(q) k_out = ( - freqs[:, :, None, :, :, :] - .mul(k_.unsqueeze(-2)) - .sum(-1) - .flatten(-2) + freqs[:, :, None, :, :, :].mul(k_.unsqueeze(-2)).sum(-1).flatten(-2) ).type_as(k) # Concatenate with the non-rotated portion if rotary_dim < head_dim @@ -833,4 +823,4 @@ def adjusted_qk( q_out = torch.cat([q_out, q[..., self.dim :]], dim=-1) k_out = torch.cat([k_out, k[..., self.dim :]], dim=-1) - return q_out, k_out \ No newline at end of file + return q_out, k_out diff --git a/tests/modules/test_positions.py b/tests/modules/test_positions.py index 735d16498..fc88cf174 100644 --- a/tests/modules/test_positions.py +++ b/tests/modules/test_positions.py @@ -4,7 +4,11 @@ import pytest import torch -from fms.modules.positions import RotaryEmbedding, PixtralRotaryEmbedding, CachedYarnRotaryEmbedding +from fms.modules.positions import ( + RotaryEmbedding, + PixtralRotaryEmbedding, + CachedYarnRotaryEmbedding, +) class RotaryEmbeddingTests(unittest.TestCase): @@ -431,7 +435,6 @@ def permute_fms_to_hf(tensor): torch.testing.assert_close(adjusted_key_fms, key_hf, rtol=1e-4, atol=1e-5) - class CachedYarnRotaryEmbeddingTests(unittest.TestCase): def test_args(self): """Test that CachedYarnRotaryEmbedding validates input shapes correctly""" @@ -461,7 +464,6 @@ def test_args(self): None, ) - def test_meta_device_error(self): """Test that attempting to compute on meta device raises an error""" yarn_rope = CachedYarnRotaryEmbedding( @@ -474,7 +476,6 @@ def test_meta_device_error(self): with self.assertRaises(AssertionError): yarn_rope.compute_freqs_cis(torch.device("meta")) - def test_hf_fms_equivalence(self): """Test that FMS CachedYarn RoPE matches HF Transformers implementation""" try: @@ -486,7 +487,6 @@ def test_hf_fms_equivalence(self): except ImportError: self.skipTest("Unable to import Transformer's Ministral3 Model / Config") - # Configuration dim = 32 num_heads = 4 @@ -494,20 +494,20 @@ def test_hf_fms_equivalence(self): beta_fast = 32.0 beta_slow = 1.0 scaling_factor = 16.0 - original_max_position_embeddings = 16384 # Dummy Value + original_max_position_embeddings = 16384 # Dummy Value llama_4_scaling_beta = 0.1 mscale = 1.0 mscale_all_dim = 1.0 - # Create sample inputs seq_len = 8 q = torch.ones(2, seq_len, num_heads, dim, dtype=torch.float) # B x S x H x D - k = 2 * torch.ones(2, seq_len, num_heads, dim, dtype=torch.float) # B x S x H x D + k = 2 * torch.ones( + 2, seq_len, num_heads, dim, dtype=torch.float + ) # B x S x H x D position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).repeat(2, 1) - ############ Get HF results hf_config = Ministral3Config( **{ @@ -522,16 +522,14 @@ def test_hf_fms_equivalence(self): "original_max_position_embeddings": original_max_position_embeddings, "llama_4_scaling_beta": llama_4_scaling_beta, "mscale": mscale, - "mscale_all_dim": mscale_all_dim - } + "mscale_all_dim": mscale_all_dim, + }, } ) transformers_emb = Ministral3RotaryEmbedding(hf_config) cos, sin = transformers_emb(q, position_ids) - query_hf, key_hf = apply_rotary_pos_emb( - q, k, cos, sin, unsqueeze_dim=2 - ) + query_hf, key_hf = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=2) ############ Get FMS results fms_emb = CachedYarnRotaryEmbedding( @@ -546,9 +544,7 @@ def test_hf_fms_equivalence(self): # FMS adjusted_qk expects [B, S, H, D] format # Both HF and FMS receive the same input tensors - query_fms, key_fms = fms_emb.adjusted_qk( - q, k, position_ids - ) + query_fms, key_fms = fms_emb.adjusted_qk(q, k, position_ids) # Convert FMS output back to HF format for comparison def permute_fms_to_hf(tensor): @@ -573,5 +569,3 @@ def permute_fms_to_hf(tensor): # Compare results torch.testing.assert_close(adjusted_query_fms, query_hf, rtol=1e-3, atol=1e-4) torch.testing.assert_close(adjusted_key_fms, key_hf, rtol=1e-3, atol=1e-4) - - From 9911dcfa878a92ee29d733a2142b2372dd77b838 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Sun, 8 Mar 2026 17:07:03 -0500 Subject: [PATCH 15/98] :recycle: Align nomenclature Signed-off-by: Gaurav-Kumbhat --- fms/modules/positions.py | 23 +++++++++++++++-------- tests/modules/test_positions.py | 28 +++++++++++++++++++++++++--- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/fms/modules/positions.py b/fms/modules/positions.py index a0009faca..71d4b558c 100644 --- a/fms/modules/positions.py +++ b/fms/modules/positions.py @@ -601,7 +601,7 @@ class CachedYarnRotaryEmbedding(PositionEncoder): def __init__( self, dim: int, # Rotary dimension - max_position_embeddings: int, + original_max_position_embeddings: int, base: float, # Rope theta scaling_factor: float, # factor *, @@ -611,6 +611,7 @@ def __init__( beta_slow: float = 1.0, mscale: float = 1.0, mscale_all_dim: float = 1.0, + llama_4_scaling_beta: Optional[float] = None, ): """ This implements Yarn scaling rotary embedding. @@ -622,14 +623,15 @@ def __init__( super().__init__() self.dim = dim - self.max_position_embeddings = ( - max_position_embeddings # original_max_position_embeddings + self.original_max_position_embeddings = ( + original_max_position_embeddings ) self.base = base self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor self.beta_fast = beta_fast # low self.beta_slow = beta_slow # high + self.llama_4_scaling_beta = llama_4_scaling_beta self.cached_freqs: dict[int, torch.Tensor] = {} @@ -644,8 +646,6 @@ def __init__( self.attn_factor = attn_factor - # TODO: Currently llama_4_scaling is not applied - def _yarn_get_mscale(self, mscale: float = 1) -> float: if self.scaling_factor <= 1: return 1.0 @@ -663,7 +663,7 @@ def _compute_cos_sin_cache( Rotation matrices with shape [max_pos, dim/2, 2, 2] """ t = torch.arange( - int(self.max_position_embeddings * self.scaling_factor), + int(self.original_max_position_embeddings * self.scaling_factor), device=device, dtype=torch.float32, ) @@ -679,6 +679,10 @@ def _compute_cos_sin_cache( return freqs_cis + def _get_llama_4_attn_scale(self, positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor: + scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings)) + return scaling.unsqueeze(-1) + def compute_freqs_cis(self, device: torch.device) -> torch.Tensor: """ Compute the frequencies for the rotary embedding. @@ -703,11 +707,11 @@ def compute_freqs_cis(self, device: torch.device) -> torch.Tensor: # NOTE: math.floor and math.ceil being used here are referred to as "truncate" option low = math.floor( self.dim - * math.log(self.max_position_embeddings / (self.beta_fast * 2 * math.pi)) + * math.log(self.original_max_position_embeddings / (self.beta_fast * 2 * math.pi)) ) / (2 * math.log(self.base)) high = math.ceil( self.dim - * math.log(self.max_position_embeddings / (self.beta_slow * 2 * math.pi)) + * math.log(self.original_max_position_embeddings / (self.beta_slow * 2 * math.pi)) ) / (2 * math.log(self.base)) # Make sure values are not going outside range @@ -823,4 +827,7 @@ def adjusted_qk( q_out = torch.cat([q_out, q[..., self.dim :]], dim=-1) k_out = torch.cat([k_out, k[..., self.dim :]], dim=-1) + # TODO: Apply llama_4_scaling + # if self.llama_4_scaling_beta: + return q_out, k_out diff --git a/tests/modules/test_positions.py b/tests/modules/test_positions.py index fc88cf174..ddabcfdef 100644 --- a/tests/modules/test_positions.py +++ b/tests/modules/test_positions.py @@ -442,7 +442,7 @@ def test_args(self): k = 2 * torch.ones(2, 4, 1, 16, dtype=torch.float) # b s h e yarn_rope = CachedYarnRotaryEmbedding( dim=16, - max_position_embeddings=32, + original_max_position_embeddings=32, base=10000, scaling_factor=1.0, ) @@ -468,7 +468,7 @@ def test_meta_device_error(self): """Test that attempting to compute on meta device raises an error""" yarn_rope = CachedYarnRotaryEmbedding( dim=16, - max_position_embeddings=32, + original_max_position_embeddings=32, base=10000, scaling_factor=1.0, ) @@ -476,6 +476,28 @@ def test_meta_device_error(self): with self.assertRaises(AssertionError): yarn_rope.compute_freqs_cis(torch.device("meta")) + def test_output_shapes(self): + """Test that output shapes match input shapes""" + batch_size = 2 + seq_len = 8 + num_heads = 4 + head_dim = 32 + + q = torch.randn(batch_size, seq_len, num_heads, head_dim) + k = torch.randn(batch_size, seq_len, num_heads, head_dim) + + yarn_rope = CachedYarnRotaryEmbedding( + dim=head_dim, + original_max_position_embeddings=128, + base=10000, + scaling_factor=1.0, + ) + + qr, kr = yarn_rope.adjusted_qk(q, k) + + self.assertEqual(qr.shape, q.shape) + self.assertEqual(kr.shape, k.shape) + def test_hf_fms_equivalence(self): """Test that FMS CachedYarn RoPE matches HF Transformers implementation""" try: @@ -534,7 +556,7 @@ def test_hf_fms_equivalence(self): ############ Get FMS results fms_emb = CachedYarnRotaryEmbedding( dim, - max_position_embeddings=original_max_position_embeddings, + original_max_position_embeddings=original_max_position_embeddings, base=rope_theta, scaling_factor=scaling_factor, beta_fast=beta_fast, From 01cadbc3a2ba24e8b34181b0a8e6dc301787ef25 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Mon, 9 Mar 2026 09:52:13 -0500 Subject: [PATCH 16/98] :sparkles: Apply cached yarn to ministral3 Signed-off-by: Gaurav-Kumbhat --- fms/models/ministral3.py | 80 +++++++++++++++------------------ fms/modules/positions.py | 7 +-- tests/modules/test_positions.py | 1 + 3 files changed, 40 insertions(+), 48 deletions(-) diff --git a/fms/models/ministral3.py b/fms/models/ministral3.py index 64737a84a..6c52f0d60 100644 --- a/fms/models/ministral3.py +++ b/fms/models/ministral3.py @@ -24,6 +24,7 @@ ) from fms.modules.feedforward import GatedLinearUnit from fms.modules.layernorm import LayerNormParameterized +from fms.modules.positions import CachedYarnRotaryEmbedding from fms.models.mistral import MistralBlock from fms.models.mistral3 import Mistral3, Mistral3MultiModalProjector from fms.models.pixtral_vision import PixtralVisionConfig, PixtralVisionModel @@ -99,21 +100,19 @@ def __init__( padding_idx=self.config.pad_id, ) - # TODO: - self.rot_emb = None - # self.rot_emb = RotaryEmbedding( - # dim=self.config.head_dim, - # scaling=self.config.rope_scaling, - # max_seq_len=self.config.max_expected_seq_len, - # ratio=self.config.rope_base, - # ) + self.rot_emb = CachedYarnRotaryEmbedding( + dim=self.config.head_dim, + base=self.config.rope_parameters.get("rope_theta"), + scaling_factor=config.rope_parameters.get("factor"), + **self.config.rope_parameters, + ) # RoPE init - # for device in set( - # [param.device for param in self.parameters()] - # + [buffer.device for buffer in self.buffers()] - # ): - # self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) + for device in set( + [param.device for param in self.parameters()] + + [buffer.device for buffer in self.buffers()] + ): + self.rot_emb.compute_freqs_cis(device) layers = [] for i in range(self.config.nlayers): @@ -143,12 +142,11 @@ def reset_parameters(self): ) # RoPE init - # TODO: - # for device in set( - # [param.device for param in self.parameters()] - # + [buffer.device for buffer in self.buffers()] - # ): - # self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) + for device in set( + [param.device for param in self.parameters()] + + [buffer.device for buffer in self.buffers()] + ): + self.rot_emb.compute_freqs_cis(device) # Call reset_parameters for relevant sub-layers for m in self.modules(): @@ -162,37 +160,29 @@ def reset_parameters(self): def _clean_up_rot_emb_cache( self, cached_freqs: dict[Optional[torch.device], dict[int, torch.Tensor]], - max_seq_len_cached: dict[Optional[torch.device], int], + # max_seq_len_cached: dict[Optional[torch.device], int], ): # remove meta tensors from cached_freqs for dev in list(cached_freqs.keys()): - for alp in list(cached_freqs[dev].keys()): - if cached_freqs[dev][alp].device == torch.device("meta"): - del cached_freqs[dev][alp] - if len(cached_freqs[dev]) == 0: - del cached_freqs[dev] - del max_seq_len_cached[dev] + if cached_freqs[dev].device == torch.device("meta"): + del cached_freqs[dev] + # del max_seq_len_cached[dev] def post_init(self): - pass - - # def post_init(self): - # This function is called in `get_model` after the model is - # fully initalized on the correct device - - # TODO: - # self._clean_up_rot_emb_cache( - # self.rot_emb.cached_freqs, - # self.rot_emb.max_seq_len_cached, - # ) - - # init RoPE on the right device(s) - # TODO: - # for device in set( - # [param.device for param in self.parameters()] - # + [buffer.device for buffer in self.buffers()] - # ): - # self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) + # This function is called in `get_model` after the model is + # fully initalized on the correct device + # TODO: Currently we are not adding max_seq_len_cached to the cache, so we are not cleaning it up. + self._clean_up_rot_emb_cache( + self.rot_emb.cached_freqs, + # self.rot_emb.max_seq_len_cached, + ) + + # init RoPE on the right device(s) + for device in set( + [param.device for param in self.parameters()] + + [buffer.device for buffer in self.buffers()] + ): + self.rot_emb.compute_freqs_cis(device) def forward( self, diff --git a/fms/modules/positions.py b/fms/modules/positions.py index 71d4b558c..e5ff3972c 100644 --- a/fms/modules/positions.py +++ b/fms/modules/positions.py @@ -612,6 +612,7 @@ def __init__( mscale: float = 1.0, mscale_all_dim: float = 1.0, llama_4_scaling_beta: Optional[float] = None, + **kwargs ): """ This implements Yarn scaling rotary embedding. @@ -690,9 +691,9 @@ def compute_freqs_cis(self, device: torch.device) -> torch.Tensor: device: device to compute frequencies on """ - if device == torch.device("meta"): - # Protect from initializing on spyre device - raise AssertionError("Attempted to init yarn freqs on meta device") + # if device == torch.device("meta"): + # # Protect from initializing on spyre device + # raise AssertionError("Attempted to init yarn freqs on meta device") if device.index in self.cached_freqs: return self.cached_freqs[device.index] diff --git a/tests/modules/test_positions.py b/tests/modules/test_positions.py index ddabcfdef..3ccd41c61 100644 --- a/tests/modules/test_positions.py +++ b/tests/modules/test_positions.py @@ -436,6 +436,7 @@ def permute_fms_to_hf(tensor): class CachedYarnRotaryEmbeddingTests(unittest.TestCase): + def test_args(self): """Test that CachedYarnRotaryEmbedding validates input shapes correctly""" q = torch.ones(2, 4, 1, 16, dtype=torch.float) # b s h e From b7914b5efd9edff9677f0d567878cd40c9c107e7 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Mon, 9 Mar 2026 15:53:29 +0000 Subject: [PATCH 17/98] Fix: both tensor have to be on the same device Signed-off-by: Yannick Schnider --- fms/utils/evaluation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fms/utils/evaluation.py b/fms/utils/evaluation.py index 68f37d857..d76522412 100644 --- a/fms/utils/evaluation.py +++ b/fms/utils/evaluation.py @@ -52,7 +52,7 @@ def loglikelihood_one(self, context: str, continuation: str) -> Tuple[float, boo logits = F.log_softmax(self.wrapped_model(input_ids)[0], -1) continuation_probs = logits[len(context_ids) - 1 :] loglikelihood = torch.gather( - continuation_probs, 1, torch.tensor(continuation_ids).unsqueeze(1) + continuation_probs, 1, torch.tensor(continuation_ids, device=self.device).unsqueeze(1) ).squeeze() predicted = torch.argmax(continuation_probs, -1).tolist() greedy = predicted == continuation_ids From 513c74d9f6bd85d50170ad3d7425916a18642f05 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Mon, 9 Mar 2026 16:20:37 +0000 Subject: [PATCH 18/98] Fix ruff Signed-off-by: Yannick Schnider --- fms/utils/evaluation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fms/utils/evaluation.py b/fms/utils/evaluation.py index d76522412..f433ed34e 100644 --- a/fms/utils/evaluation.py +++ b/fms/utils/evaluation.py @@ -52,7 +52,9 @@ def loglikelihood_one(self, context: str, continuation: str) -> Tuple[float, boo logits = F.log_softmax(self.wrapped_model(input_ids)[0], -1) continuation_probs = logits[len(context_ids) - 1 :] loglikelihood = torch.gather( - continuation_probs, 1, torch.tensor(continuation_ids, device=self.device).unsqueeze(1) + continuation_probs, + 1, + torch.tensor(continuation_ids, device=self.device).unsqueeze(1), ).squeeze() predicted = torch.argmax(continuation_probs, -1).tolist() greedy = predicted == continuation_ids From 7a53898761236b30a048917e08ee09a483c92cf6 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Mon, 9 Mar 2026 17:18:48 -0500 Subject: [PATCH 19/98] :construction: Get around compilation of cos and sin Signed-off-by: Gaurav-Kumbhat --- fms/models/ministral3.py | 39 +++++----- fms/modules/positions.py | 155 ++++++++++++++++++++++++++------------- 2 files changed, 121 insertions(+), 73 deletions(-) diff --git a/fms/models/ministral3.py b/fms/models/ministral3.py index 6c52f0d60..23107285a 100644 --- a/fms/models/ministral3.py +++ b/fms/models/ministral3.py @@ -100,19 +100,19 @@ def __init__( padding_idx=self.config.pad_id, ) + # Prepare rope parameters, ensuring original_max_position_embeddings matches max_expected_seq_len + rope_params = dict(self.config.rope_parameters) + rope_params['original_max_position_embeddings'] = self.config.max_expected_seq_len + self.rot_emb = CachedYarnRotaryEmbedding( dim=self.config.head_dim, base=self.config.rope_parameters.get("rope_theta"), scaling_factor=config.rope_parameters.get("factor"), - **self.config.rope_parameters, + **rope_params, ) - - # RoPE init - for device in set( - [param.device for param in self.parameters()] - + [buffer.device for buffer in self.buffers()] - ): - self.rot_emb.compute_freqs_cis(device) + # Note: RoPE rotation matrices are now pre-computed on CPU during + # CachedYarnRotaryEmbedding.__init__() to avoid cos/sin on Spyre device. + # The matrices are computed for max_expected_seq_len positions. layers = [] for i in range(self.config.nlayers): @@ -141,12 +141,8 @@ def reset_parameters(self): self.embedding.weight, mean=0.0, std=self.config.emb_dim**-0.5 ) - # RoPE init - for device in set( - [param.device for param in self.parameters()] - + [buffer.device for buffer in self.buffers()] - ): - self.rot_emb.compute_freqs_cis(device) + # Note: RoPE rotation matrices are pre-computed during __init__, + # no need to recompute them here # Call reset_parameters for relevant sub-layers for m in self.modules(): @@ -160,13 +156,14 @@ def reset_parameters(self): def _clean_up_rot_emb_cache( self, cached_freqs: dict[Optional[torch.device], dict[int, torch.Tensor]], - # max_seq_len_cached: dict[Optional[torch.device], int], + max_seq_len_cached: dict[Optional[torch.device], int], ): # remove meta tensors from cached_freqs for dev in list(cached_freqs.keys()): if cached_freqs[dev].device == torch.device("meta"): - del cached_freqs[dev] - # del max_seq_len_cached[dev] + if len(cached_freqs[dev]) == 0: + del cached_freqs[dev] + del max_seq_len_cached[dev] def post_init(self): # This function is called in `get_model` after the model is @@ -174,15 +171,17 @@ def post_init(self): # TODO: Currently we are not adding max_seq_len_cached to the cache, so we are not cleaning it up. self._clean_up_rot_emb_cache( self.rot_emb.cached_freqs, - # self.rot_emb.max_seq_len_cached, + self.rot_emb.max_seq_len_cached, ) - # init RoPE on the right device(s) + # Transfer pre-computed RoPE rotation matrices to the target device(s) + # The matrices are already computed on CPU, we just need to move them for device in set( [param.device for param in self.parameters()] + [buffer.device for buffer in self.buffers()] ): - self.rot_emb.compute_freqs_cis(device) + if device != torch.device("meta"): + self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) def forward( self, diff --git a/fms/modules/positions.py b/fms/modules/positions.py index e5ff3972c..3671ce432 100644 --- a/fms/modules/positions.py +++ b/fms/modules/positions.py @@ -635,6 +635,7 @@ def __init__( self.llama_4_scaling_beta = llama_4_scaling_beta self.cached_freqs: dict[int, torch.Tensor] = {} + self.max_seq_len_cached = {} # magnitude scaling factor self.mscale = float(self._yarn_get_mscale(mscale)) @@ -647,57 +648,29 @@ def __init__( self.attn_factor = attn_factor + # Pre-compute rotation matrices on CPU to avoid cos/sin on Spyre device + # This is critical for Spyre compatibility as cos/sin are not supported operations + self._precompute_rotation_matrices_on_cpu() + def _yarn_get_mscale(self, mscale: float = 1) -> float: if self.scaling_factor <= 1: return 1.0 return 0.1 * mscale * math.log(self.scaling_factor) + 1.0 - def _compute_cos_sin_cache( - self, inv_freq: torch.Tensor, device: torch.device - ) -> torch.Tensor: + def _precompute_rotation_matrices_on_cpu(self) -> None: """ - Compute the rotation matrix cache for the rotary embedding to avoid computing - while doing the forward pass. - Args: - inv_freq: The precomputed inverse frequency tensor - Returns: - Rotation matrices with shape [max_pos, dim/2, 2, 2] - """ - t = torch.arange( - int(self.original_max_position_embeddings * self.scaling_factor), - device=device, - dtype=torch.float32, - ) - freqs = torch.outer(t, inv_freq).float() - - # Apply mscale and compute cos/sin - cos = freqs.cos() * self.attn_factor - sin = freqs.sin() * self.attn_factor - - # Construct rotation matrices: [max_pos, dim/2, 2, 2] - # Matrix form: [[cos, -sin], [sin, cos]] - freqs_cis = torch.stack([cos, -sin, sin, cos], dim=-1).view(*cos.shape, 2, 2) - - return freqs_cis - - def _get_llama_4_attn_scale(self, positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor: - scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings)) - return scaling.unsqueeze(-1) + Pre-compute rotation matrices on CPU during initialization. + This ensures cos/sin operations are never called on Spyre device. - def compute_freqs_cis(self, device: torch.device) -> torch.Tensor: + The rotation matrices are computed for original_max_position_embeddings + and cached. They will be transferred to target devices as needed without + recomputation. """ - Compute the frequencies for the rotary embedding. - Args: - device: device to compute frequencies on - """ - - # if device == torch.device("meta"): - # # Protect from initializing on spyre device - # raise AssertionError("Attempted to init yarn freqs on meta device") - - if device.index in self.cached_freqs: - return self.cached_freqs[device.index] + # Use CPU device for all trigonometric operations + device = torch.device("cpu") + max_seq_len = self.original_max_position_embeddings + # Compute inverse frequencies freqs = self.base ** ( torch.arange(0, self.dim, 2, device=device).float() / self.dim ) @@ -731,12 +704,9 @@ def compute_freqs_cis(self, device: torch.device) -> torch.Tensor: ramp_func = torch.clamp(linear_func, 0, 1) # inv_freq_extrapolation_factor is the weight for extrapolation - # (1 - ramp_func) means: use extrapolation for low frequencies (< low) - # ramp_func means: use interpolation for high frequencies (> high) inv_freq_extrapolation_factor = 1 - ramp_func # Blend between interpolation and extrapolation - # Note: extrapolation_factor is applied to the extrapolation frequencies inv_freq = ( inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + inv_freq_extrapolation @@ -744,11 +714,77 @@ def compute_freqs_cis(self, device: torch.device) -> torch.Tensor: * self.extrapolation_factor ) - # Cache the computed rotation matrices for this device - freqs_cis = self._compute_cos_sin_cache(inv_freq, device) - self.cached_freqs[device.index] = freqs_cis + # Compute position encodings + t = torch.arange( + max_seq_len, + device=device, + dtype=torch.float32, + ) + freqs = torch.outer(t, inv_freq).float() - return self.cached_freqs[device.index] + # Apply mscale and compute cos/sin ON CPU + # This is the critical part - cos/sin are computed here once and never again + cos = freqs.cos() * self.attn_factor + sin = freqs.sin() * self.attn_factor + + # Construct rotation matrices: [max_pos, dim/2, 2, 2] + # Matrix form: [[cos, -sin], [sin, cos]] + freqs_cis = torch.stack([cos, -sin, sin, cos], dim=-1).view(*cos.shape, 2, 2) + + # Cache on CPU (device index -1 for CPU or use 'cpu' as key) + # We'll use -1 as a special key for CPU + self.cached_freqs[-1] = freqs_cis + self.max_seq_len_cached[-1] = max_seq_len + + def _get_llama_4_attn_scale(self, positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor: + scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings)) + return scaling.unsqueeze(-1) + + def compute_freqs_cis(self, device: torch.device, max_seq_len: int) -> torch.Tensor: + """ + Transfer pre-computed rotation matrices to the target device. + + This method no longer computes cos/sin - those are pre-computed on CPU + during __init__. This method only handles device transfers. + + Args: + device: target device to transfer rotation matrices to + max_seq_len: maximum sequence length (must not exceed original_max_position_embeddings) + """ + + if device == torch.device("meta"): + # Protect from initializing on meta device + # Return None to signal that cache computation should be skipped + return None + + dev_idx = device.index if device.type != "cpu" else -1 + + # If already cached on this device, return it + if dev_idx in self.cached_freqs: + return self.cached_freqs[dev_idx] + + # Check if we have the CPU cache (should always be true after __init__) + if -1 not in self.cached_freqs: + raise RuntimeError( + "CPU rotation matrices not found. This should have been computed in __init__." + ) + + # Verify sequence length doesn't exceed what we pre-computed + if max_seq_len > self.original_max_position_embeddings: + raise ValueError( + f"Requested max_seq_len ({max_seq_len}) exceeds " + f"original_max_position_embeddings ({self.original_max_position_embeddings}). " + f"CachedYarnRotaryEmbedding pre-computes rotation matrices during initialization " + f"and cannot dynamically extend beyond the configured maximum." + ) + + # Transfer pre-computed rotation matrices from CPU to target device + # This is a simple tensor copy - no cos/sin computation on target device + cpu_freqs = self.cached_freqs[-1] + self.cached_freqs[dev_idx] = cpu_freqs.to(device) + self.max_seq_len_cached[dev_idx] = self.max_seq_len_cached[-1] + + return self.cached_freqs[dev_idx] def adjusted_qk( self, @@ -796,11 +832,18 @@ def adjusted_qk( ): position_ids += past_kv_state[0].size(2) + # the max start position should be based on the max first position of each sequence + max_start_pos = torch.max(position_ids[:, 0]) + # Fetch the rotation matrices from cache - freqs_cis = self.compute_freqs_cis(q.device) + self.compute_freqs_cis(q.device, max_start_pos + seq_len) + + # Get device index for cache lookup + # Use -1 for CPU, otherwise use the device index + dev_idx = q.device.index if q.device.index is not None else -1 # Index by position_ids: [B, L] -> [B, L, rotary_dim/2, 2, 2] - freqs = freqs_cis[position_ids] + freqs = self.cached_freqs[dev_idx][position_ids].float() # Only apply rotation to the first self.dim dimensions # Extract the rotary portion @@ -817,10 +860,16 @@ def adjusted_qk( # Add head dimension: [B, L, 1, rotary_dim/2, 2, 2] # q_, k_: [B, L, H, rotary_dim/2, 2] q_out = ( - freqs[:, :, None, :, :, :].mul(q_.unsqueeze(-2)).sum(-1).flatten(-2) + freqs[:, -q.size(1) :, None, :, :, :] + .mul(q_.unsqueeze(-2)) + .sum(5) + .flatten(3) ).type_as(q) k_out = ( - freqs[:, :, None, :, :, :].mul(k_.unsqueeze(-2)).sum(-1).flatten(-2) + freqs[:, -k.size(1) :, None, :, :, :] + .mul(k_.unsqueeze(-2)) + .sum(5) + .flatten(3) ).type_as(k) # Concatenate with the non-rotated portion if rotary_dim < head_dim From 168befa559e3abdb944dd40907c4c73a3a8b3507 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Wed, 11 Mar 2026 22:09:46 +0000 Subject: [PATCH 20/98] :recycle::bug: Fix CachedYarnRope bug for q and v transformation and refactor CachedYarnRope implementation Signed-off-by: Gaurav-Kumbhat --- fms/models/ministral3.py | 14 +++++++++----- fms/modules/positions.py | 24 +++++++++++------------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/fms/models/ministral3.py b/fms/models/ministral3.py index 23107285a..b4accb5ec 100644 --- a/fms/models/ministral3.py +++ b/fms/models/ministral3.py @@ -24,7 +24,7 @@ ) from fms.modules.feedforward import GatedLinearUnit from fms.modules.layernorm import LayerNormParameterized -from fms.modules.positions import CachedYarnRotaryEmbedding +from fms.modules.positions import CachedYarnRotaryEmbedding, RotaryEmbedding from fms.models.mistral import MistralBlock from fms.models.mistral3 import Mistral3, Mistral3MultiModalProjector from fms.models.pixtral_vision import PixtralVisionConfig, PixtralVisionModel @@ -110,6 +110,11 @@ def __init__( scaling_factor=config.rope_parameters.get("factor"), **rope_params, ) + for device in set( + [param.device for param in self.parameters()] + + [buffer.device for buffer in self.buffers()] + ): + self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) # Note: RoPE rotation matrices are now pre-computed on CPU during # CachedYarnRotaryEmbedding.__init__() to avoid cos/sin on Spyre device. # The matrices are computed for max_expected_seq_len positions. @@ -174,14 +179,13 @@ def post_init(self): self.rot_emb.max_seq_len_cached, ) - # Transfer pre-computed RoPE rotation matrices to the target device(s) - # The matrices are already computed on CPU, we just need to move them + # init RoPE on the right device(s) for device in set( [param.device for param in self.parameters()] + [buffer.device for buffer in self.buffers()] ): - if device != torch.device("meta"): - self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) + self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) + def forward( self, diff --git a/fms/modules/positions.py b/fms/modules/positions.py index 3671ce432..b8d89c987 100644 --- a/fms/modules/positions.py +++ b/fms/modules/positions.py @@ -649,7 +649,6 @@ def __init__( self.attn_factor = attn_factor # Pre-compute rotation matrices on CPU to avoid cos/sin on Spyre device - # This is critical for Spyre compatibility as cos/sin are not supported operations self._precompute_rotation_matrices_on_cpu() def _yarn_get_mscale(self, mscale: float = 1) -> float: @@ -727,6 +726,7 @@ def _precompute_rotation_matrices_on_cpu(self) -> None: cos = freqs.cos() * self.attn_factor sin = freqs.sin() * self.attn_factor + # return cos, sin, freqs # Construct rotation matrices: [max_pos, dim/2, 2, 2] # Matrix form: [[cos, -sin], [sin, cos]] freqs_cis = torch.stack([cos, -sin, sin, cos], dim=-1).view(*cos.shape, 2, 2) @@ -753,15 +753,12 @@ def compute_freqs_cis(self, device: torch.device, max_seq_len: int) -> torch.Ten """ if device == torch.device("meta"): - # Protect from initializing on meta device - # Return None to signal that cache computation should be skipped return None - dev_idx = device.index if device.type != "cpu" else -1 + if device.index in self.cached_freqs: + return None - # If already cached on this device, return it - if dev_idx in self.cached_freqs: - return self.cached_freqs[dev_idx] + dev_idx = device.index # Check if we have the CPU cache (should always be true after __init__) if -1 not in self.cached_freqs: @@ -841,14 +838,16 @@ def adjusted_qk( # Get device index for cache lookup # Use -1 for CPU, otherwise use the device index dev_idx = q.device.index if q.device.index is not None else -1 + # dev_idx = q.device.index # Index by position_ids: [B, L] -> [B, L, rotary_dim/2, 2, 2] - freqs = self.cached_freqs[dev_idx][position_ids].float() + freqs = self.cached_freqs[dev_idx][position_ids] + freqs = freqs.float() # Only apply rotation to the first self.dim dimensions # Extract the rotary portion - q_rope = q[..., : self.dim] # [B, L, H, rotary_dim] - k_rope = k[..., : self.dim] # [B, L, H, rotary_dim] + q_rope = q + k_rope = k # Reshape for interleaved rotation # From [B, L, H, rotary_dim] to [B, L, H, rotary_dim/2, 2] for interleaved pairs @@ -873,9 +872,8 @@ def adjusted_qk( ).type_as(k) # Concatenate with the non-rotated portion if rotary_dim < head_dim - if self.dim < q.size(-1): - q_out = torch.cat([q_out, q[..., self.dim :]], dim=-1) - k_out = torch.cat([k_out, k[..., self.dim :]], dim=-1) + q_out = q_out.view_as(q_rope) + k_out = k_out.view_as(k_rope) # TODO: Apply llama_4_scaling # if self.llama_4_scaling_beta: From bdbc714da74e31574e8d7795c7e67b32a481cb1b Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Wed, 11 Mar 2026 22:13:36 +0000 Subject: [PATCH 21/98] :white_check_mark: Remove meta device test for CachedYarn as its not relevant anymore Signed-off-by: Gaurav-Kumbhat --- tests/modules/test_positions.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/modules/test_positions.py b/tests/modules/test_positions.py index 3ccd41c61..ee2134bbc 100644 --- a/tests/modules/test_positions.py +++ b/tests/modules/test_positions.py @@ -465,18 +465,6 @@ def test_args(self): None, ) - def test_meta_device_error(self): - """Test that attempting to compute on meta device raises an error""" - yarn_rope = CachedYarnRotaryEmbedding( - dim=16, - original_max_position_embeddings=32, - base=10000, - scaling_factor=1.0, - ) - - with self.assertRaises(AssertionError): - yarn_rope.compute_freqs_cis(torch.device("meta")) - def test_output_shapes(self): """Test that output shapes match input shapes""" batch_size = 2 From 743a4a7a10309c21c2c51c2cf46cbb3727f334f8 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Thu, 12 Mar 2026 18:13:38 +0000 Subject: [PATCH 22/98] :recycle: Simplify Cached Yarn rope implementation Signed-off-by: Gaurav-Kumbhat --- fms/modules/positions.py | 139 ++++++++++++++++----------------------- fms/utils/generation.py | 23 +++++-- 2 files changed, 73 insertions(+), 89 deletions(-) diff --git a/fms/modules/positions.py b/fms/modules/positions.py index b8d89c987..f887dfc54 100644 --- a/fms/modules/positions.py +++ b/fms/modules/positions.py @@ -1,7 +1,6 @@ from collections import defaultdict import copy import math -from statistics import quantiles from typing import MutableMapping, Optional, Tuple import torch @@ -648,28 +647,64 @@ def __init__( self.attn_factor = attn_factor - # Pre-compute rotation matrices on CPU to avoid cos/sin on Spyre device - self._precompute_rotation_matrices_on_cpu() - def _yarn_get_mscale(self, mscale: float = 1) -> float: if self.scaling_factor <= 1: return 1.0 return 0.1 * mscale * math.log(self.scaling_factor) + 1.0 - def _precompute_rotation_matrices_on_cpu(self) -> None: + def _compute_cos_sin_cache( + self, inv_freq: torch.Tensor, device: torch.device + ) -> torch.Tensor: + """ + Compute the rotation matrix cache for the rotary embedding to avoid computing + while doing the forward pass. + Args: + inv_freq: The precomputed inverse frequency tensor + Returns: + Rotation matrices with shape [max_pos, dim/2, 2, 2] + """ + t = torch.arange( + int(self.original_max_position_embeddings * self.scaling_factor), + device=device, + dtype=torch.float32, + ) + freqs = torch.outer(t, inv_freq).float() + + # Apply mscale and compute cos/sin + cos = freqs.cos() * self.attn_factor + sin = freqs.sin() * self.attn_factor + + # Construct rotation matrices: [max_pos, dim/2, 2, 2] + # Matrix form: [[cos, -sin], [sin, cos]] + freqs_cis = torch.stack([cos, -sin, sin, cos], dim=-1).view(*cos.shape, 2, 2) + + return freqs_cis + + + def _get_llama_4_attn_scale(self, positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor: + scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings)) + return scaling.unsqueeze(-1) + + def compute_freqs_cis(self, device: torch.device, max_seq_len: int) -> None: """ - Pre-compute rotation matrices on CPU during initialization. - This ensures cos/sin operations are never called on Spyre device. + Transfer pre-computed rotation matrices to the target device. - The rotation matrices are computed for original_max_position_embeddings - and cached. They will be transferred to target devices as needed without - recomputation. + This method no longer computes cos/sin - those are pre-computed on CPU + during __init__. This method only handles device transfers. + + Args: + device: target device to transfer rotation matrices to + max_seq_len: maximum sequence length (must not exceed original_max_position_embeddings) """ - # Use CPU device for all trigonometric operations - device = torch.device("cpu") - max_seq_len = self.original_max_position_embeddings - # Compute inverse frequencies + if device == torch.device("meta"): + return + + if device.index in self.cached_freqs: + return + + dev_idx = device.index + freqs = self.base ** ( torch.arange(0, self.dim, 2, device=device).float() / self.dim ) @@ -703,9 +738,12 @@ def _precompute_rotation_matrices_on_cpu(self) -> None: ramp_func = torch.clamp(linear_func, 0, 1) # inv_freq_extrapolation_factor is the weight for extrapolation + # (1 - ramp_func) means: use extrapolation for low frequencies (< low) + # ramp_func means: use interpolation for high frequencies (> high) inv_freq_extrapolation_factor = 1 - ramp_func # Blend between interpolation and extrapolation + # Note: extrapolation_factor is applied to the extrapolation frequencies inv_freq = ( inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + inv_freq_extrapolation @@ -713,75 +751,10 @@ def _precompute_rotation_matrices_on_cpu(self) -> None: * self.extrapolation_factor ) - # Compute position encodings - t = torch.arange( - max_seq_len, - device=device, - dtype=torch.float32, - ) - freqs = torch.outer(t, inv_freq).float() - - # Apply mscale and compute cos/sin ON CPU - # This is the critical part - cos/sin are computed here once and never again - cos = freqs.cos() * self.attn_factor - sin = freqs.sin() * self.attn_factor - - # return cos, sin, freqs - # Construct rotation matrices: [max_pos, dim/2, 2, 2] - # Matrix form: [[cos, -sin], [sin, cos]] - freqs_cis = torch.stack([cos, -sin, sin, cos], dim=-1).view(*cos.shape, 2, 2) - - # Cache on CPU (device index -1 for CPU or use 'cpu' as key) - # We'll use -1 as a special key for CPU - self.cached_freqs[-1] = freqs_cis - self.max_seq_len_cached[-1] = max_seq_len - - def _get_llama_4_attn_scale(self, positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor: - scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings)) - return scaling.unsqueeze(-1) - - def compute_freqs_cis(self, device: torch.device, max_seq_len: int) -> torch.Tensor: - """ - Transfer pre-computed rotation matrices to the target device. - - This method no longer computes cos/sin - those are pre-computed on CPU - during __init__. This method only handles device transfers. - - Args: - device: target device to transfer rotation matrices to - max_seq_len: maximum sequence length (must not exceed original_max_position_embeddings) - """ - - if device == torch.device("meta"): - return None - - if device.index in self.cached_freqs: - return None - - dev_idx = device.index - - # Check if we have the CPU cache (should always be true after __init__) - if -1 not in self.cached_freqs: - raise RuntimeError( - "CPU rotation matrices not found. This should have been computed in __init__." - ) - - # Verify sequence length doesn't exceed what we pre-computed - if max_seq_len > self.original_max_position_embeddings: - raise ValueError( - f"Requested max_seq_len ({max_seq_len}) exceeds " - f"original_max_position_embeddings ({self.original_max_position_embeddings}). " - f"CachedYarnRotaryEmbedding pre-computes rotation matrices during initialization " - f"and cannot dynamically extend beyond the configured maximum." - ) - - # Transfer pre-computed rotation matrices from CPU to target device - # This is a simple tensor copy - no cos/sin computation on target device - cpu_freqs = self.cached_freqs[-1] - self.cached_freqs[dev_idx] = cpu_freqs.to(device) - self.max_seq_len_cached[dev_idx] = self.max_seq_len_cached[-1] + # Cache the computed rotation matrices for this device + freqs_cis = self._compute_cos_sin_cache(inv_freq, device) + self.cached_freqs[dev_idx] = freqs_cis - return self.cached_freqs[dev_idx] def adjusted_qk( self, @@ -837,7 +810,7 @@ def adjusted_qk( # Get device index for cache lookup # Use -1 for CPU, otherwise use the device index - dev_idx = q.device.index if q.device.index is not None else -1 + dev_idx = q.device.index # if q.device.index is not None else -1 # dev_idx = q.device.index # Index by position_ids: [B, L] -> [B, L, rotary_dim/2, 2, 2] diff --git a/fms/utils/generation.py b/fms/utils/generation.py index 51d5f3064..ae749aefc 100644 --- a/fms/utils/generation.py +++ b/fms/utils/generation.py @@ -19,6 +19,7 @@ def pad_input_ids( is_causal_mask=True, padding_side="left", position_ids_offset=0, + pad_token_id=0, ) -> Tuple[torch.Tensor, MutableMapping[str, Any]]: """ Convert a list of Tensors to a rectangular tensor. Return extra padding kwargs for the position_ids and mask, since @@ -34,7 +35,8 @@ def pad_input_ids( position_ids_offset: int some models are trained with position_ids that do not start at 0 but at pad_id + 1. The default parameter here will work for most models, but for example MPNet requires passing a real pad_id. - + pad_token_id: int + the token ID to use for padding. Default is 0. Returns ------- Tuple[torch.Tensor, MutableMapping[str, Any]] @@ -49,25 +51,34 @@ def pad_input_ids( position_ids_list = [] for input_ids_i in input_ids_list: seq_len = input_ids_i.size(0) - pads = torch.zeros( - max_len - seq_len, dtype=torch.long, device=input_ids_i.device + pads = torch.full( + (max_len - seq_len,), + pad_token_id, + dtype=torch.long, + device=input_ids_i.device, ) non_pads = torch.ones(seq_len, dtype=torch.bool, device=input_ids_i.device) # Setting this to 0, however if 0 is the eos, we will end up truncating the output if using truncate_after_eos # once this workflow works for nested tensor, this can probably be removed - pos_ids_pads = pads + pos_ids_pads = torch.zeros( + max_len - seq_len, dtype=torch.long, device=input_ids_i.device + ) pos_ids_seq = torch.arange( 0, seq_len, dtype=torch.long, device=input_ids_i.device ) if padding_side == "left": padded_input_ids_list.append(torch.cat((pads, input_ids_i))) - mask_list.append(torch.cat((pads.bool(), non_pads))) + mask_list.append( + torch.cat((pads.bool(), non_pads)) + ) # This will be False for pad tokens position_ids_list.append(torch.cat((pos_ids_pads, pos_ids_seq))) elif padding_side == "right": padded_input_ids_list.append(torch.cat((input_ids_i, pads))) - mask_list.append(torch.cat((non_pads, pads.bool()))) + mask_list.append( + torch.cat((non_pads, pads.bool())) + ) # This will be False for pad tokens position_ids_list.append(torch.cat((pos_ids_seq, pos_ids_pads))) else: raise NotImplementedError("padding_side must be 'right' or left'") From 690a84ef3a80f9a22906b04cfa81a25f686317a5 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Thu, 12 Mar 2026 20:52:13 +0000 Subject: [PATCH 23/98] :art: Cleanup CachedYarn implementation Signed-off-by: Gaurav-Kumbhat --- fms/models/__init__.py | 8 ++++--- fms/models/ministral3.py | 4 +--- fms/modules/positions.py | 40 +++++++++++++++++---------------- tests/modules/test_positions.py | 3 +-- 4 files changed, 28 insertions(+), 27 deletions(-) diff --git a/fms/models/__init__.py b/fms/models/__init__.py index 736a7c78f..e41f9bee6 100644 --- a/fms/models/__init__.py +++ b/fms/models/__init__.py @@ -484,9 +484,11 @@ def model_wrap(model): # TODO: should we raise a warning? are uninitialized tensors ever acceptable? if initial_device != torch.device("meta"): fms_model._apply( - lambda t: torch.empty_like(t, device=initial_device) - if t.device == torch.device("meta") - else t + lambda t: ( + torch.empty_like(t, device=initial_device) + if t.device == torch.device("meta") + else t + ) ) return fms_model diff --git a/fms/models/ministral3.py b/fms/models/ministral3.py index b4accb5ec..18369650f 100644 --- a/fms/models/ministral3.py +++ b/fms/models/ministral3.py @@ -100,9 +100,8 @@ def __init__( padding_idx=self.config.pad_id, ) - # Prepare rope parameters, ensuring original_max_position_embeddings matches max_expected_seq_len + # Prepare rope parameters rope_params = dict(self.config.rope_parameters) - rope_params['original_max_position_embeddings'] = self.config.max_expected_seq_len self.rot_emb = CachedYarnRotaryEmbedding( dim=self.config.head_dim, @@ -186,7 +185,6 @@ def post_init(self): ): self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) - def forward( self, x_in, diff --git a/fms/modules/positions.py b/fms/modules/positions.py index f887dfc54..85682785a 100644 --- a/fms/modules/positions.py +++ b/fms/modules/positions.py @@ -611,7 +611,7 @@ def __init__( mscale: float = 1.0, mscale_all_dim: float = 1.0, llama_4_scaling_beta: Optional[float] = None, - **kwargs + **kwargs, ): """ This implements Yarn scaling rotary embedding. @@ -623,9 +623,7 @@ def __init__( super().__init__() self.dim = dim - self.original_max_position_embeddings = ( - original_max_position_embeddings - ) + self.original_max_position_embeddings = original_max_position_embeddings self.base = base self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor @@ -680,9 +678,10 @@ def _compute_cos_sin_cache( return freqs_cis - - def _get_llama_4_attn_scale(self, positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor: - scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings)) + def _get_llama_4_attn_scale(self, positions_ids: torch.Tensor) -> torch.Tensor: + scaling = 1 + self.llama_4_scaling_beta * torch.log( + 1 + torch.floor(positions_ids / self.original_max_position_embeddings) + ) return scaling.unsqueeze(-1) def compute_freqs_cis(self, device: torch.device, max_seq_len: int) -> None: @@ -715,11 +714,15 @@ def compute_freqs_cis(self, device: torch.device, max_seq_len: int) -> None: # NOTE: math.floor and math.ceil being used here are referred to as "truncate" option low = math.floor( self.dim - * math.log(self.original_max_position_embeddings / (self.beta_fast * 2 * math.pi)) + * math.log( + self.original_max_position_embeddings / (self.beta_fast * 2 * math.pi) + ) ) / (2 * math.log(self.base)) high = math.ceil( self.dim - * math.log(self.original_max_position_embeddings / (self.beta_slow * 2 * math.pi)) + * math.log( + self.original_max_position_embeddings / (self.beta_slow * 2 * math.pi) + ) ) / (2 * math.log(self.base)) # Make sure values are not going outside range @@ -755,7 +758,6 @@ def compute_freqs_cis(self, device: torch.device, max_seq_len: int) -> None: freqs_cis = self._compute_cos_sin_cache(inv_freq, device) self.cached_freqs[dev_idx] = freqs_cis - def adjusted_qk( self, q: torch.Tensor, @@ -808,10 +810,8 @@ def adjusted_qk( # Fetch the rotation matrices from cache self.compute_freqs_cis(q.device, max_start_pos + seq_len) - # Get device index for cache lookup - # Use -1 for CPU, otherwise use the device index - dev_idx = q.device.index # if q.device.index is not None else -1 - # dev_idx = q.device.index + # Get device index for cache lookup, None if on CPU + dev_idx = q.device.index # Index by position_ids: [B, L] -> [B, L, rotary_dim/2, 2, 2] freqs = self.cached_freqs[dev_idx][position_ids] @@ -829,7 +829,6 @@ def adjusted_qk( # Apply rotation using matrix multiplication # freqs: [B, L, rotary_dim/2, 2, 2] - # Add head dimension: [B, L, 1, rotary_dim/2, 2, 2] # q_, k_: [B, L, H, rotary_dim/2, 2] q_out = ( freqs[:, -q.size(1) :, None, :, :, :] @@ -844,11 +843,14 @@ def adjusted_qk( .flatten(3) ).type_as(k) - # Concatenate with the non-rotated portion if rotary_dim < head_dim + # Apply llama_4_scaling + if self.llama_4_scaling_beta: + cache_position = torch.arange( + q_out.shape[2], device=q_out.device, dtype=q_out.dtype + ) + q_out = q_out * self._get_llama_4_attn_scale(cache_position) + q_out = q_out.view_as(q_rope) k_out = k_out.view_as(k_rope) - # TODO: Apply llama_4_scaling - # if self.llama_4_scaling_beta: - return q_out, k_out diff --git a/tests/modules/test_positions.py b/tests/modules/test_positions.py index ee2134bbc..f3a4c36a1 100644 --- a/tests/modules/test_positions.py +++ b/tests/modules/test_positions.py @@ -436,7 +436,6 @@ def permute_fms_to_hf(tensor): class CachedYarnRotaryEmbeddingTests(unittest.TestCase): - def test_args(self): """Test that CachedYarnRotaryEmbedding validates input shapes correctly""" q = torch.ones(2, 4, 1, 16, dtype=torch.float) # b s h e @@ -506,7 +505,7 @@ def test_hf_fms_equivalence(self): beta_slow = 1.0 scaling_factor = 16.0 original_max_position_embeddings = 16384 # Dummy Value - llama_4_scaling_beta = 0.1 + llama_4_scaling_beta = None # Not testing llama_4_scaling_beta here mscale = 1.0 mscale_all_dim = 1.0 From fd3d64571781d9b5675d5638e63bb423f340454d Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Mon, 16 Mar 2026 14:15:49 -0500 Subject: [PATCH 24/98] :package: Update transformers to 5.x Signed-off-by: Gaurav-Kumbhat --- fms/models/hf/lm_head_mixins.py | 25 +++++++++++++------- fms/models/hf/roberta/modeling_roberta_hf.py | 5 +++- fms/models/roberta.py | 2 ++ pyproject.toml | 5 ++-- 4 files changed, 26 insertions(+), 11 deletions(-) diff --git a/fms/models/hf/lm_head_mixins.py b/fms/models/hf/lm_head_mixins.py index cb54c21f6..bc45639bd 100644 --- a/fms/models/hf/lm_head_mixins.py +++ b/fms/models/hf/lm_head_mixins.py @@ -274,10 +274,14 @@ class SequenceClassificationLMHeadMixin(LMHeadMixin): set at run-time based on config.num_labels and the label dtype. """ - _tied_weights_keys = [ - "lm_head.head.weight", - "lm_head.head.bias", - ] + # _tied_weights_keys = [ + # "lm_head.head.weight", + # "lm_head.head.bias", + # ] + _tied_weights_keys = { + "lm_head.head.weight": "lm_head.head.weight", + "lm_head.head.bias": "lm_head.head.bias" + } def __init__( self, @@ -374,10 +378,15 @@ def get_output_embeddings(self): class MaskedLMHeadMixin(LMHeadMixin): """Provides a model architecture with a masked lm head""" - _tied_weights_keys = [ - "lm_head.head.weight", - "lm_head.head.bias", - ] + # _tied_weights_keys = [ + # "lm_head.head.weight", + # "lm_head.head.bias", + # ] + + _tied_weights_keys = { + "lm_head.head.weight": "lm_head.head.weight", + "lm_head.head.bias": "lm_head.head.bias" + } def __init__( self, diff --git a/fms/models/hf/roberta/modeling_roberta_hf.py b/fms/models/hf/roberta/modeling_roberta_hf.py index 24e11d650..6071ed65a 100644 --- a/fms/models/hf/roberta/modeling_roberta_hf.py +++ b/fms/models/hf/roberta/modeling_roberta_hf.py @@ -120,7 +120,10 @@ class HFAdaptedRoBERTaHeadless(HFEncoderModelArchitecture): config_class = HFAdaptedRoBERTaConfig base_model_prefix = "hf_adapted_roberta" - _tied_weights_keys = ["encoder.model.embedding.weight", "embedding.weight"] + _tied_weights_keys = { + "encoder.model.embedding.weight": "roberta.embeddings.word_embeddings.weight", + "embedding.weight": "embedding.weight", + } _keys_to_ignore_on_save = ["embedding.weight"] def __init__( diff --git a/fms/models/roberta.py b/fms/models/roberta.py index d00512c2f..9db5e461e 100644 --- a/fms/models/roberta.py +++ b/fms/models/roberta.py @@ -33,6 +33,8 @@ class RoBERTaConfig(ModelConfig): nheads: int = 12 nlayers: int = 12 pad_id: int = 1 + bos_token_id: int = 0 + eos_token_id: int = 2 hidden_grow_factor: float = 4.0 activation_fn: str = "gelu" classifier_activation_fn: str = "tanh" diff --git a/pyproject.toml b/pyproject.toml index e992d2ad3..9d746415c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,15 +40,16 @@ dependencies = [ ] [project.optional-dependencies] -hf = ["transformers==4.57.6"] +hf = ["transformers>=4.57.6"] dev = [ "mypy==1.15.0", "mypy-extensions==1.0.0", "pytest==8.3.4", "sentencepiece==0.2.0", +"transformers>=4.57.6", "pyarrow-stubs==17.16", "types-requests==2.32.0.20241016", -"lm_eval==0.4.7", +"lm_eval==0.4.11", "peft==0.14.0", "fms-model-optimizer @ git+https://github.com/foundation-model-stack/fms-model-optimizer.git" ] From 638b79f7486408d4369e254c4161c7e1d9a6a30e Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Fri, 6 Mar 2026 15:12:59 -0600 Subject: [PATCH 25/98] :package: Make param builder work with transformers 5.x.x refactors Signed-off-by: Gaurav-Kumbhat --- fms/models/hf/config_utils/param_builders.py | 37 +++++++++++++++----- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/fms/models/hf/config_utils/param_builders.py b/fms/models/hf/config_utils/param_builders.py index ec0013201..cb95096bd 100644 --- a/fms/models/hf/config_utils/param_builders.py +++ b/fms/models/hf/config_utils/param_builders.py @@ -14,6 +14,20 @@ from transformers import PretrainedConfig +def reverse_rope_param_lookup(config: PretrainedConfig): + """This function allows fetching the rope_theta from the config + allowing compatibility with transformers 5.0 changes + """ + if hasattr(config, "rope_parameters"): + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = getattr(config.rope_parameters, "rope_scaling", None) + else: + rope_theta = config.rope_theta + rope_scaling = getattr(config, "rope_scaling", None) + + return rope_theta, rope_scaling + + def build_llama_params(config: PretrainedConfig) -> dict: """Param builder for mapping LlamaForCausalLM to FMS.""" config_params = { @@ -26,11 +40,9 @@ def build_llama_params(config: PretrainedConfig) -> dict: "max_expected_seq_len": config.max_position_embeddings, } # New in Llama 3 - rope_theta = getattr(config, "rope_theta", None) + rope_theta, rope_scaling = reverse_rope_param_lookup(config) if rope_theta is not None: config_params["rope_theta"] = rope_theta - # New in Llama 3.1 - rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None: config_params["rope_scaling"] = rope_scaling @@ -56,6 +68,7 @@ def build_gpt_bigcode_params(config: PretrainedConfig) -> dict: def build_mixtral_params(config: PretrainedConfig) -> dict: """Param builder for mapping MixtralForCausalLM to FMS.""" inner_dim = config.intermediate_size + rope_theta, _ = reverse_rope_param_lookup(config) config_params = { "dim": config.hidden_size, "hidden_dim": inner_dim, @@ -63,7 +76,7 @@ def build_mixtral_params(config: PretrainedConfig) -> dict: "kv_heads": config.num_key_value_heads, "num_experts": config.num_local_experts, "top_k_experts": config.num_experts_per_tok, - "rope_base": config.rope_theta, + "rope_base": rope_theta, "max_expected_seq_len": config.max_position_embeddings, } return model_params_with_common_opts(config, config_params, inner_dim=inner_dim) @@ -97,6 +110,7 @@ def build_roberta_params(config: PretrainedConfig, is_classify: bool = False) -> def build_granite_params(config: PretrainedConfig) -> dict: """Param builder for mapping GraniteForCausalLM to FMS.""" + rope_theta, _ = reverse_rope_param_lookup(config) config_params = { "attn_bias": getattr(config, "attention_bias", False), "mlp_bias": getattr(config, "mlp_bias", False), @@ -109,7 +123,7 @@ def build_granite_params(config: PretrainedConfig) -> dict: "attention_multiplier": config.attention_multiplier, "logits_scaling": config.logits_scaling, "embedding_multiplier": config.embedding_multiplier, - "rope_theta": config.rope_theta, + "rope_theta": rope_theta, "activation_fn": config.hidden_act, "head_dim": getattr( config, "head_dim", config.hidden_size // config.num_attention_heads @@ -126,6 +140,8 @@ def build_granite_moe_hybrid_params(config: PretrainedConfig) -> dict: # granite-v4 dense version. In future, based on the configuration # we may route to different architectures or classes. + rope_theta, _ = reverse_rope_param_lookup(config) + config_params = { "attn_bias": getattr(config, "attention_bias", False), "kvheads": config.num_key_value_heads, @@ -137,7 +153,7 @@ def build_granite_moe_hybrid_params(config: PretrainedConfig) -> dict: "attention_multiplier": config.attention_multiplier, "logits_scaling": config.logits_scaling, "embedding_multiplier": config.embedding_multiplier, - "rope_theta": config.rope_theta, + "rope_theta": rope_theta, "activation_fn": config.hidden_act, "head_dim": getattr( config, "head_dim", config.hidden_size // config.num_attention_heads @@ -150,6 +166,7 @@ def build_granite_moe_hybrid_params(config: PretrainedConfig) -> dict: def build_mistral_params(config: PretrainedConfig) -> dict: """Param builder for mapping MistralForCausalLM to FMS.""" + rope_theta, _ = reverse_rope_param_lookup(config) config_params = { "activation_fn": config.hidden_act, "emb_dim": config.hidden_size, @@ -161,7 +178,7 @@ def build_mistral_params(config: PretrainedConfig) -> dict: or config.hidden_size // config.num_attention_heads ), "norm_eps": config.rms_norm_eps, - "rope_base": config.rope_theta, + "rope_base": rope_theta, "sliding_window": config.sliding_window, } return model_params_with_common_opts( @@ -321,6 +338,10 @@ def build_pixtral_params(config: PretrainedConfig) -> dict: # we use the same default in Pixtral's encoder, which is 1e-5, # but should be aware in case this is changed and added to the # config in future releases. + + # To handle cases such as ministral3 + rope_theta, _ = reverse_rope_param_lookup(config) + config_params = { "hidden_size": config.hidden_size, "intermediate_size": config.intermediate_size, @@ -330,7 +351,7 @@ def build_pixtral_params(config: PretrainedConfig) -> dict: "image_size": config.image_size, "patch_size": config.patch_size, "hidden_act": config.hidden_act, - "rope_theta": config.rope_theta, + "rope_theta": rope_theta, "attention_dropout": config.attention_dropout, "initializer_range": config.initializer_range, } From dd6fdf7b96fbe082ab5c736e91b492368db11a45 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Tue, 17 Mar 2026 16:12:49 -0300 Subject: [PATCH 26/98] Fix weight loading to be compatible with FMS RoPE Signed-off-by: Max de Bayser --- fms/models/qwen3.py | 11 +++-------- fms/modules/attention.py | 4 ++-- fms/modules/positions.py | 3 --- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/fms/models/qwen3.py b/fms/models/qwen3.py index 843532619..3564876e3 100644 --- a/fms/models/qwen3.py +++ b/fms/models/qwen3.py @@ -563,11 +563,7 @@ def _hf_to_fms_rope( ) -> Mapping[str, Any]: new_sd = {} - if model_config: - head_size = model_config.emb_dim // model_config.nheads - else: - logger.warning("Missing model_config, assuming defaults for head_size") - head_size = 128 # Good default for most models + head_size = model_config.head_dim for name, param in input_sd.items(): # Some checkpoints have weights in different precisions, which can have @@ -581,7 +577,7 @@ def _hf_to_fms_rope( ) rope_params = _get_rope_params(linear_type_str) trans_required_pattern = re.compile( - f"base_model.layers.[0-9]+.attn.in_proj.(query|key).({'|'.join(rope_params)})$" + f"base_model.layers.[0-9]+.attn.in_proj.(query|key|q_norm|k_norm).({'|'.join(rope_params)})$" ) # hf -> fms requires a transpose operation for the query and key @@ -615,7 +611,6 @@ def _hf_to_fms_rope( if is_gptq_2d_qparam: temp = temp.transpose(0, 1) - new_sd[name] = temp else: new_sd[name] = param @@ -646,5 +641,5 @@ def _get_rope_params(linear_type: str) -> list[str]: serialization.register_adapter( _architecture_name, "hf", - ["hf_to_fms_names"], + ["hf_to_fms_names", "hf_to_fms_rope"], ) diff --git a/fms/modules/attention.py b/fms/modules/attention.py index 141e37457..7366be77b 100644 --- a/fms/modules/attention.py +++ b/fms/modules/attention.py @@ -591,8 +591,8 @@ def forward( k_len = keys.shape[1] # Reshape to separate heads: b x len x heads x head_dim - queries = queries.view(batch_size, q_len, self.nheads, self.head_dim) - keys = keys.view(batch_size, k_len, self.kvheads, self.head_dim) + queries = queries.view(batch_size, self.nheads, q_len, self.head_dim) + keys = keys.view(batch_size, self.kvheads, k_len, self.head_dim) # Apply normalization per head queries = self.q_norm(queries) diff --git a/fms/modules/positions.py b/fms/modules/positions.py index 5cd51eed5..dc4336b78 100644 --- a/fms/modules/positions.py +++ b/fms/modules/positions.py @@ -436,9 +436,6 @@ def adjusted_qk( return query.reshape(query_shape), key.reshape(key_shape) freqs = self.cached_freqs[q.device.index][alpha][position_ids] - - position_ids = position_ids.clamp(max=freqs.size(0) - 1) - freqs = freqs.float() # 1 L D/2 2 2 q_out = ( From 7b6208d0ae77bf52fa7059305416e6d94477727b Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Tue, 17 Mar 2026 16:15:28 -0300 Subject: [PATCH 27/98] :package: RoBERTa fixes to work with transformers 5.x.x Signed-off-by: Flavia Beo --- fms/models/hf/lm_head_mixins.py | 12 ++---------- fms/models/hf/modeling_hf_adapter.py | 9 ++++++++- fms/models/hf/roberta/modeling_roberta_hf.py | 4 ++-- fms/models/roberta.py | 3 +++ tests/models/hf/test_as_fms_model.py | 7 +++---- tests/models/hf_equivalence/test_roberta.py | 11 ++++++++--- 6 files changed, 26 insertions(+), 20 deletions(-) diff --git a/fms/models/hf/lm_head_mixins.py b/fms/models/hf/lm_head_mixins.py index bc45639bd..8367458e2 100644 --- a/fms/models/hf/lm_head_mixins.py +++ b/fms/models/hf/lm_head_mixins.py @@ -274,10 +274,6 @@ class SequenceClassificationLMHeadMixin(LMHeadMixin): set at run-time based on config.num_labels and the label dtype. """ - # _tied_weights_keys = [ - # "lm_head.head.weight", - # "lm_head.head.bias", - # ] _tied_weights_keys = { "lm_head.head.weight": "lm_head.head.weight", "lm_head.head.bias": "lm_head.head.bias" @@ -378,14 +374,10 @@ def get_output_embeddings(self): class MaskedLMHeadMixin(LMHeadMixin): """Provides a model architecture with a masked lm head""" - # _tied_weights_keys = [ - # "lm_head.head.weight", - # "lm_head.head.bias", - # ] - _tied_weights_keys = { "lm_head.head.weight": "lm_head.head.weight", - "lm_head.head.bias": "lm_head.head.bias" + "lm_head.head.bias": "lm_head.head.bias", + "embedding.weight": "embedding.weight" } def __init__( diff --git a/fms/models/hf/modeling_hf_adapter.py b/fms/models/hf/modeling_hf_adapter.py index ca1ed8adb..dd7408fe0 100644 --- a/fms/models/hf/modeling_hf_adapter.py +++ b/fms/models/hf/modeling_hf_adapter.py @@ -1,13 +1,14 @@ import abc import copy import os +from packaging.version import Version from typing import Callable, Dict, Optional, Tuple, Union import torch from torch import nn from torch.nn.modules.loss import _Loss from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin -from transformers.modeling_utils import no_init_weights +from transformers import __version__ as tf_version from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -15,6 +16,12 @@ ) from transformers.utils import ModelOutput, is_torch_fx_proxy +## Address transformers API changes +if Version(tf_version) >= Version("5.0.0"): + from transformers.initialization import no_init_weights +else: + from transformers.modeling_utils import no_init_weights + from fms.models.hf.utils import mask_2d_to_3d, mask_2d_to_3d_bidirectional diff --git a/fms/models/hf/roberta/modeling_roberta_hf.py b/fms/models/hf/roberta/modeling_roberta_hf.py index 6071ed65a..76f0c3d21 100644 --- a/fms/models/hf/roberta/modeling_roberta_hf.py +++ b/fms/models/hf/roberta/modeling_roberta_hf.py @@ -123,9 +123,9 @@ class HFAdaptedRoBERTaHeadless(HFEncoderModelArchitecture): _tied_weights_keys = { "encoder.model.embedding.weight": "roberta.embeddings.word_embeddings.weight", "embedding.weight": "embedding.weight", + "lm_head.head.weight": "lm_head.head.weight" } - _keys_to_ignore_on_save = ["embedding.weight"] - + def __init__( self, config: PretrainedConfig, diff --git a/fms/models/roberta.py b/fms/models/roberta.py index 9db5e461e..3b2776311 100644 --- a/fms/models/roberta.py +++ b/fms/models/roberta.py @@ -653,6 +653,9 @@ def _hf_to_fms_names(hf_sd: Mapping[str, Any], **kwargs) -> Mapping[str, Any]: "classification_head.head", ), # only relevant to SentenceClassification task (r"^qa_outputs", "qa_head"), # only relevant to QuestionAnswering task + # Convert HF LayerNorm parameter names (gamma/beta) to PyTorch standard (weight/bias) + (r"gamma", "weight"), + (r"beta", "bias"), ] new_sd = {} for name, param in hf_sd.items(): diff --git a/tests/models/hf/test_as_fms_model.py b/tests/models/hf/test_as_fms_model.py index d1782bb09..591ec1232 100644 --- a/tests/models/hf/test_as_fms_model.py +++ b/tests/models/hf/test_as_fms_model.py @@ -61,10 +61,9 @@ def test_as_fms_model_equivalency_for_decoder(model_id_or_path): def test_as_fms_model_equivalency_for_encoder(model_id_or_path): hf_model = AutoModelForMaskedLM.from_pretrained(model_id_or_path) with tempfile.TemporaryDirectory() as workdir: - # robertas bin file is not working properly, and we are getting different results for safetensors, this should - # be addressed in another PR + # Use safetensors format for compatibility with transformers 5.0.0 hf_model.save_pretrained( - f"{workdir}/roberta-base-masked_lm", safe_serialization=False + f"{workdir}/roberta-base-masked_lm", safe_serialization=True ) # loading from local rather than snapshot download @@ -74,7 +73,7 @@ def test_as_fms_model_equivalency_for_encoder(model_id_or_path): bos_token_id=hf_model.config.bos_token_id, pad_token_id=hf_model.config.pad_token_id, eos_token_id=hf_model.config.eos_token_id, - task_specific_params=hf_model.config.task_specific_params, + # task_specific_params=hf_model.config.task_specific_params, ) fms_model = fms_model.eval() hf_model = hf_model.eval() diff --git a/tests/models/hf_equivalence/test_roberta.py b/tests/models/hf_equivalence/test_roberta.py index dc5a3056f..315c084eb 100644 --- a/tests/models/hf_equivalence/test_roberta.py +++ b/tests/models/hf_equivalence/test_roberta.py @@ -47,7 +47,8 @@ def test_roberta_base_for_masked_lm_equivalency(model_id): assert model_param_count == hf_model_param_count hf_model_fms = to_hf_api( - model, task_specific_params=hf_model.config.task_specific_params + model, + #task_specific_params=hf_model.config.task_specific_params ) # test the param count is the same between hf model and hf fms model @@ -108,7 +109,9 @@ def test_roberta_base_for_masked_lm_equivalency(model_id): inputs = torch.arange(0, 15).unsqueeze(0) labels = torch.arange(0, 15).unsqueeze(0) - attention_mask = (inputs == 1).unsqueeze(-1) == (inputs == 1).unsqueeze(-2) + # Create 2D attention mask for transformers 5.0.0 compatibility + # For bidirectional models like RoBERTa/BERT, use all-ones mask (all tokens attend to all) + attention_mask = torch.ones_like(inputs) hf_model_loss = hf_model( input_ids=inputs, labels=labels, attention_mask=attention_mask, return_dict=True ).loss @@ -195,7 +198,9 @@ def test_roberta_base_for_sequence_classification(model_id, task, problem_type): labels = torch.randint(high=hf_model.config.num_labels, size=(1,)) else: labels = torch.randn(hf_model.config.num_labels).unsqueeze(0) - attention_mask = (inputs == 1).unsqueeze(-1) == (inputs == 1).unsqueeze(-2) + # Create 2D attention mask for transformers 5.0.0 compatibility + # For bidirectional models like RoBERTa/BERT, use all-ones mask (all tokens attend to all) + attention_mask = torch.ones_like(inputs) hf_model_loss = hf_model( input_ids=inputs, labels=labels, attention_mask=attention_mask, return_dict=True ).loss From a4e918e5a8a6ebd7c3d8b4e3d28232fb0baf2971 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Wed, 18 Mar 2026 12:02:23 -0300 Subject: [PATCH 28/98] revert change Signed-off-by: Max de Bayser --- fms/modules/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fms/modules/attention.py b/fms/modules/attention.py index 7366be77b..141e37457 100644 --- a/fms/modules/attention.py +++ b/fms/modules/attention.py @@ -591,8 +591,8 @@ def forward( k_len = keys.shape[1] # Reshape to separate heads: b x len x heads x head_dim - queries = queries.view(batch_size, self.nheads, q_len, self.head_dim) - keys = keys.view(batch_size, self.kvheads, k_len, self.head_dim) + queries = queries.view(batch_size, q_len, self.nheads, self.head_dim) + keys = keys.view(batch_size, k_len, self.kvheads, self.head_dim) # Apply normalization per head queries = self.q_norm(queries) From 489dcecc132c3d6363ce2289765d91cf889b9b46 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Fri, 20 Mar 2026 15:13:02 -0300 Subject: [PATCH 29/98] fix kv length Signed-off-by: Max de Bayser --- fms/modules/attention.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fms/modules/attention.py b/fms/modules/attention.py index 141e37457..d43255078 100644 --- a/fms/modules/attention.py +++ b/fms/modules/attention.py @@ -831,6 +831,7 @@ def forward( # q, k, v: batch_size x seq_len x emb_dim # mask: batch_size x seq_len x seq_len batch_size, q_len, _ = q.size() + kv_len = k.shape[1] # if this is self attention, we always recompute # cross attention only gets computed when a cache does not exist @@ -844,8 +845,8 @@ def forward( # note: transposes will be moved in a later PR to fix dis-contiguous tensor issues queries = q_out.view(batch_size, q_len, self.nheads, self.emb_kq_per_head) - keys = k_out.view(batch_size, q_len, self.kvheads, self.emb_kq_per_head) - values = v_out.view(batch_size, q_len, self.kvheads, self.emb_v_per_head) + keys = k_out.view(batch_size, kv_len, self.kvheads, self.emb_kq_per_head) + values = v_out.view(batch_size, kv_len, self.kvheads, self.emb_v_per_head) # You want to apply rotary embeddings pre-cache if self.position_encoder is not None: From 95d130e80ea2c103685258e04803b69e48c616ad Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Fri, 9 Jan 2026 18:06:57 +0000 Subject: [PATCH 30/98] introduce batching in loglikelihood() for lm-eval Signed-off-by: Yannick Schnider --- fms/utils/evaluation.py | 121 +++++++++++++++++++++++++++++++++------- 1 file changed, 102 insertions(+), 19 deletions(-) diff --git a/fms/utils/evaluation.py b/fms/utils/evaluation.py index 68f37d857..c598f4250 100644 --- a/fms/utils/evaluation.py +++ b/fms/utils/evaluation.py @@ -1,7 +1,9 @@ from typing import List, Tuple +import time import torch import torch.nn.functional as F +import tqdm from lm_eval.api.instance import Instance # type: ignore from lm_eval.api.model import LM # type: ignore from lm_eval.api.registry import register_model # type: ignore @@ -34,8 +36,13 @@ def generic_object(): self.model = generic_object self.model.config = generic_object # type: ignore self.model.config._name_or_path = "FMSEvalHarnessLM" # type: ignore + + def _tokenize( + self, + context: str, + continuation: str + ) -> Tuple[List[int], List[int], List[int]]: - def loglikelihood_one(self, context: str, continuation: str) -> Tuple[float, bool]: context_ids = self.tokenizer.convert_tokens_to_ids( self.tokenizer.tokenize(context) ) @@ -46,24 +53,98 @@ def loglikelihood_one(self, context: str, continuation: str) -> Tuple[float, boo self.tokenizer.tokenize(continuation) ) input_ids = context_ids + continuation_ids[:-1] - input_ids = torch.tensor( - input_ids, dtype=torch.long, device=self.device - ).unsqueeze(0) - logits = F.log_softmax(self.wrapped_model(input_ids)[0], -1) - continuation_probs = logits[len(context_ids) - 1 :] - loglikelihood = torch.gather( - continuation_probs, 1, torch.tensor(continuation_ids).unsqueeze(1) - ).squeeze() - predicted = torch.argmax(continuation_probs, -1).tolist() - greedy = predicted == continuation_ids - return loglikelihood.sum().cpu().item(), greedy - - def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: - result = [] - for request in requests: - context, continuation = request.args - result.append(self.loglikelihood_one(context, continuation)) - return result + + return context_ids, continuation_ids, input_ids + + def loglikelihood( + self, + requests: List[Instance], + batch_size: int = 16, + sorting: bool = True, + ) -> List[Tuple[float, bool]]: + + if batch_size > 1: + if not sorting: + print('Sorting can reduce padding and therefore increase throughput considerably.') + else: + # Sorting with batch size 1 has no effect, overriding sorting = False + sorting = False + + # attach original indices and sort by length + indexed_requests = list(enumerate(requests)) + + _req_len = lambda x: ( + len(self.tokenizer.tokenize(x[1].args[0])) + + len(self.tokenizer.tokenize(x[1].args[1])) + ) + + if sorting: + start_time = time.time() + indexed_requests.sort(key=_req_len) + print('Sorting of requests took', time.time() - start_time, 's.') + + results_with_idx: List[Tuple[int, Tuple[float, bool]]] = [] + + # getting the pad token id and default to EOS token id if no pad id found + # Note: is safe because padding tokens are never attended since masked out + pad_id = getattr(self.tokenizer, "pad_token_id", None) + if pad_id is None: + print('pad_token_id not provided for this tokenizer, defaulting to eos_token_id.') + pad_id = getattr(self.tokenizer, "eos_token_id") + + # looping over batches + for start in tqdm.tqdm(range(0, len(indexed_requests), batch_size)): + batch = indexed_requests[start : start + batch_size] + + context_lens = [] + continuation_ids_list = [] + input_ids_list = [] + orig_indices = [] + + # tokenize batch + for orig_idx, req in batch: + context, continuation = req.args + context_ids, continuation_ids, input_ids = self._tokenize(context, continuation) + context_lens.append(len(context_ids)) + continuation_ids_list.append(continuation_ids) + input_ids_list.append(torch.tensor(input_ids, dtype=torch.long)) + orig_indices.append(orig_idx) + + # pad inputs ids + max_len = max(x.size(0) for x in input_ids_list) + + input_ids = torch.full( + (len(batch), max_len), + pad_id, + dtype=torch.long, + device=self.device, + ) + + for i, ids in enumerate(input_ids_list): + input_ids[i, : ids.size(0)] = ids.to(self.device) + + # forward + with torch.no_grad(): + logits = self.wrapped_model(input_ids) + log_probs = F.log_softmax(logits, dim=-1) + + # post-process per sample + for i in range(len(batch)): + context_len = context_lens[i] + continuation_ids = continuation_ids_list[i] + continuation_probs = log_probs[i, context_len - 1 : context_len - 1 + len(continuation_ids)] + loglikelihood = continuation_probs.gather( + 1, torch.tensor(continuation_ids, device=self.device).unsqueeze(1), + ).squeeze(1) + predicted = torch.argmax(continuation_probs, -1).tolist() + greedy = predicted == continuation_ids + results_with_idx.append((orig_indices[i], (loglikelihood.sum().item(), greedy))) + + # restore original request order + if sorting: + results_with_idx.sort(key=lambda x: x[0]) + + return [r for _, r in results_with_idx] def loglikelihood_rolling( self, requests: List[Instance] @@ -72,3 +153,5 @@ def loglikelihood_rolling( def generate_until(self, requests: List[Instance]) -> List[str]: raise NotImplementedError("not implemented yet") + +# Made with Bob From 3520835d64acb8e59abff3028e2b3e976c17c9ed Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Fri, 9 Jan 2026 20:50:17 +0000 Subject: [PATCH 31/98] add batch size as configurable command line argument Signed-off-by: Yannick Schnider --- fms/utils/evaluation.py | 13 +++++++------ scripts/eval_harness.py | 13 ++++++++++++- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/fms/utils/evaluation.py b/fms/utils/evaluation.py index c598f4250..f96a02970 100644 --- a/fms/utils/evaluation.py +++ b/fms/utils/evaluation.py @@ -18,12 +18,16 @@ def __init__( self, model: nn.Module, tokenizer: tokenizers.BaseTokenizer, + use_cache: bool = False, + batch_size: int = 1, device="cpu", rank=0, world_size=1, ): self.wrapped_model = model self.tokenizer = tokenizer + self.use_cache = use_cache + self.batch_size = batch_size self._rank = rank self._world_size = world_size self.device = device @@ -59,11 +63,10 @@ def _tokenize( def loglikelihood( self, requests: List[Instance], - batch_size: int = 16, sorting: bool = True, ) -> List[Tuple[float, bool]]: - if batch_size > 1: + if self.batch_size > 1: if not sorting: print('Sorting can reduce padding and therefore increase throughput considerably.') else: @@ -93,8 +96,8 @@ def loglikelihood( pad_id = getattr(self.tokenizer, "eos_token_id") # looping over batches - for start in tqdm.tqdm(range(0, len(indexed_requests), batch_size)): - batch = indexed_requests[start : start + batch_size] + for start in tqdm.tqdm(range(0, len(indexed_requests), self.batch_size)): + batch = indexed_requests[start : start + self.batch_size] context_lens = [] continuation_ids_list = [] @@ -153,5 +156,3 @@ def loglikelihood_rolling( def generate_until(self, requests: List[Instance]) -> List[str]: raise NotImplementedError("not implemented yet") - -# Made with Bob diff --git a/scripts/eval_harness.py b/scripts/eval_harness.py index 424da2114..932809ca0 100644 --- a/scripts/eval_harness.py +++ b/scripts/eval_harness.py @@ -59,6 +59,12 @@ action="store_false", help="Disable the kv-cache (on by default)", ) +parser.add_argument( + "--batch_size", + type=int, + default=1, + help="batch size for loglikelihood() function", +) parser.add_argument( "--compile", action="store_true", @@ -141,7 +147,12 @@ model = torch.compile(model, mode=args.compile_mode) -lm_obj = evaluation.FMSEvalHarnessLM(model=model, tokenizer=tokenizer, device=device) +lm_obj = evaluation.FMSEvalHarnessLM( + model=model, + tokenizer=tokenizer, + use_cache=args.no_use_cache, + batch_size=args.batch_size, + device=device) results = lm_eval.simple_evaluate( model=lm_obj, From c57dde5dfca6c9a6d2b14aa4bc9471076c007f60 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Mon, 23 Mar 2026 14:45:46 +0000 Subject: [PATCH 32/98] Replace print statements with logger Signed-off-by: Yannick Schnider --- fms/utils/evaluation.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/fms/utils/evaluation.py b/fms/utils/evaluation.py index f96a02970..7476aea61 100644 --- a/fms/utils/evaluation.py +++ b/fms/utils/evaluation.py @@ -1,5 +1,6 @@ from typing import List, Tuple +import logging import time import torch import torch.nn.functional as F @@ -11,6 +12,8 @@ from fms.utils import tokenizers +logger = logging.getLogger(__name__) + @register_model("fms") class FMSEvalHarnessLM(LM): @@ -68,7 +71,7 @@ def loglikelihood( if self.batch_size > 1: if not sorting: - print('Sorting can reduce padding and therefore increase throughput considerably.') + logger.info('Sorting can reduce padding and therefore increase throughput considerably.') else: # Sorting with batch size 1 has no effect, overriding sorting = False sorting = False @@ -84,7 +87,7 @@ def loglikelihood( if sorting: start_time = time.time() indexed_requests.sort(key=_req_len) - print('Sorting of requests took', time.time() - start_time, 's.') + logger.info(f'Sorting of requests took {time.time() - start_time:.3f}s') results_with_idx: List[Tuple[int, Tuple[float, bool]]] = [] @@ -92,7 +95,7 @@ def loglikelihood( # Note: is safe because padding tokens are never attended since masked out pad_id = getattr(self.tokenizer, "pad_token_id", None) if pad_id is None: - print('pad_token_id not provided for this tokenizer, defaulting to eos_token_id.') + logger.warning('pad_token_id not provided for this tokenizer, defaulting to eos_token_id.') pad_id = getattr(self.tokenizer, "eos_token_id") # looping over batches From b2b21dff1628391c63e6b1a55daeef3dbe90b2f4 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Mon, 23 Mar 2026 15:22:36 +0000 Subject: [PATCH 33/98] Remove code from other dev branch Signed-off-by: Yannick Schnider --- fms/utils/evaluation.py | 2 -- scripts/eval_harness.py | 1 - 2 files changed, 3 deletions(-) diff --git a/fms/utils/evaluation.py b/fms/utils/evaluation.py index 7476aea61..d6b3b7330 100644 --- a/fms/utils/evaluation.py +++ b/fms/utils/evaluation.py @@ -21,7 +21,6 @@ def __init__( self, model: nn.Module, tokenizer: tokenizers.BaseTokenizer, - use_cache: bool = False, batch_size: int = 1, device="cpu", rank=0, @@ -29,7 +28,6 @@ def __init__( ): self.wrapped_model = model self.tokenizer = tokenizer - self.use_cache = use_cache self.batch_size = batch_size self._rank = rank self._world_size = world_size diff --git a/scripts/eval_harness.py b/scripts/eval_harness.py index 932809ca0..e2e381312 100644 --- a/scripts/eval_harness.py +++ b/scripts/eval_harness.py @@ -150,7 +150,6 @@ lm_obj = evaluation.FMSEvalHarnessLM( model=model, tokenizer=tokenizer, - use_cache=args.no_use_cache, batch_size=args.batch_size, device=device) From 85b44b72aa1f0ed69f2f8ab68921614926cb2071 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Mon, 23 Mar 2026 15:28:30 +0000 Subject: [PATCH 34/98] Update dependencies Signed-off-by: Yannick Schnider --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 5618da74a..d82b5e1f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,8 @@ dev = [ "sentencepiece==0.2.0", "pyarrow-stubs==17.16", "types-requests==2.32.0.20241016", +"tqdm==4.67.1", +"types-tqdm", "lm_eval==0.4.7", "peft==0.14.0", "fms-model-optimizer>=0.8.1", From afc58a21552929b7c46df2bb07f95d654961118b Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Mon, 23 Mar 2026 15:38:08 +0000 Subject: [PATCH 35/98] ruff and mypy Signed-off-by: Yannick Schnider --- fms/utils/evaluation.py | 64 +++++++++++++++++++++++------------------ scripts/eval_harness.py | 6 ++-- 2 files changed, 38 insertions(+), 32 deletions(-) diff --git a/fms/utils/evaluation.py b/fms/utils/evaluation.py index d6b3b7330..602f00f80 100644 --- a/fms/utils/evaluation.py +++ b/fms/utils/evaluation.py @@ -41,13 +41,10 @@ def generic_object(): self.model = generic_object self.model.config = generic_object # type: ignore self.model.config._name_or_path = "FMSEvalHarnessLM" # type: ignore - + def _tokenize( - self, - context: str, - continuation: str + self, context: str, continuation: str ) -> Tuple[List[int], List[int], List[int]]: - context_ids = self.tokenizer.convert_tokens_to_ids( self.tokenizer.tokenize(context) ) @@ -66,26 +63,28 @@ def loglikelihood( requests: List[Instance], sorting: bool = True, ) -> List[Tuple[float, bool]]: - if self.batch_size > 1: if not sorting: - logger.info('Sorting can reduce padding and therefore increase throughput considerably.') + logger.info( + "Sorting can reduce padding and therefore increase throughput considerably." + ) else: # Sorting with batch size 1 has no effect, overriding sorting = False sorting = False - + # attach original indices and sort by length indexed_requests = list(enumerate(requests)) - _req_len = lambda x: ( - len(self.tokenizer.tokenize(x[1].args[0])) - + len(self.tokenizer.tokenize(x[1].args[1])) - ) - if sorting: + + def _req_len(x): + return len(self.tokenizer.tokenize(x[1].args[0])) + len( + self.tokenizer.tokenize(x[1].args[1]) + ) + start_time = time.time() indexed_requests.sort(key=_req_len) - logger.info(f'Sorting of requests took {time.time() - start_time:.3f}s') + logger.info(f"Sorting of requests took {time.time() - start_time:.3f}s") results_with_idx: List[Tuple[int, Tuple[float, bool]]] = [] @@ -93,31 +92,35 @@ def loglikelihood( # Note: is safe because padding tokens are never attended since masked out pad_id = getattr(self.tokenizer, "pad_token_id", None) if pad_id is None: - logger.warning('pad_token_id not provided for this tokenizer, defaulting to eos_token_id.') + logger.warning( + "pad_token_id not provided for this tokenizer, defaulting to eos_token_id." + ) pad_id = getattr(self.tokenizer, "eos_token_id") # looping over batches for start in tqdm.tqdm(range(0, len(indexed_requests), self.batch_size)): batch = indexed_requests[start : start + self.batch_size] - context_lens = [] - continuation_ids_list = [] - input_ids_list = [] - orig_indices = [] + context_lens: List[int] = [] + continuation_ids_list: List[List[int]] = [] + input_ids_list: List[torch.Tensor] = [] + orig_indices: List[int] = [] # tokenize batch for orig_idx, req in batch: context, continuation = req.args - context_ids, continuation_ids, input_ids = self._tokenize(context, continuation) + context_ids, continuation_ids, input_ids_raw = self._tokenize( + context, continuation + ) context_lens.append(len(context_ids)) continuation_ids_list.append(continuation_ids) - input_ids_list.append(torch.tensor(input_ids, dtype=torch.long)) + input_ids_list.append(torch.tensor(input_ids_raw, dtype=torch.long)) orig_indices.append(orig_idx) # pad inputs ids max_len = max(x.size(0) for x in input_ids_list) - input_ids = torch.full( + input_ids_batch: torch.Tensor = torch.full( (len(batch), max_len), pad_id, dtype=torch.long, @@ -125,29 +128,34 @@ def loglikelihood( ) for i, ids in enumerate(input_ids_list): - input_ids[i, : ids.size(0)] = ids.to(self.device) + input_ids_batch[i, : ids.size(0)] = ids.to(self.device) # forward with torch.no_grad(): - logits = self.wrapped_model(input_ids) + logits = self.wrapped_model(input_ids_batch) log_probs = F.log_softmax(logits, dim=-1) # post-process per sample for i in range(len(batch)): context_len = context_lens[i] continuation_ids = continuation_ids_list[i] - continuation_probs = log_probs[i, context_len - 1 : context_len - 1 + len(continuation_ids)] + continuation_probs = log_probs[ + i, context_len - 1 : context_len - 1 + len(continuation_ids) + ] loglikelihood = continuation_probs.gather( - 1, torch.tensor(continuation_ids, device=self.device).unsqueeze(1), + 1, + torch.tensor(continuation_ids, device=self.device).unsqueeze(1), ).squeeze(1) predicted = torch.argmax(continuation_probs, -1).tolist() greedy = predicted == continuation_ids - results_with_idx.append((orig_indices[i], (loglikelihood.sum().item(), greedy))) + results_with_idx.append( + (orig_indices[i], (loglikelihood.sum().item(), greedy)) + ) # restore original request order if sorting: results_with_idx.sort(key=lambda x: x[0]) - + return [r for _, r in results_with_idx] def loglikelihood_rolling( diff --git a/scripts/eval_harness.py b/scripts/eval_harness.py index e2e381312..2195e5128 100644 --- a/scripts/eval_harness.py +++ b/scripts/eval_harness.py @@ -148,10 +148,8 @@ lm_obj = evaluation.FMSEvalHarnessLM( - model=model, - tokenizer=tokenizer, - batch_size=args.batch_size, - device=device) + model=model, tokenizer=tokenizer, batch_size=args.batch_size, device=device +) results = lm_eval.simple_evaluate( model=lm_obj, From 437b9285441cee63d42f2bfcedaf8e2739d81bbd Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Mon, 23 Mar 2026 16:13:12 +0000 Subject: [PATCH 36/98] Fix test_eval to use keyword argument for device parameter Signed-off-by: Yannick Schnider --- tests/utils/test_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_eval.py b/tests/utils/test_eval.py index addf45635..e6a294ad8 100644 --- a/tests/utils/test_eval.py +++ b/tests/utils/test_eval.py @@ -33,7 +33,7 @@ def test_eval(): tokenizer = get_tokenizer("char_tokenizer") model = ModelMock([ord("a"), ord("d")]) - lm_eval = evaluation.FMSEvalHarnessLM(model, tokenizer, "cpu") + lm_eval = evaluation.FMSEvalHarnessLM(model=model, tokenizer=tokenizer, device="cpu") instance = Instance( request_type="loglikelihood", doc={}, arguments=("hello", "world"), idx=0 ) From b7708eb56e4821da33404e107afd4dd3cdbbbe27 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Mon, 23 Mar 2026 16:21:55 +0000 Subject: [PATCH 37/98] Fix ruff Signed-off-by: Yannick Schnider --- tests/utils/test_eval.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_eval.py b/tests/utils/test_eval.py index e6a294ad8..8e9eb8133 100644 --- a/tests/utils/test_eval.py +++ b/tests/utils/test_eval.py @@ -33,7 +33,9 @@ def test_eval(): tokenizer = get_tokenizer("char_tokenizer") model = ModelMock([ord("a"), ord("d")]) - lm_eval = evaluation.FMSEvalHarnessLM(model=model, tokenizer=tokenizer, device="cpu") + lm_eval = evaluation.FMSEvalHarnessLM( + model=model, tokenizer=tokenizer, device="cpu" + ) instance = Instance( request_type="loglikelihood", doc={}, arguments=("hello", "world"), idx=0 ) From 8074cb83fcc698ce2ddae1930466e9158bad8da7 Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Mon, 23 Mar 2026 15:04:02 -0300 Subject: [PATCH 38/98] Fixes tie word embeddings mappings Signed-off-by: Flavia Beo --- fms/models/gpt_oss.py | 2 ++ fms/models/hf/granite/configuration_granite_hf.py | 2 +- fms/models/hf/granite/modeling_granite_hf.py | 9 +++++---- fms/models/hf/llama/modeling_llama_hf.py | 13 +++++++++++-- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/fms/models/gpt_oss.py b/fms/models/gpt_oss.py index 70ab022ee..d65fd68c9 100644 --- a/fms/models/gpt_oss.py +++ b/fms/models/gpt_oss.py @@ -50,6 +50,8 @@ class GptOssConfig(ModelConfig): src_vocab_size: int = 201088 emb_dim: int = 2880 head_dim: int = 64 + eos_token_id: int = 200002 + bos_token_id: int = 0 num_attention_heads: int = 64 sliding_window: int = 128 rope_base: float = 150000.0 diff --git a/fms/models/hf/granite/configuration_granite_hf.py b/fms/models/hf/granite/configuration_granite_hf.py index 1da9648ad..11c7e74bb 100644 --- a/fms/models/hf/granite/configuration_granite_hf.py +++ b/fms/models/hf/granite/configuration_granite_hf.py @@ -57,7 +57,7 @@ def __init__( bos_token_id=bos_token_id, is_decoder=is_decoder, tie_word_embeddings=kwargs.pop( - "tie_word_embeddings", False + "tie_word_embeddings", True ), # note: This was added here as we handle tying of heads with our underlying model, we may want to revisit this in future **kwargs, ) diff --git a/fms/models/hf/granite/modeling_granite_hf.py b/fms/models/hf/granite/modeling_granite_hf.py index 5dad1ad9b..cf6709b92 100644 --- a/fms/models/hf/granite/modeling_granite_hf.py +++ b/fms/models/hf/granite/modeling_granite_hf.py @@ -53,8 +53,6 @@ class HFAdaptedGraniteHeadless(HFDecoderModelArchitecture): config_class = HFAdaptedGraniteConfig base_model_prefix = "hf_adapted_granite" - _tied_weights_keys = ["decoder.model.embedding.weight", "embedding.weight"] - _keys_to_ignore_on_save = ["embedding.weight"] def __init__( self, @@ -110,8 +108,10 @@ def _prepare_inputs_for_generation( class HFAdaptedGraniteForCausalLM(LMHeadModelLMHeadMixin, HFAdaptedGraniteHeadless): _keys_to_ignore_on_load_missing = [r"lm_head.weight"] - _tied_weights_keys = ["embedding.weight", "lm_head.weight"] - + _tied_weights_keys = { + "lm_head.weight": "decoder.model.embedding.weight", + "embedding.weight": "decoder.model.embedding.weight", + } def __init__(self, config: HFAdaptedGraniteConfig, *args, **kwargs): super().__init__(config=config, bias=False, *args, **kwargs) @@ -119,6 +119,7 @@ def __init__(self, config: HFAdaptedGraniteConfig, *args, **kwargs): def _hf_model_from_fms( cls, model: Granite, config: HFAdaptedGraniteConfig ) -> "HFAdaptedGraniteForCausalLM": + config.tie_word_embeddings = True return cls( config=config, decoder=model.base_model, diff --git a/fms/models/hf/llama/modeling_llama_hf.py b/fms/models/hf/llama/modeling_llama_hf.py index 8f5e45cf7..a65f29420 100644 --- a/fms/models/hf/llama/modeling_llama_hf.py +++ b/fms/models/hf/llama/modeling_llama_hf.py @@ -108,16 +108,25 @@ def _prepare_inputs_for_generation( class HFAdaptedLLaMAForCausalLM(LMHeadModelLMHeadMixin, HFAdaptedLLaMAHeadless): - _keys_to_ignore_on_load_missing = [r"lm_head.weight"] - _tied_weights_keys = ["embedding.weight", "lm_head.weight"] + _keys_to_ignore_on_load_missing = [r"lm_head.weight", r"decoder\.model\.embedding\.weight"] + _tied_weights_keys = { + "decoder.model.embedding.weight": "embedding.weight", + } def __init__(self, config: HFAdaptedLLaMAConfig, *args, **kwargs): super().__init__(config=config, bias=False, *args, **kwargs) + def _tie_weights(self): + self.decoder.model.embedding.weight = self.embedding.weight + if self.config.tie_word_embeddings: + self.lm_head.weight = self.embedding.weight + @classmethod def _hf_model_from_fms( cls, model: LLaMA, config: HFAdaptedLLaMAConfig ) -> "HFAdaptedLLaMAForCausalLM": + config.tie_word_embeddings = True + print(f"{config=}") out = cls( config=config, decoder=model.base_model, From bd4b2966470ea0476c4ac32f5b505ace65498bf7 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Mon, 23 Mar 2026 15:08:52 -0300 Subject: [PATCH 39/98] Revert "fix kv length" This reverts commit 489dcecc132c3d6363ce2289765d91cf889b9b46. Signed-off-by: Max de Bayser --- fms/modules/attention.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/fms/modules/attention.py b/fms/modules/attention.py index d43255078..141e37457 100644 --- a/fms/modules/attention.py +++ b/fms/modules/attention.py @@ -831,7 +831,6 @@ def forward( # q, k, v: batch_size x seq_len x emb_dim # mask: batch_size x seq_len x seq_len batch_size, q_len, _ = q.size() - kv_len = k.shape[1] # if this is self attention, we always recompute # cross attention only gets computed when a cache does not exist @@ -845,8 +844,8 @@ def forward( # note: transposes will be moved in a later PR to fix dis-contiguous tensor issues queries = q_out.view(batch_size, q_len, self.nheads, self.emb_kq_per_head) - keys = k_out.view(batch_size, kv_len, self.kvheads, self.emb_kq_per_head) - values = v_out.view(batch_size, kv_len, self.kvheads, self.emb_v_per_head) + keys = k_out.view(batch_size, q_len, self.kvheads, self.emb_kq_per_head) + values = v_out.view(batch_size, q_len, self.kvheads, self.emb_v_per_head) # You want to apply rotary embeddings pre-cache if self.position_encoder is not None: From 067ff3d5bb8933a74eaa0a715c8567025f0faca9 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Mon, 23 Mar 2026 15:20:07 -0300 Subject: [PATCH 40/98] fix mypy issues Signed-off-by: Max de Bayser --- fms/models/qwen3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fms/models/qwen3.py b/fms/models/qwen3.py index 3564876e3..8e32e778e 100644 --- a/fms/models/qwen3.py +++ b/fms/models/qwen3.py @@ -563,8 +563,6 @@ def _hf_to_fms_rope( ) -> Mapping[str, Any]: new_sd = {} - head_size = model_config.head_dim - for name, param in input_sd.items(): # Some checkpoints have weights in different precisions, which can have # auxiliary tensors (see _get_rope_params e.g. gptq, fp8). @@ -601,6 +599,8 @@ def _hf_to_fms_rope( temp = temp.transpose(0, 1) # num_heads is used in the transformation required for hf->fms # can't be precomputed because q and k might have different num_heads + assert model_config is not None and model_config.head_dim is not None + head_size = model_config.head_dim num_heads = temp.size(0) // head_size if temp.dim() == 2: # weight From 722ba3933877bede3395d93646cccef251b9f682 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Mon, 23 Mar 2026 20:38:36 -0300 Subject: [PATCH 41/98] Fix loading of Qwen/Qwen3-0.6B The Qwen/Qwen3-0.6B model has a "model." prefix for all modules except for the lm_head. Signed-off-by: Max de Bayser --- fms/models/qwen3.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fms/models/qwen3.py b/fms/models/qwen3.py index 8e32e778e..7de833909 100644 --- a/fms/models/qwen3.py +++ b/fms/models/qwen3.py @@ -532,6 +532,7 @@ def _hf_to_fms_names( """ replacements = [ (r"^lm_head.weight", "head.weight"), + (r"^model.", ""), # Qwen3 Embedding models have no "model." prefix, but the generative ones do (r"^norm.weight", "base_model.dec_norm.weight"), (r"^embed_tokens.weight", "base_model.embedding.weight"), (r"layers", "base_model.layers"), From 2fbe86a5624e32a0444b9264544ac3e4c5e70988 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Mon, 23 Mar 2026 22:00:37 -0300 Subject: [PATCH 42/98] appease linter Signed-off-by: Max de Bayser --- fms/models/qwen3.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fms/models/qwen3.py b/fms/models/qwen3.py index 7de833909..adbd08160 100644 --- a/fms/models/qwen3.py +++ b/fms/models/qwen3.py @@ -532,7 +532,8 @@ def _hf_to_fms_names( """ replacements = [ (r"^lm_head.weight", "head.weight"), - (r"^model.", ""), # Qwen3 Embedding models have no "model." prefix, but the generative ones do + # Qwen3 Embedding models have no "model." prefix, but the generative ones do + (r"^model.", ""), (r"^norm.weight", "base_model.dec_norm.weight"), (r"^embed_tokens.weight", "base_model.embedding.weight"), (r"layers", "base_model.layers"), From 972c1ce9b810e178ff9c6de303aa10b56517e40f Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Tue, 24 Mar 2026 10:58:27 -0300 Subject: [PATCH 43/98] add workaround for aiu compilation issue Signed-off-by: Max de Bayser --- fms/modules/attention.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/fms/modules/attention.py b/fms/modules/attention.py index 141e37457..b5e182d33 100644 --- a/fms/modules/attention.py +++ b/fms/modules/attention.py @@ -594,6 +594,17 @@ def forward( queries = queries.view(batch_size, q_len, self.nheads, self.head_dim) keys = keys.view(batch_size, k_len, self.kvheads, self.head_dim) + if torch._dynamo.is_compiling(): + queries = ( + queries.transpose(-1, -2) + .contiguous() + .transpose(-1, -2) + .contiguous() + ) + keys = ( + keys.transpose(-1, -2).contiguous().transpose(-1, -2).contiguous() + ) + # Apply normalization per head queries = self.q_norm(queries) keys = self.k_norm(keys) @@ -949,14 +960,13 @@ def __init__( assert torch.distributed.is_initialized() rank, world_size = distributed.rank_and_world(group) - assert nheads % world_size == 0, ( - "The number of heads must be divisible by world size" - ) - assert (kvheads >= world_size and kvheads % world_size == 0) or ( - kvheads < world_size and world_size % kvheads == 0 - ), ( - "the kv heads must be divisible by the world size or the world size must be divisible by kv heads" - ) + assert ( + nheads % world_size == 0 + ), "The number of heads must be divisible by world size" + assert ( + (kvheads >= world_size and kvheads % world_size == 0) + or (kvheads < world_size and world_size % kvheads == 0) + ), "the kv heads must be divisible by the world size or the world size must be divisible by kv heads" MultiHeadAttention.__init__( self, emb_dim, From 30d747df1d3d7874cb13c65a7c6d81e9d622e461 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Tue, 24 Mar 2026 11:12:46 -0300 Subject: [PATCH 44/98] appease ruff Signed-off-by: Max de Bayser --- fms/modules/attention.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/fms/modules/attention.py b/fms/modules/attention.py index b5e182d33..79df3a983 100644 --- a/fms/modules/attention.py +++ b/fms/modules/attention.py @@ -960,13 +960,14 @@ def __init__( assert torch.distributed.is_initialized() rank, world_size = distributed.rank_and_world(group) - assert ( - nheads % world_size == 0 - ), "The number of heads must be divisible by world size" - assert ( - (kvheads >= world_size and kvheads % world_size == 0) - or (kvheads < world_size and world_size % kvheads == 0) - ), "the kv heads must be divisible by the world size or the world size must be divisible by kv heads" + assert nheads % world_size == 0, ( + "The number of heads must be divisible by world size" + ) + assert (kvheads >= world_size and kvheads % world_size == 0) or ( + kvheads < world_size and world_size % kvheads == 0 + ), ( + "the kv heads must be divisible by the world size or the world size must be divisible by kv heads" + ) MultiHeadAttention.__init__( self, emb_dim, From edad68284c555df2b1ab9dc2857cc64e45e9f248 Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Tue, 24 Mar 2026 12:17:07 -0300 Subject: [PATCH 45/98] Upgrades to 5.0.0 + fixes Signed-off-by: Flavia Beo --- .../hf/gpt_bigcode/modeling_gpt_bigcode_hf.py | 6 ++- fms/models/hf/granite/modeling_granite_hf.py | 2 +- fms/models/hf/llama/modeling_llama_hf.py | 51 +++++++++++++++++-- fms/models/hf/lm_head_mixins.py | 4 +- fms/models/hf/roberta/modeling_roberta_hf.py | 4 +- pyproject.toml | 2 +- tests/models/hf/test_as_fms_model.py | 1 - tests/models/hf_equivalence/test_roberta.py | 5 +- 8 files changed, 58 insertions(+), 17 deletions(-) diff --git a/fms/models/hf/gpt_bigcode/modeling_gpt_bigcode_hf.py b/fms/models/hf/gpt_bigcode/modeling_gpt_bigcode_hf.py index 7750f9a64..223a42010 100644 --- a/fms/models/hf/gpt_bigcode/modeling_gpt_bigcode_hf.py +++ b/fms/models/hf/gpt_bigcode/modeling_gpt_bigcode_hf.py @@ -77,7 +77,11 @@ class HFAdaptedGPTBigCodeForCausalLM( LMHeadModelLMHeadMixin, HFAdaptedGPTBigCodeHeadless ): _keys_to_ignore_on_load_missing = [r"lm_head.weight"] - _tied_weights_keys = ["embedding.weight", "lm_head.weight"] + _tied_weights_keys = { + "decoder.model.embedding.weight": "embedding.weight", + "embedding.weight": "embedding.weight", + "lm_head.head.weight": "lm_head.head.weight", + } def __init__(self, config: HFAdaptedGPTBigCodeConfig, *args, **kwargs): super().__init__(config=config, bias=False, *args, **kwargs) diff --git a/fms/models/hf/granite/modeling_granite_hf.py b/fms/models/hf/granite/modeling_granite_hf.py index cf6709b92..4e413cb9b 100644 --- a/fms/models/hf/granite/modeling_granite_hf.py +++ b/fms/models/hf/granite/modeling_granite_hf.py @@ -53,7 +53,6 @@ class HFAdaptedGraniteHeadless(HFDecoderModelArchitecture): config_class = HFAdaptedGraniteConfig base_model_prefix = "hf_adapted_granite" - def __init__( self, config: PretrainedConfig, @@ -112,6 +111,7 @@ class HFAdaptedGraniteForCausalLM(LMHeadModelLMHeadMixin, HFAdaptedGraniteHeadle "lm_head.weight": "decoder.model.embedding.weight", "embedding.weight": "decoder.model.embedding.weight", } + def __init__(self, config: HFAdaptedGraniteConfig, *args, **kwargs): super().__init__(config=config, bias=False, *args, **kwargs) diff --git a/fms/models/hf/llama/modeling_llama_hf.py b/fms/models/hf/llama/modeling_llama_hf.py index a65f29420..38cde3509 100644 --- a/fms/models/hf/llama/modeling_llama_hf.py +++ b/fms/models/hf/llama/modeling_llama_hf.py @@ -11,6 +11,7 @@ from fms.models.hf.lm_head_mixins import LMHeadModelLMHeadMixin from fms.models.hf.modeling_hf_adapter import HFDecoder, HFDecoderModelArchitecture from fms.models.llama import LLaMA, LLaMAHeadless +from fms.modules.head import LinearClassificationHead class HFAdaptedLLaMADecoder(HFDecoder): @@ -108,25 +109,65 @@ def _prepare_inputs_for_generation( class HFAdaptedLLaMAForCausalLM(LMHeadModelLMHeadMixin, HFAdaptedLLaMAHeadless): - _keys_to_ignore_on_load_missing = [r"lm_head.weight", r"decoder\.model\.embedding\.weight"] + _keys_to_ignore_on_load_missing = [ + r"lm_head.weight", + r"decoder\.model\.embedding\.weight", + ] _tied_weights_keys = { + "lm_head.weight": "embedding.weight", "decoder.model.embedding.weight": "embedding.weight", } def __init__(self, config: HFAdaptedLLaMAConfig, *args, **kwargs): super().__init__(config=config, bias=False, *args, **kwargs) + def _get_empty_lm_head(self, bias: bool) -> nn.Module: + """Override to use LinearClassificationHead instead of nn.Linear""" + return LinearClassificationHead( + self.config.hidden_size, self.config.vocab_size, bias=bias + ) + + def set_output_embeddings(self, new_embeddings): + """Override to ensure we always use LinearClassificationHead""" + if new_embeddings is not None and not isinstance( + new_embeddings, LinearClassificationHead + ): + # If transformers tries to set a regular nn.Linear, convert it to LinearClassificationHead + if isinstance(new_embeddings, nn.Linear): + lm_head = LinearClassificationHead( + new_embeddings.in_features, + new_embeddings.out_features, + bias=new_embeddings.bias is not None, + ) + # Copy the weights and bias + lm_head.weight = new_embeddings.weight + if new_embeddings.bias is not None: + lm_head.bias = new_embeddings.bias + self.lm_head = lm_head + else: + self.lm_head = new_embeddings + else: + self.lm_head = new_embeddings + def _tie_weights(self): - self.decoder.model.embedding.weight = self.embedding.weight + """Tie weights at runtime - FMS models save lm_head.weight, so use that as the source""" if self.config.tie_word_embeddings: - self.lm_head.weight = self.embedding.weight + self.embedding.weight = self.lm_head.weight + self.decoder.model.embedding.weight = self.embedding.weight + + def load_state_dict(self, state_dict, strict=True, assign=False): + """Override to ensure weights are tied after loading""" + result = super().load_state_dict(state_dict, strict=strict, assign=assign) + # Re-tie weights after loading to ensure correct references + self._tie_weights() + return result @classmethod def _hf_model_from_fms( cls, model: LLaMA, config: HFAdaptedLLaMAConfig ) -> "HFAdaptedLLaMAForCausalLM": - config.tie_word_embeddings = True - print(f"{config=}") + # Respect the FMS model's tie_heads setting + config.tie_word_embeddings = model.config.tie_heads out = cls( config=config, decoder=model.base_model, diff --git a/fms/models/hf/lm_head_mixins.py b/fms/models/hf/lm_head_mixins.py index 8367458e2..44b1b9d5b 100644 --- a/fms/models/hf/lm_head_mixins.py +++ b/fms/models/hf/lm_head_mixins.py @@ -276,7 +276,7 @@ class SequenceClassificationLMHeadMixin(LMHeadMixin): _tied_weights_keys = { "lm_head.head.weight": "lm_head.head.weight", - "lm_head.head.bias": "lm_head.head.bias" + "lm_head.head.bias": "lm_head.head.bias", } def __init__( @@ -377,7 +377,7 @@ class MaskedLMHeadMixin(LMHeadMixin): _tied_weights_keys = { "lm_head.head.weight": "lm_head.head.weight", "lm_head.head.bias": "lm_head.head.bias", - "embedding.weight": "embedding.weight" + "embedding.weight": "embedding.weight", } def __init__( diff --git a/fms/models/hf/roberta/modeling_roberta_hf.py b/fms/models/hf/roberta/modeling_roberta_hf.py index 76f0c3d21..4fe0b9bf7 100644 --- a/fms/models/hf/roberta/modeling_roberta_hf.py +++ b/fms/models/hf/roberta/modeling_roberta_hf.py @@ -123,9 +123,9 @@ class HFAdaptedRoBERTaHeadless(HFEncoderModelArchitecture): _tied_weights_keys = { "encoder.model.embedding.weight": "roberta.embeddings.word_embeddings.weight", "embedding.weight": "embedding.weight", - "lm_head.head.weight": "lm_head.head.weight" + "lm_head.head.weight": "lm_head.head.weight", } - + def __init__( self, config: PretrainedConfig, diff --git a/pyproject.toml b/pyproject.toml index 9d746415c..79367b11a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ classifiers=[ ] dependencies = [ "torch >= 2.5.1", - "transformers==4.57.6", + "transformers>=5.0.0", ] [project.optional-dependencies] diff --git a/tests/models/hf/test_as_fms_model.py b/tests/models/hf/test_as_fms_model.py index 591ec1232..96e9f0929 100644 --- a/tests/models/hf/test_as_fms_model.py +++ b/tests/models/hf/test_as_fms_model.py @@ -73,7 +73,6 @@ def test_as_fms_model_equivalency_for_encoder(model_id_or_path): bos_token_id=hf_model.config.bos_token_id, pad_token_id=hf_model.config.pad_token_id, eos_token_id=hf_model.config.eos_token_id, - # task_specific_params=hf_model.config.task_specific_params, ) fms_model = fms_model.eval() hf_model = hf_model.eval() diff --git a/tests/models/hf_equivalence/test_roberta.py b/tests/models/hf_equivalence/test_roberta.py index 315c084eb..f92873900 100644 --- a/tests/models/hf_equivalence/test_roberta.py +++ b/tests/models/hf_equivalence/test_roberta.py @@ -46,10 +46,7 @@ def test_roberta_base_for_masked_lm_equivalency(model_id): hf_model_param_count -= 2 * 768 assert model_param_count == hf_model_param_count - hf_model_fms = to_hf_api( - model, - #task_specific_params=hf_model.config.task_specific_params - ) + hf_model_fms = to_hf_api(model) # test the param count is the same between hf model and hf fms model hf_model_fms_param_count = sum([p.numel() for p in hf_model_fms.parameters()]) From 0a774ef1c2b2a122bb9368de2b9b068e2e632721 Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Tue, 24 Mar 2026 13:45:56 -0300 Subject: [PATCH 46/98] Version verification for retro-compatibility Signed-off-by: Flavia Beo --- .../hf/gpt_bigcode/modeling_gpt_bigcode_hf.py | 20 +++++++---- fms/models/hf/gpt_oss/modeling_gpt_oss_hf.py | 12 ------- fms/models/hf/granite/modeling_granite_hf.py | 18 +++++++--- fms/models/hf/llama/configuration_llama_hf.py | 4 +-- fms/models/hf/llama/modeling_llama_hf.py | 24 ++++++++----- fms/models/hf/lm_head_mixins.py | 36 ++++++++++++++----- fms/models/hf/roberta/modeling_roberta_hf.py | 17 ++++++--- pyproject.toml | 2 +- 8 files changed, 85 insertions(+), 48 deletions(-) diff --git a/fms/models/hf/gpt_bigcode/modeling_gpt_bigcode_hf.py b/fms/models/hf/gpt_bigcode/modeling_gpt_bigcode_hf.py index 223a42010..12bc7affd 100644 --- a/fms/models/hf/gpt_bigcode/modeling_gpt_bigcode_hf.py +++ b/fms/models/hf/gpt_bigcode/modeling_gpt_bigcode_hf.py @@ -12,6 +12,9 @@ from fms.models.hf.lm_head_mixins import LMHeadModelLMHeadMixin from fms.models.hf.modeling_hf_adapter import HFDecoder, HFDecoderModelArchitecture +from packaging.version import Version +from transformers import __version__ as tf_version + class HFAdaptedGPTBigCodeDecoder(HFDecoder): """Adapter for the GPTBigCodeDecoder""" @@ -76,12 +79,17 @@ def __init__( class HFAdaptedGPTBigCodeForCausalLM( LMHeadModelLMHeadMixin, HFAdaptedGPTBigCodeHeadless ): - _keys_to_ignore_on_load_missing = [r"lm_head.weight"] - _tied_weights_keys = { - "decoder.model.embedding.weight": "embedding.weight", - "embedding.weight": "embedding.weight", - "lm_head.head.weight": "lm_head.head.weight", - } + ## Address transformers API changes + if Version(tf_version) >= Version("5.0.0"): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + _tied_weights_keys = { + "decoder.model.embedding.weight": "embedding.weight", + "embedding.weight": "embedding.weight", + "lm_head.head.weight": "lm_head.head.weight", + } + else: + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + _tied_weights_keys = ["embedding.weight", "lm_head.weight"] def __init__(self, config: HFAdaptedGPTBigCodeConfig, *args, **kwargs): super().__init__(config=config, bias=False, *args, **kwargs) diff --git a/fms/models/hf/gpt_oss/modeling_gpt_oss_hf.py b/fms/models/hf/gpt_oss/modeling_gpt_oss_hf.py index 0e02da1b6..9063dc08c 100644 --- a/fms/models/hf/gpt_oss/modeling_gpt_oss_hf.py +++ b/fms/models/hf/gpt_oss/modeling_gpt_oss_hf.py @@ -108,7 +108,6 @@ def prepare_inputs_for_generation( token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs - print(f"{type(past_key_values)=}") if isinstance(past_key_values, DynamicCache): past_key_values = None @@ -120,16 +119,9 @@ def prepare_inputs_for_generation( attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) - if position_ids is not None: - print("before mask") - print(f"position_ids: {position_ids.shape}") - print(position_ids) - if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 - print(f"position_ids: {position_ids.shape}") - print(position_ids) # Don't mask position_ids with 1, keep them as-is for padding tokens # The attention mask will handle padding separately @@ -138,10 +130,6 @@ def prepare_inputs_for_generation( else: position_ids = None - if position_ids is not None: - print(f"position_ids: {position_ids.shape}") - print(position_ids) - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} diff --git a/fms/models/hf/granite/modeling_granite_hf.py b/fms/models/hf/granite/modeling_granite_hf.py index 4e413cb9b..f2a1cc397 100644 --- a/fms/models/hf/granite/modeling_granite_hf.py +++ b/fms/models/hf/granite/modeling_granite_hf.py @@ -10,6 +10,9 @@ from fms.models.hf.modeling_hf_adapter import HFDecoder, HFDecoderModelArchitecture from fms.models.granite import Granite, GraniteHeadless +from packaging.version import Version +from transformers import __version__ as tf_version + class HFAdaptedGraniteDecoder(HFDecoder): """Adapter for the Granite decoder""" @@ -106,11 +109,16 @@ def _prepare_inputs_for_generation( class HFAdaptedGraniteForCausalLM(LMHeadModelLMHeadMixin, HFAdaptedGraniteHeadless): - _keys_to_ignore_on_load_missing = [r"lm_head.weight"] - _tied_weights_keys = { - "lm_head.weight": "decoder.model.embedding.weight", - "embedding.weight": "decoder.model.embedding.weight", - } + ## Address transformers API changes + if Version(tf_version) >= Version("5.0.0"): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "decoder.model.embedding.weight", + "embedding.weight": "decoder.model.embedding.weight", + } + else: + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + _tied_weights_keys = ["embedding.weight", "lm_head.weight"] def __init__(self, config: HFAdaptedGraniteConfig, *args, **kwargs): super().__init__(config=config, bias=False, *args, **kwargs) diff --git a/fms/models/hf/llama/configuration_llama_hf.py b/fms/models/hf/llama/configuration_llama_hf.py index bde9c4e8c..f407790e5 100644 --- a/fms/models/hf/llama/configuration_llama_hf.py +++ b/fms/models/hf/llama/configuration_llama_hf.py @@ -57,8 +57,8 @@ def __init__( bos_token_id=bos_token_id, is_decoder=is_decoder, tie_word_embeddings=kwargs.pop( - "tie_word_embeddings", False - ), # note: This was added here as we handle tying of heads with our underlying model, we may want to revisit this in future + "tie_word_embeddings", True + ), # note: FMS models tie embeddings by default **kwargs, ) diff --git a/fms/models/hf/llama/modeling_llama_hf.py b/fms/models/hf/llama/modeling_llama_hf.py index 38cde3509..5ecf52a30 100644 --- a/fms/models/hf/llama/modeling_llama_hf.py +++ b/fms/models/hf/llama/modeling_llama_hf.py @@ -13,6 +13,9 @@ from fms.models.llama import LLaMA, LLaMAHeadless from fms.modules.head import LinearClassificationHead +from packaging.version import Version +from transformers import __version__ as tf_version + class HFAdaptedLLaMADecoder(HFDecoder): """Adapter for the LLaMA decoder""" @@ -109,14 +112,19 @@ def _prepare_inputs_for_generation( class HFAdaptedLLaMAForCausalLM(LMHeadModelLMHeadMixin, HFAdaptedLLaMAHeadless): - _keys_to_ignore_on_load_missing = [ - r"lm_head.weight", - r"decoder\.model\.embedding\.weight", - ] - _tied_weights_keys = { - "lm_head.weight": "embedding.weight", - "decoder.model.embedding.weight": "embedding.weight", - } + ## Address transformers API changes + if Version(tf_version) >= Version("5.0.0"): + _keys_to_ignore_on_load_missing = [ + r"lm_head.weight", + r"decoder\.model\.embedding\.weight", + ] + _tied_weights_keys = { + "lm_head.weight": "embedding.weight", + "decoder.model.embedding.weight": "embedding.weight", + } + else: + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + _tied_weights_keys = ["embedding.weight", "lm_head.weight"] def __init__(self, config: HFAdaptedLLaMAConfig, *args, **kwargs): super().__init__(config=config, bias=False, *args, **kwargs) diff --git a/fms/models/hf/lm_head_mixins.py b/fms/models/hf/lm_head_mixins.py index 44b1b9d5b..898275266 100644 --- a/fms/models/hf/lm_head_mixins.py +++ b/fms/models/hf/lm_head_mixins.py @@ -17,6 +17,9 @@ from fms.modules.head import MLPClassificationHead from fms.utils.activation import str_to_activation +from packaging.version import Version +from transformers import __version__ as tf_version + class LMHeadMixin: """ @@ -274,10 +277,17 @@ class SequenceClassificationLMHeadMixin(LMHeadMixin): set at run-time based on config.num_labels and the label dtype. """ - _tied_weights_keys = { - "lm_head.head.weight": "lm_head.head.weight", - "lm_head.head.bias": "lm_head.head.bias", - } + ## Address transformers API changes + if Version(tf_version) >= Version("5.0.0"): + _tied_weights_keys = { + "lm_head.head.weight": "lm_head.head.weight", + "lm_head.head.bias": "lm_head.head.bias", + } + else: + _tied_weights_keys = [ + "lm_head.head.weight", + "lm_head.head.bias", + ] def __init__( self, @@ -374,11 +384,19 @@ def get_output_embeddings(self): class MaskedLMHeadMixin(LMHeadMixin): """Provides a model architecture with a masked lm head""" - _tied_weights_keys = { - "lm_head.head.weight": "lm_head.head.weight", - "lm_head.head.bias": "lm_head.head.bias", - "embedding.weight": "embedding.weight", - } + ## Address transformers API changes + if Version(tf_version) >= Version("5.0.0"): + _tied_weights_keys = { + "lm_head.head.weight": "lm_head.head.weight", + "lm_head.head.bias": "lm_head.head.bias", + "embedding.weight": "embedding.weight", + } + else: + _tied_weights_keys = [ + "lm_head.head.weight", + "lm_head.head.bias", + "embedding.weight", + ] def __init__( self, diff --git a/fms/models/hf/roberta/modeling_roberta_hf.py b/fms/models/hf/roberta/modeling_roberta_hf.py index 4fe0b9bf7..a9c7232ab 100644 --- a/fms/models/hf/roberta/modeling_roberta_hf.py +++ b/fms/models/hf/roberta/modeling_roberta_hf.py @@ -13,6 +13,9 @@ from fms.models.hf.modeling_hf_adapter import HFEncoder, HFEncoderModelArchitecture from fms.models.roberta import RoBERTa, RoBERTaConfig, RoBERTaHeadless +from packaging.version import Version +from transformers import __version__ as tf_version + class HFAdaptedRoBERTaConfig(PretrainedConfig): model_type = "hf_adapted_roberta" @@ -120,11 +123,15 @@ class HFAdaptedRoBERTaHeadless(HFEncoderModelArchitecture): config_class = HFAdaptedRoBERTaConfig base_model_prefix = "hf_adapted_roberta" - _tied_weights_keys = { - "encoder.model.embedding.weight": "roberta.embeddings.word_embeddings.weight", - "embedding.weight": "embedding.weight", - "lm_head.head.weight": "lm_head.head.weight", - } + ## Address transformers API changes + if Version(tf_version) >= Version("5.0.0"): + _tied_weights_keys = { + "encoder.model.embedding.weight": "roberta.embeddings.word_embeddings.weight", + "embedding.weight": "embedding.weight", + "lm_head.head.weight": "lm_head.head.weight", + } + else: + _tied_weights_keys = ["embedding.weight"] def __init__( self, diff --git a/pyproject.toml b/pyproject.toml index 79367b11a..3b25bef6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ classifiers=[ ] dependencies = [ "torch >= 2.5.1", - "transformers>=5.0.0", + "transformers>=4.57.6", ] [project.optional-dependencies] From 35bdf67c74ae07bceed7663dca1b8d596570315e Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Tue, 24 Mar 2026 14:23:11 -0300 Subject: [PATCH 47/98] Fix roberta tied weight list for retro-compatibility Signed-off-by: Flavia Beo --- fms/models/hf/roberta/modeling_roberta_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fms/models/hf/roberta/modeling_roberta_hf.py b/fms/models/hf/roberta/modeling_roberta_hf.py index a9c7232ab..b887314fb 100644 --- a/fms/models/hf/roberta/modeling_roberta_hf.py +++ b/fms/models/hf/roberta/modeling_roberta_hf.py @@ -131,7 +131,7 @@ class HFAdaptedRoBERTaHeadless(HFEncoderModelArchitecture): "lm_head.head.weight": "lm_head.head.weight", } else: - _tied_weights_keys = ["embedding.weight"] + _tied_weights_keys = ["encoder.model.embedding.weight", "embedding.weight"] def __init__( self, From ff061c2387ebcd914d0abc2ad543bf55e76c0acc Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Tue, 24 Mar 2026 16:22:21 -0300 Subject: [PATCH 48/98] Add verification for version in roberta equivalency tests changes Signed-off-by: Flavia Beo --- tests/models/hf/test_as_fms_model.py | 37 +++++++++++++++------ tests/models/hf_equivalence/test_roberta.py | 31 +++++++++++++---- 2 files changed, 51 insertions(+), 17 deletions(-) diff --git a/tests/models/hf/test_as_fms_model.py b/tests/models/hf/test_as_fms_model.py index 96e9f0929..8d2e4aa8a 100644 --- a/tests/models/hf/test_as_fms_model.py +++ b/tests/models/hf/test_as_fms_model.py @@ -15,6 +15,9 @@ from fms.models.hf.utils import as_fms_model from fms.testing.comparison import HFModelSignatureParams, compare_model_signatures +from packaging.version import Version +from transformers import __version__ as tf_version + @pytest.mark.parametrize("model_id_or_path", ["bigcode/gpt_bigcode-santacoder"]) def test_as_fms_model_equivalency_for_decoder(model_id_or_path): @@ -61,19 +64,33 @@ def test_as_fms_model_equivalency_for_decoder(model_id_or_path): def test_as_fms_model_equivalency_for_encoder(model_id_or_path): hf_model = AutoModelForMaskedLM.from_pretrained(model_id_or_path) with tempfile.TemporaryDirectory() as workdir: - # Use safetensors format for compatibility with transformers 5.0.0 - hf_model.save_pretrained( - f"{workdir}/roberta-base-masked_lm", safe_serialization=True - ) + if Version(tf_version) >= Version("5.0.0"): + # Use safetensors format for compatibility with transformers 5.0.0 + hf_model.save_pretrained( + f"{workdir}/roberta-base-masked_lm", safe_serialization=True + ) + else: + hf_model.save_pretrained( + f"{workdir}/roberta-base-masked_lm", safe_serialization=False + ) # loading from local rather than snapshot download fms_model = as_fms_model(f"{workdir}/roberta-base-masked_lm") - fms_model = to_hf_api( - fms_model, - bos_token_id=hf_model.config.bos_token_id, - pad_token_id=hf_model.config.pad_token_id, - eos_token_id=hf_model.config.eos_token_id, - ) + if Version(tf_version) >= Version("5.0.0"): + fms_model = to_hf_api( + fms_model, + bos_token_id=hf_model.config.bos_token_id, + pad_token_id=hf_model.config.pad_token_id, + eos_token_id=hf_model.config.eos_token_id, + ) + else: + fms_model = to_hf_api( + fms_model, + bos_token_id=hf_model.config.bos_token_id, + pad_token_id=hf_model.config.pad_token_id, + eos_token_id=hf_model.config.eos_token_id, + task_specific_params=hf_model.config.task_specific_params, + ) fms_model = fms_model.eval() hf_model = hf_model.eval() inp = torch.arange(5, 15).unsqueeze(0) diff --git a/tests/models/hf_equivalence/test_roberta.py b/tests/models/hf_equivalence/test_roberta.py index f92873900..26ce35195 100644 --- a/tests/models/hf_equivalence/test_roberta.py +++ b/tests/models/hf_equivalence/test_roberta.py @@ -20,6 +20,9 @@ compare_model_signatures, ) +from packaging.version import Version +from transformers import __version__ as tf_version + @pytest.mark.parametrize("model_id", ["roberta-base", "google-bert/bert-base-uncased"]) def test_roberta_base_for_masked_lm_equivalency(model_id): @@ -46,7 +49,12 @@ def test_roberta_base_for_masked_lm_equivalency(model_id): hf_model_param_count -= 2 * 768 assert model_param_count == hf_model_param_count - hf_model_fms = to_hf_api(model) + if Version(tf_version) >= Version("5.0.0"): + hf_model_fms = to_hf_api(model) + else: + hf_model_fms = to_hf_api( + model, task_specific_params=hf_model.config.task_specific_params + ) # test the param count is the same between hf model and hf fms model hf_model_fms_param_count = sum([p.numel() for p in hf_model_fms.parameters()]) @@ -106,9 +114,13 @@ def test_roberta_base_for_masked_lm_equivalency(model_id): inputs = torch.arange(0, 15).unsqueeze(0) labels = torch.arange(0, 15).unsqueeze(0) - # Create 2D attention mask for transformers 5.0.0 compatibility - # For bidirectional models like RoBERTa/BERT, use all-ones mask (all tokens attend to all) - attention_mask = torch.ones_like(inputs) + if Version(tf_version) >= Version("5.0.0"): + # Create 2D attention mask for transformers 5.0.0 compatibility + # For bidirectional models like RoBERTa/BERT, use all-ones mask (all tokens attend to all) + attention_mask = torch.ones_like(inputs) + else: + attention_mask = (inputs == 1).unsqueeze(-1) == (inputs == 1).unsqueeze(-2) + hf_model_loss = hf_model( input_ids=inputs, labels=labels, attention_mask=attention_mask, return_dict=True ).loss @@ -195,9 +207,14 @@ def test_roberta_base_for_sequence_classification(model_id, task, problem_type): labels = torch.randint(high=hf_model.config.num_labels, size=(1,)) else: labels = torch.randn(hf_model.config.num_labels).unsqueeze(0) - # Create 2D attention mask for transformers 5.0.0 compatibility - # For bidirectional models like RoBERTa/BERT, use all-ones mask (all tokens attend to all) - attention_mask = torch.ones_like(inputs) + + if Version(tf_version) >= Version("5.0.0"): + # Create 2D attention mask for transformers 5.0.0 compatibility + # For bidirectional models like RoBERTa/BERT, use all-ones mask (all tokens attend to all) + attention_mask = torch.ones_like(inputs) + else: + attention_mask = (inputs == 1).unsqueeze(-1) == (inputs == 1).unsqueeze(-2) + hf_model_loss = hf_model( input_ids=inputs, labels=labels, attention_mask=attention_mask, return_dict=True ).loss From 299301b0b8b2b67aca597fdd3ae40be414e6942f Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Wed, 25 Mar 2026 10:15:16 -0300 Subject: [PATCH 49/98] Skip round trip tests in case the latest transformer versions are used Latest transformers have changed their API signature, so for now skipping these until we find a compatibility fix Signed-off-by: Flavia Beo --- tests/models/hf/test_granite_hf.py | 18 ++++++++++++++++ tests/models/hf/test_llama_hf.py | 34 ++++++++++++++++++++++++++++++ tests/models/hf/test_roberta_hf.py | 19 +++++++++++++++++ 3 files changed, 71 insertions(+) diff --git a/tests/models/hf/test_granite_hf.py b/tests/models/hf/test_granite_hf.py index 66d61f416..f00418a24 100644 --- a/tests/models/hf/test_granite_hf.py +++ b/tests/models/hf/test_granite_hf.py @@ -21,6 +21,8 @@ from fms.testing._internal.model_test_suite import ModelFixtureMixin from ..test_granite import GraniteFixtures +from packaging.version import Version +from transformers import __version__ as tf_version class GraniteHFFixtures(ModelFixtureMixin, HFConfigFixtureMixin, HFModelFixtureMixin): @@ -69,6 +71,22 @@ class TestGraniteHF( - model generation tests """ + @pytest.mark.skipif( + Version(tf_version) >= Version("5.0.0"), + reason="Transformers latest versions have changed API signatures", + ) + def test_hf_model_round_trip_equivalence(self, fms_hf_model, fms_hf_config): + pass + + @pytest.mark.skipif( + Version(tf_version) >= Version("5.0.0"), + reason="Transformers latest versions have changed API signatures", + ) + def test_hf_from_fms_and_hf_from_pretrained_equivalence( + self, tmpdir_factory, model, fms_hf_model + ): + pass + # implementation of abstract property _hf_specific_params _hf_specific_params = ["eos_token_id", "bos_token_id"] # implementation of abstract property _get_hf_signature_params diff --git a/tests/models/hf/test_llama_hf.py b/tests/models/hf/test_llama_hf.py index 63c90854c..d88afaf2b 100644 --- a/tests/models/hf/test_llama_hf.py +++ b/tests/models/hf/test_llama_hf.py @@ -21,6 +21,8 @@ from fms.testing._internal.model_test_suite import ModelFixtureMixin from ..test_llama import LLaMA2Fixtures, LLaMA2GQAFixtures +from packaging.version import Version +from transformers import __version__ as tf_version class LLaMA2HFFixtures(ModelFixtureMixin, HFConfigFixtureMixin, HFModelFixtureMixin): @@ -67,6 +69,22 @@ class TestLLaMA2HF( - model generation tests """ + @pytest.mark.skipif( + Version(tf_version) >= Version("5.0.0"), + reason="Transformers latest versions have changed API signatures", + ) + def test_hf_model_round_trip_equivalence(self, fms_hf_model, fms_hf_config): + pass + + @pytest.mark.skipif( + Version(tf_version) >= Version("5.0.0"), + reason="Transformers latest versions have changed API signatures", + ) + def test_hf_from_fms_and_hf_from_pretrained_equivalence( + self, tmpdir_factory, model, fms_hf_model + ): + pass + # implementation of abstract property _hf_specific_params _hf_specific_params = ["eos_token_id", "bos_token_id"] # implementation of abstract property _get_hf_signature_params @@ -88,6 +106,22 @@ class TestLLaMA2GQAHF( - model generation tests """ + @pytest.mark.skipif( + Version(tf_version) >= Version("5.0.0"), + reason="Transformers latest versions have changed API signatures", + ) + def test_hf_model_round_trip_equivalence(self, fms_hf_model, fms_hf_config): + pass + + @pytest.mark.skipif( + Version(tf_version) >= Version("5.0.0"), + reason="Transformers latest versions have changed API signatures", + ) + def test_hf_from_fms_and_hf_from_pretrained_equivalence( + self, tmpdir_factory, model, fms_hf_model + ): + pass + # implementation of abstract property _hf_specific_params _hf_specific_params = ["eos_token_id", "bos_token_id"] # implementation of abstract property _get_hf_signature_params diff --git a/tests/models/hf/test_roberta_hf.py b/tests/models/hf/test_roberta_hf.py index b98e501fd..3359d61e3 100644 --- a/tests/models/hf/test_roberta_hf.py +++ b/tests/models/hf/test_roberta_hf.py @@ -25,6 +25,9 @@ from ..test_roberta import RoBERTaFixtures +from packaging.version import Version +from transformers import __version__ as tf_version + class HFAdaptedRoBERTaFixtures( ModelFixtureMixin, HFConfigFixtureMixin, HFModelFixtureMixin @@ -148,6 +151,22 @@ class TestHFAdaptedRoBERTa( - model generation tests """ + @pytest.mark.skipif( + Version(tf_version) >= Version("5.0.0"), + reason="Transformers latest versions have changed API signatures", + ) + def test_hf_model_round_trip_equivalence(self, fms_hf_model, fms_hf_config): + pass + + @pytest.mark.skipif( + Version(tf_version) >= Version("5.0.0"), + reason="Transformers latest versions have changed API signatures", + ) + def test_hf_from_fms_and_hf_from_pretrained_equivalence( + self, tmpdir_factory, model, fms_hf_model + ): + pass + # implementation of abstract property _hf_specific_params _hf_specific_params = ["eos_token_id", "bos_token_id"] # implementation of abstract property _get_hf_signature_params From 9ce1269779318ff8da13827548a082d4208b067d Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Thu, 26 Mar 2026 11:12:52 -0300 Subject: [PATCH 50/98] Fixes load state dict with right keys for tf 5.0.0 Signed-off-by: Flavia Beo --- fms/models/hf/llama/configuration_llama_hf.py | 4 +- fms/models/hf/llama/modeling_llama_hf.py | 73 +++++++++++++++---- tests/models/hf/test_llama_hf.py | 16 ---- 3 files changed, 61 insertions(+), 32 deletions(-) diff --git a/fms/models/hf/llama/configuration_llama_hf.py b/fms/models/hf/llama/configuration_llama_hf.py index f407790e5..46bc010cd 100644 --- a/fms/models/hf/llama/configuration_llama_hf.py +++ b/fms/models/hf/llama/configuration_llama_hf.py @@ -56,9 +56,7 @@ def __init__( eos_token_id=eos_token_id, bos_token_id=bos_token_id, is_decoder=is_decoder, - tie_word_embeddings=kwargs.pop( - "tie_word_embeddings", True - ), # note: FMS models tie embeddings by default + tie_word_embeddings=kwargs.pop("tie_word_embeddings", False), **kwargs, ) diff --git a/fms/models/hf/llama/modeling_llama_hf.py b/fms/models/hf/llama/modeling_llama_hf.py index 5ecf52a30..623208faf 100644 --- a/fms/models/hf/llama/modeling_llama_hf.py +++ b/fms/models/hf/llama/modeling_llama_hf.py @@ -115,20 +115,37 @@ class HFAdaptedLLaMAForCausalLM(LMHeadModelLMHeadMixin, HFAdaptedLLaMAHeadless): ## Address transformers API changes if Version(tf_version) >= Version("5.0.0"): _keys_to_ignore_on_load_missing = [ - r"lm_head.weight", r"decoder\.model\.embedding\.weight", ] _tied_weights_keys = { - "lm_head.weight": "embedding.weight", "decoder.model.embedding.weight": "embedding.weight", } else: - _keys_to_ignore_on_load_missing = [r"lm_head.weight"] - _tied_weights_keys = ["embedding.weight", "lm_head.weight"] + _keys_to_ignore_on_load_missing = [r"decoder\.model\.embedding\.weight"] + # Declare that decoder.model.embedding.weight is tied to embedding.weight + _tied_weights_keys = ["decoder.model.embedding.weight"] def __init__(self, config: HFAdaptedLLaMAConfig, *args, **kwargs): super().__init__(config=config, bias=False, *args, **kwargs) + # Ensure padding token embeddings are zeroed after initialization + if ( + hasattr(self, "embedding") + and self.config.pad_token_id is not None + and self.config.pad_token_id >= 0 + ): + with torch.no_grad(): + if hasattr(self.embedding, "weight"): + self.embedding.weight[self.config.pad_token_id].zero_() + if ( + hasattr(self, "decoder") + and hasattr(self.decoder, "model") + and hasattr(self.decoder.model, "embedding") + ): + self.decoder.model.embedding.weight[ + self.config.pad_token_id + ].zero_() + def _get_empty_lm_head(self, bias: bool) -> nn.Module: """Override to use LinearClassificationHead instead of nn.Linear""" return LinearClassificationHead( @@ -157,17 +174,47 @@ def set_output_embeddings(self, new_embeddings): else: self.lm_head = new_embeddings - def _tie_weights(self): - """Tie weights at runtime - FMS models save lm_head.weight, so use that as the source""" - if self.config.tie_word_embeddings: - self.embedding.weight = self.lm_head.weight - self.decoder.model.embedding.weight = self.embedding.weight + def state_dict(self, *args, **kwargs): + """Override to exclude decoder.model.embedding.weight from state_dict. + + This prevents saving duplicate embeddings. The decoder's embedding will be + tied to the main embedding during load via tie_weights(). + """ + state_dict = super().state_dict(*args, **kwargs) + # Remove decoder.model.embedding.weight as it's tied to embedding.weight + if "decoder.model.embedding.weight" in state_dict: + del state_dict["decoder.model.embedding.weight"] + return state_dict def load_state_dict(self, state_dict, strict=True, assign=False): - """Override to ensure weights are tied after loading""" - result = super().load_state_dict(state_dict, strict=strict, assign=assign) - # Re-tie weights after loading to ensure correct references - self._tie_weights() + """Override to handle missing decoder.model.embedding.weight and ensure weights are tied after loading""" + # If decoder.model.embedding.weight is missing from state_dict, that's expected + # because we exclude it in state_dict(). It will be tied to embedding.weight. + # So we load with strict=False first + result = super().load_state_dict(state_dict, strict=False, assign=assign) + + # Filter out the expected missing key from the result + filtered_missing_keys = [ + k for k in result.missing_keys if k != "decoder.model.embedding.weight" + ] + + # If strict mode was requested and there are still missing/unexpected keys, raise an error + if strict and (filtered_missing_keys or result.unexpected_keys): + error_msgs = [] + if result.unexpected_keys: + error_msgs.append( + f"Unexpected key(s) in state_dict: {', '.join(result.unexpected_keys)}" + ) + if filtered_missing_keys: + error_msgs.append( + f"Missing key(s) in state_dict: {', '.join(filtered_missing_keys)}" + ) + if error_msgs: + raise RuntimeError( + f"Error(s) in loading state_dict for {self.__class__.__name__}:\n\t" + + "\n\t".join(error_msgs) + ) + return result @classmethod diff --git a/tests/models/hf/test_llama_hf.py b/tests/models/hf/test_llama_hf.py index d88afaf2b..736388e85 100644 --- a/tests/models/hf/test_llama_hf.py +++ b/tests/models/hf/test_llama_hf.py @@ -69,22 +69,6 @@ class TestLLaMA2HF( - model generation tests """ - @pytest.mark.skipif( - Version(tf_version) >= Version("5.0.0"), - reason="Transformers latest versions have changed API signatures", - ) - def test_hf_model_round_trip_equivalence(self, fms_hf_model, fms_hf_config): - pass - - @pytest.mark.skipif( - Version(tf_version) >= Version("5.0.0"), - reason="Transformers latest versions have changed API signatures", - ) - def test_hf_from_fms_and_hf_from_pretrained_equivalence( - self, tmpdir_factory, model, fms_hf_model - ): - pass - # implementation of abstract property _hf_specific_params _hf_specific_params = ["eos_token_id", "bos_token_id"] # implementation of abstract property _get_hf_signature_params From df171541d415280c80f892630ecf1d2575c37998 Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Thu, 26 Mar 2026 11:40:00 -0300 Subject: [PATCH 51/98] Fixes load state dict with right keys for tf 5.0.0 granite Signed-off-by: Flavia Beo --- .../hf/granite/configuration_granite_hf.py | 3 + fms/models/hf/granite/modeling_granite_hf.py | 55 ++++++++++++++++++- tests/models/hf/test_granite_hf.py | 16 ------ 3 files changed, 57 insertions(+), 17 deletions(-) diff --git a/fms/models/hf/granite/configuration_granite_hf.py b/fms/models/hf/granite/configuration_granite_hf.py index 11c7e74bb..2530b441c 100644 --- a/fms/models/hf/granite/configuration_granite_hf.py +++ b/fms/models/hf/granite/configuration_granite_hf.py @@ -76,4 +76,7 @@ def from_pretrained( def from_fms_config(cls, config: GraniteConfig, **hf_kwargs): config_dict = config.as_dict() config_dict["pad_token_id"] = config_dict.pop("pad_id") + # Set tie_word_embeddings based on tie_heads from FMS config if not explicitly provided + if "tie_word_embeddings" not in hf_kwargs: + hf_kwargs["tie_word_embeddings"] = config_dict.get("tie_heads", False) return cls.from_dict(config_dict, **hf_kwargs) diff --git a/fms/models/hf/granite/modeling_granite_hf.py b/fms/models/hf/granite/modeling_granite_hf.py index f2a1cc397..22a2652b5 100644 --- a/fms/models/hf/granite/modeling_granite_hf.py +++ b/fms/models/hf/granite/modeling_granite_hf.py @@ -123,11 +123,64 @@ class HFAdaptedGraniteForCausalLM(LMHeadModelLMHeadMixin, HFAdaptedGraniteHeadle def __init__(self, config: HFAdaptedGraniteConfig, *args, **kwargs): super().__init__(config=config, bias=False, *args, **kwargs) + def state_dict(self, *args, **kwargs): + """Override to exclude decoder.model.embedding.weight from state_dict. + + This prevents saving duplicate embeddings. The decoder's embedding will be + tied to the main embedding during load via tie_weights(). + """ + state_dict = super().state_dict(*args, **kwargs) + # Remove decoder.model.embedding.weight as it's tied to embedding.weight + if "decoder.model.embedding.weight" in state_dict: + del state_dict["decoder.model.embedding.weight"] + return state_dict + + def load_state_dict(self, state_dict, strict=True, assign=False): + """Override to handle missing decoder.model.embedding.weight and ensure weights are tied after loading""" + # If decoder.model.embedding.weight is missing from state_dict, that's expected + # because we exclude it in state_dict(). It will be tied to embedding.weight. + # So we load with strict=False first + result = super().load_state_dict(state_dict, strict=False, assign=assign) + + # Filter out the expected missing key from the result + filtered_missing_keys = [ + k for k in result.missing_keys if k != "decoder.model.embedding.weight" + ] + + # If strict mode was requested and there are still missing/unexpected keys, raise an error + if strict and (filtered_missing_keys or result.unexpected_keys): + error_msgs = [] + if result.unexpected_keys: + error_msgs.append( + f"Unexpected key(s) in state_dict: {', '.join(result.unexpected_keys)}" + ) + if filtered_missing_keys: + error_msgs.append( + f"Missing key(s) in state_dict: {', '.join(filtered_missing_keys)}" + ) + if error_msgs: + raise RuntimeError( + f"Error(s) in loading state_dict for {self.__class__.__name__}:\n\t" + + "\n\t".join(error_msgs) + ) + + # Manually tie decoder.model.embedding.weight to embedding.weight after loading + if self.decoder.model.embedding is not self.embedding: + self.decoder.model.embedding.weight = self.embedding.weight + + # Only tie lm_head to embedding if tie_word_embeddings is True AND lm_head.weight was not in the state_dict + # (if lm_head.weight was in state_dict, it means they should be separate) + if self.config.tie_word_embeddings and "lm_head.weight" not in state_dict: + self.lm_head.weight = self.embedding.weight + + return result + @classmethod def _hf_model_from_fms( cls, model: Granite, config: HFAdaptedGraniteConfig ) -> "HFAdaptedGraniteForCausalLM": - config.tie_word_embeddings = True + # Set tie_word_embeddings based on the FMS model's tie_heads config + config.tie_word_embeddings = config.tie_heads return cls( config=config, decoder=model.base_model, diff --git a/tests/models/hf/test_granite_hf.py b/tests/models/hf/test_granite_hf.py index f00418a24..a5a0cec07 100644 --- a/tests/models/hf/test_granite_hf.py +++ b/tests/models/hf/test_granite_hf.py @@ -71,22 +71,6 @@ class TestGraniteHF( - model generation tests """ - @pytest.mark.skipif( - Version(tf_version) >= Version("5.0.0"), - reason="Transformers latest versions have changed API signatures", - ) - def test_hf_model_round_trip_equivalence(self, fms_hf_model, fms_hf_config): - pass - - @pytest.mark.skipif( - Version(tf_version) >= Version("5.0.0"), - reason="Transformers latest versions have changed API signatures", - ) - def test_hf_from_fms_and_hf_from_pretrained_equivalence( - self, tmpdir_factory, model, fms_hf_model - ): - pass - # implementation of abstract property _hf_specific_params _hf_specific_params = ["eos_token_id", "bos_token_id"] # implementation of abstract property _get_hf_signature_params From 6e77fcc8bae1c133fbf448de47f6914c94d1afde Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Thu, 26 Mar 2026 11:46:28 -0300 Subject: [PATCH 52/98] Removes unused version verify Signed-off-by: Flavia Beo --- tests/models/hf/test_granite_hf.py | 2 -- tests/models/hf/test_llama_hf.py | 18 ------------------ 2 files changed, 20 deletions(-) diff --git a/tests/models/hf/test_granite_hf.py b/tests/models/hf/test_granite_hf.py index a5a0cec07..66d61f416 100644 --- a/tests/models/hf/test_granite_hf.py +++ b/tests/models/hf/test_granite_hf.py @@ -21,8 +21,6 @@ from fms.testing._internal.model_test_suite import ModelFixtureMixin from ..test_granite import GraniteFixtures -from packaging.version import Version -from transformers import __version__ as tf_version class GraniteHFFixtures(ModelFixtureMixin, HFConfigFixtureMixin, HFModelFixtureMixin): diff --git a/tests/models/hf/test_llama_hf.py b/tests/models/hf/test_llama_hf.py index 736388e85..63c90854c 100644 --- a/tests/models/hf/test_llama_hf.py +++ b/tests/models/hf/test_llama_hf.py @@ -21,8 +21,6 @@ from fms.testing._internal.model_test_suite import ModelFixtureMixin from ..test_llama import LLaMA2Fixtures, LLaMA2GQAFixtures -from packaging.version import Version -from transformers import __version__ as tf_version class LLaMA2HFFixtures(ModelFixtureMixin, HFConfigFixtureMixin, HFModelFixtureMixin): @@ -90,22 +88,6 @@ class TestLLaMA2GQAHF( - model generation tests """ - @pytest.mark.skipif( - Version(tf_version) >= Version("5.0.0"), - reason="Transformers latest versions have changed API signatures", - ) - def test_hf_model_round_trip_equivalence(self, fms_hf_model, fms_hf_config): - pass - - @pytest.mark.skipif( - Version(tf_version) >= Version("5.0.0"), - reason="Transformers latest versions have changed API signatures", - ) - def test_hf_from_fms_and_hf_from_pretrained_equivalence( - self, tmpdir_factory, model, fms_hf_model - ): - pass - # implementation of abstract property _hf_specific_params _hf_specific_params = ["eos_token_id", "bos_token_id"] # implementation of abstract property _get_hf_signature_params From 915f2888d60b7506306956e20635799c3e100268 Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Thu, 26 Mar 2026 13:46:02 -0300 Subject: [PATCH 53/98] Fixes load state dict with right keys for roberta - Makes it retro-compatible with 4.57.6 Signed-off-by: Flavia Beo --- fms/models/hf/roberta/modeling_roberta_hf.py | 100 +++++++++++++++++-- tests/models/hf/test_roberta_hf.py | 19 ---- 2 files changed, 90 insertions(+), 29 deletions(-) diff --git a/fms/models/hf/roberta/modeling_roberta_hf.py b/fms/models/hf/roberta/modeling_roberta_hf.py index b887314fb..c50824b74 100644 --- a/fms/models/hf/roberta/modeling_roberta_hf.py +++ b/fms/models/hf/roberta/modeling_roberta_hf.py @@ -123,16 +123,6 @@ class HFAdaptedRoBERTaHeadless(HFEncoderModelArchitecture): config_class = HFAdaptedRoBERTaConfig base_model_prefix = "hf_adapted_roberta" - ## Address transformers API changes - if Version(tf_version) >= Version("5.0.0"): - _tied_weights_keys = { - "encoder.model.embedding.weight": "roberta.embeddings.word_embeddings.weight", - "embedding.weight": "embedding.weight", - "lm_head.head.weight": "lm_head.head.weight", - } - else: - _tied_weights_keys = ["encoder.model.embedding.weight", "embedding.weight"] - def __init__( self, config: PretrainedConfig, @@ -154,6 +144,24 @@ def __init__( class HFAdaptedRoBERTaForMaskedLM(MaskedLMHeadMixin, HFAdaptedRoBERTaHeadless): + ## Address transformers API changes + if Version(tf_version) >= Version("5.0.0"): + _keys_to_ignore_on_load_missing = [ + r"encoder\.model\.embedding\.weight", + ] + _tied_weights_keys = { + "encoder.model.embedding.weight": "embedding.weight", + "lm_head.head.weight": "embedding.weight", + } + else: + _keys_to_ignore_on_load_missing = [ + r"encoder\.model\.embedding\.weight", + r"lm_head\.head\.weight", + ] + # For transformers < 5.0.0, set to empty list to disable automatic tying + # We'll handle tying manually in load_state_dict + _tied_weights_keys = [] + def __init__(self, config: HFAdaptedRoBERTaConfig, *args, **kwargs): super().__init__( config=config, @@ -163,6 +171,78 @@ def __init__(self, config: HFAdaptedRoBERTaConfig, *args, **kwargs): **kwargs, ) + def state_dict(self, *args, **kwargs): + """Override to exclude tied weights from state_dict. + + This prevents saving duplicate embeddings. The tied weights will be + restored during load via load_state_dict(). + """ + state_dict = super().state_dict(*args, **kwargs) + # Remove encoder.model.embedding.weight as it's tied to embedding.weight + if "encoder.model.embedding.weight" in state_dict: + del state_dict["encoder.model.embedding.weight"] + # For transformers < 5.0.0, also remove lm_head.head.weight if tied + if Version(tf_version) < Version("5.0.0") and self.config.tie_word_embeddings: + if "lm_head.head.weight" in state_dict: + del state_dict["lm_head.head.weight"] + return state_dict + + def load_state_dict(self, state_dict, strict=True, assign=False): + """Override to handle missing encoder.model.embedding.weight and ensure weights are tied after loading""" + # If encoder.model.embedding.weight is missing from state_dict, that's expected + # because we exclude it in state_dict(). It will be tied to embedding.weight. + # So we load with strict=False first + result = super().load_state_dict(state_dict, strict=False, assign=assign) + + # Filter out the expected missing keys from the result + expected_missing_keys = ["encoder.model.embedding.weight"] + if self.config.tie_word_embeddings: + expected_missing_keys.append("lm_head.head.weight") + filtered_missing_keys = [ + k for k in result.missing_keys if k not in expected_missing_keys + ] + + # Manually tie the weights after loading + if self.config.tie_word_embeddings: + # Tie encoder.model.embedding to embedding + if ( + hasattr(self, "encoder") + and hasattr(self.encoder, "model") + and hasattr(self.encoder.model, "embedding") + ): + self.encoder.model.embedding.weight = self.embedding.weight + # Tie lm_head to embedding + if hasattr(self, "lm_head") and hasattr(self.lm_head, "head"): + self.lm_head.head.weight = self.embedding.weight + + # If strict mode was requested and there are still missing/unexpected keys, raise an error + if strict and (filtered_missing_keys or result.unexpected_keys): + error_msgs = [] + if result.unexpected_keys: + error_msgs.append( + f"Unexpected key(s) in state_dict: {', '.join(result.unexpected_keys)}" + ) + if filtered_missing_keys: + error_msgs.append( + f"Missing key(s) in state_dict: {', '.join(filtered_missing_keys)}" + ) + if error_msgs: + raise RuntimeError( + f"Error(s) in loading state_dict for {self.__class__.__name__}:\n\t" + + "\n\t".join(error_msgs) + ) + + # Manually tie encoder.model.embedding.weight to embedding.weight after loading + if self.encoder.model.embedding is not self.embedding: + self.encoder.model.embedding.weight = self.embedding.weight + + # If tie_word_embeddings is True, also tie lm_head to embedding + # Only tie if lm_head.head.weight was not in the state_dict (respects separate weights) + if self.config.tie_word_embeddings and "lm_head.head.weight" not in state_dict: + self.lm_head.head.weight = self.embedding.weight + + return result + @classmethod def _hf_model_from_fms( cls, model: RoBERTa, config: HFAdaptedRoBERTaConfig diff --git a/tests/models/hf/test_roberta_hf.py b/tests/models/hf/test_roberta_hf.py index 3359d61e3..b98e501fd 100644 --- a/tests/models/hf/test_roberta_hf.py +++ b/tests/models/hf/test_roberta_hf.py @@ -25,9 +25,6 @@ from ..test_roberta import RoBERTaFixtures -from packaging.version import Version -from transformers import __version__ as tf_version - class HFAdaptedRoBERTaFixtures( ModelFixtureMixin, HFConfigFixtureMixin, HFModelFixtureMixin @@ -151,22 +148,6 @@ class TestHFAdaptedRoBERTa( - model generation tests """ - @pytest.mark.skipif( - Version(tf_version) >= Version("5.0.0"), - reason="Transformers latest versions have changed API signatures", - ) - def test_hf_model_round_trip_equivalence(self, fms_hf_model, fms_hf_config): - pass - - @pytest.mark.skipif( - Version(tf_version) >= Version("5.0.0"), - reason="Transformers latest versions have changed API signatures", - ) - def test_hf_from_fms_and_hf_from_pretrained_equivalence( - self, tmpdir_factory, model, fms_hf_model - ): - pass - # implementation of abstract property _hf_specific_params _hf_specific_params = ["eos_token_id", "bos_token_id"] # implementation of abstract property _get_hf_signature_params From 744201e41259bae04ad5a319cf243fda828564b4 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Thu, 26 Mar 2026 17:43:23 +0000 Subject: [PATCH 54/98] Add tokenization caching to batched loglikelihood evaluation Signed-off-by: Yannick Schnider --- fms/utils/evaluation.py | 40 ++++++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/fms/utils/evaluation.py b/fms/utils/evaluation.py index 602f00f80..7f2d9a5df 100644 --- a/fms/utils/evaluation.py +++ b/fms/utils/evaluation.py @@ -1,5 +1,6 @@ from typing import List, Tuple +import functools import logging import time import torch @@ -15,6 +16,28 @@ logger = logging.getLogger(__name__) +# Module-level cache for tokenization to avoid self being part of cache key +@functools.lru_cache(maxsize=None) +def _tokenize_cached( + tokenizer: tokenizers.BaseTokenizer, + context: str, + continuation: str +) -> Tuple[List[int], List[int], List[int]]: + """Tokenize context and continuation strings - cached implementation.""" + context_ids = tokenizer.convert_tokens_to_ids( + tokenizer.tokenize(context) + ) + if not len(context_ids): + context_ids = [tokenizer.bos_token_id] + + continuation_ids = tokenizer.convert_tokens_to_ids( + tokenizer.tokenize(continuation) + ) + input_ids = context_ids + continuation_ids[:-1] + + return context_ids, continuation_ids, input_ids + + @register_model("fms") class FMSEvalHarnessLM(LM): def __init__( @@ -45,18 +68,11 @@ def generic_object(): def _tokenize( self, context: str, continuation: str ) -> Tuple[List[int], List[int], List[int]]: - context_ids = self.tokenizer.convert_tokens_to_ids( - self.tokenizer.tokenize(context) - ) - if not len(context_ids): - context_ids = [self.tokenizer.bos_token_id] - - continuation_ids = self.tokenizer.convert_tokens_to_ids( - self.tokenizer.tokenize(continuation) - ) - input_ids = context_ids + continuation_ids[:-1] - - return context_ids, continuation_ids, input_ids + """Tokenize context and continuation strings. + + Cached to avoid redundant tokenization when sorting requests by length. + """ + return _tokenize_cached(self.tokenizer, context, continuation) def loglikelihood( self, From 665b257ce219ff26200b18a69f664e1fd8d29ee4 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Thu, 26 Mar 2026 17:46:25 +0000 Subject: [PATCH 55/98] simplification Signed-off-by: Yannick Schnider --- fms/utils/evaluation.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/fms/utils/evaluation.py b/fms/utils/evaluation.py index 7f2d9a5df..643d27de9 100644 --- a/fms/utils/evaluation.py +++ b/fms/utils/evaluation.py @@ -106,12 +106,7 @@ def _req_len(x): # getting the pad token id and default to EOS token id if no pad id found # Note: is safe because padding tokens are never attended since masked out - pad_id = getattr(self.tokenizer, "pad_token_id", None) - if pad_id is None: - logger.warning( - "pad_token_id not provided for this tokenizer, defaulting to eos_token_id." - ) - pad_id = getattr(self.tokenizer, "eos_token_id") + pad_id = getattr(self.tokenizer, "pad_token_id", getattr(self.tokenizer, "eos_token_id")) # looping over batches for start in tqdm.tqdm(range(0, len(indexed_requests), self.batch_size)): From 9c0848e441d4a0c3e10a478b5948174da5cc8c79 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Thu, 26 Mar 2026 17:53:30 +0000 Subject: [PATCH 56/98] Apply ruff formatting Signed-off-by: Yannick Schnider --- fms/utils/evaluation.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/fms/utils/evaluation.py b/fms/utils/evaluation.py index 643d27de9..21fdd7c1c 100644 --- a/fms/utils/evaluation.py +++ b/fms/utils/evaluation.py @@ -19,20 +19,14 @@ # Module-level cache for tokenization to avoid self being part of cache key @functools.lru_cache(maxsize=None) def _tokenize_cached( - tokenizer: tokenizers.BaseTokenizer, - context: str, - continuation: str + tokenizer: tokenizers.BaseTokenizer, context: str, continuation: str ) -> Tuple[List[int], List[int], List[int]]: """Tokenize context and continuation strings - cached implementation.""" - context_ids = tokenizer.convert_tokens_to_ids( - tokenizer.tokenize(context) - ) + context_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(context)) if not len(context_ids): context_ids = [tokenizer.bos_token_id] - continuation_ids = tokenizer.convert_tokens_to_ids( - tokenizer.tokenize(continuation) - ) + continuation_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(continuation)) input_ids = context_ids + continuation_ids[:-1] return context_ids, continuation_ids, input_ids @@ -69,7 +63,7 @@ def _tokenize( self, context: str, continuation: str ) -> Tuple[List[int], List[int], List[int]]: """Tokenize context and continuation strings. - + Cached to avoid redundant tokenization when sorting requests by length. """ return _tokenize_cached(self.tokenizer, context, continuation) @@ -106,7 +100,9 @@ def _req_len(x): # getting the pad token id and default to EOS token id if no pad id found # Note: is safe because padding tokens are never attended since masked out - pad_id = getattr(self.tokenizer, "pad_token_id", getattr(self.tokenizer, "eos_token_id")) + pad_id = getattr( + self.tokenizer, "pad_token_id", getattr(self.tokenizer, "eos_token_id") + ) # looping over batches for start in tqdm.tqdm(range(0, len(indexed_requests), self.batch_size)): From 0a7d5a4877ea1cb776b8b148165dfce5562f5b54 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Thu, 26 Mar 2026 20:41:37 +0000 Subject: [PATCH 57/98] :construction: Add ministral3 registration to support for 4.57.6 transformers Signed-off-by: Gaurav-Kumbhat --- fms/models/hf/__init__.py | 44 +++++++++++++++++++++++++++++++++++++++ fms/models/ministral3.py | 3 +-- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/fms/models/hf/__init__.py b/fms/models/hf/__init__.py index 77731cf6c..f9f879bfc 100644 --- a/fms/models/hf/__init__.py +++ b/fms/models/hf/__init__.py @@ -1,5 +1,8 @@ # type: ignore from torch import nn + +from packaging.version import Version + from fms.models.hf.modeling_hf_adapter import HFModelArchitecture from fms.models.gpt_bigcode import GPTBigCode, GPTBigCodeHeadless from fms.models.hf.gpt_bigcode import HFAdaptedGPTBigCodeForCausalLM @@ -31,6 +34,21 @@ from fms.models.roberta import RoBERTa, RoBERTaHeadless from fms.models.gpt_oss import GptOss, GptOssHeadless +from transformers import __version__ as tf_version + +# Register Ministral3 if transformers is less than 5.x.x +if Version(tf_version) < Version("5.0.0"): + from transformers import AutoConfig, AutoModelForCausalLM + # This applies FMS serialization adapters (RoPE fix, weight fusion, etc.) + from fms.models.hf.ministral3 import ( + HFAdaptedMinistral3Config, + HFAdaptedMinistral3ForCausalLM, + HFAdaptedMinistral3Headless, + ) + + # Register FMS adapter + AutoConfig.register("ministral3", HFAdaptedMinistral3Config) + AutoModelForCausalLM.register(HFAdaptedMinistral3Config, HFAdaptedMinistral3ForCausalLM) """ mapping from an FMS model to its equivalent HF-Adapted model @@ -49,6 +67,18 @@ MixtralHeadless: HFAdaptedMixtralHeadless, } +# Add Ministral3 if FMS adapter is available +try: + from fms.models.ministral3 import Ministral3, Ministral3Text + from fms.models.hf.ministral3 import ( + HFAdaptedMinistral3ForCausalLM, + HFAdaptedMinistral3Headless, + ) + _fms_to_hf_adapt_map[Ministral3Text] = HFAdaptedMinistral3ForCausalLM + _fms_to_hf_adapt_map[Ministral3] = HFAdaptedMinistral3ForCausalLM +except ImportError: + pass + """ list of all headless base HF-Adapted models used in registration """ @@ -61,6 +91,13 @@ HFAdaptedGptOssHeadless, ] +# Add Ministral3 headless if available +try: + from fms.models.hf.ministral3 import HFAdaptedMinistral3Headless + _headless_models.append(HFAdaptedMinistral3Headless) +except ImportError: + pass + """ list of all causal-lm HF-Adapted models used in registration """ @@ -72,6 +109,13 @@ HFAdaptedGptOssForCausalLM, ] +# Add Ministral3 causal LM if available +try: + from fms.models.hf.ministral3 import HFAdaptedMinistral3ForCausalLM + _causal_lm_models.append(HFAdaptedMinistral3ForCausalLM) +except ImportError: + pass + """ list of all masked-lm HF-Adapted models used in registration """ diff --git a/fms/models/ministral3.py b/fms/models/ministral3.py index 18369650f..c4e5493c0 100644 --- a/fms/models/ministral3.py +++ b/fms/models/ministral3.py @@ -50,13 +50,12 @@ class Ministral3TextConfig(ModelConfig): max_expected_seq_len: int = 262144 kvheads: int = 8 norm_eps: float = 1e-5 - sliding_window: int = 4000 # null for ministral3 in the model itself + sliding_window: Optional[int] = None # null for ministral3 in the model itself rope_parameters: Dict = field(default_factory=dict) fused_weights: bool = True # FMS Specific -- For CPU/GPU = T, AIU = F pad_id: int = -1 # borrowed from granite, we do need it linear_config: Optional[Mapping[str, Any]] = None # To support quantization - @dataclass class Ministral3Config(ModelConfig): """ From 9297282aac23c0e66a355d6876c780a79e574ca0 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Thu, 26 Mar 2026 20:47:35 +0000 Subject: [PATCH 58/98] :recycle: Remove pad id change from this branch Signed-off-by: Gaurav-Kumbhat --- fms/utils/generation.py | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/fms/utils/generation.py b/fms/utils/generation.py index ae749aefc..51d5f3064 100644 --- a/fms/utils/generation.py +++ b/fms/utils/generation.py @@ -19,7 +19,6 @@ def pad_input_ids( is_causal_mask=True, padding_side="left", position_ids_offset=0, - pad_token_id=0, ) -> Tuple[torch.Tensor, MutableMapping[str, Any]]: """ Convert a list of Tensors to a rectangular tensor. Return extra padding kwargs for the position_ids and mask, since @@ -35,8 +34,7 @@ def pad_input_ids( position_ids_offset: int some models are trained with position_ids that do not start at 0 but at pad_id + 1. The default parameter here will work for most models, but for example MPNet requires passing a real pad_id. - pad_token_id: int - the token ID to use for padding. Default is 0. + Returns ------- Tuple[torch.Tensor, MutableMapping[str, Any]] @@ -51,34 +49,25 @@ def pad_input_ids( position_ids_list = [] for input_ids_i in input_ids_list: seq_len = input_ids_i.size(0) - pads = torch.full( - (max_len - seq_len,), - pad_token_id, - dtype=torch.long, - device=input_ids_i.device, + pads = torch.zeros( + max_len - seq_len, dtype=torch.long, device=input_ids_i.device ) non_pads = torch.ones(seq_len, dtype=torch.bool, device=input_ids_i.device) # Setting this to 0, however if 0 is the eos, we will end up truncating the output if using truncate_after_eos # once this workflow works for nested tensor, this can probably be removed - pos_ids_pads = torch.zeros( - max_len - seq_len, dtype=torch.long, device=input_ids_i.device - ) + pos_ids_pads = pads pos_ids_seq = torch.arange( 0, seq_len, dtype=torch.long, device=input_ids_i.device ) if padding_side == "left": padded_input_ids_list.append(torch.cat((pads, input_ids_i))) - mask_list.append( - torch.cat((pads.bool(), non_pads)) - ) # This will be False for pad tokens + mask_list.append(torch.cat((pads.bool(), non_pads))) position_ids_list.append(torch.cat((pos_ids_pads, pos_ids_seq))) elif padding_side == "right": padded_input_ids_list.append(torch.cat((input_ids_i, pads))) - mask_list.append( - torch.cat((non_pads, pads.bool())) - ) # This will be False for pad tokens + mask_list.append(torch.cat((non_pads, pads.bool()))) position_ids_list.append(torch.cat((pos_ids_seq, pos_ids_pads))) else: raise NotImplementedError("padding_side must be 'right' or left'") From ed380c3a2baf671c8bff90a32b7b43b8d2939c26 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Thu, 26 Mar 2026 21:02:18 +0000 Subject: [PATCH 59/98] :white_check_mark: Add test for HF equivalency Signed-off-by: Gaurav-Kumbhat --- .../models/hf_equivalence/test_ministral3.py | 122 ++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 tests/models/hf_equivalence/test_ministral3.py diff --git a/tests/models/hf_equivalence/test_ministral3.py b/tests/models/hf_equivalence/test_ministral3.py new file mode 100644 index 000000000..1d62b6be2 --- /dev/null +++ b/tests/models/hf_equivalence/test_ministral3.py @@ -0,0 +1,122 @@ +from datetime import datetime, timedelta +import os +import pytest +import torch +import requests +import warnings + +from fms.models import get_model +from fms.utils.generation import generate, pad_input_ids + +from packaging.version import Version + +device = "cpu" + +def _get_inputs(processor, model_path): + from PIL import Image + + # Load system prompt else, error out to make sure we test with right system prompt + url = "https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/europe.png" + images = [Image.open(requests.get(url, stream=True).raw)] + + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What is this? Answer in one sentence.", + }, + {"type": "image"}, + ], + }, + ] + # Apply chat template and process inputs + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + inputs = processor(text=text, images=images, return_tensors="pt").to( + device + ) + return inputs + + +def _get_hf_model_output(model_path, inputs, max_new_tokens=6): + from transformers import AutoModelForImageTextToText + + model = AutoModelForImageTextToText.from_pretrained( + model_path, torch_dtype=torch.bfloat16 + ).to(device) + model.eval() + with torch.no_grad(): + output = model.generate( + **inputs, max_new_tokens=max_new_tokens, use_cache=True, do_sample=False + ) + return output + + +def _get_fms_model_output(model_path, inputs, max_new_tokens=6): + model = get_model( + "hf_pretrained", + model_path, + data_type=torch.bfloat16, + device_type=device, + ) + model.eval() + torch.set_grad_enabled(False) + + inputs["only_last_token"] = True + inputs["attn_name"] = "sdpa_causal" + input_ids = inputs.pop("input_ids") + input_ids, padding_kwargs = pad_input_ids(input_ids, min_pad_length=0) + inputs["mask"] = padding_kwargs["mask"].to(device) + inputs["position_ids"] = padding_kwargs["position_ids"].to(device) + input_ids = input_ids.to(device) + + with torch.no_grad(): + output = generate( + model, + input_ids, + max_new_tokens=max_new_tokens, + use_cache=True, + do_sample=False, + max_seq_len=model.config.text_config.max_expected_seq_len, + extra_kwargs=inputs, + prepare_model_inputs_hook=model.prepare_inputs_for_generation, + ) + + return output + + +@pytest.mark.slow +def test_ministral3_8b_equivalence(): + from transformers import __version__ as tf_version + from transformers import AutoProcessor + + if Version(tf_version) < Version("5.0.0"): + warnings.warn(f"This test requires transformers version > 5.0.0. Installed version {tf_version}. Skipping this test!") + return + + # for now, this test won't be run, but it has been verified + # if you would like to try this, set model_path to the HF model path + # for ministral-3 + + model_path = "/path/to/mistralai/Ministral-3-14B-Reasoning-2512/" + # NOTE: Ministral-3-8B-Instruct-2512-BF16 model doesn't come with its own processor + # You can use mistralai/Ministral-3-14B-Reasoning-2512 in that case + + # model_path = "" + processor = AutoProcessor.from_pretrained(model_path) + + # Get inputs with the model path for system prompt loading + inputs = _get_inputs(processor, model_path) + + hf_model_output = _get_hf_model_output(model_path, inputs) + fms_model_output = _get_fms_model_output(model_path, inputs) + + # Expected result: `This is a map of Europe` + torch.testing.assert_close(fms_model_output, hf_model_output) + + +if __name__ == "__main__": + test_ministral3_8b_equivalence() From 68a0235fbe4f98326c065af5fba01b67dee16dc2 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Thu, 26 Mar 2026 21:20:29 +0000 Subject: [PATCH 60/98] :recycle: Refactor to reduce code duplicacy Signed-off-by: Gaurav-Kumbhat --- fms/models/ministral3.py | 98 ++-------------------------------------- 1 file changed, 5 insertions(+), 93 deletions(-) diff --git a/fms/models/ministral3.py b/fms/models/ministral3.py index c4e5493c0..91f7bdb59 100644 --- a/fms/models/ministral3.py +++ b/fms/models/ministral3.py @@ -24,8 +24,8 @@ ) from fms.modules.feedforward import GatedLinearUnit from fms.modules.layernorm import LayerNormParameterized -from fms.modules.positions import CachedYarnRotaryEmbedding, RotaryEmbedding -from fms.models.mistral import MistralBlock +from fms.modules.positions import CachedYarnRotaryEmbedding +from fms.models.mistral import Mistral, MistralBlock, MistralHeadless from fms.models.mistral3 import Mistral3, Mistral3MultiModalProjector from fms.models.pixtral_vision import PixtralVisionConfig, PixtralVisionModel @@ -56,6 +56,7 @@ class Ministral3TextConfig(ModelConfig): pad_id: int = -1 # borrowed from granite, we do need it linear_config: Optional[Mapping[str, Any]] = None # To support quantization + @dataclass class Ministral3Config(ModelConfig): """ @@ -83,7 +84,7 @@ class Ministral3Config(ModelConfig): # =============== Modeling ====================== -class Ministral3Headless(nn.Module): +class Ministral3Headless(MistralHeadless): def __init__( self, config: Ministral3TextConfig, @@ -184,52 +185,8 @@ def post_init(self): ): self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) - def forward( - self, - x_in, - position_ids=None, - past_key_value_states=None, - use_cache=False, - **attn_kwargs: Unpack[AttentionKwargs], - ): - # Embed the given vocabulary indices using the given attention mask, with pre-/post-norm and dropout as specified - # x_in: batch_size x seq_len - # mask: batch_size x seq_len x seq_len - # bias: nheads x seq_len x seq_len - if past_key_value_states is None or len(past_key_value_states) == 0: - past_key_value_states = [None for _ in range(len(self.layers))] - - if x_in.dim() == 2: # input is not already embedded - x_in = self.embedding(x_in) - - # this is the output cache for all the decoder layers - present_key_value_states = [] - - for i, layer in enumerate(self.layers): - output = layer( - x=x_in, - position_ids=position_ids, - past_key_value_state=past_key_value_states[i], - use_cache=use_cache, - **attn_kwargs, - ) - - if use_cache: - x_in, present_key_value_state = output - present_key_value_states.append(present_key_value_state) - - else: - x_in = output - - dec_out = x_in - dec_out = self.dec_norm(dec_out) - if self.config.p_dropout: - dec_out = self.dropout(dec_out) - - return dec_out, present_key_value_states - -class Ministral3Text(nn.Module): +class Ministral3Text(Mistral): def __init__( self, config: Optional[Ministral3TextConfig] = None, @@ -256,51 +213,6 @@ def from_config(cls, config: Ministral3TextConfig) -> "Ministral3": def get_config(self) -> Ministral3TextConfig: return self.config - def reset_parameters(self): - self.head.weight.data.normal_( - 0, - 1 / math.sqrt(math.sqrt(self.config.emb_dim * self.config.src_vocab_size)), - ) - self.base_model.reset_parameters() - - def post_init(self): - # if this model ties weights, they are tied here - if self.config.tie_heads: - # handle assignment of non-meta weights to meta parameters - if self.head.weight.device == torch.device("meta"): - self.head.weight = self.base_model.embedding.weight - else: - self.base_model.embedding.weight = self.head.weight - - self.base_model.post_init() - - def forward( - self, - x: torch.LongTensor, - position_ids: Optional[torch.LongTensor] = None, - past_key_value_states: Optional[Tuple[torch.FloatTensor,]] = None, - use_cache: bool = False, - last_n_tokens: int = 0, - **attn_kwargs: Unpack[AttentionKwargs], - ): - get_attention_type(**attn_kwargs)["validate_attn_kwargs"]( - input_ids=x, - position_ids=position_ids, - past_key_value_states=past_key_value_states, - **attn_kwargs, - ) - output, cache = self.base_model( - x, position_ids, past_key_value_states, use_cache, **attn_kwargs - ) - - output = gather_outputs(output, last_n_tokens, **attn_kwargs) - preds = self.head(output) - - if use_cache: - return preds, cache - else: - return preds - class Ministral3(Mistral3): def __init__( From 4a2ab7f23467fd4b7274e1b8c44c2342f4e7bcdb Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Thu, 26 Mar 2026 21:20:51 +0000 Subject: [PATCH 61/98] :art: Format with ruff Signed-off-by: Gaurav-Kumbhat --- fms/models/hf/__init__.py | 8 +++++++- tests/models/hf_equivalence/test_ministral3.py | 9 +++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/fms/models/hf/__init__.py b/fms/models/hf/__init__.py index f9f879bfc..bed9df077 100644 --- a/fms/models/hf/__init__.py +++ b/fms/models/hf/__init__.py @@ -39,6 +39,7 @@ # Register Ministral3 if transformers is less than 5.x.x if Version(tf_version) < Version("5.0.0"): from transformers import AutoConfig, AutoModelForCausalLM + # This applies FMS serialization adapters (RoPE fix, weight fusion, etc.) from fms.models.hf.ministral3 import ( HFAdaptedMinistral3Config, @@ -48,7 +49,9 @@ # Register FMS adapter AutoConfig.register("ministral3", HFAdaptedMinistral3Config) - AutoModelForCausalLM.register(HFAdaptedMinistral3Config, HFAdaptedMinistral3ForCausalLM) + AutoModelForCausalLM.register( + HFAdaptedMinistral3Config, HFAdaptedMinistral3ForCausalLM + ) """ mapping from an FMS model to its equivalent HF-Adapted model @@ -74,6 +77,7 @@ HFAdaptedMinistral3ForCausalLM, HFAdaptedMinistral3Headless, ) + _fms_to_hf_adapt_map[Ministral3Text] = HFAdaptedMinistral3ForCausalLM _fms_to_hf_adapt_map[Ministral3] = HFAdaptedMinistral3ForCausalLM except ImportError: @@ -94,6 +98,7 @@ # Add Ministral3 headless if available try: from fms.models.hf.ministral3 import HFAdaptedMinistral3Headless + _headless_models.append(HFAdaptedMinistral3Headless) except ImportError: pass @@ -112,6 +117,7 @@ # Add Ministral3 causal LM if available try: from fms.models.hf.ministral3 import HFAdaptedMinistral3ForCausalLM + _causal_lm_models.append(HFAdaptedMinistral3ForCausalLM) except ImportError: pass diff --git a/tests/models/hf_equivalence/test_ministral3.py b/tests/models/hf_equivalence/test_ministral3.py index 1d62b6be2..2be03e432 100644 --- a/tests/models/hf_equivalence/test_ministral3.py +++ b/tests/models/hf_equivalence/test_ministral3.py @@ -12,6 +12,7 @@ device = "cpu" + def _get_inputs(processor, model_path): from PIL import Image @@ -35,9 +36,7 @@ def _get_inputs(processor, model_path): text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) - inputs = processor(text=text, images=images, return_tensors="pt").to( - device - ) + inputs = processor(text=text, images=images, return_tensors="pt").to(device) return inputs @@ -94,7 +93,9 @@ def test_ministral3_8b_equivalence(): from transformers import AutoProcessor if Version(tf_version) < Version("5.0.0"): - warnings.warn(f"This test requires transformers version > 5.0.0. Installed version {tf_version}. Skipping this test!") + warnings.warn( + f"This test requires transformers version > 5.0.0. Installed version {tf_version}. Skipping this test!" + ) return # for now, this test won't be run, but it has been verified From b55f238f9e749836afde271e2039e065a6397ffc Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Thu, 26 Mar 2026 22:17:32 +0000 Subject: [PATCH 62/98] :bug: Fix super initialization Signed-off-by: Gaurav-Kumbhat --- fms/models/ministral3.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fms/models/ministral3.py b/fms/models/ministral3.py index 91f7bdb59..e60cd801b 100644 --- a/fms/models/ministral3.py +++ b/fms/models/ministral3.py @@ -84,13 +84,13 @@ class Ministral3Config(ModelConfig): # =============== Modeling ====================== -class Ministral3Headless(MistralHeadless): +class Ministral3Headless(MistralHeadless, nn.Module): def __init__( self, config: Ministral3TextConfig, distributed_strategy: DistributedStrategy = NoOpStrategy, ): - super(Ministral3Headless, self).__init__() + nn.Module.__init__(self) self.config = config self.distributed_strategy = distributed_strategy @@ -186,14 +186,14 @@ def post_init(self): self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) -class Ministral3Text(Mistral): +class Ministral3Text(Mistral, nn.Module): def __init__( self, config: Optional[Ministral3TextConfig] = None, distributed_strategy: DistributedStrategy = NoOpStrategy, **kwargs, ): - super(Ministral3Text, self).__init__() + nn.Module.__init__(self) if config is not None: self.config = config else: From c627d3b6635be28a147e23cb049eaf8f0fda0885 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Fri, 27 Mar 2026 11:16:39 +0000 Subject: [PATCH 63/98] simplify padding Signed-off-by: Yannick Schnider --- fms/utils/evaluation.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/fms/utils/evaluation.py b/fms/utils/evaluation.py index 21fdd7c1c..b75ac4cc0 100644 --- a/fms/utils/evaluation.py +++ b/fms/utils/evaluation.py @@ -98,12 +98,6 @@ def _req_len(x): results_with_idx: List[Tuple[int, Tuple[float, bool]]] = [] - # getting the pad token id and default to EOS token id if no pad id found - # Note: is safe because padding tokens are never attended since masked out - pad_id = getattr( - self.tokenizer, "pad_token_id", getattr(self.tokenizer, "eos_token_id") - ) - # looping over batches for start in tqdm.tqdm(range(0, len(indexed_requests), self.batch_size)): batch = indexed_requests[start : start + self.batch_size] @@ -124,12 +118,14 @@ def _req_len(x): input_ids_list.append(torch.tensor(input_ids_raw, dtype=torch.long)) orig_indices.append(orig_idx) - # pad inputs ids + # pad input ids with token id 0 to create a rectangular batch tensor + # Note: we can use any token id here as the padding part of the logits + # will be cut during post-processing before computing the log likelihood max_len = max(x.size(0) for x in input_ids_list) input_ids_batch: torch.Tensor = torch.full( (len(batch), max_len), - pad_id, + 0, dtype=torch.long, device=self.device, ) From 7b147f480de192cba525cde54fdeb2835081cbf4 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Fri, 27 Mar 2026 09:04:10 -0500 Subject: [PATCH 64/98] :sparkles: Add ministral3 HF model porting Signed-off-by: Gaurav-Kumbhat --- fms/models/hf/ministral3/__init__.py | 15 ++ .../ministral3/configuration_ministral3_hf.py | 74 +++++++ .../hf/ministral3/modeling_ministral3_hf.py | 205 ++++++++++++++++++ 3 files changed, 294 insertions(+) create mode 100644 fms/models/hf/ministral3/__init__.py create mode 100644 fms/models/hf/ministral3/configuration_ministral3_hf.py create mode 100644 fms/models/hf/ministral3/modeling_ministral3_hf.py diff --git a/fms/models/hf/ministral3/__init__.py b/fms/models/hf/ministral3/__init__.py new file mode 100644 index 000000000..f20dbe90c --- /dev/null +++ b/fms/models/hf/ministral3/__init__.py @@ -0,0 +1,15 @@ +from fms.models.hf.ministral3.modeling_ministral3_hf import ( + HFAdaptedMinistral3ForCausalLM, + HFAdaptedMinistral3Headless, +) +from fms.models.hf.ministral3.configuration_ministral3_hf import ( + HFAdaptedMinistral3Config, +) + +__all__ = [ + "HFAdaptedMinistral3ForCausalLM", + "HFAdaptedMinistral3Headless", + "HFAdaptedMinistral3Config", +] + +# Made with Bob diff --git a/fms/models/hf/ministral3/configuration_ministral3_hf.py b/fms/models/hf/ministral3/configuration_ministral3_hf.py new file mode 100644 index 000000000..f682355f4 --- /dev/null +++ b/fms/models/hf/ministral3/configuration_ministral3_hf.py @@ -0,0 +1,74 @@ +from transformers import PretrainedConfig + + +class HFAdaptedMinistral3Config(PretrainedConfig): + """ + Configuration class for HF-adapted Ministral3 model. + + This config wraps the FMS Ministral3 configuration to make it compatible + with HuggingFace's AutoModel system. + """ + + model_type = "ministral3" + + def __init__( + self, + vocab_size=131072, + hidden_size=5120, + intermediate_size=16384, + num_hidden_layers=40, + num_attention_heads=32, + num_key_value_heads=8, + head_dim=128, + max_position_embeddings=262144, + rms_norm_eps=1e-5, + sliding_window=4000, + attention_dropout=0.0, + pad_token_id=-1, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_parameters=None, + **kwargs + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.max_position_embeddings = max_position_embeddings + self.rms_norm_eps = rms_norm_eps + self.sliding_window = sliding_window + self.attention_dropout = attention_dropout + self.rope_parameters = rope_parameters + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs + ) + + def to_fms_config(self): + """Convert to FMS Ministral3TextConfig""" + from fms.models.ministral3 import Ministral3TextConfig + + + return Ministral3TextConfig( + src_vocab_size=self.vocab_size, + emb_dim=self.hidden_size, + nheads=self.num_attention_heads, + nlayers=self.num_hidden_layers, + kvheads=self.num_key_value_heads, + head_dim=self.head_dim, + max_expected_seq_len=self.max_position_embeddings, + norm_eps=self.rms_norm_eps, + sliding_window=self.sliding_window, + hidden_grow_factor=self.intermediate_size / self.hidden_size, + pad_id=self.pad_token_id, + rope_parameters=self.rope_parameters, + ) + diff --git a/fms/models/hf/ministral3/modeling_ministral3_hf.py b/fms/models/hf/ministral3/modeling_ministral3_hf.py new file mode 100644 index 000000000..c91d8286f --- /dev/null +++ b/fms/models/hf/ministral3/modeling_ministral3_hf.py @@ -0,0 +1,205 @@ +from typing import Optional, Tuple +import torch +import torch.nn as nn +from transformers import PretrainedConfig +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions + +from fms.models.hf.lm_head_mixins import LMHeadModelLMHeadMixin +from fms.models.hf.modeling_hf_adapter import HFDecoder, HFDecoderModelArchitecture +from fms.models.hf.ministral3.configuration_ministral3_hf import HFAdaptedMinistral3Config +from fms.models.ministral3 import Ministral3Text, Ministral3TextConfig + + +class HFAdaptedMinistral3Decoder(HFDecoder): + """Adapter for the Ministral3 decoder""" + + def __init__(self, model: Ministral3Text, config: PretrainedConfig): + super().__init__(model.base_model, config, attention_mask_dim=3) + self.lm_head = model.head + + def _adapt( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + *args, + **kwargs, + ) -> BaseModelOutputWithPastAndCrossAttentions: + if kwargs.get("mask", None) is None: + kwargs["mask"] = attention_mask + + output = self.model( + x_in=input_ids, + position_ids=position_ids, + past_key_value_states=past_key_values, + use_cache=use_cache, + **kwargs, + ) + + present_key_values = None + if isinstance(output, tuple): + output, present_key_values = output + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=output, past_key_values=present_key_values + ) + + +class HFAdaptedMinistral3Headless(HFDecoderModelArchitecture): + """HF Adapter for Ministral3 that applies FMS serialization""" + + config_class = HFAdaptedMinistral3Config + base_model_prefix = "hf_adapted_ministral3" + + def __init__( + self, + config: PretrainedConfig, + decoder: Optional[nn.Module] = None, + embedding: Optional[nn.Module] = None, + *args, + **kwargs, + ): + if decoder is None or embedding is None: + # Create FMS model from config + if hasattr(config, 'to_fms_config'): + text_config = config.to_fms_config() + else: + # Fallback: create config from dict + params = config.to_dict() + text_config = Ministral3TextConfig( + src_vocab_size=params.get('vocab_size', 131072), + emb_dim=params.get('hidden_size', 5120), + nheads=params.get('num_attention_heads', 32), + nlayers=params.get('num_hidden_layers', 40), + kvheads=params.get('num_key_value_heads', 8), + head_dim=params.get('head_dim', 128), + max_expected_seq_len=params.get('max_position_embeddings', 262144), + norm_eps=params.get('rms_norm_eps', 1e-5), + pad_id=params.get('pad_token_id', -1), + ) + + model = Ministral3Text(text_config) + decoder = model if decoder is None else decoder + embedding = model.base_model.embedding if embedding is None else embedding + + decoder = HFAdaptedMinistral3Decoder(decoder, config) + super().__init__(decoder, embedding, config, *args, **kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): + """ + Override to apply FMS serialization adapters. + + This method loads the model using FMS's serialization system which applies + all necessary adapters (RoPE transformation, weight fusion, name mapping). + """ + import os + from pathlib import Path + from transformers import AutoConfig + from fms import models + from fms.utils import serialization + + # Load config + config = kwargs.get("config") + if config is None: + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) + + # Determine model path + model_path = pretrained_model_name_or_path + if not os.path.exists(model_path): + # Download from HuggingFace Hub + from huggingface_hub import snapshot_download + model_path = snapshot_download( + repo_id=pretrained_model_name_or_path, + allow_patterns=["*.safetensors", "*config.json"], + ignore_patterns=["consolidated.safetensors"], + ) + + # Convert HF config to FMS config + if hasattr(config, 'to_fms_config'): + fms_text_config = config.to_fms_config() + else: + fms_text_config = Ministral3TextConfig() + + # Create empty FMS model + fms_model = Ministral3Text(fms_text_config) + + # Load state dict using FMS serialization (applies all adapters) + state_dict = serialization.load_state_dict( + model_path=Path(model_path), + source="hf", + ) + + # Apply FMS adapters and load into model + serialization.load_state_dict_into_model( + model=fms_model, + state_dict=state_dict, + architecture="ministral3", + source="hf", + dtype=kwargs.get("torch_dtype"), + ) + + # Wrap in HF adapter + return cls( + config=config, + decoder=fms_model, + embedding=fms_model.base_model.embedding, + *args, + **kwargs + ) + + def _prepare_inputs_for_generation( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + **model_kwargs, + ) -> dict: + """ + Overriding _prepare_inputs_for_generation to include position_ids requirements + """ + position_ids = model_kwargs.pop("position_ids", None) + + if position_ids is None and attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) + + # Add more cached rope freqs if over cached number + if hasattr(self.decoder.model, 'rot_emb'): + max_expected_len = input_ids.shape[1] + torch.max(position_ids) + if max_expected_len > self.decoder.model.rot_emb.max_seq_len_cached.get(input_ids.device, 0): + self.decoder.model.rot_emb.compute_freqs_cis( + input_ids.device, max_expected_len + ) + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + **model_kwargs, + } + + +class HFAdaptedMinistral3ForCausalLM(LMHeadModelLMHeadMixin, HFAdaptedMinistral3Headless): + """Ministral3 with LM head for causal language modeling""" + + def __init__(self, config: PretrainedConfig, *args, **kwargs): + decoder = kwargs.pop("decoder", None) + lm_head = None + + if decoder is not None and hasattr(decoder, 'head'): + lm_head = decoder.head + + super().__init__( + config=config, + decoder=decoder, + bias=False, + lm_head=lm_head, + *args, + **kwargs + ) + +# Made with Bob From e749b49d5572ed547942e647a2344d39f366d027 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Fri, 27 Mar 2026 18:07:56 +0000 Subject: [PATCH 65/98] :art: Fix formatting Signed-off-by: Gaurav-Kumbhat --- .../ministral3/configuration_ministral3_hf.py | 6 +-- .../hf/ministral3/modeling_ministral3_hf.py | 50 ++++++++++--------- fms/models/ministral3.py | 7 +-- .../models/hf_equivalence/test_ministral3.py | 2 - 4 files changed, 29 insertions(+), 36 deletions(-) diff --git a/fms/models/hf/ministral3/configuration_ministral3_hf.py b/fms/models/hf/ministral3/configuration_ministral3_hf.py index f682355f4..393307ff1 100644 --- a/fms/models/hf/ministral3/configuration_ministral3_hf.py +++ b/fms/models/hf/ministral3/configuration_ministral3_hf.py @@ -29,7 +29,7 @@ def __init__( eos_token_id=2, tie_word_embeddings=False, rope_parameters=None, - **kwargs + **kwargs, ): self.vocab_size = vocab_size self.hidden_size = hidden_size @@ -49,14 +49,13 @@ def __init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, - **kwargs + **kwargs, ) def to_fms_config(self): """Convert to FMS Ministral3TextConfig""" from fms.models.ministral3 import Ministral3TextConfig - return Ministral3TextConfig( src_vocab_size=self.vocab_size, emb_dim=self.hidden_size, @@ -71,4 +70,3 @@ def to_fms_config(self): pad_id=self.pad_token_id, rope_parameters=self.rope_parameters, ) - diff --git a/fms/models/hf/ministral3/modeling_ministral3_hf.py b/fms/models/hf/ministral3/modeling_ministral3_hf.py index c91d8286f..d9c7dc349 100644 --- a/fms/models/hf/ministral3/modeling_ministral3_hf.py +++ b/fms/models/hf/ministral3/modeling_ministral3_hf.py @@ -6,7 +6,9 @@ from fms.models.hf.lm_head_mixins import LMHeadModelLMHeadMixin from fms.models.hf.modeling_hf_adapter import HFDecoder, HFDecoderModelArchitecture -from fms.models.hf.ministral3.configuration_ministral3_hf import HFAdaptedMinistral3Config +from fms.models.hf.ministral3.configuration_ministral3_hf import ( + HFAdaptedMinistral3Config, +) from fms.models.ministral3 import Ministral3Text, Ministral3TextConfig @@ -62,21 +64,21 @@ def __init__( ): if decoder is None or embedding is None: # Create FMS model from config - if hasattr(config, 'to_fms_config'): + if hasattr(config, "to_fms_config"): text_config = config.to_fms_config() else: # Fallback: create config from dict params = config.to_dict() text_config = Ministral3TextConfig( - src_vocab_size=params.get('vocab_size', 131072), - emb_dim=params.get('hidden_size', 5120), - nheads=params.get('num_attention_heads', 32), - nlayers=params.get('num_hidden_layers', 40), - kvheads=params.get('num_key_value_heads', 8), - head_dim=params.get('head_dim', 128), - max_expected_seq_len=params.get('max_position_embeddings', 262144), - norm_eps=params.get('rms_norm_eps', 1e-5), - pad_id=params.get('pad_token_id', -1), + src_vocab_size=params.get("vocab_size", 131072), + emb_dim=params.get("hidden_size", 5120), + nheads=params.get("num_attention_heads", 32), + nlayers=params.get("num_hidden_layers", 40), + kvheads=params.get("num_key_value_heads", 8), + head_dim=params.get("head_dim", 128), + max_expected_seq_len=params.get("max_position_embeddings", 262144), + norm_eps=params.get("rms_norm_eps", 1e-5), + pad_id=params.get("pad_token_id", -1), ) model = Ministral3Text(text_config) @@ -97,7 +99,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): import os from pathlib import Path from transformers import AutoConfig - from fms import models from fms.utils import serialization # Load config @@ -110,6 +111,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): if not os.path.exists(model_path): # Download from HuggingFace Hub from huggingface_hub import snapshot_download + model_path = snapshot_download( repo_id=pretrained_model_name_or_path, allow_patterns=["*.safetensors", "*config.json"], @@ -117,7 +119,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): ) # Convert HF config to FMS config - if hasattr(config, 'to_fms_config'): + if hasattr(config, "to_fms_config"): fms_text_config = config.to_fms_config() else: fms_text_config = Ministral3TextConfig() @@ -146,7 +148,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): decoder=fms_model, embedding=fms_model.base_model.embedding, *args, - **kwargs + **kwargs, ) def _prepare_inputs_for_generation( @@ -166,9 +168,11 @@ def _prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) # Add more cached rope freqs if over cached number - if hasattr(self.decoder.model, 'rot_emb'): + if hasattr(self.decoder.model, "rot_emb"): max_expected_len = input_ids.shape[1] + torch.max(position_ids) - if max_expected_len > self.decoder.model.rot_emb.max_seq_len_cached.get(input_ids.device, 0): + if max_expected_len > self.decoder.model.rot_emb.max_seq_len_cached.get( + input_ids.device, 0 + ): self.decoder.model.rot_emb.compute_freqs_cis( input_ids.device, max_expected_len ) @@ -183,23 +187,21 @@ def _prepare_inputs_for_generation( } -class HFAdaptedMinistral3ForCausalLM(LMHeadModelLMHeadMixin, HFAdaptedMinistral3Headless): +class HFAdaptedMinistral3ForCausalLM( + LMHeadModelLMHeadMixin, HFAdaptedMinistral3Headless +): """Ministral3 with LM head for causal language modeling""" def __init__(self, config: PretrainedConfig, *args, **kwargs): decoder = kwargs.pop("decoder", None) lm_head = None - if decoder is not None and hasattr(decoder, 'head'): + if decoder is not None and hasattr(decoder, "head"): lm_head = decoder.head super().__init__( - config=config, - decoder=decoder, - bias=False, - lm_head=lm_head, - *args, - **kwargs + config=config, decoder=decoder, bias=False, lm_head=lm_head, *args, **kwargs ) + # Made with Bob diff --git a/fms/models/ministral3.py b/fms/models/ministral3.py index e60cd801b..2a2ccda7a 100644 --- a/fms/models/ministral3.py +++ b/fms/models/ministral3.py @@ -1,9 +1,7 @@ -import math import logging import re from dataclasses import dataclass, field -from typing import Any, Dict, Mapping, Optional, Tuple -from typing_extensions import Unpack +from typing import Any, Dict, Mapping, Optional import torch import torch.nn as nn @@ -16,11 +14,8 @@ from fms.utils.config import ModelConfig from fms.utils import serialization -from fms.utils.headless import gather_outputs from fms.modules.attention import ( - AttentionKwargs, MultiHeadAttention, - get_attention_type, ) from fms.modules.feedforward import GatedLinearUnit from fms.modules.layernorm import LayerNormParameterized diff --git a/tests/models/hf_equivalence/test_ministral3.py b/tests/models/hf_equivalence/test_ministral3.py index 2be03e432..d00451515 100644 --- a/tests/models/hf_equivalence/test_ministral3.py +++ b/tests/models/hf_equivalence/test_ministral3.py @@ -1,5 +1,3 @@ -from datetime import datetime, timedelta -import os import pytest import torch import requests From b648fec4adde85921bf4e32fc9761a02b4e4564e Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Mon, 30 Mar 2026 11:53:43 -0300 Subject: [PATCH 66/98] Fixes gptbigcode HF adapter tests Signed-off-by: Flavia Beo --- fms/models/gpt_bigcode.py | 2 + .../configuration_gpt_bigcode_hf.py | 7 +- .../hf/gpt_bigcode/modeling_gpt_bigcode_hf.py | 86 +++++++++++++++++-- tests/models/hf/test_as_fms_model.py | 1 - .../models/hf_equivalence/test_gpt_bigcode.py | 1 - 5 files changed, 84 insertions(+), 13 deletions(-) diff --git a/fms/models/gpt_bigcode.py b/fms/models/gpt_bigcode.py index 9e933360c..3ebb2ef99 100644 --- a/fms/models/gpt_bigcode.py +++ b/fms/models/gpt_bigcode.py @@ -35,6 +35,8 @@ class GPTBigCodeConfig(ModelConfig): p_dropout: float = 0.0 emb_dropout: float = 0.0 multiquery_attn: bool = True + eos_token_id: int = 49152 + bos_token_id: int = 49152 ln_eps: float = 1e-5 linear_config: Optional[Mapping[str, Any]] = ( None # pass as {"linear_type": str, } diff --git a/fms/models/hf/gpt_bigcode/configuration_gpt_bigcode_hf.py b/fms/models/hf/gpt_bigcode/configuration_gpt_bigcode_hf.py index 09508e2fe..3a4e236e4 100644 --- a/fms/models/hf/gpt_bigcode/configuration_gpt_bigcode_hf.py +++ b/fms/models/hf/gpt_bigcode/configuration_gpt_bigcode_hf.py @@ -56,9 +56,9 @@ def __init__( eos_token_id=eos_token_id, bos_token_id=bos_token_id, is_decoder=is_decoder, - # the default for this model is to tie_heads + # the default for this model is to tie_heads (True in FMS GPTBigCodeConfig) # so set to true if tie_word_embeddings is not given - tie_word_embeddings=kwargs.pop("tie_word_embeddings", False), + tie_word_embeddings=kwargs.pop("tie_word_embeddings", True), **kwargs, ) @@ -76,4 +76,7 @@ def from_pretrained( def from_fms_config(cls, config: GPTBigCodeConfig, **hf_kwargs): config_dict = config.as_dict() config_dict["pad_token_id"] = config_dict.pop("pad_id") + # Set tie_word_embeddings based on tie_heads from FMS config + if "tie_word_embeddings" not in hf_kwargs: + hf_kwargs["tie_word_embeddings"] = config_dict.get("tie_heads", False) return cls.from_dict(config_dict, **hf_kwargs) diff --git a/fms/models/hf/gpt_bigcode/modeling_gpt_bigcode_hf.py b/fms/models/hf/gpt_bigcode/modeling_gpt_bigcode_hf.py index 12bc7affd..de9af26cf 100644 --- a/fms/models/hf/gpt_bigcode/modeling_gpt_bigcode_hf.py +++ b/fms/models/hf/gpt_bigcode/modeling_gpt_bigcode_hf.py @@ -81,28 +81,96 @@ class HFAdaptedGPTBigCodeForCausalLM( ): ## Address transformers API changes if Version(tf_version) >= Version("5.0.0"): - _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + _keys_to_ignore_on_load_missing = [ + r"decoder\.model\.embedding\.weight", + ] _tied_weights_keys = { "decoder.model.embedding.weight": "embedding.weight", - "embedding.weight": "embedding.weight", - "lm_head.head.weight": "lm_head.head.weight", + "lm_head.weight": "embedding.weight", } else: - _keys_to_ignore_on_load_missing = [r"lm_head.weight"] - _tied_weights_keys = ["embedding.weight", "lm_head.weight"] + _keys_to_ignore_on_load_missing = [ + r"decoder\.model\.embedding\.weight", + r"lm_head\.weight", + ] + # For transformers < 5.0.0, set to empty list to disable automatic tying + # We'll handle tying manually in load_state_dict + _tied_weights_keys = [] def __init__(self, config: HFAdaptedGPTBigCodeConfig, *args, **kwargs): super().__init__(config=config, bias=False, *args, **kwargs) - def _tie_weights(self): - # We know that FMS always saves the LM head weight, so ensure the right pointer is shared - self.embedding.weight = self.lm_head.weight - self.decoder.model.embedding.weight = self.embedding.weight + def state_dict(self, *args, **kwargs): + """Override to exclude tied weights from state_dict. + + This prevents saving duplicate embeddings. The tied weights will be + restored during load via load_state_dict(). + """ + state_dict = super().state_dict(*args, **kwargs) + # Remove decoder.model.embedding.weight as it's tied to embedding.weight + if "decoder.model.embedding.weight" in state_dict: + del state_dict["decoder.model.embedding.weight"] + # For transformers < 5.0.0, also remove lm_head.weight if tied + if Version(tf_version) < Version("5.0.0") and self.config.tie_word_embeddings: + if "lm_head.weight" in state_dict: + del state_dict["lm_head.weight"] + return state_dict + + def load_state_dict(self, state_dict, strict=True, assign=False): + """Override to handle missing decoder.model.embedding.weight and lm_head.weight, and ensure weights are tied after loading""" + # If decoder.model.embedding.weight and lm_head.weight are missing from state_dict, that's expected + # because we exclude them in state_dict(). They will be tied to embedding.weight. + # So we load with strict=False first + result = super().load_state_dict(state_dict, strict=False, assign=assign) + + # Filter out the expected missing keys from the result + expected_missing = ["decoder.model.embedding.weight"] + if self.config.tie_word_embeddings and Version(tf_version) < Version("5.0.0"): + # For transformers < 5.0.0, lm_head.weight is also excluded from state_dict + expected_missing.append("lm_head.weight") + + filtered_missing_keys = [ + k for k in result.missing_keys if k not in expected_missing + ] + + # Manually tie the weights after loading + if self.config.tie_word_embeddings: + # Tie decoder.model.embedding to embedding + if ( + hasattr(self, "decoder") + and hasattr(self.decoder, "model") + and hasattr(self.decoder.model, "embedding") + ): + self.decoder.model.embedding.weight = self.embedding.weight + # Tie lm_head to embedding + if hasattr(self, "lm_head") and hasattr(self.lm_head, "weight"): + self.lm_head.weight = self.embedding.weight + + # If strict mode was requested and there are still missing/unexpected keys, raise an error + if strict and (filtered_missing_keys or result.unexpected_keys): + error_msgs = [] + if result.unexpected_keys: + error_msgs.append( + f"Unexpected key(s) in state_dict: {', '.join(result.unexpected_keys)}" + ) + if filtered_missing_keys: + error_msgs.append( + f"Missing key(s) in state_dict: {', '.join(filtered_missing_keys)}" + ) + if error_msgs: + raise RuntimeError( + f"Error(s) in loading state_dict for {self.__class__.__name__}:\n\t" + + "\n\t".join(error_msgs) + ) + + return result @classmethod def _hf_model_from_fms( cls, model: nn.Module, config: HFAdaptedGPTBigCodeConfig ) -> "HFAdaptedGPTBigCodeForCausalLM": + # Respect the FMS model's tie_heads setting + config.tie_word_embeddings = model.config.tie_heads return cls( config=config, decoder=model.base_model, diff --git a/tests/models/hf/test_as_fms_model.py b/tests/models/hf/test_as_fms_model.py index 8d2e4aa8a..dddb0bedb 100644 --- a/tests/models/hf/test_as_fms_model.py +++ b/tests/models/hf/test_as_fms_model.py @@ -28,7 +28,6 @@ def test_as_fms_model_equivalency_for_decoder(model_id_or_path): fms_model = to_hf_api( fms_model, bos_token_id=hf_model.config.bos_token_id, - pad_token_id=hf_model.config.pad_token_id, eos_token_id=hf_model.config.eos_token_id, ) hf_model = hf_model.eval() diff --git a/tests/models/hf_equivalence/test_gpt_bigcode.py b/tests/models/hf_equivalence/test_gpt_bigcode.py index 17a9fb420..cd87ac3c8 100644 --- a/tests/models/hf_equivalence/test_gpt_bigcode.py +++ b/tests/models/hf_equivalence/test_gpt_bigcode.py @@ -36,7 +36,6 @@ def test_gptbigcode_equivalence(): fms_model, bos_token_id=hf_model.config.bos_token_id, eos_token_id=hf_model.config.eos_token_id, - pad_token_id=hf_model.config.pad_token_id, ) def count_parameters(m): From 7234a56d29d6c097639d13d1c00fcea57cdbe5dc Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Mon, 30 Mar 2026 16:16:17 -0300 Subject: [PATCH 67/98] Fixes granite runslow test - Keep generation only at the equivalency test. Other signature comparisons are done at the other test files. Signed-off-by: Flavia Beo --- tests/models/hf_equivalence/test_granite.py | 54 ++++++++------------- 1 file changed, 20 insertions(+), 34 deletions(-) diff --git a/tests/models/hf_equivalence/test_granite.py b/tests/models/hf_equivalence/test_granite.py index eb770ef8d..336331e74 100644 --- a/tests/models/hf_equivalence/test_granite.py +++ b/tests/models/hf_equivalence/test_granite.py @@ -3,11 +3,10 @@ from fms.models import get_model from fms.models.hf import to_hf_api -from fms.testing.comparison import ( - HFModelSignatureParams, - ModelSignatureParams, - compare_model_signatures, -) + + +from packaging.version import Version +from transformers import __version__ as tf_version @pytest.mark.slow @@ -35,34 +34,8 @@ def test_granite_8b_equivalence(): hf_model.eval() hf_model_fms.eval() - # Test Parameter Count - - def count_parameters(m): - return sum(p.numel() for p in m.parameters()) - - assert count_parameters(hf_model_fms) == count_parameters(hf_model) - - # Test Model Signatures - - inp = torch.arange(0, 16).unsqueeze(0) - fms_signature_params = ModelSignatureParams(model=model, params=1, inp=inp) - hf_fms_signature_params = HFModelSignatureParams( - model=hf_model_fms, - params=["input_ids", "labels"], - other_params={"return_dict": True}, - inp=inp, - ) - hf_signature_params = HFModelSignatureParams( - model=hf_model, - params=["input_ids", "labels"], - other_params={"return_dict": True}, - inp=inp, - ) - - compare_model_signatures(fms_signature_params, hf_fms_signature_params) - compare_model_signatures(hf_fms_signature_params, hf_signature_params) - - # Test Generation Pipeline + # Keeping signature tests only at tests/models/hf/test_granite_hf.py + # Testing model generation equivalency for Granite 8B prompt = """q: how are you? a: I am good. How about you? q: What is the weather like today? a:""" @@ -72,6 +45,7 @@ def count_parameters(m): tokenizer=tokenizer, use_cache=True, max_new_tokens=20, + do_sample=False, ) generator_hf_fms = pipeline( task="text-generation", @@ -79,7 +53,19 @@ def count_parameters(m): tokenizer=tokenizer, use_cache=True, max_new_tokens=20, + do_sample=False, ) output_hf = generator_hf(prompt) output_hf_fms = generator_hf_fms(prompt) - assert output_hf == output_hf_fms + + # Compare generated text + if Version(tf_version) >= Version("5.0.0"): + assert output_hf[0]["generated_text"] == output_hf_fms[0]["generated_text"], ( + f"Generated text mismatch:\n" + f"HF: {output_hf[0]['generated_text']}\n" + f"FMS: {output_hf_fms[0]['generated_text']}" + ) + else: + assert output_hf == output_hf_fms, ( + f"Generated text mismatch:\nHF: {output_hf}\nFMS: {output_hf_fms}" + ) From 0cd8aecd8aa695ad8e608893d6b5085178c8efb8 Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Mon, 30 Mar 2026 16:20:40 -0300 Subject: [PATCH 68/98] Remove unused import Signed-off-by: Flavia Beo --- tests/models/hf_equivalence/test_granite.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/hf_equivalence/test_granite.py b/tests/models/hf_equivalence/test_granite.py index 336331e74..36b4a87ec 100644 --- a/tests/models/hf_equivalence/test_granite.py +++ b/tests/models/hf_equivalence/test_granite.py @@ -1,5 +1,4 @@ import pytest -import torch from fms.models import get_model from fms.models.hf import to_hf_api From cf3e9778b16e0e62a5e3b952e5160dda9a72e8a6 Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Mon, 30 Mar 2026 16:36:47 -0300 Subject: [PATCH 69/98] Fixes llama runslow test Signed-off-by: Flavia Beo --- tests/models/hf_equivalence/test_llama.py | 45 ++++------------------- 1 file changed, 8 insertions(+), 37 deletions(-) diff --git a/tests/models/hf_equivalence/test_llama.py b/tests/models/hf_equivalence/test_llama.py index 853bc2638..ac5b10eff 100644 --- a/tests/models/hf_equivalence/test_llama.py +++ b/tests/models/hf_equivalence/test_llama.py @@ -1,23 +1,18 @@ +from this import d import pytest import torch from fms.models import get_model from fms.models.hf import to_hf_api -from fms.testing.comparison import ( - HFModelSignatureParams, - ModelSignatureParams, - compare_model_signatures, -) @pytest.mark.slow -def test_llama_7b_equivalence(): +def test_llama_3b_equivalence(): """Tests llama equivalence with a known implementation. Takes approximately 8:38 on an mbp with M1 chip""" from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline # for now, this test won't be run, but it has been verified - # if you would like to try this, set llama_model_path to the huggingface llama2 model path - llama_model_path = "" + llama_model_path = "meta-llama/Llama-3.2-3B" tokenizer = AutoTokenizer.from_pretrained(llama_model_path, use_fast=True) hf_model = AutoModelForCausalLM.from_pretrained(llama_model_path) @@ -35,34 +30,8 @@ def test_llama_7b_equivalence(): hf_model.eval() hf_model_fms.eval() - # Test Parameter Count - - def count_parameters(m): - return sum(p.numel() for p in m.parameters()) - - assert count_parameters(hf_model_fms) == count_parameters(hf_model) - - # Test Model Signatures - - inp = torch.arange(0, 16).unsqueeze(0) - fms_signature_params = ModelSignatureParams(model=model, params=1, inp=inp) - hf_fms_signature_params = HFModelSignatureParams( - model=hf_model_fms, - params=["input_ids", "labels"], - other_params={"return_dict": True}, - inp=inp, - ) - hf_signature_params = HFModelSignatureParams( - model=hf_model, - params=["input_ids", "labels"], - other_params={"return_dict": True}, - inp=inp, - ) - - compare_model_signatures(fms_signature_params, hf_fms_signature_params) - compare_model_signatures(hf_fms_signature_params, hf_signature_params) - - # Test Generation Pipeline + # Keeping signature tests only at tests/models/hf/test_llama_hf.py + # Testing model generation equivalency for meta-llama/Llama-3.2-3B prompt = """q: how are you? a: I am good. How about you? q: What is the weather like today? a:""" @@ -73,6 +42,7 @@ def count_parameters(m): use_cache=True, num_beams=3, max_new_tokens=20, + do_sample=False, ) generator_hf_fms = pipeline( task="text-generation", @@ -81,6 +51,7 @@ def count_parameters(m): use_cache=True, num_beams=3, max_new_tokens=20, + do_sample=False, ) output_hf = generator_hf(prompt) output_hf_fms = generator_hf_fms(prompt) @@ -98,6 +69,6 @@ def count_parameters(m): import math torch._assert( - math.isclose(hf_model_loss.item(), hf_model_fms_loss.item(), abs_tol=1e-3), + math.isclose(hf_model_loss.item(), hf_model_fms_loss.item(), abs_tol=1e-2), "model loss is not equal", ) From 54dfdba1fc3830aba3feb6dbb2ebce8e9674522e Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Mon, 30 Mar 2026 16:38:30 -0300 Subject: [PATCH 70/98] Remove unused import Signed-off-by: Flavia Beo --- tests/models/hf_equivalence/test_llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/hf_equivalence/test_llama.py b/tests/models/hf_equivalence/test_llama.py index ac5b10eff..b4be221f5 100644 --- a/tests/models/hf_equivalence/test_llama.py +++ b/tests/models/hf_equivalence/test_llama.py @@ -1,4 +1,3 @@ -from this import d import pytest import torch From cfb375bb34d6dfba4298eeccb23708eb91c5f856 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Mon, 30 Mar 2026 20:56:15 -0500 Subject: [PATCH 71/98] :art: Fix type formatting Signed-off-by: Gaurav-Kumbhat --- fms/models/ministral3.py | 195 ++++++++++++++++++++++++++++++++++----- fms/modules/positions.py | 8 +- 2 files changed, 175 insertions(+), 28 deletions(-) diff --git a/fms/models/ministral3.py b/fms/models/ministral3.py index 2a2ccda7a..98270344a 100644 --- a/fms/models/ministral3.py +++ b/fms/models/ministral3.py @@ -1,7 +1,8 @@ import logging import re from dataclasses import dataclass, field -from typing import Any, Dict, Mapping, Optional +from typing import Any, Dict, Mapping, Optional, Tuple +from typing_extensions import Unpack import torch import torch.nn as nn @@ -15,12 +16,13 @@ from fms.utils.config import ModelConfig from fms.utils import serialization from fms.modules.attention import ( + AttentionKwargs, MultiHeadAttention, ) from fms.modules.feedforward import GatedLinearUnit from fms.modules.layernorm import LayerNormParameterized from fms.modules.positions import CachedYarnRotaryEmbedding -from fms.models.mistral import Mistral, MistralBlock, MistralHeadless +from fms.models.mistral import MistralBlock from fms.models.mistral3 import Mistral3, Mistral3MultiModalProjector from fms.models.pixtral_vision import PixtralVisionConfig, PixtralVisionModel @@ -79,14 +81,23 @@ class Ministral3Config(ModelConfig): # =============== Modeling ====================== -class Ministral3Headless(MistralHeadless, nn.Module): +class Ministral3Headless(nn.Module): + """ + Headless Ministral3 model (without the language modeling head). + + This class does not inherit from MistralHeadless to avoid type conflicts + between Ministral3TextConfig and MistralConfig. Instead, it re-implements + the necessary methods while maintaining a similar structure to MistralHeadless. + This approach ensures proper type checking with mypy while allowing code reuse + through shared utility functions and modules (MistralBlock, etc.). + """ def __init__( self, config: Ministral3TextConfig, distributed_strategy: DistributedStrategy = NoOpStrategy, ): - nn.Module.__init__(self) - self.config = config + super().__init__() + self.config: Ministral3TextConfig = config self.distributed_strategy = distributed_strategy self.embedding = nn.Embedding( @@ -100,8 +111,8 @@ def __init__( self.rot_emb = CachedYarnRotaryEmbedding( dim=self.config.head_dim, - base=self.config.rope_parameters.get("rope_theta"), - scaling_factor=config.rope_parameters.get("factor"), + base=self.config.rope_parameters.get("rope_theta"), # type: ignore[arg-type] + scaling_factor=config.rope_parameters.get("factor"), # type: ignore[arg-type] **rope_params, ) for device in set( @@ -115,7 +126,8 @@ def __init__( layers = [] for i in range(self.config.nlayers): - block: nn.Module = MistralBlock(self.config, self.rot_emb) + # MistralBlock expects MistralConfig, but Ministral3TextConfig has compatible fields + block: nn.Module = MistralBlock(self.config, self.rot_emb) # type: ignore[arg-type] block = self.distributed_strategy.distribute_layer(block, i) layers.append(block) self.layers = nn.ModuleList(layers) @@ -157,20 +169,29 @@ def _clean_up_rot_emb_cache( cached_freqs: dict[Optional[torch.device], dict[int, torch.Tensor]], max_seq_len_cached: dict[Optional[torch.device], int], ): + """ + Clean up meta tensors from RoPE cache. + Re-implemented from MistralHeadless to maintain functionality. + """ # remove meta tensors from cached_freqs for dev in list(cached_freqs.keys()): - if cached_freqs[dev].device == torch.device("meta"): - if len(cached_freqs[dev]) == 0: - del cached_freqs[dev] - del max_seq_len_cached[dev] + # Check if any tensor in this device's cache is on meta device + for tensor in cached_freqs[dev].values(): + if tensor.device == torch.device("meta"): + if len(cached_freqs[dev]) == 0: + del cached_freqs[dev] + del max_seq_len_cached[dev] + break def post_init(self): - # This function is called in `get_model` after the model is - # fully initalized on the correct device + """ + Post-initialization hook called after the model is fully initialized on the correct device. + Cleans up meta tensors from RoPE cache and initializes RoPE on the correct device(s). + """ # TODO: Currently we are not adding max_seq_len_cached to the cache, so we are not cleaning it up. self._clean_up_rot_emb_cache( self.rot_emb.cached_freqs, - self.rot_emb.max_seq_len_cached, + self.rot_emb.max_seq_len_cached, # type: ignore[arg-type] ) # init RoPE on the right device(s) @@ -180,17 +201,75 @@ def post_init(self): ): self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) + def forward( + self, + x_in, + position_ids=None, + past_key_value_states=None, + use_cache=False, + **attn_kwargs: Unpack[AttentionKwargs], + ): + """ + Forward pass through the headless model. + + This method is re-implemented from MistralHeadless to maintain the same + functionality while using Ministral3TextConfig instead of MistralConfig. + """ + # Embed the given vocabulary indices using the given attention mask, with pre-/post-norm and dropout as specified + # x_in: batch_size x seq_len + # mask: batch_size x seq_len x seq_len + # bias: nheads x seq_len x seq_len + if past_key_value_states is None or len(past_key_value_states) == 0: + past_key_value_states = [None for _ in range(len(self.layers))] + + if x_in.dim() == 2: # input is not already embedded + x_in = self.embedding(x_in) + + # this is the output cache for all the decoder layers + present_key_value_states = [] + + for i, layer in enumerate(self.layers): + output = layer( + x=x_in, + position_ids=position_ids, + past_key_value_state=past_key_value_states[i], + use_cache=use_cache, + **attn_kwargs, + ) + + if use_cache: + x_in, present_key_value_state = output + present_key_value_states.append(present_key_value_state) + + else: + x_in = output + + dec_out = x_in + dec_out = self.dec_norm(dec_out) + if self.config.p_dropout: + dec_out = self.dropout(dec_out) + + return dec_out, present_key_value_states -class Ministral3Text(Mistral, nn.Module): + +class Ministral3Text(nn.Module): + """ + Ministral3 text model with language modeling head. + + This class does not inherit from Mistral to avoid type conflicts between + Ministral3TextConfig and MistralConfig. Instead, it re-implements the necessary + methods while maintaining a similar structure to Mistral. This approach ensures + proper type checking with mypy. + """ def __init__( self, config: Optional[Ministral3TextConfig] = None, distributed_strategy: DistributedStrategy = NoOpStrategy, **kwargs, ): - nn.Module.__init__(self) + super().__init__() if config is not None: - self.config = config + self.config: Ministral3TextConfig = config else: self.config = Ministral3TextConfig() self.config = self.config.updated(**kwargs) @@ -202,14 +281,79 @@ def __init__( ) @classmethod - def from_config(cls, config: Ministral3TextConfig) -> "Ministral3": + def from_config(cls, config: Ministral3TextConfig) -> "Ministral3Text": return cls(config) def get_config(self) -> Ministral3TextConfig: return self.config + def reset_parameters(self): + """ + Reset model parameters. Re-implemented from Mistral to maintain functionality. + """ + import math + self.head.weight.data.normal_( + 0, + 1 / math.sqrt(math.sqrt(self.config.emb_dim * self.config.src_vocab_size)), + ) + self.base_model.reset_parameters() + + def post_init(self): + """ + Post-initialization hook. Re-implemented from Mistral to maintain functionality. + """ + # if this model ties weights, they are tied here + if self.config.tie_heads: + # handle assignment of non-meta weights to meta parameters + if self.head.weight.device == torch.device("meta"): + self.head.weight = self.base_model.embedding.weight + else: + self.base_model.embedding.weight = self.head.weight + + self.base_model.post_init() + + def forward( + self, + x: torch.LongTensor, + position_ids: Optional[torch.LongTensor] = None, + past_key_value_states: Optional[Tuple[torch.FloatTensor,]] = None, + use_cache: bool = False, + last_n_tokens: int = 0, + **attn_kwargs: Unpack[AttentionKwargs], + ): + """ + Forward pass through the model. Re-implemented from Mistral to maintain functionality. + """ + from fms.modules.attention import get_attention_type + from fms.utils.headless import gather_outputs + + get_attention_type(**attn_kwargs)["validate_attn_kwargs"]( + input_ids=x, + position_ids=position_ids, + past_key_value_states=past_key_value_states, + **attn_kwargs, + ) + output, cache = self.base_model( + x, position_ids, past_key_value_states, use_cache, **attn_kwargs + ) + + output = gather_outputs(output, last_n_tokens, **attn_kwargs) + preds = self.head(output) + + if use_cache: + return preds, cache + else: + return preds + class Ministral3(Mistral3): + """ + Ministral3 multimodal model combining text and vision. + + Inherits from Mistral3 but uses Ministral3Config and Ministral3Text instead of + Mistral3Config and Mistral. Type ignore comments are used where the types differ + but are structurally compatible. + """ def __init__( self, config: Optional[Ministral3Config] = None, @@ -219,11 +363,11 @@ def __init__( super().__init__() if config is not None: - self.config = config + self.config = config # type: ignore[assignment] else: - self.config = Ministral3Config() + self.config = Ministral3Config() # type: ignore[assignment] - self.config = self.config.updated(**kwargs) + self.config = self.config.updated(**kwargs) # type: ignore[assignment] # Ensure weight fusion correctly propogates; # NOTE: since pixtral is only run as a standalone model @@ -233,16 +377,17 @@ def __init__( self.distributed_strategy = distributed_strategy - # Currently, we always use mistral for the LLM - self.language_model = Ministral3Text( + # Ministral3Text is structurally compatible with Mistral for the language model + self.language_model: Ministral3Text = Ministral3Text( # type: ignore[assignment] self.config.text_config, self.distributed_strategy ) # Vision encoder and projector for multimodal features self.vision_tower = PixtralVisionModel( self.config.vision_config, self.distributed_strategy ) + # Ministral3Config is compatible with Mistral3Config for the projector self.multi_modal_projector = Mistral3MultiModalProjector( - self.config, + self.config, # type: ignore[arg-type] ) diff --git a/fms/modules/positions.py b/fms/modules/positions.py index 85682785a..be81b0d72 100644 --- a/fms/modules/positions.py +++ b/fms/modules/positions.py @@ -629,10 +629,12 @@ def __init__( self.extrapolation_factor = extrapolation_factor self.beta_fast = beta_fast # low self.beta_slow = beta_slow # high - self.llama_4_scaling_beta = llama_4_scaling_beta + self.llama_4_scaling_beta = ( + llama_4_scaling_beta if llama_4_scaling_beta is not None else 1 + ) self.cached_freqs: dict[int, torch.Tensor] = {} - self.max_seq_len_cached = {} + self.max_seq_len_cached: MutableMapping[int, int] = {} # magnitude scaling factor self.mscale = float(self._yarn_get_mscale(mscale)) @@ -684,7 +686,7 @@ def _get_llama_4_attn_scale(self, positions_ids: torch.Tensor) -> torch.Tensor: ) return scaling.unsqueeze(-1) - def compute_freqs_cis(self, device: torch.device, max_seq_len: int) -> None: + def compute_freqs_cis(self, device: torch.device, max_seq_len) -> None: """ Transfer pre-computed rotation matrices to the target device. From 3e8ac1193eb781ed89bde9ad88b818bff03ec7d5 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Tue, 31 Mar 2026 03:19:53 +0000 Subject: [PATCH 72/98] :recycle: Refactor to avoid graph breaks Signed-off-by: Gaurav-Kumbhat --- fms/modules/positions.py | 140 ++++++++++++++++++++++----------------- 1 file changed, 78 insertions(+), 62 deletions(-) diff --git a/fms/modules/positions.py b/fms/modules/positions.py index be81b0d72..d1c11f908 100644 --- a/fms/modules/positions.py +++ b/fms/modules/positions.py @@ -653,18 +653,20 @@ def _yarn_get_mscale(self, mscale: float = 1) -> float: return 0.1 * mscale * math.log(self.scaling_factor) + 1.0 def _compute_cos_sin_cache( - self, inv_freq: torch.Tensor, device: torch.device + self, inv_freq: torch.Tensor, device: torch.device, max_seq_len: int ) -> torch.Tensor: """ Compute the rotation matrix cache for the rotary embedding to avoid computing while doing the forward pass. Args: inv_freq: The precomputed inverse frequency tensor + device: The device to compute on + max_seq_len: Maximum sequence length to cache Returns: Rotation matrices with shape [max_pos, dim/2, 2, 2] """ t = torch.arange( - int(self.original_max_position_embeddings * self.scaling_factor), + max_seq_len, device=device, dtype=torch.float32, ) @@ -688,77 +690,91 @@ def _get_llama_4_attn_scale(self, positions_ids: torch.Tensor) -> torch.Tensor: def compute_freqs_cis(self, device: torch.device, max_seq_len) -> None: """ - Transfer pre-computed rotation matrices to the target device. + Compute and cache rotation matrices for the target device. - This method no longer computes cos/sin - those are pre-computed on CPU - during __init__. This method only handles device transfers. + This method computes cos/sin rotation matrices and caches them per device. + If the requested max_seq_len exceeds the cached length, it recomputes with + the new length. Args: - device: target device to transfer rotation matrices to - max_seq_len: maximum sequence length (must not exceed original_max_position_embeddings) + device: target device to compute rotation matrices on + max_seq_len: maximum sequence length for the model, if exceeded the cached freqs will be recomputed """ if device == torch.device("meta"): return - if device.index in self.cached_freqs: - return - dev_idx = device.index - freqs = self.base ** ( - torch.arange(0, self.dim, 2, device=device).float() / self.dim - ) - - inv_freq_extrapolation = 1.0 / freqs - inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs) - - # NOTE: math.floor and math.ceil being used here are referred to as "truncate" option - low = math.floor( - self.dim - * math.log( - self.original_max_position_embeddings / (self.beta_fast * 2 * math.pi) - ) - ) / (2 * math.log(self.base)) - high = math.ceil( - self.dim - * math.log( - self.original_max_position_embeddings / (self.beta_slow * 2 * math.pi) - ) - ) / (2 * math.log(self.base)) - - # Make sure values are not going outside range - low = max(low, 0) - high = min(high, self.dim - 1) - - if low == high: - high += 0.001 # Prevent singularity - - # Get n-dimensional rotational scaling corrected for extrapolation - linear_func = ( - torch.arange(self.dim // 2, dtype=torch.float32, device=device) - low - ) / (high - low) - - # Compute ramp function (clamped linear interpolation) - ramp_func = torch.clamp(linear_func, 0, 1) - - # inv_freq_extrapolation_factor is the weight for extrapolation - # (1 - ramp_func) means: use extrapolation for low frequencies (< low) - # ramp_func means: use interpolation for high frequencies (> high) - inv_freq_extrapolation_factor = 1 - ramp_func - - # Blend between interpolation and extrapolation - # Note: extrapolation_factor is applied to the extrapolation frequencies - inv_freq = ( - inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) - + inv_freq_extrapolation - * inv_freq_extrapolation_factor - * self.extrapolation_factor - ) + # Initialize cache entries for this device if not present + if dev_idx not in self.cached_freqs: + self.cached_freqs[dev_idx] = None + if dev_idx not in self.max_seq_len_cached: + self.max_seq_len_cached[dev_idx] = 0 - # Cache the computed rotation matrices for this device - freqs_cis = self._compute_cos_sin_cache(inv_freq, device) - self.cached_freqs[dev_idx] = freqs_cis + # Check if cache is empty (first time) + if self.cached_freqs[dev_idx] is None: + # Use scaled max_seq_len for cache size + # This avoids a graph break from computing scaled_max_seq_len if not needed + scaled_max_seq_len = int(self.original_max_position_embeddings * self.scaling_factor) + cache_size = max(max_seq_len, scaled_max_seq_len) + + # Only recompute if we need a longer sequence than what's cached + if cache_size > self.max_seq_len_cached[dev_idx]: + freqs = self.base ** ( + torch.arange(0, self.dim, 2, device=device).float() / self.dim + ) + + inv_freq_extrapolation = 1.0 / freqs + inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs) + + # NOTE: math.floor and math.ceil being used here are referred to as "truncate" option + low = math.floor( + self.dim + * math.log( + self.original_max_position_embeddings / (self.beta_fast * 2 * math.pi) + ) + ) / (2 * math.log(self.base)) + high = math.ceil( + self.dim + * math.log( + self.original_max_position_embeddings / (self.beta_slow * 2 * math.pi) + ) + ) / (2 * math.log(self.base)) + + # Make sure values are not going outside range + low = max(low, 0) + high = min(high, self.dim - 1) + + if low == high: + high += 0.001 # Prevent singularity + + # Get n-dimensional rotational scaling corrected for extrapolation + linear_func = ( + torch.arange(self.dim // 2, dtype=torch.float32, device=device) - low + ) / (high - low) + + # Compute ramp function (clamped linear interpolation) + ramp_func = torch.clamp(linear_func, 0, 1) + + # inv_freq_extrapolation_factor is the weight for extrapolation + # (1 - ramp_func) means: use extrapolation for low frequencies (< low) + # ramp_func means: use interpolation for high frequencies (> high) + inv_freq_extrapolation_factor = 1 - ramp_func + + # Blend between interpolation and extrapolation + # Note: extrapolation_factor is applied to the extrapolation frequencies + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation + * inv_freq_extrapolation_factor + * self.extrapolation_factor + ) + + # Cache the computed rotation matrices for this device + freqs_cis = self._compute_cos_sin_cache(inv_freq, device, cache_size) + self.cached_freqs[dev_idx] = freqs_cis + self.max_seq_len_cached[dev_idx] = cache_size def adjusted_qk( self, From 75babfef35790d79cf6014b7cb6851088066401c Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Tue, 31 Mar 2026 08:48:52 -0500 Subject: [PATCH 73/98] :art: Fix formatting issues Signed-off-by: Gaurav-Kumbhat --- fms/models/ministral3.py | 4 ++++ fms/modules/positions.py | 21 ++++++++++++--------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/fms/models/ministral3.py b/fms/models/ministral3.py index 98270344a..32e4bd81a 100644 --- a/fms/models/ministral3.py +++ b/fms/models/ministral3.py @@ -91,6 +91,7 @@ class Ministral3Headless(nn.Module): This approach ensures proper type checking with mypy while allowing code reuse through shared utility functions and modules (MistralBlock, etc.). """ + def __init__( self, config: Ministral3TextConfig, @@ -261,6 +262,7 @@ class Ministral3Text(nn.Module): methods while maintaining a similar structure to Mistral. This approach ensures proper type checking with mypy. """ + def __init__( self, config: Optional[Ministral3TextConfig] = None, @@ -292,6 +294,7 @@ def reset_parameters(self): Reset model parameters. Re-implemented from Mistral to maintain functionality. """ import math + self.head.weight.data.normal_( 0, 1 / math.sqrt(math.sqrt(self.config.emb_dim * self.config.src_vocab_size)), @@ -354,6 +357,7 @@ class Ministral3(Mistral3): Mistral3Config and Mistral. Type ignore comments are used where the types differ but are structurally compatible. """ + def __init__( self, config: Optional[Ministral3Config] = None, diff --git a/fms/modules/positions.py b/fms/modules/positions.py index d1c11f908..ac65a8eee 100644 --- a/fms/modules/positions.py +++ b/fms/modules/positions.py @@ -707,19 +707,19 @@ def compute_freqs_cis(self, device: torch.device, max_seq_len) -> None: dev_idx = device.index # Initialize cache entries for this device if not present - if dev_idx not in self.cached_freqs: - self.cached_freqs[dev_idx] = None if dev_idx not in self.max_seq_len_cached: self.max_seq_len_cached[dev_idx] = 0 - # Check if cache is empty (first time) - if self.cached_freqs[dev_idx] is None: + # Check if cache is empty (first time) or needs to be recomputed + if dev_idx not in self.cached_freqs: # Use scaled max_seq_len for cache size # This avoids a graph break from computing scaled_max_seq_len if not needed - scaled_max_seq_len = int(self.original_max_position_embeddings * self.scaling_factor) + scaled_max_seq_len = int( + self.original_max_position_embeddings * self.scaling_factor + ) cache_size = max(max_seq_len, scaled_max_seq_len) - # Only recompute if we need a longer sequence than what's cached + # Only compute if we need a longer sequence than what's cached if cache_size > self.max_seq_len_cached[dev_idx]: freqs = self.base ** ( torch.arange(0, self.dim, 2, device=device).float() / self.dim @@ -732,13 +732,15 @@ def compute_freqs_cis(self, device: torch.device, max_seq_len) -> None: low = math.floor( self.dim * math.log( - self.original_max_position_embeddings / (self.beta_fast * 2 * math.pi) + self.original_max_position_embeddings + / (self.beta_fast * 2 * math.pi) ) ) / (2 * math.log(self.base)) high = math.ceil( self.dim * math.log( - self.original_max_position_embeddings / (self.beta_slow * 2 * math.pi) + self.original_max_position_embeddings + / (self.beta_slow * 2 * math.pi) ) ) / (2 * math.log(self.base)) @@ -751,7 +753,8 @@ def compute_freqs_cis(self, device: torch.device, max_seq_len) -> None: # Get n-dimensional rotational scaling corrected for extrapolation linear_func = ( - torch.arange(self.dim // 2, dtype=torch.float32, device=device) - low + torch.arange(self.dim // 2, dtype=torch.float32, device=device) + - low ) / (high - low) # Compute ramp function (clamped linear interpolation) From 7ac564cdc9b48a8029d00afbae1725067f8ecb3b Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Tue, 31 Mar 2026 11:07:40 -0300 Subject: [PATCH 74/98] Adds version verify to tests marked as slow For the cache change retro-compatibility for the versions > 4.57.x Signed-off-by: Flavia Beo --- .../models/hf_equivalence/test_gpt_bigcode.py | 61 ++++++++++--------- tests/models/hf_equivalence/test_granite.py | 14 ++++- tests/models/hf_equivalence/test_llama.py | 29 +++++++-- 3 files changed, 67 insertions(+), 37 deletions(-) diff --git a/tests/models/hf_equivalence/test_gpt_bigcode.py b/tests/models/hf_equivalence/test_gpt_bigcode.py index cd87ac3c8..49a53bd47 100644 --- a/tests/models/hf_equivalence/test_gpt_bigcode.py +++ b/tests/models/hf_equivalence/test_gpt_bigcode.py @@ -4,10 +4,9 @@ from fms.models import get_model from fms.models.hf import to_hf_api -from fms.testing.comparison import ( - HFModelSignatureParams, - compare_model_signatures, -) + +from packaging.version import Version +from transformers import __version__ as tf_version @pytest.mark.slow @@ -36,53 +35,55 @@ def test_gptbigcode_equivalence(): fms_model, bos_token_id=hf_model.config.bos_token_id, eos_token_id=hf_model.config.eos_token_id, + pad_token_id=getattr(hf_model.config, "pad_token_id", None), ) - def count_parameters(m): - return sum(p.numel() for p in m.parameters()) - - assert count_parameters(hf_model_fms) == count_parameters(hf_model) - + # Set both models to eval mode hf_model.eval() hf_model_fms.eval() - inp = torch.arange(0, 16).unsqueeze(0) - - hf_fms_signature_params = HFModelSignatureParams( - model=hf_model_fms, - params=["input_ids", "labels"], - other_params={"return_dict": True}, - inp=inp, - ) + # Keep signatures test only at tests/model/test_gpt_bigcode.py + # Testing the models' generation - hf_signature_params = HFModelSignatureParams( - model=hf_model, - params=["input_ids", "labels"], - other_params={"return_dict": True}, - inp=inp, - ) + prompt = "def print_hello_world():" - compare_model_signatures(hf_fms_signature_params, hf_signature_params) + use_cache = False - prompt = "def print_hello_world():" + if Version(tf_version) >= Version("5.0.0"): + use_cache = True + else: + # for versions > 4.57.x and < 5.0.0, use_cache is disabled; + # this way we are retro compatible with parameter called cache_position + # https://huggingface.co/docs/transformers/cache_explanation#cache-position + use_cache = False + # Use greedy decoding (num_beams=1) for deterministic generation + # Note: use_cache=False to avoid KV cache shape mismatch issues generator_hf = pipeline( task="text-generation", model=hf_model, tokenizer=tokenizer, - use_cache=True, - num_beams=3, + use_cache=use_cache, + num_beams=1, + do_sample=False, max_new_tokens=50, ) generator_hf_fms = pipeline( task="text-generation", model=hf_model_fms, tokenizer=tokenizer, - use_cache=True, - num_beams=3, + use_cache=use_cache, + num_beams=1, + do_sample=False, max_new_tokens=50, ) output_hf = generator_hf(prompt) output_hf_fms = generator_hf_fms(prompt) - assert output_hf == output_hf_fms + + # Compare generated text with helpful error message + assert output_hf[0]["generated_text"] == output_hf_fms[0]["generated_text"], ( + f"Generated text mismatch:\n" + f"HF: {output_hf[0]['generated_text']}\n" + f"FMS: {output_hf_fms[0]['generated_text']}" + ) print(output_hf_fms) diff --git a/tests/models/hf_equivalence/test_granite.py b/tests/models/hf_equivalence/test_granite.py index 36b4a87ec..503bb6edd 100644 --- a/tests/models/hf_equivalence/test_granite.py +++ b/tests/models/hf_equivalence/test_granite.py @@ -38,11 +38,21 @@ def test_granite_8b_equivalence(): prompt = """q: how are you? a: I am good. How about you? q: What is the weather like today? a:""" + use_cache = False + + if Version(tf_version) >= Version("5.0.0"): + use_cache = True + else: + # for versions > 4.57.x and < 5.0.0, use_cache is disabled; + # this way we are retro compatible with parameter called cache_position + # https://huggingface.co/docs/transformers/cache_explanation#cache-position + use_cache = False + generator_hf = pipeline( task="text-generation", model=hf_model, tokenizer=tokenizer, - use_cache=True, + use_cache=use_cache, max_new_tokens=20, do_sample=False, ) @@ -50,7 +60,7 @@ def test_granite_8b_equivalence(): task="text-generation", model=hf_model_fms, tokenizer=tokenizer, - use_cache=True, + use_cache=use_cache, max_new_tokens=20, do_sample=False, ) diff --git a/tests/models/hf_equivalence/test_llama.py b/tests/models/hf_equivalence/test_llama.py index b4be221f5..d2de5f9f3 100644 --- a/tests/models/hf_equivalence/test_llama.py +++ b/tests/models/hf_equivalence/test_llama.py @@ -4,6 +4,9 @@ from fms.models import get_model from fms.models.hf import to_hf_api +from packaging.version import Version +from transformers import __version__ as tf_version + @pytest.mark.slow def test_llama_3b_equivalence(): @@ -34,12 +37,22 @@ def test_llama_3b_equivalence(): prompt = """q: how are you? a: I am good. How about you? q: What is the weather like today? a:""" + use_cache = False + + if Version(tf_version) >= Version("5.0.0"): + use_cache = True + else: + # for versions > 4.57.x and < 5.0.0, use_cache is disabled; + # this way we are retro compatible with parameter called cache_position + # https://huggingface.co/docs/transformers/cache_explanation#cache-position + use_cache = False + generator_hf = pipeline( task="text-generation", model=hf_model, tokenizer=tokenizer, - use_cache=True, - num_beams=3, + use_cache=use_cache, + num_beams=1, # Use greedy decoding for deterministic results max_new_tokens=20, do_sample=False, ) @@ -47,14 +60,20 @@ def test_llama_3b_equivalence(): task="text-generation", model=hf_model_fms, tokenizer=tokenizer, - use_cache=True, - num_beams=3, + use_cache=use_cache, + num_beams=1, # Use greedy decoding for deterministic results max_new_tokens=20, do_sample=False, ) output_hf = generator_hf(prompt) output_hf_fms = generator_hf_fms(prompt) - assert output_hf == output_hf_fms + + # Compare generated text with helpful error message + assert output_hf[0]["generated_text"] == output_hf_fms[0]["generated_text"], ( + f"Generated text mismatch:\n" + f"HF: {output_hf[0]['generated_text']}\n" + f"FMS: {output_hf_fms[0]['generated_text']}" + ) # Test Train Loss From 55e76b6c18f6cc1c4ba4206e4725819e64dea5df Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Tue, 31 Mar 2026 11:20:36 -0300 Subject: [PATCH 75/98] Removes unused import Signed-off-by: Flavia Beo --- tests/models/hf_equivalence/test_gpt_bigcode.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/hf_equivalence/test_gpt_bigcode.py b/tests/models/hf_equivalence/test_gpt_bigcode.py index 49a53bd47..7b1834d28 100644 --- a/tests/models/hf_equivalence/test_gpt_bigcode.py +++ b/tests/models/hf_equivalence/test_gpt_bigcode.py @@ -12,7 +12,6 @@ @pytest.mark.slow def test_gptbigcode_equivalence(): """Tests GPT BigCode equivalence with a known implementation. Takes approximately 1:11 on an mbp with M1 chip""" - import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline tokenizer = AutoTokenizer.from_pretrained("bigcode/gpt_bigcode-santacoder") From 68e610e3ad5b0ab1e70b9c2cc72be696a80bc84b Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Tue, 31 Mar 2026 18:39:42 +0000 Subject: [PATCH 76/98] :white_check_marks: Add ministral3 expectation test Signed-off-by: Gaurav-Kumbhat --- tests/models/test_ministral3.py | 249 ++++++++++++++++++ ...inistral3.TestMinistral3.test_model_output | 1 + ...ral3.TestMinistral3.test_model_weight_keys | 1 + ...tral3.TestMinistral3Text.test_model_output | 0 ....TestMinistral3Text.test_model_weight_keys | 1 + 5 files changed, 252 insertions(+) create mode 100644 tests/models/test_ministral3.py create mode 100644 tests/resources/expectations/models.test_ministral3.TestMinistral3.test_model_output create mode 100644 tests/resources/expectations/models.test_ministral3.TestMinistral3.test_model_weight_keys create mode 100644 tests/resources/expectations/models.test_ministral3.TestMinistral3Text.test_model_output create mode 100644 tests/resources/expectations/models.test_ministral3.TestMinistral3Text.test_model_weight_keys diff --git a/tests/models/test_ministral3.py b/tests/models/test_ministral3.py new file mode 100644 index 000000000..5c3ef67e0 --- /dev/null +++ b/tests/models/test_ministral3.py @@ -0,0 +1,249 @@ +import pytest +import torch + +# Skip entire module if transformers version is not > 5.0.0 +# transformers = pytest.importorskip("transformers", minversion="5.0.1") + +from fms.models.pixtral_vision import PixtralVisionConfig +from fms.models.ministral3 import ( + Ministral3, + Ministral3Config, + Ministral3Text, + Ministral3TextConfig, +) +from fms.testing._internal.model_test_suite import ( + ConfigFixtureMixin, + ModelCompileTestSuite, + ModelConfigTestSuite, + ModelConsistencyTestSuite, + ModelFixtureMixin, +) +from fms.utils.config import ModelConfig + + +class Ministral3TextFixtures(ConfigFixtureMixin, ModelFixtureMixin): + """ + Base Ministral3Text Fixtures for text-only model testing + + This will include the config and model signatures for the text-only variant + """ + + @pytest.fixture(scope="class", autouse=True) + def uninitialized_model(self, config: Ministral3TextConfig): + return Ministral3Text(config) + + @pytest.fixture(scope="class", autouse=True) + def config(self) -> ModelConfig: + # Text config for ministral3 text-only model + return Ministral3TextConfig( + src_vocab_size=384, + nheads=8, + nlayers=2, + hidden_grow_factor=3.5, + multiple_of=2, + tie_heads=False, + p_dropout=0.0, + activation_fn="silu", + emb_dim=16, + head_dim=128, + max_expected_seq_len=4096, + kvheads=2, + norm_eps=1e-05, + sliding_window=4000, + rope_parameters={ + "rope_type": "yarn", + "rope_theta": 100_0000.0, + "beta_fast": 32.0, + "beta_slow": 1.0, + "factor": 1.0, + "original_max_position_embeddings": 4096, + "mscale": 1.0, + "mscale_all_dim": 1.0, + }, + fused_weights=True, + pad_id=0, + ) + + +class TestMinistral3Text( + ModelConfigTestSuite, + ModelConsistencyTestSuite, + ModelCompileTestSuite, + Ministral3TextFixtures, +): + """Test suite for Ministral3Text (text-only model)""" + + @staticmethod + def get_logits(f_out): + return f_out + + input_ids = torch.arange(380).unsqueeze(0) + + _get_signature_params = ["x"] + _get_signature_input_ids = input_ids + _get_signature_optional_params = { + "last_n_tokens": 1, + } + + _get_signature_logits_getter_fn = get_logits + + +class Ministral3Fixtures(ConfigFixtureMixin, ModelFixtureMixin): + """ + Base Ministral3 Fixtures for multimodal model testing + + This will include the config and model signatures for the full multimodal variant + """ + + @pytest.fixture(scope="class", autouse=True) + def uninitialized_model(self, config: Ministral3Config): + return Ministral3(config) + + @pytest.fixture(scope="class", autouse=True) + def config(self) -> ModelConfig: + # Text / vision configs are essentially those used for the + # ministral3 llm and pixtral encoder tests + _text_config = Ministral3TextConfig( + src_vocab_size=384, + nheads=8, + nlayers=2, + hidden_grow_factor=3.5, + multiple_of=2, + tie_heads=False, + p_dropout=0.0, + activation_fn="silu", + emb_dim=16, + head_dim=128, + max_expected_seq_len=4096, + kvheads=2, + norm_eps=1e-05, + sliding_window=4000, + rope_parameters={ + "rope_type": "yarn", + "rope_theta": 100_0000.0, + "beta_fast": 32.0, + "beta_slow": 1.0, + "factor": 1.0, + "original_max_position_embeddings": 4096, + "mscale": 1.0, + "mscale_all_dim": 1.0, + }, + fused_weights=True, + pad_id=0, + ) + + _vision_config = PixtralVisionConfig( + hidden_size=16, + intermediate_size=64, + nlayers=8, + nheads=8, + nchannels=3, + image_size=280, + patch_size=14, + hidden_act="silu", + layer_norm_eps=1e-5, + rope_theta=10000.0, + attention_dropout=0.0, + fused_weights=True, + ) + + return Ministral3Config( + vision_config=_vision_config, + text_config=_text_config, + spatial_merge_size=2, + image_token_index=10, + vision_feature_layer=-1, + ) + + +class TestMinistral3( + ModelConfigTestSuite, + ModelConsistencyTestSuite, + ModelCompileTestSuite, + Ministral3Fixtures, +): + """Test suite for Ministral3 (multimodal model with vision)""" + + @staticmethod + def get_logits(f_out): + return f_out[0] + + pixel_values = [ + [[torch.arange(0, 1, 1 / 280).tolist() for _ in range(280)] for _ in range(3)] + ] + input_ids = torch.arange(380).unsqueeze(0) + pixel_values = torch.tensor(pixel_values) # [1, 3, 280, 280] + + _get_signature_params = ["input_ids_or_embeds"] + _get_signature_input_ids = input_ids + _get_signature_optional_params = { + "pixel_values": pixel_values, + "image_sizes": [(280, 280)], + "last_n_tokens": 1, + } + + _get_signature_logits_getter_fn = get_logits + + def test_config_passed_to_model_and_updated(self, model, config): + """test model constructor appropriately merges any passed kwargs into the config without mutating the original config""" + model = type(model)( + config=config, + vision_feature_layer=config.vision_feature_layer - 1, + ) + # check not same reference + assert model.get_config() is not config + + # modify feature layer to the new value expected and check equivalence + config.vision_feature_layer = config.vision_feature_layer - 1 + assert model.get_config().as_dict() == config.as_dict() + + +class TestMinistral3Vision(Ministral3Fixtures): + """Test suite specifically for Ministral3 vision capabilities (Pixtral integration)""" + + def test_vision_tower_exists(self, model): + """Test that the vision tower is properly initialized""" + assert hasattr(model, "vision_tower") + assert model.vision_tower is not None + + def test_multimodal_projector_exists(self, model): + """Test that the multimodal projector is properly initialized""" + assert hasattr(model, "multi_modal_projector") + assert model.multi_modal_projector is not None + + def test_vision_config_propagation(self, model, config): + """Test that vision config is properly propagated to the vision tower""" + assert model.vision_tower.config.hidden_size == config.vision_config.hidden_size + assert model.vision_tower.config.nlayers == config.vision_config.nlayers + assert model.vision_tower.config.nheads == config.vision_config.nheads + + def test_text_config_propagation(self, model, config): + """Test that text config is properly propagated to the language model""" + assert model.language_model.config.emb_dim == config.text_config.emb_dim + assert model.language_model.config.nlayers == config.text_config.nlayers + assert model.language_model.config.nheads == config.text_config.nheads + + def test_fused_weights_propagation(self, config): + """Test that fused_weights setting propagates correctly""" + # Test with fused_weights=False + config_unfused = Ministral3Config( + vision_config=config.vision_config, + text_config=config.text_config, + fused_weights=False, + ) + model_unfused = Ministral3(config_unfused) + assert not model_unfused.config.text_config.fused_weights + assert not model_unfused.config.vision_config.fused_weights + + # Test with fused_weights=True (default) + config_fused = Ministral3Config( + vision_config=config.vision_config, + text_config=config.text_config, + fused_weights=True, + ) + model_fused = Ministral3(config_fused) + assert model_fused.config.text_config.fused_weights + assert model_fused.config.vision_config.fused_weights + + +# Made with Bob diff --git a/tests/resources/expectations/models.test_ministral3.TestMinistral3.test_model_output b/tests/resources/expectations/models.test_ministral3.TestMinistral3.test_model_output new file mode 100644 index 000000000..599d04e0e --- /dev/null +++ b/tests/resources/expectations/models.test_ministral3.TestMinistral3.test_model_output @@ -0,0 +1 @@ +0.05594000965356827,0.07009149342775345,0.053363971412181854,0.03607359901070595,0.043681174516677856,0.020182866603136063,0.04964417219161987,0.03600818291306496,0.019329946488142014,0.03297266364097595,0.05269748345017433,0.06371667981147766,0.05002269893884659,0.0729384496808052,0.04788241162896156,0.04309707507491112,0.0,0.04357776418328285,0.07921542227268219,0.05602042376995087,0.076865553855896,0.09743601083755493,0.06762775778770447,0.04591166228055954,0.05145244300365448,0.03331846743822098,0.04739246517419815,0.050563327968120575,0.02307116612792015,0.032998014241456985,0.044376641511917114,0.05828513205051422,0.0640689954161644,0.041963111609220505,0.06661190092563629,0.021993771195411682,0.04157637059688568,0.04543519392609596,0.04329625517129898,0.05870999023318291,0.06388126313686371,0.029258813709020615,0.059809643775224686,0.03835966810584068,0.04687876254320145,0.06482844799757004,0.04912843182682991,0.06644786894321442,0.06589303910732269,0.07015223801136017,0.05971097946166992,0.056327346712350845,0.05118750035762787,0.0716148167848587,0.028318822383880615,0.049211010336875916,0.03845946490764618,0.06162078678607941,0.05900892615318298,0.06576602905988693,0.08574651181697845,0.03540661185979843,0.05052792280912399,0.0230337455868721,0.06948710978031158,0.07192317396402359,0.05369403585791588,0.057112567126750946,0.060141220688819885,0.046240903437137604,0.07929585129022598,0.043933380395174026,0.023667536675930023,0.052385009825229645,0.08191967010498047,0.057452715933322906,0.06330782175064087,0.05094665288925171,0.036971595138311386,0.028086058795452118,0.047337062656879425,0.020837370306253433,0.047451362013816833,0.06426891684532166,0.05620785057544708,0.040379032492637634,0.04898042976856232,0.07401958107948303,0.047023944556713104,0.042739544063806534,0.06628362834453583,0.038229767233133316,0.05312575772404671,0.03877401351928711,0.05186797305941582,0.0719762071967125,0.04904578626155853,0.03628318011760712,0.056798968464136124,0.07532411068677902,0.03873002529144287,0.03852667286992073,0.06784132122993469,0.048094168305397034,0.040452368557453156,0.03793191909790039,0.05681830644607544,0.04649338126182556,0.058725111186504364,0.03882820904254913,0.051229797303676605,0.0463840626180172,0.08150972425937653,0.048791803419589996,0.056264426559209824,0.04249201714992523,0.03347032517194748,0.030649248510599136,0.07214010506868362,0.03595946356654167,0.04968692362308502,0.03366164118051529,0.05083736032247543,0.04783494770526886,0.05568528175354004,0.05780668929219246,0.06438229233026505,0.06571371853351593,0.02533162385225296,0.06469275802373886,0.07349156588315964,0.04786044731736183,0.034229956567287445,0.04778200387954712,0.03741319105029106,0.07336124777793884,0.03802672401070595,0.060037147253751755,0.05296716466546059,0.03978914022445679,0.062454141676425934,0.03599924221634865,0.03369104862213135,0.05177672207355499,0.03293869271874428,0.03639262169599533,0.044998038560152054,0.02907366305589676,0.040119871497154236,0.027147360146045685,0.048559799790382385,0.045276857912540436,0.06866105645895004,0.061060402542352676,0.062874436378479,0.04782712832093239,0.06262528896331787,0.05039990320801735,0.05237971618771553,0.05835968628525734,0.0787557065486908,0.06677548587322235,0.05194058269262314,0.029971152544021606,0.03606399521231651,0.021126702427864075,0.0569908544421196,0.03827045112848282,0.09272080659866333,0.046609699726104736,0.02733200043439865,0.05647573620080948,0.04432094842195511,0.078376904129982,0.026644285768270493,0.04794427752494812,0.07198895514011383,0.022146187722682953,0.0255056694149971,0.06209910660982132,0.04710085317492485,0.05572003871202469,0.006826341152191162,0.06468471139669418,0.06109367683529854,0.06187712401151657,0.0668192133307457,0.03002656251192093,0.08999139815568924,0.04203619062900543,0.04498240351676941,0.03518887609243393,0.04585396498441696,0.04450260102748871,0.043581049889326096,0.042778052389621735,0.03706224635243416,0.0652753934264183,0.06435524672269821,0.05307982861995697,0.0440589115023613,0.0703946128487587,0.06323257088661194,0.0340811088681221,0.08040595799684525,0.07880274951457977,0.07025232166051865,0.05970701947808266,0.06700189411640167,0.05935615673661232,0.07070165127515793,0.06446608155965805,0.06095588207244873,0.03486824035644531,0.0901142954826355,0.05650843679904938,0.05714709684252739,0.042580023407936096,0.053357988595962524,0.04063870757818222,0.06193080171942711,0.0859304741024971,0.023655224591493607,0.04665638506412506,0.0530717633664608,0.057501520961523056,0.06564461439847946,0.04514537751674652,0.07286965101957321,0.0398518331348896,0.053668275475502014,0.06284862011671066,0.06383650004863739,0.05530640110373497,0.041489358991384506,0.027504812926054,0.059377700090408325,0.05432402715086937,0.04612661898136139,0.04668726027011871,0.04169453680515289,0.052302200347185135,0.06847386062145233,0.04534345865249634,0.07110844552516937,0.05035146698355675,0.032670266926288605,0.06409081071615219,0.05779050663113594,0.037005819380283356,0.04338227957487106,0.05145125463604927,0.07567987591028214,0.05758322775363922,0.06591200828552246,0.05284283682703972,0.06210080906748772,0.03757575526833534,0.05757122486829758,0.05699337646365166,0.05737210810184479,0.015277102589607239,0.04997456818819046,0.016514025628566742,0.05247597396373749,0.04775907099246979,0.054676733911037445,0.05624669790267944,0.08368708193302155,0.0477510429918766,0.031088262796401978,0.07771854102611542,0.050609201192855835,0.05379769206047058,0.06517907232046127,0.07358067482709885,0.05581139028072357,0.06085634604096413,0.032429441809654236,0.06907801330089569,0.02039666473865509,0.04166264832019806,0.07063055783510208,0.057863399386405945,0.07071326673030853,0.038445115089416504,0.08084820955991745,0.03923499584197998,0.033806439489126205,0.02290938049554825,0.06376587599515915,0.042040325701236725,0.04806685820221901,0.03303639218211174,0.039669156074523926,0.03858298063278198,0.02584291249513626,0.06175125390291214,0.07047692686319351,0.06338955461978912,0.05593520775437355,0.06257142126560211,0.06045425310730934,0.04628739878535271,0.051556095480918884,0.03937205299735069,0.031223706901073456,0.053349029272794724,0.09146720170974731,0.02697083353996277,0.06529778987169266,0.03044925630092621,0.054164767265319824,0.004468046128749847,0.04793383926153183,0.05577708035707474,0.04961930587887764,0.049213431775569916,0.05563195049762726,0.03984059393405914,0.01825942099094391,0.06467756628990173,0.06567580997943878,0.05535610020160675,0.04216676950454712,0.0649462565779686,0.03579432889819145,0.014211699366569519,0.06829570978879929,0.04222029447555542,0.0759027823805809,0.033252403140068054,0.051514144986867905,0.06949586421251297,0.02851555123925209,0.07451120018959045,0.06473961472511292,0.04491368681192398,0.029872383922338486,0.07605069875717163,0.05592980608344078,0.021009720861911774,0.061686813831329346,0.0517142079770565,0.07307823747396469,0.07604604959487915,0.05582322180271149,0.051429785788059235,0.0553152859210968,0.036295417696237564,0.07948359847068787,0.03756716102361679,0.06539832055568695,0.043496109545230865,0.0744294673204422,0.05192861333489418,0.06575309485197067,0.031244300305843353,0.04962065815925598,0.02668258175253868,0.03821685165166855,0.03925998508930206,0.03346341848373413,0.04489396512508392,0.04963776469230652,0.048156771808862686,0.03926989436149597,0.050722893327474594,0.04133383184671402,0.07423140853643417,0.08342073112726212,0.05672823265194893,0.04799208790063858,0.06504455208778381,0.04992742836475372,0.03532298654317856,0.04873589798808098,0.06584585458040237,0.047546569257974625,0.04797333851456642,0.06895318627357483,0.06908740103244781,0.06236596778035164,0.045539986342191696 \ No newline at end of file diff --git a/tests/resources/expectations/models.test_ministral3.TestMinistral3.test_model_weight_keys b/tests/resources/expectations/models.test_ministral3.TestMinistral3.test_model_weight_keys new file mode 100644 index 000000000..986cbf91e --- /dev/null +++ b/tests/resources/expectations/models.test_ministral3.TestMinistral3.test_model_weight_keys @@ -0,0 +1 @@ +language_model.base_model.dec_norm.weight,language_model.base_model.embedding.weight,language_model.base_model.layers.0.attn.dense.weight,language_model.base_model.layers.0.attn.in_proj.qkv_fused.weight,language_model.base_model.layers.0.ff_ln.weight,language_model.base_model.layers.0.ff_sub_layer.w2.weight,language_model.base_model.layers.0.ff_sub_layer.wg1_fused.weight,language_model.base_model.layers.0.ln.weight,language_model.base_model.layers.1.attn.dense.weight,language_model.base_model.layers.1.attn.in_proj.qkv_fused.weight,language_model.base_model.layers.1.ff_ln.weight,language_model.base_model.layers.1.ff_sub_layer.w2.weight,language_model.base_model.layers.1.ff_sub_layer.wg1_fused.weight,language_model.base_model.layers.1.ln.weight,language_model.head.weight,multi_modal_projector.linear_1.weight,multi_modal_projector.linear_2.weight,multi_modal_projector.norm.weight,multi_modal_projector.patch_merger.merging_layer.weight,vision_tower.ln_pre.weight,vision_tower.patch_conv.weight,vision_tower.transformer.layers.0.attention_norm.weight,vision_tower.transformer.layers.0.attn.dense.weight,vision_tower.transformer.layers.0.attn.in_proj.qkv_fused.weight,vision_tower.transformer.layers.0.ff_sub_layer.w2.weight,vision_tower.transformer.layers.0.ff_sub_layer.wg1_fused.weight,vision_tower.transformer.layers.0.ffn_norm.weight,vision_tower.transformer.layers.1.attention_norm.weight,vision_tower.transformer.layers.1.attn.dense.weight,vision_tower.transformer.layers.1.attn.in_proj.qkv_fused.weight,vision_tower.transformer.layers.1.ff_sub_layer.w2.weight,vision_tower.transformer.layers.1.ff_sub_layer.wg1_fused.weight,vision_tower.transformer.layers.1.ffn_norm.weight,vision_tower.transformer.layers.2.attention_norm.weight,vision_tower.transformer.layers.2.attn.dense.weight,vision_tower.transformer.layers.2.attn.in_proj.qkv_fused.weight,vision_tower.transformer.layers.2.ff_sub_layer.w2.weight,vision_tower.transformer.layers.2.ff_sub_layer.wg1_fused.weight,vision_tower.transformer.layers.2.ffn_norm.weight,vision_tower.transformer.layers.3.attention_norm.weight,vision_tower.transformer.layers.3.attn.dense.weight,vision_tower.transformer.layers.3.attn.in_proj.qkv_fused.weight,vision_tower.transformer.layers.3.ff_sub_layer.w2.weight,vision_tower.transformer.layers.3.ff_sub_layer.wg1_fused.weight,vision_tower.transformer.layers.3.ffn_norm.weight,vision_tower.transformer.layers.4.attention_norm.weight,vision_tower.transformer.layers.4.attn.dense.weight,vision_tower.transformer.layers.4.attn.in_proj.qkv_fused.weight,vision_tower.transformer.layers.4.ff_sub_layer.w2.weight,vision_tower.transformer.layers.4.ff_sub_layer.wg1_fused.weight,vision_tower.transformer.layers.4.ffn_norm.weight,vision_tower.transformer.layers.5.attention_norm.weight,vision_tower.transformer.layers.5.attn.dense.weight,vision_tower.transformer.layers.5.attn.in_proj.qkv_fused.weight,vision_tower.transformer.layers.5.ff_sub_layer.w2.weight,vision_tower.transformer.layers.5.ff_sub_layer.wg1_fused.weight,vision_tower.transformer.layers.5.ffn_norm.weight,vision_tower.transformer.layers.6.attention_norm.weight,vision_tower.transformer.layers.6.attn.dense.weight,vision_tower.transformer.layers.6.attn.in_proj.qkv_fused.weight,vision_tower.transformer.layers.6.ff_sub_layer.w2.weight,vision_tower.transformer.layers.6.ff_sub_layer.wg1_fused.weight,vision_tower.transformer.layers.6.ffn_norm.weight,vision_tower.transformer.layers.7.attention_norm.weight,vision_tower.transformer.layers.7.attn.dense.weight,vision_tower.transformer.layers.7.attn.in_proj.qkv_fused.weight,vision_tower.transformer.layers.7.ff_sub_layer.w2.weight,vision_tower.transformer.layers.7.ff_sub_layer.wg1_fused.weight,vision_tower.transformer.layers.7.ffn_norm.weight \ No newline at end of file diff --git a/tests/resources/expectations/models.test_ministral3.TestMinistral3Text.test_model_output b/tests/resources/expectations/models.test_ministral3.TestMinistral3Text.test_model_output new file mode 100644 index 000000000..e69de29bb diff --git a/tests/resources/expectations/models.test_ministral3.TestMinistral3Text.test_model_weight_keys b/tests/resources/expectations/models.test_ministral3.TestMinistral3Text.test_model_weight_keys new file mode 100644 index 000000000..26382e8bd --- /dev/null +++ b/tests/resources/expectations/models.test_ministral3.TestMinistral3Text.test_model_weight_keys @@ -0,0 +1 @@ +base_model.dec_norm.weight,base_model.embedding.weight,base_model.layers.0.attn.dense.weight,base_model.layers.0.attn.in_proj.qkv_fused.weight,base_model.layers.0.ff_ln.weight,base_model.layers.0.ff_sub_layer.w2.weight,base_model.layers.0.ff_sub_layer.wg1_fused.weight,base_model.layers.0.ln.weight,base_model.layers.1.attn.dense.weight,base_model.layers.1.attn.in_proj.qkv_fused.weight,base_model.layers.1.ff_ln.weight,base_model.layers.1.ff_sub_layer.w2.weight,base_model.layers.1.ff_sub_layer.wg1_fused.weight,base_model.layers.1.ln.weight,head.weight \ No newline at end of file From 423bc2f0e9c700ec66f90d9375da874c9f27f5a8 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Tue, 31 Mar 2026 19:03:04 +0000 Subject: [PATCH 77/98] :white_check_mark::recycle: Refactor ministral3 hf equivalence test Signed-off-by: Gaurav-Kumbhat --- tests/models/hf_equivalence/test_ministral3.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/models/hf_equivalence/test_ministral3.py b/tests/models/hf_equivalence/test_ministral3.py index d00451515..a06d96e9b 100644 --- a/tests/models/hf_equivalence/test_ministral3.py +++ b/tests/models/hf_equivalence/test_ministral3.py @@ -86,7 +86,7 @@ def _get_fms_model_output(model_path, inputs, max_new_tokens=6): @pytest.mark.slow -def test_ministral3_8b_equivalence(): +def test_ministral3_14b_equivalence(): from transformers import __version__ as tf_version from transformers import AutoProcessor @@ -100,11 +100,10 @@ def test_ministral3_8b_equivalence(): # if you would like to try this, set model_path to the HF model path # for ministral-3 - model_path = "/path/to/mistralai/Ministral-3-14B-Reasoning-2512/" - # NOTE: Ministral-3-8B-Instruct-2512-BF16 model doesn't come with its own processor - # You can use mistralai/Ministral-3-14B-Reasoning-2512 in that case + # NOTE: Since ministral3-8b doesn't come with its own processor + # we are using 14B model here + model_path = "mistralai/Ministral-3-14B-Reasoning-2512" - # model_path = "" processor = AutoProcessor.from_pretrained(model_path) # Get inputs with the model path for system prompt loading @@ -118,4 +117,4 @@ def test_ministral3_8b_equivalence(): if __name__ == "__main__": - test_ministral3_8b_equivalence() + test_ministral3_14b_equivalence() From aa34c4269581c4008ef458e2a9e9693badb8f92b Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Tue, 31 Mar 2026 14:43:38 -0500 Subject: [PATCH 78/98] :recycle::white_check_mark: Refactor ministral3 expectation test Signed-off-by: Gaurav-Kumbhat --- tests/models/test_ministral3.py | 75 +------------------ ...inistral3.TestMinistral3.test_model_output | 2 +- ...tral3.TestMinistral3Text.test_model_output | 0 ....TestMinistral3Text.test_model_weight_keys | 1 - 4 files changed, 3 insertions(+), 75 deletions(-) delete mode 100644 tests/resources/expectations/models.test_ministral3.TestMinistral3Text.test_model_output delete mode 100644 tests/resources/expectations/models.test_ministral3.TestMinistral3Text.test_model_weight_keys diff --git a/tests/models/test_ministral3.py b/tests/models/test_ministral3.py index 5c3ef67e0..0b879a1bb 100644 --- a/tests/models/test_ministral3.py +++ b/tests/models/test_ministral3.py @@ -1,14 +1,10 @@ import pytest import torch -# Skip entire module if transformers version is not > 5.0.0 -# transformers = pytest.importorskip("transformers", minversion="5.0.1") - from fms.models.pixtral_vision import PixtralVisionConfig from fms.models.ministral3 import ( Ministral3, Ministral3Config, - Ministral3Text, Ministral3TextConfig, ) from fms.testing._internal.model_test_suite import ( @@ -21,78 +17,11 @@ from fms.utils.config import ModelConfig -class Ministral3TextFixtures(ConfigFixtureMixin, ModelFixtureMixin): - """ - Base Ministral3Text Fixtures for text-only model testing - - This will include the config and model signatures for the text-only variant - """ - - @pytest.fixture(scope="class", autouse=True) - def uninitialized_model(self, config: Ministral3TextConfig): - return Ministral3Text(config) - - @pytest.fixture(scope="class", autouse=True) - def config(self) -> ModelConfig: - # Text config for ministral3 text-only model - return Ministral3TextConfig( - src_vocab_size=384, - nheads=8, - nlayers=2, - hidden_grow_factor=3.5, - multiple_of=2, - tie_heads=False, - p_dropout=0.0, - activation_fn="silu", - emb_dim=16, - head_dim=128, - max_expected_seq_len=4096, - kvheads=2, - norm_eps=1e-05, - sliding_window=4000, - rope_parameters={ - "rope_type": "yarn", - "rope_theta": 100_0000.0, - "beta_fast": 32.0, - "beta_slow": 1.0, - "factor": 1.0, - "original_max_position_embeddings": 4096, - "mscale": 1.0, - "mscale_all_dim": 1.0, - }, - fused_weights=True, - pad_id=0, - ) - - -class TestMinistral3Text( - ModelConfigTestSuite, - ModelConsistencyTestSuite, - ModelCompileTestSuite, - Ministral3TextFixtures, -): - """Test suite for Ministral3Text (text-only model)""" - - @staticmethod - def get_logits(f_out): - return f_out - - input_ids = torch.arange(380).unsqueeze(0) - - _get_signature_params = ["x"] - _get_signature_input_ids = input_ids - _get_signature_optional_params = { - "last_n_tokens": 1, - } - - _get_signature_logits_getter_fn = get_logits - - class Ministral3Fixtures(ConfigFixtureMixin, ModelFixtureMixin): """ - Base Ministral3 Fixtures for multimodal model testing + Base Ministral3 Fixtures that can be re-used for other purposes - This will include the config and model signatures for the full multimodal variant + This will include the config and model signatures for the multimodal variant """ @pytest.fixture(scope="class", autouse=True) diff --git a/tests/resources/expectations/models.test_ministral3.TestMinistral3.test_model_output b/tests/resources/expectations/models.test_ministral3.TestMinistral3.test_model_output index 599d04e0e..8df428c51 100644 --- a/tests/resources/expectations/models.test_ministral3.TestMinistral3.test_model_output +++ b/tests/resources/expectations/models.test_ministral3.TestMinistral3.test_model_output @@ -1 +1 @@ -0.05594000965356827,0.07009149342775345,0.053363971412181854,0.03607359901070595,0.043681174516677856,0.020182866603136063,0.04964417219161987,0.03600818291306496,0.019329946488142014,0.03297266364097595,0.05269748345017433,0.06371667981147766,0.05002269893884659,0.0729384496808052,0.04788241162896156,0.04309707507491112,0.0,0.04357776418328285,0.07921542227268219,0.05602042376995087,0.076865553855896,0.09743601083755493,0.06762775778770447,0.04591166228055954,0.05145244300365448,0.03331846743822098,0.04739246517419815,0.050563327968120575,0.02307116612792015,0.032998014241456985,0.044376641511917114,0.05828513205051422,0.0640689954161644,0.041963111609220505,0.06661190092563629,0.021993771195411682,0.04157637059688568,0.04543519392609596,0.04329625517129898,0.05870999023318291,0.06388126313686371,0.029258813709020615,0.059809643775224686,0.03835966810584068,0.04687876254320145,0.06482844799757004,0.04912843182682991,0.06644786894321442,0.06589303910732269,0.07015223801136017,0.05971097946166992,0.056327346712350845,0.05118750035762787,0.0716148167848587,0.028318822383880615,0.049211010336875916,0.03845946490764618,0.06162078678607941,0.05900892615318298,0.06576602905988693,0.08574651181697845,0.03540661185979843,0.05052792280912399,0.0230337455868721,0.06948710978031158,0.07192317396402359,0.05369403585791588,0.057112567126750946,0.060141220688819885,0.046240903437137604,0.07929585129022598,0.043933380395174026,0.023667536675930023,0.052385009825229645,0.08191967010498047,0.057452715933322906,0.06330782175064087,0.05094665288925171,0.036971595138311386,0.028086058795452118,0.047337062656879425,0.020837370306253433,0.047451362013816833,0.06426891684532166,0.05620785057544708,0.040379032492637634,0.04898042976856232,0.07401958107948303,0.047023944556713104,0.042739544063806534,0.06628362834453583,0.038229767233133316,0.05312575772404671,0.03877401351928711,0.05186797305941582,0.0719762071967125,0.04904578626155853,0.03628318011760712,0.056798968464136124,0.07532411068677902,0.03873002529144287,0.03852667286992073,0.06784132122993469,0.048094168305397034,0.040452368557453156,0.03793191909790039,0.05681830644607544,0.04649338126182556,0.058725111186504364,0.03882820904254913,0.051229797303676605,0.0463840626180172,0.08150972425937653,0.048791803419589996,0.056264426559209824,0.04249201714992523,0.03347032517194748,0.030649248510599136,0.07214010506868362,0.03595946356654167,0.04968692362308502,0.03366164118051529,0.05083736032247543,0.04783494770526886,0.05568528175354004,0.05780668929219246,0.06438229233026505,0.06571371853351593,0.02533162385225296,0.06469275802373886,0.07349156588315964,0.04786044731736183,0.034229956567287445,0.04778200387954712,0.03741319105029106,0.07336124777793884,0.03802672401070595,0.060037147253751755,0.05296716466546059,0.03978914022445679,0.062454141676425934,0.03599924221634865,0.03369104862213135,0.05177672207355499,0.03293869271874428,0.03639262169599533,0.044998038560152054,0.02907366305589676,0.040119871497154236,0.027147360146045685,0.048559799790382385,0.045276857912540436,0.06866105645895004,0.061060402542352676,0.062874436378479,0.04782712832093239,0.06262528896331787,0.05039990320801735,0.05237971618771553,0.05835968628525734,0.0787557065486908,0.06677548587322235,0.05194058269262314,0.029971152544021606,0.03606399521231651,0.021126702427864075,0.0569908544421196,0.03827045112848282,0.09272080659866333,0.046609699726104736,0.02733200043439865,0.05647573620080948,0.04432094842195511,0.078376904129982,0.026644285768270493,0.04794427752494812,0.07198895514011383,0.022146187722682953,0.0255056694149971,0.06209910660982132,0.04710085317492485,0.05572003871202469,0.006826341152191162,0.06468471139669418,0.06109367683529854,0.06187712401151657,0.0668192133307457,0.03002656251192093,0.08999139815568924,0.04203619062900543,0.04498240351676941,0.03518887609243393,0.04585396498441696,0.04450260102748871,0.043581049889326096,0.042778052389621735,0.03706224635243416,0.0652753934264183,0.06435524672269821,0.05307982861995697,0.0440589115023613,0.0703946128487587,0.06323257088661194,0.0340811088681221,0.08040595799684525,0.07880274951457977,0.07025232166051865,0.05970701947808266,0.06700189411640167,0.05935615673661232,0.07070165127515793,0.06446608155965805,0.06095588207244873,0.03486824035644531,0.0901142954826355,0.05650843679904938,0.05714709684252739,0.042580023407936096,0.053357988595962524,0.04063870757818222,0.06193080171942711,0.0859304741024971,0.023655224591493607,0.04665638506412506,0.0530717633664608,0.057501520961523056,0.06564461439847946,0.04514537751674652,0.07286965101957321,0.0398518331348896,0.053668275475502014,0.06284862011671066,0.06383650004863739,0.05530640110373497,0.041489358991384506,0.027504812926054,0.059377700090408325,0.05432402715086937,0.04612661898136139,0.04668726027011871,0.04169453680515289,0.052302200347185135,0.06847386062145233,0.04534345865249634,0.07110844552516937,0.05035146698355675,0.032670266926288605,0.06409081071615219,0.05779050663113594,0.037005819380283356,0.04338227957487106,0.05145125463604927,0.07567987591028214,0.05758322775363922,0.06591200828552246,0.05284283682703972,0.06210080906748772,0.03757575526833534,0.05757122486829758,0.05699337646365166,0.05737210810184479,0.015277102589607239,0.04997456818819046,0.016514025628566742,0.05247597396373749,0.04775907099246979,0.054676733911037445,0.05624669790267944,0.08368708193302155,0.0477510429918766,0.031088262796401978,0.07771854102611542,0.050609201192855835,0.05379769206047058,0.06517907232046127,0.07358067482709885,0.05581139028072357,0.06085634604096413,0.032429441809654236,0.06907801330089569,0.02039666473865509,0.04166264832019806,0.07063055783510208,0.057863399386405945,0.07071326673030853,0.038445115089416504,0.08084820955991745,0.03923499584197998,0.033806439489126205,0.02290938049554825,0.06376587599515915,0.042040325701236725,0.04806685820221901,0.03303639218211174,0.039669156074523926,0.03858298063278198,0.02584291249513626,0.06175125390291214,0.07047692686319351,0.06338955461978912,0.05593520775437355,0.06257142126560211,0.06045425310730934,0.04628739878535271,0.051556095480918884,0.03937205299735069,0.031223706901073456,0.053349029272794724,0.09146720170974731,0.02697083353996277,0.06529778987169266,0.03044925630092621,0.054164767265319824,0.004468046128749847,0.04793383926153183,0.05577708035707474,0.04961930587887764,0.049213431775569916,0.05563195049762726,0.03984059393405914,0.01825942099094391,0.06467756628990173,0.06567580997943878,0.05535610020160675,0.04216676950454712,0.0649462565779686,0.03579432889819145,0.014211699366569519,0.06829570978879929,0.04222029447555542,0.0759027823805809,0.033252403140068054,0.051514144986867905,0.06949586421251297,0.02851555123925209,0.07451120018959045,0.06473961472511292,0.04491368681192398,0.029872383922338486,0.07605069875717163,0.05592980608344078,0.021009720861911774,0.061686813831329346,0.0517142079770565,0.07307823747396469,0.07604604959487915,0.05582322180271149,0.051429785788059235,0.0553152859210968,0.036295417696237564,0.07948359847068787,0.03756716102361679,0.06539832055568695,0.043496109545230865,0.0744294673204422,0.05192861333489418,0.06575309485197067,0.031244300305843353,0.04962065815925598,0.02668258175253868,0.03821685165166855,0.03925998508930206,0.03346341848373413,0.04489396512508392,0.04963776469230652,0.048156771808862686,0.03926989436149597,0.050722893327474594,0.04133383184671402,0.07423140853643417,0.08342073112726212,0.05672823265194893,0.04799208790063858,0.06504455208778381,0.04992742836475372,0.03532298654317856,0.04873589798808098,0.06584585458040237,0.047546569257974625,0.04797333851456642,0.06895318627357483,0.06908740103244781,0.06236596778035164,0.045539986342191696 \ No newline at end of file +0.05594002455472946,0.07009149342775345,0.05336397886276245,0.03607361018657684,0.04368116706609726,0.020182885229587555,0.049644190818071365,0.03600817546248436,0.01932992786169052,0.03297264873981476,0.05269749090075493,0.06371669471263885,0.05002269893884659,0.0729384496808052,0.047882407903671265,0.043097082525491714,0.0,0.043577760457992554,0.07921542227268219,0.056020431220531464,0.07686556875705719,0.09743601083755493,0.06762776523828506,0.04591165855526924,0.05145245045423508,0.03331845998764038,0.04739247262477875,0.05056333541870117,0.02307118847966194,0.03299802169203758,0.04437664523720741,0.05828511714935303,0.06406901031732559,0.04196309298276901,0.06661192327737808,0.021993786096572876,0.04157635569572449,0.04543519765138626,0.043296266347169876,0.0587100051343441,0.06388124078512192,0.02925882861018181,0.059809643775224686,0.03835967928171158,0.04687875136733055,0.06482845544815063,0.0491284504532814,0.06644788384437561,0.06589306145906448,0.07015224546194077,0.05971100553870201,0.05632736161351204,0.051187485456466675,0.0716148093342781,0.028318822383880615,0.049211032688617706,0.038459450006484985,0.061620816588401794,0.05900892987847328,0.06576603651046753,0.08574653416872025,0.035406630486249924,0.05052793771028519,0.023033753037452698,0.06948710232973099,0.07192317396402359,0.05369403213262558,0.05711257457733154,0.060141224414110184,0.046240903437137604,0.07929585874080658,0.04393337666988373,0.023667525500059128,0.05238502472639084,0.08191967755556107,0.05745270848274231,0.06330782175064087,0.0509466677904129,0.03697159141302109,0.028086066246032715,0.04733707010746002,0.02083737403154373,0.04745139181613922,0.06426891684532166,0.05620785057544708,0.04037903621792793,0.048980433493852615,0.07401958853006363,0.047023944556713104,0.042739540338516235,0.06628364324569702,0.03822977840900421,0.05312575399875641,0.03877400606870651,0.05186798423528671,0.07197622954845428,0.049045808613300323,0.036283183842897415,0.05679897591471672,0.07532411813735962,0.03873002529144287,0.03852667286992073,0.06784132868051529,0.048094164580106735,0.040452368557453156,0.037931911647319794,0.05681835114955902,0.04649338498711586,0.05872511491179466,0.038828182965517044,0.051229801028966904,0.0463840588927269,0.08150972425937653,0.04879182204604149,0.056264422833919525,0.04249201714992523,0.03347032517194748,0.030649255961179733,0.07214010506868362,0.03595946729183197,0.04968693479895592,0.03366164118051529,0.05083736032247543,0.04783494397997856,0.055685292929410934,0.05780669301748276,0.06438229978084564,0.06571371108293533,0.025331638753414154,0.06469275802373886,0.07349155098199844,0.04786045104265213,0.03422994911670685,0.04778201878070831,0.03741317614912987,0.07336126267910004,0.03802672028541565,0.06003715842962265,0.05296717956662178,0.03978915512561798,0.06245414912700653,0.03599925339221954,0.033691056072711945,0.0517766997218132,0.032938700169324875,0.036392614245414734,0.04499802365899086,0.029073677957057953,0.04011987894773483,0.027147367596626282,0.048559803515672684,0.04527686536312103,0.06866106390953064,0.06106039136648178,0.062874436378479,0.04782714694738388,0.06262529641389847,0.05039992928504944,0.05237973481416702,0.05835968255996704,0.07875573635101318,0.06677549332380295,0.05194057524204254,0.0299711674451828,0.03606399893760681,0.021126694977283478,0.05699087306857109,0.03827044367790222,0.09272082149982452,0.046609703451395035,0.02733200043439865,0.05647575110197067,0.04432094469666481,0.07837691903114319,0.026644282042980194,0.04794425517320633,0.07198896259069443,0.022146183997392654,0.025505665689706802,0.06209911033511162,0.04710085690021515,0.05572003871202469,0.006826337426900864,0.06468471139669418,0.061093684285879135,0.061877135187387466,0.06681922823190689,0.03002656251192093,0.08999139815568924,0.04203619807958603,0.0449824258685112,0.03518887981772423,0.04585397616028786,0.04450259357690811,0.04358106851577759,0.04277805984020233,0.03706223890185356,0.0652754083275795,0.06435525417327881,0.05307982489466667,0.044058907777071,0.0703946053981781,0.06323256343603134,0.0340811163187027,0.08040597289800644,0.07880276441574097,0.07025234401226044,0.05970700457692146,0.06700190156698227,0.05935616418719292,0.07070162892341614,0.06446608901023865,0.060955893248319626,0.03486822545528412,0.0901142954826355,0.056508421897888184,0.05714711174368858,0.04258004203438759,0.053357988595962524,0.04063870385289192,0.06193080544471741,0.0859304741024971,0.023655209690332413,0.04665641486644745,0.053071774542331696,0.05750152841210365,0.06564462929964066,0.045145370066165924,0.07286965101957321,0.03985186666250229,0.05366826057434082,0.06284862756729126,0.06383651494979858,0.055306412279605865,0.041489385068416595,0.027504824101924896,0.059377703815698624,0.05432402342557907,0.04612661898136139,0.04668726399540901,0.041694555431604385,0.05230220407247543,0.06847383826971054,0.04534347355365753,0.07110844552516937,0.050351470708847046,0.0326702818274498,0.06409081816673279,0.05779050290584564,0.03700583428144455,0.04338226094841957,0.051451265811920166,0.07567987591028214,0.05758321285247803,0.06591202318668365,0.05284284055233002,0.06210082396864891,0.037575751543045044,0.057571228593587875,0.05699337273836136,0.057372115552425385,0.015277087688446045,0.04997456073760986,0.016514025628566742,0.05247596651315689,0.04775907099246979,0.05467674881219864,0.05624669790267944,0.08368711173534393,0.047751057893037796,0.03108825907111168,0.07771854102611542,0.05060920864343643,0.05379769951105118,0.06517908722162247,0.07358068972826004,0.05581140145659447,0.06085636094212532,0.03242943808436394,0.06907801330089569,0.020396657288074493,0.04166264459490776,0.07063055783510208,0.05786340683698654,0.07071325927972794,0.0384451299905777,0.08084823936223984,0.03923499584197998,0.033806439489126205,0.02290938049554825,0.06376589089632034,0.04204033687710762,0.048066891729831696,0.03303638845682144,0.03966914862394333,0.038582973182201385,0.025842905044555664,0.06175125390291214,0.07047692686319351,0.06338955461978912,0.05593523383140564,0.06257142871618271,0.06045427918434143,0.0462874099612236,0.05155611038208008,0.03937206417322159,0.031223703175783157,0.05334903299808502,0.0914672315120697,0.026970822364091873,0.06529779732227325,0.030449271202087402,0.05416475236415863,0.004468072205781937,0.047933876514434814,0.05577709525823593,0.04961932450532913,0.04921341687440872,0.05563195049762726,0.03984059765934944,0.01825941726565361,0.06467757374048233,0.06567581743001938,0.05535609647631645,0.04216676577925682,0.0649462640285492,0.03579434007406235,0.014211688190698624,0.06829570978879929,0.04222029447555542,0.0759027823805809,0.03325241059064865,0.0515141487121582,0.06949587166309357,0.028515517711639404,0.07451120764017105,0.06473961472511292,0.044913679361343384,0.029872387647628784,0.07605072110891342,0.055929817259311676,0.021009724587202072,0.061686813831329346,0.051714230328798294,0.07307826727628708,0.07604605704545975,0.05582323297858238,0.05142980068922043,0.05531527101993561,0.036295413970947266,0.07948359847068787,0.03756716102361679,0.06539834290742874,0.04349612444639206,0.0744294822216034,0.05192862078547478,0.06575308740139008,0.03124430403113365,0.04962066188454628,0.02668260782957077,0.03821687772870064,0.039260007441043854,0.03346341848373413,0.04489395394921303,0.049637775868177414,0.048156771808862686,0.03926992416381836,0.050722889602184296,0.04133383929729462,0.07423140853643417,0.08342073857784271,0.05672822892665863,0.047992072999477386,0.06504455953836441,0.049927420914173126,0.035322993993759155,0.04873591288924217,0.06584584712982178,0.04754657298326492,0.04797336459159851,0.06895320117473602,0.06908737868070602,0.06236596405506134,0.045539986342191696 \ No newline at end of file diff --git a/tests/resources/expectations/models.test_ministral3.TestMinistral3Text.test_model_output b/tests/resources/expectations/models.test_ministral3.TestMinistral3Text.test_model_output deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/resources/expectations/models.test_ministral3.TestMinistral3Text.test_model_weight_keys b/tests/resources/expectations/models.test_ministral3.TestMinistral3Text.test_model_weight_keys deleted file mode 100644 index 26382e8bd..000000000 --- a/tests/resources/expectations/models.test_ministral3.TestMinistral3Text.test_model_weight_keys +++ /dev/null @@ -1 +0,0 @@ -base_model.dec_norm.weight,base_model.embedding.weight,base_model.layers.0.attn.dense.weight,base_model.layers.0.attn.in_proj.qkv_fused.weight,base_model.layers.0.ff_ln.weight,base_model.layers.0.ff_sub_layer.w2.weight,base_model.layers.0.ff_sub_layer.wg1_fused.weight,base_model.layers.0.ln.weight,base_model.layers.1.attn.dense.weight,base_model.layers.1.attn.in_proj.qkv_fused.weight,base_model.layers.1.ff_ln.weight,base_model.layers.1.ff_sub_layer.w2.weight,base_model.layers.1.ff_sub_layer.wg1_fused.weight,base_model.layers.1.ln.weight,head.weight \ No newline at end of file From bba08de9da484e711913d119da3336d04bc31156 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Wed, 1 Apr 2026 23:34:58 +0000 Subject: [PATCH 79/98] :bug: Fix position id memory issue at higher context length Signed-off-by: Gaurav-Kumbhat --- fms/modules/positions.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/fms/modules/positions.py b/fms/modules/positions.py index ac65a8eee..e256a26b9 100644 --- a/fms/modules/positions.py +++ b/fms/modules/positions.py @@ -683,10 +683,10 @@ def _compute_cos_sin_cache( return freqs_cis def _get_llama_4_attn_scale(self, positions_ids: torch.Tensor) -> torch.Tensor: - scaling = 1 + self.llama_4_scaling_beta * torch.log( - 1 + torch.floor(positions_ids / self.original_max_position_embeddings) - ) - return scaling.unsqueeze(-1) + pos_idx = (positions_ids // self.original_max_position_embeddings).float() + + scaling = 1 + self.llama_4_scaling_beta * torch.log(1 + torch.floor(pos_idx)) + return scaling def compute_freqs_cis(self, device: torch.device, max_seq_len) -> None: """ @@ -866,10 +866,17 @@ def adjusted_qk( # Apply llama_4_scaling if self.llama_4_scaling_beta: - cache_position = torch.arange( - q_out.shape[2], device=q_out.device, dtype=q_out.dtype - ) - q_out = q_out * self._get_llama_4_attn_scale(cache_position) + # Compute scaling per position: [B, L] + scaling = self._get_llama_4_attn_scale(position_ids) + + # Convert to target dtype and add singleton dimensions for broadcasting + # Instead of creating [B, L, 1, 1], use in-place broadcasting which is more memory efficient + # This avoids creating large intermediate tensors at high context lengths + scaling = scaling.to(q_out.dtype) + + # Apply scaling using implicit broadcasting: [B, L] broadcasts to [B, L, H, D] + # Reshape to [B, L, 1, 1] only for the multiplication to ensure proper broadcasting + q_out = q_out * scaling[:, :, None, None] q_out = q_out.view_as(q_rope) k_out = k_out.view_as(k_rope) From 31ff5208b9fe33a9aef8c531ba94e7554123a4bd Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Thu, 2 Apr 2026 19:24:56 +0000 Subject: [PATCH 80/98] :wrench: Apply review suggestions and improve commenting Signed-off-by: Gaurav-Kumbhat --- fms/models/hf/config_utils/param_builders.py | 6 +++--- fms/models/ministral3.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fms/models/hf/config_utils/param_builders.py b/fms/models/hf/config_utils/param_builders.py index 0b1080f00..1d0d84b06 100644 --- a/fms/models/hf/config_utils/param_builders.py +++ b/fms/models/hf/config_utils/param_builders.py @@ -397,19 +397,19 @@ def build_mistral3_params(config: PretrainedConfig) -> dict: def build_ministral3_params(config: PretrainedConfig) -> dict: """Param builder for ministral3 mapping Mistral3ForConditionalGeneration to FMS.""" - ## NOTE: Since ministral3 and mistral3 uses same architecture class + ## NOTE: Since ministral3 and ministral3 uses same architecture class # we are combining their build params function into one. They also # use same vision model # Sanity checks – we currently support only Mistral text + Pixtral vision if getattr(config.text_config, "model_type", None) != "ministral3": raise ValueError( - "FMS implementation of Mistral3 currently supports only 'mistral' language model" + "FMS implementation of Ministral3 currently supports only 'ministral3' language model" ) if getattr(config.vision_config, "model_type", None) != "pixtral": raise ValueError( - "FMS implementation of Mistral3 currently supports only 'pixtral' vision tower" + "FMS implementation of Ministral3 currently supports only 'pixtral' vision tower" ) config_params = { "projector_hidden_act": config.projector_hidden_act, diff --git a/fms/models/ministral3.py b/fms/models/ministral3.py index 32e4bd81a..b0ab40596 100644 --- a/fms/models/ministral3.py +++ b/fms/models/ministral3.py @@ -484,7 +484,7 @@ def _hf_to_fms_rope( new_sd = {} if model_config is None: - # It Fall back to values for Ministral3; ModelConfig should really not be + # Fall back to values for Ministral3; ModelConfig should really not be # optional here though, as setting the wrong head dimensions can cause a # lot of confusion. lm_head_dim = 128 From bb8a45f7598b507e1e1c276f7f954723f6619db1 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Thu, 2 Apr 2026 21:25:08 +0000 Subject: [PATCH 81/98] :memo: Update comment as per review suggestion Signed-off-by: Gaurav-Kumbhat --- fms/models/hf/config_utils/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fms/models/hf/config_utils/__init__.py b/fms/models/hf/config_utils/__init__.py index c096952d8..0385f071b 100644 --- a/fms/models/hf/config_utils/__init__.py +++ b/fms/models/hf/config_utils/__init__.py @@ -46,7 +46,9 @@ "MPNetForMaskedLM": ("mpnet", pb.build_mpnet_params), "BertForMaskedLM": ("bert", pb.build_bert_params), "Mistral3ForConditionalGeneration": ("mistral3", pb.build_mistral3_params), - # NOTE: Special case for Ministral3 + # This mapping logic in FMS relies on mapping top level model_type in config.json to a particular class. + # However, in case of ministral3 models, this still comes out to be mistral3, and not ministral3. + # To distinguish between these models we add a special handling for ministral3 at this mapping layer, to get around this problem. "FMSMinistral3ForConditionalGeneration": ("ministral3", pb.build_ministral3_params), # Classify arches have some extra keys for labels "RobertaForSequenceClassification": ("roberta_classification", partial(pb.build_roberta_params, is_classify=True)), From 97653913031a399ead9d1073da01935a1421761a Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Thu, 2 Apr 2026 22:53:30 +0000 Subject: [PATCH 82/98] :bug: Fix ministral test config issue Signed-off-by: Gaurav-Kumbhat --- tests/models/test_ministral3.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/test_ministral3.py b/tests/models/test_ministral3.py index 0b879a1bb..5f21de597 100644 --- a/tests/models/test_ministral3.py +++ b/tests/models/test_ministral3.py @@ -46,7 +46,7 @@ def config(self) -> ModelConfig: max_expected_seq_len=4096, kvheads=2, norm_eps=1e-05, - sliding_window=4000, + sliding_window=None, rope_parameters={ "rope_type": "yarn", "rope_theta": 100_0000.0, @@ -56,6 +56,7 @@ def config(self) -> ModelConfig: "original_max_position_embeddings": 4096, "mscale": 1.0, "mscale_all_dim": 1.0, + "llama_4_scaling_beta": 0.1, }, fused_weights=True, pad_id=0, From e2ed50b5b63e20d31d418a43d6dfbfb4042a82dd Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Thu, 2 Apr 2026 23:25:48 +0000 Subject: [PATCH 83/98] :green_heart: Potentially fix hanging CI Signed-off-by: Gaurav-Kumbhat --- pyproject.toml | 6 ++-- tests/models/test_ministral3.py | 57 ++------------------------------- 2 files changed, 6 insertions(+), 57 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7f7f93b1c..1d27d9f13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,13 +39,13 @@ dependencies = [ ] [project.optional-dependencies] -hf = ["transformers==4.55.4"] +hf = ["transformers>=4.55.4"] dev = [ "mypy==1.15.0", "mypy-extensions==1.0.0", -"pytest==8.3.4", +"pytest>=9.0.2", "sentencepiece==0.2.0", -"transformers==4.55.4", +"transformers>=4.55.4", "pyarrow-stubs==17.16", "types-requests==2.32.0.20241016", "lm_eval==0.4.7", diff --git a/tests/models/test_ministral3.py b/tests/models/test_ministral3.py index 5f21de597..c5a9e727c 100644 --- a/tests/models/test_ministral3.py +++ b/tests/models/test_ministral3.py @@ -43,17 +43,17 @@ def config(self) -> ModelConfig: activation_fn="silu", emb_dim=16, head_dim=128, - max_expected_seq_len=4096, + max_expected_seq_len=2048, kvheads=2, norm_eps=1e-05, sliding_window=None, rope_parameters={ "rope_type": "yarn", - "rope_theta": 100_0000.0, + "rope_theta": 1000.0, "beta_fast": 32.0, "beta_slow": 1.0, "factor": 1.0, - "original_max_position_embeddings": 4096, + "original_max_position_embeddings": 1024, "mscale": 1.0, "mscale_all_dim": 1.0, "llama_4_scaling_beta": 0.1, @@ -126,54 +126,3 @@ def test_config_passed_to_model_and_updated(self, model, config): # modify feature layer to the new value expected and check equivalence config.vision_feature_layer = config.vision_feature_layer - 1 assert model.get_config().as_dict() == config.as_dict() - - -class TestMinistral3Vision(Ministral3Fixtures): - """Test suite specifically for Ministral3 vision capabilities (Pixtral integration)""" - - def test_vision_tower_exists(self, model): - """Test that the vision tower is properly initialized""" - assert hasattr(model, "vision_tower") - assert model.vision_tower is not None - - def test_multimodal_projector_exists(self, model): - """Test that the multimodal projector is properly initialized""" - assert hasattr(model, "multi_modal_projector") - assert model.multi_modal_projector is not None - - def test_vision_config_propagation(self, model, config): - """Test that vision config is properly propagated to the vision tower""" - assert model.vision_tower.config.hidden_size == config.vision_config.hidden_size - assert model.vision_tower.config.nlayers == config.vision_config.nlayers - assert model.vision_tower.config.nheads == config.vision_config.nheads - - def test_text_config_propagation(self, model, config): - """Test that text config is properly propagated to the language model""" - assert model.language_model.config.emb_dim == config.text_config.emb_dim - assert model.language_model.config.nlayers == config.text_config.nlayers - assert model.language_model.config.nheads == config.text_config.nheads - - def test_fused_weights_propagation(self, config): - """Test that fused_weights setting propagates correctly""" - # Test with fused_weights=False - config_unfused = Ministral3Config( - vision_config=config.vision_config, - text_config=config.text_config, - fused_weights=False, - ) - model_unfused = Ministral3(config_unfused) - assert not model_unfused.config.text_config.fused_weights - assert not model_unfused.config.vision_config.fused_weights - - # Test with fused_weights=True (default) - config_fused = Ministral3Config( - vision_config=config.vision_config, - text_config=config.text_config, - fused_weights=True, - ) - model_fused = Ministral3(config_fused) - assert model_fused.config.text_config.fused_weights - assert model_fused.config.vision_config.fused_weights - - -# Made with Bob From 46ac2449e7346656bc1dd5b0d56c5c19ead09ec0 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Fri, 3 Apr 2026 00:59:02 +0000 Subject: [PATCH 84/98] :green_heart: Potentially fix hanging CI test hang Signed-off-by: Gaurav-Kumbhat --- .github/workflows/fms-testing-app.yml | 6 +++--- tests/models/test_ministral3.py | 18 +++++++++--------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/.github/workflows/fms-testing-app.yml b/.github/workflows/fms-testing-app.yml index 76bfbbce4..3e87a90eb 100644 --- a/.github/workflows/fms-testing-app.yml +++ b/.github/workflows/fms-testing-app.yml @@ -38,7 +38,7 @@ jobs: # In case of a cache miss, but hit on a secondary key, this will update what's changed python -m pip install --upgrade pip pip install .[dev] - + # Enables the virtual env for following steps echo "$VIRTUAL_ENV/bin" >> $GITHUB_PATH echo "VIRTUAL_ENV=$VIRTUAL_ENV" >> $GITHUB_ENV @@ -48,10 +48,10 @@ jobs: # Install fms from the PR, all dependencies are already # installed in the virtual env pip install . - export OMP_NUM_THREADS=1 + export OMP_NUM_THREADS=2 export MKL_NUM_THREADS=1 export OPENBLAS_NUM_THREADS=1 - pytest -vv -rP tests/ + pytest -vv -rP tests/ - name: Save Virtualenv id: cache-venv-save diff --git a/tests/models/test_ministral3.py b/tests/models/test_ministral3.py index c5a9e727c..617f506e7 100644 --- a/tests/models/test_ministral3.py +++ b/tests/models/test_ministral3.py @@ -42,8 +42,8 @@ def config(self) -> ModelConfig: p_dropout=0.0, activation_fn="silu", emb_dim=16, - head_dim=128, - max_expected_seq_len=2048, + head_dim=64, + max_expected_seq_len=256, kvheads=2, norm_eps=1e-05, sliding_window=None, @@ -53,7 +53,7 @@ def config(self) -> ModelConfig: "beta_fast": 32.0, "beta_slow": 1.0, "factor": 1.0, - "original_max_position_embeddings": 1024, + "original_max_position_embeddings": 128, "mscale": 1.0, "mscale_all_dim": 1.0, "llama_4_scaling_beta": 0.1, @@ -65,14 +65,14 @@ def config(self) -> ModelConfig: _vision_config = PixtralVisionConfig( hidden_size=16, intermediate_size=64, - nlayers=8, + nlayers=2, nheads=8, nchannels=3, - image_size=280, + image_size=84, patch_size=14, hidden_act="silu", layer_norm_eps=1e-5, - rope_theta=10000.0, + rope_theta=1000.0, attention_dropout=0.0, fused_weights=True, ) @@ -99,16 +99,16 @@ def get_logits(f_out): return f_out[0] pixel_values = [ - [[torch.arange(0, 1, 1 / 280).tolist() for _ in range(280)] for _ in range(3)] + [[torch.arange(0, 1, 1 / 84).tolist() for _ in range(84)] for _ in range(3)] ] input_ids = torch.arange(380).unsqueeze(0) - pixel_values = torch.tensor(pixel_values) # [1, 3, 280, 280] + pixel_values = torch.tensor(pixel_values) # [1, 3, 84, 84] _get_signature_params = ["input_ids_or_embeds"] _get_signature_input_ids = input_ids _get_signature_optional_params = { "pixel_values": pixel_values, - "image_sizes": [(280, 280)], + "image_sizes": [(84, 84)], "last_n_tokens": 1, } From 9f203eb7e3513c54aba2046ec02160cd02382aac Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Fri, 3 Apr 2026 13:57:06 +0000 Subject: [PATCH 85/98] :bug: Fix duplicate model loading causing heavy memory usage in unit tests Signed-off-by: Gaurav-Kumbhat --- fms/models/ministral3.py | 6 +++--- tests/models/test_ministral3.py | 12 ++++++------ ....test_ministral3.TestMinistral3.test_model_output | 2 +- ..._ministral3.TestMinistral3.test_model_weight_keys | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/fms/models/ministral3.py b/fms/models/ministral3.py index b0ab40596..fe2b6b631 100644 --- a/fms/models/ministral3.py +++ b/fms/models/ministral3.py @@ -349,7 +349,7 @@ def forward( return preds -class Ministral3(Mistral3): +class Ministral3(Mistral3): # type: ignore[misc] """ Ministral3 multimodal model combining text and vision. @@ -364,10 +364,10 @@ def __init__( distributed_strategy: DistributedStrategy = NoOpStrategy, **kwargs, ): - super().__init__() + nn.Module.__init__(self) if config is not None: - self.config = config # type: ignore[assignment] + self.config: Ministral3Config = config # type: ignore[assignment] else: self.config = Ministral3Config() # type: ignore[assignment] diff --git a/tests/models/test_ministral3.py b/tests/models/test_ministral3.py index 617f506e7..2e5cc5387 100644 --- a/tests/models/test_ministral3.py +++ b/tests/models/test_ministral3.py @@ -43,7 +43,7 @@ def config(self) -> ModelConfig: activation_fn="silu", emb_dim=16, head_dim=64, - max_expected_seq_len=256, + max_expected_seq_len=1024, kvheads=2, norm_eps=1e-05, sliding_window=None, @@ -53,7 +53,7 @@ def config(self) -> ModelConfig: "beta_fast": 32.0, "beta_slow": 1.0, "factor": 1.0, - "original_max_position_embeddings": 128, + "original_max_position_embeddings": 512, "mscale": 1.0, "mscale_all_dim": 1.0, "llama_4_scaling_beta": 0.1, @@ -68,7 +68,7 @@ def config(self) -> ModelConfig: nlayers=2, nheads=8, nchannels=3, - image_size=84, + image_size=280, patch_size=14, hidden_act="silu", layer_norm_eps=1e-5, @@ -99,16 +99,16 @@ def get_logits(f_out): return f_out[0] pixel_values = [ - [[torch.arange(0, 1, 1 / 84).tolist() for _ in range(84)] for _ in range(3)] + [[torch.arange(0, 1, 1 / 280).tolist() for _ in range(280)] for _ in range(3)] ] input_ids = torch.arange(380).unsqueeze(0) - pixel_values = torch.tensor(pixel_values) # [1, 3, 84, 84] + pixel_values = torch.tensor(pixel_values) # [1, 3, 280, 280] _get_signature_params = ["input_ids_or_embeds"] _get_signature_input_ids = input_ids _get_signature_optional_params = { "pixel_values": pixel_values, - "image_sizes": [(84, 84)], + "image_sizes": [(280, 280)], "last_n_tokens": 1, } diff --git a/tests/resources/expectations/models.test_ministral3.TestMinistral3.test_model_output b/tests/resources/expectations/models.test_ministral3.TestMinistral3.test_model_output index 8df428c51..6abcbe278 100644 --- a/tests/resources/expectations/models.test_ministral3.TestMinistral3.test_model_output +++ b/tests/resources/expectations/models.test_ministral3.TestMinistral3.test_model_output @@ -1 +1 @@ -0.05594002455472946,0.07009149342775345,0.05336397886276245,0.03607361018657684,0.04368116706609726,0.020182885229587555,0.049644190818071365,0.03600817546248436,0.01932992786169052,0.03297264873981476,0.05269749090075493,0.06371669471263885,0.05002269893884659,0.0729384496808052,0.047882407903671265,0.043097082525491714,0.0,0.043577760457992554,0.07921542227268219,0.056020431220531464,0.07686556875705719,0.09743601083755493,0.06762776523828506,0.04591165855526924,0.05145245045423508,0.03331845998764038,0.04739247262477875,0.05056333541870117,0.02307118847966194,0.03299802169203758,0.04437664523720741,0.05828511714935303,0.06406901031732559,0.04196309298276901,0.06661192327737808,0.021993786096572876,0.04157635569572449,0.04543519765138626,0.043296266347169876,0.0587100051343441,0.06388124078512192,0.02925882861018181,0.059809643775224686,0.03835967928171158,0.04687875136733055,0.06482845544815063,0.0491284504532814,0.06644788384437561,0.06589306145906448,0.07015224546194077,0.05971100553870201,0.05632736161351204,0.051187485456466675,0.0716148093342781,0.028318822383880615,0.049211032688617706,0.038459450006484985,0.061620816588401794,0.05900892987847328,0.06576603651046753,0.08574653416872025,0.035406630486249924,0.05052793771028519,0.023033753037452698,0.06948710232973099,0.07192317396402359,0.05369403213262558,0.05711257457733154,0.060141224414110184,0.046240903437137604,0.07929585874080658,0.04393337666988373,0.023667525500059128,0.05238502472639084,0.08191967755556107,0.05745270848274231,0.06330782175064087,0.0509466677904129,0.03697159141302109,0.028086066246032715,0.04733707010746002,0.02083737403154373,0.04745139181613922,0.06426891684532166,0.05620785057544708,0.04037903621792793,0.048980433493852615,0.07401958853006363,0.047023944556713104,0.042739540338516235,0.06628364324569702,0.03822977840900421,0.05312575399875641,0.03877400606870651,0.05186798423528671,0.07197622954845428,0.049045808613300323,0.036283183842897415,0.05679897591471672,0.07532411813735962,0.03873002529144287,0.03852667286992073,0.06784132868051529,0.048094164580106735,0.040452368557453156,0.037931911647319794,0.05681835114955902,0.04649338498711586,0.05872511491179466,0.038828182965517044,0.051229801028966904,0.0463840588927269,0.08150972425937653,0.04879182204604149,0.056264422833919525,0.04249201714992523,0.03347032517194748,0.030649255961179733,0.07214010506868362,0.03595946729183197,0.04968693479895592,0.03366164118051529,0.05083736032247543,0.04783494397997856,0.055685292929410934,0.05780669301748276,0.06438229978084564,0.06571371108293533,0.025331638753414154,0.06469275802373886,0.07349155098199844,0.04786045104265213,0.03422994911670685,0.04778201878070831,0.03741317614912987,0.07336126267910004,0.03802672028541565,0.06003715842962265,0.05296717956662178,0.03978915512561798,0.06245414912700653,0.03599925339221954,0.033691056072711945,0.0517766997218132,0.032938700169324875,0.036392614245414734,0.04499802365899086,0.029073677957057953,0.04011987894773483,0.027147367596626282,0.048559803515672684,0.04527686536312103,0.06866106390953064,0.06106039136648178,0.062874436378479,0.04782714694738388,0.06262529641389847,0.05039992928504944,0.05237973481416702,0.05835968255996704,0.07875573635101318,0.06677549332380295,0.05194057524204254,0.0299711674451828,0.03606399893760681,0.021126694977283478,0.05699087306857109,0.03827044367790222,0.09272082149982452,0.046609703451395035,0.02733200043439865,0.05647575110197067,0.04432094469666481,0.07837691903114319,0.026644282042980194,0.04794425517320633,0.07198896259069443,0.022146183997392654,0.025505665689706802,0.06209911033511162,0.04710085690021515,0.05572003871202469,0.006826337426900864,0.06468471139669418,0.061093684285879135,0.061877135187387466,0.06681922823190689,0.03002656251192093,0.08999139815568924,0.04203619807958603,0.0449824258685112,0.03518887981772423,0.04585397616028786,0.04450259357690811,0.04358106851577759,0.04277805984020233,0.03706223890185356,0.0652754083275795,0.06435525417327881,0.05307982489466667,0.044058907777071,0.0703946053981781,0.06323256343603134,0.0340811163187027,0.08040597289800644,0.07880276441574097,0.07025234401226044,0.05970700457692146,0.06700190156698227,0.05935616418719292,0.07070162892341614,0.06446608901023865,0.060955893248319626,0.03486822545528412,0.0901142954826355,0.056508421897888184,0.05714711174368858,0.04258004203438759,0.053357988595962524,0.04063870385289192,0.06193080544471741,0.0859304741024971,0.023655209690332413,0.04665641486644745,0.053071774542331696,0.05750152841210365,0.06564462929964066,0.045145370066165924,0.07286965101957321,0.03985186666250229,0.05366826057434082,0.06284862756729126,0.06383651494979858,0.055306412279605865,0.041489385068416595,0.027504824101924896,0.059377703815698624,0.05432402342557907,0.04612661898136139,0.04668726399540901,0.041694555431604385,0.05230220407247543,0.06847383826971054,0.04534347355365753,0.07110844552516937,0.050351470708847046,0.0326702818274498,0.06409081816673279,0.05779050290584564,0.03700583428144455,0.04338226094841957,0.051451265811920166,0.07567987591028214,0.05758321285247803,0.06591202318668365,0.05284284055233002,0.06210082396864891,0.037575751543045044,0.057571228593587875,0.05699337273836136,0.057372115552425385,0.015277087688446045,0.04997456073760986,0.016514025628566742,0.05247596651315689,0.04775907099246979,0.05467674881219864,0.05624669790267944,0.08368711173534393,0.047751057893037796,0.03108825907111168,0.07771854102611542,0.05060920864343643,0.05379769951105118,0.06517908722162247,0.07358068972826004,0.05581140145659447,0.06085636094212532,0.03242943808436394,0.06907801330089569,0.020396657288074493,0.04166264459490776,0.07063055783510208,0.05786340683698654,0.07071325927972794,0.0384451299905777,0.08084823936223984,0.03923499584197998,0.033806439489126205,0.02290938049554825,0.06376589089632034,0.04204033687710762,0.048066891729831696,0.03303638845682144,0.03966914862394333,0.038582973182201385,0.025842905044555664,0.06175125390291214,0.07047692686319351,0.06338955461978912,0.05593523383140564,0.06257142871618271,0.06045427918434143,0.0462874099612236,0.05155611038208008,0.03937206417322159,0.031223703175783157,0.05334903299808502,0.0914672315120697,0.026970822364091873,0.06529779732227325,0.030449271202087402,0.05416475236415863,0.004468072205781937,0.047933876514434814,0.05577709525823593,0.04961932450532913,0.04921341687440872,0.05563195049762726,0.03984059765934944,0.01825941726565361,0.06467757374048233,0.06567581743001938,0.05535609647631645,0.04216676577925682,0.0649462640285492,0.03579434007406235,0.014211688190698624,0.06829570978879929,0.04222029447555542,0.0759027823805809,0.03325241059064865,0.0515141487121582,0.06949587166309357,0.028515517711639404,0.07451120764017105,0.06473961472511292,0.044913679361343384,0.029872387647628784,0.07605072110891342,0.055929817259311676,0.021009724587202072,0.061686813831329346,0.051714230328798294,0.07307826727628708,0.07604605704545975,0.05582323297858238,0.05142980068922043,0.05531527101993561,0.036295413970947266,0.07948359847068787,0.03756716102361679,0.06539834290742874,0.04349612444639206,0.0744294822216034,0.05192862078547478,0.06575308740139008,0.03124430403113365,0.04962066188454628,0.02668260782957077,0.03821687772870064,0.039260007441043854,0.03346341848373413,0.04489395394921303,0.049637775868177414,0.048156771808862686,0.03926992416381836,0.050722889602184296,0.04133383929729462,0.07423140853643417,0.08342073857784271,0.05672822892665863,0.047992072999477386,0.06504455953836441,0.049927420914173126,0.035322993993759155,0.04873591288924217,0.06584584712982178,0.04754657298326492,0.04797336459159851,0.06895320117473602,0.06908737868070602,0.06236596405506134,0.045539986342191696 \ No newline at end of file +0.03632379323244095,0.035344384610652924,0.04109947383403778,0.04705234616994858,0.02966911904513836,0.027637649327516556,0.049139343202114105,0.04274829849600792,0.022070135921239853,0.035422153770923615,0.04549841582775116,0.06979531049728394,0.018350187689065933,0.05250324681401253,0.04394900053739548,0.025542259216308594,0.04211209714412689,0.038119807839393616,0.06207406893372536,0.042680226266384125,0.03132777288556099,0.030300583690404892,0.04647667706012726,0.0628783106803894,0.03639940917491913,0.015807166695594788,0.044461850076913834,0.06800466775894165,0.05346361920237541,0.05086711794137955,0.039510663598775864,0.04158898442983627,0.0438123382627964,0.03720914572477341,0.03388279303908348,0.04455924779176712,0.05358163267374039,0.04916258528828621,0.03697759658098221,0.018106043338775635,0.02980373054742813,0.043442510068416595,0.09326419979333878,0.03632631525397301,0.05088116228580475,0.05541339889168739,0.018951017409563065,0.03223343566060066,0.03842088580131531,0.02900439128279686,0.046620115637779236,0.01803099364042282,0.02091536670923233,0.05170918628573418,0.044846102595329285,0.058182407170534134,0.030878297984600067,0.034685246646404266,0.07139439135789871,0.06044542416930199,0.03549648076295853,0.06912861764431,0.009071849286556244,0.025459356606006622,0.04916194826364517,0.06315391510725021,0.057791080325841904,0.028736522421240807,0.03252141922712326,0.016005106270313263,0.044276200234889984,0.03603126108646393,0.0485466867685318,0.038767650723457336,0.01092495396733284,0.04759937897324562,0.03870445862412453,0.013458169996738434,0.04053083062171936,0.03947385400533676,0.017870020121335983,0.04844598099589348,0.0581950880587101,0.055418312549591064,0.025134429335594177,0.034038133919239044,0.05665779113769531,0.052461177110672,0.056680209934711456,0.02046525478363037,0.03636244684457779,0.03620140627026558,0.06002183258533478,0.03450213372707367,0.048353008925914764,0.05591093376278877,0.021953288465738297,0.01148252934217453,0.02776729315519333,0.0543694943189621,0.04897244647145271,0.08276629447937012,0.03943256661295891,0.024314522743225098,0.01969410851597786,0.06646724045276642,0.04952558875083923,0.06016235798597336,0.0372757613658905,0.03610376641154289,0.03856613487005234,0.03377523273229599,0.039066195487976074,0.0388367660343647,0.03446221351623535,0.05671267583966255,0.016908816993236542,0.04948379471898079,0.014955379068851471,0.06306394189596176,0.03582386299967766,0.0284174345433712,0.02190270647406578,0.020916521549224854,0.01623149961233139,0.026946045458316803,0.042997196316719055,0.05673808977007866,0.04490277171134949,0.051634132862091064,0.06655339896678925,0.024741385132074356,0.02554463967680931,0.025075264275074005,0.06419844925403595,0.05065048485994339,0.023269474506378174,0.0374397337436676,0.058870624750852585,0.050886452198028564,0.05288230627775192,0.011418655514717102,0.04351235181093216,0.05177979916334152,0.008410856127738953,0.05673884600400925,0.07906267046928406,0.02831476926803589,0.06606761366128922,0.06546656787395477,0.02292202040553093,0.0669541284441948,0.04900304228067398,0.04335181042551994,0.037119194865226746,0.05495976656675339,0.0428440198302269,0.02695220336318016,0.048280417919158936,0.061942972242832184,0.05410611629486084,0.025197669863700867,0.0343097485601902,0.0485098734498024,0.04501078650355339,0.04809942841529846,0.03756085783243179,0.0800384134054184,0.020274784415960312,0.028009338304400444,0.030865177512168884,0.03936278074979782,0.036193519830703735,0.027260806411504745,0.05933503061532974,0.030368126928806305,0.05535055324435234,0.04547633230686188,0.04060015454888344,0.0550357922911644,0.04980041831731796,0.024750161916017532,0.03155101090669632,0.04797712713479996,0.03040524572134018,0.0588507354259491,0.0661880224943161,0.05012280121445656,0.04061368480324745,0.029085181653499603,0.018001634627580643,0.06061973050236702,0.0545937642455101,0.02107742428779602,0.03403978794813156,0.03869195654988289,0.039192765951156616,0.026884475722908974,0.04369514435529709,0.014917895197868347,0.0412215031683445,0.04482920840382576,0.039483435451984406,0.021520648151636124,0.015865735709667206,0.014385655522346497,0.04757208377122879,0.05576566234230995,0.06021423265337944,0.07017016410827637,0.019559122622013092,0.03287498652935028,0.06662260740995407,0.029564443975687027,0.049461666494607925,0.048533495515584946,0.03161636367440224,0.020749472081661224,0.018325716257095337,0.05265233665704727,0.04959123209118843,0.07803913950920105,0.09322942793369293,0.03470584377646446,0.059826046228408813,0.014522898942232132,0.04684620350599289,0.028367342427372932,0.024929732084274292,0.05427218973636627,0.055192891508340836,0.06993880867958069,0.05736250802874565,0.039255522191524506,0.03649609535932541,0.03716912493109703,0.05207357183098793,0.05704871192574501,0.047411393374204636,0.03350989520549774,0.029189160093665123,0.04978639632463455,0.04152081534266472,0.04540292173624039,0.03291565924882889,0.0168733149766922,0.05194355547428131,0.03737252950668335,0.032327763736248016,0.006422523409128189,0.044815316796302795,0.061221398413181305,0.05886367708444595,0.048430051654577255,0.036045532673597336,0.086647167801857,0.034529656171798706,0.026246577501296997,0.021082334220409393,0.03746473416686058,0.05042137950658798,0.0453735776245594,0.036198459565639496,0.04994291067123413,0.055014483630657196,0.03810656815767288,0.05709894746541977,0.03851801156997681,0.027759354561567307,0.03249160200357437,0.05834164470434189,0.02921876683831215,0.04014115780591965,0.029058510437607765,0.0327381007373333,0.029454249888658524,0.02221957966685295,0.04515340179204941,0.05466482788324356,0.02826894074678421,0.0754048153758049,0.03590726479887962,0.04223071038722992,0.03235270828008652,0.05102522298693657,0.031965434551239014,0.041603680700063705,0.05431303754448891,0.05392816290259361,0.059461794793605804,0.04279337078332901,0.07470382750034332,0.04824378341436386,0.03597980737686157,0.06232500076293945,0.02643832564353943,0.03588464856147766,0.05403822287917137,0.029046498239040375,0.03683880716562271,0.03383907675743103,0.01864594966173172,0.028193499892950058,0.029390959069132805,0.04838241636753082,0.05014704540371895,0.046300433576107025,0.06215522810816765,0.039700545370578766,0.0333699993789196,0.030975978821516037,0.04532227665185928,0.0524325966835022,0.04479862004518509,0.020864754915237427,0.041993238031864166,0.014496278017759323,0.02587035670876503,0.03935857117176056,0.006578847765922546,0.0600462406873703,0.04416990280151367,0.04120875895023346,0.07324066013097763,0.0581766739487648,0.04910421743988991,0.06558243185281754,0.016860563308000565,0.04886248707771301,0.05159497261047363,0.06468531489372253,0.051031693816185,0.018195927143096924,0.05852110683917999,0.028646713122725487,0.03622948378324509,0.027462899684906006,0.033439330756664276,0.02357686683535576,0.04468560963869095,0.04253336042165756,0.028357256203889847,0.05817369744181633,0.04523523896932602,0.009467549622058868,0.06549543142318726,0.051435090601444244,0.00678020715713501,0.04932556301355362,0.04272378236055374,0.03490716591477394,0.06707193702459335,0.07073091715574265,0.029599450528621674,0.03390365466475487,0.04892093688249588,0.03966914117336273,0.032240383327007294,0.03547070920467377,0.034188807010650635,0.032814037054777145,0.035747505724430084,0.03739427030086517,0.07473348081111908,0.04302956908941269,0.05853409692645073,0.05293271690607071,0.03494928777217865,0.025019578635692596,0.043228499591350555,0.04090597853064537,0.034271351993083954,0.03638218715786934,0.021678276360034943,0.01926206797361374,0.025183215737342834,0.011848703026771545,0.06673452258110046,0.0,0.04093881696462631,0.032984331250190735,0.06162877753376961,0.036073438823223114,0.041489459574222565 \ No newline at end of file diff --git a/tests/resources/expectations/models.test_ministral3.TestMinistral3.test_model_weight_keys b/tests/resources/expectations/models.test_ministral3.TestMinistral3.test_model_weight_keys index 986cbf91e..b6ed8a47a 100644 --- a/tests/resources/expectations/models.test_ministral3.TestMinistral3.test_model_weight_keys +++ b/tests/resources/expectations/models.test_ministral3.TestMinistral3.test_model_weight_keys @@ -1 +1 @@ -language_model.base_model.dec_norm.weight,language_model.base_model.embedding.weight,language_model.base_model.layers.0.attn.dense.weight,language_model.base_model.layers.0.attn.in_proj.qkv_fused.weight,language_model.base_model.layers.0.ff_ln.weight,language_model.base_model.layers.0.ff_sub_layer.w2.weight,language_model.base_model.layers.0.ff_sub_layer.wg1_fused.weight,language_model.base_model.layers.0.ln.weight,language_model.base_model.layers.1.attn.dense.weight,language_model.base_model.layers.1.attn.in_proj.qkv_fused.weight,language_model.base_model.layers.1.ff_ln.weight,language_model.base_model.layers.1.ff_sub_layer.w2.weight,language_model.base_model.layers.1.ff_sub_layer.wg1_fused.weight,language_model.base_model.layers.1.ln.weight,language_model.head.weight,multi_modal_projector.linear_1.weight,multi_modal_projector.linear_2.weight,multi_modal_projector.norm.weight,multi_modal_projector.patch_merger.merging_layer.weight,vision_tower.ln_pre.weight,vision_tower.patch_conv.weight,vision_tower.transformer.layers.0.attention_norm.weight,vision_tower.transformer.layers.0.attn.dense.weight,vision_tower.transformer.layers.0.attn.in_proj.qkv_fused.weight,vision_tower.transformer.layers.0.ff_sub_layer.w2.weight,vision_tower.transformer.layers.0.ff_sub_layer.wg1_fused.weight,vision_tower.transformer.layers.0.ffn_norm.weight,vision_tower.transformer.layers.1.attention_norm.weight,vision_tower.transformer.layers.1.attn.dense.weight,vision_tower.transformer.layers.1.attn.in_proj.qkv_fused.weight,vision_tower.transformer.layers.1.ff_sub_layer.w2.weight,vision_tower.transformer.layers.1.ff_sub_layer.wg1_fused.weight,vision_tower.transformer.layers.1.ffn_norm.weight,vision_tower.transformer.layers.2.attention_norm.weight,vision_tower.transformer.layers.2.attn.dense.weight,vision_tower.transformer.layers.2.attn.in_proj.qkv_fused.weight,vision_tower.transformer.layers.2.ff_sub_layer.w2.weight,vision_tower.transformer.layers.2.ff_sub_layer.wg1_fused.weight,vision_tower.transformer.layers.2.ffn_norm.weight,vision_tower.transformer.layers.3.attention_norm.weight,vision_tower.transformer.layers.3.attn.dense.weight,vision_tower.transformer.layers.3.attn.in_proj.qkv_fused.weight,vision_tower.transformer.layers.3.ff_sub_layer.w2.weight,vision_tower.transformer.layers.3.ff_sub_layer.wg1_fused.weight,vision_tower.transformer.layers.3.ffn_norm.weight,vision_tower.transformer.layers.4.attention_norm.weight,vision_tower.transformer.layers.4.attn.dense.weight,vision_tower.transformer.layers.4.attn.in_proj.qkv_fused.weight,vision_tower.transformer.layers.4.ff_sub_layer.w2.weight,vision_tower.transformer.layers.4.ff_sub_layer.wg1_fused.weight,vision_tower.transformer.layers.4.ffn_norm.weight,vision_tower.transformer.layers.5.attention_norm.weight,vision_tower.transformer.layers.5.attn.dense.weight,vision_tower.transformer.layers.5.attn.in_proj.qkv_fused.weight,vision_tower.transformer.layers.5.ff_sub_layer.w2.weight,vision_tower.transformer.layers.5.ff_sub_layer.wg1_fused.weight,vision_tower.transformer.layers.5.ffn_norm.weight,vision_tower.transformer.layers.6.attention_norm.weight,vision_tower.transformer.layers.6.attn.dense.weight,vision_tower.transformer.layers.6.attn.in_proj.qkv_fused.weight,vision_tower.transformer.layers.6.ff_sub_layer.w2.weight,vision_tower.transformer.layers.6.ff_sub_layer.wg1_fused.weight,vision_tower.transformer.layers.6.ffn_norm.weight,vision_tower.transformer.layers.7.attention_norm.weight,vision_tower.transformer.layers.7.attn.dense.weight,vision_tower.transformer.layers.7.attn.in_proj.qkv_fused.weight,vision_tower.transformer.layers.7.ff_sub_layer.w2.weight,vision_tower.transformer.layers.7.ff_sub_layer.wg1_fused.weight,vision_tower.transformer.layers.7.ffn_norm.weight \ No newline at end of file +language_model.base_model.dec_norm.weight,language_model.base_model.embedding.weight,language_model.base_model.layers.0.attn.dense.weight,language_model.base_model.layers.0.attn.in_proj.qkv_fused.weight,language_model.base_model.layers.0.ff_ln.weight,language_model.base_model.layers.0.ff_sub_layer.w2.weight,language_model.base_model.layers.0.ff_sub_layer.wg1_fused.weight,language_model.base_model.layers.0.ln.weight,language_model.base_model.layers.1.attn.dense.weight,language_model.base_model.layers.1.attn.in_proj.qkv_fused.weight,language_model.base_model.layers.1.ff_ln.weight,language_model.base_model.layers.1.ff_sub_layer.w2.weight,language_model.base_model.layers.1.ff_sub_layer.wg1_fused.weight,language_model.base_model.layers.1.ln.weight,language_model.head.weight,multi_modal_projector.linear_1.weight,multi_modal_projector.linear_2.weight,multi_modal_projector.norm.weight,multi_modal_projector.patch_merger.merging_layer.weight,vision_tower.ln_pre.weight,vision_tower.patch_conv.weight,vision_tower.transformer.layers.0.attention_norm.weight,vision_tower.transformer.layers.0.attn.dense.weight,vision_tower.transformer.layers.0.attn.in_proj.qkv_fused.weight,vision_tower.transformer.layers.0.ff_sub_layer.w2.weight,vision_tower.transformer.layers.0.ff_sub_layer.wg1_fused.weight,vision_tower.transformer.layers.0.ffn_norm.weight,vision_tower.transformer.layers.1.attention_norm.weight,vision_tower.transformer.layers.1.attn.dense.weight,vision_tower.transformer.layers.1.attn.in_proj.qkv_fused.weight,vision_tower.transformer.layers.1.ff_sub_layer.w2.weight,vision_tower.transformer.layers.1.ff_sub_layer.wg1_fused.weight,vision_tower.transformer.layers.1.ffn_norm.weight \ No newline at end of file From 245f345943d4ed136b5c2f167ef26ae76fa4ba7a Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Mon, 6 Apr 2026 11:46:07 -0300 Subject: [PATCH 86/98] Adds last hidden state comparison Signed-off-by: Flavia Beo --- tests/models/hf_equivalence/test_qwen3.py | 81 ++++++++++++++++++----- 1 file changed, 63 insertions(+), 18 deletions(-) diff --git a/tests/models/hf_equivalence/test_qwen3.py b/tests/models/hf_equivalence/test_qwen3.py index 24b2b339c..06f90379e 100644 --- a/tests/models/hf_equivalence/test_qwen3.py +++ b/tests/models/hf_equivalence/test_qwen3.py @@ -31,13 +31,14 @@ def _get_hf_model_output(model_path, inputs): outputs = model(**inputs) # The model uses the last token's representation as the embedding embeddings = outputs.last_hidden_state[:, -1, :] + # Normalize embeddings for cosine similarity embeddings = F.normalize(embeddings, p=2, dim=1) query_embedding = embeddings[0] doc_embeddings = embeddings[1:] - return query_embedding, doc_embeddings + return outputs, query_embedding, doc_embeddings def _get_fms_model_output(model_path, inputs): @@ -55,26 +56,32 @@ def _get_fms_model_output(model_path, inputs): # Get input_ids from the inputs dict input_ids = inputs["input_ids"].to(device) - # Prepare inputs for FMS - this will create appropriate mask and position_ids - input_ids_padded, padding_kwargs = pad_input_ids(input_ids, min_pad_length=0) - input_ids_padded = input_ids_padded.to(device) + # Create position_ids - should be sequential positions (0, 1, 2, ...) + batch_size, seq_len = input_ids.shape + position_ids = ( + torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) + ) + + # Don't pass a mask - let the model handle it internally + # The model will use its default causal mask behavior with torch.no_grad(): # Get embeddings from base model (before LM head) - embeddings = model( - input_ids_padded, - mask=padding_kwargs["mask"].to(device), - position_ids=padding_kwargs["position_ids"].to(device), + outputs, _ = model.base_model( + input_ids, + mask=None, + position_ids=position_ids, ) # The model uses the last token's representation as the embedding - embeddings = embeddings[:, -1, :] + embeddings = outputs[:, -1, :] + # Normalize embeddings for cosine similarity embeddings = F.normalize(embeddings, p=2, dim=1) query_embedding = embeddings[0] doc_embeddings = embeddings[1:] - return query_embedding, doc_embeddings + return outputs, query_embedding, doc_embeddings @pytest.mark.slow @@ -110,21 +117,40 @@ def test_qwen3_embedding_0_6b_equivalence(): ) # Get outputs from both models - hf_query_embedding, hf_doc_embeddings = _get_hf_model_output(model_path, inputs) - fms_query_embedding, fms_doc_embeddings = _get_fms_model_output(model_path, inputs) + hf_outputs, hf_query_embedding, hf_doc_embeddings = _get_hf_model_output( + model_path, inputs + ) + fms_outputs, fms_query_embedding, fms_doc_embeddings = _get_fms_model_output( + model_path, inputs + ) hf_scores = hf_query_embedding @ hf_doc_embeddings.T fms_scores = fms_query_embedding @ fms_doc_embeddings.T # First sentence contains the awnser to the query. # It's score should be always the highest. + assert hf_scores[0] > hf_scores[1] assert fms_scores[0] > fms_scores[1] - assert fms_scores[0] > 0.7 - assert hf_scores[0] > 0.7 assert hf_scores[0] > fms_scores[1] assert fms_scores[0] > hf_scores[1] + # Compare the actual hidden states (extract tensor from HF output) + # Only compare non-padded tokens using the attention mask + attention_mask = inputs["attention_mask"] + + for i in range(len(input_texts)): + # Get the valid (non-padded) length for this sequence + valid_length = attention_mask[i].sum().item() + + # Only compare the valid tokens (exclude padding) + hf_hidden = hf_outputs.last_hidden_state[i, :valid_length] + fms_hidden = fms_outputs[i, :valid_length] + + assert torch.allclose(hf_hidden, fms_hidden, atol=1e-2, rtol=1e-2), ( + f"Hidden states don't match for sequence {i}" + ) + @pytest.mark.slow def test_qwen3_embedding_4b_equivalence(): @@ -159,8 +185,12 @@ def test_qwen3_embedding_4b_equivalence(): ) # Get outputs from both models - hf_query_embedding, hf_doc_embeddings = _get_hf_model_output(model_path, inputs) - fms_query_embedding, fms_doc_embeddings = _get_fms_model_output(model_path, inputs) + hf_outputs, hf_query_embedding, hf_doc_embeddings = _get_hf_model_output( + model_path, inputs + ) + fms_outputs, fms_query_embedding, fms_doc_embeddings = _get_fms_model_output( + model_path, inputs + ) hf_scores = hf_query_embedding @ hf_doc_embeddings.T fms_scores = fms_query_embedding @ fms_doc_embeddings.T @@ -169,11 +199,25 @@ def test_qwen3_embedding_4b_equivalence(): # It's score should be always the highest. assert hf_scores[0] > hf_scores[1] assert fms_scores[0] > fms_scores[1] - assert fms_scores[0] > 0.7 - assert hf_scores[0] > 0.7 assert hf_scores[0] > fms_scores[1] assert fms_scores[0] > hf_scores[1] + # Compare the actual hidden states (extract tensor from HF output) + # Only compare non-padded tokens using the attention mask + attention_mask = inputs["attention_mask"] + + for i in range(len(input_texts)): + # Get the valid (non-padded) length for this sequence + valid_length = attention_mask[i].sum().item() + + # Only compare the valid tokens (exclude padding) + hf_hidden = hf_outputs.last_hidden_state[i, :valid_length] + fms_hidden = fms_outputs[i, :valid_length] + + assert torch.allclose(hf_hidden, fms_hidden, atol=1e-2, rtol=1e-2), ( + f"Hidden states don't match for sequence {i}" + ) + def test_qwen3_forward_pass(): """ @@ -336,3 +380,4 @@ def test_qwen3_with_cache(): test_qwen3_with_cache() test_qwen3_parameter_count() test_qwen3_embedding_0_6b_equivalence() + test_qwen3_embedding_4b_equivalence() From 2e39477e60182ebe5a9fff02c670eb3f2d8737f7 Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Mon, 6 Apr 2026 11:52:03 -0300 Subject: [PATCH 87/98] Fix lint Signed-off-by: Flavia Beo --- fms/models/hf/config_utils/param_builders.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fms/models/hf/config_utils/param_builders.py b/fms/models/hf/config_utils/param_builders.py index 991267e5e..633848411 100644 --- a/fms/models/hf/config_utils/param_builders.py +++ b/fms/models/hf/config_utils/param_builders.py @@ -393,6 +393,7 @@ def build_mistral3_params(config: PretrainedConfig) -> dict: config_params["vision_config"] = PixtralVisionConfig(**vision_config_params) return config_params + def build_ministral3_params(config: PretrainedConfig) -> dict: """Param builder for ministral3 mapping Mistral3ForConditionalGeneration to FMS.""" @@ -449,6 +450,7 @@ def build_ministral3_text_params(config: PretrainedConfig) -> dict: # config, config_params, inner_dim=config.intermediate_size ) + def build_qwen3_embeddings_params(config: PretrainedConfig) -> dict: """Param builder for mapping Qwen3ForCausalLM to FMS.""" config_params = { From 32f6e63a212af4efa4b3f9a1f4c0a12ae465d799 Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Mon, 6 Apr 2026 14:19:30 -0300 Subject: [PATCH 88/98] Remove duplicated imports Signed-off-by: Flavia Beo --- fms/models/hf/modeling_hf_adapter.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fms/models/hf/modeling_hf_adapter.py b/fms/models/hf/modeling_hf_adapter.py index f573e85cf..b0bc65a65 100644 --- a/fms/models/hf/modeling_hf_adapter.py +++ b/fms/models/hf/modeling_hf_adapter.py @@ -2,7 +2,6 @@ import copy import os from packaging.version import Version -from packaging.version import Version from typing import Callable, Dict, Optional, Tuple, Union import torch @@ -10,7 +9,6 @@ from torch.nn.modules.loss import _Loss from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin from transformers import __version__ as tf_version -from transformers import __version__ as tf_version from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -24,6 +22,7 @@ else: from transformers.modeling_utils import no_init_weights + from fms.models.hf.utils import mask_2d_to_3d, mask_2d_to_3d_bidirectional From 691c6f8c0a05db76f58b3aa1e1e8d669f0f83dbf Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Wed, 8 Apr 2026 18:16:52 -0300 Subject: [PATCH 89/98] Requested changes from the review Signed-off-by: Flavia Beo --- fms/models/hf/config_utils/param_builders.py | 3 +- fms/models/qwen3.py | 36 ++------------------ 2 files changed, 5 insertions(+), 34 deletions(-) diff --git a/fms/models/hf/config_utils/param_builders.py b/fms/models/hf/config_utils/param_builders.py index 633848411..003a65349 100644 --- a/fms/models/hf/config_utils/param_builders.py +++ b/fms/models/hf/config_utils/param_builders.py @@ -453,6 +453,7 @@ def build_ministral3_text_params(config: PretrainedConfig) -> dict: # def build_qwen3_embeddings_params(config: PretrainedConfig) -> dict: """Param builder for mapping Qwen3ForCausalLM to FMS.""" + rope_theta, _ = reverse_rope_param_lookup(config) config_params = { "norm_eps": config.rms_norm_eps, "bos_token_id": config.bos_token_id, @@ -463,7 +464,7 @@ def build_qwen3_embeddings_params(config: PretrainedConfig) -> dict: "max_expected_seq_len": config.max_position_embeddings, "kvheads": config.num_key_value_heads, "p_dropout": config.attention_dropout, - "rope_base": config.rope_theta, + "rope_base": rope_theta, "head_dim": getattr( config, "head_dim", config.hidden_size // config.num_attention_heads ), diff --git a/fms/models/qwen3.py b/fms/models/qwen3.py index adbd08160..a4808358c 100644 --- a/fms/models/qwen3.py +++ b/fms/models/qwen3.py @@ -194,21 +194,9 @@ def __init__( self.config = self.config.updated(**kwargs) self.distributed_strategy = distributed_strategy - embedding = nn.Embedding( + self.embedding = nn.Embedding( self.config.src_vocab_size, self.config.emb_dim, self.config.pad_id ) - # TP does not work with tied weights - if ( - not isinstance(self.distributed_strategy, TensorParallelStrategy) - or not self.config.tie_heads - ): - self.embedding = self.distributed_strategy.distribute_module(embedding) - else: - logger.warning( - "You're using TP on a model with tied weights between head and embedding. " - "The tied weights won't be sharded, which can result in unexpected OOMs." - ) - self.embedding = embedding self.rot_emb = RotaryEmbedding( dim=self.config.head_dim, @@ -397,17 +385,9 @@ def __init__( self.distributed_strategy = distributed_strategy self.base_model = Qwen3Headless(self.config, self.distributed_strategy) - head = LinearClassificationHead( + self.head = LinearClassificationHead( self.config.emb_dim, self.config.src_vocab_size, bias=False ) - # TP does not work with tied weights - if ( - not isinstance(self.distributed_strategy, TensorParallelStrategy) - or not self.config.tie_heads - ): - self.head = self.distributed_strategy.distribute_module(head) - else: - self.head = head def get_config(self) -> Qwen3Config: return self.config @@ -591,14 +571,8 @@ def _hf_to_fms_rope( # Therefore, to make FMS produce the correct order of outputs when # loading from an HF checkpoint, we need to undo the transformation # that HF does from the original Meta weights - is_gptq_2d_qparam = "gptq" in linear_type_str and param.dim() == 2 if bool(trans_required_pattern.match(name)) and param.numel() > 1: temp = param - if is_gptq_2d_qparam: - # GPTQ qweights are [in_feat, out_feat] (unlike usual [out_feat, in_feat]) - # and are fully transposed before & after process. - # GPTQ scales and qzeros are also transposed accordingly - temp = temp.transpose(0, 1) # num_heads is used in the transformation required for hf->fms # can't be precomputed because q and k might have different num_heads assert model_config is not None and model_config.head_dim is not None @@ -611,8 +585,6 @@ def _hf_to_fms_rope( temp_view = temp.view(num_heads, 2, -1) temp = temp_view.transpose(1, 2).reshape(*temp.size()) - if is_gptq_2d_qparam: - temp = temp.transpose(0, 1) new_sd[name] = temp else: new_sd[name] = param @@ -621,9 +593,7 @@ def _hf_to_fms_rope( def _get_rope_params(linear_type: str) -> list[str]: - if "gptq" in linear_type: - return ["qweight", "scales", "qzeros", "bias"] - elif "int8" in linear_type: + if "int8" in linear_type: # quantize_weight is fms-model-optimizer identifier of weight clip values return ["weight", "bias", "quantize_weight"] elif "fp8" in linear_type: From 346bc0d90bdd030a915806fa1bd6933e00202586 Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Wed, 8 Apr 2026 18:18:10 -0300 Subject: [PATCH 90/98] Moves compiled example to folder Signed-off-by: Flavia Beo --- scripts/{ => examples}/qwen3_embedding_compiled_example.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename scripts/{ => examples}/qwen3_embedding_compiled_example.py (100%) diff --git a/scripts/qwen3_embedding_compiled_example.py b/scripts/examples/qwen3_embedding_compiled_example.py similarity index 100% rename from scripts/qwen3_embedding_compiled_example.py rename to scripts/examples/qwen3_embedding_compiled_example.py From 3d01fdbcff48a74dc1f93a2d31aeb210b5efc87d Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Wed, 8 Apr 2026 18:24:55 -0300 Subject: [PATCH 91/98] Adds comment to explain the norm_eps Signed-off-by: Flavia Beo --- fms/models/qwen3.py | 1 - fms/modules/attention.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fms/models/qwen3.py b/fms/models/qwen3.py index a4808358c..24c01c675 100644 --- a/fms/models/qwen3.py +++ b/fms/models/qwen3.py @@ -12,7 +12,6 @@ from fms.distributed.strategy import ( DistributedStrategy, NoOpStrategy, - TensorParallelStrategy, ) from fms.modules.attention import ( AttentionKwargs, diff --git a/fms/modules/attention.py b/fms/modules/attention.py index 79df3a983..11ce06d11 100644 --- a/fms/modules/attention.py +++ b/fms/modules/attention.py @@ -537,6 +537,7 @@ def __init__( linear_config=linear_config, ) + # Apply normalization if enabled - this is passed as attention kwarg if norm_eps: self.norm = True self.q_norm = LayerNormParameterized( @@ -584,7 +585,7 @@ def forward( keys = self.key(k) # b x klen x (kvheads * head_dim) values = self.value(v) # b x vlen x (kvheads * head_dim) - # Apply normalization if enabled + # Apply normalization if enabled - this is passed as attention kwarg # Normalization should be applied per-head, so we need to reshape first if self.norm: batch_size, q_len, _ = queries.shape From fd51aabed1c06d26464b3dad2f6700718fe1cfdd Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Thu, 9 Apr 2026 17:01:08 -0300 Subject: [PATCH 92/98] Suggestion for the per head norm to be used only when models actually have it Signed-off-by: Flavia Beo --- fms/models/qwen3.py | 1 + fms/modules/attention.py | 18 ++++++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/fms/models/qwen3.py b/fms/models/qwen3.py index 24c01c675..fcaf50311 100644 --- a/fms/models/qwen3.py +++ b/fms/models/qwen3.py @@ -119,6 +119,7 @@ def __init__(self, config: Qwen3Config, rotary_emb: RotaryEmbedding): position_encoder=rotary_emb, fused=self.config.fused_weights, linear_config=self.config.linear_config, + apply_norm_per_head=True, norm_eps=self.config.norm_eps, head_dim=self.config.head_dim, ) diff --git a/fms/modules/attention.py b/fms/modules/attention.py index 11ce06d11..8c755013f 100644 --- a/fms/modules/attention.py +++ b/fms/modules/attention.py @@ -498,6 +498,7 @@ def __init__( emb_v_per_head: int, use_bias: bool, linear_config: Optional[Mapping[str, Any]] = None, + apply_norm_per_head: Optional[bool] = None, norm_eps: Optional[float] = None, head_dim: Optional[int] = None, *args, @@ -517,6 +518,7 @@ def __init__( self.head_dim = head_dim or emb_kq_per_head self.norm_eps = norm_eps or 1e-5 + self.apply_norm_per_head = apply_norm_per_head self.query = get_linear( self.emb_dim, @@ -538,7 +540,7 @@ def __init__( ) # Apply normalization if enabled - this is passed as attention kwarg - if norm_eps: + if norm_eps and self.apply_norm_per_head: self.norm = True self.q_norm = LayerNormParameterized( head_dim, @@ -587,7 +589,7 @@ def forward( # Apply normalization if enabled - this is passed as attention kwarg # Normalization should be applied per-head, so we need to reshape first - if self.norm: + if self.norm and self.apply_norm_per_head: batch_size, q_len, _ = queries.shape k_len = keys.shape[1] @@ -631,6 +633,7 @@ def __init__( emb_v_per_head: int, use_bias: bool, linear_config: Optional[Mapping[str, Any]] = None, + apply_norm_per_head: Optional[bool] = None, norm_eps: Optional[float] = None, head_dim: Optional[int] = None, *args, @@ -735,6 +738,14 @@ class MultiHeadAttention(nn.Module): scale_factor : float | None Optional scaling factor applied to the attention logits. If None, a default scaling based on the embedding dimension may be used. + apply_norm_per_head : bool | None + If True, applies normalization per attention head. If None, normalization is not applied. + norm_eps : float | None + Epsilon value for normalization to ensure numerical stability. Only used when + apply_norm_per_head is True. + head_dim : int | None + Dimensionality of each attention head. If None, it will be computed based on other + parameters. has_sinks : bool If True, enables the use of sink tokens, which are represented by learnable parameters (one per attention head). Sink tokens can be used to aggregate information across tokens @@ -755,6 +766,7 @@ def __init__( fused: bool = True, linear_config: Optional[Mapping[str, Any]] = None, scale_factor: Optional[float] = None, + apply_norm_per_head: Optional[bool] = None, norm_eps: Optional[float] = None, head_dim: Optional[int] = None, has_sinks: bool = False, @@ -770,6 +782,7 @@ def __init__( self.fused = fused self.linear_config = linear_config self.scale_factor = scale_factor + self.apply_norm_per_head = apply_norm_per_head self.norm_eps = norm_eps self.head_dim = head_dim self.has_sinks = has_sinks @@ -782,6 +795,7 @@ def __init__( self.emb_v_per_head, self.use_bias, linear_config=linear_config, + apply_norm_per_head=self.apply_norm_per_head, norm_eps=self.norm_eps, head_dim=self.head_dim, ) From ceff78a8420f00b0d674a2d2ab6af7001bc6a192 Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Fri, 10 Apr 2026 10:26:25 -0300 Subject: [PATCH 93/98] Removes uneeded var Signed-off-by: Flavia Beo --- fms/modules/attention.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/fms/modules/attention.py b/fms/modules/attention.py index 8c755013f..330f9b500 100644 --- a/fms/modules/attention.py +++ b/fms/modules/attention.py @@ -541,7 +541,6 @@ def __init__( # Apply normalization if enabled - this is passed as attention kwarg if norm_eps and self.apply_norm_per_head: - self.norm = True self.q_norm = LayerNormParameterized( head_dim, elementwise_scale=True, @@ -558,8 +557,6 @@ def __init__( eps=norm_eps, use_high_precision_pow=True, ) - else: - self.norm = False def reset_parameters(self): for m in self.modules(): @@ -589,7 +586,7 @@ def forward( # Apply normalization if enabled - this is passed as attention kwarg # Normalization should be applied per-head, so we need to reshape first - if self.norm and self.apply_norm_per_head: + if self.apply_norm_per_head: batch_size, q_len, _ = queries.shape k_len = keys.shape[1] From cb3b56931fc2581963cee049b53ff1e0688e30b0 Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Fri, 10 Apr 2026 11:08:39 -0300 Subject: [PATCH 94/98] Fix Qwen models weights Signed-off-by: Flavia Beo --- fms/models/qwen3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fms/models/qwen3.py b/fms/models/qwen3.py index fcaf50311..60720a92f 100644 --- a/fms/models/qwen3.py +++ b/fms/models/qwen3.py @@ -561,7 +561,7 @@ def _hf_to_fms_rope( ) # hf -> fms requires a transpose operation for the query and key - # weight and bias parameters for Llama models + # weight and bias parameters for Qwen models # This transpose is due to the different implementation of RoPE in # HF and FMS. While FMS follows the original RoPE paper # (https://arxiv.org/abs/2104.09864), HF has its own implementation From 07b0eecc5270bcba3d4aef669debf8be0ac5ebe1 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 20 Apr 2026 18:30:20 -0600 Subject: [PATCH 95/98] :zap: remove extra kv cache validation (#524) Signed-off-by: Joe Runde --- fms/utils/spyre/paged.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fms/utils/spyre/paged.py b/fms/utils/spyre/paged.py index c883a0585..bedb229d5 100644 --- a/fms/utils/spyre/paged.py +++ b/fms/utils/spyre/paged.py @@ -9,6 +9,7 @@ import torch from torch.library import custom_op +from torch import compiler import torch.nn.functional as F @@ -213,6 +214,11 @@ def __spyre_paged_validate_attn_kwargs_op( if left_padded_prompt_mask is not None: assert input_ids.shape[0] == left_padded_prompt_mask.shape[0] + if compiler.is_compiling(): + # When compiling for spyre, don't run expensive KV cache validation. + # The input `past_key_value_states` aren't actually sent to the device when running on spyre + return + if past_key_value_states is not None: for k, v in past_key_value_states: # assert that for each layer, k and v have the same number of blocks From 396328a697ac024a5c6b6207a74736512082e7f8 Mon Sep 17 00:00:00 2001 From: Gaurav-Kumbhat Date: Wed, 29 Apr 2026 11:05:21 -0500 Subject: [PATCH 96/98] :memo: Add comment for parity with pad_token_id Signed-off-by: Gaurav-Kumbhat --- fms/utils/generation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fms/utils/generation.py b/fms/utils/generation.py index ab7bd60a5..697629071 100644 --- a/fms/utils/generation.py +++ b/fms/utils/generation.py @@ -222,6 +222,8 @@ def generate( contiguous_cache: ensures the cache is contiguous in device memory eos_token_id: the optional token id representing the end of sequence pad_token_id: the optional token id representing the pad token + (This value doesn't affect the generation function and + is present for parity with paged generation) timing: whether to measure timings: "per-token" for each token generation time, "e2e" for full generation loop. Both options make `generate` return a tuple with the following information: From cff957af2007ca2a99344d8e2dd0a0a720b21355 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 29 Apr 2026 11:55:56 -0700 Subject: [PATCH 97/98] feat(trust-remote-code): add trust_remote_code parameter support to model inference Signed-off-by: Prashant Gupta --- fms/models/__init__.py | 6 ++++++ fms/models/hf/utils.py | 3 ++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/fms/models/__init__.py b/fms/models/__init__.py index 9f00d6cd9..fea71f48b 100644 --- a/fms/models/__init__.py +++ b/fms/models/__init__.py @@ -112,9 +112,15 @@ def __maybe_infer_model_variant( logger.info(f"inferring model configuration from {model_path_or_variant}") + # Only pass kwargs that are accepted by infer_model_configuration + infer_kwargs = {} + if "trust_remote_code" in kwargs: + infer_kwargs["trust_remote_code"] = kwargs["trust_remote_code"] + extra_kwargs = infer_model_configuration( model_path_or_variant, download_weights=is_hf_pretrained and variant is not None, # type: ignore[arg-type] + **infer_kwargs, ) architecture = extra_kwargs.pop("architecture") variant = extra_kwargs.pop("variant") diff --git a/fms/models/hf/utils.py b/fms/models/hf/utils.py index afe9fddbc..c9415947d 100644 --- a/fms/models/hf/utils.py +++ b/fms/models/hf/utils.py @@ -120,6 +120,7 @@ def mask_2d_to_3d_bidirectional( def infer_model_configuration( model_id_or_path: str | os.PathLike, download_weights: bool = True, + trust_remote_code: bool = False, ) -> Dict[str, Any]: # if the path does not exist, download it from huggingface and get the local path if not os.path.exists(model_id_or_path): @@ -165,7 +166,7 @@ def infer_model_configuration( else: model_path = str(model_id_or_path) - config = AutoConfig.from_pretrained(model_path) + config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) ## HACK to map Mistral3ForConditionalGeneration to Ministral3 class successfully From da7839a1cd64d98d93dda3899321d7a2791a2576 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Mon, 4 May 2026 10:54:08 -0700 Subject: [PATCH 98/98] Simplify trust_remote_code handling in model inference Refactor to directly use kwargs for trust_remote_code. Signed-off-by: Prashant Gupta --- fms/models/__init__.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/fms/models/__init__.py b/fms/models/__init__.py index fea71f48b..1604ddbce 100644 --- a/fms/models/__init__.py +++ b/fms/models/__init__.py @@ -112,15 +112,10 @@ def __maybe_infer_model_variant( logger.info(f"inferring model configuration from {model_path_or_variant}") - # Only pass kwargs that are accepted by infer_model_configuration - infer_kwargs = {} - if "trust_remote_code" in kwargs: - infer_kwargs["trust_remote_code"] = kwargs["trust_remote_code"] - extra_kwargs = infer_model_configuration( model_path_or_variant, download_weights=is_hf_pretrained and variant is not None, # type: ignore[arg-type] - **infer_kwargs, + trust_remote_code=kwargs.get("trust_remote_code", False), ) architecture = extra_kwargs.pop("architecture") variant = extra_kwargs.pop("variant")