Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,16 @@ def stamp(data: dict[str, Any], opts: CallOptions) -> None:
meta[PROTOCOL_VERSION_META_KEY] = protocol_version
meta[CLIENT_INFO_META_KEY] = client_info
meta[CLIENT_CAPABILITIES_META_KEY] = capabilities
opts["cancel_on_abandon"] = False
# `cancel_on_abandon` stays at the dispatcher default (True): the
# courtesy `notifications/cancelled` is the abandon signal. On the
# stream transports it is the 2026 wire's cancellation spelling; the
# streamable-HTTP transport translates it into aborting the request's
# own POST instead of writing it (the 2026 HTTP wire has no
# client-to-server notifications - closing the stream is the signal).
# The negotiation methods still opt out, mirroring `_preconnect_stamp`:
# the spec forbids cancelling them.
if data["method"] in ("initialize", "server/discover"):
opts["cancel_on_abandon"] = False
headers = opts.setdefault("headers", {})
headers[MCP_PROTOCOL_VERSION_HEADER] = protocol_version
headers[MCP_METHOD_HEADER] = data["method"]
Expand Down
111 changes: 95 additions & 16 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
RequestId,
jsonrpc_message_adapter,
)
from mcp_types.version import MODERN_PROTOCOL_VERSIONS
from pydantic import ValidationError

from mcp.client._transport import TransportStreams
from mcp.shared._compat import resync_tracer
from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams
from mcp.shared._httpx_utils import create_mcp_http_client
from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER
from mcp.shared.jsonrpc_dispatcher import cancelled_request_id_from_params
from mcp.shared.message import ClientMessageMetadata, SessionMessage

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -70,6 +72,19 @@ class RequestContext:
read_stream_writer: StreamWriter


@dataclass(slots=True)
class _InFlightPost:
"""A request POST in flight: its abort scope and the era it was sent under.

`modern` is the negotiated-version cache as of this request's dequeue, so a
later cancel frame is interpreted under the era the request actually ran
with, not whatever the cache says by then.
"""

scope: anyio.CancelScope
modern: bool


class StreamableHTTPTransport:
"""StreamableHTTP client transport implementation."""

Expand All @@ -81,21 +96,28 @@ def __init__(self, url: str) -> None:
"""
self.url = url
self.session_id: str | None = None
# Captured from each stamped POST's metadata. Reused on outbound HTTP that carries
# no per-message header (transport-internal GET/DELETE, and dispatcher-written
# response/error/cancel POSTs that bypass the session's stamp). Cleared when an
# `initialize` POST goes out so a probe-stamped value cannot leak onto the handshake.
# Captured from each stamped message's metadata, synchronously in the
# post_writer loop so the cache always reflects wire order (a POST task's
# scheduling is arbitrary). Reused on outbound HTTP that carries no
# per-message header (transport-internal GET/DELETE, and dispatcher-written
# response/error POSTs that bypass the session's stamp), and consulted by
# `_consume_modern_cancellation`. Cleared when an `initialize` message is
# dequeued so a probe-stamped value cannot leak onto the handshake.
self._protocol_version_header: str | None = None
# Every request's POST runs inside one of these so an outbound
# `notifications/cancelled` at 2026 can abort it; see
# `_consume_modern_cancellation`. Keys are verbatim-typed ("1" is not 1).
self._in_flight_posts: dict[RequestId, _InFlightPost] = {}

def _prepare_headers(self) -> dict[str, str]:
"""Build MCP-specific request headers for any outbound HTTP request.

These are merged with the ``httpx.AsyncClient`` defaults (these take
precedence). The cached ``MCP-Protocol-Version`` is included whenever
present so messages that don't pass through the session's stamp —
response/error/cancel POSTs, transport-internal GET/DELETE — still
carry the negotiated version. Per-message headers are layered on top
by the caller.
response/error POSTs, legacy cancel frames, transport-internal
GET/DELETE — still carry the negotiated version. Per-message headers
are layered on top by the caller.
"""
headers: dict[str, str] = {
"accept": "application/json, text/event-stream",
Expand Down Expand Up @@ -245,19 +267,57 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
await event_source.response.aclose()
break

def _consume_modern_cancellation(self, session_message: SessionMessage) -> bool:
"""Translate an outbound `notifications/cancelled` at 2026; True means "do not POST".

The 2026 wire defines no client-to-server notifications over streamable
HTTP: closing a request's response stream IS its cancellation signal.
The dispatcher still emits the courtesy frame as its abandon signal
(every outbound cancel names one of our own request ids - the spec
forbids cancelling a request the sender did not issue), so this
transport translates it: when the named request's POST is in flight,
that POST's own recorded era decides - abort-and-swallow at 2026, POST
the frame below it (where the frame is the signal and a disconnect
explicitly is not). With no POST to consult, the cached negotiated
version decides; at 2026 the frame is swallowed even unmatched, so a
late cancel racing the response cannot leak onto the wire.
"""
message = session_message.message
if not (isinstance(message, JSONRPCNotification) and message.method == "notifications/cancelled"):
return False
request_id = cancelled_request_id_from_params(message.params)
post = self._in_flight_posts.get(request_id) if request_id is not None else None
if post is not None:
if not post.modern:
return False
logger.debug("aborting in-flight POST for cancelled request %r", request_id)
post.scope.cancel()
return True
return self._protocol_version_header in MODERN_PROTOCOL_VERSIONS

async def _run_request_post(
self,
post_fn: Callable[[], Awaitable[None]],
post: _InFlightPost,
request_id: RequestId,
) -> None:
"""Run one request's POST inside its abort scope (see `_consume_modern_cancellation`)."""
try:
with post.scope:
await post_fn()
finally:
# Identity-guarded: a reused id may already have a successor
# registered while this task unwinds - popping by key alone would
# evict the live entry and leave the new POST unabortable.
if self._in_flight_posts.get(request_id) is post:
del self._in_flight_posts[request_id]

async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
message = ctx.session_message.message
is_initialization = self._is_initialization_request(message)
if is_initialization:
# `initialize` is the negotiation, not a "subsequent request" — discard any
# probe-stamped value so the discover→fallback path can't leak it onto the handshake.
self._protocol_version_header = None
headers = self._prepare_headers()
if ctx.metadata is not None and ctx.metadata.headers is not None:
headers.update(ctx.metadata.headers)
if MCP_PROTOCOL_VERSION_HEADER in ctx.metadata.headers:
self._protocol_version_header = ctx.metadata.headers[MCP_PROTOCOL_VERSION_HEADER]

async with ctx.client.stream(
"POST",
Expand Down Expand Up @@ -302,7 +362,7 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
await ctx.read_stream_writer.send(session_message)
return

if is_initialization:
if self._is_initialization_request(message):
self._maybe_extract_session_id_from_response(response)

# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
Expand Down Expand Up @@ -455,6 +515,8 @@ async def post_writer(

async def _handle_message(session_message: SessionMessage) -> None:
message = session_message.message
if self._consume_modern_cancellation(session_message):
return
metadata = (
session_message.metadata
if isinstance(session_message.metadata, ClientMessageMetadata)
Expand All @@ -470,6 +532,15 @@ async def _handle_message(session_message: SessionMessage) -> None:
if self._is_initialized_notification(message):
start_get_stream()

if self._is_initialization_request(message):
# `initialize` is the negotiation, not a "subsequent request" — discard any
# probe-stamped value so the discover→fallback path can't leak it onto the handshake.
self._protocol_version_header = None
elif metadata is not None and metadata.headers is not None:
stamped_version = metadata.headers.get(MCP_PROTOCOL_VERSION_HEADER)
if stamped_version is not None:
self._protocol_version_header = stamped_version

ctx = RequestContext(
client=client,
session_id=self.session_id,
Expand All @@ -486,7 +557,15 @@ async def handle_request_async():

# If this is a request, start a new task to handle it
if isinstance(message, JSONRPCRequest):
tg.start_soon(handle_request_async)
# Register the abort scope before the spawn: the next
# message through this loop can already be the abandon
# signal for this id, ahead of the task ever running.
post = _InFlightPost(
scope=anyio.CancelScope(),
modern=self._protocol_version_header in MODERN_PROTOCOL_VERSIONS,
)
self._in_flight_posts[message.id] = post
tg.start_soon(self._run_request_post, handle_request_async, post, message.id)
else:
await handle_request_async()

Expand Down
33 changes: 28 additions & 5 deletions src/mcp/shared/direct_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pydantic import ValidationError

from mcp.shared._compat import resync_tracer
from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT
from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT, coerce_request_id
from mcp.shared.exceptions import MCPError, NoBackChannelError
from mcp.shared.message import MessageMetadata
from mcp.shared.transport_context import TransportContext
Expand Down Expand Up @@ -56,7 +56,8 @@ class _DirectDispatchContext:
_back_request: _Request
_back_notify: _Notify
request_id: RequestId | None = None
"""A dispatcher-synthesized id for requests; `None` for notifications."""
"""The caller-supplied `CallOptions["request_id"]`, else a dispatcher-synthesized
id for requests; `None` for notifications."""
message_metadata: MessageMetadata = None # TODO(maxisbey): remove for Context rework
"""Always `None`: in-memory dispatch attaches no transport metadata."""
_on_progress: ProgressFnT | None = None
Expand Down Expand Up @@ -106,6 +107,7 @@ def __init__(self, transport_ctx: TransportContext, *, raise_handler_exceptions:
self._on_request: OnRequest | None = None
self._on_notify: OnNotify | None = None
self._next_id = 0
self._in_flight_ids: set[RequestId] = set()
self._ready = anyio.Event()
self._close_event = anyio.Event()
self._running = False
Expand Down Expand Up @@ -227,9 +229,28 @@ async def _dispatch_request(
# waiting on a peer whose run() has not started yet.
await self._wait_ready()
assert self._on_request is not None
# Synthesize an id: the DispatchContext contract reserves None for notifications.
self._next_id += 1
dctx = self._make_context(on_progress=opts.get("on_progress"), request_id=self._next_id)
supplied_id = opts.get("request_id")
if supplied_id is not None:
request_id: RequestId = supplied_id
# Collisions use the same coerced domain as JSONRPCDispatcher's
# pending keys, so this in-memory stand-in raises for exactly
# the ids the wire dispatcher would; the context still sees
# the verbatim value.
in_flight_key = coerce_request_id(request_id)
if in_flight_key in self._in_flight_ids:
raise ValueError(f"request id {request_id!r} is already in flight")
else:
# Synthesize an id (the DispatchContext contract reserves None
# for notifications), minting past any key a supplied id
# occupies: the collision error is reserved for the caller
# who actually chose the id.
self._next_id += 1
while self._next_id in self._in_flight_ids:
self._next_id += 1
request_id = self._next_id
in_flight_key = request_id
self._in_flight_ids.add(in_flight_key)
dctx = self._make_context(on_progress=opts.get("on_progress"), request_id=request_id)
try:
return await self._on_request(dctx, method, params)
except MCPError:
Expand All @@ -247,6 +268,8 @@ async def _dispatch_request(
raise MCPError(code=INTERNAL_ERROR, message=str(e)) from e
logger.exception("request handler raised")
raise MCPError(code=INTERNAL_ERROR, message="Internal server error") from None
finally:
self._in_flight_ids.discard(in_flight_key)
except TimeoutError:
raise MCPError(
code=REQUEST_TIMEOUT,
Expand Down
27 changes: 27 additions & 0 deletions src/mcp/shared/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,26 @@
"OnRequest",
"Outbound",
"ProgressFnT",
"coerce_request_id",
]

TransportT_co = TypeVar("TransportT_co", bound=TransportContext, covariant=True)


def coerce_request_id(request_id: RequestId) -> RequestId:
"""Coerce a stringified int request id back to int so a peer-echoed id still correlates (matches the TS SDK).

This is the collision/correlation domain dispatchers share: "7" and 7 are one
id for correlation purposes, even where the wire carries the verbatim value.
"""
if isinstance(request_id, str):
try:
return int(request_id)
except ValueError:
pass
return request_id


class ProgressFnT(Protocol):
"""Callback invoked when a progress notification arrives for a pending request."""

Expand All @@ -51,6 +66,18 @@ class CallOptions(TypedDict, total=False):
All keys are optional. Dispatchers ignore keys they do not understand.
"""

request_id: RequestId
"""Send the request under this caller-supplied id instead of a dispatcher-minted one.

The peer sees the value verbatim ("7" stays a string). A value that collides
with one of the sender's own in-flight request ids raises `ValueError`.
Callers that need to know a request's id before its result arrives (a
`subscriptions/listen` stream is demultiplexed by it) mint their own ids
here; string ids that don't parse as integers can never collide with the
dispatcher's minted sequence. Per the class contract, dispatchers that
predate this key ignore it and mint as usual.
"""

timeout: float
"""Seconds to wait for a result before raising and sending `notifications/cancelled`."""

Expand Down
Loading
Loading