webgpu: Skip MatMulNBitsMlpFusion when kernel unavailable#29089
Conversation
MatMulNBitsMlpFusion previously attempted to fuse gated MLP patterns unconditionally. When the MatMulNBitsMlp kernel was not registered on the WebGPU EP (e.g., in certain build configurations), the fusion would create fused nodes that would fail at inference time with 'kernel not found' errors. Fix: Add a kernel availability check at transformer registration time. Query the WebGPU EP kernel registry for MatMulNBitsMlp and pass the result to the fusion constructor. ApplyImpl now returns early when the kernel is unavailable, preventing the fusion from creating nodes that cannot execute. Changes: - Add ExecutionProviders* param to GenerateTransformers - Query KernelRegistry::TryFindKernel at registration time - Add has_matmul_nbits_mlp_kernel_ member to MatMulNBitsMlpFusion - Skip fusion in ApplyImpl when kernel unavailable - Add test for kernel-unavailable case Verified: Model loads and runs correctly with WebGPU EP. Fusion is enabled when kernel is available, skipped when unavailable.
Review of PR #29089 — webgpu: Skip MatMulNBitsMlpFusion when kernel unavailableVerdict: approve, with two design suggestions and one minor warning-clean nit. The bug is real (transformer/EP plugin version skew → fused graph + missing kernel → hard failure at first Why the fix is correctI verified the Passing And the lookup tuple matches the registration site at 86/86 checks green, single commit, well-bounded blast radius (1 transformer, 1 plumbing parameter, 1 new test). Two design suggestions1. The same bug exists for
|
|
Thanks for the detailed feedback. I updated the PR accordingly.
Rationale for leaving GroupQueryAttentionPreNormFusion unchanged:
Also addressed the small cleanup point:
I need to think through the right way to add checking for the GroupQueryAttentionPreNormFusion path, and will handle that in a follow-up PR. |
Re-review of PR #29089 —
|
| Prior suggestion | Status in 1298643 |
|---|---|
#1 — Cover MatMulNBitsQkvFusion too |
Done. Added a second TryFindKernel for "MatMulNBitsQkv" and a matching if (has_matmul_nbits_qkv_kernel) gate at the registration site. |
| #2 — Refactor to "skip registration entirely" instead of internal early-return | Done. The bool has_matmul_nbits_mlp_kernel_ field is gone, the ApplyImpl early-return in matmul_nbits_mlp_fusion.cc:250-253 is gone, the ctor returns to its original one-arg signature, and the gating moves entirely to the call site. |
#3 — [[maybe_unused]] const ExecutionProviders* execution_providers |
Done. graph_transformer_utils.cc:217. |
#4 — Drop the has_matmul_nbits_mlp_kernel = false default footgun |
Done implicitly — the parameter no longer exists. |
GroupQueryAttentionPreNormFusion |
Deferred with sound rationale. GQA support is gated on optional-input shape contracts (q_norm_weight, k_norm_weight) that a flat TryFindKernel query can't represent, so a blanket registration-time gate would risk false-negative-disabling fusion on EPs that do support the relevant input layout. Will be handled in a follow-up with a narrower capability check. |
Net effect on the diff:
// graph_transformer_utils.cc — registration site
bool has_matmul_nbits_mlp_kernel = false;
bool has_matmul_nbits_qkv_kernel = false;
if (execution_providers != nullptr) {
const auto* webgpu_ep = execution_providers->Get(onnxruntime::kWebGpuExecutionProvider);
if (webgpu_ep != nullptr) {
auto registry = webgpu_ep->GetKernelRegistry();
if (registry) {
has_matmul_nbits_mlp_kernel = registry->TryFindKernel(
onnxruntime::kWebGpuExecutionProvider, "MatMulNBitsMlp", kMSDomain, 1,
KernelRegistry::TypeConstraintMap{}, logger, nullptr).IsOK();
has_matmul_nbits_qkv_kernel = registry->TryFindKernel(
onnxruntime::kWebGpuExecutionProvider, "MatMulNBitsQkv", kMSDomain, 1,
KernelRegistry::TypeConstraintMap{}, logger, nullptr).IsOK();
}
}
}
if (has_matmul_nbits_mlp_kernel) {
transformers.emplace_back(std::make_unique<MatMulNBitsMlpFusion>(
InlinedHashSet<std::string_view>{onnxruntime::kWebGpuExecutionProvider}));
}
if (has_matmul_nbits_qkv_kernel) {
transformers.emplace_back(std::make_unique<MatMulNBitsQkvFusion>(
InlinedHashSet<std::string_view>{onnxruntime::kWebGpuExecutionProvider}));
}MatMulNBitsMlpFusionclass is unchanged from main; tests reverted to original ctor form (no, truesuffix anywhere). No dead artifacts.- The transformer registration order is preserved:
GroupQueryAttentionPreNormFusion→MatMulNBitsMlpFusion→MatMulNBitsQkvFusion. No reordering side effects to worry about. - Per-test
MatMulNBitsMlpFusionSkipsWhenKernelUnavailablewas correctly removed — the skip path no longer lives in the fusion class. The equivalent at theGenerateTransformerslevel (verifying the fusion isn't registered without the kernel) would be a different test scaffold; not blocking, but a one-lineGTEST_SKIP() << "no integration test for GenerateTransformers gating; relied on by phi4-graph-prune model-level test"style would be nice for future grep-ability. Optional.
Sanity of the MatMulNBitsQkv lookup
MatMulNBitsQkv registers at webgpu_contrib_kernels.cc:34 as (kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBitsQkv) — matches the lookup tuple ("MatMulNBitsQkv", kMSDomain, 1) exactly. Empty TypeConstraintMap{} makes MatchKernelDefTypes vacuously true (only VerifyVersion gates the result), same well-trodden pattern as the MLP query.
GroupQueryAttentionPreNormFusion deferral — agree
The rationale checks out:
GroupQueryAttentionis a long-standing op inkMSDomainopset 1; any WebGPU EP plugin that exposes a GQA kernel will passTryFindKernel("GroupQueryAttention", kMSDomain, 1, ...).IsOK().- But the PreNorm variant produces a node whose runtime correctness depends on the kernel supporting the q_norm_weight / k_norm_weight optional input slots, which the kernel registry doesn't model. A blanket presence check would say "yes, GQA is supported" and let the fusion proceed; the fused node would then fail at runtime on older kernels that don't honor those inputs.
- A future check would need to either (a) introduce a versioning bump on the GQA kernel and key off SinceVersion, or (b) probe the EP for a capability bit. Both are out-of-scope for this PR.
Sensible to defer.
One purely optional nit (do not block)
Two TryFindKernel calls now share identical scaffolding (same provider, domain=kMSDomain, version=1, empty TypeConstraintMap, logger, null out). If a third op ever joins the list, a 3-line lambda is the cheapest reshape:
auto has_webgpu_kernel = [&](std::string_view op_type) {
return registry->TryFindKernel(onnxruntime::kWebGpuExecutionProvider, op_type, kMSDomain, 1,
KernelRegistry::TypeConstraintMap{}, logger, nullptr).IsOK();
};
has_matmul_nbits_mlp_kernel = has_webgpu_kernel("MatMulNBitsMlp");
has_matmul_nbits_qkv_kernel = has_webgpu_kernel("MatMulNBitsQkv");Two call sites is borderline — fine either way. Wouldn't touch it just for this.
Bottom line
Land it. The follow-up commit took the cleaner shape (registration-site gating), generalized to the sibling fusion, fixed the warning issue, and deferred the GroupQueryAttention case with a principled reason instead of forcing it. Re-approve.
Summary
Motivation
MatMulNBitsMlpFusion is a WebGPU-specific optimization that fuses gated MLP patterns into a single MatMulNBitsMlp kernel operation. However, it was previously registered unconditionally without verifying that the MatMulNBitsMlp kernel actually exists at runtime.
This created a mismatch scenario: The fusion optimizer could be compiled into ORT, but the kernel might not be registered on the WebGPU EP, leading to "kernel not found" failures at inference time.
Realistic Scenario: Version Mismatch
Consider a common deployment scenario:
This can occur when optimizer versions and EP plugin versions diverge in production deployments.
The Fix
Rather than unconditionally assuming the kernel exists, the fix queries the WebGPU EP's kernel registry at graph optimization time (session initialization). If the kernel is not found, fusion is skipped entirely, preventing "kernel not found" errors and allowing graceful fallback to unfused (but functional) operators.
Test plan