From 9eb10e3b903cb673baddfb8500bb58394f900276 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=86=AF=E5=9F=BA=E9=AD=81?= <1412414664@qq.com> Date: Sun, 7 Jun 2026 18:18:39 +0800 Subject: [PATCH] fix: drain stdio responses after stdin EOF --- src/mcp/server/lowlevel/server.py | 24 +++++++--- src/mcp/server/mcpserver/server.py | 1 + src/mcp/server/session.py | 7 ++- src/mcp/shared/session.py | 17 +++++++- tests/server/test_cancel_handling.py | 65 ++++++++++++++++++++++++++++ tests/server/test_stdio.py | 57 +++++++++++++++++++++++- 6 files changed, 162 insertions(+), 9 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 37127c5621..df4d6cabbc 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -45,6 +45,7 @@ async def main(): from typing import Any, Generic, cast import anyio +from anyio.lowlevel import checkpoint from opentelemetry.trace import SpanKind, StatusCode from starlette.applications import Starlette from starlette.middleware import Middleware @@ -74,6 +75,8 @@ async def main(): LifespanResultT = TypeVar("LifespanResultT", default=Any) +STDIO_READ_EOF_RESPONSE_DRAIN_TIMEOUT = 5.0 + class NotificationOptions: def __init__(self, prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False): @@ -347,6 +350,8 @@ async def run( # the initialization lifecycle, but can do so with any available node # rather than requiring initialization for each connection. stateless: bool = False, + drain_in_flight_on_read_eof: bool = False, + read_eof_response_drain_timeout: float = STDIO_READ_EOF_RESPONSE_DRAIN_TIMEOUT, ): async with AsyncExitStack() as stack: lifespan_context = await stack.enter_async_context(self.lifespan(self)) @@ -356,6 +361,7 @@ async def run( write_stream, initialization_options, stateless=stateless, + close_write_stream_on_read_eof=not drain_in_flight_on_read_eof, ) ) @@ -378,11 +384,19 @@ async def run( raise_exceptions, ) finally: - # Transport closed: cancel in-flight handlers. Without this the - # TG join waits for them, and when they eventually try to - # respond they hit a closed write stream (the session's - # _receive_loop closed it when the read stream ended). - tg.cancel_scope.cancel() + cancel_in_flight = True + if drain_in_flight_on_read_eof: + with anyio.move_on_after(read_eof_response_drain_timeout) as drain_scope: + while session.has_in_flight_requests: + await checkpoint() + cancel_in_flight = drain_scope.cancel_called + + # Transport closed or drain timed out: cancel in-flight handlers. + # Without this the TG join can wait indefinitely, or handlers can + # eventually try to respond through a write stream that the session + # closed when the read stream ended. + if cancel_in_flight: + tg.cancel_scope.cancel() async def _handle_message( self, diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index ec2365810e..4121db60b8 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -852,6 +852,7 @@ async def run_stdio_async(self) -> None: read_stream, write_stream, self._lowlevel_server.create_initialization_options(), + drain_in_flight_on_read_eof=True, ) async def run_sse_async( # pragma: no cover diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 3fc7bbf0d3..67d423451e 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -80,8 +80,13 @@ def __init__( write_stream: WriteStream[SessionMessage], init_options: InitializationOptions, stateless: bool = False, + close_write_stream_on_read_eof: bool = True, ) -> None: - super().__init__(read_stream, write_stream) + super().__init__( + read_stream, + write_stream, + close_write_stream_on_read_eof=close_write_stream_on_read_eof, + ) self._stateless = stateless self._initialization_state = ( InitializationState.Initialized if stateless else InitializationState.NotInitialized diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index ea5d8833bd..86a6ab81ba 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -189,12 +189,14 @@ def __init__( write_stream: WriteStream[SessionMessage], # If none, reading will never time out read_timeout_seconds: float | None = None, + close_write_stream_on_read_eof: bool = True, ) -> None: self._read_stream = read_stream self._write_stream = write_stream self._response_streams = {} self._request_id = 0 self._session_read_timeout_seconds = read_timeout_seconds + self._close_write_stream_on_read_eof = close_write_stream_on_read_eof self._in_flight = {} self._progress_callbacks = {} self._exit_stack = AsyncExitStack() @@ -216,7 +218,10 @@ async def __aexit__( # would be very surprising behavior), so make sure to cancel the tasks # in the task group. self._task_group.cancel_scope.cancel() - return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + if not self._close_write_stream_on_read_eof: + await self._write_stream.aclose() + return result async def send_request( self, @@ -331,7 +336,10 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]: raise NotImplementedError async def _receive_loop(self) -> None: - async with self._read_stream, self._write_stream: + async with AsyncExitStack() as stream_stack: + await stream_stack.enter_async_context(self._read_stream) + if self._close_write_stream_on_read_eof: + await stream_stack.enter_async_context(self._write_stream) try: async def _handle_session_message(message: SessionMessage) -> None: @@ -438,6 +446,11 @@ async def _handle_session_message(message: SessionMessage) -> None: pass self._response_streams.clear() + @property + def has_in_flight_requests(self) -> bool: + """Return whether client requests are still being handled.""" + return bool(self._in_flight) + def _normalize_request_id(self, response_id: RequestId) -> RequestId: """Normalize a response ID to match how request IDs are stored. diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index cff5a37c15..17b43c9c98 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -172,6 +172,71 @@ async def run_server(): assert handler_cancelled.is_set() +@pytest.mark.anyio +async def test_server_cancels_in_flight_handlers_when_read_eof_drain_times_out(): + """A bounded read-EOF drain still cancels handlers that never finish.""" + handler_started = anyio.Event() + handler_cancelled = anyio.Event() + server_run_returned = anyio.Event() + + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + handler_started.set() + try: + await anyio.sleep_forever() + finally: + handler_cancelled.set() + raise AssertionError # pragma: no cover + + server = Server("test", on_call_tool=handle_call_tool) + + to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) + server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server(): + await server.run( + server_read, + server_write, + server.create_initialization_options(), + drain_in_flight_on_read_eof=True, + read_eof_response_drain_timeout=0.01, + ) + server_run_returned.set() + + init_req = JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="test", version="1.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized") + call_req = JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"), + ) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server: + tg.start_soon(run_server) + + await to_server.send(SessionMessage(init_req)) + await from_server.receive() + await to_server.send(SessionMessage(initialized)) + await to_server.send(SessionMessage(call_req)) + + await handler_started.wait() + await to_server.aclose() + + await server_run_returned.wait() + + assert handler_cancelled.is_set() + + @pytest.mark.anyio async def test_server_handles_transport_close_with_pending_server_to_client_requests(): """When the transport closes while handlers are blocked on server→client diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 054a157b3b..544a74412a 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -1,4 +1,5 @@ import io +import json import sys import threading from collections.abc import AsyncIterator @@ -7,11 +8,12 @@ import anyio import pytest +from anyio.lowlevel import checkpoint from mcp.server.mcpserver import MCPServer from mcp.server.stdio import stdio_server from mcp.shared.message import SessionMessage -from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter +from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter @pytest.mark.anyio @@ -142,6 +144,59 @@ def test_mcpserver_run_stdio_serves_until_stdin_closes(monkeypatch: pytest.Monke assert response == JSONRPCResponse(jsonrpc="2.0", id=1, result={}) +def test_mcpserver_run_stdio_drains_in_flight_tool_responses_after_stdin_eof( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """stdin EOF must not drop responses for requests the server already accepted.""" + server = MCPServer(name="DrainStdioServer") + + @server.tool() + async def slow_echo(text: str) -> str: + await checkpoint() + return text + + payload_lines = [ + JSONRPCRequest( + jsonrpc="2.0", + id=0, + method="initialize", + params={ + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "stdio-replay", "version": "0.1"}, + }, + ).model_dump_json(by_alias=True, exclude_none=True), + JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized", params={}).model_dump_json( + by_alias=True, exclude_none=True + ), + JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="tools/call", + params={"name": "slow_echo", "arguments": {"text": "first"}}, + ).model_dump_json(by_alias=True, exclude_none=True), + JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "slow_echo", "arguments": {"text": "second"}}, + ).model_dump_json(by_alias=True, exclude_none=True), + ] + stdin_bytes = io.BytesIO(("\n".join(payload_lines) + "\n").encode()) + captured = _KeepOpenBytesIO() + monkeypatch.setattr(sys, "stdin", TextIOWrapper(stdin_bytes, encoding="utf-8")) + monkeypatch.setattr(sys, "stdout", TextIOWrapper(captured, encoding="utf-8")) + + _run_stdio_bounded(server) + + output = captured.getvalue().decode() + responses = [json.loads(line) for line in output.splitlines() if line] + + assert [response["id"] for response in responses] == [0, 1, 2] + assert responses[1]["result"]["content"][0]["text"] == "first" + assert responses[2]["result"]["content"][0]["text"] == "second" + + def test_mcpserver_run_stdio_runs_lifespan_cleanup_after_stdin_closes(monkeypatch: pytest.MonkeyPatch) -> None: """Code after `yield` in a lifespan runs when stdin EOF ends `run("stdio")`.