Skip to content

Commit 53117cb

Browse files
authored
Make client-side cancellation work over the 2026 transports (#3046)
1 parent bf44027 commit 53117cb

14 files changed

Lines changed: 1213 additions & 102 deletions

src/mcp/client/session.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,16 @@ def stamp(data: dict[str, Any], opts: CallOptions) -> None:
9393
meta[PROTOCOL_VERSION_META_KEY] = protocol_version
9494
meta[CLIENT_INFO_META_KEY] = client_info
9595
meta[CLIENT_CAPABILITIES_META_KEY] = capabilities
96-
opts["cancel_on_abandon"] = False
96+
# `cancel_on_abandon` stays at the dispatcher default (True): the
97+
# courtesy `notifications/cancelled` is the abandon signal. On the
98+
# stream transports it is the 2026 wire's cancellation spelling; the
99+
# streamable-HTTP transport translates it into aborting the request's
100+
# own POST instead of writing it (the 2026 HTTP wire has no
101+
# client-to-server notifications - closing the stream is the signal).
102+
# The negotiation methods still opt out, mirroring `_preconnect_stamp`:
103+
# the spec forbids cancelling them.
104+
if data["method"] in ("initialize", "server/discover"):
105+
opts["cancel_on_abandon"] = False
97106
headers = opts.setdefault("headers", {})
98107
headers[MCP_PROTOCOL_VERSION_HEADER] = protocol_version
99108
headers[MCP_METHOD_HEADER] = data["method"]

src/mcp/client/streamable_http.py

Lines changed: 95 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@
2626
RequestId,
2727
jsonrpc_message_adapter,
2828
)
29+
from mcp_types.version import MODERN_PROTOCOL_VERSIONS
2930
from pydantic import ValidationError
3031

3132
from mcp.client._transport import TransportStreams
3233
from mcp.shared._compat import resync_tracer
3334
from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams
3435
from mcp.shared._httpx_utils import create_mcp_http_client
3536
from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER
37+
from mcp.shared.jsonrpc_dispatcher import cancelled_request_id_from_params
3638
from mcp.shared.message import ClientMessageMetadata, SessionMessage
3739

3840
logger = logging.getLogger(__name__)
@@ -70,6 +72,19 @@ class RequestContext:
7072
read_stream_writer: StreamWriter
7173

7274

75+
@dataclass(slots=True)
76+
class _InFlightPost:
77+
"""A request POST in flight: its abort scope and the era it was sent under.
78+
79+
`modern` is the negotiated-version cache as of this request's dequeue, so a
80+
later cancel frame is interpreted under the era the request actually ran
81+
with, not whatever the cache says by then.
82+
"""
83+
84+
scope: anyio.CancelScope
85+
modern: bool
86+
87+
7388
class StreamableHTTPTransport:
7489
"""StreamableHTTP client transport implementation."""
7590

@@ -81,21 +96,28 @@ def __init__(self, url: str) -> None:
8196
"""
8297
self.url = url
8398
self.session_id: str | None = None
84-
# Captured from each stamped POST's metadata. Reused on outbound HTTP that carries
85-
# no per-message header (transport-internal GET/DELETE, and dispatcher-written
86-
# response/error/cancel POSTs that bypass the session's stamp). Cleared when an
87-
# `initialize` POST goes out so a probe-stamped value cannot leak onto the handshake.
99+
# Captured from each stamped message's metadata, synchronously in the
100+
# post_writer loop so the cache always reflects wire order (a POST task's
101+
# scheduling is arbitrary). Reused on outbound HTTP that carries no
102+
# per-message header (transport-internal GET/DELETE, and dispatcher-written
103+
# response/error POSTs that bypass the session's stamp), and consulted by
104+
# `_consume_modern_cancellation`. Cleared when an `initialize` message is
105+
# dequeued so a probe-stamped value cannot leak onto the handshake.
88106
self._protocol_version_header: str | None = None
107+
# Every request's POST runs inside one of these so an outbound
108+
# `notifications/cancelled` at 2026 can abort it; see
109+
# `_consume_modern_cancellation`. Keys are verbatim-typed ("1" is not 1).
110+
self._in_flight_posts: dict[RequestId, _InFlightPost] = {}
89111

90112
def _prepare_headers(self) -> dict[str, str]:
91113
"""Build MCP-specific request headers for any outbound HTTP request.
92114
93115
These are merged with the ``httpx.AsyncClient`` defaults (these take
94116
precedence). The cached ``MCP-Protocol-Version`` is included whenever
95117
present so messages that don't pass through the session's stamp —
96-
response/error/cancel POSTs, transport-internal GET/DELETE — still
97-
carry the negotiated version. Per-message headers are layered on top
98-
by the caller.
118+
response/error POSTs, legacy cancel frames, transport-internal
119+
GET/DELETE — still carry the negotiated version. Per-message headers
120+
are layered on top by the caller.
99121
"""
100122
headers: dict[str, str] = {
101123
"accept": "application/json, text/event-stream",
@@ -245,19 +267,57 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
245267
await event_source.response.aclose()
246268
break
247269

270+
def _consume_modern_cancellation(self, session_message: SessionMessage) -> bool:
271+
"""Translate an outbound `notifications/cancelled` at 2026; True means "do not POST".
272+
273+
The 2026 wire defines no client-to-server notifications over streamable
274+
HTTP: closing a request's response stream IS its cancellation signal.
275+
The dispatcher still emits the courtesy frame as its abandon signal
276+
(every outbound cancel names one of our own request ids - the spec
277+
forbids cancelling a request the sender did not issue), so this
278+
transport translates it: when the named request's POST is in flight,
279+
that POST's own recorded era decides - abort-and-swallow at 2026, POST
280+
the frame below it (where the frame is the signal and a disconnect
281+
explicitly is not). With no POST to consult, the cached negotiated
282+
version decides; at 2026 the frame is swallowed even unmatched, so a
283+
late cancel racing the response cannot leak onto the wire.
284+
"""
285+
message = session_message.message
286+
if not (isinstance(message, JSONRPCNotification) and message.method == "notifications/cancelled"):
287+
return False
288+
request_id = cancelled_request_id_from_params(message.params)
289+
post = self._in_flight_posts.get(request_id) if request_id is not None else None
290+
if post is not None:
291+
if not post.modern:
292+
return False
293+
logger.debug("aborting in-flight POST for cancelled request %r", request_id)
294+
post.scope.cancel()
295+
return True
296+
return self._protocol_version_header in MODERN_PROTOCOL_VERSIONS
297+
298+
async def _run_request_post(
299+
self,
300+
post_fn: Callable[[], Awaitable[None]],
301+
post: _InFlightPost,
302+
request_id: RequestId,
303+
) -> None:
304+
"""Run one request's POST inside its abort scope (see `_consume_modern_cancellation`)."""
305+
try:
306+
with post.scope:
307+
await post_fn()
308+
finally:
309+
# Identity-guarded: a reused id may already have a successor
310+
# registered while this task unwinds - popping by key alone would
311+
# evict the live entry and leave the new POST unabortable.
312+
if self._in_flight_posts.get(request_id) is post:
313+
del self._in_flight_posts[request_id]
314+
248315
async def _handle_post_request(self, ctx: RequestContext) -> None:
249316
"""Handle a POST request with response processing."""
250317
message = ctx.session_message.message
251-
is_initialization = self._is_initialization_request(message)
252-
if is_initialization:
253-
# `initialize` is the negotiation, not a "subsequent request" — discard any
254-
# probe-stamped value so the discover→fallback path can't leak it onto the handshake.
255-
self._protocol_version_header = None
256318
headers = self._prepare_headers()
257319
if ctx.metadata is not None and ctx.metadata.headers is not None:
258320
headers.update(ctx.metadata.headers)
259-
if MCP_PROTOCOL_VERSION_HEADER in ctx.metadata.headers:
260-
self._protocol_version_header = ctx.metadata.headers[MCP_PROTOCOL_VERSION_HEADER]
261321

262322
async with ctx.client.stream(
263323
"POST",
@@ -302,7 +362,7 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
302362
await ctx.read_stream_writer.send(session_message)
303363
return
304364

305-
if is_initialization:
365+
if self._is_initialization_request(message):
306366
self._maybe_extract_session_id_from_response(response)
307367

308368
# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
@@ -455,6 +515,8 @@ async def post_writer(
455515

456516
async def _handle_message(session_message: SessionMessage) -> None:
457517
message = session_message.message
518+
if self._consume_modern_cancellation(session_message):
519+
return
458520
metadata = (
459521
session_message.metadata
460522
if isinstance(session_message.metadata, ClientMessageMetadata)
@@ -470,6 +532,15 @@ async def _handle_message(session_message: SessionMessage) -> None:
470532
if self._is_initialized_notification(message):
471533
start_get_stream()
472534

535+
if self._is_initialization_request(message):
536+
# `initialize` is the negotiation, not a "subsequent request" — discard any
537+
# probe-stamped value so the discover→fallback path can't leak it onto the handshake.
538+
self._protocol_version_header = None
539+
elif metadata is not None and metadata.headers is not None:
540+
stamped_version = metadata.headers.get(MCP_PROTOCOL_VERSION_HEADER)
541+
if stamped_version is not None:
542+
self._protocol_version_header = stamped_version
543+
473544
ctx = RequestContext(
474545
client=client,
475546
session_id=self.session_id,
@@ -486,7 +557,15 @@ async def handle_request_async():
486557

487558
# If this is a request, start a new task to handle it
488559
if isinstance(message, JSONRPCRequest):
489-
tg.start_soon(handle_request_async)
560+
# Register the abort scope before the spawn: the next
561+
# message through this loop can already be the abandon
562+
# signal for this id, ahead of the task ever running.
563+
post = _InFlightPost(
564+
scope=anyio.CancelScope(),
565+
modern=self._protocol_version_header in MODERN_PROTOCOL_VERSIONS,
566+
)
567+
self._in_flight_posts[message.id] = post
568+
tg.start_soon(self._run_request_post, handle_request_async, post, message.id)
490569
else:
491570
await handle_request_async()
492571

src/mcp/shared/direct_dispatcher.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pydantic import ValidationError
2929

3030
from mcp.shared._compat import resync_tracer
31-
from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT
31+
from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT, coerce_request_id
3232
from mcp.shared.exceptions import MCPError, NoBackChannelError
3333
from mcp.shared.message import MessageMetadata
3434
from mcp.shared.transport_context import TransportContext
@@ -56,7 +56,8 @@ class _DirectDispatchContext:
5656
_back_request: _Request
5757
_back_notify: _Notify
5858
request_id: RequestId | None = None
59-
"""A dispatcher-synthesized id for requests; `None` for notifications."""
59+
"""The caller-supplied `CallOptions["request_id"]`, else a dispatcher-synthesized
60+
id for requests; `None` for notifications."""
6061
message_metadata: MessageMetadata = None # TODO(maxisbey): remove for Context rework
6162
"""Always `None`: in-memory dispatch attaches no transport metadata."""
6263
_on_progress: ProgressFnT | None = None
@@ -106,6 +107,7 @@ def __init__(self, transport_ctx: TransportContext, *, raise_handler_exceptions:
106107
self._on_request: OnRequest | None = None
107108
self._on_notify: OnNotify | None = None
108109
self._next_id = 0
110+
self._in_flight_ids: set[RequestId] = set()
109111
self._ready = anyio.Event()
110112
self._close_event = anyio.Event()
111113
self._running = False
@@ -227,9 +229,28 @@ async def _dispatch_request(
227229
# waiting on a peer whose run() has not started yet.
228230
await self._wait_ready()
229231
assert self._on_request is not None
230-
# Synthesize an id: the DispatchContext contract reserves None for notifications.
231-
self._next_id += 1
232-
dctx = self._make_context(on_progress=opts.get("on_progress"), request_id=self._next_id)
232+
supplied_id = opts.get("request_id")
233+
if supplied_id is not None:
234+
request_id: RequestId = supplied_id
235+
# Collisions use the same coerced domain as JSONRPCDispatcher's
236+
# pending keys, so this in-memory stand-in raises for exactly
237+
# the ids the wire dispatcher would; the context still sees
238+
# the verbatim value.
239+
in_flight_key = coerce_request_id(request_id)
240+
if in_flight_key in self._in_flight_ids:
241+
raise ValueError(f"request id {request_id!r} is already in flight")
242+
else:
243+
# Synthesize an id (the DispatchContext contract reserves None
244+
# for notifications), minting past any key a supplied id
245+
# occupies: the collision error is reserved for the caller
246+
# who actually chose the id.
247+
self._next_id += 1
248+
while self._next_id in self._in_flight_ids:
249+
self._next_id += 1
250+
request_id = self._next_id
251+
in_flight_key = request_id
252+
self._in_flight_ids.add(in_flight_key)
253+
dctx = self._make_context(on_progress=opts.get("on_progress"), request_id=request_id)
233254
try:
234255
return await self._on_request(dctx, method, params)
235256
except MCPError:
@@ -247,6 +268,8 @@ async def _dispatch_request(
247268
raise MCPError(code=INTERNAL_ERROR, message=str(e)) from e
248269
logger.exception("request handler raised")
249270
raise MCPError(code=INTERNAL_ERROR, message="Internal server error") from None
271+
finally:
272+
self._in_flight_ids.discard(in_flight_key)
250273
except TimeoutError:
251274
raise MCPError(
252275
code=REQUEST_TIMEOUT,

src/mcp/shared/dispatcher.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,26 @@
3434
"OnRequest",
3535
"Outbound",
3636
"ProgressFnT",
37+
"coerce_request_id",
3738
]
3839

3940
TransportT_co = TypeVar("TransportT_co", bound=TransportContext, covariant=True)
4041

4142

43+
def coerce_request_id(request_id: RequestId) -> RequestId:
44+
"""Coerce a stringified int request id back to int so a peer-echoed id still correlates (matches the TS SDK).
45+
46+
This is the collision/correlation domain dispatchers share: "7" and 7 are one
47+
id for correlation purposes, even where the wire carries the verbatim value.
48+
"""
49+
if isinstance(request_id, str):
50+
try:
51+
return int(request_id)
52+
except ValueError:
53+
pass
54+
return request_id
55+
56+
4257
class ProgressFnT(Protocol):
4358
"""Callback invoked when a progress notification arrives for a pending request."""
4459

@@ -51,6 +66,18 @@ class CallOptions(TypedDict, total=False):
5166
All keys are optional. Dispatchers ignore keys they do not understand.
5267
"""
5368

69+
request_id: RequestId
70+
"""Send the request under this caller-supplied id instead of a dispatcher-minted one.
71+
72+
The peer sees the value verbatim ("7" stays a string). A value that collides
73+
with one of the sender's own in-flight request ids raises `ValueError`.
74+
Callers that need to know a request's id before its result arrives (a
75+
`subscriptions/listen` stream is demultiplexed by it) mint their own ids
76+
here; string ids that don't parse as integers can never collide with the
77+
dispatcher's minted sequence. Per the class contract, dispatchers that
78+
predate this key ignore it and mint as usual.
79+
"""
80+
5481
timeout: float
5582
"""Seconds to wait for a result before raising and sending `notifications/cancelled`."""
5683

0 commit comments

Comments
 (0)