feat(idefics3): add FMS-native Idefics3 model and HF config mapping#503
feat(idefics3): add FMS-native Idefics3 model and HF config mapping#503toddllm wants to merge 13 commits into
Conversation
- Add FMS-native idefics3 architecture and components\n- Map HF Idefics3ForConditionalGeneration to idefics3 config\n- Wire SigLIP/LLaMA plumbing for HF checkpoints Signed-off-by: Todd Deshane <todd.deshane@ibm.com>
Signed-off-by: Todd Deshane <todd.deshane@ibm.com>
- Harden packing and connector behavior\n- Shifted loss + pad/image masking\n- Padding-aware generation + cached-decode mask handling\n- Avoid is_causal in cached decode under SDPA Signed-off-by: Todd Deshane <todd.deshane@ibm.com>
|
For CPU testing: |
| ModelConsistencyTestSuite, | ||
| Idefics3Fixtures, | ||
| ): | ||
| @staticmethod |
There was a problem hiding this comment.
There is one test failing as I run in my environment deployment:
FAILED tests/models/test_idefics3.py::TestIdefics3::test_model_output[False] - RuntimeError: Expected all tensors to be on the same device, but got index is on cuda:0, different from other tensors on cpu (when checking argument in method wrapper_CUDA__index_select)
My environment has GPU, but I haven't set up any device to run the tests; Maybe good idea to confirm at the tests level all is running under the same device?
There was a problem hiding this comment.
I will test on GPU. Thank you for looking into this @flaviabeo !
Signed-off-by: Todd Deshane <todd.deshane@ibm.com>
Signed-off-by: Todd Deshane <todd.deshane@ibm.com>
Signed-off-by: Todd Deshane <todd.deshane@ibm.com>
|
@flaviabeo Nice catch with the GPU testing! I added some fixes. Reran GPU parity after the multi-image tolerance tweak; max_abs_diff=6.151e-05 on seed 42, passes with atol=rtol=7e-5. Unit tests fixed by skipping random init for non-float buffers (position_ids). Let me know if you see anything on your GPU. |
| @@ -7,6 +7,7 @@ | |||
| # Used in Llava Next for Granite vision | |||
| from fms.models.siglip_vision import SiglipVisionConfig | |||
| from fms.models.granite import GraniteConfig | |||
There was a problem hiding this comment.
GraniteConfig is imported but never used. Should be removed.
| "layer_norm_eps": vision_cfg.layer_norm_eps, | ||
| "attention_dropout": vision_cfg.attention_dropout, | ||
| } | ||
| use_navit = getattr(vision_cfg, "use_navit_position_buckets", None) |
There was a problem hiding this comment.
This handles multiple fallback cases for NaViT position buckets detection.
It would be good to add a comment explaining why these fallbacks are necessary.
|
|
||
| image_size = config_params["vision_config"].image_size | ||
| patch_size = config_params["vision_config"].patch_size | ||
| patches_per_side = int(image_size) // int(patch_size) |
There was a problem hiding this comment.
These values are already ints from the config, unnecessary casting.
| for k, v in adapted_text.items(): | ||
| new_sd[f"text_model.{k}"] = v | ||
|
|
||
| # ---- connector |
There was a problem hiding this comment.
I see that we have duplicate logic in the two if branches (lines 101-113 and 114-123). Could be refactored: simplified version or something like this:
# Simplified version for k, v in input_sd.items(): key = k[len("model."):] if k.startswith("model.") else k if key.startswith("connector."): key = key[len("connector."):] if key.startswith("modality_projection.proj."): key = key[len("modality_projection.proj."):] new_sd[f"connector.proj.{key}"] = v elif key.startswith("modality_projection."): key = key[len("modality_projection."):] new_sd[f"connector.proj.{key}"] = v elif key == "modality_projection": new_sd["connector.proj.weight"] = v else: new_sd[f"connector.{key}"] = v
| base.embeddings, "reset_parameters" | ||
| ): | ||
| base.embeddings.reset_parameters() | ||
| if hasattr(base, "encoder") and hasattr(base.encoder, "reset_parameters"): |
There was a problem hiding this comment.
Complex nested hasattr checks. Consider adding a method to base classes instead:
def reset_parameters(self):
# Let each component handle its own initialization
if hasattr(self.vision_tower, 'reset_parameters'):
self.vision_tower.reset_parameters()
if hasattr(self.text_model, 'reset_parameters'):
self.text_model.reset_parameters()
nn.init.xavier_uniform_(self.connector.proj.weight)| Returns a dict with pixel_values, pixel_attention_mask, and num_patches. | ||
| """ | ||
|
|
||
| def to_pil(img_like): |
There was a problem hiding this comment.
The to_pil function handles many edge cases but has duplicated PIL import checks. Consider extracting to module level.
| **attn_kwargs: Unpack[AttentionKwargs], | ||
| ): | ||
| # Embed the given vocabulary indices using the given attention mask, with pre-/post-norm and dropout as specified | ||
| if x_in is not None and inputs_embeds is not None: |
There was a problem hiding this comment.
Minor: Type hint for x_in should be Optional[torch.Tensor] for consistency.
| nb_patches_w, device=position_ids.device, dtype=position_ids.dtype | ||
| ) | ||
|
|
||
| fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6) |
There was a problem hiding this comment.
Question: What's the purpose of * (1 - 1e-6)? This should be documented.
|
|
||
| target_device = torch.device(device) | ||
|
|
||
| # Helper function to check if two devices are equivalent |
There was a problem hiding this comment.
String comparison 'cuda' should use dev1.type constant or be more defensive.
- Do not apply NaViT-style bucketing when use_navit_position_buckets is disabled, even if patch_attention_mask is provided. - Match HF bucketing semantics by using exact i/n fractional coords with bucketize(right=True). - Add a regression test and update downstream expectations impacted by this fix. Signed-off-by: Todd Deshane <todd.deshane@ibm.com>
- Remove unused imports and document NaViT config fallbacks. - Refactor HF connector key remapping to avoid duplication. - Simplify reset_parameters checks. - Hoist Pillow import/requirements handling in preprocessing. - Minor typing/style tweaks for device-safe comparisons. - Update SigLIP expected weight keys to include position_ids buffer. Signed-off-by: Todd Deshane <todd.deshane@ibm.com>
|
@kaoutar55 @flaviabeo comments addressed in latest commits |
alex-jw-brooks
left a comment
There was a problem hiding this comment.
Thanks for all the work! some thoughts - it also may be better to wrap the components in one file similar to llava next / ongoing with mistral3 (multimodal) for more consistent patterns across models in case things are refactored in the future. A bit worried things will be really inconsistent for composite models since there are not well documented interfaces for how they should be called through FMS generate etc atm
| "layer_norm_eps": vision_cfg.layer_norm_eps, | ||
| "attention_dropout": vision_cfg.attention_dropout, | ||
| } | ||
| use_navit = getattr(vision_cfg, "use_navit_position_buckets", None) |
There was a problem hiding this comment.
When is use_navit_position_buckets expected to be a defined attribute on the HF config? I don't see it in the idefics3 config here.
IMO it may also be a good idea to separate build_siglip_vision_params and build_siglip_idefics3_vision_params for clarity as well
There was a problem hiding this comment.
Good call. use_navit_position_buckets is not reliably present on HF Idefics3/SmolVLM vision configs, so we now treat it as optional and avoid assuming it exists.
I split the logic into:
build_siglip_vision_params(...)for generic SigLIP defaults (explicit flag if present, else default False)build_siglip_idefics3_vision_params(...)for idefics3/smolvlm-scoped NaViT inference
build_idefics3_params(...) now uses the idefics3-specific builder. I also added unit coverage in tests/models/hf/test_param_builders.py for:
- default False when attr is absent
- smolvlm max_image_size heuristic enabling NaViT
- no heuristic leakage for non-idefics3 parent configs
| H, W: grid dimensions (e.g. 32 for 512/16) | ||
| """ | ||
| if patch_embeds.shape[1] != H * W: | ||
| raise ValueError(f"Expected {H * W} patches, got {patch_embeds.shape[1]}") |
There was a problem hiding this comment.
For the error messages here, can you either log the full shape, or use negative indices? This will get confusing if a batch dim of 1 is squeezed out somewhere since it would log the the wrong dim for the number of patches as is
There was a problem hiding this comment.
Updated this to make error messages robust to squeezed/broadcasted shapes and to include full tensor shapes.
Changes made:
- use stable indexing (
shape[-2]for patch count) instead of fixed positional assumptions - include full shape context in errors, e.g.
x.shape=.../patch_embeds.shape=... - include expected/actual patch counts and H/W/scale in the message
This should make debugging much clearer when an upstream batch/image dimension changes.
| self.out_features = text_hidden | ||
| self.proj = nn.Linear(self.in_features, self.out_features, bias=False) | ||
|
|
||
| @staticmethod |
There was a problem hiding this comment.
Should not be a staticmethod and should just use self.scale
There was a problem hiding this comment.
Agreed, updated.
pixel_unshuffle is now an instance method and uses self.scale directly (no @staticmethod and no external scale arg). This keeps connector behavior fully owned by the module config.
| return out | ||
|
|
||
| @torch.no_grad() | ||
| def generate( |
There was a problem hiding this comment.
Can you try to align this with what is done in llava next, which is to create the merged multimodal embeddings in prepare_inputs_for_generation, using kwargs that have optional values for stuff from the HF processor, e.g., pixel values etc?
I imagine this will likely refactored in the future to be more well patterned, but this will allow for better compatibility with the existing generate, e.g., by passing prepare_inputs_for_generation as a preprocessing hook, and also keep things more consistent with projects that may want to consume it like vLLM spyre
There was a problem hiding this comment.
Implemented this alignment with LlavaNext/shared generation flow.
Changes made:
- added
prepare_inputs_for_generation(iteration, input_ids, kwargs)to build multimodal embeddings on prefill and avoid re-encoding images during cached decode - switched
Idefics3.generate()to usefms.utils.generation.generate(..., prepare_model_inputs_hook=self.prepare_inputs_for_generation) - threaded optional processor kwargs (
pixel_values,images,pixel_attention_mask) through the hook path
This keeps idefics3 compatible with the shared generate path and closer to the existing multimodal pattern used in LlavaNext.
| if hasattr(full_model, "model") and hasattr(full_model.model, "vision_model"): | ||
| vision_model = full_model.model.vision_model | ||
| elif hasattr(full_model, "vision_model"): | ||
| vision_model = full_model.vision_model |
There was a problem hiding this comment.
when is this fallback case expected?
There was a problem hiding this comment.
Added clarification and narrowed the fallback expectations.
Expected/default layout is full_model.model.vision_model.
Fallback path (full_model.vision_model) is only for HF wrapper/layout variation where the nested model container is absent.
I documented this inline and include model type context in fallback logging/errors so it’s explicit when/why the fallback is taken.
| and "vision" not in name | ||
| and "text" not in name | ||
| ): | ||
| # This is risky, better to rely on structure |
There was a problem hiding this comment.
Can you add some logs to some of these fallback cases?
There was a problem hiding this comment.
Added logs on fallback paths.
Current behavior:
- log
infowhen using non-default but expected layout fallback paths - log context (HF model class/type) in fallback messages
- fail loudly (with explicit error) for risky connector extraction cases rather than silently scanning/selecting ambiguous params
This keeps normal/default path quiet while making fallback behavior diagnosable.
| import torch | ||
|
|
||
|
|
||
| class SmolVLMPreprocessor: |
There was a problem hiding this comment.
Similar to some of the above comments on aligning with llava next for model structure and not having a custom .generate - I don't think that we should have an explicit wrapper for the HF processor (even a lazy one for conditional dependencies), since it may cause confusion with how the model is to be used with transformers
There was a problem hiding this comment.
Agreed, removed the explicit wrapper class.
Changes made:
- deleted
SmolVLMPreprocessorwrapper class - kept a thin helper
load_smolvlm_processor(...)that returns HFAutoProcessordirectly - kept
load_smolvlm_preprocessor(...)only as a deprecated compatibility alias to avoid immediate breakage for older call sites
So the canonical interface is now the HF processor itself, aligned with your comment.
| "vision", | ||
| "vision_model", | ||
| "visual", | ||
| "vision_encoder", |
There was a problem hiding this comment.
is this degree of flexibility actually needed for idefic3, or it is written to generically common attributes for vision towers in VLMs?
There was a problem hiding this comment.
Tightened this to idefics3/smolvlm-relevant layouts rather than broad generic scanning.
Current order:
checkpoint.model.vision_model(default)checkpoint.model.vision_tower(fallback, logged)- top-level known attrs (
vision_model,vision_tower,vision) (fallback, logged) - dict-like wrapper keys only if values are actual
nn.Moduleobjects
So flexibility is now constrained to known/expected variants, with logs on non-default selections.
| elif kk == "modality_projection": | ||
| new_sd["connector.proj.weight"] = v | ||
| else: | ||
| new_sd[f"connector.{kk}"] = v |
There was a problem hiding this comment.
A bit apprehensive about adding so much state dict key manipulation here since the serialization utils are so complex already, and other models use regexes for pattern replacement, e.g., here. 🤔 I am not a big fan of the regexes either, but might be better to do things consistently so that we don't see different behaviors for this model, since it will be hard to debug, especially if more adapter steps are added
@kaoutar55 @flaviabeo @gkumbhat any thoughts?
There was a problem hiding this comment.
Aligned this with the existing regex-replacement pattern approach for consistency.
I replaced the branchy connector key remap logic with ordered regex replacement rules (including both model.connector.* and connector.* inputs), then gate on connector. outputs. This keeps behavior explicit while matching the style used elsewhere (e.g. llava_next) for easier debugging/maintenance.
There was a problem hiding this comment.
Hi! Yes, agreed @alex-jw-brooks - better to standardize the way of doing these replacements. These can cause warnings when the weights keys are not mapped correctly and can be hard to fix a bug on the mappings.
| embeddings = patch_embeds.flatten(2).transpose(1, 2) | ||
| embeddings = embeddings + self.position_embedding(self.position_ids) | ||
|
|
||
| if ( |
There was a problem hiding this comment.
can you separate most of the nai vt handling (except the early return with no patch attn mask, which is the same as the normal case) into a separate function to keep forward more minimal? it'll be easier to look at other models that use siglip without it, e.g., granite vision
There was a problem hiding this comment.
Done. I split most NaViT-specific logic out of SiglipVisionEmbeddings.forward into a helper (_navit_position_ids(...)).
forward now keeps the early return for the non-NaViT/no-mask path and only calls the helper for the NaViT bucketed position-id path. This keeps the main forward path minimal for non-NaViT siglip users.
There was a problem hiding this comment.
I see some tests failing:
______ test_config_mapping[SiglipModel-google/siglip-so400m-patch14-384] _______
arch = 'SiglipModel', model_name = 'google/siglip-so400m-patch14-384'
@pytest.mark.parametrize("arch,model_name", MODEL_CONFIG_MAP.items())
def test_config_mapping(arch: str, model_name: str):
from fms.models.hf.utils import infer_model_configuration
cfg_filename = get_file_kwargs_filename(model_name)
kwargs_path = os.path.join(FMS_CONFIGS_DIR, cfg_filename)
if not os.path.isfile(kwargs_path):
raise FileNotFoundError(
f"Model {model_name} has no kwargs file; have you generated it?"
)
with open(kwargs_path, "r") as f:
expected_kwargs = json.load(f)
actual_kwargs = infer_model_configuration(model_name, download_weights=False)
for k, actual_v in actual_kwargs.items():
> expected_v = expected_kwargs[k]
E KeyError: 'use_navit_position_buckets'
tests/utils/test_config_mapping.py:62: KeyError
----------------------------- Captured stderr call -----------------------------
Fetching 1 files: 0%| | 0/1 [00:00<?, ?it/s]
Fetching 1 files: 100%|██████████| 1/1 [00:00<00:00, 29.36it/s]
_ test_config_mapping[LlavaNextForConditionalGeneration-ibm-granite/granite-vision-3.2-2b] _
arch = 'LlavaNextForConditionalGeneration'
model_name = 'ibm-granite/granite-vision-3.2-2b'
It seems related to this change: use_navit_position_buckets - something is not inferring correctly this model configuration or the args are getting lost in between set up and call, maybe?
- add build_siglip_idefics3_vision_params for idefics3/smolvlm-specific inference - keep generic build_siglip_vision_params behavior explicit - wire build_idefics3_params to the specialized builder - add unit tests for default false, smolvlm enablement, and non-idefics3 no-leak behavior Signed-off-by: Todd Deshane <todd.deshane@ibm.com>
- make connector pixel_unshuffle use self.scale and improve shape error messages - document expected HF layout fallbacks and add non-default path logs - tighten vision tower extraction to known idefics3/smolvlm layouts - factor NaViT-specific position-id handling out of siglip forward - normalize docstrings/comments to ASCII for terminal-friendly diffs Signed-off-by: Todd Deshane <todd.deshane@ibm.com>
…ndling - move idefics3 generation onto fms.utils.generation with prepare_inputs_for_generation - thread optional image kwargs through prefill/decode flow - support pixel_attention_mask in image encoding and packing path - fix 5D image + singleton 4D mask broadcasting in _encode_images - add regression coverage for multi-image singleton mask handling Signed-off-by: Todd Deshane <todd.deshane@ibm.com>
- remove custom SmolVLMPreprocessor wrapper class in favor of HF AutoProcessor helper - keep deprecated load_smolvlm_preprocessor alias with compatibility shim - switch connector key remapping to ordered regex replacement rules - add regression test to lock contiguous expanded pixel_attention_mask behavior Signed-off-by: Todd Deshane <todd.deshane@ibm.com>
|
@alex-jw-brooks comments addressed and unit and parity tests re-ran on CPU and GPU |
Summary
Introduce an FMS-native
idefics3architecture and wire it to HF/SmolVLM configs and SigLIP vision, with minimal supporting plumbing and focused tests. Changes are localized to Idefics3 and shared utilities; other architectures are unaffected.What changed
New FMS-style Idefics3 layout
fms/models/idefics3.pyas the public entrypoint (architecture registration + HF adapter registration).fms/models/idefics3_components/for the implementation (config, core model, connector, packing, vision adapter, HF helpers).fms/models/idefics3/package.HF / SmolVLM plumbing
_map_model_configinfms/models/hf/utils.pyto mapIdefics3ForConditionalGenerationtoarchitecture="idefics3", building nestedSiglipVisionConfig+LLaMAConfigand derivingconnector_scale+image_span_len.source="hf"can adapt SmolVLM/Idefics3 state dicts into the FMS layout.fms/models/idefics3_components/hf_adapter.py, normalize connector weight keys so HFmodality_projection.*maps intoconnector.proj.*, and support optional FMS SigLIP usage viaFmsSiglipVisionWrapper.Vision + text support
fms/models/siglip_vision.py: strip leading"model."from SmolVLM checkpoints during HF→FMS name mapping.fms/models/llama.py: add aninputs_embedspath inLLaMAHeadless.forwardso packed multimodal embeddings can drive the text backbone.Config + serialization polish
fms/utils/config.py: allow nested overrides viaModelConfig.updated()with dicts (e.g.text_config={...}).fms/utils/serialization.py: switch unused adapted-checkpoint key reporting fromprinttologger.warning.Tests
Run on this branch:
python -m pytest -q tests/models/test_idefics3.pyThis exercises the tiny nested Idefics3 config (SigLIP vision + LLaMA text) and checks both forward behavior and weight keys against the expectations below.
fms/models/idefics3_components/fms_model.py+fms/models/idefics3.py(overall architecture + registration).fms/models/hf/utils.py::_map_model_config(HF config → FMS config mapping).fms/models/idefics3_components/hf_adapter.py+vision_adapter.py(SmolVLM/SigLIP wrapping and weight key normalization).tests/models/test_idefics3.py