Fix Shape→Gather→TopK regression: preserve rank-1 single-element index output in data propagation#29084
Fix Shape→Gather→TopK regression: preserve rank-1 single-element index output in data propagation#29084titaiwangms wants to merge 3 commits into
Conversation
There was a problem hiding this comment.
Pull request overview
This PR fixes a shape-inference data-propagation regression where Shape → Gather([-1]) could incorrectly drop a rank-1 single-element output to a scalar, causing valid models that feed TopK’s K input to fail load-time shape inference. The change updates Gather’s custom propagation to route based on index rank, makes scalar-only elementwise propagation (Add/Sub/Mul/Div) tolerant to rank-1 [1] single-element values, and adds targeted regression tests + shared helpers.
Changes:
- Teach
Graph::SaveShapeValuesFromDataPropagation’s initializer-reader to also report initializer rank (num dims), enabling rank-based routing in custom propagation. - Preserve rank for single-element propagated values via a new shared helper and update Gather/Add/Sub/Mul/Div propagation accordingly.
- Add shape-inference regression tests and new testdata model generators for the affected patterns.
Reviewed changes
Copilot reviewed 23 out of 26 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/graph/graph.cc | Extend initializer-value reader to also report initializer rank (num dims). |
| onnxruntime/core/graph/data_propagation/custom_data_propagation.h | Introduce GetInitializedInputValuesFunc typedef carrying values + rank. |
| onnxruntime/core/graph/data_propagation/custom_data_propagation.cc | Thread the new initializer-reader signature through custom propagation factory. |
| onnxruntime/core/graph/data_propagation/data_propagation_value_utils.h | New shared helpers for reading/writing single-element propagated values while preserving rank. |
| onnxruntime/core/graph/data_propagation/gather_op_data_propagation.{h,cc} | Fix Gather propagation to preserve rank-1 [1] vs scalar based on indices rank. |
| onnxruntime/core/graph/data_propagation/add_op_data_propagation.{h,cc} | Allow propagation through Add when operands are scalar or rank-1 [1]. |
| onnxruntime/core/graph/data_propagation/sub_op_data_propagation.{h,cc} | Allow propagation through Sub when operands are scalar or rank-1 [1]. |
| onnxruntime/core/graph/data_propagation/mul_op_data_propagation.{h,cc} | Allow propagation through Mul when operands are scalar or rank-1 [1]. |
| onnxruntime/core/graph/data_propagation/div_op_data_propagation.{h,cc} | Allow propagation through Div when operands are scalar or rank-1 [1] and guard div-by-zero. |
| onnxruntime/core/graph/data_propagation/squeeze_op_data_propagation.{h,cc} | Update signature usage for initializer-reader (axes). |
| onnxruntime/core/graph/data_propagation/unsqueeze_op_data_propagation.{h,cc} | Update signature usage for initializer-reader (axes). |
| onnxruntime/core/graph/data_propagation/size_op_data_propagation.h | Update signature type in ctor to match base class. |
| onnxruntime/test/framework/shape_inference_test.cc | Add regression + guard tests covering Gather→TopK, Gather→Mul→TopK, and multi-element non-collapse. |
| onnxruntime/test/testdata/test_shape_data_propagation_gather_topk.py | Generator for the core Gather→TopK regression model. |
| onnxruntime/test/testdata/test_shape_data_propagation_gather_mul_topk.py | Generator for the Gather→Mul→TopK chain regression model. |
| onnxruntime/test/testdata/test_shape_data_propagation_shape_mul_constantofshape.py | Generator for multi-element guard model (no scalar collapse). |
Review summary — multi-model review team (readability · code · critical · deep · integration)Solid, well-contained fix. The shared Findings below, deduplicated and prioritized. None are hard blockers, but two Major1. Rank ≥ 2 single-element index is collapsed to rank-1 ( 2. Cross-module ripple into Unsqueeze/Squeeze data propagation ( Minor3. Signed-integer overflow (UB) in Add/Sub/Mul/Div. 4. 5. Reader/writer channel precedence. 6. Doc/behavior drift on 7. Test coverage gap. End-to-end rank-preservation is exercised only via the 8. Readability nits
PraiseSourcing the index rank from the same Reviewed by a 5-model team (Claude Sonnet · GPT-5.3-Codex · GPT-5.5 · Claude Opus · Gemini 3.1 Pro). Findings are advisory. |
1f862c2 to
14865d9
Compare
|
Thanks for the thorough review! Pushed an update (now at 14865d9) addressing the feedback. Summary of changes:
Full Intentionally deferred (happy to fold either in if you'd prefer):
Unrelated CI note: the |
14865d9 to
90ff823
Compare
|
Pushed a small follow-up (now at Root cause: the failure was an AddressSanitizer allocator footprint effect, not a build break or a test assertion failure — the full ASan test binary exhausted ASan's 8 GB Change in this commit: trims the four data-propagation opt-level test loops from five levels back to No production code changed from the previous revision. Full |
| effective_num_dims = (indices_shape != nullptr) ? indices_shape->dim_size() : -1; | ||
| } | ||
|
|
||
| switch (ClassifySingleValueRank(effective_num_dims)) { |
There was a problem hiding this comment.
nit: This enum and helper are only consumed by the switch in gather_op_data_propagation.cc, where it just maps num_dims to a boolean is_rank1 with a decline for ≥ 2. This can be replaced with a simple if/else inline in the Gather propagation:
if (effective_num_dims == 0) {
SetSinglePropagatedShapeValue(output_def_, dim.dim_value(), /*is_rank1=*/false);
} else if (effective_num_dims == 1) {
SetSinglePropagatedShapeValue(output_def_, dim.dim_value(), /*is_rank1=*/true);
}
// else: rank >= 2 or unknown — declineIt makes code much simpler.
There was a problem hiding this comment.
Done — thanks. The SingleValueRank enum and its ClassifySingleValueRank helper are removed; Gather now routes inline exactly as suggested in gather_op_data_propagation.cc (head 81c60a2): effective_num_dims == 0 sets a scalar, == 1 sets a rank-1 [1] value, and rank >= 2 or unknown is left unset to fall back to ONNX data propagation. effective_num_dims is the index rank, sourced from the same constant initializer the index value came from. Agreed it reads much more simply this way.
| // predicate ShouldDeclineUnsqueezeSingleValue (see ShouldDeclineUnsqueezeSingleValueTest); | ||
| // full-suite-green is only the no-regression backstop. Multi-element shape vectors are | ||
| // unaffected. | ||
| if (ShouldDeclineUnsqueezeSingleValue(tensor_shape_proto.dim_size())) { |
There was a problem hiding this comment.
nit: this is Unsqueeze-specific logic ("should Unsqueeze decline propagating this value?") and is only called from unsqueeze_op_data_propagation.cc. It belongs alongside the Unsqueeze implementation, not in the shared utils.
There was a problem hiding this comment.
Done — thanks. ShouldDeclineUnsqueezeSingleValue is removed from the shared data_propagation_value_utils.h; the Unsqueeze-specific decline now lives inline next to the Unsqueeze implementation in unsqueeze_op_data_propagation.cc (head 81c60a2), as a direct if (tensor_shape_proto.dim_size() == 1) return Status::OK(); guard with a comment explaining that any Unsqueeze would lift a rank-1 single-element value to rank >= 2. Agreed it belongs alongside the Unsqueeze op rather than in shared utils.
GatherOpDataPropagation::infer() guarded on indices.size() == 1 (element
count) and called SetInferredShapeScalarValue() unconditionally. That guard
is true for BOTH a 0-D scalar index and a 1-D single-element index, so the
1-D case had its rank dropped: Graph::getInputData() then emitted a 0-D
(dimensionless) TensorProto for the propagated value.
For the common Shape -> Gather([-1]) -> TopK exporter pattern this produced a
0-D K initializer, which ONNX TopK shape inference correctly rejects ("K
input must be a one-dimensional tensor of size 1.") at Graph::Resolve time.
The model is spec-valid (a 1-D Gather index yields a rank-1 Gather output),
so this was an ORT rank-preservation bug. It reproduces even at
ORT_DISABLE_ALL, where constant folding never runs, confirming the cause is
shape-inference data propagation rather than constant folding.
Changes:
- Gather: route by the index rank instead of element count. A 0-D scalar
index stores a scalar value; a 1-D single-element index stores a rank-1
value, so getInputData() emits a TensorProto with dims=[1] and downstream
TopK sees a valid 1-D size-1 K; a rank >= 2 index (or an index whose rank is
unknown) declines and falls back to ONNX data propagation, because the
single-value channel cannot faithfully represent a rank >= 2 Gather output.
The rank routing is inlined next to the Gather logic (a small if/else on the
index rank). The index rank is sourced from the same constant initializer the
index value comes from (via get_initialized_input_values now reporting the
initializer rank), instead of a second, independently resolved NodeArg shape
-- removing a potential source-of-truth drift.
- Elementwise consumers (Add/Sub/Mul/Div): previously scalar-only, they would
silently stop propagating once an operand became a rank-1 value (e.g. a
Shape -> Gather(1-D idx) -> Mul -> TopK chain), since the custom propagation
result replaces ONNX's rank-correct fallback. They now accept a single
element carried as either a rank-0 scalar or a rank-1 [1] value and keep the
output rank consistent with ONNX broadcasting (rank-1 if any operand is
rank-1, else scalar), so such chains keep propagating end to end. Div also
guards against division by zero.
- Unsqueeze: decline rather than propagate a dubious value when unsqueezing a
single-element (scalar-like, rank-1 [1]) value, whose result is a rank >= 2
tensor the values channel cannot faithfully represent (it would otherwise
fabricate a misleading [1, value]). The decline is a single inlined check on
the value's element count, next to the Unsqueeze logic. Multi-element shape
vectors are unaffected. Squeeze is left unchanged -- it already converts a
rank-1 [1] value to the correct scalar.
- Add shared helpers (data_propagation_value_utils.h) for reading and writing a
single-element shape value while preserving its rank, used by Gather and the
elementwise ops so producers and consumers cannot disagree on rank. The
reader declines a rank-1 multi-element value (it must never collapse to
element[0]), so a multi-element value cannot be mistaken for a single one.
The setter is correct-by-construction: it populates exactly one channel and
clears the other, so the scalar-first reader and the values-first
getInputData() can never disagree on rank even if the output carried a stale
value from a prior pass. The Gather and Unsqueeze rank decisions are kept
inline next to each op rather than shared, since the two ops route on
different quantities (index rank vs. value element count).
Tests: add ShapeInferenceV2Test.GatherToTopKRankPreservationTest and
GatherMulToTopKRankPreservationTest (with fixtures and generators) that load
the model at the disabled, basic, and all-optimization levels and assert the
rank-1 K is preserved and propagated through the chain, plus
GatherSqueezeRangeRankPreservationTest, an
observable end-to-end lock that a Shape -> Gather([-1]) -> Squeeze -> Range
chain resolves Range's length to the concrete propagated K (locking Squeeze's
correct-scalar behavior through a real downstream consumer). Add unit tests
pinning the pure read/write helpers directly: SinglePropagatedShapeValueGuardTest
(reader behavior on each channel) and
SetSinglePropagatedShapeValueKeepsSingleChannelTest (the setter clears the
opposite channel). Cover the two decline branches with real end-to-end tests
(rather than pure-predicate unit tests) that exercise them through a
rank-lowering Squeeze: GatherRank2IndexDeclineTest (a rank-2 Gather index must
decline -- relaxing it concretizes a Range length that must stay symbolic) and
UnsqueezeSingleValueDeclineTest (unsqueezing a single-element value must decline
-- relaxing it fabricates a non-scalar that makes the model fail to load). Both
run at ORT_DISABLE_ALL so constant folding does not mask the data-propagation
decline. Add ShapeMulMultiElementNoScalarCollapseTest, an end-to-end
check that a Shape -> Mul -> ConstantOfShape multi-element chain still resolves
to its full rank-2 shape. The data-propagation regression tests run at the
disabled, basic, and all-optimization levels (data propagation executes in the
pre-optimization Graph::Resolve pass, so it is independent of the
graph-optimization level). Existing scalar-index data-prop fixtures continue to
exercise the scalar path unchanged.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: titaiwangms <titaiwang@microsoft.com>
90ff823 to
839dee9
Compare
Replace the full Ort::Session in GatherRank2IndexDeclineTest and UnsqueezeSingleValueDeclineTest with onnxruntime::Model::Load (which runs Graph::Resolve) plus direct output NodeArg shape inspection. Data propagation runs unconditionally inside Graph::InferAndVerifyTypeMatch during Resolve, and no session-level optimizer/constant-folding runs, so Resolve alone exercises the decline branch in isolation -- with none of the session arena / execution-provider / kernel-registry allocations. This trims each decline test's single-process AddressSanitizer footprint (#29139) while preserving end-to-end, orthogonal discrimination of the rank-preservation fix (#29072): relaxing the Gather decline concretizes the Range length to 2000; relaxing the Unsqueeze decline makes Range's input non-scalar so Resolve returns a non-OK status and the load fails. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: titaiwangms <titaiwang@microsoft.com>
The two decline tests dereference output_shape->dim(0) immediately after checking its dim_size, so the rank guard must halt on failure: convert EXPECT_EQ to ASSERT_EQ at both sites (GatherRank2IndexDeclineTest and UnsqueezeSingleValueDeclineTest). A theoretical rank-0 regression would otherwise out-of-bounds access the protobuf repeated field instead of failing cleanly; this also matches the surrounding ASSERT-guarded preconditions. Test-only, no assertion-semantics change. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: titaiwangms <titaiwang@microsoft.com>
Summary
A spec-valid
Shape → Gather(1-D index [-1]) → TopKmodel fails to load since ORT 1.25.0 with:The model is valid: a rank-1 (single-element) Gather index produces a rank-1 Gather output, so the value feeding TopK's
Kinput is a 1-D size-1 tensor — exactly what TopK requires. The failure was an ORT rank-preservation bug in shape-inference data propagation, not a problem with the model.Root cause.
GatherOpDataPropagation::infer()routed by element count rather than index rank: it guarded onindices.size() == 1, which is true for both a 0-D scalar index and a 1-D single-element index, and then unconditionally calledSetInferredShapeScalarValue(). That dropped the rank of the spec-valid 1-D size-1 case, soGraph::getInputData()emitted a 0-D (dimensionless) propagated value. ONNX TopK shape inference then correctly rejected the 0-DK. This path was introduced by #26269 (partial data propagation to enhance shape inference).This reproduces even at
GraphOptimizationLevel.ORT_DISABLE_ALL, where constant folding never runs — confirming the cause is data propagation in shape inference, not constant folding (#26345 was an earlier mis-attribution; see the corrected analysis).Fixes the regression reported in #29072. Corrected root-cause analysis: #29072 (comment)
The fix
getInputData()emits aTensorProtowithdims=[1]and downstream TopK sees a valid 1-D size-1K. The index rank is taken from the same constant initializer the index value comes from (viaget_initialized_input_valuesnow reporting the initializer rank), rather than a second, independently-resolvedNodeArgshape — removing a potential source-of-truth drift (EDGE Remove vsts test runner in cmake file #2).Shape → Gather(1-D idx) → Mul → TopKchain), because the custom-propagation result replaces ONNX's rank-correct fallback. They now accept a single element carried as either a rank-0 scalar or a rank-1[1]value and keep the output rank consistent with ONNX broadcasting (rank-1 if any operand is rank-1, else scalar), so such chains keep propagating end-to-end. Div additionally guards against division by zero.data_propagation_value_utils.h). Centralizes reading/writing a single-element shape value while preserving its rank, used by both the Gather producer and the elementwise consumers so they cannot disagree on rank. The reader declines a rank-1 multi-element value (it must never collapse toelement[0]), so a multi-element value can never be mistaken for a single one.Testing
Five
ShapeInferenceV2Testcases (with fixtures + generators), all loading the model at every optimization level (includingORT_DISABLE_ALL):GatherToTopKRankPreservationTest— the coreShape → Gather([-1]) → TopKregression; asserts the rank-1Kis preserved.GatherMulToTopKRankPreservationTest— the… → Gather(1-D idx) → Mul → TopKchain; asserts propagation survives the elementwise op.SinglePropagatedShapeValueGuardTest— a direct unit test pinning the shared reader's behavior on each channel (scalar, rank-1 single-element, rank-1 multi-element, symbolic, empty). Mutation-proven: relaxing thedim_size()==1guard makes this test fail, restoring it makes it pass — so the guard the whole fix hinges on is test-locked.ShapeMulMultiElementNoScalarCollapseTest— end-to-end check that a multi-elementShape → Mul → ConstantOfShapechain still resolves to its full rank-2 shape (no bogus scalar collapse).PartialDataPropagationTest— pre-existing scalar-index coverage, unchanged.Full
onnxruntime_test_allsuite passes (0 failures) on top of the currentmain(opset-27 / ONNX 1.22.0 integration). The constant-folding memory path (#26345) is untouched — the diff is confined todata_propagation/, a smallgraph.ccchange, and tests.Follow-ups (intentionally out of scope for this PR)
[1,1]) to decline rather than route as rank-1 — needs its own discriminating unit test; pathological/non-exporter, worst case is degraded inference rather than a crash.DCO
Commit is DCO signed-off.