[Torch] Add aten._scaled_mm_v2 op support and FX import plumbing#4587
[Torch] Add aten._scaled_mm_v2 op support and FX import plumbing#4587manaalmj wants to merge 2 commits into
Conversation
508f976 to
6abb905
Compare
8583d0f to
5bf477a
Compare
5bf477a to
415b744
Compare
Adds Torch dialect support and FX importer plumbing for aten._scaled_mm_v2 Change-Id: Ifb21243209d36ccb555302b2e9a33c78c1cfbc65
415b744 to
16cdbb9
Compare
Lallapallooza
left a comment
There was a problem hiding this comment.
Thanks for the patch, few questions.
| getSingleTensorTypeFromList(getScaleA()); | ||
| FailureOr<BaseTensorType> scaleBType = | ||
| getSingleTensorTypeFromList(getScaleB()); | ||
| if (failed(scaleAType) || failed(scaleBType)) |
There was a problem hiding this comment.
scale_a/scale_b are Tensor[], but this verifier only validates the case where both operands are literal one-tensor prim.ListConstructs. A literal empty list or literal two-element list makes getSingleTensorTypeFromList fail, and line 6617 returns success(), so the scale dtype/shape checks below are skipped instead of rejecting a statically known unsupported list shape. Non-literal lists can remain unverifiable, but known literal lengths should not bypass the verifier.
Can we reject unsupported literal list lengths, or add explicit handling for PyTorch's two-level NV recipe ([BlockWise1x16, TensorWise]), which checks scale_[ab][1] as a scalar f32 tensor?
There was a problem hiding this comment.
Done. Non-literal lists remain unverifiable, but statically known empty/mismatched/two-level literal lists no longer bypass verification.
| " %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" | ||
| " return %0 : !torch.list<int>\n" | ||
| " }\n" | ||
| " func.func @\"__torch_mlir_shape_fn.aten._scaled_mm_v2\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<list<int>>, %arg6: !torch.list<int>, %arg7: !torch.list<int>, %arg8: !torch.optional<list<int>>, %arg9: !torch.optional<int>, %arg10: !torch.list<int>, %arg11: !torch.bool) -> !torch.list<int> {\n" |
There was a problem hiding this comment.
nit: I have an issue locally, check_generated_sources.sh reports stale, maybe it just my local env, but please double check and rebase on latest main and use project pinned requirements
There was a problem hiding this comment.
Done. Rebased on latest upstream main, used the project pinned requirements, and reran check_generated_sources.sh. It passes locally with no stale generated sources.
83454fb to
16cdbb9
Compare
Change-Id: I8a1cad3cf6ebb60bab655a7df7715a117e7c8972
|
@sahas3 Can you please review as well? |
| int64_t mat2K = mat2Shape[0]; | ||
| int64_t n = mat2Shape[1]; | ||
|
|
||
| bool selfIsFp4 = |
There was a problem hiding this comment.
The K adjustment doubles the contracting dimension independently for whichever operand is FP4, but PyTorch v2 meta only applies the packed-K multiplier when both operands are FP4.
That makes same-raw-K mixed FP4/FP8 inputs accepted by PyTorch meta, for example self=[128,64] FP4 with mat2=[64,128] FP8, fail this verifier as 128 != 64. Can we make the verifier and abstract shape function use the same mixed-dtype rule?
| return values; | ||
| } | ||
|
|
||
| static constexpr int64_t kScaledMmV2TensorWise = 0; |
There was a problem hiding this comment.
The verifier omits ScalingType.BlockWise1x16 (value 2), but PyTorch exposes that enum and accepts both the single-level NV recipe [BlockWise1x16] and the two-level [BlockWise1x16, TensorWise] form for _scaled_mm_v2.
https://github.com/pytorch/pytorch/blob/449aa5b695056c4c14c3134909de5ad1a3078cc8/aten/src/ATen/BlasBackend.h#L34-L43
https://github.com/pytorch/pytorch/blob/449aa5b695056c4c14c3134909de5ad1a3078cc8/torch/_meta_registrations.py#L7167-L7185
https://github.com/pytorch/pytorch/blob/449aa5b695056c4c14c3134909de5ad1a3078cc8/torch/_meta_registrations.py#L7354-L7391
| return emitOpError("expected mat2 to have an FP8 or FP4 dtype, but got ") | ||
| << mat2Type.getDtype(); | ||
|
|
||
| if (!selfType.hasSizes() || !mat2Type.hasSizes()) |
There was a problem hiding this comment.
This bailout is before the verifier looks at statically constructed scale_*, recipe_*, and swizzle_* lists. With unranked self/mat2, invalid metadata such as empty lists, mismatched scale_a/recipe_a lengths, or unsupported recipe values verifies successfully even though the same literals are rejected once the tensors are ranked. There is a similar later bailout at areAllSizesKnown() that still allows ranked-dynamic cases to skip scale checks that do not need M/N/K, such as tensorwise scale dtype and scalar-ness. Can we split the verifier so list/recipe/swizzle validation and dtype/numel checks run whenever those operands are statically visible, and only defer checks that actually need concrete matrix dimensions?
| "rank 2, but got ranks ") | ||
| << scaleAShape.size() << " and " << scaleBShape.size(); | ||
|
|
||
| int64_t kBlocks128 = logicalK / 128; |
There was a problem hiding this comment.
PyTorch's v2 meta computes L4 with round_up(K / 128, 4). Can we align this with the Python meta contract?
| # CHECK: %[[FALSE:.*]] = torch.constant.bool false | ||
| # CHECK: %[[MM:.*]] = torch.aten._scaled_mm_v2 %arg0, %arg1, %[[SCALE_A]], %[[RECIPE_A]], %[[SWIZZLE_A]], %[[SCALE_B]], %[[RECIPE_B]], %[[SWIZZLE_B]], %[[NONE]], %[[OUT_DTYPE]], %[[CONTRACTION]], %[[FALSE]] | ||
| # CHECK: return %[[MM]] | ||
| def test_import_scaled_mm_v2_block_scaled_fp4(): |
There was a problem hiding this comment.
Can we add FX importer coverage for the missing v2 paths: RowWise, one f32 blockwise pair, NV recipe 2, and one non-default call with bias/use_fast_accum/contraction_dim?
Adds Torch dialect support and FX importer plumbing for MXFP4 aten._scaled_mm_v2, plus TorchToTosa lowering to tosa.matmul_t_block_scaled. The lowering keeps packed FP4 data as ui8 in Torch IR and exposes logical f4E2M1FN tensors at the TOSA boundary.