[TRTLLM-12557][feat] WideEP FT: add AlltoAll watchdog (1a.4)#15524
[TRTLLM-12557][feat] WideEP FT: add AlltoAll watchdog (1a.4)#15524chienchunhung wants to merge 6 commits into
Conversation
Eliminates the infinite-spin AlltoAll hang that turns a single GPU failure in a Wide-EP group into a 5-minute HangDetector fire + full restart. The dispatch and combine kernels now take a uint64[2] bitmask of currently-alive EP ranks; dead ranks are skipped on every completion-flag write/wait, peer recv_counter store, EPLB stats write, and per-token routing decision (dead-targeted slots collapse to the same -1 sentinel combine already uses for duplicates). The mask is optional on both torch ops; omitting it (or passing all-ones) produces bit-identical output to the pre-change kernel. kMaxRanks is bumped 64 -> 128 to cover NVL72 with headroom; kRankMaskWords = 2 names the kernel ABI explicitly. Tests cover (a) all-ones mask matches no-mask bit-for-bit, and (b) one rank masked dead -> surviving ranks complete dispatch+combine without hang, dead-targeted topk slots dropped, in tests/unittest/_torch/multi_gpu/test_moe_a2a_rank_mask.py. Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
48cce16 to
e5824e0
Compare
|
/bot run |
|
PR_Github #55085 [ run ] triggered by Bot. Commit: |
ccb7fd3 to
f62ed04
Compare
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
f62ed04 to
26639ec
Compare
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
4ddb47d to
df71666
Compare
|
/bot run |
|
PR_Github #55091 [ run ] triggered by Bot. Commit: |
|
PR_Github #55085 [ run ] completed with state |
|
PR_Github #55091 [ run ] completed with state
|
📝 WalkthroughWalkthroughAdds ChangesWide-EP Fault Tolerance: Active Rank Mask + AlltoAll Watchdog
Sequence Diagram(s)sequenceDiagram
participant FusedMoE
participant MoeAlltoAll
participant moe_a2a_dispatch
participant CUDAKernel as moeA2ADispatchKernel
participant AlltoAllWatchdog
FusedMoE->>MoeAlltoAll: dispatch(tokens, active_rank_mask)
MoeAlltoAll->>MoeAlltoAll: _get_active_rank_mask_tensor(active_rank_mask)
MoeAlltoAll->>moe_a2a_dispatch: dispatch(tokens, ..., active_rank_mask)
moe_a2a_dispatch->>CUDAKernel: launch with kernel_ptrs.active_rank_mask
CUDAKernel->>CUDAKernel: is_rank_active() → skip routing/sync for dead ranks
CUDAKernel-->>moe_a2a_dispatch: dispatched tokens
moe_a2a_dispatch-->>MoeAlltoAll: dispatched tokens
MoeAlltoAll->>AlltoAllWatchdog: watch("dispatch", expected_flag, active_mask)
AlltoAllWatchdog->>AlltoAllWatchdog: poll completion flags for active ranks
AlltoAllWatchdog-->>MoeAlltoAll: (async monitoring)
MoeAlltoAll-->>FusedMoE: dispatched tokens
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (5)
tensorrt_llm/_torch/alltoall_watchdog.py (3)
122-147: 🚀 Performance & Scalability | 🔵 Trivial | ⚡ Quick winReuse the pinned host buffer instead of allocating one per poll.
_read_cuda_flagsallocates a fresh pinnedhost_flagstensor and a newEventon every read. When the watchdog is enabled, this runs for each queued phase (≈1-2 reads per dispatch/combine), so a busy decode step with many MoE layers triggers a large number of smallcudaHostAlloccalls, which are not cheap. Since the common path already waits for the copy to finish before returning, the buffer can be reused; only retire-and-reallocate on the timeout path where the in-flight copy still owns it.⚡ Sketch
self._retired_copies: list[tuple[torch.Tensor, torch.cuda.Event]] = [] + self._host_flags: torch.Tensor | None = None if workspace.device.type == "cuda": self._copy_stream = torch.cuda.Stream(device=workspace.device)self._prune_retired_copies() - - host_flags = torch.empty( - (self._ep_size,), - dtype=torch.int32, - device="cpu", - pin_memory=prefer_pinned(), - ) + if self._host_flags is None: + self._host_flags = torch.empty( + (self._ep_size,), + dtype=torch.int32, + device="cpu", + pin_memory=prefer_pinned(), + ) + host_flags = self._host_flags event = torch.cuda.Event(blocking=False) ... if remaining_s <= 0: self._retired_copies.append((host_flags, event)) + self._host_flags = None # buffer still in flight; allocate fresh next time raise CompletionFlagReadTimeout(🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/alltoall_watchdog.py` around lines 122 - 147, The _read_cuda_flags method allocates a new pinned host_flags tensor and a new Event on every call, which is inefficient when called repeatedly. Instead, create and store a reusable pinned host_flags tensor and Event as instance variables (initialized once, perhaps in __init__ or lazily on first use), and reuse them across calls to _read_cuda_flags. Only allocate and append new buffers and events to the _retired_copies list when a timeout occurs, since that is the only case where an in-flight copy still owns the buffer. This avoids repeated cudaHostAlloc calls on the common path where the copy completes before returning.
30-30: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winPrefer PEP 604 / built-in generics over
OptionalandDeque.This new module mixes
X | None(e.g. lines 112, 202, 245) withOptional[...](lines 178-179, 216-217) andtyping.Deque(line 199). As per coding guidelines, use|syntax instead oftyping.Union/Optionaland prefer the built-indeque/collection types.from __future__ import annotationsis already enabled, so the runtime cost is nil.♻️ Suggested adjustments
-from typing import Callable, Deque, Mapping, Optional, Protocol, Sequence +from collections.abc import Callable, Mapping, Sequence +from typing import Protocol- self._queue: Deque[_CollectiveWatch] = deque() + self._queue: deque[_CollectiveWatch] = deque()- health: Optional[EPGroupHealthLike] = None, - on_timeout: Optional[Callable[[AlltoAllWatchdogTimeout], None]] = None, + health: EPGroupHealthLike | None = None, + on_timeout: Callable[[AlltoAllWatchdogTimeout], None] | None = None,(apply to the
from_workspacesignature as well)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/alltoall_watchdog.py` at line 30, The module uses inconsistent type annotation styles, mixing PEP 604 syntax (X | None) with Optional and typing.Deque. Remove Optional and Deque from the imports at the top of the file, add from collections import deque if not already present, and replace all instances of Optional[X] with X | None throughout the module (including in function signatures like from_workspace). Replace all usages of typing.Deque with the built-in deque type from collections. Since the module already has from __future__ import annotations enabled, these changes will have no runtime cost.Source: Coding guidelines
424-430: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueConsider narrowing
except BaseExceptiontoException.Catching
BaseExceptionhere also swallowsKeyboardInterrupt/SystemExit/GeneratorExit. They are unlikely in a daemon thread, but narrowing toExceptionkeeps the watchdog-failure visibility intent while letting genuine control-flow exceptions propagate. As per coding guidelines, limitexceptto the smallest set of errors possible.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/alltoall_watchdog.py` around lines 424 - 430, The except clause is catching BaseException which is too broad and will swallow control-flow exceptions like KeyboardInterrupt and SystemExit that should propagate in the watchdog. In the watchdog polling error handler where the exception is caught and stored in self._last_error, change the except clause from catching BaseException to catching Exception instead. Keep all the error handling logic the same (storing the error, clearing the queue, notifying via condition variable, and logging the error message) - only the exception type being caught needs to be narrowed.Source: Coding guidelines
tests/unittest/_torch/multi_gpu/test_moe_a2a_rank_mask.py (1)
60-63: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winAdd return annotations to the new MPI test functions.
The fixture, worker functions, and test functions should declare their return types to match the repo’s Python typing guidance. As per coding guidelines, “Always annotate functions with return types.”
Proposed return annotations
`@pytest.fixture`(autouse=True) -def setup_test(): +def setup_test() -> None: @@ def _worker_all_active_matches_no_mask( @@ -): +) -> tuple[int, bool, bool]: @@ def _worker_one_rank_masked( @@ -): +) -> tuple[int, str, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: @@ -def test_all_active_mask_matches_no_mask(mpi_pool_executor, local_num_tokens, top_k): +def test_all_active_mask_matches_no_mask(mpi_pool_executor, local_num_tokens, top_k) -> None: @@ -def test_one_rank_masked_completes(mpi_pool_executor, dead_rank, local_num_tokens, top_k): +def test_one_rank_masked_completes( + mpi_pool_executor, + dead_rank, + local_num_tokens, + top_k, +) -> None:Also applies to: 165-172, 211-219, 279-397
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/multi_gpu/test_moe_a2a_rank_mask.py` around lines 60 - 63, Add return type annotations to all test-related functions in this file to comply with the typing guidelines. For the setup_test fixture and other functions that don't explicitly return a value, add the return annotation -> None. For test functions (starting with test_) and worker functions, ensure they also have appropriate return type annotations (-> None for most test functions). Apply this change to the setup_test fixture shown in the diff and to all other test functions and fixtures referenced in the "Also applies to" section (lines 165-172, 211-219, 279-397).Source: Coding guidelines
tests/unittest/_torch/modules/test_alltoall_watchdog.py (1)
17-19: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winAnnotate the remaining test helper contract.
_wait_foris shared by the timeout tests, so typing its predicate makes the polling contract explicit and keeps this new module aligned with the repo’s static-typing guidance. Consider typing the pytest fixture too while touching the signature. As per coding guidelines, “Always annotate functions” and prefer preciseCallableargument types.Proposed typing cleanup
import threading import time +from collections.abc import Callable from types import SimpleNamespace @@ -def _wait_for(predicate, timeout_s: float = 1.0) -> None: +def _wait_for(predicate: Callable[[], bool], timeout_s: float = 1.0) -> None: @@ -def test_wide_ep_ft_options_create_shared_health_when_enabled(monkeypatch) -> None: +def test_wide_ep_ft_options_create_shared_health_when_enabled( + monkeypatch: pytest.MonkeyPatch, +) -> None:Also applies to: 72-78, 129-144
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/modules/test_alltoall_watchdog.py` around lines 17 - 19, The _wait_for function and related pytest fixtures lack proper type annotations for their parameters and return types. Add type annotations to the _wait_for function, specifically typing its predicate parameter as a Callable that accepts and returns the appropriate types based on how it's used in the timeout tests. Additionally, add type annotations to the pytest fixtures referenced at lines 72-78 and 129-144 using appropriate return type hints to make the polling contract explicit and align with the repository's static typing guidance and coding guidelines that require all functions to be annotated with precise Callable types for function arguments.Source: Coding guidelines
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h`:
- Around line 95-98: The combine operation's per-token load path reads peer
buffers without checking if those peers are still active according to
active_rank_mask, which can cause stale reads if a rank was active during
dispatch but became inactive by combine time. Add an explicit active_rank_mask
check before reading from peer buffers in the combine's per-token load
operation, or alternatively, enforce that both dispatch and combine phases use
the same mask object to guarantee consistency. The check should verify that a
peer rank's bit is set in active_rank_mask before allowing its buffer to be
read, similar to how topk_send_indices is currently used to skip dead-targeted
slots.
In `@tensorrt_llm/_torch/distributed/moe_alltoall.py`:
- Around line 236-243: The issue is that _watchdog_flag_generation is maintained
as a per-instance counter (initialized when reading the current flag value and
later incremented locally), but since completion flags come from shared
workspace state, each instance may have out-of-sync generation values. Instead
of using per-instance state for _watchdog_flag_generation, derive the
expected_flag value from a generation source that is shared with and
synchronized through the workspace (obtained from the AlltoAllWatchdog or
workspace object itself). This ensures all MoeAlltoAll instances reference
consistent generation values that match the actual shared workspace state,
preventing stale expected_flag values and false watchdog timeouts.
In `@tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py`:
- Around line 319-327: The watchdog flag generation tracking is unsafe because
`_watchdog_flag_generation` is stored as a per-instance counter while the
completion flags are tied to shared workspace state in `_WORKSPACES`. When
multiple `NVLinkOneSided` instances reuse the same workspace, their local
counters drift out of sync causing false timeout reports. Replace the
per-instance `_watchdog_flag_generation` counter with a workspace-shared
generation source that is synchronized for concurrent access. Update the
initialization of `_watchdog_flag_generation` and its usage in the watchdog
setup and flag increment logic to use this workspace-shared source instead of
maintaining a separate local counter.
In `@tests/unittest/_torch/multi_gpu/test_moe_a2a_rank_mask.py`:
- Around line 165-203: The worker function _worker_all_active_matches_no_mask
needs to explicitly return the MPI rank along with the boolean comparison
results, changing the return statement to include rank as the first element in
the tuple so that test assertions can verify correctness based on actual rank
rather than relying on enumerate index order. Additionally, add return type
annotations to _worker_all_active_matches_no_mask to indicate it returns a tuple
containing an int and two bools, to _worker_one_rank_masked to indicate it
returns a tuple with an int and additional elements, and to the test functions
test_all_active_mask_matches_no_mask and test_one_rank_masked_completes to
indicate they return None, following the Python annotation guideline.
---
Nitpick comments:
In `@tensorrt_llm/_torch/alltoall_watchdog.py`:
- Around line 122-147: The _read_cuda_flags method allocates a new pinned
host_flags tensor and a new Event on every call, which is inefficient when
called repeatedly. Instead, create and store a reusable pinned host_flags tensor
and Event as instance variables (initialized once, perhaps in __init__ or lazily
on first use), and reuse them across calls to _read_cuda_flags. Only allocate
and append new buffers and events to the _retired_copies list when a timeout
occurs, since that is the only case where an in-flight copy still owns the
buffer. This avoids repeated cudaHostAlloc calls on the common path where the
copy completes before returning.
- Line 30: The module uses inconsistent type annotation styles, mixing PEP 604
syntax (X | None) with Optional and typing.Deque. Remove Optional and Deque from
the imports at the top of the file, add from collections import deque if not
already present, and replace all instances of Optional[X] with X | None
throughout the module (including in function signatures like from_workspace).
Replace all usages of typing.Deque with the built-in deque type from
collections. Since the module already has from __future__ import annotations
enabled, these changes will have no runtime cost.
- Around line 424-430: The except clause is catching BaseException which is too
broad and will swallow control-flow exceptions like KeyboardInterrupt and
SystemExit that should propagate in the watchdog. In the watchdog polling error
handler where the exception is caught and stored in self._last_error, change the
except clause from catching BaseException to catching Exception instead. Keep
all the error handling logic the same (storing the error, clearing the queue,
notifying via condition variable, and logging the error message) - only the
exception type being caught needs to be narrowed.
In `@tests/unittest/_torch/modules/test_alltoall_watchdog.py`:
- Around line 17-19: The _wait_for function and related pytest fixtures lack
proper type annotations for their parameters and return types. Add type
annotations to the _wait_for function, specifically typing its predicate
parameter as a Callable that accepts and returns the appropriate types based on
how it's used in the timeout tests. Additionally, add type annotations to the
pytest fixtures referenced at lines 72-78 and 129-144 using appropriate return
type hints to make the polling contract explicit and align with the repository's
static typing guidance and coding guidelines that require all functions to be
annotated with precise Callable types for function arguments.
In `@tests/unittest/_torch/multi_gpu/test_moe_a2a_rank_mask.py`:
- Around line 60-63: Add return type annotations to all test-related functions
in this file to comply with the typing guidelines. For the setup_test fixture
and other functions that don't explicitly return a value, add the return
annotation -> None. For test functions (starting with test_) and worker
functions, ensure they also have appropriate return type annotations (-> None
for most test functions). Apply this change to the setup_test fixture shown in
the diff and to all other test functions and fixtures referenced in the "Also
applies to" section (lines 165-172, 211-219, 279-397).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: cccddc86-43ba-4417-85ca-f4598b16e4b1
📒 Files selected for processing (13)
cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cucpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.hcpp/tensorrt_llm/thop/moeAlltoAllOp.cpptensorrt_llm/_torch/alltoall_watchdog.pytensorrt_llm/_torch/custom_ops/cpp_custom_ops.pytensorrt_llm/_torch/distributed/moe_alltoall.pytensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.pytensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.pytensorrt_llm/_torch/modules/fused_moe/wide_ep_ft.pytests/unittest/_torch/modules/test_alltoall_watchdog.pytests/unittest/_torch/multi_gpu/test_moe_a2a_rank_mask.py
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
|
/bot run |
|
PR_Github #55292 [ run ] triggered by Bot. Commit: |
|
PR_Github #55292 [ run ] completed with state
|
Summary
AlltoAllWatchdoghost thread that polls dispatch/combine completion flags in FIFO order and reports timed-out ranks.MoeAlltoAllandNVLinkOneSided, including active-rank-mask forwarding from health.Stack
WideEP-FT/1a.2-nvlink-kernel-mask).Validation
python -m compileall -q tensorrt_llm/_torch/alltoall_watchdog.py tensorrt_llm/_torch/distributed/moe_alltoall.py tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py tests/unittest/_torch/modules/test_alltoall_watchdog.pygit diff --checkNot run: pytest, because the available local Python runtimes do not have
torchinstalled.Summary by CodeRabbit
Release Notes
New Features
Tests