mirror of
https://github.com/trycua/computer.git
synced 2025-12-31 10:29:59 -06:00
Stream MCP responses instead of buffering
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -40,11 +41,66 @@ except ImportError as e:
|
||||
# Global computer instance for reuse
|
||||
global_computer = None
|
||||
|
||||
|
||||
def get_env_bool(key: str, default: bool = False) -> bool:
|
||||
"""Get boolean value from environment variable."""
|
||||
return os.getenv(key, str(default)).lower() in ("true", "1", "yes")
|
||||
|
||||
async def _maybe_call_ctx_method(ctx: Context, method_name: str, *args, **kwargs) -> None:
|
||||
"""Call a context helper if it exists, awaiting the result when necessary."""
|
||||
|
||||
method = getattr(ctx, method_name, None)
|
||||
if not callable(method):
|
||||
return
|
||||
|
||||
result = method(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
|
||||
|
||||
def _normalise_message_content(content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
|
||||
"""Normalise message content to a list of structured parts."""
|
||||
|
||||
if isinstance(content, list):
|
||||
return content
|
||||
|
||||
if content is None:
|
||||
return []
|
||||
|
||||
return [{"type": "output_text", "text": str(content)}]
|
||||
|
||||
|
||||
def _extract_text_from_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||
"""Extract textual content for inclusion in the aggregated result string."""
|
||||
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
texts: List[str] = []
|
||||
for part in content or []:
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
if part.get("type") in {"output_text", "text"} and part.get("text"):
|
||||
texts.append(str(part["text"]))
|
||||
return "\n".join(texts)
|
||||
|
||||
|
||||
def _serialise_tool_content(content: Any) -> str:
|
||||
"""Convert tool outputs into a string for aggregation."""
|
||||
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
texts: List[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") in {"output_text", "text"} and part.get("text"):
|
||||
texts.append(str(part["text"]))
|
||||
if texts:
|
||||
return "\n".join(texts)
|
||||
if content is None:
|
||||
return ""
|
||||
return str(content)
|
||||
|
||||
|
||||
|
||||
def serve() -> FastMCP:
|
||||
"""Create and configure the MCP server."""
|
||||
@@ -110,7 +166,7 @@ def serve() -> FastMCP:
|
||||
messages = [{"role": "user", "content": task}]
|
||||
|
||||
# Collect all results
|
||||
full_result = ""
|
||||
aggregated_messages: List[str] = []
|
||||
async for result in agent.run(messages):
|
||||
logger.info(f"Agent processing step")
|
||||
ctx.info(f"Agent processing step")
|
||||
@@ -119,37 +175,61 @@ def serve() -> FastMCP:
|
||||
outputs = result.get("output", [])
|
||||
for output in outputs:
|
||||
output_type = output.get("type")
|
||||
|
||||
if output_type == "message":
|
||||
logger.debug(f"Message: {output}")
|
||||
content = output.get("content", [])
|
||||
for content_part in content:
|
||||
if content_part.get("text"):
|
||||
full_result += f"Message: {content_part.get('text', '')}\n"
|
||||
elif output_type == "tool_use":
|
||||
logger.debug(f"Tool use: {output}")
|
||||
tool_name = output.get("name", "")
|
||||
full_result += f"Tool: {tool_name}\n"
|
||||
elif output_type == "tool_result":
|
||||
logger.debug(f"Tool result: {output}")
|
||||
result_content = output.get("content", "")
|
||||
if isinstance(result_content, list):
|
||||
for item in result_content:
|
||||
if item.get("type") == "text":
|
||||
full_result += f"Result: {item.get('text', '')}\n"
|
||||
else:
|
||||
full_result += f"Result: {result_content}\n"
|
||||
logger.debug("Streaming assistant message: %s", output)
|
||||
content = _normalise_message_content(output.get("content"))
|
||||
aggregated_text = _extract_text_from_content(content)
|
||||
if aggregated_text:
|
||||
aggregated_messages.append(aggregated_text)
|
||||
await _maybe_call_ctx_method(
|
||||
ctx,
|
||||
"yield_message",
|
||||
role=output.get("role", "assistant"),
|
||||
content=content,
|
||||
)
|
||||
|
||||
# Add separator between steps
|
||||
full_result += "\n" + "-" * 20 + "\n"
|
||||
elif output_type in {"tool_use", "computer_call", "function_call"}:
|
||||
logger.debug("Streaming tool call: %s", output)
|
||||
call_id = output.get("id") or output.get("call_id")
|
||||
tool_name = output.get("name") or output.get("action", {}).get("type")
|
||||
tool_input = output.get("input") or output.get("arguments") or output.get("action")
|
||||
if call_id:
|
||||
await _maybe_call_ctx_method(
|
||||
ctx,
|
||||
"yield_tool_call",
|
||||
name=tool_name,
|
||||
call_id=call_id,
|
||||
input=tool_input,
|
||||
)
|
||||
|
||||
elif output_type in {"tool_result", "computer_call_output", "function_call_output"}:
|
||||
logger.debug("Streaming tool output: %s", output)
|
||||
call_id = output.get("call_id") or output.get("id")
|
||||
content = output.get("content") or output.get("output")
|
||||
aggregated_text = _serialise_tool_content(content)
|
||||
if aggregated_text:
|
||||
aggregated_messages.append(aggregated_text)
|
||||
if call_id:
|
||||
await _maybe_call_ctx_method(
|
||||
ctx,
|
||||
"yield_tool_output",
|
||||
call_id=call_id,
|
||||
output=content,
|
||||
is_error=output.get("status") == "failed" or output.get("is_error", False),
|
||||
)
|
||||
|
||||
logger.info("CUA task completed successfully")
|
||||
ctx.info("CUA task completed successfully")
|
||||
|
||||
screenshot_image = Image(
|
||||
format="png",
|
||||
data=await global_computer.interface.screenshot()
|
||||
)
|
||||
|
||||
logger.info(f"CUA task completed successfully")
|
||||
ctx.info(f"CUA task completed successfully")
|
||||
return (
|
||||
full_result or "Task completed with no text output.",
|
||||
Image(
|
||||
format="png",
|
||||
data=await global_computer.interface.screenshot()
|
||||
)
|
||||
"\n".join(aggregated_messages).strip() or "Task completed with no text output.",
|
||||
screenshot_image,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -173,7 +253,7 @@ def serve() -> FastMCP:
|
||||
)
|
||||
|
||||
@server.tool()
|
||||
async def run_multi_cua_tasks(ctx: Context, tasks: List[str]) -> List:
|
||||
async def run_multi_cua_tasks(ctx: Context, tasks: List[str]) -> List[Tuple[str, Image]]:
|
||||
"""
|
||||
Run multiple CUA tasks in a MacOS VM in sequence and return the combined results.
|
||||
|
||||
@@ -184,14 +264,20 @@ def serve() -> FastMCP:
|
||||
Returns:
|
||||
Combined results from all tasks
|
||||
"""
|
||||
results = []
|
||||
total_tasks = len(tasks)
|
||||
if total_tasks == 0:
|
||||
ctx.report_progress(1.0)
|
||||
return []
|
||||
|
||||
results: List[Tuple[str, Image]] = []
|
||||
for i, task in enumerate(tasks):
|
||||
logger.info(f"Running task {i+1}/{len(tasks)}: {task}")
|
||||
ctx.info(f"Running task {i+1}/{len(tasks)}: {task}")
|
||||
|
||||
ctx.report_progress(i / len(tasks))
|
||||
results.extend(await run_cua_task(ctx, task))
|
||||
ctx.report_progress((i + 1) / len(tasks))
|
||||
logger.info(f"Running task {i+1}/{total_tasks}: {task}")
|
||||
ctx.info(f"Running task {i+1}/{total_tasks}: {task}")
|
||||
|
||||
ctx.report_progress(i / total_tasks)
|
||||
task_result = await run_cua_task(ctx, task)
|
||||
results.append(task_result)
|
||||
ctx.report_progress((i + 1) / total_tasks)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
Reference in New Issue
Block a user