Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,8 @@ async def run(
# but also make tracing exceptions much easier during testing and when using
# in-process servers.
raise_exceptions: bool = False,
drain_in_flight_on_read_eof: bool = False,
read_eof_response_drain_timeout: float = 5.0,
) -> None:
"""Serve a single connection over the given streams until the read side closes.

Expand All @@ -448,6 +450,8 @@ async def run(
self,
read_stream,
write_stream,
drain_in_flight_on_read_eof=drain_in_flight_on_read_eof,
read_eof_response_drain_timeout=read_eof_response_drain_timeout,
lifespan_state=lifespan_context,
init_options=initialization_options,
raise_exceptions=raise_exceptions,
Expand Down
1 change: 1 addition & 0 deletions src/mcp/server/mcpserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,7 @@ async def run_stdio_async(self) -> None:
read_stream,
write_stream,
self._lowlevel_server.create_initialization_options(),
drain_in_flight_on_read_eof=True,
)

async def run_sse_async( # pragma: no cover
Expand Down
4 changes: 4 additions & 0 deletions src/mcp/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,8 @@ async def serve_loop(
session_id: str | None = None,
init_options: InitializationOptions | None = None,
raise_exceptions: bool = False,
drain_in_flight_on_read_eof: bool = False,
read_eof_response_drain_timeout: float = 5.0,
) -> None:
"""Drive ``server`` in loop mode over a stream pair until the channel closes.

Expand All @@ -460,6 +462,8 @@ async def serve_loop(
read_stream,
write_stream,
raise_handler_exceptions=raise_exceptions,
drain_in_flight_on_read_eof=drain_in_flight_on_read_eof,
read_eof_response_drain_timeout=read_eof_response_drain_timeout,
# Handle `initialize` inline so a client that pipelines it with the
# next request (spec: SHOULD NOT, not MUST NOT) sees the initialized
# state instead of failing the init-gate.
Expand Down
27 changes: 27 additions & 0 deletions src/mcp/shared/jsonrpc_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ def __init__(
raise_handler_exceptions: bool = False,
inline_methods: frozenset[str] = frozenset(),
on_stream_exception: Callable[[Exception], Awaitable[None]] | None = None,
drain_in_flight_on_read_eof: bool = False,
read_eof_response_drain_timeout: float = 5.0,
) -> None:
"""Wire a dispatcher over a transport's `SessionMessage` stream pair.

Expand All @@ -264,12 +266,23 @@ def __init__(
)
self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode
self._raise_handler_exceptions = raise_handler_exceptions
self._drain_in_flight_on_read_eof = drain_in_flight_on_read_eof
self._read_eof_response_drain_timeout = read_eof_response_drain_timeout
# Request methods handled inline in the read loop (awaited before the
# next message is dequeued) instead of spawned concurrently. Use for
# methods whose side effects must be observable to the next message,
# e.g. `initialize`, so a pipelined follow-up sees the initialized state.
# Only suitable for handlers that complete quickly, since inline handling
# blocks dequeuing; a handler that awaits the peer (`send_raw_request`)
# while inline will deadlock because the parked read loop cannot dequeue
# the response.
self._inline_methods = inline_methods
self._on_stream_exception = on_stream_exception

self._next_id = 0
self._pending: dict[RequestId, _Pending] = {}
self._in_flight: dict[RequestId, _InFlight[TransportT]] = {}
self._responses_in_flight: set[RequestId] = set()
self._tg: anyio.abc.TaskGroup | None = None
self._running = False
self._closed = False
Expand Down Expand Up @@ -451,6 +464,12 @@ async def run(
except anyio.ClosedResourceError:
# Receive end closed under us (stateless SHTTP teardown); same as EOF.
logger.debug("read stream closed by transport; treating as EOF")
if self._drain_in_flight_on_read_eof:
with anyio.move_on_after(self._read_eof_response_drain_timeout) as scope:
while self._in_flight or self._responses_in_flight:
await anyio.sleep(0)
if scope.cancelled_caught:
logger.debug("timed out draining in-flight responses after read EOF")
# EOF: wake blocked `send_raw_request` waiters with CONNECTION_CLOSED.
self._running = False
self._closed = True
Expand Down Expand Up @@ -722,16 +741,24 @@ async def _write(self, message: JSONRPCMessage, metadata: MessageMetadata = None
await self._write_stream.send(SessionMessage(message=message, metadata=metadata))

async def _write_result(self, request_id: RequestId, result: dict[str, Any]) -> None:
key = _coerce_id(request_id)
self._responses_in_flight.add(key)
try:
await self._write(JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result))
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
logger.debug("dropped result for %r: write stream closed", request_id)
finally:
self._responses_in_flight.discard(key)

async def _write_error(self, request_id: RequestId, error: ErrorData) -> None:
key = _coerce_id(request_id)
self._responses_in_flight.add(key)
try:
await self._write(JSONRPCError(jsonrpc="2.0", id=request_id, error=error))
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
logger.debug("dropped error for %r: write stream closed", request_id)
finally:
self._responses_in_flight.discard(key)

async def _final_write(
self,
Expand Down
65 changes: 65 additions & 0 deletions tests/server/test_cancel_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,71 @@ async def run_server():
assert handler_cancelled.is_set()


@pytest.mark.anyio
async def test_server_cancels_in_flight_handlers_when_read_eof_drain_times_out():
"""A bounded read-EOF drain still cancels handlers that never finish."""
handler_started = anyio.Event()
handler_cancelled = anyio.Event()
server_run_returned = anyio.Event()

async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
handler_started.set()
try:
await anyio.sleep_forever()
finally:
handler_cancelled.set()
raise AssertionError # pragma: no cover

server = Server("test", on_call_tool=handle_call_tool)

to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10)
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)

async def run_server():
await server.run(
server_read,
server_write,
server.create_initialization_options(),
drain_in_flight_on_read_eof=True,
read_eof_response_drain_timeout=0.01,
)
server_run_returned.set()

init_req = JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=InitializeRequestParams(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities(),
client_info=Implementation(name="test", version="1.0"),
).model_dump(by_alias=True, mode="json", exclude_none=True),
)
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
call_req = JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
)

with anyio.fail_after(5):
async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server:
tg.start_soon(run_server)

await to_server.send(SessionMessage(init_req))
await from_server.receive()
await to_server.send(SessionMessage(initialized))
await to_server.send(SessionMessage(call_req))

await handler_started.wait()
await to_server.aclose()

await server_run_returned.wait()

assert handler_cancelled.is_set()


@pytest.mark.anyio
async def test_server_handles_transport_close_with_pending_server_to_client_requests():
"""When the transport closes while handlers are blocked on server→client
Expand Down
57 changes: 56 additions & 1 deletion tests/server/test_stdio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
import json
import sys
import threading
from collections.abc import AsyncIterator
Expand All @@ -7,11 +8,12 @@

import anyio
import pytest
from anyio.lowlevel import checkpoint

from mcp.server.mcpserver import MCPServer
from mcp.server.stdio import stdio_server
from mcp.shared.message import SessionMessage
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter
from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter


@pytest.mark.anyio
Expand Down Expand Up @@ -142,6 +144,59 @@ def test_mcpserver_run_stdio_serves_until_stdin_closes(monkeypatch: pytest.Monke
assert response == JSONRPCResponse(jsonrpc="2.0", id=1, result={})


def test_mcpserver_run_stdio_drains_in_flight_tool_responses_after_stdin_eof(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""stdin EOF must not drop responses for requests the server already accepted."""
server = MCPServer(name="DrainStdioServer")

@server.tool()
async def slow_echo(text: str) -> str:
await checkpoint()
return text

payload_lines = [
JSONRPCRequest(
jsonrpc="2.0",
id=0,
method="initialize",
params={
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "stdio-replay", "version": "0.1"},
},
).model_dump_json(by_alias=True, exclude_none=True),
JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized", params={}).model_dump_json(
by_alias=True, exclude_none=True
),
JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="tools/call",
params={"name": "slow_echo", "arguments": {"text": "first"}},
).model_dump_json(by_alias=True, exclude_none=True),
JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params={"name": "slow_echo", "arguments": {"text": "second"}},
).model_dump_json(by_alias=True, exclude_none=True),
]
stdin_bytes = io.BytesIO(("\n".join(payload_lines) + "\n").encode())
captured = _KeepOpenBytesIO()
monkeypatch.setattr(sys, "stdin", TextIOWrapper(stdin_bytes, encoding="utf-8"))
monkeypatch.setattr(sys, "stdout", TextIOWrapper(captured, encoding="utf-8"))

_run_stdio_bounded(server)

output = captured.getvalue().decode()
responses = [json.loads(line) for line in output.splitlines() if line]

assert [response["id"] for response in responses] == [0, 1, 2]
assert responses[1]["result"]["content"][0]["text"] == "first"
assert responses[2]["result"]["content"][0]["text"] == "second"


def test_mcpserver_run_stdio_runs_lifespan_cleanup_after_stdin_closes(monkeypatch: pytest.MonkeyPatch) -> None:
"""Code after `yield` in a lifespan runs when stdin EOF ends `run("stdio")`.

Expand Down
Loading