Skip to content

Add Ministral 3B and 8B text-only model support#517

Open
raghukiran1224 wants to merge 2 commits into
mainfrom
add-ministral-3b-support
Open

Add Ministral 3B and 8B text-only model support#517
raghukiran1224 wants to merge 2 commits into
mainfrom
add-ministral-3b-support

Conversation

@raghukiran1224
Copy link
Copy Markdown
Contributor

@raghukiran1224 raghukiran1224 commented Mar 11, 2026

Summary

  • Registers Ministral 3B (3.4B params) and 8B (8.5B params) as new "3b" and "8b" variants of the existing mistral architecture — no new model classes needed
  • Adds YaRN RoPE support for mistral serialization (rope weight permutation skip, HF param name mapping)
  • Fixes _clean_up_rot_emb_cache to handle YaRN's (cos, sin) tuple cache format
  • Handles loading text tower weights from Mistral3 multimodal checkpoints (language_model. prefix strip, vision/projector key filtering)
  • Includes generation example script (scripts/run_ministral.py) supporting both variants via --variant 3b|8b

Test plan

  • All existing TestMistral (7B) tests pass unchanged
  • New TestMinistral3b tests pass (config round-trip, consistency, compile, weight keys)
  • New TestMinistral8b tests pass (config round-trip, consistency, compile, weight keys)
  • get_model("mistral", "3b") returns correct config (3072 emb_dim, 26 layers, 3.4B params)
  • get_model("mistral", "8b") returns correct config (4096 emb_dim, 34 layers, 8.5B params)
  • End-to-end generation from HF checkpoint (Ministral-3-3B-Instruct-2512-BF16) produces coherent output

🤖 Generated with Claude Code

Raghu Ganti and others added 2 commits March 11, 2026 11:08
Register Ministral 3B (3.4B params) as a new variant of the existing
mistral architecture. The text decoder is architecturally identical to
Mistral 7B (pre-norm, GQA, SwiGLU), differing in dimensions and RoPE.

Key changes:
- Add _3b_config with emb_dim=3072, nlayers=26, vocab=131072,
  tie_heads=True, YaRN RoPE (factor=16, 262K context), no sliding window
- Change sliding_window type to Optional[int] to support None
- Skip Q/K weight permutation for YaRN in _hf_to_fms_rope (YaRN uses
  rotate-half same as HF, so interleave-to-pair would corrupt weights)
- Fix _clean_up_rot_emb_cache to handle YaRN (cos, sin) tuple format
- Strip language_model. prefix and skip vision/projector keys in
  _hf_to_fms_names for loading text tower from multimodal checkpoints
- Map HF YaRN rope_scaling params in build_mistral_params
- Add test fixtures and test class for 3b variant with YaRN
- Add scripts/run_ministral_3b.py generation example

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add _8b_config (4096 emb_dim, 34 layers, 131K vocab, untied heads,
  YaRN RoPE factor=16, 262K context) and register as mistral 8b variant
- Add Ministral8bFixtures and TestMinistral8b test class
- Rename run_ministral_3b.py to run_ministral.py with --variant flag
  supporting both 3b and 8b

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@raghukiran1224 raghukiran1224 changed the title Add Ministral 3B text-only model support Add Ministral 3B and 8B text-only model support Mar 11, 2026
@raghukiran1224
Copy link
Copy Markdown
Contributor Author

@mudhakar — requesting your review on this PR (Ministral 3B + 8B text-only support for FMS).

Copy link
Copy Markdown
Collaborator

@kaoutar55 kaoutar55 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @raghukiran1224

Overall the design looks sound: it cleanly registers Ministral 3B and 8B as new mistral variants, adds YaRN handling in both config building and RoPE conversion, and accounts for loading the text tower from multimodal Mistral3 checkpoints by stripping language_model. and dropping vision/projector weights.

Reusing the existing mistral architecture rather than adding a parallel model class is the right call — the decoder is architecturally equivalent and only differs in dimensions and RoPE setup. The "3b" and "8b" variants are registered cleanly and the implementation stays localized.

The YaRN-specific handling is internally consistent: the param builder maps YaRN fields into FMS config, _clean_up_rot_emb_cache is updated to handle tuple caches, and _hf_to_fms_rope skips the usual Q/K permutation for YaRN since HF and FMS use the same rotate-half convention. The dedicated 3B and 8B tests look good too.

Changes I'd request before merge

1. Add a serialization regression test for the multimodal key filtering.

The new logic in _hf_to_fms_names — stripping language_model. and skipping vision_tower. / multi_modal_projector. keys — is central to this PR, but it doesn't appear to be directly exercised by the new tests. The visible test additions are config/consistency/compile fixtures. We need to add at least one test with a small fake HF-style state dict that covers all three cases: text tower keys that should survive and be remapped, and multimodal-only keys that should be dropped entirely.

Something like this would cover it:

def test_hf_to_fms_names_multimodal_checkpoint():
    fake_sd = {
        # text tower — should survive and be remapped
        "language_model.model.embed_tokens.weight": torch.zeros(1),
        "language_model.model.layers.0.self_attn.q_proj.weight": torch.zeros(1),
        "language_model.lm_head.weight": torch.zeros(1),
        # multimodal-only — should be dropped entirely
        "vision_tower.encoder.layers.0.weight": torch.zeros(1),
        "multi_modal_projector.linear.weight": torch.zeros(1),
    }

    result = _hf_to_fms_names(fake_sd)

    # vision and projector keys must not appear in output
    assert not any("vision_tower" in k for k in result)
    assert not any("multi_modal_projector" in k for k in result)

    # text tower keys must be remapped to FMS names
    assert "base_model.embedding.weight" in result
    assert "head.weight" in result
    assert "layers.0.attn.in_proj.query.weight" in result  # adjust to actual FMS name

The exact FMS key names in the assertions will need to match whatever _hf_to_fms_names actually produces for the layers.0.self_attn.q_proj path — you can run the function on a real or fake state dict to confirm those before committing.


2. Filter vision/projector keys after prefix normalization, not before.

Right now the code skips keys if the original name starts with vision_tower. or multi_modal_projector., and only then applies the ^language_model\. stripping. This works for the current checkpoint layout but breaks if future checkpoints nest these under language_model. or another prefix. Safer to normalize the prefix first, then decide whether the key belongs to the text tower.


3. _hf_to_fms_rope early return may be too broad.

The YaRN check returns the full state dict untouched:

if model_config and model_config.rope_scaling.get("rope_type") == "yarn":
    return input_sd

The permutation skip is correct for Q/K weights, but are there any other transforms later in this function that YaRN models still need? Worth a quick audit to make sure nothing else is getting skipped by accident.


Things worth mention but are not blocking the merge

  • ntk_alpha / ntk_beta naming: you're mapping HF's beta_slow → ntk_alpha and beta_fast → ntk_beta. That's a confusing inversion — a one-line comment explaining the mapping would be useful in the long term.

  • sliding_window=None downstream: changing the field to Optional[int] is right. Just worth confirming nothing downstream passes config.sliding_window directly to an attention kernel without a None guard. The consistency tests probably catch this.

  • Glob pattern in the run script has no fallback: if no files match *model*.safetensors (e.g. a future checkpoint uses different naming), the script fails silently. Need to add a small warning if the glob comes up empty.

  • Fallback tokenizer in run_ministral.py uses character ordinals, which is fine for a smoke test but can produce out-of-range IDs for Unicode prompts. A brief comment warning users would be enough.

  • Both commits are missing DCO sign-off, so the PR is currently blocked on that anyway.

@gkumbhat
Copy link
Copy Markdown
Collaborator

gkumbhat commented Apr 8, 2026

Ministral3 has been enabled via #518 and this PR can be closed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants