Stream MCP responses instead of buffering

This commit is contained in:
Adam
2025-09-25 16:07:53 -04:00
parent bdb8e56e10
commit 65263112cd
3 changed files with 429 additions and 37 deletions

View File

@@ -1,5 +1,6 @@
import asyncio
import base64
import inspect
import logging
import os
import sys
@@ -40,11 +41,66 @@ except ImportError as e:
# Global computer instance for reuse
global_computer = None
def get_env_bool(key: str, default: bool = False) -> bool:
"""Get boolean value from environment variable."""
return os.getenv(key, str(default)).lower() in ("true", "1", "yes")
async def _maybe_call_ctx_method(ctx: Context, method_name: str, *args, **kwargs) -> None:
"""Call a context helper if it exists, awaiting the result when necessary."""
method = getattr(ctx, method_name, None)
if not callable(method):
return
result = method(*args, **kwargs)
if inspect.isawaitable(result):
await result
def _normalise_message_content(content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
"""Normalise message content to a list of structured parts."""
if isinstance(content, list):
return content
if content is None:
return []
return [{"type": "output_text", "text": str(content)}]
def _extract_text_from_content(content: Union[str, List[Dict[str, Any]]]) -> str:
"""Extract textual content for inclusion in the aggregated result string."""
if isinstance(content, str):
return content
texts: List[str] = []
for part in content or []:
if not isinstance(part, dict):
continue
if part.get("type") in {"output_text", "text"} and part.get("text"):
texts.append(str(part["text"]))
return "\n".join(texts)
def _serialise_tool_content(content: Any) -> str:
"""Convert tool outputs into a string for aggregation."""
if isinstance(content, str):
return content
if isinstance(content, list):
texts: List[str] = []
for part in content:
if isinstance(part, dict) and part.get("type") in {"output_text", "text"} and part.get("text"):
texts.append(str(part["text"]))
if texts:
return "\n".join(texts)
if content is None:
return ""
return str(content)
def serve() -> FastMCP:
"""Create and configure the MCP server."""
@@ -110,7 +166,7 @@ def serve() -> FastMCP:
messages = [{"role": "user", "content": task}]
# Collect all results
full_result = ""
aggregated_messages: List[str] = []
async for result in agent.run(messages):
logger.info(f"Agent processing step")
ctx.info(f"Agent processing step")
@@ -119,37 +175,61 @@ def serve() -> FastMCP:
outputs = result.get("output", [])
for output in outputs:
output_type = output.get("type")
if output_type == "message":
logger.debug(f"Message: {output}")
content = output.get("content", [])
for content_part in content:
if content_part.get("text"):
full_result += f"Message: {content_part.get('text', '')}\n"
elif output_type == "tool_use":
logger.debug(f"Tool use: {output}")
tool_name = output.get("name", "")
full_result += f"Tool: {tool_name}\n"
elif output_type == "tool_result":
logger.debug(f"Tool result: {output}")
result_content = output.get("content", "")
if isinstance(result_content, list):
for item in result_content:
if item.get("type") == "text":
full_result += f"Result: {item.get('text', '')}\n"
else:
full_result += f"Result: {result_content}\n"
logger.debug("Streaming assistant message: %s", output)
content = _normalise_message_content(output.get("content"))
aggregated_text = _extract_text_from_content(content)
if aggregated_text:
aggregated_messages.append(aggregated_text)
await _maybe_call_ctx_method(
ctx,
"yield_message",
role=output.get("role", "assistant"),
content=content,
)
# Add separator between steps
full_result += "\n" + "-" * 20 + "\n"
elif output_type in {"tool_use", "computer_call", "function_call"}:
logger.debug("Streaming tool call: %s", output)
call_id = output.get("id") or output.get("call_id")
tool_name = output.get("name") or output.get("action", {}).get("type")
tool_input = output.get("input") or output.get("arguments") or output.get("action")
if call_id:
await _maybe_call_ctx_method(
ctx,
"yield_tool_call",
name=tool_name,
call_id=call_id,
input=tool_input,
)
elif output_type in {"tool_result", "computer_call_output", "function_call_output"}:
logger.debug("Streaming tool output: %s", output)
call_id = output.get("call_id") or output.get("id")
content = output.get("content") or output.get("output")
aggregated_text = _serialise_tool_content(content)
if aggregated_text:
aggregated_messages.append(aggregated_text)
if call_id:
await _maybe_call_ctx_method(
ctx,
"yield_tool_output",
call_id=call_id,
output=content,
is_error=output.get("status") == "failed" or output.get("is_error", False),
)
logger.info("CUA task completed successfully")
ctx.info("CUA task completed successfully")
screenshot_image = Image(
format="png",
data=await global_computer.interface.screenshot()
)
logger.info(f"CUA task completed successfully")
ctx.info(f"CUA task completed successfully")
return (
full_result or "Task completed with no text output.",
Image(
format="png",
data=await global_computer.interface.screenshot()
)
"\n".join(aggregated_messages).strip() or "Task completed with no text output.",
screenshot_image,
)
except Exception as e:
@@ -173,7 +253,7 @@ def serve() -> FastMCP:
)
@server.tool()
async def run_multi_cua_tasks(ctx: Context, tasks: List[str]) -> List:
async def run_multi_cua_tasks(ctx: Context, tasks: List[str]) -> List[Tuple[str, Image]]:
"""
Run multiple CUA tasks in a MacOS VM in sequence and return the combined results.
@@ -184,14 +264,20 @@ def serve() -> FastMCP:
Returns:
Combined results from all tasks
"""
results = []
total_tasks = len(tasks)
if total_tasks == 0:
ctx.report_progress(1.0)
return []
results: List[Tuple[str, Image]] = []
for i, task in enumerate(tasks):
logger.info(f"Running task {i+1}/{len(tasks)}: {task}")
ctx.info(f"Running task {i+1}/{len(tasks)}: {task}")
ctx.report_progress(i / len(tasks))
results.extend(await run_cua_task(ctx, task))
ctx.report_progress((i + 1) / len(tasks))
logger.info(f"Running task {i+1}/{total_tasks}: {task}")
ctx.info(f"Running task {i+1}/{total_tasks}: {task}")
ctx.report_progress(i / total_tasks)
task_result = await run_cua_task(ctx, task)
results.append(task_result)
ctx.report_progress((i + 1) / total_tasks)
return results