mirror of
https://github.com/trycua/computer.git
synced 2026-01-01 02:50:15 -06:00
Merge branch 'main' into feat/cua-bench-submodules
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.4.37
|
||||
current_version = 0.5.1
|
||||
commit = True
|
||||
tag = True
|
||||
tag_name = agent-v{new_version}
|
||||
|
||||
@@ -51,7 +51,7 @@ async def main():
|
||||
|
||||
# Create agent
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
model="anthropic/claude-sonnet-4-5-20250929",
|
||||
tools=[computer],
|
||||
only_n_most_recent_images=3,
|
||||
trajectory_dir="trajectories",
|
||||
@@ -78,7 +78,7 @@ if __name__ == "__main__":
|
||||
- [Chat History](https://cua.ai/docs/agent-sdk/chat-history)
|
||||
- [Callbacks](https://cua.ai/docs/agent-sdk/callbacks)
|
||||
- [Custom Tools](https://cua.ai/docs/agent-sdk/custom-tools)
|
||||
- [Custom Computer Handlers](https://cua.ai/docs/agent-sdk/custom-computer-handlers)
|
||||
- [Custom Computer Handlers](https://cua.ai/docs/computer-sdk/custom-computer-handlers)
|
||||
- [Prompt Caching](https://cua.ai/docs/agent-sdk/prompt-caching)
|
||||
- [Usage Tracking](https://cua.ai/docs/agent-sdk/usage-tracking)
|
||||
- [Benchmarks](https://cua.ai/docs/agent-sdk/benchmarks)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Adapters package for agent - Custom LLM adapters for LiteLLM
|
||||
"""
|
||||
|
||||
from .cua_adapter import CUAAdapter
|
||||
from .huggingfacelocal_adapter import HuggingFaceLocalAdapter
|
||||
from .human_adapter import HumanAdapter
|
||||
from .mlxvlm_adapter import MLXVLMAdapter
|
||||
@@ -10,4 +11,5 @@ __all__ = [
|
||||
"HuggingFaceLocalAdapter",
|
||||
"HumanAdapter",
|
||||
"MLXVLMAdapter",
|
||||
"CUAAdapter",
|
||||
]
|
||||
|
||||
145
libs/python/agent/agent/adapters/cua_adapter.py
Normal file
145
libs/python/agent/agent/adapters/cua_adapter.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import os
|
||||
from typing import Any, AsyncIterator, Iterator
|
||||
|
||||
from litellm import acompletion, completion
|
||||
from litellm.llms.custom_llm import CustomLLM
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||
|
||||
|
||||
class CUAAdapter(CustomLLM):
|
||||
def __init__(self, base_url: str | None = None, api_key: str | None = None, **_: Any):
|
||||
super().__init__()
|
||||
self.base_url = base_url or os.environ.get("CUA_BASE_URL") or "https://inference.cua.ai/v1"
|
||||
self.api_key = (
|
||||
api_key or os.environ.get("CUA_INFERENCE_API_KEY") or os.environ.get("CUA_API_KEY")
|
||||
)
|
||||
|
||||
def _normalize_model(self, model: str) -> str:
|
||||
# Accept either "cua/<model>" or raw "<model>"
|
||||
return model.split("/", 1)[1] if model and model.startswith("cua/") else model
|
||||
|
||||
def completion(self, *args, **kwargs) -> ModelResponse:
|
||||
model = kwargs.get("model", "")
|
||||
api_base = kwargs.get("api_base") or self.base_url
|
||||
if "anthropic/" in model:
|
||||
model = f"anthropic/{self._normalize_model(model)}"
|
||||
api_base = api_base.removesuffix("/v1")
|
||||
else:
|
||||
model = f"openai/{self._normalize_model(model)}"
|
||||
|
||||
params = {
|
||||
"model": model,
|
||||
"messages": kwargs.get("messages", []),
|
||||
"api_base": api_base,
|
||||
"api_key": kwargs.get("api_key") or self.api_key,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
if "optional_params" in kwargs:
|
||||
params.update(kwargs["optional_params"])
|
||||
del kwargs["optional_params"]
|
||||
|
||||
if "headers" in kwargs:
|
||||
params["headers"] = kwargs["headers"]
|
||||
del kwargs["headers"]
|
||||
|
||||
# Print dropped parameters
|
||||
original_keys = set(kwargs.keys())
|
||||
used_keys = set(params.keys()) # Only these are extracted from kwargs
|
||||
ignored_keys = {
|
||||
"litellm_params",
|
||||
"client",
|
||||
"print_verbose",
|
||||
"acompletion",
|
||||
"timeout",
|
||||
"logging_obj",
|
||||
"encoding",
|
||||
"custom_prompt_dict",
|
||||
"model_response",
|
||||
"logger_fn",
|
||||
}
|
||||
dropped_keys = original_keys - used_keys - ignored_keys
|
||||
if dropped_keys:
|
||||
dropped_keyvals = {k: kwargs[k] for k in dropped_keys}
|
||||
# print(f"CUAAdapter.completion: Dropped parameters: {dropped_keyvals}")
|
||||
|
||||
return completion(**params) # type: ignore
|
||||
|
||||
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
||||
model = kwargs.get("model", "")
|
||||
api_base = kwargs.get("api_base") or self.base_url
|
||||
if "anthropic/" in model:
|
||||
model = f"anthropic/{self._normalize_model(model)}"
|
||||
api_base = api_base.removesuffix("/v1")
|
||||
else:
|
||||
model = f"openai/{self._normalize_model(model)}"
|
||||
|
||||
params = {
|
||||
"model": model,
|
||||
"messages": kwargs.get("messages", []),
|
||||
"api_base": api_base,
|
||||
"api_key": kwargs.get("api_key") or self.api_key,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
if "optional_params" in kwargs:
|
||||
params.update(kwargs["optional_params"])
|
||||
del kwargs["optional_params"]
|
||||
|
||||
if "headers" in kwargs:
|
||||
params["headers"] = kwargs["headers"]
|
||||
del kwargs["headers"]
|
||||
|
||||
# Print dropped parameters
|
||||
original_keys = set(kwargs.keys())
|
||||
used_keys = set(params.keys()) # Only these are extracted from kwargs
|
||||
ignored_keys = {
|
||||
"litellm_params",
|
||||
"client",
|
||||
"print_verbose",
|
||||
"acompletion",
|
||||
"timeout",
|
||||
"logging_obj",
|
||||
"encoding",
|
||||
"custom_prompt_dict",
|
||||
"model_response",
|
||||
"logger_fn",
|
||||
}
|
||||
dropped_keys = original_keys - used_keys - ignored_keys
|
||||
if dropped_keys:
|
||||
dropped_keyvals = {k: kwargs[k] for k in dropped_keys}
|
||||
# print(f"CUAAdapter.acompletion: Dropped parameters: {dropped_keyvals}")
|
||||
|
||||
response = await acompletion(**params) # type: ignore
|
||||
|
||||
return response
|
||||
|
||||
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||
params = dict(kwargs)
|
||||
inner_model = self._normalize_model(params.get("model", ""))
|
||||
params.update(
|
||||
{
|
||||
"model": f"openai/{inner_model}",
|
||||
"api_base": self.base_url,
|
||||
"api_key": self.api_key,
|
||||
"stream": True,
|
||||
}
|
||||
)
|
||||
# Yield chunks directly from LiteLLM's streaming generator
|
||||
for chunk in completion(**params): # type: ignore
|
||||
yield chunk # type: ignore
|
||||
|
||||
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
||||
params = dict(kwargs)
|
||||
inner_model = self._normalize_model(params.get("model", ""))
|
||||
params.update(
|
||||
{
|
||||
"model": f"openai/{inner_model}",
|
||||
"api_base": self.base_url,
|
||||
"api_key": self.api_key,
|
||||
"stream": True,
|
||||
}
|
||||
)
|
||||
stream = await acompletion(**params) # type: ignore
|
||||
async for chunk in stream: # type: ignore
|
||||
yield chunk # type: ignore
|
||||
@@ -23,11 +23,7 @@ import litellm
|
||||
import litellm.utils
|
||||
from litellm.responses.utils import Usage
|
||||
|
||||
from .adapters import (
|
||||
HuggingFaceLocalAdapter,
|
||||
HumanAdapter,
|
||||
MLXVLMAdapter,
|
||||
)
|
||||
from .adapters import CUAAdapter, HuggingFaceLocalAdapter, HumanAdapter, MLXVLMAdapter
|
||||
from .callbacks import (
|
||||
BudgetManagerCallback,
|
||||
ImageRetentionCallback,
|
||||
@@ -193,7 +189,7 @@ class ComputerAgent:
|
||||
Initialize ComputerAgent.
|
||||
|
||||
Args:
|
||||
model: Model name (e.g., "claude-3-5-sonnet-20241022", "computer-use-preview", "omni+vertex_ai/gemini-pro")
|
||||
model: Model name (e.g., "claude-sonnet-4-5-20250929", "computer-use-preview", "omni+vertex_ai/gemini-pro")
|
||||
tools: List of tools (computer objects, decorated functions, etc.)
|
||||
custom_loop: Custom agent loop function to use instead of auto-selection
|
||||
only_n_most_recent_images: If set, only keep the N most recent images in message history. Adds ImageRetentionCallback automatically.
|
||||
@@ -241,13 +237,6 @@ class ComputerAgent:
|
||||
if self.instructions:
|
||||
self.callbacks.append(PromptInstructionsCallback(self.instructions))
|
||||
|
||||
# Add telemetry callback if telemetry_enabled is set
|
||||
if self.telemetry_enabled:
|
||||
if isinstance(self.telemetry_enabled, bool):
|
||||
self.callbacks.append(TelemetryCallback(self))
|
||||
else:
|
||||
self.callbacks.append(TelemetryCallback(self, **self.telemetry_enabled))
|
||||
|
||||
# Add logging callback if verbosity is set
|
||||
if self.verbosity is not None:
|
||||
self.callbacks.append(LoggingCallback(level=self.verbosity))
|
||||
@@ -278,10 +267,12 @@ class ComputerAgent:
|
||||
)
|
||||
human_adapter = HumanAdapter()
|
||||
mlx_adapter = MLXVLMAdapter()
|
||||
cua_adapter = CUAAdapter()
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "huggingface-local", "custom_handler": hf_adapter},
|
||||
{"provider": "human", "custom_handler": human_adapter},
|
||||
{"provider": "mlx", "custom_handler": mlx_adapter},
|
||||
{"provider": "cua", "custom_handler": cua_adapter},
|
||||
]
|
||||
litellm.suppress_debug_info = True
|
||||
|
||||
@@ -299,6 +290,13 @@ class ComputerAgent:
|
||||
self.agent_loop = config_info.agent_class()
|
||||
self.agent_config_info = config_info
|
||||
|
||||
# Add telemetry callback AFTER agent_loop is set so it can capture the correct agent_type
|
||||
if self.telemetry_enabled:
|
||||
if isinstance(self.telemetry_enabled, bool):
|
||||
self.callbacks.append(TelemetryCallback(self))
|
||||
else:
|
||||
self.callbacks.append(TelemetryCallback(self, **self.telemetry_enabled))
|
||||
|
||||
self.tool_schemas = []
|
||||
self.computer_handler = None
|
||||
|
||||
|
||||
@@ -60,11 +60,14 @@ class TelemetryCallback(AsyncCallbackHandler):
|
||||
|
||||
def _record_agent_initialization(self) -> None:
|
||||
"""Record agent type/model and session initialization."""
|
||||
# Get the agent loop type (class name)
|
||||
agent_type = "unknown"
|
||||
if hasattr(self.agent, "agent_loop") and self.agent.agent_loop is not None:
|
||||
agent_type = type(self.agent.agent_loop).__name__
|
||||
|
||||
agent_info = {
|
||||
"session_id": self.session_id,
|
||||
"agent_type": (
|
||||
self.agent.agent_loop.__name__ if hasattr(self.agent, "agent_loop") else "unknown"
|
||||
),
|
||||
"agent_type": agent_type,
|
||||
"model": getattr(self.agent, "model", "unknown"),
|
||||
**SYSTEM_INFO,
|
||||
}
|
||||
|
||||
@@ -6,8 +6,8 @@ Usage:
|
||||
|
||||
Examples:
|
||||
python -m agent.cli openai/computer-use-preview
|
||||
python -m agent.cli anthropic/claude-3-5-sonnet-20241022
|
||||
python -m agent.cli omniparser+anthropic/claude-3-5-sonnet-20241022
|
||||
python -m agent.cli anthropic/claude-sonnet-4-5-20250929
|
||||
python -m agent.cli omniparser+anthropic/claude-sonnet-4-5-20250929
|
||||
"""
|
||||
|
||||
try:
|
||||
@@ -232,15 +232,15 @@ async def main():
|
||||
epilog="""
|
||||
Examples:
|
||||
python -m agent.cli openai/computer-use-preview
|
||||
python -m agent.cli anthropic/claude-3-5-sonnet-20241022
|
||||
python -m agent.cli omniparser+anthropic/claude-3-5-sonnet-20241022
|
||||
python -m agent.cli anthropic/claude-sonnet-4-5-20250929
|
||||
python -m agent.cli omniparser+anthropic/claude-sonnet-4-5-20250929
|
||||
python -m agent.cli huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"model",
|
||||
help="Model string (e.g., 'openai/computer-use-preview', 'anthropic/claude-3-5-sonnet-20241022')",
|
||||
help="Model string (e.g., 'openai/computer-use-preview', 'anthropic/claude-sonnet-4-5-20250929')",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
||||
@@ -1,36 +1,42 @@
|
||||
"""
|
||||
Agent loops for agent
|
||||
"""
|
||||
|
||||
# Import the loops to register them
|
||||
from . import (
|
||||
anthropic,
|
||||
composed_grounded,
|
||||
gemini,
|
||||
glm45v,
|
||||
gta1,
|
||||
holo,
|
||||
internvl,
|
||||
moondream3,
|
||||
omniparser,
|
||||
openai,
|
||||
opencua,
|
||||
qwen,
|
||||
uitars,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"anthropic",
|
||||
"openai",
|
||||
"uitars",
|
||||
"omniparser",
|
||||
"gta1",
|
||||
"composed_grounded",
|
||||
"glm45v",
|
||||
"opencua",
|
||||
"internvl",
|
||||
"holo",
|
||||
"moondream3",
|
||||
"gemini",
|
||||
"qwen",
|
||||
]
|
||||
"""
|
||||
Agent loops for agent
|
||||
"""
|
||||
|
||||
# Import the loops to register them
|
||||
from . import (
|
||||
anthropic,
|
||||
composed_grounded,
|
||||
gelato,
|
||||
gemini,
|
||||
glm45v,
|
||||
gta1,
|
||||
holo,
|
||||
internvl,
|
||||
moondream3,
|
||||
omniparser,
|
||||
openai,
|
||||
opencua,
|
||||
generic_vlm,
|
||||
uiins,
|
||||
uitars,
|
||||
uitars2,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"anthropic",
|
||||
"openai",
|
||||
"uitars",
|
||||
"omniparser",
|
||||
"gta1",
|
||||
"composed_grounded",
|
||||
"glm45v",
|
||||
"opencua",
|
||||
"internvl",
|
||||
"holo",
|
||||
"moondream3",
|
||||
"gemini",
|
||||
"generic_vlm",
|
||||
"uiins",
|
||||
"gelato",
|
||||
"uitars2",
|
||||
]
|
||||
|
||||
@@ -107,12 +107,9 @@ async def _prepare_tools_for_anthropic(tool_schemas: List[Dict[str, Any]], model
|
||||
function_schema = schema["function"]
|
||||
anthropic_tools.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_schema["name"],
|
||||
"description": function_schema.get("description", ""),
|
||||
"parameters": function_schema.get("parameters", {}),
|
||||
},
|
||||
"name": function_schema["name"],
|
||||
"description": function_schema.get("description", ""),
|
||||
"input_schema": function_schema.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -666,11 +663,25 @@ def _convert_completion_to_responses_items(response: Any) -> List[Dict[str, Any]
|
||||
if content_item.get("type") == "text":
|
||||
responses_items.append(make_output_text_item(content_item.get("text", "")))
|
||||
elif content_item.get("type") == "tool_use":
|
||||
# Convert tool use to computer call
|
||||
# Check if this is a custom function tool or computer tool
|
||||
tool_name = content_item.get("name", "computer")
|
||||
tool_input = content_item.get("input", {})
|
||||
action_type = tool_input.get("action")
|
||||
call_id = content_item.get("id")
|
||||
|
||||
# Handle custom function tools (not computer tools)
|
||||
if tool_name != "computer":
|
||||
from ..responses import make_function_call_item
|
||||
|
||||
responses_items.append(
|
||||
make_function_call_item(
|
||||
function_name=tool_name, arguments=tool_input, call_id=call_id
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# Computer tool - process actions
|
||||
action_type = tool_input.get("action")
|
||||
|
||||
# Action reference:
|
||||
# https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/computer-use-tool#available-actions
|
||||
|
||||
@@ -868,6 +879,25 @@ def _convert_completion_to_responses_items(response: Any) -> List[Dict[str, Any]
|
||||
# Handle tool calls (alternative format)
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
|
||||
# Handle custom function tools
|
||||
if tool_name != "computer":
|
||||
from ..responses import make_function_call_item
|
||||
|
||||
# tool_call.function.arguments is a JSON string, need to parse it
|
||||
try:
|
||||
args_dict = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
args_dict = {}
|
||||
responses_items.append(
|
||||
make_function_call_item(
|
||||
function_name=tool_name, arguments=args_dict, call_id=tool_call.id
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle computer tool
|
||||
if tool_call.function.name == "computer":
|
||||
try:
|
||||
try:
|
||||
|
||||
183
libs/python/agent/agent/loops/gelato.py
Normal file
183
libs/python/agent/agent/loops/gelato.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
Gelato agent loop implementation for click prediction using litellm.acompletion
|
||||
Model: https://huggingface.co/mlfoundations/Gelato-30B-A3B
|
||||
Code: https://github.com/mlfoundations/Gelato/tree/main
|
||||
"""
|
||||
|
||||
import base64
|
||||
import math
|
||||
import re
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from PIL import Image
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..types import AgentCapability
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. For elements with area, return the center point.
|
||||
|
||||
Output the coordinate pair exactly:
|
||||
(x,y)
|
||||
"""
|
||||
|
||||
|
||||
def extract_coordinates(raw_string):
|
||||
"""
|
||||
Extract the coordinates from the raw string.
|
||||
Args:
|
||||
raw_string: str (e.g. "(100, 200)")
|
||||
Returns:
|
||||
x: float (e.g. 100.0)
|
||||
y: float (e.g. 200.0)
|
||||
"""
|
||||
try:
|
||||
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
|
||||
return [tuple(map(int, match)) for match in matches][0]
|
||||
except:
|
||||
return 0, 0
|
||||
|
||||
|
||||
def smart_resize(
|
||||
height: int,
|
||||
width: int,
|
||||
factor: int = 28,
|
||||
min_pixels: int = 3136,
|
||||
max_pixels: int = 8847360,
|
||||
) -> Tuple[int, int]:
|
||||
"""Smart resize function similar to qwen_vl_utils."""
|
||||
# Calculate the total pixels
|
||||
total_pixels = height * width
|
||||
|
||||
# If already within bounds, return original dimensions
|
||||
if min_pixels <= total_pixels <= max_pixels:
|
||||
# Round to nearest factor
|
||||
new_height = (height // factor) * factor
|
||||
new_width = (width // factor) * factor
|
||||
return new_height, new_width
|
||||
|
||||
# Calculate scaling factor
|
||||
if total_pixels > max_pixels:
|
||||
scale = (max_pixels / total_pixels) ** 0.5
|
||||
else:
|
||||
scale = (min_pixels / total_pixels) ** 0.5
|
||||
|
||||
# Apply scaling
|
||||
new_height = int(height * scale)
|
||||
new_width = int(width * scale)
|
||||
|
||||
# Round to nearest factor
|
||||
new_height = (new_height // factor) * factor
|
||||
new_width = (new_width // factor) * factor
|
||||
|
||||
# Ensure minimum size
|
||||
new_height = max(new_height, factor)
|
||||
new_width = max(new_width, factor)
|
||||
|
||||
return new_height, new_width
|
||||
|
||||
|
||||
@register_agent(models=r".*Gelato.*")
|
||||
class GelatoConfig(AsyncAgentConfig):
|
||||
"""Gelato agent configuration implementing AsyncAgentConfig protocol for click prediction."""
|
||||
|
||||
def __init__(self):
|
||||
self.current_model = None
|
||||
self.last_screenshot_b64 = None
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def predict_click(
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
) -> Optional[Tuple[float, float]]:
|
||||
"""
|
||||
Predict click coordinates using UI-Ins model via litellm.acompletion.
|
||||
|
||||
Args:
|
||||
model: The UI-Ins model name
|
||||
image_b64: Base64 encoded image
|
||||
instruction: Instruction for where to click
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
# Decode base64 image
|
||||
image_data = base64.b64decode(image_b64)
|
||||
image = Image.open(BytesIO(image_data))
|
||||
width, height = image.width, image.height
|
||||
|
||||
# Smart resize the image (similar to qwen_vl_utils)
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=28, # Default factor for Qwen models
|
||||
min_pixels=3136,
|
||||
max_pixels=4096 * 2160,
|
||||
)
|
||||
resized_image = image.resize((resized_width, resized_height))
|
||||
scale_x, scale_y = width / resized_width, height / resized_height
|
||||
|
||||
# Convert resized image back to base64
|
||||
buffered = BytesIO()
|
||||
resized_image.save(buffered, format="PNG")
|
||||
resized_image_b64 = base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
# Prepare system and user messages
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": SYSTEM_PROMPT.strip()}],
|
||||
}
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{resized_image_b64}"},
|
||||
},
|
||||
{"type": "text", "text": instruction},
|
||||
],
|
||||
}
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": [system_message, user_message],
|
||||
"max_tokens": 2056,
|
||||
"temperature": 0.0,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Use liteLLM acompletion
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
# Extract response text
|
||||
output_text = response.choices[0].message.content # type: ignore
|
||||
|
||||
# Extract and rescale coordinates
|
||||
pred_x, pred_y = extract_coordinates(output_text) # type: ignore
|
||||
pred_x *= scale_x
|
||||
pred_y *= scale_y
|
||||
|
||||
return (math.floor(pred_x), math.floor(pred_y))
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""Return the capabilities supported by this agent."""
|
||||
return ["click"]
|
||||
@@ -20,6 +20,7 @@ from ..loops.base import AsyncAgentConfig
|
||||
from ..responses import (
|
||||
convert_completion_messages_to_responses_items,
|
||||
convert_responses_items_to_completion_messages,
|
||||
make_reasoning_item,
|
||||
)
|
||||
from ..types import AgentCapability
|
||||
|
||||
@@ -233,8 +234,8 @@ def convert_qwen_tool_args_to_computer_action(args: Dict[str, Any]) -> Optional[
|
||||
return None
|
||||
|
||||
|
||||
@register_agent(models=r"(?i).*qwen.*", priority=-1)
|
||||
class Qwen3VlConfig(AsyncAgentConfig):
|
||||
@register_agent(models=r"(?i).*", priority=-100)
|
||||
class GenericVlmConfig(AsyncAgentConfig):
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
@@ -373,13 +374,23 @@ class Qwen3VlConfig(AsyncAgentConfig):
|
||||
if _on_usage:
|
||||
await _on_usage(usage)
|
||||
|
||||
# Parse tool call from text; then convert to responses items via fake tool_calls
|
||||
# Extract response data
|
||||
resp_dict = response.model_dump() # type: ignore
|
||||
choice = (resp_dict.get("choices") or [{}])[0]
|
||||
content_text = ((choice.get("message") or {}).get("content")) or ""
|
||||
tool_call = _parse_tool_call_from_text(content_text)
|
||||
message = choice.get("message") or {}
|
||||
content_text = message.get("content") or ""
|
||||
tool_calls_array = message.get("tool_calls") or []
|
||||
reasoning_text = message.get("reasoning") or ""
|
||||
|
||||
output_items: List[Dict[str, Any]] = []
|
||||
|
||||
# Add reasoning if present (Ollama Cloud format)
|
||||
if reasoning_text:
|
||||
output_items.append(make_reasoning_item(reasoning_text))
|
||||
|
||||
# Priority 1: Try to parse tool call from content text (OpenRouter format)
|
||||
tool_call = _parse_tool_call_from_text(content_text)
|
||||
|
||||
if tool_call and isinstance(tool_call, dict):
|
||||
fn_name = tool_call.get("name") or "computer"
|
||||
raw_args = tool_call.get("arguments") or {}
|
||||
@@ -405,8 +416,50 @@ class Qwen3VlConfig(AsyncAgentConfig):
|
||||
],
|
||||
}
|
||||
output_items.extend(convert_completion_messages_to_responses_items([fake_cm]))
|
||||
elif tool_calls_array:
|
||||
# Priority 2: Use tool_calls field if present (Ollama Cloud format)
|
||||
# Process and unnormalize coordinates in tool calls
|
||||
processed_tool_calls = []
|
||||
for tc in tool_calls_array:
|
||||
function = tc.get("function", {})
|
||||
fn_name = function.get("name", "computer")
|
||||
args_str = function.get("arguments", "{}")
|
||||
|
||||
try:
|
||||
args = json.loads(args_str)
|
||||
|
||||
# Unnormalize coordinates if present
|
||||
if "coordinate" in args and last_rw is not None and last_rh is not None:
|
||||
args = await _unnormalize_coordinate(args, (last_rw, last_rh))
|
||||
|
||||
# Convert Qwen format to Computer Calls format if this is a computer tool
|
||||
if fn_name == "computer":
|
||||
converted_action = convert_qwen_tool_args_to_computer_action(args)
|
||||
if converted_action:
|
||||
args = converted_action
|
||||
|
||||
processed_tool_calls.append(
|
||||
{
|
||||
"type": tc.get("type", "function"),
|
||||
"id": tc.get("id", "call_0"),
|
||||
"function": {
|
||||
"name": fn_name,
|
||||
"arguments": json.dumps(args),
|
||||
},
|
||||
}
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# Keep original if parsing fails
|
||||
processed_tool_calls.append(tc)
|
||||
|
||||
fake_cm = {
|
||||
"role": "assistant",
|
||||
"content": content_text if content_text else "",
|
||||
"tool_calls": processed_tool_calls,
|
||||
}
|
||||
output_items.extend(convert_completion_messages_to_responses_items([fake_cm]))
|
||||
else:
|
||||
# Fallback: just return assistant text
|
||||
# No tool calls found in either format, return text response
|
||||
fake_cm = {"role": "assistant", "content": content_text}
|
||||
output_items.extend(convert_completion_messages_to_responses_items([fake_cm]))
|
||||
|
||||
@@ -365,6 +365,22 @@ class OmniparserConfig(AsyncAgentConfig):
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Add Vertex AI specific parameters if using vertex_ai models
|
||||
if llm_model.startswith("vertex_ai/"):
|
||||
import os
|
||||
|
||||
# Pass vertex_project and vertex_location to liteLLM
|
||||
if "vertex_project" not in api_kwargs:
|
||||
api_kwargs["vertex_project"] = os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
if "vertex_location" not in api_kwargs:
|
||||
api_kwargs["vertex_location"] = "global"
|
||||
|
||||
# Pass through Gemini 3-specific parameters if provided
|
||||
if "thinking_level" in kwargs:
|
||||
api_kwargs["thinking_level"] = kwargs["thinking_level"]
|
||||
if "media_resolution" in kwargs:
|
||||
api_kwargs["media_resolution"] = kwargs["media_resolution"]
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
175
libs/python/agent/agent/loops/uiins.py
Normal file
175
libs/python/agent/agent/loops/uiins.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
UI-Ins agent loop implementation for click prediction using litellm.acompletion
|
||||
Paper: https://arxiv.org/pdf/2510.202861
|
||||
Code: https://github.com/alibaba/UI-Ins
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import litellm
|
||||
from PIL import Image
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..types import AgentCapability, AgentResponse, Messages, Tools
|
||||
|
||||
SYSTEM_PROMPT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.\n\n## Output Format\nReturn a json object with a reasoning process in tags, a function name and arguments within XML tags:\n```\n\n...\n\n\n{"name": "grounding", "arguments": }\n\n```\n represents the following item of the action space:\n## Action Space{"action": "click", "coordinate": [x, y]}\nYour task is to accurately locate a UI element based on the instruction. You should first analyze instruction in tags and finally output the function in tags.\n"""
|
||||
|
||||
|
||||
def parse_coordinates(raw_string: str) -> tuple[int, int]:
|
||||
matches = re.findall(r"\[(\d+),\s*(\d+)\]", raw_string)
|
||||
if matches:
|
||||
return tuple(map(int, matches[0]))
|
||||
return -1, -1
|
||||
|
||||
|
||||
def smart_resize(
|
||||
height: int,
|
||||
width: int,
|
||||
factor: int = 28,
|
||||
min_pixels: int = 3136,
|
||||
max_pixels: int = 8847360,
|
||||
) -> Tuple[int, int]:
|
||||
"""Smart resize function similar to qwen_vl_utils."""
|
||||
# Calculate the total pixels
|
||||
total_pixels = height * width
|
||||
|
||||
# If already within bounds, return original dimensions
|
||||
if min_pixels <= total_pixels <= max_pixels:
|
||||
# Round to nearest factor
|
||||
new_height = (height // factor) * factor
|
||||
new_width = (width // factor) * factor
|
||||
return new_height, new_width
|
||||
|
||||
# Calculate scaling factor
|
||||
if total_pixels > max_pixels:
|
||||
scale = (max_pixels / total_pixels) ** 0.5
|
||||
else:
|
||||
scale = (min_pixels / total_pixels) ** 0.5
|
||||
|
||||
# Apply scaling
|
||||
new_height = int(height * scale)
|
||||
new_width = int(width * scale)
|
||||
|
||||
# Round to nearest factor
|
||||
new_height = (new_height // factor) * factor
|
||||
new_width = (new_width // factor) * factor
|
||||
|
||||
# Ensure minimum size
|
||||
new_height = max(new_height, factor)
|
||||
new_width = max(new_width, factor)
|
||||
|
||||
return new_height, new_width
|
||||
|
||||
|
||||
@register_agent(models=r".*UI-Ins.*")
|
||||
class UIInsConfig(AsyncAgentConfig):
|
||||
"""UI-Ins agent configuration implementing AsyncAgentConfig protocol for click prediction."""
|
||||
|
||||
def __init__(self):
|
||||
self.current_model = None
|
||||
self.last_screenshot_b64 = None
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def predict_click(
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
) -> Optional[Tuple[float, float]]:
|
||||
"""
|
||||
Predict click coordinates using UI-Ins model via litellm.acompletion.
|
||||
|
||||
Args:
|
||||
model: The UI-Ins model name
|
||||
image_b64: Base64 encoded image
|
||||
instruction: Instruction for where to click
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
# Decode base64 image
|
||||
image_data = base64.b64decode(image_b64)
|
||||
image = Image.open(BytesIO(image_data))
|
||||
width, height = image.width, image.height
|
||||
|
||||
# Smart resize the image (similar to qwen_vl_utils)
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=28, # Default factor for Qwen models
|
||||
min_pixels=3136,
|
||||
max_pixels=4096 * 2160,
|
||||
)
|
||||
resized_image = image.resize((resized_width, resized_height))
|
||||
scale_x, scale_y = width / resized_width, height / resized_height
|
||||
|
||||
# Convert resized image back to base64
|
||||
buffered = BytesIO()
|
||||
resized_image.save(buffered, format="PNG")
|
||||
resized_image_b64 = base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
# Prepare system and user messages
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are a helpful assistant."},
|
||||
{"type": "text", "text": SYSTEM_PROMPT},
|
||||
],
|
||||
}
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{resized_image_b64}"},
|
||||
},
|
||||
{"type": "text", "text": instruction},
|
||||
],
|
||||
}
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": [system_message, user_message],
|
||||
"max_tokens": 2056,
|
||||
"temperature": 0.0,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Use liteLLM acompletion
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
# Extract response text
|
||||
output_text = response.choices[0].message.content # type: ignore
|
||||
|
||||
# Extract and rescale coordinates
|
||||
pred_x, pred_y = parse_coordinates(output_text) # type: ignore
|
||||
pred_x *= scale_x
|
||||
pred_y *= scale_y
|
||||
|
||||
return (math.floor(pred_x), math.floor(pred_y))
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""Return the capabilities supported by this agent."""
|
||||
return ["click"]
|
||||
@@ -563,7 +563,7 @@ def convert_uitars_messages_to_litellm(messages: Messages) -> List[Dict[str, Any
|
||||
return litellm_messages
|
||||
|
||||
|
||||
@register_agent(models=r"(?i).*ui-?tars.*")
|
||||
@register_agent(models=r"(?i).*ui-?tars.*", priority=-1)
|
||||
class UITARSConfig:
|
||||
"""
|
||||
UITARS agent configuration using liteLLM for ByteDance-Seed/UI-TARS-1.5-7B model.
|
||||
|
||||
951
libs/python/agent/agent/loops/uitars2.py
Normal file
951
libs/python/agent/agent/loops/uitars2.py
Normal file
@@ -0,0 +1,951 @@
|
||||
"""
|
||||
UITARS-2 agent loop implementation using LiteLLM.
|
||||
- Prepends a system prompt modeled after the training prompts in examples/seed_16_gui.ipynb
|
||||
- Converts Responses items -> completion messages
|
||||
- Calls litellm.acompletion
|
||||
- Parses <seed:tool_call> ... </seed:tool_call> outputs back into Responses items (computer actions)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from litellm.responses.litellm_completion_transformation.transformation import (
|
||||
LiteLLMCompletionResponsesConfig,
|
||||
)
|
||||
|
||||
from ..decorators import register_agent
|
||||
from .omniparser import get_last_computer_call_output # type: ignore
|
||||
|
||||
try:
|
||||
from PIL import Image # type: ignore
|
||||
except Exception: # pragma: no cover
|
||||
Image = None # type: ignore
|
||||
from ..responses import (
|
||||
convert_responses_items_to_completion_messages,
|
||||
make_click_item,
|
||||
make_double_click_item,
|
||||
make_drag_item,
|
||||
make_function_call_item,
|
||||
make_keypress_item,
|
||||
make_move_item,
|
||||
make_output_text_item,
|
||||
make_reasoning_item,
|
||||
make_screenshot_item,
|
||||
make_scroll_item,
|
||||
make_type_item,
|
||||
make_wait_item,
|
||||
)
|
||||
from ..types import AgentCapability
|
||||
|
||||
TOOL_SCHEMAS: List[Dict[str, Any]] = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "open_computer",
|
||||
"parameters": {},
|
||||
"description": "Open computer.",
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "click",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"point": {
|
||||
"type": "string",
|
||||
"description": "Click coordinates. The format is: <point>x y</point>",
|
||||
}
|
||||
},
|
||||
"required": ["point"],
|
||||
},
|
||||
"description": "Mouse left single click action.",
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "left_double",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"point": {
|
||||
"type": "string",
|
||||
"description": "Click coordinates. The format is: <point>x y</point>",
|
||||
}
|
||||
},
|
||||
"required": ["point"],
|
||||
},
|
||||
"description": "Mouse left double click action.",
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "right_single",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"point": {
|
||||
"type": "string",
|
||||
"description": "Click coordinates. The format is: <point>x y</point>",
|
||||
}
|
||||
},
|
||||
"required": ["point"],
|
||||
},
|
||||
"description": "Mouse right single click action.",
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "scroll",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"point": {
|
||||
"type": "string",
|
||||
"description": "Scroll start position. If not specified, default to execute on the current mouse position. The format is: <point>x y</point>",
|
||||
},
|
||||
"direction": {
|
||||
"type": "string",
|
||||
"description": "Scroll direction.",
|
||||
"enum": ["up", "down", "left", "right"],
|
||||
},
|
||||
},
|
||||
"required": ["direction"],
|
||||
},
|
||||
"description": "Scroll action.",
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "move_to",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"point": {
|
||||
"type": "string",
|
||||
"description": "Target coordinates. The format is: <point>x y</point>",
|
||||
}
|
||||
},
|
||||
"required": ["point"],
|
||||
},
|
||||
"description": "Mouse move action.",
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "hotkey",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Hotkeys you want to press. Split keys with a space and use lowercase.",
|
||||
}
|
||||
},
|
||||
"required": ["key"],
|
||||
},
|
||||
"description": "Press hotkey.",
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "finished",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Provide the final answer or response to complete the task.",
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
"description": "This function is used to indicate the completion of a task by providing the final answer or response.",
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "press",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Key you want to press. Only one key can be pressed at one time.",
|
||||
}
|
||||
},
|
||||
"required": ["key"],
|
||||
},
|
||||
"description": "Press key.",
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "release",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Key you want to release. Only one key can be released at one time.",
|
||||
}
|
||||
},
|
||||
"required": ["key"],
|
||||
},
|
||||
"description": "Release key.",
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "mouse_down",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"point": {
|
||||
"type": "string",
|
||||
"description": "Mouse down position. If not specified, default to execute on the current mouse position. The format is: <point>x y</point>",
|
||||
},
|
||||
"button": {
|
||||
"type": "string",
|
||||
"description": "Down button. Default to left.",
|
||||
"enum": ["left", "right"],
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
"description": "Mouse down action.",
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "mouse_up",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"point": {
|
||||
"type": "string",
|
||||
"description": "Mouse up position. If not specified, default to execute on the current mouse position. The format is: <point>x y</point>",
|
||||
},
|
||||
"button": {
|
||||
"type": "string",
|
||||
"description": "Up button. Default to left.",
|
||||
"enum": ["left", "right"],
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
"description": "Mouse up action.",
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "call_user",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Message or information displayed to the user to request their input, feedback, or guidance.",
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
"description": "This function is used to interact with the user by displaying a message and requesting their input, feedback, or guidance.",
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "wait",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"time": {"type": "integer", "description": "Wait time in seconds."}},
|
||||
"required": [],
|
||||
},
|
||||
"description": "Wait for a while.",
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "drag",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"start_point": {
|
||||
"type": "string",
|
||||
"description": "Drag start point. The format is: <point>x y</point>",
|
||||
},
|
||||
"end_point": {
|
||||
"type": "string",
|
||||
"description": "Drag end point. The format is: <point>x y</point>",
|
||||
},
|
||||
},
|
||||
"required": ["start_point", "end_point"],
|
||||
},
|
||||
"description": "Mouse left button drag action.",
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "type",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Type content. If you want to submit your input, use \\n at the end of content.",
|
||||
}
|
||||
},
|
||||
"required": ["content"],
|
||||
},
|
||||
"description": "Type content.",
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "take_screenshot",
|
||||
"parameters": {},
|
||||
"description": "Take screenshot.",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _format_tool_schemas_json_lines(schemas: List[Dict[str, Any]]) -> str:
|
||||
# Nicely formatted: pretty JSON with indentation, separated by blank lines
|
||||
return "\n\n".join(json.dumps(s, ensure_ascii=False, indent=2) for s in schemas) + "\n\n"
|
||||
|
||||
|
||||
_PROMPT_PREFIX = (
|
||||
"You should begin by detailing the internal reasoning process, and then present the answer to the user. "
|
||||
"The reasoning process should be enclosed within <think_never_used_51bce0c785ca2f68081bfa7d91973934> "
|
||||
"</think_never_used_51bce0c785ca2f68081bfa7d91973934> tags, as follows:\n"
|
||||
"<think_never_used_51bce0c785ca2f68081bfa7d91973934> reasoning process here "
|
||||
"</think_never_used_51bce0c785ca2f68081bfa7d91973934> answer here.\n\n"
|
||||
"You have different modes of thinking:\n"
|
||||
"Unrestricted think mode: Engage in an internal thinking process with thorough reasoning and reflections. "
|
||||
"You have an unlimited budget for thinking tokens and can continue thinking until you fully solve the problem.\n"
|
||||
"Efficient think mode: Provide a concise internal thinking process with efficient reasoning and reflections. "
|
||||
"You don't have a strict token budget but be less verbose and more direct in your thinking.\n"
|
||||
"No think mode: Respond directly to the question without any internal reasoning process or extra thinking tokens. "
|
||||
"Still follow the template with the minimum required thinking tokens to justify the answer.\n"
|
||||
"Budgeted think mode: Limit your internal reasoning and reflections to stay within the specified token budget\n\n"
|
||||
"Based on the complexity of the problem, select the appropriate mode for reasoning among the provided options listed below.\n\n"
|
||||
"Provided Mode(s):\nEfficient think.\n\n"
|
||||
"You are provided with a task description, a history of previous actions, and corresponding screenshots. "
|
||||
"Your goal is to perform the next action to complete the task. "
|
||||
"If performing the same action multiple times results in a static screen with no changes, attempt a modified or alternative action.\n\n"
|
||||
"## Function Definition\n\n"
|
||||
"- You have access to the following functions:\n\n"
|
||||
)
|
||||
|
||||
_PROMPT_SUFFIX = (
|
||||
"- To call a function, use the following structure without any suffix:\n\n"
|
||||
"<gui_think> reasoning process </gui_think>\n"
|
||||
"<seed:tool_call><function=example_function_name><parameter=example_parameter_1>value_1</parameter>"
|
||||
"<parameter=example_parameter_2>multiline...\n</parameter></function></seed:tool_call>\n\n"
|
||||
"## Important Notes\n"
|
||||
"- Function calls must begin with <function= and end with </function>.\n"
|
||||
"- All required parameters must be explicitly provided.\n"
|
||||
"\n## Additional Notes\n"
|
||||
"- You can execute multiple actions within a single tool call. For example:\n"
|
||||
"<seed:tool_call><function=example_function_1><parameter=example_parameter_1>value_1</parameter><parameter=example_parameter_2>\n"
|
||||
"This is the value for the second parameter\nthat can span\nmultiple lines\n"
|
||||
"</parameter></function><function=example_function_2><parameter=example_parameter_3>value_4</parameter></function></seed:tool_call>"
|
||||
)
|
||||
|
||||
|
||||
SYSTEM_PROMPT = _PROMPT_PREFIX + _format_tool_schemas_json_lines(TOOL_SCHEMAS) + _PROMPT_SUFFIX
|
||||
|
||||
|
||||
def _extract_function_schemas_from_tools(
|
||||
tools: Optional[List[Dict[str, Any]]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
schemas: List[Dict[str, Any]] = []
|
||||
if not tools:
|
||||
return schemas
|
||||
for t in tools:
|
||||
if t.get("type") == "function":
|
||||
fn = t.get("function", {})
|
||||
name = fn.get("name")
|
||||
params = fn.get("parameters", {})
|
||||
desc = fn.get("description", "")
|
||||
if name:
|
||||
schemas.append(
|
||||
{
|
||||
"type": "function",
|
||||
"name": name,
|
||||
"parameters": params if isinstance(params, dict) else {},
|
||||
"description": desc,
|
||||
}
|
||||
)
|
||||
return schemas
|
||||
|
||||
|
||||
def _parse_seed_tool_calls(text: str) -> List[Dict[str, Any]]:
|
||||
"""Parse <seed:tool_call> blocks into a list of {function, parameters} dicts.
|
||||
Also captures optional <gui_think>...</gui_think> as reasoning.
|
||||
"""
|
||||
actions: List[Dict[str, Any]] = []
|
||||
if not text:
|
||||
return actions
|
||||
|
||||
# Extract reasoning if present
|
||||
reasoning_text = None
|
||||
think_match = re.search(r"<gui_think>([\s\S]*?)</gui_think>", text)
|
||||
if think_match:
|
||||
reasoning_text = think_match.group(1).strip()
|
||||
|
||||
# Iterate each seed tool_call block
|
||||
for block in re.finditer(r"<seed:tool_call>([\s\S]*?)</seed:tool_call>", text):
|
||||
content = block.group(1)
|
||||
# One or multiple <function=...>...</function> inside
|
||||
for fmatch in re.finditer(r"<function=([\w_]+)>([\s\S]*?)</function>", content):
|
||||
fname = fmatch.group(1)
|
||||
inner = fmatch.group(2)
|
||||
params: Dict[str, str] = {}
|
||||
for pmatch in re.finditer(r"<parameter=([\w_]+)>([\s\S]*?)</parameter>", inner):
|
||||
pname = pmatch.group(1)
|
||||
pval = pmatch.group(2).strip()
|
||||
params[pname] = pval
|
||||
actions.append({"function": fname, "parameters": params})
|
||||
|
||||
# If we have a global reasoning and at least one action, attach it to first
|
||||
if reasoning_text and actions:
|
||||
actions[0]["reasoning"] = reasoning_text
|
||||
elif reasoning_text:
|
||||
actions.append({"function": "reasoning", "parameters": {"content": reasoning_text}})
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
def _normalize_xy_to_uitars(x: int, y: int, width: int, height: int) -> Tuple[int, int]:
|
||||
width = max(1, int(width))
|
||||
height = max(1, int(height))
|
||||
nx = max(0, min(1000, int(round((x / width) * 1000))))
|
||||
ny = max(0, min(1000, int(round((y / height) * 1000))))
|
||||
return nx, ny
|
||||
|
||||
|
||||
def _denormalize_xy_from_uitars(nx: float, ny: float, width: int, height: int) -> Tuple[int, int]:
|
||||
width = max(1, int(width))
|
||||
height = max(1, int(height))
|
||||
x = int(round((nx / 1000.0) * width))
|
||||
y = int(round((ny / 1000.0) * height))
|
||||
return x, y
|
||||
|
||||
|
||||
def _map_computer_action_to_function(
|
||||
action: Dict[str, Any], width: int, height: int
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Map a computer action item to a UITARS function + parameters dict of strings.
|
||||
Returns dict like {"function": name, "parameters": {..}} or None if unknown.
|
||||
"""
|
||||
atype = action.get("type") or action.get("action")
|
||||
if atype == "click":
|
||||
x, y = action.get("x"), action.get("y")
|
||||
btn = action.get("button", "left")
|
||||
if x is None or y is None:
|
||||
return None
|
||||
nx, ny = _normalize_xy_to_uitars(int(x), int(y), width, height)
|
||||
if btn == "right":
|
||||
return {
|
||||
"function": "right_single",
|
||||
"parameters": {"point": f"<point>{nx} {ny}</point>"},
|
||||
}
|
||||
return {"function": "click", "parameters": {"point": f"<point>{nx} {ny}</point>"}}
|
||||
if atype == "double_click":
|
||||
x, y = action.get("x"), action.get("y")
|
||||
if x is None or y is None:
|
||||
return None
|
||||
nx, ny = _normalize_xy_to_uitars(int(x), int(y), width, height)
|
||||
return {"function": "left_double", "parameters": {"point": f"<point>{nx} {ny}</point>"}}
|
||||
if atype == "move":
|
||||
x, y = action.get("x"), action.get("y")
|
||||
if x is None or y is None:
|
||||
return None
|
||||
nx, ny = _normalize_xy_to_uitars(int(x), int(y), width, height)
|
||||
return {"function": "move_to", "parameters": {"point": f"<point>{nx} {ny}</point>"}}
|
||||
if atype == "keypress":
|
||||
keys = action.get("keys", [])
|
||||
if isinstance(keys, list) and keys:
|
||||
if len(keys) == 1:
|
||||
return {"function": "press", "parameters": {"key": keys[0]}}
|
||||
else:
|
||||
return {"function": "hotkey", "parameters": {"key": " ".join(keys)}}
|
||||
return None
|
||||
if atype == "type":
|
||||
text = action.get("text", "")
|
||||
return {"function": "type", "parameters": {"content": text}}
|
||||
if atype == "scroll":
|
||||
x, y = action.get("x", 512), action.get("y", 512)
|
||||
nx, ny = _normalize_xy_to_uitars(int(x), int(y), width, height)
|
||||
sx, sy = action.get("scroll_x", 0), action.get("scroll_y", 0)
|
||||
# Our parser used positive sy for up
|
||||
direction = (
|
||||
"up"
|
||||
if sy and sy > 0
|
||||
else (
|
||||
"down"
|
||||
if sy and sy < 0
|
||||
else ("right" if sx and sx > 0 else ("left" if sx and sx < 0 else "down"))
|
||||
)
|
||||
)
|
||||
return {
|
||||
"function": "scroll",
|
||||
"parameters": {"direction": direction, "point": f"<point>{nx} {ny}</point>"},
|
||||
}
|
||||
if atype == "drag":
|
||||
path = action.get("path", [])
|
||||
if isinstance(path, list) and len(path) >= 2:
|
||||
sx, sy = path[0].get("x"), path[0].get("y")
|
||||
ex, ey = path[-1].get("x"), path[-1].get("y")
|
||||
if sx is None or sy is None or ex is None or ey is None:
|
||||
return None
|
||||
nsx, nsy = _normalize_xy_to_uitars(int(sx), int(sy), width, height)
|
||||
nex, ney = _normalize_xy_to_uitars(int(ex), int(ey), width, height)
|
||||
return {
|
||||
"function": "drag",
|
||||
"parameters": {
|
||||
"start_point": f"<point>{nsx} {nsy}</point>",
|
||||
"end_point": f"<point>{nex} {ney}</point>",
|
||||
},
|
||||
}
|
||||
return None
|
||||
if atype == "wait":
|
||||
return {"function": "wait", "parameters": {}}
|
||||
if atype == "screenshot":
|
||||
return {"function": "take_screenshot", "parameters": {}}
|
||||
# Fallback unknown
|
||||
return None
|
||||
|
||||
|
||||
def _to_uitars_messages(
|
||||
messages: List[Dict[str, Any]], width: int, height: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert responses items into completion messages tailored for UI-TARS.
|
||||
|
||||
- User content is passed through similar to convert_responses_items_to_completion_messages
|
||||
- Assistant/tool history is rendered as text with <gui_think> and <seed:tool_call> blocks
|
||||
"""
|
||||
uitars_messages: List[Dict[str, Any]] = []
|
||||
|
||||
def flush_seed_block(pending_think: Optional[str], pending_functions: List[Dict[str, Any]]):
|
||||
if not pending_think and not pending_functions:
|
||||
return
|
||||
parts: List[str] = []
|
||||
if pending_think:
|
||||
parts.append(f"<gui_think> {pending_think} </gui_think>")
|
||||
if pending_functions:
|
||||
inner = []
|
||||
for f in pending_functions:
|
||||
fname = f["function"]
|
||||
params = f.get("parameters", {})
|
||||
param_blocks = []
|
||||
for k, v in params.items():
|
||||
param_blocks.append(f"<parameter={k}>{v}</parameter>")
|
||||
inner.append(f"<function={fname}>{''.join(param_blocks)}</function>")
|
||||
parts.append(f"<seed:tool_call>{''.join(inner)}</seed:tool_call>")
|
||||
uitars_messages.append({"role": "assistant", "content": "".join(parts)})
|
||||
|
||||
# Accumulators for a single assistant seed block
|
||||
pending_think: Optional[str] = None
|
||||
pending_functions: List[Dict[str, Any]] = []
|
||||
|
||||
for msg in messages:
|
||||
mtype = msg.get("type")
|
||||
role = msg.get("role")
|
||||
|
||||
# On any user message, flush current assistant block
|
||||
if role == "user" or mtype == "user":
|
||||
flush_seed_block(pending_think, pending_functions)
|
||||
pending_think, pending_functions = None, []
|
||||
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, list):
|
||||
completion_content = []
|
||||
for item in content:
|
||||
if item.get("type") == "input_image":
|
||||
completion_content.append(
|
||||
{"type": "image_url", "image_url": {"url": item.get("image_url")}}
|
||||
)
|
||||
elif item.get("type") in ("input_text", "text"):
|
||||
completion_content.append({"type": "text", "text": item.get("text")})
|
||||
uitars_messages.append({"role": "user", "content": completion_content})
|
||||
elif isinstance(content, str):
|
||||
uitars_messages.append({"role": "user", "content": content})
|
||||
continue
|
||||
|
||||
# Reasoning item
|
||||
if mtype == "reasoning":
|
||||
# Responses reasoning stores summary list
|
||||
summary = msg.get("summary", [])
|
||||
texts = [
|
||||
s.get("text", "")
|
||||
for s in summary
|
||||
if isinstance(s, dict) and s.get("type") == "summary_text"
|
||||
]
|
||||
if texts:
|
||||
pending_think = "\n".join([t for t in texts if t])
|
||||
continue
|
||||
|
||||
# Computer/tool calls -> map to functions
|
||||
if mtype == "computer_call":
|
||||
f = _map_computer_action_to_function(msg.get("action", {}), width, height)
|
||||
if f:
|
||||
pending_functions.append(f)
|
||||
continue
|
||||
if mtype == "function_call":
|
||||
# Include custom tools as-is
|
||||
name = msg.get("name")
|
||||
try:
|
||||
args_obj = json.loads(msg.get("arguments", "{}"))
|
||||
except json.JSONDecodeError:
|
||||
args_obj = {}
|
||||
# Ensure string values
|
||||
params = {k: (str(v) if not isinstance(v, str) else v) for k, v in args_obj.items()}
|
||||
pending_functions.append({"function": name, "parameters": params})
|
||||
continue
|
||||
|
||||
# If assistant message text is given, flush current block and add as plain assistant text
|
||||
if role == "assistant" or mtype == "message":
|
||||
flush_seed_block(pending_think, pending_functions)
|
||||
pending_think, pending_functions = None, []
|
||||
content = msg.get("content", [])
|
||||
if isinstance(content, list):
|
||||
texts = [
|
||||
c.get("text", "")
|
||||
for c in content
|
||||
if isinstance(c, dict) and c.get("type") in ("output_text", "text")
|
||||
]
|
||||
if texts:
|
||||
uitars_messages.append(
|
||||
{"role": "assistant", "content": "\n".join([t for t in texts if t])}
|
||||
)
|
||||
elif isinstance(content, str) and content:
|
||||
uitars_messages.append({"role": "assistant", "content": content})
|
||||
continue
|
||||
|
||||
# On outputs, flush pending assistant block and send outputs as user messages
|
||||
if mtype in ("function_call_output", "computer_call_output"):
|
||||
flush_seed_block(pending_think, pending_functions)
|
||||
pending_think, pending_functions = None, []
|
||||
output = msg.get("output")
|
||||
if isinstance(output, dict) and output.get("type") == "input_image":
|
||||
img_url = output.get("image_url")
|
||||
if img_url:
|
||||
uitars_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": img_url}},
|
||||
],
|
||||
}
|
||||
)
|
||||
elif isinstance(output, str):
|
||||
uitars_messages.append({"role": "user", "content": output})
|
||||
else:
|
||||
# Fallback stringify
|
||||
uitars_messages.append({"role": "user", "content": json.dumps(output)})
|
||||
continue
|
||||
|
||||
# Flush any remaining pending seed block
|
||||
flush_seed_block(pending_think, pending_functions)
|
||||
|
||||
return uitars_messages
|
||||
|
||||
|
||||
def _to_response_items(
|
||||
actions: List[Dict[str, Any]],
|
||||
tool_names: Optional[set[str]] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
) -> List[Any]:
|
||||
"""Map parsed actions into Responses items (computer actions + optional reasoning)."""
|
||||
items: List[Any] = []
|
||||
tool_names = tool_names or set()
|
||||
|
||||
# Optional top-level reasoning attached to first
|
||||
if actions and actions[0].get("reasoning"):
|
||||
items.append(make_reasoning_item(actions[0]["reasoning"]))
|
||||
|
||||
# Dimensions default
|
||||
w = int(width) if width else 1024
|
||||
h = int(height) if height else 768
|
||||
|
||||
for a in actions:
|
||||
fn = a.get("function")
|
||||
params = a.get("parameters", {})
|
||||
if fn == "reasoning":
|
||||
items.append(make_reasoning_item(params.get("content", "")))
|
||||
elif fn in ("click", "left_double", "right_single"):
|
||||
# params.point is like: <point>x y</point> or plain "x y"
|
||||
point = params.get("point", "").strip()
|
||||
m = re.search(r"([\-\d\.]+)\s+([\-\d\.]+)", point)
|
||||
if not m:
|
||||
continue
|
||||
nx = float(m.group(1))
|
||||
ny = float(m.group(2))
|
||||
x, y = _denormalize_xy_from_uitars(nx, ny, w, h)
|
||||
if fn == "left_double":
|
||||
items.append(make_double_click_item(x, y))
|
||||
elif fn == "right_single":
|
||||
items.append(make_click_item(x, y, "right"))
|
||||
else:
|
||||
items.append(make_click_item(x, y, "left"))
|
||||
elif fn == "move_to":
|
||||
point = params.get("point", "").strip()
|
||||
m = re.search(r"([\-\d\.]+)\s+([\-\d\.]+)", point)
|
||||
if not m:
|
||||
continue
|
||||
nx = float(m.group(1))
|
||||
ny = float(m.group(2))
|
||||
x, y = _denormalize_xy_from_uitars(nx, ny, w, h)
|
||||
items.append(make_move_item(x, y))
|
||||
elif fn == "drag":
|
||||
sp = params.get("start_point", "").strip()
|
||||
ep = params.get("end_point", "").strip()
|
||||
ms = re.search(r"([\-\d\.]+)\s+([\-\d\.]+)", sp)
|
||||
me = re.search(r"([\-\d\.]+)\s+([\-\d\.]+)", ep)
|
||||
if not (ms and me):
|
||||
continue
|
||||
nsx, nsy = float(ms.group(1)), float(ms.group(2))
|
||||
nex, ney = float(me.group(1)), float(me.group(2))
|
||||
sx, sy = _denormalize_xy_from_uitars(nsx, nsy, w, h)
|
||||
ex, ey = _denormalize_xy_from_uitars(nex, ney, w, h)
|
||||
items.append(make_drag_item([{"x": sx, "y": sy}, {"x": ex, "y": ey}]))
|
||||
elif fn == "hotkey":
|
||||
key = params.get("key", "")
|
||||
keys = key.split()
|
||||
if keys:
|
||||
items.append(make_keypress_item(keys))
|
||||
elif fn == "press":
|
||||
key = params.get("key", "")
|
||||
if key:
|
||||
items.append(make_keypress_item([key]))
|
||||
elif fn == "type":
|
||||
content = params.get("content", "")
|
||||
items.append(make_type_item(content))
|
||||
elif fn == "scroll":
|
||||
# direction: up/down/left/right. Point optional
|
||||
direction = params.get("direction", "down").lower()
|
||||
point = params.get("point", "")
|
||||
m = re.search(r"([\-\d\.]+)\s+([\-\d\.]+)", point)
|
||||
if m:
|
||||
nx = float(m.group(1))
|
||||
ny = float(m.group(2))
|
||||
x, y = _denormalize_xy_from_uitars(nx, ny, w, h)
|
||||
else:
|
||||
x, y = _denormalize_xy_from_uitars(500.0, 500.0, w, h)
|
||||
dy = 5 if direction == "up" else -5
|
||||
dx = 5 if direction == "right" else (-5 if direction == "left" else 0)
|
||||
items.append(make_scroll_item(x, y, dx, dy))
|
||||
elif fn == "wait":
|
||||
items.append(make_wait_item())
|
||||
elif fn == "finished":
|
||||
content = params.get("content", "")
|
||||
items.append(make_output_text_item(content or "Task completed."))
|
||||
break
|
||||
elif fn == "take_screenshot":
|
||||
items.append(make_screenshot_item())
|
||||
elif fn == "open_computer":
|
||||
items.append(make_screenshot_item())
|
||||
else:
|
||||
# If this function name is present in provided tool schemas, emit function_call
|
||||
if fn in tool_names:
|
||||
# Convert simple string params into an arguments object
|
||||
# Parameters are strings; pass through as-is
|
||||
items.append(make_function_call_item(fn, params))
|
||||
else:
|
||||
# Unknown function -> surface as assistant text
|
||||
items.append(make_output_text_item(f"Unknown action: {fn} {params}"))
|
||||
|
||||
return items
|
||||
|
||||
|
||||
@register_agent(models=r"(?i).*ui-?tars-?2.*")
|
||||
class UITARS2Config:
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
# Determine screen dimensions (prefer computer_handler, fallback to last screenshot)
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
if computer_handler is not None and hasattr(computer_handler, "get_dimensions"):
|
||||
try:
|
||||
dims = await computer_handler.get_dimensions() # type: ignore
|
||||
if isinstance(dims, (list, tuple)) and len(dims) == 2:
|
||||
width, height = int(dims[0]), int(dims[1])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if width is None or height is None:
|
||||
try:
|
||||
last_out = get_last_computer_call_output(messages) # type: ignore
|
||||
if last_out:
|
||||
image_url = last_out.get("output", {}).get("image_url", "")
|
||||
if image_url:
|
||||
b64 = image_url.split(",")[-1]
|
||||
img_bytes = base64.b64decode(b64)
|
||||
if Image is not None:
|
||||
img = Image.open(io.BytesIO(img_bytes))
|
||||
width, height = img.size
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if width is None or height is None:
|
||||
width, height = 1024, 768
|
||||
|
||||
# Convert Responses items to UI-TARS style messages with <seed:tool_call> history
|
||||
completion_messages = _to_uitars_messages(messages, width, height)
|
||||
|
||||
# Build dynamic system prompt by concatenating built-in schemas and provided function tools
|
||||
provided_fn_schemas = _extract_function_schemas_from_tools(tools)
|
||||
combined_schemas = (
|
||||
TOOL_SCHEMAS + provided_fn_schemas if provided_fn_schemas else TOOL_SCHEMAS
|
||||
)
|
||||
dynamic_system_prompt = (
|
||||
_PROMPT_PREFIX + _format_tool_schemas_json_lines(combined_schemas) + _PROMPT_SUFFIX
|
||||
)
|
||||
|
||||
# Prepend system prompt (based on training prompts + provided tools)
|
||||
litellm_messages: List[Dict[str, Any]] = [
|
||||
{"role": "system", "content": dynamic_system_prompt},
|
||||
]
|
||||
litellm_messages.extend(completion_messages)
|
||||
|
||||
api_kwargs: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": litellm_messages,
|
||||
"max_retries": max_retries,
|
||||
"stream": stream,
|
||||
**{k: v for k, v in kwargs.items()},
|
||||
}
|
||||
if use_prompt_caching:
|
||||
api_kwargs["use_prompt_caching"] = use_prompt_caching
|
||||
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
usage = {
|
||||
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage( # type: ignore
|
||||
response.usage
|
||||
).model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(usage)
|
||||
|
||||
# Extract text content (first choice)
|
||||
response_dict = response.model_dump() # type: ignore
|
||||
content_text = ""
|
||||
choices = response_dict.get("choices", [])
|
||||
if choices:
|
||||
msg = choices[0].get("message", {})
|
||||
# message.content may be string or array; gather text pieces
|
||||
mc = msg.get("content")
|
||||
if isinstance(mc, str):
|
||||
content_text = mc
|
||||
elif isinstance(mc, list):
|
||||
parts = []
|
||||
for part in mc:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
parts.append(part.get("text", ""))
|
||||
content_text = "\n".join([p for p in parts if p])
|
||||
|
||||
# Parse the seed tool calls and map to response items
|
||||
actions = _parse_seed_tool_calls(content_text)
|
||||
# Build set of tool names from provided tools to emit function_call items
|
||||
tool_names: set[str] = set()
|
||||
for s in provided_fn_schemas:
|
||||
name = s.get("name")
|
||||
if isinstance(name, str):
|
||||
tool_names.add(name)
|
||||
output_items = _to_response_items(actions, tool_names, width, height)
|
||||
|
||||
return {"output": output_items, "usage": usage}
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
return ["step"]
|
||||
|
||||
async def predict_click(
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""Predict a single click coordinate using a minimal prompt with a click tool.
|
||||
|
||||
This sends the current screenshot and instruction, asking the model to
|
||||
output a click action in the form:
|
||||
Action: click(point='(x,y)')
|
||||
"""
|
||||
# Minimal grounding-style prompt
|
||||
system_text = (
|
||||
"You are a GUI agent. Given the instruction, return a single action on the current screen.\n\n"
|
||||
"## Output Format\n\n"
|
||||
"Action: click(point='(x,y)')\n\n"
|
||||
"## User Instruction\n"
|
||||
f"{instruction}"
|
||||
)
|
||||
|
||||
# Build messages with image
|
||||
litellm_messages: List[Dict[str, Any]] = [
|
||||
{"role": "system", "content": system_text},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Please return a single click action."},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
api_kwargs: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": litellm_messages,
|
||||
"max_tokens": kwargs.get("max_tokens", 512),
|
||||
"temperature": kwargs.get("temperature", 0.0),
|
||||
"do_sample": kwargs.get("temperature", 0.0) > 0.0,
|
||||
}
|
||||
api_kwargs.update(
|
||||
{k: v for k, v in (kwargs or {}).items() if k not in ["max_tokens", "temperature"]}
|
||||
)
|
||||
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
# Extract response content
|
||||
response_dict = response.model_dump() # type: ignore
|
||||
choices = response_dict.get("choices", [])
|
||||
if not choices:
|
||||
return None
|
||||
msg = choices[0].get("message", {})
|
||||
content_text = msg.get("content", "")
|
||||
if isinstance(content_text, list):
|
||||
text_parts = [
|
||||
p.get("text", "")
|
||||
for p in content_text
|
||||
if isinstance(p, dict) and p.get("type") == "text"
|
||||
]
|
||||
content_text = "\n".join([t for t in text_parts if t])
|
||||
if not isinstance(content_text, str):
|
||||
return None
|
||||
|
||||
# Parse coordinates
|
||||
# Pattern for click(point='(x,y)') or click(start_box='(x,y)')
|
||||
patterns = [
|
||||
r"click\(point='\((\d+),(\d+)\)'\)",
|
||||
r"click\((?:start_box|point)='\((\d+),(\d+)\)'\)",
|
||||
]
|
||||
for pat in patterns:
|
||||
m = re.search(pat, content_text)
|
||||
if m:
|
||||
try:
|
||||
x, y = int(m.group(1)), int(m.group(2))
|
||||
return (x, y)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
@@ -22,14 +22,14 @@ async def test_http_endpoint():
|
||||
|
||||
# Example 1: Simple text request
|
||||
simple_request = {
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"model": "anthropic/claude-sonnet-4-5-20250929",
|
||||
"input": "Tell me a three sentence bedtime story about a unicorn.",
|
||||
"env": {"ANTHROPIC_API_KEY": anthropic_api_key},
|
||||
}
|
||||
|
||||
# Example 2: Multi-modal request with image
|
||||
multimodal_request = {
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"model": "anthropic/claude-sonnet-4-5-20250929",
|
||||
"input": [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -47,7 +47,7 @@ async def test_http_endpoint():
|
||||
|
||||
# Example 3: Request with custom agent and computer kwargs
|
||||
custom_request = {
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"model": "anthropic/claude-sonnet-4-5-20250929",
|
||||
"input": "Take a screenshot and tell me what you see",
|
||||
"env": {"ANTHROPIC_API_KEY": anthropic_api_key},
|
||||
}
|
||||
@@ -95,7 +95,7 @@ def curl_examples():
|
||||
"""curl http://localhost:8000/responses \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"model": "anthropic/claude-sonnet-4-5-20250929",
|
||||
"input": "Tell me a three sentence bedtime story about a unicorn."
|
||||
}'"""
|
||||
)
|
||||
@@ -105,7 +105,7 @@ def curl_examples():
|
||||
"""curl http://localhost:8000/responses \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"model": "anthropic/claude-sonnet-4-5-20250929",
|
||||
"input": [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -126,7 +126,7 @@ def curl_examples():
|
||||
"""curl http://localhost:8000/responses \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"model": "anthropic/claude-sonnet-4-5-20250929",
|
||||
"input": "Take a screenshot and tell me what you see",
|
||||
"agent_kwargs": {
|
||||
"save_trajectory": true,
|
||||
@@ -166,7 +166,7 @@ async def test_p2p_client():
|
||||
|
||||
# Send a test request
|
||||
request = {
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"model": "anthropic/claude-sonnet-4-5-20250929",
|
||||
"input": "Hello from P2P client!",
|
||||
}
|
||||
await connection.send(json.dumps(request))
|
||||
|
||||
@@ -442,7 +442,9 @@ def get_all_element_descriptions(responses_items: List[Dict[str, Any]]) -> List[
|
||||
|
||||
# Conversion functions between responses_items and completion messages formats
|
||||
def convert_responses_items_to_completion_messages(
|
||||
messages: List[Dict[str, Any]], allow_images_in_tool_results: bool = True
|
||||
messages: List[Dict[str, Any]],
|
||||
allow_images_in_tool_results: bool = True,
|
||||
send_multiple_user_images_per_parallel_tool_results: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert responses_items message format to liteLLM completion format.
|
||||
|
||||
@@ -450,10 +452,11 @@ def convert_responses_items_to_completion_messages(
|
||||
messages: List of responses_items format messages
|
||||
allow_images_in_tool_results: If True, include images in tool role messages.
|
||||
If False, send tool message + separate user message with image.
|
||||
send_multiple_user_images_per_parallel_tool_results: If True, send multiple user images in parallel tool results.
|
||||
"""
|
||||
completion_messages = []
|
||||
|
||||
for message in messages:
|
||||
for i, message in enumerate(messages):
|
||||
msg_type = message.get("type")
|
||||
role = message.get("role")
|
||||
|
||||
@@ -561,6 +564,14 @@ def convert_responses_items_to_completion_messages(
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Determine if the next message is also a tool call output
|
||||
next_type = None
|
||||
if i + 1 < len(messages):
|
||||
next_msg = messages[i + 1]
|
||||
next_type = next_msg.get("type")
|
||||
is_next_message_image_result = next_type in [
|
||||
"computer_call_output",
|
||||
]
|
||||
# Send tool message + separate user message with image (OpenAI compatible)
|
||||
completion_messages += [
|
||||
{
|
||||
@@ -574,6 +585,12 @@ def convert_responses_items_to_completion_messages(
|
||||
{"type": "image_url", "image_url": {"url": output.get("image_url")}}
|
||||
],
|
||||
},
|
||||
] if send_multiple_user_images_per_parallel_tool_results or (not is_next_message_image_result) else [
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": "[Execution completed. See screenshot below]",
|
||||
},
|
||||
]
|
||||
else:
|
||||
# Handle text output as tool response
|
||||
|
||||
@@ -6,9 +6,9 @@ with an advanced UI for model selection and configuration.
|
||||
|
||||
Supported Agent Models:
|
||||
- OpenAI: openai/computer-use-preview
|
||||
- Anthropic: anthropic/claude-3-5-sonnet-20241022, anthropic/claude-3-7-sonnet-20250219
|
||||
- Anthropic: anthropic/claude-sonnet-4-5-20250929, anthropic/claude-3-7-sonnet-20250219
|
||||
- UI-TARS: huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B
|
||||
- Omniparser: omniparser+anthropic/claude-3-5-sonnet-20241022, omniparser+ollama_chat/gemma3
|
||||
- Omniparser: omniparser+anthropic/claude-sonnet-4-5-20250929, omniparser+ollama_chat/gemma3
|
||||
|
||||
Requirements:
|
||||
- Mac with Apple Silicon (M1/M2/M3/M4), Linux, or Windows
|
||||
@@ -116,14 +116,12 @@ MODEL_MAPPINGS = {
|
||||
"Anthropic: Claude 4 Opus (20250514)": "anthropic/claude-opus-4-20250514",
|
||||
"Anthropic: Claude 4 Sonnet (20250514)": "anthropic/claude-sonnet-4-20250514",
|
||||
"Anthropic: Claude 3.7 Sonnet (20250219)": "anthropic/claude-3-7-sonnet-20250219",
|
||||
"Anthropic: Claude 3.5 Sonnet (20241022)": "anthropic/claude-3-5-sonnet-20241022",
|
||||
},
|
||||
"omni": {
|
||||
"default": "omniparser+openai/gpt-4o",
|
||||
"OMNI: OpenAI GPT-4o": "omniparser+openai/gpt-4o",
|
||||
"OMNI: OpenAI GPT-4o mini": "omniparser+openai/gpt-4o-mini",
|
||||
"OMNI: Claude 3.7 Sonnet (20250219)": "omniparser+anthropic/claude-3-7-sonnet-20250219",
|
||||
"OMNI: Claude 3.5 Sonnet (20241022)": "omniparser+anthropic/claude-3-5-sonnet-20241022",
|
||||
},
|
||||
"uitars": {
|
||||
"default": "huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B" if is_mac else "ui-tars",
|
||||
|
||||
@@ -44,13 +44,11 @@ def create_gradio_ui() -> gr.Blocks:
|
||||
"Anthropic: Claude 4 Opus (20250514)",
|
||||
"Anthropic: Claude 4 Sonnet (20250514)",
|
||||
"Anthropic: Claude 3.7 Sonnet (20250219)",
|
||||
"Anthropic: Claude 3.5 Sonnet (20241022)",
|
||||
]
|
||||
omni_models = [
|
||||
"OMNI: OpenAI GPT-4o",
|
||||
"OMNI: OpenAI GPT-4o mini",
|
||||
"OMNI: Claude 3.7 Sonnet (20250219)",
|
||||
"OMNI: Claude 3.5 Sonnet (20241022)",
|
||||
]
|
||||
|
||||
# Check if API keys are available
|
||||
|
||||
@@ -102,7 +102,7 @@ async def main():
|
||||
# model="anthropic/claude-opus-4-20250514",
|
||||
# model="anthropic/claude-sonnet-4-20250514",
|
||||
# model="anthropic/claude-3-7-sonnet-20250219",
|
||||
# model="anthropic/claude-3-5-sonnet-20241022",
|
||||
# model="anthropic/claude-sonnet-4-5-20250929",
|
||||
# == UI-TARS ==
|
||||
# model="huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B",
|
||||
# TODO: add local mlx provider
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
|
||||
|
||||
[project]
|
||||
name = "cua-agent"
|
||||
version = "0.4.37"
|
||||
version = "0.5.1"
|
||||
description = "CUA (Computer Use) Agent for AI-driven computer interaction"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
@@ -100,8 +100,6 @@ all = [
|
||||
"python-dotenv>=1.0.1",
|
||||
# cli requirements
|
||||
"yaspin>=3.1.0",
|
||||
# hud requirements
|
||||
"hud-python==0.4.52",
|
||||
# gemini requirements
|
||||
"google-genai>=1.41.0",
|
||||
# qwen requirements
|
||||
|
||||
@@ -24,7 +24,7 @@ def mock_litellm():
|
||||
"id": "chatcmpl-test123",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": kwargs.get("model", "anthropic/claude-3-5-sonnet-20241022"),
|
||||
"model": kwargs.get("model", "anthropic/claude-sonnet-4-5-20250929"),
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
|
||||
@@ -18,18 +18,18 @@ class TestComputerAgentInitialization:
|
||||
"""Test that agent can be initialized with a model string."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
agent = ComputerAgent(model="anthropic/claude-sonnet-4-5-20250929")
|
||||
|
||||
assert agent is not None
|
||||
assert hasattr(agent, "model")
|
||||
assert agent.model == "anthropic/claude-3-5-sonnet-20241022"
|
||||
assert agent.model == "anthropic/claude-sonnet-4-5-20250929"
|
||||
|
||||
@patch("agent.agent.litellm")
|
||||
def test_agent_initialization_with_tools(self, mock_litellm, disable_telemetry, mock_computer):
|
||||
"""Test that agent can be initialized with tools."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022", tools=[mock_computer])
|
||||
agent = ComputerAgent(model="anthropic/claude-sonnet-4-5-20250929", tools=[mock_computer])
|
||||
|
||||
assert agent is not None
|
||||
assert hasattr(agent, "tools")
|
||||
@@ -41,7 +41,7 @@ class TestComputerAgentInitialization:
|
||||
|
||||
budget = 5.0
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022", max_trajectory_budget=budget
|
||||
model="anthropic/claude-sonnet-4-5-20250929", max_trajectory_budget=budget
|
||||
)
|
||||
|
||||
assert agent is not None
|
||||
@@ -79,7 +79,7 @@ class TestComputerAgentRun:
|
||||
|
||||
mock_litellm.acompletion = AsyncMock(return_value=mock_response)
|
||||
|
||||
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
agent = ComputerAgent(model="anthropic/claude-sonnet-4-5-20250929")
|
||||
|
||||
# Run should return an async generator
|
||||
result_generator = agent.run(sample_messages)
|
||||
@@ -92,7 +92,7 @@ class TestComputerAgentRun:
|
||||
"""Test that agent has run method available."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
agent = ComputerAgent(model="anthropic/claude-sonnet-4-5-20250929")
|
||||
|
||||
# Verify run method exists
|
||||
assert hasattr(agent, "run")
|
||||
@@ -102,7 +102,7 @@ class TestComputerAgentRun:
|
||||
"""Test that agent has agent_loop initialized."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
agent = ComputerAgent(model="anthropic/claude-sonnet-4-5-20250929")
|
||||
|
||||
# Verify agent_loop is initialized
|
||||
assert hasattr(agent, "agent_loop")
|
||||
@@ -132,7 +132,7 @@ class TestComputerAgentIntegration:
|
||||
"""Test that agent can be initialized with Computer tool."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022", tools=[mock_computer])
|
||||
agent = ComputerAgent(model="anthropic/claude-sonnet-4-5-20250929", tools=[mock_computer])
|
||||
|
||||
# Verify agent accepted the tool
|
||||
assert agent is not None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.1.29
|
||||
current_version = 0.1.30
|
||||
commit = True
|
||||
tag = True
|
||||
tag_name = computer-server-v{new_version}
|
||||
|
||||
@@ -36,11 +36,11 @@ pip install cua-computer-server
|
||||
|
||||
Refer to this notebook for a step-by-step guide on how to use the Computer-Use Server on the host system or VM:
|
||||
|
||||
- [Computer-Use Server](../../notebooks/computer_server_nb.ipynb)
|
||||
- [Computer-Use Server](../../../notebooks/computer_server_nb.ipynb)
|
||||
|
||||
## Docs
|
||||
|
||||
- [Commands](https://cua.ai/docs/libraries/computer-server/Commands)
|
||||
- [REST-API](https://cua.ai/docs/libraries/computer-server/REST-API)
|
||||
- [WebSocket-API](https://cua.ai/docs/libraries/computer-server/WebSocket-API)
|
||||
- [Index](https://cua.ai/docs/libraries/computer-server/index)
|
||||
- [Index](https://cua.ai/docs/libraries/computer-server)
|
||||
|
||||
@@ -1287,7 +1287,15 @@ class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
if not isinstance(screenshot, Image.Image):
|
||||
return {"success": False, "error": "Failed to capture screenshot"}
|
||||
|
||||
# Resize image to reduce size (max width 1920, maintain aspect ratio)
|
||||
max_width = 1920
|
||||
if screenshot.width > max_width:
|
||||
ratio = max_width / screenshot.width
|
||||
new_height = int(screenshot.height * ratio)
|
||||
screenshot = screenshot.resize((max_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
buffered = BytesIO()
|
||||
# Use PNG format with optimization to reduce file size
|
||||
screenshot.save(buffered, format="PNG", optimize=True)
|
||||
buffered.seek(0)
|
||||
image_data = base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
@@ -75,14 +75,23 @@ class Watchdog:
|
||||
Returns:
|
||||
WebSocket URI for the Computer API Server
|
||||
"""
|
||||
ip_address = (
|
||||
"localhost"
|
||||
if not self.container_name
|
||||
else f"{self.container_name}.containers.cloud.trycua.com"
|
||||
)
|
||||
protocol = "wss" if self.container_name else "ws"
|
||||
port = "8443" if self.container_name else "8000"
|
||||
return f"{protocol}://{ip_address}:{port}/ws"
|
||||
if not self.container_name:
|
||||
return "ws://localhost:8000/ws"
|
||||
|
||||
# Try .sandbox.cua.ai first, fallback to .containers.cloud.trycua.com
|
||||
return f"wss://{self.container_name}.sandbox.cua.ai:8443/ws"
|
||||
|
||||
@property
|
||||
def ws_uri_fallback(self) -> str:
|
||||
"""Get the fallback WebSocket URI using legacy hostname.
|
||||
|
||||
Returns:
|
||||
Fallback WebSocket URI for the Computer API Server
|
||||
"""
|
||||
if not self.container_name:
|
||||
return "ws://localhost:8000/ws"
|
||||
|
||||
return f"wss://{self.container_name}.containers.cloud.trycua.com:8443/ws"
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""
|
||||
@@ -91,11 +100,11 @@ class Watchdog:
|
||||
Returns:
|
||||
True if connection successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Create a simple ping message
|
||||
ping_message = {"command": "get_screen_size", "params": {}}
|
||||
# Create a simple ping message
|
||||
ping_message = {"command": "get_screen_size", "params": {}}
|
||||
|
||||
# Try to connect to the WebSocket
|
||||
# Try primary URI first (.sandbox.cua.ai)
|
||||
try:
|
||||
async with websockets.connect(
|
||||
self.ws_uri, max_size=1024 * 1024 * 10 # 10MB limit to match server
|
||||
) as websocket:
|
||||
@@ -105,13 +114,40 @@ class Watchdog:
|
||||
# Wait for any response or just close
|
||||
try:
|
||||
response = await asyncio.wait_for(websocket.recv(), timeout=5)
|
||||
logger.debug(f"Ping response received: {response[:100]}...")
|
||||
logger.debug(f"Ping response received from primary URI: {response[:100]}...")
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Ping failed: {e}")
|
||||
return False
|
||||
logger.debug(f"Primary URI ping failed: {e}")
|
||||
|
||||
# Try fallback URI (.containers.cloud.trycua.com)
|
||||
if self.container_name:
|
||||
try:
|
||||
async with websockets.connect(
|
||||
self.ws_uri_fallback,
|
||||
max_size=1024 * 1024 * 10, # 10MB limit to match server
|
||||
) as websocket:
|
||||
# Send ping message
|
||||
await websocket.send(json.dumps(ping_message))
|
||||
|
||||
# Wait for any response or just close
|
||||
try:
|
||||
response = await asyncio.wait_for(websocket.recv(), timeout=5)
|
||||
logger.debug(
|
||||
f"Ping response received from fallback URI: {response[:100]}..."
|
||||
)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
return False
|
||||
except Exception as fallback_e:
|
||||
logger.warning(
|
||||
f"Both primary and fallback ping failed. Primary: {e}, Fallback: {fallback_e}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logger.warning(f"Ping failed: {e}")
|
||||
return False
|
||||
|
||||
def kill_processes_on_port(self, port: int) -> bool:
|
||||
"""
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
|
||||
|
||||
[project]
|
||||
name = "cua-computer-server"
|
||||
version = "0.1.29"
|
||||
version = "0.1.30"
|
||||
|
||||
description = "Server component for the Computer-Use Interface (CUI) framework powering Cua"
|
||||
authors = [
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.4.11
|
||||
current_version = 0.4.17
|
||||
commit = True
|
||||
tag = True
|
||||
tag_name = computer-v{new_version}
|
||||
|
||||
@@ -40,7 +40,7 @@ try:
|
||||
await computer.interface.right_click(300, 300)
|
||||
await computer.interface.double_click(400, 400)
|
||||
|
||||
await computer.interface.type("Hello, World!")
|
||||
await computer.interface.type_text("Hello, World!")
|
||||
await computer.interface.press_key("enter")
|
||||
|
||||
await computer.interface.set_clipboard("Test clipboard")
|
||||
|
||||
@@ -107,13 +107,17 @@ class Computer:
|
||||
host: Host to use for VM provider connections (e.g. "localhost", "host.docker.internal")
|
||||
storage: Optional path for persistent VM storage (Lumier provider)
|
||||
ephemeral: Whether to use ephemeral storage
|
||||
api_key: Optional API key for cloud providers
|
||||
api_key: Optional API key for cloud providers (defaults to CUA_API_KEY environment variable)
|
||||
experiments: Optional list of experimental features to enable (e.g. ["app-use"])
|
||||
"""
|
||||
|
||||
self.logger = Logger("computer", verbosity)
|
||||
self.logger.info("Initializing Computer...")
|
||||
|
||||
# Fall back to environment variable for api_key if not provided
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("CUA_API_KEY")
|
||||
|
||||
if not image:
|
||||
if os_type == "macos":
|
||||
image = "macos-sequoia-cua:latest"
|
||||
|
||||
@@ -31,21 +31,26 @@ class CloudProvider(BaseVMProvider):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
api_key: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
api_base: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
api_key: API key for authentication
|
||||
api_key: API key for authentication (defaults to CUA_API_KEY environment variable)
|
||||
name: Name of the VM
|
||||
verbose: Enable verbose logging
|
||||
"""
|
||||
assert api_key, "api_key required for CloudProvider"
|
||||
# Fall back to environment variable if api_key not provided
|
||||
if api_key is None:
|
||||
api_key = os.getenv("CUA_API_KEY")
|
||||
assert api_key, "api_key required for CloudProvider (provide via parameter or CUA_API_KEY environment variable)"
|
||||
self.api_key = api_key
|
||||
self.verbose = verbose
|
||||
self.api_base = (api_base or DEFAULT_API_BASE).rstrip("/")
|
||||
# Host caching dictionary: {vm_name: host_string}
|
||||
self._host_cache: Dict[str, str] = {}
|
||||
|
||||
@property
|
||||
def provider_type(self) -> VMProviderType:
|
||||
@@ -60,12 +65,12 @@ class CloudProvider(BaseVMProvider):
|
||||
async def get_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get VM information by querying the VM status endpoint.
|
||||
|
||||
- Build hostname via get_ip(name) → "{name}.containers.cloud.trycua.com"
|
||||
- Build hostname via _get_host_for_vm(name) using cached host or fallback
|
||||
- Probe https://{hostname}:8443/status with a short timeout
|
||||
- If JSON contains a "status" field, return it; otherwise infer
|
||||
- Fallback to DNS resolve check to distinguish unknown vs not_found
|
||||
"""
|
||||
hostname = await self.get_ip(name=name)
|
||||
hostname = await self._get_host_for_vm(name)
|
||||
|
||||
# Try HTTPS probe to the computer-server status endpoint (8443)
|
||||
try:
|
||||
@@ -118,8 +123,20 @@ class CloudProvider(BaseVMProvider):
|
||||
vm = dict(item) if isinstance(item, dict) else {}
|
||||
name = vm.get("name")
|
||||
password = vm.get("password")
|
||||
api_host = vm.get("host") # Read host from API response
|
||||
|
||||
if isinstance(name, str) and name:
|
||||
host = f"{name}.containers.cloud.trycua.com"
|
||||
# Use host from API if available, otherwise fallback to legacy format
|
||||
if isinstance(api_host, str) and api_host:
|
||||
host = api_host
|
||||
# Cache the host for this VM
|
||||
self._host_cache[name] = host
|
||||
else:
|
||||
# Legacy fallback
|
||||
host = f"{name}.containers.cloud.trycua.com"
|
||||
# Cache the legacy host
|
||||
self._host_cache[name] = host
|
||||
|
||||
# api_url: always set if missing
|
||||
if not vm.get("api_url"):
|
||||
vm["api_url"] = f"https://{host}:8443"
|
||||
@@ -227,15 +244,73 @@ class CloudProvider(BaseVMProvider):
|
||||
"message": "update_vm not supported by public API",
|
||||
}
|
||||
|
||||
async def _get_host_for_vm(self, name: str) -> str:
|
||||
"""
|
||||
Get the host for a VM, trying multiple approaches:
|
||||
1. Check cache first
|
||||
2. Try to refresh cache by calling list_vms
|
||||
3. Try .sandbox.cua.ai format
|
||||
4. Fallback to legacy .containers.cloud.trycua.com format
|
||||
|
||||
Args:
|
||||
name: VM name
|
||||
|
||||
Returns:
|
||||
Host string for the VM
|
||||
"""
|
||||
# Check cache first
|
||||
if name in self._host_cache:
|
||||
return self._host_cache[name]
|
||||
|
||||
# Try to refresh cache by calling list_vms
|
||||
try:
|
||||
await self.list_vms()
|
||||
# Check cache again after refresh
|
||||
if name in self._host_cache:
|
||||
return self._host_cache[name]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to refresh VM list for host lookup: {e}")
|
||||
|
||||
# Try .sandbox.cua.ai format first
|
||||
sandbox_host = f"{name}.sandbox.cua.ai"
|
||||
if await self._test_host_connectivity(sandbox_host):
|
||||
self._host_cache[name] = sandbox_host
|
||||
return sandbox_host
|
||||
|
||||
# Fallback to legacy format
|
||||
legacy_host = f"{name}.containers.cloud.trycua.com"
|
||||
# Cache the legacy host
|
||||
self._host_cache[name] = legacy_host
|
||||
return legacy_host
|
||||
|
||||
async def _test_host_connectivity(self, hostname: str) -> bool:
|
||||
"""
|
||||
Test if a host is reachable by trying to connect to its status endpoint.
|
||||
|
||||
Args:
|
||||
hostname: Host to test
|
||||
|
||||
Returns:
|
||||
True if host is reachable, False otherwise
|
||||
"""
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=2) # Short timeout for connectivity test
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
url = f"https://{hostname}:8443/status"
|
||||
async with session.get(url, allow_redirects=False) as resp:
|
||||
# Any response (even error) means the host is reachable
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_ip(
|
||||
self, name: Optional[str] = None, storage: Optional[str] = None, retry_delay: int = 2
|
||||
) -> str:
|
||||
"""
|
||||
Return the VM's IP address as '{container_name}.containers.cloud.trycua.com'.
|
||||
Uses the provided 'name' argument (the VM name requested by the caller),
|
||||
falling back to self.name only if 'name' is None.
|
||||
Retries up to 3 times with retry_delay seconds if hostname is not available.
|
||||
Return the VM's host address, trying to use cached host from API or falling back to legacy format.
|
||||
Uses the provided 'name' argument (the VM name requested by the caller).
|
||||
"""
|
||||
if name is None:
|
||||
raise ValueError("VM name is required for CloudProvider.get_ip")
|
||||
return f"{name}.containers.cloud.trycua.com"
|
||||
|
||||
return await self._get_host_for_vm(name)
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
|
||||
|
||||
[project]
|
||||
name = "cua-computer"
|
||||
version = "0.4.11"
|
||||
version = "0.4.17"
|
||||
description = "Computer-Use Interface (CUI) framework powering Cua"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
|
||||
@@ -44,13 +44,12 @@ class PostHogTelemetryClient:
|
||||
@classmethod
|
||||
def is_telemetry_enabled(cls) -> bool:
|
||||
"""True if telemetry is currently active for this process."""
|
||||
return (
|
||||
# Legacy opt-out flag
|
||||
os.environ.get("CUA_TELEMETRY", "").lower() != "off"
|
||||
# Opt-in flag (defaults to enabled)
|
||||
and os.environ.get("CUA_TELEMETRY_ENABLED", "true").lower()
|
||||
in {"1", "true", "yes", "on"}
|
||||
)
|
||||
return os.environ.get("CUA_TELEMETRY_ENABLED", "true").lower() in {
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
"on",
|
||||
}
|
||||
|
||||
def _get_or_create_installation_id(self) -> str:
|
||||
"""Get or create a unique installation ID that persists across runs.
|
||||
|
||||
@@ -24,15 +24,7 @@ class TestTelemetryEnabled:
|
||||
|
||||
assert is_telemetry_enabled() is True
|
||||
|
||||
def test_telemetry_disabled_with_legacy_flag(self, monkeypatch):
|
||||
"""Test that telemetry can be disabled with legacy CUA_TELEMETRY=off."""
|
||||
monkeypatch.setenv("CUA_TELEMETRY", "off")
|
||||
|
||||
from core.telemetry import is_telemetry_enabled
|
||||
|
||||
assert is_telemetry_enabled() is False
|
||||
|
||||
def test_telemetry_disabled_with_new_flag(self, monkeypatch):
|
||||
def test_telemetry_disabled_with_flag(self, monkeypatch):
|
||||
"""Test that telemetry can be disabled with CUA_TELEMETRY_ENABLED=false."""
|
||||
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", "false")
|
||||
|
||||
|
||||
@@ -133,7 +133,7 @@ await cleanup_session(ctx, "session-to-cleanup")
|
||||
|
||||
### Environment Variables
|
||||
|
||||
- `CUA_MODEL_NAME`: Model to use (default: `anthropic/claude-3-5-sonnet-20241022`)
|
||||
- `CUA_MODEL_NAME`: Model to use (default: `anthropic/claude-sonnet-4-5-20250929`)
|
||||
- `CUA_MAX_IMAGES`: Maximum images to keep (default: `3`)
|
||||
|
||||
### Session Manager Configuration
|
||||
|
||||
63
libs/python/mcp-server/QUICK_TEST_COMMANDS.sh
Executable file
63
libs/python/mcp-server/QUICK_TEST_COMMANDS.sh
Executable file
@@ -0,0 +1,63 @@
|
||||
#!/bin/bash
|
||||
# Quick Test Commands for MCP Server Local Desktop Option
|
||||
# Run these commands to test the implementation
|
||||
|
||||
set -e # Exit on error
|
||||
|
||||
echo "======================================================================"
|
||||
echo "Testing MCP Server Local Desktop Option"
|
||||
echo "======================================================================"
|
||||
echo ""
|
||||
|
||||
# Change to repo root
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
# Test 1: Quick Logic Test (No setup required)
|
||||
echo "Test 1: Quick Logic Test (No setup required)"
|
||||
echo "----------------------------------------------------------------------"
|
||||
python tests/quick_test_local_option.py
|
||||
echo ""
|
||||
|
||||
# Test 2: Automated Tests (Requires pytest and packages)
|
||||
echo "Test 2: Automated Tests (Requires pytest and packages installed)"
|
||||
echo "----------------------------------------------------------------------"
|
||||
if command -v pytest &> /dev/null; then
|
||||
echo "Running pytest..."
|
||||
pytest tests/test_mcp_server_local_option.py -v || echo "Note: Some tests may require full setup"
|
||||
else
|
||||
echo "⚠️ pytest not found. Install with: pip install pytest"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Test 3: Existing MCP server tests
|
||||
echo "Test 3: Existing MCP Server Tests"
|
||||
echo "----------------------------------------------------------------------"
|
||||
if command -v pytest &> /dev/null; then
|
||||
echo "Running existing session management tests..."
|
||||
pytest tests/test_mcp_server_session_management.py -v || echo "Note: Some tests may fail if dependencies are missing"
|
||||
else
|
||||
echo "⚠️ pytest not found. Install with: pip install pytest"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Summary
|
||||
echo "======================================================================"
|
||||
echo "Test Summary"
|
||||
echo "======================================================================"
|
||||
echo "✅ Quick logic test completed"
|
||||
echo ""
|
||||
echo "Next steps for comprehensive testing:"
|
||||
echo "1. Install dependencies:"
|
||||
echo " pip install -e libs/python/core"
|
||||
echo " pip install -e libs/python/computer"
|
||||
echo " pip install -e libs/python/agent"
|
||||
echo " pip install -e libs/python/mcp-server"
|
||||
echo " pip install -e libs/python/computer-server"
|
||||
echo ""
|
||||
echo "2. For manual end-to-end testing, see:"
|
||||
echo " tests/MANUAL_TEST_LOCAL_OPTION.md"
|
||||
echo ""
|
||||
echo "3. For detailed testing info, see:"
|
||||
echo " tests/TESTING_SUMMARY.md"
|
||||
echo ""
|
||||
|
||||
@@ -44,7 +44,7 @@ Add this to your MCP client configuration:
|
||||
"args": [
|
||||
"bash",
|
||||
"-lc",
|
||||
"export CUA_MODEL_NAME='anthropic/claude-3-5-sonnet-20241022'; ~/.cua/start_mcp_server.sh"
|
||||
"export CUA_MODEL_NAME='anthropic/claude-sonnet-4-5-20250929'; ~/.cua/start_mcp_server.sh"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -156,7 +156,7 @@ def serve() -> FastMCP:
|
||||
|
||||
try:
|
||||
# Get model name
|
||||
model_name = os.getenv("CUA_MODEL_NAME", "anthropic/claude-3-5-sonnet-20241022")
|
||||
model_name = os.getenv("CUA_MODEL_NAME", "anthropic/claude-sonnet-4-5-20250929")
|
||||
logger.info(f"Using model: {model_name}")
|
||||
|
||||
# Create agent with the new v0.4.x API
|
||||
|
||||
@@ -10,6 +10,7 @@ This module provides:
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
import weakref
|
||||
@@ -57,7 +58,14 @@ class ComputerPool:
|
||||
logger.debug("Creating new computer instance")
|
||||
from computer import Computer
|
||||
|
||||
computer = Computer(verbosity=logging.INFO)
|
||||
# Check if we should use host computer server
|
||||
use_host = os.getenv("CUA_USE_HOST_COMPUTER_SERVER", "false").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
)
|
||||
|
||||
computer = Computer(verbosity=logging.INFO, use_host_computer_server=use_host)
|
||||
await computer.run()
|
||||
self._in_use.add(computer)
|
||||
return computer
|
||||
|
||||
244
libs/python/mcp-server/quick_test_local_option.py
Executable file
244
libs/python/mcp-server/quick_test_local_option.py
Executable file
@@ -0,0 +1,244 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick test to verify the local desktop option logic without full setup.
|
||||
|
||||
This script tests the environment variable parsing and logic flow
|
||||
without requiring VMs, computer-server, or MCP clients to be running.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def test_env_var_parsing():
|
||||
"""Test that environment variable is parsed correctly."""
|
||||
print("Testing CUA_USE_HOST_COMPUTER_SERVER environment variable parsing...")
|
||||
print("-" * 60)
|
||||
|
||||
test_cases = [
|
||||
# (env_value, expected_result, description)
|
||||
("true", True, "lowercase 'true'"),
|
||||
("True", True, "capitalized 'True'"),
|
||||
("TRUE", True, "uppercase 'TRUE'"),
|
||||
("1", True, "numeric '1'"),
|
||||
("yes", True, "lowercase 'yes'"),
|
||||
("Yes", True, "capitalized 'Yes'"),
|
||||
("false", False, "lowercase 'false'"),
|
||||
("False", False, "capitalized 'False'"),
|
||||
("FALSE", False, "uppercase 'FALSE'"),
|
||||
("0", False, "numeric '0'"),
|
||||
("no", False, "lowercase 'no'"),
|
||||
("", False, "empty string"),
|
||||
("random", False, "random value"),
|
||||
(None, False, "not set (None)"),
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for env_value, expected, description in test_cases:
|
||||
# Simulate the logic from session_manager.py line 59
|
||||
if env_value is None:
|
||||
actual = os.getenv("CUA_USE_HOST_COMPUTER_SERVER", "false").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
)
|
||||
else:
|
||||
os.environ["CUA_USE_HOST_COMPUTER_SERVER"] = env_value
|
||||
actual = os.getenv("CUA_USE_HOST_COMPUTER_SERVER", "false").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
)
|
||||
|
||||
status = "✓ PASS" if actual == expected else "✗ FAIL"
|
||||
if actual == expected:
|
||||
passed += 1
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
print(
|
||||
f"{status} | Value: {env_value!r:15} | Expected: {expected!s:5} | Got: {actual!s:5} | {description}"
|
||||
)
|
||||
|
||||
# Clean up
|
||||
os.environ.pop("CUA_USE_HOST_COMPUTER_SERVER", None)
|
||||
|
||||
print("-" * 60)
|
||||
print(f"Results: {passed} passed, {failed} failed")
|
||||
return failed == 0
|
||||
|
||||
|
||||
def test_session_manager_logic():
|
||||
"""Test the logic flow in session_manager.py without actual Computer creation."""
|
||||
print("\nTesting session_manager.py logic flow...")
|
||||
print("-" * 60)
|
||||
|
||||
# Read the actual session_manager.py to verify the logic
|
||||
import pathlib
|
||||
|
||||
session_manager_path = (
|
||||
pathlib.Path(__file__).parent.parent
|
||||
/ "libs"
|
||||
/ "python"
|
||||
/ "mcp-server"
|
||||
/ "mcp_server"
|
||||
/ "session_manager.py"
|
||||
)
|
||||
|
||||
if not session_manager_path.exists():
|
||||
print(f"✗ FAIL | session_manager.py not found at {session_manager_path}")
|
||||
return False
|
||||
|
||||
content = session_manager_path.read_text()
|
||||
|
||||
# Check for the key logic
|
||||
checks = [
|
||||
('os.getenv("CUA_USE_HOST_COMPUTER_SERVER"', "Environment variable check present"),
|
||||
("use_host_computer_server=use_host", "use_host_computer_server parameter passed"),
|
||||
("Computer(", "Computer instantiation present"),
|
||||
]
|
||||
|
||||
all_checks_passed = True
|
||||
for check_str, description in checks:
|
||||
if check_str in content:
|
||||
print(f"✓ PASS | {description}")
|
||||
else:
|
||||
print(f"✗ FAIL | {description} - not found")
|
||||
all_checks_passed = False
|
||||
|
||||
print("-" * 60)
|
||||
return all_checks_passed
|
||||
|
||||
|
||||
def test_documentation_consistency():
|
||||
"""Verify documentation mentions the new feature."""
|
||||
print("\nTesting documentation consistency...")
|
||||
print("-" * 60)
|
||||
|
||||
import pathlib
|
||||
|
||||
docs_to_check = [
|
||||
("configuration.mdx", "CUA_USE_HOST_COMPUTER_SERVER"),
|
||||
("usage.mdx", "Targeting Your Local Desktop"),
|
||||
]
|
||||
|
||||
docs_path = (
|
||||
pathlib.Path(__file__).parent.parent
|
||||
/ "docs"
|
||||
/ "content"
|
||||
/ "docs"
|
||||
/ "libraries"
|
||||
/ "mcp-server"
|
||||
)
|
||||
|
||||
all_docs_ok = True
|
||||
for doc_file, expected_content in docs_to_check:
|
||||
doc_path = docs_path / doc_file
|
||||
if not doc_path.exists():
|
||||
print(f"✗ FAIL | {doc_file} not found")
|
||||
all_docs_ok = False
|
||||
continue
|
||||
|
||||
content = doc_path.read_text()
|
||||
if expected_content in content:
|
||||
print(f"✓ PASS | {doc_file} contains '{expected_content}'")
|
||||
else:
|
||||
print(f"✗ FAIL | {doc_file} missing '{expected_content}'")
|
||||
all_docs_ok = False
|
||||
|
||||
print("-" * 60)
|
||||
return all_docs_ok
|
||||
|
||||
|
||||
def print_usage_examples():
|
||||
"""Print usage examples for both modes."""
|
||||
print("\n" + "=" * 60)
|
||||
print("USAGE EXAMPLES")
|
||||
print("=" * 60)
|
||||
|
||||
print("\n1. DEFAULT MODE (VM):")
|
||||
print("-" * 60)
|
||||
print(
|
||||
"""
|
||||
{
|
||||
"mcpServers": {
|
||||
"cua-agent": {
|
||||
"command": "/bin/bash",
|
||||
"args": ["~/.cua/start_mcp_server.sh"],
|
||||
"env": {
|
||||
"CUA_MODEL_NAME": "anthropic/claude-sonnet-4-5-20250929"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Note: CUA_USE_HOST_COMPUTER_SERVER is not set, so VM mode is used (safe).
|
||||
"""
|
||||
)
|
||||
|
||||
print("\n2. LOCAL DESKTOP MODE:")
|
||||
print("-" * 60)
|
||||
print(
|
||||
"""
|
||||
Step 1: Start computer-server locally:
|
||||
python -m computer_server
|
||||
|
||||
Step 2: Configure MCP client:
|
||||
{
|
||||
"mcpServers": {
|
||||
"cua-agent": {
|
||||
"command": "/bin/bash",
|
||||
"args": ["~/.cua/start_mcp_server.sh"],
|
||||
"env": {
|
||||
"CUA_MODEL_NAME": "anthropic/claude-sonnet-4-5-20250929",
|
||||
"CUA_USE_HOST_COMPUTER_SERVER": "true"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
⚠️ WARNING: AI will have direct access to your desktop!
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all quick tests."""
|
||||
print("=" * 60)
|
||||
print("QUICK TEST: MCP Server Local Desktop Option")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
results = []
|
||||
|
||||
# Run tests
|
||||
results.append(("Environment Variable Parsing", test_env_var_parsing()))
|
||||
results.append(("Session Manager Logic", test_session_manager_logic()))
|
||||
results.append(("Documentation Consistency", test_documentation_consistency()))
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
for test_name, passed in results:
|
||||
status = "✓ PASSED" if passed else "✗ FAILED"
|
||||
print(f"{status} | {test_name}")
|
||||
|
||||
all_passed = all(result for _, result in results)
|
||||
|
||||
if all_passed:
|
||||
print("\n🎉 All quick tests passed!")
|
||||
print_usage_examples()
|
||||
print("\nNext steps:")
|
||||
print("1. Run full automated tests: pytest tests/test_mcp_server_local_option.py")
|
||||
print("2. Follow manual testing guide: tests/MANUAL_TEST_LOCAL_OPTION.md")
|
||||
return 0
|
||||
else:
|
||||
print("\n❌ Some tests failed. Please review the output above.")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
138
libs/python/mcp-server/test_mcp_server_local_option.py
Normal file
138
libs/python/mcp-server/test_mcp_server_local_option.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
Test script to verify MCP Server local desktop option works correctly.
|
||||
|
||||
This test verifies:
|
||||
1. Default behavior: Computer uses VM
|
||||
2. New behavior: Computer uses host when CUA_USE_HOST_COMPUTER_SERVER=true
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the mcp-server module to path
|
||||
mcp_server_path = Path(__file__).parent.parent / "libs" / "python" / "mcp-server"
|
||||
sys.path.insert(0, str(mcp_server_path.parent.parent.parent / "libs" / "python"))
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_vm_mode():
|
||||
"""Test that the default mode uses VM (not host computer server)."""
|
||||
# Ensure environment variable is not set or is false
|
||||
os.environ.pop("CUA_USE_HOST_COMPUTER_SERVER", None)
|
||||
|
||||
from mcp_server.session_manager import ComputerPool
|
||||
|
||||
pool = ComputerPool(max_size=1)
|
||||
|
||||
try:
|
||||
computer = await pool.acquire()
|
||||
|
||||
# Verify the computer was initialized
|
||||
assert computer is not None
|
||||
|
||||
# Check that use_host_computer_server was set to False (default)
|
||||
# This should start a VM
|
||||
print("✓ Default mode: Computer initialized (VM mode expected)")
|
||||
|
||||
await pool.release(computer)
|
||||
|
||||
finally:
|
||||
await pool.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_desktop_mode():
|
||||
"""Test that setting CUA_USE_HOST_COMPUTER_SERVER=true uses host."""
|
||||
# Set environment variable to true
|
||||
os.environ["CUA_USE_HOST_COMPUTER_SERVER"] = "true"
|
||||
|
||||
# Need to reload module to pick up new env var
|
||||
import importlib
|
||||
|
||||
import mcp_server.session_manager
|
||||
from mcp_server.session_manager import ComputerPool
|
||||
|
||||
importlib.reload(mcp_server.session_manager)
|
||||
|
||||
pool = mcp_server.session_manager.ComputerPool(max_size=1)
|
||||
|
||||
try:
|
||||
computer = await pool.acquire()
|
||||
|
||||
# Verify the computer was initialized
|
||||
assert computer is not None
|
||||
|
||||
# Check that use_host_computer_server was set to True
|
||||
print("✓ Local mode: Computer initialized (host mode expected)")
|
||||
|
||||
await pool.release(computer)
|
||||
|
||||
finally:
|
||||
await pool.shutdown()
|
||||
# Clean up env var
|
||||
os.environ.pop("CUA_USE_HOST_COMPUTER_SERVER", None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_env_var_parsing():
|
||||
"""Test that various values of CUA_USE_HOST_COMPUTER_SERVER are parsed correctly."""
|
||||
test_cases = [
|
||||
("true", True),
|
||||
("True", True),
|
||||
("TRUE", True),
|
||||
("1", True),
|
||||
("yes", True),
|
||||
("false", False),
|
||||
("False", False),
|
||||
("FALSE", False),
|
||||
("0", False),
|
||||
("no", False),
|
||||
("", False),
|
||||
("random", False),
|
||||
]
|
||||
|
||||
for value, expected in test_cases:
|
||||
os.environ["CUA_USE_HOST_COMPUTER_SERVER"] = value
|
||||
|
||||
# Check parsing logic
|
||||
use_host = os.getenv("CUA_USE_HOST_COMPUTER_SERVER", "false").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
)
|
||||
|
||||
assert (
|
||||
use_host == expected
|
||||
), f"Failed for value '{value}': expected {expected}, got {use_host}"
|
||||
print(f"✓ Env var '{value}' correctly parsed as {expected}")
|
||||
|
||||
os.environ.pop("CUA_USE_HOST_COMPUTER_SERVER", None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing MCP Server Local Desktop Option")
|
||||
print("=" * 60)
|
||||
|
||||
print("\n1. Testing environment variable parsing...")
|
||||
asyncio.run(test_env_var_parsing())
|
||||
|
||||
print("\n2. Testing default VM mode...")
|
||||
try:
|
||||
asyncio.run(test_default_vm_mode())
|
||||
except Exception as e:
|
||||
print(f"✗ Default VM mode test failed: {e}")
|
||||
print("Note: This may require lume/VM setup to fully test")
|
||||
|
||||
print("\n3. Testing local desktop mode...")
|
||||
try:
|
||||
asyncio.run(test_local_desktop_mode())
|
||||
except Exception as e:
|
||||
print(f"✗ Local desktop mode test failed: {e}")
|
||||
print("Note: This may require computer-server to be running locally")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Tests completed!")
|
||||
Reference in New Issue
Block a user