Skip to content

Add Sapiens2 Model#45919

Draft
guarin wants to merge 54 commits into
huggingface:mainfrom
guarin:add-sapiens2
Draft

Add Sapiens2 Model#45919
guarin wants to merge 54 commits into
huggingface:mainfrom
guarin:add-sapiens2

Conversation

@guarin
Copy link
Copy Markdown
Member

@guarin guarin commented May 12, 2026

What does this PR do?

  • Adds the new Sapiens2 model from Meta

There is an open PR for the original Sapiens model (v1) from 2024: #33167 I started from scratch for v2 as it supersedes the old version.

Sapiens2 repo: https://github.com/facebookresearch/sapiens2

TODO before merge

  • Drop cv2 dependency?
  • Re-use pose pre- and post-processing from ViTPose where possible
  • Update docs
  • Tidy up all docstrings
  • Once config is settled, create PR to hub with model and processor configs

Code Agent Policy

The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by
code agents. We are currently bottlenecked by our ability to review and respond to them. As a result,
we ask that new users do not submit pure code agent PRs at this time.
You may use code agents in drafting or to help you diagnose issues. We'd also ask autonomous "OpenClaw"-like agents
not to open any PRs or issues for the moment.

PRs that appear to be fully agent-written will probably be closed without review, and we may block users who do this
repeatedly or maliciously.

This is a rapidly-evolving situation that's causing significant shockwaves in the open-source community. As a result,
this policy is likely to be updated regularly in the near future. For more information, please read CONTRIBUTING.md.

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Left a small initial review

Comment on lines +130 to +131
if not config.use_mask_token:
del self.mask_token
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need a conditional here?

Copy link
Copy Markdown
Member Author

@guarin guarin May 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DINOv3 has always a mask token. Sapiens2 was pretrained using a mask token but the checkpoints were uploaded without it (probably EMA model). I had to add the conditional to handle the checkpoints without mask token. If someone would like to continue pretraining Sapiens2 they would need to set use_mask_token=True. If I don't add the conditional I get a warning about missing mask tokens in the checkpoint.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning about missing mask token is OK imo ( a bit annoying indeed). My point here is that the deletion is conditional to this in the modular file so I was surprised, but it just copies it over. A del statement in a modular file entirely deletes an attribute in the expanded modeling file, else

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that a part of your comment is missing?

Comment on lines +134 to +135
if bool_masked_pos is not None and not self.config.use_mask_token:
raise ValueError("bool_masked_pos requires use_mask_token=True in the config")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question here, is it something that can happen?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to fill before merge, it's also nice to add some usage examples here, possibly link to documentation images that we can host on the hub

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do 👍🏼

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

)


# TODO(guarin): Double check if we cannot inherit attribute docstrings from parent class.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately, not at the moment 😬


model_type = "sapiens2"

# TODO(guarin): This is needed to load the original checkpoints but makes unit tests fail.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, how so?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>                   file_pointer = safe_open(file, framework="pt", device="cpu")
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E                   FileNotFoundError: No such file or directory: /var/folders/c8/0vwzz_7s429ggn8376hs9q9c0000gn/T/tmptyicczri/sapiens2_0.4b_pretrain.safetensors

In a lot of inherited tests:

FAILED tests/models/sapiens2/test_modeling_sapiens2.py::Sapiens2ModelTest::test_bc_torch_dtype - FileNotFoundError: No such file or directory: /var/folders/c8/0vwzz_7s429ggn8376hs9q9c0000gn/T/tmp6xiqbxf5/sapiens2_0.4b_pretr...
FAILED tests/models/sapiens2/test_modeling_sapiens2.py::Sapiens2ModelTest::test_can_load_from_already_mapped_keys - FileNotFoundError: No such file or directory: /var/folders/c8/0vwzz_7s429ggn8376hs9q9c0000gn/T/tmpl91m5jdt/sapiens2_0.4b_pretr...
FAILED tests/models/sapiens2/test_modeling_sapiens2.py::Sapiens2ModelTest::test_can_use_safetensors - FileNotFoundError: No such file or directory: /var/folders/c8/0vwzz_7s429ggn8376hs9q9c0000gn/T/tmp9q3b8ixl/sapiens2_0.4b_pretr...
FAILED tests/models/sapiens2/test_modeling_sapiens2.py::Sapiens2ModelTest::test_correct_missing_keys - FileNotFoundError: No such file or directory: /var/folders/c8/0vwzz_7s429ggn8376hs9q9c0000gn/T/tmpij68xy5b/sapiens2_0.4b_pretr...
FAILED tests/models/sapiens2/test_modeling_sapiens2.py::Sapiens2ModelTest::test_eager_matches_sdpa_inference_00_fp16_pad_left_sdpa_kernels - FileNotFoundError: No such file or directory: /var/folders/c8/0vwzz_7s429ggn8376hs9q9c0000gn/T/tmpbda3izxx/sapiens2_0.4b_pretr...
...

Didn't have time to look into it in detail yet. Might be that it tries to save and then reload the model again. When reloading it will then try to find it under sapiens2_0.4b_pretrain.safetensors but probably it is saved to model.safetensors instead.

Copy link
Copy Markdown
Member Author

@guarin guarin May 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relevant part from stacktrace

    def test_sdpa_can_dispatch_non_composite_models(self):
        """
        Tests if non-composite models dispatch correctly on SDPA/eager when requested so when loading the model.
        This tests only by looking at layer names, as usually SDPA layers are called "SDPAAttention".
        """
        if not self.has_attentions:
            self.skipTest(reason="Model architecture does not support attentions")
    
        if not self.all_model_classes[0]._supports_sdpa or self._is_composite:
            self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
    
        for model_class in self.all_model_classes:
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
            model = model_class(config)
    
            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
>               model_sdpa = model_class.from_pretrained(tmpdirname)

Looking at model.save_pretrained I couldn't find a mention of config.transformers_weights which probably means that it is saved to a default path instead.

I also couldn't find any usage of transformers_weights in any other model config so I guess the proper fix is to leave it as None and make a PR to the hub instead.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes there's no usage here because we have very few model releases that are as scarce in standard files as this one, but it might not be a bad precedent. QQ to @ArthurZucker on this - WDYT would suit better given

context: a model release with a single file named whatever.safetensors, no config.json, no index, nothing

  1. use transformers_weights in the new default config and make sure save/load works?
  2. Open a PR to original repo?

If 1) is too complicated we can use 2) and refer to the git branch of the pr (I mean model = AutoModel.from_pretrained("org/model-name", revision="refs/pr/1")

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also no preprocessor config, so AutoImageProcessor doesn't work either.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! if it only has a safetensors, its safe to assume 0 libraries depend on it -> let's open a PR and try to get it merged reaching out the authors! (we can most probably push final format of the weights as well like we do often)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like a good draft overall!

Comment on lines +907 to +934
"sapiens2": [
WeightRenaming(r"^cls_token$", r"embeddings.cls_token"),
WeightRenaming(r"^storage_tokens$", r"embeddings.register_tokens"),
WeightRenaming(r"^patch_embed\.projection\.", r"embeddings.patch_embeddings."),
WeightRenaming(r"^rope_embed\.", r"rope_embeddings."),
WeightRenaming(r"blocks\.(\d+)\.", r"model.layer.\1."),
WeightRenaming(r"attn\.proj\.", r"attention.o_proj."),
WeightRenaming(r"attn\.wq\.", r"attention.q_proj."),
WeightRenaming(r"attn\.wk\.", r"attention.k_proj."),
WeightRenaming(r"attn\.wv\.", r"attention.v_proj."),
WeightRenaming(r"attn\.q_norm\.", r"attention.q_norm."),
WeightRenaming(r"attn\.k_norm\.", r"attention.k_norm."),
WeightRenaming(r"attn\.gamma\.weight$", r"layer_scale1.lambda1"),
WeightRenaming(r"ffn\.w3\.", r"mlp.down_proj."),
WeightRenaming(r"\.ln1\.", r".norm1."),
WeightRenaming(r"\.ln2\.", r".norm2."),
WeightRenaming(r"^ln1\.", r"norm."),
WeightConverter(
source_patterns=r"ffn\.w12\.weight",
target_patterns=[r"mlp.gate_proj.weight", r"mlp.up_proj.weight"],
operations=[Chunk(dim=0)],
),
WeightConverter(
source_patterns=r"ffn\.w12\.bias",
target_patterns=[r"mlp.gate_proj.bias", r"mlp.up_proj.bias"],
operations=[Chunk(dim=0)],
),
],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good usage here


@slow
def test_inference_no_head(self):
# TODO(guarin): cleanup. transformers_weights required because original checkpoints are called "sapiens2_0.4b_pretrain.safetensors" instead of "model.safetensors"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be solved with default config values, right?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think so. Once configs are on the hub the AutoImageProcessor should also load correctly, this seems to fail for now.


@require_torch
@require_vision
class Sapiens2ModelIntegrationTest(unittest.TestCase):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to check that the model behaves as expected - FYI for most generative models we also try to put a full end-to-end test with generation (model.generate()), if relevant

with torch.no_grad():
outputs = model(**inputs)

# verify the last hidden states
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do these come from the original implem?

Copy link
Copy Markdown
Member Author

@guarin guarin May 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, however tests do not always pass. Sapiens2 runs everything in bf16 by default. When I convert the model and input image to bf16 and use logits from the original model in bf16 the tests pass. For this I also had to use bf16 in the rope embed following the original code. If I run with fp32 and compare to fp32 logits from the original model I get some differences even if I keep rope in bf16 as in the original repo. Not yet sure where the diff comes from.

In general, if the original model is in bf16 should we also load and test it in bf16? Or do we prefer to load in fp32 and adjust tests accordingly?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general we prefer to make the tests conditions match the original implementation/environment. For RoPE sometimes there's some hidden upcasting... if you have the full implementation on both sides and can't track the origin, having the output in json from https://github.com/huggingface/transformers/blob/main/src/transformers/model_debugging_utils.py can help

Copy link
Copy Markdown
Member Author

@guarin guarin May 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it is fine if we use bf16 in the tests but then load the model by default in fp32? In the meantime I'll try to figure out where the diff for fp32 comes from.

Also regarding rope, is it ok if I keep it fixed to bf16 as in Sapiens2 or would you prefer to keep the more flexible implementation following DINOv3? If I go with fixed bf16 I'll also have to add a custom apply_rotary_pos_emb implementation with the dtype casting.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated now the rope implementation to exactly match the original code. Now the tests pass for bf16 and fp32. I had to slightly reduce to tolerance from 1e-4 to 1e-3 because of one value that is slightly different.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tracked it down, for a perfect match I had to change a couple of things:

  • Generate expected logits with
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.conv.fp32_precision = "ieee"
torch.backends.fp32_precision = "ieee"

as is the default in the tests

  • Update image loading to use torchvision decode_image
  • Skip imageprocessor and use torchvision.transforms.v2 instead. I believe the difference between the two comes from the norm+rescale fusing in ImageProcessor which gives slightly different values

Given that those are all FP precision issues and will likely result in tests failing on different architectures/torch versions I propose to keep 1e-3 tolerance.

Comment on lines +207 to +208
def test_output_hidden_states(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_output_hidden_states is defined on the tester instead of the test class for DINOv3. I moved it to the test class for Sapiens2. Might merit a follow-up PR for DINOv3.

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, sapiens2

@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45919&sha=6c32cb

Copy link
Copy Markdown
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!! Looks great for a first draft. As you said, there's a bit of standardization to do with the VitPose image processor, as for the cv2 requirements, there's no real equivalence for INTER_AREA in pil and torchvision, so we might have no choice but too keep the requirement.
Most of the comments are nit-picking, the overall structure looks great. One other thing to nit pick in the structure, I usually prefer to put the image processor code at the top of modular, but not a big deal

Comment on lines +302 to +304
if self.use_qk_norm:
self.q_norm = nn.RMSNorm(self.head_dim, eps=config.layer_norm_eps)
self.k_norm = nn.RMSNorm(self.head_dim, eps=config.layer_norm_eps)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we can use a ternary with q/k_norm set to identity when use_qk_norm is False, so we don't need to set an attr use_qk_norm, and have only one path in forward (see internvl)

Comment on lines +292 to +294
self.num_kv_heads = (
self.num_heads if config.layer_types[layer_idx] == "full_attention" else config.num_key_value_heads
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's define self.num_key_value_groups instead with repeat_kv in the eager method instead (like in gemma4 for example), we can take eager_attention_forward from a model other than dinov3_vit.

We can probably inherit Sapiens2Attention from another model as well, at least partially for the init

Comment on lines +196 to +243
class Sapiens2RopePositionEmbedding(nn.Module):
periods: torch.Tensor

def __init__(self, config: Sapiens2Config):
super().__init__()

self.patch_size = config.patch_size
self.pos_embed_shift = config.pos_embed_shift
self.pos_embed_jitter = config.pos_embed_jitter
self.pos_embed_rescale = config.pos_embed_rescale
self.base = config.rope_theta
self.head_dim = config.hidden_size // config.num_attention_heads
self.pos_embed_dtype = getattr(torch, config.pos_embed_dtype)

periods = self.base ** (
2 * torch.arange(self.head_dim // 4, dtype=self.pos_embed_dtype) / (self.head_dim // 2)
)
self.register_buffer("periods", periods, persistent=True) # persistent=True to match original checkpoints

def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
_, _, height, width = pixel_values.shape
num_patches_h = height // self.patch_size
num_patches_w = width // self.patch_size

device = pixel_values.device
device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu"

with maybe_autocast(device_type=device_type, enabled=False):
patch_coords = get_patches_center_coordinates(
num_patches_h, num_patches_w, dtype=self.pos_embed_dtype, device=device
)
if self.training:
patch_coords = augment_patches_center_coordinates(
patch_coords,
shift=self.pos_embed_shift,
jitter=self.pos_embed_jitter,
rescale=self.pos_embed_rescale,
)

# (height * width, 2, head_dim / 4) -> (height * width, head_dim / 2) -> (height * width, head_dim)
angles = 2 * math.pi * patch_coords[:, :, None] / self.periods[None, None, :].to(self.pos_embed_dtype)
angles = angles.flatten(1, 2)
angles = angles.tile(2)

cos = torch.cos(angles)
sin = torch.sin(angles)

return cos, sin
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we fully reuse DINOv3ViTRopePositionEmbedding here?

periods = self.base ** (
2 * torch.arange(self.head_dim // 4, dtype=self.pos_embed_dtype) / (self.head_dim // 2)
)
self.register_buffer("periods", periods, persistent=True) # persistent=True to match original checkpoints
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we usually name this inv_freq, also no need to set persistent to True to match the original checkpoint. If we don't end up pushing new checkpoint, we can change the _keys_to_ignore_on_load_missing attr of the model

Comment on lines +391 to +401
class Sapiens2ConvLayer(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, kernel_size: int = 1, bias: bool = True, activation: str = "silu"
):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias)
self.norm = nn.InstanceNorm2d(out_channels)
self.activation = ACT2FN[activation]

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.activation(self.norm(self.conv(hidden_states)))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIt: we can inherit from e.g. BeitConvLayer and change only the norm in the init and the init signature

Comment on lines +566 to +567
class Sapiens2PoseHead(Sapiens2SegmentationHead):
pass
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this, lets just use Sapiens2SegmentationHead in Sapiens2ForPoseEstimation


ALLOWED_LAYER_TYPES = (
"full_attention",
"grouped_query_attention", # used in Sapiens2
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this, we can just have num_key_value_groups=1 for full attention

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants