Add Ministral 3B and 8B text-only model support#517
Conversation
Register Ministral 3B (3.4B params) as a new variant of the existing mistral architecture. The text decoder is architecturally identical to Mistral 7B (pre-norm, GQA, SwiGLU), differing in dimensions and RoPE. Key changes: - Add _3b_config with emb_dim=3072, nlayers=26, vocab=131072, tie_heads=True, YaRN RoPE (factor=16, 262K context), no sliding window - Change sliding_window type to Optional[int] to support None - Skip Q/K weight permutation for YaRN in _hf_to_fms_rope (YaRN uses rotate-half same as HF, so interleave-to-pair would corrupt weights) - Fix _clean_up_rot_emb_cache to handle YaRN (cos, sin) tuple format - Strip language_model. prefix and skip vision/projector keys in _hf_to_fms_names for loading text tower from multimodal checkpoints - Map HF YaRN rope_scaling params in build_mistral_params - Add test fixtures and test class for 3b variant with YaRN - Add scripts/run_ministral_3b.py generation example Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add _8b_config (4096 emb_dim, 34 layers, 131K vocab, untied heads, YaRN RoPE factor=16, 262K context) and register as mistral 8b variant - Add Ministral8bFixtures and TestMinistral8b test class - Rename run_ministral_3b.py to run_ministral.py with --variant flag supporting both 3b and 8b Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
@mudhakar — requesting your review on this PR (Ministral 3B + 8B text-only support for FMS). |
kaoutar55
left a comment
There was a problem hiding this comment.
Overall the design looks sound: it cleanly registers Ministral 3B and 8B as new mistral variants, adds YaRN handling in both config building and RoPE conversion, and accounts for loading the text tower from multimodal Mistral3 checkpoints by stripping language_model. and dropping vision/projector weights.
Reusing the existing mistral architecture rather than adding a parallel model class is the right call — the decoder is architecturally equivalent and only differs in dimensions and RoPE setup. The "3b" and "8b" variants are registered cleanly and the implementation stays localized.
The YaRN-specific handling is internally consistent: the param builder maps YaRN fields into FMS config, _clean_up_rot_emb_cache is updated to handle tuple caches, and _hf_to_fms_rope skips the usual Q/K permutation for YaRN since HF and FMS use the same rotate-half convention. The dedicated 3B and 8B tests look good too.
Changes I'd request before merge
1. Add a serialization regression test for the multimodal key filtering.
The new logic in _hf_to_fms_names — stripping language_model. and skipping vision_tower. / multi_modal_projector. keys — is central to this PR, but it doesn't appear to be directly exercised by the new tests. The visible test additions are config/consistency/compile fixtures. We need to add at least one test with a small fake HF-style state dict that covers all three cases: text tower keys that should survive and be remapped, and multimodal-only keys that should be dropped entirely.
Something like this would cover it:
def test_hf_to_fms_names_multimodal_checkpoint():
fake_sd = {
# text tower — should survive and be remapped
"language_model.model.embed_tokens.weight": torch.zeros(1),
"language_model.model.layers.0.self_attn.q_proj.weight": torch.zeros(1),
"language_model.lm_head.weight": torch.zeros(1),
# multimodal-only — should be dropped entirely
"vision_tower.encoder.layers.0.weight": torch.zeros(1),
"multi_modal_projector.linear.weight": torch.zeros(1),
}
result = _hf_to_fms_names(fake_sd)
# vision and projector keys must not appear in output
assert not any("vision_tower" in k for k in result)
assert not any("multi_modal_projector" in k for k in result)
# text tower keys must be remapped to FMS names
assert "base_model.embedding.weight" in result
assert "head.weight" in result
assert "layers.0.attn.in_proj.query.weight" in result # adjust to actual FMS nameThe exact FMS key names in the assertions will need to match whatever _hf_to_fms_names actually produces for the layers.0.self_attn.q_proj path — you can run the function on a real or fake state dict to confirm those before committing.
2. Filter vision/projector keys after prefix normalization, not before.
Right now the code skips keys if the original name starts with vision_tower. or multi_modal_projector., and only then applies the ^language_model\. stripping. This works for the current checkpoint layout but breaks if future checkpoints nest these under language_model. or another prefix. Safer to normalize the prefix first, then decide whether the key belongs to the text tower.
3. _hf_to_fms_rope early return may be too broad.
The YaRN check returns the full state dict untouched:
if model_config and model_config.rope_scaling.get("rope_type") == "yarn":
return input_sdThe permutation skip is correct for Q/K weights, but are there any other transforms later in this function that YaRN models still need? Worth a quick audit to make sure nothing else is getting skipped by accident.
Things worth mention but are not blocking the merge
-
ntk_alpha/ntk_betanaming: you're mapping HF'sbeta_slow → ntk_alphaandbeta_fast → ntk_beta. That's a confusing inversion — a one-line comment explaining the mapping would be useful in the long term. -
sliding_window=Nonedownstream: changing the field toOptional[int]is right. Just worth confirming nothing downstream passesconfig.sliding_windowdirectly to an attention kernel without aNoneguard. The consistency tests probably catch this. -
Glob pattern in the run script has no fallback: if no files match
*model*.safetensors(e.g. a future checkpoint uses different naming), the script fails silently. Need to add a small warning if the glob comes up empty. -
Fallback tokenizer in
run_ministral.pyuses character ordinals, which is fine for a smoke test but can produce out-of-range IDs for Unicode prompts. A brief comment warning users would be enough. -
Both commits are missing DCO sign-off, so the PR is currently blocked on that anyway.
|
Ministral3 has been enabled via #518 and this PR can be closed |
Summary
"3b"and"8b"variants of the existingmistralarchitecture — no new model classes needed_clean_up_rot_emb_cacheto handle YaRN's(cos, sin)tuple cache formatlanguage_model.prefix strip, vision/projector key filtering)scripts/run_ministral.py) supporting both variants via--variant 3b|8bTest plan
TestMistral(7B) tests pass unchangedTestMinistral3btests pass (config round-trip, consistency, compile, weight keys)TestMinistral8btests pass (config round-trip, consistency, compile, weight keys)get_model("mistral", "3b")returns correct config (3072 emb_dim, 26 layers, 3.4B params)get_model("mistral", "8b")returns correct config (4096 emb_dim, 34 layers, 8.5B params)Ministral-3-3B-Instruct-2512-BF16) produces coherent output🤖 Generated with Claude Code