Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ async def main():
from typing import Any, Generic, cast

import anyio
from anyio.lowlevel import checkpoint
from opentelemetry.trace import SpanKind, StatusCode
from starlette.applications import Starlette
from starlette.middleware import Middleware
Expand Down Expand Up @@ -74,6 +75,8 @@ async def main():

LifespanResultT = TypeVar("LifespanResultT", default=Any)

STDIO_READ_EOF_RESPONSE_DRAIN_TIMEOUT = 5.0


class NotificationOptions:
def __init__(self, prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False):
Expand Down Expand Up @@ -347,6 +350,8 @@ async def run(
# the initialization lifecycle, but can do so with any available node
# rather than requiring initialization for each connection.
stateless: bool = False,
drain_in_flight_on_read_eof: bool = False,
read_eof_response_drain_timeout: float = STDIO_READ_EOF_RESPONSE_DRAIN_TIMEOUT,
):
async with AsyncExitStack() as stack:
lifespan_context = await stack.enter_async_context(self.lifespan(self))
Expand All @@ -356,6 +361,7 @@ async def run(
write_stream,
initialization_options,
stateless=stateless,
close_write_stream_on_read_eof=not drain_in_flight_on_read_eof,
)
)

Expand All @@ -378,11 +384,19 @@ async def run(
raise_exceptions,
)
finally:
# Transport closed: cancel in-flight handlers. Without this the
# TG join waits for them, and when they eventually try to
# respond they hit a closed write stream (the session's
# _receive_loop closed it when the read stream ended).
tg.cancel_scope.cancel()
cancel_in_flight = True
if drain_in_flight_on_read_eof:
with anyio.move_on_after(read_eof_response_drain_timeout) as drain_scope:
while session.has_in_flight_requests:
await checkpoint()
cancel_in_flight = drain_scope.cancel_called

# Transport closed or drain timed out: cancel in-flight handlers.
# Without this the TG join can wait indefinitely, or handlers can
# eventually try to respond through a write stream that the session
# closed when the read stream ended.
if cancel_in_flight:
tg.cancel_scope.cancel()

async def _handle_message(
self,
Expand Down
1 change: 1 addition & 0 deletions src/mcp/server/mcpserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,7 @@ async def run_stdio_async(self) -> None:
read_stream,
write_stream,
self._lowlevel_server.create_initialization_options(),
drain_in_flight_on_read_eof=True,
)

async def run_sse_async( # pragma: no cover
Expand Down
7 changes: 6 additions & 1 deletion src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,13 @@ def __init__(
write_stream: WriteStream[SessionMessage],
init_options: InitializationOptions,
stateless: bool = False,
close_write_stream_on_read_eof: bool = True,
) -> None:
super().__init__(read_stream, write_stream)
super().__init__(
read_stream,
write_stream,
close_write_stream_on_read_eof=close_write_stream_on_read_eof,
)
self._stateless = stateless
self._initialization_state = (
InitializationState.Initialized if stateless else InitializationState.NotInitialized
Expand Down
17 changes: 15 additions & 2 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,14 @@ def __init__(
write_stream: WriteStream[SessionMessage],
# If none, reading will never time out
read_timeout_seconds: float | None = None,
close_write_stream_on_read_eof: bool = True,
) -> None:
self._read_stream = read_stream
self._write_stream = write_stream
self._response_streams = {}
self._request_id = 0
self._session_read_timeout_seconds = read_timeout_seconds
self._close_write_stream_on_read_eof = close_write_stream_on_read_eof
self._in_flight = {}
self._progress_callbacks = {}
self._exit_stack = AsyncExitStack()
Expand All @@ -216,7 +218,10 @@ async def __aexit__(
# would be very surprising behavior), so make sure to cancel the tasks
# in the task group.
self._task_group.cancel_scope.cancel()
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
if not self._close_write_stream_on_read_eof:
await self._write_stream.aclose()
return result

async def send_request(
self,
Expand Down Expand Up @@ -331,7 +336,10 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
raise NotImplementedError

async def _receive_loop(self) -> None:
async with self._read_stream, self._write_stream:
async with AsyncExitStack() as stream_stack:
await stream_stack.enter_async_context(self._read_stream)
if self._close_write_stream_on_read_eof:
await stream_stack.enter_async_context(self._write_stream)
try:

async def _handle_session_message(message: SessionMessage) -> None:
Expand Down Expand Up @@ -438,6 +446,11 @@ async def _handle_session_message(message: SessionMessage) -> None:
pass
self._response_streams.clear()

@property
def has_in_flight_requests(self) -> bool:
"""Return whether client requests are still being handled."""
return bool(self._in_flight)

def _normalize_request_id(self, response_id: RequestId) -> RequestId:
"""Normalize a response ID to match how request IDs are stored.

Expand Down
65 changes: 65 additions & 0 deletions tests/server/test_cancel_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,71 @@ async def run_server():
assert handler_cancelled.is_set()


@pytest.mark.anyio
async def test_server_cancels_in_flight_handlers_when_read_eof_drain_times_out():
"""A bounded read-EOF drain still cancels handlers that never finish."""
handler_started = anyio.Event()
handler_cancelled = anyio.Event()
server_run_returned = anyio.Event()

async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
handler_started.set()
try:
await anyio.sleep_forever()
finally:
handler_cancelled.set()
raise AssertionError # pragma: no cover

server = Server("test", on_call_tool=handle_call_tool)

to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10)
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)

async def run_server():
await server.run(
server_read,
server_write,
server.create_initialization_options(),
drain_in_flight_on_read_eof=True,
read_eof_response_drain_timeout=0.01,
)
server_run_returned.set()

init_req = JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=InitializeRequestParams(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities(),
client_info=Implementation(name="test", version="1.0"),
).model_dump(by_alias=True, mode="json", exclude_none=True),
)
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
call_req = JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
)

with anyio.fail_after(5):
async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server:
tg.start_soon(run_server)

await to_server.send(SessionMessage(init_req))
await from_server.receive()
await to_server.send(SessionMessage(initialized))
await to_server.send(SessionMessage(call_req))

await handler_started.wait()
await to_server.aclose()

await server_run_returned.wait()

assert handler_cancelled.is_set()


@pytest.mark.anyio
async def test_server_handles_transport_close_with_pending_server_to_client_requests():
"""When the transport closes while handlers are blocked on server→client
Expand Down
57 changes: 56 additions & 1 deletion tests/server/test_stdio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
import json
import sys
import threading
from collections.abc import AsyncIterator
Expand All @@ -7,11 +8,12 @@

import anyio
import pytest
from anyio.lowlevel import checkpoint

from mcp.server.mcpserver import MCPServer
from mcp.server.stdio import stdio_server
from mcp.shared.message import SessionMessage
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter
from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter


@pytest.mark.anyio
Expand Down Expand Up @@ -142,6 +144,59 @@ def test_mcpserver_run_stdio_serves_until_stdin_closes(monkeypatch: pytest.Monke
assert response == JSONRPCResponse(jsonrpc="2.0", id=1, result={})


def test_mcpserver_run_stdio_drains_in_flight_tool_responses_after_stdin_eof(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""stdin EOF must not drop responses for requests the server already accepted."""
server = MCPServer(name="DrainStdioServer")

@server.tool()
async def slow_echo(text: str) -> str:
await checkpoint()
return text

payload_lines = [
JSONRPCRequest(
jsonrpc="2.0",
id=0,
method="initialize",
params={
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "stdio-replay", "version": "0.1"},
},
).model_dump_json(by_alias=True, exclude_none=True),
JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized", params={}).model_dump_json(
by_alias=True, exclude_none=True
),
JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="tools/call",
params={"name": "slow_echo", "arguments": {"text": "first"}},
).model_dump_json(by_alias=True, exclude_none=True),
JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params={"name": "slow_echo", "arguments": {"text": "second"}},
).model_dump_json(by_alias=True, exclude_none=True),
]
stdin_bytes = io.BytesIO(("\n".join(payload_lines) + "\n").encode())
captured = _KeepOpenBytesIO()
monkeypatch.setattr(sys, "stdin", TextIOWrapper(stdin_bytes, encoding="utf-8"))
monkeypatch.setattr(sys, "stdout", TextIOWrapper(captured, encoding="utf-8"))

_run_stdio_bounded(server)

output = captured.getvalue().decode()
responses = [json.loads(line) for line in output.splitlines() if line]

assert [response["id"] for response in responses] == [0, 1, 2]
assert responses[1]["result"]["content"][0]["text"] == "first"
assert responses[2]["result"]["content"][0]["text"] == "second"


def test_mcpserver_run_stdio_runs_lifespan_cleanup_after_stdin_closes(monkeypatch: pytest.MonkeyPatch) -> None:
"""Code after `yield` in a lifespan runs when stdin EOF ends `run("stdio")`.

Expand Down
Loading