Adding llava_next model#420
Conversation
|
PR ready for review @JRosenkranz |
| position_ids=None, | ||
| past_key_value_states=None, | ||
| use_cache=False, | ||
| is_input_embedded=False, |
There was a problem hiding this comment.
we may want to think about how this could tie in with other models. Seems not specific to Granite.
There was a problem hiding this comment.
Yeah, this is not Granite-specific, but hasn't been needed so far with any other models.
So far only granite-vision via llava_next sends its inputs to granite in this form. So maybe good for now for Granite?
But you are right, pixtral would probably have something similar.
There was a problem hiding this comment.
For this param, I think we can get away with checking on the tensor.dim and removing the param entirely. If the size is 3, it would skip embedding, if the size is 2 would run the embedding. If we do this, we would add documentation for how x_in is handled. In either case, whether including a param or not, we will need some error handling, so don't really see a downside to this.
There was a problem hiding this comment.
Cool, will try it out!
| vision_feature_select_strategy: str = "full" | ||
| vision_feature_layer: list = field(default_factory=lambda: [-24, -20, -12, -1]) | ||
| image_grid_pinpoints: list = field(default_factory=lambda: _granite_3_2_2b_grid) | ||
| tie_word_embeddings: bool = True |
There was a problem hiding this comment.
typically we call this tie_heads, might make sense to make the same assumption
There was a problem hiding this comment.
will replace-- actually wasn't being used, so removed it.
| ) | ||
| self.config = config | ||
|
|
||
| # NOTE: HF doesn't do this |
There was a problem hiding this comment.
can you mention how hf initializes its weights
There was a problem hiding this comment.
HF doesn't have this particular initialization! Will update in comment
| image_sizes=None, | ||
| position_ids=None, | ||
| past_key_value_states=None, | ||
| inputs_embeds=None, |
There was a problem hiding this comment.
Should this match closer to granite with is_input_embedded?
There was a problem hiding this comment.
Hmm, this is passed with this particular key from an image processor.
Also, you are right, is_input_embedded is set to True, when inputs_embeds are sent to granite.forward() in line 399
| use_cache=False, | ||
| **attn_kwargs: Unpack[AttentionKwargs], | ||
| ): | ||
| if input_ids is None and inputs_embeds is None: |
There was a problem hiding this comment.
should this raise an exception?
There was a problem hiding this comment.
Had an exception here originally, but the tests complained for some reason. I can try and bring it back.
| "llava_next uses nested configs for vision and text model, which get flattened with config.as_dict()" | ||
| ) | ||
|
|
||
| def test_model_compile_no_graph_breaks(self, model): |
There was a problem hiding this comment.
Could we fix these as part of this PR?
There was a problem hiding this comment.
Let's do it in a follow-up PR.
Once this original implementation is in, others from the team will address the graph breaks issue-- one of the original motivation to support granite-vision in fms.
There was a problem hiding this comment.
I think part of landing the model to FMS is having this fixed, right? Why do we want to land a broken model in FMS in the first place to then fix it later? If there's a good reason to not wait on that we can make an exception
There was a problem hiding this comment.
Hmm, no reason, I thought we let the HF impl equivalence variant in first, so that others from the team can focus on modifying it to remove the graph break and add alternate ops for AIU. But we can work off of the branch and merge it all together if you prefer.
|
DCO is complaining, I don't know how to fix it. |
JRosenkranz
left a comment
There was a problem hiding this comment.
Could we add a script for this to show usage? and/or an HF equivalence test similar to https://github.com/foundation-model-stack/foundation-model-stack/blob/main/tests/models/hf_equivalence/test_llama.py
|
Yup, will add an example usage snippet. Will try and add the HF equivalence test as well. |
| return unpadded_tensor | ||
|
|
||
| # TODO: fix graph break in the HF impl here | ||
| def select_best_resolution(self, original_size, possible_resolutions: list): |
There was a problem hiding this comment.
I think if we're careful with tensor arithmetic, all of this can be rewritten to not have graph breaks. Let's set up a short discussion offline to figure it out.
There was a problem hiding this comment.
in case we can't figure it out, and given you already have a custom function hook for generate(), the call to image_size_to_num_patches can be moved there and outside of the forward() pass, which will remove the graph break, and then we change the inputs to forward from image_sizes to image_num_patches
There was a problem hiding this comment.
Yup, I believe so too. First pass was HF equivalence, followed by graph break removal, followed by alternative ops for AIU.
Good idea regarding the custom function hook fallback, will keep in mind!
There was a problem hiding this comment.
I moved image feature generation outside forward() to bypass unsupported op, and now no graph breaks either, as you predicted :)
|
|
||
| return best_fit | ||
|
|
||
| def image_size_to_num_patches(self, image_size, grid_pinpoints, patch_size: int): |
There was a problem hiding this comment.
this can be rewritten without a single loop
There was a problem hiding this comment.
can you rewrite this using modular arithmetic?
| ) | ||
| image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() | ||
| image_feature = image_feature.flatten(1, 2).flatten(2, 3) | ||
| image_feature = self.unpad_image(image_feature, image_sizes[image_idx]) |
There was a problem hiding this comment.
does this also cause a graph break?
There was a problem hiding this comment.
Interesting, I know this does, but the if comparison inside unpad_image might also cause one. Will verify with the next pass targeting graph break removal.
ani300
left a comment
There was a problem hiding this comment.
I have some questions about why are we merging with the graph break still there, and why there isn't some code to load from HF on the HF utils so we can use "hf_pretrained" on get_model.
|
Thanks for the review again! I left my responses with the individual comments |
|
Addressed comments except graph break-- TODO. |
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
| config_params["vision_feature_select_strategy"] = ( | ||
| config.vision_feature_select_strategy | ||
| ) | ||
| _, vision_config_params = _map_model_config("SiglipModel", config) |
There was a problem hiding this comment.
should this be config.vision_config?
There was a problem hiding this comment.
Nah, Siglip uses text + vision, and fms support is only for vision (since granite-vision only uses the vision encoder of Siglip, without it's text part, using granite for the latter). So sending in vision here is fine, processed appropriately in line 215
|
|
||
| # infer common params | ||
| if hasattr(config, "vocab_size"): | ||
| if infer_common_params: |
There was a problem hiding this comment.
even if infer_common_params is True, some models lack src_vocab_size right?
There was a problem hiding this comment.
That was only true for Siglip, but I took care of it in its section itself.
ani300
left a comment
There was a problem hiding this comment.
This is getting closer to being ready for merge as a first pass. A few more comments to address.
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
|
Addressed the latest comments, thanks! |
| past_key_value_states: Optional[Tuple[torch.FloatTensor,]] = None, | ||
| use_cache: bool = False, | ||
| only_last_token: bool = False, | ||
| is_input_embedded: Optional[bool] = False, |
There was a problem hiding this comment.
|
|
||
| def forward( | ||
| self, | ||
| input_ids: torch.Tensor, |
There was a problem hiding this comment.
to match other fms models, can we make this the inputs, and remove inputs_embeds
| if pixel_values is not None and pixel_values.size(0) > 0: | ||
| image_features = self.get_image_features( | ||
| pixel_values, | ||
| image_sizes, | ||
| ) | ||
|
|
||
| image_features = self.pack_image_features( | ||
| image_features, | ||
| image_sizes, | ||
| image_newline=self.image_newline, | ||
| ) | ||
|
|
||
| special_image_mask = (input_ids == self.config.image_token_index).unsqueeze( | ||
| -1 | ||
| ) | ||
| special_image_mask = special_image_mask.expand_as(inputs_embeds).to( | ||
| inputs_embeds.device | ||
| ) | ||
| image_features = image_features.to( | ||
| inputs_embeds.device, inputs_embeds.dtype | ||
| ) | ||
| inputs_embeds = inputs_embeds.masked_scatter( | ||
| special_image_mask, image_features | ||
| ) |
There was a problem hiding this comment.
question: are pixel_values constant (same value) for all inferences ? if so, then the input_embeds can be precomputed. This will eliminate the "masked_scatter" in the forward method which is difficult to support in Spyre... we can talk offline, but I checked implementations and they are complicated (even if the op is decomposed).
There was a problem hiding this comment.
Great point, thanks @mserranos! Will work on it in the subsequent PR alongwith addressing the graph breaks.
There was a problem hiding this comment.
better if you put it in this current PR instead of a subsequent PR, could you elaborate about the graph break problem that you are seeing ?
There was a problem hiding this comment.
There might be more. We decided to fix them in a follow-up PR.
There was a problem hiding this comment.
@sahilsuneja1 if there are graph breaks in torch._dynamo then this PR is not useful for Spyre, I suggest you think about how to fix those breaks in this PR along with the "masked_scatter" replacement. We can talk offline, I would like to understand more this code. Have you tried torch 2.7.1 or the upcoming torch 2.8.0 to be released soon ?
There was a problem hiding this comment.
Ok.
This code is a first pass mapping of HF transformers impl to make granite-vision run on fms + gpu.
Haven't tested with torch 2.7.1 or 2.8.0.
Haven't added support for AIU yet-- that was supposed to be a follow-up PR.
There was a problem hiding this comment.
We can hold off if you want everything as part of this PR itself
There was a problem hiding this comment.
yes, I looked at the code where the graph break happens, it is part of the "pack_image_features", If the input_embeds are precomputed as I suggested then that graph break will disappear. I have not yet looked at the second potential graph break.
There was a problem hiding this comment.
I modified as per your and Antoni's suggestion and now no graph breaks and unsupported ops :) Thank you!
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
|
Addressed review comments round 3, thanks again! |
| input_ids: torch.Tensor, | ||
| position_ids: Optional[torch.Tensor] = None, | ||
| past_key_value_states: Optional[Tuple[torch.FloatTensor,]] = None, | ||
| inputs: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
Do we need a separate param for this? Could we just have input_ids be input_ids_or_embeds
| ) | ||
| from fms.utils.config import ModelConfig | ||
|
|
||
| _text_config = GraniteConfig( |
There was a problem hiding this comment.
Can this and the vision config be inside config in line 88?
ani300
left a comment
There was a problem hiding this comment.
one minor change in the interface and this is ready for merging
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
ani300
left a comment
There was a problem hiding this comment.
lgtm! will wait for @JRosenkranz to take a final look before I merge
|
bot:test |
* adding siglip vision support Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * import fix Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * update attn_kwargs Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * ruff Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * adding tests Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * ruff format Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * test update Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * addressing review comments Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * siglip updates post review foundation-model-stack#2 Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * adding fms/models/llava_next Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * adding fms/models/llava_next Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * update attn_kwargs Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * ruff Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * adding tests for llava_next Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * rebasing on siglip branch + addressing review comments Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * rebse atop siglip foundation-model-stack#2 Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * adding hf bs fms output equivalence test Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * moving imports to make build framework happy during testing Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * pytest.mark.slow Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * addressing review comments Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * mypy Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * adding HF config loading via hf_pretrained Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * percolate unfuse_weights to constituents Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * hf config loading check Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * recursive model config mapping Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * removing loops Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * fix graph break Signed-off-by: Sahil Suneja <suneja@us.ibm.com> * combining input_ids and inputs_embeds args Signed-off-by: Sahil Suneja <suneja@us.ibm.com> --------- Signed-off-by: Sahil Suneja <suneja@us.ibm.com> Co-authored-by: Antoni Viros <aviros@ibm.com> Signed-off-by: Niraj Kamal Karunanidhi <nirajkkamal@gmail.com>
This enables granite-vision in fms == siglip vision (PR #405 ) + llava-next vision-language connector (this PR) + granite-3.1-2b-instruct (already supported)
Please merge PR #405 first, which serves as the parent branch for this.