From 65263112cdc6d9e4ea20d9636eb2200376a4bee2 Mon Sep 17 00:00:00 2001 From: Adam <62897873+YeIIcw@users.noreply.github.com> Date: Thu, 25 Sep 2025 16:07:53 -0400 Subject: [PATCH] Stream MCP responses instead of buffering --- .vscode/settings.json | 3 + libs/python/mcp-server/mcp_server/server.py | 160 ++++++++--- tests/test_mcp_server_streming.py | 303 ++++++++++++++++++++ 3 files changed, 429 insertions(+), 37 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 tests/test_mcp_server_streming.py diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..7e68766a --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python-envs.pythonProjects": [] +} \ No newline at end of file diff --git a/libs/python/mcp-server/mcp_server/server.py b/libs/python/mcp-server/mcp_server/server.py index 73996d5e..66aaceba 100644 --- a/libs/python/mcp-server/mcp_server/server.py +++ b/libs/python/mcp-server/mcp_server/server.py @@ -1,5 +1,6 @@ import asyncio import base64 +import inspect import logging import os import sys @@ -40,11 +41,66 @@ except ImportError as e: # Global computer instance for reuse global_computer = None - def get_env_bool(key: str, default: bool = False) -> bool: """Get boolean value from environment variable.""" return os.getenv(key, str(default)).lower() in ("true", "1", "yes") +async def _maybe_call_ctx_method(ctx: Context, method_name: str, *args, **kwargs) -> None: + """Call a context helper if it exists, awaiting the result when necessary.""" + + method = getattr(ctx, method_name, None) + if not callable(method): + return + + result = method(*args, **kwargs) + if inspect.isawaitable(result): + await result + + +def _normalise_message_content(content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]: + """Normalise message content to a list of structured parts.""" + + if isinstance(content, list): + return content + + if content is None: + return [] + + return [{"type": "output_text", "text": str(content)}] + + +def _extract_text_from_content(content: Union[str, List[Dict[str, Any]]]) -> str: + """Extract textual content for inclusion in the aggregated result string.""" + + if isinstance(content, str): + return content + + texts: List[str] = [] + for part in content or []: + if not isinstance(part, dict): + continue + if part.get("type") in {"output_text", "text"} and part.get("text"): + texts.append(str(part["text"])) + return "\n".join(texts) + + +def _serialise_tool_content(content: Any) -> str: + """Convert tool outputs into a string for aggregation.""" + + if isinstance(content, str): + return content + if isinstance(content, list): + texts: List[str] = [] + for part in content: + if isinstance(part, dict) and part.get("type") in {"output_text", "text"} and part.get("text"): + texts.append(str(part["text"])) + if texts: + return "\n".join(texts) + if content is None: + return "" + return str(content) + + def serve() -> FastMCP: """Create and configure the MCP server.""" @@ -110,7 +166,7 @@ def serve() -> FastMCP: messages = [{"role": "user", "content": task}] # Collect all results - full_result = "" + aggregated_messages: List[str] = [] async for result in agent.run(messages): logger.info(f"Agent processing step") ctx.info(f"Agent processing step") @@ -119,37 +175,61 @@ def serve() -> FastMCP: outputs = result.get("output", []) for output in outputs: output_type = output.get("type") + if output_type == "message": - logger.debug(f"Message: {output}") - content = output.get("content", []) - for content_part in content: - if content_part.get("text"): - full_result += f"Message: {content_part.get('text', '')}\n" - elif output_type == "tool_use": - logger.debug(f"Tool use: {output}") - tool_name = output.get("name", "") - full_result += f"Tool: {tool_name}\n" - elif output_type == "tool_result": - logger.debug(f"Tool result: {output}") - result_content = output.get("content", "") - if isinstance(result_content, list): - for item in result_content: - if item.get("type") == "text": - full_result += f"Result: {item.get('text', '')}\n" - else: - full_result += f"Result: {result_content}\n" + logger.debug("Streaming assistant message: %s", output) + content = _normalise_message_content(output.get("content")) + aggregated_text = _extract_text_from_content(content) + if aggregated_text: + aggregated_messages.append(aggregated_text) + await _maybe_call_ctx_method( + ctx, + "yield_message", + role=output.get("role", "assistant"), + content=content, + ) - # Add separator between steps - full_result += "\n" + "-" * 20 + "\n" + elif output_type in {"tool_use", "computer_call", "function_call"}: + logger.debug("Streaming tool call: %s", output) + call_id = output.get("id") or output.get("call_id") + tool_name = output.get("name") or output.get("action", {}).get("type") + tool_input = output.get("input") or output.get("arguments") or output.get("action") + if call_id: + await _maybe_call_ctx_method( + ctx, + "yield_tool_call", + name=tool_name, + call_id=call_id, + input=tool_input, + ) + + elif output_type in {"tool_result", "computer_call_output", "function_call_output"}: + logger.debug("Streaming tool output: %s", output) + call_id = output.get("call_id") or output.get("id") + content = output.get("content") or output.get("output") + aggregated_text = _serialise_tool_content(content) + if aggregated_text: + aggregated_messages.append(aggregated_text) + if call_id: + await _maybe_call_ctx_method( + ctx, + "yield_tool_output", + call_id=call_id, + output=content, + is_error=output.get("status") == "failed" or output.get("is_error", False), + ) + + logger.info("CUA task completed successfully") + ctx.info("CUA task completed successfully") + + screenshot_image = Image( + format="png", + data=await global_computer.interface.screenshot() + ) - logger.info(f"CUA task completed successfully") - ctx.info(f"CUA task completed successfully") return ( - full_result or "Task completed with no text output.", - Image( - format="png", - data=await global_computer.interface.screenshot() - ) + "\n".join(aggregated_messages).strip() or "Task completed with no text output.", + screenshot_image, ) except Exception as e: @@ -173,7 +253,7 @@ def serve() -> FastMCP: ) @server.tool() - async def run_multi_cua_tasks(ctx: Context, tasks: List[str]) -> List: + async def run_multi_cua_tasks(ctx: Context, tasks: List[str]) -> List[Tuple[str, Image]]: """ Run multiple CUA tasks in a MacOS VM in sequence and return the combined results. @@ -184,14 +264,20 @@ def serve() -> FastMCP: Returns: Combined results from all tasks """ - results = [] + total_tasks = len(tasks) + if total_tasks == 0: + ctx.report_progress(1.0) + return [] + + results: List[Tuple[str, Image]] = [] for i, task in enumerate(tasks): - logger.info(f"Running task {i+1}/{len(tasks)}: {task}") - ctx.info(f"Running task {i+1}/{len(tasks)}: {task}") - - ctx.report_progress(i / len(tasks)) - results.extend(await run_cua_task(ctx, task)) - ctx.report_progress((i + 1) / len(tasks)) + logger.info(f"Running task {i+1}/{total_tasks}: {task}") + ctx.info(f"Running task {i+1}/{total_tasks}: {task}") + + ctx.report_progress(i / total_tasks) + task_result = await run_cua_task(ctx, task) + results.append(task_result) + ctx.report_progress((i + 1) / total_tasks) return results diff --git a/tests/test_mcp_server_streming.py b/tests/test_mcp_server_streming.py new file mode 100644 index 00000000..ed84cbfe --- /dev/null +++ b/tests/test_mcp_server_streming.py @@ -0,0 +1,303 @@ +import asyncio +import importlib.util +import sys +import types +from pathlib import Path + +import pytest + + +def _install_stub_module(name: str, module: types.ModuleType, registry: dict[str, types.ModuleType | None]) -> None: + registry[name] = sys.modules.get(name) + sys.modules[name] = module + + +@pytest.fixture +def server_module(): + stubbed_modules: dict[str, types.ModuleType | None] = {} + + # Stub MCP Context primitives + mcp_module = types.ModuleType("mcp") + mcp_module.__path__ = [] # mark as package + + mcp_server_module = types.ModuleType("mcp.server") + mcp_server_module.__path__ = [] + + fastmcp_module = types.ModuleType("mcp.server.fastmcp") + + class _StubContext: + async def yield_message(self, *args, **kwargs): + return None + + async def yield_tool_call(self, *args, **kwargs): + return None + + async def yield_tool_output(self, *args, **kwargs): + return None + + def report_progress(self, *_args, **_kwargs): + return None + + def info(self, *_args, **_kwargs): + return None + + def error(self, *_args, **_kwargs): + return None + + class _StubImage: + def __init__(self, format: str, data: bytes): + self.format = format + self.data = data + + class _StubFastMCP: + def __init__(self, name: str): + self.name = name + self._tools: dict[str, types.FunctionType] = {} + + def tool(self, *args, **kwargs): + def decorator(func): + self._tools[func.__name__] = func + return func + + return decorator + + def run(self): + return None + + fastmcp_module.Context = _StubContext + fastmcp_module.FastMCP = _StubFastMCP + fastmcp_module.Image = _StubImage + + _install_stub_module("mcp", mcp_module, stubbed_modules) + _install_stub_module("mcp.server", mcp_server_module, stubbed_modules) + _install_stub_module("mcp.server.fastmcp", fastmcp_module, stubbed_modules) + + # Stub Computer module to avoid heavy dependencies + computer_module = types.ModuleType("computer") + + class _StubInterface: + async def screenshot(self) -> bytes: # pragma: no cover - default stub + return b"" + + class _StubComputer: + def __init__(self, *args, **kwargs): + self.interface = _StubInterface() + + async def run(self): # pragma: no cover - default stub + return None + + class _StubVMProviderType: + CLOUD = "cloud" + LOCAL = "local" + + computer_module.Computer = _StubComputer + computer_module.VMProviderType = _StubVMProviderType + + _install_stub_module("computer", computer_module, stubbed_modules) + + # Stub agent module so server can import ComputerAgent + agent_module = types.ModuleType("agent") + + class _StubComputerAgent: + def __init__(self, *args, **kwargs): + pass + + async def run(self, *_args, **_kwargs): # pragma: no cover - default stub + if False: # pragma: no cover + yield {} + return + + agent_module.ComputerAgent = _StubComputerAgent + + _install_stub_module("agent", agent_module, stubbed_modules) + + module_name = "mcp_server_server_under_test" + module_path = Path("libs/python/mcp-server/mcp_server/server.py").resolve() + spec = importlib.util.spec_from_file_location(module_name, module_path) + server_module = importlib.util.module_from_spec(spec) + assert spec and spec.loader + spec.loader.exec_module(server_module) + + server_instance = getattr(server_module, "server", None) + if server_instance is not None and hasattr(server_instance, "_tools"): + for name, func in server_instance._tools.items(): + setattr(server_module, name, func) + + try: + yield server_module + finally: + sys.modules.pop(module_name, None) + for name, original in stubbed_modules.items(): + if original is None: + sys.modules.pop(name, None) + else: + sys.modules[name] = original + + +class FakeContext: + def __init__(self) -> None: + self.events: list[tuple] = [] + self.progress_updates: list[float] = [] + + def info(self, message: str) -> None: + self.events.append(("info", message)) + + def error(self, message: str) -> None: + self.events.append(("error", message)) + + def report_progress(self, value: float) -> None: + self.progress_updates.append(value) + + async def yield_message(self, *, role: str, content): + timestamp = asyncio.get_running_loop().time() + self.events.append(("message", role, content, timestamp)) + + async def yield_tool_call(self, *, name: str | None, call_id: str, input): + timestamp = asyncio.get_running_loop().time() + self.events.append(("tool_call", name, call_id, input, timestamp)) + + async def yield_tool_output(self, *, call_id: str, output, is_error: bool = False): + timestamp = asyncio.get_running_loop().time() + self.events.append(("tool_output", call_id, output, is_error, timestamp)) + + +def test_run_cua_task_streams_partial_results(server_module): + async def _run_test(): + class FakeAgent: + script = [] + + def __init__(self, *args, **kwargs): + pass + + async def run(self, messages): # type: ignore[override] + for factory, delay in type(self).script: + yield factory(messages) + if delay: + await asyncio.sleep(delay) + + FakeAgent.script = [ + ( + lambda _messages: { + "output": [ + { + "type": "message", + "role": "assistant", + "content": [ + {"type": "output_text", "text": "First chunk"} + ], + } + ] + }, + 0.0, + ), + ( + lambda _messages: { + "output": [ + { + "type": "tool_use", + "id": "call_1", + "name": "computer", + "input": {"action": "click"}, + }, + { + "type": "computer_call_output", + "call_id": "call_1", + "output": [ + {"type": "text", "text": "Tool completed"} + ], + }, + ] + }, + 0.05, + ), + ] + + class FakeInterface: + def __init__(self) -> None: + self.calls = 0 + + async def screenshot(self) -> bytes: + self.calls += 1 + return b"final-image" + + fake_interface = FakeInterface() + server_module.global_computer = types.SimpleNamespace(interface=fake_interface) + server_module.ComputerAgent = FakeAgent # type: ignore[assignment] + + ctx = FakeContext() + task = asyncio.create_task(server_module.run_cua_task(ctx, "open settings")) + + await asyncio.sleep(0.01) + assert not task.done(), "Task should still be running to simulate long operation" + message_events = [event for event in ctx.events if event[0] == "message"] + assert message_events, "Expected message event before task completion" + + text_result, image = await task + + assert "First chunk" in text_result + assert "Tool completed" in text_result + assert image.data == b"final-image" + assert fake_interface.calls == 1 + + tool_call_events = [event for event in ctx.events if event[0] == "tool_call"] + tool_output_events = [event for event in ctx.events if event[0] == "tool_output"] + assert tool_call_events and tool_output_events + assert tool_call_events[0][2] == "call_1" + assert tool_output_events[0][1] == "call_1" + + asyncio.run(_run_test()) + + +def test_run_multi_cua_tasks_reports_progress(server_module, monkeypatch): + async def _run_test(): + class FakeAgent: + script = [] + + def __init__(self, *args, **kwargs): + pass + + async def run(self, messages): # type: ignore[override] + for factory, delay in type(self).script: + yield factory(messages) + if delay: + await asyncio.sleep(delay) + + FakeAgent.script = [ + ( + lambda messages: { + "output": [ + { + "type": "message", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": f"Result for {messages[0].get('content')}", + } + ], + } + ] + }, + 0.0, + ) + ] + + server_module.ComputerAgent = FakeAgent # type: ignore[assignment] + + class FakeInterface: + async def screenshot(self) -> bytes: + return b"progress-image" + + server_module.global_computer = types.SimpleNamespace(interface=FakeInterface()) + + ctx = FakeContext() + + results = await server_module.run_multi_cua_tasks(ctx, ["a", "b", "c"]) + + assert len(results) == 3 + assert results[0][0] == "Result for a" + assert ctx.progress_updates[0] == pytest.approx(0.0) + assert ctx.progress_updates[-1] == pytest.approx(1.0) + assert len(ctx.progress_updates) == 6 + + asyncio.run(_run_test()) \ No newline at end of file