mirror of
https://github.com/trycua/computer.git
synced 2025-12-31 02:19:58 -06:00
Stream MCP responses instead of buffering
This commit is contained in:
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"python-envs.pythonProjects": []
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -40,11 +41,66 @@ except ImportError as e:
|
||||
# Global computer instance for reuse
|
||||
global_computer = None
|
||||
|
||||
|
||||
def get_env_bool(key: str, default: bool = False) -> bool:
|
||||
"""Get boolean value from environment variable."""
|
||||
return os.getenv(key, str(default)).lower() in ("true", "1", "yes")
|
||||
|
||||
async def _maybe_call_ctx_method(ctx: Context, method_name: str, *args, **kwargs) -> None:
|
||||
"""Call a context helper if it exists, awaiting the result when necessary."""
|
||||
|
||||
method = getattr(ctx, method_name, None)
|
||||
if not callable(method):
|
||||
return
|
||||
|
||||
result = method(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
|
||||
|
||||
def _normalise_message_content(content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
|
||||
"""Normalise message content to a list of structured parts."""
|
||||
|
||||
if isinstance(content, list):
|
||||
return content
|
||||
|
||||
if content is None:
|
||||
return []
|
||||
|
||||
return [{"type": "output_text", "text": str(content)}]
|
||||
|
||||
|
||||
def _extract_text_from_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||
"""Extract textual content for inclusion in the aggregated result string."""
|
||||
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
texts: List[str] = []
|
||||
for part in content or []:
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
if part.get("type") in {"output_text", "text"} and part.get("text"):
|
||||
texts.append(str(part["text"]))
|
||||
return "\n".join(texts)
|
||||
|
||||
|
||||
def _serialise_tool_content(content: Any) -> str:
|
||||
"""Convert tool outputs into a string for aggregation."""
|
||||
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
texts: List[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") in {"output_text", "text"} and part.get("text"):
|
||||
texts.append(str(part["text"]))
|
||||
if texts:
|
||||
return "\n".join(texts)
|
||||
if content is None:
|
||||
return ""
|
||||
return str(content)
|
||||
|
||||
|
||||
|
||||
def serve() -> FastMCP:
|
||||
"""Create and configure the MCP server."""
|
||||
@@ -110,7 +166,7 @@ def serve() -> FastMCP:
|
||||
messages = [{"role": "user", "content": task}]
|
||||
|
||||
# Collect all results
|
||||
full_result = ""
|
||||
aggregated_messages: List[str] = []
|
||||
async for result in agent.run(messages):
|
||||
logger.info(f"Agent processing step")
|
||||
ctx.info(f"Agent processing step")
|
||||
@@ -119,37 +175,61 @@ def serve() -> FastMCP:
|
||||
outputs = result.get("output", [])
|
||||
for output in outputs:
|
||||
output_type = output.get("type")
|
||||
|
||||
if output_type == "message":
|
||||
logger.debug(f"Message: {output}")
|
||||
content = output.get("content", [])
|
||||
for content_part in content:
|
||||
if content_part.get("text"):
|
||||
full_result += f"Message: {content_part.get('text', '')}\n"
|
||||
elif output_type == "tool_use":
|
||||
logger.debug(f"Tool use: {output}")
|
||||
tool_name = output.get("name", "")
|
||||
full_result += f"Tool: {tool_name}\n"
|
||||
elif output_type == "tool_result":
|
||||
logger.debug(f"Tool result: {output}")
|
||||
result_content = output.get("content", "")
|
||||
if isinstance(result_content, list):
|
||||
for item in result_content:
|
||||
if item.get("type") == "text":
|
||||
full_result += f"Result: {item.get('text', '')}\n"
|
||||
else:
|
||||
full_result += f"Result: {result_content}\n"
|
||||
logger.debug("Streaming assistant message: %s", output)
|
||||
content = _normalise_message_content(output.get("content"))
|
||||
aggregated_text = _extract_text_from_content(content)
|
||||
if aggregated_text:
|
||||
aggregated_messages.append(aggregated_text)
|
||||
await _maybe_call_ctx_method(
|
||||
ctx,
|
||||
"yield_message",
|
||||
role=output.get("role", "assistant"),
|
||||
content=content,
|
||||
)
|
||||
|
||||
# Add separator between steps
|
||||
full_result += "\n" + "-" * 20 + "\n"
|
||||
elif output_type in {"tool_use", "computer_call", "function_call"}:
|
||||
logger.debug("Streaming tool call: %s", output)
|
||||
call_id = output.get("id") or output.get("call_id")
|
||||
tool_name = output.get("name") or output.get("action", {}).get("type")
|
||||
tool_input = output.get("input") or output.get("arguments") or output.get("action")
|
||||
if call_id:
|
||||
await _maybe_call_ctx_method(
|
||||
ctx,
|
||||
"yield_tool_call",
|
||||
name=tool_name,
|
||||
call_id=call_id,
|
||||
input=tool_input,
|
||||
)
|
||||
|
||||
elif output_type in {"tool_result", "computer_call_output", "function_call_output"}:
|
||||
logger.debug("Streaming tool output: %s", output)
|
||||
call_id = output.get("call_id") or output.get("id")
|
||||
content = output.get("content") or output.get("output")
|
||||
aggregated_text = _serialise_tool_content(content)
|
||||
if aggregated_text:
|
||||
aggregated_messages.append(aggregated_text)
|
||||
if call_id:
|
||||
await _maybe_call_ctx_method(
|
||||
ctx,
|
||||
"yield_tool_output",
|
||||
call_id=call_id,
|
||||
output=content,
|
||||
is_error=output.get("status") == "failed" or output.get("is_error", False),
|
||||
)
|
||||
|
||||
logger.info("CUA task completed successfully")
|
||||
ctx.info("CUA task completed successfully")
|
||||
|
||||
screenshot_image = Image(
|
||||
format="png",
|
||||
data=await global_computer.interface.screenshot()
|
||||
)
|
||||
|
||||
logger.info(f"CUA task completed successfully")
|
||||
ctx.info(f"CUA task completed successfully")
|
||||
return (
|
||||
full_result or "Task completed with no text output.",
|
||||
Image(
|
||||
format="png",
|
||||
data=await global_computer.interface.screenshot()
|
||||
)
|
||||
"\n".join(aggregated_messages).strip() or "Task completed with no text output.",
|
||||
screenshot_image,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -173,7 +253,7 @@ def serve() -> FastMCP:
|
||||
)
|
||||
|
||||
@server.tool()
|
||||
async def run_multi_cua_tasks(ctx: Context, tasks: List[str]) -> List:
|
||||
async def run_multi_cua_tasks(ctx: Context, tasks: List[str]) -> List[Tuple[str, Image]]:
|
||||
"""
|
||||
Run multiple CUA tasks in a MacOS VM in sequence and return the combined results.
|
||||
|
||||
@@ -184,14 +264,20 @@ def serve() -> FastMCP:
|
||||
Returns:
|
||||
Combined results from all tasks
|
||||
"""
|
||||
results = []
|
||||
total_tasks = len(tasks)
|
||||
if total_tasks == 0:
|
||||
ctx.report_progress(1.0)
|
||||
return []
|
||||
|
||||
results: List[Tuple[str, Image]] = []
|
||||
for i, task in enumerate(tasks):
|
||||
logger.info(f"Running task {i+1}/{len(tasks)}: {task}")
|
||||
ctx.info(f"Running task {i+1}/{len(tasks)}: {task}")
|
||||
|
||||
ctx.report_progress(i / len(tasks))
|
||||
results.extend(await run_cua_task(ctx, task))
|
||||
ctx.report_progress((i + 1) / len(tasks))
|
||||
logger.info(f"Running task {i+1}/{total_tasks}: {task}")
|
||||
ctx.info(f"Running task {i+1}/{total_tasks}: {task}")
|
||||
|
||||
ctx.report_progress(i / total_tasks)
|
||||
task_result = await run_cua_task(ctx, task)
|
||||
results.append(task_result)
|
||||
ctx.report_progress((i + 1) / total_tasks)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
303
tests/test_mcp_server_streming.py
Normal file
303
tests/test_mcp_server_streming.py
Normal file
@@ -0,0 +1,303 @@
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _install_stub_module(name: str, module: types.ModuleType, registry: dict[str, types.ModuleType | None]) -> None:
|
||||
registry[name] = sys.modules.get(name)
|
||||
sys.modules[name] = module
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server_module():
|
||||
stubbed_modules: dict[str, types.ModuleType | None] = {}
|
||||
|
||||
# Stub MCP Context primitives
|
||||
mcp_module = types.ModuleType("mcp")
|
||||
mcp_module.__path__ = [] # mark as package
|
||||
|
||||
mcp_server_module = types.ModuleType("mcp.server")
|
||||
mcp_server_module.__path__ = []
|
||||
|
||||
fastmcp_module = types.ModuleType("mcp.server.fastmcp")
|
||||
|
||||
class _StubContext:
|
||||
async def yield_message(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def yield_tool_call(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def yield_tool_output(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
def report_progress(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
def info(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
def error(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
class _StubImage:
|
||||
def __init__(self, format: str, data: bytes):
|
||||
self.format = format
|
||||
self.data = data
|
||||
|
||||
class _StubFastMCP:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self._tools: dict[str, types.FunctionType] = {}
|
||||
|
||||
def tool(self, *args, **kwargs):
|
||||
def decorator(func):
|
||||
self._tools[func.__name__] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def run(self):
|
||||
return None
|
||||
|
||||
fastmcp_module.Context = _StubContext
|
||||
fastmcp_module.FastMCP = _StubFastMCP
|
||||
fastmcp_module.Image = _StubImage
|
||||
|
||||
_install_stub_module("mcp", mcp_module, stubbed_modules)
|
||||
_install_stub_module("mcp.server", mcp_server_module, stubbed_modules)
|
||||
_install_stub_module("mcp.server.fastmcp", fastmcp_module, stubbed_modules)
|
||||
|
||||
# Stub Computer module to avoid heavy dependencies
|
||||
computer_module = types.ModuleType("computer")
|
||||
|
||||
class _StubInterface:
|
||||
async def screenshot(self) -> bytes: # pragma: no cover - default stub
|
||||
return b""
|
||||
|
||||
class _StubComputer:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.interface = _StubInterface()
|
||||
|
||||
async def run(self): # pragma: no cover - default stub
|
||||
return None
|
||||
|
||||
class _StubVMProviderType:
|
||||
CLOUD = "cloud"
|
||||
LOCAL = "local"
|
||||
|
||||
computer_module.Computer = _StubComputer
|
||||
computer_module.VMProviderType = _StubVMProviderType
|
||||
|
||||
_install_stub_module("computer", computer_module, stubbed_modules)
|
||||
|
||||
# Stub agent module so server can import ComputerAgent
|
||||
agent_module = types.ModuleType("agent")
|
||||
|
||||
class _StubComputerAgent:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def run(self, *_args, **_kwargs): # pragma: no cover - default stub
|
||||
if False: # pragma: no cover
|
||||
yield {}
|
||||
return
|
||||
|
||||
agent_module.ComputerAgent = _StubComputerAgent
|
||||
|
||||
_install_stub_module("agent", agent_module, stubbed_modules)
|
||||
|
||||
module_name = "mcp_server_server_under_test"
|
||||
module_path = Path("libs/python/mcp-server/mcp_server/server.py").resolve()
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
server_module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
spec.loader.exec_module(server_module)
|
||||
|
||||
server_instance = getattr(server_module, "server", None)
|
||||
if server_instance is not None and hasattr(server_instance, "_tools"):
|
||||
for name, func in server_instance._tools.items():
|
||||
setattr(server_module, name, func)
|
||||
|
||||
try:
|
||||
yield server_module
|
||||
finally:
|
||||
sys.modules.pop(module_name, None)
|
||||
for name, original in stubbed_modules.items():
|
||||
if original is None:
|
||||
sys.modules.pop(name, None)
|
||||
else:
|
||||
sys.modules[name] = original
|
||||
|
||||
|
||||
class FakeContext:
|
||||
def __init__(self) -> None:
|
||||
self.events: list[tuple] = []
|
||||
self.progress_updates: list[float] = []
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
self.events.append(("info", message))
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
self.events.append(("error", message))
|
||||
|
||||
def report_progress(self, value: float) -> None:
|
||||
self.progress_updates.append(value)
|
||||
|
||||
async def yield_message(self, *, role: str, content):
|
||||
timestamp = asyncio.get_running_loop().time()
|
||||
self.events.append(("message", role, content, timestamp))
|
||||
|
||||
async def yield_tool_call(self, *, name: str | None, call_id: str, input):
|
||||
timestamp = asyncio.get_running_loop().time()
|
||||
self.events.append(("tool_call", name, call_id, input, timestamp))
|
||||
|
||||
async def yield_tool_output(self, *, call_id: str, output, is_error: bool = False):
|
||||
timestamp = asyncio.get_running_loop().time()
|
||||
self.events.append(("tool_output", call_id, output, is_error, timestamp))
|
||||
|
||||
|
||||
def test_run_cua_task_streams_partial_results(server_module):
|
||||
async def _run_test():
|
||||
class FakeAgent:
|
||||
script = []
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def run(self, messages): # type: ignore[override]
|
||||
for factory, delay in type(self).script:
|
||||
yield factory(messages)
|
||||
if delay:
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
FakeAgent.script = [
|
||||
(
|
||||
lambda _messages: {
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "output_text", "text": "First chunk"}
|
||||
],
|
||||
}
|
||||
]
|
||||
},
|
||||
0.0,
|
||||
),
|
||||
(
|
||||
lambda _messages: {
|
||||
"output": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "call_1",
|
||||
"name": "computer",
|
||||
"input": {"action": "click"},
|
||||
},
|
||||
{
|
||||
"type": "computer_call_output",
|
||||
"call_id": "call_1",
|
||||
"output": [
|
||||
{"type": "text", "text": "Tool completed"}
|
||||
],
|
||||
},
|
||||
]
|
||||
},
|
||||
0.05,
|
||||
),
|
||||
]
|
||||
|
||||
class FakeInterface:
|
||||
def __init__(self) -> None:
|
||||
self.calls = 0
|
||||
|
||||
async def screenshot(self) -> bytes:
|
||||
self.calls += 1
|
||||
return b"final-image"
|
||||
|
||||
fake_interface = FakeInterface()
|
||||
server_module.global_computer = types.SimpleNamespace(interface=fake_interface)
|
||||
server_module.ComputerAgent = FakeAgent # type: ignore[assignment]
|
||||
|
||||
ctx = FakeContext()
|
||||
task = asyncio.create_task(server_module.run_cua_task(ctx, "open settings"))
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
assert not task.done(), "Task should still be running to simulate long operation"
|
||||
message_events = [event for event in ctx.events if event[0] == "message"]
|
||||
assert message_events, "Expected message event before task completion"
|
||||
|
||||
text_result, image = await task
|
||||
|
||||
assert "First chunk" in text_result
|
||||
assert "Tool completed" in text_result
|
||||
assert image.data == b"final-image"
|
||||
assert fake_interface.calls == 1
|
||||
|
||||
tool_call_events = [event for event in ctx.events if event[0] == "tool_call"]
|
||||
tool_output_events = [event for event in ctx.events if event[0] == "tool_output"]
|
||||
assert tool_call_events and tool_output_events
|
||||
assert tool_call_events[0][2] == "call_1"
|
||||
assert tool_output_events[0][1] == "call_1"
|
||||
|
||||
asyncio.run(_run_test())
|
||||
|
||||
|
||||
def test_run_multi_cua_tasks_reports_progress(server_module, monkeypatch):
|
||||
async def _run_test():
|
||||
class FakeAgent:
|
||||
script = []
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def run(self, messages): # type: ignore[override]
|
||||
for factory, delay in type(self).script:
|
||||
yield factory(messages)
|
||||
if delay:
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
FakeAgent.script = [
|
||||
(
|
||||
lambda messages: {
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": f"Result for {messages[0].get('content')}",
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
},
|
||||
0.0,
|
||||
)
|
||||
]
|
||||
|
||||
server_module.ComputerAgent = FakeAgent # type: ignore[assignment]
|
||||
|
||||
class FakeInterface:
|
||||
async def screenshot(self) -> bytes:
|
||||
return b"progress-image"
|
||||
|
||||
server_module.global_computer = types.SimpleNamespace(interface=FakeInterface())
|
||||
|
||||
ctx = FakeContext()
|
||||
|
||||
results = await server_module.run_multi_cua_tasks(ctx, ["a", "b", "c"])
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[0][0] == "Result for a"
|
||||
assert ctx.progress_updates[0] == pytest.approx(0.0)
|
||||
assert ctx.progress_updates[-1] == pytest.approx(1.0)
|
||||
assert len(ctx.progress_updates) == 6
|
||||
|
||||
asyncio.run(_run_test())
|
||||
Reference in New Issue
Block a user