From 08c7fa18e4fc1b52a5bc3f170f600d91ada45186 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=86=AF=E5=9F=BA=E9=AD=81?= <1412414664@qq.com> Date: Sun, 7 Jun 2026 18:18:39 +0800 Subject: [PATCH] fix: drain stdio responses after stdin EOF --- src/mcp/server/lowlevel/server.py | 4 ++ src/mcp/server/mcpserver/server.py | 1 + src/mcp/server/runner.py | 4 ++ src/mcp/shared/jsonrpc_dispatcher.py | 27 ++++++++++++ tests/server/test_cancel_handling.py | 65 ++++++++++++++++++++++++++++ tests/server/test_stdio.py | 57 +++++++++++++++++++++++- 6 files changed, 157 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 22ce0dca50..c0f288a0d0 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -436,6 +436,8 @@ async def run( # but also make tracing exceptions much easier during testing and when using # in-process servers. raise_exceptions: bool = False, + drain_in_flight_on_read_eof: bool = False, + read_eof_response_drain_timeout: float = 5.0, ) -> None: """Serve a single connection over the given streams until the read side closes. @@ -448,6 +450,8 @@ async def run( self, read_stream, write_stream, + drain_in_flight_on_read_eof=drain_in_flight_on_read_eof, + read_eof_response_drain_timeout=read_eof_response_drain_timeout, lifespan_state=lifespan_context, init_options=initialization_options, raise_exceptions=raise_exceptions, diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 2064bd60cd..4ac285fe6f 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -838,6 +838,7 @@ async def run_stdio_async(self) -> None: read_stream, write_stream, self._lowlevel_server.create_initialization_options(), + drain_in_flight_on_read_eof=True, ) async def run_sse_async( # pragma: no cover diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 6b64ce9c49..a4d3f10217 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -447,6 +447,8 @@ async def serve_loop( session_id: str | None = None, init_options: InitializationOptions | None = None, raise_exceptions: bool = False, + drain_in_flight_on_read_eof: bool = False, + read_eof_response_drain_timeout: float = 5.0, ) -> None: """Drive ``server`` in loop mode over a stream pair until the channel closes. @@ -460,6 +462,8 @@ async def serve_loop( read_stream, write_stream, raise_handler_exceptions=raise_exceptions, + drain_in_flight_on_read_eof=drain_in_flight_on_read_eof, + read_eof_response_drain_timeout=read_eof_response_drain_timeout, # Handle `initialize` inline so a client that pipelines it with the # next request (spec: SHOULD NOT, not MUST NOT) sees the initialized # state instead of failing the init-gate. diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 7fabafff65..aea940fdd9 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -239,6 +239,8 @@ def __init__( raise_handler_exceptions: bool = False, inline_methods: frozenset[str] = frozenset(), on_stream_exception: Callable[[Exception], Awaitable[None]] | None = None, + drain_in_flight_on_read_eof: bool = False, + read_eof_response_drain_timeout: float = 5.0, ) -> None: """Wire a dispatcher over a transport's `SessionMessage` stream pair. @@ -264,12 +266,23 @@ def __init__( ) self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode self._raise_handler_exceptions = raise_handler_exceptions + self._drain_in_flight_on_read_eof = drain_in_flight_on_read_eof + self._read_eof_response_drain_timeout = read_eof_response_drain_timeout + # Request methods handled inline in the read loop (awaited before the + # next message is dequeued) instead of spawned concurrently. Use for + # methods whose side effects must be observable to the next message, + # e.g. `initialize`, so a pipelined follow-up sees the initialized state. + # Only suitable for handlers that complete quickly, since inline handling + # blocks dequeuing; a handler that awaits the peer (`send_raw_request`) + # while inline will deadlock because the parked read loop cannot dequeue + # the response. self._inline_methods = inline_methods self._on_stream_exception = on_stream_exception self._next_id = 0 self._pending: dict[RequestId, _Pending] = {} self._in_flight: dict[RequestId, _InFlight[TransportT]] = {} + self._responses_in_flight: set[RequestId] = set() self._tg: anyio.abc.TaskGroup | None = None self._running = False self._closed = False @@ -451,6 +464,12 @@ async def run( except anyio.ClosedResourceError: # Receive end closed under us (stateless SHTTP teardown); same as EOF. logger.debug("read stream closed by transport; treating as EOF") + if self._drain_in_flight_on_read_eof: + with anyio.move_on_after(self._read_eof_response_drain_timeout) as scope: + while self._in_flight or self._responses_in_flight: + await anyio.sleep(0) + if scope.cancelled_caught: + logger.debug("timed out draining in-flight responses after read EOF") # EOF: wake blocked `send_raw_request` waiters with CONNECTION_CLOSED. self._running = False self._closed = True @@ -722,16 +741,24 @@ async def _write(self, message: JSONRPCMessage, metadata: MessageMetadata = None await self._write_stream.send(SessionMessage(message=message, metadata=metadata)) async def _write_result(self, request_id: RequestId, result: dict[str, Any]) -> None: + key = _coerce_id(request_id) + self._responses_in_flight.add(key) try: await self._write(JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result)) except (anyio.BrokenResourceError, anyio.ClosedResourceError): logger.debug("dropped result for %r: write stream closed", request_id) + finally: + self._responses_in_flight.discard(key) async def _write_error(self, request_id: RequestId, error: ErrorData) -> None: + key = _coerce_id(request_id) + self._responses_in_flight.add(key) try: await self._write(JSONRPCError(jsonrpc="2.0", id=request_id, error=error)) except (anyio.BrokenResourceError, anyio.ClosedResourceError): logger.debug("dropped error for %r: write stream closed", request_id) + finally: + self._responses_in_flight.discard(key) async def _final_write( self, diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 0744e63022..46d4d47582 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -172,6 +172,71 @@ async def run_server(): assert handler_cancelled.is_set() +@pytest.mark.anyio +async def test_server_cancels_in_flight_handlers_when_read_eof_drain_times_out(): + """A bounded read-EOF drain still cancels handlers that never finish.""" + handler_started = anyio.Event() + handler_cancelled = anyio.Event() + server_run_returned = anyio.Event() + + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + handler_started.set() + try: + await anyio.sleep_forever() + finally: + handler_cancelled.set() + raise AssertionError # pragma: no cover + + server = Server("test", on_call_tool=handle_call_tool) + + to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) + server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server(): + await server.run( + server_read, + server_write, + server.create_initialization_options(), + drain_in_flight_on_read_eof=True, + read_eof_response_drain_timeout=0.01, + ) + server_run_returned.set() + + init_req = JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="test", version="1.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized") + call_req = JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"), + ) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server: + tg.start_soon(run_server) + + await to_server.send(SessionMessage(init_req)) + await from_server.receive() + await to_server.send(SessionMessage(initialized)) + await to_server.send(SessionMessage(call_req)) + + await handler_started.wait() + await to_server.aclose() + + await server_run_returned.wait() + + assert handler_cancelled.is_set() + + @pytest.mark.anyio async def test_server_handles_transport_close_with_pending_server_to_client_requests(): """When the transport closes while handlers are blocked on server→client diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 054a157b3b..544a74412a 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -1,4 +1,5 @@ import io +import json import sys import threading from collections.abc import AsyncIterator @@ -7,11 +8,12 @@ import anyio import pytest +from anyio.lowlevel import checkpoint from mcp.server.mcpserver import MCPServer from mcp.server.stdio import stdio_server from mcp.shared.message import SessionMessage -from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter +from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter @pytest.mark.anyio @@ -142,6 +144,59 @@ def test_mcpserver_run_stdio_serves_until_stdin_closes(monkeypatch: pytest.Monke assert response == JSONRPCResponse(jsonrpc="2.0", id=1, result={}) +def test_mcpserver_run_stdio_drains_in_flight_tool_responses_after_stdin_eof( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """stdin EOF must not drop responses for requests the server already accepted.""" + server = MCPServer(name="DrainStdioServer") + + @server.tool() + async def slow_echo(text: str) -> str: + await checkpoint() + return text + + payload_lines = [ + JSONRPCRequest( + jsonrpc="2.0", + id=0, + method="initialize", + params={ + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "stdio-replay", "version": "0.1"}, + }, + ).model_dump_json(by_alias=True, exclude_none=True), + JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized", params={}).model_dump_json( + by_alias=True, exclude_none=True + ), + JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="tools/call", + params={"name": "slow_echo", "arguments": {"text": "first"}}, + ).model_dump_json(by_alias=True, exclude_none=True), + JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "slow_echo", "arguments": {"text": "second"}}, + ).model_dump_json(by_alias=True, exclude_none=True), + ] + stdin_bytes = io.BytesIO(("\n".join(payload_lines) + "\n").encode()) + captured = _KeepOpenBytesIO() + monkeypatch.setattr(sys, "stdin", TextIOWrapper(stdin_bytes, encoding="utf-8")) + monkeypatch.setattr(sys, "stdout", TextIOWrapper(captured, encoding="utf-8")) + + _run_stdio_bounded(server) + + output = captured.getvalue().decode() + responses = [json.loads(line) for line in output.splitlines() if line] + + assert [response["id"] for response in responses] == [0, 1, 2] + assert responses[1]["result"]["content"][0]["text"] == "first" + assert responses[2]["result"]["content"][0]["text"] == "second" + + def test_mcpserver_run_stdio_runs_lifespan_cleanup_after_stdin_closes(monkeypatch: pytest.MonkeyPatch) -> None: """Code after `yield` in a lifespan runs when stdin EOF ends `run("stdio")`.