Skip to content

Adding llava_next model#420

Merged
JRosenkranz merged 31 commits into
foundation-model-stack:mainfrom
sahilsuneja1:llava_next
Jul 23, 2025
Merged

Adding llava_next model#420
JRosenkranz merged 31 commits into
foundation-model-stack:mainfrom
sahilsuneja1:llava_next

Conversation

@sahilsuneja1
Copy link
Copy Markdown
Contributor

@sahilsuneja1 sahilsuneja1 commented May 29, 2025

This enables granite-vision in fms == siglip vision (PR #405 ) + llava-next vision-language connector (this PR) + granite-3.1-2b-instruct (already supported)

Please merge PR #405 first, which serves as the parent branch for this.

@sahilsuneja1 sahilsuneja1 changed the title [DRAFT] Adding llava_next model Adding llava_next model Jun 13, 2025
@sahilsuneja1
Copy link
Copy Markdown
Contributor Author

PR ready for review @JRosenkranz

Comment thread fms/models/granite.py Outdated
position_ids=None,
past_key_value_states=None,
use_cache=False,
is_input_embedded=False,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we may want to think about how this could tie in with other models. Seems not specific to Granite.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is not Granite-specific, but hasn't been needed so far with any other models.
So far only granite-vision via llava_next sends its inputs to granite in this form. So maybe good for now for Granite?
But you are right, pixtral would probably have something similar.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this param, I think we can get away with checking on the tensor.dim and removing the param entirely. If the size is 3, it would skip embedding, if the size is 2 would run the embedding. If we do this, we would add documentation for how x_in is handled. In either case, whether including a param or not, we will need some error handling, so don't really see a downside to this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, will try it out!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread fms/models/llava_next.py Outdated
vision_feature_select_strategy: str = "full"
vision_feature_layer: list = field(default_factory=lambda: [-24, -20, -12, -1])
image_grid_pinpoints: list = field(default_factory=lambda: _granite_3_2_2b_grid)
tie_word_embeddings: bool = True
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typically we call this tie_heads, might make sense to make the same assumption

Copy link
Copy Markdown
Contributor Author

@sahilsuneja1 sahilsuneja1 Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will replace-- actually wasn't being used, so removed it.

Comment thread fms/models/llava_next.py Outdated
)
self.config = config

# NOTE: HF doesn't do this
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you mention how hf initializes its weights

Copy link
Copy Markdown
Contributor Author

@sahilsuneja1 sahilsuneja1 Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HF doesn't have this particular initialization! Will update in comment

Comment thread fms/models/llava_next.py Outdated
image_sizes=None,
position_ids=None,
past_key_value_states=None,
inputs_embeds=None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this match closer to granite with is_input_embedded?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this is passed with this particular key from an image processor.
Also, you are right, is_input_embedded is set to True, when inputs_embeds are sent to granite.forward() in line 399

Comment thread fms/models/llava_next.py Outdated
use_cache=False,
**attn_kwargs: Unpack[AttentionKwargs],
):
if input_ids is None and inputs_embeds is None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this raise an exception?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had an exception here originally, but the tests complained for some reason. I can try and bring it back.

Comment thread fms/models/llava_next.py
Comment thread tests/models/test_llava_next.py Outdated
"llava_next uses nested configs for vision and text model, which get flattened with config.as_dict()"
)

def test_model_compile_no_graph_breaks(self, model):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we fix these as part of this PR?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do it in a follow-up PR.
Once this original implementation is in, others from the team will address the graph breaks issue-- one of the original motivation to support granite-vision in fms.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think part of landing the model to FMS is having this fixed, right? Why do we want to land a broken model in FMS in the first place to then fix it later? If there's a good reason to not wait on that we can make an exception

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, no reason, I thought we let the HF impl equivalence variant in first, so that others from the team can focus on modifying it to remove the graph break and add alternate ops for AIU. But we can work off of the branch and merge it all together if you prefer.

@sahilsuneja1
Copy link
Copy Markdown
Contributor Author

DCO is complaining, I don't know how to fix it.
Once you are satisfied with the review updates, I will create a new clean PR atop a fresh branch, and tag this one.

Copy link
Copy Markdown
Collaborator

@JRosenkranz JRosenkranz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a script for this to show usage? and/or an HF equivalence test similar to https://github.com/foundation-model-stack/foundation-model-stack/blob/main/tests/models/hf_equivalence/test_llama.py

@sahilsuneja1
Copy link
Copy Markdown
Contributor Author

sahilsuneja1 commented Jun 25, 2025

Yup, will add an example usage snippet. Will try and add the HF equivalence test as well.
Update: Landed somewhere in the middle-- added tests/models/hf_equivalence/test_granite_vision.py which tests HF vs FMS output equivalence, but doesn't use the HF equivalence framework

Comment thread fms/models/hf/utils.py
Comment thread tests/models/hf_equivalence/test_granite_vision.py Outdated
Comment thread tests/models/hf_equivalence/test_granite_vision.py
Comment thread fms/models/llava_next.py Outdated
Comment thread fms/models/llava_next.py Outdated
return unpadded_tensor

# TODO: fix graph break in the HF impl here
def select_best_resolution(self, original_size, possible_resolutions: list):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if we're careful with tensor arithmetic, all of this can be rewritten to not have graph breaks. Let's set up a short discussion offline to figure it out.

Copy link
Copy Markdown
Collaborator

@ani300 ani300 Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in case we can't figure it out, and given you already have a custom function hook for generate(), the call to image_size_to_num_patches can be moved there and outside of the forward() pass, which will remove the graph break, and then we change the inputs to forward from image_sizes to image_num_patches

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, I believe so too. First pass was HF equivalence, followed by graph break removal, followed by alternative ops for AIU.

Good idea regarding the custom function hook fallback, will keep in mind!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved image feature generation outside forward() to bypass unsupported op, and now no graph breaks either, as you predicted :)

Comment thread fms/models/llava_next.py Outdated

return best_fit

def image_size_to_num_patches(self, image_size, grid_pinpoints, patch_size: int):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be rewritten without a single loop

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you rewrite this using modular arithmetic?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment thread fms/models/llava_next.py
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = self.unpad_image(image_feature, image_sizes[image_idx])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this also cause a graph break?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, I know this does, but the if comparison inside unpad_image might also cause one. Will verify with the next pass targeting graph break removal.

Copy link
Copy Markdown
Collaborator

@ani300 ani300 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some questions about why are we merging with the graph break still there, and why there isn't some code to load from HF on the HF utils so we can use "hf_pretrained" on get_model.

@sahilsuneja1
Copy link
Copy Markdown
Contributor Author

Thanks for the review again! I left my responses with the individual comments

@sahilsuneja1
Copy link
Copy Markdown
Contributor Author

Addressed comments except graph break-- TODO.

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
Comment thread fms/models/hf/utils.py Outdated
Comment thread fms/models/hf/utils.py
config_params["vision_feature_select_strategy"] = (
config.vision_feature_select_strategy
)
_, vision_config_params = _map_model_config("SiglipModel", config)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be config.vision_config?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nah, Siglip uses text + vision, and fms support is only for vision (since granite-vision only uses the vision encoder of Siglip, without it's text part, using granite for the latter). So sending in vision here is fine, processed appropriately in line 215

Comment thread fms/models/hf/utils.py

# infer common params
if hasattr(config, "vocab_size"):
if infer_common_params:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even if infer_common_params is True, some models lack src_vocab_size right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was only true for Siglip, but I took care of it in its section itself.

Copy link
Copy Markdown
Collaborator

@ani300 ani300 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is getting closer to being ready for merge as a first pass. A few more comments to address.

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
@sahilsuneja1
Copy link
Copy Markdown
Contributor Author

Addressed the latest comments, thanks!

Comment thread fms/models/granite.py Outdated
past_key_value_states: Optional[Tuple[torch.FloatTensor,]] = None,
use_cache: bool = False,
only_last_token: bool = False,
is_input_embedded: Optional[bool] = False,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment thread fms/models/llava_next.py Outdated

def forward(
self,
input_ids: torch.Tensor,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to match other fms models, can we make this the inputs, and remove inputs_embeds

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread fms/models/llava_next.py Outdated
Comment on lines +384 to +407
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features(
pixel_values,
image_sizes,
)

image_features = self.pack_image_features(
image_features,
image_sizes,
image_newline=self.image_newline,
)

special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
-1
)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
image_features = image_features.to(
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(
special_image_mask, image_features
)
Copy link
Copy Markdown
Collaborator

@mserranos mserranos Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: are pixel_values constant (same value) for all inferences ? if so, then the input_embeds can be precomputed. This will eliminate the "masked_scatter" in the forward method which is difficult to support in Spyre... we can talk offline, but I checked implementations and they are complicated (even if the op is decomposed).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point, thanks @mserranos! Will work on it in the subsequent PR alongwith addressing the graph breaks.

Copy link
Copy Markdown
Collaborator

@mserranos mserranos Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better if you put it in this current PR instead of a subsequent PR, could you elaborate about the graph break problem that you are seeing ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is one here, and potentially another one in this function.

Copy link
Copy Markdown
Contributor Author

@sahilsuneja1 sahilsuneja1 Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be more. We decided to fix them in a follow-up PR.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sahilsuneja1 if there are graph breaks in torch._dynamo then this PR is not useful for Spyre, I suggest you think about how to fix those breaks in this PR along with the "masked_scatter" replacement. We can talk offline, I would like to understand more this code. Have you tried torch 2.7.1 or the upcoming torch 2.8.0 to be released soon ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.
This code is a first pass mapping of HF transformers impl to make granite-vision run on fms + gpu.
Haven't tested with torch 2.7.1 or 2.8.0.
Haven't added support for AIU yet-- that was supposed to be a follow-up PR.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can hold off if you want everything as part of this PR itself

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I looked at the code where the graph break happens, it is part of the "pack_image_features", If the input_embeds are precomputed as I suggested then that graph break will disappear. I have not yet looked at the second potential graph break.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified as per your and Antoni's suggestion and now no graph breaks and unsupported ops :) Thank you!

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
@sahilsuneja1
Copy link
Copy Markdown
Contributor Author

Addressed review comments round 3, thanks again!

Comment thread fms/models/llava_next.py Outdated
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
past_key_value_states: Optional[Tuple[torch.FloatTensor,]] = None,
inputs: Optional[torch.Tensor] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a separate param for this? Could we just have input_ids be input_ids_or_embeds

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree with @JRosenkranz here

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Comment thread tests/models/test_llava_next.py Outdated
)
from fms.utils.config import ModelConfig

_text_config = GraniteConfig(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this and the vision config be inside config in line 88?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Copy Markdown
Collaborator

@ani300 ani300 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one minor change in the interface and this is ready for merging

sahilsuneja1 and others added 2 commits July 22, 2025 19:57
Copy link
Copy Markdown
Collaborator

@ani300 ani300 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! will wait for @JRosenkranz to take a final look before I merge

Copy link
Copy Markdown
Collaborator

@JRosenkranz JRosenkranz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@JRosenkranz
Copy link
Copy Markdown
Collaborator

bot:test
TEST_FILE=test_decoders.py MODEL_ID=ibm-granite/granite-3.3-8b-instruct BATCH_SIZE=8 SEQUENCE_LENGTH=64 USE_TINY_MODEL=1

@JRosenkranz JRosenkranz merged commit 6cdf7a3 into foundation-model-stack:main Jul 23, 2025
5 checks passed
@JRosenkranz JRosenkranz deleted the llava_next branch July 23, 2025 13:47
nirajkamal pushed a commit to nirajkamal/foundation-model-stack that referenced this pull request Sep 15, 2025
* adding siglip vision support

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* import fix

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* update attn_kwargs

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* ruff

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* adding tests

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* ruff format

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* test update

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* addressing review comments

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* siglip updates post  review foundation-model-stack#2

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* adding fms/models/llava_next

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* adding fms/models/llava_next

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* update attn_kwargs

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* ruff

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* adding tests for llava_next

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* rebasing on siglip branch + addressing review comments

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* rebse atop siglip foundation-model-stack#2

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* adding hf bs fms output equivalence test

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* moving imports to make build framework happy during testing

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* pytest.mark.slow

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* addressing review comments

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* mypy

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* adding HF config loading via hf_pretrained

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* percolate unfuse_weights to constituents

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* hf config loading check

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* recursive model config mapping

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* removing loops

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* fix graph break

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

* combining input_ids and inputs_embeds args

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>

---------

Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
Co-authored-by: Antoni Viros <aviros@ibm.com>
Signed-off-by: Niraj Kamal Karunanidhi <nirajkkamal@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants