Skip to content

webgpu: Skip MatMulNBitsMlpFusion when kernel unavailable#29089

Merged
qjia7 merged 3 commits into
microsoft:mainfrom
qjia7:webgpu-skip-matmul-nbits-mlp-fusion-without-kernel
Jun 18, 2026
Merged

webgpu: Skip MatMulNBitsMlpFusion when kernel unavailable#29089
qjia7 merged 3 commits into
microsoft:mainfrom
qjia7:webgpu-skip-matmul-nbits-mlp-fusion-without-kernel

Conversation

@qjia7

@qjia7 qjia7 commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Add kernel availability check at transformer registration time
  • Query WebGPU EP kernel registry for MatMulNBitsMlp before enabling fusion
  • Skip fusion gracefully when kernel is not registered
  • Add unit test for kernel-unavailable case

Motivation

MatMulNBitsMlpFusion is a WebGPU-specific optimization that fuses gated MLP patterns into a single MatMulNBitsMlp kernel operation. However, it was previously registered unconditionally without verifying that the MatMulNBitsMlp kernel actually exists at runtime.

This created a mismatch scenario: The fusion optimizer could be compiled into ORT, but the kernel might not be registered on the WebGPU EP, leading to "kernel not found" failures at inference time.

Realistic Scenario: Version Mismatch

Consider a common deployment scenario:

  • ORT is compiled with MatMulNBitsMlpFusion and MatMulNBitsMlp kernel support in WebGPU
  • WebGPU plugin (deployed separately) is an older version that predates MatMulNBitsMlp kernel implementation
  • At inference time:
    • The ORT optimizer successfully runs MatMulNBitsMlpFusion and creates fused nodes
    • The WebGPU plugin loads, but its kernel registry does not have MatMulNBitsMlp
    • Inference fails: "kernel 'MatMulNBitsMlp' not found on WebGpuExecutionProvider"

This can occur when optimizer versions and EP plugin versions diverge in production deployments.

The Fix

Rather than unconditionally assuming the kernel exists, the fix queries the WebGPU EP's kernel registry at graph optimization time (session initialization). If the kernel is not found, fusion is skipped entirely, preventing "kernel not found" errors and allowing graceful fallback to unfused (but functional) operators.

Test plan

  • Build ORT with WebGPU EP succeeds
  • Model loads and infers correctly (phi4-graph-prune verified)
  • Fusion enabled when kernel available (40 MatMulNBitsMlp nodes fused)
  • New test \MatMulNBitsMlpFusionSkipsWhenKernelUnavailable\ verifies graceful skip when kernel unavailable
  • Lintrunner reports no issues

MatMulNBitsMlpFusion previously attempted to fuse gated MLP patterns
unconditionally. When the MatMulNBitsMlp kernel was not registered on
the WebGPU EP (e.g., in certain build configurations), the fusion would
create fused nodes that would fail at inference time with 'kernel not
found' errors.

Fix: Add a kernel availability check at transformer registration time.
Query the WebGPU EP kernel registry for MatMulNBitsMlp and pass the
result to the fusion constructor. ApplyImpl now returns early when the
kernel is unavailable, preventing the fusion from creating nodes that
cannot execute.

Changes:
- Add ExecutionProviders* param to GenerateTransformers
- Query KernelRegistry::TryFindKernel at registration time
- Add has_matmul_nbits_mlp_kernel_ member to MatMulNBitsMlpFusion
- Skip fusion in ApplyImpl when kernel unavailable
- Add test for kernel-unavailable case

Verified: Model loads and runs correctly with WebGPU EP. Fusion is
enabled when kernel is available, skipped when unavailable.
@qjia7 qjia7 requested a review from hariharans29 June 17, 2026 07:22
@hariharans29

Copy link
Copy Markdown
Member

Review of PR #29089 — webgpu: Skip MatMulNBitsMlpFusion when kernel unavailable

Verdict: approve, with two design suggestions and one minor warning-clean nit. The bug is real (transformer/EP plugin version skew → fused graph + missing kernel → hard failure at first Run()), the diagnosis is right, and the mechanism (3-arg KernelRegistry::TryFindKernel at session-init time) is the intended use of that API.

Why the fix is correct

I verified the TryFindKernel(exec_provider, op_type, domain, version, type_constraints, logger, out) 3-arg overload at kernel_registry.h:64 — it's documented as "Find out whether a kernel is registered, without a node. This should be useful in graph optimizers, to check whether the node it is about to generate, is supported or not." Exactly this use case.

Passing TypeConstraintMap{} is correct: MatchKernelDefTypes at kernel_registry.cc:104 iterates the constraint map, so an empty map vacuously matches and only VerifyVersion is actually applied. So the call answers "is some MatMulNBitsMlp kernel registered for WebGPU at SinceVersion 1?" without falsely rejecting on type-constraint shape — which is what you want.

And the lookup tuple matches the registration site at webgpu_contrib_kernels.cc:35: (kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBitsMlp). So the check finds the kernel in any build where it's actually present.

86/86 checks green, single commit, well-bounded blast radius (1 transformer, 1 plumbing parameter, 1 new test).

Two design suggestions

1. The same bug exists for MatMulNBitsQkvFusion (and GroupQueryAttentionPreNormFusion)

Right next door at graph_transformer_utils.cc:453:

transformers.emplace_back(std::make_unique<MatMulNBitsMlpFusion>(
    InlinedHashSet<std::string_view>{onnxruntime::kWebGpuExecutionProvider},
    has_matmul_nbits_mlp_kernel));               // ← fixed by this PR
transformers.emplace_back(std::make_unique<MatMulNBitsQkvFusion>(
    InlinedHashSet<std::string_view>{onnxruntime::kWebGpuExecutionProvider}));   // ← same bug

MatMulNBitsQkv is registered identically as (kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBitsQkv) at webgpu_contrib_kernels.cc:34, and the PR description's "ORT compiled with fusion, plugin WebGPU EP doesn't have the kernel" failure mode is exactly the same scenario for MatMulNBitsQkv. The visible diff context also shows a GroupQueryAttentionPreNormFusion line right above, which is also WebGPU-only and presumably has the same risk.

I'd land this PR for MatMulNBitsMlp alone and immediately follow up with a sibling PR for the others — or fold both into this one. The plumbing is already in place; it's a 1-line lookup + 1-arg per fusion. Up to you, but the partial fix invites the same bug report to recur from a different shader file.

2. Cleaner alternative: don't register the transformer at all

The current shape (constructor flag → store member → early-return in ApplyImpl → update every test to pass true) is more machinery than the behavior needs. The equivalent two-line version is just:

if (has_matmul_nbits_mlp_kernel) {
  transformers.emplace_back(std::make_unique<MatMulNBitsMlpFusion>(
      InlinedHashSet<std::string_view>{onnxruntime::kWebGpuExecutionProvider}));
}

This:

  • Reverts every test change in matmul_nbits_mlp_fusion_test.cc (no , true suffix needed anywhere).
  • Reverts the matmul_nbits_mlp_fusion.h and matmul_nbits_mlp_fusion.cc changes entirely.
  • Replaces the new MatMulNBitsMlpFusionSkipsWhenKernelUnavailable test with a GenerateTransformersExcludesMatMulNBitsMlpFusionWhenKernelUnavailable test that exercises the registration site directly (which is what's actually being tested).
  • Scales to suggestion 1 cleanly: each sibling fusion becomes another if (has_*) transformers.emplace_back(...).

The current shape pays for the same observable behavior with a member field, a constructor arg, a default value with a footgun (see #4 below), and ~30 lines of test diff. The transformer doesn't need to know the kernel exists if the registration site already filtered it out.

If there's a reason to keep the flag inside the transformer (e.g. you anticipate the transformer being constructed from somewhere other than GenerateTransformers and want defense-in-depth) I'd want to hear it. Otherwise the registration-site filter is the simpler primitive.

Minor: [[maybe_unused]] on the new parameter

graph_transformer_utils.cc:217:

[[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool,
const ExecutionProviders* execution_providers) {

The only consumer of execution_providers in this function is the new kernel-check block, which sits inside the #if !defined(DISABLE_CONTRIB_OPS) region (the entire MatMulNBits block is gated on it — see graph_transformer_utils.cc:456). In a DISABLE_CONTRIB_OPS build the parameter is unreferenced and a warnings-as-errors toolchain will fail. Add [[maybe_unused]] to match the neighboring parameter:

[[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool,
[[maybe_unused]] const ExecutionProviders* execution_providers) {

Other things worth a one-line mention

  • Constructor default has_matmul_nbits_mlp_kernel = false is the safe direction (no false positives at runtime), but it's also a silent footgun: any future caller that forgets the second arg gets fusion disabled with no warning. Since every existing caller now passes it explicitly, the default is dead code — just remove it. Goes away entirely if you take suggestion Remove vsts test runner in cmake file #2.
  • Decl/use distance: in graph_transformer_utils.cc:459-479 the bool has_matmul_nbits_mlp_kernel is set, then 16 lines later finally consumed by the MatMulNBitsMlpFusion ctor with GroupQueryAttentionPreNormFusion registration sandwiched between them. Easier to read if the bool computation lives immediately above its single use, but trivial.
  • webgpu_ep->GetKernelRegistry() returns a std::shared_ptr<KernelRegistry> per usual convention, and the PR correctly null-checks it with if (registry && registry->TryFindKernel(...).IsOK()). Defensive and right.
  • execution_providers lifetime: the lambda at inference_session.cc:4330-4334 captures &execution_providers_ (member of InferenceSession) and is invoked synchronously to construct the transformer list. No lifetime risk.

Bottom line

Land it. The diagnosis and the use of TryFindKernel are right. Strongly consider suggestion #2 (it's a strict simplification with the same observable behavior); please definitely either fold in suggestion #1 here or open the follow-up issue immediately so MatMulNBitsQkv doesn't ship the same regression on the next plugin/optimizer version skew.

@qjia7

qjia7 commented Jun 17, 2026

Copy link
Copy Markdown
Contributor Author

Thanks for the detailed feedback. I updated the PR accordingly.

  1. Moved the kernel-availability check from fusion internals to transformer registration in GenerateTransformers, as suggested.
  2. Added the same registration-time gating for MatMulNBitsQkvFusion.
  3. Kept GroupQueryAttentionPreNormFusion unchanged in this PR scope.

Rationale for leaving GroupQueryAttentionPreNormFusion unchanged:

  • For MatMulNBitsMlp and MatMulNBitsQkv, support is a straightforward kernel-presence check.
  • For GroupQueryAttention, support depends on optional input-path behavior (for example q_norm_weight and k_norm_weight), and kernel registration alone does not encode a reliable input-count capability contract.
  • GroupQueryAttention support is currently validated through schema plus runtime behavior for the optional inputs, so I did not add a registration gate here to avoid false negatives or incorrect gating.

Also addressed the small cleanup point:

  • Added maybe_unused for the execution_providers parameter where needed.

I need to think through the right way to add checking for the GroupQueryAttentionPreNormFusion path, and will handle that in a follow-up PR.

@hariharans29

Copy link
Copy Markdown
Member

Re-review of PR #29089webgpu: Skip MatMulNBitsMlpFusion when kernel unavailable

Verdict: approve on 1298643. The follow-up commit cleanly addresses every prior item. 86/86 checks green, +47/−9 final, 4 files.

What changed in 1298643 (the "Address comments" pass)

Prior suggestion Status in 1298643
#1 — Cover MatMulNBitsQkvFusion too Done. Added a second TryFindKernel for "MatMulNBitsQkv" and a matching if (has_matmul_nbits_qkv_kernel) gate at the registration site.
#2 — Refactor to "skip registration entirely" instead of internal early-return Done. The bool has_matmul_nbits_mlp_kernel_ field is gone, the ApplyImpl early-return in matmul_nbits_mlp_fusion.cc:250-253 is gone, the ctor returns to its original one-arg signature, and the gating moves entirely to the call site.
#3[[maybe_unused]] const ExecutionProviders* execution_providers Done. graph_transformer_utils.cc:217.
#4 — Drop the has_matmul_nbits_mlp_kernel = false default footgun Done implicitly — the parameter no longer exists.
GroupQueryAttentionPreNormFusion Deferred with sound rationale. GQA support is gated on optional-input shape contracts (q_norm_weight, k_norm_weight) that a flat TryFindKernel query can't represent, so a blanket registration-time gate would risk false-negative-disabling fusion on EPs that do support the relevant input layout. Will be handled in a follow-up with a narrower capability check.

Net effect on the diff:

// graph_transformer_utils.cc — registration site
bool has_matmul_nbits_mlp_kernel = false;
bool has_matmul_nbits_qkv_kernel = false;
if (execution_providers != nullptr) {
  const auto* webgpu_ep = execution_providers->Get(onnxruntime::kWebGpuExecutionProvider);
  if (webgpu_ep != nullptr) {
    auto registry = webgpu_ep->GetKernelRegistry();
    if (registry) {
      has_matmul_nbits_mlp_kernel = registry->TryFindKernel(
          onnxruntime::kWebGpuExecutionProvider, "MatMulNBitsMlp", kMSDomain, 1,
          KernelRegistry::TypeConstraintMap{}, logger, nullptr).IsOK();
      has_matmul_nbits_qkv_kernel = registry->TryFindKernel(
          onnxruntime::kWebGpuExecutionProvider, "MatMulNBitsQkv", kMSDomain, 1,
          KernelRegistry::TypeConstraintMap{}, logger, nullptr).IsOK();
    }
  }
}
if (has_matmul_nbits_mlp_kernel) {
  transformers.emplace_back(std::make_unique<MatMulNBitsMlpFusion>(
      InlinedHashSet<std::string_view>{onnxruntime::kWebGpuExecutionProvider}));
}
if (has_matmul_nbits_qkv_kernel) {
  transformers.emplace_back(std::make_unique<MatMulNBitsQkvFusion>(
      InlinedHashSet<std::string_view>{onnxruntime::kWebGpuExecutionProvider}));
}
  • MatMulNBitsMlpFusion class is unchanged from main; tests reverted to original ctor form (no , true suffix anywhere). No dead artifacts.
  • The transformer registration order is preserved: GroupQueryAttentionPreNormFusionMatMulNBitsMlpFusionMatMulNBitsQkvFusion. No reordering side effects to worry about.
  • Per-test MatMulNBitsMlpFusionSkipsWhenKernelUnavailable was correctly removed — the skip path no longer lives in the fusion class. The equivalent at the GenerateTransformers level (verifying the fusion isn't registered without the kernel) would be a different test scaffold; not blocking, but a one-line GTEST_SKIP() << "no integration test for GenerateTransformers gating; relied on by phi4-graph-prune model-level test" style would be nice for future grep-ability. Optional.

Sanity of the MatMulNBitsQkv lookup

MatMulNBitsQkv registers at webgpu_contrib_kernels.cc:34 as (kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBitsQkv) — matches the lookup tuple ("MatMulNBitsQkv", kMSDomain, 1) exactly. Empty TypeConstraintMap{} makes MatchKernelDefTypes vacuously true (only VerifyVersion gates the result), same well-trodden pattern as the MLP query.

GroupQueryAttentionPreNormFusion deferral — agree

The rationale checks out:

  • GroupQueryAttention is a long-standing op in kMSDomain opset 1; any WebGPU EP plugin that exposes a GQA kernel will pass TryFindKernel("GroupQueryAttention", kMSDomain, 1, ...).IsOK().
  • But the PreNorm variant produces a node whose runtime correctness depends on the kernel supporting the q_norm_weight / k_norm_weight optional input slots, which the kernel registry doesn't model. A blanket presence check would say "yes, GQA is supported" and let the fusion proceed; the fused node would then fail at runtime on older kernels that don't honor those inputs.
  • A future check would need to either (a) introduce a versioning bump on the GQA kernel and key off SinceVersion, or (b) probe the EP for a capability bit. Both are out-of-scope for this PR.

Sensible to defer.

One purely optional nit (do not block)

Two TryFindKernel calls now share identical scaffolding (same provider, domain=kMSDomain, version=1, empty TypeConstraintMap, logger, null out). If a third op ever joins the list, a 3-line lambda is the cheapest reshape:

auto has_webgpu_kernel = [&](std::string_view op_type) {
  return registry->TryFindKernel(onnxruntime::kWebGpuExecutionProvider, op_type, kMSDomain, 1,
                                 KernelRegistry::TypeConstraintMap{}, logger, nullptr).IsOK();
};
has_matmul_nbits_mlp_kernel = has_webgpu_kernel("MatMulNBitsMlp");
has_matmul_nbits_qkv_kernel = has_webgpu_kernel("MatMulNBitsQkv");

Two call sites is borderline — fine either way. Wouldn't touch it just for this.

Bottom line

Land it. The follow-up commit took the cleaner shape (registration-site gating), generalized to the sibling fusion, fixed the warning issue, and deferred the GroupQueryAttention case with a principled reason instead of forcing it. Re-approve.

hariharans29
hariharans29 previously approved these changes Jun 17, 2026
@qjia7 qjia7 merged commit c172210 into microsoft:main Jun 18, 2026
86 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants