Fix default attribute values in shape inference#7602
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #7602 +/- ##
=======================================
Coverage 55.26% 55.27%
=======================================
Files 515 515
Lines 32420 32425 +5
Branches 2898 2898
=======================================
+ Hits 17918 17923 +5
Misses 13713 13713
Partials 789 789 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Thanks for fixing the bug! Could you signoff the commits? Thanks: https://github.com/onnx/onnx/pull/7602/checks?check_run_id=61303166766
|
Is there a way a test can be created in https://github.com/onnx/onnx/blob/main/onnx/test/cpp/shape_inference_test.cc or https://github.com/onnx/onnx/blob/main/onnx/test/shape_inference_test.py ? |
Signed-off-by: Matteo Salvarezza <matteos@wolfram.com>
Signed-off-by: Matteo Salvarezza <matteos@wolfram.com>
1bfe456 to
bea63a8
Compare
|
Added signoff and test. I had to add |
Signed-off-by: Matteo Salvarezza <matteos@wolfram.com>
Replaced base64 model loading with text format parsing for clarity and maintainability.
There was a problem hiding this comment.
Pull request overview
Fixes ONNX shape-inference handling of attribute defaults when an attribute is present in the protobuf but its scalar value field is unset (e.g., axis declared but i not set), aligning behavior with protobuf scalar defaults rather than schema defaults.
Changes:
- Update
getAttributehelpers to distinguish “attribute absent” vs “attribute present but scalar unset”, using protobuf defaults in the latter case. - Add a regression test covering the
Flatten(axis)scenario from issue #7573 using textproto parsing.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
onnx/defs/shape_inference.h |
Adjusts attribute lookup defaults to treat present-but-unset scalar fields as protobuf defaults. |
onnx/test/shape_inference_test.py |
Adds a regression test reproducing the protobuf-default vs schema-default mismatch. |
| inline int64_t getAttribute(const DataPropagationContext& ctx, const std::string& attributeName, int64_t defaultValue) { | ||
| const auto* attr_proto = ctx.getAttribute(attributeName); | ||
| if ((nullptr != attr_proto) && attr_proto->has_i()) | ||
| return attr_proto->i(); | ||
| return defaultValue; | ||
| else if (nullptr != attr_proto) | ||
| return 0; // protobuf default for integers | ||
| else | ||
| return defaultValue; |
There was a problem hiding this comment.
Same issue as the InferenceContext overload: returning 0 whenever attr_proto exists but has_i() is false ignores the attribute's declared type (or missing type), and can mask malformed attributes. It would be safer to only use the protobuf scalar default when attr_proto->type() is INT, and otherwise fall back to defaultValue or report an inference/type error.
| getAttribute(const InferenceContext& ctx, const std::string& attributeName, const std::string& defaultValue) { | ||
| const auto* attr_proto = ctx.getAttribute(attributeName); | ||
| if ((nullptr != attr_proto) && attr_proto->has_s()) | ||
| return attr_proto->s(); | ||
| return defaultValue; | ||
| else if (nullptr != attr_proto) | ||
| return ""; // protobuf default for strings | ||
| else | ||
| return defaultValue; |
There was a problem hiding this comment.
Returning an empty string whenever the attribute exists but has_s() is false has the same caveat as the integer overload: it applies even if the attribute's type is unset/UNDEFINED or not STRING, which can silently override the schema default. Consider checking attr_proto->has_type() && attr_proto->type() == AttributeProto_AttributeType_STRING before using the protobuf default (otherwise keep defaultValue or fail).
This is an attempt to fix #7573
When using the default value for an attribute, there are two separate scenarios:
The previous logic was using the schema default in both cases, but in the second case one should use the protobuf default for the relevant data type instead.
This change fixes the example reported in the issue.