Skip to content

Fix AttributeError in AssistantToTargetTranslator.unmap_input_ids with cross-vocab models#45320

Merged
zucchini-nlp merged 4 commits intohuggingface:mainfrom
Regata3010:fix-cross-vocab-assisted-generation
Apr 10, 2026
Merged

Fix AttributeError in AssistantToTargetTranslator.unmap_input_ids with cross-vocab models#45320
zucchini-nlp merged 4 commits intohuggingface:mainfrom
Regata3010:fix-cross-vocab-assisted-generation

Conversation

@Regata3010
Copy link
Copy Markdown
Contributor

@Regata3010 Regata3010 commented Apr 8, 2026

What does this PR do?

Fixes a crash in assisted generation when using model pairs with different vocabulary sizes but the same tokenizer family (e.g., Qwen2.5-7B + Qwen2.5-0.5B).

map_input_embeddings is only initialized when len(self._suppress_input_ids) > 0 (line 723), but unmap_input_ids() only checked self.assistant_prune_lm_head. This caused an AttributeError when the
assistant vocab is a subset of the target vocab (no suppressed IDs), but assistant_prune_lm_head is enabled.

The fix adds the same len(self._suppress_input_ids) > 0 guard to unmap_input_ids(), matching the initialization condition.

Reproduction

from transformers import AutoModelForCausalLM, AutoTokenizer
                                                                                                                                                                                                                    
target = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B", device_map="auto")                                                                                                                                 
draft = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B", device_map="auto")                                                                                                                                
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B")                                                                                                                                                        
                
input_ids = tokenizer.encode("Hello world", return_tensors="pt").to("cuda")                                                                                                                                         
output = target.generate(
    input_ids,                                                                                                                                                                                                      
    assistant_model=draft,
    tokenizer=tokenizer,
    assistant_tokenizer=tokenizer,
    max_new_tokens=32,                                                                                                                                                                                              
)
                                                                                                                                                                                                                    
Before fix:     
AttributeError: 'AssistantToTargetTranslator' object has no attribute 'map_input_embeddings'

After fix: generates correctly.                                                                                                                                                                                     
 
Linked issue                                                                                                                                                                                                        
                
Fixes #45307                                                                                                                                                                                          
                
AI disclosure

AI assistance was used in identifying the root cause. The fix was reviewed and validated by the submitter. The one-line change mirrors the existing guard at line 723.                                              
   

  map_input_embeddings is only initialized when _suppress_input_ids is
  non-empty (line 723-740), but unmap_input_ids() only checked
  assistant_prune_lm_head. This caused an AttributeError when using
  assisted generation with models that have different vocab sizes but
  share the same tokenizer family (e.g., Qwen2.5-7B + Qwen2.5-0.5B).

  Added len(self._suppress_input_ids) > 0 check to match the
  initialization guard.
@Regata3010 Regata3010 mentioned this pull request Apr 8, 2026
Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Can you add a comment on why we check both (same tokenizer base vocab but different vocab sizes, etc)

The failing test is just flaky, dw about that

@zucchini-nlp zucchini-nlp enabled auto-merge April 10, 2026 09:03
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@zucchini-nlp zucchini-nlp added this pull request to the merge queue Apr 10, 2026
Merged via the queue into huggingface:main with commit c43f15c Apr 10, 2026
28 checks passed
@Regata3010
Copy link
Copy Markdown
Contributor Author

Thanks for the review. Learned a lot doing this.

sirzechs66 pushed a commit to sirzechs66/transformers that referenced this pull request Apr 18, 2026
…h cross-vocab models (huggingface#45320)

* Fix AssistantToTargetTranslator crash with cross-vocab models

  map_input_embeddings is only initialized when _suppress_input_ids is
  non-empty (line 723-740), but unmap_input_ids() only checked
  assistant_prune_lm_head. This caused an AttributeError when using
  assisted generation with models that have different vocab sizes but
  share the same tokenizer family (e.g., Qwen2.5-7B + Qwen2.5-0.5B).

  Added len(self._suppress_input_ids) > 0 check to match the
  initialization guard.

* Add comment explaining cross-vocab guard in unmap_input_ids

* Add Comment Explaining Cross-Vocab guard in unmap_input_ids

---------

Co-authored-by: Raushan Turganbay <raushan@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants