Skip to content

[Torch] Add aten._scaled_mm_v2 op support and FX import plumbing#4587

Open
manaalmj wants to merge 2 commits into
llvm:mainfrom
manaalmj:mxfp4-scaled-mm-v2-to-tosa
Open

[Torch] Add aten._scaled_mm_v2 op support and FX import plumbing#4587
manaalmj wants to merge 2 commits into
llvm:mainfrom
manaalmj:mxfp4-scaled-mm-v2-to-tosa

Conversation

@manaalmj

Copy link
Copy Markdown
Collaborator

Adds Torch dialect support and FX importer plumbing for MXFP4 aten._scaled_mm_v2, plus TorchToTosa lowering to tosa.matmul_t_block_scaled. The lowering keeps packed FP4 data as ui8 in Torch IR and exposes logical f4E2M1FN tensors at the TOSA boundary.

@manaalmj manaalmj force-pushed the mxfp4-scaled-mm-v2-to-tosa branch 4 times, most recently from 508f976 to 6abb905 Compare June 1, 2026 22:18
@manaalmj manaalmj marked this pull request as draft June 2, 2026 11:36
@manaalmj manaalmj force-pushed the mxfp4-scaled-mm-v2-to-tosa branch 6 times, most recently from 8583d0f to 5bf477a Compare June 3, 2026 10:49
@manaalmj manaalmj marked this pull request as ready for review June 3, 2026 11:12
@manaalmj manaalmj marked this pull request as draft June 3, 2026 13:28
@manaalmj manaalmj force-pushed the mxfp4-scaled-mm-v2-to-tosa branch from 5bf477a to 415b744 Compare June 4, 2026 21:47
@manaalmj manaalmj changed the title [Torch][TorchToTosa] Add MXFP4 scaled_mm_v2 support [Torch] Add MXFP4 scaled_mm_v2 support Jun 4, 2026
@manaalmj manaalmj changed the title [Torch] Add MXFP4 scaled_mm_v2 support [Torch] Add aten._scaled_mm_v2 op support and FX import plumbing Jun 4, 2026
Adds Torch dialect support and FX importer plumbing for aten._scaled_mm_v2

Change-Id: Ifb21243209d36ccb555302b2e9a33c78c1cfbc65
@manaalmj manaalmj force-pushed the mxfp4-scaled-mm-v2-to-tosa branch from 415b744 to 16cdbb9 Compare June 5, 2026 07:40
@manaalmj manaalmj marked this pull request as ready for review June 5, 2026 08:02

@Lallapallooza Lallapallooza left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the patch, few questions.

Comment thread lib/Dialect/Torch/IR/TorchOps.cpp Outdated
Comment thread lib/Dialect/Torch/IR/TorchOps.cpp Outdated
Comment thread lib/Dialect/Torch/IR/TorchOps.cpp Outdated
Comment thread lib/Dialect/Torch/IR/TorchOps.cpp
Comment thread lib/Dialect/Torch/IR/TorchOps.cpp Outdated
getSingleTensorTypeFromList(getScaleA());
FailureOr<BaseTensorType> scaleBType =
getSingleTensorTypeFromList(getScaleB());
if (failed(scaleAType) || failed(scaleBType))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scale_a/scale_b are Tensor[], but this verifier only validates the case where both operands are literal one-tensor prim.ListConstructs. A literal empty list or literal two-element list makes getSingleTensorTypeFromList fail, and line 6617 returns success(), so the scale dtype/shape checks below are skipped instead of rejecting a statically known unsupported list shape. Non-literal lists can remain unverifiable, but known literal lengths should not bypass the verifier.

Can we reject unsupported literal list lengths, or add explicit handling for PyTorch's two-level NV recipe ([BlockWise1x16, TensorWise]), which checks scale_[ab][1] as a scalar f32 tensor?

https://github.com/pytorch/pytorch/blob/449aa5b695056c4c14c3134909de5ad1a3078cc8/torch/_meta_registrations.py#L7177-L7185

https://github.com/pytorch/pytorch/blob/449aa5b695056c4c14c3134909de5ad1a3078cc8/torch/_meta_registrations.py#L7371-L7386

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Non-literal lists remain unverifiable, but statically known empty/mismatched/two-level literal lists no longer bypass verification.

Comment thread test/python/fx_importer/scaled_mm_test.py Outdated
" %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._scaled_mm_v2\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<list<int>>, %arg6: !torch.list<int>, %arg7: !torch.list<int>, %arg8: !torch.optional<list<int>>, %arg9: !torch.optional<int>, %arg10: !torch.list<int>, %arg11: !torch.bool) -> !torch.list<int> {\n"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I have an issue locally, check_generated_sources.sh reports stale, maybe it just my local env, but please double check and rebase on latest main and use project pinned requirements

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Rebased on latest upstream main, used the project pinned requirements, and reran check_generated_sources.sh. It passes locally with no stale generated sources.

Change-Id: I8a1cad3cf6ebb60bab655a7df7715a117e7c8972
@manaalmj manaalmj requested review from Lallapallooza and sahas3 June 10, 2026 09:58
@manaalmj

Copy link
Copy Markdown
Collaborator Author

@sahas3 Can you please review as well?

int64_t mat2K = mat2Shape[0];
int64_t n = mat2Shape[1];

bool selfIsFp4 =

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The K adjustment doubles the contracting dimension independently for whichever operand is FP4, but PyTorch v2 meta only applies the packed-K multiplier when both operands are FP4.

That makes same-raw-K mixed FP4/FP8 inputs accepted by PyTorch meta, for example self=[128,64] FP4 with mat2=[64,128] FP8, fail this verifier as 128 != 64. Can we make the verifier and abstract shape function use the same mixed-dtype rule?

https://github.com/pytorch/pytorch/blob/449aa5b695056c4c14c3134909de5ad1a3078cc8/torch/_meta_registrations.py#L7089-L7092

return values;
}

static constexpr int64_t kScaledMmV2TensorWise = 0;

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return emitOpError("expected mat2 to have an FP8 or FP4 dtype, but got ")
<< mat2Type.getDtype();

if (!selfType.hasSizes() || !mat2Type.hasSizes())

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bailout is before the verifier looks at statically constructed scale_*, recipe_*, and swizzle_* lists. With unranked self/mat2, invalid metadata such as empty lists, mismatched scale_a/recipe_a lengths, or unsupported recipe values verifies successfully even though the same literals are rejected once the tensors are ranked. There is a similar later bailout at areAllSizesKnown() that still allows ranked-dynamic cases to skip scale checks that do not need M/N/K, such as tensorwise scale dtype and scalar-ness. Can we split the verifier so list/recipe/swizzle validation and dtype/numel checks run whenever those operands are statically visible, and only defer checks that actually need concrete matrix dimensions?

"rank 2, but got ranks ")
<< scaleAShape.size() << " and " << scaleBShape.size();

int64_t kBlocks128 = logicalK / 128;

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch's v2 meta computes L4 with round_up(K / 128, 4). Can we align this with the Python meta contract?

https://github.com/pytorch/pytorch/blob/449aa5b695056c4c14c3134909de5ad1a3078cc8/torch/_meta_registrations.py#L7264-L7325

# CHECK: %[[FALSE:.*]] = torch.constant.bool false
# CHECK: %[[MM:.*]] = torch.aten._scaled_mm_v2 %arg0, %arg1, %[[SCALE_A]], %[[RECIPE_A]], %[[SWIZZLE_A]], %[[SCALE_B]], %[[RECIPE_B]], %[[SWIZZLE_B]], %[[NONE]], %[[OUT_DTYPE]], %[[CONTRACTION]], %[[FALSE]]
# CHECK: return %[[MM]]
def test_import_scaled_mm_v2_block_scaled_fp4():

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add FX importer coverage for the missing v2 paths: RowWise, one f32 blockwise pair, NV recipe 2, and one non-default call with bias/use_fast_accum/contraction_dim?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants