Adding siglip vision model#405
Conversation
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
|
PR ready for review @JRosenkranz |
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
JRosenkranz
left a comment
There was a problem hiding this comment.
Has this been tested for equivalency with huggingface? We should try loading model using get_model with an "hf" source. Can you follow something similar to this as a reference: https://github.com/foundation-model-stack/foundation-model-stack/blob/main/tests/models/hf_equivalence/test_granite.py
| self.attention = torch.nn.MultiheadAttention( | ||
| config.hidden_size, config.num_attention_heads, batch_first=True | ||
| ) |
There was a problem hiding this comment.
we can and probably should move this to use our MultiHeadAttention block, as that will enable quantization and any other AIU specific behavior we might need
There was a problem hiding this comment.
I was also curious why the hf implementation uses two different attention mechanisms- their own SiglipAttention implementation in SiglipEncoderLayer as well as torch.nn.MultiheadAttention here. So kept it as is.
I do use our MultiHeadAttention block in SiglipEncoderLayer's self.self_attn, line 114.
Will try substituting with our own block to see if the output stays the same.
There was a problem hiding this comment.
Got following error on using our MultiHeadAttention block. Maybe there is a reason HF implementation also uses two variants of MHA calls-- one in SiglipEncoderLayer (where we are able to use our MultiHeadAttention impl) and vanilla torch.nn.MultiheadAttention in SiglipMultiheadAttentionPoolingHead)
File "/gpfs/suneja/foundation-model-stack/scripts/inference_siglip_vision.py", line 243, in <module>
outputs = model(**inputs)
^^^^^^^^^^^^^^^
File "/gpfs/suneja/conda-envs/env1/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/suneja/conda-envs/env1/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/suneja/foundation-model-stack/fms/models/siglip_vision.py", line 319, in forward
pooler_output = self.head(hidden_states) if self.use_head else None
^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/suneja/conda-envs/env1/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/suneja/conda-envs/env1/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/suneja/foundation-model-stack/fms/models/siglip_vision.py", line 259, in forward
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/suneja/conda-envs/env1/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/suneja/conda-envs/env1/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/suneja/foundation-model-stack/fms/modules/attention.py", line 641, in forward
q_out, k_out, v_out = self.in_proj(q, k, v)
^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/suneja/conda-envs/env1/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/suneja/conda-envs/env1/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/suneja/foundation-model-stack/fms/modules/attention.py", line 512, in forward
raise ValueError("q, k, and v must be the same or k and v must be None")
ValueError: q, k, and v must be the same or k and v must be None
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
| self.use_head = ( | ||
| True | ||
| if not hasattr(self.config, "vision_use_head") | ||
| else self.config.vision_use_head |
There was a problem hiding this comment.
I don't see a vision_use_head as part of the above ModelConfig. Is it supposed to be there?
There was a problem hiding this comment.
Hmm, I borrowed this from the HF implementation, to keep the Siglip implementation more general beyond its use in granite-vision, which does not explicitly set that config. I guess I can set vision_use_head to True as default, or drop this altogether for now.
There was a problem hiding this comment.
If we see a requirement for vision_use_head in the future, I would add it to the config with True as the default. Otherwise i would remove it
There was a problem hiding this comment.
Removed for now
| SiglipFixtures, | ||
| ): | ||
| @staticmethod | ||
| def get_last_hidden_state(f_out): |
There was a problem hiding this comment.
do we want a case where vision_use_head=True?
There was a problem hiding this comment.
Hmm, that wouldn't change anything. It would set self.use_head to True in siplip.py:301 which is the default case anyways when vision_use_head is not set. Also, I'll just drop it for now, as per one of the previous comments.
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
| self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) | ||
|
|
||
| # HF implementation uses PT MHA here, as opposed to SiglipAttnention as in the SiglipEncoderLayer | ||
| self.attention = torch.nn.MultiheadAttention( |
There was a problem hiding this comment.
thinking mostly of AIU support for this, do we want to use our attention implementation instead?
There was a problem hiding this comment.
I tried using our MultiHeadAttention implementation, but lead to the error I shared before. Does using torch.nn variant lead to issues on AIU?
There was a problem hiding this comment.
ah, I see what's happening, this uses cross-attention, which our Attention class cannot do (we removed it as no models used cross attention anymore after everyone moved to decoders from encoder-decoder architectures). Keep it as torch.nn.MHA for now then, although I'm 99% sure this will fail on the AIU
|
|
||
| new_sd = input_sd | ||
| if has_fused_weights: | ||
| new_sd = serialization._mlp_glu_unfused_to_fused_adapter_step( |
There was a problem hiding this comment.
the MLP layer isn't GLU, so weight fusion with this helper won't work, there should be another one that fuses the MLP instead
There was a problem hiding this comment.
Thanks! Will fix
There was a problem hiding this comment.
@ani300 There doesn't seem to be one specifically for mlp in utils/serialization.py
There was a problem hiding this comment.
oh nvm, that's what happens when you review code after 11pm. There's no fusion possible for MLP. Just remove the glu portion here, the code should be:
new_sd = input_sd
if has_fused_weights:
new_sd = serialization._attn_unfused_to_fused_step(new_sd)
return new_sdThere was a problem hiding this comment.
Np! Thank for reviewing!
|
Thanks! Will take care of the others, but reharding our MHA impl, I run into this error. |
Signed-off-by: Sahil Suneja <suneja@us.ibm.com>
Needed to enable granite-vision = siglip vision (this PR) + llava-next vision-language connector (PR #420) + granite-3.1-2b-instruct (already supported)