diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 5c1459dff..c161a2866 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -17,9 +17,10 @@ async def run_server(): ``` """ +import os import sys from contextlib import asynccontextmanager -from io import TextIOWrapper +from io import TextIOWrapper, UnsupportedOperation import anyio import anyio.lowlevel @@ -29,6 +30,33 @@ async def run_server(): from mcp.shared.message import SessionMessage +def _wrap_stdin() -> tuple[anyio.AsyncFile[str], bool]: + """Wrap stdin as UTF-8 text without closing process stdio on exit.""" + try: + stdin_fd = os.dup(sys.stdin.fileno()) + except (AttributeError, OSError, UnsupportedOperation): + # Some tests and embedders replace sys.stdin with fileno-less in-memory + # streams. Reusing the caller-provided wrapper avoids closing it when the + # transport exits. + return anyio.wrap_file(sys.stdin), False + + stdin_buffer = os.fdopen(stdin_fd, "rb", closefd=True) + return anyio.wrap_file(TextIOWrapper(stdin_buffer, encoding="utf-8", errors="replace")), True + + +def _wrap_stdout() -> tuple[anyio.AsyncFile[str], bool]: + """Wrap stdout as UTF-8 text without closing process stdio on exit.""" + try: + stdout_fd = os.dup(sys.stdout.fileno()) + except (AttributeError, OSError, UnsupportedOperation): + # Match the fileno-less stdin fallback for in-memory test streams and + # embedders that provide file-like stdout objects. + return anyio.wrap_file(sys.stdout), False + + stdout_buffer = os.fdopen(stdout_fd, "wb", closefd=True) + return anyio.wrap_file(TextIOWrapper(stdout_buffer, encoding="utf-8")), True + + @asynccontextmanager async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.AsyncFile[str] | None = None): """Server transport for stdio: this communicates with an MCP client by reading @@ -38,10 +66,12 @@ async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio. # standard process handles. Encoding of stdin/stdout as text streams on # python is platform-dependent (Windows is particularly problematic), so we # re-wrap the underlying binary stream to ensure UTF-8. + close_stdin = False + close_stdout = False if not stdin: - stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8", errors="replace")) + stdin, close_stdin = _wrap_stdin() if not stdout: - stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) + stdout, close_stdout = _wrap_stdout() read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) write_stream, write_stream_reader = create_context_streams[SessionMessage](0) @@ -71,7 +101,13 @@ async def stdout_writer(): except anyio.ClosedResourceError: # pragma: no cover await anyio.lowlevel.checkpoint() - async with anyio.create_task_group() as tg: - tg.start_soon(stdin_reader) - tg.start_soon(stdout_writer) - yield read_stream, write_stream + try: + async with anyio.create_task_group() as tg: + tg.start_soon(stdin_reader) + tg.start_soon(stdout_writer) + yield read_stream, write_stream + finally: + if close_stdin: + await stdin.aclose() + if close_stdout: + await stdout.aclose() diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 054a157b3..84eb1fb50 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -1,7 +1,8 @@ import io import sys +import tempfile import threading -from collections.abc import AsyncIterator +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from io import TextIOWrapper @@ -67,33 +68,79 @@ async def test_stdio_server_round_trips_messages_over_injected_streams() -> None @pytest.mark.anyio -async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch) -> None: - """Non-UTF-8 stdin bytes surface as an in-stream exception without killing the stream. +async def test_stdio_server_supports_fileno_less_standard_streams(monkeypatch: pytest.MonkeyPatch) -> None: + """The default path supports in-memory stdio replacements without fileno().""" + request = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + raw_stdin = io.BytesIO(request.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n") + raw_stdout = io.BytesIO() - Invalid bytes are replaced with U+FFFD, fail JSON parsing, and arrive as an in-stream - exception; subsequent valid messages are still processed. - """ - # \xff\xfe are invalid UTF-8 start bytes. - valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") - raw_stdin = io.BytesIO(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n") - - # Replace sys.stdin with a wrapper whose .buffer is our raw bytes, so that - # stdio_server()'s default path wraps it with errors='replace'. - monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8")) - monkeypatch.setattr(sys, "stdout", TextIOWrapper(io.BytesIO(), encoding="utf-8")) + test_stdin = TextIOWrapper(raw_stdin, encoding="utf-8") + test_stdout = TextIOWrapper(raw_stdout, encoding="utf-8") + monkeypatch.setattr(sys, "stdin", test_stdin) + monkeypatch.setattr(sys, "stdout", test_stdout) with anyio.fail_after(5): async with stdio_server() as (read_stream, write_stream): await write_stream.aclose() async with read_stream: # pragma: no branch - # First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> exception in stream - first = await read_stream.receive() - assert isinstance(first, Exception) + message = await read_stream.receive() + assert isinstance(message, SessionMessage) + assert message.message == request + + assert not test_stdin.closed + assert not test_stdout.closed + test_stdin.seek(0) + assert test_stdin.readline() == request.model_dump_json(by_alias=True, exclude_none=True) + "\n" + test_stdout.write("stdio still open") + test_stdout.flush() + - # Second line: valid message still comes through - second = await read_stream.receive() - assert isinstance(second, SessionMessage) - assert second.message == valid +@pytest.mark.anyio +async def test_stdio_server_invalid_utf8() -> None: + """Non-UTF-8 stdin bytes surface as an in-stream exception without killing the stream. + + Invalid bytes are replaced with U+FFFD, fail JSON parsing, and arrive as an in-stream + exception; subsequent valid messages are still processed. + """ + # \xff\xfe are invalid UTF-8 start bytes. + valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + raw_stdin = tempfile.TemporaryFile() + raw_stdin.write(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n") + raw_stdin.seek(0) + raw_stdout = tempfile.TemporaryFile() + + # Replace sys.stdin/stdout with wrappers backed by real file descriptors so + # stdio_server()'s default path can duplicate them without closing the + # original process-level streams. + original_stdin = sys.stdin + original_stdout = sys.stdout + test_stdin = TextIOWrapper(raw_stdin, encoding="utf-8") + test_stdout = TextIOWrapper(raw_stdout, encoding="utf-8") + sys.stdin = test_stdin + sys.stdout = test_stdout + + try: + with anyio.fail_after(5): + async with stdio_server() as (read_stream, write_stream): + await write_stream.aclose() + async with read_stream: # pragma: no branch + # First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> exception in stream + first = await read_stream.receive() + assert isinstance(first, Exception) + + # Second line: valid message still comes through + second = await read_stream.receive() + assert isinstance(second, SessionMessage) + assert second.message == valid + + assert not sys.stdin.closed + assert not sys.stdout.closed + sys.stdout.write("stdio still open") + finally: + sys.stdin = original_stdin + sys.stdout = original_stdout + test_stdin.close() + test_stdout.close() class _KeepOpenBytesIO(io.BytesIO): @@ -151,7 +198,7 @@ def test_mcpserver_run_stdio_runs_lifespan_cleanup_after_stdin_closes(monkeypatc events: list[str] = [] @asynccontextmanager - async def lifespan(server: MCPServer) -> AsyncIterator[None]: + async def lifespan(server: MCPServer) -> AsyncGenerator[None, None]: events.append("setup") try: yield