[pt2] Add indices dtype check to embedding meta registration#179754
[pt2] Add indices dtype check to embedding meta registration#179754XAheli wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/179754
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 06aabc7 with merge base 42e4e00 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@pytorchbot label "module: correctness (silent)" |
|
@pytorchbot label "module: pt2-dispatcher" |
|
@pytorchbot label "topic: fuzzer" |
|
@pytorchbot label "topic: not user facing" |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary: ## Context PyTorch PR pytorch/pytorch#179754 (fixing pytorch/pytorch#178042) added a dtype validation check to the `aten.embedding` meta registration in `torch/_meta_registrations.py`: ```python torch._check( indices.dtype in (torch.long, torch.int32), lambda: ( "Expected tensor for argument #1 'indices' to have one of the following " f"scalar types: Long, Int; but got {indices.dtype} instead" ), ) ``` This aligns the meta function with the C++ implementation (`checkScalarTypes` in `Embedding.cpp`), which already enforced integer indices. Previously, no meta registration existed for `aten.embedding`, so FakeTensor tracing during `torch.export`/`torch.compile` silently accepted float indices, and AOTAutograd's DCE could remove the dead node before the C++ check ever fired. ## Problem `test_batched_export_with_backprop` in `test_static_attention.py` creates example token inputs using `torch.zeros()` without specifying a dtype: ```python # Before (defaults to torch.float32) torch.zeros(batch_size, input_len) torch.zeros(1, input_len) ``` During `torch.export.export()`, these float32 tensors flow into `self.tok_embeddings(tokens)` (an `nn.Embedding` layer in `llama_transformer.py`), which dispatches to `aten.embedding`. The new meta function dtype check rejects float32 indices, causing the export to fail. Note that the actual backprop loop already uses integer indices correctly via `torch.randint(config.vocab_size, (batch_size, input_len))` — only the export-tracing example inputs were wrong. ## Fix Add explicit `dtype=torch.long` to both `torch.zeros` calls used as token example inputs: ```python # After torch.zeros(batch_size, input_len, dtype=torch.long) torch.zeros(1, input_len, dtype=torch.long) ``` Differential Revision: D101547370
Fixes #178042
The
aten.embeddingmeta function was missing the indices dtype check that exists in C++ (checkScalarTypesinEmbedding.cpp). During compile, FakeTensor tracing passes the invalid op through without error, and then AOTAutograd's DCE removes the dead node — so the C++ check is never reached.Added
torch._checkfor indices dtype in the meta function so the error fires during tracing, before DCE runs.Test:
test_embedding_float_indices_errorintest/nn/test_embedding.py— covers eager,aot_eager,inductorCo-authored-with: Claude
cc @bdhirsh @penguinwu @bobrenjc93 @aorenste