mirror of
https://github.com/trycua/computer.git
synced 2026-01-01 02:50:15 -06:00
Merge upstream/main to resolve conflicts with trycua/cua
This commit is contained in:
10
libs/python/agent/.bumpversion.cfg
Normal file
10
libs/python/agent/.bumpversion.cfg
Normal file
@@ -0,0 +1,10 @@
|
||||
[bumpversion]
|
||||
current_version = 0.4.35
|
||||
commit = True
|
||||
tag = True
|
||||
tag_name = agent-v{new_version}
|
||||
message = Bump cua-agent to v{new_version}
|
||||
|
||||
[bumpversion:file:pyproject.toml]
|
||||
search = version = "{current_version}"
|
||||
replace = version = "{new_version}"
|
||||
@@ -8,10 +8,11 @@
|
||||
</picture>
|
||||
</div>
|
||||
|
||||
[](#)
|
||||
[](#)
|
||||
[](https://discord.com/invite/mVnXXpdE85)
|
||||
[](https://pypi.org/project/cua-computer/)
|
||||
[](#)
|
||||
[](#)
|
||||
[](https://discord.com/invite/mVnXXpdE85)
|
||||
[](https://pypi.org/project/cua-computer/)
|
||||
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
@@ -47,7 +48,7 @@ async def main():
|
||||
name=os.getenv("CUA_CONTAINER_NAME"),
|
||||
api_key=os.getenv("CUA_API_KEY")
|
||||
) as computer:
|
||||
|
||||
|
||||
# Create agent
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
@@ -56,10 +57,10 @@ async def main():
|
||||
trajectory_dir="trajectories",
|
||||
max_trajectory_budget=5.0 # $5 budget limit
|
||||
)
|
||||
|
||||
|
||||
# Run agent
|
||||
messages = [{"role": "user", "content": "Take a screenshot and tell me what you see"}]
|
||||
|
||||
|
||||
async for result in agent.run(messages):
|
||||
for item in result["output"]:
|
||||
if item["type"] == "message":
|
||||
@@ -84,4 +85,4 @@ if __name__ == "__main__":
|
||||
|
||||
## License
|
||||
|
||||
MIT License - see LICENSE file for details.
|
||||
MIT License - see LICENSE file for details.
|
||||
|
||||
@@ -5,19 +5,13 @@ agent - Decorator-based Computer Use Agent with liteLLM integration
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from .decorators import register_agent
|
||||
from .agent import ComputerAgent
|
||||
from .types import Messages, AgentResponse
|
||||
|
||||
# Import loops to register them
|
||||
from . import loops
|
||||
from .agent import ComputerAgent
|
||||
from .decorators import register_agent
|
||||
from .types import AgentResponse, Messages
|
||||
|
||||
__all__ = [
|
||||
"register_agent",
|
||||
"ComputerAgent",
|
||||
"Messages",
|
||||
"AgentResponse"
|
||||
]
|
||||
__all__ = ["register_agent", "ComputerAgent", "Messages", "AgentResponse"]
|
||||
|
||||
__version__ = "0.4.0"
|
||||
|
||||
|
||||
@@ -5,8 +5,9 @@ Usage:
|
||||
python -m agent.cli <model_string>
|
||||
"""
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from .cli import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -2,27 +2,30 @@ import asyncio
|
||||
import functools
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Iterator, AsyncIterator, Dict, List, Any, Optional
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||
|
||||
from litellm import acompletion, completion
|
||||
from litellm.llms.custom_llm import CustomLLM
|
||||
from litellm import completion, acompletion
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||
|
||||
# Try to import HuggingFace dependencies
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoModelForImageTextToText, AutoProcessor
|
||||
|
||||
HF_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_AVAILABLE = False
|
||||
|
||||
from .models import load_model as load_model_handler
|
||||
|
||||
|
||||
class HuggingFaceLocalAdapter(CustomLLM):
|
||||
"""HuggingFace Local Adapter for running vision-language models locally."""
|
||||
|
||||
|
||||
def __init__(self, device: str = "auto", trust_remote_code: bool = False, **kwargs):
|
||||
"""Initialize the adapter.
|
||||
|
||||
|
||||
Args:
|
||||
device: Device to load model on ("auto", "cuda", "cpu", etc.)
|
||||
trust_remote_code: Whether to trust remote code
|
||||
@@ -34,129 +37,120 @@ class HuggingFaceLocalAdapter(CustomLLM):
|
||||
# Cache for model handlers keyed by model_name
|
||||
self._handlers: Dict[str, Any] = {}
|
||||
self._executor = ThreadPoolExecutor(max_workers=1) # Single thread pool
|
||||
|
||||
|
||||
def _get_handler(self, model_name: str):
|
||||
"""Get or create a model handler for the given model name."""
|
||||
if model_name not in self._handlers:
|
||||
self._handlers[model_name] = load_model_handler(model_name=model_name, device=self.device, trust_remote_code=self.trust_remote_code)
|
||||
self._handlers[model_name] = load_model_handler(
|
||||
model_name=model_name, device=self.device, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
return self._handlers[model_name]
|
||||
|
||||
|
||||
def _convert_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Convert OpenAI format messages to HuggingFace format.
|
||||
|
||||
|
||||
Args:
|
||||
messages: Messages in OpenAI format
|
||||
|
||||
|
||||
Returns:
|
||||
Messages in HuggingFace format
|
||||
"""
|
||||
converted_messages = []
|
||||
|
||||
|
||||
for message in messages:
|
||||
converted_message = {
|
||||
"role": message["role"],
|
||||
"content": []
|
||||
}
|
||||
|
||||
converted_message = {"role": message["role"], "content": []}
|
||||
|
||||
content = message.get("content", [])
|
||||
if isinstance(content, str):
|
||||
# Simple text content
|
||||
converted_message["content"].append({
|
||||
"type": "text",
|
||||
"text": content
|
||||
})
|
||||
converted_message["content"].append({"type": "text", "text": content})
|
||||
elif isinstance(content, list):
|
||||
# Multi-modal content
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
converted_message["content"].append({
|
||||
"type": "text",
|
||||
"text": item.get("text", "")
|
||||
})
|
||||
converted_message["content"].append(
|
||||
{"type": "text", "text": item.get("text", "")}
|
||||
)
|
||||
elif item.get("type") == "image_url":
|
||||
# Convert image_url format to image format
|
||||
image_url = item.get("image_url", {}).get("url", "")
|
||||
converted_message["content"].append({
|
||||
"type": "image",
|
||||
"image": image_url
|
||||
})
|
||||
|
||||
converted_message["content"].append({"type": "image", "image": image_url})
|
||||
|
||||
converted_messages.append(converted_message)
|
||||
|
||||
|
||||
return converted_messages
|
||||
|
||||
|
||||
def _generate(self, **kwargs) -> str:
|
||||
"""Generate response using the local HuggingFace model.
|
||||
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing messages and model info
|
||||
|
||||
|
||||
Returns:
|
||||
Generated text response
|
||||
"""
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError(
|
||||
"HuggingFace transformers dependencies not found. "
|
||||
"Please install with: pip install \"cua-agent[uitars-hf]\""
|
||||
'Please install with: pip install "cua-agent[uitars-hf]"'
|
||||
)
|
||||
|
||||
|
||||
# Extract messages and model from kwargs
|
||||
messages = kwargs.get('messages', [])
|
||||
model_name = kwargs.get('model', 'ByteDance-Seed/UI-TARS-1.5-7B')
|
||||
max_new_tokens = kwargs.get('max_tokens', 128)
|
||||
|
||||
messages = kwargs.get("messages", [])
|
||||
model_name = kwargs.get("model", "ByteDance-Seed/UI-TARS-1.5-7B")
|
||||
max_new_tokens = kwargs.get("max_tokens", 128)
|
||||
|
||||
# Warn about ignored kwargs
|
||||
ignored_kwargs = set(kwargs.keys()) - {'messages', 'model', 'max_tokens'}
|
||||
ignored_kwargs = set(kwargs.keys()) - {"messages", "model", "max_tokens"}
|
||||
if ignored_kwargs:
|
||||
warnings.warn(f"Ignoring unsupported kwargs: {ignored_kwargs}")
|
||||
|
||||
|
||||
# Convert messages to HuggingFace format
|
||||
hf_messages = self._convert_messages(messages)
|
||||
|
||||
|
||||
# Delegate to model handler
|
||||
handler = self._get_handler(model_name)
|
||||
generated_text = handler.generate(hf_messages, max_new_tokens=max_new_tokens)
|
||||
return generated_text
|
||||
|
||||
|
||||
def completion(self, *args, **kwargs) -> ModelResponse:
|
||||
"""Synchronous completion method.
|
||||
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated text
|
||||
"""
|
||||
generated_text = self._generate(**kwargs)
|
||||
|
||||
|
||||
return completion(
|
||||
model=f"huggingface-local/{kwargs['model']}",
|
||||
mock_response=generated_text,
|
||||
)
|
||||
|
||||
|
||||
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
||||
"""Asynchronous completion method.
|
||||
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated text
|
||||
"""
|
||||
# Run _generate in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
generated_text = await loop.run_in_executor(
|
||||
self._executor,
|
||||
functools.partial(self._generate, **kwargs)
|
||||
self._executor, functools.partial(self._generate, **kwargs)
|
||||
)
|
||||
|
||||
|
||||
return await acompletion(
|
||||
model=f"huggingface-local/{kwargs['model']}",
|
||||
mock_response=generated_text,
|
||||
)
|
||||
|
||||
|
||||
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||
"""Synchronous streaming method.
|
||||
|
||||
|
||||
Returns:
|
||||
Iterator of GenericStreamingChunk
|
||||
"""
|
||||
generated_text = self._generate(**kwargs)
|
||||
|
||||
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
@@ -165,22 +159,21 @@ class HuggingFaceLocalAdapter(CustomLLM):
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
yield generic_streaming_chunk
|
||||
|
||||
|
||||
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
||||
"""Asynchronous streaming method.
|
||||
|
||||
|
||||
Returns:
|
||||
AsyncIterator of GenericStreamingChunk
|
||||
"""
|
||||
# Run _generate in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
generated_text = await loop.run_in_executor(
|
||||
self._executor,
|
||||
functools.partial(self._generate, **kwargs)
|
||||
self._executor, functools.partial(self._generate, **kwargs)
|
||||
)
|
||||
|
||||
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
@@ -189,5 +182,5 @@ class HuggingFaceLocalAdapter(CustomLLM):
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
yield generic_streaming_chunk
|
||||
|
||||
yield generic_streaming_chunk
|
||||
|
||||
@@ -1,22 +1,23 @@
|
||||
import os
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List
|
||||
|
||||
import requests
|
||||
from typing import List, Dict, Any, Iterator, AsyncIterator
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||
from litellm import acompletion, completion
|
||||
from litellm.llms.custom_llm import CustomLLM
|
||||
from litellm import completion, acompletion
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||
|
||||
|
||||
class HumanAdapter(CustomLLM):
|
||||
"""Human Adapter for human-in-the-loop completions.
|
||||
|
||||
|
||||
This adapter sends completion requests to a human completion server
|
||||
where humans can review and respond to AI requests.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, base_url: str | None = None, timeout: float = 300.0, **kwargs):
|
||||
"""Initialize the human adapter.
|
||||
|
||||
|
||||
Args:
|
||||
base_url: Base URL for the human completion server.
|
||||
Defaults to HUMAN_BASE_URL environment variable or http://localhost:8002
|
||||
@@ -24,60 +25,58 @@ class HumanAdapter(CustomLLM):
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
super().__init__()
|
||||
self.base_url = base_url or os.getenv('HUMAN_BASE_URL', 'http://localhost:8002')
|
||||
self.base_url = base_url or os.getenv("HUMAN_BASE_URL", "http://localhost:8002")
|
||||
self.timeout = timeout
|
||||
|
||||
|
||||
# Ensure base_url doesn't end with slash
|
||||
self.base_url = self.base_url.rstrip('/')
|
||||
|
||||
self.base_url = self.base_url.rstrip("/")
|
||||
|
||||
def _queue_completion(self, messages: List[Dict[str, Any]], model: str) -> str:
|
||||
"""Queue a completion request and return the call ID.
|
||||
|
||||
|
||||
Args:
|
||||
messages: Messages in OpenAI format
|
||||
model: Model name
|
||||
|
||||
|
||||
Returns:
|
||||
Call ID for tracking the request
|
||||
|
||||
|
||||
Raises:
|
||||
Exception: If queueing fails
|
||||
"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/queue",
|
||||
json={"messages": messages, "model": model},
|
||||
timeout=10
|
||||
f"{self.base_url}/queue", json={"messages": messages, "model": model}, timeout=10
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["id"]
|
||||
except requests.RequestException as e:
|
||||
raise Exception(f"Failed to queue completion request: {e}")
|
||||
|
||||
|
||||
def _wait_for_completion(self, call_id: str) -> Dict[str, Any]:
|
||||
"""Wait for human to complete the call.
|
||||
|
||||
|
||||
Args:
|
||||
call_id: ID of the queued completion call
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing response and/or tool_calls
|
||||
|
||||
|
||||
Raises:
|
||||
TimeoutError: If timeout is exceeded
|
||||
Exception: If completion fails
|
||||
"""
|
||||
import time
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Check status
|
||||
status_response = requests.get(f"{self.base_url}/status/{call_id}")
|
||||
status_response.raise_for_status()
|
||||
status_data = status_response.json()
|
||||
|
||||
|
||||
if status_data["status"] == "completed":
|
||||
result = {}
|
||||
if "response" in status_data and status_data["response"]:
|
||||
@@ -88,38 +87,41 @@ class HumanAdapter(CustomLLM):
|
||||
elif status_data["status"] == "failed":
|
||||
error_msg = status_data.get("error", "Unknown error")
|
||||
raise Exception(f"Completion failed: {error_msg}")
|
||||
|
||||
|
||||
# Check timeout
|
||||
if time.time() - start_time > self.timeout:
|
||||
raise TimeoutError(f"Timeout waiting for human response after {self.timeout} seconds")
|
||||
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for human response after {self.timeout} seconds"
|
||||
)
|
||||
|
||||
# Wait before checking again
|
||||
time.sleep(1.0)
|
||||
|
||||
|
||||
except requests.RequestException as e:
|
||||
if time.time() - start_time > self.timeout:
|
||||
raise TimeoutError(f"Timeout waiting for human response: {e}")
|
||||
# Continue trying if we haven't timed out
|
||||
time.sleep(1.0)
|
||||
|
||||
|
||||
async def _async_wait_for_completion(self, call_id: str) -> Dict[str, Any]:
|
||||
"""Async version of wait_for_completion.
|
||||
|
||||
|
||||
Args:
|
||||
call_id: ID of the queued completion call
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing response and/or tool_calls
|
||||
|
||||
|
||||
Raises:
|
||||
TimeoutError: If timeout is exceeded
|
||||
Exception: If completion fails
|
||||
"""
|
||||
import aiohttp
|
||||
import time
|
||||
|
||||
|
||||
import aiohttp
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
while True:
|
||||
try:
|
||||
@@ -127,7 +129,7 @@ class HumanAdapter(CustomLLM):
|
||||
async with session.get(f"{self.base_url}/status/{call_id}") as response:
|
||||
response.raise_for_status()
|
||||
status_data = await response.json()
|
||||
|
||||
|
||||
if status_data["status"] == "completed":
|
||||
result = {}
|
||||
if "response" in status_data and status_data["response"]:
|
||||
@@ -138,166 +140,158 @@ class HumanAdapter(CustomLLM):
|
||||
elif status_data["status"] == "failed":
|
||||
error_msg = status_data.get("error", "Unknown error")
|
||||
raise Exception(f"Completion failed: {error_msg}")
|
||||
|
||||
|
||||
# Check timeout
|
||||
if time.time() - start_time > self.timeout:
|
||||
raise TimeoutError(f"Timeout waiting for human response after {self.timeout} seconds")
|
||||
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for human response after {self.timeout} seconds"
|
||||
)
|
||||
|
||||
# Wait before checking again
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if time.time() - start_time > self.timeout:
|
||||
raise TimeoutError(f"Timeout waiting for human response: {e}")
|
||||
# Continue trying if we haven't timed out
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
|
||||
def _generate_response(self, messages: List[Dict[str, Any]], model: str) -> Dict[str, Any]:
|
||||
"""Generate a human response for the given messages.
|
||||
|
||||
|
||||
Args:
|
||||
messages: Messages in OpenAI format
|
||||
model: Model name
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing response and/or tool_calls
|
||||
"""
|
||||
# Queue the completion request
|
||||
call_id = self._queue_completion(messages, model)
|
||||
|
||||
|
||||
# Wait for human response
|
||||
response = self._wait_for_completion(call_id)
|
||||
|
||||
|
||||
return response
|
||||
|
||||
async def _async_generate_response(self, messages: List[Dict[str, Any]], model: str) -> Dict[str, Any]:
|
||||
|
||||
async def _async_generate_response(
|
||||
self, messages: List[Dict[str, Any]], model: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Async version of _generate_response.
|
||||
|
||||
|
||||
Args:
|
||||
messages: Messages in OpenAI format
|
||||
model: Model name
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing response and/or tool_calls
|
||||
"""
|
||||
# Queue the completion request (sync operation)
|
||||
call_id = self._queue_completion(messages, model)
|
||||
|
||||
|
||||
# Wait for human response (async)
|
||||
response = await self._async_wait_for_completion(call_id)
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def completion(self, *args, **kwargs) -> ModelResponse:
|
||||
"""Synchronous completion method.
|
||||
|
||||
|
||||
Returns:
|
||||
ModelResponse with human-generated text or tool calls
|
||||
"""
|
||||
messages = kwargs.get('messages', [])
|
||||
model = kwargs.get('model', 'human')
|
||||
|
||||
messages = kwargs.get("messages", [])
|
||||
model = kwargs.get("model", "human")
|
||||
|
||||
# Generate human response
|
||||
human_response_data = self._generate_response(messages, model)
|
||||
|
||||
|
||||
# Create ModelResponse with proper structure
|
||||
from litellm.types.utils import ModelResponse, Choices, Message
|
||||
import uuid
|
||||
import time
|
||||
|
||||
import uuid
|
||||
|
||||
from litellm.types.utils import Choices, Message, ModelResponse
|
||||
|
||||
# Create message content based on response type
|
||||
if "tool_calls" in human_response_data and human_response_data["tool_calls"]:
|
||||
# Tool calls response
|
||||
message = Message(
|
||||
role="assistant",
|
||||
content=human_response_data.get("response", ""),
|
||||
tool_calls=human_response_data["tool_calls"]
|
||||
tool_calls=human_response_data["tool_calls"],
|
||||
)
|
||||
else:
|
||||
# Text response
|
||||
message = Message(
|
||||
role="assistant",
|
||||
content=human_response_data.get("response", "")
|
||||
)
|
||||
|
||||
choice = Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=message
|
||||
)
|
||||
|
||||
message = Message(role="assistant", content=human_response_data.get("response", ""))
|
||||
|
||||
choice = Choices(finish_reason="stop", index=0, message=message)
|
||||
|
||||
result = ModelResponse(
|
||||
id=f"human-{uuid.uuid4()}",
|
||||
choices=[choice],
|
||||
created=int(time.time()),
|
||||
model=f"human/{model}",
|
||||
object="chat.completion"
|
||||
object="chat.completion",
|
||||
)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
||||
"""Asynchronous completion method.
|
||||
|
||||
|
||||
Returns:
|
||||
ModelResponse with human-generated text or tool calls
|
||||
"""
|
||||
messages = kwargs.get('messages', [])
|
||||
model = kwargs.get('model', 'human')
|
||||
|
||||
messages = kwargs.get("messages", [])
|
||||
model = kwargs.get("model", "human")
|
||||
|
||||
# Generate human response
|
||||
human_response_data = await self._async_generate_response(messages, model)
|
||||
|
||||
|
||||
# Create ModelResponse with proper structure
|
||||
from litellm.types.utils import ModelResponse, Choices, Message
|
||||
import uuid
|
||||
import time
|
||||
|
||||
import uuid
|
||||
|
||||
from litellm.types.utils import Choices, Message, ModelResponse
|
||||
|
||||
# Create message content based on response type
|
||||
if "tool_calls" in human_response_data and human_response_data["tool_calls"]:
|
||||
# Tool calls response
|
||||
message = Message(
|
||||
role="assistant",
|
||||
content=human_response_data.get("response", ""),
|
||||
tool_calls=human_response_data["tool_calls"]
|
||||
tool_calls=human_response_data["tool_calls"],
|
||||
)
|
||||
else:
|
||||
# Text response
|
||||
message = Message(
|
||||
role="assistant",
|
||||
content=human_response_data.get("response", "")
|
||||
)
|
||||
|
||||
choice = Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=message
|
||||
)
|
||||
|
||||
message = Message(role="assistant", content=human_response_data.get("response", ""))
|
||||
|
||||
choice = Choices(finish_reason="stop", index=0, message=message)
|
||||
|
||||
result = ModelResponse(
|
||||
id=f"human-{uuid.uuid4()}",
|
||||
choices=[choice],
|
||||
created=int(time.time()),
|
||||
model=f"human/{model}",
|
||||
object="chat.completion"
|
||||
object="chat.completion",
|
||||
)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||
"""Synchronous streaming method.
|
||||
|
||||
|
||||
Yields:
|
||||
Streaming chunks with human-generated text or tool calls
|
||||
"""
|
||||
messages = kwargs.get('messages', [])
|
||||
model = kwargs.get('model', 'human')
|
||||
|
||||
messages = kwargs.get("messages", [])
|
||||
model = kwargs.get("model", "human")
|
||||
|
||||
# Generate human response
|
||||
human_response_data = self._generate_response(messages, model)
|
||||
|
||||
|
||||
import time
|
||||
|
||||
|
||||
# Handle tool calls vs text response
|
||||
if "tool_calls" in human_response_data and human_response_data["tool_calls"]:
|
||||
# Stream tool calls as a single chunk
|
||||
@@ -319,22 +313,26 @@ class HumanAdapter(CustomLLM):
|
||||
"is_finished": True,
|
||||
"text": response_text,
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": len(response_text.split()), "prompt_tokens": 0, "total_tokens": len(response_text.split())},
|
||||
"usage": {
|
||||
"completion_tokens": len(response_text.split()),
|
||||
"prompt_tokens": 0,
|
||||
"total_tokens": len(response_text.split()),
|
||||
},
|
||||
}
|
||||
yield generic_chunk
|
||||
|
||||
|
||||
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
||||
"""Asynchronous streaming method.
|
||||
|
||||
|
||||
Yields:
|
||||
Streaming chunks with human-generated text or tool calls
|
||||
"""
|
||||
messages = kwargs.get('messages', [])
|
||||
model = kwargs.get('model', 'human')
|
||||
|
||||
messages = kwargs.get("messages", [])
|
||||
model = kwargs.get("model", "human")
|
||||
|
||||
# Generate human response
|
||||
human_response = await self._async_generate_response(messages, model)
|
||||
|
||||
|
||||
# Return as single streaming chunk
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
@@ -342,7 +340,11 @@ class HumanAdapter(CustomLLM):
|
||||
"is_finished": True,
|
||||
"text": human_response,
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": len(human_response.split()), "prompt_tokens": 0, "total_tokens": len(human_response.split())},
|
||||
"usage": {
|
||||
"completion_tokens": len(human_response.split()),
|
||||
"prompt_tokens": 0,
|
||||
"total_tokens": len(human_response.split()),
|
||||
},
|
||||
}
|
||||
|
||||
yield generic_streaming_chunk
|
||||
|
||||
yield generic_streaming_chunk
|
||||
|
||||
@@ -1,24 +1,26 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import warnings
|
||||
import io
|
||||
import base64
|
||||
import functools
|
||||
import io
|
||||
import math
|
||||
import re
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Iterator, AsyncIterator, Dict, List, Any, Optional, Tuple, cast
|
||||
from PIL import Image
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple, cast
|
||||
|
||||
from litellm import acompletion, completion
|
||||
from litellm.llms.custom_llm import CustomLLM
|
||||
from litellm import completion, acompletion
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||
from PIL import Image
|
||||
|
||||
# Try to import MLX dependencies
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_vlm import load, generate
|
||||
from mlx_vlm import generate, load
|
||||
from mlx_vlm.prompt_utils import apply_chat_template
|
||||
from mlx_vlm.utils import load_config
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
MLX_AVAILABLE = True
|
||||
except ImportError:
|
||||
MLX_AVAILABLE = False
|
||||
@@ -29,20 +31,28 @@ MIN_PIXELS = 100 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
|
||||
|
||||
def round_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
|
||||
def ceil_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
|
||||
def floor_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
|
||||
def smart_resize(
|
||||
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||||
height: int,
|
||||
width: int,
|
||||
factor: int = IMAGE_FACTOR,
|
||||
min_pixels: int = MIN_PIXELS,
|
||||
max_pixels: int = MAX_PIXELS,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
@@ -70,61 +80,62 @@ def smart_resize(
|
||||
|
||||
class MLXVLMAdapter(CustomLLM):
|
||||
"""MLX VLM Adapter for running vision-language models locally using MLX."""
|
||||
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the adapter.
|
||||
|
||||
|
||||
Args:
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.models = {} # Cache for loaded models
|
||||
self.processors = {} # Cache for loaded processors
|
||||
self.configs = {} # Cache for loaded configs
|
||||
self._executor = ThreadPoolExecutor(max_workers=1) # Single thread pool
|
||||
|
||||
|
||||
def _load_model_and_processor(self, model_name: str):
|
||||
"""Load model and processor if not already cached.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to load
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (model, processor, config)
|
||||
"""
|
||||
if not MLX_AVAILABLE:
|
||||
raise ImportError("MLX VLM dependencies not available. Please install mlx-vlm.")
|
||||
|
||||
|
||||
if model_name not in self.models:
|
||||
# Load model and processor
|
||||
model_obj, processor = load(
|
||||
model_name,
|
||||
processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
|
||||
model_name, processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
|
||||
)
|
||||
config = load_config(model_name)
|
||||
|
||||
|
||||
# Cache them
|
||||
self.models[model_name] = model_obj
|
||||
self.processors[model_name] = processor
|
||||
self.configs[model_name] = config
|
||||
|
||||
|
||||
return self.models[model_name], self.processors[model_name], self.configs[model_name]
|
||||
|
||||
def _process_coordinates(self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]) -> str:
|
||||
|
||||
def _process_coordinates(
|
||||
self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]
|
||||
) -> str:
|
||||
"""Process coordinates in box tokens based on image resizing using smart_resize approach.
|
||||
|
||||
|
||||
Args:
|
||||
text: Text containing box tokens
|
||||
original_size: Original image size (width, height)
|
||||
model_size: Model processed image size (width, height)
|
||||
|
||||
|
||||
Returns:
|
||||
Text with processed coordinates
|
||||
"""
|
||||
# Find all box tokens
|
||||
box_pattern = r"<\|box_start\|>\((\d+),\s*(\d+)\)<\|box_end\|>"
|
||||
|
||||
|
||||
def process_coords(match):
|
||||
model_x, model_y = int(match.group(1)), int(match.group(2))
|
||||
# Scale coordinates from model space to original image space
|
||||
@@ -132,15 +143,20 @@ class MLXVLMAdapter(CustomLLM):
|
||||
new_x = int(model_x * original_size[0] / model_size[0]) # Width
|
||||
new_y = int(model_y * original_size[1] / model_size[1]) # Height
|
||||
return f"<|box_start|>({new_x},{new_y})<|box_end|>"
|
||||
|
||||
|
||||
return re.sub(box_pattern, process_coords, text)
|
||||
|
||||
def _convert_messages(self, messages: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Image.Image], Dict[int, Tuple[int, int]], Dict[int, Tuple[int, int]]]:
|
||||
|
||||
def _convert_messages(self, messages: List[Dict[str, Any]]) -> Tuple[
|
||||
List[Dict[str, Any]],
|
||||
List[Image.Image],
|
||||
Dict[int, Tuple[int, int]],
|
||||
Dict[int, Tuple[int, int]],
|
||||
]:
|
||||
"""Convert OpenAI format messages to MLX VLM format and extract images.
|
||||
|
||||
|
||||
Args:
|
||||
messages: Messages in OpenAI format
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (processed_messages, images, original_sizes, model_sizes)
|
||||
"""
|
||||
@@ -149,13 +165,10 @@ class MLXVLMAdapter(CustomLLM):
|
||||
original_sizes = {} # Track original sizes of images for coordinate mapping
|
||||
model_sizes = {} # Track model processed sizes
|
||||
image_index = 0
|
||||
|
||||
|
||||
for message in messages:
|
||||
processed_message = {
|
||||
"role": message["role"],
|
||||
"content": []
|
||||
}
|
||||
|
||||
processed_message = {"role": message["role"], "content": []}
|
||||
|
||||
content = message.get("content", [])
|
||||
if isinstance(content, str):
|
||||
# Simple text content
|
||||
@@ -165,164 +178,163 @@ class MLXVLMAdapter(CustomLLM):
|
||||
processed_content = []
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
processed_content.append({
|
||||
"type": "text",
|
||||
"text": item.get("text", "")
|
||||
})
|
||||
processed_content.append({"type": "text", "text": item.get("text", "")})
|
||||
elif item.get("type") == "image_url":
|
||||
image_url = item.get("image_url", {}).get("url", "")
|
||||
pil_image = None
|
||||
|
||||
|
||||
if image_url.startswith("data:image/"):
|
||||
# Extract base64 data
|
||||
base64_data = image_url.split(',')[1]
|
||||
base64_data = image_url.split(",")[1]
|
||||
# Convert base64 to PIL Image
|
||||
image_data = base64.b64decode(base64_data)
|
||||
pil_image = Image.open(io.BytesIO(image_data))
|
||||
else:
|
||||
# Handle file path or URL
|
||||
pil_image = Image.open(image_url)
|
||||
|
||||
|
||||
# Store original image size for coordinate mapping
|
||||
original_size = pil_image.size
|
||||
original_sizes[image_index] = original_size
|
||||
|
||||
|
||||
# Use smart_resize to determine model size
|
||||
# Note: smart_resize expects (height, width) but PIL gives (width, height)
|
||||
height, width = original_size[1], original_size[0]
|
||||
new_height, new_width = smart_resize(height, width)
|
||||
# Store model size in (width, height) format for consistent coordinate processing
|
||||
model_sizes[image_index] = (new_width, new_height)
|
||||
|
||||
|
||||
# Resize the image using the calculated dimensions from smart_resize
|
||||
resized_image = pil_image.resize((new_width, new_height))
|
||||
images.append(resized_image)
|
||||
|
||||
|
||||
# Add image placeholder to content
|
||||
processed_content.append({
|
||||
"type": "image"
|
||||
})
|
||||
|
||||
processed_content.append({"type": "image"})
|
||||
|
||||
image_index += 1
|
||||
|
||||
|
||||
processed_message["content"] = processed_content
|
||||
|
||||
|
||||
processed_messages.append(processed_message)
|
||||
|
||||
|
||||
return processed_messages, images, original_sizes, model_sizes
|
||||
|
||||
|
||||
def _generate(self, **kwargs) -> str:
|
||||
"""Generate response using the local MLX VLM model.
|
||||
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing messages and model info
|
||||
|
||||
|
||||
Returns:
|
||||
Generated text response
|
||||
"""
|
||||
messages = kwargs.get('messages', [])
|
||||
model_name = kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')
|
||||
max_tokens = kwargs.get('max_tokens', 128)
|
||||
|
||||
messages = kwargs.get("messages", [])
|
||||
model_name = kwargs.get("model", "mlx-community/UI-TARS-1.5-7B-4bit")
|
||||
max_tokens = kwargs.get("max_tokens", 128)
|
||||
|
||||
# Warn about ignored kwargs
|
||||
ignored_kwargs = set(kwargs.keys()) - {'messages', 'model', 'max_tokens'}
|
||||
ignored_kwargs = set(kwargs.keys()) - {"messages", "model", "max_tokens"}
|
||||
if ignored_kwargs:
|
||||
warnings.warn(f"Ignoring unsupported kwargs: {ignored_kwargs}")
|
||||
|
||||
|
||||
# Load model and processor
|
||||
model, processor, config = self._load_model_and_processor(model_name)
|
||||
|
||||
|
||||
# Convert messages and extract images
|
||||
processed_messages, images, original_sizes, model_sizes = self._convert_messages(messages)
|
||||
|
||||
|
||||
# Process user text input with box coordinates after image processing
|
||||
# Swap original_size and model_size arguments for inverse transformation
|
||||
for msg_idx, msg in enumerate(processed_messages):
|
||||
if msg.get("role") == "user" and isinstance(msg.get("content"), str):
|
||||
content = msg.get("content", "")
|
||||
if "<|box_start|>" in content and original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
|
||||
if (
|
||||
"<|box_start|>" in content
|
||||
and original_sizes
|
||||
and model_sizes
|
||||
and 0 in original_sizes
|
||||
and 0 in model_sizes
|
||||
):
|
||||
orig_size = original_sizes[0]
|
||||
model_size = model_sizes[0]
|
||||
# Swap arguments to perform inverse transformation for user input
|
||||
processed_messages[msg_idx]["content"] = self._process_coordinates(content, model_size, orig_size)
|
||||
|
||||
processed_messages[msg_idx]["content"] = self._process_coordinates(
|
||||
content, model_size, orig_size
|
||||
)
|
||||
|
||||
try:
|
||||
# Format prompt according to model requirements using the processor directly
|
||||
prompt = processor.apply_chat_template(
|
||||
processed_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
return_tensors='pt'
|
||||
processed_messages, tokenize=False, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
tokenizer = cast(PreTrainedTokenizer, processor)
|
||||
|
||||
|
||||
# Generate response
|
||||
text_content, usage = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
str(prompt),
|
||||
images, # type: ignore
|
||||
model,
|
||||
tokenizer,
|
||||
str(prompt),
|
||||
images, # type: ignore
|
||||
verbose=False,
|
||||
max_tokens=max_tokens
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error generating response: {str(e)}") from e
|
||||
|
||||
|
||||
# Process coordinates in the response back to original image space
|
||||
if original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
|
||||
# Get original image size and model size (using the first image)
|
||||
orig_size = original_sizes[0]
|
||||
model_size = model_sizes[0]
|
||||
|
||||
|
||||
# Check if output contains box tokens that need processing
|
||||
if "<|box_start|>" in text_content:
|
||||
# Process coordinates from model space back to original image space
|
||||
text_content = self._process_coordinates(text_content, orig_size, model_size)
|
||||
|
||||
|
||||
return text_content
|
||||
|
||||
|
||||
def completion(self, *args, **kwargs) -> ModelResponse:
|
||||
"""Synchronous completion method.
|
||||
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated text
|
||||
"""
|
||||
generated_text = self._generate(**kwargs)
|
||||
|
||||
|
||||
result = completion(
|
||||
model=f"mlx/{kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')}",
|
||||
mock_response=generated_text,
|
||||
)
|
||||
return cast(ModelResponse, result)
|
||||
|
||||
|
||||
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
||||
"""Asynchronous completion method.
|
||||
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated text
|
||||
"""
|
||||
# Run _generate in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
generated_text = await loop.run_in_executor(
|
||||
self._executor,
|
||||
functools.partial(self._generate, **kwargs)
|
||||
self._executor, functools.partial(self._generate, **kwargs)
|
||||
)
|
||||
|
||||
|
||||
result = await acompletion(
|
||||
model=f"mlx/{kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')}",
|
||||
mock_response=generated_text,
|
||||
)
|
||||
return cast(ModelResponse, result)
|
||||
|
||||
|
||||
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||
"""Synchronous streaming method.
|
||||
|
||||
|
||||
Returns:
|
||||
Iterator of GenericStreamingChunk
|
||||
"""
|
||||
generated_text = self._generate(**kwargs)
|
||||
|
||||
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
@@ -331,22 +343,21 @@ class MLXVLMAdapter(CustomLLM):
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
yield generic_streaming_chunk
|
||||
|
||||
|
||||
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
||||
"""Asynchronous streaming method.
|
||||
|
||||
|
||||
Returns:
|
||||
AsyncIterator of GenericStreamingChunk
|
||||
"""
|
||||
# Run _generate in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
generated_text = await loop.run_in_executor(
|
||||
self._executor,
|
||||
functools.partial(self._generate, **kwargs)
|
||||
self._executor, functools.partial(self._generate, **kwargs)
|
||||
)
|
||||
|
||||
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
@@ -355,5 +366,5 @@ class MLXVLMAdapter(CustomLLM):
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
yield generic_streaming_chunk
|
||||
|
||||
yield generic_streaming_chunk
|
||||
|
||||
@@ -2,32 +2,40 @@ from typing import Optional
|
||||
|
||||
try:
|
||||
from transformers import AutoConfig
|
||||
|
||||
HF_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_AVAILABLE = False
|
||||
|
||||
from .generic import GenericHFModel
|
||||
from .internvl import InternVLModel
|
||||
from .opencua import OpenCUAModel
|
||||
from .qwen2_5_vl import Qwen2_5_VLModel
|
||||
from .internvl import InternVLModel
|
||||
|
||||
|
||||
def load_model(model_name: str, device: str = "auto", trust_remote_code: bool = False):
|
||||
"""Factory function to load and return the right model handler instance.
|
||||
|
||||
|
||||
- If the underlying transformers config class matches OpenCUA, return OpenCUAModel
|
||||
- Otherwise, return GenericHFModel
|
||||
"""
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError(
|
||||
"HuggingFace transformers dependencies not found. Install with: pip install \"cua-agent[uitars-hf]\""
|
||||
'HuggingFace transformers dependencies not found. Install with: pip install "cua-agent[uitars-hf]"'
|
||||
)
|
||||
cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
||||
cls = cfg.__class__.__name__
|
||||
print(f"cls: {cls}")
|
||||
if "OpenCUA" in cls:
|
||||
return OpenCUAModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
|
||||
return OpenCUAModel(
|
||||
model_name=model_name, device=device, trust_remote_code=trust_remote_code
|
||||
)
|
||||
elif "Qwen2_5_VL" in cls:
|
||||
return Qwen2_5_VLModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
|
||||
return Qwen2_5_VLModel(
|
||||
model_name=model_name, device=device, trust_remote_code=trust_remote_code
|
||||
)
|
||||
elif "InternVL" in cls:
|
||||
return InternVLModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
|
||||
return InternVLModel(
|
||||
model_name=model_name, device=device, trust_remote_code=trust_remote_code
|
||||
)
|
||||
return GenericHFModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# Hugging Face imports are local to avoid hard dependency at module import
|
||||
try:
|
||||
import torch # type: ignore
|
||||
from transformers import AutoModel, AutoProcessor # type: ignore
|
||||
|
||||
HF_AVAILABLE = True
|
||||
except Exception:
|
||||
HF_AVAILABLE = False
|
||||
@@ -14,10 +15,12 @@ class GenericHFModel:
|
||||
Loads an AutoModelForImageTextToText and AutoProcessor and generates text.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
|
||||
def __init__(
|
||||
self, model_name: str, device: str = "auto", trust_remote_code: bool = False
|
||||
) -> None:
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError(
|
||||
"HuggingFace transformers dependencies not found. Install with: pip install \"cua-agent[uitars-hf]\""
|
||||
'HuggingFace transformers dependencies not found. Install with: pip install "cua-agent[uitars-hf]"'
|
||||
)
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
@@ -64,7 +67,7 @@ class GenericHFModel:
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
# Trim prompt tokens from output
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
# Decode
|
||||
output_text = self.processor.batch_decode(
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# Hugging Face imports are local to avoid hard dependency at module import
|
||||
try:
|
||||
import torch # type: ignore
|
||||
from transformers import AutoModel, AutoTokenizer # type: ignore
|
||||
# Attempt to import InternVL's model dependencies
|
||||
import einops as _ # type: ignore
|
||||
import timm as _ # type: ignore
|
||||
from PIL import Image # type: ignore
|
||||
import torchvision.transforms as T # type: ignore
|
||||
from torchvision.transforms.functional import InterpolationMode # type: ignore
|
||||
import base64 # type: ignore
|
||||
from io import BytesIO # type: ignore
|
||||
|
||||
# Attempt to import InternVL's model dependencies
|
||||
import einops as _ # type: ignore
|
||||
import requests # type: ignore
|
||||
import timm as _ # type: ignore
|
||||
import torch # type: ignore
|
||||
import torchvision.transforms as T # type: ignore
|
||||
from PIL import Image # type: ignore
|
||||
from torchvision.transforms.functional import InterpolationMode # type: ignore
|
||||
from transformers import AutoModel, AutoTokenizer # type: ignore
|
||||
|
||||
HF_AVAILABLE = True
|
||||
except Exception:
|
||||
HF_AVAILABLE = False
|
||||
@@ -25,10 +28,12 @@ class InternVLModel:
|
||||
Provides preprocessing to support multi-turn conversations with multiple images.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
|
||||
def __init__(
|
||||
self, model_name: str, device: str = "auto", trust_remote_code: bool = False
|
||||
) -> None:
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError(
|
||||
"InternVL dependencies not found. Install with: pip install \"cua-agent[internvl-hf]\""
|
||||
'InternVL dependencies not found. Install with: pip install "cua-agent[internvl-hf]"'
|
||||
)
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
@@ -60,16 +65,25 @@ class InternVLModel:
|
||||
|
||||
def _build_transform(self, input_size: int) -> T.Compose:
|
||||
MEAN, STD = self.IMAGENET_MEAN, self.IMAGENET_STD
|
||||
transform = T.Compose([
|
||||
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
||||
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=MEAN, std=STD)
|
||||
])
|
||||
transform = T.Compose(
|
||||
[
|
||||
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
|
||||
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=MEAN, std=STD),
|
||||
]
|
||||
)
|
||||
return transform
|
||||
|
||||
def _find_closest_aspect_ratio(self, aspect_ratio: float, target_ratios: List[tuple], width: int, height: int, image_size: int):
|
||||
best_ratio_diff = float('inf')
|
||||
def _find_closest_aspect_ratio(
|
||||
self,
|
||||
aspect_ratio: float,
|
||||
target_ratios: List[tuple],
|
||||
width: int,
|
||||
height: int,
|
||||
image_size: int,
|
||||
):
|
||||
best_ratio_diff = float("inf")
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
@@ -83,17 +97,29 @@ class InternVLModel:
|
||||
best_ratio = ratio
|
||||
return best_ratio
|
||||
|
||||
def _dynamic_preprocess(self, image: Image.Image, min_num: int = 1, max_num: int = 12, image_size: int = 448, use_thumbnail: bool = True) -> List[Image.Image]:
|
||||
def _dynamic_preprocess(
|
||||
self,
|
||||
image: Image.Image,
|
||||
min_num: int = 1,
|
||||
max_num: int = 12,
|
||||
image_size: int = 448,
|
||||
use_thumbnail: bool = True,
|
||||
) -> List[Image.Image]:
|
||||
orig_width, orig_height = image.size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
target_ratios = set(
|
||||
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
||||
i * j <= max_num and i * j >= min_num)
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
)
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
target_aspect_ratio = self._find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
||||
)
|
||||
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
@@ -106,7 +132,7 @@ class InternVLModel:
|
||||
(i % (target_width // image_size)) * image_size,
|
||||
(i // (target_width // image_size)) * image_size,
|
||||
((i % (target_width // image_size)) + 1) * image_size,
|
||||
((i // (target_width // image_size)) + 1) * image_size
|
||||
((i // (target_width // image_size)) + 1) * image_size,
|
||||
)
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
@@ -122,20 +148,24 @@ class InternVLModel:
|
||||
# data URL base64
|
||||
header, b64data = src.split(",", 1)
|
||||
img_bytes = base64.b64decode(b64data)
|
||||
return Image.open(BytesIO(img_bytes)).convert('RGB')
|
||||
return Image.open(BytesIO(img_bytes)).convert("RGB")
|
||||
if src.startswith("http://") or src.startswith("https://"):
|
||||
resp = requests.get(src, timeout=10)
|
||||
resp.raise_for_status()
|
||||
return Image.open(BytesIO(resp.content)).convert('RGB')
|
||||
return Image.open(BytesIO(resp.content)).convert("RGB")
|
||||
# Assume local file path
|
||||
return Image.open(src).convert('RGB')
|
||||
return Image.open(src).convert("RGB")
|
||||
|
||||
def _images_to_pixel_values(self, images: List[Image.Image], input_size: int = 448, max_num: int = 12):
|
||||
def _images_to_pixel_values(
|
||||
self, images: List[Image.Image], input_size: int = 448, max_num: int = 12
|
||||
):
|
||||
transform = self._build_transform(input_size=input_size)
|
||||
pixel_values_list = []
|
||||
num_patches_list: List[int] = []
|
||||
for img in images:
|
||||
tiles = self._dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
||||
tiles = self._dynamic_preprocess(
|
||||
img, image_size=input_size, use_thumbnail=True, max_num=max_num
|
||||
)
|
||||
pv = [transform(tile) for tile in tiles]
|
||||
pv = torch.stack(pv)
|
||||
num_patches_list.append(pv.shape[0])
|
||||
@@ -191,7 +221,9 @@ class InternVLModel:
|
||||
last_user_text_parts = parts_text or last_user_text_parts
|
||||
elif role == "assistant":
|
||||
# Only keep text content for history
|
||||
parts_text = [item.get("text", "") for item in content_items if item.get("type") == "text"]
|
||||
parts_text = [
|
||||
item.get("text", "") for item in content_items if item.get("type") == "text"
|
||||
]
|
||||
text = "\n".join(parts_text).strip()
|
||||
if text:
|
||||
context_lines.append(f"Assistant: {text}")
|
||||
@@ -200,7 +232,9 @@ class InternVLModel:
|
||||
pixel_values = None
|
||||
num_patches_list: List[int] = []
|
||||
if all_images:
|
||||
pixel_values, num_patches_list = self._images_to_pixel_values(all_images, input_size=448, max_num=12)
|
||||
pixel_values, num_patches_list = self._images_to_pixel_values(
|
||||
all_images, input_size=448, max_num=12
|
||||
)
|
||||
if pixel_values is not None:
|
||||
# Convert dtype/device as in docs
|
||||
pixel_values = pixel_values.to(torch.bfloat16)
|
||||
@@ -246,7 +280,9 @@ class InternVLModel:
|
||||
num_patches_list=num_patches_list,
|
||||
)
|
||||
else:
|
||||
response = self.model.chat(self.tokenizer, pixel_values, question, generation_config)
|
||||
response = self.model.chat(
|
||||
self.tokenizer, pixel_values, question, generation_config
|
||||
)
|
||||
except Exception as e:
|
||||
# Fallback: return empty string to avoid crashing the adapter
|
||||
return ""
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
from typing import List, Dict, Any
|
||||
import re
|
||||
import base64
|
||||
import re
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
import blobfile as _ # assert blobfile is installed
|
||||
import torch # type: ignore
|
||||
from transformers import AutoTokenizer, AutoModel, AutoImageProcessor # type: ignore
|
||||
from PIL import Image # type: ignore
|
||||
import blobfile as _ # assert blobfile is installed
|
||||
from transformers import ( # type: ignore
|
||||
AutoImageProcessor,
|
||||
AutoModel,
|
||||
AutoTokenizer,
|
||||
)
|
||||
|
||||
OPENCUA_AVAILABLE = True
|
||||
except Exception:
|
||||
OPENCUA_AVAILABLE = False
|
||||
@@ -16,10 +21,12 @@ except Exception:
|
||||
class OpenCUAModel:
|
||||
"""OpenCUA model handler using AutoTokenizer, AutoModel and AutoImageProcessor."""
|
||||
|
||||
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
|
||||
def __init__(
|
||||
self, model_name: str, device: str = "auto", trust_remote_code: bool = False
|
||||
) -> None:
|
||||
if not OPENCUA_AVAILABLE:
|
||||
raise ImportError(
|
||||
"OpenCUA requirements not found. Install with: pip install \"cua-agent[opencua-hf]\""
|
||||
'OpenCUA requirements not found. Install with: pip install "cua-agent[opencua-hf]"'
|
||||
)
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
@@ -56,7 +63,11 @@ class OpenCUAModel:
|
||||
return ""
|
||||
|
||||
def generate(self, messages: List[Dict[str, Any]], max_new_tokens: int = 512) -> str:
|
||||
assert self.model is not None and self.tokenizer is not None and self.image_processor is not None
|
||||
assert (
|
||||
self.model is not None
|
||||
and self.tokenizer is not None
|
||||
and self.image_processor is not None
|
||||
)
|
||||
|
||||
# Tokenize text side using chat template
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
@@ -74,7 +85,11 @@ class OpenCUAModel:
|
||||
pixel_values = torch.tensor(image_info["pixel_values"]).to(
|
||||
dtype=torch.bfloat16, device=self.model.device
|
||||
)
|
||||
grid_thws = torch.tensor(image_info["image_grid_thw"]) if "image_grid_thw" in image_info else None
|
||||
grid_thws = (
|
||||
torch.tensor(image_info["image_grid_thw"])
|
||||
if "image_grid_thw" in image_info
|
||||
else None
|
||||
)
|
||||
|
||||
gen_kwargs: Dict[str, Any] = {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# Hugging Face imports are local to avoid hard dependency at module import
|
||||
try:
|
||||
import torch # type: ignore
|
||||
from transformers import AutoModelForImageTextToText, AutoProcessor # type: ignore
|
||||
|
||||
HF_AVAILABLE = True
|
||||
except Exception:
|
||||
HF_AVAILABLE = False
|
||||
@@ -14,10 +15,12 @@ class Qwen2_5_VLModel:
|
||||
Loads an AutoModelForImageTextToText and AutoProcessor and generates text.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
|
||||
def __init__(
|
||||
self, model_name: str, device: str = "auto", trust_remote_code: bool = False
|
||||
) -> None:
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError(
|
||||
"HuggingFace transformers dependencies not found. Install with: pip install \"cua-agent[uitars-hf]\""
|
||||
'HuggingFace transformers dependencies not found. Install with: pip install "cua-agent[uitars-hf]"'
|
||||
)
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
@@ -64,7 +67,7 @@ class Qwen2_5_VLModel:
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
# Trim prompt tokens from output
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
# Decode
|
||||
output_text = self.processor.batch_decode(
|
||||
|
||||
@@ -3,76 +3,83 @@ ComputerAgent - Main agent class that selects and runs agent loops
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Callable, Set, Tuple
|
||||
|
||||
from litellm.responses.utils import Usage
|
||||
|
||||
from .types import (
|
||||
Messages,
|
||||
AgentCapability,
|
||||
ToolError,
|
||||
IllegalArgumentError
|
||||
)
|
||||
from .responses import make_tool_error_item, replace_failed_computer_calls_with_function_calls
|
||||
from .decorators import find_agent_config
|
||||
import inspect
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import litellm
|
||||
import litellm.utils
|
||||
import inspect
|
||||
from litellm.responses.utils import Usage
|
||||
|
||||
from .adapters import (
|
||||
HuggingFaceLocalAdapter,
|
||||
HumanAdapter,
|
||||
MLXVLMAdapter,
|
||||
)
|
||||
from .callbacks import (
|
||||
ImageRetentionCallback,
|
||||
LoggingCallback,
|
||||
TrajectorySaverCallback,
|
||||
BudgetManagerCallback,
|
||||
TelemetryCallback,
|
||||
ImageRetentionCallback,
|
||||
LoggingCallback,
|
||||
OperatorNormalizerCallback,
|
||||
PromptInstructionsCallback,
|
||||
TelemetryCallback,
|
||||
TrajectorySaverCallback,
|
||||
)
|
||||
from .computers import (
|
||||
AsyncComputerHandler,
|
||||
is_agent_computer,
|
||||
make_computer_handler
|
||||
from .computers import AsyncComputerHandler, is_agent_computer, make_computer_handler
|
||||
from .decorators import find_agent_config
|
||||
from .responses import (
|
||||
make_tool_error_item,
|
||||
replace_failed_computer_calls_with_function_calls,
|
||||
)
|
||||
from .types import AgentCapability, IllegalArgumentError, Messages, ToolError
|
||||
|
||||
|
||||
def assert_callable_with(f, *args, **kwargs):
|
||||
"""Check if function can be called with given arguments."""
|
||||
try:
|
||||
inspect.signature(f).bind(*args, **kwargs)
|
||||
return True
|
||||
except TypeError as e:
|
||||
sig = inspect.signature(f)
|
||||
raise IllegalArgumentError(f"Expected {sig}, got args={args} kwargs={kwargs}") from e
|
||||
"""Check if function can be called with given arguments."""
|
||||
try:
|
||||
inspect.signature(f).bind(*args, **kwargs)
|
||||
return True
|
||||
except TypeError as e:
|
||||
sig = inspect.signature(f)
|
||||
raise IllegalArgumentError(f"Expected {sig}, got args={args} kwargs={kwargs}") from e
|
||||
|
||||
|
||||
def get_json(obj: Any, max_depth: int = 10) -> Any:
|
||||
def custom_serializer(o: Any, depth: int = 0, seen: Optional[Set[int]] = None) -> Any:
|
||||
if seen is None:
|
||||
seen = set()
|
||||
|
||||
|
||||
# Use model_dump() if available
|
||||
if hasattr(o, 'model_dump'):
|
||||
if hasattr(o, "model_dump"):
|
||||
return o.model_dump()
|
||||
|
||||
|
||||
# Check depth limit
|
||||
if depth > max_depth:
|
||||
return f"<max_depth_exceeded:{max_depth}>"
|
||||
|
||||
|
||||
# Check for circular references using object id
|
||||
obj_id = id(o)
|
||||
if obj_id in seen:
|
||||
return f"<circular_reference:{type(o).__name__}>"
|
||||
|
||||
|
||||
# Handle Computer objects
|
||||
if hasattr(o, '__class__') and 'computer' in getattr(o, '__class__').__name__.lower():
|
||||
if hasattr(o, "__class__") and "computer" in o.__class__.__name__.lower():
|
||||
return f"<computer:{o.__class__.__name__}>"
|
||||
|
||||
# Handle objects with __dict__
|
||||
if hasattr(o, '__dict__'):
|
||||
if hasattr(o, "__dict__"):
|
||||
seen.add(obj_id)
|
||||
try:
|
||||
result = {}
|
||||
@@ -84,7 +91,7 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
|
||||
return result
|
||||
finally:
|
||||
seen.discard(obj_id)
|
||||
|
||||
|
||||
# Handle common types that might contain nested objects
|
||||
elif isinstance(o, dict):
|
||||
seen.add(obj_id)
|
||||
@@ -96,7 +103,7 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
|
||||
}
|
||||
finally:
|
||||
seen.discard(obj_id)
|
||||
|
||||
|
||||
elif isinstance(o, (list, tuple, set)):
|
||||
seen.add(obj_id)
|
||||
try:
|
||||
@@ -107,32 +114,33 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
|
||||
]
|
||||
finally:
|
||||
seen.discard(obj_id)
|
||||
|
||||
|
||||
# For basic types that json.dumps can handle
|
||||
elif isinstance(o, (str, int, float, bool)) or o is None:
|
||||
return o
|
||||
|
||||
|
||||
# Fallback to string representation
|
||||
else:
|
||||
return str(o)
|
||||
|
||||
|
||||
def remove_nones(obj: Any) -> Any:
|
||||
if isinstance(obj, dict):
|
||||
return {k: remove_nones(v) for k, v in obj.items() if v is not None}
|
||||
elif isinstance(obj, list):
|
||||
return [remove_nones(item) for item in obj if item is not None]
|
||||
return obj
|
||||
|
||||
|
||||
# Serialize with circular reference and depth protection
|
||||
serialized = custom_serializer(obj)
|
||||
|
||||
|
||||
# Convert to JSON string and back to ensure JSON compatibility
|
||||
json_str = json.dumps(serialized)
|
||||
parsed = json.loads(json_str)
|
||||
|
||||
|
||||
# Final cleanup of any remaining None values
|
||||
return remove_nones(parsed)
|
||||
|
||||
|
||||
def sanitize_message(msg: Any) -> Any:
|
||||
"""Return a copy of the message with image_url omitted for computer_call_output messages."""
|
||||
if msg.get("type") == "computer_call_output":
|
||||
@@ -143,19 +151,24 @@ def sanitize_message(msg: Any) -> Any:
|
||||
return sanitized
|
||||
return msg
|
||||
|
||||
|
||||
def get_output_call_ids(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
call_ids = []
|
||||
for message in messages:
|
||||
if message.get("type") == "computer_call_output" or message.get("type") == "function_call_output":
|
||||
if (
|
||||
message.get("type") == "computer_call_output"
|
||||
or message.get("type") == "function_call_output"
|
||||
):
|
||||
call_ids.append(message.get("call_id"))
|
||||
return call_ids
|
||||
|
||||
|
||||
class ComputerAgent:
|
||||
"""
|
||||
Main agent class that automatically selects the appropriate agent loop
|
||||
based on the model and executes tool calls.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
@@ -172,11 +185,11 @@ class ComputerAgent:
|
||||
max_trajectory_budget: Optional[float | dict] = None,
|
||||
telemetry_enabled: Optional[bool] = True,
|
||||
trust_remote_code: Optional[bool] = False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize ComputerAgent.
|
||||
|
||||
|
||||
Args:
|
||||
model: Model name (e.g., "claude-3-5-sonnet-20241022", "computer-use-preview", "omni+vertex_ai/gemini-pro")
|
||||
tools: List of tools (computer objects, decorated functions, etc.)
|
||||
@@ -193,11 +206,11 @@ class ComputerAgent:
|
||||
telemetry_enabled: If set, adds TelemetryCallback to track anonymized usage data. Enabled by default.
|
||||
trust_remote_code: If set, trust remote code when loading local models. Disabled by default.
|
||||
**kwargs: Additional arguments passed to the agent loop
|
||||
"""
|
||||
"""
|
||||
# If the loop is "human/human", we need to prefix a grounding model fallback
|
||||
if model in ["human/human", "human"]:
|
||||
model = "openai/computer-use-preview+human/human"
|
||||
|
||||
|
||||
self.model = model
|
||||
self.tools = tools or []
|
||||
self.custom_loop = custom_loop
|
||||
@@ -236,34 +249,33 @@ class ComputerAgent:
|
||||
# Add image retention callback if only_n_most_recent_images is set
|
||||
if self.only_n_most_recent_images:
|
||||
self.callbacks.append(ImageRetentionCallback(self.only_n_most_recent_images))
|
||||
|
||||
|
||||
# Add trajectory saver callback if trajectory_dir is set
|
||||
if self.trajectory_dir:
|
||||
if isinstance(self.trajectory_dir, dict):
|
||||
self.callbacks.append(TrajectorySaverCallback(**self.trajectory_dir))
|
||||
elif isinstance(self.trajectory_dir, (str, Path)):
|
||||
self.callbacks.append(TrajectorySaverCallback(str(self.trajectory_dir)))
|
||||
|
||||
|
||||
# Add budget manager if max_trajectory_budget is set
|
||||
if max_trajectory_budget:
|
||||
if isinstance(max_trajectory_budget, dict):
|
||||
self.callbacks.append(BudgetManagerCallback(**max_trajectory_budget))
|
||||
else:
|
||||
self.callbacks.append(BudgetManagerCallback(max_trajectory_budget))
|
||||
|
||||
|
||||
# == Enable local model providers w/ LiteLLM ==
|
||||
|
||||
# Register local model providers
|
||||
hf_adapter = HuggingFaceLocalAdapter(
|
||||
device="auto",
|
||||
trust_remote_code=self.trust_remote_code or False
|
||||
device="auto", trust_remote_code=self.trust_remote_code or False
|
||||
)
|
||||
human_adapter = HumanAdapter()
|
||||
mlx_adapter = MLXVLMAdapter()
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "huggingface-local", "custom_handler": hf_adapter},
|
||||
{"provider": "human", "custom_handler": human_adapter},
|
||||
{"provider": "mlx", "custom_handler": mlx_adapter}
|
||||
{"provider": "mlx", "custom_handler": mlx_adapter},
|
||||
]
|
||||
litellm.suppress_debug_info = True
|
||||
|
||||
@@ -280,16 +292,16 @@ class ComputerAgent:
|
||||
# Instantiate the agent config class
|
||||
self.agent_loop = config_info.agent_class()
|
||||
self.agent_config_info = config_info
|
||||
|
||||
|
||||
self.tool_schemas = []
|
||||
self.computer_handler = None
|
||||
|
||||
|
||||
async def _initialize_computers(self):
|
||||
"""Initialize computer objects"""
|
||||
if not self.tool_schemas:
|
||||
# Process tools and create tool schemas
|
||||
self.tool_schemas = self._process_tools()
|
||||
|
||||
|
||||
# Find computer tool and create interface adapter
|
||||
computer_handler = None
|
||||
for schema in self.tool_schemas:
|
||||
@@ -297,7 +309,7 @@ class ComputerAgent:
|
||||
computer_handler = await make_computer_handler(schema["computer"])
|
||||
break
|
||||
self.computer_handler = computer_handler
|
||||
|
||||
|
||||
def _process_input(self, input: Messages) -> List[Dict[str, Any]]:
|
||||
"""Process input messages and create schemas for the agent loop"""
|
||||
if isinstance(input, str):
|
||||
@@ -307,69 +319,73 @@ class ComputerAgent:
|
||||
def _process_tools(self) -> List[Dict[str, Any]]:
|
||||
"""Process tools and create schemas for the agent loop"""
|
||||
schemas = []
|
||||
|
||||
|
||||
for tool in self.tools:
|
||||
# Check if it's a computer object (has interface attribute)
|
||||
if is_agent_computer(tool):
|
||||
# This is a computer tool - will be handled by agent loop
|
||||
schemas.append({
|
||||
"type": "computer",
|
||||
"computer": tool
|
||||
})
|
||||
schemas.append({"type": "computer", "computer": tool})
|
||||
elif callable(tool):
|
||||
# Use litellm.utils.function_to_dict to extract schema from docstring
|
||||
try:
|
||||
function_schema = litellm.utils.function_to_dict(tool)
|
||||
schemas.append({
|
||||
"type": "function",
|
||||
"function": function_schema
|
||||
})
|
||||
schemas.append({"type": "function", "function": function_schema})
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process tool {tool}: {e}")
|
||||
else:
|
||||
print(f"Warning: Unknown tool type: {tool}")
|
||||
|
||||
|
||||
return schemas
|
||||
|
||||
|
||||
def _get_tool(self, name: str) -> Optional[Callable]:
|
||||
"""Get a tool by name"""
|
||||
for tool in self.tools:
|
||||
if hasattr(tool, '__name__') and tool.__name__ == name:
|
||||
if hasattr(tool, "__name__") and tool.__name__ == name:
|
||||
return tool
|
||||
elif hasattr(tool, 'func') and tool.func.__name__ == name:
|
||||
elif hasattr(tool, "func") and tool.func.__name__ == name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# AGENT RUN LOOP LIFECYCLE HOOKS
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def _on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Initialize run tracking by calling callbacks."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_run_start'):
|
||||
if hasattr(callback, "on_run_start"):
|
||||
await callback.on_run_start(kwargs, old_items)
|
||||
|
||||
async def _on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
||||
|
||||
async def _on_run_end(
|
||||
self,
|
||||
kwargs: Dict[str, Any],
|
||||
old_items: List[Dict[str, Any]],
|
||||
new_items: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Finalize run tracking by calling callbacks."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_run_end'):
|
||||
if hasattr(callback, "on_run_end"):
|
||||
await callback.on_run_end(kwargs, old_items, new_items)
|
||||
|
||||
async def _on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
|
||||
|
||||
async def _on_run_continue(
|
||||
self,
|
||||
kwargs: Dict[str, Any],
|
||||
old_items: List[Dict[str, Any]],
|
||||
new_items: List[Dict[str, Any]],
|
||||
) -> bool:
|
||||
"""Check if run should continue by calling callbacks."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_run_continue'):
|
||||
if hasattr(callback, "on_run_continue"):
|
||||
should_continue = await callback.on_run_continue(kwargs, old_items, new_items)
|
||||
if not should_continue:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def _on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Prepare messages for the LLM call by applying callbacks."""
|
||||
result = messages
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_llm_start'):
|
||||
if hasattr(callback, "on_llm_start"):
|
||||
result = await callback.on_llm_start(result)
|
||||
return result
|
||||
|
||||
@@ -377,82 +393,91 @@ class ComputerAgent:
|
||||
"""Postprocess messages after the LLM call by applying callbacks."""
|
||||
result = messages
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_llm_end'):
|
||||
if hasattr(callback, "on_llm_end"):
|
||||
result = await callback.on_llm_end(result)
|
||||
return result
|
||||
|
||||
async def _on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
|
||||
"""Called when responses are received."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_responses'):
|
||||
if hasattr(callback, "on_responses"):
|
||||
await callback.on_responses(get_json(kwargs), get_json(responses))
|
||||
|
||||
|
||||
async def _on_computer_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a computer call is about to start."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_computer_call_start'):
|
||||
if hasattr(callback, "on_computer_call_start"):
|
||||
await callback.on_computer_call_start(get_json(item))
|
||||
|
||||
async def _on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
||||
|
||||
async def _on_computer_call_end(
|
||||
self, item: Dict[str, Any], result: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Called when a computer call has completed."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_computer_call_end'):
|
||||
if hasattr(callback, "on_computer_call_end"):
|
||||
await callback.on_computer_call_end(get_json(item), get_json(result))
|
||||
|
||||
|
||||
async def _on_function_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a function call is about to start."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_function_call_start'):
|
||||
if hasattr(callback, "on_function_call_start"):
|
||||
await callback.on_function_call_start(get_json(item))
|
||||
|
||||
async def _on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
||||
|
||||
async def _on_function_call_end(
|
||||
self, item: Dict[str, Any], result: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Called when a function call has completed."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_function_call_end'):
|
||||
if hasattr(callback, "on_function_call_end"):
|
||||
await callback.on_function_call_end(get_json(item), get_json(result))
|
||||
|
||||
|
||||
async def _on_text(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a text message is encountered."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_text'):
|
||||
if hasattr(callback, "on_text"):
|
||||
await callback.on_text(get_json(item))
|
||||
|
||||
|
||||
async def _on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
||||
"""Called when an LLM API call is about to start."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_api_start'):
|
||||
if hasattr(callback, "on_api_start"):
|
||||
await callback.on_api_start(get_json(kwargs))
|
||||
|
||||
|
||||
async def _on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
||||
"""Called when an LLM API call has completed."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_api_end'):
|
||||
if hasattr(callback, "on_api_end"):
|
||||
await callback.on_api_end(get_json(kwargs), get_json(result))
|
||||
|
||||
async def _on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Called when usage information is received."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_usage'):
|
||||
if hasattr(callback, "on_usage"):
|
||||
await callback.on_usage(get_json(usage))
|
||||
|
||||
async def _on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
|
||||
"""Called when a screenshot is taken."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_screenshot'):
|
||||
if hasattr(callback, "on_screenshot"):
|
||||
await callback.on_screenshot(screenshot, name)
|
||||
|
||||
# ============================================================================
|
||||
# AGENT OUTPUT PROCESSING
|
||||
# ============================================================================
|
||||
|
||||
async def _handle_item(self, item: Any, computer: Optional[AsyncComputerHandler] = None, ignore_call_ids: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
||||
|
||||
async def _handle_item(
|
||||
self,
|
||||
item: Any,
|
||||
computer: Optional[AsyncComputerHandler] = None,
|
||||
ignore_call_ids: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Handle each item; may cause a computer action + screenshot."""
|
||||
call_id = item.get("call_id")
|
||||
if ignore_call_ids and call_id and call_id in ignore_call_ids:
|
||||
return []
|
||||
|
||||
|
||||
item_type = item.get("type", None)
|
||||
|
||||
|
||||
if item_type == "message":
|
||||
await self._on_text(item)
|
||||
# # Print messages
|
||||
@@ -461,7 +486,7 @@ class ComputerAgent:
|
||||
# if content_item.get("text"):
|
||||
# print(content_item.get("text"))
|
||||
return []
|
||||
|
||||
|
||||
try:
|
||||
if item_type == "computer_call":
|
||||
await self._on_computer_call_start(item)
|
||||
@@ -472,14 +497,16 @@ class ComputerAgent:
|
||||
action = item.get("action")
|
||||
action_type = action.get("type")
|
||||
if action_type is None:
|
||||
print(f"Action type cannot be `None`: action={action}, action_type={action_type}")
|
||||
print(
|
||||
f"Action type cannot be `None`: action={action}, action_type={action_type}"
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
# Extract action arguments (all fields except 'type')
|
||||
action_args = {k: v for k, v in action.items() if k != "type"}
|
||||
|
||||
|
||||
# print(f"{action_type}({action_args})")
|
||||
|
||||
|
||||
# Execute the computer action
|
||||
computer_method = getattr(computer, action_type, None)
|
||||
if computer_method:
|
||||
@@ -487,13 +514,13 @@ class ComputerAgent:
|
||||
await computer_method(**action_args)
|
||||
else:
|
||||
raise ToolError(f"Unknown computer action: {action_type}")
|
||||
|
||||
|
||||
# Take screenshot after action
|
||||
if self.screenshot_delay and self.screenshot_delay > 0:
|
||||
await asyncio.sleep(self.screenshot_delay)
|
||||
screenshot_base64 = await computer.screenshot()
|
||||
await self._on_screenshot(screenshot_base64, "screenshot_after")
|
||||
|
||||
|
||||
# Handle safety checks
|
||||
pending_checks = item.get("pending_safety_checks", [])
|
||||
acknowledged_checks = []
|
||||
@@ -505,7 +532,7 @@ class ComputerAgent:
|
||||
# acknowledged_checks.append(check)
|
||||
# else:
|
||||
# raise ValueError(f"Safety check failed: {check_message}")
|
||||
|
||||
|
||||
# Create call output
|
||||
call_output = {
|
||||
"type": "computer_call_output",
|
||||
@@ -516,25 +543,25 @@ class ComputerAgent:
|
||||
"image_url": f"data:image/png;base64,{screenshot_base64}",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# # Additional URL safety checks for browser environments
|
||||
# if await computer.get_environment() == "browser":
|
||||
# current_url = await computer.get_current_url()
|
||||
# call_output["output"]["current_url"] = current_url
|
||||
# # TODO: implement a callback for URL safety checks
|
||||
# # check_blocklisted_url(current_url)
|
||||
|
||||
|
||||
result = [call_output]
|
||||
await self._on_computer_call_end(item, result)
|
||||
return result
|
||||
|
||||
|
||||
if item_type == "function_call":
|
||||
await self._on_function_call_start(item)
|
||||
# Perform function call
|
||||
function = self._get_tool(item.get("name"))
|
||||
if not function:
|
||||
raise ToolError(f"Function {item.get("name")} not found")
|
||||
|
||||
raise ToolError(f"Function {item.get('name')} not found")
|
||||
|
||||
args = json.loads(item.get("arguments"))
|
||||
|
||||
# Validate arguments before execution
|
||||
@@ -545,14 +572,14 @@ class ComputerAgent:
|
||||
result = await function(**args)
|
||||
else:
|
||||
result = await asyncio.to_thread(function, **args)
|
||||
|
||||
|
||||
# Create function call output
|
||||
call_output = {
|
||||
"type": "function_call_output",
|
||||
"call_id": item.get("call_id"),
|
||||
"output": str(result),
|
||||
}
|
||||
|
||||
|
||||
result = [call_output]
|
||||
await self._on_function_call_end(item, result)
|
||||
return result
|
||||
@@ -564,36 +591,35 @@ class ComputerAgent:
|
||||
# ============================================================================
|
||||
# MAIN AGENT LOOP
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def run(
|
||||
self,
|
||||
messages: Messages,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
self, messages: Messages, stream: bool = False, **kwargs
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
Run the agent with the given messages using Computer protocol handler pattern.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
stream: Whether to stream the response
|
||||
**kwargs: Additional arguments
|
||||
|
||||
|
||||
Returns:
|
||||
AsyncGenerator that yields response chunks
|
||||
"""
|
||||
if not self.agent_config_info:
|
||||
raise ValueError("Agent configuration not found")
|
||||
|
||||
|
||||
capabilities = self.get_capabilities()
|
||||
if "step" not in capabilities:
|
||||
raise ValueError(f"Agent loop {self.agent_config_info.agent_class.__name__} does not support step predictions")
|
||||
raise ValueError(
|
||||
f"Agent loop {self.agent_config_info.agent_class.__name__} does not support step predictions"
|
||||
)
|
||||
|
||||
await self._initialize_computers()
|
||||
|
||||
|
||||
# Merge kwargs
|
||||
merged_kwargs = {**self.kwargs, **kwargs}
|
||||
|
||||
|
||||
old_items = self._process_input(messages)
|
||||
new_items = []
|
||||
|
||||
@@ -603,7 +629,7 @@ class ComputerAgent:
|
||||
"stream": stream,
|
||||
"model": self.model,
|
||||
"agent_loop": self.agent_config_info.agent_class.__name__,
|
||||
**merged_kwargs
|
||||
**merged_kwargs,
|
||||
}
|
||||
await self._on_run_start(run_kwargs, old_items)
|
||||
|
||||
@@ -620,7 +646,7 @@ class ComputerAgent:
|
||||
combined_messages = old_items + new_items
|
||||
combined_messages = replace_failed_computer_calls_with_function_calls(combined_messages)
|
||||
preprocessed_messages = await self._on_llm_start(combined_messages)
|
||||
|
||||
|
||||
loop_kwargs = {
|
||||
"messages": preprocessed_messages,
|
||||
"model": self.model,
|
||||
@@ -629,7 +655,7 @@ class ComputerAgent:
|
||||
"computer_handler": self.computer_handler,
|
||||
"max_retries": self.max_retries,
|
||||
"use_prompt_caching": self.use_prompt_caching,
|
||||
**merged_kwargs
|
||||
**merged_kwargs,
|
||||
}
|
||||
|
||||
# Run agent loop iteration
|
||||
@@ -641,13 +667,13 @@ class ComputerAgent:
|
||||
_on_screenshot=self._on_screenshot,
|
||||
)
|
||||
result = get_json(result)
|
||||
|
||||
|
||||
# Lifecycle hook: Postprocess messages after the LLM call
|
||||
# Use cases:
|
||||
# - PII deanonymization (if you want tool calls to see PII)
|
||||
result["output"] = await self._on_llm_end(result.get("output", []))
|
||||
await self._on_responses(loop_kwargs, result)
|
||||
|
||||
|
||||
# Yield agent response
|
||||
yield result
|
||||
|
||||
@@ -659,7 +685,9 @@ class ComputerAgent:
|
||||
|
||||
# Handle computer actions
|
||||
for item in result.get("output"):
|
||||
partial_items = await self._handle_item(item, self.computer_handler, ignore_call_ids=output_call_ids)
|
||||
partial_items = await self._handle_item(
|
||||
item, self.computer_handler, ignore_call_ids=output_call_ids
|
||||
)
|
||||
new_items += partial_items
|
||||
|
||||
# Yield partial response
|
||||
@@ -669,54 +697,52 @@ class ComputerAgent:
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
await self._on_run_end(loop_kwargs, old_items, new_items)
|
||||
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
instruction: str,
|
||||
image_b64: Optional[str] = None
|
||||
self, instruction: str, image_b64: Optional[str] = None
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates based on image and instruction.
|
||||
|
||||
|
||||
Args:
|
||||
instruction: Instruction for where to click
|
||||
image_b64: Base64 encoded image (optional, will take screenshot if not provided)
|
||||
|
||||
|
||||
Returns:
|
||||
None or tuple with (x, y) coordinates
|
||||
"""
|
||||
if not self.agent_config_info:
|
||||
raise ValueError("Agent configuration not found")
|
||||
|
||||
|
||||
capabilities = self.get_capabilities()
|
||||
if "click" not in capabilities:
|
||||
raise ValueError(f"Agent loop {self.agent_config_info.agent_class.__name__} does not support click predictions")
|
||||
if hasattr(self.agent_loop, 'predict_click'):
|
||||
raise ValueError(
|
||||
f"Agent loop {self.agent_config_info.agent_class.__name__} does not support click predictions"
|
||||
)
|
||||
if hasattr(self.agent_loop, "predict_click"):
|
||||
if not image_b64:
|
||||
if not self.computer_handler:
|
||||
raise ValueError("Computer tool or image_b64 is required for predict_click")
|
||||
image_b64 = await self.computer_handler.screenshot()
|
||||
return await self.agent_loop.predict_click(
|
||||
model=self.model,
|
||||
image_b64=image_b64,
|
||||
instruction=instruction
|
||||
model=self.model, image_b64=image_b64, instruction=instruction
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""
|
||||
Get list of capabilities supported by the current agent config.
|
||||
|
||||
|
||||
Returns:
|
||||
List of capability strings (e.g., ["step", "click"])
|
||||
"""
|
||||
if not self.agent_config_info:
|
||||
raise ValueError("Agent configuration not found")
|
||||
|
||||
if hasattr(self.agent_loop, 'get_capabilities'):
|
||||
|
||||
if hasattr(self.agent_loop, "get_capabilities"):
|
||||
return self.agent_loop.get_capabilities()
|
||||
return ["step"] # Default capability
|
||||
return ["step"] # Default capability
|
||||
|
||||
@@ -3,17 +3,17 @@ Callback system for ComputerAgent preprocessing and postprocessing hooks.
|
||||
"""
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
from .budget_manager import BudgetManagerCallback
|
||||
from .image_retention import ImageRetentionCallback
|
||||
from .logging import LoggingCallback
|
||||
from .trajectory_saver import TrajectorySaverCallback
|
||||
from .budget_manager import BudgetManagerCallback
|
||||
from .telemetry import TelemetryCallback
|
||||
from .operator_validator import OperatorNormalizerCallback
|
||||
from .prompt_instructions import PromptInstructionsCallback
|
||||
from .telemetry import TelemetryCallback
|
||||
from .trajectory_saver import TrajectorySaverCallback
|
||||
|
||||
__all__ = [
|
||||
"AsyncCallbackHandler",
|
||||
"ImageRetentionCallback",
|
||||
"ImageRetentionCallback",
|
||||
"LoggingCallback",
|
||||
"TrajectorySaverCallback",
|
||||
"BudgetManagerCallback",
|
||||
|
||||
@@ -3,7 +3,7 @@ Base callback handler interface for ComputerAgent preprocessing and postprocessi
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
class AsyncCallbackHandler(ABC):
|
||||
@@ -16,42 +16,52 @@ class AsyncCallbackHandler(ABC):
|
||||
"""Called at the start of an agent run loop."""
|
||||
pass
|
||||
|
||||
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
||||
async def on_run_end(
|
||||
self,
|
||||
kwargs: Dict[str, Any],
|
||||
old_items: List[Dict[str, Any]],
|
||||
new_items: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Called at the end of an agent run loop."""
|
||||
pass
|
||||
|
||||
async def on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
|
||||
|
||||
async def on_run_continue(
|
||||
self,
|
||||
kwargs: Dict[str, Any],
|
||||
old_items: List[Dict[str, Any]],
|
||||
new_items: List[Dict[str, Any]],
|
||||
) -> bool:
|
||||
"""Called during agent run loop to determine if execution should continue.
|
||||
|
||||
|
||||
Args:
|
||||
kwargs: Run arguments
|
||||
old_items: Original messages
|
||||
new_items: New messages generated during run
|
||||
|
||||
|
||||
Returns:
|
||||
True to continue execution, False to stop
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Called before messages are sent to the agent loop.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries to preprocess
|
||||
|
||||
|
||||
Returns:
|
||||
List of preprocessed message dictionaries
|
||||
"""
|
||||
return messages
|
||||
|
||||
|
||||
async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Called after the agent loop returns output.
|
||||
|
||||
|
||||
Args:
|
||||
output: List of output message dictionaries to postprocess
|
||||
|
||||
|
||||
Returns:
|
||||
List of postprocessed output dictionaries
|
||||
"""
|
||||
@@ -60,63 +70,67 @@ class AsyncCallbackHandler(ABC):
|
||||
async def on_computer_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when a computer call is about to start.
|
||||
|
||||
|
||||
Args:
|
||||
item: The computer call item dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
||||
|
||||
async def on_computer_call_end(
|
||||
self, item: Dict[str, Any], result: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""
|
||||
Called when a computer call has completed.
|
||||
|
||||
|
||||
Args:
|
||||
item: The computer call item dictionary
|
||||
result: The result of the computer call
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def on_function_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when a function call is about to start.
|
||||
|
||||
|
||||
Args:
|
||||
item: The function call item dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
||||
|
||||
async def on_function_call_end(
|
||||
self, item: Dict[str, Any], result: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""
|
||||
Called when a function call has completed.
|
||||
|
||||
|
||||
Args:
|
||||
item: The function call item dictionary
|
||||
result: The result of the function call
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def on_text(self, item: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when a text message is encountered.
|
||||
|
||||
|
||||
Args:
|
||||
item: The message item dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when an API call is about to start.
|
||||
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs being passed to the API call
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
||||
"""
|
||||
Called when an API call has completed.
|
||||
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs that were passed to the API call
|
||||
result: The result of the API call
|
||||
@@ -126,7 +140,7 @@ class AsyncCallbackHandler(ABC):
|
||||
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when usage information is received.
|
||||
|
||||
|
||||
Args:
|
||||
usage: The usage information
|
||||
"""
|
||||
@@ -135,7 +149,7 @@ class AsyncCallbackHandler(ABC):
|
||||
async def on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
|
||||
"""
|
||||
Called when a screenshot is taken.
|
||||
|
||||
|
||||
Args:
|
||||
screenshot: The screenshot image
|
||||
name: The name of the screenshot
|
||||
@@ -145,9 +159,9 @@ class AsyncCallbackHandler(ABC):
|
||||
async def on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when responses are received.
|
||||
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs being passed to the agent loop
|
||||
responses: The responses received
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
|
||||
class BudgetExceededError(Exception):
|
||||
"""Exception raised when budget is exceeded."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BudgetManagerCallback(AsyncCallbackHandler):
|
||||
"""Budget manager callback that tracks usage costs and can stop execution when budget is exceeded."""
|
||||
|
||||
def __init__(self, max_budget: float, reset_after_each_run: bool = True, raise_error: bool = False):
|
||||
|
||||
def __init__(
|
||||
self, max_budget: float, reset_after_each_run: bool = True, raise_error: bool = False
|
||||
):
|
||||
"""
|
||||
Initialize BudgetManagerCallback.
|
||||
|
||||
|
||||
Args:
|
||||
max_budget: Maximum budget allowed
|
||||
reset_after_each_run: Whether to reset budget after each run
|
||||
@@ -21,24 +27,30 @@ class BudgetManagerCallback(AsyncCallbackHandler):
|
||||
self.reset_after_each_run = reset_after_each_run
|
||||
self.raise_error = raise_error
|
||||
self.total_cost = 0.0
|
||||
|
||||
|
||||
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Reset budget if configured to do so."""
|
||||
if self.reset_after_each_run:
|
||||
self.total_cost = 0.0
|
||||
|
||||
|
||||
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Track usage costs."""
|
||||
if "response_cost" in usage:
|
||||
self.total_cost += usage["response_cost"]
|
||||
|
||||
async def on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
|
||||
|
||||
async def on_run_continue(
|
||||
self,
|
||||
kwargs: Dict[str, Any],
|
||||
old_items: List[Dict[str, Any]],
|
||||
new_items: List[Dict[str, Any]],
|
||||
) -> bool:
|
||||
"""Check if budget allows continuation."""
|
||||
if self.total_cost >= self.max_budget:
|
||||
if self.raise_error:
|
||||
raise BudgetExceededError(f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}")
|
||||
raise BudgetExceededError(
|
||||
f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}"
|
||||
)
|
||||
else:
|
||||
print(f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}")
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
Image retention callback handler that limits the number of recent images in message history.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
|
||||
@@ -11,40 +12,40 @@ class ImageRetentionCallback(AsyncCallbackHandler):
|
||||
Callback handler that applies image retention policy to limit the number
|
||||
of recent images in message history to prevent context window overflow.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, only_n_most_recent_images: Optional[int] = None):
|
||||
"""
|
||||
Initialize the image retention callback.
|
||||
|
||||
|
||||
Args:
|
||||
only_n_most_recent_images: If set, only keep the N most recent images in message history
|
||||
"""
|
||||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
|
||||
|
||||
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Apply image retention policy to messages before sending to agent loop.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
|
||||
Returns:
|
||||
List of messages with image retention policy applied
|
||||
"""
|
||||
if self.only_n_most_recent_images is None:
|
||||
return messages
|
||||
|
||||
|
||||
return self._apply_image_retention(messages)
|
||||
|
||||
|
||||
def _apply_image_retention(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Apply image retention policy to keep only the N most recent images.
|
||||
|
||||
|
||||
Removes computer_call_output items with image_url and their corresponding computer_call items,
|
||||
keeping only the most recent N image pairs based on only_n_most_recent_images setting.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
|
||||
Returns:
|
||||
Filtered list of messages with image retention applied
|
||||
"""
|
||||
@@ -78,7 +79,11 @@ class ImageRetentionCallback(AsyncCallbackHandler):
|
||||
# Remove the immediately preceding computer_call with matching call_id (if present)
|
||||
call_id = messages[idx].get("call_id")
|
||||
prev_idx = idx - 1
|
||||
if prev_idx >= 0 and messages[prev_idx].get("type") == "computer_call" and messages[prev_idx].get("call_id") == call_id:
|
||||
if (
|
||||
prev_idx >= 0
|
||||
and messages[prev_idx].get("type") == "computer_call"
|
||||
and messages[prev_idx].get("call_id") == call_id
|
||||
):
|
||||
to_remove.add(prev_idx)
|
||||
# Check a single reasoning immediately before that computer_call
|
||||
r_idx = prev_idx - 1
|
||||
@@ -87,4 +92,4 @@ class ImageRetentionCallback(AsyncCallbackHandler):
|
||||
|
||||
# Construct filtered list
|
||||
filtered = [m for i, m in enumerate(messages) if i not in to_remove]
|
||||
return filtered
|
||||
return filtered
|
||||
|
||||
@@ -4,17 +4,18 @@ Logging callback for ComputerAgent that provides configurable logging of agent l
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
|
||||
def sanitize_image_urls(data: Any) -> Any:
|
||||
"""
|
||||
Recursively search for 'image_url' keys and set their values to '[omitted]'.
|
||||
|
||||
|
||||
Args:
|
||||
data: Any data structure (dict, list, or primitive type)
|
||||
|
||||
|
||||
Returns:
|
||||
A deep copy of the data with all 'image_url' values replaced with '[omitted]'
|
||||
"""
|
||||
@@ -28,11 +29,11 @@ def sanitize_image_urls(data: Any) -> Any:
|
||||
# Recursively sanitize the value
|
||||
sanitized[key] = sanitize_image_urls(value)
|
||||
return sanitized
|
||||
|
||||
|
||||
elif isinstance(data, list):
|
||||
# Recursively sanitize each item in the list
|
||||
return [sanitize_image_urls(item) for item in data]
|
||||
|
||||
|
||||
else:
|
||||
# For primitive types (str, int, bool, None, etc.), return as-is
|
||||
return data
|
||||
@@ -41,37 +42,36 @@ def sanitize_image_urls(data: Any) -> Any:
|
||||
class LoggingCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
Callback handler that logs agent lifecycle events with configurable verbosity.
|
||||
|
||||
|
||||
Logging levels:
|
||||
- DEBUG: All events including API calls, message preprocessing, and detailed outputs
|
||||
- INFO: Major lifecycle events (start/end, messages, outputs)
|
||||
- INFO: Major lifecycle events (start/end, messages, outputs)
|
||||
- WARNING: Only warnings and errors
|
||||
- ERROR: Only errors
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, logger: Optional[logging.Logger] = None, level: int = logging.INFO):
|
||||
"""
|
||||
Initialize the logging callback.
|
||||
|
||||
|
||||
Args:
|
||||
logger: Logger instance to use. If None, creates a logger named 'agent.ComputerAgent'
|
||||
level: Logging level (logging.DEBUG, logging.INFO, etc.)
|
||||
"""
|
||||
self.logger = logger or logging.getLogger('agent.ComputerAgent')
|
||||
self.logger = logger or logging.getLogger("agent.ComputerAgent")
|
||||
self.level = level
|
||||
|
||||
|
||||
# Set up logger if it doesn't have handlers
|
||||
if not self.logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
self.logger.addHandler(handler)
|
||||
self.logger.setLevel(level)
|
||||
|
||||
|
||||
def _update_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Update total usage statistics."""
|
||||
|
||||
def add_dicts(target: Dict[str, Any], source: Dict[str, Any]) -> None:
|
||||
for key, value in source.items():
|
||||
if isinstance(value, dict):
|
||||
@@ -82,18 +82,25 @@ class LoggingCallback(AsyncCallbackHandler):
|
||||
if key not in target:
|
||||
target[key] = 0
|
||||
target[key] += value
|
||||
|
||||
add_dicts(self.total_usage, usage)
|
||||
|
||||
|
||||
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Called before the run starts."""
|
||||
self.total_usage = {}
|
||||
|
||||
|
||||
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Called when usage information is received."""
|
||||
self._update_usage(usage)
|
||||
|
||||
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
||||
async def on_run_end(
|
||||
self,
|
||||
kwargs: Dict[str, Any],
|
||||
old_items: List[Dict[str, Any]],
|
||||
new_items: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Called after the run ends."""
|
||||
|
||||
def format_dict(d, indent=0):
|
||||
lines = []
|
||||
prefix = f" - {' ' * indent}"
|
||||
@@ -106,10 +113,10 @@ class LoggingCallback(AsyncCallbackHandler):
|
||||
else:
|
||||
lines.append(f"{prefix}{key}: {value}")
|
||||
return lines
|
||||
|
||||
|
||||
formatted_output = "\n".join(format_dict(self.total_usage))
|
||||
self.logger.info(f"Total usage:\n{formatted_output}")
|
||||
|
||||
|
||||
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Called before LLM processing starts."""
|
||||
if self.logger.isEnabledFor(logging.INFO):
|
||||
@@ -118,27 +125,27 @@ class LoggingCallback(AsyncCallbackHandler):
|
||||
sanitized_messages = [sanitize_image_urls(msg) for msg in messages]
|
||||
self.logger.debug(f"LLM input messages: {json.dumps(sanitized_messages, indent=2)}")
|
||||
return messages
|
||||
|
||||
|
||||
async def on_llm_end(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Called after LLM processing ends."""
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
sanitized_messages = [sanitize_image_urls(msg) for msg in messages]
|
||||
self.logger.debug(f"LLM output: {json.dumps(sanitized_messages, indent=2)}")
|
||||
return messages
|
||||
|
||||
|
||||
async def on_computer_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a computer call starts."""
|
||||
action = item.get("action", {})
|
||||
action_type = action.get("type", "unknown")
|
||||
action_args = {k: v for k, v in action.items() if k != "type"}
|
||||
|
||||
|
||||
# INFO level logging for the action
|
||||
self.logger.info(f"Computer: {action_type}({action_args})")
|
||||
|
||||
|
||||
# DEBUG level logging for full details
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
self.logger.debug(f"Computer call started: {json.dumps(action, indent=2)}")
|
||||
|
||||
|
||||
async def on_computer_call_end(self, item: Dict[str, Any], result: Any) -> None:
|
||||
"""Called when a computer call ends."""
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
@@ -147,48 +154,52 @@ class LoggingCallback(AsyncCallbackHandler):
|
||||
if result:
|
||||
sanitized_result = sanitize_image_urls(result)
|
||||
self.logger.debug(f"Computer call result: {json.dumps(sanitized_result, indent=2)}")
|
||||
|
||||
|
||||
async def on_function_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a function call starts."""
|
||||
name = item.get("name", "unknown")
|
||||
arguments = item.get("arguments", "{}")
|
||||
|
||||
|
||||
# INFO level logging for the function call
|
||||
self.logger.info(f"Function: {name}({arguments})")
|
||||
|
||||
|
||||
# DEBUG level logging for full details
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
self.logger.debug(f"Function call started: {name}")
|
||||
|
||||
|
||||
async def on_function_call_end(self, item: Dict[str, Any], result: Any) -> None:
|
||||
"""Called when a function call ends."""
|
||||
# INFO level logging for function output (similar to function_call_output)
|
||||
if result:
|
||||
# Handle both list and direct result formats
|
||||
if isinstance(result, list) and len(result) > 0:
|
||||
output = result[0].get("output", str(result)) if isinstance(result[0], dict) else str(result[0])
|
||||
output = (
|
||||
result[0].get("output", str(result))
|
||||
if isinstance(result[0], dict)
|
||||
else str(result[0])
|
||||
)
|
||||
else:
|
||||
output = str(result)
|
||||
|
||||
|
||||
# Truncate long outputs
|
||||
if len(output) > 100:
|
||||
output = output[:100] + "..."
|
||||
|
||||
|
||||
self.logger.info(f"Output: {output}")
|
||||
|
||||
|
||||
# DEBUG level logging for full details
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
name = item.get("name", "unknown")
|
||||
self.logger.debug(f"Function call completed: {name}")
|
||||
if result:
|
||||
self.logger.debug(f"Function call result: {json.dumps(result, indent=2)}")
|
||||
|
||||
|
||||
async def on_text(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a text message is encountered."""
|
||||
# Get the role to determine if it's Agent or User
|
||||
role = item.get("role", "unknown")
|
||||
content_items = item.get("content", [])
|
||||
|
||||
|
||||
# Process content items to build display text
|
||||
text_parts = []
|
||||
for content_item in content_items:
|
||||
@@ -206,10 +217,10 @@ class LoggingCallback(AsyncCallbackHandler):
|
||||
else:
|
||||
# Non-text content, show as [type]
|
||||
text_parts.append(f"[{content_type}]")
|
||||
|
||||
|
||||
# Join all text parts
|
||||
display_text = ''.join(text_parts) if text_parts else "[empty]"
|
||||
|
||||
display_text = "".join(text_parts) if text_parts else "[empty]"
|
||||
|
||||
# Log with appropriate level and format
|
||||
if role == "assistant":
|
||||
self.logger.info(f"Agent: {display_text}")
|
||||
@@ -219,7 +230,7 @@ class LoggingCallback(AsyncCallbackHandler):
|
||||
# Fallback for unknown roles, use debug level
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
self.logger.debug(f"Text message ({role}): {display_text}")
|
||||
|
||||
|
||||
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
||||
"""Called when an API call is about to start."""
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
@@ -232,16 +243,18 @@ class LoggingCallback(AsyncCallbackHandler):
|
||||
elif "input" in kwargs:
|
||||
sanitized_input = sanitize_image_urls(kwargs["input"])
|
||||
self.logger.debug(f"API call input: {json.dumps(sanitized_input, indent=2)}")
|
||||
|
||||
|
||||
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
||||
"""Called when an API call has completed."""
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
model = kwargs.get("model", "unknown")
|
||||
self.logger.debug(f"API call completed for model: {model}")
|
||||
self.logger.debug(f"API call result: {json.dumps(sanitize_image_urls(result), indent=2)}")
|
||||
self.logger.debug(
|
||||
f"API call result: {json.dumps(sanitize_image_urls(result), indent=2)}"
|
||||
)
|
||||
|
||||
async def on_screenshot(self, item: Union[str, bytes], name: str = "screenshot") -> None:
|
||||
"""Called when a screenshot is taken."""
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
image_size = len(item) / 1024
|
||||
self.logger.debug(f"Screenshot captured: {name} {image_size:.2f} KB")
|
||||
self.logger.debug(f"Screenshot captured: {name} {image_size:.2f} KB")
|
||||
|
||||
@@ -9,6 +9,7 @@ Ensures agent output actions conform to expected schemas by fixing common issues
|
||||
This runs in on_llm_end, which receives the output array (AgentMessage[] as dicts).
|
||||
The purpose is to avoid spending another LLM call to fix broken computer call syntax when possible.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
@@ -48,6 +49,7 @@ class OperatorNormalizerCallback(AsyncCallbackHandler):
|
||||
action["type"] = "type"
|
||||
|
||||
action_type = action.get("type")
|
||||
|
||||
def _keep_keys(action: Dict[str, Any], keys_to_keep: List[str]):
|
||||
"""Keep only the provided keys on action; delete everything else.
|
||||
Always ensures required 'type' is present if listed in keys_to_keep.
|
||||
@@ -55,6 +57,7 @@ class OperatorNormalizerCallback(AsyncCallbackHandler):
|
||||
for key in list(action.keys()):
|
||||
if key not in keys_to_keep:
|
||||
del action[key]
|
||||
|
||||
# rename "coordinate" to "x", "y"
|
||||
if "coordinate" in action:
|
||||
action["x"] = action["coordinate"][0]
|
||||
@@ -100,7 +103,6 @@ class OperatorNormalizerCallback(AsyncCallbackHandler):
|
||||
keep = required_keys_by_type.get(action_type or "")
|
||||
if keep:
|
||||
_keep_keys(action, keep)
|
||||
|
||||
|
||||
# # Second pass: if an assistant message is immediately followed by a computer_call,
|
||||
# # replace the assistant message itself with a reasoning message with summary text.
|
||||
|
||||
@@ -2,38 +2,41 @@
|
||||
PII anonymization callback handler using Microsoft Presidio for text and image redaction.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from .base import AsyncCallbackHandler
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
try:
|
||||
# TODO: Add Presidio dependencies
|
||||
from PIL import Image
|
||||
|
||||
PRESIDIO_AVAILABLE = True
|
||||
except ImportError:
|
||||
PRESIDIO_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PIIAnonymizationCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
Callback handler that anonymizes PII in text and images using Microsoft Presidio.
|
||||
|
||||
|
||||
This handler:
|
||||
1. Anonymizes PII in messages before sending to the agent loop
|
||||
2. Deanonymizes PII in tool calls and message outputs after the agent loop
|
||||
3. Redacts PII from images in computer_call_output messages
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# TODO: Any extra kwargs if needed
|
||||
):
|
||||
"""
|
||||
Initialize the PII anonymization callback.
|
||||
|
||||
|
||||
Args:
|
||||
anonymize_text: Whether to anonymize text content
|
||||
anonymize_images: Whether to redact images
|
||||
@@ -46,16 +49,16 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
|
||||
"Presidio is not available. Install with: "
|
||||
"pip install cua-agent[pii-anonymization]"
|
||||
)
|
||||
|
||||
|
||||
# TODO: Implement __init__
|
||||
|
||||
|
||||
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Anonymize PII in messages before sending to agent loop.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
|
||||
Returns:
|
||||
List of messages with PII anonymized
|
||||
"""
|
||||
@@ -63,16 +66,16 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
|
||||
for msg in messages:
|
||||
anonymized_msg = await self._anonymize_message(msg)
|
||||
anonymized_messages.append(anonymized_msg)
|
||||
|
||||
|
||||
return anonymized_messages
|
||||
|
||||
|
||||
async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Deanonymize PII in tool calls and message outputs after agent loop.
|
||||
|
||||
|
||||
Args:
|
||||
output: List of output dictionaries
|
||||
|
||||
|
||||
Returns:
|
||||
List of output with PII deanonymized for tool calls
|
||||
"""
|
||||
@@ -84,13 +87,13 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
|
||||
deanonymized_output.append(deanonymized_item)
|
||||
else:
|
||||
deanonymized_output.append(item)
|
||||
|
||||
|
||||
return deanonymized_output
|
||||
|
||||
|
||||
async def _anonymize_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# TODO: Implement _anonymize_message
|
||||
return message
|
||||
|
||||
|
||||
async def _deanonymize_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# TODO: Implement _deanonymize_item
|
||||
return item
|
||||
|
||||
@@ -2,17 +2,17 @@
|
||||
Telemetry callback handler for Computer-Use Agent (cua-agent)
|
||||
"""
|
||||
|
||||
import platform
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
from core.telemetry import (
|
||||
record_event,
|
||||
is_telemetry_enabled,
|
||||
record_event,
|
||||
)
|
||||
|
||||
import platform
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
SYSTEM_INFO = {
|
||||
"os": platform.system().lower(),
|
||||
@@ -20,32 +20,29 @@ SYSTEM_INFO = {
|
||||
"python_version": platform.python_version(),
|
||||
}
|
||||
|
||||
|
||||
class TelemetryCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
Telemetry callback handler for Computer-Use Agent (cua-agent)
|
||||
|
||||
|
||||
Tracks agent usage, performance metrics, and optionally trajectory data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent,
|
||||
log_trajectory: bool = False
|
||||
):
|
||||
|
||||
def __init__(self, agent, log_trajectory: bool = False):
|
||||
"""
|
||||
Initialize telemetry callback.
|
||||
|
||||
|
||||
Args:
|
||||
agent: The ComputerAgent instance
|
||||
log_trajectory: Whether to log full trajectory items (opt-in)
|
||||
"""
|
||||
self.agent = agent
|
||||
self.log_trajectory = log_trajectory
|
||||
|
||||
|
||||
# Generate session/run IDs
|
||||
self.session_id = str(uuid.uuid4())
|
||||
self.run_id = None
|
||||
|
||||
|
||||
# Track timing and metrics
|
||||
self.run_start_time = None
|
||||
self.step_count = 0
|
||||
@@ -54,126 +51,133 @@ class TelemetryCallback(AsyncCallbackHandler):
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"response_cost": 0.0
|
||||
"response_cost": 0.0,
|
||||
}
|
||||
|
||||
|
||||
# Record agent initialization
|
||||
if is_telemetry_enabled():
|
||||
self._record_agent_initialization()
|
||||
|
||||
|
||||
def _record_agent_initialization(self) -> None:
|
||||
"""Record agent type/model and session initialization."""
|
||||
agent_info = {
|
||||
"session_id": self.session_id,
|
||||
"agent_type": self.agent.agent_loop.__name__ if hasattr(self.agent, 'agent_loop') else 'unknown',
|
||||
"model": getattr(self.agent, 'model', 'unknown'),
|
||||
**SYSTEM_INFO
|
||||
"agent_type": (
|
||||
self.agent.agent_loop.__name__ if hasattr(self.agent, "agent_loop") else "unknown"
|
||||
),
|
||||
"model": getattr(self.agent, "model", "unknown"),
|
||||
**SYSTEM_INFO,
|
||||
}
|
||||
|
||||
|
||||
record_event("agent_session_start", agent_info)
|
||||
|
||||
|
||||
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Called at the start of an agent run loop."""
|
||||
if not is_telemetry_enabled():
|
||||
return
|
||||
|
||||
|
||||
self.run_id = str(uuid.uuid4())
|
||||
self.run_start_time = time.time()
|
||||
self.step_count = 0
|
||||
|
||||
|
||||
# Calculate input context size
|
||||
input_context_size = self._calculate_context_size(old_items)
|
||||
|
||||
|
||||
run_data = {
|
||||
"session_id": self.session_id,
|
||||
"run_id": self.run_id,
|
||||
"start_time": self.run_start_time,
|
||||
"input_context_size": input_context_size,
|
||||
"num_existing_messages": len(old_items)
|
||||
"num_existing_messages": len(old_items),
|
||||
}
|
||||
|
||||
|
||||
# Log trajectory if opted in
|
||||
if self.log_trajectory:
|
||||
trajectory = self._extract_trajectory(old_items)
|
||||
if trajectory:
|
||||
run_data["uploaded_trajectory"] = trajectory
|
||||
|
||||
|
||||
record_event("agent_run_start", run_data)
|
||||
|
||||
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
||||
|
||||
async def on_run_end(
|
||||
self,
|
||||
kwargs: Dict[str, Any],
|
||||
old_items: List[Dict[str, Any]],
|
||||
new_items: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Called at the end of an agent run loop."""
|
||||
if not is_telemetry_enabled() or not self.run_start_time:
|
||||
return
|
||||
|
||||
|
||||
run_duration = time.time() - self.run_start_time
|
||||
|
||||
|
||||
run_data = {
|
||||
"session_id": self.session_id,
|
||||
"run_id": self.run_id,
|
||||
"end_time": time.time(),
|
||||
"duration_seconds": run_duration,
|
||||
"num_steps": self.step_count,
|
||||
"total_usage": self.total_usage.copy()
|
||||
"total_usage": self.total_usage.copy(),
|
||||
}
|
||||
|
||||
|
||||
# Log trajectory if opted in
|
||||
if self.log_trajectory:
|
||||
trajectory = self._extract_trajectory(new_items)
|
||||
if trajectory:
|
||||
run_data["uploaded_trajectory"] = trajectory
|
||||
|
||||
|
||||
record_event("agent_run_end", run_data)
|
||||
|
||||
|
||||
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Called when usage information is received."""
|
||||
if not is_telemetry_enabled():
|
||||
return
|
||||
|
||||
|
||||
# Accumulate usage stats
|
||||
self.total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
|
||||
self.total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
|
||||
self.total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
|
||||
self.total_usage["total_tokens"] += usage.get("total_tokens", 0)
|
||||
self.total_usage["response_cost"] += usage.get("response_cost", 0.0)
|
||||
|
||||
|
||||
# Record individual usage event
|
||||
usage_data = {
|
||||
"session_id": self.session_id,
|
||||
"run_id": self.run_id,
|
||||
"step": self.step_count,
|
||||
**usage
|
||||
**usage,
|
||||
}
|
||||
|
||||
|
||||
record_event("agent_usage", usage_data)
|
||||
|
||||
|
||||
async def on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
|
||||
"""Called when responses are received."""
|
||||
if not is_telemetry_enabled():
|
||||
return
|
||||
|
||||
|
||||
self.step_count += 1
|
||||
step_duration = None
|
||||
|
||||
|
||||
if self.step_start_time:
|
||||
step_duration = time.time() - self.step_start_time
|
||||
|
||||
|
||||
self.step_start_time = time.time()
|
||||
|
||||
|
||||
step_data = {
|
||||
"session_id": self.session_id,
|
||||
"run_id": self.run_id,
|
||||
"step": self.step_count,
|
||||
"timestamp": self.step_start_time
|
||||
"timestamp": self.step_start_time,
|
||||
}
|
||||
|
||||
|
||||
if step_duration is not None:
|
||||
step_data["duration_seconds"] = step_duration
|
||||
|
||||
|
||||
record_event("agent_step", step_data)
|
||||
|
||||
|
||||
def _calculate_context_size(self, items: List[Dict[str, Any]]) -> int:
|
||||
"""Calculate approximate context size in tokens/characters."""
|
||||
total_size = 0
|
||||
|
||||
|
||||
for item in items:
|
||||
if item.get("type") == "message" and "content" in item:
|
||||
content = item["content"]
|
||||
@@ -185,25 +189,27 @@ class TelemetryCallback(AsyncCallbackHandler):
|
||||
total_size += len(part["text"])
|
||||
elif "content" in item and isinstance(item["content"], str):
|
||||
total_size += len(item["content"])
|
||||
|
||||
|
||||
return total_size
|
||||
|
||||
|
||||
def _extract_trajectory(self, items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Extract trajectory items that should be logged."""
|
||||
trajectory = []
|
||||
|
||||
|
||||
for item in items:
|
||||
# Include user messages, assistant messages, reasoning, computer calls, and computer outputs
|
||||
if (
|
||||
item.get("role") == "user" or # User inputs
|
||||
(item.get("type") == "message" and item.get("role") == "assistant") or # Model outputs
|
||||
item.get("type") == "reasoning" or # Reasoning traces
|
||||
item.get("type") == "computer_call" or # Computer actions
|
||||
item.get("type") == "computer_call_output" # Computer outputs
|
||||
item.get("role") == "user" # User inputs
|
||||
or (
|
||||
item.get("type") == "message" and item.get("role") == "assistant"
|
||||
) # Model outputs
|
||||
or item.get("type") == "reasoning" # Reasoning traces
|
||||
or item.get("type") == "computer_call" # Computer actions
|
||||
or item.get("type") == "computer_call_output" # Computer outputs
|
||||
):
|
||||
# Create a copy of the item with timestamp
|
||||
trajectory_item = item.copy()
|
||||
trajectory_item["logged_at"] = time.time()
|
||||
trajectory.append(trajectory_item)
|
||||
|
||||
return trajectory
|
||||
|
||||
return trajectory
|
||||
|
||||
@@ -2,26 +2,28 @@
|
||||
Trajectory saving callback handler for ComputerAgent.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Union, override
|
||||
from PIL import Image, ImageDraw
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union, override
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
|
||||
def sanitize_image_urls(data: Any) -> Any:
|
||||
"""
|
||||
Recursively search for 'image_url' keys and set their values to '[omitted]'.
|
||||
|
||||
|
||||
Args:
|
||||
data: Any data structure (dict, list, or primitive type)
|
||||
|
||||
|
||||
Returns:
|
||||
A deep copy of the data with all 'image_url' values replaced with '[omitted]'
|
||||
"""
|
||||
@@ -35,17 +37,19 @@ def sanitize_image_urls(data: Any) -> Any:
|
||||
# Recursively sanitize the value
|
||||
sanitized[key] = sanitize_image_urls(value)
|
||||
return sanitized
|
||||
|
||||
|
||||
elif isinstance(data, list):
|
||||
# Recursively sanitize each item in the list
|
||||
return [sanitize_image_urls(item) for item in data]
|
||||
|
||||
|
||||
else:
|
||||
# For primitive types (str, int, bool, None, etc.), return as-is
|
||||
return data
|
||||
|
||||
|
||||
def extract_computer_call_outputs(items: List[Dict[str, Any]], screenshot_dir: Optional[Path]) -> List[Dict[str, Any]]:
|
||||
def extract_computer_call_outputs(
|
||||
items: List[Dict[str, Any]], screenshot_dir: Optional[Path]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Save any base64-encoded screenshots from computer_call_output entries to files and
|
||||
replace their image_url with the saved file path when a call_id is present.
|
||||
@@ -103,18 +107,21 @@ def extract_computer_call_outputs(items: List[Dict[str, Any]], screenshot_dir: O
|
||||
updated.append(msg)
|
||||
return updated
|
||||
|
||||
|
||||
class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
Callback handler that saves agent trajectories to disk.
|
||||
|
||||
|
||||
Saves each run as a separate trajectory with unique ID, and each turn
|
||||
within the trajectory gets its own folder with screenshots and responses.
|
||||
"""
|
||||
|
||||
def __init__(self, trajectory_dir: str, reset_on_run: bool = True, screenshot_dir: Optional[str] = None):
|
||||
|
||||
def __init__(
|
||||
self, trajectory_dir: str, reset_on_run: bool = True, screenshot_dir: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize trajectory saver.
|
||||
|
||||
|
||||
Args:
|
||||
trajectory_dir: Base directory to save trajectories
|
||||
reset_on_run: If True, reset trajectory_id/turn/artifact on each run.
|
||||
@@ -129,7 +136,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
self.reset_on_run = reset_on_run
|
||||
# Optional directory to store extracted screenshots from metadata/new_items
|
||||
self.screenshot_dir: Optional[Path] = Path(screenshot_dir) if screenshot_dir else None
|
||||
|
||||
|
||||
# Ensure trajectory directory exists
|
||||
self.trajectory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -137,7 +144,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
"""Get the directory for the current turn."""
|
||||
if not self.trajectory_id:
|
||||
raise ValueError("Trajectory not initialized - call _on_run_start first")
|
||||
|
||||
|
||||
# format: trajectory_id/turn_000
|
||||
turn_dir = self.trajectory_dir / self.trajectory_id / f"turn_{self.current_turn:03d}"
|
||||
turn_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -166,6 +173,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
|
||||
def _update_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Update total usage statistics."""
|
||||
|
||||
def add_dicts(target: Dict[str, Any], source: Dict[str, Any]) -> None:
|
||||
for key, value in source.items():
|
||||
if isinstance(value, dict):
|
||||
@@ -176,20 +184,21 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
if key not in target:
|
||||
target[key] = 0
|
||||
target[key] += value
|
||||
|
||||
add_dicts(self.total_usage, usage)
|
||||
|
||||
|
||||
@override
|
||||
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Initialize trajectory tracking for a new run."""
|
||||
model = kwargs.get("model", "unknown")
|
||||
|
||||
|
||||
# Only reset trajectory state if reset_on_run is True or no trajectory exists
|
||||
if self.reset_on_run or not self.trajectory_id:
|
||||
model_name_short = model.split("+")[-1].split("/")[-1].lower()[:16]
|
||||
if "+" in model:
|
||||
model_name_short = model.split("+")[0].lower()[:4] + "_" + model_name_short
|
||||
# strip non-alphanumeric characters from model_name_short
|
||||
model_name_short = ''.join(c for c in model_name_short if c.isalnum() or c == '_')
|
||||
model_name_short = "".join(c for c in model_name_short if c.isalnum() or c == "_")
|
||||
|
||||
# id format: yyyy-mm-dd_model_hhmmss_uuid[:4]
|
||||
now = datetime.now()
|
||||
@@ -198,11 +207,11 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
self.current_artifact = 0
|
||||
self.model = model
|
||||
self.total_usage = {}
|
||||
|
||||
|
||||
# Create trajectory directory
|
||||
trajectory_path = self.trajectory_dir / self.trajectory_id
|
||||
trajectory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Save trajectory metadata (optionally extract screenshots to screenshot_dir)
|
||||
kwargs_to_save = kwargs.copy()
|
||||
try:
|
||||
@@ -219,7 +228,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
"status": "running",
|
||||
"kwargs": kwargs_to_save,
|
||||
}
|
||||
|
||||
|
||||
with open(trajectory_path / "metadata.json", "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
else:
|
||||
@@ -227,22 +236,27 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
self.model = model
|
||||
|
||||
@override
|
||||
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
||||
async def on_run_end(
|
||||
self,
|
||||
kwargs: Dict[str, Any],
|
||||
old_items: List[Dict[str, Any]],
|
||||
new_items: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Finalize run tracking by updating metadata with completion status, usage, and new items."""
|
||||
if not self.trajectory_id:
|
||||
return
|
||||
|
||||
|
||||
# Update metadata with completion status, total usage, and new items
|
||||
trajectory_path = self.trajectory_dir / self.trajectory_id
|
||||
metadata_path = trajectory_path / "metadata.json"
|
||||
|
||||
|
||||
# Read existing metadata
|
||||
if metadata_path.exists():
|
||||
with open(metadata_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
metadata = {}
|
||||
|
||||
|
||||
# Update metadata with completion info
|
||||
# Optionally extract screenshots from new_items before persisting
|
||||
new_items_to_save = new_items
|
||||
@@ -251,32 +265,34 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
metadata.update({
|
||||
"status": "completed",
|
||||
"completed_at": str(uuid.uuid1().time),
|
||||
"total_usage": self.total_usage,
|
||||
"new_items": new_items_to_save,
|
||||
"total_turns": self.current_turn
|
||||
})
|
||||
|
||||
metadata.update(
|
||||
{
|
||||
"status": "completed",
|
||||
"completed_at": str(uuid.uuid1().time),
|
||||
"total_usage": self.total_usage,
|
||||
"new_items": new_items_to_save,
|
||||
"total_turns": self.current_turn,
|
||||
}
|
||||
)
|
||||
|
||||
# Save updated metadata
|
||||
with open(metadata_path, "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
@override
|
||||
|
||||
@override
|
||||
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
||||
if not self.trajectory_id:
|
||||
return
|
||||
|
||||
self._save_artifact("api_start", { "kwargs": kwargs })
|
||||
|
||||
|
||||
self._save_artifact("api_start", {"kwargs": kwargs})
|
||||
|
||||
@override
|
||||
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
||||
"""Save API call result."""
|
||||
if not self.trajectory_id:
|
||||
return
|
||||
|
||||
self._save_artifact("api_result", { "kwargs": kwargs, "result": result })
|
||||
|
||||
self._save_artifact("api_result", {"kwargs": kwargs, "result": result})
|
||||
|
||||
@override
|
||||
async def on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
|
||||
@@ -295,77 +311,83 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
"""Save responses to the current turn directory and update usage statistics."""
|
||||
if not self.trajectory_id:
|
||||
return
|
||||
|
||||
|
||||
# Save responses
|
||||
turn_dir = self._get_turn_dir()
|
||||
response_data = {
|
||||
"timestamp": str(uuid.uuid1().time),
|
||||
"model": self.model,
|
||||
"kwargs": kwargs,
|
||||
"response": responses
|
||||
"response": responses,
|
||||
}
|
||||
|
||||
|
||||
self._save_artifact("agent_response", response_data)
|
||||
|
||||
|
||||
# Increment turn counter
|
||||
self.current_turn += 1
|
||||
|
||||
def _draw_crosshair_on_image(self, image_bytes: bytes, x: int, y: int) -> bytes:
|
||||
"""
|
||||
Draw a red dot and crosshair at the specified coordinates on the image.
|
||||
|
||||
|
||||
Args:
|
||||
image_bytes: The original image as bytes
|
||||
x: X coordinate for the crosshair
|
||||
y: Y coordinate for the crosshair
|
||||
|
||||
|
||||
Returns:
|
||||
Modified image as bytes with red dot and crosshair
|
||||
"""
|
||||
# Open the image
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
|
||||
# Draw crosshair lines (red, 2px thick)
|
||||
crosshair_size = 20
|
||||
line_width = 2
|
||||
color = "red"
|
||||
|
||||
|
||||
# Horizontal line
|
||||
draw.line([(x - crosshair_size, y), (x + crosshair_size, y)], fill=color, width=line_width)
|
||||
# Vertical line
|
||||
draw.line([(x, y - crosshair_size), (x, y + crosshair_size)], fill=color, width=line_width)
|
||||
|
||||
|
||||
# Draw center dot (filled circle)
|
||||
dot_radius = 3
|
||||
draw.ellipse([(x - dot_radius, y - dot_radius), (x + dot_radius, y + dot_radius)], fill=color)
|
||||
|
||||
draw.ellipse(
|
||||
[(x - dot_radius, y - dot_radius), (x + dot_radius, y + dot_radius)], fill=color
|
||||
)
|
||||
|
||||
# Convert back to bytes
|
||||
output = io.BytesIO()
|
||||
image.save(output, format='PNG')
|
||||
image.save(output, format="PNG")
|
||||
return output.getvalue()
|
||||
|
||||
@override
|
||||
async def on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
||||
async def on_computer_call_end(
|
||||
self, item: Dict[str, Any], result: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""
|
||||
Called when a computer call has completed.
|
||||
Saves screenshots and computer call output.
|
||||
"""
|
||||
if not self.trajectory_id:
|
||||
return
|
||||
|
||||
self._save_artifact("computer_call_result", { "item": item, "result": result })
|
||||
|
||||
|
||||
self._save_artifact("computer_call_result", {"item": item, "result": result})
|
||||
|
||||
# Check if action has x/y coordinates and there's a screenshot in the result
|
||||
action = item.get("action", {})
|
||||
if "x" in action and "y" in action:
|
||||
# Look for screenshot in the result
|
||||
for result_item in result:
|
||||
if (result_item.get("type") == "computer_call_output" and
|
||||
result_item.get("output", {}).get("type") == "input_image"):
|
||||
|
||||
if (
|
||||
result_item.get("type") == "computer_call_output"
|
||||
and result_item.get("output", {}).get("type") == "input_image"
|
||||
):
|
||||
|
||||
image_url = result_item["output"]["image_url"]
|
||||
|
||||
|
||||
# Extract base64 image data
|
||||
if image_url.startswith("data:image/"):
|
||||
# Format: data:image/png;base64,<base64_data>
|
||||
@@ -373,26 +395,24 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
else:
|
||||
# Assume it's just base64 data
|
||||
base64_data = image_url
|
||||
|
||||
|
||||
try:
|
||||
# Decode the image
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
|
||||
|
||||
# Draw crosshair at the action coordinates
|
||||
annotated_image = self._draw_crosshair_on_image(
|
||||
image_bytes,
|
||||
int(action["x"]),
|
||||
int(action["y"])
|
||||
image_bytes, int(action["x"]), int(action["y"])
|
||||
)
|
||||
|
||||
|
||||
# Save as screenshot_action
|
||||
self._save_artifact("screenshot_action", annotated_image)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# If annotation fails, just log and continue
|
||||
print(f"Failed to annotate screenshot: {e}")
|
||||
|
||||
|
||||
break # Only process the first screenshot found
|
||||
|
||||
# Increment turn counter
|
||||
self.current_turn += 1
|
||||
self.current_turn += 1
|
||||
|
||||
@@ -3,7 +3,7 @@ CLI chat interface for agent - Computer Use Agent
|
||||
|
||||
Usage:
|
||||
python -m agent.cli <model_string>
|
||||
|
||||
|
||||
Examples:
|
||||
python -m agent.cli openai/computer-use-preview
|
||||
python -m agent.cli anthropic/claude-3-5-sonnet-20241022
|
||||
@@ -11,19 +11,22 @@ Examples:
|
||||
"""
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from typing import List, Dict, Any
|
||||
import dotenv
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import dotenv
|
||||
|
||||
try:
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
PIL_AVAILABLE = True
|
||||
except Exception:
|
||||
PIL_AVAILABLE = False
|
||||
@@ -31,36 +34,44 @@ try:
|
||||
except ImportError:
|
||||
if __name__ == "__main__":
|
||||
raise ImportError(
|
||||
"CLI dependencies not found. "
|
||||
"Please install with: pip install \"cua-agent[cli]\""
|
||||
"CLI dependencies not found. " 'Please install with: pip install "cua-agent[cli]"'
|
||||
)
|
||||
|
||||
# Load environment variables
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
# Color codes for terminal output
|
||||
class Colors:
|
||||
RESET = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
DIM = '\033[2m'
|
||||
|
||||
# Text colors
|
||||
RED = '\033[31m'
|
||||
GREEN = '\033[32m'
|
||||
YELLOW = '\033[33m'
|
||||
BLUE = '\033[34m'
|
||||
MAGENTA = '\033[35m'
|
||||
CYAN = '\033[36m'
|
||||
WHITE = '\033[37m'
|
||||
GRAY = '\033[90m'
|
||||
|
||||
# Background colors
|
||||
BG_RED = '\033[41m'
|
||||
BG_GREEN = '\033[42m'
|
||||
BG_YELLOW = '\033[43m'
|
||||
BG_BLUE = '\033[44m'
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
DIM = "\033[2m"
|
||||
|
||||
def print_colored(text: str, color: str = "", bold: bool = False, dim: bool = False, end: str = "\n", right: str = ""):
|
||||
# Text colors
|
||||
RED = "\033[31m"
|
||||
GREEN = "\033[32m"
|
||||
YELLOW = "\033[33m"
|
||||
BLUE = "\033[34m"
|
||||
MAGENTA = "\033[35m"
|
||||
CYAN = "\033[36m"
|
||||
WHITE = "\033[37m"
|
||||
GRAY = "\033[90m"
|
||||
|
||||
# Background colors
|
||||
BG_RED = "\033[41m"
|
||||
BG_GREEN = "\033[42m"
|
||||
BG_YELLOW = "\033[43m"
|
||||
BG_BLUE = "\033[44m"
|
||||
|
||||
|
||||
def print_colored(
|
||||
text: str,
|
||||
color: str = "",
|
||||
bold: bool = False,
|
||||
dim: bool = False,
|
||||
end: str = "\n",
|
||||
right: str = "",
|
||||
):
|
||||
"""Print colored text to terminal with optional right-aligned text."""
|
||||
prefix = ""
|
||||
if bold:
|
||||
@@ -69,24 +80,25 @@ def print_colored(text: str, color: str = "", bold: bool = False, dim: bool = Fa
|
||||
prefix += Colors.DIM
|
||||
if color:
|
||||
prefix += color
|
||||
|
||||
|
||||
if right:
|
||||
# Get terminal width (default to 80 if unable to determine)
|
||||
try:
|
||||
import shutil
|
||||
|
||||
terminal_width = shutil.get_terminal_size().columns
|
||||
except:
|
||||
terminal_width = 80
|
||||
|
||||
# Add right margin
|
||||
terminal_width -= 1
|
||||
|
||||
|
||||
# Calculate padding needed
|
||||
# Account for ANSI escape codes not taking visual space
|
||||
visible_left_len = len(text)
|
||||
visible_right_len = len(right)
|
||||
padding = terminal_width - visible_left_len - visible_right_len
|
||||
|
||||
|
||||
if padding > 0:
|
||||
output = f"{prefix}{text}{' ' * padding}{right}{Colors.RESET}"
|
||||
else:
|
||||
@@ -94,7 +106,7 @@ def print_colored(text: str, color: str = "", bold: bool = False, dim: bool = Fa
|
||||
output = f"{prefix}{text} {right}{Colors.RESET}"
|
||||
else:
|
||||
output = f"{prefix}{text}{Colors.RESET}"
|
||||
|
||||
|
||||
print(output, end=end)
|
||||
|
||||
|
||||
@@ -113,29 +125,34 @@ def print_action(action_type: str, details: Dict[str, Any], total_cost: float):
|
||||
args_str = f"('{details['text']}')"
|
||||
elif action_type == "scroll" and "x" in details and "y" in details:
|
||||
args_str = f"({details['x']}, {details['y']})"
|
||||
|
||||
|
||||
if total_cost > 0:
|
||||
print_colored(f"🛠️ {action_type}{args_str}", dim=True, right=f"💸 ${total_cost:.2f}")
|
||||
else:
|
||||
print_colored(f"🛠️ {action_type}{args_str}", dim=True)
|
||||
|
||||
|
||||
def print_welcome(model: str, agent_loop: str, container_name: str):
|
||||
"""Print welcome message."""
|
||||
print_colored(f"Connected to {container_name} ({model}, {agent_loop})")
|
||||
print_colored("Type 'exit' to quit.", dim=True)
|
||||
|
||||
|
||||
async def ainput(prompt: str = ""):
|
||||
return await asyncio.to_thread(input, prompt)
|
||||
|
||||
async def chat_loop(agent, model: str, container_name: str, initial_prompt: str = "", show_usage: bool = True):
|
||||
|
||||
async def chat_loop(
|
||||
agent, model: str, container_name: str, initial_prompt: str = "", show_usage: bool = True
|
||||
):
|
||||
"""Main chat loop with the agent."""
|
||||
print_welcome(model, agent.agent_config_info.agent_class.__name__, container_name)
|
||||
|
||||
|
||||
history = []
|
||||
|
||||
|
||||
if initial_prompt:
|
||||
history.append({"role": "user", "content": initial_prompt})
|
||||
|
||||
|
||||
total_cost = 0
|
||||
|
||||
while True:
|
||||
@@ -143,31 +160,31 @@ async def chat_loop(agent, model: str, container_name: str, initial_prompt: str
|
||||
# Get user input with prompt
|
||||
print_colored("> ", end="")
|
||||
user_input = await ainput()
|
||||
|
||||
if user_input.lower() in ['exit', 'quit', 'q']:
|
||||
|
||||
if user_input.lower() in ["exit", "quit", "q"]:
|
||||
print_colored("\n👋 Goodbye!")
|
||||
break
|
||||
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
|
||||
# Add user message to history
|
||||
history.append({"role": "user", "content": user_input})
|
||||
|
||||
|
||||
# Stream responses from the agent with spinner
|
||||
with yaspin(text="Thinking...", spinner="line", attrs=["dark"]) as spinner:
|
||||
spinner.hide()
|
||||
|
||||
|
||||
async for result in agent.run(history):
|
||||
# Add agent responses to history
|
||||
history.extend(result.get("output", []))
|
||||
|
||||
if show_usage:
|
||||
total_cost += result.get("usage", {}).get("response_cost", 0)
|
||||
|
||||
|
||||
# Process and display the output
|
||||
for item in result.get("output", []):
|
||||
if item.get("type") == "message":
|
||||
if item.get("type") == "message" and item.get("role") == "assistant":
|
||||
# Display agent text response
|
||||
content = item.get("content", [])
|
||||
for content_part in content:
|
||||
@@ -176,7 +193,7 @@ async def chat_loop(agent, model: str, container_name: str, initial_prompt: str
|
||||
if text:
|
||||
spinner.hide()
|
||||
print_colored(text)
|
||||
|
||||
|
||||
elif item.get("type") == "computer_call":
|
||||
# Display computer action
|
||||
action = item.get("action", {})
|
||||
@@ -186,7 +203,7 @@ async def chat_loop(agent, model: str, container_name: str, initial_prompt: str
|
||||
print_action(action_type, action, total_cost)
|
||||
spinner.text = f"Performing {action_type}..."
|
||||
spinner.show()
|
||||
|
||||
|
||||
elif item.get("type") == "function_call":
|
||||
# Display function call
|
||||
function_name = item.get("name", "")
|
||||
@@ -194,18 +211,18 @@ async def chat_loop(agent, model: str, container_name: str, initial_prompt: str
|
||||
print_colored(f"🔧 Calling function: {function_name}", dim=True)
|
||||
spinner.text = f"Calling {function_name}..."
|
||||
spinner.show()
|
||||
|
||||
|
||||
elif item.get("type") == "function_call_output":
|
||||
# Display function output (dimmed)
|
||||
output = item.get("output", "")
|
||||
if output and len(output.strip()) > 0:
|
||||
spinner.hide()
|
||||
print_colored(f"📤 {output}", dim=True)
|
||||
|
||||
|
||||
spinner.hide()
|
||||
if show_usage and total_cost > 0:
|
||||
print_colored(f"Total cost: ${total_cost:.2f}", dim=True)
|
||||
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main CLI function."""
|
||||
@@ -218,104 +235,103 @@ Examples:
|
||||
python -m agent.cli anthropic/claude-3-5-sonnet-20241022
|
||||
python -m agent.cli omniparser+anthropic/claude-3-5-sonnet-20241022
|
||||
python -m agent.cli huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"model",
|
||||
help="Model string (e.g., 'openai/computer-use-preview', 'anthropic/claude-3-5-sonnet-20241022')"
|
||||
help="Model string (e.g., 'openai/computer-use-preview', 'anthropic/claude-3-5-sonnet-20241022')",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
choices=["cloud", "lume", "winsandbox", "docker"],
|
||||
default="cloud",
|
||||
help="Computer provider to use: cloud (default), lume, winsandbox, or docker",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--images",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of recent images to keep in context (default: 3)"
|
||||
help="Number of recent images to keep in context (default: 3)",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument("--trajectory", action="store_true", help="Save trajectory for debugging")
|
||||
|
||||
parser.add_argument("--budget", type=float, help="Maximum budget for the session (in dollars)")
|
||||
|
||||
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
|
||||
|
||||
parser.add_argument(
|
||||
"--trajectory",
|
||||
action="store_true",
|
||||
help="Save trajectory for debugging"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--budget",
|
||||
type=float,
|
||||
help="Maximum budget for the session (in dollars)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Enable verbose logging"
|
||||
"-p",
|
||||
"--prompt",
|
||||
type=str,
|
||||
help="Initial prompt to send to the agent. Leave blank for interactive mode.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-p", "--prompt",
|
||||
type=str,
|
||||
help="Initial prompt to send to the agent. Leave blank for interactive mode."
|
||||
"--prompt-file",
|
||||
type=Path,
|
||||
help="Path to a UTF-8 text file whose contents will be used as the initial prompt. If provided, overrides --prompt.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--predict-click",
|
||||
dest="predict_click",
|
||||
type=str,
|
||||
help="Instruction for click prediction. If set, runs predict_click, draws crosshair on a fresh screenshot, saves and opens it."
|
||||
help="Instruction for click prediction. If set, runs predict_click, draws crosshair on a fresh screenshot, saves and opens it.",
|
||||
)
|
||||
|
||||
parser.add_argument("-c", "--cache", action="store_true", help="Tell the API to enable caching")
|
||||
|
||||
parser.add_argument(
|
||||
"-u", "--usage", action="store_true", help="Show total cost of the agent runs"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-c", "--cache",
|
||||
action="store_true",
|
||||
help="Tell the API to enable caching"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-u", "--usage",
|
||||
action="store_true",
|
||||
help="Show total cost of the agent runs"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-r", "--max-retries",
|
||||
"-r",
|
||||
"--max-retries",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Maximum number of retries for the LLM API calls"
|
||||
help="Maximum number of retries for the LLM API calls",
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Check for required environment variables
|
||||
container_name = os.getenv("CUA_CONTAINER_NAME")
|
||||
cua_api_key = os.getenv("CUA_API_KEY")
|
||||
|
||||
# Prompt for missing environment variables
|
||||
|
||||
# Prompt for missing environment variables (container name always required)
|
||||
if not container_name:
|
||||
print_colored("CUA_CONTAINER_NAME not set.", dim=True)
|
||||
print_colored("You can get a CUA container at https://www.trycua.com/", dim=True)
|
||||
container_name = input("Enter your CUA container name: ").strip()
|
||||
if not container_name:
|
||||
print_colored("❌ Container name is required.")
|
||||
sys.exit(1)
|
||||
|
||||
if not cua_api_key:
|
||||
if args.provider == "cloud":
|
||||
print_colored("CUA_CONTAINER_NAME not set.", dim=True)
|
||||
print_colored("You can get a CUA container at https://www.trycua.com/", dim=True)
|
||||
container_name = input("Enter your CUA container name: ").strip()
|
||||
if not container_name:
|
||||
print_colored("❌ Container name is required.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
container_name = "cli-sandbox"
|
||||
|
||||
# Only require API key for cloud provider
|
||||
if args.provider == "cloud" and not cua_api_key:
|
||||
print_colored("CUA_API_KEY not set.", dim=True)
|
||||
cua_api_key = input("Enter your CUA API key: ").strip()
|
||||
if not cua_api_key:
|
||||
print_colored("❌ API key is required.")
|
||||
print_colored("❌ API key is required for cloud provider.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# Check for provider-specific API keys based on model
|
||||
provider_api_keys = {
|
||||
"openai/": "OPENAI_API_KEY",
|
||||
"anthropic/": "ANTHROPIC_API_KEY",
|
||||
"omniparser+": "OPENAI_API_KEY",
|
||||
"omniparser+": "ANTHROPIC_API_KEY",
|
||||
}
|
||||
|
||||
|
||||
# Find matching provider and check for API key
|
||||
for prefix, env_var in provider_api_keys.items():
|
||||
if args.model.startswith(prefix):
|
||||
if prefix in args.model:
|
||||
if not os.getenv(env_var):
|
||||
print_colored(f"{env_var} not set.", dim=True)
|
||||
api_key = input(f"Enter your {env_var.replace('_', ' ').title()}: ").strip()
|
||||
@@ -325,7 +341,7 @@ Examples:
|
||||
# Set the environment variable for the session
|
||||
os.environ[env_var] = api_key
|
||||
break
|
||||
|
||||
|
||||
# Import here to avoid import errors if dependencies are missing
|
||||
try:
|
||||
from agent import ComputerAgent
|
||||
@@ -334,46 +350,62 @@ Examples:
|
||||
print_colored(f"❌ Import error: {e}", Colors.RED, bold=True)
|
||||
print_colored("Make sure agent and computer libraries are installed.", Colors.YELLOW)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# Resolve provider -> os_type, provider_type, api key requirement
|
||||
provider_map = {
|
||||
"cloud": ("linux", "cloud", True),
|
||||
"lume": ("macos", "lume", False),
|
||||
"winsandbox": ("windows", "winsandbox", False),
|
||||
"docker": ("linux", "docker", False),
|
||||
}
|
||||
os_type, provider_type, needs_api_key = provider_map[args.provider]
|
||||
|
||||
computer_kwargs = {
|
||||
"os_type": os_type,
|
||||
"provider_type": provider_type,
|
||||
"name": container_name,
|
||||
}
|
||||
if needs_api_key:
|
||||
computer_kwargs["api_key"] = cua_api_key # type: ignore
|
||||
|
||||
# Create computer instance
|
||||
async with Computer(
|
||||
os_type="linux",
|
||||
provider_type="cloud",
|
||||
name=container_name,
|
||||
api_key=cua_api_key
|
||||
) as computer:
|
||||
|
||||
async with Computer(**computer_kwargs) as computer: # type: ignore
|
||||
|
||||
# Create agent
|
||||
agent_kwargs = {
|
||||
"model": args.model,
|
||||
"tools": [computer],
|
||||
"trust_remote_code": True, # needed for some local models (e.g., InternVL, OpenCUA)
|
||||
"trust_remote_code": True, # needed for some local models (e.g., InternVL, OpenCUA)
|
||||
"verbosity": 20 if args.verbose else 30, # DEBUG vs WARNING
|
||||
"max_retries": args.max_retries
|
||||
"max_retries": args.max_retries,
|
||||
}
|
||||
|
||||
if args.images > 0:
|
||||
agent_kwargs["only_n_most_recent_images"] = args.images
|
||||
|
||||
|
||||
if args.trajectory:
|
||||
agent_kwargs["trajectory_dir"] = "trajectories"
|
||||
|
||||
|
||||
if args.budget:
|
||||
agent_kwargs["max_trajectory_budget"] = {
|
||||
"max_budget": args.budget,
|
||||
"raise_error": True,
|
||||
"reset_after_each_run": False
|
||||
"reset_after_each_run": False,
|
||||
}
|
||||
|
||||
if args.cache:
|
||||
agent_kwargs["use_prompt_caching"] = True
|
||||
|
||||
|
||||
agent = ComputerAgent(**agent_kwargs)
|
||||
|
||||
|
||||
# If predict-click mode is requested, run once and exit
|
||||
if args.predict_click:
|
||||
if not PIL_AVAILABLE:
|
||||
print_colored("❌ Pillow (PIL) is required for --predict-click visualization. Install with: pip install pillow", Colors.RED, bold=True)
|
||||
print_colored(
|
||||
"❌ Pillow (PIL) is required for --predict-click visualization. Install with: pip install pillow",
|
||||
Colors.RED,
|
||||
bold=True,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
instruction = args.predict_click
|
||||
@@ -408,6 +440,7 @@ Examples:
|
||||
|
||||
try:
|
||||
from io import BytesIO
|
||||
|
||||
with Image.open(BytesIO(img_bytes)) as img:
|
||||
img = img.convert("RGB")
|
||||
draw = ImageDraw.Draw(img)
|
||||
@@ -430,9 +463,9 @@ Examples:
|
||||
if system == "windows":
|
||||
os.startfile(str(out_path)) # type: ignore[attr-defined]
|
||||
elif system == "darwin":
|
||||
os.system(f"open \"{out_path}\"")
|
||||
os.system(f'open "{out_path}"')
|
||||
else:
|
||||
os.system(f"xdg-open \"{out_path}\"")
|
||||
os.system(f'xdg-open "{out_path}"')
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
@@ -442,13 +475,21 @@ Examples:
|
||||
# Done
|
||||
sys.exit(0)
|
||||
|
||||
# Start chat loop (default interactive mode)
|
||||
await chat_loop(agent, args.model, container_name, args.prompt, args.usage)
|
||||
# Resolve initial prompt from --prompt-file or --prompt
|
||||
initial_prompt = args.prompt or ""
|
||||
if args.prompt_file:
|
||||
try:
|
||||
initial_prompt = args.prompt_file.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
print_colored(f"❌ Failed to read --prompt-file: {e}", Colors.RED, bold=True)
|
||||
sys.exit(1)
|
||||
|
||||
# Start chat loop (default interactive mode)
|
||||
await chat_loop(agent, args.model, container_name, initial_prompt, args.usage)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except (KeyboardInterrupt, EOFError) as _:
|
||||
print_colored("\n\n👋 Goodbye!")
|
||||
print_colored("\n\n👋 Goodbye!")
|
||||
|
||||
@@ -6,27 +6,32 @@ computer interface types, supporting both the ComputerHandler protocol and the
|
||||
Computer library interface.
|
||||
"""
|
||||
|
||||
from computer import Computer as cuaComputer
|
||||
|
||||
from .base import AsyncComputerHandler
|
||||
from .cua import cuaComputerHandler
|
||||
from .custom import CustomComputerHandler
|
||||
from computer import Computer as cuaComputer
|
||||
|
||||
|
||||
def is_agent_computer(computer):
|
||||
"""Check if the given computer is a ComputerHandler or CUA Computer."""
|
||||
return isinstance(computer, AsyncComputerHandler) or \
|
||||
isinstance(computer, cuaComputer) or \
|
||||
(isinstance(computer, dict)) #and "screenshot" in computer)
|
||||
return (
|
||||
isinstance(computer, AsyncComputerHandler)
|
||||
or isinstance(computer, cuaComputer)
|
||||
or (isinstance(computer, dict))
|
||||
) # and "screenshot" in computer)
|
||||
|
||||
|
||||
async def make_computer_handler(computer):
|
||||
"""
|
||||
Create a computer handler from a computer interface.
|
||||
|
||||
|
||||
Args:
|
||||
computer: Either a ComputerHandler instance, Computer instance, or dict of functions
|
||||
|
||||
|
||||
Returns:
|
||||
ComputerHandler: A computer handler instance
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If the computer type is not supported
|
||||
"""
|
||||
@@ -38,4 +43,4 @@ async def make_computer_handler(computer):
|
||||
return computer_handler
|
||||
if isinstance(computer, dict):
|
||||
return CustomComputerHandler(computer)
|
||||
raise ValueError(f"Unsupported computer type: {type(computer)}")
|
||||
raise ValueError(f"Unsupported computer type: {type(computer)}")
|
||||
|
||||
@@ -2,23 +2,32 @@
|
||||
Base computer interface protocol for agent interactions.
|
||||
"""
|
||||
|
||||
from typing import Protocol, Literal, List, Dict, Any, Union, Optional, runtime_checkable
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AsyncComputerHandler(Protocol):
|
||||
"""Protocol defining the interface for computer interactions."""
|
||||
|
||||
# ==== Computer-Use-Preview Action Space ====
|
||||
|
||||
# ==== Computer-Use-Preview Action Space ====
|
||||
|
||||
async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]:
|
||||
"""Get the current environment type."""
|
||||
...
|
||||
|
||||
|
||||
async def get_dimensions(self) -> tuple[int, int]:
|
||||
"""Get screen dimensions as (width, height)."""
|
||||
...
|
||||
|
||||
|
||||
async def screenshot(self, text: Optional[str] = None) -> str:
|
||||
"""Take a screenshot and return as base64 string.
|
||||
|
||||
@@ -26,49 +35,49 @@ class AsyncComputerHandler(Protocol):
|
||||
text: Optional descriptive text (for compatibility with GPT-4o models, ignored)
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
async def click(self, x: int, y: int, button: str = "left") -> None:
|
||||
"""Click at coordinates with specified button."""
|
||||
...
|
||||
|
||||
|
||||
async def double_click(self, x: int, y: int) -> None:
|
||||
"""Double click at coordinates."""
|
||||
...
|
||||
|
||||
|
||||
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
|
||||
"""Scroll at coordinates with specified scroll amounts."""
|
||||
...
|
||||
|
||||
|
||||
async def type(self, text: str) -> None:
|
||||
"""Type text."""
|
||||
...
|
||||
|
||||
|
||||
async def wait(self, ms: int = 1000) -> None:
|
||||
"""Wait for specified milliseconds."""
|
||||
...
|
||||
|
||||
|
||||
async def move(self, x: int, y: int) -> None:
|
||||
"""Move cursor to coordinates."""
|
||||
...
|
||||
|
||||
|
||||
async def keypress(self, keys: Union[List[str], str]) -> None:
|
||||
"""Press key combination."""
|
||||
...
|
||||
|
||||
|
||||
async def drag(self, path: List[Dict[str, int]]) -> None:
|
||||
"""Drag along specified path."""
|
||||
...
|
||||
|
||||
|
||||
async def get_current_url(self) -> str:
|
||||
"""Get current URL (for browser environments)."""
|
||||
...
|
||||
|
||||
# ==== Anthropic Action Space ====
|
||||
|
||||
# ==== Anthropic Action Space ====
|
||||
|
||||
async def left_mouse_down(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
"""Left mouse down at coordinates."""
|
||||
...
|
||||
|
||||
|
||||
async def left_mouse_up(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
"""Left mouse up at coordinates."""
|
||||
...
|
||||
|
||||
@@ -3,24 +3,27 @@ Computer handler implementation for OpenAI computer-use-preview protocol.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Dict, List, Any, Literal, Union, Optional
|
||||
from .base import AsyncComputerHandler
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from computer import Computer
|
||||
|
||||
from .base import AsyncComputerHandler
|
||||
|
||||
|
||||
class cuaComputerHandler(AsyncComputerHandler):
|
||||
"""Computer handler that implements the Computer protocol using the computer interface."""
|
||||
|
||||
|
||||
def __init__(self, cua_computer: Computer):
|
||||
"""Initialize with a computer interface (from tool schema)."""
|
||||
self.cua_computer = cua_computer
|
||||
self.interface = None
|
||||
|
||||
async def _initialize(self):
|
||||
if hasattr(self.cua_computer, '_initialized') and not self.cua_computer._initialized:
|
||||
if hasattr(self.cua_computer, "_initialized") and not self.cua_computer._initialized:
|
||||
await self.cua_computer.run()
|
||||
self.interface = self.cua_computer.interface
|
||||
|
||||
# ==== Computer-Use-Preview Action Space ====
|
||||
|
||||
# ==== Computer-Use-Preview Action Space ====
|
||||
|
||||
async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]:
|
||||
"""Get the current environment type."""
|
||||
@@ -32,7 +35,7 @@ class cuaComputerHandler(AsyncComputerHandler):
|
||||
assert self.interface is not None
|
||||
screen_size = await self.interface.get_screen_size()
|
||||
return screen_size["width"], screen_size["height"]
|
||||
|
||||
|
||||
async def screenshot(self, text: Optional[str] = None) -> str:
|
||||
"""Take a screenshot and return as base64 string.
|
||||
|
||||
@@ -41,8 +44,8 @@ class cuaComputerHandler(AsyncComputerHandler):
|
||||
"""
|
||||
assert self.interface is not None
|
||||
screenshot_bytes = await self.interface.screenshot()
|
||||
return base64.b64encode(screenshot_bytes).decode('utf-8')
|
||||
|
||||
return base64.b64encode(screenshot_bytes).decode("utf-8")
|
||||
|
||||
async def click(self, x: int, y: int, button: str = "left") -> None:
|
||||
"""Click at coordinates with specified button."""
|
||||
assert self.interface is not None
|
||||
@@ -53,34 +56,35 @@ class cuaComputerHandler(AsyncComputerHandler):
|
||||
else:
|
||||
# Default to left click for unknown buttons
|
||||
await self.interface.left_click(x, y)
|
||||
|
||||
|
||||
async def double_click(self, x: int, y: int) -> None:
|
||||
"""Double click at coordinates."""
|
||||
assert self.interface is not None
|
||||
await self.interface.double_click(x, y)
|
||||
|
||||
|
||||
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
|
||||
"""Scroll at coordinates with specified scroll amounts."""
|
||||
assert self.interface is not None
|
||||
await self.interface.move_cursor(x, y)
|
||||
await self.interface.scroll(scroll_x, scroll_y)
|
||||
|
||||
|
||||
async def type(self, text: str) -> None:
|
||||
"""Type text."""
|
||||
assert self.interface is not None
|
||||
await self.interface.type_text(text)
|
||||
|
||||
|
||||
async def wait(self, ms: int = 1000) -> None:
|
||||
"""Wait for specified milliseconds."""
|
||||
assert self.interface is not None
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(ms / 1000.0)
|
||||
|
||||
|
||||
async def move(self, x: int, y: int) -> None:
|
||||
"""Move cursor to coordinates."""
|
||||
assert self.interface is not None
|
||||
await self.interface.move_cursor(x, y)
|
||||
|
||||
|
||||
async def keypress(self, keys: Union[List[str], str]) -> None:
|
||||
"""Press key combination."""
|
||||
assert self.interface is not None
|
||||
@@ -91,38 +95,38 @@ class cuaComputerHandler(AsyncComputerHandler):
|
||||
else:
|
||||
# Handle key combinations
|
||||
await self.interface.hotkey(*keys)
|
||||
|
||||
|
||||
async def drag(self, path: List[Dict[str, int]]) -> None:
|
||||
"""Drag along specified path."""
|
||||
assert self.interface is not None
|
||||
if not path:
|
||||
return
|
||||
|
||||
|
||||
# Start drag from first point
|
||||
start = path[0]
|
||||
await self.interface.mouse_down(start["x"], start["y"])
|
||||
|
||||
|
||||
# Move through path
|
||||
for point in path[1:]:
|
||||
await self.interface.move_cursor(point["x"], point["y"])
|
||||
|
||||
|
||||
# End drag at last point
|
||||
end = path[-1]
|
||||
await self.interface.mouse_up(end["x"], end["y"])
|
||||
|
||||
|
||||
async def get_current_url(self) -> str:
|
||||
"""Get current URL (for browser environments)."""
|
||||
# This would need to be implemented based on the specific browser interface
|
||||
# For now, return empty string
|
||||
return ""
|
||||
|
||||
# ==== Anthropic Computer Action Space ====
|
||||
# ==== Anthropic Computer Action Space ====
|
||||
async def left_mouse_down(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
"""Left mouse down at coordinates."""
|
||||
assert self.interface is not None
|
||||
await self.interface.mouse_down(x, y, button="left")
|
||||
|
||||
|
||||
async def left_mouse_up(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
"""Left mouse up at coordinates."""
|
||||
assert self.interface is not None
|
||||
await self.interface.mouse_up(x, y, button="left")
|
||||
await self.interface.mouse_up(x, y, button="left")
|
||||
|
||||
@@ -3,47 +3,49 @@ Custom computer handler implementation that accepts a dictionary of functions.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Dict, List, Any, Literal, Union, Optional, Callable
|
||||
from PIL import Image
|
||||
import io
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from .base import AsyncComputerHandler
|
||||
|
||||
|
||||
class CustomComputerHandler(AsyncComputerHandler):
|
||||
"""Computer handler that implements the Computer protocol using a dictionary of custom functions."""
|
||||
|
||||
|
||||
def __init__(self, functions: Dict[str, Callable]):
|
||||
"""
|
||||
Initialize with a dictionary of functions.
|
||||
|
||||
|
||||
Args:
|
||||
functions: Dictionary where keys are method names and values are callable functions.
|
||||
Only 'screenshot' is required, all others are optional.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If required 'screenshot' function is not provided.
|
||||
"""
|
||||
if 'screenshot' not in functions:
|
||||
if "screenshot" not in functions:
|
||||
raise ValueError("'screenshot' function is required in functions dictionary")
|
||||
|
||||
|
||||
self.functions = functions
|
||||
self._last_screenshot_size: Optional[tuple[int, int]] = None
|
||||
|
||||
|
||||
async def _call_function(self, func, *args, **kwargs):
|
||||
"""
|
||||
Call a function, handling both async and sync functions.
|
||||
|
||||
|
||||
Args:
|
||||
func: The function to call
|
||||
*args: Positional arguments to pass to the function
|
||||
**kwargs: Keyword arguments to pass to the function
|
||||
|
||||
|
||||
Returns:
|
||||
The result of the function call
|
||||
"""
|
||||
import asyncio
|
||||
import inspect
|
||||
|
||||
|
||||
if callable(func):
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*args, **kwargs)
|
||||
@@ -51,14 +53,14 @@ class CustomComputerHandler(AsyncComputerHandler):
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
return func
|
||||
|
||||
|
||||
async def _get_value(self, attribute: str):
|
||||
"""
|
||||
Get value for an attribute, checking both 'get_{attribute}' and '{attribute}' keys.
|
||||
|
||||
|
||||
Args:
|
||||
attribute: The attribute name to look for
|
||||
|
||||
|
||||
Returns:
|
||||
The value from the functions dict, called if callable, returned directly if not
|
||||
"""
|
||||
@@ -66,20 +68,20 @@ class CustomComputerHandler(AsyncComputerHandler):
|
||||
get_key = f"get_{attribute}"
|
||||
if get_key in self.functions:
|
||||
return await self._call_function(self.functions[get_key])
|
||||
|
||||
# Check for '{attribute}'
|
||||
|
||||
# Check for '{attribute}'
|
||||
if attribute in self.functions:
|
||||
return await self._call_function(self.functions[attribute])
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _to_b64_str(self, img: Union[bytes, Image.Image, str]) -> str:
|
||||
"""
|
||||
Convert image to base64 string.
|
||||
|
||||
|
||||
Args:
|
||||
img: Image as bytes, PIL Image, or base64 string
|
||||
|
||||
|
||||
Returns:
|
||||
str: Base64 encoded image string
|
||||
"""
|
||||
@@ -88,47 +90,47 @@ class CustomComputerHandler(AsyncComputerHandler):
|
||||
return img
|
||||
elif isinstance(img, bytes):
|
||||
# Raw bytes
|
||||
return base64.b64encode(img).decode('utf-8')
|
||||
return base64.b64encode(img).decode("utf-8")
|
||||
elif isinstance(img, Image.Image):
|
||||
# PIL Image
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='PNG')
|
||||
return base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
img.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
else:
|
||||
raise ValueError(f"Unsupported image type: {type(img)}")
|
||||
|
||||
# ==== Computer-Use-Preview Action Space ====
|
||||
|
||||
# ==== Computer-Use-Preview Action Space ====
|
||||
|
||||
async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]:
|
||||
"""Get the current environment type."""
|
||||
result = await self._get_value('environment')
|
||||
result = await self._get_value("environment")
|
||||
if result is None:
|
||||
return "linux"
|
||||
assert result in ["windows", "mac", "linux", "browser"]
|
||||
return result # type: ignore
|
||||
return result # type: ignore
|
||||
|
||||
async def get_dimensions(self) -> tuple[int, int]:
|
||||
"""Get screen dimensions as (width, height)."""
|
||||
result = await self._get_value('dimensions')
|
||||
result = await self._get_value("dimensions")
|
||||
if result is not None:
|
||||
return result # type: ignore
|
||||
|
||||
return result # type: ignore
|
||||
|
||||
# Fallback: use last screenshot size if available
|
||||
if not self._last_screenshot_size:
|
||||
await self.screenshot()
|
||||
assert self._last_screenshot_size is not None, "Failed to get screenshot size"
|
||||
|
||||
|
||||
return self._last_screenshot_size
|
||||
|
||||
|
||||
async def screenshot(self, text: Optional[str] = None) -> str:
|
||||
"""Take a screenshot and return as base64 string.
|
||||
|
||||
Args:
|
||||
text: Optional descriptive text (for compatibility with GPT-4o models, ignored)
|
||||
"""
|
||||
result = await self._call_function(self.functions['screenshot'])
|
||||
b64_str = self._to_b64_str(result) # type: ignore
|
||||
|
||||
result = await self._call_function(self.functions["screenshot"])
|
||||
b64_str = self._to_b64_str(result) # type: ignore
|
||||
|
||||
# Try to extract dimensions for fallback use
|
||||
try:
|
||||
if isinstance(result, Image.Image):
|
||||
@@ -140,74 +142,75 @@ class CustomComputerHandler(AsyncComputerHandler):
|
||||
except Exception:
|
||||
# If we can't get dimensions, that's okay
|
||||
pass
|
||||
|
||||
|
||||
return b64_str
|
||||
|
||||
|
||||
async def click(self, x: int, y: int, button: str = "left") -> None:
|
||||
"""Click at coordinates with specified button."""
|
||||
if 'click' in self.functions:
|
||||
await self._call_function(self.functions['click'], x, y, button)
|
||||
if "click" in self.functions:
|
||||
await self._call_function(self.functions["click"], x, y, button)
|
||||
# No-op if not implemented
|
||||
|
||||
|
||||
async def double_click(self, x: int, y: int) -> None:
|
||||
"""Double click at coordinates."""
|
||||
if 'double_click' in self.functions:
|
||||
await self._call_function(self.functions['double_click'], x, y)
|
||||
if "double_click" in self.functions:
|
||||
await self._call_function(self.functions["double_click"], x, y)
|
||||
# No-op if not implemented
|
||||
|
||||
|
||||
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
|
||||
"""Scroll at coordinates with specified scroll amounts."""
|
||||
if 'scroll' in self.functions:
|
||||
await self._call_function(self.functions['scroll'], x, y, scroll_x, scroll_y)
|
||||
if "scroll" in self.functions:
|
||||
await self._call_function(self.functions["scroll"], x, y, scroll_x, scroll_y)
|
||||
# No-op if not implemented
|
||||
|
||||
|
||||
async def type(self, text: str) -> None:
|
||||
"""Type text."""
|
||||
if 'type' in self.functions:
|
||||
await self._call_function(self.functions['type'], text)
|
||||
if "type" in self.functions:
|
||||
await self._call_function(self.functions["type"], text)
|
||||
# No-op if not implemented
|
||||
|
||||
|
||||
async def wait(self, ms: int = 1000) -> None:
|
||||
"""Wait for specified milliseconds."""
|
||||
if 'wait' in self.functions:
|
||||
await self._call_function(self.functions['wait'], ms)
|
||||
if "wait" in self.functions:
|
||||
await self._call_function(self.functions["wait"], ms)
|
||||
else:
|
||||
# Default implementation
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(ms / 1000.0)
|
||||
|
||||
|
||||
async def move(self, x: int, y: int) -> None:
|
||||
"""Move cursor to coordinates."""
|
||||
if 'move' in self.functions:
|
||||
await self._call_function(self.functions['move'], x, y)
|
||||
if "move" in self.functions:
|
||||
await self._call_function(self.functions["move"], x, y)
|
||||
# No-op if not implemented
|
||||
|
||||
|
||||
async def keypress(self, keys: Union[List[str], str]) -> None:
|
||||
"""Press key combination."""
|
||||
if 'keypress' in self.functions:
|
||||
await self._call_function(self.functions['keypress'], keys)
|
||||
if "keypress" in self.functions:
|
||||
await self._call_function(self.functions["keypress"], keys)
|
||||
# No-op if not implemented
|
||||
|
||||
|
||||
async def drag(self, path: List[Dict[str, int]]) -> None:
|
||||
"""Drag along specified path."""
|
||||
if 'drag' in self.functions:
|
||||
await self._call_function(self.functions['drag'], path)
|
||||
if "drag" in self.functions:
|
||||
await self._call_function(self.functions["drag"], path)
|
||||
# No-op if not implemented
|
||||
|
||||
|
||||
async def get_current_url(self) -> str:
|
||||
"""Get current URL (for browser environments)."""
|
||||
if 'get_current_url' in self.functions:
|
||||
return await self._get_value('current_url') # type: ignore
|
||||
if "get_current_url" in self.functions:
|
||||
return await self._get_value("current_url") # type: ignore
|
||||
return "" # Default fallback
|
||||
|
||||
|
||||
async def left_mouse_down(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
"""Left mouse down at coordinates."""
|
||||
if 'left_mouse_down' in self.functions:
|
||||
await self._call_function(self.functions['left_mouse_down'], x, y)
|
||||
if "left_mouse_down" in self.functions:
|
||||
await self._call_function(self.functions["left_mouse_down"], x, y)
|
||||
# No-op if not implemented
|
||||
|
||||
|
||||
async def left_mouse_up(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
"""Left mouse up at coordinates."""
|
||||
if 'left_mouse_up' in self.functions:
|
||||
await self._call_function(self.functions['left_mouse_up'], x, y)
|
||||
if "left_mouse_up" in self.functions:
|
||||
await self._call_function(self.functions["left_mouse_up"], x, y)
|
||||
# No-op if not implemented
|
||||
|
||||
@@ -3,47 +3,56 @@ Decorators for agent - agent_loop decorator
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from .types import AgentConfigInfo
|
||||
|
||||
# Global registry
|
||||
_agent_configs: List[AgentConfigInfo] = []
|
||||
|
||||
|
||||
def register_agent(models: str, priority: int = 0):
|
||||
"""
|
||||
Decorator to register an AsyncAgentConfig class.
|
||||
|
||||
|
||||
Args:
|
||||
models: Regex pattern to match supported models
|
||||
priority: Priority for agent selection (higher = more priority)
|
||||
"""
|
||||
|
||||
def decorator(agent_class: type):
|
||||
# Validate that the class implements AsyncAgentConfig protocol
|
||||
if not hasattr(agent_class, 'predict_step'):
|
||||
raise ValueError(f"Agent class {agent_class.__name__} must implement predict_step method")
|
||||
if not hasattr(agent_class, 'predict_click'):
|
||||
raise ValueError(f"Agent class {agent_class.__name__} must implement predict_click method")
|
||||
if not hasattr(agent_class, 'get_capabilities'):
|
||||
raise ValueError(f"Agent class {agent_class.__name__} must implement get_capabilities method")
|
||||
|
||||
if not hasattr(agent_class, "predict_step"):
|
||||
raise ValueError(
|
||||
f"Agent class {agent_class.__name__} must implement predict_step method"
|
||||
)
|
||||
if not hasattr(agent_class, "predict_click"):
|
||||
raise ValueError(
|
||||
f"Agent class {agent_class.__name__} must implement predict_click method"
|
||||
)
|
||||
if not hasattr(agent_class, "get_capabilities"):
|
||||
raise ValueError(
|
||||
f"Agent class {agent_class.__name__} must implement get_capabilities method"
|
||||
)
|
||||
|
||||
# Register the agent config
|
||||
config_info = AgentConfigInfo(
|
||||
agent_class=agent_class,
|
||||
models_regex=models,
|
||||
priority=priority
|
||||
agent_class=agent_class, models_regex=models, priority=priority
|
||||
)
|
||||
_agent_configs.append(config_info)
|
||||
|
||||
|
||||
# Sort by priority (highest first)
|
||||
_agent_configs.sort(key=lambda x: x.priority, reverse=True)
|
||||
|
||||
|
||||
return agent_class
|
||||
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_agent_configs() -> List[AgentConfigInfo]:
|
||||
"""Get all registered agent configs"""
|
||||
return _agent_configs.copy()
|
||||
|
||||
|
||||
def find_agent_config(model: str) -> Optional[AgentConfigInfo]:
|
||||
"""Find the best matching agent config for a model"""
|
||||
for config_info in _agent_configs:
|
||||
|
||||
@@ -12,7 +12,7 @@ Components:
|
||||
Usage:
|
||||
# Run the server and UI
|
||||
python -m agent.human_tool
|
||||
|
||||
|
||||
# Or run components separately
|
||||
python -m agent.human_tool.server # API server only
|
||||
python -m agent.human_tool.ui # UI only
|
||||
@@ -21,9 +21,4 @@ Usage:
|
||||
from .server import CompletionQueue, completion_queue
|
||||
from .ui import HumanCompletionUI, create_ui
|
||||
|
||||
__all__ = [
|
||||
"CompletionQueue",
|
||||
"completion_queue",
|
||||
"HumanCompletionUI",
|
||||
"create_ui"
|
||||
]
|
||||
__all__ = ["CompletionQueue", "completion_queue", "HumanCompletionUI", "create_ui"]
|
||||
|
||||
@@ -8,6 +8,7 @@ with a Gradio UI for human interaction.
|
||||
|
||||
import gradio as gr
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .server import app as fastapi_app
|
||||
from .ui import create_ui
|
||||
|
||||
@@ -18,6 +19,7 @@ gradio_demo = create_ui()
|
||||
CUSTOM_PATH = "/gradio"
|
||||
app = gr.mount_gradio_app(fastapi_app, gradio_demo, path=CUSTOM_PATH)
|
||||
|
||||
|
||||
# Add a redirect from root to Gradio UI
|
||||
@fastapi_app.get("/")
|
||||
async def redirect_to_ui():
|
||||
@@ -25,14 +27,16 @@ async def redirect_to_ui():
|
||||
return {
|
||||
"message": "Human Completion Server is running",
|
||||
"ui_url": "/gradio",
|
||||
"api_docs": "/docs"
|
||||
"api_docs": "/docs",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
print("🚀 Starting Human-in-the-Loop Completion Server...")
|
||||
print("📊 API Server: http://localhost:8002")
|
||||
print("🎨 Gradio UI: http://localhost:8002/gradio")
|
||||
print("📚 API Docs: http://localhost:8002/docs")
|
||||
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8002)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any, Optional
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
@@ -49,7 +49,7 @@ class CompletionQueue:
|
||||
self._queue: Dict[str, CompletionCall] = {}
|
||||
self._pending_order: List[str] = []
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def add_completion(self, messages: List[Dict[str, Any]], model: str) -> str:
|
||||
"""Add a completion call to the queue."""
|
||||
async with self._lock:
|
||||
@@ -59,42 +59,47 @@ class CompletionQueue:
|
||||
messages=messages,
|
||||
model=model,
|
||||
status=CompletionStatus.PENDING,
|
||||
created_at=datetime.now()
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
self._queue[call_id] = completion_call
|
||||
self._pending_order.append(call_id)
|
||||
return call_id
|
||||
|
||||
|
||||
async def get_pending_calls(self) -> List[Dict[str, Any]]:
|
||||
"""Get all pending completion calls."""
|
||||
async with self._lock:
|
||||
pending_calls = []
|
||||
for call_id in self._pending_order:
|
||||
if call_id in self._queue and self._queue[call_id].status == CompletionStatus.PENDING:
|
||||
if (
|
||||
call_id in self._queue
|
||||
and self._queue[call_id].status == CompletionStatus.PENDING
|
||||
):
|
||||
call = self._queue[call_id]
|
||||
pending_calls.append({
|
||||
"id": call.id,
|
||||
"model": call.model,
|
||||
"created_at": call.created_at.isoformat(),
|
||||
"messages": call.messages
|
||||
})
|
||||
pending_calls.append(
|
||||
{
|
||||
"id": call.id,
|
||||
"model": call.model,
|
||||
"created_at": call.created_at.isoformat(),
|
||||
"messages": call.messages,
|
||||
}
|
||||
)
|
||||
return pending_calls
|
||||
|
||||
|
||||
async def get_call_status(self, call_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get the status of a specific completion call."""
|
||||
async with self._lock:
|
||||
if call_id not in self._queue:
|
||||
return None
|
||||
|
||||
|
||||
call = self._queue[call_id]
|
||||
result = {
|
||||
"id": call.id,
|
||||
"status": call.status.value,
|
||||
"created_at": call.created_at.isoformat(),
|
||||
"model": call.model,
|
||||
"messages": call.messages
|
||||
"messages": call.messages,
|
||||
}
|
||||
|
||||
|
||||
if call.completed_at:
|
||||
result["completed_at"] = call.completed_at.isoformat()
|
||||
if call.response:
|
||||
@@ -103,69 +108,74 @@ class CompletionQueue:
|
||||
result["tool_calls"] = call.tool_calls
|
||||
if call.error:
|
||||
result["error"] = call.error
|
||||
|
||||
|
||||
return result
|
||||
|
||||
async def complete_call(self, call_id: str, response: Optional[str] = None, tool_calls: Optional[List[Dict[str, Any]]] = None) -> bool:
|
||||
|
||||
async def complete_call(
|
||||
self,
|
||||
call_id: str,
|
||||
response: Optional[str] = None,
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> bool:
|
||||
"""Mark a completion call as completed with a response or tool calls."""
|
||||
async with self._lock:
|
||||
if call_id not in self._queue:
|
||||
return False
|
||||
|
||||
|
||||
call = self._queue[call_id]
|
||||
if call.status != CompletionStatus.PENDING:
|
||||
return False
|
||||
|
||||
|
||||
call.status = CompletionStatus.COMPLETED
|
||||
call.completed_at = datetime.now()
|
||||
call.response = response
|
||||
call.tool_calls = tool_calls
|
||||
|
||||
|
||||
# Remove from pending order
|
||||
if call_id in self._pending_order:
|
||||
self._pending_order.remove(call_id)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def fail_call(self, call_id: str, error: str) -> bool:
|
||||
"""Mark a completion call as failed with an error."""
|
||||
async with self._lock:
|
||||
if call_id not in self._queue:
|
||||
return False
|
||||
|
||||
|
||||
call = self._queue[call_id]
|
||||
if call.status != CompletionStatus.PENDING:
|
||||
return False
|
||||
|
||||
|
||||
call.status = CompletionStatus.FAILED
|
||||
call.completed_at = datetime.now()
|
||||
call.error = error
|
||||
|
||||
|
||||
# Remove from pending order
|
||||
if call_id in self._pending_order:
|
||||
self._pending_order.remove(call_id)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def wait_for_completion(self, call_id: str, timeout: float = 300.0) -> Optional[str]:
|
||||
"""Wait for a completion call to be completed and return the response."""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
|
||||
while True:
|
||||
status = await self.get_call_status(call_id)
|
||||
if not status:
|
||||
return None
|
||||
|
||||
|
||||
if status["status"] == CompletionStatus.COMPLETED.value:
|
||||
return status.get("response")
|
||||
elif status["status"] == CompletionStatus.FAILED.value:
|
||||
raise Exception(f"Completion failed: {status.get('error', 'Unknown error')}")
|
||||
|
||||
|
||||
# Check timeout
|
||||
if asyncio.get_event_loop().time() - start_time > timeout:
|
||||
await self.fail_call(call_id, "Timeout waiting for human response")
|
||||
raise TimeoutError("Timeout waiting for human response")
|
||||
|
||||
|
||||
# Wait a bit before checking again
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
@@ -204,9 +214,7 @@ async def get_status(call_id: str):
|
||||
async def complete_call(call_id: str, response: CompletionResponse):
|
||||
"""Complete a call with a human response."""
|
||||
success = await completion_queue.complete_call(
|
||||
call_id,
|
||||
response=response.response,
|
||||
tool_calls=response.tool_calls
|
||||
call_id, response=response.response, tool_calls=response.tool_calls
|
||||
)
|
||||
if success:
|
||||
return {"status": "success", "message": "Call completed"}
|
||||
@@ -219,7 +227,9 @@ async def fail_call(call_id: str, error: Dict[str, str]):
|
||||
"""Mark a call as failed."""
|
||||
success = await completion_queue.fail_call(call_id, error.get("error", "Unknown error"))
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Completion call not found or already completed")
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Completion call not found or already completed"
|
||||
)
|
||||
return {"status": "failed"}
|
||||
|
||||
|
||||
@@ -231,4 +241,5 @@ async def root():
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8002)
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
import gradio as gr
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
import requests
|
||||
from .server import completion_queue
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import gradio as gr
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from .server import completion_queue
|
||||
|
||||
|
||||
class HumanCompletionUI:
|
||||
def __init__(self, server_url: str = "http://localhost:8002"):
|
||||
self.server_url = server_url
|
||||
@@ -20,7 +23,7 @@ class HumanCompletionUI:
|
||||
self.current_button: str = "left"
|
||||
self.current_scroll_x: int = 0
|
||||
self.current_scroll_y: int = -120
|
||||
|
||||
|
||||
def format_messages_for_chatbot(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Format messages for display in gr.Chatbot with type='messages'."""
|
||||
formatted = []
|
||||
@@ -28,7 +31,7 @@ class HumanCompletionUI:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
tool_calls = msg.get("tool_calls", [])
|
||||
|
||||
|
||||
# Handle different content formats
|
||||
if isinstance(content, list):
|
||||
# Multi-modal content - can include text and images
|
||||
@@ -55,7 +58,7 @@ class HumanCompletionUI:
|
||||
else:
|
||||
# For URL images, create gr.Image with URL
|
||||
formatted_content.append(gr.Image(value=image_url))
|
||||
|
||||
|
||||
# Determine final content format
|
||||
if len(formatted_content) == 1:
|
||||
content = formatted_content[0]
|
||||
@@ -63,28 +66,28 @@ class HumanCompletionUI:
|
||||
content = formatted_content
|
||||
else:
|
||||
content = "[Empty content]"
|
||||
|
||||
|
||||
# Ensure role is valid for Gradio Chatbot
|
||||
if role not in ["user", "assistant"]:
|
||||
role = "assistant" if role == "system" else "user"
|
||||
|
||||
|
||||
# Invert roles for better display in human UI context
|
||||
# (what the AI says becomes "user", what human should respond becomes "assistant")
|
||||
if role == "user":
|
||||
role = "assistant"
|
||||
else:
|
||||
role = "user"
|
||||
|
||||
|
||||
# Add the main message if it has content
|
||||
if content and str(content).strip():
|
||||
formatted.append({"role": role, "content": content})
|
||||
|
||||
|
||||
# Handle tool calls - create separate messages for each tool call
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call.get("function", {}).get("name", "unknown")
|
||||
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
|
||||
|
||||
|
||||
try:
|
||||
# Parse arguments to format them nicely
|
||||
arguments = json.loads(arguments_str)
|
||||
@@ -92,18 +95,20 @@ class HumanCompletionUI:
|
||||
except json.JSONDecodeError:
|
||||
# If parsing fails, use the raw string
|
||||
formatted_args = arguments_str
|
||||
|
||||
|
||||
# Create a formatted message for the tool call
|
||||
tool_call_content = f"```json\n{formatted_args}\n```"
|
||||
|
||||
formatted.append({
|
||||
"role": role,
|
||||
"content": tool_call_content,
|
||||
"metadata": {"title": f"🛠️ Used {function_name}"}
|
||||
})
|
||||
|
||||
|
||||
formatted.append(
|
||||
{
|
||||
"role": role,
|
||||
"content": tool_call_content,
|
||||
"metadata": {"title": f"🛠️ Used {function_name}"},
|
||||
}
|
||||
)
|
||||
|
||||
return formatted
|
||||
|
||||
|
||||
def get_pending_calls(self) -> List[Dict[str, Any]]:
|
||||
"""Get pending calls from the server."""
|
||||
try:
|
||||
@@ -113,38 +118,39 @@ class HumanCompletionUI:
|
||||
except Exception as e:
|
||||
print(f"Error fetching pending calls: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def complete_call_with_response(self, call_id: str, response: str) -> bool:
|
||||
"""Complete a call with a text response."""
|
||||
try:
|
||||
response_data = {"response": response}
|
||||
response_obj = requests.post(
|
||||
f"{self.server_url}/complete/{call_id}",
|
||||
json=response_data,
|
||||
timeout=10
|
||||
f"{self.server_url}/complete/{call_id}", json=response_data, timeout=10
|
||||
)
|
||||
response_obj.raise_for_status()
|
||||
return True
|
||||
except requests.RequestException as e:
|
||||
print(f"Error completing call: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def complete_call_with_tool_calls(self, call_id: str, tool_calls: List[Dict[str, Any]]) -> bool:
|
||||
"""Complete a call with tool calls."""
|
||||
try:
|
||||
response_data = {"tool_calls": tool_calls}
|
||||
response_obj = requests.post(
|
||||
f"{self.server_url}/complete/{call_id}",
|
||||
json=response_data,
|
||||
timeout=10
|
||||
f"{self.server_url}/complete/{call_id}", json=response_data, timeout=10
|
||||
)
|
||||
response_obj.raise_for_status()
|
||||
return True
|
||||
except requests.RequestException as e:
|
||||
print(f"Error completing call: {e}")
|
||||
return False
|
||||
|
||||
def complete_call(self, call_id: str, response: Optional[str] = None, tool_calls: Optional[List[Dict[str, Any]]] = None) -> bool:
|
||||
|
||||
def complete_call(
|
||||
self,
|
||||
call_id: str,
|
||||
response: Optional[str] = None,
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> bool:
|
||||
"""Complete a call with either a response or tool calls."""
|
||||
try:
|
||||
response_data = {}
|
||||
@@ -152,25 +158,23 @@ class HumanCompletionUI:
|
||||
response_data["response"] = response
|
||||
if tool_calls:
|
||||
response_data["tool_calls"] = tool_calls
|
||||
|
||||
|
||||
response_obj = requests.post(
|
||||
f"{self.server_url}/complete/{call_id}",
|
||||
json=response_data,
|
||||
timeout=10
|
||||
f"{self.server_url}/complete/{call_id}", json=response_data, timeout=10
|
||||
)
|
||||
response_obj.raise_for_status()
|
||||
return True
|
||||
except requests.RequestException as e:
|
||||
print(f"Error completing call: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_last_image_from_messages(self, messages: List[Dict[str, Any]]) -> Optional[Any]:
|
||||
"""Extract the last image from the messages for display above conversation."""
|
||||
last_image = None
|
||||
|
||||
|
||||
for msg in reversed(messages): # Start from the last message
|
||||
content = msg.get("content", "")
|
||||
|
||||
|
||||
if isinstance(content, list):
|
||||
for item in reversed(content): # Get the last image in the message
|
||||
if item.get("type") == "image_url":
|
||||
@@ -189,13 +193,13 @@ class HumanCompletionUI:
|
||||
else:
|
||||
# For URL images, return the URL
|
||||
return image_url
|
||||
|
||||
|
||||
return last_image
|
||||
|
||||
|
||||
def refresh_pending_calls(self):
|
||||
"""Refresh the list of pending calls."""
|
||||
pending_calls = self.get_pending_calls()
|
||||
|
||||
|
||||
if not pending_calls:
|
||||
return (
|
||||
gr.update(choices=["latest"], value="latest"), # dropdown
|
||||
@@ -205,27 +209,27 @@ class HumanCompletionUI:
|
||||
gr.update(visible=False), # click_actions_group hidden
|
||||
gr.update(visible=False), # actions_group hidden
|
||||
)
|
||||
|
||||
|
||||
# Sort pending calls by created_at to get oldest first
|
||||
sorted_calls = sorted(pending_calls, key=lambda x: x.get("created_at", ""))
|
||||
|
||||
|
||||
# Create choices for dropdown
|
||||
choices = [("latest", "latest")] # Add "latest" option first
|
||||
|
||||
|
||||
for call in sorted_calls:
|
||||
call_id = call["id"]
|
||||
model = call.get("model", "unknown")
|
||||
created_at = call.get("created_at", "")
|
||||
# Format timestamp
|
||||
try:
|
||||
dt = datetime.fromisoformat(created_at.replace('Z', '+00:00'))
|
||||
dt = datetime.fromisoformat(created_at.replace("Z", "+00:00"))
|
||||
time_str = dt.strftime("%H:%M:%S")
|
||||
except:
|
||||
time_str = created_at
|
||||
|
||||
|
||||
choice_label = f"{call_id[:8]}... ({model}) - {time_str}"
|
||||
choices.append((choice_label, call_id))
|
||||
|
||||
|
||||
# Default to "latest" which shows the oldest pending conversation
|
||||
selected_call_id = "latest"
|
||||
if selected_call_id == "latest" and sorted_calls:
|
||||
@@ -239,7 +243,7 @@ class HumanCompletionUI:
|
||||
conversation = []
|
||||
self.current_call_id = None
|
||||
self.last_image = None
|
||||
|
||||
|
||||
return (
|
||||
gr.update(choices=choices, value="latest"),
|
||||
gr.update(value=self.last_image),
|
||||
@@ -248,7 +252,7 @@ class HumanCompletionUI:
|
||||
gr.update(visible=True), # click_actions_group visible when there is a call
|
||||
gr.update(visible=True), # actions_group visible when there is a call
|
||||
)
|
||||
|
||||
|
||||
def on_call_selected(self, selected_choice):
|
||||
"""Handle when a call is selected from the dropdown."""
|
||||
if not selected_choice:
|
||||
@@ -259,7 +263,7 @@ class HumanCompletionUI:
|
||||
gr.update(visible=False), # click_actions_group hidden
|
||||
gr.update(visible=False), # actions_group hidden
|
||||
)
|
||||
|
||||
|
||||
pending_calls = self.get_pending_calls()
|
||||
if not pending_calls:
|
||||
return (
|
||||
@@ -269,7 +273,7 @@ class HumanCompletionUI:
|
||||
gr.update(visible=False), # click_actions_group hidden
|
||||
gr.update(visible=False), # actions_group hidden
|
||||
)
|
||||
|
||||
|
||||
# Handle "latest" option
|
||||
if selected_choice == "latest":
|
||||
# Sort calls by created_at to get oldest first
|
||||
@@ -284,17 +288,17 @@ class HumanCompletionUI:
|
||||
if call_id_short in selected_choice:
|
||||
call_id = call["id"]
|
||||
break
|
||||
|
||||
|
||||
if not call_id:
|
||||
return (
|
||||
gr.update(value=None), # no image
|
||||
gr.update(value=[]), # empty chatbot
|
||||
gr.update(interactive=False)
|
||||
gr.update(interactive=False),
|
||||
)
|
||||
|
||||
|
||||
# Find the selected call
|
||||
selected_call = next((c for c in pending_calls if c["id"] == call_id), None)
|
||||
|
||||
|
||||
if not selected_call:
|
||||
return (
|
||||
gr.update(value=None), # no image
|
||||
@@ -303,12 +307,12 @@ class HumanCompletionUI:
|
||||
gr.update(visible=False), # click_actions_group hidden
|
||||
gr.update(visible=False), # actions_group hidden
|
||||
)
|
||||
|
||||
|
||||
conversation = self.format_messages_for_chatbot(selected_call.get("messages", []))
|
||||
self.current_call_id = call_id
|
||||
# Get the last image from messages
|
||||
self.last_image = self.get_last_image_from_messages(selected_call.get("messages", []))
|
||||
|
||||
|
||||
return (
|
||||
gr.update(value=self.last_image),
|
||||
gr.update(value=conversation),
|
||||
@@ -316,110 +320,111 @@ class HumanCompletionUI:
|
||||
gr.update(visible=True), # click_actions_group visible
|
||||
gr.update(visible=True), # actions_group visible
|
||||
)
|
||||
|
||||
|
||||
def submit_response(self, response_text: str):
|
||||
"""Submit a text response to the current call."""
|
||||
if not self.current_call_id:
|
||||
return (
|
||||
gr.update(value=response_text), # keep response text
|
||||
gr.update(value="❌ No call selected") # status
|
||||
gr.update(value="❌ No call selected"), # status
|
||||
)
|
||||
|
||||
|
||||
if not response_text.strip():
|
||||
return (
|
||||
gr.update(value=response_text), # keep response text
|
||||
gr.update(value="❌ Response cannot be empty") # status
|
||||
gr.update(value="❌ Response cannot be empty"), # status
|
||||
)
|
||||
|
||||
|
||||
success = self.complete_call_with_response(self.current_call_id, response_text)
|
||||
|
||||
|
||||
if success:
|
||||
status_msg = "✅ Response submitted successfully!"
|
||||
return (
|
||||
gr.update(value=""), # clear response text
|
||||
gr.update(value=status_msg) # status
|
||||
gr.update(value=status_msg), # status
|
||||
)
|
||||
else:
|
||||
return (
|
||||
gr.update(value=response_text), # keep response text
|
||||
gr.update(value="❌ Failed to submit response") # status
|
||||
gr.update(value="❌ Failed to submit response"), # status
|
||||
)
|
||||
|
||||
|
||||
def submit_action(self, action_type: str, **kwargs) -> str:
|
||||
"""Submit a computer action as a tool call."""
|
||||
if not self.current_call_id:
|
||||
return "❌ No call selected"
|
||||
|
||||
|
||||
import uuid
|
||||
|
||||
|
||||
# Create tool call structure
|
||||
action_data = {"type": action_type, **kwargs}
|
||||
tool_call = {
|
||||
"id": f"call_{uuid.uuid4().hex[:24]}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "computer",
|
||||
"arguments": json.dumps(action_data)
|
||||
}
|
||||
"function": {"name": "computer", "arguments": json.dumps(action_data)},
|
||||
}
|
||||
|
||||
|
||||
success = self.complete_call_with_tool_calls(self.current_call_id, [tool_call])
|
||||
|
||||
|
||||
if success:
|
||||
return f"✅ {action_type.capitalize()} action submitted as tool call"
|
||||
else:
|
||||
return f"❌ Failed to submit {action_type} action"
|
||||
|
||||
def submit_click_action(self, x: int, y: int, action_type: str = "click", button: str = "left") -> str:
|
||||
|
||||
def submit_click_action(
|
||||
self, x: int, y: int, action_type: str = "click", button: str = "left"
|
||||
) -> str:
|
||||
"""Submit a coordinate-based action."""
|
||||
if action_type == "click":
|
||||
return self.submit_action(action_type, x=x, y=y, button=button)
|
||||
else:
|
||||
return self.submit_action(action_type, x=x, y=y)
|
||||
|
||||
|
||||
def submit_type_action(self, text: str) -> str:
|
||||
"""Submit a type action."""
|
||||
return self.submit_action("type", text=text)
|
||||
|
||||
|
||||
def submit_hotkey_action(self, keys: str) -> str:
|
||||
"""Submit a hotkey action."""
|
||||
return self.submit_action("keypress", keys=keys)
|
||||
|
||||
|
||||
def submit_wait_action(self) -> str:
|
||||
"""Submit a wait action with no kwargs."""
|
||||
return self.submit_action("wait")
|
||||
|
||||
def submit_description_click(self, description: str, action_type: str = "click", button: str = "left") -> str:
|
||||
|
||||
def submit_description_click(
|
||||
self, description: str, action_type: str = "click", button: str = "left"
|
||||
) -> str:
|
||||
"""Submit a description-based action."""
|
||||
if action_type == "click":
|
||||
return self.submit_action(action_type, element_description=description, button=button)
|
||||
else:
|
||||
return self.submit_action(action_type, element_description=description)
|
||||
|
||||
|
||||
def wait_for_pending_calls(self, max_seconds: float = 10.0, check_interval: float = 0.2):
|
||||
"""Wait for pending calls to appear or until max_seconds elapsed.
|
||||
|
||||
|
||||
This method loops and checks for pending calls at regular intervals,
|
||||
returning as soon as a pending call is found or the maximum wait time is reached.
|
||||
|
||||
|
||||
Args:
|
||||
max_seconds: Maximum number of seconds to wait
|
||||
check_interval: How often to check for pending calls (in seconds)
|
||||
"""
|
||||
import time
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
while time.time() - start_time < max_seconds:
|
||||
# Check if there are any pending calls
|
||||
pending_calls = self.get_pending_calls()
|
||||
if pending_calls:
|
||||
# Found pending calls, return immediately
|
||||
return self.refresh_pending_calls()
|
||||
|
||||
|
||||
# Wait before checking again
|
||||
time.sleep(check_interval)
|
||||
|
||||
|
||||
# Max wait time reached, return current state
|
||||
return self.refresh_pending_calls()
|
||||
|
||||
@@ -427,79 +432,73 @@ class HumanCompletionUI:
|
||||
def create_ui():
|
||||
"""Create the Gradio interface."""
|
||||
ui_handler = HumanCompletionUI()
|
||||
|
||||
|
||||
with gr.Blocks(title="Human-in-the-Loop Agent Tool", fill_width=True) as demo:
|
||||
gr.Markdown("# 🤖 Human-in-the-Loop Agent Tool")
|
||||
gr.Markdown("Review AI conversation requests and provide human responses.")
|
||||
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
with gr.Group():
|
||||
screenshot_image = gr.Image(
|
||||
label="Interactive Screenshot",
|
||||
interactive=False,
|
||||
height=600
|
||||
label="Interactive Screenshot", interactive=False, height=600
|
||||
)
|
||||
|
||||
|
||||
# Action type selection for image clicks (wrapped for visibility control)
|
||||
with gr.Group(visible=False) as click_actions_group:
|
||||
with gr.Row():
|
||||
action_type_radio = gr.Dropdown(
|
||||
label="Interactive Action",
|
||||
choices=["click", "double_click", "move", "left_mouse_up", "left_mouse_down", "scroll"],
|
||||
choices=[
|
||||
"click",
|
||||
"double_click",
|
||||
"move",
|
||||
"left_mouse_up",
|
||||
"left_mouse_down",
|
||||
"scroll",
|
||||
],
|
||||
value="click",
|
||||
scale=2
|
||||
scale=2,
|
||||
)
|
||||
action_button_radio = gr.Dropdown(
|
||||
label="Button",
|
||||
choices=["left", "right", "wheel", "back", "forward"],
|
||||
value="left",
|
||||
visible=True,
|
||||
scale=1
|
||||
scale=1,
|
||||
)
|
||||
scroll_x_input = gr.Number(
|
||||
label="scroll_x",
|
||||
value=0,
|
||||
visible=False,
|
||||
scale=1
|
||||
label="scroll_x", value=0, visible=False, scale=1
|
||||
)
|
||||
scroll_y_input = gr.Number(
|
||||
label="scroll_y",
|
||||
value=-120,
|
||||
visible=False,
|
||||
scale=1
|
||||
label="scroll_y", value=-120, visible=False, scale=1
|
||||
)
|
||||
|
||||
|
||||
conversation_chatbot = gr.Chatbot(
|
||||
label="Conversation",
|
||||
type="messages",
|
||||
height=500,
|
||||
show_copy_button=True
|
||||
label="Conversation", type="messages", height=500, show_copy_button=True
|
||||
)
|
||||
|
||||
|
||||
with gr.Column(scale=1):
|
||||
with gr.Group():
|
||||
call_dropdown = gr.Dropdown(
|
||||
label="Select a pending conversation request",
|
||||
choices=["latest"],
|
||||
interactive=True,
|
||||
value="latest"
|
||||
value="latest",
|
||||
)
|
||||
refresh_btn = gr.Button("🔄 Refresh", variant="secondary")
|
||||
status_display = gr.Textbox(
|
||||
label="Status",
|
||||
interactive=False,
|
||||
value="Ready to receive requests..."
|
||||
label="Status", interactive=False, value="Ready to receive requests..."
|
||||
)
|
||||
|
||||
with gr.Group():
|
||||
response_text = gr.Textbox(
|
||||
label="Message",
|
||||
lines=3,
|
||||
placeholder="Enter your message here..."
|
||||
label="Message", lines=3, placeholder="Enter your message here..."
|
||||
)
|
||||
submit_btn = gr.Button("📤 Submit Message", variant="primary", interactive=False)
|
||||
|
||||
submit_btn = gr.Button(
|
||||
"📤 Submit Message", variant="primary", interactive=False
|
||||
)
|
||||
|
||||
# Action Accordions (wrapped for visibility control)
|
||||
with gr.Group(visible=False) as actions_group:
|
||||
with gr.Tabs():
|
||||
@@ -507,58 +506,73 @@ def create_ui():
|
||||
with gr.Group():
|
||||
description_text = gr.Textbox(
|
||||
label="Element Description",
|
||||
placeholder="e.g., 'Privacy and security option in left sidebar'"
|
||||
placeholder="e.g., 'Privacy and security option in left sidebar'",
|
||||
)
|
||||
with gr.Row():
|
||||
description_action_type = gr.Dropdown(
|
||||
label="Action",
|
||||
choices=["click", "double_click", "move", "left_mouse_up", "left_mouse_down"],
|
||||
value="click"
|
||||
choices=[
|
||||
"click",
|
||||
"double_click",
|
||||
"move",
|
||||
"left_mouse_up",
|
||||
"left_mouse_down",
|
||||
],
|
||||
value="click",
|
||||
)
|
||||
description_button = gr.Dropdown(
|
||||
label="Button",
|
||||
choices=["left", "right", "wheel", "back", "forward"],
|
||||
value="left"
|
||||
value="left",
|
||||
)
|
||||
description_submit_btn = gr.Button("Submit Click Action")
|
||||
|
||||
|
||||
with gr.Tab("📝 Type Action"):
|
||||
with gr.Group():
|
||||
type_text = gr.Textbox(
|
||||
label="Text to Type",
|
||||
placeholder="Enter text to type..."
|
||||
label="Text to Type", placeholder="Enter text to type..."
|
||||
)
|
||||
type_submit_btn = gr.Button("Submit Type")
|
||||
|
||||
|
||||
with gr.Tab("⌨️ Keypress Action"):
|
||||
with gr.Group():
|
||||
keypress_text = gr.Textbox(
|
||||
label="Keys",
|
||||
placeholder="e.g., ctrl+c, alt+tab"
|
||||
label="Keys", placeholder="e.g., ctrl+c, alt+tab"
|
||||
)
|
||||
keypress_submit_btn = gr.Button("Submit Keypress")
|
||||
|
||||
|
||||
with gr.Tab("🧰 Misc Actions"):
|
||||
with gr.Group():
|
||||
misc_action_dropdown = gr.Dropdown(
|
||||
label="Action",
|
||||
choices=["wait"],
|
||||
value="wait"
|
||||
label="Action", choices=["wait"], value="wait"
|
||||
)
|
||||
misc_submit_btn = gr.Button("Submit Action")
|
||||
|
||||
|
||||
# Event handlers
|
||||
refresh_btn.click(
|
||||
fn=ui_handler.refresh_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
outputs=[
|
||||
call_dropdown,
|
||||
screenshot_image,
|
||||
conversation_chatbot,
|
||||
submit_btn,
|
||||
click_actions_group,
|
||||
actions_group,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
call_dropdown.change(
|
||||
fn=ui_handler.on_call_selected,
|
||||
inputs=[call_dropdown],
|
||||
outputs=[screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
outputs=[
|
||||
screenshot_image,
|
||||
conversation_chatbot,
|
||||
submit_btn,
|
||||
click_actions_group,
|
||||
actions_group,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def handle_image_click(evt: gr.SelectData):
|
||||
if evt.index is not None:
|
||||
x, y = evt.index
|
||||
@@ -568,31 +582,44 @@ def create_ui():
|
||||
sx_i = int(ui_handler.current_scroll_x or 0)
|
||||
sy_i = int(ui_handler.current_scroll_y or 0)
|
||||
# Submit a scroll action with x,y position and scroll deltas
|
||||
result = ui_handler.submit_action("scroll", x=x, y=y, scroll_x=sx_i, scroll_y=sy_i)
|
||||
result = ui_handler.submit_action(
|
||||
"scroll", x=x, y=y, scroll_x=sx_i, scroll_y=sy_i
|
||||
)
|
||||
else:
|
||||
result = ui_handler.submit_click_action(x, y, action_type, button)
|
||||
ui_handler.wait_for_pending_calls()
|
||||
return result
|
||||
return "No coordinates selected"
|
||||
|
||||
screenshot_image.select(
|
||||
fn=handle_image_click,
|
||||
outputs=[status_display]
|
||||
).then(
|
||||
screenshot_image.select(fn=handle_image_click, outputs=[status_display]).then(
|
||||
fn=ui_handler.wait_for_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
outputs=[
|
||||
call_dropdown,
|
||||
screenshot_image,
|
||||
conversation_chatbot,
|
||||
submit_btn,
|
||||
click_actions_group,
|
||||
actions_group,
|
||||
],
|
||||
)
|
||||
|
||||
# Response submission
|
||||
submit_btn.click(
|
||||
fn=ui_handler.submit_response,
|
||||
inputs=[response_text],
|
||||
outputs=[response_text, status_display]
|
||||
outputs=[response_text, status_display],
|
||||
).then(
|
||||
fn=ui_handler.refresh_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
outputs=[
|
||||
call_dropdown,
|
||||
screenshot_image,
|
||||
conversation_chatbot,
|
||||
submit_btn,
|
||||
click_actions_group,
|
||||
actions_group,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# Toggle visibility of controls based on action type
|
||||
def toggle_action_controls(action_type):
|
||||
# Button visible only for click
|
||||
@@ -603,59 +630,63 @@ def create_ui():
|
||||
# Update state
|
||||
ui_handler.current_action_type = action_type or "click"
|
||||
return button_vis, scroll_x_vis, scroll_y_vis
|
||||
|
||||
|
||||
action_type_radio.change(
|
||||
fn=toggle_action_controls,
|
||||
inputs=[action_type_radio],
|
||||
outputs=[action_button_radio, scroll_x_input, scroll_y_input]
|
||||
outputs=[action_button_radio, scroll_x_input, scroll_y_input],
|
||||
)
|
||||
|
||||
# Keep other control values in ui_handler state
|
||||
def on_button_change(val):
|
||||
ui_handler.current_button = (val or "left")
|
||||
action_button_radio.change(
|
||||
fn=on_button_change,
|
||||
inputs=[action_button_radio]
|
||||
)
|
||||
ui_handler.current_button = val or "left"
|
||||
|
||||
action_button_radio.change(fn=on_button_change, inputs=[action_button_radio])
|
||||
|
||||
def on_scroll_x_change(val):
|
||||
try:
|
||||
ui_handler.current_scroll_x = int(val) if val is not None else 0
|
||||
except Exception:
|
||||
ui_handler.current_scroll_x = 0
|
||||
scroll_x_input.change(
|
||||
fn=on_scroll_x_change,
|
||||
inputs=[scroll_x_input]
|
||||
)
|
||||
|
||||
scroll_x_input.change(fn=on_scroll_x_change, inputs=[scroll_x_input])
|
||||
|
||||
def on_scroll_y_change(val):
|
||||
try:
|
||||
ui_handler.current_scroll_y = int(val) if val is not None else 0
|
||||
except Exception:
|
||||
ui_handler.current_scroll_y = 0
|
||||
scroll_y_input.change(
|
||||
fn=on_scroll_y_change,
|
||||
inputs=[scroll_y_input]
|
||||
)
|
||||
|
||||
|
||||
scroll_y_input.change(fn=on_scroll_y_change, inputs=[scroll_y_input])
|
||||
|
||||
type_submit_btn.click(
|
||||
fn=ui_handler.submit_type_action,
|
||||
inputs=[type_text],
|
||||
outputs=[status_display]
|
||||
fn=ui_handler.submit_type_action, inputs=[type_text], outputs=[status_display]
|
||||
).then(
|
||||
fn=ui_handler.wait_for_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
outputs=[
|
||||
call_dropdown,
|
||||
screenshot_image,
|
||||
conversation_chatbot,
|
||||
submit_btn,
|
||||
click_actions_group,
|
||||
actions_group,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
keypress_submit_btn.click(
|
||||
fn=ui_handler.submit_hotkey_action,
|
||||
inputs=[keypress_text],
|
||||
outputs=[status_display]
|
||||
fn=ui_handler.submit_hotkey_action, inputs=[keypress_text], outputs=[status_display]
|
||||
).then(
|
||||
fn=ui_handler.wait_for_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
outputs=[
|
||||
call_dropdown,
|
||||
screenshot_image,
|
||||
conversation_chatbot,
|
||||
submit_btn,
|
||||
click_actions_group,
|
||||
actions_group,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def handle_description_submit(description, action_type, button):
|
||||
if description:
|
||||
result = ui_handler.submit_description_click(description, action_type, button)
|
||||
@@ -666,12 +697,19 @@ def create_ui():
|
||||
description_submit_btn.click(
|
||||
fn=handle_description_submit,
|
||||
inputs=[description_text, description_action_type, description_button],
|
||||
outputs=[status_display]
|
||||
outputs=[status_display],
|
||||
).then(
|
||||
fn=ui_handler.wait_for_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
outputs=[
|
||||
call_dropdown,
|
||||
screenshot_image,
|
||||
conversation_chatbot,
|
||||
submit_btn,
|
||||
click_actions_group,
|
||||
actions_group,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# Misc action handler
|
||||
def handle_misc_submit(selected_action):
|
||||
if selected_action == "wait":
|
||||
@@ -681,20 +719,32 @@ def create_ui():
|
||||
return f"Unsupported misc action: {selected_action}"
|
||||
|
||||
misc_submit_btn.click(
|
||||
fn=handle_misc_submit,
|
||||
inputs=[misc_action_dropdown],
|
||||
outputs=[status_display]
|
||||
fn=handle_misc_submit, inputs=[misc_action_dropdown], outputs=[status_display]
|
||||
).then(
|
||||
fn=ui_handler.wait_for_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
outputs=[
|
||||
call_dropdown,
|
||||
screenshot_image,
|
||||
conversation_chatbot,
|
||||
submit_btn,
|
||||
click_actions_group,
|
||||
actions_group,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# Load initial data
|
||||
demo.load(
|
||||
fn=ui_handler.refresh_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
outputs=[
|
||||
call_dropdown,
|
||||
screenshot_image,
|
||||
conversation_chatbot,
|
||||
submit_btn,
|
||||
click_actions_group,
|
||||
actions_group,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
|
||||
@@ -8,21 +8,22 @@ Exports:
|
||||
- run_full_dataset(dataset, ...)
|
||||
- MCPComputerAgent
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from agent.computers import is_agent_computer
|
||||
from datasets import load_dataset, Dataset
|
||||
from hud.datasets import Task, run_dataset
|
||||
from datasets import Dataset, load_dataset
|
||||
from hud import trace
|
||||
from hud.datasets import Task, run_dataset
|
||||
|
||||
from .agent import MCPComputerAgent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Single-task runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_single_task(
|
||||
dataset: str | Dataset | list[dict[str, Any]],
|
||||
*,
|
||||
@@ -47,24 +48,20 @@ async def run_single_task(
|
||||
|
||||
# Load dataset and pick a sample
|
||||
if isinstance(dataset, str):
|
||||
dataset = load_dataset(dataset, split="train") # type: ignore[arg-type]
|
||||
dataset = load_dataset(dataset, split="train") # type: ignore[arg-type]
|
||||
elif isinstance(dataset, list):
|
||||
dataset = dataset
|
||||
else:
|
||||
dataset = dataset["train"]
|
||||
|
||||
|
||||
sample_task = dataset[task_id] # type: ignore[index]
|
||||
task_prompt = sample_task.get("prompt", f"Task {sample_task.get('id', 0)}") # type: ignore[attr-defined]
|
||||
|
||||
# Filter any existing Computer tools
|
||||
# The eval framework will add its own Computer tool per task
|
||||
if tools:
|
||||
tools = [
|
||||
tool
|
||||
for tool in tools
|
||||
if not is_agent_computer(tool)
|
||||
]
|
||||
|
||||
tools = [tool for tool in tools if not is_agent_computer(tool)]
|
||||
|
||||
with trace(name=task_prompt):
|
||||
task = Task(**sample_task) # type: ignore[arg-type]
|
||||
|
||||
@@ -87,13 +84,14 @@ async def run_single_task(
|
||||
)
|
||||
print(f"Running: {task_prompt}")
|
||||
result = await agent.run(task, max_steps=10)
|
||||
print(f"✅ Reward: {getattr(result, 'reward')}")
|
||||
print(f"✅ Reward: {result.reward}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full-dataset runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_full_dataset(
|
||||
dataset: str | Dataset | list[dict[str, Any]],
|
||||
*,
|
||||
@@ -121,9 +119,9 @@ async def run_full_dataset(
|
||||
|
||||
# Run with our MCP-based agent class.
|
||||
if isinstance(dataset, str):
|
||||
dataset_name = dataset.split('/')[-1]
|
||||
dataset_name = dataset.split("/")[-1]
|
||||
job_name = job_name or f"Evaluation {dataset_name}"
|
||||
dataset = load_dataset(dataset, split=split) # type: ignore[arg-type]
|
||||
dataset = load_dataset(dataset, split=split) # type: ignore[arg-type]
|
||||
else:
|
||||
dataset_name = "custom"
|
||||
job_name = job_name or f"Evaluation {time.strftime('%H:%M %Y-%m-%d')}"
|
||||
@@ -131,12 +129,8 @@ async def run_full_dataset(
|
||||
# Filter any existing Computer tools
|
||||
# The eval framework will add its own Computer tool per task
|
||||
if tools:
|
||||
tools = [
|
||||
tool
|
||||
for tool in tools
|
||||
if not is_agent_computer(tool)
|
||||
]
|
||||
|
||||
tools = [tool for tool in tools if not is_agent_computer(tool)]
|
||||
|
||||
# Execute evaluation
|
||||
return await run_dataset(
|
||||
name=job_name,
|
||||
@@ -170,4 +164,4 @@ __all__ = [
|
||||
"run_single_task",
|
||||
"run_full_dataset",
|
||||
"MCPComputerAgent",
|
||||
]
|
||||
]
|
||||
|
||||
@@ -9,26 +9,26 @@ Key differences from the OpenAI OperatorAgent variant:
|
||||
- Planning is executed via `ComputerAgent.run(messages)`.
|
||||
- The first yielded result per step is returned as the agent response.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Optional
|
||||
|
||||
import hud
|
||||
import mcp.types as types
|
||||
from agent.agent import ComputerAgent as BaseComputerAgent
|
||||
from agent.callbacks import PromptInstructionsCallback
|
||||
from agent.callbacks.trajectory_saver import TrajectorySaverCallback
|
||||
from agent.computers import is_agent_computer
|
||||
from agent.responses import make_failed_tool_call_items
|
||||
from hud.agents import MCPAgent
|
||||
from hud.tools.computer.settings import computer_settings
|
||||
from hud.types import AgentResponse, MCPToolCall, MCPToolResult, Trace
|
||||
|
||||
from agent.responses import make_failed_tool_call_items
|
||||
from agent.computers import is_agent_computer
|
||||
from PIL import Image
|
||||
import mcp.types as types
|
||||
import hud
|
||||
import uuid
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class MCPComputerAgent(MCPAgent):
|
||||
@@ -114,8 +114,10 @@ class MCPComputerAgent(MCPAgent):
|
||||
self.last_screenshot_b64 = None
|
||||
|
||||
buffer = io.BytesIO()
|
||||
Image.new('RGB', (self.metadata["display_width"], self.metadata["display_height"])).save(buffer, format='PNG')
|
||||
self.last_screenshot_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
Image.new("RGB", (self.metadata["display_width"], self.metadata["display_height"])).save(
|
||||
buffer, format="PNG"
|
||||
)
|
||||
self.last_screenshot_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
# Ensure a computer shim is present so width/height/environment are known
|
||||
computer_shim = {
|
||||
@@ -128,12 +130,8 @@ class MCPComputerAgent(MCPAgent):
|
||||
}
|
||||
agent_tools: list[Any] = [computer_shim]
|
||||
if tools:
|
||||
agent_tools.extend([
|
||||
tool
|
||||
for tool in tools
|
||||
if not is_agent_computer(tool)
|
||||
])
|
||||
|
||||
agent_tools.extend([tool for tool in tools if not is_agent_computer(tool)])
|
||||
|
||||
agent_kwargs = {
|
||||
"model": self.model,
|
||||
"trajectory_dir": trajectory_dir,
|
||||
@@ -150,9 +148,7 @@ class MCPComputerAgent(MCPAgent):
|
||||
"telemetry_enabled": telemetry_enabled,
|
||||
}
|
||||
|
||||
self.computer_agent = BaseComputerAgent(
|
||||
**agent_kwargs
|
||||
)
|
||||
self.computer_agent = BaseComputerAgent(**agent_kwargs)
|
||||
|
||||
async def get_system_messages(self) -> list[Any]:
|
||||
"""Create initial messages.
|
||||
@@ -161,9 +157,7 @@ class MCPComputerAgent(MCPAgent):
|
||||
"""
|
||||
return []
|
||||
|
||||
async def format_blocks(
|
||||
self, blocks: list[types.ContentBlock]
|
||||
) -> list[dict[str, Any]]:
|
||||
async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Format blocks for OpenAI input format.
|
||||
|
||||
@@ -200,42 +194,49 @@ class MCPComputerAgent(MCPAgent):
|
||||
|
||||
# Call the ComputerAgent LLM API
|
||||
async for result in self.computer_agent.run(messages): # type: ignore[arg-type]
|
||||
items = result['output']
|
||||
items = result["output"]
|
||||
if not items or tool_calls:
|
||||
break
|
||||
|
||||
for item in items:
|
||||
if item['type'] in ['reasoning', 'message', 'computer_call', 'function_call', 'function_call_output']:
|
||||
if item["type"] in [
|
||||
"reasoning",
|
||||
"message",
|
||||
"computer_call",
|
||||
"function_call",
|
||||
"function_call_output",
|
||||
]:
|
||||
agent_result.append(item)
|
||||
|
||||
|
||||
# Add messages to output text
|
||||
if item['type'] == 'reasoning':
|
||||
if item["type"] == "reasoning":
|
||||
output_text.extend(
|
||||
f"Reasoning: {summary['text']}"
|
||||
for summary in item['summary']
|
||||
f"Reasoning: {summary['text']}" for summary in item["summary"]
|
||||
)
|
||||
elif item['type'] == 'message':
|
||||
if isinstance(item['content'], list):
|
||||
elif item["type"] == "message":
|
||||
if isinstance(item["content"], list):
|
||||
output_text.extend(
|
||||
item['text']
|
||||
for item in item['content']
|
||||
if item['type'] == 'output_text'
|
||||
item["text"]
|
||||
for item in item["content"]
|
||||
if item["type"] == "output_text"
|
||||
)
|
||||
elif isinstance(item['content'], str):
|
||||
output_text.append(item['content'])
|
||||
|
||||
elif isinstance(item["content"], str):
|
||||
output_text.append(item["content"])
|
||||
|
||||
# If we get a tool call, we're not done
|
||||
if item['type'] == 'computer_call':
|
||||
if item["type"] == "computer_call":
|
||||
id = item["call_id"]
|
||||
tool_calls.append(MCPToolCall(
|
||||
name="openai_computer",
|
||||
arguments=item["action"],
|
||||
id=id,
|
||||
))
|
||||
tool_calls.append(
|
||||
MCPToolCall(
|
||||
name="openai_computer",
|
||||
arguments=item["action"],
|
||||
id=id,
|
||||
)
|
||||
)
|
||||
is_done = False
|
||||
self.tool_call_inputs[id] = agent_result
|
||||
break
|
||||
|
||||
|
||||
# if we have tool calls, we should exit the loop
|
||||
if tool_calls:
|
||||
break
|
||||
@@ -247,7 +248,7 @@ class MCPComputerAgent(MCPAgent):
|
||||
tool_calls=tool_calls,
|
||||
done=is_done,
|
||||
)
|
||||
|
||||
|
||||
def _log_image(self, image_b64: str):
|
||||
callbacks = self.computer_agent.callbacks
|
||||
for callback in callbacks:
|
||||
@@ -257,9 +258,7 @@ class MCPComputerAgent(MCPAgent):
|
||||
callback._save_artifact("screenshot_after", image_bytes)
|
||||
|
||||
async def format_tool_results(
|
||||
self,
|
||||
tool_calls: list[MCPToolCall],
|
||||
tool_results: list[MCPToolResult]
|
||||
self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Extract latest screenshot from tool results in dict form.
|
||||
|
||||
@@ -274,45 +273,60 @@ class MCPComputerAgent(MCPAgent):
|
||||
previous_output = self.previous_output.copy() or []
|
||||
|
||||
# First we need to remove any pending computer_calls from the end of previous_output
|
||||
while previous_output and previous_output[-1]['type'] == 'computer_call':
|
||||
while previous_output and previous_output[-1]["type"] == "computer_call":
|
||||
previous_output.pop()
|
||||
messages.extend(previous_output)
|
||||
|
||||
# If the call is a 'response', don't add the result
|
||||
if call.name == 'response':
|
||||
if call.name == "response":
|
||||
continue
|
||||
# Otherwise, if we have a result, we should add it to the messages
|
||||
content = [
|
||||
{ "type": "input_text", "text": content.text } if isinstance(content, types.TextContent)
|
||||
else { "type": "input_image", "image_url": f"data:image/png;base64,{content.data}" } if isinstance(content, types.ImageContent)
|
||||
else { "type": "input_text", "text": "" }
|
||||
(
|
||||
{"type": "input_text", "text": content.text}
|
||||
if isinstance(content, types.TextContent)
|
||||
else (
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{content.data}",
|
||||
}
|
||||
if isinstance(content, types.ImageContent)
|
||||
else {"type": "input_text", "text": ""}
|
||||
)
|
||||
)
|
||||
for content in result.content
|
||||
]
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": content,
|
||||
})
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
|
||||
# Add the assistant's computer call
|
||||
messages.extend(self.tool_call_inputs[call.id])
|
||||
|
||||
|
||||
if result.isError:
|
||||
error_text = "".join([
|
||||
content.text
|
||||
for content in result.content
|
||||
if isinstance(content, types.TextContent)
|
||||
])
|
||||
error_text = "".join(
|
||||
[
|
||||
content.text
|
||||
for content in result.content
|
||||
if isinstance(content, types.TextContent)
|
||||
]
|
||||
)
|
||||
|
||||
# Replace computer call with failed tool call
|
||||
messages.pop()
|
||||
messages.extend(make_failed_tool_call_items(
|
||||
tool_name=call.name,
|
||||
tool_kwargs=call.arguments or {},
|
||||
error_message=error_text,
|
||||
call_id=call.id,
|
||||
))
|
||||
messages.extend(
|
||||
make_failed_tool_call_items(
|
||||
tool_name=call.name,
|
||||
tool_kwargs=call.arguments or {},
|
||||
error_message=error_text,
|
||||
call_id=call.id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Get the latest screenshot
|
||||
screenshots = [
|
||||
@@ -325,23 +339,27 @@ class MCPComputerAgent(MCPAgent):
|
||||
if screenshots:
|
||||
self._log_image(screenshots[0])
|
||||
self.last_screenshot_b64 = screenshots[0]
|
||||
messages.append({
|
||||
"type": "computer_call_output",
|
||||
"call_id": call.id,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshots[0]}"
|
||||
},
|
||||
})
|
||||
messages.append(
|
||||
{
|
||||
"type": "computer_call_output",
|
||||
"call_id": call.id,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshots[0]}",
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Otherwise, replace computer call with failed tool call
|
||||
messages.pop()
|
||||
messages.extend(make_failed_tool_call_items(
|
||||
tool_name=call.name,
|
||||
tool_kwargs=call.arguments or {},
|
||||
error_message="No screenshots returned.",
|
||||
call_id=call.id,
|
||||
))
|
||||
messages.extend(
|
||||
make_failed_tool_call_items(
|
||||
tool_name=call.name,
|
||||
tool_kwargs=call.arguments or {},
|
||||
error_message="No screenshots returned.",
|
||||
call_id=call.id,
|
||||
)
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@@ -7,30 +7,33 @@ OpenAI-like response blocks. We intentionally only support a single-step call
|
||||
by consuming the first yielded result from `ComputerAgent.run()`.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.agent import ComputerAgent as BaseComputerAgent
|
||||
from agent.callbacks import PromptInstructionsCallback
|
||||
from hud.tools.computer.settings import computer_settings
|
||||
from PIL import Image
|
||||
from hud.agents import OperatorAgent
|
||||
from hud.tools.computer.settings import computer_settings
|
||||
|
||||
# OpenAI Responses typed models (required)
|
||||
from openai.types.responses import (
|
||||
Response,
|
||||
ResponseComputerToolCall,
|
||||
ResponseInputParam,
|
||||
ResponseOutputItem,
|
||||
ResponseComputerToolCall,
|
||||
ResponseOutputMessage,
|
||||
ResponseOutputText,
|
||||
ResponseReasoningItem,
|
||||
ResponseUsage,
|
||||
)
|
||||
from PIL import Image
|
||||
|
||||
def _map_agent_output_to_openai_blocks(output_items: List[Dict[str, Any]]) -> List[ResponseOutputItem]:
|
||||
|
||||
def _map_agent_output_to_openai_blocks(
|
||||
output_items: List[Dict[str, Any]],
|
||||
) -> List[ResponseOutputItem]:
|
||||
"""Map our agent output items to OpenAI ResponseOutputItem typed models.
|
||||
|
||||
Only a subset is supported: computer_call, assistant message (text), and reasoning.
|
||||
@@ -40,14 +43,16 @@ def _map_agent_output_to_openai_blocks(output_items: List[Dict[str, Any]]) -> Li
|
||||
for item in output_items or []:
|
||||
t = item.get("type")
|
||||
if t == "computer_call":
|
||||
comp = ResponseComputerToolCall.model_validate({
|
||||
"id": item.get("id") or f"cu_{uuid.uuid4().hex}",
|
||||
"type": "computer_call",
|
||||
"call_id": item["call_id"],
|
||||
"action": item["action"],
|
||||
"pending_safety_checks": item.get("pending_safety_checks", []),
|
||||
"status": "completed",
|
||||
})
|
||||
comp = ResponseComputerToolCall.model_validate(
|
||||
{
|
||||
"id": item.get("id") or f"cu_{uuid.uuid4().hex}",
|
||||
"type": "computer_call",
|
||||
"call_id": item["call_id"],
|
||||
"action": item["action"],
|
||||
"pending_safety_checks": item.get("pending_safety_checks", []),
|
||||
"status": "completed",
|
||||
}
|
||||
)
|
||||
blocks.append(comp)
|
||||
# we will exit early here as the responses api only supports a single step
|
||||
break
|
||||
@@ -55,31 +60,38 @@ def _map_agent_output_to_openai_blocks(output_items: List[Dict[str, Any]]) -> Li
|
||||
content_blocks: List[ResponseOutputText] = []
|
||||
for c in item.get("content", []) or []:
|
||||
content_blocks.append(
|
||||
ResponseOutputText.model_validate({
|
||||
"type": "output_text",
|
||||
"text": c["text"],
|
||||
"annotations": [],
|
||||
})
|
||||
ResponseOutputText.model_validate(
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": c["text"],
|
||||
"annotations": [],
|
||||
}
|
||||
)
|
||||
)
|
||||
if content_blocks:
|
||||
msg = ResponseOutputMessage.model_validate({
|
||||
"id": item.get("id") or f"msg_{uuid.uuid4()}",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": [ct.model_dump() for ct in content_blocks],
|
||||
})
|
||||
msg = ResponseOutputMessage.model_validate(
|
||||
{
|
||||
"id": item.get("id") or f"msg_{uuid.uuid4()}",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": [ct.model_dump() for ct in content_blocks],
|
||||
}
|
||||
)
|
||||
blocks.append(msg)
|
||||
elif t == "reasoning":
|
||||
reasoning = ResponseReasoningItem.model_validate({
|
||||
"id": item.get("id") or f"rsn_{uuid.uuid4()}",
|
||||
"type": "reasoning",
|
||||
"summary": item["summary"],
|
||||
})
|
||||
reasoning = ResponseReasoningItem.model_validate(
|
||||
{
|
||||
"id": item.get("id") or f"rsn_{uuid.uuid4()}",
|
||||
"type": "reasoning",
|
||||
"summary": item["summary"],
|
||||
}
|
||||
)
|
||||
blocks.append(reasoning)
|
||||
# Unhandled types are ignored
|
||||
return blocks
|
||||
|
||||
|
||||
def _to_plain_dict_list(items: Any) -> List[Dict[str, Any]]:
|
||||
out: List[Dict[str, Any]] = []
|
||||
for it in list(items):
|
||||
@@ -92,6 +104,7 @@ def _to_plain_dict_list(items: Any) -> List[Dict[str, Any]]:
|
||||
out.append(dict(it)) # may raise if not mapping
|
||||
return out
|
||||
|
||||
|
||||
class FakeAsyncOpenAI:
|
||||
"""Minimal fake OpenAI client with only `responses.create` implemented.
|
||||
|
||||
@@ -132,10 +145,12 @@ class FakeAsyncOpenAI:
|
||||
# Pre-pend instructions message
|
||||
effective_input = full_input
|
||||
if instructions:
|
||||
effective_input = [{
|
||||
"role": "user",
|
||||
"content": instructions,
|
||||
}] + full_input
|
||||
effective_input = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": instructions,
|
||||
}
|
||||
] + full_input
|
||||
|
||||
# Run a single iteration of the ComputerAgent
|
||||
agent_result: Optional[Dict[str, Any]] = None
|
||||
@@ -152,32 +167,43 @@ class FakeAsyncOpenAI:
|
||||
blocks_to_cache = full_input + output
|
||||
for b in blocks_to_cache:
|
||||
bid = getattr(b, "id", None) or f"tmp-{hash(repr(b))}"
|
||||
self.blocks_cache[bid] = b # type: ignore[assignment]
|
||||
self.blocks_cache[bid] = b # type: ignore[assignment]
|
||||
block_ids.append(bid)
|
||||
response_id = agent_result.get("id") or f"fake-{int(time.time()*1000)}"
|
||||
self.context_cache[response_id] = block_ids
|
||||
|
||||
try:
|
||||
return Response.model_validate({
|
||||
"id": response_id,
|
||||
"created_at": time.time(),
|
||||
"object": "response",
|
||||
"model": model,
|
||||
"output": output,
|
||||
"parallel_tool_calls": False,
|
||||
"tool_choice": "auto",
|
||||
"tools": [],
|
||||
"previous_response_id": previous_response_id,
|
||||
"usage": ResponseUsage.model_validate({
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
"input_tokens_details": usage.get("input_tokens_details", { "cached_tokens": 0 }),
|
||||
"output_tokens_details": usage.get("output_tokens_details", { "reasoning_tokens": 0 }),
|
||||
}),
|
||||
})
|
||||
return Response.model_validate(
|
||||
{
|
||||
"id": response_id,
|
||||
"created_at": time.time(),
|
||||
"object": "response",
|
||||
"model": model,
|
||||
"output": output,
|
||||
"parallel_tool_calls": False,
|
||||
"tool_choice": "auto",
|
||||
"tools": [],
|
||||
"previous_response_id": previous_response_id,
|
||||
"usage": ResponseUsage.model_validate(
|
||||
{
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
"input_tokens_details": usage.get(
|
||||
"input_tokens_details", {"cached_tokens": 0}
|
||||
),
|
||||
"output_tokens_details": usage.get(
|
||||
"output_tokens_details", {"reasoning_tokens": 0}
|
||||
),
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error while validating agent response (attempt {attempt + 1}/{max_retries}): ", e)
|
||||
print(
|
||||
f"Error while validating agent response (attempt {attempt + 1}/{max_retries}): ",
|
||||
e,
|
||||
)
|
||||
if attempt == max_retries - 1:
|
||||
print(traceback.format_exc())
|
||||
raise e
|
||||
@@ -221,9 +247,15 @@ class ProxyOperatorAgent(OperatorAgent):
|
||||
allowed_tools = allowed_tools or ["openai_computer"]
|
||||
|
||||
computer_shim = {
|
||||
'screenshot': lambda: Image.new('RGB', (computer_settings.OPENAI_COMPUTER_WIDTH, computer_settings.OPENAI_COMPUTER_HEIGHT)),
|
||||
'environment': 'linux',
|
||||
'dimensions': (computer_settings.OPENAI_COMPUTER_WIDTH, computer_settings.OPENAI_COMPUTER_HEIGHT)
|
||||
"screenshot": lambda: Image.new(
|
||||
"RGB",
|
||||
(computer_settings.OPENAI_COMPUTER_WIDTH, computer_settings.OPENAI_COMPUTER_HEIGHT),
|
||||
),
|
||||
"environment": "linux",
|
||||
"dimensions": (
|
||||
computer_settings.OPENAI_COMPUTER_WIDTH,
|
||||
computer_settings.OPENAI_COMPUTER_HEIGHT,
|
||||
),
|
||||
}
|
||||
# Build tools ensuring the computer_shim is included
|
||||
agent_tools: list[Any] = [computer_shim]
|
||||
@@ -258,6 +290,7 @@ class ProxyOperatorAgent(OperatorAgent):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FakeAsyncOpenAI",
|
||||
"ProxyOperatorAgent",
|
||||
|
||||
@@ -3,26 +3,34 @@ Agent loops for agent
|
||||
"""
|
||||
|
||||
# Import the loops to register them
|
||||
from . import anthropic
|
||||
from . import openai
|
||||
from . import uitars
|
||||
from . import omniparser
|
||||
from . import gta1
|
||||
from . import composed_grounded
|
||||
from . import glm45v
|
||||
from . import opencua
|
||||
from . import internvl
|
||||
from . import holo
|
||||
from . import (
|
||||
anthropic,
|
||||
composed_grounded,
|
||||
gemini,
|
||||
glm45v,
|
||||
gta1,
|
||||
holo,
|
||||
internvl,
|
||||
moondream3,
|
||||
omniparser,
|
||||
openai,
|
||||
opencua,
|
||||
qwen,
|
||||
uitars,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"anthropic",
|
||||
"openai",
|
||||
"uitars",
|
||||
"omniparser",
|
||||
"gta1",
|
||||
"composed_grounded",
|
||||
"glm45v",
|
||||
"anthropic",
|
||||
"openai",
|
||||
"uitars",
|
||||
"omniparser",
|
||||
"gta1",
|
||||
"composed_grounded",
|
||||
"glm45v",
|
||||
"opencua",
|
||||
"internvl",
|
||||
"holo",
|
||||
]
|
||||
"moondream3",
|
||||
"gemini",
|
||||
"qwen",
|
||||
]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,13 +2,15 @@
|
||||
Base protocol for async agent configurations
|
||||
"""
|
||||
|
||||
from typing import Protocol, List, Dict, Any, Optional, Tuple, Union
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
|
||||
|
||||
from ..types import AgentCapability
|
||||
|
||||
|
||||
class AsyncAgentConfig(Protocol):
|
||||
"""Protocol defining the interface for async agent configurations."""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def predict_step(
|
||||
self,
|
||||
@@ -22,11 +24,11 @@ class AsyncAgentConfig(Protocol):
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Predict the next step based on input items.
|
||||
|
||||
|
||||
Args:
|
||||
messages: Input items following Responses format (message, function_call, computer_call)
|
||||
model: Model name to use
|
||||
@@ -39,37 +41,34 @@ class AsyncAgentConfig(Protocol):
|
||||
_on_usage: Callback for usage tracking
|
||||
_on_screenshot: Callback for screenshot events
|
||||
**kwargs: Additional arguments
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with "output" (output items) and "usage" array
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str
|
||||
self, model: str, image_b64: str, instruction: str
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates based on image and instruction.
|
||||
|
||||
|
||||
Args:
|
||||
model: Model name to use
|
||||
image_b64: Base64 encoded image
|
||||
instruction: Instruction for where to click
|
||||
|
||||
|
||||
Returns:
|
||||
None or tuple with (x, y) coordinates
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""
|
||||
Get list of capabilities supported by this agent config.
|
||||
|
||||
|
||||
Returns:
|
||||
List of capability strings (e.g., ["step", "click"])
|
||||
"""
|
||||
|
||||
@@ -3,122 +3,117 @@ Composed-grounded agent loop implementation that combines grounding and thinking
|
||||
Uses a two-stage approach: grounding model for element detection, thinking model for reasoning.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
import json
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import litellm
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from PIL import Image
|
||||
|
||||
from ..agent import find_agent_config
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..responses import (
|
||||
convert_computer_calls_xy2desc,
|
||||
convert_responses_items_to_completion_messages,
|
||||
convert_completion_messages_to_responses_items,
|
||||
convert_computer_calls_desc2xy,
|
||||
get_all_element_descriptions
|
||||
convert_computer_calls_xy2desc,
|
||||
convert_responses_items_to_completion_messages,
|
||||
get_all_element_descriptions,
|
||||
)
|
||||
from ..agent import find_agent_config
|
||||
from ..types import AgentCapability, AgentResponse, Messages, Tools
|
||||
|
||||
GROUNDED_COMPUTER_TOOL_SCHEMA = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "computer",
|
||||
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool uses element descriptions to locate and interact with UI elements on the screen (e.g., 'red submit button', 'search text field', 'hamburger menu icon', 'close button in top right corner').",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"screenshot",
|
||||
"click",
|
||||
"double_click",
|
||||
"drag",
|
||||
"type",
|
||||
"keypress",
|
||||
"scroll",
|
||||
"move",
|
||||
"wait",
|
||||
"get_current_url",
|
||||
"get_dimensions",
|
||||
"get_environment"
|
||||
],
|
||||
"description": "The action to perform (required for all actions)"
|
||||
},
|
||||
"element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to interact with (required for click, double_click, move, scroll actions)"
|
||||
},
|
||||
"start_element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to start dragging from (required for drag action)"
|
||||
},
|
||||
"end_element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to drag to (required for drag action)"
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to type (required for type action)"
|
||||
},
|
||||
"keys": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "computer",
|
||||
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool uses element descriptions to locate and interact with UI elements on the screen (e.g., 'red submit button', 'search text field', 'hamburger menu icon', 'close button in top right corner').",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"screenshot",
|
||||
"click",
|
||||
"double_click",
|
||||
"drag",
|
||||
"type",
|
||||
"keypress",
|
||||
"scroll",
|
||||
"move",
|
||||
"wait",
|
||||
"get_current_url",
|
||||
"get_dimensions",
|
||||
"get_environment",
|
||||
],
|
||||
"description": "The action to perform (required for all actions)",
|
||||
},
|
||||
"element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to interact with (required for click, double_click, move, scroll actions)",
|
||||
},
|
||||
"start_element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to start dragging from (required for drag action)",
|
||||
},
|
||||
"end_element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to drag to (required for drag action)",
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to type (required for type action)",
|
||||
},
|
||||
"keys": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Key(s) to press (required for keypress action)",
|
||||
},
|
||||
"button": {
|
||||
"type": "string",
|
||||
"enum": ["left", "right", "wheel", "back", "forward"],
|
||||
"description": "The mouse button to use for click action (required for click and double_click action)",
|
||||
},
|
||||
"scroll_x": {
|
||||
"type": "integer",
|
||||
"description": "Horizontal scroll amount for scroll action (required for scroll action)",
|
||||
},
|
||||
"scroll_y": {
|
||||
"type": "integer",
|
||||
"description": "Vertical scroll amount for scroll action (required for scroll action)",
|
||||
},
|
||||
},
|
||||
"description": "Key(s) to press (required for keypress action)"
|
||||
"required": ["action"],
|
||||
},
|
||||
"button": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"left",
|
||||
"right",
|
||||
"wheel",
|
||||
"back",
|
||||
"forward"
|
||||
],
|
||||
"description": "The mouse button to use for click action (required for click and double_click action)",
|
||||
},
|
||||
"scroll_x": {
|
||||
"type": "integer",
|
||||
"description": "Horizontal scroll amount for scroll action (required for scroll action)",
|
||||
},
|
||||
"scroll_y": {
|
||||
"type": "integer",
|
||||
"description": "Vertical scroll amount for scroll action (required for scroll action)",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"action"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _prepare_tools_for_grounded(tool_schemas: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Prepare tools for grounded API format"""
|
||||
grounded_tools = []
|
||||
|
||||
|
||||
for schema in tool_schemas:
|
||||
if schema["type"] == "computer":
|
||||
grounded_tools.append(GROUNDED_COMPUTER_TOOL_SCHEMA)
|
||||
else:
|
||||
grounded_tools.append(schema)
|
||||
|
||||
|
||||
return grounded_tools
|
||||
|
||||
|
||||
def get_last_computer_call_image(messages: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""Get the last computer call output image from messages."""
|
||||
for message in reversed(messages):
|
||||
if (isinstance(message, dict) and
|
||||
message.get("type") == "computer_call_output" and
|
||||
isinstance(message.get("output"), dict) and
|
||||
message["output"].get("type") == "input_image"):
|
||||
if (
|
||||
isinstance(message, dict)
|
||||
and message.get("type") == "computer_call_output"
|
||||
and isinstance(message.get("output"), dict)
|
||||
and message["output"].get("type") == "input_image"
|
||||
):
|
||||
image_url = message["output"].get("image_url", "")
|
||||
if image_url.startswith("data:image/png;base64,"):
|
||||
return image_url.split(",", 1)[1]
|
||||
@@ -129,14 +124,14 @@ def get_last_computer_call_image(messages: List[Dict[str, Any]]) -> Optional[str
|
||||
class ComposedGroundedConfig(AsyncAgentConfig):
|
||||
"""
|
||||
Composed-grounded agent configuration that uses both grounding and thinking models.
|
||||
|
||||
|
||||
The model parameter should be in format: "grounding_model+thinking_model"
|
||||
e.g., "huggingface-local/HelloKKMe/GTA1-7B+gemini/gemini-1.5-pro"
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.desc2xy: Dict[str, Tuple[float, float]] = {}
|
||||
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
@@ -150,11 +145,11 @@ class ComposedGroundedConfig(AsyncAgentConfig):
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Composed-grounded predict step implementation.
|
||||
|
||||
|
||||
Process:
|
||||
0. Store last computer call image, if none then take a screenshot
|
||||
1. Convert computer calls from xy to descriptions
|
||||
@@ -167,18 +162,20 @@ class ComposedGroundedConfig(AsyncAgentConfig):
|
||||
"""
|
||||
# Parse the composed model
|
||||
if "+" not in model:
|
||||
raise ValueError(f"Composed model must be in format 'grounding_model+thinking_model', got: {model}")
|
||||
raise ValueError(
|
||||
f"Composed model must be in format 'grounding_model+thinking_model', got: {model}"
|
||||
)
|
||||
grounding_model, thinking_model = model.split("+", 1)
|
||||
|
||||
|
||||
pre_output_items = []
|
||||
|
||||
|
||||
# Step 0: Store last computer call image, if none then take a screenshot
|
||||
last_image_b64 = get_last_computer_call_image(messages)
|
||||
if last_image_b64 is None:
|
||||
# Take a screenshot
|
||||
screenshot_b64 = await computer_handler.screenshot() # type: ignore
|
||||
screenshot_b64 = await computer_handler.screenshot() # type: ignore
|
||||
if screenshot_b64:
|
||||
|
||||
|
||||
call_id = uuid.uuid4().hex
|
||||
pre_output_items += [
|
||||
{
|
||||
@@ -187,45 +184,42 @@ class ComposedGroundedConfig(AsyncAgentConfig):
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": "Taking a screenshot to see the current computer screen."
|
||||
"text": "Taking a screenshot to see the current computer screen.",
|
||||
}
|
||||
]
|
||||
],
|
||||
},
|
||||
{
|
||||
"action": {
|
||||
"type": "screenshot"
|
||||
},
|
||||
"action": {"type": "screenshot"},
|
||||
"call_id": call_id,
|
||||
"status": "completed",
|
||||
"type": "computer_call"
|
||||
"type": "computer_call",
|
||||
},
|
||||
{
|
||||
"type": "computer_call_output",
|
||||
"call_id": call_id,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshot_b64}"
|
||||
}
|
||||
"image_url": f"data:image/png;base64,{screenshot_b64}",
|
||||
},
|
||||
},
|
||||
]
|
||||
last_image_b64 = screenshot_b64
|
||||
|
||||
|
||||
# Call screenshot callback if provided
|
||||
if _on_screenshot:
|
||||
await _on_screenshot(screenshot_b64)
|
||||
|
||||
tool_schemas = _prepare_tools_for_grounded(tools) # type: ignore
|
||||
|
||||
tool_schemas = _prepare_tools_for_grounded(tools) # type: ignore
|
||||
|
||||
# Step 1: Convert computer calls from xy to descriptions
|
||||
input_messages = messages + pre_output_items
|
||||
messages_with_descriptions = convert_computer_calls_xy2desc(input_messages, self.desc2xy)
|
||||
|
||||
|
||||
# Step 2: Convert responses items to completion messages
|
||||
completion_messages = convert_responses_items_to_completion_messages(
|
||||
messages_with_descriptions,
|
||||
allow_images_in_tool_results=False
|
||||
messages_with_descriptions, allow_images_in_tool_results=False
|
||||
)
|
||||
|
||||
|
||||
# Step 3: Call thinking model with litellm.acompletion
|
||||
api_kwargs = {
|
||||
"model": thinking_model,
|
||||
@@ -233,98 +227,90 @@ class ComposedGroundedConfig(AsyncAgentConfig):
|
||||
"tools": tool_schemas,
|
||||
"max_retries": max_retries,
|
||||
"stream": stream,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if use_prompt_caching:
|
||||
api_kwargs["use_prompt_caching"] = use_prompt_caching
|
||||
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
|
||||
# Make the completion call
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
|
||||
# Extract usage information
|
||||
usage = {
|
||||
**response.usage.model_dump(), # type: ignore
|
||||
**response.usage.model_dump(), # type: ignore
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(usage)
|
||||
|
||||
|
||||
# Step 4: Convert completion messages back to responses items format
|
||||
response_dict = response.model_dump() # type: ignore
|
||||
response_dict = response.model_dump() # type: ignore
|
||||
choice_messages = [choice["message"] for choice in response_dict["choices"]]
|
||||
thinking_output_items = []
|
||||
|
||||
|
||||
for choice_message in choice_messages:
|
||||
thinking_output_items.extend(convert_completion_messages_to_responses_items([choice_message]))
|
||||
|
||||
thinking_output_items.extend(
|
||||
convert_completion_messages_to_responses_items([choice_message])
|
||||
)
|
||||
|
||||
# Step 5: Get all element descriptions and populate desc2xy mapping
|
||||
element_descriptions = get_all_element_descriptions(thinking_output_items)
|
||||
|
||||
|
||||
if element_descriptions and last_image_b64:
|
||||
# Use grounding model to predict coordinates for each description
|
||||
grounding_agent_conf = find_agent_config(grounding_model)
|
||||
if grounding_agent_conf:
|
||||
grounding_agent = grounding_agent_conf.agent_class()
|
||||
|
||||
|
||||
for desc in element_descriptions:
|
||||
for _ in range(3): # try 3 times
|
||||
for _ in range(3): # try 3 times
|
||||
coords = await grounding_agent.predict_click(
|
||||
model=grounding_model,
|
||||
image_b64=last_image_b64,
|
||||
instruction=desc
|
||||
model=grounding_model, image_b64=last_image_b64, instruction=desc
|
||||
)
|
||||
if coords:
|
||||
self.desc2xy[desc] = coords
|
||||
break
|
||||
|
||||
|
||||
# Step 6: Convert computer calls from descriptions back to xy coordinates
|
||||
final_output_items = convert_computer_calls_desc2xy(thinking_output_items, self.desc2xy)
|
||||
|
||||
|
||||
# Step 7: Return output and usage
|
||||
return {
|
||||
"output": pre_output_items + final_output_items,
|
||||
"usage": usage
|
||||
}
|
||||
|
||||
return {"output": pre_output_items + final_output_items, "usage": usage}
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str,
|
||||
**kwargs
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates using the grounding model.
|
||||
|
||||
|
||||
For composed models, uses only the grounding model part for click prediction.
|
||||
"""
|
||||
# Parse the composed model to get grounding model
|
||||
if "+" not in model:
|
||||
raise ValueError(f"Composed model must be in format 'grounding_model+thinking_model', got: {model}")
|
||||
raise ValueError(
|
||||
f"Composed model must be in format 'grounding_model+thinking_model', got: {model}"
|
||||
)
|
||||
grounding_model, thinking_model = model.split("+", 1)
|
||||
|
||||
|
||||
# Find and use the grounding agent
|
||||
grounding_agent_conf = find_agent_config(grounding_model)
|
||||
if grounding_agent_conf:
|
||||
grounding_agent = grounding_agent_conf.agent_class()
|
||||
return await grounding_agent.predict_click(
|
||||
model=grounding_model,
|
||||
image_b64=image_b64,
|
||||
instruction=instruction,
|
||||
**kwargs
|
||||
model=grounding_model, image_b64=image_b64, instruction=instruction, **kwargs
|
||||
)
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""Return the capabilities supported by this agent."""
|
||||
return ["click", "step"]
|
||||
|
||||
410
libs/python/agent/agent/loops/gemini.py
Normal file
410
libs/python/agent/agent/loops/gemini.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""
|
||||
Gemini 2.5 Computer Use agent loop
|
||||
|
||||
Maps internal Agent SDK message format to Google's Gemini Computer Use API and back.
|
||||
|
||||
Key features:
|
||||
- Lazy import of google.genai
|
||||
- Configure Computer Use tool with excluded browser-specific predefined functions
|
||||
- Optional custom function declarations hook for computer-call specific functions
|
||||
- Convert Gemini function_call parts into internal computer_call actions
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..types import AgentCapability
|
||||
|
||||
|
||||
def _lazy_import_genai():
|
||||
"""Import google.genai lazily to avoid hard dependency unless used."""
|
||||
try:
|
||||
from google import genai # type: ignore
|
||||
from google.genai import types # type: ignore
|
||||
|
||||
return genai, types
|
||||
except Exception as e: # pragma: no cover
|
||||
raise RuntimeError(
|
||||
"google.genai is required for the Gemini Computer Use loop. Install the Google Gemini SDK."
|
||||
) from e
|
||||
|
||||
|
||||
def _data_url_to_bytes(data_url: str) -> Tuple[bytes, str]:
|
||||
"""Convert a data URL to raw bytes and mime type."""
|
||||
if not data_url.startswith("data:"):
|
||||
# Assume it's base64 png payload
|
||||
try:
|
||||
return base64.b64decode(data_url), "image/png"
|
||||
except Exception:
|
||||
return b"", "application/octet-stream"
|
||||
header, b64 = data_url.split(",", 1)
|
||||
mime = "image/png"
|
||||
if ";" in header:
|
||||
mime = header.split(";")[0].split(":", 1)[1] or "image/png"
|
||||
return base64.b64decode(b64), mime
|
||||
|
||||
|
||||
def _bytes_image_size(img_bytes: bytes) -> Tuple[int, int]:
|
||||
try:
|
||||
img = Image.open(io.BytesIO(img_bytes))
|
||||
return img.size
|
||||
except Exception:
|
||||
return (1024, 768)
|
||||
|
||||
|
||||
def _find_last_user_text(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
texts: List[str] = []
|
||||
for msg in reversed(messages):
|
||||
if msg.get("type") in (None, "message") and msg.get("role") == "user":
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
return [content]
|
||||
elif isinstance(content, list):
|
||||
for c in content:
|
||||
if c.get("type") in ("input_text", "output_text") and c.get("text"):
|
||||
texts.append(c["text"]) # newest first
|
||||
if texts:
|
||||
return list(reversed(texts))
|
||||
return []
|
||||
|
||||
|
||||
def _find_last_screenshot(messages: List[Dict[str, Any]]) -> Optional[bytes]:
|
||||
for msg in reversed(messages):
|
||||
if msg.get("type") == "computer_call_output":
|
||||
out = msg.get("output", {})
|
||||
if isinstance(out, dict) and out.get("type") in ("input_image", "computer_screenshot"):
|
||||
image_url = out.get("image_url", "")
|
||||
if image_url:
|
||||
data, _ = _data_url_to_bytes(image_url)
|
||||
return data
|
||||
return None
|
||||
|
||||
|
||||
def _denormalize(v: int, size: int) -> int:
|
||||
# Gemini returns 0-999 normalized
|
||||
try:
|
||||
return max(0, min(size - 1, int(round(v / 1000 * size))))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def _map_gemini_fc_to_computer_call(
|
||||
fc: Dict[str, Any],
|
||||
screen_w: int,
|
||||
screen_h: int,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
name = fc.get("name")
|
||||
args = fc.get("args", {}) or {}
|
||||
|
||||
action: Dict[str, Any] = {}
|
||||
if name == "click_at":
|
||||
x = _denormalize(int(args.get("x", 0)), screen_w)
|
||||
y = _denormalize(int(args.get("y", 0)), screen_h)
|
||||
action = {"type": "click", "x": x, "y": y, "button": "left"}
|
||||
elif name == "type_text_at":
|
||||
x = _denormalize(int(args.get("x", 0)), screen_w)
|
||||
y = _denormalize(int(args.get("y", 0)), screen_h)
|
||||
text = args.get("text", "")
|
||||
if args.get("press_enter") == True:
|
||||
text += "\n"
|
||||
action = {"type": "type", "x": x, "y": y, "text": text}
|
||||
elif name == "hover_at":
|
||||
x = _denormalize(int(args.get("x", 0)), screen_w)
|
||||
y = _denormalize(int(args.get("y", 0)), screen_h)
|
||||
action = {"type": "move", "x": x, "y": y}
|
||||
elif name == "key_combination":
|
||||
keys = str(args.get("keys", ""))
|
||||
action = {"type": "keypress", "keys": keys}
|
||||
elif name == "scroll_document":
|
||||
direction = args.get("direction", "down")
|
||||
magnitude = 800
|
||||
dx, dy = 0, 0
|
||||
if direction == "down":
|
||||
dy = magnitude
|
||||
elif direction == "up":
|
||||
dy = -magnitude
|
||||
elif direction == "right":
|
||||
dx = magnitude
|
||||
elif direction == "left":
|
||||
dx = -magnitude
|
||||
action = {
|
||||
"type": "scroll",
|
||||
"scroll_x": dx,
|
||||
"scroll_y": dy,
|
||||
"x": int(screen_w / 2),
|
||||
"y": int(screen_h / 2),
|
||||
}
|
||||
elif name == "scroll_at":
|
||||
x = _denormalize(int(args.get("x", 500)), screen_w)
|
||||
y = _denormalize(int(args.get("y", 500)), screen_h)
|
||||
direction = args.get("direction", "down")
|
||||
magnitude = int(args.get("magnitude", 800))
|
||||
dx, dy = 0, 0
|
||||
if direction == "down":
|
||||
dy = magnitude
|
||||
elif direction == "up":
|
||||
dy = -magnitude
|
||||
elif direction == "right":
|
||||
dx = magnitude
|
||||
elif direction == "left":
|
||||
dx = -magnitude
|
||||
action = {"type": "scroll", "scroll_x": dx, "scroll_y": dy, "x": x, "y": y}
|
||||
elif name == "drag_and_drop":
|
||||
x = _denormalize(int(args.get("x", 0)), screen_w)
|
||||
y = _denormalize(int(args.get("y", 0)), screen_h)
|
||||
dx = _denormalize(int(args.get("destination_x", x)), screen_w)
|
||||
dy = _denormalize(int(args.get("destination_y", y)), screen_h)
|
||||
action = {
|
||||
"type": "drag",
|
||||
"start_x": x,
|
||||
"start_y": y,
|
||||
"end_x": dx,
|
||||
"end_y": dy,
|
||||
"button": "left",
|
||||
}
|
||||
elif name == "wait_5_seconds":
|
||||
action = {"type": "wait"}
|
||||
else:
|
||||
# Unsupported / excluded browser-specific or custom function; ignore
|
||||
return None
|
||||
|
||||
return {
|
||||
"type": "computer_call",
|
||||
"call_id": uuid.uuid4().hex,
|
||||
"status": "completed",
|
||||
"action": action,
|
||||
}
|
||||
|
||||
|
||||
@register_agent(models=r"^gemini-2\.5-computer-use-preview-10-2025$")
|
||||
class GeminiComputerUseConfig(AsyncAgentConfig):
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
genai, types = _lazy_import_genai()
|
||||
|
||||
client = genai.Client()
|
||||
|
||||
# Build excluded predefined functions for browser-specific behavior
|
||||
excluded = [
|
||||
"open_web_browser",
|
||||
"search",
|
||||
"navigate",
|
||||
"go_forward",
|
||||
"go_back",
|
||||
"scroll_document",
|
||||
]
|
||||
# Optional custom functions: can be extended by host code via `tools` parameter later if desired
|
||||
CUSTOM_FUNCTION_DECLARATIONS: List[Any] = []
|
||||
|
||||
# Compose tools config
|
||||
generate_content_config = types.GenerateContentConfig(
|
||||
tools=[
|
||||
types.Tool(
|
||||
computer_use=types.ComputerUse(
|
||||
environment=types.Environment.ENVIRONMENT_BROWSER,
|
||||
excluded_predefined_functions=excluded,
|
||||
)
|
||||
),
|
||||
# types.Tool(function_declarations=CUSTOM_FUNCTION_DECLARATIONS), # enable when custom functions needed
|
||||
]
|
||||
)
|
||||
|
||||
# Prepare contents: last user text + latest screenshot
|
||||
user_texts = _find_last_user_text(messages)
|
||||
screenshot_bytes = _find_last_screenshot(messages)
|
||||
|
||||
parts: List[Any] = []
|
||||
for t in user_texts:
|
||||
parts.append(types.Part(text=t))
|
||||
|
||||
screen_w, screen_h = 1024, 768
|
||||
if screenshot_bytes:
|
||||
screen_w, screen_h = _bytes_image_size(screenshot_bytes)
|
||||
parts.append(types.Part.from_bytes(data=screenshot_bytes, mime_type="image/png"))
|
||||
|
||||
# If we don't have any content, at least pass an empty user part to prompt reasoning
|
||||
if not parts:
|
||||
parts = [types.Part(text="Proceed to the next action.")]
|
||||
|
||||
contents = [types.Content(role="user", parts=parts)]
|
||||
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"contents": contents,
|
||||
"config": generate_content_config,
|
||||
}
|
||||
|
||||
if _on_api_start:
|
||||
await _on_api_start(
|
||||
{
|
||||
"model": api_kwargs["model"],
|
||||
# "contents": api_kwargs["contents"], # Disabled for now
|
||||
"config": api_kwargs["config"],
|
||||
}
|
||||
)
|
||||
|
||||
response = client.models.generate_content(**api_kwargs)
|
||||
|
||||
if _on_api_end:
|
||||
await _on_api_end(
|
||||
{
|
||||
"model": api_kwargs["model"],
|
||||
# "contents": api_kwargs["contents"], # Disabled for now
|
||||
"config": api_kwargs["config"],
|
||||
},
|
||||
response,
|
||||
)
|
||||
|
||||
# Usage (Gemini SDK may not always provide token usage; populate when available)
|
||||
usage: Dict[str, Any] = {}
|
||||
try:
|
||||
# Some SDKs expose response.usage; if available, copy
|
||||
if getattr(response, "usage_metadata", None):
|
||||
md = response.usage_metadata
|
||||
usage = {
|
||||
"prompt_tokens": getattr(md, "prompt_token_count", None) or 0,
|
||||
"completion_tokens": getattr(md, "candidates_token_count", None) or 0,
|
||||
"total_tokens": getattr(md, "total_token_count", None) or 0,
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if _on_usage and usage:
|
||||
await _on_usage(usage)
|
||||
|
||||
# Parse output into internal items
|
||||
output_items: List[Dict[str, Any]] = []
|
||||
|
||||
candidate = response.candidates[0]
|
||||
# Text parts from the model (assistant message)
|
||||
text_parts: List[str] = []
|
||||
function_calls: List[Dict[str, Any]] = []
|
||||
for p in candidate.content.parts:
|
||||
if getattr(p, "text", None):
|
||||
text_parts.append(p.text)
|
||||
if getattr(p, "function_call", None):
|
||||
# p.function_call has name and args
|
||||
fc = {
|
||||
"name": getattr(p.function_call, "name", None),
|
||||
"args": dict(getattr(p.function_call, "args", {}) or {}),
|
||||
}
|
||||
function_calls.append(fc)
|
||||
|
||||
if text_parts:
|
||||
output_items.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "\n".join(text_parts)}],
|
||||
}
|
||||
)
|
||||
|
||||
# Map function calls to internal computer_call actions
|
||||
for fc in function_calls:
|
||||
item = _map_gemini_fc_to_computer_call(fc, screen_w, screen_h)
|
||||
if item is not None:
|
||||
output_items.append(item)
|
||||
|
||||
return {"output": output_items, "usage": usage}
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str,
|
||||
**kwargs,
|
||||
) -> Optional[Tuple[float, float]]:
|
||||
"""Ask Gemini CUA to output a single click action for the given instruction.
|
||||
|
||||
Excludes all predefined tools except `click_at` and sends the screenshot.
|
||||
Returns pixel (x, y) if a click is proposed, else None.
|
||||
"""
|
||||
genai, types = _lazy_import_genai()
|
||||
|
||||
client = genai.Client()
|
||||
|
||||
# Exclude all but click_at
|
||||
exclude_all_but_click = [
|
||||
"open_web_browser",
|
||||
"wait_5_seconds",
|
||||
"go_back",
|
||||
"go_forward",
|
||||
"search",
|
||||
"navigate",
|
||||
"hover_at",
|
||||
"type_text_at",
|
||||
"key_combination",
|
||||
"scroll_document",
|
||||
"scroll_at",
|
||||
"drag_and_drop",
|
||||
]
|
||||
|
||||
config = types.GenerateContentConfig(
|
||||
tools=[
|
||||
types.Tool(
|
||||
computer_use=types.ComputerUse(
|
||||
environment=types.Environment.ENVIRONMENT_BROWSER,
|
||||
excluded_predefined_functions=exclude_all_but_click,
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Prepare prompt parts
|
||||
try:
|
||||
img_bytes = base64.b64decode(image_b64)
|
||||
except Exception:
|
||||
img_bytes = b""
|
||||
|
||||
w, h = _bytes_image_size(img_bytes) if img_bytes else (1024, 768)
|
||||
|
||||
parts: List[Any] = [types.Part(text=f"Click {instruction}.")]
|
||||
if img_bytes:
|
||||
parts.append(types.Part.from_bytes(data=img_bytes, mime_type="image/png"))
|
||||
|
||||
contents = [types.Content(role="user", parts=parts)]
|
||||
|
||||
response = client.models.generate_content(
|
||||
model=model,
|
||||
contents=contents,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Parse first click_at
|
||||
try:
|
||||
candidate = response.candidates[0]
|
||||
for p in candidate.content.parts:
|
||||
fc = getattr(p, "function_call", None)
|
||||
if fc and getattr(fc, "name", None) == "click_at":
|
||||
args = dict(getattr(fc, "args", {}) or {})
|
||||
x = _denormalize(int(args.get("x", 0)), w)
|
||||
y = _denormalize(int(args.get("y", 0)), h)
|
||||
return float(x), float(y)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
return ["click", "step"]
|
||||
@@ -4,33 +4,36 @@ Supports vision-language models for computer control with bounding box parsing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from litellm.responses.litellm_completion_transformation.transformation import (
|
||||
LiteLLMCompletionResponsesConfig,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.responses.litellm_completion_transformation.transformation import LiteLLMCompletionResponsesConfig
|
||||
from PIL import Image
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..responses import (
|
||||
convert_responses_items_to_completion_messages,
|
||||
convert_completion_messages_to_responses_items,
|
||||
make_reasoning_item,
|
||||
make_output_text_item,
|
||||
convert_responses_items_to_completion_messages,
|
||||
make_click_item,
|
||||
make_double_click_item,
|
||||
make_drag_item,
|
||||
make_input_image_item,
|
||||
make_keypress_item,
|
||||
make_output_text_item,
|
||||
make_reasoning_item,
|
||||
make_scroll_item,
|
||||
make_type_item,
|
||||
make_wait_item,
|
||||
make_input_image_item
|
||||
)
|
||||
from ..types import AgentCapability, AgentResponse, Messages, Tools
|
||||
|
||||
# GLM-4.5V specific constants
|
||||
GLM_ACTION_SPACE = """
|
||||
@@ -251,16 +254,18 @@ Call rule: `FAIL()`
|
||||
}
|
||||
}"""
|
||||
|
||||
|
||||
def encode_image_to_base64(image_path: str) -> str:
|
||||
"""Encode image file to base64 string with data URI."""
|
||||
with open(image_path, "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
return f"data:image/png;base64,{encoded_string}"
|
||||
|
||||
|
||||
def parse_glm_response(response: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse GLM-4.5V response to extract action and memory.
|
||||
|
||||
|
||||
The special tokens <|begin_of_box|> and <|end_of_box|> mark bounding boxes.
|
||||
Coordinates are normalized values between 0 and 1000.
|
||||
"""
|
||||
@@ -274,26 +279,23 @@ def parse_glm_response(response: str) -> Dict[str, Any]:
|
||||
action_pattern = r"[\w_]+\([^)]*\)"
|
||||
matches = re.findall(action_pattern, response)
|
||||
action = matches[0] if matches else None
|
||||
|
||||
|
||||
# Extract memory section
|
||||
memory_pattern = r"Memory:(.*?)$"
|
||||
memory_match = re.search(memory_pattern, response, re.DOTALL)
|
||||
memory = memory_match.group(1).strip() if memory_match else "[]"
|
||||
|
||||
|
||||
# Extract action text (everything before Memory:)
|
||||
action_text_pattern = r'^(.*?)Memory:'
|
||||
action_text_pattern = r"^(.*?)Memory:"
|
||||
action_text_match = re.search(action_text_pattern, response, re.DOTALL)
|
||||
action_text = action_text_match.group(1).strip() if action_text_match else response
|
||||
|
||||
|
||||
# Clean up action text by removing special tokens
|
||||
if action_text:
|
||||
action_text = action_text.replace("<|begin_of_box|>", "").replace("<|end_of_box|>", "")
|
||||
|
||||
return {
|
||||
"action": action,
|
||||
"action_text": action_text,
|
||||
"memory": memory
|
||||
}
|
||||
|
||||
return {"action": action, "action_text": action_text, "memory": memory}
|
||||
|
||||
|
||||
def get_last_image_from_messages(messages: Messages) -> Optional[str]:
|
||||
"""Extract the last image from messages for processing."""
|
||||
@@ -314,23 +316,28 @@ def get_last_image_from_messages(messages: Messages) -> Optional[str]:
|
||||
image_url_obj = item.get("image_url", {})
|
||||
if isinstance(image_url_obj, dict):
|
||||
image_url = image_url_obj.get("url", "")
|
||||
if isinstance(image_url, str) and image_url.startswith("data:image/"):
|
||||
if isinstance(image_url, str) and image_url.startswith(
|
||||
"data:image/"
|
||||
):
|
||||
return image_url.split(",", 1)[1]
|
||||
return None
|
||||
|
||||
def convert_responses_items_to_glm45v_pc_prompt(messages: Messages, task: str, memory: str = "") -> List[Dict[str, Any]]:
|
||||
|
||||
def convert_responses_items_to_glm45v_pc_prompt(
|
||||
messages: Messages, task: str, memory: str = ""
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert responses items to GLM-4.5V PC prompt format with historical actions.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of message items from the conversation
|
||||
task: The task description
|
||||
memory: Current memory state
|
||||
|
||||
|
||||
Returns:
|
||||
List of content items for the prompt (text and image_url items)
|
||||
"""
|
||||
action_space = GLM_ACTION_SPACE
|
||||
|
||||
|
||||
# Template head
|
||||
head_text = f"""You are a GUI Agent, and your primary task is to respond accurately to user requests or questions. In addition to directly answering the user's queries, you can also use tools or perform GUI operations directly until you fulfill the user's request or provide a correct answer. You should carefully read and understand the images and questions provided by the user, and engage in thinking and reflection when appropriate. The coordinates involved are all represented in thousandths (0-999).
|
||||
|
||||
@@ -345,7 +352,7 @@ Ubuntu
|
||||
|
||||
# Historical Actions and Current Memory
|
||||
History:"""
|
||||
|
||||
|
||||
# Template tail
|
||||
tail_text = f"""
|
||||
Memory:
|
||||
@@ -363,18 +370,18 @@ Memory:
|
||||
|
||||
Current Screenshot:
|
||||
"""
|
||||
|
||||
|
||||
# Build history from messages
|
||||
history = []
|
||||
history_images = []
|
||||
|
||||
|
||||
# Group messages into steps
|
||||
current_step = []
|
||||
step_num = 0
|
||||
|
||||
|
||||
for message in messages:
|
||||
msg_type = message.get("type")
|
||||
|
||||
|
||||
if msg_type == "reasoning":
|
||||
current_step.append(message)
|
||||
elif msg_type == "message" and message.get("role") == "assistant":
|
||||
@@ -386,7 +393,7 @@ Current Screenshot:
|
||||
# End of step - process it
|
||||
if current_step:
|
||||
step_num += 1
|
||||
|
||||
|
||||
# Extract bot thought from message content
|
||||
bot_thought = ""
|
||||
for item in current_step:
|
||||
@@ -397,14 +404,14 @@ Current Screenshot:
|
||||
bot_thought = content_item.get("text", "")
|
||||
break
|
||||
break
|
||||
|
||||
|
||||
# Extract action from computer_call
|
||||
action_text = ""
|
||||
for item in current_step:
|
||||
if item.get("type") == "computer_call":
|
||||
action = item.get("action", {})
|
||||
action_type = action.get("type", "")
|
||||
|
||||
|
||||
if action_type == "click":
|
||||
x, y = action.get("x", 0), action.get("y", 0)
|
||||
# Convert to 0-999 range (assuming screen dimensions)
|
||||
@@ -436,7 +443,7 @@ Current Screenshot:
|
||||
elif action_type == "wait":
|
||||
action_text = "WAIT()"
|
||||
break
|
||||
|
||||
|
||||
# Extract screenshot from computer_call_output
|
||||
screenshot_url = None
|
||||
for item in current_step:
|
||||
@@ -445,34 +452,34 @@ Current Screenshot:
|
||||
if output.get("type") == "input_image":
|
||||
screenshot_url = output.get("image_url", "")
|
||||
break
|
||||
|
||||
|
||||
# Store step info
|
||||
step_info = {
|
||||
"step_num": step_num,
|
||||
"bot_thought": bot_thought,
|
||||
"action_text": action_text,
|
||||
"screenshot_url": screenshot_url
|
||||
"screenshot_url": screenshot_url,
|
||||
}
|
||||
history.append(step_info)
|
||||
|
||||
|
||||
# Store screenshot for last 4 steps
|
||||
if screenshot_url:
|
||||
history_images.append(screenshot_url)
|
||||
|
||||
|
||||
current_step = []
|
||||
|
||||
|
||||
# Build content array with head, history, and tail
|
||||
content = []
|
||||
current_text = head_text
|
||||
|
||||
|
||||
total_history_steps = len(history)
|
||||
history_image_count = min(4, len(history_images)) # Last 4 images
|
||||
|
||||
|
||||
for step_idx, step_info in enumerate(history):
|
||||
step_num = step_info["step_num"]
|
||||
bot_thought = step_info["bot_thought"]
|
||||
action_text = step_info["action_text"]
|
||||
|
||||
|
||||
if step_idx < total_history_steps - history_image_count:
|
||||
# For steps beyond the last 4, use text placeholder
|
||||
current_text += f"\nstep {step_num}: Screenshot:(Omitted in context.) Thought: {bot_thought}\nAction: {action_text}"
|
||||
@@ -480,20 +487,21 @@ Current Screenshot:
|
||||
# For the last 4 steps, insert images
|
||||
current_text += f"\nstep {step_num}: Screenshot:"
|
||||
content.append({"type": "text", "text": current_text})
|
||||
|
||||
|
||||
# Add image
|
||||
img_idx = step_idx - (total_history_steps - history_image_count)
|
||||
if img_idx < len(history_images):
|
||||
content.append({"type": "image_url", "image_url": {"url": history_images[img_idx]}})
|
||||
|
||||
|
||||
current_text = f" Thought: {bot_thought}\nAction: {action_text}"
|
||||
|
||||
|
||||
# Add tail
|
||||
current_text += tail_text
|
||||
content.append({"type": "text", "text": current_text})
|
||||
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def model_dump(obj) -> Dict[str, Any]:
|
||||
if isinstance(obj, dict):
|
||||
return {k: model_dump(v) for k, v in obj.items()}
|
||||
@@ -502,58 +510,61 @@ def model_dump(obj) -> Dict[str, Any]:
|
||||
else:
|
||||
return obj
|
||||
|
||||
def convert_glm_completion_to_responses_items(response: ModelResponse, image_width: int, image_height: int) -> List[Dict[str, Any]]:
|
||||
|
||||
def convert_glm_completion_to_responses_items(
|
||||
response: ModelResponse, image_width: int, image_height: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert GLM-4.5V completion response to responses items format.
|
||||
|
||||
|
||||
Args:
|
||||
response: LiteLLM ModelResponse from GLM-4.5V
|
||||
image_width: Original image width for coordinate scaling
|
||||
image_height: Original image height for coordinate scaling
|
||||
|
||||
|
||||
Returns:
|
||||
List of response items in the proper format
|
||||
"""
|
||||
import uuid
|
||||
|
||||
|
||||
response_items = []
|
||||
|
||||
|
||||
if not response.choices or not response.choices[0].message:
|
||||
return response_items
|
||||
|
||||
|
||||
message = response.choices[0].message
|
||||
content = message.content or ""
|
||||
reasoning_content = getattr(message, 'reasoning_content', None)
|
||||
|
||||
reasoning_content = getattr(message, "reasoning_content", None)
|
||||
|
||||
# Add reasoning item if present
|
||||
if reasoning_content:
|
||||
reasoning_item = model_dump(make_reasoning_item(reasoning_content))
|
||||
response_items.append(reasoning_item)
|
||||
|
||||
|
||||
# Parse the content to extract action and text
|
||||
parsed_response = parse_glm_response(content)
|
||||
action = parsed_response.get("action", "")
|
||||
action_text = parsed_response.get("action_text", "")
|
||||
|
||||
|
||||
# Add message item with text content (excluding action and memory)
|
||||
if action_text:
|
||||
# Remove action from action_text if it's there
|
||||
clean_text = action_text
|
||||
if action and action in clean_text:
|
||||
clean_text = clean_text.replace(action, "").strip()
|
||||
|
||||
|
||||
# Remove memory section
|
||||
memory_pattern = r"Memory:\s*\[.*?\]\s*$"
|
||||
clean_text = re.sub(memory_pattern, "", clean_text, flags=re.DOTALL).strip()
|
||||
|
||||
|
||||
if clean_text:
|
||||
message_item = model_dump(make_output_text_item(clean_text))
|
||||
response_items.append(message_item)
|
||||
|
||||
|
||||
# Convert action to computer call if present
|
||||
if action:
|
||||
call_id = f"call_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
# Parse different action types and create appropriate computer calls
|
||||
if action.startswith("left_click"):
|
||||
coord_match = re.search(r"start_box='?\[(\d+),\s*(\d+)\]'?", action)
|
||||
@@ -566,7 +577,7 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
|
||||
computer_call["call_id"] = call_id
|
||||
computer_call["status"] = "completed"
|
||||
response_items.append(computer_call)
|
||||
|
||||
|
||||
elif action.startswith("right_click"):
|
||||
coord_match = re.search(r"start_box='?\[(\d+),\s*(\d+)\]'?", action)
|
||||
if coord_match:
|
||||
@@ -577,7 +588,7 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
|
||||
computer_call["call_id"] = call_id
|
||||
computer_call["status"] = "completed"
|
||||
response_items.append(computer_call)
|
||||
|
||||
|
||||
elif action.startswith("left_double_click"):
|
||||
coord_match = re.search(r"start_box='?\[(\d+),\s*(\d+)\]'?", action)
|
||||
if coord_match:
|
||||
@@ -588,7 +599,7 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
|
||||
computer_call["call_id"] = call_id
|
||||
computer_call["status"] = "completed"
|
||||
response_items.append(computer_call)
|
||||
|
||||
|
||||
elif action.startswith("left_drag"):
|
||||
start_match = re.search(r"start_box='?\[(\d+),\s*(\d+)\]'?", action)
|
||||
end_match = re.search(r"end_box='?\[(\d+),\s*(\d+)\]'?", action)
|
||||
@@ -605,18 +616,18 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
|
||||
computer_call["call_id"] = call_id
|
||||
computer_call["status"] = "completed"
|
||||
response_items.append(computer_call)
|
||||
|
||||
|
||||
elif action.startswith("key"):
|
||||
key_match = re.search(r"keys='([^']+)'", action)
|
||||
if key_match:
|
||||
keys = key_match.group(1)
|
||||
# Split keys by '+' for key combinations, or use as single key
|
||||
key_list = keys.split('+') if '+' in keys else [keys]
|
||||
key_list = keys.split("+") if "+" in keys else [keys]
|
||||
computer_call = model_dump(make_keypress_item(key_list))
|
||||
computer_call["call_id"] = call_id
|
||||
computer_call["status"] = "completed"
|
||||
response_items.append(computer_call)
|
||||
|
||||
|
||||
elif action.startswith("type"):
|
||||
content_match = re.search(r"content='([^']*)'", action)
|
||||
if content_match:
|
||||
@@ -625,7 +636,7 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
|
||||
computer_call["call_id"] = call_id
|
||||
computer_call["status"] = "completed"
|
||||
response_items.append(computer_call)
|
||||
|
||||
|
||||
elif action.startswith("scroll"):
|
||||
coord_match = re.search(r"start_box='?\[(\d+),\s*(\d+)\]'?", action)
|
||||
direction_match = re.search(r"direction='([^']+)'", action)
|
||||
@@ -648,15 +659,16 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
|
||||
computer_call["call_id"] = call_id
|
||||
computer_call["status"] = "completed"
|
||||
response_items.append(computer_call)
|
||||
|
||||
|
||||
elif action == "WAIT()":
|
||||
computer_call = model_dump(make_wait_item())
|
||||
computer_call["call_id"] = call_id
|
||||
computer_call["status"] = "completed"
|
||||
response_items.append(computer_call)
|
||||
|
||||
|
||||
return response_items
|
||||
|
||||
|
||||
@register_agent(models=r"(?i).*GLM-4\.5V.*")
|
||||
class Glm4vConfig(AsyncAgentConfig):
|
||||
"""GLM-4.5V agent configuration using liteLLM."""
|
||||
@@ -674,11 +686,11 @@ class Glm4vConfig(AsyncAgentConfig):
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Predict the next step using GLM-4.5V model.
|
||||
|
||||
|
||||
Args:
|
||||
messages: Input messages following Responses format
|
||||
model: Model name to use
|
||||
@@ -691,7 +703,7 @@ class Glm4vConfig(AsyncAgentConfig):
|
||||
_on_api_end: Callback for API end
|
||||
_on_usage: Callback for usage tracking
|
||||
_on_screenshot: Callback for screenshot events
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with "output" and "usage" keys
|
||||
"""
|
||||
@@ -708,7 +720,7 @@ class Glm4vConfig(AsyncAgentConfig):
|
||||
user_instruction = item.get("text", "")
|
||||
break
|
||||
break
|
||||
|
||||
|
||||
# Get the last image for processing
|
||||
last_image_b64 = get_last_image_from_messages(messages)
|
||||
if not last_image_b64 and computer_handler:
|
||||
@@ -718,35 +730,28 @@ class Glm4vConfig(AsyncAgentConfig):
|
||||
last_image_b64 = screenshot_b64
|
||||
if _on_screenshot:
|
||||
await _on_screenshot(screenshot_b64)
|
||||
|
||||
|
||||
if not last_image_b64:
|
||||
raise ValueError("No image available for GLM-4.5V processing")
|
||||
|
||||
|
||||
# Convert responses items to GLM-4.5V PC prompt format with historical actions
|
||||
prompt_content = convert_responses_items_to_glm45v_pc_prompt(
|
||||
messages=messages,
|
||||
task=user_instruction,
|
||||
memory="[]" # Initialize with empty memory for now
|
||||
memory="[]", # Initialize with empty memory for now
|
||||
)
|
||||
|
||||
|
||||
# Add the current screenshot to the end
|
||||
prompt_content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{last_image_b64}"}
|
||||
})
|
||||
|
||||
prompt_content.append(
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{last_image_b64}"}}
|
||||
)
|
||||
|
||||
# Prepare messages for liteLLM
|
||||
litellm_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful GUI agent assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt_content
|
||||
}
|
||||
{"role": "system", "content": "You are a helpful GUI agent assistant."},
|
||||
{"role": "user", "content": prompt_content},
|
||||
]
|
||||
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
@@ -757,20 +762,20 @@ class Glm4vConfig(AsyncAgentConfig):
|
||||
# "skip_special_tokens": False,
|
||||
# }
|
||||
}
|
||||
|
||||
|
||||
# Add API callbacks
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
|
||||
# Call liteLLM
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
|
||||
# Get image dimensions for coordinate scaling
|
||||
image_width, image_height = 1920, 1080 # Default dimensions
|
||||
|
||||
|
||||
# Try to get actual dimensions from the image
|
||||
try:
|
||||
image_data = base64.b64decode(last_image_b64)
|
||||
@@ -778,41 +783,38 @@ class Glm4vConfig(AsyncAgentConfig):
|
||||
image_width, image_height = image.size
|
||||
except Exception:
|
||||
pass # Use default dimensions
|
||||
|
||||
|
||||
# Convert GLM completion response to responses items
|
||||
response_items = convert_glm_completion_to_responses_items(response, image_width, image_height)
|
||||
|
||||
response_items = convert_glm_completion_to_responses_items(
|
||||
response, image_width, image_height
|
||||
)
|
||||
|
||||
# Extract usage information
|
||||
response_usage = {
|
||||
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(response.usage).model_dump(),
|
||||
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(
|
||||
response.usage
|
||||
).model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(response_usage)
|
||||
|
||||
|
||||
# Create agent response
|
||||
agent_response = {
|
||||
"output": response_items,
|
||||
"usage": response_usage
|
||||
}
|
||||
|
||||
agent_response = {"output": response_items, "usage": response_usage}
|
||||
|
||||
return agent_response
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str,
|
||||
**kwargs
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates using GLM-4.5V model.
|
||||
|
||||
|
||||
Args:
|
||||
model: Model name to use
|
||||
image_b64: Base64 encoded image
|
||||
instruction: Instruction for where to click
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple with (x, y) coordinates or None
|
||||
"""
|
||||
@@ -824,22 +826,22 @@ Respond with a single click action in this format:
|
||||
left_click(start_box='[x,y]')
|
||||
|
||||
Where x,y are coordinates normalized to 0-999 range."""
|
||||
|
||||
|
||||
# Prepare messages for liteLLM
|
||||
litellm_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful GUI agent assistant."
|
||||
},
|
||||
{"role": "system", "content": "You are a helpful GUI agent assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": click_prompt},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}}
|
||||
]
|
||||
}
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
@@ -848,21 +850,21 @@ Where x,y are coordinates normalized to 0-999 range."""
|
||||
"temperature": 0.001,
|
||||
"extra_body": {
|
||||
"skip_special_tokens": False,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Call liteLLM
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
|
||||
# Extract response content
|
||||
response_content = response.choices[0].message.content.strip()
|
||||
print(response)
|
||||
|
||||
|
||||
# Parse response for click coordinates
|
||||
# Look for coordinates in the response, handling special tokens
|
||||
coord_pattern = r"<\|begin_of_box\|>.*?left_click\(start_box='?\[(\d+),(\d+)\]'?\).*?<\|end_of_box\|>"
|
||||
match = re.search(coord_pattern, response_content)
|
||||
|
||||
|
||||
if not match:
|
||||
# Fallback: look for coordinates without special tokens
|
||||
coord_pattern = r"left_click\(start_box='?\[(\d+),(\d+)\]'?\)"
|
||||
@@ -870,7 +872,7 @@ Where x,y are coordinates normalized to 0-999 range."""
|
||||
|
||||
if match:
|
||||
x, y = int(match.group(1)), int(match.group(2))
|
||||
|
||||
|
||||
# Get actual image dimensions for scaling
|
||||
try:
|
||||
image_data = base64.b64decode(image_b64)
|
||||
@@ -879,15 +881,15 @@ Where x,y are coordinates normalized to 0-999 range."""
|
||||
except Exception:
|
||||
# Use default dimensions
|
||||
image_width, image_height = 1920, 1080
|
||||
|
||||
|
||||
# Convert from 0-999 normalized coordinates to actual pixel coordinates
|
||||
actual_x = int((x / 999.0) * image_width)
|
||||
actual_y = int((y / 999.0) * image_height)
|
||||
|
||||
|
||||
return (actual_x, actual_y)
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Log error and return None
|
||||
print(f"Error in predict_click: {e}")
|
||||
@@ -896,7 +898,7 @@ Where x,y are coordinates normalized to 0-999 range."""
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""
|
||||
Get list of capabilities supported by this agent config.
|
||||
|
||||
|
||||
Returns:
|
||||
List of capability strings
|
||||
"""
|
||||
|
||||
@@ -5,75 +5,80 @@ Code: https://github.com/Yan98/GTA1
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import base64
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
|
||||
from io import BytesIO
|
||||
import uuid
|
||||
from PIL import Image
|
||||
import litellm
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import litellm
|
||||
from PIL import Image
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..types import AgentCapability, AgentResponse, Messages, Tools
|
||||
|
||||
SYSTEM_PROMPT = '''
|
||||
SYSTEM_PROMPT = """
|
||||
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point.
|
||||
|
||||
Output the coordinate pair exactly:
|
||||
(x,y)
|
||||
'''.strip()
|
||||
""".strip()
|
||||
|
||||
|
||||
def extract_coordinates(raw_string: str) -> Tuple[float, float]:
|
||||
"""Extract coordinates from model output."""
|
||||
try:
|
||||
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
|
||||
return tuple(map(float, matches[0])) # type: ignore
|
||||
return tuple(map(float, matches[0])) # type: ignore
|
||||
except:
|
||||
return (0.0, 0.0)
|
||||
|
||||
def smart_resize(height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 8847360) -> Tuple[int, int]:
|
||||
|
||||
def smart_resize(
|
||||
height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 8847360
|
||||
) -> Tuple[int, int]:
|
||||
"""Smart resize function similar to qwen_vl_utils."""
|
||||
# Calculate the total pixels
|
||||
total_pixels = height * width
|
||||
|
||||
|
||||
# If already within bounds, return original dimensions
|
||||
if min_pixels <= total_pixels <= max_pixels:
|
||||
# Round to nearest factor
|
||||
new_height = (height // factor) * factor
|
||||
new_width = (width // factor) * factor
|
||||
return new_height, new_width
|
||||
|
||||
|
||||
# Calculate scaling factor
|
||||
if total_pixels > max_pixels:
|
||||
scale = (max_pixels / total_pixels) ** 0.5
|
||||
else:
|
||||
scale = (min_pixels / total_pixels) ** 0.5
|
||||
|
||||
|
||||
# Apply scaling
|
||||
new_height = int(height * scale)
|
||||
new_width = int(width * scale)
|
||||
|
||||
|
||||
# Round to nearest factor
|
||||
new_height = (new_height // factor) * factor
|
||||
new_width = (new_width // factor) * factor
|
||||
|
||||
|
||||
# Ensure minimum size
|
||||
new_height = max(new_height, factor)
|
||||
new_width = max(new_width, factor)
|
||||
|
||||
|
||||
return new_height, new_width
|
||||
|
||||
|
||||
@register_agent(models=r".*GTA1.*")
|
||||
class GTA1Config(AsyncAgentConfig):
|
||||
"""GTA1 agent configuration implementing AsyncAgentConfig protocol for click prediction."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.current_model = None
|
||||
self.last_screenshot_b64 = None
|
||||
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
@@ -87,25 +92,21 @@ class GTA1Config(AsyncAgentConfig):
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str,
|
||||
**kwargs
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
) -> Optional[Tuple[float, float]]:
|
||||
"""
|
||||
Predict click coordinates using GTA1 model via litellm.acompletion.
|
||||
|
||||
|
||||
Args:
|
||||
model: The GTA1 model name
|
||||
image_b64: Base64 encoded image
|
||||
instruction: Instruction for where to click
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
@@ -113,66 +114,62 @@ class GTA1Config(AsyncAgentConfig):
|
||||
image_data = base64.b64decode(image_b64)
|
||||
image = Image.open(BytesIO(image_data))
|
||||
width, height = image.width, image.height
|
||||
|
||||
|
||||
# Smart resize the image (similar to qwen_vl_utils)
|
||||
resized_height, resized_width = smart_resize(
|
||||
height, width,
|
||||
height,
|
||||
width,
|
||||
factor=28, # Default factor for Qwen models
|
||||
min_pixels=3136,
|
||||
max_pixels=4096 * 2160
|
||||
max_pixels=4096 * 2160,
|
||||
)
|
||||
resized_image = image.resize((resized_width, resized_height))
|
||||
scale_x, scale_y = width / resized_width, height / resized_height
|
||||
|
||||
|
||||
# Convert resized image back to base64
|
||||
buffered = BytesIO()
|
||||
resized_image.save(buffered, format="PNG")
|
||||
resized_image_b64 = base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
|
||||
# Prepare system and user messages
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": SYSTEM_PROMPT.format(height=resized_height, width=resized_width)
|
||||
"content": SYSTEM_PROMPT.format(height=resized_height, width=resized_width),
|
||||
}
|
||||
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{resized_image_b64}"
|
||||
}
|
||||
"image_url": {"url": f"data:image/png;base64,{resized_image_b64}"},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": instruction
|
||||
}
|
||||
]
|
||||
{"type": "text", "text": instruction},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": [system_message, user_message],
|
||||
"max_tokens": 2056,
|
||||
"temperature": 0.0,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
# Use liteLLM acompletion
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
|
||||
# Extract response text
|
||||
output_text = response.choices[0].message.content # type: ignore
|
||||
|
||||
output_text = response.choices[0].message.content # type: ignore
|
||||
|
||||
# Extract and rescale coordinates
|
||||
pred_x, pred_y = extract_coordinates(output_text) # type: ignore
|
||||
pred_x, pred_y = extract_coordinates(output_text) # type: ignore
|
||||
pred_x *= scale_x
|
||||
pred_y *= scale_y
|
||||
|
||||
|
||||
return (math.floor(pred_x), math.floor(pred_y))
|
||||
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""Return the capabilities supported by this agent."""
|
||||
return ["click"]
|
||||
|
||||
@@ -21,8 +21,8 @@ import litellm
|
||||
from PIL import Image
|
||||
|
||||
from ..decorators import register_agent
|
||||
from .base import AsyncAgentConfig
|
||||
from ..types import AgentCapability
|
||||
from .base import AsyncAgentConfig
|
||||
|
||||
|
||||
def _strip_hf_prefix(model: str) -> str:
|
||||
@@ -53,7 +53,9 @@ def _maybe_smart_resize(image: Image.Image, model: str) -> Tuple[Image.Image, Tu
|
||||
if image_processor is None:
|
||||
return image, (orig_w, orig_h)
|
||||
|
||||
factor = getattr(image_processor, "patch_size", 14) * getattr(image_processor, "merge_size", 1)
|
||||
factor = getattr(image_processor, "patch_size", 14) * getattr(
|
||||
image_processor, "merge_size", 1
|
||||
)
|
||||
min_pixels = getattr(image_processor, "min_pixels", 256 * 256)
|
||||
max_pixels = getattr(image_processor, "max_pixels", 1536 * 1536)
|
||||
|
||||
|
||||
@@ -18,13 +18,12 @@ import re
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from PIL import Image
|
||||
import litellm
|
||||
from PIL import Image
|
||||
|
||||
from ..decorators import register_agent
|
||||
from .composed_grounded import ComposedGroundedConfig
|
||||
from ..types import AgentCapability
|
||||
|
||||
from .composed_grounded import ComposedGroundedConfig
|
||||
|
||||
# Regex patterns for extracting coordinates
|
||||
# Accept optional whitespace and optional decimal fractions
|
||||
@@ -91,7 +90,7 @@ class InternVLConfig(ComposedGroundedConfig):
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Fallback to a self-composed model"""
|
||||
return await super().predict_step(
|
||||
@@ -105,15 +104,11 @@ class InternVLConfig(ComposedGroundedConfig):
|
||||
_on_api_end=_on_api_end,
|
||||
_on_usage=_on_usage,
|
||||
_on_screenshot=_on_screenshot,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str,
|
||||
**kwargs
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates using InternVL via litellm.acompletion.
|
||||
|
||||
493
libs/python/agent/agent/loops/moondream3.py
Normal file
493
libs/python/agent/agent/loops/moondream3.py
Normal file
@@ -0,0 +1,493 @@
|
||||
"""
|
||||
Moondream3+ composed-grounded agent loop implementation.
|
||||
Grounding is handled by a local Moondream3 preview model via Transformers.
|
||||
Thinking is delegated to the trailing LLM in the composed model string: "moondream3+<thinking_model>".
|
||||
|
||||
Differences from composed_grounded:
|
||||
- Provides a singleton Moondream3 client outside the class.
|
||||
- predict_click uses model.point(image, instruction, settings={"max_objects": 1}) and returns pixel coordinates.
|
||||
- If the last image was a screenshot (or we take one), run model.detect(image, "all form ui") to get bboxes, then
|
||||
run model.caption on each cropped bbox to label it. Overlay labels on the screenshot and emit via _on_screenshot.
|
||||
- Add a user message listing all detected form UI names so the thinker can reference them.
|
||||
- If the thinking model doesn't support vision, filter out image content before calling litellm.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..responses import (
|
||||
convert_completion_messages_to_responses_items,
|
||||
convert_computer_calls_desc2xy,
|
||||
convert_computer_calls_xy2desc,
|
||||
convert_responses_items_to_completion_messages,
|
||||
get_all_element_descriptions,
|
||||
)
|
||||
from ..types import AgentCapability
|
||||
|
||||
_MOONDREAM_SINGLETON = None
|
||||
|
||||
|
||||
def get_moondream_model() -> Any:
|
||||
"""Get a singleton instance of the Moondream3 preview model."""
|
||||
global _MOONDREAM_SINGLETON
|
||||
if _MOONDREAM_SINGLETON is None:
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
_MOONDREAM_SINGLETON = AutoModelForCausalLM.from_pretrained(
|
||||
"moondream/moondream3-preview",
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda",
|
||||
)
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"moondream3 requires torch and transformers. Install with: pip install cua-agent[moondream3]"
|
||||
) from e
|
||||
return _MOONDREAM_SINGLETON
|
||||
|
||||
|
||||
def _decode_image_b64(image_b64: str) -> Image.Image:
|
||||
data = base64.b64decode(image_b64)
|
||||
return Image.open(io.BytesIO(data)).convert("RGB")
|
||||
|
||||
|
||||
def _image_to_b64(img: Image.Image) -> str:
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
return base64.b64encode(buf.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def _supports_vision(model: str) -> bool:
|
||||
"""Heuristic vision support detection for thinking model."""
|
||||
m = model.lower()
|
||||
vision_markers = [
|
||||
"gpt-4o",
|
||||
"gpt-4.1",
|
||||
"o1",
|
||||
"o3",
|
||||
"claude-3",
|
||||
"claude-3.5",
|
||||
"sonnet",
|
||||
"haiku",
|
||||
"opus",
|
||||
"gemini-1.5",
|
||||
"llava",
|
||||
]
|
||||
return any(v in m for v in vision_markers)
|
||||
|
||||
|
||||
def _filter_images_from_completion_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
filtered: List[Dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
msg_copy = {**msg}
|
||||
content = msg_copy.get("content")
|
||||
if isinstance(content, list):
|
||||
msg_copy["content"] = [c for c in content if c.get("type") != "image_url"]
|
||||
filtered.append(msg_copy)
|
||||
return filtered
|
||||
|
||||
|
||||
def _annotate_detect_and_label_ui(base_img: Image.Image, model_md) -> Tuple[str, List[str]]:
|
||||
"""Detect UI elements with Moondream, caption each, draw labels with backgrounds.
|
||||
|
||||
Args:
|
||||
base_img: PIL image of the screenshot (RGB or RGBA). Will be copied/converted internally.
|
||||
model_md: Moondream model instance with .detect() and .query() methods.
|
||||
|
||||
Returns:
|
||||
A tuple of (annotated_image_base64_png, detected_names)
|
||||
"""
|
||||
# Ensure RGBA for semi-transparent fills
|
||||
if base_img.mode != "RGBA":
|
||||
base_img = base_img.convert("RGBA")
|
||||
W, H = base_img.width, base_img.height
|
||||
|
||||
# Detect objects
|
||||
try:
|
||||
detect_result = model_md.detect(base_img, "all ui elements")
|
||||
objects = detect_result.get("objects", []) if isinstance(detect_result, dict) else []
|
||||
except Exception:
|
||||
objects = []
|
||||
|
||||
draw = ImageDraw.Draw(base_img)
|
||||
try:
|
||||
font = ImageFont.load_default()
|
||||
except Exception:
|
||||
font = None
|
||||
|
||||
detected_names: List[str] = []
|
||||
|
||||
for i, obj in enumerate(objects):
|
||||
try:
|
||||
# Clamp normalized coords and crop
|
||||
x_min = max(0.0, min(1.0, float(obj.get("x_min", 0.0))))
|
||||
y_min = max(0.0, min(1.0, float(obj.get("y_min", 0.0))))
|
||||
x_max = max(0.0, min(1.0, float(obj.get("x_max", 0.0))))
|
||||
y_max = max(0.0, min(1.0, float(obj.get("y_max", 0.0))))
|
||||
left, top, right, bottom = (
|
||||
int(x_min * W),
|
||||
int(y_min * H),
|
||||
int(x_max * W),
|
||||
int(y_max * H),
|
||||
)
|
||||
left, top = max(0, left), max(0, top)
|
||||
right, bottom = min(W - 1, right), min(H - 1, bottom)
|
||||
crop = base_img.crop((left, top, right, bottom))
|
||||
|
||||
# Prompted short caption
|
||||
try:
|
||||
result = model_md.query(crop, "Caption this UI element in few words.")
|
||||
caption_text = (result or {}).get("answer", "")
|
||||
except Exception:
|
||||
caption_text = ""
|
||||
|
||||
name = (caption_text or "").strip() or f"element_{i+1}"
|
||||
detected_names.append(name)
|
||||
|
||||
# Draw bbox
|
||||
draw.rectangle([left, top, right, bottom], outline=(255, 215, 0, 255), width=2)
|
||||
|
||||
# Label background with padding and rounded corners
|
||||
label = f"{i+1}. {name}"
|
||||
padding = 3
|
||||
if font:
|
||||
text_bbox = draw.textbbox((0, 0), label, font=font)
|
||||
else:
|
||||
text_bbox = draw.textbbox((0, 0), label)
|
||||
text_w = text_bbox[2] - text_bbox[0]
|
||||
text_h = text_bbox[3] - text_bbox[1]
|
||||
|
||||
tx = left + 3
|
||||
ty = top - (text_h + 2 * padding + 4)
|
||||
if ty < 0:
|
||||
ty = top + 3
|
||||
|
||||
bg_left = tx - padding
|
||||
bg_top = ty - padding
|
||||
bg_right = tx + text_w + padding
|
||||
bg_bottom = ty + text_h + padding
|
||||
try:
|
||||
draw.rounded_rectangle(
|
||||
[bg_left, bg_top, bg_right, bg_bottom],
|
||||
radius=4,
|
||||
fill=(0, 0, 0, 160),
|
||||
outline=(255, 215, 0, 200),
|
||||
width=1,
|
||||
)
|
||||
except Exception:
|
||||
draw.rectangle(
|
||||
[bg_left, bg_top, bg_right, bg_bottom],
|
||||
fill=(0, 0, 0, 160),
|
||||
outline=(255, 215, 0, 200),
|
||||
width=1,
|
||||
)
|
||||
|
||||
text_fill = (255, 255, 255, 255)
|
||||
if font:
|
||||
draw.text((tx, ty), label, fill=text_fill, font=font)
|
||||
else:
|
||||
draw.text((tx, ty), label, fill=text_fill)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Encode PNG base64
|
||||
annotated = base_img
|
||||
if annotated.mode not in ("RGBA", "RGB"):
|
||||
annotated = annotated.convert("RGBA")
|
||||
annotated_b64 = _image_to_b64(annotated)
|
||||
return annotated_b64, detected_names
|
||||
|
||||
|
||||
GROUNDED_COMPUTER_TOOL_SCHEMA = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "computer",
|
||||
"description": (
|
||||
"Control a computer by taking screenshots and interacting with UI elements. "
|
||||
"The screenshot action will include a list of detected form UI element names when available. "
|
||||
"Use element descriptions to locate and interact with UI elements on the screen."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"screenshot",
|
||||
"click",
|
||||
"double_click",
|
||||
"drag",
|
||||
"type",
|
||||
"keypress",
|
||||
"scroll",
|
||||
"move",
|
||||
"wait",
|
||||
"get_current_url",
|
||||
"get_dimensions",
|
||||
"get_environment",
|
||||
],
|
||||
"description": "The action to perform (required for all actions)",
|
||||
},
|
||||
"element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to interact with (required for click/double_click/move/scroll)",
|
||||
},
|
||||
"start_element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to start dragging from (required for drag)",
|
||||
},
|
||||
"end_element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to drag to (required for drag)",
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to type (required for type)",
|
||||
},
|
||||
"keys": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Key(s) to press (required for keypress)",
|
||||
},
|
||||
"button": {
|
||||
"type": "string",
|
||||
"enum": ["left", "right", "wheel", "back", "forward"],
|
||||
"description": "The mouse button to use for click/double_click",
|
||||
},
|
||||
"scroll_x": {
|
||||
"type": "integer",
|
||||
"description": "Horizontal scroll amount (required for scroll)",
|
||||
},
|
||||
"scroll_y": {
|
||||
"type": "integer",
|
||||
"description": "Vertical scroll amount (required for scroll)",
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@register_agent(r"moondream3\+.*", priority=2)
|
||||
class Moondream3PlusConfig(AsyncAgentConfig):
|
||||
def __init__(self):
|
||||
self.desc2xy: Dict[str, Tuple[float, float]] = {}
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
# Parse composed model: moondream3+<thinking_model>
|
||||
if "+" not in model:
|
||||
raise ValueError(f"Composed model must be 'moondream3+<thinking_model>', got: {model}")
|
||||
_, thinking_model = model.split("+", 1)
|
||||
|
||||
pre_output_items: List[Dict[str, Any]] = []
|
||||
|
||||
# Acquire last screenshot; if missing, take one
|
||||
last_image_b64: Optional[str] = None
|
||||
for message in reversed(messages):
|
||||
if (
|
||||
isinstance(message, dict)
|
||||
and message.get("type") == "computer_call_output"
|
||||
and isinstance(message.get("output"), dict)
|
||||
and message["output"].get("type") == "input_image"
|
||||
):
|
||||
image_url = message["output"].get("image_url", "")
|
||||
if image_url.startswith("data:image/png;base64,"):
|
||||
last_image_b64 = image_url.split(",", 1)[1]
|
||||
break
|
||||
|
||||
if last_image_b64 is None and computer_handler is not None:
|
||||
# Take a screenshot
|
||||
screenshot_b64 = await computer_handler.screenshot() # type: ignore
|
||||
if screenshot_b64:
|
||||
call_id = uuid.uuid4().hex
|
||||
pre_output_items += [
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": "Taking a screenshot to analyze the current screen.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"type": "computer_call",
|
||||
"call_id": call_id,
|
||||
"status": "completed",
|
||||
"action": {"type": "screenshot"},
|
||||
},
|
||||
{
|
||||
"type": "computer_call_output",
|
||||
"call_id": call_id,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshot_b64}",
|
||||
},
|
||||
},
|
||||
]
|
||||
last_image_b64 = screenshot_b64
|
||||
if _on_screenshot:
|
||||
await _on_screenshot(screenshot_b64)
|
||||
|
||||
# If we have a last screenshot, run Moondream detection and labeling
|
||||
detected_names: List[str] = []
|
||||
if last_image_b64 is not None:
|
||||
base_img = _decode_image_b64(last_image_b64)
|
||||
model_md = get_moondream_model()
|
||||
annotated_b64, detected_names = _annotate_detect_and_label_ui(base_img, model_md)
|
||||
if _on_screenshot:
|
||||
await _on_screenshot(annotated_b64, "annotated_form_ui")
|
||||
|
||||
# Also push a user message listing all detected names
|
||||
if detected_names:
|
||||
names_text = "\n".join(f"- {n}" for n in detected_names)
|
||||
pre_output_items.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "input_text", "text": "Detected form UI elements on screen:"},
|
||||
{"type": "input_text", "text": names_text},
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "Please continue with the next action needed to perform your task.",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
tool_schemas = []
|
||||
for schema in tools or []:
|
||||
if schema.get("type") == "computer":
|
||||
tool_schemas.append(GROUNDED_COMPUTER_TOOL_SCHEMA)
|
||||
else:
|
||||
tool_schemas.append(schema)
|
||||
|
||||
# Step 1: Convert computer calls from xy to descriptions
|
||||
input_messages = messages + pre_output_items
|
||||
messages_with_descriptions = convert_computer_calls_xy2desc(input_messages, self.desc2xy)
|
||||
|
||||
# Step 2: Convert responses items to completion messages
|
||||
completion_messages = convert_responses_items_to_completion_messages(
|
||||
messages_with_descriptions,
|
||||
allow_images_in_tool_results=False,
|
||||
)
|
||||
|
||||
# Optionally filter images if model lacks vision
|
||||
if not _supports_vision(thinking_model):
|
||||
completion_messages = _filter_images_from_completion_messages(completion_messages)
|
||||
|
||||
# Step 3: Call thinking model with litellm.acompletion
|
||||
api_kwargs = {
|
||||
"model": thinking_model,
|
||||
"messages": completion_messages,
|
||||
"tools": tool_schemas,
|
||||
"max_retries": max_retries,
|
||||
"stream": stream,
|
||||
**kwargs,
|
||||
}
|
||||
if use_prompt_caching:
|
||||
api_kwargs["use_prompt_caching"] = use_prompt_caching
|
||||
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
usage = {
|
||||
**response.usage.model_dump(), # type: ignore
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(usage)
|
||||
|
||||
# Step 4: Convert completion messages back to responses items format
|
||||
response_dict = response.model_dump() # type: ignore
|
||||
choice_messages = [choice["message"] for choice in response_dict["choices"]]
|
||||
thinking_output_items: List[Dict[str, Any]] = []
|
||||
for choice_message in choice_messages:
|
||||
thinking_output_items.extend(
|
||||
convert_completion_messages_to_responses_items([choice_message])
|
||||
)
|
||||
|
||||
# Step 5: Use Moondream to get coordinates for each description
|
||||
element_descriptions = get_all_element_descriptions(thinking_output_items)
|
||||
if element_descriptions and last_image_b64:
|
||||
for desc in element_descriptions:
|
||||
for _ in range(3): # try 3 times
|
||||
coords = await self.predict_click(
|
||||
model=model,
|
||||
image_b64=last_image_b64,
|
||||
instruction=desc,
|
||||
)
|
||||
if coords:
|
||||
self.desc2xy[desc] = coords
|
||||
break
|
||||
|
||||
# Step 6: Convert computer calls from descriptions back to xy coordinates
|
||||
final_output_items = convert_computer_calls_desc2xy(thinking_output_items, self.desc2xy)
|
||||
|
||||
# Step 7: Return output and usage
|
||||
return {"output": pre_output_items + final_output_items, "usage": usage}
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str,
|
||||
**kwargs,
|
||||
) -> Optional[Tuple[float, float]]:
|
||||
"""Predict click coordinates using Moondream3's point API.
|
||||
|
||||
Returns pixel coordinates (x, y) as floats.
|
||||
"""
|
||||
img = _decode_image_b64(image_b64)
|
||||
W, H = img.width, img.height
|
||||
model_md = get_moondream_model()
|
||||
try:
|
||||
result = model_md.point(img, instruction, settings={"max_objects": 1})
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
try:
|
||||
pt = (result or {}).get("points", [])[0]
|
||||
x_norm = float(pt.get("x", 0.0))
|
||||
y_norm = float(pt.get("y", 0.0))
|
||||
x_px = max(0.0, min(float(W - 1), x_norm * W))
|
||||
y_px = max(0.0, min(float(H - 1), y_norm * H))
|
||||
return (x_px, y_px)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
return ["click", "step"]
|
||||
@@ -5,100 +5,102 @@ Code: https://github.com/microsoft/OmniParser
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
|
||||
import litellm
|
||||
import inspect
|
||||
import base64
|
||||
import inspect
|
||||
import json
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import litellm
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..types import AgentCapability, AgentResponse, Messages, Tools
|
||||
|
||||
SOM_TOOL_SCHEMA = {
|
||||
"type": "function",
|
||||
"name": "computer",
|
||||
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool shows screenshots with numbered elements overlaid on them. Each UI element has been assigned a unique ID number that you can see in the image. Use the element's ID number to interact with any element instead of pixel coordinates.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"screenshot",
|
||||
"click",
|
||||
"double_click",
|
||||
"drag",
|
||||
"type",
|
||||
"keypress",
|
||||
"scroll",
|
||||
"move",
|
||||
"wait",
|
||||
"get_current_url",
|
||||
"get_dimensions",
|
||||
"get_environment"
|
||||
],
|
||||
"description": "The action to perform"
|
||||
},
|
||||
"element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)"
|
||||
},
|
||||
"start_element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to start dragging from (required for drag action)"
|
||||
},
|
||||
"end_element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to drag to (required for drag action)"
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to type (required for type action)"
|
||||
},
|
||||
"keys": {
|
||||
"type": "string",
|
||||
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')"
|
||||
},
|
||||
"button": {
|
||||
"type": "string",
|
||||
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
|
||||
},
|
||||
"scroll_x": {
|
||||
"type": "integer",
|
||||
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
|
||||
},
|
||||
"scroll_y": {
|
||||
"type": "integer",
|
||||
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
|
||||
},
|
||||
"type": "function",
|
||||
"name": "computer",
|
||||
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool shows screenshots with numbered elements overlaid on them. Each UI element has been assigned a unique ID number that you can see in the image. Use the element's ID number to interact with any element instead of pixel coordinates.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"screenshot",
|
||||
"click",
|
||||
"double_click",
|
||||
"drag",
|
||||
"type",
|
||||
"keypress",
|
||||
"scroll",
|
||||
"move",
|
||||
"wait",
|
||||
"get_current_url",
|
||||
"get_dimensions",
|
||||
"get_environment",
|
||||
],
|
||||
"description": "The action to perform",
|
||||
},
|
||||
"element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)",
|
||||
},
|
||||
"start_element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to start dragging from (required for drag action)",
|
||||
},
|
||||
"end_element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to drag to (required for drag action)",
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to type (required for type action)",
|
||||
},
|
||||
"keys": {
|
||||
"type": "string",
|
||||
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')",
|
||||
},
|
||||
"button": {
|
||||
"type": "string",
|
||||
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
|
||||
},
|
||||
"scroll_x": {
|
||||
"type": "integer",
|
||||
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
|
||||
},
|
||||
"scroll_y": {
|
||||
"type": "integer",
|
||||
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
},
|
||||
"required": [
|
||||
"action"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
OMNIPARSER_AVAILABLE = False
|
||||
try:
|
||||
from som import OmniParser
|
||||
|
||||
OMNIPARSER_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
OMNIPARSER_SINGLETON = None
|
||||
|
||||
|
||||
def get_parser():
|
||||
global OMNIPARSER_SINGLETON
|
||||
if OMNIPARSER_SINGLETON is None:
|
||||
OMNIPARSER_SINGLETON = OmniParser()
|
||||
return OMNIPARSER_SINGLETON
|
||||
|
||||
|
||||
|
||||
def get_last_computer_call_output(messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""Get the last computer_call_output message from a messages list.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of messages to search through
|
||||
|
||||
|
||||
Returns:
|
||||
The last computer_call_output message dict, or None if not found
|
||||
"""
|
||||
@@ -107,11 +109,12 @@ def get_last_computer_call_output(messages: List[Dict[str, Any]]) -> Optional[Di
|
||||
return message
|
||||
return None
|
||||
|
||||
|
||||
def _prepare_tools_for_omniparser(tool_schemas: List[Dict[str, Any]]) -> Tuple[Tools, dict]:
|
||||
"""Prepare tools for OpenAI API format"""
|
||||
omniparser_tools = []
|
||||
id2xy = dict()
|
||||
|
||||
|
||||
for schema in tool_schemas:
|
||||
if schema["type"] == "computer":
|
||||
omniparser_tools.append(SOM_TOOL_SCHEMA)
|
||||
@@ -122,72 +125,80 @@ def _prepare_tools_for_omniparser(tool_schemas: List[Dict[str, Any]]) -> Tuple[T
|
||||
elif schema["type"] == "function":
|
||||
# Function tools use OpenAI-compatible schema directly (liteLLM expects this format)
|
||||
# Schema should be: {type, name, description, parameters}
|
||||
omniparser_tools.append({ "type": "function", **schema["function"] })
|
||||
|
||||
omniparser_tools.append({"type": "function", **schema["function"]})
|
||||
|
||||
return omniparser_tools, id2xy
|
||||
|
||||
async def replace_function_with_computer_call(item: Dict[str, Any], id2xy: Dict[int, Tuple[float, float]]):
|
||||
item_type = item.get("type")
|
||||
|
||||
def _get_xy(element_id: Optional[int]) -> Union[Tuple[float, float], Tuple[None, None]]:
|
||||
if element_id is None:
|
||||
return (None, None)
|
||||
return id2xy.get(element_id, (None, None))
|
||||
async def replace_function_with_computer_call(
|
||||
item: Dict[str, Any], id2xy: Dict[int, Tuple[float, float]]
|
||||
):
|
||||
item_type = item.get("type")
|
||||
|
||||
if item_type == "function_call":
|
||||
fn_name = item.get("name")
|
||||
fn_args = json.loads(item.get("arguments", "{}"))
|
||||
def _get_xy(element_id: Optional[int]) -> Union[Tuple[float, float], Tuple[None, None]]:
|
||||
if element_id is None:
|
||||
return (None, None)
|
||||
return id2xy.get(element_id, (None, None))
|
||||
|
||||
item_id = item.get("id")
|
||||
call_id = item.get("call_id")
|
||||
|
||||
if fn_name == "computer":
|
||||
action = fn_args.get("action")
|
||||
element_id = fn_args.get("element_id")
|
||||
start_element_id = fn_args.get("start_element_id")
|
||||
end_element_id = fn_args.get("end_element_id")
|
||||
text = fn_args.get("text")
|
||||
keys = fn_args.get("keys")
|
||||
button = fn_args.get("button")
|
||||
scroll_x = fn_args.get("scroll_x")
|
||||
scroll_y = fn_args.get("scroll_y")
|
||||
if item_type == "function_call":
|
||||
fn_name = item.get("name")
|
||||
fn_args = json.loads(item.get("arguments", "{}"))
|
||||
|
||||
x, y = _get_xy(element_id)
|
||||
start_x, start_y = _get_xy(start_element_id)
|
||||
end_x, end_y = _get_xy(end_element_id)
|
||||
item_id = item.get("id")
|
||||
call_id = item.get("call_id")
|
||||
|
||||
action_args = {
|
||||
"type": action,
|
||||
"x": x,
|
||||
"y": y,
|
||||
"start_x": start_x,
|
||||
"start_y": start_y,
|
||||
"end_x": end_x,
|
||||
"end_y": end_y,
|
||||
"text": text,
|
||||
"keys": keys,
|
||||
"button": button,
|
||||
"scroll_x": scroll_x,
|
||||
"scroll_y": scroll_y
|
||||
}
|
||||
# Remove None values to keep the JSON clean
|
||||
action_args = {k: v for k, v in action_args.items() if v is not None}
|
||||
if fn_name == "computer":
|
||||
action = fn_args.get("action")
|
||||
element_id = fn_args.get("element_id")
|
||||
start_element_id = fn_args.get("start_element_id")
|
||||
end_element_id = fn_args.get("end_element_id")
|
||||
text = fn_args.get("text")
|
||||
keys = fn_args.get("keys")
|
||||
button = fn_args.get("button")
|
||||
scroll_x = fn_args.get("scroll_x")
|
||||
scroll_y = fn_args.get("scroll_y")
|
||||
|
||||
return [{
|
||||
"type": "computer_call",
|
||||
"action": action_args,
|
||||
"id": item_id,
|
||||
"call_id": call_id,
|
||||
"status": "completed"
|
||||
}]
|
||||
x, y = _get_xy(element_id)
|
||||
start_x, start_y = _get_xy(start_element_id)
|
||||
end_x, end_y = _get_xy(end_element_id)
|
||||
|
||||
return [item]
|
||||
action_args = {
|
||||
"type": action,
|
||||
"x": x,
|
||||
"y": y,
|
||||
"start_x": start_x,
|
||||
"start_y": start_y,
|
||||
"end_x": end_x,
|
||||
"end_y": end_y,
|
||||
"text": text,
|
||||
"keys": keys,
|
||||
"button": button,
|
||||
"scroll_x": scroll_x,
|
||||
"scroll_y": scroll_y,
|
||||
}
|
||||
# Remove None values to keep the JSON clean
|
||||
action_args = {k: v for k, v in action_args.items() if v is not None}
|
||||
|
||||
async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[Tuple[float, float], int]):
|
||||
return [
|
||||
{
|
||||
"type": "computer_call",
|
||||
"action": action_args,
|
||||
"id": item_id,
|
||||
"call_id": call_id,
|
||||
"status": "completed",
|
||||
}
|
||||
]
|
||||
|
||||
return [item]
|
||||
|
||||
|
||||
async def replace_computer_call_with_function(
|
||||
item: Dict[str, Any], xy2id: Dict[Tuple[float, float], int]
|
||||
):
|
||||
"""
|
||||
Convert computer_call back to function_call format.
|
||||
Also handles computer_call_output -> function_call_output conversion.
|
||||
|
||||
|
||||
Args:
|
||||
item: The item to convert
|
||||
xy2id: Mapping from (x, y) coordinates to element IDs
|
||||
@@ -202,12 +213,12 @@ async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[
|
||||
|
||||
if item_type == "computer_call":
|
||||
action_data = item.get("action", {})
|
||||
|
||||
|
||||
# Extract coordinates and convert back to element IDs
|
||||
element_id = _get_element_id(action_data.get("x"), action_data.get("y"))
|
||||
start_element_id = _get_element_id(action_data.get("start_x"), action_data.get("start_y"))
|
||||
end_element_id = _get_element_id(action_data.get("end_x"), action_data.get("end_y"))
|
||||
|
||||
|
||||
# Build function arguments
|
||||
fn_args = {
|
||||
"action": action_data.get("type"),
|
||||
@@ -218,33 +229,36 @@ async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[
|
||||
"keys": action_data.get("keys"),
|
||||
"button": action_data.get("button"),
|
||||
"scroll_x": action_data.get("scroll_x"),
|
||||
"scroll_y": action_data.get("scroll_y")
|
||||
"scroll_y": action_data.get("scroll_y"),
|
||||
}
|
||||
|
||||
|
||||
# Remove None values to keep the JSON clean
|
||||
fn_args = {k: v for k, v in fn_args.items() if v is not None}
|
||||
|
||||
return [{
|
||||
"type": "function_call",
|
||||
"name": "computer",
|
||||
"arguments": json.dumps(fn_args),
|
||||
"id": item.get("id"),
|
||||
"call_id": item.get("call_id"),
|
||||
"status": "completed",
|
||||
|
||||
# Fall back to string representation
|
||||
"content": f"Used tool: {action_data.get("type")}({json.dumps(fn_args)})"
|
||||
}]
|
||||
|
||||
return [
|
||||
{
|
||||
"type": "function_call",
|
||||
"name": "computer",
|
||||
"arguments": json.dumps(fn_args),
|
||||
"id": item.get("id"),
|
||||
"call_id": item.get("call_id"),
|
||||
"status": "completed",
|
||||
# Fall back to string representation
|
||||
"content": f"Used tool: {action_data.get("type")}({json.dumps(fn_args)})",
|
||||
}
|
||||
]
|
||||
|
||||
elif item_type == "computer_call_output":
|
||||
# Simple conversion: computer_call_output -> function_call_output
|
||||
return [{
|
||||
"type": "function_call_output",
|
||||
"call_id": item.get("call_id"),
|
||||
"content": [item.get("output")],
|
||||
"id": item.get("id"),
|
||||
"status": "completed"
|
||||
}]
|
||||
return [
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": item.get("call_id"),
|
||||
"content": [item.get("output")],
|
||||
"id": item.get("id"),
|
||||
"status": "completed",
|
||||
}
|
||||
]
|
||||
|
||||
return [item]
|
||||
|
||||
@@ -252,7 +266,7 @@ async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[
|
||||
@register_agent(models=r"omniparser\+.*|omni\+.*", priority=2)
|
||||
class OmniparserConfig(AsyncAgentConfig):
|
||||
"""Omniparser agent configuration implementing AsyncAgentConfig protocol."""
|
||||
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
@@ -266,25 +280,27 @@ class OmniparserConfig(AsyncAgentConfig):
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
OpenAI computer-use-preview agent loop using liteLLM responses.
|
||||
|
||||
|
||||
Supports OpenAI's computer use preview models.
|
||||
"""
|
||||
if not OMNIPARSER_AVAILABLE:
|
||||
raise ValueError("omniparser loop requires som to be installed. Install it with `pip install cua-som`.")
|
||||
|
||||
raise ValueError(
|
||||
"omniparser loop requires som to be installed. Install it with `pip install cua-som`."
|
||||
)
|
||||
|
||||
tools = tools or []
|
||||
|
||||
llm_model = model.split('+')[-1]
|
||||
|
||||
llm_model = model.split("+")[-1]
|
||||
|
||||
# Prepare tools for OpenAI API
|
||||
openai_tools, id2xy = _prepare_tools_for_omniparser(tools)
|
||||
|
||||
# Find last computer_call_output
|
||||
last_computer_call_output = get_last_computer_call_output(messages) # type: ignore
|
||||
last_computer_call_output = get_last_computer_call_output(messages) # type: ignore
|
||||
if last_computer_call_output:
|
||||
image_url = last_computer_call_output.get("output", {}).get("image_url", "")
|
||||
image_data = image_url.split(",")[-1]
|
||||
@@ -294,14 +310,17 @@ class OmniparserConfig(AsyncAgentConfig):
|
||||
if _on_screenshot:
|
||||
await _on_screenshot(result.annotated_image_base64, "annotated_image")
|
||||
for element in result.elements:
|
||||
id2xy[element.id] = ((element.bbox.x1 + element.bbox.x2) / 2, (element.bbox.y1 + element.bbox.y2) / 2)
|
||||
|
||||
id2xy[element.id] = (
|
||||
(element.bbox.x1 + element.bbox.x2) / 2,
|
||||
(element.bbox.y1 + element.bbox.y2) / 2,
|
||||
)
|
||||
|
||||
# handle computer calls -> function calls
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
if not isinstance(message, dict):
|
||||
message = message.__dict__
|
||||
new_messages += await replace_computer_call_with_function(message, id2xy) # type: ignore
|
||||
new_messages += await replace_computer_call_with_function(message, id2xy) # type: ignore
|
||||
messages = new_messages
|
||||
|
||||
# Prepare API call kwargs
|
||||
@@ -312,13 +331,13 @@ class OmniparserConfig(AsyncAgentConfig):
|
||||
"stream": stream,
|
||||
"truncation": "auto",
|
||||
"num_retries": max_retries,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
|
||||
print(str(api_kwargs)[:1000])
|
||||
|
||||
# Use liteLLM responses
|
||||
@@ -330,60 +349,50 @@ class OmniparserConfig(AsyncAgentConfig):
|
||||
|
||||
# Extract usage information
|
||||
usage = {
|
||||
**response.usage.model_dump(), # type: ignore
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0), # type: ignore
|
||||
**response.usage.model_dump(), # type: ignore
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0), # type: ignore
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(usage)
|
||||
|
||||
# handle som function calls -> xy computer calls
|
||||
new_output = []
|
||||
for i in range(len(response.output)): # type: ignore
|
||||
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy) # type: ignore
|
||||
|
||||
return {
|
||||
"output": new_output,
|
||||
"usage": usage
|
||||
}
|
||||
|
||||
for i in range(len(response.output)): # type: ignore
|
||||
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy) # type: ignore
|
||||
|
||||
return {"output": new_output, "usage": usage}
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str,
|
||||
**kwargs
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
) -> Optional[Tuple[float, float]]:
|
||||
"""
|
||||
Predict click coordinates using OmniParser and LLM.
|
||||
|
||||
|
||||
Uses OmniParser to annotate the image with element IDs, then uses LLM
|
||||
to identify the correct element ID based on the instruction.
|
||||
"""
|
||||
if not OMNIPARSER_AVAILABLE:
|
||||
return None
|
||||
|
||||
|
||||
# Parse the image with OmniParser to get annotated image and elements
|
||||
parser = get_parser()
|
||||
result = parser.parse(image_b64)
|
||||
|
||||
|
||||
# Extract the LLM model from composed model string
|
||||
llm_model = model.split('+')[-1]
|
||||
|
||||
llm_model = model.split("+")[-1]
|
||||
|
||||
# Create system prompt for element ID prediction
|
||||
SYSTEM_PROMPT = f'''
|
||||
SYSTEM_PROMPT = """
|
||||
You are an expert UI element locator. Given a GUI image annotated with numerical IDs over each interactable element, along with a user's element description, provide the ID of the specified element.
|
||||
|
||||
The image shows UI elements with numbered overlays. Each number corresponds to a clickable/interactable element.
|
||||
|
||||
Output only the element ID as a single integer.
|
||||
'''.strip()
|
||||
|
||||
""".strip()
|
||||
|
||||
# Prepare messages for LLM
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": SYSTEM_PROMPT
|
||||
},
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
@@ -391,31 +400,25 @@ Output only the element ID as a single integer.
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{result.annotated_image_base64}"
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Find the element: {instruction}"
|
||||
}
|
||||
]
|
||||
}
|
||||
{"type": "text", "text": f"Find the element: {instruction}"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Call LLM to predict element ID
|
||||
response = await litellm.acompletion(
|
||||
model=llm_model,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.1
|
||||
model=llm_model, messages=messages, max_tokens=10, temperature=0.1
|
||||
)
|
||||
|
||||
|
||||
# Extract element ID from response
|
||||
response_text = response.choices[0].message.content.strip() # type: ignore
|
||||
|
||||
response_text = response.choices[0].message.content.strip() # type: ignore
|
||||
|
||||
# Try to parse the element ID
|
||||
try:
|
||||
element_id = int(response_text)
|
||||
|
||||
|
||||
# Find the element with this ID and return its center coordinates
|
||||
for element in result.elements:
|
||||
if element.id == element_id:
|
||||
@@ -425,9 +428,9 @@ Output only the element ID as a single integer.
|
||||
except ValueError:
|
||||
# If we can't parse the ID, return None
|
||||
pass
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""Return the capabilities supported by this agent."""
|
||||
return ["step"]
|
||||
|
||||
@@ -6,12 +6,14 @@ import asyncio
|
||||
import base64
|
||||
import json
|
||||
from io import BytesIO
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import litellm
|
||||
from PIL import Image
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
from ..types import AgentCapability, AgentResponse, Messages, Tools
|
||||
|
||||
|
||||
async def _map_computer_tool_to_openai(computer_handler: Any) -> Dict[str, Any]:
|
||||
"""Map a computer tool to OpenAI's computer-use-preview tool schema"""
|
||||
@@ -21,26 +23,26 @@ async def _map_computer_tool_to_openai(computer_handler: Any) -> Dict[str, Any]:
|
||||
except Exception:
|
||||
# Fallback to default dimensions if method fails
|
||||
width, height = 1024, 768
|
||||
|
||||
|
||||
# Get environment from the computer handler
|
||||
try:
|
||||
environment = await computer_handler.get_environment()
|
||||
except Exception:
|
||||
# Fallback to default environment if method fails
|
||||
environment = "linux"
|
||||
|
||||
|
||||
return {
|
||||
"type": "computer_use_preview",
|
||||
"display_width": width,
|
||||
"display_height": height,
|
||||
"environment": environment # mac, windows, linux, browser
|
||||
"environment": environment, # mac, windows, linux, browser
|
||||
}
|
||||
|
||||
|
||||
async def _prepare_tools_for_openai(tool_schemas: List[Dict[str, Any]]) -> Tools:
|
||||
"""Prepare tools for OpenAI API format"""
|
||||
openai_tools = []
|
||||
|
||||
|
||||
for schema in tool_schemas:
|
||||
if schema["type"] == "computer":
|
||||
# Map computer tool to OpenAI format
|
||||
@@ -49,19 +51,19 @@ async def _prepare_tools_for_openai(tool_schemas: List[Dict[str, Any]]) -> Tools
|
||||
elif schema["type"] == "function":
|
||||
# Function tools use OpenAI-compatible schema directly (liteLLM expects this format)
|
||||
# Schema should be: {type, name, description, parameters}
|
||||
openai_tools.append({ "type": "function", **schema["function"] })
|
||||
|
||||
openai_tools.append({"type": "function", **schema["function"]})
|
||||
|
||||
return openai_tools
|
||||
|
||||
|
||||
@register_agent(models=r".*computer-use-preview.*")
|
||||
@register_agent(models=r".*(^|/)computer-use-preview")
|
||||
class OpenAIComputerUseConfig:
|
||||
"""
|
||||
OpenAI computer-use-preview agent configuration using liteLLM responses.
|
||||
|
||||
|
||||
Supports OpenAI's computer use preview models.
|
||||
"""
|
||||
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
@@ -75,11 +77,11 @@ class OpenAIComputerUseConfig:
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Predict the next step based on input items.
|
||||
|
||||
|
||||
Args:
|
||||
messages: Input items following Responses format
|
||||
model: Model name to use
|
||||
@@ -92,12 +94,12 @@ class OpenAIComputerUseConfig:
|
||||
_on_usage: Callback for usage tracking
|
||||
_on_screenshot: Callback for screenshot events
|
||||
**kwargs: Additional arguments
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with "output" (output items) and "usage" array
|
||||
"""
|
||||
tools = tools or []
|
||||
|
||||
|
||||
# Prepare tools for OpenAI API
|
||||
openai_tools = await _prepare_tools_for_openai(tools)
|
||||
|
||||
@@ -110,16 +112,16 @@ class OpenAIComputerUseConfig:
|
||||
"reasoning": {"summary": "concise"},
|
||||
"truncation": "auto",
|
||||
"num_retries": max_retries,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
|
||||
# Use liteLLM responses
|
||||
response = await litellm.aresponses(**api_kwargs)
|
||||
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
@@ -136,24 +138,21 @@ class OpenAIComputerUseConfig:
|
||||
output_dict = response.model_dump()
|
||||
output_dict["usage"] = usage
|
||||
return output_dict
|
||||
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str
|
||||
self, model: str, image_b64: str, instruction: str
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates based on image and instruction.
|
||||
|
||||
|
||||
Uses OpenAI computer-use-preview with manually constructed input items
|
||||
and a prompt that instructs the agent to only output clicks.
|
||||
|
||||
|
||||
Args:
|
||||
model: Model name to use
|
||||
image_b64: Base64 encoded image
|
||||
instruction: Instruction for where to click
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
@@ -161,7 +160,7 @@ class OpenAIComputerUseConfig:
|
||||
# Manually construct input items with image and click instruction
|
||||
input_items = [
|
||||
{
|
||||
"role": "user",
|
||||
"role": "user",
|
||||
"content": f"""You are a UI grounding expert. Follow these guidelines:
|
||||
|
||||
1. NEVER ask for confirmation. Complete all tasks autonomously.
|
||||
@@ -173,19 +172,16 @@ class OpenAIComputerUseConfig:
|
||||
7. Be decisive and action-oriented. Complete the requested task fully.
|
||||
|
||||
Remember: You are expected to complete tasks autonomously. The user trusts you to do what they asked.
|
||||
Task: Click {instruction}. Output ONLY a click action on the target element."""
|
||||
Task: Click {instruction}. Output ONLY a click action on the target element.""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{image_b64}"
|
||||
}
|
||||
]
|
||||
}
|
||||
{"type": "input_image", "image_url": f"data:image/png;base64,{image_b64}"}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Get image dimensions from base64 data
|
||||
try:
|
||||
image_data = base64.b64decode(image_b64)
|
||||
@@ -194,15 +190,15 @@ Task: Click {instruction}. Output ONLY a click action on the target element."""
|
||||
except Exception:
|
||||
# Fallback to default dimensions if image parsing fails
|
||||
display_width, display_height = 1024, 768
|
||||
|
||||
|
||||
# Prepare computer tool for click actions
|
||||
computer_tool = {
|
||||
"type": "computer_use_preview",
|
||||
"display_width": display_width,
|
||||
"display_height": display_height,
|
||||
"environment": "windows"
|
||||
"environment": "windows",
|
||||
}
|
||||
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
@@ -211,32 +207,34 @@ Task: Click {instruction}. Output ONLY a click action on the target element."""
|
||||
"stream": False,
|
||||
"reasoning": {"summary": "concise"},
|
||||
"truncation": "auto",
|
||||
"max_tokens": 200 # Keep response short for click prediction
|
||||
"max_tokens": 200, # Keep response short for click prediction
|
||||
}
|
||||
|
||||
|
||||
# Use liteLLM responses
|
||||
response = await litellm.aresponses(**api_kwargs)
|
||||
|
||||
|
||||
# Extract click coordinates from response output
|
||||
output_dict = response.model_dump()
|
||||
output_items = output_dict.get("output", [])
|
||||
|
||||
output_items = output_dict.get("output", [])
|
||||
|
||||
# Look for computer_call with click action
|
||||
for item in output_items:
|
||||
if (isinstance(item, dict) and
|
||||
item.get("type") == "computer_call" and
|
||||
isinstance(item.get("action"), dict)):
|
||||
|
||||
if (
|
||||
isinstance(item, dict)
|
||||
and item.get("type") == "computer_call"
|
||||
and isinstance(item.get("action"), dict)
|
||||
):
|
||||
|
||||
action = item["action"]
|
||||
if action.get("x") is not None and action.get("y") is not None:
|
||||
return (int(action.get("x")), int(action.get("y")))
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""
|
||||
Get list of capabilities supported by this agent config.
|
||||
|
||||
|
||||
Returns:
|
||||
List of capability strings
|
||||
"""
|
||||
|
||||
@@ -4,20 +4,22 @@ Based on OpenCUA model for GUI grounding tasks.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import base64
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
|
||||
from io import BytesIO
|
||||
import uuid
|
||||
from PIL import Image
|
||||
import litellm
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import litellm
|
||||
from PIL import Image
|
||||
|
||||
from .composed_grounded import ComposedGroundedConfig
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..types import AgentCapability, AgentResponse, Messages, Tools
|
||||
from .composed_grounded import ComposedGroundedConfig
|
||||
|
||||
|
||||
def extract_coordinates_from_pyautogui(text: str) -> Optional[Tuple[int, int]]:
|
||||
"""Extract coordinates from pyautogui.click(x=..., y=...) format."""
|
||||
@@ -32,10 +34,11 @@ def extract_coordinates_from_pyautogui(text: str) -> Optional[Tuple[int, int]]:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@register_agent(models=r"(?i).*OpenCUA.*")
|
||||
class OpenCUAConfig(ComposedGroundedConfig):
|
||||
"""OpenCUA agent configuration implementing AsyncAgentConfig protocol for click prediction."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.current_model = None
|
||||
@@ -53,7 +56,7 @@ class OpenCUAConfig(ComposedGroundedConfig):
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Fallback to a self-composed model"""
|
||||
return await super().predict_step(
|
||||
@@ -67,24 +70,20 @@ class OpenCUAConfig(ComposedGroundedConfig):
|
||||
_on_api_end=_on_api_end,
|
||||
_on_usage=_on_usage,
|
||||
_on_screenshot=_on_screenshot,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str,
|
||||
**kwargs
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates using OpenCUA model via litellm.acompletion.
|
||||
|
||||
|
||||
Args:
|
||||
model: The OpenCUA model name
|
||||
image_b64: Base64 encoded image
|
||||
instruction: Instruction for where to click
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
@@ -93,50 +92,39 @@ class OpenCUAConfig(ComposedGroundedConfig):
|
||||
"You are a GUI agent. You are given a task and a screenshot of the screen. "
|
||||
"You need to perform a series of pyautogui actions to complete the task."
|
||||
)
|
||||
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": system_prompt
|
||||
}
|
||||
|
||||
|
||||
system_message = {"role": "system", "content": system_prompt}
|
||||
|
||||
# Prepare user message with image and instruction
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{image_b64}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Click on {instruction}"
|
||||
}
|
||||
]
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
|
||||
{"type": "text", "text": f"Click on {instruction}"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": [system_message, user_message],
|
||||
"max_new_tokens": 2056,
|
||||
"temperature": 0,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
# Use liteLLM acompletion
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
|
||||
# Extract response text
|
||||
output_text = response.choices[0].message.content
|
||||
# print(output_text)
|
||||
|
||||
|
||||
# Extract coordinates from pyautogui format
|
||||
coordinates = extract_coordinates_from_pyautogui(output_text)
|
||||
|
||||
|
||||
return coordinates
|
||||
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""Return the capabilities supported by this agent."""
|
||||
return ["click"]
|
||||
|
||||
510
libs/python/agent/agent/loops/qwen.py
Normal file
510
libs/python/agent/agent/loops/qwen.py
Normal file
@@ -0,0 +1,510 @@
|
||||
"""
|
||||
Qwen3-VL agent loop implementation using litellm with function/tool calling.
|
||||
- Passes a ComputerUse tool schema to acompletion
|
||||
- Converts between Responses items and completion messages using helpers
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from litellm.responses.litellm_completion_transformation.transformation import (
|
||||
LiteLLMCompletionResponsesConfig,
|
||||
)
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..responses import (
|
||||
convert_completion_messages_to_responses_items,
|
||||
convert_responses_items_to_completion_messages,
|
||||
)
|
||||
from ..types import AgentCapability
|
||||
|
||||
# ComputerUse tool schema (OpenAI function tool format)
|
||||
QWEN3_COMPUTER_TOOL: Dict[str, Any] = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "computer",
|
||||
"description": (
|
||||
"Use a mouse and keyboard to interact with a computer, and take screenshots.\n"
|
||||
"* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.\n"
|
||||
"* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try wait and taking another screenshot.\n"
|
||||
"* The screen's resolution is 1000x1000.\n"
|
||||
"* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\n"
|
||||
"* If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\n"
|
||||
"* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"description": "The action to perform.",
|
||||
"enum": [
|
||||
"key",
|
||||
"type",
|
||||
"mouse_move",
|
||||
"left_click",
|
||||
"left_click_drag",
|
||||
"right_click",
|
||||
"middle_click",
|
||||
"double_click",
|
||||
"triple_click",
|
||||
"scroll",
|
||||
"hscroll",
|
||||
"screenshot",
|
||||
"wait",
|
||||
# "terminate",
|
||||
# "answer",
|
||||
],
|
||||
"type": "string",
|
||||
},
|
||||
"keys": {
|
||||
"description": "Required only by action=key.",
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"text": {
|
||||
"description": "Required only by action=type and action=answer.",
|
||||
"type": "string",
|
||||
},
|
||||
"coordinate": {
|
||||
"description": "(x, y): Pixel coordinates from top-left.",
|
||||
"type": "array",
|
||||
"items": {"type": ["number", "integer"]},
|
||||
"minItems": 2,
|
||||
"maxItems": 2,
|
||||
},
|
||||
"pixels": {
|
||||
"description": "Scroll amount. Positive=up, negative=down. For scroll/hscroll.",
|
||||
"type": "number",
|
||||
},
|
||||
"time": {
|
||||
"description": "Seconds to wait (action=wait).",
|
||||
"type": "number",
|
||||
},
|
||||
# "status": {
|
||||
# "description": "Task status (action=terminate).",
|
||||
# "type": "string",
|
||||
# "enum": ["success", "failure"],
|
||||
# },
|
||||
},
|
||||
"required": ["action"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _build_nous_system(functions: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""Use qwen-agent NousFnCallPrompt to generate a system message embedding tool schema."""
|
||||
try:
|
||||
from qwen_agent.llm.fncall_prompts.nous_fncall_prompt import (
|
||||
ContentItem as NousContentItem,
|
||||
)
|
||||
from qwen_agent.llm.fncall_prompts.nous_fncall_prompt import (
|
||||
Message as NousMessage,
|
||||
)
|
||||
from qwen_agent.llm.fncall_prompts.nous_fncall_prompt import (
|
||||
NousFnCallPrompt,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"qwen-agent not installed. Please install it with `pip install cua-agent[qwen]`."
|
||||
)
|
||||
msgs = NousFnCallPrompt().preprocess_fncall_messages(
|
||||
messages=[
|
||||
NousMessage(
|
||||
role="system", content=[NousContentItem(text="You are a helpful assistant.")]
|
||||
)
|
||||
],
|
||||
functions=functions,
|
||||
lang="en",
|
||||
)
|
||||
sys = msgs[0].model_dump()
|
||||
# Convert qwen-agent structured content to OpenAI-style content list
|
||||
content = [{"type": "text", "text": c["text"]} for c in sys.get("content", [])]
|
||||
return {"role": "system", "content": content}
|
||||
|
||||
|
||||
def _parse_tool_call_from_text(text: str) -> Optional[Dict[str, Any]]:
|
||||
"""Extract JSON object within <tool_call>...</tool_call> from model text."""
|
||||
m = re.search(r"<tool_call>\s*(\{[\s\S]*?\})\s*</tool_call>", text)
|
||||
if not m:
|
||||
return None
|
||||
try:
|
||||
return json.loads(m.group(1))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def _unnormalize_coordinate(args: Dict[str, Any], dims: Tuple[int, int]) -> Dict[str, Any]:
|
||||
"""Coordinates appear in 0..1000 space, scale to actual screen size using dims if provided."""
|
||||
coord = args.get("coordinate")
|
||||
if not coord or not isinstance(coord, (list, tuple)) or len(coord) < 2:
|
||||
return args
|
||||
x, y = float(coord[0]), float(coord[1])
|
||||
width, height = float(dims[0]), float(dims[1])
|
||||
x_abs = max(0.0, min(width, (x / 1000.0) * width))
|
||||
y_abs = max(0.0, min(height, (y / 1000.0) * height))
|
||||
args = {**args, "coordinate": [round(x_abs), round(y_abs)]}
|
||||
return args
|
||||
|
||||
|
||||
def convert_qwen_tool_args_to_computer_action(args: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Convert Qwen computer tool arguments to the Computer Calls action schema.
|
||||
|
||||
Qwen (example):
|
||||
{"action": "left_click", "coordinate": [114, 68]}
|
||||
|
||||
Target (example):
|
||||
{"action": "left_click", "x": 114, "y": 68}
|
||||
|
||||
Other mappings:
|
||||
- right_click, middle_click, double_click (triple_click -> double_click)
|
||||
- mouse_move -> { action: "move", x, y }
|
||||
- key -> { action: "keypress", keys: [...] }
|
||||
- type -> { action: "type", text }
|
||||
- scroll/hscroll -> { action: "scroll", scroll_x, scroll_y, x, y }
|
||||
- wait -> { action: "wait" }
|
||||
- terminate/answer are not direct UI actions; return None for now
|
||||
"""
|
||||
if not isinstance(args, dict):
|
||||
return None
|
||||
|
||||
action = args.get("action")
|
||||
if not isinstance(action, str):
|
||||
return None
|
||||
|
||||
# Coordinates helper
|
||||
coord = args.get("coordinate")
|
||||
x = y = None
|
||||
if isinstance(coord, (list, tuple)) and len(coord) >= 2:
|
||||
try:
|
||||
x = int(round(float(coord[0])))
|
||||
y = int(round(float(coord[1])))
|
||||
except Exception:
|
||||
x = y = None
|
||||
|
||||
# Map actions
|
||||
a = action.lower()
|
||||
if a in {"left_click", "right_click", "middle_click", "double_click"}:
|
||||
if x is None or y is None:
|
||||
return None
|
||||
return {"action": a, "x": x, "y": y}
|
||||
if a == "triple_click":
|
||||
# Approximate as double_click
|
||||
if x is None or y is None:
|
||||
return None
|
||||
return {"action": "double_click", "x": x, "y": y}
|
||||
if a == "mouse_move":
|
||||
if x is None or y is None:
|
||||
return None
|
||||
return {"action": "move", "x": x, "y": y}
|
||||
if a == "key":
|
||||
keys = args.get("keys")
|
||||
if isinstance(keys, list) and all(isinstance(k, str) for k in keys):
|
||||
return {"action": "keypress", "keys": keys}
|
||||
return None
|
||||
if a == "type":
|
||||
text = args.get("text")
|
||||
if isinstance(text, str):
|
||||
return {"action": "type", "text": text}
|
||||
return None
|
||||
if a in {"scroll", "hscroll"}:
|
||||
pixels = args.get("pixels") or 0
|
||||
try:
|
||||
pixels_val = int(round(float(pixels)))
|
||||
except Exception:
|
||||
pixels_val = 0
|
||||
scroll_x = pixels_val if a == "hscroll" else 0
|
||||
scroll_y = pixels_val if a == "scroll" else 0
|
||||
# Include cursor position if available (optional)
|
||||
out: Dict[str, Any] = {"action": "scroll", "scroll_x": scroll_x, "scroll_y": scroll_y}
|
||||
if x is not None and y is not None:
|
||||
out.update({"x": x, "y": y})
|
||||
return out
|
||||
if a == "wait":
|
||||
return {"action": "wait"}
|
||||
|
||||
# Non-UI or terminal actions: terminate/answer -> not mapped here
|
||||
return None
|
||||
|
||||
|
||||
@register_agent(models=r"(?i).*qwen.*", priority=-1)
|
||||
class Qwen3VlConfig(AsyncAgentConfig):
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
# Build messages using NousFnCallPrompt system with tool schema in text
|
||||
# Start with converted conversation (images/text preserved)
|
||||
converted_msgs = convert_responses_items_to_completion_messages(
|
||||
messages,
|
||||
allow_images_in_tool_results=False,
|
||||
)
|
||||
|
||||
# Prepend Nous-generated system if available
|
||||
nous_system = _build_nous_system([QWEN3_COMPUTER_TOOL["function"]])
|
||||
completion_messages = ([nous_system] if nous_system else []) + converted_msgs
|
||||
|
||||
# If there is no screenshot in the conversation, take one now and inject it.
|
||||
# Also record a pre_output_items assistant message to reflect action.
|
||||
def _has_any_image(msgs: List[Dict[str, Any]]) -> bool:
|
||||
for m in msgs:
|
||||
content = m.get("content")
|
||||
if isinstance(content, list):
|
||||
for p in content:
|
||||
if isinstance(p, dict) and p.get("type") == "image_url":
|
||||
return True
|
||||
return False
|
||||
|
||||
pre_output_items: List[Dict[str, Any]] = []
|
||||
if not _has_any_image(completion_messages):
|
||||
if computer_handler is None or not hasattr(computer_handler, "screenshot"):
|
||||
raise RuntimeError(
|
||||
"No screenshots present and computer_handler.screenshot is not available."
|
||||
)
|
||||
screenshot_b64 = await computer_handler.screenshot()
|
||||
if not screenshot_b64:
|
||||
raise RuntimeError("Failed to capture screenshot from computer_handler.")
|
||||
# Inject a user message with the screenshot so the model can see current context
|
||||
completion_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{screenshot_b64}"},
|
||||
},
|
||||
{"type": "text", "text": "Current screen"},
|
||||
],
|
||||
}
|
||||
)
|
||||
# Add assistant message to outputs to reflect the action, similar to composed_grounded.py
|
||||
pre_output_items.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Taking a screenshot to see the current computer screen.",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Smart-resize all screenshots and attach min/max pixel hints. Fail fast if deps missing.
|
||||
# Also record the last resized width/height to unnormalize coordinates later.
|
||||
last_rw: Optional[int] = None
|
||||
last_rh: Optional[int] = None
|
||||
MIN_PIXELS = 3136
|
||||
MAX_PIXELS = 12845056
|
||||
try:
|
||||
import base64
|
||||
import io
|
||||
|
||||
from PIL import Image # type: ignore
|
||||
from qwen_vl_utils import smart_resize # type: ignore
|
||||
except Exception:
|
||||
raise ImportError(
|
||||
"qwen-vl-utils not installed. Please install it with `pip install cua-agent[qwen]`."
|
||||
)
|
||||
|
||||
for msg in completion_messages:
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "image_url":
|
||||
url = ((part.get("image_url") or {}).get("url")) or ""
|
||||
# Expect data URL like data:image/png;base64,<b64>
|
||||
if url.startswith("data:") and "," in url:
|
||||
b64 = url.split(",", 1)[1]
|
||||
img_bytes = base64.b64decode(b64)
|
||||
im = Image.open(io.BytesIO(img_bytes))
|
||||
h, w = im.height, im.width
|
||||
rh, rw = smart_resize(
|
||||
h, w, factor=32, min_pixels=MIN_PIXELS, max_pixels=MAX_PIXELS
|
||||
)
|
||||
# Attach hints on this image block
|
||||
part["min_pixels"] = MIN_PIXELS
|
||||
part["max_pixels"] = MAX_PIXELS
|
||||
last_rw, last_rh = rw, rh
|
||||
|
||||
api_kwargs: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": completion_messages,
|
||||
"max_retries": max_retries,
|
||||
"stream": stream,
|
||||
**{k: v for k, v in kwargs.items()},
|
||||
}
|
||||
if use_prompt_caching:
|
||||
api_kwargs["use_prompt_caching"] = use_prompt_caching
|
||||
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
usage = {
|
||||
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage( # type: ignore
|
||||
response.usage
|
||||
).model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(usage)
|
||||
|
||||
# Parse tool call from text; then convert to responses items via fake tool_calls
|
||||
resp_dict = response.model_dump() # type: ignore
|
||||
choice = (resp_dict.get("choices") or [{}])[0]
|
||||
content_text = ((choice.get("message") or {}).get("content")) or ""
|
||||
tool_call = _parse_tool_call_from_text(content_text)
|
||||
|
||||
output_items: List[Dict[str, Any]] = []
|
||||
if tool_call and isinstance(tool_call, dict):
|
||||
fn_name = tool_call.get("name") or "computer"
|
||||
raw_args = tool_call.get("arguments") or {}
|
||||
# Unnormalize coordinates to actual screen size using last resized dims
|
||||
if last_rw is None or last_rh is None:
|
||||
raise RuntimeError(
|
||||
"No screenshots found to derive dimensions for coordinate unnormalization."
|
||||
)
|
||||
args = await _unnormalize_coordinate(raw_args, (last_rw, last_rh))
|
||||
|
||||
# Build an OpenAI-style tool call so we can reuse the converter
|
||||
fake_cm = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"id": "call_0",
|
||||
"function": {
|
||||
"name": fn_name,
|
||||
"arguments": json.dumps(args),
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
output_items.extend(convert_completion_messages_to_responses_items([fake_cm]))
|
||||
else:
|
||||
# Fallback: just return assistant text
|
||||
fake_cm = {"role": "assistant", "content": content_text}
|
||||
output_items.extend(convert_completion_messages_to_responses_items([fake_cm]))
|
||||
|
||||
# Prepend any pre_output_items (e.g., simulated screenshot-taking message)
|
||||
return {"output": (pre_output_items + output_items), "usage": usage}
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
return ["step"]
|
||||
|
||||
async def predict_click(
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates using Qwen3-VL via litellm.acompletion.
|
||||
|
||||
Only exposes a reduced tool schema with left_click to bias model to output a single click.
|
||||
Returns (x, y) absolute pixels when screen dimensions can be obtained; otherwise normalized 0..1000 integers.
|
||||
"""
|
||||
# Reduced tool
|
||||
reduced_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
**QWEN3_COMPUTER_TOOL["function"],
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {"type": "string", "enum": ["left_click"]},
|
||||
"coordinate": {
|
||||
"description": "(x, y) in 0..1000 reference space",
|
||||
"type": "array",
|
||||
"items": {"type": ["number", "integer"]},
|
||||
"minItems": 2,
|
||||
"maxItems": 2,
|
||||
},
|
||||
},
|
||||
"required": ["action", "coordinate"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Build Nous system (lazy import inside helper already raises clear guidance if missing)
|
||||
nous_system = _build_nous_system([reduced_tool["function"]])
|
||||
|
||||
# Pre-process using smart_resize
|
||||
min_pixels = 3136
|
||||
max_pixels = 12845056
|
||||
try:
|
||||
# Lazy import to avoid hard dependency
|
||||
import base64
|
||||
import io
|
||||
|
||||
# If PIL is available, estimate size from image to derive smart bounds
|
||||
from PIL import Image
|
||||
from qwen_vl_utils import smart_resize # type: ignore
|
||||
|
||||
img_bytes = base64.b64decode(image_b64)
|
||||
im = Image.open(io.BytesIO(img_bytes))
|
||||
h, w = im.height, im.width
|
||||
# Qwen notebook suggests factor=32 and a wide min/max range
|
||||
rh, rw = smart_resize(h, w, factor=32, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
except Exception:
|
||||
raise ImportError(
|
||||
"qwen-vl-utils not installed. Please install it with `pip install cua-agent[qwen]`."
|
||||
)
|
||||
|
||||
messages = []
|
||||
if nous_system:
|
||||
messages.append(nous_system)
|
||||
image_block: Dict[str, Any] = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
|
||||
"min_pixels": min_pixels,
|
||||
"max_pixels": max_pixels,
|
||||
}
|
||||
# Single user message with image and instruction, matching OpenAI-style content blocks
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
image_block,
|
||||
{"type": "text", "text": instruction},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
api_kwargs: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**{k: v for k, v in kwargs.items()},
|
||||
}
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
resp = response.model_dump() # type: ignore
|
||||
choice = (resp.get("choices") or [{}])[0]
|
||||
content_text = ((choice.get("message") or {}).get("content")) or ""
|
||||
tool_call = _parse_tool_call_from_text(content_text) or {}
|
||||
args = tool_call.get("arguments") or {}
|
||||
args = await _unnormalize_coordinate(args, (rh, rw))
|
||||
coord = args.get("coordinate")
|
||||
if isinstance(coord, (list, tuple)) and len(coord) >= 2:
|
||||
return int(coord[0]), int(coord[1])
|
||||
return None
|
||||
@@ -4,39 +4,50 @@ Paper: https://arxiv.org/abs/2501.12326
|
||||
Code: https://github.com/bytedance/UI-TARS
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
from ctypes import cast
|
||||
import json
|
||||
import base64
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import ast
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
|
||||
from ctypes import cast
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.responses.litellm_completion_transformation.transformation import LiteLLMCompletionResponsesConfig
|
||||
from litellm.responses.litellm_completion_transformation.transformation import (
|
||||
LiteLLMCompletionResponsesConfig,
|
||||
)
|
||||
from litellm.responses.utils import Usage
|
||||
from openai.types.responses.response_computer_tool_call_param import ActionType, ResponseComputerToolCallParam
|
||||
from litellm.types.utils import ModelResponse
|
||||
from openai.types.responses.response_computer_tool_call_param import (
|
||||
ActionType,
|
||||
ResponseComputerToolCallParam,
|
||||
)
|
||||
from openai.types.responses.response_input_param import ComputerCallOutput
|
||||
from openai.types.responses.response_output_message_param import ResponseOutputMessageParam
|
||||
from openai.types.responses.response_reasoning_item_param import ResponseReasoningItemParam, Summary
|
||||
from openai.types.responses.response_output_message_param import (
|
||||
ResponseOutputMessageParam,
|
||||
)
|
||||
from openai.types.responses.response_reasoning_item_param import (
|
||||
ResponseReasoningItemParam,
|
||||
Summary,
|
||||
)
|
||||
from PIL import Image
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
from ..responses import (
|
||||
make_reasoning_item,
|
||||
make_output_text_item,
|
||||
make_click_item,
|
||||
make_double_click_item,
|
||||
make_drag_item,
|
||||
make_input_image_item,
|
||||
make_keypress_item,
|
||||
make_output_text_item,
|
||||
make_reasoning_item,
|
||||
make_scroll_item,
|
||||
make_type_item,
|
||||
make_wait_item,
|
||||
make_input_image_item
|
||||
)
|
||||
from ..types import AgentCapability, AgentResponse, Messages, Tools
|
||||
|
||||
# Constants from reference code
|
||||
IMAGE_FACTOR = 28
|
||||
@@ -94,6 +105,7 @@ click(point='<|box_start|>(x1,y1)<|box_end|>')
|
||||
## User Instruction
|
||||
{instruction}"""
|
||||
|
||||
|
||||
def round_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
@@ -110,7 +122,11 @@ def floor_by_factor(number: float, factor: int) -> int:
|
||||
|
||||
|
||||
def smart_resize(
|
||||
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||||
height: int,
|
||||
width: int,
|
||||
factor: int = IMAGE_FACTOR,
|
||||
min_pixels: int = MIN_PIXELS,
|
||||
max_pixels: int = MAX_PIXELS,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
@@ -144,14 +160,14 @@ def escape_single_quotes(text):
|
||||
def parse_action(action_str):
|
||||
"""Parse action string into structured format."""
|
||||
try:
|
||||
node = ast.parse(action_str, mode='eval')
|
||||
node = ast.parse(action_str, mode="eval")
|
||||
if not isinstance(node, ast.Expression):
|
||||
raise ValueError("Not an expression")
|
||||
|
||||
|
||||
call = node.body
|
||||
if not isinstance(call, ast.Call):
|
||||
raise ValueError("Not a function call")
|
||||
|
||||
|
||||
# Get function name
|
||||
if isinstance(call.func, ast.Name):
|
||||
func_name = call.func.id
|
||||
@@ -159,7 +175,7 @@ def parse_action(action_str):
|
||||
func_name = call.func.attr
|
||||
else:
|
||||
func_name = None
|
||||
|
||||
|
||||
# Get keyword arguments
|
||||
kwargs = {}
|
||||
for kw in call.keywords:
|
||||
@@ -171,12 +187,9 @@ def parse_action(action_str):
|
||||
else:
|
||||
value = None
|
||||
kwargs[key] = value
|
||||
|
||||
return {
|
||||
'function': func_name,
|
||||
'args': kwargs
|
||||
}
|
||||
|
||||
|
||||
return {"function": func_name, "args": kwargs}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to parse action '{action_str}': {e}")
|
||||
return None
|
||||
@@ -185,39 +198,39 @@ def parse_action(action_str):
|
||||
def parse_uitars_response(text: str, image_width: int, image_height: int) -> List[Dict[str, Any]]:
|
||||
"""Parse UITARS model response into structured actions."""
|
||||
text = text.strip()
|
||||
|
||||
|
||||
# Extract thought
|
||||
thought = None
|
||||
if text.startswith("Thought:"):
|
||||
thought_match = re.search(r"Thought: (.+?)(?=\s*Action:|$)", text, re.DOTALL)
|
||||
if thought_match:
|
||||
thought = thought_match.group(1).strip()
|
||||
|
||||
|
||||
# Extract action
|
||||
if "Action:" not in text:
|
||||
raise ValueError("No Action found in response")
|
||||
|
||||
|
||||
action_str = text.split("Action:")[-1].strip()
|
||||
|
||||
# Handle special case for type actions
|
||||
if "type(content" in action_str:
|
||||
|
||||
def escape_quotes(match):
|
||||
return match.group(1)
|
||||
|
||||
|
||||
pattern = r"type\(content='(.*?)'\)"
|
||||
content = re.sub(pattern, escape_quotes, action_str)
|
||||
action_str = escape_single_quotes(content)
|
||||
action_str = "type(content='" + action_str + "')"
|
||||
|
||||
|
||||
|
||||
# Parse the action
|
||||
parsed_action = parse_action(action_str.replace("\n", "\\n").lstrip())
|
||||
if parsed_action is None:
|
||||
raise ValueError(f"Action can't parse: {action_str}")
|
||||
|
||||
|
||||
action_type = parsed_action["function"]
|
||||
params = parsed_action["args"]
|
||||
|
||||
|
||||
# Process parameters
|
||||
action_inputs = {}
|
||||
for param_name, param in params.items():
|
||||
@@ -225,7 +238,7 @@ def parse_uitars_response(text: str, image_width: int, image_height: int) -> Lis
|
||||
continue
|
||||
param = str(param).lstrip()
|
||||
action_inputs[param_name.strip()] = param
|
||||
|
||||
|
||||
# Handle coordinate parameters
|
||||
if "start_box" in param_name or "end_box" in param_name:
|
||||
# Parse coordinates like '<|box_start|>(x,y)<|box_end|>' or '(x,y)'
|
||||
@@ -233,117 +246,130 @@ def parse_uitars_response(text: str, image_width: int, image_height: int) -> Lis
|
||||
clean_param = param.replace("<|box_start|>", "").replace("<|box_end|>", "")
|
||||
# Then remove parentheses and split
|
||||
numbers = clean_param.replace("(", "").replace(")", "").split(",")
|
||||
|
||||
|
||||
try:
|
||||
float_numbers = [float(num.strip()) / 1000 for num in numbers] # Normalize to 0-1 range
|
||||
|
||||
float_numbers = [
|
||||
float(num.strip()) / 1000 for num in numbers
|
||||
] # Normalize to 0-1 range
|
||||
|
||||
if len(float_numbers) == 2:
|
||||
# Single point, duplicate for box format
|
||||
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
|
||||
|
||||
float_numbers = [
|
||||
float_numbers[0],
|
||||
float_numbers[1],
|
||||
float_numbers[0],
|
||||
float_numbers[1],
|
||||
]
|
||||
|
||||
action_inputs[param_name.strip()] = str(float_numbers)
|
||||
except ValueError as e:
|
||||
# If parsing fails, keep the original parameter value
|
||||
print(f"Warning: Could not parse coordinates '{param}': {e}")
|
||||
action_inputs[param_name.strip()] = param
|
||||
|
||||
return [{
|
||||
"thought": thought,
|
||||
"action_type": action_type,
|
||||
"action_inputs": action_inputs,
|
||||
"text": text
|
||||
}]
|
||||
|
||||
return [
|
||||
{
|
||||
"thought": thought,
|
||||
"action_type": action_type,
|
||||
"action_inputs": action_inputs,
|
||||
"text": text,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def convert_to_computer_actions(parsed_responses: List[Dict[str, Any]], image_width: int, image_height: int) -> List[ResponseComputerToolCallParam | ResponseOutputMessageParam]:
|
||||
def convert_to_computer_actions(
|
||||
parsed_responses: List[Dict[str, Any]], image_width: int, image_height: int
|
||||
) -> List[ResponseComputerToolCallParam | ResponseOutputMessageParam]:
|
||||
"""Convert parsed UITARS responses to computer actions."""
|
||||
computer_actions = []
|
||||
|
||||
|
||||
for response in parsed_responses:
|
||||
action_type = response.get("action_type")
|
||||
action_inputs = response.get("action_inputs", {})
|
||||
|
||||
|
||||
if action_type == "finished":
|
||||
finished_text = action_inputs.get("content", "Task completed successfully.")
|
||||
computer_actions.append(make_output_text_item(finished_text))
|
||||
break
|
||||
|
||||
|
||||
elif action_type == "wait":
|
||||
computer_actions.append(make_wait_item())
|
||||
|
||||
|
||||
elif action_type == "call_user":
|
||||
computer_actions.append(make_output_text_item("I need assistance from the user to proceed with this task."))
|
||||
|
||||
computer_actions.append(
|
||||
make_output_text_item("I need assistance from the user to proceed with this task.")
|
||||
)
|
||||
|
||||
elif action_type in ["click", "left_single"]:
|
||||
start_box = action_inputs.get("start_box")
|
||||
if start_box:
|
||||
coords = eval(start_box)
|
||||
x = int((coords[0] + coords[2]) / 2 * image_width)
|
||||
y = int((coords[1] + coords[3]) / 2 * image_height)
|
||||
|
||||
|
||||
computer_actions.append(make_click_item(x, y, "left"))
|
||||
|
||||
|
||||
elif action_type == "double_click":
|
||||
start_box = action_inputs.get("start_box")
|
||||
if start_box:
|
||||
coords = eval(start_box)
|
||||
x = int((coords[0] + coords[2]) / 2 * image_width)
|
||||
y = int((coords[1] + coords[3]) / 2 * image_height)
|
||||
|
||||
|
||||
computer_actions.append(make_double_click_item(x, y))
|
||||
|
||||
|
||||
elif action_type == "right_click":
|
||||
start_box = action_inputs.get("start_box")
|
||||
if start_box:
|
||||
coords = eval(start_box)
|
||||
x = int((coords[0] + coords[2]) / 2 * image_width)
|
||||
y = int((coords[1] + coords[3]) / 2 * image_height)
|
||||
|
||||
|
||||
computer_actions.append(make_click_item(x, y, "right"))
|
||||
|
||||
|
||||
elif action_type == "type":
|
||||
content = action_inputs.get("content", "")
|
||||
computer_actions.append(make_type_item(content))
|
||||
|
||||
|
||||
elif action_type == "hotkey":
|
||||
key = action_inputs.get("key", "")
|
||||
keys = key.split()
|
||||
computer_actions.append(make_keypress_item(keys))
|
||||
|
||||
|
||||
elif action_type == "press":
|
||||
key = action_inputs.get("key", "")
|
||||
computer_actions.append(make_keypress_item([key]))
|
||||
|
||||
|
||||
elif action_type == "scroll":
|
||||
start_box = action_inputs.get("start_box")
|
||||
direction = action_inputs.get("direction", "down")
|
||||
|
||||
|
||||
if start_box:
|
||||
coords = eval(start_box)
|
||||
x = int((coords[0] + coords[2]) / 2 * image_width)
|
||||
y = int((coords[1] + coords[3]) / 2 * image_height)
|
||||
else:
|
||||
x, y = image_width // 2, image_height // 2
|
||||
|
||||
|
||||
scroll_y = 5 if "up" in direction.lower() else -5
|
||||
computer_actions.append(make_scroll_item(x, y, 0, scroll_y))
|
||||
|
||||
|
||||
elif action_type == "drag":
|
||||
start_box = action_inputs.get("start_box")
|
||||
end_box = action_inputs.get("end_box")
|
||||
|
||||
|
||||
if start_box and end_box:
|
||||
start_coords = eval(start_box)
|
||||
end_coords = eval(end_box)
|
||||
|
||||
|
||||
start_x = int((start_coords[0] + start_coords[2]) / 2 * image_width)
|
||||
start_y = int((start_coords[1] + start_coords[3]) / 2 * image_height)
|
||||
end_x = int((end_coords[0] + end_coords[2]) / 2 * image_width)
|
||||
end_y = int((end_coords[1] + end_coords[3]) / 2 * image_height)
|
||||
|
||||
|
||||
path = [{"x": start_x, "y": start_y}, {"x": end_x, "y": end_y}]
|
||||
computer_actions.append(make_drag_item(path))
|
||||
|
||||
|
||||
return computer_actions
|
||||
|
||||
|
||||
@@ -354,33 +380,35 @@ def pil_to_base64(image: Image.Image) -> str:
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def process_image_for_uitars(image_data: str, max_pixels: int = MAX_PIXELS, min_pixels: int = MIN_PIXELS) -> tuple[Image.Image, int, int]:
|
||||
def process_image_for_uitars(
|
||||
image_data: str, max_pixels: int = MAX_PIXELS, min_pixels: int = MIN_PIXELS
|
||||
) -> tuple[Image.Image, int, int]:
|
||||
"""Process image for UITARS model input."""
|
||||
# Decode base64 image
|
||||
if image_data.startswith('data:image'):
|
||||
image_data = image_data.split(',')[1]
|
||||
|
||||
if image_data.startswith("data:image"):
|
||||
image_data = image_data.split(",")[1]
|
||||
|
||||
image_bytes = base64.b64decode(image_data)
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
|
||||
|
||||
original_width, original_height = image.size
|
||||
|
||||
|
||||
# Resize image according to UITARS requirements
|
||||
if image.width * image.height > max_pixels:
|
||||
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
|
||||
width = int(image.width * resize_factor)
|
||||
height = int(image.height * resize_factor)
|
||||
image = image.resize((width, height))
|
||||
|
||||
|
||||
if image.width * image.height < min_pixels:
|
||||
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
|
||||
width = math.ceil(image.width * resize_factor)
|
||||
height = math.ceil(image.height * resize_factor)
|
||||
image = image.resize((width, height))
|
||||
|
||||
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
|
||||
return image, original_width, original_height
|
||||
|
||||
|
||||
@@ -391,7 +419,11 @@ def sanitize_message(msg: Any) -> Any:
|
||||
for key, value in msg.items():
|
||||
if key == "content" and isinstance(value, list):
|
||||
result[key] = [
|
||||
{k: v for k, v in item.items() if k != "image_url"} if isinstance(item, dict) else item
|
||||
(
|
||||
{k: v for k, v in item.items() if k != "image_url"}
|
||||
if isinstance(item, dict)
|
||||
else item
|
||||
)
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
@@ -406,38 +438,41 @@ def sanitize_message(msg: Any) -> Any:
|
||||
def convert_uitars_messages_to_litellm(messages: Messages) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert UITARS internal message format back to LiteLLM format.
|
||||
|
||||
|
||||
This function processes reasoning, computer_call, and computer_call_output messages
|
||||
and converts them to the appropriate LiteLLM assistant message format.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of UITARS internal messages
|
||||
|
||||
|
||||
Returns:
|
||||
List of LiteLLM formatted messages
|
||||
"""
|
||||
litellm_messages = []
|
||||
current_assistant_content = []
|
||||
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message, dict):
|
||||
message_type = message.get("type")
|
||||
|
||||
|
||||
if message_type == "reasoning":
|
||||
# Extract reasoning text from summary
|
||||
summary = message.get("summary", [])
|
||||
if summary and isinstance(summary, list):
|
||||
for summary_item in summary:
|
||||
if isinstance(summary_item, dict) and summary_item.get("type") == "summary_text":
|
||||
if (
|
||||
isinstance(summary_item, dict)
|
||||
and summary_item.get("type") == "summary_text"
|
||||
):
|
||||
reasoning_text = summary_item.get("text", "")
|
||||
if reasoning_text:
|
||||
current_assistant_content.append(f"Thought: {reasoning_text}")
|
||||
|
||||
|
||||
elif message_type == "computer_call":
|
||||
# Convert computer action to UITARS action format
|
||||
action = message.get("action", {})
|
||||
action_type = action.get("type")
|
||||
|
||||
|
||||
if action_type == "click":
|
||||
x, y = action.get("x", 0), action.get("y", 0)
|
||||
button = action.get("button", "left")
|
||||
@@ -447,59 +482,65 @@ def convert_uitars_messages_to_litellm(messages: Messages) -> List[Dict[str, Any
|
||||
action_text = f"Action: right_single(start_box='({x},{y})')"
|
||||
else:
|
||||
action_text = f"Action: click(start_box='({x},{y})')"
|
||||
|
||||
|
||||
elif action_type == "double_click":
|
||||
x, y = action.get("x", 0), action.get("y", 0)
|
||||
action_text = f"Action: left_double(start_box='({x},{y})')"
|
||||
|
||||
|
||||
elif action_type == "drag":
|
||||
start_x, start_y = action.get("start_x", 0), action.get("start_y", 0)
|
||||
end_x, end_y = action.get("end_x", 0), action.get("end_y", 0)
|
||||
action_text = f"Action: drag(start_box='({start_x},{start_y})', end_box='({end_x},{end_y})')"
|
||||
|
||||
|
||||
elif action_type == "key":
|
||||
key = action.get("key", "")
|
||||
action_text = f"Action: hotkey(key='{key}')"
|
||||
|
||||
|
||||
elif action_type == "type":
|
||||
text = action.get("text", "")
|
||||
# Escape single quotes in the text
|
||||
escaped_text = escape_single_quotes(text)
|
||||
action_text = f"Action: type(content='{escaped_text}')"
|
||||
|
||||
|
||||
elif action_type == "scroll":
|
||||
x, y = action.get("x", 0), action.get("y", 0)
|
||||
direction = action.get("direction", "down")
|
||||
action_text = f"Action: scroll(start_box='({x},{y})', direction='{direction}')"
|
||||
|
||||
|
||||
elif action_type == "wait":
|
||||
action_text = "Action: wait()"
|
||||
|
||||
|
||||
else:
|
||||
# Fallback for unknown action types
|
||||
action_text = f"Action: {action_type}({action})"
|
||||
|
||||
|
||||
current_assistant_content.append(action_text)
|
||||
|
||||
|
||||
# When we hit a computer_call_output, finalize the current assistant message
|
||||
if current_assistant_content:
|
||||
litellm_messages.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "\n".join(current_assistant_content)}]
|
||||
})
|
||||
litellm_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "\n".join(current_assistant_content)}
|
||||
],
|
||||
}
|
||||
)
|
||||
current_assistant_content = []
|
||||
|
||||
|
||||
elif message_type == "computer_call_output":
|
||||
# Add screenshot from computer call output
|
||||
output = message.get("output", {})
|
||||
if isinstance(output, dict) and output.get("type") == "input_image":
|
||||
image_url = output.get("image_url", "")
|
||||
if image_url:
|
||||
litellm_messages.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "image_url", "image_url": {"url": image_url}}]
|
||||
})
|
||||
|
||||
litellm_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image_url", "image_url": {"url": image_url}}],
|
||||
}
|
||||
)
|
||||
|
||||
elif message.get("role") == "user":
|
||||
# # Handle user messages
|
||||
# content = message.get("content", "")
|
||||
@@ -514,24 +555,22 @@ def convert_uitars_messages_to_litellm(messages: Messages) -> List[Dict[str, Any
|
||||
# "content": content
|
||||
# })
|
||||
pass
|
||||
|
||||
|
||||
# Add any remaining assistant content
|
||||
if current_assistant_content:
|
||||
litellm_messages.append({
|
||||
"role": "assistant",
|
||||
"content": current_assistant_content
|
||||
})
|
||||
|
||||
litellm_messages.append({"role": "assistant", "content": current_assistant_content})
|
||||
|
||||
return litellm_messages
|
||||
|
||||
|
||||
@register_agent(models=r"(?i).*ui-?tars.*")
|
||||
class UITARSConfig:
|
||||
"""
|
||||
UITARS agent configuration using liteLLM for ByteDance-Seed/UI-TARS-1.5-7B model.
|
||||
|
||||
|
||||
Supports UITARS vision-language models for computer control.
|
||||
"""
|
||||
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
@@ -545,11 +584,11 @@ class UITARSConfig:
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Predict the next step based on input messages.
|
||||
|
||||
|
||||
Args:
|
||||
messages: Input messages following Responses format
|
||||
model: Model name to use
|
||||
@@ -562,22 +601,22 @@ class UITARSConfig:
|
||||
_on_usage: Callback for usage tracking
|
||||
_on_screenshot: Callback for screenshot events
|
||||
**kwargs: Additional arguments
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with "output" (output items) and "usage" array
|
||||
"""
|
||||
tools = tools or []
|
||||
|
||||
|
||||
# Create response items
|
||||
response_items = []
|
||||
|
||||
|
||||
# Find computer tool for screen dimensions
|
||||
computer_tool = None
|
||||
for tool_schema in tools:
|
||||
if tool_schema["type"] == "computer":
|
||||
computer_tool = tool_schema["computer"]
|
||||
break
|
||||
|
||||
|
||||
# Get screen dimensions
|
||||
screen_width, screen_height = 1024, 768
|
||||
if computer_tool:
|
||||
@@ -585,20 +624,20 @@ class UITARSConfig:
|
||||
screen_width, screen_height = await computer_tool.get_dimensions()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# Process messages to extract instruction and image
|
||||
instruction = ""
|
||||
image_data = None
|
||||
|
||||
|
||||
# Convert messages to list if string
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
|
||||
# Extract instruction and latest screenshot
|
||||
for message in reversed(messages):
|
||||
if isinstance(message, dict):
|
||||
content = message.get("content", "")
|
||||
|
||||
|
||||
# Handle different content formats
|
||||
if isinstance(content, str):
|
||||
if not instruction and message.get("role") == "user":
|
||||
@@ -614,46 +653,41 @@ class UITARSConfig:
|
||||
image_data = image_url.get("url", "")
|
||||
else:
|
||||
image_data = image_url
|
||||
|
||||
|
||||
# Also check for computer_call_output with screenshots
|
||||
if message.get("type") == "computer_call_output" and not image_data:
|
||||
output = message.get("output", {})
|
||||
if isinstance(output, dict) and output.get("type") == "input_image":
|
||||
image_data = output.get("image_url", "")
|
||||
|
||||
|
||||
if instruction and image_data:
|
||||
break
|
||||
|
||||
|
||||
if not instruction:
|
||||
instruction = "Help me complete this task by analyzing the screen and taking appropriate actions."
|
||||
|
||||
instruction = (
|
||||
"Help me complete this task by analyzing the screen and taking appropriate actions."
|
||||
)
|
||||
|
||||
# Create prompt
|
||||
user_prompt = UITARS_PROMPT_TEMPLATE.format(
|
||||
instruction=instruction,
|
||||
action_space=UITARS_ACTION_SPACE,
|
||||
language="English"
|
||||
instruction=instruction, action_space=UITARS_ACTION_SPACE, language="English"
|
||||
)
|
||||
|
||||
|
||||
# Convert conversation history to LiteLLM format
|
||||
history_messages = convert_uitars_messages_to_litellm(messages)
|
||||
|
||||
|
||||
# Prepare messages for liteLLM
|
||||
litellm_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}
|
||||
]
|
||||
litellm_messages = [{"role": "system", "content": "You are a helpful assistant."}]
|
||||
|
||||
# Add current user instruction with screenshot
|
||||
current_user_message = {
|
||||
"role": "user",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": user_prompt},
|
||||
]
|
||||
],
|
||||
}
|
||||
litellm_messages.append(current_user_message)
|
||||
|
||||
|
||||
# Process image for UITARS
|
||||
if not image_data:
|
||||
# Take screenshot if none found in messages
|
||||
@@ -667,17 +701,22 @@ class UITARSConfig:
|
||||
raise ValueError("No screenshot found in messages and no computer_handler provided")
|
||||
processed_image, original_width, original_height = process_image_for_uitars(image_data)
|
||||
encoded_image = pil_to_base64(processed_image)
|
||||
|
||||
|
||||
# Add conversation history
|
||||
if history_messages:
|
||||
litellm_messages.extend(history_messages)
|
||||
else:
|
||||
litellm_messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_image}"}}
|
||||
]
|
||||
})
|
||||
litellm_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{encoded_image}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
@@ -687,146 +726,142 @@ class UITARSConfig:
|
||||
"temperature": kwargs.get("temperature", 0.0),
|
||||
"do_sample": kwargs.get("temperature", 0.0) > 0.0,
|
||||
"num_retries": max_retries,
|
||||
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]}
|
||||
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]},
|
||||
}
|
||||
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
|
||||
# Call liteLLM with UITARS model
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
|
||||
# Extract response content
|
||||
response_content = response.choices[0].message.content.strip() # type: ignore
|
||||
|
||||
response_content = response.choices[0].message.content.strip() # type: ignore
|
||||
|
||||
# Parse UITARS response
|
||||
parsed_responses = parse_uitars_response(response_content, original_width, original_height)
|
||||
|
||||
|
||||
# Convert to computer actions
|
||||
computer_actions = convert_to_computer_actions(parsed_responses, original_width, original_height)
|
||||
|
||||
computer_actions = convert_to_computer_actions(
|
||||
parsed_responses, original_width, original_height
|
||||
)
|
||||
|
||||
# Add computer actions to response items
|
||||
thought = parsed_responses[0].get("thought", "")
|
||||
if thought:
|
||||
response_items.append(make_reasoning_item(thought))
|
||||
response_items.extend(computer_actions)
|
||||
|
||||
|
||||
# Extract usage information
|
||||
response_usage = {
|
||||
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(response.usage).model_dump(),
|
||||
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(
|
||||
response.usage
|
||||
).model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(response_usage)
|
||||
|
||||
# Create agent response
|
||||
agent_response = {
|
||||
"output": response_items,
|
||||
"usage": response_usage
|
||||
}
|
||||
|
||||
agent_response = {"output": response_items, "usage": response_usage}
|
||||
|
||||
return agent_response
|
||||
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str
|
||||
self, model: str, image_b64: str, instruction: str
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates based on image and instruction.
|
||||
|
||||
|
||||
UITARS supports click prediction through its action parsing.
|
||||
|
||||
|
||||
Args:
|
||||
model: Model name to use
|
||||
image_b64: Base64 encoded image
|
||||
instruction: Instruction for where to click
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple with (x, y) coordinates or None
|
||||
"""
|
||||
try:
|
||||
# Create prompt using grounding template
|
||||
user_prompt = GROUNDING_UITARS_PROMPT_TEMPLATE.format(
|
||||
instruction=instruction
|
||||
)
|
||||
|
||||
user_prompt = GROUNDING_UITARS_PROMPT_TEMPLATE.format(instruction=instruction)
|
||||
|
||||
# Process image for UITARS
|
||||
processed_image, original_width, original_height = process_image_for_uitars(image_b64)
|
||||
encoded_image = pil_to_base64(processed_image)
|
||||
|
||||
|
||||
# Prepare messages for liteLLM
|
||||
litellm_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": user_prompt},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_image}"}}
|
||||
]
|
||||
}
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{encoded_image}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": litellm_messages,
|
||||
"max_tokens": 2056,
|
||||
"temperature": 0.0,
|
||||
"do_sample": False
|
||||
"do_sample": False,
|
||||
}
|
||||
|
||||
|
||||
# Call liteLLM with UITARS model
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
|
||||
# Extract response content
|
||||
response_content = response.choices[0].message.content.strip() # type: ignore
|
||||
|
||||
response_content = response.choices[0].message.content.strip() # type: ignore
|
||||
|
||||
print(response_content)
|
||||
|
||||
# Parse the response to extract click coordinates
|
||||
# Look for click action with coordinates (with special tokens)
|
||||
click_pattern = r"click\(point='<\|box_start\|>\((\d+),(\d+)\)<\|box_end\|>'\)"
|
||||
match = re.search(click_pattern, response_content)
|
||||
|
||||
|
||||
# Fallback: Look for simpler format without special tokens
|
||||
if not match:
|
||||
# Pattern for: click(start_box='(x,y)') or click(point='(x,y)')
|
||||
fallback_pattern = r"click\((?:start_box|point)='\((\d+),(\d+)\)'\)"
|
||||
match = re.search(fallback_pattern, response_content)
|
||||
|
||||
|
||||
if match:
|
||||
x, y = int(match.group(1)), int(match.group(2))
|
||||
# Scale coordinates back to original image dimensions
|
||||
scale_x = original_width / processed_image.width
|
||||
scale_y = original_height / processed_image.height
|
||||
|
||||
|
||||
scaled_x = int(x * scale_x)
|
||||
scaled_y = int(y * scale_y)
|
||||
|
||||
|
||||
return (scaled_x, scaled_y)
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Log error and return None
|
||||
print(f"Error in predict_click: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""
|
||||
Get list of capabilities supported by this agent config.
|
||||
|
||||
|
||||
Returns:
|
||||
List of capability strings
|
||||
"""
|
||||
return ["step", "click"]
|
||||
return ["step", "click"]
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
"""
|
||||
Example usage of the proxy server and client requests.
|
||||
"""
|
||||
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import aiohttp
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
async def test_http_endpoint():
|
||||
"""Test the HTTP /responses endpoint."""
|
||||
|
||||
|
||||
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
assert isinstance(anthropic_api_key, str), "ANTHROPIC_API_KEY environment variable must be set"
|
||||
|
||||
@@ -21,11 +24,9 @@ async def test_http_endpoint():
|
||||
simple_request = {
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"input": "Tell me a three sentence bedtime story about a unicorn.",
|
||||
"env": {
|
||||
"ANTHROPIC_API_KEY": anthropic_api_key
|
||||
}
|
||||
"env": {"ANTHROPIC_API_KEY": anthropic_api_key},
|
||||
}
|
||||
|
||||
|
||||
# Example 2: Multi-modal request with image
|
||||
multimodal_request = {
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
@@ -36,70 +37,72 @@ async def test_http_endpoint():
|
||||
{"type": "input_text", "text": "what is in this image?"},
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
}
|
||||
]
|
||||
"image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"env": {
|
||||
"ANTHROPIC_API_KEY": anthropic_api_key
|
||||
}
|
||||
"env": {"ANTHROPIC_API_KEY": anthropic_api_key},
|
||||
}
|
||||
|
||||
|
||||
# Example 3: Request with custom agent and computer kwargs
|
||||
custom_request = {
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"input": "Take a screenshot and tell me what you see",
|
||||
"env": {
|
||||
"ANTHROPIC_API_KEY": anthropic_api_key
|
||||
}
|
||||
"env": {"ANTHROPIC_API_KEY": anthropic_api_key},
|
||||
}
|
||||
|
||||
|
||||
# Test requests
|
||||
base_url = "https://m-linux-96lcxd2c2k.containers.cloud.trycua.com:8443"
|
||||
# base_url = "http://localhost:8000"
|
||||
api_key = os.getenv("CUA_API_KEY")
|
||||
assert isinstance(api_key, str), "CUA_API_KEY environment variable must be set"
|
||||
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for i, request_data in enumerate([
|
||||
simple_request,
|
||||
# multimodal_request,
|
||||
custom_request
|
||||
], 1):
|
||||
for i, request_data in enumerate(
|
||||
[
|
||||
simple_request,
|
||||
# multimodal_request,
|
||||
custom_request,
|
||||
],
|
||||
1,
|
||||
):
|
||||
print(f"\n--- Test {i} ---")
|
||||
print(f"Request: {json.dumps(request_data, indent=2)}")
|
||||
|
||||
|
||||
try:
|
||||
print(f"Sending request to {base_url}/responses")
|
||||
async with session.post(
|
||||
f"{base_url}/responses",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json", "X-API-Key": api_key}
|
||||
headers={"Content-Type": "application/json", "X-API-Key": api_key},
|
||||
) as response:
|
||||
result = await response.json()
|
||||
print(f"Status: {response.status}")
|
||||
print(f"Response: {json.dumps(result, indent=2)}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
|
||||
def curl_examples():
|
||||
"""Print curl command examples."""
|
||||
|
||||
|
||||
print("=== CURL Examples ===\n")
|
||||
|
||||
|
||||
print("1. Simple text request:")
|
||||
print("""curl http://localhost:8000/responses \\
|
||||
print(
|
||||
"""curl http://localhost:8000/responses \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"input": "Tell me a three sentence bedtime story about a unicorn."
|
||||
}'""")
|
||||
|
||||
}'"""
|
||||
)
|
||||
|
||||
print("\n2. Multi-modal request with image:")
|
||||
print("""curl http://localhost:8000/responses \\
|
||||
print(
|
||||
"""curl http://localhost:8000/responses \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
@@ -115,10 +118,12 @@ def curl_examples():
|
||||
]
|
||||
}
|
||||
]
|
||||
}'""")
|
||||
|
||||
}'"""
|
||||
)
|
||||
|
||||
print("\n3. Request with custom configuration:")
|
||||
print("""curl http://localhost:8000/responses \\
|
||||
print(
|
||||
"""curl http://localhost:8000/responses \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
@@ -131,50 +136,49 @@ def curl_examples():
|
||||
"os_type": "linux",
|
||||
"provider_type": "cloud"
|
||||
}
|
||||
}'""")
|
||||
}'"""
|
||||
)
|
||||
|
||||
|
||||
async def test_p2p_client():
|
||||
"""Example P2P client using peerjs-python."""
|
||||
try:
|
||||
from peerjs import Peer, PeerOptions, ConnectionEventType
|
||||
from aiortc import RTCConfiguration, RTCIceServer
|
||||
|
||||
from peerjs import ConnectionEventType, Peer, PeerOptions
|
||||
|
||||
# Set up client peer
|
||||
options = PeerOptions(
|
||||
host="0.peerjs.com",
|
||||
port=443,
|
||||
secure=True,
|
||||
config=RTCConfiguration(
|
||||
iceServers=[RTCIceServer(urls="stun:stun.l.google.com:19302")]
|
||||
)
|
||||
config=RTCConfiguration(iceServers=[RTCIceServer(urls="stun:stun.l.google.com:19302")]),
|
||||
)
|
||||
|
||||
|
||||
client_peer = Peer(id="test-client", peer_options=options)
|
||||
await client_peer.start()
|
||||
|
||||
|
||||
# Connect to proxy server
|
||||
connection = client_peer.connect("computer-agent-proxy")
|
||||
|
||||
|
||||
@connection.on(ConnectionEventType.Open)
|
||||
async def connection_open():
|
||||
print("Connected to proxy server")
|
||||
|
||||
|
||||
# Send a test request
|
||||
request = {
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"input": "Hello from P2P client!"
|
||||
"input": "Hello from P2P client!",
|
||||
}
|
||||
await connection.send(json.dumps(request))
|
||||
|
||||
|
||||
@connection.on(ConnectionEventType.Data)
|
||||
async def connection_data(data):
|
||||
print(f"Received response: {data}")
|
||||
await client_peer.destroy()
|
||||
|
||||
|
||||
# Wait for connection
|
||||
await asyncio.sleep(10)
|
||||
|
||||
|
||||
except ImportError:
|
||||
print("P2P dependencies not available. Install peerjs-python for P2P testing.")
|
||||
except Exception as e:
|
||||
@@ -183,7 +187,7 @@ async def test_p2p_client():
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "curl":
|
||||
curl_examples()
|
||||
elif len(sys.argv) > 1 and sys.argv[1] == "p2p":
|
||||
|
||||
@@ -7,24 +7,25 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Any, List, Union, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from computer import Computer
|
||||
|
||||
from ..agent import ComputerAgent
|
||||
from computer import Computer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResponsesHandler:
|
||||
"""Handler for /responses endpoint that processes agent requests."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.computer = None
|
||||
self.agent = None
|
||||
# Simple in-memory caches
|
||||
self._computer_cache: Dict[str, Any] = {}
|
||||
self._agent_cache: Dict[str, Any] = {}
|
||||
|
||||
|
||||
async def setup_computer_agent(
|
||||
self,
|
||||
model: str,
|
||||
@@ -75,7 +76,9 @@ class ResponsesHandler:
|
||||
computer = Computer(**default_c_config)
|
||||
await computer.__aenter__()
|
||||
self._computer_cache[comp_key] = computer
|
||||
logger.info(f"Computer created and cached with key={comp_key} config={default_c_config}")
|
||||
logger.info(
|
||||
f"Computer created and cached with key={comp_key} config={default_c_config}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Reusing cached computer for key={comp_key}")
|
||||
|
||||
@@ -115,14 +118,14 @@ class ResponsesHandler:
|
||||
|
||||
# Bind current agent reference
|
||||
self.agent = agent
|
||||
|
||||
|
||||
async def process_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a /responses request and return the result.
|
||||
|
||||
|
||||
Args:
|
||||
request_data: Dictionary containing model, input, and optional kwargs
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with the agent's response
|
||||
"""
|
||||
@@ -133,12 +136,12 @@ class ResponsesHandler:
|
||||
agent_kwargs = request_data.get("agent_kwargs", {})
|
||||
computer_kwargs = request_data.get("computer_kwargs", {})
|
||||
env_overrides = request_data.get("env", {}) or {}
|
||||
|
||||
|
||||
if not model:
|
||||
raise ValueError("Model is required")
|
||||
if not input_data:
|
||||
raise ValueError("Input is required")
|
||||
|
||||
|
||||
# Apply env overrides for the duration of this request
|
||||
with self._env_overrides(env_overrides):
|
||||
# Set up (and possibly reuse) computer and agent via caches
|
||||
@@ -155,28 +158,22 @@ class ResponsesHandler:
|
||||
# Run agent and get first result
|
||||
async for result in agent.run(messages):
|
||||
# Return the first result and break
|
||||
return {
|
||||
"success": True,
|
||||
"result": result,
|
||||
"model": model
|
||||
}
|
||||
|
||||
return {"success": True, "result": result, "model": model}
|
||||
|
||||
# If no results were yielded
|
||||
return {
|
||||
"success": False,
|
||||
"error": "No results from agent",
|
||||
"model": model
|
||||
}
|
||||
|
||||
return {"success": False, "error": "No results from agent", "model": model}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing request: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"model": request_data.get("model", "unknown")
|
||||
"model": request_data.get("model", "unknown"),
|
||||
}
|
||||
|
||||
def _convert_input_to_messages(self, input_data: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
|
||||
|
||||
def _convert_input_to_messages(
|
||||
self, input_data: Union[str, List[Dict[str, Any]]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert input data to messages format."""
|
||||
if isinstance(input_data, str):
|
||||
# Simple string input
|
||||
@@ -192,22 +189,18 @@ class ResponsesHandler:
|
||||
if part.get("type") == "input_text":
|
||||
content_parts.append({"type": "text", "text": part["text"]})
|
||||
elif part.get("type") == "input_image":
|
||||
content_parts.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": part["image_url"]}
|
||||
})
|
||||
content_parts.append(
|
||||
{"type": "image_url", "image_url": {"url": part["image_url"]}}
|
||||
)
|
||||
else:
|
||||
content_parts.append(part)
|
||||
messages.append({
|
||||
"role": msg["role"],
|
||||
"content": content_parts
|
||||
})
|
||||
messages.append({"role": msg["role"], "content": content_parts})
|
||||
else:
|
||||
messages.append(msg)
|
||||
return messages
|
||||
else:
|
||||
raise ValueError("Input must be string or list of messages")
|
||||
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources."""
|
||||
if self.computer:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,37 +2,43 @@
|
||||
Type definitions for agent
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Optional, Callable, Protocol, Literal
|
||||
from pydantic import BaseModel
|
||||
import re
|
||||
from litellm import ResponseInputParam, ResponsesAPIResponse, ToolParam
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Protocol
|
||||
|
||||
from litellm import ResponseInputParam, ResponsesAPIResponse, ToolParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Agent input types
|
||||
Messages = str | ResponseInputParam | List[Dict[str, Any]]
|
||||
Tools = Optional[Iterable[ToolParam]]
|
||||
|
||||
# Agent output types
|
||||
AgentResponse = ResponsesAPIResponse
|
||||
AgentResponse = ResponsesAPIResponse
|
||||
AgentCapability = Literal["step", "click"]
|
||||
|
||||
|
||||
# Exception types
|
||||
class ToolError(RuntimeError):
|
||||
"""Base exception for tool-related errors"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class IllegalArgumentError(ToolError):
|
||||
"""Exception raised when function arguments are invalid"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Agent config registration
|
||||
class AgentConfigInfo(BaseModel):
|
||||
"""Information about a registered agent config"""
|
||||
|
||||
agent_class: type
|
||||
models_regex: str
|
||||
priority: int = 0
|
||||
|
||||
|
||||
def matches_model(self, model: str) -> bool:
|
||||
"""Check if this agent config matches the given model"""
|
||||
return bool(re.match(self.models_regex, model))
|
||||
|
||||
@@ -2,6 +2,6 @@
|
||||
UI components for agent
|
||||
"""
|
||||
|
||||
from .gradio import launch_ui, create_gradio_ui
|
||||
from .gradio import create_gradio_ui, launch_ui
|
||||
|
||||
__all__ = ["launch_ui", "create_gradio_ui"]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .gradio import launch_ui
|
||||
|
||||
if __name__ == "__main__":
|
||||
launch_ui()
|
||||
launch_ui()
|
||||
|
||||
@@ -18,21 +18,21 @@ Requirements:
|
||||
- OpenAI or Anthropic API key
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, AsyncGenerator, Any, Tuple, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components.chatbot import MetadataDict
|
||||
from typing import cast
|
||||
|
||||
# Import from agent package
|
||||
from agent import ComputerAgent
|
||||
from agent.types import Messages, AgentResponse
|
||||
from agent.types import AgentResponse, Messages
|
||||
from computer import Computer
|
||||
from gradio.components.chatbot import MetadataDict
|
||||
|
||||
# Global variables
|
||||
global_agent = None
|
||||
@@ -42,11 +42,13 @@ SETTINGS_FILE = Path(".gradio_settings.json")
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
import dotenv
|
||||
|
||||
if dotenv.load_dotenv():
|
||||
print(f"DEBUG - Loaded environment variables from {dotenv.find_dotenv()}")
|
||||
else:
|
||||
print("DEBUG - No .env file found")
|
||||
|
||||
|
||||
# --- Settings Load/Save Functions ---
|
||||
def load_settings() -> Dict[str, Any]:
|
||||
"""Loads settings from the JSON file."""
|
||||
@@ -84,7 +86,7 @@ def save_settings(settings: Dict[str, Any]):
|
||||
# async def on_screenshot(self, screenshot_base64: str, action_type: str = "") -> None:
|
||||
# """Add screenshot to chatbot when a screenshot is taken."""
|
||||
# image_markdown = f""
|
||||
|
||||
|
||||
# if self.chatbot_history is not None:
|
||||
# self.chatbot_history.append(
|
||||
# gr.ChatMessage(
|
||||
@@ -141,7 +143,7 @@ def get_model_string(model_name: str, loop_provider: str) -> str:
|
||||
ollama_model = model_name.split("OMNI: Ollama ", 1)[1]
|
||||
return f"omniparser+ollama_chat/{ollama_model}"
|
||||
return "omniparser+ollama_chat/llama3"
|
||||
|
||||
|
||||
# Map based on loop provider
|
||||
mapping = MODEL_MAPPINGS.get(loop_provider.lower(), MODEL_MAPPINGS["openai"])
|
||||
return mapping.get(model_name, mapping["default"])
|
||||
@@ -151,6 +153,7 @@ def get_ollama_models() -> List[str]:
|
||||
"""Get available models from Ollama if installed."""
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
result = subprocess.run(["ollama", "list"], capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
lines = result.stdout.strip().split("\n")
|
||||
@@ -174,16 +177,14 @@ def create_computer_instance(
|
||||
os_type: str = "macos",
|
||||
provider_type: str = "lume",
|
||||
name: Optional[str] = None,
|
||||
api_key: Optional[str] = None
|
||||
api_key: Optional[str] = None,
|
||||
) -> Computer:
|
||||
"""Create or get the global Computer instance."""
|
||||
global global_computer
|
||||
if global_computer is None:
|
||||
if provider_type == "localhost":
|
||||
global_computer = Computer(
|
||||
verbosity=verbosity,
|
||||
os_type=os_type,
|
||||
use_host_computer_server=True
|
||||
verbosity=verbosity, os_type=os_type, use_host_computer_server=True
|
||||
)
|
||||
else:
|
||||
global_computer = Computer(
|
||||
@@ -191,7 +192,7 @@ def create_computer_instance(
|
||||
os_type=os_type,
|
||||
provider_type=provider_type,
|
||||
name=name if name else "",
|
||||
api_key=api_key
|
||||
api_key=api_key,
|
||||
)
|
||||
return global_computer
|
||||
|
||||
@@ -217,7 +218,7 @@ def create_agent(
|
||||
os_type=computer_os,
|
||||
provider_type=computer_provider,
|
||||
name=computer_name,
|
||||
api_key=computer_api_key
|
||||
api_key=computer_api_key,
|
||||
)
|
||||
|
||||
# Handle custom models
|
||||
@@ -233,12 +234,15 @@ def create_agent(
|
||||
"only_n_most_recent_images": only_n_most_recent_images,
|
||||
"verbosity": verbosity,
|
||||
}
|
||||
|
||||
|
||||
if save_trajectory:
|
||||
agent_kwargs["trajectory_dir"] = "trajectories"
|
||||
|
||||
|
||||
if max_trajectory_budget:
|
||||
agent_kwargs["max_trajectory_budget"] = {"max_budget": max_trajectory_budget, "raise_error": True}
|
||||
agent_kwargs["max_trajectory_budget"] = {
|
||||
"max_budget": max_trajectory_budget,
|
||||
"raise_error": True,
|
||||
}
|
||||
|
||||
global_agent = ComputerAgent(**agent_kwargs)
|
||||
return global_agent
|
||||
@@ -247,7 +251,8 @@ def create_agent(
|
||||
def launch_ui():
|
||||
"""Standalone function to launch the Gradio app."""
|
||||
from agent.ui.gradio.ui_components import create_gradio_ui
|
||||
print(f"Starting Gradio app for CUA Agent...")
|
||||
|
||||
print("Starting Gradio app for CUA Agent...")
|
||||
demo = create_gradio_ui()
|
||||
demo.launch(share=False, inbrowser=True)
|
||||
|
||||
|
||||
@@ -2,19 +2,25 @@
|
||||
UI Components for the Gradio interface
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any, cast
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components.chatbot import MetadataDict
|
||||
|
||||
from .app import (
|
||||
load_settings, save_settings, create_agent, get_model_string,
|
||||
get_ollama_models, global_agent, global_computer
|
||||
create_agent,
|
||||
get_model_string,
|
||||
get_ollama_models,
|
||||
global_agent,
|
||||
global_computer,
|
||||
load_settings,
|
||||
save_settings,
|
||||
)
|
||||
|
||||
# Global messages array to maintain conversation history
|
||||
@@ -23,15 +29,15 @@ global_messages = []
|
||||
|
||||
def create_gradio_ui() -> gr.Blocks:
|
||||
"""Create a Gradio UI for the Computer-Use Agent."""
|
||||
|
||||
|
||||
# Load settings
|
||||
saved_settings = load_settings()
|
||||
|
||||
|
||||
# Check for API keys
|
||||
openai_api_key = os.environ.get("OPENAI_API_KEY", "")
|
||||
anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY", "")
|
||||
cua_api_key = os.environ.get("CUA_API_KEY", "")
|
||||
|
||||
|
||||
# Model choices
|
||||
openai_models = ["OpenAI: Computer-Use Preview"]
|
||||
anthropic_models = [
|
||||
@@ -43,10 +49,10 @@ def create_gradio_ui() -> gr.Blocks:
|
||||
omni_models = [
|
||||
"OMNI: OpenAI GPT-4o",
|
||||
"OMNI: OpenAI GPT-4o mini",
|
||||
"OMNI: Claude 3.7 Sonnet (20250219)",
|
||||
"OMNI: Claude 3.5 Sonnet (20241022)"
|
||||
"OMNI: Claude 3.7 Sonnet (20250219)",
|
||||
"OMNI: Claude 3.5 Sonnet (20241022)",
|
||||
]
|
||||
|
||||
|
||||
# Check if API keys are available
|
||||
has_openai_key = bool(openai_api_key)
|
||||
has_anthropic_key = bool(anthropic_api_key)
|
||||
@@ -59,15 +65,20 @@ def create_gradio_ui() -> gr.Blocks:
|
||||
|
||||
# Detect platform
|
||||
is_mac = platform.system().lower() == "darwin"
|
||||
|
||||
|
||||
# Format model choices
|
||||
provider_to_models = {
|
||||
"OPENAI": openai_models,
|
||||
"ANTHROPIC": anthropic_models,
|
||||
"OMNI": omni_models + ["Custom model (OpenAI compatible API)", "Custom model (ollama)"],
|
||||
"UITARS": ([
|
||||
"huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B",
|
||||
] if is_mac else []) + ["Custom model (OpenAI compatible API)"],
|
||||
"UITARS": (
|
||||
[
|
||||
"huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B",
|
||||
]
|
||||
if is_mac
|
||||
else []
|
||||
)
|
||||
+ ["Custom model (OpenAI compatible API)"],
|
||||
}
|
||||
|
||||
# Apply saved settings
|
||||
@@ -82,7 +93,9 @@ def create_gradio_ui() -> gr.Blocks:
|
||||
elif initial_loop == "ANTHROPIC":
|
||||
initial_model = anthropic_models[0] if anthropic_models else "No models available"
|
||||
else: # OMNI
|
||||
initial_model = omni_models[0] if omni_models else "Custom model (OpenAI compatible API)"
|
||||
initial_model = (
|
||||
omni_models[0] if omni_models else "Custom model (OpenAI compatible API)"
|
||||
)
|
||||
|
||||
initial_custom_model = saved_settings.get("custom_model", "Qwen2.5-VL-7B-Instruct")
|
||||
initial_provider_base_url = saved_settings.get("provider_base_url", "http://localhost:1234/v1")
|
||||
@@ -96,16 +109,27 @@ def create_gradio_ui() -> gr.Blocks:
|
||||
"Open Safari, search for 'macOS automation tools', and save the first three results as bookmarks",
|
||||
"Configure SSH keys and set up a connection to a remote server",
|
||||
]
|
||||
|
||||
def generate_python_code(agent_loop_choice, model_name, tasks, recent_images=3, save_trajectory=True, computer_os="linux", computer_provider="cloud", container_name="", cua_cloud_api_key="", max_budget=None):
|
||||
|
||||
def generate_python_code(
|
||||
agent_loop_choice,
|
||||
model_name,
|
||||
tasks,
|
||||
recent_images=3,
|
||||
save_trajectory=True,
|
||||
computer_os="linux",
|
||||
computer_provider="cloud",
|
||||
container_name="",
|
||||
cua_cloud_api_key="",
|
||||
max_budget=None,
|
||||
):
|
||||
"""Generate Python code for the current configuration and tasks."""
|
||||
tasks_str = ""
|
||||
for task in tasks:
|
||||
if task and task.strip():
|
||||
tasks_str += f' "{task}",\n'
|
||||
|
||||
|
||||
model_string = get_model_string(model_name, agent_loop_choice)
|
||||
|
||||
|
||||
computer_args = []
|
||||
if computer_os != "macos":
|
||||
computer_args.append(f'os_type="{computer_os}"')
|
||||
@@ -115,14 +139,14 @@ def create_gradio_ui() -> gr.Blocks:
|
||||
computer_args.append(f'name="{container_name}"')
|
||||
if cua_cloud_api_key:
|
||||
computer_args.append(f'api_key="{cua_cloud_api_key}"')
|
||||
|
||||
|
||||
computer_args_str = ", ".join(computer_args)
|
||||
if computer_args_str:
|
||||
computer_args_str = f"({computer_args_str})"
|
||||
else:
|
||||
computer_args_str = "()"
|
||||
|
||||
code = f'''import asyncio
|
||||
|
||||
code = f"""import asyncio
|
||||
from computer import Computer
|
||||
from agent import ComputerAgent
|
||||
|
||||
@@ -131,22 +155,22 @@ async def main():
|
||||
agent = ComputerAgent(
|
||||
model="{model_string}",
|
||||
tools=[computer],
|
||||
only_n_most_recent_images={recent_images},'''
|
||||
|
||||
only_n_most_recent_images={recent_images},"""
|
||||
|
||||
if save_trajectory:
|
||||
code += '''
|
||||
trajectory_dir="trajectories",'''
|
||||
|
||||
code += """
|
||||
trajectory_dir="trajectories","""
|
||||
|
||||
if max_budget:
|
||||
code += f'''
|
||||
max_trajectory_budget={{"max_budget": {max_budget}, "raise_error": True}},'''
|
||||
|
||||
code += '''
|
||||
code += f"""
|
||||
max_trajectory_budget={{"max_budget": {max_budget}, "raise_error": True}},"""
|
||||
|
||||
code += """
|
||||
)
|
||||
'''
|
||||
|
||||
"""
|
||||
|
||||
if tasks_str:
|
||||
code += f'''
|
||||
code += f"""
|
||||
# Prompts for the computer-use agent
|
||||
tasks = [
|
||||
{tasks_str.rstrip()}
|
||||
@@ -158,23 +182,23 @@ async def main():
|
||||
async for result in agent.run(messages):
|
||||
for item in result["output"]:
|
||||
if item["type"] == "message":
|
||||
print(item["content"][0]["text"])'''
|
||||
print(item["content"][0]["text"])"""
|
||||
else:
|
||||
code += f'''
|
||||
code += """
|
||||
# Execute a single task
|
||||
task = "Search for information about CUA on GitHub"
|
||||
print(f"Executing task: {{task}}")
|
||||
messages = [{{"role": "user", "content": task}}]
|
||||
print(f"Executing task: {task}")
|
||||
messages = [{"role": "user", "content": task}]
|
||||
async for result in agent.run(messages):
|
||||
for item in result["output"]:
|
||||
if item["type"] == "message":
|
||||
print(item["content"][0]["text"])'''
|
||||
print(item["content"][0]["text"])"""
|
||||
|
||||
code += '''
|
||||
code += """
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())'''
|
||||
|
||||
asyncio.run(main())"""
|
||||
|
||||
return code
|
||||
|
||||
# Create the Gradio interface
|
||||
@@ -199,11 +223,11 @@ if __name__ == "__main__":
|
||||
value=generate_python_code(initial_loop, "gpt-4o", []),
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
|
||||
with gr.Accordion("Computer Configuration", open=True):
|
||||
is_windows = platform.system().lower() == "windows"
|
||||
is_mac = platform.system().lower() == "darwin"
|
||||
|
||||
|
||||
providers = ["cloud", "localhost", "docker"]
|
||||
if is_mac:
|
||||
providers += ["lume"]
|
||||
@@ -227,30 +251,30 @@ if __name__ == "__main__":
|
||||
value=computer_choices[0],
|
||||
info="Select the operating system for the computer",
|
||||
)
|
||||
|
||||
|
||||
computer_provider = gr.Radio(
|
||||
choices=providers,
|
||||
label="Provider",
|
||||
value="lume" if is_mac else "cloud",
|
||||
info="Select the computer provider",
|
||||
)
|
||||
|
||||
|
||||
container_name = gr.Textbox(
|
||||
label="Container Name",
|
||||
placeholder="Enter container name (optional)",
|
||||
value=os.environ.get("CUA_CONTAINER_NAME", ""),
|
||||
info="Optional name for the container",
|
||||
)
|
||||
|
||||
|
||||
cua_cloud_api_key = gr.Textbox(
|
||||
label="CUA Cloud API Key",
|
||||
placeholder="Enter your CUA Cloud API key",
|
||||
value=os.environ.get("CUA_API_KEY", ""),
|
||||
type="password",
|
||||
info="Required for cloud provider",
|
||||
visible=(not has_cua_key)
|
||||
visible=(not has_cua_key),
|
||||
)
|
||||
|
||||
|
||||
with gr.Accordion("Agent Configuration", open=True):
|
||||
agent_loop = gr.Dropdown(
|
||||
choices=["OPENAI", "ANTHROPIC", "OMNI", "UITARS"],
|
||||
@@ -267,90 +291,113 @@ if __name__ == "__main__":
|
||||
value=openai_models[0] if openai_models else "No models available",
|
||||
info="Select OpenAI model",
|
||||
interactive=True,
|
||||
visible=(initial_loop == "OPENAI")
|
||||
visible=(initial_loop == "OPENAI"),
|
||||
)
|
||||
|
||||
|
||||
anthropic_model_choice = gr.Dropdown(
|
||||
choices=anthropic_models,
|
||||
label="Anthropic Model",
|
||||
value=anthropic_models[0] if anthropic_models else "No models available",
|
||||
value=(
|
||||
anthropic_models[0] if anthropic_models else "No models available"
|
||||
),
|
||||
info="Select Anthropic model",
|
||||
interactive=True,
|
||||
visible=(initial_loop == "ANTHROPIC")
|
||||
visible=(initial_loop == "ANTHROPIC"),
|
||||
)
|
||||
|
||||
|
||||
omni_model_choice = gr.Dropdown(
|
||||
choices=omni_models + ["Custom model (OpenAI compatible API)", "Custom model (ollama)"],
|
||||
choices=omni_models
|
||||
+ ["Custom model (OpenAI compatible API)", "Custom model (ollama)"],
|
||||
label="OMNI Model",
|
||||
value=omni_models[0] if omni_models else "Custom model (OpenAI compatible API)",
|
||||
value=(
|
||||
omni_models[0]
|
||||
if omni_models
|
||||
else "Custom model (OpenAI compatible API)"
|
||||
),
|
||||
info="Select OMNI model or choose a custom model option",
|
||||
interactive=True,
|
||||
visible=(initial_loop == "OMNI")
|
||||
visible=(initial_loop == "OMNI"),
|
||||
)
|
||||
|
||||
|
||||
uitars_model_choice = gr.Dropdown(
|
||||
choices=provider_to_models.get("UITARS", ["No models available"]),
|
||||
label="UITARS Model",
|
||||
value=provider_to_models.get("UITARS", ["No models available"])[0] if provider_to_models.get("UITARS") else "No models available",
|
||||
value=(
|
||||
provider_to_models.get("UITARS", ["No models available"])[0]
|
||||
if provider_to_models.get("UITARS")
|
||||
else "No models available"
|
||||
),
|
||||
info="Select UITARS model",
|
||||
interactive=True,
|
||||
visible=(initial_loop == "UITARS")
|
||||
visible=(initial_loop == "UITARS"),
|
||||
)
|
||||
|
||||
|
||||
model_choice = gr.Textbox(visible=False)
|
||||
|
||||
# API key inputs
|
||||
with gr.Group(visible=not has_openai_key and (initial_loop == "OPENAI" or initial_loop == "OMNI")) as openai_key_group:
|
||||
with gr.Group(
|
||||
visible=not has_openai_key
|
||||
and (initial_loop == "OPENAI" or initial_loop == "OMNI")
|
||||
) as openai_key_group:
|
||||
openai_api_key_input = gr.Textbox(
|
||||
label="OpenAI API Key",
|
||||
placeholder="Enter your OpenAI API key",
|
||||
value=os.environ.get("OPENAI_API_KEY", ""),
|
||||
interactive=True,
|
||||
type="password",
|
||||
info="Required for OpenAI models"
|
||||
info="Required for OpenAI models",
|
||||
)
|
||||
|
||||
with gr.Group(visible=not has_anthropic_key and (initial_loop == "ANTHROPIC" or initial_loop == "OMNI")) as anthropic_key_group:
|
||||
|
||||
with gr.Group(
|
||||
visible=not has_anthropic_key
|
||||
and (initial_loop == "ANTHROPIC" or initial_loop == "OMNI")
|
||||
) as anthropic_key_group:
|
||||
anthropic_api_key_input = gr.Textbox(
|
||||
label="Anthropic API Key",
|
||||
placeholder="Enter your Anthropic API key",
|
||||
value=os.environ.get("ANTHROPIC_API_KEY", ""),
|
||||
interactive=True,
|
||||
type="password",
|
||||
info="Required for Anthropic models"
|
||||
info="Required for Anthropic models",
|
||||
)
|
||||
|
||||
|
||||
# API key handlers
|
||||
def set_openai_api_key(key):
|
||||
if key and key.strip():
|
||||
os.environ["OPENAI_API_KEY"] = key.strip()
|
||||
print(f"DEBUG - Set OpenAI API key environment variable")
|
||||
print("DEBUG - Set OpenAI API key environment variable")
|
||||
return key
|
||||
|
||||
|
||||
def set_anthropic_api_key(key):
|
||||
if key and key.strip():
|
||||
os.environ["ANTHROPIC_API_KEY"] = key.strip()
|
||||
print(f"DEBUG - Set Anthropic API key environment variable")
|
||||
print("DEBUG - Set Anthropic API key environment variable")
|
||||
return key
|
||||
|
||||
|
||||
openai_api_key_input.change(
|
||||
fn=set_openai_api_key,
|
||||
inputs=[openai_api_key_input],
|
||||
outputs=[openai_api_key_input],
|
||||
queue=False
|
||||
queue=False,
|
||||
)
|
||||
|
||||
|
||||
anthropic_api_key_input.change(
|
||||
fn=set_anthropic_api_key,
|
||||
inputs=[anthropic_api_key_input],
|
||||
outputs=[anthropic_api_key_input],
|
||||
queue=False
|
||||
queue=False,
|
||||
)
|
||||
|
||||
# UI update function
|
||||
def update_ui(loop=None, openai_model=None, anthropic_model=None, omni_model=None, uitars_model=None):
|
||||
def update_ui(
|
||||
loop=None,
|
||||
openai_model=None,
|
||||
anthropic_model=None,
|
||||
omni_model=None,
|
||||
uitars_model=None,
|
||||
):
|
||||
loop = loop or agent_loop.value
|
||||
|
||||
|
||||
model_value = None
|
||||
if loop == "OPENAI" and openai_model:
|
||||
model_value = openai_model
|
||||
@@ -360,21 +407,37 @@ if __name__ == "__main__":
|
||||
model_value = omni_model
|
||||
elif loop == "UITARS" and uitars_model:
|
||||
model_value = uitars_model
|
||||
|
||||
openai_visible = (loop == "OPENAI")
|
||||
anthropic_visible = (loop == "ANTHROPIC")
|
||||
omni_visible = (loop == "OMNI")
|
||||
uitars_visible = (loop == "UITARS")
|
||||
|
||||
show_openai_key = not has_openai_key and (loop == "OPENAI" or (loop == "OMNI" and model_value and "OpenAI" in model_value and "Custom" not in model_value))
|
||||
show_anthropic_key = not has_anthropic_key and (loop == "ANTHROPIC" or (loop == "OMNI" and model_value and "Claude" in model_value and "Custom" not in model_value))
|
||||
|
||||
|
||||
openai_visible = loop == "OPENAI"
|
||||
anthropic_visible = loop == "ANTHROPIC"
|
||||
omni_visible = loop == "OMNI"
|
||||
uitars_visible = loop == "UITARS"
|
||||
|
||||
show_openai_key = not has_openai_key and (
|
||||
loop == "OPENAI"
|
||||
or (
|
||||
loop == "OMNI"
|
||||
and model_value
|
||||
and "OpenAI" in model_value
|
||||
and "Custom" not in model_value
|
||||
)
|
||||
)
|
||||
show_anthropic_key = not has_anthropic_key and (
|
||||
loop == "ANTHROPIC"
|
||||
or (
|
||||
loop == "OMNI"
|
||||
and model_value
|
||||
and "Claude" in model_value
|
||||
and "Custom" not in model_value
|
||||
)
|
||||
)
|
||||
|
||||
is_custom_openai_api = model_value == "Custom model (OpenAI compatible API)"
|
||||
is_custom_ollama = model_value == "Custom model (ollama)"
|
||||
is_any_custom = is_custom_openai_api or is_custom_ollama
|
||||
|
||||
|
||||
model_choice_value = model_value if model_value else ""
|
||||
|
||||
|
||||
return [
|
||||
gr.update(visible=openai_visible),
|
||||
gr.update(visible=anthropic_visible),
|
||||
@@ -385,15 +448,18 @@ if __name__ == "__main__":
|
||||
gr.update(visible=is_any_custom),
|
||||
gr.update(visible=is_custom_openai_api),
|
||||
gr.update(visible=is_custom_openai_api),
|
||||
gr.update(value=model_choice_value)
|
||||
gr.update(value=model_choice_value),
|
||||
]
|
||||
|
||||
|
||||
# Custom model inputs
|
||||
custom_model = gr.Textbox(
|
||||
label="Custom Model Name",
|
||||
placeholder="Enter custom model name (e.g., Qwen2.5-VL-7B-Instruct or llama3)",
|
||||
value=initial_custom_model,
|
||||
visible=(initial_model == "Custom model (OpenAI compatible API)" or initial_model == "Custom model (ollama)"),
|
||||
visible=(
|
||||
initial_model == "Custom model (OpenAI compatible API)"
|
||||
or initial_model == "Custom model (ollama)"
|
||||
),
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
@@ -413,36 +479,56 @@ if __name__ == "__main__":
|
||||
interactive=True,
|
||||
type="password",
|
||||
)
|
||||
|
||||
|
||||
# Provider visibility update function
|
||||
def update_provider_visibility(provider):
|
||||
"""Update visibility of container name and API key based on selected provider."""
|
||||
is_localhost = provider == "localhost"
|
||||
return [
|
||||
gr.update(visible=not is_localhost), # container_name
|
||||
gr.update(visible=not is_localhost and not has_cua_key) # cua_cloud_api_key
|
||||
gr.update(
|
||||
visible=not is_localhost and not has_cua_key
|
||||
), # cua_cloud_api_key
|
||||
]
|
||||
|
||||
|
||||
# Connect provider change event
|
||||
computer_provider.change(
|
||||
fn=update_provider_visibility,
|
||||
inputs=[computer_provider],
|
||||
outputs=[container_name, cua_cloud_api_key],
|
||||
queue=False
|
||||
queue=False,
|
||||
)
|
||||
|
||||
|
||||
# Connect UI update events
|
||||
for dropdown in [agent_loop, omni_model_choice, uitars_model_choice, openai_model_choice, anthropic_model_choice]:
|
||||
for dropdown in [
|
||||
agent_loop,
|
||||
omni_model_choice,
|
||||
uitars_model_choice,
|
||||
openai_model_choice,
|
||||
anthropic_model_choice,
|
||||
]:
|
||||
dropdown.change(
|
||||
fn=update_ui,
|
||||
inputs=[agent_loop, openai_model_choice, anthropic_model_choice, omni_model_choice, uitars_model_choice],
|
||||
outputs=[
|
||||
openai_model_choice, anthropic_model_choice, omni_model_choice, uitars_model_choice,
|
||||
openai_key_group, anthropic_key_group,
|
||||
custom_model, provider_base_url, provider_api_key,
|
||||
model_choice
|
||||
inputs=[
|
||||
agent_loop,
|
||||
openai_model_choice,
|
||||
anthropic_model_choice,
|
||||
omni_model_choice,
|
||||
uitars_model_choice,
|
||||
],
|
||||
queue=False
|
||||
outputs=[
|
||||
openai_model_choice,
|
||||
anthropic_model_choice,
|
||||
omni_model_choice,
|
||||
uitars_model_choice,
|
||||
openai_key_group,
|
||||
anthropic_key_group,
|
||||
custom_model,
|
||||
provider_base_url,
|
||||
provider_api_key,
|
||||
model_choice,
|
||||
],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
save_trajectory = gr.Checkbox(
|
||||
@@ -461,7 +547,7 @@ if __name__ == "__main__":
|
||||
info="Number of recent images to keep in context",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
|
||||
max_budget = gr.Number(
|
||||
label="Max Budget ($)",
|
||||
value=lambda: None,
|
||||
@@ -479,9 +565,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
chatbot_history = gr.Chatbot(type="messages")
|
||||
msg = gr.Textbox(
|
||||
placeholder="Ask me to perform tasks in a virtual environment"
|
||||
)
|
||||
msg = gr.Textbox(placeholder="Ask me to perform tasks in a virtual environment")
|
||||
clear = gr.Button("Clear")
|
||||
cancel_button = gr.Button("Cancel", variant="stop")
|
||||
|
||||
@@ -498,11 +582,23 @@ if __name__ == "__main__":
|
||||
global global_agent
|
||||
if global_agent:
|
||||
print("DEBUG - Cancelling agent task")
|
||||
history.append(gr.ChatMessage(role="assistant", content="Task cancelled by user", metadata={"title": "❌ Cancelled"}))
|
||||
history.append(
|
||||
gr.ChatMessage(
|
||||
role="assistant",
|
||||
content="Task cancelled by user",
|
||||
metadata={"title": "❌ Cancelled"},
|
||||
)
|
||||
)
|
||||
else:
|
||||
history.append(gr.ChatMessage(role="assistant", content="No active agent task to cancel", metadata={"title": "ℹ️ Info"}))
|
||||
history.append(
|
||||
gr.ChatMessage(
|
||||
role="assistant",
|
||||
content="No active agent task to cancel",
|
||||
metadata={"title": "ℹ️ Info"},
|
||||
)
|
||||
)
|
||||
return history
|
||||
|
||||
|
||||
# Process response function
|
||||
async def process_response(
|
||||
history,
|
||||
@@ -542,10 +638,13 @@ if __name__ == "__main__":
|
||||
model_choice_value = uitars_model_value
|
||||
else:
|
||||
model_choice_value = "No models available"
|
||||
|
||||
|
||||
# Determine if this is a custom model selection
|
||||
is_custom_model_selected = model_choice_value in ["Custom model (OpenAI compatible API)", "Custom model (ollama)"]
|
||||
|
||||
is_custom_model_selected = model_choice_value in [
|
||||
"Custom model (OpenAI compatible API)",
|
||||
"Custom model (ollama)",
|
||||
]
|
||||
|
||||
# Determine the model name string to analyze
|
||||
if is_custom_model_selected:
|
||||
model_string_to_analyze = custom_model_value
|
||||
@@ -583,13 +682,19 @@ if __name__ == "__main__":
|
||||
model_string=model_string,
|
||||
save_trajectory=save_traj,
|
||||
only_n_most_recent_images=recent_imgs,
|
||||
custom_model_name=custom_model_value if is_custom_model_selected else None,
|
||||
custom_model_name=(
|
||||
custom_model_value if is_custom_model_selected else None
|
||||
),
|
||||
computer_os=computer_os,
|
||||
computer_provider=computer_provider,
|
||||
computer_name=container_name,
|
||||
computer_api_key=cua_cloud_api_key,
|
||||
verbosity=logging.DEBUG,
|
||||
max_trajectory_budget=max_budget_value if max_budget_value and max_budget_value > 0 else None,
|
||||
max_trajectory_budget=(
|
||||
max_budget_value
|
||||
if max_budget_value and max_budget_value > 0
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
if global_agent is None:
|
||||
@@ -605,7 +710,7 @@ if __name__ == "__main__":
|
||||
# Add user message to global history
|
||||
global global_messages
|
||||
global_messages.append({"role": "user", "content": last_user_message})
|
||||
|
||||
|
||||
# Stream responses from the agent
|
||||
async for result in global_agent.run(global_messages):
|
||||
global_messages += result.get("output", [])
|
||||
@@ -613,18 +718,20 @@ if __name__ == "__main__":
|
||||
# from pprint import pprint
|
||||
# pprint(result)
|
||||
# print(f"DEBUG - Agent response ------- END")
|
||||
|
||||
|
||||
# Process the result output
|
||||
for item in result.get("output", []):
|
||||
if item.get("type") == "message":
|
||||
content = item.get("content", [])
|
||||
for content_part in content:
|
||||
if content_part.get("text"):
|
||||
history.append(gr.ChatMessage(
|
||||
role=item.get("role", "assistant"),
|
||||
content=content_part.get("text", ""),
|
||||
metadata=content_part.get("metadata", {})
|
||||
))
|
||||
history.append(
|
||||
gr.ChatMessage(
|
||||
role=item.get("role", "assistant"),
|
||||
content=content_part.get("text", ""),
|
||||
metadata=content_part.get("metadata", {}),
|
||||
)
|
||||
)
|
||||
elif item.get("type") == "computer_call":
|
||||
action = item.get("action", {})
|
||||
action_type = action.get("type", "")
|
||||
@@ -632,43 +739,52 @@ if __name__ == "__main__":
|
||||
action_title = f"🛠️ Performing {action_type}"
|
||||
if action.get("x") and action.get("y"):
|
||||
action_title += f" at ({action['x']}, {action['y']})"
|
||||
history.append(gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=f"```json\n{json.dumps(action)}\n```",
|
||||
metadata={"title": action_title}
|
||||
))
|
||||
history.append(
|
||||
gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=f"```json\n{json.dumps(action)}\n```",
|
||||
metadata={"title": action_title},
|
||||
)
|
||||
)
|
||||
elif item.get("type") == "function_call":
|
||||
function_name = item.get("name", "")
|
||||
arguments = item.get("arguments", "{}")
|
||||
history.append(gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=f"🔧 Calling function: {function_name}\n```json\n{arguments}\n```",
|
||||
metadata={"title": f"Function Call: {function_name}"}
|
||||
))
|
||||
history.append(
|
||||
gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=f"🔧 Calling function: {function_name}\n```json\n{arguments}\n```",
|
||||
metadata={"title": f"Function Call: {function_name}"},
|
||||
)
|
||||
)
|
||||
elif item.get("type") == "function_call_output":
|
||||
output = item.get("output", "")
|
||||
history.append(gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=f"📤 Function output:\n```\n{output}\n```",
|
||||
metadata={"title": "Function Output"}
|
||||
))
|
||||
history.append(
|
||||
gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=f"📤 Function output:\n```\n{output}\n```",
|
||||
metadata={"title": "Function Output"},
|
||||
)
|
||||
)
|
||||
elif item.get("type") == "computer_call_output":
|
||||
output = item.get("output", {}).get("image_url", "")
|
||||
image_markdown = f""
|
||||
history.append(gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=image_markdown,
|
||||
metadata={"title": "🖥️ Computer Output"}
|
||||
))
|
||||
|
||||
history.append(
|
||||
gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=image_markdown,
|
||||
metadata={"title": "🖥️ Computer Output"},
|
||||
)
|
||||
)
|
||||
|
||||
yield history
|
||||
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
history.append(gr.ChatMessage(role="assistant", content=f"Error: {str(e)}"))
|
||||
yield history
|
||||
|
||||
|
||||
# Connect the submit button
|
||||
submit_event = msg.submit(
|
||||
fn=chat_submit,
|
||||
@@ -706,44 +822,77 @@ if __name__ == "__main__":
|
||||
global global_messages
|
||||
global_messages.clear()
|
||||
return None
|
||||
|
||||
|
||||
clear.click(clear_chat, None, chatbot_history, queue=False)
|
||||
|
||||
|
||||
# Connect cancel button
|
||||
cancel_button.click(
|
||||
cancel_agent_task,
|
||||
[chatbot_history],
|
||||
[chatbot_history],
|
||||
queue=False
|
||||
cancel_agent_task, [chatbot_history], [chatbot_history], queue=False
|
||||
)
|
||||
|
||||
# Code display update function
|
||||
def update_code_display(agent_loop, model_choice_val, custom_model_val, chat_history, recent_images_val, save_trajectory_val, computer_os, computer_provider, container_name, cua_cloud_api_key, max_budget_val):
|
||||
def update_code_display(
|
||||
agent_loop,
|
||||
model_choice_val,
|
||||
custom_model_val,
|
||||
chat_history,
|
||||
recent_images_val,
|
||||
save_trajectory_val,
|
||||
computer_os,
|
||||
computer_provider,
|
||||
container_name,
|
||||
cua_cloud_api_key,
|
||||
max_budget_val,
|
||||
):
|
||||
messages = []
|
||||
if chat_history:
|
||||
for msg in chat_history:
|
||||
if isinstance(msg, dict) and msg.get("role") == "user":
|
||||
messages.append(msg.get("content", ""))
|
||||
|
||||
|
||||
return generate_python_code(
|
||||
agent_loop,
|
||||
model_choice_val or custom_model_val or "gpt-4o",
|
||||
messages,
|
||||
agent_loop,
|
||||
model_choice_val or custom_model_val or "gpt-4o",
|
||||
messages,
|
||||
recent_images_val,
|
||||
save_trajectory_val,
|
||||
computer_os,
|
||||
computer_provider,
|
||||
container_name,
|
||||
cua_cloud_api_key,
|
||||
max_budget_val
|
||||
max_budget_val,
|
||||
)
|
||||
|
||||
|
||||
# Update code display when configuration changes
|
||||
for component in [agent_loop, model_choice, custom_model, chatbot_history, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key, max_budget]:
|
||||
for component in [
|
||||
agent_loop,
|
||||
model_choice,
|
||||
custom_model,
|
||||
chatbot_history,
|
||||
recent_images,
|
||||
save_trajectory,
|
||||
computer_os,
|
||||
computer_provider,
|
||||
container_name,
|
||||
cua_cloud_api_key,
|
||||
max_budget,
|
||||
]:
|
||||
component.change(
|
||||
update_code_display,
|
||||
inputs=[agent_loop, model_choice, custom_model, chatbot_history, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key, max_budget],
|
||||
outputs=[code_display]
|
||||
inputs=[
|
||||
agent_loop,
|
||||
model_choice,
|
||||
custom_model,
|
||||
chatbot_history,
|
||||
recent_images,
|
||||
save_trajectory,
|
||||
computer_os,
|
||||
computer_provider,
|
||||
container_name,
|
||||
cua_cloud_api_key,
|
||||
max_budget,
|
||||
],
|
||||
outputs=[code_display],
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
@@ -5,26 +5,30 @@ This directory contains benchmarks designed to test agent providers in the Compu
|
||||
## Overview
|
||||
|
||||
The benchmark system evaluates models on GUI grounding tasks, specifically click prediction accuracy. It supports both:
|
||||
|
||||
- **Computer Agent SDK providers** (using model strings like `"huggingface-local/HelloKKMe/GTA1-7B"`)
|
||||
- **Reference agent implementations** (custom model classes implementing the `ModelProtocol`)
|
||||
|
||||
## Available Benchmarks
|
||||
|
||||
### 1. ScreenSpot-v2 (`ss-v2.py`)
|
||||
|
||||
- **Dataset**: ScreenSpot-v2 (click-only GUI grounding)
|
||||
- **Format**: Standard resolution screenshots
|
||||
- **Task**: Predict click coordinates given an instruction and image
|
||||
- **Metrics**: Accuracy, Error Rate, Timing, VRAM usage
|
||||
|
||||
### 2. ScreenSpot-Pro (`ss-pro.py`)
|
||||
### 2. ScreenSpot-Pro (`ss-pro.py`)
|
||||
|
||||
- **Dataset**: ScreenSpot-Pro (high-resolution click-only GUI grounding)
|
||||
- **Format**: High-resolution screenshots
|
||||
- **Task**: Predict click coordinates given an instruction and image
|
||||
- **Metrics**: Accuracy, Error Rate, Timing, VRAM usage
|
||||
|
||||
### 3. Interactive Testing (`interactive.py`)
|
||||
|
||||
- **Real-time testing**: Take screenshots and visualize model predictions
|
||||
- **Commands**:
|
||||
- **Commands**:
|
||||
- Type instruction → test all models on last screenshot
|
||||
- `screenshot` → take screenshot
|
||||
- `models` → list available models
|
||||
@@ -34,14 +38,16 @@ The benchmark system evaluates models on GUI grounding tasks, specifically click
|
||||
## Running Benchmarks
|
||||
|
||||
### 1. Configure Models
|
||||
|
||||
Edit `utils.py` to specify which models you want to test in `get_available_models()`.
|
||||
|
||||
### 2. Run Benchmark
|
||||
|
||||
```bash
|
||||
# ScreenSpot-v2 benchmark
|
||||
python ss-v2.py --samples 50
|
||||
|
||||
# ScreenSpot-Pro benchmark
|
||||
# ScreenSpot-Pro benchmark
|
||||
python ss-pro.py --samples 50
|
||||
|
||||
# Interactive testing
|
||||
@@ -51,6 +57,7 @@ python interactive.py
|
||||
## Output
|
||||
|
||||
### Console Output
|
||||
|
||||
```
|
||||
Model Results:
|
||||
Accuracy: 85.50% (171/200)
|
||||
@@ -59,10 +66,11 @@ Model Results:
|
||||
```
|
||||
|
||||
### Generated Files
|
||||
|
||||
- **Markdown Report**: `*_results.md` with detailed results tables
|
||||
- **Visualizations**: `output/` directory with prediction visualizations
|
||||
- **Interactive Output**: `interactive_output/` for interactive session results
|
||||
|
||||
## Contributing
|
||||
|
||||
To add a new reference model, follow the instructions in [contrib.md](contrib.md).
|
||||
To add a new reference model, follow the instructions in [contrib.md](contrib.md).
|
||||
|
||||
@@ -17,29 +17,29 @@ class YourModelName(ModelProtocol):
|
||||
def __init__(self, model_path: str):
|
||||
self.model_path = model_path
|
||||
self._model = None
|
||||
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self.model_path
|
||||
|
||||
|
||||
async def load_model(self) -> None:
|
||||
"""Load the model into memory."""
|
||||
# Your model loading logic here
|
||||
pass
|
||||
|
||||
|
||||
async def unload_model(self) -> None:
|
||||
"""Unload the model from memory."""
|
||||
# Your model cleanup logic here
|
||||
pass
|
||||
|
||||
|
||||
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates for the given image and instruction.
|
||||
|
||||
|
||||
Args:
|
||||
image: PIL Image to analyze
|
||||
instruction: Text instruction describing what to click
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
@@ -56,7 +56,7 @@ def get_available_models() -> List[Union[str, ModelProtocol]]:
|
||||
models = [
|
||||
# Computer Agent SDK providers
|
||||
"huggingface-local/HelloKKMe/GTA1-7B",
|
||||
|
||||
|
||||
# Reference implementations
|
||||
GTA1Model("HelloKKMe/GTA1-7B"),
|
||||
YourModelName("path/to/your/model"), # Add your model here
|
||||
@@ -79,6 +79,7 @@ This will help you verify that your model loads correctly and produces reasonabl
|
||||
Here's a complete example of adding a hypothetical "MyVisionModel":
|
||||
|
||||
1. **Create `models/my_vision_model.py`:**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
@@ -91,11 +92,11 @@ class MyVisionModel(ModelProtocol):
|
||||
self.model_path = model_path
|
||||
self.model = None
|
||||
self.processor = None
|
||||
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return f"MyVisionModel({self.model_path})"
|
||||
|
||||
|
||||
async def load_model(self) -> None:
|
||||
"""Load the model and processor."""
|
||||
self.processor = AutoProcessor.from_pretrained(self.model_path)
|
||||
@@ -104,7 +105,7 @@ class MyVisionModel(ModelProtocol):
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
|
||||
async def unload_model(self) -> None:
|
||||
"""Clean up model resources."""
|
||||
del self.model
|
||||
@@ -112,7 +113,7 @@ class MyVisionModel(ModelProtocol):
|
||||
self.model = None
|
||||
self.processor = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
|
||||
"""Predict click coordinates."""
|
||||
try:
|
||||
@@ -122,19 +123,19 @@ class MyVisionModel(ModelProtocol):
|
||||
images=image,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
|
||||
# Run inference
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
|
||||
# Extract coordinates (model-specific logic)
|
||||
x, y = self._extract_coordinates(outputs)
|
||||
return (int(x), int(y))
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Prediction failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _extract_coordinates(self, outputs):
|
||||
"""Extract x, y coordinates from model outputs."""
|
||||
# Your model-specific coordinate extraction logic
|
||||
@@ -142,6 +143,7 @@ class MyVisionModel(ModelProtocol):
|
||||
```
|
||||
|
||||
2. **Update `models/__init__.py`:**
|
||||
|
||||
```python
|
||||
from .gta1 import GTA1Model
|
||||
from .my_vision_model import MyVisionModel
|
||||
@@ -150,6 +152,7 @@ __all__ = ["GTA1Model", "MyVisionModel"]
|
||||
```
|
||||
|
||||
3. **Update `utils.py`:**
|
||||
|
||||
```python
|
||||
from models import GTA1Model, MyVisionModel
|
||||
|
||||
|
||||
@@ -9,60 +9,56 @@ Models are loaded/unloaded one at a time to avoid memory issues.
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from utils import (
|
||||
ModelWrapper,
|
||||
take_screenshot,
|
||||
get_available_models,
|
||||
save_prediction_visualization,
|
||||
get_available_models
|
||||
take_screenshot,
|
||||
)
|
||||
|
||||
|
||||
async def predict_with_all_models(image, instruction: str, models) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Predict click coordinates with all models sequentially.
|
||||
|
||||
|
||||
Args:
|
||||
image: PIL Image to analyze
|
||||
instruction: Instruction text
|
||||
models: List of model instances
|
||||
|
||||
|
||||
Returns:
|
||||
List of prediction results
|
||||
"""
|
||||
predictions = []
|
||||
|
||||
|
||||
for model in models:
|
||||
model_wrapper = ModelWrapper(model)
|
||||
print(f"\n🔄 Loading {model_wrapper.model_name}...")
|
||||
|
||||
|
||||
try:
|
||||
# Load model
|
||||
await model_wrapper.load_model()
|
||||
|
||||
|
||||
# Predict
|
||||
coords = await model_wrapper.predict_click(image, instruction)
|
||||
|
||||
predictions.append({
|
||||
'model_name': model_wrapper.model_name,
|
||||
'coords': coords,
|
||||
'error': None
|
||||
})
|
||||
|
||||
|
||||
predictions.append(
|
||||
{"model_name": model_wrapper.model_name, "coords": coords, "error": None}
|
||||
)
|
||||
|
||||
if coords:
|
||||
print(f"✅ {model_wrapper.model_name}: ({coords[0]}, {coords[1]})")
|
||||
else:
|
||||
print(f"❌ {model_wrapper.model_name}: No prediction")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ {model_wrapper.model_name}: ERROR - {str(e)}")
|
||||
predictions.append({
|
||||
'model_name': model_wrapper.model_name,
|
||||
'coords': None,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
predictions.append(
|
||||
{"model_name": model_wrapper.model_name, "coords": None, "error": str(e)}
|
||||
)
|
||||
|
||||
finally:
|
||||
# Always unload model to free memory
|
||||
try:
|
||||
@@ -70,7 +66,7 @@ async def predict_with_all_models(image, instruction: str, models) -> List[Dict[
|
||||
print(f"🗑️ Unloaded {model_wrapper.model_name}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error unloading {model_wrapper.model_name}: {e}")
|
||||
|
||||
|
||||
return predictions
|
||||
|
||||
|
||||
@@ -103,87 +99,91 @@ async def main():
|
||||
Main interactive loop.
|
||||
"""
|
||||
print_header()
|
||||
|
||||
|
||||
# Get available models
|
||||
models = get_available_models()
|
||||
print_models(models)
|
||||
|
||||
|
||||
# Create output directory for visualizations
|
||||
output_dir = "interactive_output"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
|
||||
session_count = 0
|
||||
last_screenshot = None
|
||||
screenshot_timestamp = None
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Get user input
|
||||
print(f"\n{'='*40}")
|
||||
user_input = input("🎯 Enter instruction (or command): ").strip()
|
||||
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
|
||||
# Handle commands
|
||||
if user_input.lower() in ['quit', 'exit', 'q']:
|
||||
if user_input.lower() in ["quit", "exit", "q"]:
|
||||
print("👋 Goodbye!")
|
||||
break
|
||||
|
||||
elif user_input.lower() == 'models':
|
||||
|
||||
elif user_input.lower() == "models":
|
||||
print_models(models)
|
||||
continue
|
||||
|
||||
elif user_input.lower() == 'screenshot':
|
||||
|
||||
elif user_input.lower() == "screenshot":
|
||||
print("📸 Taking screenshot...")
|
||||
try:
|
||||
last_screenshot = take_screenshot()
|
||||
screenshot_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
screenshot_path = os.path.join(output_dir, f"screenshot_{screenshot_timestamp}.png")
|
||||
screenshot_path = os.path.join(
|
||||
output_dir, f"screenshot_{screenshot_timestamp}.png"
|
||||
)
|
||||
last_screenshot.save(screenshot_path)
|
||||
print(f"✅ Screenshot captured and saved to: {screenshot_path}")
|
||||
print(f"📝 Ready for instructions! Screenshot size: {last_screenshot.size}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error taking screenshot: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# Handle instruction input
|
||||
if last_screenshot is None:
|
||||
print("⚠️ No screenshot available! Please take a screenshot first using 'screenshot' command.")
|
||||
print(
|
||||
"⚠️ No screenshot available! Please take a screenshot first using 'screenshot' command."
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
session_count += 1
|
||||
print(f"\n🎯 Session {session_count}: '{user_input}'")
|
||||
print(f"📷 Using screenshot from: {screenshot_timestamp}")
|
||||
|
||||
|
||||
# Predict with all models using last screenshot
|
||||
print(f"\n🤖 Testing {len(models)} models on screenshot...")
|
||||
predictions = await predict_with_all_models(last_screenshot, user_input, models)
|
||||
|
||||
|
||||
# Display results summary
|
||||
print(f"\n📊 Results Summary:")
|
||||
print("\n📊 Results Summary:")
|
||||
print("-" * 50)
|
||||
for pred in predictions:
|
||||
if pred['coords']:
|
||||
if pred["coords"]:
|
||||
print(f"✅ {pred['model_name']}: ({pred['coords'][0]}, {pred['coords'][1]})")
|
||||
elif pred['error']:
|
||||
elif pred["error"]:
|
||||
print(f"❌ {pred['model_name']}: ERROR - {pred['error']}")
|
||||
else:
|
||||
print(f"❌ {pred['model_name']}: No prediction")
|
||||
|
||||
|
||||
# Save visualization
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
vis_filename = f"session_{session_count:03d}_{timestamp}.png"
|
||||
vis_path = os.path.join(output_dir, vis_filename)
|
||||
|
||||
|
||||
try:
|
||||
save_prediction_visualization(last_screenshot, user_input, predictions, vis_path)
|
||||
print(f"\n💾 Visualization saved to: {vis_path}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error saving visualization: {e}")
|
||||
|
||||
|
||||
print(f"\n✨ Session {session_count} completed!")
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n👋 Interrupted by user. Goodbye!")
|
||||
break
|
||||
|
||||
@@ -2,34 +2,37 @@
|
||||
Base protocol for benchmark models.
|
||||
"""
|
||||
|
||||
from typing import Protocol, Optional, Tuple
|
||||
from typing import Optional, Protocol, Tuple
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ModelProtocol(Protocol):
|
||||
"""Protocol for benchmark models that can predict click coordinates."""
|
||||
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Return the name of the model."""
|
||||
...
|
||||
|
||||
|
||||
async def load_model(self) -> None:
|
||||
"""Load the model into memory."""
|
||||
...
|
||||
|
||||
|
||||
async def unload_model(self) -> None:
|
||||
"""Unload the model from memory."""
|
||||
...
|
||||
|
||||
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
|
||||
|
||||
async def predict_click(
|
||||
self, image: Image.Image, instruction: str
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates for the given image and instruction.
|
||||
|
||||
|
||||
Args:
|
||||
image: PIL Image to analyze
|
||||
instruction: Text instruction describing what to click
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
|
||||
@@ -2,54 +2,51 @@
|
||||
GTA1 model implementation for benchmarking.
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from PIL import Image
|
||||
import torch
|
||||
import re
|
||||
import gc
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from qwen_vl_utils import process_vision_info, smart_resize
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||||
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
from .base import ModelProtocol
|
||||
|
||||
|
||||
class GTA1Model:
|
||||
"""Ground truth GTA1 model implementation."""
|
||||
|
||||
|
||||
def __init__(self, model_path: str = "HelloKKMe/GTA1-7B"):
|
||||
self.model_path = model_path
|
||||
self.model = None
|
||||
self.processor = None
|
||||
self.max_new_tokens = 32
|
||||
|
||||
self.system_prompt = '''
|
||||
|
||||
self.system_prompt = """
|
||||
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point.
|
||||
|
||||
Output the coordinate pair exactly:
|
||||
(x,y)
|
||||
'''.strip()
|
||||
|
||||
""".strip()
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Return the name of the model."""
|
||||
return f"GTA1-{self.model_path.split('/')[-1]}"
|
||||
|
||||
|
||||
async def load_model(self) -> None:
|
||||
"""Load the model into memory."""
|
||||
if self.model is None:
|
||||
print(f"Loading GTA1 model: {self.model_path}")
|
||||
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto"
|
||||
self.model_path, torch_dtype=torch.bfloat16, device_map="auto"
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
self.model_path,
|
||||
min_pixels=3136,
|
||||
max_pixels=4096 * 2160
|
||||
self.model_path, min_pixels=3136, max_pixels=4096 * 2160
|
||||
)
|
||||
print("GTA1 model loaded successfully")
|
||||
|
||||
|
||||
async def unload_model(self) -> None:
|
||||
"""Unload the model from memory."""
|
||||
if self.model is not None:
|
||||
@@ -62,23 +59,25 @@ Output the coordinate pair exactly:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
print("GTA1 model unloaded")
|
||||
|
||||
|
||||
def _extract_coordinates(self, raw_string: str) -> Tuple[int, int]:
|
||||
"""Extract coordinates from model output."""
|
||||
try:
|
||||
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
|
||||
return tuple(map(int, map(float, matches[0]))) # type: ignore
|
||||
return tuple(map(int, map(float, matches[0]))) # type: ignore
|
||||
except:
|
||||
return (0, 0)
|
||||
|
||||
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
|
||||
|
||||
async def predict_click(
|
||||
self, image: Image.Image, instruction: str
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates for the given image and instruction.
|
||||
|
||||
|
||||
Args:
|
||||
image: PIL Image to analyze
|
||||
instruction: Text instruction describing what to click
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
@@ -87,76 +86,73 @@ Output the coordinate pair exactly:
|
||||
|
||||
assert self.processor is not None
|
||||
assert self.model is not None
|
||||
|
||||
|
||||
try:
|
||||
width, height = image.width, image.height
|
||||
|
||||
|
||||
# Resize image according to processor requirements
|
||||
resized_height, resized_width = smart_resize(
|
||||
image.height,
|
||||
image.width,
|
||||
factor=self.processor.image_processor.patch_size * self.processor.image_processor.merge_size,
|
||||
factor=self.processor.image_processor.patch_size
|
||||
* self.processor.image_processor.merge_size,
|
||||
min_pixels=self.processor.image_processor.min_pixels,
|
||||
max_pixels=self.processor.image_processor.max_pixels,
|
||||
)
|
||||
resized_image = image.resize((resized_width, resized_height))
|
||||
scale_x, scale_y = width / resized_width, height / resized_height
|
||||
|
||||
|
||||
# Prepare messages
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": self.system_prompt.format(height=resized_height, width=resized_width)
|
||||
"content": self.system_prompt.format(height=resized_height, width=resized_width),
|
||||
}
|
||||
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": resized_image},
|
||||
{"type": "text", "text": instruction}
|
||||
]
|
||||
{"type": "text", "text": instruction},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# Process inputs
|
||||
image_inputs, video_inputs = process_vision_info([system_message, user_message]) # type: ignore
|
||||
image_inputs, video_inputs = process_vision_info([system_message, user_message]) # type: ignore
|
||||
text = self.processor.apply_chat_template(
|
||||
[system_message, user_message],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
[system_message, user_message], tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to(self.model.device)
|
||||
|
||||
|
||||
# Generate prediction
|
||||
output_ids = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
do_sample=False,
|
||||
temperature=1.0,
|
||||
use_cache=True
|
||||
**inputs,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
do_sample=False,
|
||||
temperature=1.0,
|
||||
use_cache=True,
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):]
|
||||
output_ids[len(input_ids) :]
|
||||
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
|
||||
]
|
||||
output_text = self.processor.batch_decode(
|
||||
generated_ids,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)[0]
|
||||
|
||||
|
||||
# Extract and rescale coordinates
|
||||
pred_x, pred_y = self._extract_coordinates(output_text)
|
||||
pred_x = int(pred_x * scale_x)
|
||||
pred_y = int(pred_y * scale_y)
|
||||
|
||||
|
||||
return (pred_x, pred_y)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in GTA1 prediction: {e}")
|
||||
return None
|
||||
|
||||
@@ -15,103 +15,106 @@ from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils import (
|
||||
ModelWrapper,
|
||||
is_click_in_bbox,
|
||||
save_results_to_markdown,
|
||||
save_visualizations,
|
||||
ModelWrapper,
|
||||
get_available_models,
|
||||
get_gpu_memory
|
||||
get_gpu_memory,
|
||||
is_click_in_bbox,
|
||||
save_results_to_markdown,
|
||||
save_visualizations,
|
||||
)
|
||||
|
||||
|
||||
async def evaluate_model(model_wrapper: ModelWrapper, dataset, max_samples: Optional[int] = None) -> dict:
|
||||
async def evaluate_model(
|
||||
model_wrapper: ModelWrapper, dataset, max_samples: Optional[int] = None
|
||||
) -> dict:
|
||||
"""
|
||||
Evaluate a model on the ScreenSpot-Pro dataset.
|
||||
|
||||
|
||||
Args:
|
||||
model_wrapper: ModelWrapper instance
|
||||
dataset: ScreenSpot-Pro dataset (list of samples)
|
||||
max_samples: Maximum number of samples to evaluate (None for all)
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with evaluation results
|
||||
"""
|
||||
print(f"\nEvaluating model: {model_wrapper.model_name}")
|
||||
|
||||
|
||||
# Load model
|
||||
await model_wrapper.load_model()
|
||||
|
||||
|
||||
total_samples = len(dataset)
|
||||
if max_samples is not None:
|
||||
total_samples = min(max_samples, total_samples)
|
||||
|
||||
|
||||
correct_predictions = 0
|
||||
error_predictions = 0
|
||||
results = []
|
||||
|
||||
|
||||
for i in tqdm(range(total_samples), desc=f"Evaluating {model_wrapper.model_name}"):
|
||||
sample = dataset[i]
|
||||
|
||||
|
||||
# Extract sample data
|
||||
image = sample['image']
|
||||
instruction = sample['instruction']
|
||||
bbox = sample['bbox'] # [x1, y1, x2, y2]
|
||||
sample_id = sample['img_filename']
|
||||
|
||||
image = sample["image"]
|
||||
instruction = sample["instruction"]
|
||||
bbox = sample["bbox"] # [x1, y1, x2, y2]
|
||||
sample_id = sample["img_filename"]
|
||||
|
||||
# Predict click coordinates with timing
|
||||
start_time = time.time()
|
||||
click_coords = await model_wrapper.predict_click(image, instruction)
|
||||
prediction_time = time.time() - start_time
|
||||
|
||||
|
||||
# Check if prediction is correct
|
||||
is_correct = is_click_in_bbox(click_coords, bbox)
|
||||
|
||||
|
||||
if is_correct:
|
||||
correct_predictions += 1
|
||||
|
||||
results.append({
|
||||
'id': sample_id,
|
||||
'instruction': instruction,
|
||||
'bbox': bbox,
|
||||
'predicted_coords': click_coords,
|
||||
'is_correct': is_correct,
|
||||
'failed': False,
|
||||
'prediction_time': prediction_time
|
||||
})
|
||||
|
||||
|
||||
results.append(
|
||||
{
|
||||
"id": sample_id,
|
||||
"instruction": instruction,
|
||||
"bbox": bbox,
|
||||
"predicted_coords": click_coords,
|
||||
"is_correct": is_correct,
|
||||
"failed": False,
|
||||
"prediction_time": prediction_time,
|
||||
}
|
||||
)
|
||||
|
||||
# Unload model
|
||||
await model_wrapper.unload_model()
|
||||
|
||||
|
||||
# Calculate metrics
|
||||
accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0
|
||||
error_rate = error_predictions / total_samples if total_samples > 0 else 0.0
|
||||
|
||||
|
||||
# Calculate timing statistics
|
||||
successful_times = [r['prediction_time'] for r in results if not r['failed']]
|
||||
successful_times = [r["prediction_time"] for r in results if not r["failed"]]
|
||||
avg_prediction_time = sum(successful_times) / len(successful_times) if successful_times else 0.0
|
||||
median_prediction_time = statistics.median(successful_times) if successful_times else 0.0
|
||||
min_prediction_time = min(successful_times) if successful_times else 0.0
|
||||
max_prediction_time = max(successful_times) if successful_times else 0.0
|
||||
|
||||
|
||||
# Get VRAM statistics
|
||||
vram_stats = model_wrapper.get_vram_stats()
|
||||
|
||||
|
||||
return {
|
||||
'model_name': model_wrapper.model_name,
|
||||
'total_samples': total_samples,
|
||||
'correct_predictions': correct_predictions,
|
||||
'failed_predictions': error_predictions,
|
||||
'accuracy': accuracy,
|
||||
'failure_rate': error_rate,
|
||||
'avg_prediction_time': avg_prediction_time,
|
||||
'median_prediction_time': median_prediction_time,
|
||||
'min_prediction_time': min_prediction_time,
|
||||
'max_prediction_time': max_prediction_time,
|
||||
'vram_max_mb': vram_stats['max_mb'],
|
||||
'vram_avg_mb': vram_stats['avg_mb'],
|
||||
'results': results
|
||||
"model_name": model_wrapper.model_name,
|
||||
"total_samples": total_samples,
|
||||
"correct_predictions": correct_predictions,
|
||||
"failed_predictions": error_predictions,
|
||||
"accuracy": accuracy,
|
||||
"failure_rate": error_rate,
|
||||
"avg_prediction_time": avg_prediction_time,
|
||||
"median_prediction_time": median_prediction_time,
|
||||
"min_prediction_time": min_prediction_time,
|
||||
"max_prediction_time": max_prediction_time,
|
||||
"vram_max_mb": vram_stats["max_mb"],
|
||||
"vram_avg_mb": vram_stats["avg_mb"],
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
@@ -120,42 +123,44 @@ async def main():
|
||||
Main function to run the benchmark.
|
||||
"""
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description='ScreenSpot-Pro Benchmark Script')
|
||||
parser.add_argument('--samples', type=int, default=300,
|
||||
help='Number of samples to evaluate (default: 300)')
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help='Random seed for shuffling (default: 42)')
|
||||
parser = argparse.ArgumentParser(description="ScreenSpot-Pro Benchmark Script")
|
||||
parser.add_argument(
|
||||
"--samples", type=int, default=300, help="Number of samples to evaluate (default: 300)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=42, help="Random seed for shuffling (default: 42)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Set random seed
|
||||
random.seed(args.seed)
|
||||
|
||||
|
||||
# Load dataset
|
||||
print("Loading ScreenSpot-Pro dataset...")
|
||||
ds = load_dataset("lmms-lab/ScreenSpot-Pro")
|
||||
dataset = ds['train'] # type: ignore
|
||||
dataset = ds["train"] # type: ignore
|
||||
# Convert to list to support indexing
|
||||
dataset_list = list(dataset)
|
||||
print(f"Dataset loaded: {len(dataset_list)} samples")
|
||||
|
||||
|
||||
# Shuffle dataset with seed
|
||||
random.shuffle(dataset_list)
|
||||
print(f"Dataset shuffled with seed {args.seed}")
|
||||
|
||||
|
||||
# Get available models
|
||||
models = get_available_models()
|
||||
|
||||
|
||||
# Evaluation settings
|
||||
max_samples = args.samples # Use command line argument
|
||||
|
||||
|
||||
# Run evaluations
|
||||
all_results = []
|
||||
|
||||
|
||||
for model in models:
|
||||
model_wrapper = ModelWrapper(model)
|
||||
result = await evaluate_model(model_wrapper, dataset_list, max_samples)
|
||||
all_results.append(result)
|
||||
|
||||
|
||||
# Print summary
|
||||
print(f"\n{result['model_name']} Results:")
|
||||
print(f" Accuracy: {result['accuracy']*100:.2f}%")
|
||||
@@ -164,15 +169,17 @@ async def main():
|
||||
print(f" Error Rate: {result['failure_rate']*100:.2f}%")
|
||||
print(f" Avg Time: {result['avg_prediction_time']:.2f}s")
|
||||
print(f" Median Time: {result['median_prediction_time']:.2f}s")
|
||||
print(f" Time Range: {result['min_prediction_time']:.2f}s - {result['max_prediction_time']:.2f}s")
|
||||
print(
|
||||
f" Time Range: {result['min_prediction_time']:.2f}s - {result['max_prediction_time']:.2f}s"
|
||||
)
|
||||
print(f" VRAM Max: {result['vram_max_mb']:.1f}MB")
|
||||
print(f" VRAM Avg: {result['vram_avg_mb']:.1f}MB")
|
||||
|
||||
|
||||
# Print GPU memory info
|
||||
gpu_memory = get_gpu_memory()
|
||||
if gpu_memory and gpu_memory[0] > 0:
|
||||
print(f" GPU Free Memory: {gpu_memory[0]:.1f}MB")
|
||||
|
||||
|
||||
# Save results
|
||||
if all_results:
|
||||
save_results_to_markdown(all_results)
|
||||
@@ -183,4 +190,4 @@ async def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -15,36 +15,37 @@ from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils import (
|
||||
ModelWrapper,
|
||||
is_click_in_bbox,
|
||||
save_results_to_markdown,
|
||||
save_visualizations,
|
||||
ModelWrapper,
|
||||
get_available_models,
|
||||
get_gpu_memory
|
||||
get_gpu_memory,
|
||||
is_click_in_bbox,
|
||||
save_results_to_markdown,
|
||||
save_visualizations,
|
||||
)
|
||||
|
||||
|
||||
async def evaluate_model(model_wrapper: ModelWrapper, samples, max_samples: Optional[int] = None) -> dict:
|
||||
async def evaluate_model(
|
||||
model_wrapper: ModelWrapper, samples, max_samples: Optional[int] = None
|
||||
) -> dict:
|
||||
"""
|
||||
Evaluate a model on any iterable of samples.
|
||||
|
||||
|
||||
Args:
|
||||
model_wrapper: ModelWrapper instance
|
||||
samples: Iterable of dicts with keys: image, bbox, instruction
|
||||
max_samples: Maximum number of samples to evaluate (None for all)
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with evaluation results
|
||||
"""
|
||||
print(f"\nEvaluating model: {model_wrapper.model_name}")
|
||||
|
||||
|
||||
# Load model
|
||||
await model_wrapper.load_model()
|
||||
|
||||
|
||||
# Convert to list if needed and limit samples
|
||||
if hasattr(samples, '__len__'):
|
||||
if hasattr(samples, "__len__"):
|
||||
total_samples = len(samples)
|
||||
if max_samples is not None:
|
||||
total_samples = min(max_samples, total_samples)
|
||||
@@ -55,69 +56,71 @@ async def evaluate_model(model_wrapper: ModelWrapper, samples, max_samples: Opti
|
||||
if max_samples is not None:
|
||||
sample_list = sample_list[:max_samples]
|
||||
total_samples = len(sample_list)
|
||||
|
||||
|
||||
correct_predictions = 0
|
||||
error_predictions = 0
|
||||
results = []
|
||||
|
||||
|
||||
for i, sample in enumerate(tqdm(sample_list, desc=f"Evaluating {model_wrapper.model_name}")):
|
||||
# Extract required data (only these 3 keys matter)
|
||||
image = sample['image']
|
||||
instruction = sample['instruction']
|
||||
bbox = sample['bbox'] # [x1, y1, x2, y2]
|
||||
|
||||
image = sample["image"]
|
||||
instruction = sample["instruction"]
|
||||
bbox = sample["bbox"] # [x1, y1, x2, y2]
|
||||
|
||||
# Predict click coordinates with timing
|
||||
start_time = time.time()
|
||||
click_coords = await model_wrapper.predict_click(image, instruction)
|
||||
prediction_time = time.time() - start_time
|
||||
|
||||
|
||||
# Check if prediction is correct
|
||||
is_correct = is_click_in_bbox(click_coords, bbox)
|
||||
|
||||
|
||||
if is_correct:
|
||||
correct_predictions += 1
|
||||
|
||||
results.append({
|
||||
'sample_idx': i,
|
||||
'instruction': instruction,
|
||||
'bbox': bbox,
|
||||
'predicted_coords': click_coords,
|
||||
'is_correct': is_correct,
|
||||
'failed': False,
|
||||
'prediction_time': prediction_time
|
||||
})
|
||||
|
||||
|
||||
results.append(
|
||||
{
|
||||
"sample_idx": i,
|
||||
"instruction": instruction,
|
||||
"bbox": bbox,
|
||||
"predicted_coords": click_coords,
|
||||
"is_correct": is_correct,
|
||||
"failed": False,
|
||||
"prediction_time": prediction_time,
|
||||
}
|
||||
)
|
||||
|
||||
# Unload model
|
||||
await model_wrapper.unload_model()
|
||||
|
||||
|
||||
# Calculate metrics
|
||||
accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0
|
||||
error_rate = error_predictions / total_samples if total_samples > 0 else 0.0
|
||||
|
||||
|
||||
# Calculate timing statistics
|
||||
successful_times = [r['prediction_time'] for r in results if not r['failed']]
|
||||
successful_times = [r["prediction_time"] for r in results if not r["failed"]]
|
||||
avg_prediction_time = sum(successful_times) / len(successful_times) if successful_times else 0.0
|
||||
median_prediction_time = statistics.median(successful_times) if successful_times else 0.0
|
||||
min_prediction_time = min(successful_times) if successful_times else 0.0
|
||||
max_prediction_time = max(successful_times) if successful_times else 0.0
|
||||
|
||||
|
||||
# Get VRAM statistics
|
||||
vram_stats = model_wrapper.get_vram_stats()
|
||||
|
||||
|
||||
return {
|
||||
'model_name': model_wrapper.model_name,
|
||||
'total_samples': total_samples,
|
||||
'correct_predictions': correct_predictions,
|
||||
'failed_predictions': error_predictions,
|
||||
'accuracy': accuracy,
|
||||
'failure_rate': error_rate,
|
||||
'avg_prediction_time': avg_prediction_time,
|
||||
'median_prediction_time': median_prediction_time,
|
||||
'min_prediction_time': min_prediction_time,
|
||||
'max_prediction_time': max_prediction_time,
|
||||
'vram_max_mb': vram_stats['max_mb'],
|
||||
'vram_avg_mb': vram_stats['avg_mb'],
|
||||
'results': results
|
||||
"model_name": model_wrapper.model_name,
|
||||
"total_samples": total_samples,
|
||||
"correct_predictions": correct_predictions,
|
||||
"failed_predictions": error_predictions,
|
||||
"accuracy": accuracy,
|
||||
"failure_rate": error_rate,
|
||||
"avg_prediction_time": avg_prediction_time,
|
||||
"median_prediction_time": median_prediction_time,
|
||||
"min_prediction_time": min_prediction_time,
|
||||
"max_prediction_time": max_prediction_time,
|
||||
"vram_max_mb": vram_stats["max_mb"],
|
||||
"vram_avg_mb": vram_stats["avg_mb"],
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
@@ -126,56 +129,60 @@ async def main():
|
||||
Main function to run the benchmark.
|
||||
"""
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description='ScreenSpot-v2 Benchmark Script')
|
||||
parser.add_argument('--samples', type=int, default=500,
|
||||
help='Number of samples to evaluate (default: 500)')
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help='Random seed for shuffling (default: 42)')
|
||||
parser = argparse.ArgumentParser(description="ScreenSpot-v2 Benchmark Script")
|
||||
parser.add_argument(
|
||||
"--samples", type=int, default=500, help="Number of samples to evaluate (default: 500)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=42, help="Random seed for shuffling (default: 42)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Set random seed
|
||||
random.seed(args.seed)
|
||||
|
||||
|
||||
# Load dataset
|
||||
print("Loading ScreenSpot-v2 dataset...")
|
||||
ds = load_dataset("lmms-lab/ScreenSpot-v2")
|
||||
dataset = ds['train'] # type: ignore
|
||||
dataset = ds["train"] # type: ignore
|
||||
# Convert to simple list of dicts with only required keys
|
||||
samples = []
|
||||
for item in dataset:
|
||||
# Convert dataset item to dict if needed
|
||||
item_dict = dict(item) if hasattr(item, 'keys') else item
|
||||
|
||||
item_dict = dict(item) if hasattr(item, "keys") else item
|
||||
|
||||
# Convert ScreenSpot-v2 bbox format [x, y, w, h] to [x1, y1, x2, y2]
|
||||
bbox_xywh = item_dict['bbox'] # type: ignore
|
||||
bbox_xywh = item_dict["bbox"] # type: ignore
|
||||
x, y, w, h = bbox_xywh
|
||||
bbox_xyxy = [x, y, x + w, y + h]
|
||||
|
||||
samples.append({
|
||||
'image': item_dict['image'], # type: ignore
|
||||
'instruction': item_dict['instruction'], # type: ignore
|
||||
'bbox': bbox_xyxy
|
||||
})
|
||||
|
||||
samples.append(
|
||||
{
|
||||
"image": item_dict["image"], # type: ignore
|
||||
"instruction": item_dict["instruction"], # type: ignore
|
||||
"bbox": bbox_xyxy,
|
||||
}
|
||||
)
|
||||
print(f"Dataset loaded: {len(samples)} samples")
|
||||
|
||||
|
||||
# Shuffle samples with seed
|
||||
random.shuffle(samples)
|
||||
print(f"Samples shuffled with seed {args.seed}")
|
||||
|
||||
|
||||
# Get available models
|
||||
models = get_available_models()
|
||||
|
||||
|
||||
# Evaluation settings
|
||||
max_samples = args.samples # Use command line argument
|
||||
|
||||
|
||||
# Run evaluations
|
||||
all_results = []
|
||||
|
||||
|
||||
for model in models:
|
||||
model_wrapper = ModelWrapper(model)
|
||||
result = await evaluate_model(model_wrapper, samples, max_samples)
|
||||
all_results.append(result)
|
||||
|
||||
|
||||
# Print summary
|
||||
print(f"\n{result['model_name']} Results:")
|
||||
print(f" Accuracy: {result['accuracy']*100:.2f}%")
|
||||
@@ -184,18 +191,22 @@ async def main():
|
||||
print(f" Error Rate: {result['failure_rate']*100:.2f}%")
|
||||
print(f" Avg Time: {result['avg_prediction_time']:.2f}s")
|
||||
print(f" Median Time: {result['median_prediction_time']:.2f}s")
|
||||
print(f" Time Range: {result['min_prediction_time']:.2f}s - {result['max_prediction_time']:.2f}s")
|
||||
print(
|
||||
f" Time Range: {result['min_prediction_time']:.2f}s - {result['max_prediction_time']:.2f}s"
|
||||
)
|
||||
print(f" VRAM Max: {result['vram_max_mb']:.1f}MB")
|
||||
print(f" VRAM Avg: {result['vram_avg_mb']:.1f}MB")
|
||||
|
||||
|
||||
# Print GPU memory info
|
||||
gpu_memory = get_gpu_memory()
|
||||
if gpu_memory and gpu_memory[0] > 0:
|
||||
print(f" GPU Free Memory: {gpu_memory[0]:.1f}MB")
|
||||
|
||||
|
||||
# Save results
|
||||
if all_results:
|
||||
save_results_to_markdown(all_results, "screenspot_v2_results.md", title="ScreenSpot-v2 Benchmark Results")
|
||||
save_results_to_markdown(
|
||||
all_results, "screenspot_v2_results.md", title="ScreenSpot-v2 Benchmark Results"
|
||||
)
|
||||
save_visualizations(all_results, samples)
|
||||
print("\nBenchmark completed successfully!")
|
||||
else:
|
||||
@@ -203,4 +214,4 @@ async def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -4,38 +4,40 @@ Shared utilities for ScreenSpot-Pro benchmarking and interactive testing.
|
||||
"""
|
||||
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import subprocess as sp
|
||||
import statistics
|
||||
import subprocess as sp
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import List, Union, Tuple, Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image, ImageDraw
|
||||
from tqdm import tqdm
|
||||
import gc
|
||||
import torch
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
from agent.agent import ComputerAgent
|
||||
from models.base import ModelProtocol
|
||||
|
||||
|
||||
def get_gpu_memory() -> List[int]:
|
||||
"""
|
||||
Get GPU memory usage using nvidia-smi.
|
||||
|
||||
|
||||
Returns:
|
||||
List of free memory values in MB for each GPU
|
||||
"""
|
||||
try:
|
||||
command = "nvidia-smi --query-gpu=memory.free --format=csv"
|
||||
memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
|
||||
memory_free_info = sp.check_output(command.split()).decode("ascii").split("\n")[:-1][1:]
|
||||
memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
|
||||
return memory_free_values
|
||||
except (sp.CalledProcessError, FileNotFoundError, IndexError):
|
||||
@@ -51,39 +53,34 @@ def get_gpu_memory() -> List[int]:
|
||||
def get_vram_usage() -> dict:
|
||||
"""
|
||||
Get current VRAM usage statistics.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with VRAM usage info (in MB)
|
||||
"""
|
||||
if torch.cuda.is_available():
|
||||
device = torch.cuda.current_device()
|
||||
allocated = torch.cuda.memory_allocated(device) / 1024 / 1024 # Convert to MB
|
||||
reserved = torch.cuda.memory_reserved(device) / 1024 / 1024 # Convert to MB
|
||||
reserved = torch.cuda.memory_reserved(device) / 1024 / 1024 # Convert to MB
|
||||
total = torch.cuda.get_device_properties(device).total_memory / 1024 / 1024
|
||||
return {
|
||||
'allocated_mb': allocated,
|
||||
'reserved_mb': reserved,
|
||||
'total_mb': total,
|
||||
'free_mb': total - reserved
|
||||
"allocated_mb": allocated,
|
||||
"reserved_mb": reserved,
|
||||
"total_mb": total,
|
||||
"free_mb": total - reserved,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'allocated_mb': 0.0,
|
||||
'reserved_mb': 0.0,
|
||||
'total_mb': 0.0,
|
||||
'free_mb': 0.0
|
||||
}
|
||||
return {"allocated_mb": 0.0, "reserved_mb": 0.0, "total_mb": 0.0, "free_mb": 0.0}
|
||||
|
||||
|
||||
def get_available_models() -> List[Union[str, ModelProtocol]]:
|
||||
"""
|
||||
Get list of available models for testing.
|
||||
|
||||
|
||||
Returns:
|
||||
List of model strings and model classes
|
||||
"""
|
||||
local_provider = "huggingface-local/" # Options: huggingface-local/ or mlx/
|
||||
|
||||
|
||||
# from models.gta1 import GTA1Model
|
||||
|
||||
models = [
|
||||
@@ -94,42 +91,41 @@ def get_available_models() -> List[Union[str, ModelProtocol]]:
|
||||
# f"{local_provider}HelloKKMe/GTA1-32B",
|
||||
"openai/computer-use-preview+openai/gpt-4o-mini",
|
||||
"anthropic/claude-opus-4-20250514+openai/gpt-4o-mini",
|
||||
|
||||
# === Reference model classes ===
|
||||
# GTA1Model("HelloKKMe/GTA1-7B"),
|
||||
# GTA1Model("HelloKKMe/GTA1-32B"),
|
||||
# GTA1Model("HelloKKMe/GTA1-32B"),
|
||||
]
|
||||
|
||||
|
||||
return models
|
||||
|
||||
|
||||
def is_click_in_bbox(click_coords: Optional[Tuple[int, int]], bbox: List[int]) -> bool:
|
||||
"""
|
||||
Check if click coordinates are within the bounding box.
|
||||
|
||||
|
||||
Args:
|
||||
click_coords: (x, y) coordinates or None
|
||||
bbox: [x1, y1, x2, y2] bounding box
|
||||
|
||||
|
||||
Returns:
|
||||
True if click is within bbox, False otherwise
|
||||
"""
|
||||
if click_coords is None:
|
||||
return False
|
||||
|
||||
|
||||
x, y = click_coords
|
||||
x1, y1, x2, y2 = bbox
|
||||
|
||||
|
||||
return x1 <= x <= x2 and y1 <= y <= y2
|
||||
|
||||
|
||||
def image_to_base64(image: Image.Image) -> str:
|
||||
"""
|
||||
Convert PIL Image to base64 string.
|
||||
|
||||
|
||||
Args:
|
||||
image: PIL Image
|
||||
|
||||
|
||||
Returns:
|
||||
Base64 encoded image string
|
||||
"""
|
||||
@@ -142,213 +138,252 @@ class ModelWrapper:
|
||||
"""
|
||||
Wrapper to provide unified interface for both ComputerAgent and custom models.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, model: Union[str, ModelProtocol]):
|
||||
self.model = model
|
||||
self.is_computer_agent = isinstance(model, str)
|
||||
self.agent: Optional[ComputerAgent] = None
|
||||
self.vram_usage_history: List[float] = [] # Track VRAM usage over time
|
||||
|
||||
|
||||
if self.is_computer_agent:
|
||||
self.model_name = str(model)
|
||||
else:
|
||||
self.model_name = f"{model.__class__.__name__}('{getattr(model, 'model_name', 'unknown')}')"
|
||||
|
||||
self.model_name = (
|
||||
f"{model.__class__.__name__}('{getattr(model, 'model_name', 'unknown')}')"
|
||||
)
|
||||
|
||||
async def load_model(self) -> None:
|
||||
"""Load the model."""
|
||||
if self.is_computer_agent:
|
||||
self.agent = ComputerAgent(model=str(self.model))
|
||||
else:
|
||||
await self.model.load_model() # type: ignore
|
||||
|
||||
await self.model.load_model() # type: ignore
|
||||
|
||||
# Record initial VRAM usage after loading
|
||||
vram_info = get_vram_usage()
|
||||
self.vram_usage_history.append(vram_info['allocated_mb'])
|
||||
|
||||
self.vram_usage_history.append(vram_info["allocated_mb"])
|
||||
|
||||
async def unload_model(self) -> None:
|
||||
"""Unload the model."""
|
||||
if not self.is_computer_agent:
|
||||
await self.model.unload_model() # type: ignore
|
||||
await self.model.unload_model() # type: ignore
|
||||
else:
|
||||
del self.agent
|
||||
self.agent = None
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
# Record VRAM usage after unloading
|
||||
vram_info = get_vram_usage()
|
||||
self.vram_usage_history.append(vram_info['allocated_mb'])
|
||||
|
||||
self.vram_usage_history.append(vram_info["allocated_mb"])
|
||||
|
||||
def get_vram_stats(self) -> dict:
|
||||
"""Get VRAM usage statistics for this model."""
|
||||
if not self.vram_usage_history:
|
||||
return {'max_mb': 0.0, 'avg_mb': 0.0}
|
||||
|
||||
return {"max_mb": 0.0, "avg_mb": 0.0}
|
||||
|
||||
return {
|
||||
'max_mb': max(self.vram_usage_history),
|
||||
'avg_mb': sum(self.vram_usage_history) / len(self.vram_usage_history)
|
||||
"max_mb": max(self.vram_usage_history),
|
||||
"avg_mb": sum(self.vram_usage_history) / len(self.vram_usage_history),
|
||||
}
|
||||
|
||||
|
||||
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
|
||||
async def predict_click(
|
||||
self, image: Image.Image, instruction: str
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""Predict click coordinates."""
|
||||
# Record VRAM usage before prediction
|
||||
vram_info = get_vram_usage()
|
||||
self.vram_usage_history.append(vram_info['allocated_mb'])
|
||||
|
||||
self.vram_usage_history.append(vram_info["allocated_mb"])
|
||||
|
||||
if self.is_computer_agent:
|
||||
if self.agent is None:
|
||||
await self.load_model()
|
||||
|
||||
|
||||
if self.agent is not None:
|
||||
image_b64 = image_to_base64(image)
|
||||
result = await self.agent.predict_click(instruction=instruction, image_b64=image_b64)
|
||||
|
||||
result = await self.agent.predict_click(
|
||||
instruction=instruction, image_b64=image_b64
|
||||
)
|
||||
|
||||
# Record VRAM usage after prediction
|
||||
vram_info = get_vram_usage()
|
||||
self.vram_usage_history.append(vram_info['allocated_mb'])
|
||||
|
||||
self.vram_usage_history.append(vram_info["allocated_mb"])
|
||||
|
||||
return result
|
||||
return None
|
||||
else:
|
||||
result = await self.model.predict_click(image, instruction) # type: ignore
|
||||
|
||||
result = await self.model.predict_click(image, instruction) # type: ignore
|
||||
|
||||
# Record VRAM usage after prediction
|
||||
vram_info = get_vram_usage()
|
||||
self.vram_usage_history.append(vram_info['allocated_mb'])
|
||||
|
||||
self.vram_usage_history.append(vram_info["allocated_mb"])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def save_results_to_markdown(all_results: List[dict],output_file: str = "screenspot_pro_results.md", title: str = "ScreenSpot-Pro Benchmark Results") -> None:
|
||||
def save_results_to_markdown(
|
||||
all_results: List[dict],
|
||||
output_file: str = "screenspot_pro_results.md",
|
||||
title: str = "ScreenSpot-Pro Benchmark Results",
|
||||
) -> None:
|
||||
"""
|
||||
Save evaluation results to a markdown table.
|
||||
|
||||
|
||||
Args:
|
||||
all_results: List of evaluation results for each model
|
||||
output_file: Output markdown file path
|
||||
"""
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
f.write(f"# {title}\n\n")
|
||||
f.write(f"**Evaluation Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
||||
|
||||
|
||||
# Summary table
|
||||
f.write("## Summary\n\n")
|
||||
f.write("| Model | Total Samples | Correct | Errors | Accuracy | Error Rate | Avg Time (s) | Median Time (s) | Time Range (s) | VRAM Max (GB) | VRAM Avg (GB) |\n")
|
||||
f.write("|-------|---------------|---------|--------|----------|------------|--------------|-----------------|----------------|---------------|---------------|\n")
|
||||
|
||||
f.write(
|
||||
"| Model | Total Samples | Correct | Errors | Accuracy | Error Rate | Avg Time (s) | Median Time (s) | Time Range (s) | VRAM Max (GB) | VRAM Avg (GB) |\n"
|
||||
)
|
||||
f.write(
|
||||
"|-------|---------------|---------|--------|----------|------------|--------------|-----------------|----------------|---------------|---------------|\n"
|
||||
)
|
||||
|
||||
for result in all_results:
|
||||
model_name = result['model_name']
|
||||
total = result['total_samples']
|
||||
correct = result['correct_predictions']
|
||||
errors = result['failed_predictions']
|
||||
accuracy = result['accuracy'] * 100
|
||||
error_rate = result['failure_rate'] * 100
|
||||
avg_time = result.get('avg_prediction_time', 0.0)
|
||||
median_time = result.get('median_prediction_time', 0.0)
|
||||
min_time = result.get('min_prediction_time', 0.0)
|
||||
max_time = result.get('max_prediction_time', 0.0)
|
||||
model_name = result["model_name"]
|
||||
total = result["total_samples"]
|
||||
correct = result["correct_predictions"]
|
||||
errors = result["failed_predictions"]
|
||||
accuracy = result["accuracy"] * 100
|
||||
error_rate = result["failure_rate"] * 100
|
||||
avg_time = result.get("avg_prediction_time", 0.0)
|
||||
median_time = result.get("median_prediction_time", 0.0)
|
||||
min_time = result.get("min_prediction_time", 0.0)
|
||||
max_time = result.get("max_prediction_time", 0.0)
|
||||
time_range = f"{min_time:.2f} - {max_time:.2f}"
|
||||
vram_max = result.get('vram_max_mb', 0.0) / 1024
|
||||
vram_avg = result.get('vram_avg_mb', 0.0) / 1024
|
||||
|
||||
f.write(f"| {model_name} | {total} | {correct} | {errors} | {accuracy:.2f}% | {error_rate:.2f}% | {avg_time:.2f} | {median_time:.2f} | {time_range} | {vram_max:.1f} | {vram_avg:.1f} |\n")
|
||||
|
||||
vram_max = result.get("vram_max_mb", 0.0) / 1024
|
||||
vram_avg = result.get("vram_avg_mb", 0.0) / 1024
|
||||
|
||||
f.write(
|
||||
f"| {model_name} | {total} | {correct} | {errors} | {accuracy:.2f}% | {error_rate:.2f}% | {avg_time:.2f} | {median_time:.2f} | {time_range} | {vram_max:.1f} | {vram_avg:.1f} |\n"
|
||||
)
|
||||
|
||||
# Detailed results for each model
|
||||
for result in all_results:
|
||||
f.write(f"\n## {result['model_name']} - Detailed Results\n\n")
|
||||
f.write("| Sample Index | Instruction | BBox | Predicted | Correct | Error | Time (s) |\n")
|
||||
f.write(
|
||||
"| Sample Index | Instruction | BBox | Predicted | Correct | Error | Time (s) |\n"
|
||||
)
|
||||
f.write("|-----------|-------------|------|-----------|---------|-------|----------|\n")
|
||||
|
||||
for sample_result in result['results'][:10]: # Show first 10 samples
|
||||
sample_idx = sample_result['sample_idx']
|
||||
instruction = sample_result['instruction'][:50] + "..." if len(sample_result['instruction']) > 50 else sample_result['instruction']
|
||||
bbox = str(sample_result['bbox'])
|
||||
predicted = str(sample_result['predicted_coords']) if sample_result['predicted_coords'] else "None"
|
||||
correct = "PASS" if sample_result['is_correct'] else "FAIL"
|
||||
error = "YES" if sample_result['failed'] else "NO"
|
||||
pred_time = sample_result.get('prediction_time', 0.0)
|
||||
|
||||
f.write(f"| {sample_idx} | {instruction} | {bbox} | {predicted} | {correct} | {error} | {pred_time:.2f} |\n")
|
||||
|
||||
if len(result['results']) > 10:
|
||||
|
||||
for sample_result in result["results"][:10]: # Show first 10 samples
|
||||
sample_idx = sample_result["sample_idx"]
|
||||
instruction = (
|
||||
sample_result["instruction"][:50] + "..."
|
||||
if len(sample_result["instruction"]) > 50
|
||||
else sample_result["instruction"]
|
||||
)
|
||||
bbox = str(sample_result["bbox"])
|
||||
predicted = (
|
||||
str(sample_result["predicted_coords"])
|
||||
if sample_result["predicted_coords"]
|
||||
else "None"
|
||||
)
|
||||
correct = "PASS" if sample_result["is_correct"] else "FAIL"
|
||||
error = "YES" if sample_result["failed"] else "NO"
|
||||
pred_time = sample_result.get("prediction_time", 0.0)
|
||||
|
||||
f.write(
|
||||
f"| {sample_idx} | {instruction} | {bbox} | {predicted} | {correct} | {error} | {pred_time:.2f} |\n"
|
||||
)
|
||||
|
||||
if len(result["results"]) > 10:
|
||||
f.write(f"\n*Showing first 10 of {len(result['results'])} samples*\n")
|
||||
|
||||
|
||||
print(f"\nResults saved to: {output_file}")
|
||||
|
||||
|
||||
def save_visualizations(all_results: List[dict], samples, output_dir: str = "output") -> None:
|
||||
"""
|
||||
Save visualizations of predicted coordinates vs bboxes to an output folder.
|
||||
|
||||
|
||||
Args:
|
||||
all_results: List of evaluation results for each model
|
||||
samples: List of sample dicts with image, bbox, instruction keys
|
||||
output_dir: Output directory path
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
|
||||
for result in all_results:
|
||||
model_name = result['model_name'].replace('/', '_').replace('\\', '_')
|
||||
model_name = result["model_name"].replace("/", "_").replace("\\", "_")
|
||||
model_dir = os.path.join(output_dir, model_name)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
|
||||
print(f"Saving visualizations for {result['model_name']}...")
|
||||
|
||||
|
||||
# Save first 10 samples for visualization
|
||||
for i, sample_result in enumerate(tqdm(result['results'][:10], desc=f"Saving {model_name} visualizations")):
|
||||
for i, sample_result in enumerate(
|
||||
tqdm(result["results"][:10], desc=f"Saving {model_name} visualizations")
|
||||
):
|
||||
# Get sample data using index
|
||||
sample_idx = sample_result['sample_idx']
|
||||
|
||||
sample_idx = sample_result["sample_idx"]
|
||||
|
||||
if sample_idx < len(samples):
|
||||
sample = samples[sample_idx]
|
||||
image = sample['image'].copy() # Make a copy to avoid modifying original
|
||||
image = sample["image"].copy() # Make a copy to avoid modifying original
|
||||
else:
|
||||
print(f"Warning: Could not find sample at index {sample_idx}")
|
||||
continue
|
||||
|
||||
bbox = sample_result['bbox']
|
||||
predicted_coords = sample_result['predicted_coords']
|
||||
is_correct = sample_result['is_correct']
|
||||
|
||||
|
||||
bbox = sample_result["bbox"]
|
||||
predicted_coords = sample_result["predicted_coords"]
|
||||
is_correct = sample_result["is_correct"]
|
||||
|
||||
# Draw on image
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
|
||||
# Draw bounding box (ground truth) in green
|
||||
x1, y1, x2, y2 = bbox
|
||||
draw.rectangle([x1, y1, x2, y2], outline="green", width=3)
|
||||
draw.text((x1, y1-20), "Ground Truth", fill="green")
|
||||
|
||||
draw.text((x1, y1 - 20), "Ground Truth", fill="green")
|
||||
|
||||
# Draw predicted click in red or blue
|
||||
if predicted_coords is not None:
|
||||
px, py = predicted_coords
|
||||
color = "blue" if is_correct else "red"
|
||||
# Draw crosshair
|
||||
crosshair_size = 15
|
||||
draw.line([(px-crosshair_size, py), (px+crosshair_size, py)], fill=color, width=3)
|
||||
draw.line([(px, py-crosshair_size), (px, py+crosshair_size)], fill=color, width=3)
|
||||
draw.text((px+10, py-20), f"Predicted ({px},{py})", fill=color)
|
||||
|
||||
draw.line(
|
||||
[(px - crosshair_size, py), (px + crosshair_size, py)], fill=color, width=3
|
||||
)
|
||||
draw.line(
|
||||
[(px, py - crosshair_size), (px, py + crosshair_size)], fill=color, width=3
|
||||
)
|
||||
draw.text((px + 10, py - 20), f"Predicted ({px},{py})", fill=color)
|
||||
|
||||
# Add status text
|
||||
status = "CORRECT" if is_correct else "INCORRECT"
|
||||
status_color = "blue" if is_correct else "red"
|
||||
draw.text((10, 10), f"Status: {status}", fill=status_color)
|
||||
draw.text((10, 30), f"Instruction: {sample_result['instruction'][:50]}...", fill="black")
|
||||
|
||||
draw.text(
|
||||
(10, 30), f"Instruction: {sample_result['instruction'][:50]}...", fill="black"
|
||||
)
|
||||
|
||||
# Save image
|
||||
filename = f"sample_{i+1:02d}_idx{sample_idx}_{status.lower()}.png"
|
||||
filepath = os.path.join(model_dir, filename)
|
||||
image.save(filepath)
|
||||
|
||||
|
||||
print(f"Visualizations saved to: {model_dir}")
|
||||
|
||||
|
||||
def save_prediction_visualization(image: Image.Image, instruction: str, predictions: List[dict],
|
||||
output_file: str = "interactive_prediction.png") -> None:
|
||||
def save_prediction_visualization(
|
||||
image: Image.Image,
|
||||
instruction: str,
|
||||
predictions: List[dict],
|
||||
output_file: str = "interactive_prediction.png",
|
||||
) -> None:
|
||||
"""
|
||||
Save visualization of multiple model predictions on a single image.
|
||||
|
||||
|
||||
Args:
|
||||
image: PIL Image to visualize
|
||||
instruction: Instruction text
|
||||
@@ -358,32 +393,32 @@ def save_prediction_visualization(image: Image.Image, instruction: str, predicti
|
||||
# Create a copy of the image
|
||||
vis_image = image.copy()
|
||||
draw = ImageDraw.Draw(vis_image)
|
||||
|
||||
|
||||
# Colors for different models
|
||||
colors = ["red", "blue", "orange", "purple", "brown", "pink", "gray", "olive"]
|
||||
|
||||
|
||||
# Draw predictions
|
||||
for i, pred in enumerate(predictions):
|
||||
color = colors[i % len(colors)]
|
||||
model_name = pred['model_name']
|
||||
coords = pred.get('coords')
|
||||
error = pred.get('error')
|
||||
|
||||
model_name = pred["model_name"]
|
||||
coords = pred.get("coords")
|
||||
error = pred.get("error")
|
||||
|
||||
if coords is not None:
|
||||
px, py = coords
|
||||
# Draw crosshair
|
||||
crosshair_size = 20
|
||||
draw.line([(px-crosshair_size, py), (px+crosshair_size, py)], fill=color, width=4)
|
||||
draw.line([(px, py-crosshair_size), (px, py+crosshair_size)], fill=color, width=4)
|
||||
draw.line([(px - crosshair_size, py), (px + crosshair_size, py)], fill=color, width=4)
|
||||
draw.line([(px, py - crosshair_size), (px, py + crosshair_size)], fill=color, width=4)
|
||||
# Draw model name
|
||||
draw.text((px+15, py+15), f"{model_name}: ({px},{py})", fill=color)
|
||||
draw.text((px + 15, py + 15), f"{model_name}: ({px},{py})", fill=color)
|
||||
else:
|
||||
# Draw error text
|
||||
draw.text((10, 50 + i*20), f"{model_name}: ERROR - {error}", fill=color)
|
||||
|
||||
draw.text((10, 50 + i * 20), f"{model_name}: ERROR - {error}", fill=color)
|
||||
|
||||
# Add instruction at the top
|
||||
draw.text((10, 10), f"Instruction: {instruction}", fill="black")
|
||||
|
||||
|
||||
# Save image
|
||||
vis_image.save(output_file)
|
||||
print(f"Prediction visualization saved to: {output_file}")
|
||||
@@ -392,12 +427,13 @@ def save_prediction_visualization(image: Image.Image, instruction: str, predicti
|
||||
def take_screenshot() -> Image.Image:
|
||||
"""
|
||||
Take a screenshot of the current screen.
|
||||
|
||||
|
||||
Returns:
|
||||
PIL Image of the screenshot
|
||||
"""
|
||||
try:
|
||||
import pyautogui
|
||||
|
||||
screenshot = pyautogui.screenshot()
|
||||
return screenshot
|
||||
except ImportError:
|
||||
@@ -406,4 +442,3 @@ def take_screenshot() -> Image.Image:
|
||||
except Exception as e:
|
||||
print(f"Error taking screenshot: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -9,58 +9,61 @@ from agent import ComputerAgent
|
||||
from computer import Computer
|
||||
from computer.helpers import sandboxed
|
||||
|
||||
|
||||
@sandboxed()
|
||||
def read_file(location: str) -> str:
|
||||
"""Read contents of a file
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
location : str
|
||||
Path to the file to read
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Contents of the file or error message
|
||||
"""
|
||||
try:
|
||||
with open(location, 'r') as f:
|
||||
with open(location, "r") as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
return f"Error reading file: {str(e)}"
|
||||
|
||||
|
||||
def save_note(content: str, filename: str = "note.txt") -> str:
|
||||
"""Save content to a note file
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
content : str
|
||||
Content to save to the file
|
||||
filename : str, optional
|
||||
Name of the file to save to (default is "note.txt")
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Success or error message
|
||||
"""
|
||||
try:
|
||||
with open(filename, 'w') as f:
|
||||
with open(filename, "w") as f:
|
||||
f.write(content)
|
||||
return f"Saved note to {filename}"
|
||||
except Exception as e:
|
||||
return f"Error saving note: {str(e)}"
|
||||
|
||||
|
||||
def calculate(a: int, b: int) -> int:
|
||||
"""Calculate the sum of two integers
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : int
|
||||
First integer
|
||||
b : int
|
||||
Second integer
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
@@ -68,15 +71,18 @@ def calculate(a: int, b: int) -> int:
|
||||
"""
|
||||
return a + b
|
||||
|
||||
|
||||
async def main():
|
||||
"""Example usage of ComputerAgent with different models"""
|
||||
|
||||
|
||||
# Example 1: Using Claude with computer and custom tools
|
||||
print("=== Example 1: Claude with Computer ===")
|
||||
|
||||
import os
|
||||
import dotenv
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
assert os.getenv("CUA_CONTAINER_NAME") is not None, "CUA_CONTAINER_NAME is not set"
|
||||
@@ -86,38 +92,37 @@ async def main():
|
||||
os_type="linux",
|
||||
provider_type="cloud",
|
||||
name=os.getenv("CUA_CONTAINER_NAME") or "",
|
||||
api_key=os.getenv("CUA_API_KEY") or ""
|
||||
api_key=os.getenv("CUA_API_KEY") or "",
|
||||
) as computer:
|
||||
agent = ComputerAgent(
|
||||
# Supported models:
|
||||
|
||||
# == OpenAI CUA (computer-use-preview) ==
|
||||
model="openai/computer-use-preview",
|
||||
|
||||
# == Anthropic CUA (Claude > 3.5) ==
|
||||
# model="anthropic/claude-opus-4-20250514",
|
||||
# model="anthropic/claude-opus-4-20250514",
|
||||
# model="anthropic/claude-sonnet-4-20250514",
|
||||
# model="anthropic/claude-3-7-sonnet-20250219",
|
||||
# model="anthropic/claude-3-5-sonnet-20241022",
|
||||
|
||||
# == UI-TARS ==
|
||||
# model="huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B",
|
||||
# TODO: add local mlx provider
|
||||
# model="mlx-community/UI-TARS-1.5-7B-6bit",
|
||||
# model="ollama_chat/0000/ui-tars-1.5-7b",
|
||||
|
||||
# == Omniparser + Any LLM ==
|
||||
# model="omniparser+..."
|
||||
# model="omniparser+anthropic/claude-opus-4-20250514",
|
||||
|
||||
tools=[computer],
|
||||
only_n_most_recent_images=3,
|
||||
verbosity=logging.INFO,
|
||||
trajectory_dir="trajectories",
|
||||
use_prompt_caching=True,
|
||||
max_trajectory_budget={ "max_budget": 1.0, "raise_error": True, "reset_after_each_run": False },
|
||||
max_trajectory_budget={
|
||||
"max_budget": 1.0,
|
||||
"raise_error": True,
|
||||
"reset_after_each_run": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
history = []
|
||||
while True:
|
||||
user_input = input("> ")
|
||||
@@ -143,5 +148,6 @@ async def main():
|
||||
# elif item["type"] == "function_call_output":
|
||||
# print("===>", item["output"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
|
||||
|
||||
[project]
|
||||
name = "cua-agent"
|
||||
version = "0.4.0"
|
||||
version = "0.4.35"
|
||||
description = "CUA (Computer Use) Agent for AI-driven computer interaction"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
@@ -29,6 +29,11 @@ requires-python = ">=3.12"
|
||||
[project.optional-dependencies]
|
||||
openai = []
|
||||
anthropic = []
|
||||
qwen = [
|
||||
"qwen-vl-utils",
|
||||
"qwen-agent",
|
||||
"Pillow>=10.0.0",
|
||||
]
|
||||
omni = [
|
||||
"cua-som>=0.1.0,<0.2.0",
|
||||
]
|
||||
@@ -49,7 +54,7 @@ glm45v-hf = [
|
||||
opencua-hf = [
|
||||
"accelerate",
|
||||
"torch",
|
||||
"transformers==4.53.0",
|
||||
"transformers>=4.53.0",
|
||||
"tiktoken>=0.11.0",
|
||||
"blobfile>=3.0.0"
|
||||
]
|
||||
@@ -60,6 +65,11 @@ internvl-hf = [
|
||||
"einops",
|
||||
"timm"
|
||||
]
|
||||
moondream3 = [
|
||||
"accelerate",
|
||||
"torch",
|
||||
"transformers>=4.55.0"
|
||||
]
|
||||
ui = [
|
||||
"gradio>=5.23.3",
|
||||
"python-dotenv>=1.0.1",
|
||||
@@ -68,7 +78,10 @@ cli = [
|
||||
"yaspin>=3.1.0",
|
||||
]
|
||||
hud = [
|
||||
"hud-python==0.4.26",
|
||||
"hud-python==0.4.52",
|
||||
]
|
||||
gemini = [
|
||||
"google-genai>=1.41.0",
|
||||
]
|
||||
all = [
|
||||
# uitars requirements
|
||||
@@ -88,7 +101,13 @@ all = [
|
||||
# cli requirements
|
||||
"yaspin>=3.1.0",
|
||||
# hud requirements
|
||||
"hud-python==0.4.26",
|
||||
"hud-python==0.4.52",
|
||||
# gemini requirements
|
||||
"google-genai>=1.41.0",
|
||||
# qwen requirements
|
||||
"qwen-vl-utils",
|
||||
"qwen-agent",
|
||||
"Pillow>=10.0.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
@@ -98,4 +117,4 @@ constraint-dependencies = ["fastrtc>0.43.0", "mlx-audio>0.2.3"]
|
||||
distribution = true
|
||||
|
||||
[tool.pdm.build]
|
||||
includes = ["agent/"]
|
||||
includes = ["agent/"]
|
||||
10
libs/python/computer-server/.bumpversion.cfg
Normal file
10
libs/python/computer-server/.bumpversion.cfg
Normal file
@@ -0,0 +1,10 @@
|
||||
[bumpversion]
|
||||
current_version = 0.1.27
|
||||
commit = True
|
||||
tag = True
|
||||
tag_name = computer-server-v{new_version}
|
||||
message = Bump cua-computer-server to v{new_version}
|
||||
|
||||
[bumpversion:file:pyproject.toml]
|
||||
search = version = "{current_version}"
|
||||
replace = version = "{new_version}"
|
||||
@@ -8,10 +8,11 @@
|
||||
</picture>
|
||||
</div>
|
||||
|
||||
[](#)
|
||||
[](#)
|
||||
[](https://discord.com/invite/mVnXXpdE85)
|
||||
[](https://pypi.org/project/cua-computer-server/)
|
||||
[](#)
|
||||
[](#)
|
||||
[](https://discord.com/invite/mVnXXpdE85)
|
||||
[](https://pypi.org/project/cua-computer-server/)
|
||||
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
@@ -42,4 +43,4 @@ Refer to this notebook for a step-by-step guide on how to use the Computer-Use S
|
||||
- [Commands](https://trycua.com/docs/libraries/computer-server/Commands)
|
||||
- [REST-API](https://trycua.com/docs/libraries/computer-server/REST-API)
|
||||
- [WebSocket-API](https://trycua.com/docs/libraries/computer-server/WebSocket-API)
|
||||
- [Index](https://trycua.com/docs/libraries/computer-server/index)
|
||||
- [Index](https://trycua.com/docs/libraries/computer-server/index)
|
||||
|
||||
@@ -4,6 +4,7 @@ This allows the server to be started with `python -m computer_server`.
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
from .cli import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -36,7 +36,7 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
|
||||
help="Path to SSL private key file (enables HTTPS)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssl-certfile",
|
||||
"--ssl-certfile",
|
||||
type=str,
|
||||
help="Path to SSL certificate file (enables HTTPS)",
|
||||
)
|
||||
@@ -72,17 +72,19 @@ def main() -> None:
|
||||
|
||||
# Check if watchdog should be enabled
|
||||
container_name = os.environ.get("CONTAINER_NAME")
|
||||
enable_watchdog = args.watchdog or bool(container_name)
|
||||
|
||||
enable_watchdog = (args.watchdog or bool(container_name)) and not sys.platform.startswith("win")
|
||||
|
||||
if container_name:
|
||||
logger.info(f"Container environment detected (CONTAINER_NAME={container_name}), enabling watchdog")
|
||||
logger.info(
|
||||
f"Container environment detected (CONTAINER_NAME={container_name}), enabling watchdog"
|
||||
)
|
||||
elif args.watchdog:
|
||||
logger.info("Watchdog explicitly enabled via --watchdog flag")
|
||||
|
||||
|
||||
# Start watchdog if enabled
|
||||
if enable_watchdog:
|
||||
logger.info(f"Starting watchdog monitoring with {args.watchdog_interval}s interval")
|
||||
|
||||
|
||||
def run_watchdog_thread():
|
||||
"""Run watchdog in a separate thread."""
|
||||
loop = asyncio.new_event_loop()
|
||||
@@ -90,38 +92,32 @@ def main() -> None:
|
||||
try:
|
||||
# Create CLI args dict for watchdog
|
||||
cli_args = {
|
||||
'host': args.host,
|
||||
'port': args.port,
|
||||
'log_level': args.log_level,
|
||||
'ssl_keyfile': args.ssl_keyfile,
|
||||
'ssl_certfile': args.ssl_certfile
|
||||
"host": args.host,
|
||||
"port": args.port,
|
||||
"log_level": args.log_level,
|
||||
"ssl_keyfile": args.ssl_keyfile,
|
||||
"ssl_certfile": args.ssl_certfile,
|
||||
}
|
||||
|
||||
|
||||
# Create watchdog with restart settings
|
||||
from .watchdog import Watchdog
|
||||
watchdog = Watchdog(
|
||||
cli_args=cli_args,
|
||||
ping_interval=args.watchdog_interval
|
||||
)
|
||||
|
||||
watchdog = Watchdog(cli_args=cli_args, ping_interval=args.watchdog_interval)
|
||||
watchdog.restart_enabled = not args.no_restart
|
||||
|
||||
|
||||
loop.run_until_complete(watchdog.start_monitoring())
|
||||
except Exception as e:
|
||||
logger.error(f"Watchdog error: {e}")
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
# Start watchdog in background thread
|
||||
watchdog_thread = threading.Thread(
|
||||
target=run_watchdog_thread,
|
||||
daemon=True,
|
||||
name="watchdog"
|
||||
)
|
||||
watchdog_thread = threading.Thread(target=run_watchdog_thread, daemon=True, name="watchdog")
|
||||
watchdog_thread.start()
|
||||
|
||||
# Create and start the server
|
||||
logger.info(f"Starting CUA Computer API server on {args.host}:{args.port}...")
|
||||
|
||||
|
||||
# Handle SSL configuration
|
||||
ssl_args = {}
|
||||
if args.ssl_keyfile and args.ssl_certfile:
|
||||
@@ -131,10 +127,12 @@ def main() -> None:
|
||||
}
|
||||
logger.info("HTTPS mode enabled with SSL certificates")
|
||||
elif args.ssl_keyfile or args.ssl_certfile:
|
||||
logger.warning("Both --ssl-keyfile and --ssl-certfile are required for HTTPS. Running in HTTP mode.")
|
||||
logger.warning(
|
||||
"Both --ssl-keyfile and --ssl-certfile are required for HTTPS. Running in HTTP mode."
|
||||
)
|
||||
else:
|
||||
logger.info("HTTP mode (no SSL certificates provided)")
|
||||
|
||||
|
||||
server = Server(host=args.host, port=args.port, log_level=args.log_level, **ssl_args)
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
class BaseDioramaHandler:
|
||||
"""Base Diorama handler for unsupported OSes."""
|
||||
|
||||
async def diorama_cmd(self, action: str, arguments: dict = None) -> dict:
|
||||
return {"success": False, "error": "Diorama is not supported on this OS yet."}
|
||||
|
||||
@@ -1,31 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Diorama: A virtual desktop manager for macOS"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Union
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from computer_server.diorama.draw import capture_all_apps, AppActivationContext, get_frontmost_and_active_app, get_all_windows, get_running_apps
|
||||
|
||||
from computer_server.diorama.diorama_computer import DioramaComputer
|
||||
from computer_server.diorama.draw import (
|
||||
AppActivationContext,
|
||||
capture_all_apps,
|
||||
get_all_windows,
|
||||
get_frontmost_and_active_app,
|
||||
get_running_apps,
|
||||
)
|
||||
from computer_server.handlers.macos import *
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
# simple, nicely formatted logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
automation_handler = MacOSAutomationHandler()
|
||||
|
||||
|
||||
class Diorama:
|
||||
"""Virtual desktop manager that provides automation capabilities for macOS applications.
|
||||
|
||||
|
||||
Manages application windows and provides an interface for taking screenshots,
|
||||
mouse interactions, keyboard input, and coordinate transformations between
|
||||
screenshot space and screen space.
|
||||
"""
|
||||
|
||||
_scheduler_queue = None
|
||||
_scheduler_task = None
|
||||
_loop = None
|
||||
@@ -34,10 +41,10 @@ class Diorama:
|
||||
@classmethod
|
||||
def create_from_apps(cls, *args) -> DioramaComputer:
|
||||
"""Create a DioramaComputer instance from a list of application names.
|
||||
|
||||
|
||||
Args:
|
||||
*args: Variable number of application names to include in the desktop
|
||||
|
||||
|
||||
Returns:
|
||||
DioramaComputer: A computer interface for the specified applications
|
||||
"""
|
||||
@@ -46,10 +53,10 @@ class Diorama:
|
||||
|
||||
# Dictionary to store cursor positions for each unique app_list hash
|
||||
_cursor_positions = {}
|
||||
|
||||
|
||||
def __init__(self, app_list):
|
||||
"""Initialize a Diorama instance for the specified applications.
|
||||
|
||||
|
||||
Args:
|
||||
app_list: List of application names to manage
|
||||
"""
|
||||
@@ -57,10 +64,10 @@ class Diorama:
|
||||
self.interface = self.Interface(self)
|
||||
self.computer = DioramaComputer(self)
|
||||
self.focus_context = None
|
||||
|
||||
|
||||
# Create a hash for this app_list to use as a key
|
||||
self.app_list_hash = hash(tuple(sorted(app_list)))
|
||||
|
||||
|
||||
# Initialize cursor position for this app_list if it doesn't exist
|
||||
if self.app_list_hash not in Diorama._cursor_positions:
|
||||
Diorama._cursor_positions[self.app_list_hash] = (0, 0)
|
||||
@@ -68,7 +75,7 @@ class Diorama:
|
||||
@classmethod
|
||||
def _ensure_scheduler(cls):
|
||||
"""Ensure the async scheduler loop is running.
|
||||
|
||||
|
||||
Creates and starts the scheduler task if it hasn't been started yet.
|
||||
"""
|
||||
if not cls._scheduler_started:
|
||||
@@ -81,7 +88,7 @@ class Diorama:
|
||||
@classmethod
|
||||
async def _scheduler_loop(cls):
|
||||
"""Main scheduler loop that processes automation commands.
|
||||
|
||||
|
||||
Continuously processes commands from the scheduler queue, handling
|
||||
screenshots, mouse actions, keyboard input, and scrolling operations.
|
||||
"""
|
||||
@@ -91,31 +98,37 @@ class Diorama:
|
||||
args = cmd.get("arguments", {})
|
||||
future = cmd.get("future")
|
||||
logger.info(f"Processing command: {action} | args={args}")
|
||||
|
||||
|
||||
app_whitelist = args.get("app_list", [])
|
||||
|
||||
|
||||
all_windows = get_all_windows()
|
||||
running_apps = get_running_apps()
|
||||
frontmost_app, active_app_to_use, active_app_pid = get_frontmost_and_active_app(all_windows, running_apps, app_whitelist)
|
||||
frontmost_app, active_app_to_use, active_app_pid = get_frontmost_and_active_app(
|
||||
all_windows, running_apps, app_whitelist
|
||||
)
|
||||
focus_context = AppActivationContext(active_app_pid, active_app_to_use, logger)
|
||||
|
||||
|
||||
with focus_context:
|
||||
try:
|
||||
if action == "screenshot":
|
||||
logger.info(f"Taking screenshot for apps: {app_whitelist}")
|
||||
result, img = capture_all_apps(
|
||||
app_whitelist=app_whitelist,
|
||||
save_to_disk=False,
|
||||
take_focus=False
|
||||
app_whitelist=app_whitelist, save_to_disk=False, take_focus=False
|
||||
)
|
||||
logger.info("Screenshot complete.")
|
||||
if future:
|
||||
future.set_result((result, img))
|
||||
# Mouse actions
|
||||
elif action in ["left_click", "right_click", "double_click", "move_cursor", "drag_to"]:
|
||||
elif action in [
|
||||
"left_click",
|
||||
"right_click",
|
||||
"double_click",
|
||||
"move_cursor",
|
||||
"drag_to",
|
||||
]:
|
||||
x = args.get("x")
|
||||
y = args.get("y")
|
||||
|
||||
|
||||
duration = args.get("duration", 0.5)
|
||||
if action == "left_click":
|
||||
await automation_handler.left_click(x, y)
|
||||
@@ -134,7 +147,7 @@ class Diorama:
|
||||
y = args.get("y")
|
||||
if x is not None and y is not None:
|
||||
await automation_handler.move_cursor(x, y)
|
||||
|
||||
|
||||
clicks = args.get("clicks", 1)
|
||||
if action == "scroll_up":
|
||||
await automation_handler.scroll_up(clicks)
|
||||
@@ -171,31 +184,31 @@ class Diorama:
|
||||
if future:
|
||||
future.set_exception(e)
|
||||
|
||||
class Interface():
|
||||
class Interface:
|
||||
"""Interface for interacting with the virtual desktop.
|
||||
|
||||
|
||||
Provides methods for taking screenshots, mouse interactions, keyboard input,
|
||||
and coordinate transformations between screenshot and screen coordinates.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, diorama):
|
||||
"""Initialize the interface with a reference to the parent Diorama instance.
|
||||
|
||||
|
||||
Args:
|
||||
diorama: The parent Diorama instance
|
||||
"""
|
||||
self._diorama = diorama
|
||||
|
||||
|
||||
self._scene_hitboxes = []
|
||||
self._scene_size = None
|
||||
|
||||
async def _send_cmd(self, action, arguments=None):
|
||||
"""Send a command to the scheduler queue.
|
||||
|
||||
|
||||
Args:
|
||||
action (str): The action to perform
|
||||
arguments (dict, optional): Arguments for the action
|
||||
|
||||
|
||||
Returns:
|
||||
The result of the command execution
|
||||
"""
|
||||
@@ -203,11 +216,13 @@ class Diorama:
|
||||
loop = asyncio.get_event_loop()
|
||||
future = loop.create_future()
|
||||
logger.info(f"Enqueuing {action} command for apps: {self._diorama.app_list}")
|
||||
await Diorama._scheduler_queue.put({
|
||||
"action": action,
|
||||
"arguments": {"app_list": self._diorama.app_list, **(arguments or {})},
|
||||
"future": future
|
||||
})
|
||||
await Diorama._scheduler_queue.put(
|
||||
{
|
||||
"action": action,
|
||||
"arguments": {"app_list": self._diorama.app_list, **(arguments or {})},
|
||||
"future": future,
|
||||
}
|
||||
)
|
||||
try:
|
||||
return await future
|
||||
except asyncio.CancelledError:
|
||||
@@ -216,21 +231,23 @@ class Diorama:
|
||||
|
||||
async def screenshot(self, as_bytes: bool = True) -> Union[str, Image.Image]:
|
||||
"""Take a screenshot of the managed applications.
|
||||
|
||||
|
||||
Args:
|
||||
as_bytes (bool): If True, return base64-encoded bytes; if False, return PIL Image
|
||||
|
||||
|
||||
Returns:
|
||||
Union[str, Image.Image]: Base64-encoded PNG bytes or PIL Image object
|
||||
"""
|
||||
import base64
|
||||
|
||||
result, img = await self._send_cmd("screenshot")
|
||||
self._scene_hitboxes = result.get("hitboxes", [])
|
||||
self._scene_size = img.size
|
||||
|
||||
|
||||
if as_bytes:
|
||||
# PIL Image to bytes, then base64 encode for JSON
|
||||
import io
|
||||
|
||||
img_byte_arr = io.BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
img_bytes = img_byte_arr.getvalue()
|
||||
@@ -241,7 +258,7 @@ class Diorama:
|
||||
|
||||
async def left_click(self, x, y):
|
||||
"""Perform a left mouse click at the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x (int): X coordinate in screenshot space (or None to use last position)
|
||||
y (int): Y coordinate in screenshot space (or None to use last position)
|
||||
@@ -258,7 +275,7 @@ class Diorama:
|
||||
|
||||
async def right_click(self, x, y):
|
||||
"""Perform a right mouse click at the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x (int): X coordinate in screenshot space (or None to use last position)
|
||||
y (int): Y coordinate in screenshot space (or None to use last position)
|
||||
@@ -269,13 +286,13 @@ class Diorama:
|
||||
x, y = x or last_pos[0], y or last_pos[1]
|
||||
# Update cursor position for this app_list hash
|
||||
Diorama._cursor_positions[app_list_hash] = (x, y)
|
||||
|
||||
|
||||
sx, sy = await self.to_screen_coordinates(x, y)
|
||||
await self._send_cmd("right_click", {"x": sx, "y": sy})
|
||||
|
||||
async def double_click(self, x, y):
|
||||
"""Perform a double mouse click at the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x (int): X coordinate in screenshot space (or None to use last position)
|
||||
y (int): Y coordinate in screenshot space (or None to use last position)
|
||||
@@ -286,13 +303,13 @@ class Diorama:
|
||||
x, y = x or last_pos[0], y or last_pos[1]
|
||||
# Update cursor position for this app_list hash
|
||||
Diorama._cursor_positions[app_list_hash] = (x, y)
|
||||
|
||||
|
||||
sx, sy = await self.to_screen_coordinates(x, y)
|
||||
await self._send_cmd("double_click", {"x": sx, "y": sy})
|
||||
|
||||
async def move_cursor(self, x, y):
|
||||
"""Move the mouse cursor to the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x (int): X coordinate in screenshot space (or None to use last position)
|
||||
y (int): Y coordinate in screenshot space (or None to use last position)
|
||||
@@ -303,13 +320,13 @@ class Diorama:
|
||||
x, y = x or last_pos[0], y or last_pos[1]
|
||||
# Update cursor position for this app_list hash
|
||||
Diorama._cursor_positions[app_list_hash] = (x, y)
|
||||
|
||||
|
||||
sx, sy = await self.to_screen_coordinates(x, y)
|
||||
await self._send_cmd("move_cursor", {"x": sx, "y": sy})
|
||||
|
||||
async def drag_to(self, x, y, duration=0.5):
|
||||
"""Drag the mouse from current position to the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x (int): X coordinate in screenshot space (or None to use last position)
|
||||
y (int): Y coordinate in screenshot space (or None to use last position)
|
||||
@@ -321,13 +338,13 @@ class Diorama:
|
||||
x, y = x or last_pos[0], y or last_pos[1]
|
||||
# Update cursor position for this app_list hash
|
||||
Diorama._cursor_positions[app_list_hash] = (x, y)
|
||||
|
||||
|
||||
sx, sy = await self.to_screen_coordinates(x, y)
|
||||
await self._send_cmd("drag_to", {"x": sx, "y": sy, "duration": duration})
|
||||
|
||||
async def get_cursor_position(self):
|
||||
"""Get the current cursor position in screen coordinates.
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (x, y) coordinates of the cursor in screen space
|
||||
"""
|
||||
@@ -335,7 +352,7 @@ class Diorama:
|
||||
|
||||
async def type_text(self, text):
|
||||
"""Type the specified text using the keyboard.
|
||||
|
||||
|
||||
Args:
|
||||
text (str): The text to type
|
||||
"""
|
||||
@@ -343,7 +360,7 @@ class Diorama:
|
||||
|
||||
async def press_key(self, key):
|
||||
"""Press a single key on the keyboard.
|
||||
|
||||
|
||||
Args:
|
||||
key (str): The key to press
|
||||
"""
|
||||
@@ -351,7 +368,7 @@ class Diorama:
|
||||
|
||||
async def hotkey(self, keys):
|
||||
"""Press a combination of keys simultaneously.
|
||||
|
||||
|
||||
Args:
|
||||
keys (list): List of keys to press together
|
||||
"""
|
||||
@@ -359,7 +376,7 @@ class Diorama:
|
||||
|
||||
async def scroll_up(self, clicks: int = 1):
|
||||
"""Scroll up at the current cursor position.
|
||||
|
||||
|
||||
Args:
|
||||
clicks (int): Number of scroll clicks to perform
|
||||
"""
|
||||
@@ -367,12 +384,12 @@ class Diorama:
|
||||
app_list_hash = hash(tuple(sorted(self._diorama.app_list)))
|
||||
last_pos = Diorama._cursor_positions.get(app_list_hash, (0, 0))
|
||||
x, y = last_pos[0], last_pos[1]
|
||||
|
||||
|
||||
await self._send_cmd("scroll_up", {"clicks": clicks, "x": x, "y": y})
|
||||
|
||||
async def scroll_down(self, clicks: int = 1):
|
||||
"""Scroll down at the current cursor position.
|
||||
|
||||
|
||||
Args:
|
||||
clicks (int): Number of scroll clicks to perform
|
||||
"""
|
||||
@@ -380,18 +397,18 @@ class Diorama:
|
||||
app_list_hash = hash(tuple(sorted(self._diorama.app_list)))
|
||||
last_pos = Diorama._cursor_positions.get(app_list_hash, (0, 0))
|
||||
x, y = last_pos[0], last_pos[1]
|
||||
|
||||
|
||||
await self._send_cmd("scroll_down", {"clicks": clicks, "x": x, "y": y})
|
||||
|
||||
async def get_screen_size(self) -> dict[str, int]:
|
||||
"""Get the size of the screenshot area.
|
||||
|
||||
|
||||
Returns:
|
||||
dict[str, int]: Dictionary with 'width' and 'height' keys
|
||||
"""
|
||||
if not self._scene_size:
|
||||
await self.screenshot()
|
||||
return { "width": self._scene_size[0], "height": self._scene_size[1] }
|
||||
return {"width": self._scene_size[0], "height": self._scene_size[1]}
|
||||
|
||||
async def to_screen_coordinates(self, x: float, y: float) -> tuple[float, float]:
|
||||
"""Convert screenshot coordinates to screen coordinates.
|
||||
@@ -404,29 +421,29 @@ class Diorama:
|
||||
tuple[float, float]: (x, y) absolute coordinates in screen space
|
||||
"""
|
||||
if not self._scene_hitboxes:
|
||||
await self.screenshot() # get hitboxes
|
||||
await self.screenshot() # get hitboxes
|
||||
# Try all hitboxes
|
||||
for h in self._scene_hitboxes[::-1]:
|
||||
rect_from = h.get("hitbox")
|
||||
rect_to = h.get("target")
|
||||
if not rect_from or len(rect_from) != 4:
|
||||
continue
|
||||
|
||||
|
||||
# check if (x, y) is inside rect_from
|
||||
x0, y0, x1, y1 = rect_from
|
||||
if x0 <= x <= x1 and y0 <= y <= y1:
|
||||
logger.info(f"Found hitbox: {h}")
|
||||
# remap (x, y) to rect_to
|
||||
tx0, ty0, tx1, ty1 = rect_to
|
||||
|
||||
|
||||
# calculate offset from x0, y0
|
||||
offset_x = x - x0
|
||||
offset_y = y - y0
|
||||
|
||||
|
||||
# remap offset to rect_to
|
||||
tx = tx0 + offset_x
|
||||
ty = ty0 + offset_y
|
||||
|
||||
|
||||
return tx, ty
|
||||
return x, y
|
||||
|
||||
@@ -441,34 +458,37 @@ class Diorama:
|
||||
tuple[float, float]: (x, y) absolute coordinates in screenshot space
|
||||
"""
|
||||
if not self._scene_hitboxes:
|
||||
await self.screenshot() # get hitboxes
|
||||
await self.screenshot() # get hitboxes
|
||||
# Try all hitboxes
|
||||
for h in self._scene_hitboxes[::-1]:
|
||||
rect_from = h.get("target")
|
||||
rect_to = h.get("hitbox")
|
||||
if not rect_from or len(rect_from) != 4:
|
||||
continue
|
||||
|
||||
|
||||
# check if (x, y) is inside rect_from
|
||||
x0, y0, x1, y1 = rect_from
|
||||
if x0 <= x <= x1 and y0 <= y <= y1:
|
||||
# remap (x, y) to rect_to
|
||||
tx0, ty0, tx1, ty1 = rect_to
|
||||
|
||||
|
||||
# calculate offset from x0, y0
|
||||
offset_x = x - x0
|
||||
offset_y = y - y0
|
||||
|
||||
|
||||
# remap offset to rect_to
|
||||
tx = tx0 + offset_x
|
||||
ty = ty0 + offset_y
|
||||
|
||||
|
||||
return tx, ty
|
||||
return x, y
|
||||
|
||||
import pyautogui
|
||||
|
||||
import time
|
||||
|
||||
import pyautogui
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function demonstrating Diorama usage with multiple desktops and mouse tracking."""
|
||||
desktop1 = Diorama.create_from_apps(["Discord", "Notes"])
|
||||
@@ -511,7 +531,7 @@ async def main():
|
||||
# Draw on a copy of the screenshot
|
||||
frame = base_img.copy()
|
||||
frame_draw = ImageDraw.Draw(frame)
|
||||
frame_draw.ellipse((sx-5, sy-5, sx+5, sy+5), fill="blue", outline="blue")
|
||||
frame_draw.ellipse((sx - 5, sy - 5, sx + 5, sy + 5), fill="blue", outline="blue")
|
||||
# Save the frame
|
||||
frame.save("app_screenshots/desktop3_mouse.png")
|
||||
print(f"Mouse at screen ({mouse_x}, {mouse_y}) -> screenshot ({sx:.1f}, {sy:.1f})")
|
||||
@@ -520,15 +540,13 @@ async def main():
|
||||
print("Stopped tracking.")
|
||||
|
||||
draw.text((rect[0], rect[1]), str(idx), fill="red")
|
||||
|
||||
|
||||
canvas.save("app_screenshots/desktop3_hitboxes.png")
|
||||
|
||||
|
||||
|
||||
# move mouse in a square spiral around the screen
|
||||
import math
|
||||
import random
|
||||
|
||||
|
||||
step = 20 # pixels per move
|
||||
dot_radius = 10
|
||||
width = screen_size["width"]
|
||||
@@ -539,11 +557,12 @@ async def main():
|
||||
await desktop3.interface.move_cursor(x, y)
|
||||
img = await desktop3.interface.screenshot(as_bytes=False)
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.ellipse((x-dot_radius, y-dot_radius, x+dot_radius, y+dot_radius), fill="red")
|
||||
draw.ellipse((x - dot_radius, y - dot_radius, x + dot_radius, y + dot_radius), fill="red")
|
||||
img.save("current.png")
|
||||
await asyncio.sleep(0.03)
|
||||
x += step
|
||||
y = math.sin(x / width * math.pi * 2) * 50 + 25
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import asyncio
|
||||
|
||||
|
||||
class DioramaComputer:
|
||||
"""
|
||||
A minimal Computer-like interface for Diorama, compatible with ComputerAgent.
|
||||
Implements _initialized, run(), and __aenter__ for agent compatibility.
|
||||
"""
|
||||
|
||||
def __init__(self, diorama):
|
||||
"""
|
||||
Initialize the DioramaComputer with a diorama instance.
|
||||
|
||||
|
||||
Args:
|
||||
diorama: The diorama instance to wrap with a computer-like interface.
|
||||
"""
|
||||
@@ -19,10 +21,10 @@ class DioramaComputer:
|
||||
async def __aenter__(self):
|
||||
"""
|
||||
Async context manager entry method for compatibility with ComputerAgent.
|
||||
|
||||
|
||||
Ensures an event loop is running and marks the instance as initialized.
|
||||
Creates a new event loop if none is currently running.
|
||||
|
||||
|
||||
Returns:
|
||||
DioramaComputer: The initialized instance.
|
||||
"""
|
||||
@@ -37,10 +39,10 @@ class DioramaComputer:
|
||||
async def run(self):
|
||||
"""
|
||||
Run method stub for compatibility with ComputerAgent interface.
|
||||
|
||||
|
||||
Ensures the instance is initialized before returning. If not already
|
||||
initialized, calls __aenter__ to perform initialization.
|
||||
|
||||
|
||||
Returns:
|
||||
DioramaComputer: The initialized instance.
|
||||
"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,13 +1,15 @@
|
||||
import inspect
|
||||
import platform
|
||||
import sys
|
||||
import platform
|
||||
import inspect
|
||||
from computer_server.diorama.diorama import Diorama
|
||||
from computer_server.diorama.base import BaseDioramaHandler
|
||||
from typing import Optional
|
||||
|
||||
from computer_server.diorama.base import BaseDioramaHandler
|
||||
from computer_server.diorama.diorama import Diorama
|
||||
|
||||
|
||||
class MacOSDioramaHandler(BaseDioramaHandler):
|
||||
"""Handler for Diorama commands on macOS, using local diorama module."""
|
||||
|
||||
async def diorama_cmd(self, action: str, arguments: Optional[dict] = None) -> dict:
|
||||
if platform.system().lower() != "darwin":
|
||||
return {"success": False, "error": "Diorama is only supported on macOS."}
|
||||
@@ -30,4 +32,5 @@ class MacOSDioramaHandler(BaseDioramaHandler):
|
||||
return {"success": True, "result": result}
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
return {"success": False, "error": str(e), "trace": traceback.format_exc()}
|
||||
|
||||
@@ -8,31 +8,31 @@ like the menubar and dock, which are needed for proper screenshot composition.
|
||||
|
||||
import sys
|
||||
import time
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
# Import Objective-C bridge libraries
|
||||
try:
|
||||
import AppKit
|
||||
import Foundation
|
||||
from AppKit import NSRunningApplication, NSWorkspace
|
||||
from ApplicationServices import (
|
||||
AXUIElementCreateSystemWide,
|
||||
AXUIElementCreateApplication,
|
||||
AXUIElementCopyAttributeValue,
|
||||
AXUIElementCopyAttributeValues,
|
||||
kAXChildrenAttribute,
|
||||
kAXRoleAttribute,
|
||||
kAXTitleAttribute,
|
||||
kAXPositionAttribute,
|
||||
kAXSizeAttribute,
|
||||
kAXErrorSuccess,
|
||||
AXValueGetType,
|
||||
kAXValueCGSizeType,
|
||||
kAXValueCGPointType,
|
||||
AXUIElementCreateApplication,
|
||||
AXUIElementCreateSystemWide,
|
||||
AXUIElementGetTypeID,
|
||||
AXValueGetType,
|
||||
AXValueGetValue,
|
||||
kAXChildrenAttribute,
|
||||
kAXErrorSuccess,
|
||||
kAXMenuBarAttribute,
|
||||
kAXPositionAttribute,
|
||||
kAXRoleAttribute,
|
||||
kAXSizeAttribute,
|
||||
kAXTitleAttribute,
|
||||
kAXValueCGPointType,
|
||||
kAXValueCGSizeType,
|
||||
)
|
||||
from AppKit import NSWorkspace, NSRunningApplication
|
||||
import Foundation
|
||||
except ImportError:
|
||||
print("Error: This script requires PyObjC to be installed.")
|
||||
print("Please install it with: pip install pyobjc")
|
||||
@@ -74,13 +74,8 @@ def element_value(element, type):
|
||||
|
||||
def get_element_bounds(element):
|
||||
"""Get the bounds of an accessibility element"""
|
||||
bounds = {
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"width": 0,
|
||||
"height": 0
|
||||
}
|
||||
|
||||
bounds = {"x": 0, "y": 0, "width": 0, "height": 0}
|
||||
|
||||
# Get position
|
||||
position_value = element_attribute(element, kAXPositionAttribute)
|
||||
if position_value:
|
||||
@@ -88,7 +83,7 @@ def get_element_bounds(element):
|
||||
if position_value:
|
||||
bounds["x"] = position_value.x
|
||||
bounds["y"] = position_value.y
|
||||
|
||||
|
||||
# Get size
|
||||
size_value = element_attribute(element, kAXSizeAttribute)
|
||||
if size_value:
|
||||
@@ -96,7 +91,7 @@ def get_element_bounds(element):
|
||||
if size_value:
|
||||
bounds["width"] = size_value.width
|
||||
bounds["height"] = size_value.height
|
||||
|
||||
|
||||
return bounds
|
||||
|
||||
|
||||
@@ -111,13 +106,13 @@ def find_dock_process():
|
||||
|
||||
def get_menubar_bounds():
|
||||
"""Get the bounds of the macOS menubar
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with x, y, width, height of the menubar
|
||||
"""
|
||||
# Get the system-wide accessibility element
|
||||
system_element = AXUIElementCreateSystemWide()
|
||||
|
||||
|
||||
# Try to find the menubar
|
||||
menubar = element_attribute(system_element, kAXMenuBarAttribute)
|
||||
if menubar is None:
|
||||
@@ -127,19 +122,19 @@ def get_menubar_bounds():
|
||||
app_pid = frontmost_app.processIdentifier()
|
||||
app_element = AXUIElementCreateApplication(app_pid)
|
||||
menubar = element_attribute(app_element, kAXMenuBarAttribute)
|
||||
|
||||
|
||||
if menubar is None:
|
||||
print("Error: Could not get menubar")
|
||||
# Return default menubar bounds as fallback
|
||||
return {"x": 0, "y": 0, "width": 1800, "height": 24}
|
||||
|
||||
|
||||
# Get menubar bounds
|
||||
return get_element_bounds(menubar)
|
||||
|
||||
|
||||
def get_dock_bounds():
|
||||
"""Get the bounds of the macOS Dock
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with x, y, width, height of the Dock
|
||||
"""
|
||||
@@ -148,19 +143,19 @@ def get_dock_bounds():
|
||||
print("Error: Could not find Dock process")
|
||||
# Return empty bounds as fallback
|
||||
return {"x": 0, "y": 0, "width": 0, "height": 0}
|
||||
|
||||
|
||||
# Create an accessibility element for the Dock
|
||||
dock_element = AXUIElementCreateApplication(dock_pid)
|
||||
if dock_element is None:
|
||||
print(f"Error: Could not create accessibility element for Dock (PID {dock_pid})")
|
||||
return {"x": 0, "y": 0, "width": 0, "height": 0}
|
||||
|
||||
|
||||
# Get the Dock's children
|
||||
children = element_attribute(dock_element, kAXChildrenAttribute)
|
||||
if not children or len(children) == 0:
|
||||
print("Error: Could not get Dock children")
|
||||
return {"x": 0, "y": 0, "width": 0, "height": 0}
|
||||
|
||||
|
||||
# Find the Dock's list (first child is usually the main dock list)
|
||||
dock_list = None
|
||||
for child in children:
|
||||
@@ -168,28 +163,25 @@ def get_dock_bounds():
|
||||
if role == "AXList":
|
||||
dock_list = child
|
||||
break
|
||||
|
||||
|
||||
if dock_list is None:
|
||||
print("Error: Could not find Dock list")
|
||||
return {"x": 0, "y": 0, "width": 0, "height": 0}
|
||||
|
||||
|
||||
# Get the bounds of the dock list
|
||||
return get_element_bounds(dock_list)
|
||||
|
||||
|
||||
def get_ui_element_bounds():
|
||||
"""Get the bounds of important UI elements like menubar and dock
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with menubar and dock bounds
|
||||
"""
|
||||
menubar_bounds = get_menubar_bounds()
|
||||
dock_bounds = get_dock_bounds()
|
||||
|
||||
return {
|
||||
"menubar": menubar_bounds,
|
||||
"dock": dock_bounds
|
||||
}
|
||||
|
||||
return {"menubar": menubar_bounds, "dock": dock_bounds}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,24 +1,26 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
class BaseAccessibilityHandler(ABC):
|
||||
"""Abstract base class for OS-specific accessibility handlers."""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def get_accessibility_tree(self) -> Dict[str, Any]:
|
||||
"""Get the accessibility tree of the current window."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def find_element(self, role: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
value: Optional[str] = None) -> Dict[str, Any]:
|
||||
async def find_element(
|
||||
self, role: Optional[str] = None, title: Optional[str] = None, value: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Find an element in the accessibility tree by criteria."""
|
||||
pass
|
||||
|
||||
|
||||
class BaseFileHandler(ABC):
|
||||
"""Abstract base class for OS-specific file handlers."""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def file_exists(self, path: str) -> Dict[str, Any]:
|
||||
"""Check if a file exists at the specified path."""
|
||||
@@ -43,7 +45,7 @@ class BaseFileHandler(ABC):
|
||||
async def write_text(self, path: str, content: str) -> Dict[str, Any]:
|
||||
"""Write text content to a file."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def write_bytes(self, path: str, content_b64: str) -> Dict[str, Any]:
|
||||
"""Write binary content to a file. Sent over the websocket as a base64 string."""
|
||||
@@ -65,9 +67,11 @@ class BaseFileHandler(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def read_bytes(self, path: str, offset: int = 0, length: Optional[int] = None) -> Dict[str, Any]:
|
||||
async def read_bytes(
|
||||
self, path: str, offset: int = 0, length: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Read the binary contents of a file. Sent over the websocket as a base64 string.
|
||||
|
||||
|
||||
Args:
|
||||
path: Path to the file
|
||||
offset: Byte offset to start reading from (default: 0)
|
||||
@@ -80,9 +84,10 @@ class BaseFileHandler(ABC):
|
||||
"""Get the size of a file in bytes."""
|
||||
pass
|
||||
|
||||
|
||||
class BaseAutomationHandler(ABC):
|
||||
"""Abstract base class for OS-specific automation handlers.
|
||||
|
||||
|
||||
Categories:
|
||||
- Mouse Actions: Methods for mouse control
|
||||
- Keyboard Actions: Methods for keyboard input
|
||||
@@ -90,18 +95,22 @@ class BaseAutomationHandler(ABC):
|
||||
- Screen Actions: Methods for screen interaction
|
||||
- Clipboard Actions: Methods for clipboard operations
|
||||
"""
|
||||
|
||||
|
||||
# Mouse Actions
|
||||
@abstractmethod
|
||||
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
|
||||
async def mouse_down(
|
||||
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
|
||||
) -> Dict[str, Any]:
|
||||
"""Perform a mouse down at the current or specified position."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
|
||||
async def mouse_up(
|
||||
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
|
||||
) -> Dict[str, Any]:
|
||||
"""Perform a mouse up at the current or specified position."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Perform a left click at the current or specified position."""
|
||||
@@ -113,7 +122,9 @@ class BaseAutomationHandler(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
|
||||
async def double_click(
|
||||
self, x: Optional[int] = None, y: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Perform a double click at the current or specified position."""
|
||||
pass
|
||||
|
||||
@@ -123,9 +134,11 @@ class BaseAutomationHandler(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def drag_to(self, x: int, y: int, button: str = "left", duration: float = 0.5) -> Dict[str, Any]:
|
||||
async def drag_to(
|
||||
self, x: int, y: int, button: str = "left", duration: float = 0.5
|
||||
) -> Dict[str, Any]:
|
||||
"""Drag the cursor from current position to specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x: The x coordinate to drag to
|
||||
y: The y coordinate to drag to
|
||||
@@ -133,11 +146,13 @@ class BaseAutomationHandler(ABC):
|
||||
duration: How long the drag should take in seconds
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def drag(self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5) -> Dict[str, Any]:
|
||||
async def drag(
|
||||
self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5
|
||||
) -> Dict[str, Any]:
|
||||
"""Drag the cursor from current position to specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
path: A list of tuples of x and y coordinates to drag to
|
||||
button: The mouse button to use ('left', 'middle', 'right')
|
||||
@@ -150,12 +165,12 @@ class BaseAutomationHandler(ABC):
|
||||
async def key_down(self, key: str) -> Dict[str, Any]:
|
||||
"""Press and hold the specified key."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def key_up(self, key: str) -> Dict[str, Any]:
|
||||
"""Release the specified key."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def type_text(self, text: str) -> Dict[str, Any]:
|
||||
"""Type the specified text."""
|
||||
@@ -176,7 +191,7 @@ class BaseAutomationHandler(ABC):
|
||||
async def scroll(self, x: int, y: int) -> Dict[str, Any]:
|
||||
"""Scroll the specified amount."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def scroll_down(self, clicks: int = 1) -> Dict[str, Any]:
|
||||
"""Scroll down by the specified number of clicks."""
|
||||
@@ -212,9 +227,9 @@ class BaseAutomationHandler(ABC):
|
||||
@abstractmethod
|
||||
async def set_clipboard(self, text: str) -> Dict[str, Any]:
|
||||
"""Set the clipboard content."""
|
||||
pass
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def run_command(self, command: str) -> Dict[str, Any]:
|
||||
"""Run a command and return the output."""
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -1,68 +1,89 @@
|
||||
import platform
|
||||
import subprocess
|
||||
from typing import Tuple, Type
|
||||
from .base import BaseAccessibilityHandler, BaseAutomationHandler, BaseFileHandler
|
||||
|
||||
from computer_server.diorama.base import BaseDioramaHandler
|
||||
|
||||
from .base import BaseAccessibilityHandler, BaseAutomationHandler, BaseFileHandler
|
||||
|
||||
# Conditionally import platform-specific handlers
|
||||
system = platform.system().lower()
|
||||
if system == 'darwin':
|
||||
from .macos import MacOSAccessibilityHandler, MacOSAutomationHandler
|
||||
if system == "darwin":
|
||||
from computer_server.diorama.macos import MacOSDioramaHandler
|
||||
elif system == 'linux':
|
||||
|
||||
from .macos import MacOSAccessibilityHandler, MacOSAutomationHandler
|
||||
elif system == "linux":
|
||||
from .linux import LinuxAccessibilityHandler, LinuxAutomationHandler
|
||||
elif system == 'windows':
|
||||
elif system == "windows":
|
||||
from .windows import WindowsAccessibilityHandler, WindowsAutomationHandler
|
||||
|
||||
from .generic import GenericFileHandler
|
||||
|
||||
|
||||
class HandlerFactory:
|
||||
"""Factory for creating OS-specific handlers."""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_current_os() -> str:
|
||||
"""Determine the current OS.
|
||||
|
||||
|
||||
Returns:
|
||||
str: The OS type ('darwin' for macOS, 'linux' for Linux, or 'windows' for Windows)
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: If unable to determine the current OS
|
||||
"""
|
||||
try:
|
||||
# Use platform.system() as primary method
|
||||
system = platform.system().lower()
|
||||
if system in ['darwin', 'linux', 'windows']:
|
||||
if system in ["darwin", "linux", "windows"]:
|
||||
return system
|
||||
|
||||
|
||||
# Fallback to uname if platform.system() doesn't return expected values (Unix-like systems only)
|
||||
result = subprocess.run(['uname', '-s'], capture_output=True, text=True)
|
||||
result = subprocess.run(["uname", "-s"], capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip().lower()
|
||||
|
||||
|
||||
raise RuntimeError(f"Unsupported OS: {system}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to determine current OS: {str(e)}")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create_handlers() -> Tuple[BaseAccessibilityHandler, BaseAutomationHandler, BaseDioramaHandler, BaseFileHandler]:
|
||||
def create_handlers() -> (
|
||||
Tuple[BaseAccessibilityHandler, BaseAutomationHandler, BaseDioramaHandler, BaseFileHandler]
|
||||
):
|
||||
"""Create and return appropriate handlers for the current OS.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[BaseAccessibilityHandler, BaseAutomationHandler, BaseDioramaHandler, BaseFileHandler]: A tuple containing
|
||||
the appropriate accessibility, automation, diorama, and file handlers for the current OS.
|
||||
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the current OS is not supported
|
||||
RuntimeError: If unable to determine the current OS
|
||||
"""
|
||||
os_type = HandlerFactory._get_current_os()
|
||||
|
||||
if os_type == 'darwin':
|
||||
return MacOSAccessibilityHandler(), MacOSAutomationHandler(), MacOSDioramaHandler(), GenericFileHandler()
|
||||
elif os_type == 'linux':
|
||||
return LinuxAccessibilityHandler(), LinuxAutomationHandler(), BaseDioramaHandler(), GenericFileHandler()
|
||||
elif os_type == 'windows':
|
||||
return WindowsAccessibilityHandler(), WindowsAutomationHandler(), BaseDioramaHandler(), GenericFileHandler()
|
||||
|
||||
if os_type == "darwin":
|
||||
return (
|
||||
MacOSAccessibilityHandler(),
|
||||
MacOSAutomationHandler(),
|
||||
MacOSDioramaHandler(),
|
||||
GenericFileHandler(),
|
||||
)
|
||||
elif os_type == "linux":
|
||||
return (
|
||||
LinuxAccessibilityHandler(),
|
||||
LinuxAutomationHandler(),
|
||||
BaseDioramaHandler(),
|
||||
GenericFileHandler(),
|
||||
)
|
||||
elif os_type == "windows":
|
||||
return (
|
||||
WindowsAccessibilityHandler(),
|
||||
WindowsAutomationHandler(),
|
||||
BaseDioramaHandler(),
|
||||
GenericFileHandler(),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"OS '{os_type}' is not supported")
|
||||
|
||||
@@ -6,38 +6,41 @@ Includes:
|
||||
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from .base import BaseFileHandler
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .base import BaseFileHandler
|
||||
|
||||
|
||||
def resolve_path(path: str) -> Path:
|
||||
"""Resolve a path to its absolute path. Expand ~ to the user's home directory.
|
||||
|
||||
|
||||
Args:
|
||||
path: The file or directory path to resolve
|
||||
|
||||
|
||||
Returns:
|
||||
Path: The resolved absolute path
|
||||
"""
|
||||
return Path(path).expanduser().resolve()
|
||||
|
||||
|
||||
class GenericFileHandler(BaseFileHandler):
|
||||
"""
|
||||
Generic file handler that provides file system operations for all operating systems.
|
||||
|
||||
|
||||
This class implements the BaseFileHandler interface and provides methods for
|
||||
file and directory operations including reading, writing, creating, and deleting
|
||||
files and directories.
|
||||
"""
|
||||
|
||||
|
||||
async def file_exists(self, path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Check if a file exists at the specified path.
|
||||
|
||||
|
||||
Args:
|
||||
path: The file path to check
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing 'success' boolean and either 'exists' boolean or 'error' string
|
||||
"""
|
||||
@@ -49,10 +52,10 @@ class GenericFileHandler(BaseFileHandler):
|
||||
async def directory_exists(self, path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Check if a directory exists at the specified path.
|
||||
|
||||
|
||||
Args:
|
||||
path: The directory path to check
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing 'success' boolean and either 'exists' boolean or 'error' string
|
||||
"""
|
||||
@@ -64,25 +67,30 @@ class GenericFileHandler(BaseFileHandler):
|
||||
async def list_dir(self, path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
List all files and directories in the specified directory.
|
||||
|
||||
|
||||
Args:
|
||||
path: The directory path to list
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing 'success' boolean and either 'files' list of names or 'error' string
|
||||
"""
|
||||
try:
|
||||
return {"success": True, "files": [p.name for p in resolve_path(path).iterdir() if p.is_file() or p.is_dir()]}
|
||||
return {
|
||||
"success": True,
|
||||
"files": [
|
||||
p.name for p in resolve_path(path).iterdir() if p.is_file() or p.is_dir()
|
||||
],
|
||||
}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
async def read_text(self, path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Read the contents of a text file.
|
||||
|
||||
|
||||
Args:
|
||||
path: The file path to read from
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing 'success' boolean and either 'content' string or 'error' string
|
||||
"""
|
||||
@@ -94,11 +102,11 @@ class GenericFileHandler(BaseFileHandler):
|
||||
async def write_text(self, path: str, content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Write text content to a file.
|
||||
|
||||
|
||||
Args:
|
||||
path: The file path to write to
|
||||
content: The text content to write
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing 'success' boolean and optionally 'error' string
|
||||
"""
|
||||
@@ -108,60 +116,64 @@ class GenericFileHandler(BaseFileHandler):
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def write_bytes(self, path: str, content_b64: str, append: bool = False) -> Dict[str, Any]:
|
||||
async def write_bytes(
|
||||
self, path: str, content_b64: str, append: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Write binary content to a file from base64 encoded string.
|
||||
|
||||
|
||||
Args:
|
||||
path: The file path to write to
|
||||
content_b64: Base64 encoded binary content
|
||||
append: If True, append to existing file; if False, overwrite
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing 'success' boolean and optionally 'error' string
|
||||
"""
|
||||
try:
|
||||
mode = 'ab' if append else 'wb'
|
||||
mode = "ab" if append else "wb"
|
||||
with open(resolve_path(path), mode) as f:
|
||||
f.write(base64.b64decode(content_b64))
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def read_bytes(self, path: str, offset: int = 0, length: Optional[int] = None) -> Dict[str, Any]:
|
||||
|
||||
async def read_bytes(
|
||||
self, path: str, offset: int = 0, length: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Read binary content from a file and return as base64 encoded string.
|
||||
|
||||
|
||||
Args:
|
||||
path: The file path to read from
|
||||
offset: Byte offset to start reading from
|
||||
length: Number of bytes to read; if None, read entire file from offset
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing 'success' boolean and either 'content_b64' string or 'error' string
|
||||
"""
|
||||
try:
|
||||
file_path = resolve_path(path)
|
||||
with open(file_path, 'rb') as f:
|
||||
with open(file_path, "rb") as f:
|
||||
if offset > 0:
|
||||
f.seek(offset)
|
||||
|
||||
|
||||
if length is not None:
|
||||
content = f.read(length)
|
||||
else:
|
||||
content = f.read()
|
||||
|
||||
return {"success": True, "content_b64": base64.b64encode(content).decode('utf-8')}
|
||||
|
||||
return {"success": True, "content_b64": base64.b64encode(content).decode("utf-8")}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def get_file_size(self, path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the size of a file in bytes.
|
||||
|
||||
|
||||
Args:
|
||||
path: The file path to get size for
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing 'success' boolean and either 'size' integer or 'error' string
|
||||
"""
|
||||
@@ -175,10 +187,10 @@ class GenericFileHandler(BaseFileHandler):
|
||||
async def delete_file(self, path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Delete a file at the specified path.
|
||||
|
||||
|
||||
Args:
|
||||
path: The file path to delete
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing 'success' boolean and optionally 'error' string
|
||||
"""
|
||||
@@ -191,13 +203,13 @@ class GenericFileHandler(BaseFileHandler):
|
||||
async def create_dir(self, path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a directory at the specified path.
|
||||
|
||||
|
||||
Creates parent directories if they don't exist and doesn't raise an error
|
||||
if the directory already exists.
|
||||
|
||||
|
||||
Args:
|
||||
path: The directory path to create
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing 'success' boolean and optionally 'error' string
|
||||
"""
|
||||
@@ -210,10 +222,10 @@ class GenericFileHandler(BaseFileHandler):
|
||||
async def delete_dir(self, path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Delete an empty directory at the specified path.
|
||||
|
||||
|
||||
Args:
|
||||
path: The directory path to delete
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing 'success' boolean and optionally 'error' string
|
||||
"""
|
||||
|
||||
@@ -7,14 +7,15 @@ To use GUI automation in a headless environment:
|
||||
1. Install Xvfb: sudo apt-get install xvfb
|
||||
2. Run with virtual display: xvfb-run python -m computer_server
|
||||
"""
|
||||
from typing import Dict, Any, List, Tuple, Optional
|
||||
import logging
|
||||
import subprocess
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -23,30 +24,36 @@ logger = logging.getLogger(__name__)
|
||||
# This allows the server to run in headless environments
|
||||
try:
|
||||
import pyautogui
|
||||
|
||||
pyautogui.FAILSAFE = False
|
||||
|
||||
logger.info("pyautogui successfully imported, GUI automation available")
|
||||
except Exception as e:
|
||||
logger.warning(f"pyautogui import failed: {str(e)}. GUI operations will be simulated.")
|
||||
|
||||
from pynput.mouse import Button, Controller as MouseController
|
||||
from pynput.keyboard import Key, Controller as KeyboardController
|
||||
from pynput.keyboard import Controller as KeyboardController
|
||||
from pynput.keyboard import Key
|
||||
from pynput.mouse import Button
|
||||
from pynput.mouse import Controller as MouseController
|
||||
|
||||
from .base import BaseAccessibilityHandler, BaseAutomationHandler
|
||||
|
||||
|
||||
class LinuxAccessibilityHandler(BaseAccessibilityHandler):
|
||||
"""Linux implementation of accessibility handler."""
|
||||
|
||||
|
||||
async def get_accessibility_tree(self) -> Dict[str, Any]:
|
||||
"""Get the accessibility tree of the current window.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing success status and a simulated tree structure
|
||||
since Linux doesn't have equivalent accessibility API like macOS.
|
||||
"""
|
||||
# Linux doesn't have equivalent accessibility API like macOS
|
||||
# Return a minimal dummy tree
|
||||
logger.info("Getting accessibility tree (simulated, no accessibility API available on Linux)")
|
||||
logger.info(
|
||||
"Getting accessibility tree (simulated, no accessibility API available on Linux)"
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"tree": {
|
||||
@@ -54,32 +61,31 @@ class LinuxAccessibilityHandler(BaseAccessibilityHandler):
|
||||
"title": "Linux Window",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"size": {"width": 1920, "height": 1080},
|
||||
"children": []
|
||||
}
|
||||
"children": [],
|
||||
},
|
||||
}
|
||||
|
||||
async def find_element(self, role: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
value: Optional[str] = None) -> Dict[str, Any]:
|
||||
|
||||
async def find_element(
|
||||
self, role: Optional[str] = None, title: Optional[str] = None, value: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Find an element in the accessibility tree by criteria.
|
||||
|
||||
|
||||
Args:
|
||||
role: The role of the element to find.
|
||||
title: The title of the element to find.
|
||||
value: The value of the element to find.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary indicating that element search is not supported on Linux.
|
||||
"""
|
||||
logger.info(f"Finding element with role={role}, title={title}, value={value} (not supported on Linux)")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Element search not supported on Linux"
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Finding element with role={role}, title={title}, value={value} (not supported on Linux)"
|
||||
)
|
||||
return {"success": False, "message": "Element search not supported on Linux"}
|
||||
|
||||
def get_cursor_position(self) -> Tuple[int, int]:
|
||||
"""Get the current cursor position.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The x and y coordinates of the cursor position.
|
||||
Returns (0, 0) if pyautogui is not available.
|
||||
@@ -89,13 +95,13 @@ class LinuxAccessibilityHandler(BaseAccessibilityHandler):
|
||||
return pos.x, pos.y
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get cursor position with pyautogui: {e}")
|
||||
|
||||
|
||||
logger.info("Getting cursor position (simulated)")
|
||||
return 0, 0
|
||||
|
||||
|
||||
def get_screen_size(self) -> Tuple[int, int]:
|
||||
"""Get the screen size.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The width and height of the screen in pixels.
|
||||
Returns (1920, 1080) if pyautogui is not available.
|
||||
@@ -105,24 +111,28 @@ class LinuxAccessibilityHandler(BaseAccessibilityHandler):
|
||||
return size.width, size.height
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get screen size with pyautogui: {e}")
|
||||
|
||||
|
||||
logger.info("Getting screen size (simulated)")
|
||||
return 1920, 1080
|
||||
|
||||
|
||||
class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
"""Linux implementation of automation handler using pyautogui."""
|
||||
|
||||
keyboard = KeyboardController()
|
||||
mouse = MouseController()
|
||||
|
||||
|
||||
# Mouse Actions
|
||||
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
|
||||
async def mouse_down(
|
||||
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
|
||||
) -> Dict[str, Any]:
|
||||
"""Press and hold a mouse button at the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x: The x coordinate to move to before pressing. If None, uses current position.
|
||||
y: The y coordinate to move to before pressing. If None, uses current position.
|
||||
button: The mouse button to press ("left", "right", or "middle").
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and error message if failed.
|
||||
"""
|
||||
@@ -133,15 +143,17 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
|
||||
|
||||
async def mouse_up(
|
||||
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
|
||||
) -> Dict[str, Any]:
|
||||
"""Release a mouse button at the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x: The x coordinate to move to before releasing. If None, uses current position.
|
||||
y: The y coordinate to move to before releasing. If None, uses current position.
|
||||
button: The mouse button to release ("left", "right", or "middle").
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and error message if failed.
|
||||
"""
|
||||
@@ -152,14 +164,14 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
async def move_cursor(self, x: int, y: int) -> Dict[str, Any]:
|
||||
"""Move the cursor to the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x: The x coordinate to move to.
|
||||
y: The y coordinate to move to.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and error message if failed.
|
||||
"""
|
||||
@@ -171,11 +183,11 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Perform a left mouse click at the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x: The x coordinate to click at. If None, clicks at current position.
|
||||
y: The y coordinate to click at. If None, clicks at current position.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and error message if failed.
|
||||
"""
|
||||
@@ -189,11 +201,11 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Perform a right mouse click at the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x: The x coordinate to click at. If None, clicks at current position.
|
||||
y: The y coordinate to click at. If None, clicks at current position.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and error message if failed.
|
||||
"""
|
||||
@@ -205,13 +217,15 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
|
||||
async def double_click(
|
||||
self, x: Optional[int] = None, y: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Perform a double click at the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x: The x coordinate to double click at. If None, clicks at current position.
|
||||
y: The y coordinate to double click at. If None, clicks at current position.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and error message if failed.
|
||||
"""
|
||||
@@ -223,14 +237,16 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def click(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
|
||||
async def click(
|
||||
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
|
||||
) -> Dict[str, Any]:
|
||||
"""Perform a mouse click with the specified button at the given coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x: The x coordinate to click at. If None, clicks at current position.
|
||||
y: The y coordinate to click at. If None, clicks at current position.
|
||||
button: The mouse button to click ("left", "right", or "middle").
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and error message if failed.
|
||||
"""
|
||||
@@ -242,15 +258,17 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def drag_to(self, x: int, y: int, button: str = "left", duration: float = 0.5) -> Dict[str, Any]:
|
||||
async def drag_to(
|
||||
self, x: int, y: int, button: str = "left", duration: float = 0.5
|
||||
) -> Dict[str, Any]:
|
||||
"""Drag from the current position to the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x: The x coordinate to drag to.
|
||||
y: The y coordinate to drag to.
|
||||
button: The mouse button to use for dragging.
|
||||
duration: The time in seconds to take for the drag operation.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and error message if failed.
|
||||
"""
|
||||
@@ -260,16 +278,18 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def drag(self, start_x: int, start_y: int, end_x: int, end_y: int, button: str = "left") -> Dict[str, Any]:
|
||||
async def drag(
|
||||
self, start_x: int, start_y: int, end_x: int, end_y: int, button: str = "left"
|
||||
) -> Dict[str, Any]:
|
||||
"""Drag from start coordinates to end coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
start_x: The starting x coordinate.
|
||||
start_y: The starting y coordinate.
|
||||
end_x: The ending x coordinate.
|
||||
end_y: The ending y coordinate.
|
||||
button: The mouse button to use for dragging.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and error message if failed.
|
||||
"""
|
||||
@@ -280,14 +300,16 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def drag_path(self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5) -> Dict[str, Any]:
|
||||
async def drag_path(
|
||||
self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5
|
||||
) -> Dict[str, Any]:
|
||||
"""Drag along a path defined by a list of coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
path: A list of (x, y) coordinate tuples defining the drag path.
|
||||
button: The mouse button to use for dragging.
|
||||
duration: The time in seconds to take for each segment of the drag.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and error message if failed.
|
||||
"""
|
||||
@@ -304,10 +326,10 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
# Keyboard Actions
|
||||
async def key_down(self, key: str) -> Dict[str, Any]:
|
||||
"""Press and hold a key.
|
||||
|
||||
|
||||
Args:
|
||||
key: The key to press down.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and error message if failed.
|
||||
"""
|
||||
@@ -316,13 +338,13 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
async def key_up(self, key: str) -> Dict[str, Any]:
|
||||
"""Release a key.
|
||||
|
||||
|
||||
Args:
|
||||
key: The key to release.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and error message if failed.
|
||||
"""
|
||||
@@ -331,13 +353,13 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
async def type_text(self, text: str) -> Dict[str, Any]:
|
||||
"""Type the specified text using the keyboard.
|
||||
|
||||
|
||||
Args:
|
||||
text: The text to type.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and error message if failed.
|
||||
"""
|
||||
@@ -350,10 +372,10 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def press_key(self, key: str) -> Dict[str, Any]:
|
||||
"""Press and release a key.
|
||||
|
||||
|
||||
Args:
|
||||
key: The key to press.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and error message if failed.
|
||||
"""
|
||||
@@ -365,10 +387,10 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def hotkey(self, keys: List[str]) -> Dict[str, Any]:
|
||||
"""Press a combination of keys simultaneously.
|
||||
|
||||
|
||||
Args:
|
||||
keys: A list of keys to press together as a hotkey combination.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and error message if failed.
|
||||
"""
|
||||
@@ -381,11 +403,11 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
# Scrolling Actions
|
||||
async def scroll(self, x: int, y: int) -> Dict[str, Any]:
|
||||
"""Scroll the mouse wheel.
|
||||
|
||||
|
||||
Args:
|
||||
x: The horizontal scroll amount.
|
||||
y: The vertical scroll amount.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and error message if failed.
|
||||
"""
|
||||
@@ -394,13 +416,13 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
async def scroll_down(self, clicks: int = 1) -> Dict[str, Any]:
|
||||
"""Scroll down by the specified number of clicks.
|
||||
|
||||
|
||||
Args:
|
||||
clicks: The number of scroll clicks to perform downward.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and error message if failed.
|
||||
"""
|
||||
@@ -412,10 +434,10 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def scroll_up(self, clicks: int = 1) -> Dict[str, Any]:
|
||||
"""Scroll up by the specified number of clicks.
|
||||
|
||||
|
||||
Args:
|
||||
clicks: The number of scroll clicks to perform upward.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and error message if failed.
|
||||
"""
|
||||
@@ -428,13 +450,14 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
# Screen Actions
|
||||
async def screenshot(self) -> Dict[str, Any]:
|
||||
"""Take a screenshot of the current screen.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing success status and base64-encoded image data,
|
||||
or error message if failed.
|
||||
"""
|
||||
try:
|
||||
from PIL import Image
|
||||
|
||||
screenshot = pyautogui.screenshot()
|
||||
if not isinstance(screenshot, Image.Image):
|
||||
return {"success": False, "error": "Failed to capture screenshot"}
|
||||
@@ -448,7 +471,7 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def get_screen_size(self) -> Dict[str, Any]:
|
||||
"""Get the size of the screen.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing success status and screen dimensions,
|
||||
or error message if failed.
|
||||
@@ -461,7 +484,7 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def get_cursor_position(self) -> Dict[str, Any]:
|
||||
"""Get the current position of the cursor.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing success status and cursor coordinates,
|
||||
or error message if failed.
|
||||
@@ -475,13 +498,14 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
# Clipboard Actions
|
||||
async def copy_to_clipboard(self) -> Dict[str, Any]:
|
||||
"""Get the current content of the clipboard.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing success status and clipboard content,
|
||||
or error message if failed.
|
||||
"""
|
||||
try:
|
||||
import pyperclip
|
||||
|
||||
content = pyperclip.paste()
|
||||
return {"success": True, "content": content}
|
||||
except Exception as e:
|
||||
@@ -489,15 +513,16 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def set_clipboard(self, text: str) -> Dict[str, Any]:
|
||||
"""Set the clipboard content to the specified text.
|
||||
|
||||
|
||||
Args:
|
||||
text: The text to copy to the clipboard.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and error message if failed.
|
||||
"""
|
||||
try:
|
||||
import pyperclip
|
||||
|
||||
pyperclip.copy(text)
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
@@ -506,10 +531,10 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
# Command Execution
|
||||
async def run_command(self, command: str) -> Dict[str, Any]:
|
||||
"""Execute a shell command asynchronously.
|
||||
|
||||
|
||||
Args:
|
||||
command: The shell command to execute.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing success status, stdout, stderr,
|
||||
and return code, or error message if failed.
|
||||
@@ -517,18 +542,16 @@ class LinuxAutomationHandler(BaseAutomationHandler):
|
||||
try:
|
||||
# Create subprocess
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
# Wait for the subprocess to finish
|
||||
stdout, stderr = await process.communicate()
|
||||
# Return decoded output
|
||||
return {
|
||||
"success": True,
|
||||
"stdout": stdout.decode() if stdout else "",
|
||||
"stderr": stderr.decode() if stderr else "",
|
||||
"return_code": process.returncode
|
||||
"success": True,
|
||||
"stdout": stdout.decode() if stdout else "",
|
||||
"stderr": stderr.decode() if stderr else "",
|
||||
"return_code": process.returncode,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@@ -1,54 +1,57 @@
|
||||
import pyautogui
|
||||
|
||||
pyautogui.FAILSAFE = False
|
||||
from pynput.mouse import Button, Controller as MouseController
|
||||
from pynput.keyboard import Key, Controller as KeyboardController
|
||||
import time
|
||||
import asyncio
|
||||
import base64
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from ctypes import POINTER, byref, c_void_p
|
||||
from io import BytesIO
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
from ctypes import byref, c_void_p, POINTER
|
||||
from AppKit import NSWorkspace # type: ignore
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import AppKit
|
||||
import Foundation
|
||||
import objc
|
||||
from AppKit import NSWorkspace # type: ignore
|
||||
from ApplicationServices import AXUIElementCopyAttributeValue # type: ignore
|
||||
from ApplicationServices import AXUIElementCopyAttributeValues # type: ignore
|
||||
from ApplicationServices import AXUIElementCreateApplication # type: ignore
|
||||
from ApplicationServices import AXUIElementCreateSystemWide # type: ignore
|
||||
from ApplicationServices import AXUIElementGetTypeID # type: ignore
|
||||
from ApplicationServices import AXValueGetType # type: ignore
|
||||
from ApplicationServices import AXValueGetValue # type: ignore
|
||||
from ApplicationServices import kAXChildrenAttribute # type: ignore
|
||||
from ApplicationServices import kAXDescriptionAttribute # type: ignore
|
||||
from ApplicationServices import kAXEnabledAttribute # type: ignore
|
||||
from ApplicationServices import kAXErrorSuccess # type: ignore
|
||||
from ApplicationServices import kAXFocusedApplicationAttribute # type: ignore
|
||||
from ApplicationServices import kAXFocusedUIElementAttribute # type: ignore
|
||||
from ApplicationServices import kAXFocusedWindowAttribute # type: ignore
|
||||
from ApplicationServices import kAXMainWindowAttribute # type: ignore
|
||||
from ApplicationServices import kAXPositionAttribute # type: ignore
|
||||
from ApplicationServices import kAXRoleAttribute # type: ignore
|
||||
from ApplicationServices import kAXRoleDescriptionAttribute # type: ignore
|
||||
from ApplicationServices import kAXSelectedTextAttribute # type: ignore
|
||||
from ApplicationServices import kAXSelectedTextRangeAttribute # type: ignore
|
||||
from ApplicationServices import kAXSizeAttribute # type: ignore
|
||||
from ApplicationServices import kAXTitleAttribute # type: ignore
|
||||
from ApplicationServices import kAXValueAttribute # type: ignore
|
||||
from ApplicationServices import kAXValueCFRangeType # type: ignore
|
||||
from ApplicationServices import kAXValueCGPointType # type: ignore
|
||||
from ApplicationServices import kAXValueCGSizeType # type: ignore
|
||||
from ApplicationServices import kAXVisibleChildrenAttribute # type: ignore
|
||||
from ApplicationServices import kAXWindowsAttribute # type: ignore
|
||||
from pynput.keyboard import Controller as KeyboardController
|
||||
from pynput.keyboard import Key
|
||||
from pynput.mouse import Button
|
||||
from pynput.mouse import Controller as MouseController
|
||||
from Quartz.CoreGraphics import * # type: ignore
|
||||
from Quartz.CoreGraphics import CGPoint, CGSize # type: ignore
|
||||
import Foundation
|
||||
from ApplicationServices import (
|
||||
AXUIElementCreateSystemWide, # type: ignore
|
||||
AXUIElementCreateApplication, # type: ignore
|
||||
AXUIElementCopyAttributeValue, # type: ignore
|
||||
AXUIElementCopyAttributeValues, # type: ignore
|
||||
kAXFocusedWindowAttribute, # type: ignore
|
||||
kAXWindowsAttribute, # type: ignore
|
||||
kAXMainWindowAttribute, # type: ignore
|
||||
kAXChildrenAttribute, # type: ignore
|
||||
kAXRoleAttribute, # type: ignore
|
||||
kAXTitleAttribute, # type: ignore
|
||||
kAXValueAttribute, # type: ignore
|
||||
kAXDescriptionAttribute, # type: ignore
|
||||
kAXEnabledAttribute, # type: ignore
|
||||
kAXPositionAttribute, # type: ignore
|
||||
kAXSizeAttribute, # type: ignore
|
||||
kAXErrorSuccess, # type: ignore
|
||||
AXValueGetType, # type: ignore
|
||||
kAXValueCGSizeType, # type: ignore
|
||||
kAXValueCGPointType, # type: ignore
|
||||
kAXValueCFRangeType, # type: ignore
|
||||
AXUIElementGetTypeID, # type: ignore
|
||||
AXValueGetValue, # type: ignore
|
||||
kAXVisibleChildrenAttribute, # type: ignore
|
||||
kAXRoleDescriptionAttribute, # type: ignore
|
||||
kAXFocusedApplicationAttribute, # type: ignore
|
||||
kAXFocusedUIElementAttribute, # type: ignore
|
||||
kAXSelectedTextAttribute, # type: ignore
|
||||
kAXSelectedTextRangeAttribute, # type: ignore
|
||||
)
|
||||
import objc
|
||||
import re
|
||||
import json
|
||||
import copy
|
||||
import asyncio
|
||||
|
||||
from .base import BaseAccessibilityHandler, BaseAutomationHandler
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -73,24 +76,26 @@ kCGWindowAlpha = "kCGWindowAlpha" # Window opacity
|
||||
NSApplicationActivationOptions = {
|
||||
"regular": 0, # Default activation
|
||||
"bringing_all_windows_forward": 1 << 0, # NSApplicationActivateAllWindows
|
||||
"ignoring_other_apps": 1 << 1 # NSApplicationActivateIgnoringOtherApps
|
||||
"ignoring_other_apps": 1 << 1, # NSApplicationActivateIgnoringOtherApps
|
||||
}
|
||||
|
||||
|
||||
def CFAttributeToPyObject(attrValue):
|
||||
"""Convert Core Foundation attribute values to Python objects.
|
||||
|
||||
|
||||
Args:
|
||||
attrValue: Core Foundation attribute value to convert
|
||||
|
||||
|
||||
Returns:
|
||||
Converted Python object or None if conversion fails
|
||||
"""
|
||||
|
||||
def list_helper(list_value):
|
||||
"""Helper function to convert CF arrays to Python lists.
|
||||
|
||||
|
||||
Args:
|
||||
list_value: Core Foundation array to convert
|
||||
|
||||
|
||||
Returns:
|
||||
Python list containing converted items
|
||||
"""
|
||||
@@ -101,10 +106,10 @@ def CFAttributeToPyObject(attrValue):
|
||||
|
||||
def number_helper(number_value):
|
||||
"""Helper function to convert CF numbers to Python numbers.
|
||||
|
||||
|
||||
Args:
|
||||
number_value: Core Foundation number to convert
|
||||
|
||||
|
||||
Returns:
|
||||
Python int or float, or None if conversion fails
|
||||
"""
|
||||
@@ -123,10 +128,10 @@ def CFAttributeToPyObject(attrValue):
|
||||
|
||||
def axuielement_helper(element_value):
|
||||
"""Helper function to handle AX UI elements.
|
||||
|
||||
|
||||
Args:
|
||||
element_value: Accessibility UI element to process
|
||||
|
||||
|
||||
Returns:
|
||||
The element value unchanged
|
||||
"""
|
||||
@@ -164,11 +169,11 @@ def CFAttributeToPyObject(attrValue):
|
||||
|
||||
def element_attribute(element, attribute):
|
||||
"""Get an attribute value from an accessibility element.
|
||||
|
||||
|
||||
Args:
|
||||
element: The accessibility element
|
||||
attribute: The attribute name to retrieve
|
||||
|
||||
|
||||
Returns:
|
||||
The attribute value or None if not found
|
||||
"""
|
||||
@@ -190,11 +195,11 @@ def element_attribute(element, attribute):
|
||||
|
||||
def element_value(element, type):
|
||||
"""Extract a typed value from an accessibility element.
|
||||
|
||||
|
||||
Args:
|
||||
element: The accessibility element containing the value
|
||||
type: The expected value type
|
||||
|
||||
|
||||
Returns:
|
||||
The extracted value or None if extraction fails
|
||||
"""
|
||||
@@ -206,10 +211,10 @@ def element_value(element, type):
|
||||
|
||||
class UIElement:
|
||||
"""Represents a UI element in the accessibility tree with position, size, and hierarchy information."""
|
||||
|
||||
|
||||
def __init__(self, element, offset_x=0, offset_y=0, max_depth=None, parents_visible_bbox=None):
|
||||
"""Initialize a UIElement from an accessibility element.
|
||||
|
||||
|
||||
Args:
|
||||
element: The accessibility element to wrap
|
||||
offset_x: X offset for position calculations
|
||||
@@ -297,7 +302,7 @@ class UIElement:
|
||||
|
||||
def _set_bboxes(self, parents_visible_bbox):
|
||||
"""Set bounding box and visible bounding box for the element.
|
||||
|
||||
|
||||
Args:
|
||||
parents_visible_bbox: Parent's visible bounding box for intersection calculation
|
||||
"""
|
||||
@@ -332,13 +337,13 @@ class UIElement:
|
||||
|
||||
def _get_children(self, element, start_position, offset_x, offset_y):
|
||||
"""Get child elements from the accessibility element.
|
||||
|
||||
|
||||
Args:
|
||||
element: The parent accessibility element
|
||||
start_position: Starting position for offset calculations
|
||||
offset_x: X offset for child positioning
|
||||
offset_y: Y offset for child positioning
|
||||
|
||||
|
||||
Returns:
|
||||
List of UIElement children
|
||||
"""
|
||||
@@ -371,7 +376,7 @@ class UIElement:
|
||||
|
||||
def component_hash(self):
|
||||
"""Generate a hash identifier for this component based on its properties.
|
||||
|
||||
|
||||
Returns:
|
||||
MD5 hash string of component properties
|
||||
"""
|
||||
@@ -388,10 +393,10 @@ class UIElement:
|
||||
|
||||
def hash_from_string(self, string):
|
||||
"""Generate MD5 hash from a string.
|
||||
|
||||
|
||||
Args:
|
||||
string: Input string to hash
|
||||
|
||||
|
||||
Returns:
|
||||
MD5 hash hexdigest or empty string if input is None/empty
|
||||
"""
|
||||
@@ -403,10 +408,10 @@ class UIElement:
|
||||
|
||||
def children_content_hash(self, children):
|
||||
"""Generate a hash representing the content and structure of child elements.
|
||||
|
||||
|
||||
Args:
|
||||
children: List of child UIElement objects
|
||||
|
||||
|
||||
Returns:
|
||||
Combined hash of children content and structure
|
||||
"""
|
||||
@@ -426,16 +431,17 @@ class UIElement:
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert the UIElement to a dictionary representation.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing all element properties and children
|
||||
"""
|
||||
|
||||
def children_to_dict(children):
|
||||
"""Convert list of children to dictionary format.
|
||||
|
||||
|
||||
Args:
|
||||
children: List of UIElement children to convert
|
||||
|
||||
|
||||
Returns:
|
||||
List of dictionaries representing the children
|
||||
"""
|
||||
@@ -464,7 +470,7 @@ class UIElement:
|
||||
size = f"{self.size.width:.0f};{self.size.height:.0f}"
|
||||
else:
|
||||
size = ""
|
||||
|
||||
|
||||
return {
|
||||
"id": self.identifier,
|
||||
"name": self.name,
|
||||
@@ -482,36 +488,38 @@ class UIElement:
|
||||
}
|
||||
|
||||
|
||||
import Quartz
|
||||
from AppKit import NSWorkspace, NSRunningApplication
|
||||
from pathlib import Path
|
||||
|
||||
import Quartz
|
||||
from AppKit import NSRunningApplication, NSWorkspace
|
||||
|
||||
|
||||
def get_all_windows_zorder():
|
||||
"""Get all windows in the system with their z-order information.
|
||||
|
||||
|
||||
Returns:
|
||||
List of window dictionaries sorted by z-index, containing window properties
|
||||
like id, name, pid, owner, bounds, layer, and opacity
|
||||
"""
|
||||
window_list = Quartz.CGWindowListCopyWindowInfo(
|
||||
Quartz.kCGWindowListOptionOnScreenOnly,
|
||||
Quartz.kCGNullWindowID
|
||||
Quartz.kCGWindowListOptionOnScreenOnly, Quartz.kCGNullWindowID
|
||||
)
|
||||
z_order = {window['kCGWindowNumber']: z_index for z_index, window in enumerate(window_list[::-1])}
|
||||
z_order = {
|
||||
window["kCGWindowNumber"]: z_index for z_index, window in enumerate(window_list[::-1])
|
||||
}
|
||||
window_list_all = Quartz.CGWindowListCopyWindowInfo(
|
||||
Quartz.kCGWindowListOptionAll,
|
||||
Quartz.kCGNullWindowID
|
||||
Quartz.kCGWindowListOptionAll, Quartz.kCGNullWindowID
|
||||
)
|
||||
windows = []
|
||||
for window in window_list_all:
|
||||
window_id = window.get('kCGWindowNumber', 0)
|
||||
window_name = window.get('kCGWindowName', '')
|
||||
window_pid = window.get('kCGWindowOwnerPID', 0)
|
||||
window_bounds = window.get('kCGWindowBounds', {})
|
||||
window_owner = window.get('kCGWindowOwnerName', '')
|
||||
window_is_on_screen = window.get('kCGWindowIsOnscreen', False)
|
||||
layer = window.get('kCGWindowLayer', 0)
|
||||
opacity = window.get('kCGWindowAlpha', 1.0)
|
||||
window_id = window.get("kCGWindowNumber", 0)
|
||||
window_name = window.get("kCGWindowName", "")
|
||||
window_pid = window.get("kCGWindowOwnerPID", 0)
|
||||
window_bounds = window.get("kCGWindowBounds", {})
|
||||
window_owner = window.get("kCGWindowOwnerName", "")
|
||||
window_is_on_screen = window.get("kCGWindowIsOnscreen", False)
|
||||
layer = window.get("kCGWindowLayer", 0)
|
||||
opacity = window.get("kCGWindowAlpha", 1.0)
|
||||
z_index = z_order.get(window_id, -1)
|
||||
if window_name == "Dock" and window_owner == "Dock":
|
||||
role = "dock"
|
||||
@@ -522,32 +530,35 @@ def get_all_windows_zorder():
|
||||
else:
|
||||
role = "app"
|
||||
if window_bounds:
|
||||
windows.append({
|
||||
"id": window_id,
|
||||
"name": window_name or "Unnamed Window",
|
||||
"pid": window_pid,
|
||||
"owner": window_owner,
|
||||
"role": role,
|
||||
"is_on_screen": window_is_on_screen,
|
||||
"bounds": {
|
||||
"x": window_bounds.get('X', 0),
|
||||
"y": window_bounds.get('Y', 0),
|
||||
"width": window_bounds.get('Width', 0),
|
||||
"height": window_bounds.get('Height', 0)
|
||||
},
|
||||
"layer": layer,
|
||||
"z_index": z_index,
|
||||
"opacity": opacity
|
||||
})
|
||||
windows.append(
|
||||
{
|
||||
"id": window_id,
|
||||
"name": window_name or "Unnamed Window",
|
||||
"pid": window_pid,
|
||||
"owner": window_owner,
|
||||
"role": role,
|
||||
"is_on_screen": window_is_on_screen,
|
||||
"bounds": {
|
||||
"x": window_bounds.get("X", 0),
|
||||
"y": window_bounds.get("Y", 0),
|
||||
"width": window_bounds.get("Width", 0),
|
||||
"height": window_bounds.get("Height", 0),
|
||||
},
|
||||
"layer": layer,
|
||||
"z_index": z_index,
|
||||
"opacity": opacity,
|
||||
}
|
||||
)
|
||||
windows = sorted(windows, key=lambda x: x["z_index"])
|
||||
return windows
|
||||
|
||||
|
||||
def get_app_info(app):
|
||||
"""Extract information from an NSRunningApplication object.
|
||||
|
||||
|
||||
Args:
|
||||
app: NSRunningApplication instance
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing app name, bundle ID, PID, and status flags
|
||||
"""
|
||||
@@ -560,12 +571,13 @@ def get_app_info(app):
|
||||
"terminated": app.isTerminated(),
|
||||
}
|
||||
|
||||
|
||||
def get_menubar_items(active_app_pid=None):
|
||||
"""Get menubar items for the active application.
|
||||
|
||||
|
||||
Args:
|
||||
active_app_pid: Process ID of the active application, or None to use frontmost app
|
||||
|
||||
|
||||
Returns:
|
||||
List of menubar item dictionaries with title, bounds, index, and app_pid
|
||||
"""
|
||||
@@ -591,26 +603,24 @@ def get_menubar_items(active_app_pid=None):
|
||||
position_value = element_attribute(item, kAXPositionAttribute)
|
||||
if position_value:
|
||||
position_value = element_value(position_value, kAXValueCGPointType)
|
||||
bounds["x"] = getattr(position_value, 'x', 0)
|
||||
bounds["y"] = getattr(position_value, 'y', 0)
|
||||
bounds["x"] = getattr(position_value, "x", 0)
|
||||
bounds["y"] = getattr(position_value, "y", 0)
|
||||
size_value = element_attribute(item, kAXSizeAttribute)
|
||||
if size_value:
|
||||
size_value = element_value(size_value, kAXValueCGSizeType)
|
||||
bounds["width"] = getattr(size_value, 'width', 0)
|
||||
bounds["height"] = getattr(size_value, 'height', 0)
|
||||
menubar_items.append({
|
||||
"title": title,
|
||||
"bounds": bounds,
|
||||
"index": i,
|
||||
"app_pid": active_app_pid
|
||||
})
|
||||
bounds["width"] = getattr(size_value, "width", 0)
|
||||
bounds["height"] = getattr(size_value, "height", 0)
|
||||
menubar_items.append(
|
||||
{"title": title, "bounds": bounds, "index": i, "app_pid": active_app_pid}
|
||||
)
|
||||
return menubar_items
|
||||
|
||||
|
||||
def get_dock_items():
|
||||
"""Get all items in the macOS Dock.
|
||||
|
||||
|
||||
Returns:
|
||||
List of dock item dictionaries with title, description, bounds, index,
|
||||
List of dock item dictionaries with title, description, bounds, index,
|
||||
type, role, and subrole information
|
||||
"""
|
||||
dock_items = []
|
||||
@@ -648,13 +658,13 @@ def get_dock_items():
|
||||
position_value = element_attribute(item, kAXPositionAttribute)
|
||||
if position_value:
|
||||
position_value = element_value(position_value, kAXValueCGPointType)
|
||||
bounds["x"] = getattr(position_value, 'x', 0)
|
||||
bounds["y"] = getattr(position_value, 'y', 0)
|
||||
bounds["x"] = getattr(position_value, "x", 0)
|
||||
bounds["y"] = getattr(position_value, "y", 0)
|
||||
size_value = element_attribute(item, kAXSizeAttribute)
|
||||
if size_value:
|
||||
size_value = element_value(size_value, kAXValueCGSizeType)
|
||||
bounds["width"] = getattr(size_value, 'width', 0)
|
||||
bounds["height"] = getattr(size_value, 'height', 0)
|
||||
bounds["width"] = getattr(size_value, "width", 0)
|
||||
bounds["height"] = getattr(size_value, "height", 0)
|
||||
item_type = "unknown"
|
||||
if subrole == "AXApplicationDockItem":
|
||||
item_type = "application"
|
||||
@@ -666,23 +676,26 @@ def get_dock_items():
|
||||
item_type = "separator"
|
||||
elif "trash" in title.lower():
|
||||
item_type = "trash"
|
||||
dock_items.append({
|
||||
"title": title,
|
||||
"description": description,
|
||||
"bounds": bounds,
|
||||
"index": i,
|
||||
"type": item_type,
|
||||
"role": role,
|
||||
"subrole": subrole
|
||||
})
|
||||
dock_items.append(
|
||||
{
|
||||
"title": title,
|
||||
"description": description,
|
||||
"bounds": bounds,
|
||||
"index": i,
|
||||
"type": item_type,
|
||||
"role": role,
|
||||
"subrole": subrole,
|
||||
}
|
||||
)
|
||||
return dock_items
|
||||
|
||||
|
||||
class MacOSAccessibilityHandler(BaseAccessibilityHandler):
|
||||
"""Handler for macOS accessibility features and UI element inspection."""
|
||||
|
||||
|
||||
def get_desktop_state(self):
|
||||
"""Get the current state of the desktop including windows, apps, menubar, and dock.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing applications, windows, menubar_items, and dock_items
|
||||
"""
|
||||
@@ -696,7 +709,9 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
|
||||
pid = app.processIdentifier()
|
||||
try:
|
||||
app_elem = AXUIElementCreateApplication(pid)
|
||||
err, app_windows = AXUIElementCopyAttributeValue(app_elem, kAXWindowsAttribute, None)
|
||||
err, app_windows = AXUIElementCopyAttributeValue(
|
||||
app_elem, kAXWindowsAttribute, None
|
||||
)
|
||||
trees = []
|
||||
if err == kAXErrorSuccess and app_windows:
|
||||
for ax_win in app_windows:
|
||||
@@ -713,31 +728,32 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
|
||||
pid = win["pid"]
|
||||
idx = pid_to_idx.get(pid, 0)
|
||||
ax_trees = pid_to_ax_trees.get(pid, [])
|
||||
win["children"] = ax_trees[idx]["children"] if idx < len(ax_trees) and "children" in ax_trees[idx] else []
|
||||
win["children"] = (
|
||||
ax_trees[idx]["children"]
|
||||
if idx < len(ax_trees) and "children" in ax_trees[idx]
|
||||
else []
|
||||
)
|
||||
pid_to_idx[pid] = idx + 1
|
||||
pid_to_window_ids.setdefault(pid, []).append(win["id"])
|
||||
for app in running_apps:
|
||||
info = get_app_info(app)
|
||||
app_pid = info["pid"]
|
||||
applications.append({
|
||||
"info": info,
|
||||
"windows": pid_to_window_ids.get(app_pid, [])
|
||||
})
|
||||
applications.append({"info": info, "windows": pid_to_window_ids.get(app_pid, [])})
|
||||
menubar_items = get_menubar_items()
|
||||
dock_items = get_dock_items()
|
||||
return {
|
||||
"applications": applications,
|
||||
"windows": windows,
|
||||
"menubar_items": menubar_items,
|
||||
"dock_items": dock_items
|
||||
"dock_items": dock_items,
|
||||
}
|
||||
|
||||
def get_application_windows(self, pid: int):
|
||||
"""Get all windows for a specific application.
|
||||
|
||||
|
||||
Args:
|
||||
pid: Process ID of the application
|
||||
|
||||
|
||||
Returns:
|
||||
List of accessibility window elements or empty list if none found
|
||||
"""
|
||||
@@ -753,7 +769,7 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
|
||||
|
||||
def get_all_windows(self):
|
||||
"""Get all visible windows in the system.
|
||||
|
||||
|
||||
Returns:
|
||||
List of window dictionaries with app information and window details
|
||||
"""
|
||||
@@ -791,7 +807,7 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
|
||||
|
||||
def get_running_apps(self):
|
||||
"""Get all currently running applications.
|
||||
|
||||
|
||||
Returns:
|
||||
List of NSRunningApplication objects
|
||||
"""
|
||||
@@ -803,11 +819,11 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
|
||||
|
||||
def get_ax_attribute(self, element, attribute):
|
||||
"""Get an accessibility attribute from an element.
|
||||
|
||||
|
||||
Args:
|
||||
element: The accessibility element
|
||||
attribute: The attribute name to retrieve
|
||||
|
||||
|
||||
Returns:
|
||||
The attribute value or None if not found
|
||||
"""
|
||||
@@ -815,10 +831,10 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
|
||||
|
||||
def serialize_node(self, element):
|
||||
"""Create a serializable dictionary representation of an accessibility element.
|
||||
|
||||
|
||||
Args:
|
||||
element: The accessibility element to serialize
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing element properties like role, title, value, position, and size
|
||||
"""
|
||||
@@ -851,16 +867,13 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
|
||||
|
||||
async def get_accessibility_tree(self) -> Dict[str, Any]:
|
||||
"""Get the complete accessibility tree for the current desktop state.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and desktop state information
|
||||
"""
|
||||
"""
|
||||
try:
|
||||
desktop_state = self.get_desktop_state()
|
||||
return {
|
||||
"success": True,
|
||||
**desktop_state
|
||||
}
|
||||
return {"success": True, **desktop_state}
|
||||
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
@@ -869,12 +882,12 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
|
||||
self, role: Optional[str] = None, title: Optional[str] = None, value: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Find an accessibility element matching the specified criteria.
|
||||
|
||||
|
||||
Args:
|
||||
role: The accessibility role to match (optional)
|
||||
title: The title to match (optional)
|
||||
value: The value to match (optional)
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and the found element or error message
|
||||
"""
|
||||
@@ -883,10 +896,10 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
|
||||
|
||||
def match_element(element):
|
||||
"""Check if an element matches the search criteria.
|
||||
|
||||
|
||||
Args:
|
||||
element: The accessibility element to check
|
||||
|
||||
|
||||
Returns:
|
||||
True if element matches all specified criteria, False otherwise
|
||||
"""
|
||||
@@ -900,10 +913,10 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
|
||||
|
||||
def search_tree(element):
|
||||
"""Recursively search the accessibility tree for matching elements.
|
||||
|
||||
|
||||
Args:
|
||||
element: The accessibility element to search from
|
||||
|
||||
|
||||
Returns:
|
||||
Serialized element dictionary if match found, None otherwise
|
||||
"""
|
||||
@@ -924,58 +937,71 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
"""Handler for macOS automation including mouse, keyboard, and screen operations."""
|
||||
|
||||
|
||||
# Mouse Actions
|
||||
mouse = MouseController()
|
||||
keyboard = KeyboardController()
|
||||
|
||||
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
|
||||
|
||||
async def mouse_down(
|
||||
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
|
||||
) -> Dict[str, Any]:
|
||||
"""Press and hold a mouse button at the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x: X coordinate (optional, uses current position if None)
|
||||
y: Y coordinate (optional, uses current position if None)
|
||||
button: Mouse button to press ("left", "right", or "middle")
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and error message if failed
|
||||
"""
|
||||
try:
|
||||
if x is not None and y is not None:
|
||||
self.mouse.position = (x, y)
|
||||
self.mouse.press(Button.left if button == "left" else Button.right if button == "right" else Button.middle)
|
||||
self.mouse.press(
|
||||
Button.left
|
||||
if button == "left"
|
||||
else Button.right if button == "right" else Button.middle
|
||||
)
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
|
||||
async def mouse_up(
|
||||
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
|
||||
) -> Dict[str, Any]:
|
||||
"""Release a mouse button at the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x: X coordinate (optional, uses current position if None)
|
||||
y: Y coordinate (optional, uses current position if None)
|
||||
button: Mouse button to release ("left", "right", or "middle")
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and error message if failed
|
||||
"""
|
||||
try:
|
||||
if x is not None and y is not None:
|
||||
self.mouse.position = (x, y)
|
||||
self.mouse.release(Button.left if button == "left" else Button.right if button == "right" else Button.middle)
|
||||
self.mouse.release(
|
||||
Button.left
|
||||
if button == "left"
|
||||
else Button.right if button == "right" else Button.middle
|
||||
)
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Perform a left mouse click at the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x: X coordinate (optional, uses current position if None)
|
||||
y: Y coordinate (optional, uses current position if None)
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and error message if failed
|
||||
"""
|
||||
@@ -989,11 +1015,11 @@ class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Perform a right mouse click at the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x: X coordinate (optional, uses current position if None)
|
||||
y: Y coordinate (optional, uses current position if None)
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and error message if failed
|
||||
"""
|
||||
@@ -1009,11 +1035,11 @@ class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
self, x: Optional[int] = None, y: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Perform a double left mouse click at the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x: X coordinate (optional, uses current position if None)
|
||||
y: Y coordinate (optional, uses current position if None)
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and error message if failed
|
||||
"""
|
||||
@@ -1027,11 +1053,11 @@ class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def move_cursor(self, x: int, y: int) -> Dict[str, Any]:
|
||||
"""Move the mouse cursor to the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x: Target X coordinate
|
||||
y: Target Y coordinate
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and error message if failed
|
||||
"""
|
||||
@@ -1045,18 +1071,22 @@ class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
self, x: int, y: int, button: str = "left", duration: float = 0.5
|
||||
) -> Dict[str, Any]:
|
||||
"""Drag from current position to target coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x: Target X coordinate
|
||||
y: Target Y coordinate
|
||||
button: Mouse button to use for dragging ("left", "right", or "middle")
|
||||
duration: Duration of the drag operation in seconds
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and error message if failed
|
||||
"""
|
||||
try:
|
||||
btn = Button.left if button == "left" else Button.right if button == "right" else Button.middle
|
||||
btn = (
|
||||
Button.left
|
||||
if button == "left"
|
||||
else Button.right if button == "right" else Button.middle
|
||||
)
|
||||
# Press
|
||||
self.mouse.press(btn)
|
||||
# Move with sleep to simulate drag duration
|
||||
@@ -1082,19 +1112,23 @@ class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5
|
||||
) -> Dict[str, Any]:
|
||||
"""Drag the mouse along a specified path of coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
path: List of (x, y) coordinate tuples defining the drag path
|
||||
button: Mouse button to use for dragging ("left", "right", or "middle")
|
||||
duration: Total duration of the drag operation in seconds
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and error message if failed
|
||||
"""
|
||||
try:
|
||||
if not path or len(path) < 2:
|
||||
return {"success": False, "error": "Path must contain at least 2 points"}
|
||||
btn = Button.left if button == "left" else Button.right if button == "right" else Button.middle
|
||||
btn = (
|
||||
Button.left
|
||||
if button == "left"
|
||||
else Button.right if button == "right" else Button.middle
|
||||
)
|
||||
# Move to the first point
|
||||
self.mouse.position = path[0]
|
||||
self.mouse.press(btn)
|
||||
@@ -1114,10 +1148,10 @@ class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
# Keyboard Actions
|
||||
async def key_down(self, key: str) -> Dict[str, Any]:
|
||||
"""Press and hold a keyboard key.
|
||||
|
||||
|
||||
Args:
|
||||
key: Key name to press (using pyautogui key names)
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and error message if failed
|
||||
"""
|
||||
@@ -1127,13 +1161,13 @@ class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
async def key_up(self, key: str) -> Dict[str, Any]:
|
||||
"""Release a keyboard key.
|
||||
|
||||
|
||||
Args:
|
||||
key: Key name to release (using pyautogui key names)
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and error message if failed
|
||||
"""
|
||||
@@ -1143,13 +1177,13 @@ class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
async def type_text(self, text: str) -> Dict[str, Any]:
|
||||
"""Type text using the keyboard with Unicode support.
|
||||
|
||||
|
||||
Args:
|
||||
text: Text string to type
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and error message if failed
|
||||
"""
|
||||
@@ -1162,10 +1196,10 @@ class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def press_key(self, key: str) -> Dict[str, Any]:
|
||||
"""Press and release a keyboard key.
|
||||
|
||||
|
||||
Args:
|
||||
key: Key name to press (using pyautogui key names)
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and error message if failed
|
||||
"""
|
||||
@@ -1178,10 +1212,10 @@ class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def hotkey(self, keys: List[str]) -> Dict[str, Any]:
|
||||
"""Press a combination of keys simultaneously.
|
||||
|
||||
|
||||
Args:
|
||||
keys: List of key names to press together (using pyautogui key names)
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and error message if failed
|
||||
"""
|
||||
@@ -1195,11 +1229,11 @@ class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
# Scrolling Actions
|
||||
async def scroll(self, x: int, y: int) -> Dict[str, Any]:
|
||||
"""Scroll the mouse wheel in the specified direction.
|
||||
|
||||
|
||||
Args:
|
||||
x: Horizontal scroll amount
|
||||
y: Vertical scroll amount (positive for up, negative for down)
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and error message if failed
|
||||
"""
|
||||
@@ -1208,13 +1242,13 @@ class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
async def scroll_down(self, clicks: int = 1) -> Dict[str, Any]:
|
||||
"""Scroll down by the specified number of clicks.
|
||||
|
||||
|
||||
Args:
|
||||
clicks: Number of scroll clicks to perform
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and error message if failed
|
||||
"""
|
||||
@@ -1226,10 +1260,10 @@ class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def scroll_up(self, clicks: int = 1) -> Dict[str, Any]:
|
||||
"""Scroll up by the specified number of clicks.
|
||||
|
||||
|
||||
Args:
|
||||
clicks: Number of scroll clicks to perform
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and error message if failed
|
||||
"""
|
||||
@@ -1242,7 +1276,7 @@ class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
# Screen Actions
|
||||
async def screenshot(self) -> Dict[str, Any]:
|
||||
"""Capture a screenshot of the current screen.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and base64-encoded image data or error message
|
||||
"""
|
||||
@@ -1263,7 +1297,7 @@ class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def get_screen_size(self) -> Dict[str, Any]:
|
||||
"""Get the dimensions of the current screen.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and screen size or error message
|
||||
"""
|
||||
@@ -1275,7 +1309,7 @@ class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def get_cursor_position(self) -> Dict[str, Any]:
|
||||
"""Get the current position of the mouse cursor.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and cursor position or error message
|
||||
"""
|
||||
@@ -1288,7 +1322,7 @@ class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
# Clipboard Actions
|
||||
async def copy_to_clipboard(self) -> Dict[str, Any]:
|
||||
"""Get the current content of the system clipboard.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and clipboard content or error message
|
||||
"""
|
||||
@@ -1302,10 +1336,10 @@ class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def set_clipboard(self, text: str) -> Dict[str, Any]:
|
||||
"""Set the content of the system clipboard.
|
||||
|
||||
|
||||
Args:
|
||||
text: Text to copy to the clipboard
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status and error message if failed
|
||||
"""
|
||||
@@ -1319,28 +1353,26 @@ class MacOSAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def run_command(self, command: str) -> Dict[str, Any]:
|
||||
"""Run a shell command and return its output.
|
||||
|
||||
|
||||
Args:
|
||||
command: Shell command to execute
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing success status, stdout, stderr, and return code
|
||||
"""
|
||||
try:
|
||||
# Create subprocess
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
# Wait for the subprocess to finish
|
||||
stdout, stderr = await process.communicate()
|
||||
# Return decoded output
|
||||
return {
|
||||
"success": True,
|
||||
"stdout": stdout.decode() if stdout else "",
|
||||
"success": True,
|
||||
"stdout": stdout.decode() if stdout else "",
|
||||
"stderr": stderr.decode() if stderr else "",
|
||||
"return_code": process.returncode
|
||||
"return_code": process.returncode,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@@ -4,15 +4,17 @@ Windows implementation of automation and accessibility handlers.
|
||||
This implementation uses pyautogui for GUI automation and Windows-specific APIs
|
||||
for accessibility and system operations.
|
||||
"""
|
||||
from typing import Dict, Any, List, Tuple, Optional
|
||||
import logging
|
||||
import subprocess
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from io import BytesIO
|
||||
from pynput.mouse import Controller as MouseController
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pynput.keyboard import Controller as KeyboardController
|
||||
from pynput.mouse import Controller as MouseController
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -20,6 +22,7 @@ logger = logging.getLogger(__name__)
|
||||
# Try to import pyautogui
|
||||
try:
|
||||
import pyautogui
|
||||
|
||||
pyautogui.FAILSAFE = False
|
||||
logger.info("pyautogui successfully imported, GUI automation available")
|
||||
except Exception as e:
|
||||
@@ -28,58 +31,62 @@ except Exception as e:
|
||||
|
||||
# Try to import Windows-specific modules
|
||||
try:
|
||||
import win32gui
|
||||
import win32con
|
||||
import win32api
|
||||
import win32con
|
||||
import win32gui
|
||||
|
||||
logger.info("Windows API modules successfully imported")
|
||||
WINDOWS_API_AVAILABLE = True
|
||||
except Exception as e:
|
||||
logger.error(f"Windows API modules import failed: {str(e)}. Some Windows-specific features will be unavailable.")
|
||||
logger.error(
|
||||
f"Windows API modules import failed: {str(e)}. Some Windows-specific features will be unavailable."
|
||||
)
|
||||
WINDOWS_API_AVAILABLE = False
|
||||
|
||||
from .base import BaseAccessibilityHandler, BaseAutomationHandler
|
||||
|
||||
|
||||
class WindowsAccessibilityHandler(BaseAccessibilityHandler):
|
||||
"""Windows implementation of accessibility handler."""
|
||||
|
||||
|
||||
async def get_accessibility_tree(self) -> Dict[str, Any]:
|
||||
"""Get the accessibility tree of the current window.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing the success status and either
|
||||
the accessibility tree or an error message.
|
||||
Structure: {"success": bool, "tree": dict} or
|
||||
Structure: {"success": bool, "tree": dict} or
|
||||
{"success": bool, "error": str}
|
||||
"""
|
||||
if not WINDOWS_API_AVAILABLE:
|
||||
return {"success": False, "error": "Windows API not available"}
|
||||
|
||||
|
||||
try:
|
||||
# Get the foreground window
|
||||
hwnd = win32gui.GetForegroundWindow()
|
||||
if not hwnd:
|
||||
return {"success": False, "error": "No foreground window found"}
|
||||
|
||||
|
||||
# Get window information
|
||||
window_text = win32gui.GetWindowText(hwnd)
|
||||
rect = win32gui.GetWindowRect(hwnd)
|
||||
|
||||
|
||||
tree = {
|
||||
"role": "Window",
|
||||
"title": window_text,
|
||||
"position": {"x": rect[0], "y": rect[1]},
|
||||
"size": {"width": rect[2] - rect[0], "height": rect[3] - rect[1]},
|
||||
"children": []
|
||||
"children": [],
|
||||
}
|
||||
|
||||
|
||||
# Enumerate child windows
|
||||
def enum_child_proc(hwnd_child, children_list):
|
||||
"""Callback function to enumerate child windows and collect their information.
|
||||
|
||||
|
||||
Args:
|
||||
hwnd_child: Handle to the child window being enumerated.
|
||||
children_list: List to append child window information to.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True to continue enumeration, False to stop.
|
||||
"""
|
||||
@@ -87,46 +94,49 @@ class WindowsAccessibilityHandler(BaseAccessibilityHandler):
|
||||
child_text = win32gui.GetWindowText(hwnd_child)
|
||||
child_rect = win32gui.GetWindowRect(hwnd_child)
|
||||
child_class = win32gui.GetClassName(hwnd_child)
|
||||
|
||||
|
||||
child_info = {
|
||||
"role": child_class,
|
||||
"title": child_text,
|
||||
"position": {"x": child_rect[0], "y": child_rect[1]},
|
||||
"size": {"width": child_rect[2] - child_rect[0], "height": child_rect[3] - child_rect[1]},
|
||||
"children": []
|
||||
"size": {
|
||||
"width": child_rect[2] - child_rect[0],
|
||||
"height": child_rect[3] - child_rect[1],
|
||||
},
|
||||
"children": [],
|
||||
}
|
||||
children_list.append(child_info)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting child window info: {e}")
|
||||
return True
|
||||
|
||||
|
||||
win32gui.EnumChildWindows(hwnd, enum_child_proc, tree["children"])
|
||||
|
||||
|
||||
return {"success": True, "tree": tree}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting accessibility tree: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def find_element(self, role: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
value: Optional[str] = None) -> Dict[str, Any]:
|
||||
|
||||
async def find_element(
|
||||
self, role: Optional[str] = None, title: Optional[str] = None, value: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Find an element in the accessibility tree by criteria.
|
||||
|
||||
|
||||
Args:
|
||||
role (Optional[str]): The role or class name of the element to find.
|
||||
title (Optional[str]): The title or text of the element to find.
|
||||
value (Optional[str]): The value of the element (not used in Windows implementation).
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing the success status and either
|
||||
the found element or an error message.
|
||||
Structure: {"success": bool, "element": dict} or
|
||||
Structure: {"success": bool, "element": dict} or
|
||||
{"success": bool, "error": str}
|
||||
"""
|
||||
if not WINDOWS_API_AVAILABLE:
|
||||
return {"success": False, "error": "Windows API not available"}
|
||||
|
||||
|
||||
try:
|
||||
# Find window by title if specified
|
||||
if title:
|
||||
@@ -139,10 +149,10 @@ class WindowsAccessibilityHandler(BaseAccessibilityHandler):
|
||||
"role": "Window",
|
||||
"title": title,
|
||||
"position": {"x": rect[0], "y": rect[1]},
|
||||
"size": {"width": rect[2] - rect[0], "height": rect[3] - rect[1]}
|
||||
}
|
||||
"size": {"width": rect[2] - rect[0], "height": rect[3] - rect[1]},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Find window by class name if role is specified
|
||||
if role:
|
||||
hwnd = win32gui.FindWindow(role, None)
|
||||
@@ -155,36 +165,40 @@ class WindowsAccessibilityHandler(BaseAccessibilityHandler):
|
||||
"role": role,
|
||||
"title": window_text,
|
||||
"position": {"x": rect[0], "y": rect[1]},
|
||||
"size": {"width": rect[2] - rect[0], "height": rect[3] - rect[1]}
|
||||
}
|
||||
"size": {"width": rect[2] - rect[0], "height": rect[3] - rect[1]},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
return {"success": False, "error": "Element not found"}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding element: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
class WindowsAutomationHandler(BaseAutomationHandler):
|
||||
"""Windows implementation of automation handler using pyautogui and Windows APIs."""
|
||||
|
||||
|
||||
mouse = MouseController()
|
||||
keyboard = KeyboardController()
|
||||
|
||||
# Mouse Actions
|
||||
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
|
||||
async def mouse_down(
|
||||
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
|
||||
) -> Dict[str, Any]:
|
||||
"""Press and hold a mouse button at the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x (Optional[int]): The x-coordinate to move to before pressing. If None, uses current position.
|
||||
y (Optional[int]): The y-coordinate to move to before pressing. If None, uses current position.
|
||||
button (str): The mouse button to press ("left", "right", or "middle").
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and optional error message.
|
||||
"""
|
||||
if not pyautogui:
|
||||
return {"success": False, "error": "pyautogui not available"}
|
||||
|
||||
|
||||
try:
|
||||
if x is not None and y is not None:
|
||||
pyautogui.moveTo(x, y)
|
||||
@@ -192,21 +206,23 @@ class WindowsAutomationHandler(BaseAutomationHandler):
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
|
||||
|
||||
async def mouse_up(
|
||||
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
|
||||
) -> Dict[str, Any]:
|
||||
"""Release a mouse button at the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x (Optional[int]): The x-coordinate to move to before releasing. If None, uses current position.
|
||||
y (Optional[int]): The y-coordinate to move to before releasing. If None, uses current position.
|
||||
button (str): The mouse button to release ("left", "right", or "middle").
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and optional error message.
|
||||
"""
|
||||
if not pyautogui:
|
||||
return {"success": False, "error": "pyautogui not available"}
|
||||
|
||||
|
||||
try:
|
||||
if x is not None and y is not None:
|
||||
pyautogui.moveTo(x, y)
|
||||
@@ -214,20 +230,20 @@ class WindowsAutomationHandler(BaseAutomationHandler):
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
async def move_cursor(self, x: int, y: int) -> Dict[str, Any]:
|
||||
"""Move the mouse cursor to the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x (int): The x-coordinate to move to.
|
||||
y (int): The y-coordinate to move to.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and optional error message.
|
||||
"""
|
||||
if not pyautogui:
|
||||
return {"success": False, "error": "pyautogui not available"}
|
||||
|
||||
|
||||
try:
|
||||
pyautogui.moveTo(x, y)
|
||||
return {"success": True}
|
||||
@@ -236,17 +252,17 @@ class WindowsAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Perform a left mouse click at the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x (Optional[int]): The x-coordinate to click at. If None, clicks at current position.
|
||||
y (Optional[int]): The y-coordinate to click at. If None, clicks at current position.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and optional error message.
|
||||
"""
|
||||
if not pyautogui:
|
||||
return {"success": False, "error": "pyautogui not available"}
|
||||
|
||||
|
||||
try:
|
||||
if x is not None and y is not None:
|
||||
pyautogui.moveTo(x, y)
|
||||
@@ -257,17 +273,17 @@ class WindowsAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Perform a right mouse click at the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x (Optional[int]): The x-coordinate to click at. If None, clicks at current position.
|
||||
y (Optional[int]): The y-coordinate to click at. If None, clicks at current position.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and optional error message.
|
||||
"""
|
||||
if not pyautogui:
|
||||
return {"success": False, "error": "pyautogui not available"}
|
||||
|
||||
|
||||
try:
|
||||
if x is not None and y is not None:
|
||||
pyautogui.moveTo(x, y)
|
||||
@@ -276,19 +292,21 @@ class WindowsAutomationHandler(BaseAutomationHandler):
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
|
||||
async def double_click(
|
||||
self, x: Optional[int] = None, y: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Perform a double left mouse click at the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x (Optional[int]): The x-coordinate to double-click at. If None, clicks at current position.
|
||||
y (Optional[int]): The y-coordinate to double-click at. If None, clicks at current position.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and optional error message.
|
||||
"""
|
||||
if not pyautogui:
|
||||
return {"success": False, "error": "pyautogui not available"}
|
||||
|
||||
|
||||
try:
|
||||
if x is not None and y is not None:
|
||||
pyautogui.moveTo(x, y)
|
||||
@@ -297,52 +315,56 @@ class WindowsAutomationHandler(BaseAutomationHandler):
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def drag_to(self, x: int, y: int, button: str = "left", duration: float = 0.5) -> Dict[str, Any]:
|
||||
async def drag_to(
|
||||
self, x: int, y: int, button: str = "left", duration: float = 0.5
|
||||
) -> Dict[str, Any]:
|
||||
"""Drag from the current position to the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x (int): The x-coordinate to drag to.
|
||||
y (int): The y-coordinate to drag to.
|
||||
button (str): The mouse button to use for dragging ("left", "right", or "middle").
|
||||
duration (float): The time in seconds to take for the drag operation.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and optional error message.
|
||||
"""
|
||||
if not pyautogui:
|
||||
return {"success": False, "error": "pyautogui not available"}
|
||||
|
||||
|
||||
try:
|
||||
pyautogui.dragTo(x, y, duration=duration, button=button)
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def drag(self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5) -> Dict[str, Any]:
|
||||
async def drag(
|
||||
self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5
|
||||
) -> Dict[str, Any]:
|
||||
"""Drag the mouse through a series of coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
path (List[Tuple[int, int]]): A list of (x, y) coordinate tuples to drag through.
|
||||
button (str): The mouse button to use for dragging ("left", "right", or "middle").
|
||||
duration (float): The total time in seconds for the entire drag operation.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and optional error message.
|
||||
"""
|
||||
if not pyautogui:
|
||||
return {"success": False, "error": "pyautogui not available"}
|
||||
|
||||
|
||||
try:
|
||||
if not path:
|
||||
return {"success": False, "error": "Path is empty"}
|
||||
|
||||
|
||||
# Move to first position
|
||||
pyautogui.moveTo(*path[0])
|
||||
|
||||
|
||||
# Drag through all positions
|
||||
for x, y in path[1:]:
|
||||
pyautogui.dragTo(x, y, duration=duration/len(path), button=button)
|
||||
|
||||
pyautogui.dragTo(x, y, duration=duration / len(path), button=button)
|
||||
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
@@ -350,70 +372,68 @@ class WindowsAutomationHandler(BaseAutomationHandler):
|
||||
# Keyboard Actions
|
||||
async def key_down(self, key: str) -> Dict[str, Any]:
|
||||
"""Press and hold a keyboard key.
|
||||
|
||||
|
||||
Args:
|
||||
key (str): The key to press down (e.g., 'ctrl', 'shift', 'a').
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and optional error message.
|
||||
"""
|
||||
if not pyautogui:
|
||||
return {"success": False, "error": "pyautogui not available"}
|
||||
|
||||
|
||||
try:
|
||||
pyautogui.keyDown(key)
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
async def key_up(self, key: str) -> Dict[str, Any]:
|
||||
"""Release a keyboard key.
|
||||
|
||||
|
||||
Args:
|
||||
key (str): The key to release (e.g., 'ctrl', 'shift', 'a').
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and optional error message.
|
||||
"""
|
||||
if not pyautogui:
|
||||
return {"success": False, "error": "pyautogui not available"}
|
||||
|
||||
|
||||
try:
|
||||
pyautogui.keyUp(key)
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
async def type_text(self, text: str) -> Dict[str, Any]:
|
||||
"""Type the specified text.
|
||||
|
||||
|
||||
Args:
|
||||
text (str): The text to type.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and optional error message.
|
||||
"""
|
||||
if not pyautogui:
|
||||
return {"success": False, "error": "pyautogui not available"}
|
||||
|
||||
try:
|
||||
pyautogui.write(text)
|
||||
# use pynput for Unicode support
|
||||
self.keyboard.type(text)
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def press_key(self, key: str) -> Dict[str, Any]:
|
||||
"""Press and release a keyboard key.
|
||||
|
||||
|
||||
Args:
|
||||
key (str): The key to press (e.g., 'enter', 'space', 'tab').
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and optional error message.
|
||||
"""
|
||||
if not pyautogui:
|
||||
return {"success": False, "error": "pyautogui not available"}
|
||||
|
||||
|
||||
try:
|
||||
pyautogui.press(key)
|
||||
return {"success": True}
|
||||
@@ -422,16 +442,16 @@ class WindowsAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def hotkey(self, keys: List[str]) -> Dict[str, Any]:
|
||||
"""Press a combination of keys simultaneously.
|
||||
|
||||
|
||||
Args:
|
||||
keys (List[str]): The keys to press together (e.g., ['ctrl', 'c'], ['alt', 'tab']).
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and optional error message.
|
||||
"""
|
||||
if not pyautogui:
|
||||
return {"success": False, "error": "pyautogui not available"}
|
||||
|
||||
|
||||
try:
|
||||
pyautogui.hotkey(*keys)
|
||||
return {"success": True}
|
||||
@@ -441,35 +461,35 @@ class WindowsAutomationHandler(BaseAutomationHandler):
|
||||
# Scrolling Actions
|
||||
async def scroll(self, x: int, y: int) -> Dict[str, Any]:
|
||||
"""Scroll vertically at the current cursor position.
|
||||
|
||||
|
||||
Args:
|
||||
x (int): Horizontal scroll amount (not used in pyautogui implementation).
|
||||
y (int): Vertical scroll amount. Positive values scroll up, negative values scroll down.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and optional error message.
|
||||
"""
|
||||
if not pyautogui:
|
||||
return {"success": False, "error": "pyautogui not available"}
|
||||
|
||||
|
||||
try:
|
||||
self.mouse.scroll(x, y)
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
async def scroll_down(self, clicks: int = 1) -> Dict[str, Any]:
|
||||
"""Scroll down by the specified number of clicks.
|
||||
|
||||
|
||||
Args:
|
||||
clicks (int): The number of scroll clicks to perform downward.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and optional error message.
|
||||
"""
|
||||
if not pyautogui:
|
||||
return {"success": False, "error": "pyautogui not available"}
|
||||
|
||||
|
||||
try:
|
||||
pyautogui.scroll(-clicks)
|
||||
return {"success": True}
|
||||
@@ -478,16 +498,16 @@ class WindowsAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def scroll_up(self, clicks: int = 1) -> Dict[str, Any]:
|
||||
"""Scroll up by the specified number of clicks.
|
||||
|
||||
|
||||
Args:
|
||||
clicks (int): The number of scroll clicks to perform upward.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and optional error message.
|
||||
"""
|
||||
if not pyautogui:
|
||||
return {"success": False, "error": "pyautogui not available"}
|
||||
|
||||
|
||||
try:
|
||||
pyautogui.scroll(clicks)
|
||||
return {"success": True}
|
||||
@@ -497,22 +517,23 @@ class WindowsAutomationHandler(BaseAutomationHandler):
|
||||
# Screen Actions
|
||||
async def screenshot(self) -> Dict[str, Any]:
|
||||
"""Capture a screenshot of the entire screen.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing the success status and either
|
||||
base64-encoded image data or an error message.
|
||||
Structure: {"success": bool, "image_data": str} or
|
||||
Structure: {"success": bool, "image_data": str} or
|
||||
{"success": bool, "error": str}
|
||||
"""
|
||||
if not pyautogui:
|
||||
return {"success": False, "error": "pyautogui not available"}
|
||||
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
|
||||
screenshot = pyautogui.screenshot()
|
||||
if not isinstance(screenshot, Image.Image):
|
||||
return {"success": False, "error": "Failed to capture screenshot"}
|
||||
|
||||
|
||||
buffered = BytesIO()
|
||||
screenshot.save(buffered, format="PNG", optimize=True)
|
||||
buffered.seek(0)
|
||||
@@ -523,11 +544,11 @@ class WindowsAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def get_screen_size(self) -> Dict[str, Any]:
|
||||
"""Get the size of the screen in pixels.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing the success status and either
|
||||
screen size information or an error message.
|
||||
Structure: {"success": bool, "size": {"width": int, "height": int}} or
|
||||
Structure: {"success": bool, "size": {"width": int, "height": int}} or
|
||||
{"success": bool, "error": str}
|
||||
"""
|
||||
try:
|
||||
@@ -546,11 +567,11 @@ class WindowsAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def get_cursor_position(self) -> Dict[str, Any]:
|
||||
"""Get the current position of the mouse cursor.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing the success status and either
|
||||
cursor position or an error message.
|
||||
Structure: {"success": bool, "position": {"x": int, "y": int}} or
|
||||
Structure: {"success": bool, "position": {"x": int, "y": int}} or
|
||||
{"success": bool, "error": str}
|
||||
"""
|
||||
try:
|
||||
@@ -569,15 +590,16 @@ class WindowsAutomationHandler(BaseAutomationHandler):
|
||||
# Clipboard Actions
|
||||
async def copy_to_clipboard(self) -> Dict[str, Any]:
|
||||
"""Get the current content of the clipboard.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing the success status and either
|
||||
clipboard content or an error message.
|
||||
Structure: {"success": bool, "content": str} or
|
||||
Structure: {"success": bool, "content": str} or
|
||||
{"success": bool, "error": str}
|
||||
"""
|
||||
try:
|
||||
import pyperclip
|
||||
|
||||
content = pyperclip.paste()
|
||||
return {"success": True, "content": content}
|
||||
except Exception as e:
|
||||
@@ -585,15 +607,16 @@ class WindowsAutomationHandler(BaseAutomationHandler):
|
||||
|
||||
async def set_clipboard(self, text: str) -> Dict[str, Any]:
|
||||
"""Set the clipboard content to the specified text.
|
||||
|
||||
|
||||
Args:
|
||||
text (str): The text to copy to the clipboard.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with success status and optional error message.
|
||||
"""
|
||||
try:
|
||||
import pyperclip
|
||||
|
||||
pyperclip.copy(text)
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
@@ -602,31 +625,29 @@ class WindowsAutomationHandler(BaseAutomationHandler):
|
||||
# Command Execution
|
||||
async def run_command(self, command: str) -> Dict[str, Any]:
|
||||
"""Execute a shell command asynchronously.
|
||||
|
||||
|
||||
Args:
|
||||
command (str): The shell command to execute.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing the success status and either
|
||||
command output or an error message.
|
||||
Structure: {"success": bool, "stdout": str, "stderr": str, "return_code": int} or
|
||||
Structure: {"success": bool, "stdout": str, "stderr": str, "return_code": int} or
|
||||
{"success": bool, "error": str}
|
||||
"""
|
||||
try:
|
||||
# Create subprocess
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
# Wait for the subprocess to finish
|
||||
stdout, stderr = await process.communicate()
|
||||
# Return decoded output
|
||||
return {
|
||||
"success": True,
|
||||
"stdout": stdout.decode() if stdout else "",
|
||||
"stderr": stderr.decode() if stderr else "",
|
||||
"return_code": process.returncode
|
||||
"success": True,
|
||||
"stdout": stdout.decode() if stdout else "",
|
||||
"stderr": stderr.decode() if stderr else "",
|
||||
"return_code": process.returncode,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@@ -1,27 +1,37 @@
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, HTTPException, Header
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from typing import List, Dict, Any, Optional, Union, Literal, cast
|
||||
import uvicorn
|
||||
import logging
|
||||
import asyncio
|
||||
import json
|
||||
import traceback
|
||||
import inspect
|
||||
from contextlib import redirect_stdout, redirect_stderr
|
||||
from io import StringIO
|
||||
from .handlers.factory import HandlerFactory
|
||||
import os
|
||||
import aiohttp
|
||||
import hashlib
|
||||
import time
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import redirect_stderr, redirect_stdout
|
||||
from io import StringIO
|
||||
from typing import Any, Dict, List, Literal, Optional, Union, cast
|
||||
|
||||
import aiohttp
|
||||
import uvicorn
|
||||
from fastapi import (
|
||||
FastAPI,
|
||||
Header,
|
||||
HTTPException,
|
||||
Request,
|
||||
WebSocket,
|
||||
WebSocketDisconnect,
|
||||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from .handlers.factory import HandlerFactory
|
||||
|
||||
# Authentication session TTL (in seconds). Override via env var CUA_AUTH_TTL_SECONDS. Default: 60s
|
||||
AUTH_SESSION_TTL_SECONDS: int = int(os.environ.get("CUA_AUTH_TTL_SECONDS", "60"))
|
||||
|
||||
try:
|
||||
from agent import ComputerAgent
|
||||
|
||||
HAS_AGENT = True
|
||||
except ImportError:
|
||||
HAS_AGENT = False
|
||||
@@ -54,16 +64,20 @@ app.add_middleware(
|
||||
protocol_version = 1
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
package_version = version("cua-computer-server")
|
||||
except Exception:
|
||||
# Fallback for cases where package is not installed or importlib.metadata is not available
|
||||
try:
|
||||
import pkg_resources
|
||||
|
||||
package_version = pkg_resources.get_distribution("cua-computer-server").version
|
||||
except Exception:
|
||||
package_version = "unknown"
|
||||
|
||||
accessibility_handler, automation_handler, diorama_handler, file_handler = HandlerFactory.create_handlers()
|
||||
accessibility_handler, automation_handler, diorama_handler, file_handler = (
|
||||
HandlerFactory.create_handlers()
|
||||
)
|
||||
handlers = {
|
||||
"version": lambda: {"protocol": protocol_version, "package": package_version},
|
||||
# App-Use commands
|
||||
@@ -118,87 +132,91 @@ class AuthenticationManager:
|
||||
def __init__(self):
|
||||
self.sessions: Dict[str, Dict[str, Any]] = {}
|
||||
self.container_name = os.environ.get("CONTAINER_NAME")
|
||||
|
||||
|
||||
def _hash_credentials(self, container_name: str, api_key: str) -> str:
|
||||
"""Create a hash of container name and API key for session identification"""
|
||||
combined = f"{container_name}:{api_key}"
|
||||
return hashlib.sha256(combined.encode()).hexdigest()
|
||||
|
||||
|
||||
def _is_session_valid(self, session_data: Dict[str, Any]) -> bool:
|
||||
"""Check if a session is still valid based on expiration time"""
|
||||
if not session_data.get('valid', False):
|
||||
if not session_data.get("valid", False):
|
||||
return False
|
||||
|
||||
expires_at = session_data.get('expires_at', 0)
|
||||
|
||||
expires_at = session_data.get("expires_at", 0)
|
||||
return time.time() < expires_at
|
||||
|
||||
|
||||
async def auth(self, container_name: str, api_key: str) -> bool:
|
||||
"""Authenticate container name and API key, using cached sessions when possible"""
|
||||
# If no CONTAINER_NAME is set, always allow access (local development)
|
||||
if not self.container_name:
|
||||
logger.info("No CONTAINER_NAME set in environment. Allowing access (local development mode)")
|
||||
logger.info(
|
||||
"No CONTAINER_NAME set in environment. Allowing access (local development mode)"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
# Layer 1: VM Identity Verification
|
||||
if container_name != self.container_name:
|
||||
logger.warning(f"VM name mismatch. Expected: {self.container_name}, Got: {container_name}")
|
||||
logger.warning(
|
||||
f"VM name mismatch. Expected: {self.container_name}, Got: {container_name}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
# Create hash for session lookup
|
||||
session_hash = self._hash_credentials(container_name, api_key)
|
||||
|
||||
|
||||
# Check if we have a valid cached session
|
||||
if session_hash in self.sessions:
|
||||
session_data = self.sessions[session_hash]
|
||||
if self._is_session_valid(session_data):
|
||||
logger.info(f"Using cached authentication for container: {container_name}")
|
||||
return session_data['valid']
|
||||
return session_data["valid"]
|
||||
else:
|
||||
# Remove expired session
|
||||
del self.sessions[session_hash]
|
||||
|
||||
|
||||
# No valid cached session, authenticate with API
|
||||
logger.info(f"Authenticating with TryCUA API for container: {container_name}")
|
||||
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
|
||||
async with session.get(
|
||||
f"https://www.trycua.com/api/vm/auth?container_name={container_name}",
|
||||
f"https://www.cua.ai/api/vm/auth?container_name={container_name}",
|
||||
headers=headers,
|
||||
) as resp:
|
||||
is_valid = resp.status == 200 and bool((await resp.text()).strip())
|
||||
|
||||
|
||||
# Cache the result with configurable expiration
|
||||
self.sessions[session_hash] = {
|
||||
'valid': is_valid,
|
||||
'expires_at': time.time() + AUTH_SESSION_TTL_SECONDS
|
||||
"valid": is_valid,
|
||||
"expires_at": time.time() + AUTH_SESSION_TTL_SECONDS,
|
||||
}
|
||||
|
||||
|
||||
if is_valid:
|
||||
logger.info(f"Authentication successful for container: {container_name}")
|
||||
else:
|
||||
logger.warning(f"Authentication failed for container: {container_name}. Status: {resp.status}")
|
||||
|
||||
logger.warning(
|
||||
f"Authentication failed for container: {container_name}. Status: {resp.status}"
|
||||
)
|
||||
|
||||
return is_valid
|
||||
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Failed to validate API key with TryCUA API: {str(e)}")
|
||||
# Cache failed result to avoid repeated requests
|
||||
self.sessions[session_hash] = {
|
||||
'valid': False,
|
||||
'expires_at': time.time() + AUTH_SESSION_TTL_SECONDS
|
||||
"valid": False,
|
||||
"expires_at": time.time() + AUTH_SESSION_TTL_SECONDS,
|
||||
}
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during authentication: {str(e)}")
|
||||
# Cache failed result to avoid repeated requests
|
||||
self.sessions[session_hash] = {
|
||||
'valid': False,
|
||||
'expires_at': time.time() + AUTH_SESSION_TTL_SECONDS
|
||||
"valid": False,
|
||||
"expires_at": time.time() + AUTH_SESSION_TTL_SECONDS,
|
||||
}
|
||||
return False
|
||||
|
||||
@@ -218,6 +236,7 @@ class ConnectionManager:
|
||||
manager = ConnectionManager()
|
||||
auth_manager = AuthenticationManager()
|
||||
|
||||
|
||||
@app.get("/status")
|
||||
async def status():
|
||||
sys = platform.system().lower()
|
||||
@@ -234,80 +253,67 @@ async def status():
|
||||
features.append("agent")
|
||||
return {"status": "ok", "os_type": os_type, "features": features}
|
||||
|
||||
|
||||
@app.websocket("/ws", name="websocket_endpoint")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
global handlers
|
||||
|
||||
# WebSocket message size is configured at the app or endpoint level, not on the instance
|
||||
await manager.connect(websocket)
|
||||
|
||||
|
||||
# Check if CONTAINER_NAME is set (indicating cloud provider)
|
||||
server_container_name = os.environ.get("CONTAINER_NAME")
|
||||
|
||||
|
||||
# If cloud provider, perform authentication handshake
|
||||
if server_container_name:
|
||||
try:
|
||||
logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Waiting for authentication...")
|
||||
|
||||
logger.info(
|
||||
f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Waiting for authentication..."
|
||||
)
|
||||
|
||||
# Wait for authentication message
|
||||
auth_data = await websocket.receive_json()
|
||||
|
||||
|
||||
# Validate auth message format
|
||||
if auth_data.get("command") != "authenticate":
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "First message must be authentication"
|
||||
})
|
||||
await websocket.send_json(
|
||||
{"success": False, "error": "First message must be authentication"}
|
||||
)
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
|
||||
# Extract credentials
|
||||
client_api_key = auth_data.get("params", {}).get("api_key")
|
||||
client_container_name = auth_data.get("params", {}).get("container_name")
|
||||
|
||||
|
||||
# Validate credentials using AuthenticationManager
|
||||
if not client_api_key:
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "API key required"
|
||||
})
|
||||
await websocket.send_json({"success": False, "error": "API key required"})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
|
||||
if not client_container_name:
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "Container name required"
|
||||
})
|
||||
await websocket.send_json({"success": False, "error": "Container name required"})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
|
||||
# Use AuthenticationManager for validation
|
||||
is_authenticated = await auth_manager.auth(client_container_name, client_api_key)
|
||||
if not is_authenticated:
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "Authentication failed"
|
||||
})
|
||||
await websocket.send_json({"success": False, "error": "Authentication failed"})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
|
||||
logger.info(f"Authentication successful for VM: {client_container_name}")
|
||||
await websocket.send_json({
|
||||
"success": True,
|
||||
"message": "Authentication successful"
|
||||
})
|
||||
|
||||
await websocket.send_json({"success": True, "message": "Authentication successful"})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during authentication handshake: {str(e)}")
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "Authentication failed"
|
||||
})
|
||||
await websocket.send_json({"success": False, "error": "Authentication failed"})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
@@ -330,7 +336,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
handler_func = handlers[command]
|
||||
sig = inspect.signature(handler_func)
|
||||
filtered_params = {k: v for k, v in params.items() if k in sig.parameters}
|
||||
|
||||
|
||||
# Handle both sync and async functions
|
||||
if asyncio.iscoroutinefunction(handler_func):
|
||||
result = await handler_func(**filtered_params)
|
||||
@@ -367,20 +373,21 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
pass
|
||||
manager.disconnect(websocket)
|
||||
|
||||
|
||||
@app.post("/cmd")
|
||||
async def cmd_endpoint(
|
||||
request: Request,
|
||||
container_name: Optional[str] = Header(None, alias="X-Container-Name"),
|
||||
api_key: Optional[str] = Header(None, alias="X-API-Key")
|
||||
api_key: Optional[str] = Header(None, alias="X-API-Key"),
|
||||
):
|
||||
"""
|
||||
Backup endpoint for when WebSocket connections fail.
|
||||
Accepts commands via HTTP POST with streaming response.
|
||||
|
||||
|
||||
Headers:
|
||||
- X-Container-Name: Container name for cloud authentication
|
||||
- X-API-Key: API key for cloud authentication
|
||||
|
||||
|
||||
Body:
|
||||
{
|
||||
"command": "command_name",
|
||||
@@ -388,7 +395,7 @@ async def cmd_endpoint(
|
||||
}
|
||||
"""
|
||||
global handlers
|
||||
|
||||
|
||||
# Parse request body
|
||||
try:
|
||||
body = await request.json()
|
||||
@@ -396,32 +403,34 @@ async def cmd_endpoint(
|
||||
params = body.get("params", {})
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON body: {str(e)}")
|
||||
|
||||
|
||||
if not command:
|
||||
raise HTTPException(status_code=400, detail="Command is required")
|
||||
|
||||
|
||||
# Check if CONTAINER_NAME is set (indicating cloud provider)
|
||||
server_container_name = os.environ.get("CONTAINER_NAME")
|
||||
|
||||
|
||||
# If cloud provider, perform authentication
|
||||
if server_container_name:
|
||||
logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Performing authentication...")
|
||||
|
||||
logger.info(
|
||||
f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Performing authentication..."
|
||||
)
|
||||
|
||||
# Validate required headers
|
||||
if not container_name:
|
||||
raise HTTPException(status_code=401, detail="Container name required")
|
||||
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=401, detail="API key required")
|
||||
|
||||
|
||||
# Validate with AuthenticationManager
|
||||
is_authenticated = await auth_manager.auth(container_name, api_key)
|
||||
if not is_authenticated:
|
||||
raise HTTPException(status_code=401, detail="Authentication failed")
|
||||
|
||||
|
||||
if command not in handlers:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown command: {command}")
|
||||
|
||||
|
||||
async def generate_response():
|
||||
"""Generate streaming response for the command execution"""
|
||||
try:
|
||||
@@ -429,35 +438,36 @@ async def cmd_endpoint(
|
||||
handler_func = handlers[command]
|
||||
sig = inspect.signature(handler_func)
|
||||
filtered_params = {k: v for k, v in params.items() if k in sig.parameters}
|
||||
|
||||
|
||||
# Handle both sync and async functions
|
||||
if asyncio.iscoroutinefunction(handler_func):
|
||||
result = await handler_func(**filtered_params)
|
||||
else:
|
||||
# Run sync functions in thread pool to avoid blocking event loop
|
||||
result = await asyncio.to_thread(handler_func, **filtered_params)
|
||||
|
||||
|
||||
# Stream the successful result
|
||||
response_data = {"success": True, **result}
|
||||
yield f"data: {json.dumps(response_data)}\n\n"
|
||||
|
||||
|
||||
except Exception as cmd_error:
|
||||
logger.error(f"Error executing command {command}: {str(cmd_error)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
# Stream the error result
|
||||
error_data = {"success": False, "error": str(cmd_error)}
|
||||
yield f"data: {json.dumps(error_data)}\n\n"
|
||||
|
||||
|
||||
return StreamingResponse(
|
||||
generate_response(),
|
||||
media_type="text/plain",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@app.post("/responses")
|
||||
async def agent_response_endpoint(
|
||||
request: Request,
|
||||
@@ -480,11 +490,17 @@ async def agent_response_endpoint(
|
||||
"""
|
||||
if not HAS_AGENT:
|
||||
raise HTTPException(status_code=501, detail="ComputerAgent not available")
|
||||
|
||||
|
||||
# Authenticate via AuthenticationManager if running in cloud (CONTAINER_NAME set)
|
||||
container_name = os.environ.get("CONTAINER_NAME")
|
||||
if container_name:
|
||||
is_public = os.environ.get("CUA_ENABLE_PUBLIC_PROXY", "").lower().strip() in ["1", "true", "yes", "y", "on"]
|
||||
is_public = os.environ.get("CUA_ENABLE_PUBLIC_PROXY", "").lower().strip() in [
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
"y",
|
||||
"on",
|
||||
]
|
||||
if not is_public:
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=401, detail="Missing AGENT PROXY auth headers")
|
||||
@@ -511,10 +527,12 @@ async def agent_response_endpoint(
|
||||
def __init__(self, overrides: Dict[str, str]):
|
||||
self.overrides = overrides
|
||||
self._original: Dict[str, Optional[str]] = {}
|
||||
|
||||
def __enter__(self):
|
||||
for k, v in (self.overrides or {}).items():
|
||||
self._original[k] = os.environ.get(k)
|
||||
os.environ[k] = str(v)
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
for k, old in self._original.items():
|
||||
if old is None:
|
||||
@@ -598,9 +616,9 @@ async def agent_response_endpoint(
|
||||
start = path[0]
|
||||
await self._auto.mouse_down(start["x"], start["y"])
|
||||
for pt in path[1:]:
|
||||
await self._auto.move_cursor(pt["x"], pt["y"])
|
||||
await self._auto.move_cursor(pt["x"], pt["y"])
|
||||
end = path[-1]
|
||||
await self._auto.mouse_up(end["x"], end["y"])
|
||||
await self._auto.mouse_up(end["x"], end["y"])
|
||||
|
||||
async def get_current_url(self) -> str:
|
||||
# Not available in this server context
|
||||
@@ -667,7 +685,11 @@ async def agent_response_endpoint(
|
||||
async for result in agent.run(messages):
|
||||
total_output += result["output"]
|
||||
# Try to collect usage if present
|
||||
if isinstance(result, dict) and "usage" in result and isinstance(result["usage"], dict):
|
||||
if (
|
||||
isinstance(result, dict)
|
||||
and "usage" in result
|
||||
and isinstance(result["usage"], dict)
|
||||
):
|
||||
# Merge usage counters
|
||||
for k, v in result["usage"].items():
|
||||
if isinstance(v, (int, float)):
|
||||
@@ -686,14 +708,14 @@ async def agent_response_endpoint(
|
||||
logger.error(f"Error running agent: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
error = str(e)
|
||||
|
||||
|
||||
# Build response payload
|
||||
payload = {
|
||||
"model": model,
|
||||
"error": error,
|
||||
"output": total_output,
|
||||
"usage": total_usage,
|
||||
"status": "completed" if not error else "failed"
|
||||
"status": "completed" if not error else "failed",
|
||||
}
|
||||
|
||||
# CORS: allow any origin
|
||||
|
||||
@@ -5,8 +5,9 @@ Provides a clean API for starting and stopping the server.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uvicorn
|
||||
from typing import Optional
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .main import app as fastapi_app
|
||||
@@ -32,8 +33,14 @@ class Server:
|
||||
await server.stop() # Stop the server
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 8000, log_level: str = "info",
|
||||
ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8000,
|
||||
log_level: str = "info",
|
||||
ssl_keyfile: Optional[str] = None,
|
||||
ssl_certfile: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the server.
|
||||
|
||||
@@ -58,12 +65,12 @@ class Server:
|
||||
Start the server synchronously. This will block until the server is stopped.
|
||||
"""
|
||||
uvicorn.run(
|
||||
self.app,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
self.app,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
log_level=self.log_level,
|
||||
ssl_keyfile=self.ssl_keyfile,
|
||||
ssl_certfile=self.ssl_certfile
|
||||
ssl_certfile=self.ssl_certfile,
|
||||
)
|
||||
|
||||
async def start_async(self) -> None:
|
||||
@@ -72,12 +79,12 @@ class Server:
|
||||
will run in the background.
|
||||
"""
|
||||
server_config = uvicorn.Config(
|
||||
self.app,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
self.app,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
log_level=self.log_level,
|
||||
ssl_keyfile=self.ssl_keyfile,
|
||||
ssl_certfile=self.ssl_certfile
|
||||
ssl_certfile=self.ssl_certfile,
|
||||
)
|
||||
|
||||
self._should_exit.clear()
|
||||
|
||||
@@ -12,9 +12,10 @@ import platform
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import websockets
|
||||
from typing import Optional
|
||||
|
||||
import websockets
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -45,62 +46,62 @@ class Watchdog:
|
||||
"""Watchdog class to monitor server health via WebSocket connection.
|
||||
Unix/Linux only - provides restart capabilities.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, cli_args: Optional[dict] = None, ping_interval: int = 30):
|
||||
"""
|
||||
Initialize the watchdog.
|
||||
|
||||
|
||||
Args:
|
||||
cli_args: Dictionary of CLI arguments to replicate when restarting
|
||||
ping_interval: Interval between ping checks in seconds
|
||||
"""
|
||||
# Check if running on Unix/Linux
|
||||
if platform.system() not in ['Linux', 'Darwin']:
|
||||
if platform.system() not in ["Linux", "Darwin"]:
|
||||
raise RuntimeError("Watchdog is only supported on Unix/Linux systems")
|
||||
|
||||
|
||||
# Store CLI arguments for restart
|
||||
self.cli_args = cli_args or {}
|
||||
self.host = self.cli_args.get('host', 'localhost')
|
||||
self.port = self.cli_args.get('port', 8000)
|
||||
self.host = self.cli_args.get("host", "localhost")
|
||||
self.port = self.cli_args.get("port", 8000)
|
||||
self.ping_interval = ping_interval
|
||||
self.container_name = os.environ.get("CONTAINER_NAME")
|
||||
self.running = False
|
||||
self.restart_enabled = True
|
||||
|
||||
|
||||
@property
|
||||
def ws_uri(self) -> str:
|
||||
"""Get the WebSocket URI using the current IP address.
|
||||
|
||||
|
||||
Returns:
|
||||
WebSocket URI for the Computer API Server
|
||||
"""
|
||||
ip_address = "localhost" if not self.container_name else f"{self.container_name}.containers.cloud.trycua.com"
|
||||
ip_address = (
|
||||
"localhost"
|
||||
if not self.container_name
|
||||
else f"{self.container_name}.containers.cloud.trycua.com"
|
||||
)
|
||||
protocol = "wss" if self.container_name else "ws"
|
||||
port = "8443" if self.container_name else "8000"
|
||||
return f"{protocol}://{ip_address}:{port}/ws"
|
||||
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""
|
||||
Test connection to the WebSocket endpoint.
|
||||
|
||||
|
||||
Returns:
|
||||
True if connection successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Create a simple ping message
|
||||
ping_message = {
|
||||
"command": "get_screen_size",
|
||||
"params": {}
|
||||
}
|
||||
|
||||
ping_message = {"command": "get_screen_size", "params": {}}
|
||||
|
||||
# Try to connect to the WebSocket
|
||||
async with websockets.connect(
|
||||
self.ws_uri,
|
||||
max_size=1024 * 1024 * 10 # 10MB limit to match server
|
||||
self.ws_uri, max_size=1024 * 1024 * 10 # 10MB limit to match server
|
||||
) as websocket:
|
||||
# Send ping message
|
||||
await websocket.send(json.dumps(ping_message))
|
||||
|
||||
|
||||
# Wait for any response or just close
|
||||
try:
|
||||
response = await asyncio.wait_for(websocket.recv(), timeout=5)
|
||||
@@ -111,30 +112,27 @@ class Watchdog:
|
||||
except Exception as e:
|
||||
logger.warning(f"Ping failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def kill_processes_on_port(self, port: int) -> bool:
|
||||
"""
|
||||
Kill any processes using the specified port.
|
||||
|
||||
|
||||
Args:
|
||||
port: Port number to check and kill processes on
|
||||
|
||||
|
||||
Returns:
|
||||
True if processes were killed or none found, False on error
|
||||
"""
|
||||
try:
|
||||
# Find processes using the port
|
||||
result = subprocess.run(
|
||||
["lsof", "-ti", f":{port}"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10
|
||||
["lsof", "-ti", f":{port}"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
|
||||
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
pids = result.stdout.strip().split('\n')
|
||||
pids = result.stdout.strip().split("\n")
|
||||
logger.info(f"Found {len(pids)} processes using port {port}: {pids}")
|
||||
|
||||
|
||||
# Kill each process
|
||||
for pid in pids:
|
||||
if pid.strip():
|
||||
@@ -145,42 +143,42 @@ class Watchdog:
|
||||
logger.warning(f"Timeout killing process {pid}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error killing process {pid}: {e}")
|
||||
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.debug(f"No processes found using port {port}")
|
||||
return True
|
||||
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error(f"Timeout finding processes on port {port}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding processes on port {port}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def restart_server(self) -> bool:
|
||||
"""
|
||||
Attempt to restart the server by killing existing processes and starting new one.
|
||||
|
||||
|
||||
Returns:
|
||||
True if restart was attempted, False on error
|
||||
"""
|
||||
if not self.restart_enabled:
|
||||
logger.info("Server restart is disabled")
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
logger.info("Attempting to restart server...")
|
||||
|
||||
|
||||
# Kill processes on the port
|
||||
port_to_kill = 8443 if self.container_name else self.port
|
||||
if not self.kill_processes_on_port(port_to_kill):
|
||||
logger.error("Failed to kill processes on port, restart aborted")
|
||||
return False
|
||||
|
||||
|
||||
# Wait a moment for processes to die
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
# Try to restart the server
|
||||
# In container mode, we can't easily restart, so just log
|
||||
if self.container_name:
|
||||
@@ -190,50 +188,50 @@ class Watchdog:
|
||||
else:
|
||||
# For local mode, try to restart the CLI
|
||||
logger.info("Attempting to restart local server...")
|
||||
|
||||
|
||||
# Get the current Python executable and script
|
||||
python_exe = sys.executable
|
||||
|
||||
|
||||
# Try to find the CLI module
|
||||
try:
|
||||
# Build command with all original CLI arguments
|
||||
cmd = [python_exe, "-m", "computer_server.cli"]
|
||||
|
||||
|
||||
# Add all CLI arguments except watchdog-related ones
|
||||
for key, value in self.cli_args.items():
|
||||
if key in ['watchdog', 'watchdog_interval', 'no_restart']:
|
||||
if key in ["watchdog", "watchdog_interval", "no_restart"]:
|
||||
continue # Skip watchdog args to avoid recursive watchdog
|
||||
|
||||
|
||||
# Convert underscores to hyphens for CLI args
|
||||
arg_name = f"--{key.replace('_', '-')}"
|
||||
|
||||
|
||||
if isinstance(value, bool):
|
||||
if value: # Only add flag if True
|
||||
cmd.append(arg_name)
|
||||
else:
|
||||
cmd.extend([arg_name, str(value)])
|
||||
|
||||
|
||||
logger.info(f"Starting server with command: {' '.join(cmd)}")
|
||||
|
||||
|
||||
# Start process in background
|
||||
subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
start_new_session=True
|
||||
start_new_session=True,
|
||||
)
|
||||
|
||||
|
||||
logger.info("Server restart initiated")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to restart server: {e}")
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during server restart: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def start_monitoring(self) -> None:
|
||||
"""Start the watchdog monitoring loop."""
|
||||
self.running = True
|
||||
@@ -241,14 +239,14 @@ class Watchdog:
|
||||
logger.info(f"Ping interval: {self.ping_interval} seconds")
|
||||
if self.container_name:
|
||||
logger.info(f"Container mode detected: {self.container_name}")
|
||||
|
||||
|
||||
consecutive_failures = 0
|
||||
max_failures = 3
|
||||
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
success = await self.ping()
|
||||
|
||||
|
||||
if success:
|
||||
if consecutive_failures > 0:
|
||||
logger.info("Server connection restored")
|
||||
@@ -257,15 +255,17 @@ class Watchdog:
|
||||
else:
|
||||
consecutive_failures += 1
|
||||
logger.warning(f"Ping failed ({consecutive_failures}/{max_failures})")
|
||||
|
||||
|
||||
if consecutive_failures >= max_failures:
|
||||
logger.error(f"Server appears to be down after {max_failures} consecutive failures")
|
||||
|
||||
logger.error(
|
||||
f"Server appears to be down after {max_failures} consecutive failures"
|
||||
)
|
||||
|
||||
# Attempt to restart the server
|
||||
if self.restart_enabled:
|
||||
logger.info("Attempting automatic server restart...")
|
||||
restart_success = self.restart_server()
|
||||
|
||||
|
||||
if restart_success:
|
||||
logger.info("Server restart initiated, waiting before next ping...")
|
||||
# Wait longer after restart attempt
|
||||
@@ -275,17 +275,17 @@ class Watchdog:
|
||||
logger.error("Server restart failed")
|
||||
else:
|
||||
logger.warning("Automatic restart is disabled")
|
||||
|
||||
|
||||
# Wait for next ping interval
|
||||
await asyncio.sleep(self.ping_interval)
|
||||
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Watchdog monitoring cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in watchdog loop: {e}")
|
||||
await asyncio.sleep(self.ping_interval)
|
||||
|
||||
|
||||
def stop_monitoring(self) -> None:
|
||||
"""Stop the watchdog monitoring."""
|
||||
self.running = False
|
||||
@@ -295,13 +295,13 @@ class Watchdog:
|
||||
async def run_watchdog(cli_args: Optional[dict] = None, ping_interval: int = 30) -> None:
|
||||
"""
|
||||
Run the watchdog monitoring.
|
||||
|
||||
|
||||
Args:
|
||||
cli_args: Dictionary of CLI arguments to replicate when restarting
|
||||
ping_interval: Interval between ping checks in seconds
|
||||
"""
|
||||
watchdog = Watchdog(cli_args=cli_args, ping_interval=ping_interval)
|
||||
|
||||
|
||||
try:
|
||||
await watchdog.start_monitoring()
|
||||
except KeyboardInterrupt:
|
||||
@@ -313,21 +313,18 @@ async def run_watchdog(cli_args: Optional[dict] = None, ping_interval: int = 30)
|
||||
if __name__ == "__main__":
|
||||
# For testing the watchdog standalone
|
||||
import argparse
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run Computer API server watchdog")
|
||||
parser.add_argument("--host", default="localhost", help="Server host to monitor")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Server port to monitor")
|
||||
parser.add_argument("--ping-interval", type=int, default=30, help="Ping interval in seconds")
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
cli_args = {
|
||||
'host': args.host,
|
||||
'port': args.port
|
||||
}
|
||||
|
||||
cli_args = {"host": args.host, "port": args.port}
|
||||
asyncio.run(run_watchdog(cli_args, args.ping_interval))
|
||||
|
||||
@@ -4,14 +4,15 @@ build-backend = "pdm.backend"
|
||||
|
||||
[project]
|
||||
name = "cua-computer-server"
|
||||
version = "0.1.0"
|
||||
version = "0.1.27"
|
||||
|
||||
description = "Server component for the Computer-Use Interface (CUI) framework powering Cua"
|
||||
authors = [
|
||||
{ name = "TryCua", email = "gh@trycua.com" }
|
||||
]
|
||||
readme = "README.md"
|
||||
license = { text = "MIT" }
|
||||
requires-python = ">=3.9"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastapi>=0.111.0",
|
||||
"uvicorn[standard]>=0.27.0",
|
||||
@@ -21,7 +22,14 @@ dependencies = [
|
||||
"pillow>=10.2.0",
|
||||
"aiohttp>=3.9.1",
|
||||
"pyperclip>=1.9.0",
|
||||
"websockets>=12.0"
|
||||
"websockets>=12.0",
|
||||
# OS-specific runtime deps
|
||||
"pyobjc-framework-Cocoa>=10.1; sys_platform == 'darwin'",
|
||||
"pyobjc-framework-Quartz>=10.1; sys_platform == 'darwin'",
|
||||
"pyobjc-framework-ApplicationServices>=10.1; sys_platform == 'darwin'",
|
||||
"python-xlib>=0.33; sys_platform == 'linux'",
|
||||
"pywin32>=310; sys_platform == 'win32'",
|
||||
"pip-system-certs; sys_platform == 'win32'",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -66,23 +74,4 @@ dev = [
|
||||
]
|
||||
|
||||
[tool.pdm.scripts]
|
||||
api = "python -m computer_server"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py310"
|
||||
select = ["E", "F", "B", "I"]
|
||||
fix = true
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
|
||||
[tool.mypy]
|
||||
strict = true
|
||||
python_version = "3.10"
|
||||
ignore_missing_imports = true
|
||||
disallow_untyped_defs = true
|
||||
check_untyped_defs = true
|
||||
warn_return_any = true
|
||||
show_error_codes = true
|
||||
warn_unused_ignores = false
|
||||
api = "python -m computer_server"
|
||||
@@ -10,6 +10,7 @@ Usage:
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
from computer_server.cli import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -6,18 +6,22 @@ This script tests both WebSocket (/ws) and REST (/cmd) connections to the Comput
|
||||
and keeps it alive, allowing you to verify the server is running correctly.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import websockets
|
||||
import argparse
|
||||
import sys
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
import dotenv
|
||||
import websockets
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
async def test_websocket_connection(host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None):
|
||||
|
||||
async def test_websocket_connection(
|
||||
host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None
|
||||
):
|
||||
"""Test WebSocket connection to the Computer Server."""
|
||||
if container_name:
|
||||
# Container mode: use WSS with container domain and port 8443
|
||||
@@ -37,19 +41,16 @@ async def test_websocket_connection(host="localhost", port=8000, keep_alive=Fals
|
||||
if not api_key:
|
||||
print("Error: API key required for container connections")
|
||||
return False
|
||||
|
||||
|
||||
print("Sending authentication...")
|
||||
auth_message = {
|
||||
"command": "authenticate",
|
||||
"params": {
|
||||
"api_key": api_key,
|
||||
"container_name": container_name
|
||||
}
|
||||
"params": {"api_key": api_key, "container_name": container_name},
|
||||
}
|
||||
await websocket.send(json.dumps(auth_message))
|
||||
auth_response = await websocket.recv()
|
||||
print(f"Authentication response: {auth_response}")
|
||||
|
||||
|
||||
# Check if authentication was successful
|
||||
auth_data = json.loads(auth_response)
|
||||
if not auth_data.get("success", False):
|
||||
@@ -90,7 +91,9 @@ async def test_websocket_connection(host="localhost", port=8000, keep_alive=Fals
|
||||
return True
|
||||
|
||||
|
||||
async def test_rest_connection(host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None):
|
||||
async def test_rest_connection(
|
||||
host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None
|
||||
):
|
||||
"""Test REST connection to the Computer Server."""
|
||||
if container_name:
|
||||
# Container mode: use HTTPS with container domain and port 8443
|
||||
@@ -113,13 +116,11 @@ async def test_rest_connection(host="localhost", port=8000, keep_alive=False, co
|
||||
return False
|
||||
headers["X-Container-Name"] = container_name
|
||||
headers["X-API-Key"] = api_key
|
||||
print(f"Using container authentication headers")
|
||||
print("Using container authentication headers")
|
||||
|
||||
# Test screenshot endpoint
|
||||
async with session.post(
|
||||
f"{base_url}/cmd",
|
||||
json={"command": "screenshot", "params": {}},
|
||||
headers=headers
|
||||
f"{base_url}/cmd", json={"command": "screenshot", "params": {}}, headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
text = await response.text()
|
||||
@@ -133,7 +134,7 @@ async def test_rest_connection(host="localhost", port=8000, keep_alive=False, co
|
||||
async with session.post(
|
||||
f"{base_url}/cmd",
|
||||
json={"command": "get_screen_size", "params": {}},
|
||||
headers=headers
|
||||
headers=headers,
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
text = await response.text()
|
||||
@@ -151,7 +152,7 @@ async def test_rest_connection(host="localhost", port=8000, keep_alive=False, co
|
||||
async with session.post(
|
||||
f"{base_url}/cmd",
|
||||
json={"command": "get_cursor_position", "params": {}},
|
||||
headers=headers
|
||||
headers=headers,
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
text = await response.text()
|
||||
@@ -171,7 +172,9 @@ async def test_rest_connection(host="localhost", port=8000, keep_alive=False, co
|
||||
return True
|
||||
|
||||
|
||||
async def test_connection(host="localhost", port=8000, keep_alive=False, container_name=None, use_rest=False, api_key=None):
|
||||
async def test_connection(
|
||||
host="localhost", port=8000, keep_alive=False, container_name=None, use_rest=False, api_key=None
|
||||
):
|
||||
"""Test connection to the Computer Server using WebSocket or REST."""
|
||||
if use_rest:
|
||||
return await test_rest_connection(host, port, keep_alive, container_name, api_key)
|
||||
@@ -183,40 +186,50 @@ def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Test connection to Computer Server")
|
||||
parser.add_argument("--host", default="localhost", help="Host address (default: localhost)")
|
||||
parser.add_argument("-p", "--port", type=int, default=8000, help="Port number (default: 8000)")
|
||||
parser.add_argument("-c", "--container-name", help="Container name for cloud connection (uses WSS/HTTPS and port 8443)")
|
||||
parser.add_argument("--api-key", help="API key for container authentication (can also use CUA_API_KEY env var)")
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--container-name",
|
||||
help="Container name for cloud connection (uses WSS/HTTPS and port 8443)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-key", help="API key for container authentication (can also use CUA_API_KEY env var)"
|
||||
)
|
||||
parser.add_argument("--keep-alive", action="store_true", help="Keep connection alive")
|
||||
parser.add_argument("--rest", action="store_true", help="Use REST endpoint (/cmd) instead of WebSocket (/ws)")
|
||||
parser.add_argument(
|
||||
"--rest", action="store_true", help="Use REST endpoint (/cmd) instead of WebSocket (/ws)"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def main():
|
||||
args = parse_args()
|
||||
|
||||
|
||||
# Convert hyphenated argument to underscore for function parameter
|
||||
container_name = getattr(args, 'container_name', None)
|
||||
|
||||
container_name = getattr(args, "container_name", None)
|
||||
|
||||
# Get API key from argument or environment variable
|
||||
api_key = getattr(args, 'api_key', None) or os.environ.get('CUA_API_KEY')
|
||||
|
||||
api_key = getattr(args, "api_key", None) or os.environ.get("CUA_API_KEY")
|
||||
|
||||
# Check if container name is provided but API key is missing
|
||||
if container_name and not api_key:
|
||||
print("Warning: Container name provided but no API key found.")
|
||||
print("Please provide --api-key argument or set CUA_API_KEY environment variable.")
|
||||
return 1
|
||||
|
||||
|
||||
print(f"Testing {'REST' if args.rest else 'WebSocket'} connection...")
|
||||
if container_name:
|
||||
print(f"Container: {container_name}")
|
||||
print(f"API Key: {'***' + api_key[-4:] if api_key and len(api_key) > 4 else 'Not provided'}")
|
||||
|
||||
print(
|
||||
f"API Key: {'***' + api_key[-4:] if api_key and len(api_key) > 4 else 'Not provided'}"
|
||||
)
|
||||
|
||||
success = await test_connection(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
keep_alive=args.keep_alive,
|
||||
container_name=container_name,
|
||||
use_rest=args.rest,
|
||||
api_key=api_key
|
||||
api_key=api_key,
|
||||
)
|
||||
return 0 if success else 1
|
||||
|
||||
|
||||
10
libs/python/computer/.bumpversion.cfg
Normal file
10
libs/python/computer/.bumpversion.cfg
Normal file
@@ -0,0 +1,10 @@
|
||||
[bumpversion]
|
||||
current_version = 0.4.7
|
||||
commit = True
|
||||
tag = True
|
||||
tag_name = computer-v{new_version}
|
||||
message = Bump cua-computer to v{new_version}
|
||||
|
||||
[bumpversion:file:pyproject.toml]
|
||||
search = version = "{current_version}"
|
||||
replace = version = "{new_version}"
|
||||
@@ -8,10 +8,11 @@
|
||||
</picture>
|
||||
</div>
|
||||
|
||||
[](#)
|
||||
[](#)
|
||||
[](https://discord.com/invite/mVnXXpdE85)
|
||||
[](https://pypi.org/project/cua-computer/)
|
||||
[](#)
|
||||
[](#)
|
||||
[](https://discord.com/invite/mVnXXpdE85)
|
||||
[](https://pypi.org/project/cua-computer/)
|
||||
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
@@ -29,11 +30,11 @@ from computer import Computer
|
||||
computer = Computer(os_type="macos", display="1024x768", memory="8GB", cpu="4")
|
||||
try:
|
||||
await computer.run()
|
||||
|
||||
|
||||
screenshot = await computer.interface.screenshot()
|
||||
with open("screenshot.png", "wb") as f:
|
||||
f.write(screenshot)
|
||||
|
||||
|
||||
await computer.interface.move_cursor(100, 100)
|
||||
await computer.interface.left_click()
|
||||
await computer.interface.right_click(300, 300)
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
from typing import Optional, List, Literal, Dict, Any, Union, TYPE_CHECKING, cast
|
||||
import asyncio
|
||||
from .models import Computer as ComputerConfig, Display
|
||||
from .interface.factory import InterfaceFactory
|
||||
import time
|
||||
from PIL import Image
|
||||
import io
|
||||
import re
|
||||
from .logger import Logger, LogLevel
|
||||
import json
|
||||
import logging
|
||||
from core.telemetry import is_telemetry_enabled, record_event
|
||||
import os
|
||||
from . import helpers
|
||||
|
||||
import platform
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
|
||||
|
||||
from core.telemetry import is_telemetry_enabled, record_event
|
||||
from PIL import Image
|
||||
|
||||
from . import helpers
|
||||
from .interface.factory import InterfaceFactory
|
||||
from .logger import Logger, LogLevel
|
||||
from .models import Computer as ComputerConfig
|
||||
from .models import Display
|
||||
|
||||
SYSTEM_INFO = {
|
||||
"os": platform.system().lower(),
|
||||
@@ -27,6 +30,7 @@ from .providers.factory import VMProviderFactory
|
||||
|
||||
OSType = Literal["macos", "linux", "windows"]
|
||||
|
||||
|
||||
class Computer:
|
||||
"""Computer is the main class for interacting with the computer."""
|
||||
|
||||
@@ -40,8 +44,11 @@ class Computer:
|
||||
Returns:
|
||||
DioramaComputer: A proxy object with the Diorama interface, but using diorama_cmds.
|
||||
"""
|
||||
assert "app-use" in self.experiments, "App Usage is an experimental feature. Enable it by passing experiments=['app-use'] to Computer()"
|
||||
assert (
|
||||
"app-use" in self.experiments
|
||||
), "App Usage is an experimental feature. Enable it by passing experiments=['app-use'] to Computer()"
|
||||
from .diorama_computer import DioramaComputer
|
||||
|
||||
return DioramaComputer(self, apps)
|
||||
|
||||
def __init__(
|
||||
@@ -63,7 +70,7 @@ class Computer:
|
||||
storage: Optional[str] = None,
|
||||
ephemeral: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
experiments: Optional[List[str]] = None
|
||||
experiments: Optional[List[str]] = None,
|
||||
):
|
||||
"""Initialize a new Computer instance.
|
||||
|
||||
@@ -111,32 +118,36 @@ class Computer:
|
||||
self.os_type = os_type
|
||||
self.provider_type = provider_type
|
||||
self.ephemeral = ephemeral
|
||||
|
||||
|
||||
self.api_key = api_key
|
||||
self.experiments = experiments or []
|
||||
|
||||
|
||||
if "app-use" in self.experiments:
|
||||
assert self.os_type == "macos", "App use experiment is only supported on macOS"
|
||||
|
||||
# The default is currently to use non-ephemeral storage
|
||||
if storage and ephemeral and storage != "ephemeral":
|
||||
raise ValueError("Storage path and ephemeral flag cannot be used together")
|
||||
|
||||
|
||||
# Windows Sandbox always uses ephemeral storage
|
||||
if self.provider_type == VMProviderType.WINSANDBOX:
|
||||
if not ephemeral and storage != None and storage != "ephemeral":
|
||||
self.logger.warning("Windows Sandbox storage is always ephemeral. Setting ephemeral=True.")
|
||||
self.logger.warning(
|
||||
"Windows Sandbox storage is always ephemeral. Setting ephemeral=True."
|
||||
)
|
||||
self.ephemeral = True
|
||||
self.storage = "ephemeral"
|
||||
else:
|
||||
self.storage = "ephemeral" if ephemeral else storage
|
||||
|
||||
|
||||
# For Lumier provider, store the first shared directory path to use
|
||||
# for VM file sharing
|
||||
self.shared_path = None
|
||||
if shared_directories and len(shared_directories) > 0:
|
||||
self.shared_path = shared_directories[0]
|
||||
self.logger.info(f"Using first shared directory for VM file sharing: {self.shared_path}")
|
||||
self.logger.info(
|
||||
f"Using first shared directory for VM file sharing: {self.shared_path}"
|
||||
)
|
||||
|
||||
# Store telemetry preference
|
||||
self._telemetry_enabled = telemetry_enabled
|
||||
@@ -154,8 +165,8 @@ class Computer:
|
||||
self.interface_logger = Logger("computer.interface", verbosity)
|
||||
|
||||
if not use_host_computer_server:
|
||||
if ":" not in image or len(image.split(":")) != 2:
|
||||
raise ValueError("Image must be in the format <image_name>:<tag>")
|
||||
if ":" not in image:
|
||||
image = f"{image}:latest"
|
||||
|
||||
if not name:
|
||||
# Normalize the name to be used for the VM
|
||||
@@ -263,8 +274,14 @@ class Computer:
|
||||
self.logger.info(f"Starting VM: {self.image}")
|
||||
if not self._provider_context:
|
||||
try:
|
||||
provider_type_name = self.provider_type.name if isinstance(self.provider_type, VMProviderType) else self.provider_type
|
||||
self.logger.verbose(f"Initializing {provider_type_name} provider context...")
|
||||
provider_type_name = (
|
||||
self.provider_type.name
|
||||
if isinstance(self.provider_type, VMProviderType)
|
||||
else self.provider_type
|
||||
)
|
||||
self.logger.verbose(
|
||||
f"Initializing {provider_type_name} provider context..."
|
||||
)
|
||||
|
||||
# Explicitly set provider parameters
|
||||
storage = "ephemeral" if self.ephemeral else self.storage
|
||||
@@ -281,9 +298,13 @@ class Computer:
|
||||
if self.provider_type == VMProviderType.LUMIER:
|
||||
self.logger.info(f"Using VM image for Lumier provider: {image}")
|
||||
if shared_path:
|
||||
self.logger.info(f"Using shared path for Lumier provider: {shared_path}")
|
||||
self.logger.info(
|
||||
f"Using shared path for Lumier provider: {shared_path}"
|
||||
)
|
||||
if noVNC_port:
|
||||
self.logger.info(f"Using noVNC port for Lumier provider: {noVNC_port}")
|
||||
self.logger.info(
|
||||
f"Using noVNC port for Lumier provider: {noVNC_port}"
|
||||
)
|
||||
self.config.vm_provider = VMProviderFactory.create_provider(
|
||||
self.provider_type,
|
||||
port=port,
|
||||
@@ -339,11 +360,17 @@ class Computer:
|
||||
except ImportError as ie:
|
||||
self.logger.error(f"Failed to import provider dependencies: {ie}")
|
||||
if str(ie).find("lume") >= 0 and str(ie).find("lumier") < 0:
|
||||
self.logger.error("Please install with: pip install cua-computer[lume]")
|
||||
self.logger.error(
|
||||
"Please install with: pip install cua-computer[lume]"
|
||||
)
|
||||
elif str(ie).find("lumier") >= 0 or str(ie).find("docker") >= 0:
|
||||
self.logger.error("Please install with: pip install cua-computer[lumier] and make sure Docker is installed")
|
||||
self.logger.error(
|
||||
"Please install with: pip install cua-computer[lumier] and make sure Docker is installed"
|
||||
)
|
||||
elif str(ie).find("cloud") >= 0:
|
||||
self.logger.error("Please install with: pip install cua-computer[cloud]")
|
||||
self.logger.error(
|
||||
"Please install with: pip install cua-computer[cloud]"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize provider context: {e}")
|
||||
@@ -354,16 +381,14 @@ class Computer:
|
||||
try:
|
||||
if self.config.vm_provider is None:
|
||||
raise RuntimeError(f"VM provider not initialized for {self.config.name}")
|
||||
|
||||
|
||||
vm = await self.config.vm_provider.get_vm(self.config.name)
|
||||
self.logger.verbose(f"Found existing VM: {self.config.name}")
|
||||
is_running = vm.get("status") == "running"
|
||||
except Exception as e:
|
||||
self.logger.error(f"VM not found: {self.config.name}")
|
||||
self.logger.error(f"Error: {e}")
|
||||
raise RuntimeError(
|
||||
f"VM {self.config.name} could not be found or created."
|
||||
)
|
||||
raise RuntimeError(f"VM {self.config.name} could not be found or created.")
|
||||
|
||||
# Start the VM if it's not running
|
||||
if not is_running:
|
||||
@@ -376,13 +401,10 @@ class Computer:
|
||||
path = os.path.abspath(os.path.expanduser(path))
|
||||
if os.path.exists(path):
|
||||
# Add path in format expected by Lume API
|
||||
shared_dirs.append({
|
||||
"hostPath": path,
|
||||
"readOnly": False
|
||||
})
|
||||
shared_dirs.append({"hostPath": path, "readOnly": False})
|
||||
else:
|
||||
self.logger.warning(f"Shared directory does not exist: {path}")
|
||||
|
||||
|
||||
# Prepare run options to pass to the provider
|
||||
run_opts = {}
|
||||
|
||||
@@ -392,11 +414,11 @@ class Computer:
|
||||
"width": self.config.display.width,
|
||||
"height": self.config.display.height,
|
||||
}
|
||||
|
||||
|
||||
# Check if scale_factor exists before adding it
|
||||
if hasattr(self.config.display, "scale_factor"):
|
||||
display_info["scale_factor"] = self.config.display.scale_factor
|
||||
|
||||
|
||||
run_opts["display"] = display_info
|
||||
|
||||
# Add shared directories if available
|
||||
@@ -406,21 +428,23 @@ class Computer:
|
||||
# Run the VM with the provider
|
||||
try:
|
||||
if self.config.vm_provider is None:
|
||||
raise RuntimeError(f"VM provider not initialized for {self.config.name}")
|
||||
|
||||
raise RuntimeError(
|
||||
f"VM provider not initialized for {self.config.name}"
|
||||
)
|
||||
|
||||
# Use the complete run_opts we prepared earlier
|
||||
# Handle ephemeral storage for run_vm method too
|
||||
storage_param = "ephemeral" if self.ephemeral else self.storage
|
||||
|
||||
|
||||
# Log the image being used
|
||||
self.logger.info(f"Running VM using image: {self.image}")
|
||||
|
||||
|
||||
# Call provider.run_vm with explicit image parameter
|
||||
response = await self.config.vm_provider.run_vm(
|
||||
image=self.image,
|
||||
name=self.config.name,
|
||||
run_opts=run_opts,
|
||||
storage=storage_param
|
||||
storage=storage_param,
|
||||
)
|
||||
self.logger.info(f"VM run response: {response if response else 'None'}")
|
||||
except Exception as run_error:
|
||||
@@ -432,14 +456,16 @@ class Computer:
|
||||
try:
|
||||
if self.provider_type == VMProviderType.LUMIER:
|
||||
max_retries = 60 # Increased for Lumier VM startup which takes longer
|
||||
retry_delay = 3 # 3 seconds between retries for Lumier
|
||||
retry_delay = 3 # 3 seconds between retries for Lumier
|
||||
else:
|
||||
max_retries = 30 # Default for other providers
|
||||
retry_delay = 2 # 2 seconds between retries
|
||||
|
||||
self.logger.info(f"Waiting up to {max_retries * retry_delay} seconds for VM to be ready...")
|
||||
retry_delay = 2 # 2 seconds between retries
|
||||
|
||||
self.logger.info(
|
||||
f"Waiting up to {max_retries * retry_delay} seconds for VM to be ready..."
|
||||
)
|
||||
ip = await self.get_ip(max_retries=max_retries, retry_delay=retry_delay)
|
||||
|
||||
|
||||
# If we get here, we have a valid IP
|
||||
self.logger.info(f"VM is ready with IP: {ip}")
|
||||
ip_address = ip
|
||||
@@ -451,13 +477,16 @@ class Computer:
|
||||
raise RuntimeError(f"VM failed to become ready: {wait_error}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize computer: {e}")
|
||||
self.logger.error(traceback.format_exc())
|
||||
raise RuntimeError(f"Failed to initialize computer: {e}")
|
||||
|
||||
try:
|
||||
# Verify we have a valid IP before initializing the interface
|
||||
if not ip_address or ip_address == "unknown" or ip_address == "0.0.0.0":
|
||||
raise RuntimeError(f"Cannot initialize interface - invalid IP address: {ip_address}")
|
||||
|
||||
raise RuntimeError(
|
||||
f"Cannot initialize interface - invalid IP address: {ip_address}"
|
||||
)
|
||||
|
||||
# Initialize the interface using the factory with the specified OS
|
||||
self.logger.info(f"Initializing interface for {self.os_type} at {ip_address}")
|
||||
from .interface.base import BaseComputerInterface
|
||||
@@ -467,18 +496,17 @@ class Computer:
|
||||
self._interface = cast(
|
||||
BaseComputerInterface,
|
||||
InterfaceFactory.create_interface_for_os(
|
||||
os=self.os_type,
|
||||
os=self.os_type,
|
||||
ip_address=ip_address,
|
||||
api_key=self.api_key,
|
||||
vm_name=self.config.name
|
||||
vm_name=self.config.name,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self._interface = cast(
|
||||
BaseComputerInterface,
|
||||
InterfaceFactory.create_interface_for_os(
|
||||
os=self.os_type,
|
||||
ip_address=ip_address
|
||||
os=self.os_type, ip_address=ip_address
|
||||
),
|
||||
)
|
||||
|
||||
@@ -508,10 +536,10 @@ class Computer:
|
||||
|
||||
# Set the initialization flag and clear the initializing flag
|
||||
self._initialized = True
|
||||
|
||||
|
||||
# Set this instance as the default computer for remote decorators
|
||||
helpers.set_default_computer(self)
|
||||
|
||||
|
||||
self.logger.info("Computer successfully initialized")
|
||||
except Exception as e:
|
||||
raise
|
||||
@@ -520,7 +548,7 @@ class Computer:
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
self.logger.debug(f"Computer initialization took {duration_ms:.2f}ms")
|
||||
return
|
||||
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from the computer's WebSocket interface."""
|
||||
if self._interface:
|
||||
@@ -534,13 +562,17 @@ class Computer:
|
||||
self.logger.info("Stopping Computer...")
|
||||
|
||||
# In VM mode, first explicitly stop the VM, then exit the provider context
|
||||
if not self.use_host_computer_server and self._provider_context and self.config.vm_provider is not None:
|
||||
if (
|
||||
not self.use_host_computer_server
|
||||
and self._provider_context
|
||||
and self.config.vm_provider is not None
|
||||
):
|
||||
try:
|
||||
self.logger.info(f"Stopping VM {self.config.name}...")
|
||||
await self.config.vm_provider.stop_vm(
|
||||
name=self.config.name,
|
||||
storage=self.storage # Pass storage explicitly for clarity
|
||||
)
|
||||
name=self.config.name,
|
||||
storage=self.storage, # Pass storage explicitly for clarity
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error stopping VM: {e}")
|
||||
|
||||
@@ -551,55 +583,156 @@ class Computer:
|
||||
await self.disconnect()
|
||||
self.logger.info("Computer stopped")
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Error during cleanup: {e}") # Log as debug since this might be expected
|
||||
self.logger.debug(
|
||||
f"Error during cleanup: {e}"
|
||||
) # Log as debug since this might be expected
|
||||
finally:
|
||||
# Log stop time for performance monitoring
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
self.logger.debug(f"Computer stop process took {duration_ms:.2f}ms")
|
||||
return
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the computer."""
|
||||
await self.run()
|
||||
|
||||
async def restart(self) -> None:
|
||||
"""Restart the computer.
|
||||
|
||||
If using a VM provider that supports restart, this will issue a restart
|
||||
without tearing down the provider context, then reconnect the interface.
|
||||
Falls back to stop()+run() when a provider restart is not available.
|
||||
"""
|
||||
# Host computer server: just disconnect and run again
|
||||
if self.use_host_computer_server:
|
||||
try:
|
||||
await self.disconnect()
|
||||
finally:
|
||||
await self.run()
|
||||
return
|
||||
|
||||
# If no VM provider context yet, fall back to full run
|
||||
if not getattr(self, "_provider_context", None) or self.config.vm_provider is None:
|
||||
self.logger.info("No provider context active; performing full restart via run()")
|
||||
await self.run()
|
||||
return
|
||||
|
||||
# Gracefully close current interface connection if present
|
||||
if self._interface:
|
||||
try:
|
||||
self._interface.close()
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Error closing interface prior to restart: {e}")
|
||||
|
||||
# Attempt provider-level restart if implemented
|
||||
try:
|
||||
storage_param = "ephemeral" if self.ephemeral else self.storage
|
||||
if hasattr(self.config.vm_provider, "restart_vm"):
|
||||
self.logger.info(f"Restarting VM {self.config.name} via provider...")
|
||||
await self.config.vm_provider.restart_vm(
|
||||
name=self.config.name, storage=storage_param
|
||||
)
|
||||
else:
|
||||
# Fallback: stop then start without leaving provider context
|
||||
self.logger.info(
|
||||
f"Provider has no restart_vm; performing stop+start for {self.config.name}..."
|
||||
)
|
||||
await self.config.vm_provider.stop_vm(name=self.config.name, storage=storage_param)
|
||||
await self.config.vm_provider.run_vm(
|
||||
image=self.image, name=self.config.name, run_opts={}, storage=storage_param
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to restart VM via provider: {e}")
|
||||
# As a last resort, do a full stop (with provider context exit) and run
|
||||
try:
|
||||
await self.stop()
|
||||
finally:
|
||||
await self.run()
|
||||
return
|
||||
|
||||
# Wait for VM to be ready and reconnect interface
|
||||
try:
|
||||
self.logger.info("Waiting for VM to be ready after restart...")
|
||||
if self.provider_type == VMProviderType.LUMIER:
|
||||
max_retries = 60
|
||||
retry_delay = 3
|
||||
else:
|
||||
max_retries = 30
|
||||
retry_delay = 2
|
||||
ip_address = await self.get_ip(max_retries=max_retries, retry_delay=retry_delay)
|
||||
|
||||
self.logger.info(f"Re-initializing interface for {self.os_type} at {ip_address}")
|
||||
from .interface.base import BaseComputerInterface
|
||||
|
||||
if self.provider_type == VMProviderType.CLOUD and self.api_key and self.config.name:
|
||||
self._interface = cast(
|
||||
BaseComputerInterface,
|
||||
InterfaceFactory.create_interface_for_os(
|
||||
os=self.os_type,
|
||||
ip_address=ip_address,
|
||||
api_key=self.api_key,
|
||||
vm_name=self.config.name,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self._interface = cast(
|
||||
BaseComputerInterface,
|
||||
InterfaceFactory.create_interface_for_os(
|
||||
os=self.os_type,
|
||||
ip_address=ip_address,
|
||||
),
|
||||
)
|
||||
|
||||
self.logger.info("Connecting to WebSocket interface after restart...")
|
||||
await self._interface.wait_for_ready(timeout=30)
|
||||
self.logger.info("Computer reconnected and ready after restart")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to reconnect after restart: {e}")
|
||||
# Try a full reset if reconnection failed
|
||||
try:
|
||||
await self.stop()
|
||||
finally:
|
||||
await self.run()
|
||||
|
||||
# @property
|
||||
async def get_ip(self, max_retries: int = 15, retry_delay: int = 3) -> str:
|
||||
"""Get the IP address of the VM or localhost if using host computer server.
|
||||
|
||||
This method delegates to the provider's get_ip method, which waits indefinitely
|
||||
|
||||
This method delegates to the provider's get_ip method, which waits indefinitely
|
||||
until the VM has a valid IP address.
|
||||
|
||||
|
||||
Args:
|
||||
max_retries: Unused parameter, kept for backward compatibility
|
||||
retry_delay: Delay between retries in seconds (default: 2)
|
||||
|
||||
|
||||
Returns:
|
||||
IP address of the VM or localhost if using host computer server
|
||||
"""
|
||||
# For host computer server, always return localhost immediately
|
||||
if self.use_host_computer_server:
|
||||
return "127.0.0.1"
|
||||
|
||||
|
||||
# Get IP from the provider - each provider implements its own waiting logic
|
||||
if self.config.vm_provider is None:
|
||||
raise RuntimeError("VM provider is not initialized")
|
||||
|
||||
|
||||
# Log that we're waiting for the IP
|
||||
self.logger.info(f"Waiting for VM {self.config.name} to get an IP address...")
|
||||
|
||||
|
||||
# Call the provider's get_ip method which will wait indefinitely
|
||||
storage_param = "ephemeral" if self.ephemeral else self.storage
|
||||
|
||||
|
||||
# Log the image being used
|
||||
self.logger.info(f"Running VM using image: {self.image}")
|
||||
|
||||
|
||||
# Call provider.get_ip with explicit image parameter
|
||||
ip = await self.config.vm_provider.get_ip(
|
||||
name=self.config.name,
|
||||
storage=storage_param,
|
||||
retry_delay=retry_delay
|
||||
name=self.config.name, storage=storage_param, retry_delay=retry_delay
|
||||
)
|
||||
|
||||
|
||||
# Log success
|
||||
self.logger.info(f"VM {self.config.name} has IP address: {ip}")
|
||||
return ip
|
||||
|
||||
|
||||
async def wait_vm_ready(self) -> Optional[Dict[str, Any]]:
|
||||
"""Wait for VM to be ready with an IP address.
|
||||
@@ -687,8 +820,8 @@ class Computer:
|
||||
if self.config.vm_provider is not None:
|
||||
vm = await self.config.vm_provider.get_vm(self.config.name)
|
||||
# VM data is returned as a dictionary from the Lumier provider
|
||||
status = vm.get('status', 'unknown') if vm else "unknown"
|
||||
ip = vm.get('ip_address') if vm else None
|
||||
status = vm.get("status", "unknown") if vm else "unknown"
|
||||
ip = vm.get("ip_address") if vm else None
|
||||
else:
|
||||
status = "unknown"
|
||||
ip = None
|
||||
@@ -705,16 +838,13 @@ class Computer:
|
||||
self.logger.info(
|
||||
f"Updating VM settings: CPU={cpu or self.config.cpu}, Memory={memory or self.config.memory}"
|
||||
)
|
||||
update_opts = {
|
||||
"cpu": cpu or int(self.config.cpu),
|
||||
"memory": memory or self.config.memory
|
||||
}
|
||||
update_opts = {"cpu": cpu or int(self.config.cpu), "memory": memory or self.config.memory}
|
||||
if self.config.vm_provider is not None:
|
||||
await self.config.vm_provider.update_vm(
|
||||
name=self.config.name,
|
||||
update_opts=update_opts,
|
||||
storage=self.storage # Pass storage explicitly for clarity
|
||||
)
|
||||
await self.config.vm_provider.update_vm(
|
||||
name=self.config.name,
|
||||
update_opts=update_opts,
|
||||
storage=self.storage, # Pass storage explicitly for clarity
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("VM provider not initialized")
|
||||
|
||||
@@ -781,65 +911,94 @@ class Computer:
|
||||
"""
|
||||
return await self.interface.to_screenshot_coordinates(x, y)
|
||||
|
||||
|
||||
# Add virtual environment management functions to computer interface
|
||||
async def venv_install(self, venv_name: str, requirements: list[str]):
|
||||
"""Install packages in a virtual environment.
|
||||
|
||||
|
||||
Args:
|
||||
venv_name: Name of the virtual environment
|
||||
requirements: List of package requirements to install
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (stdout, stderr) from the installation command
|
||||
"""
|
||||
requirements = requirements or []
|
||||
# Windows vs POSIX handling
|
||||
if self.os_type == "windows":
|
||||
# Use %USERPROFILE% for home directory and cmd.exe semantics
|
||||
venv_path = f"%USERPROFILE%\\.venvs\\{venv_name}"
|
||||
ensure_dir_cmd = 'if not exist "%USERPROFILE%\\.venvs" mkdir "%USERPROFILE%\\.venvs"'
|
||||
create_cmd = f'if not exist "{venv_path}" python -m venv "{venv_path}"'
|
||||
requirements_str = " ".join(requirements)
|
||||
# Activate via activate.bat and install
|
||||
install_cmd = (
|
||||
f'call "{venv_path}\\Scripts\\activate.bat" && pip install {requirements_str}'
|
||||
if requirements_str
|
||||
else "echo No requirements to install"
|
||||
)
|
||||
await self.interface.run_command(ensure_dir_cmd)
|
||||
await self.interface.run_command(create_cmd)
|
||||
return await self.interface.run_command(install_cmd)
|
||||
else:
|
||||
# POSIX (macOS/Linux)
|
||||
venv_path = f"$HOME/.venvs/{venv_name}"
|
||||
create_cmd = f'mkdir -p "$HOME/.venvs" && python3 -m venv "{venv_path}"'
|
||||
# Check if venv exists, if not create it
|
||||
check_cmd = f'test -d "{venv_path}" || ({create_cmd})'
|
||||
_ = await self.interface.run_command(check_cmd)
|
||||
# Install packages
|
||||
requirements_str = " ".join(requirements)
|
||||
install_cmd = (
|
||||
f'. "{venv_path}/bin/activate" && pip install {requirements_str}'
|
||||
if requirements_str
|
||||
else "echo No requirements to install"
|
||||
)
|
||||
return await self.interface.run_command(install_cmd)
|
||||
|
||||
# Create virtual environment if it doesn't exist
|
||||
venv_path = f"~/.venvs/{venv_name}"
|
||||
create_cmd = f"mkdir -p ~/.venvs && python3 -m venv {venv_path}"
|
||||
|
||||
# Check if venv exists, if not create it
|
||||
check_cmd = f"test -d {venv_path} || ({create_cmd})"
|
||||
_ = await self.interface.run_command(check_cmd)
|
||||
|
||||
# Install packages
|
||||
requirements_str = " ".join(requirements)
|
||||
install_cmd = f". {venv_path}/bin/activate && pip install {requirements_str}"
|
||||
return await self.interface.run_command(install_cmd)
|
||||
|
||||
async def venv_cmd(self, venv_name: str, command: str):
|
||||
"""Execute a shell command in a virtual environment.
|
||||
|
||||
|
||||
Args:
|
||||
venv_name: Name of the virtual environment
|
||||
command: Shell command to execute in the virtual environment
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (stdout, stderr) from the command execution
|
||||
"""
|
||||
venv_path = f"~/.venvs/{venv_name}"
|
||||
|
||||
# Check if virtual environment exists
|
||||
check_cmd = f"test -d {venv_path}"
|
||||
result = await self.interface.run_command(check_cmd)
|
||||
|
||||
if result.stderr or "test:" in result.stdout: # venv doesn't exist
|
||||
return "", f"Virtual environment '{venv_name}' does not exist. Create it first using venv_install."
|
||||
|
||||
# Activate virtual environment and run command
|
||||
full_command = f". {venv_path}/bin/activate && {command}"
|
||||
return await self.interface.run_command(full_command)
|
||||
|
||||
if self.os_type == "windows":
|
||||
# Windows (cmd.exe)
|
||||
venv_path = f"%USERPROFILE%\\.venvs\\{venv_name}"
|
||||
# Check existence and signal if missing
|
||||
check_cmd = f'if not exist "{venv_path}" (echo VENV_NOT_FOUND) else (echo VENV_FOUND)'
|
||||
result = await self.interface.run_command(check_cmd)
|
||||
if "VENV_NOT_FOUND" in getattr(result, "stdout", ""):
|
||||
# Auto-create the venv with no requirements
|
||||
await self.venv_install(venv_name, [])
|
||||
# Activate and run the command
|
||||
full_command = f'call "{venv_path}\\Scripts\\activate.bat" && {command}'
|
||||
return await self.interface.run_command(full_command)
|
||||
else:
|
||||
# POSIX (macOS/Linux)
|
||||
venv_path = f"$HOME/.venvs/{venv_name}"
|
||||
# Check if virtual environment exists
|
||||
check_cmd = f'test -d "{venv_path}"'
|
||||
result = await self.interface.run_command(check_cmd)
|
||||
if result.stderr or "test:" in result.stdout: # venv doesn't exist
|
||||
# Auto-create the venv with no requirements
|
||||
await self.venv_install(venv_name, [])
|
||||
# Activate virtual environment and run command
|
||||
full_command = f'. "{venv_path}/bin/activate" && {command}'
|
||||
return await self.interface.run_command(full_command)
|
||||
|
||||
async def venv_exec(self, venv_name: str, python_func, *args, **kwargs):
|
||||
"""Execute Python function in a virtual environment using source code extraction.
|
||||
|
||||
|
||||
Args:
|
||||
venv_name: Name of the virtual environment
|
||||
python_func: A callable function to execute
|
||||
*args: Positional arguments to pass to the function
|
||||
**kwargs: Keyword arguments to pass to the function
|
||||
|
||||
|
||||
Returns:
|
||||
The result of the function execution, or raises any exception that occurred
|
||||
"""
|
||||
@@ -847,29 +1006,29 @@ class Computer:
|
||||
import inspect
|
||||
import json
|
||||
import textwrap
|
||||
|
||||
|
||||
try:
|
||||
# Get function source code using inspect.getsource
|
||||
source = inspect.getsource(python_func)
|
||||
# Remove common leading whitespace (dedent)
|
||||
func_source = textwrap.dedent(source).strip()
|
||||
|
||||
|
||||
# Remove decorators
|
||||
while func_source.lstrip().startswith("@"):
|
||||
func_source = func_source.split("\n", 1)[1].strip()
|
||||
|
||||
|
||||
# Get function name for execution
|
||||
func_name = python_func.__name__
|
||||
|
||||
|
||||
# Serialize args and kwargs as JSON (safer than dill for cross-version compatibility)
|
||||
args_json = json.dumps(args, default=str)
|
||||
kwargs_json = json.dumps(kwargs, default=str)
|
||||
|
||||
|
||||
except OSError as e:
|
||||
raise Exception(f"Cannot retrieve source code for function {python_func.__name__}: {e}")
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to reconstruct function source: {e}")
|
||||
|
||||
|
||||
# Create Python code that will define and execute the function
|
||||
python_code = f'''
|
||||
import json
|
||||
@@ -914,25 +1073,27 @@ output_json = json.dumps(output_payload, default=str)
|
||||
# Print the JSON output with markers
|
||||
print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
|
||||
'''
|
||||
|
||||
|
||||
# Encode the Python code in base64 to avoid shell escaping issues
|
||||
encoded_code = base64.b64encode(python_code.encode('utf-8')).decode('ascii')
|
||||
|
||||
encoded_code = base64.b64encode(python_code.encode("utf-8")).decode("ascii")
|
||||
|
||||
# Execute the Python code in the virtual environment
|
||||
python_command = f"python -c \"import base64; exec(base64.b64decode('{encoded_code}').decode('utf-8'))\""
|
||||
python_command = (
|
||||
f"python -c \"import base64; exec(base64.b64decode('{encoded_code}').decode('utf-8'))\""
|
||||
)
|
||||
result = await self.venv_cmd(venv_name, python_command)
|
||||
|
||||
|
||||
# Parse the output to extract the payload
|
||||
start_marker = "<<<VENV_EXEC_START>>>"
|
||||
end_marker = "<<<VENV_EXEC_END>>>"
|
||||
|
||||
# Print original stdout
|
||||
print(result.stdout[:result.stdout.find(start_marker)])
|
||||
|
||||
print(result.stdout[: result.stdout.find(start_marker)])
|
||||
|
||||
if start_marker in result.stdout and end_marker in result.stdout:
|
||||
start_idx = result.stdout.find(start_marker) + len(start_marker)
|
||||
end_idx = result.stdout.find(end_marker)
|
||||
|
||||
|
||||
if start_idx < end_idx:
|
||||
output_json = result.stdout[start_idx:end_idx]
|
||||
|
||||
@@ -941,7 +1102,7 @@ print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
|
||||
output_payload = json.loads(output_json)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to decode output payload: {e}")
|
||||
|
||||
|
||||
if output_payload["success"]:
|
||||
return output_payload["result"]
|
||||
else:
|
||||
@@ -953,4 +1114,6 @@ print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
|
||||
raise Exception("Invalid output format: markers found but no content between them")
|
||||
else:
|
||||
# Fallback: return stdout/stderr if no payload markers found
|
||||
raise Exception(f"No output payload found. stdout: {result.stdout}, stderr: {result.stderr}")
|
||||
raise Exception(
|
||||
f"No output payload found. stdout: {result.stdout}, stderr: {result.stderr}"
|
||||
)
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
import asyncio
|
||||
from .interface.models import KeyType, Key
|
||||
|
||||
from .interface.models import Key, KeyType
|
||||
|
||||
|
||||
class DioramaComputer:
|
||||
"""
|
||||
A Computer-compatible proxy for Diorama that sends commands over the ComputerInterface.
|
||||
"""
|
||||
|
||||
def __init__(self, computer, apps):
|
||||
"""
|
||||
Initialize the DioramaComputer with a computer instance and list of apps.
|
||||
|
||||
|
||||
Args:
|
||||
computer: The computer instance to proxy commands through
|
||||
apps: List of applications available in the diorama environment
|
||||
@@ -21,7 +24,7 @@ class DioramaComputer:
|
||||
async def __aenter__(self):
|
||||
"""
|
||||
Async context manager entry point.
|
||||
|
||||
|
||||
Returns:
|
||||
self: The DioramaComputer instance
|
||||
"""
|
||||
@@ -31,7 +34,7 @@ class DioramaComputer:
|
||||
async def run(self):
|
||||
"""
|
||||
Initialize and run the DioramaComputer if not already initialized.
|
||||
|
||||
|
||||
Returns:
|
||||
self: The DioramaComputer instance
|
||||
"""
|
||||
@@ -39,14 +42,16 @@ class DioramaComputer:
|
||||
await self.__aenter__()
|
||||
return self
|
||||
|
||||
|
||||
class DioramaComputerInterface:
|
||||
"""
|
||||
Diorama Interface proxy that sends diorama_cmds via the Computer's interface.
|
||||
"""
|
||||
|
||||
def __init__(self, computer, apps):
|
||||
"""
|
||||
Initialize the DioramaComputerInterface.
|
||||
|
||||
|
||||
Args:
|
||||
computer: The computer instance to send commands through
|
||||
apps: List of applications available in the diorama environment
|
||||
@@ -58,14 +63,14 @@ class DioramaComputerInterface:
|
||||
async def _send_cmd(self, action, arguments=None):
|
||||
"""
|
||||
Send a command to the diorama interface through the computer.
|
||||
|
||||
|
||||
Args:
|
||||
action (str): The action/command to execute
|
||||
arguments (dict, optional): Additional arguments for the command
|
||||
|
||||
|
||||
Returns:
|
||||
The result from the diorama command execution
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the computer interface is not initialized or command fails
|
||||
"""
|
||||
@@ -77,25 +82,30 @@ class DioramaComputerInterface:
|
||||
raise RuntimeError("Computer interface not initialized. Call run() first.")
|
||||
result = await iface.diorama_cmd(action, arguments)
|
||||
if not result.get("success"):
|
||||
raise RuntimeError(f"Diorama command failed: {result.get('error')}\n{result.get('trace')}")
|
||||
raise RuntimeError(
|
||||
f"Diorama command failed: {result.get('error')}\n{result.get('trace')}"
|
||||
)
|
||||
return result.get("result")
|
||||
|
||||
async def screenshot(self, as_bytes=True):
|
||||
"""
|
||||
Take a screenshot of the diorama scene.
|
||||
|
||||
|
||||
Args:
|
||||
as_bytes (bool): If True, return image as bytes; if False, return PIL Image object
|
||||
|
||||
|
||||
Returns:
|
||||
bytes or PIL.Image: Screenshot data in the requested format
|
||||
"""
|
||||
from PIL import Image
|
||||
import base64
|
||||
|
||||
from PIL import Image
|
||||
|
||||
result = await self._send_cmd("screenshot")
|
||||
# assume result is a b64 string of an image
|
||||
img_bytes = base64.b64decode(result)
|
||||
import io
|
||||
|
||||
img = Image.open(io.BytesIO(img_bytes))
|
||||
self._scene_size = img.size
|
||||
return img_bytes if as_bytes else img
|
||||
@@ -103,7 +113,7 @@ class DioramaComputerInterface:
|
||||
async def get_screen_size(self):
|
||||
"""
|
||||
Get the dimensions of the diorama scene.
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing 'width' and 'height' keys with pixel dimensions
|
||||
"""
|
||||
@@ -114,7 +124,7 @@ class DioramaComputerInterface:
|
||||
async def move_cursor(self, x, y):
|
||||
"""
|
||||
Move the cursor to the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x (int): X coordinate to move cursor to
|
||||
y (int): Y coordinate to move cursor to
|
||||
@@ -124,7 +134,7 @@ class DioramaComputerInterface:
|
||||
async def left_click(self, x=None, y=None):
|
||||
"""
|
||||
Perform a left mouse click at the specified coordinates or current cursor position.
|
||||
|
||||
|
||||
Args:
|
||||
x (int, optional): X coordinate to click at. If None, clicks at current cursor position
|
||||
y (int, optional): Y coordinate to click at. If None, clicks at current cursor position
|
||||
@@ -134,7 +144,7 @@ class DioramaComputerInterface:
|
||||
async def right_click(self, x=None, y=None):
|
||||
"""
|
||||
Perform a right mouse click at the specified coordinates or current cursor position.
|
||||
|
||||
|
||||
Args:
|
||||
x (int, optional): X coordinate to click at. If None, clicks at current cursor position
|
||||
y (int, optional): Y coordinate to click at. If None, clicks at current cursor position
|
||||
@@ -144,7 +154,7 @@ class DioramaComputerInterface:
|
||||
async def double_click(self, x=None, y=None):
|
||||
"""
|
||||
Perform a double mouse click at the specified coordinates or current cursor position.
|
||||
|
||||
|
||||
Args:
|
||||
x (int, optional): X coordinate to double-click at. If None, clicks at current cursor position
|
||||
y (int, optional): Y coordinate to double-click at. If None, clicks at current cursor position
|
||||
@@ -154,7 +164,7 @@ class DioramaComputerInterface:
|
||||
async def scroll_up(self, clicks=1):
|
||||
"""
|
||||
Scroll up by the specified number of clicks.
|
||||
|
||||
|
||||
Args:
|
||||
clicks (int): Number of scroll clicks to perform upward. Defaults to 1
|
||||
"""
|
||||
@@ -163,7 +173,7 @@ class DioramaComputerInterface:
|
||||
async def scroll_down(self, clicks=1):
|
||||
"""
|
||||
Scroll down by the specified number of clicks.
|
||||
|
||||
|
||||
Args:
|
||||
clicks (int): Number of scroll clicks to perform downward. Defaults to 1
|
||||
"""
|
||||
@@ -172,7 +182,7 @@ class DioramaComputerInterface:
|
||||
async def drag_to(self, x, y, duration=0.5):
|
||||
"""
|
||||
Drag from the current cursor position to the specified coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x (int): X coordinate to drag to
|
||||
y (int): Y coordinate to drag to
|
||||
@@ -183,7 +193,7 @@ class DioramaComputerInterface:
|
||||
async def get_cursor_position(self):
|
||||
"""
|
||||
Get the current cursor position.
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing the current cursor coordinates
|
||||
"""
|
||||
@@ -192,7 +202,7 @@ class DioramaComputerInterface:
|
||||
async def type_text(self, text):
|
||||
"""
|
||||
Type the specified text at the current cursor position.
|
||||
|
||||
|
||||
Args:
|
||||
text (str): The text to type
|
||||
"""
|
||||
@@ -201,7 +211,7 @@ class DioramaComputerInterface:
|
||||
async def press_key(self, key):
|
||||
"""
|
||||
Press a single key.
|
||||
|
||||
|
||||
Args:
|
||||
key: The key to press
|
||||
"""
|
||||
@@ -210,10 +220,10 @@ class DioramaComputerInterface:
|
||||
async def hotkey(self, *keys):
|
||||
"""
|
||||
Press multiple keys simultaneously as a hotkey combination.
|
||||
|
||||
|
||||
Args:
|
||||
*keys: Variable number of keys to press together. Can be Key enum instances or strings
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If any key is not a Key enum or string type
|
||||
"""
|
||||
@@ -224,7 +234,9 @@ class DioramaComputerInterface:
|
||||
elif isinstance(key, str):
|
||||
# Try to convert to enum if it matches a known key
|
||||
key_or_enum = Key.from_string(key)
|
||||
actual_keys.append(key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum)
|
||||
actual_keys.append(
|
||||
key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.")
|
||||
await self._send_cmd("hotkey", {"keys": actual_keys})
|
||||
@@ -232,11 +244,11 @@ class DioramaComputerInterface:
|
||||
async def to_screen_coordinates(self, x, y):
|
||||
"""
|
||||
Convert coordinates to screen coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x (int): X coordinate to convert
|
||||
y (int): Y coordinate to convert
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing the converted screen coordinates
|
||||
"""
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""
|
||||
Helper functions and decorators for the Computer module.
|
||||
"""
|
||||
import logging
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Optional, TypeVar, cast
|
||||
|
||||
@@ -11,10 +12,11 @@ _default_computer = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_default_computer(computer):
|
||||
"""
|
||||
Set the default computer instance to be used by the remote decorator.
|
||||
|
||||
|
||||
Args:
|
||||
computer: The computer instance to use as default
|
||||
"""
|
||||
@@ -25,21 +27,24 @@ def set_default_computer(computer):
|
||||
def sandboxed(venv_name: str = "default", computer: str = "default", max_retries: int = 3):
|
||||
"""
|
||||
Decorator that wraps a function to be executed remotely via computer.venv_exec
|
||||
|
||||
|
||||
Args:
|
||||
venv_name: Name of the virtual environment to execute in
|
||||
computer: The computer instance to use, or "default" to use the globally set default
|
||||
max_retries: Maximum number of retries for the remote execution
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Determine which computer instance to use
|
||||
comp = computer if computer != "default" else _default_computer
|
||||
|
||||
|
||||
if comp is None:
|
||||
raise RuntimeError("No computer instance available. Either specify a computer instance or call set_default_computer() first.")
|
||||
|
||||
raise RuntimeError(
|
||||
"No computer instance available. Either specify a computer instance or call set_default_computer() first."
|
||||
)
|
||||
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
return await comp.venv_exec(venv_name, func, *args, **kwargs)
|
||||
@@ -48,5 +53,7 @@ def sandboxed(venv_name: str = "default", computer: str = "default", max_retries
|
||||
await asyncio.sleep(1)
|
||||
if i == max_retries - 1:
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
Interface package for Computer SDK.
|
||||
"""
|
||||
|
||||
from .factory import InterfaceFactory
|
||||
from .base import BaseComputerInterface
|
||||
from .factory import InterfaceFactory
|
||||
from .macos import MacOSComputerInterface
|
||||
|
||||
__all__ = [
|
||||
"InterfaceFactory",
|
||||
"BaseComputerInterface",
|
||||
"MacOSComputerInterface",
|
||||
]
|
||||
]
|
||||
|
||||
@@ -1,14 +1,23 @@
|
||||
"""Base interface for computer control."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, Any, Tuple, List
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from ..logger import Logger, LogLevel
|
||||
from .models import MouseButton, CommandResult
|
||||
from .models import CommandResult, MouseButton
|
||||
|
||||
|
||||
class BaseComputerInterface(ABC):
|
||||
"""Base class for computer control interfaces."""
|
||||
|
||||
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
ip_address: str,
|
||||
username: str = "lume",
|
||||
password: str = "lume",
|
||||
api_key: Optional[str] = None,
|
||||
vm_name: Optional[str] = None,
|
||||
):
|
||||
"""Initialize interface.
|
||||
|
||||
Args:
|
||||
@@ -24,7 +33,7 @@ class BaseComputerInterface(ABC):
|
||||
self.api_key = api_key
|
||||
self.vm_name = vm_name
|
||||
self.logger = Logger("cua.interface", LogLevel.NORMAL)
|
||||
|
||||
|
||||
# Optional default delay time between commands (in seconds)
|
||||
self.delay: float = 0.0
|
||||
|
||||
@@ -55,9 +64,15 @@ class BaseComputerInterface(ABC):
|
||||
|
||||
# Mouse Actions
|
||||
@abstractmethod
|
||||
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: "MouseButton" = "left", delay: Optional[float] = None) -> None:
|
||||
async def mouse_down(
|
||||
self,
|
||||
x: Optional[int] = None,
|
||||
y: Optional[int] = None,
|
||||
button: "MouseButton" = "left",
|
||||
delay: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Press and hold a mouse button.
|
||||
|
||||
|
||||
Args:
|
||||
x: X coordinate to press at. If None, uses current cursor position.
|
||||
y: Y coordinate to press at. If None, uses current cursor position.
|
||||
@@ -65,11 +80,17 @@ class BaseComputerInterface(ABC):
|
||||
delay: Optional delay in seconds after the action
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: "MouseButton" = "left", delay: Optional[float] = None) -> None:
|
||||
async def mouse_up(
|
||||
self,
|
||||
x: Optional[int] = None,
|
||||
y: Optional[int] = None,
|
||||
button: "MouseButton" = "left",
|
||||
delay: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Release a mouse button.
|
||||
|
||||
|
||||
Args:
|
||||
x: X coordinate to release at. If None, uses current cursor position.
|
||||
y: Y coordinate to release at. If None, uses current cursor position.
|
||||
@@ -77,11 +98,13 @@ class BaseComputerInterface(ABC):
|
||||
delay: Optional delay in seconds after the action
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None) -> None:
|
||||
async def left_click(
|
||||
self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None
|
||||
) -> None:
|
||||
"""Perform a left mouse button click.
|
||||
|
||||
|
||||
Args:
|
||||
x: X coordinate to click at. If None, uses current cursor position.
|
||||
y: Y coordinate to click at. If None, uses current cursor position.
|
||||
@@ -90,9 +113,11 @@ class BaseComputerInterface(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None) -> None:
|
||||
async def right_click(
|
||||
self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None
|
||||
) -> None:
|
||||
"""Perform a right mouse button click.
|
||||
|
||||
|
||||
Args:
|
||||
x: X coordinate to click at. If None, uses current cursor position.
|
||||
y: Y coordinate to click at. If None, uses current cursor position.
|
||||
@@ -101,9 +126,11 @@ class BaseComputerInterface(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None) -> None:
|
||||
async def double_click(
|
||||
self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None
|
||||
) -> None:
|
||||
"""Perform a double left mouse button click.
|
||||
|
||||
|
||||
Args:
|
||||
x: X coordinate to double-click at. If None, uses current cursor position.
|
||||
y: Y coordinate to double-click at. If None, uses current cursor position.
|
||||
@@ -114,7 +141,7 @@ class BaseComputerInterface(ABC):
|
||||
@abstractmethod
|
||||
async def move_cursor(self, x: int, y: int, delay: Optional[float] = None) -> None:
|
||||
"""Move the cursor to the specified screen coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
x: X coordinate to move cursor to.
|
||||
y: Y coordinate to move cursor to.
|
||||
@@ -123,7 +150,14 @@ class BaseComputerInterface(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def drag_to(self, x: int, y: int, button: str = "left", duration: float = 0.5, delay: Optional[float] = None) -> None:
|
||||
async def drag_to(
|
||||
self,
|
||||
x: int,
|
||||
y: int,
|
||||
button: str = "left",
|
||||
duration: float = 0.5,
|
||||
delay: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Drag from current position to specified coordinates.
|
||||
|
||||
Args:
|
||||
@@ -136,7 +170,13 @@ class BaseComputerInterface(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def drag(self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5, delay: Optional[float] = None) -> None:
|
||||
async def drag(
|
||||
self,
|
||||
path: List[Tuple[int, int]],
|
||||
button: str = "left",
|
||||
duration: float = 0.5,
|
||||
delay: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Drag the cursor along a path of coordinates.
|
||||
|
||||
Args:
|
||||
@@ -151,27 +191,27 @@ class BaseComputerInterface(ABC):
|
||||
@abstractmethod
|
||||
async def key_down(self, key: str, delay: Optional[float] = None) -> None:
|
||||
"""Press and hold a key.
|
||||
|
||||
|
||||
Args:
|
||||
key: The key to press and hold (e.g., 'a', 'shift', 'ctrl').
|
||||
delay: Optional delay in seconds after the action.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def key_up(self, key: str, delay: Optional[float] = None) -> None:
|
||||
"""Release a previously pressed key.
|
||||
|
||||
|
||||
Args:
|
||||
key: The key to release (e.g., 'a', 'shift', 'ctrl').
|
||||
delay: Optional delay in seconds after the action.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def type_text(self, text: str, delay: Optional[float] = None) -> None:
|
||||
"""Type the specified text string.
|
||||
|
||||
|
||||
Args:
|
||||
text: The text string to type.
|
||||
delay: Optional delay in seconds after the action.
|
||||
@@ -181,7 +221,7 @@ class BaseComputerInterface(ABC):
|
||||
@abstractmethod
|
||||
async def press_key(self, key: str, delay: Optional[float] = None) -> None:
|
||||
"""Press and release a single key.
|
||||
|
||||
|
||||
Args:
|
||||
key: The key to press (e.g., 'a', 'enter', 'escape').
|
||||
delay: Optional delay in seconds after the action.
|
||||
@@ -191,7 +231,7 @@ class BaseComputerInterface(ABC):
|
||||
@abstractmethod
|
||||
async def hotkey(self, *keys: str, delay: Optional[float] = None) -> None:
|
||||
"""Press multiple keys simultaneously (keyboard shortcut).
|
||||
|
||||
|
||||
Args:
|
||||
*keys: Variable number of keys to press together (e.g., 'ctrl', 'c').
|
||||
delay: Optional delay in seconds after the action.
|
||||
@@ -202,18 +242,18 @@ class BaseComputerInterface(ABC):
|
||||
@abstractmethod
|
||||
async def scroll(self, x: int, y: int, delay: Optional[float] = None) -> None:
|
||||
"""Scroll the mouse wheel by specified amounts.
|
||||
|
||||
|
||||
Args:
|
||||
x: Horizontal scroll amount (positive = right, negative = left).
|
||||
y: Vertical scroll amount (positive = up, negative = down).
|
||||
delay: Optional delay in seconds after the action.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def scroll_down(self, clicks: int = 1, delay: Optional[float] = None) -> None:
|
||||
"""Scroll down by the specified number of clicks.
|
||||
|
||||
|
||||
Args:
|
||||
clicks: Number of scroll clicks to perform downward.
|
||||
delay: Optional delay in seconds after the action.
|
||||
@@ -223,7 +263,7 @@ class BaseComputerInterface(ABC):
|
||||
@abstractmethod
|
||||
async def scroll_up(self, clicks: int = 1, delay: Optional[float] = None) -> None:
|
||||
"""Scroll up by the specified number of clicks.
|
||||
|
||||
|
||||
Args:
|
||||
clicks: Number of scroll clicks to perform upward.
|
||||
delay: Optional delay in seconds after the action.
|
||||
@@ -252,7 +292,7 @@ class BaseComputerInterface(ABC):
|
||||
@abstractmethod
|
||||
async def get_cursor_position(self) -> Dict[str, int]:
|
||||
"""Get the current cursor position on screen.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with 'x' and 'y' keys containing cursor coordinates.
|
||||
"""
|
||||
@@ -262,7 +302,7 @@ class BaseComputerInterface(ABC):
|
||||
@abstractmethod
|
||||
async def copy_to_clipboard(self) -> str:
|
||||
"""Get the current clipboard content.
|
||||
|
||||
|
||||
Returns:
|
||||
The text content currently stored in the clipboard.
|
||||
"""
|
||||
@@ -271,7 +311,7 @@ class BaseComputerInterface(ABC):
|
||||
@abstractmethod
|
||||
async def set_clipboard(self, text: str) -> None:
|
||||
"""Set the clipboard content to the specified text.
|
||||
|
||||
|
||||
Args:
|
||||
text: The text to store in the clipboard.
|
||||
"""
|
||||
@@ -281,10 +321,10 @@ class BaseComputerInterface(ABC):
|
||||
@abstractmethod
|
||||
async def file_exists(self, path: str) -> bool:
|
||||
"""Check if a file exists at the specified path.
|
||||
|
||||
|
||||
Args:
|
||||
path: The file path to check.
|
||||
|
||||
|
||||
Returns:
|
||||
True if the file exists, False otherwise.
|
||||
"""
|
||||
@@ -293,128 +333,128 @@ class BaseComputerInterface(ABC):
|
||||
@abstractmethod
|
||||
async def directory_exists(self, path: str) -> bool:
|
||||
"""Check if a directory exists at the specified path.
|
||||
|
||||
|
||||
Args:
|
||||
path: The directory path to check.
|
||||
|
||||
|
||||
Returns:
|
||||
True if the directory exists, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def list_dir(self, path: str) -> List[str]:
|
||||
"""List the contents of a directory.
|
||||
|
||||
|
||||
Args:
|
||||
path: The directory path to list.
|
||||
|
||||
|
||||
Returns:
|
||||
List of file and directory names in the specified directory.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def read_text(self, path: str) -> str:
|
||||
"""Read the text contents of a file.
|
||||
|
||||
|
||||
Args:
|
||||
path: The file path to read from.
|
||||
|
||||
|
||||
Returns:
|
||||
The text content of the file.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def write_text(self, path: str, content: str) -> None:
|
||||
"""Write text content to a file.
|
||||
|
||||
|
||||
Args:
|
||||
path: The file path to write to.
|
||||
content: The text content to write.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def read_bytes(self, path: str, offset: int = 0, length: Optional[int] = None) -> bytes:
|
||||
"""Read file binary contents with optional seeking support.
|
||||
|
||||
|
||||
Args:
|
||||
path: Path to the file
|
||||
offset: Byte offset to start reading from (default: 0)
|
||||
length: Number of bytes to read (default: None for entire file)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def write_bytes(self, path: str, content: bytes) -> None:
|
||||
"""Write binary content to a file.
|
||||
|
||||
|
||||
Args:
|
||||
path: The file path to write to.
|
||||
content: The binary content to write.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def delete_file(self, path: str) -> None:
|
||||
"""Delete a file at the specified path.
|
||||
|
||||
|
||||
Args:
|
||||
path: The file path to delete.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def create_dir(self, path: str) -> None:
|
||||
"""Create a directory at the specified path.
|
||||
|
||||
|
||||
Args:
|
||||
path: The directory path to create.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def delete_dir(self, path: str) -> None:
|
||||
"""Delete a directory at the specified path.
|
||||
|
||||
|
||||
Args:
|
||||
path: The directory path to delete.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def get_file_size(self, path: str) -> int:
|
||||
"""Get the size of a file in bytes.
|
||||
|
||||
|
||||
Args:
|
||||
path: The file path to get the size of.
|
||||
|
||||
|
||||
Returns:
|
||||
The size of the file in bytes.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def run_command(self, command: str) -> CommandResult:
|
||||
"""Run shell command and return structured result.
|
||||
|
||||
|
||||
Executes a shell command using subprocess.run with shell=True and check=False.
|
||||
The command is run in the target environment and captures both stdout and stderr.
|
||||
|
||||
|
||||
Args:
|
||||
command (str): The shell command to execute
|
||||
|
||||
|
||||
Returns:
|
||||
CommandResult: A structured result containing:
|
||||
- stdout (str): Standard output from the command
|
||||
- stderr (str): Standard error from the command
|
||||
- stderr (str): Standard error from the command
|
||||
- returncode (int): Exit code from the command (0 indicates success)
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the command execution fails at the system level
|
||||
|
||||
|
||||
Example:
|
||||
result = await interface.run_command("ls -la")
|
||||
if result.returncode == 0:
|
||||
@@ -428,12 +468,12 @@ class BaseComputerInterface(ABC):
|
||||
@abstractmethod
|
||||
async def get_accessibility_tree(self) -> Dict:
|
||||
"""Get the accessibility tree of the current screen.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing the hierarchical accessibility information of screen elements.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def to_screen_coordinates(self, x: float, y: float) -> tuple[float, float]:
|
||||
"""Convert screenshot coordinates to screen coordinates.
|
||||
|
||||
@@ -1,42 +1,44 @@
|
||||
"""Factory for creating computer interfaces."""
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
from .base import BaseComputerInterface
|
||||
|
||||
|
||||
class InterfaceFactory:
|
||||
"""Factory for creating OS-specific computer interfaces."""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create_interface_for_os(
|
||||
os: Literal['macos', 'linux', 'windows'],
|
||||
os: Literal["macos", "linux", "windows"],
|
||||
ip_address: str,
|
||||
api_key: Optional[str] = None,
|
||||
vm_name: Optional[str] = None
|
||||
vm_name: Optional[str] = None,
|
||||
) -> BaseComputerInterface:
|
||||
"""Create an interface for the specified OS.
|
||||
|
||||
|
||||
Args:
|
||||
os: Operating system type ('macos', 'linux', or 'windows')
|
||||
ip_address: IP address of the computer to control
|
||||
api_key: Optional API key for cloud authentication
|
||||
vm_name: Optional VM name for cloud authentication
|
||||
|
||||
|
||||
Returns:
|
||||
BaseComputerInterface: The appropriate interface for the OS
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If the OS type is not supported
|
||||
"""
|
||||
# Import implementations here to avoid circular imports
|
||||
from .macos import MacOSComputerInterface
|
||||
from .linux import LinuxComputerInterface
|
||||
from .macos import MacOSComputerInterface
|
||||
from .windows import WindowsComputerInterface
|
||||
|
||||
if os == 'macos':
|
||||
|
||||
if os == "macos":
|
||||
return MacOSComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
|
||||
elif os == 'linux':
|
||||
elif os == "linux":
|
||||
return LinuxComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
|
||||
elif os == 'windows':
|
||||
elif os == "windows":
|
||||
return WindowsComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
|
||||
else:
|
||||
raise ValueError(f"Unsupported OS type: {os}")
|
||||
|
||||
@@ -2,21 +2,35 @@ import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
import websockets
|
||||
from PIL import Image
|
||||
|
||||
import websockets
|
||||
import aiohttp
|
||||
|
||||
from ..logger import Logger, LogLevel
|
||||
from ..utils import (
|
||||
bytes_to_image,
|
||||
decode_base64_image,
|
||||
draw_box,
|
||||
encode_base64_image,
|
||||
resize_image,
|
||||
)
|
||||
from .base import BaseComputerInterface
|
||||
from ..utils import decode_base64_image, encode_base64_image, bytes_to_image, draw_box, resize_image
|
||||
from .models import Key, KeyType, MouseButton, CommandResult
|
||||
from .models import CommandResult, Key, KeyType, MouseButton
|
||||
|
||||
|
||||
class GenericComputerInterface(BaseComputerInterface):
|
||||
"""Generic interface with common functionality for all supported platforms (Windows, Linux, macOS)."""
|
||||
|
||||
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None, logger_name: str = "computer.interface.generic"):
|
||||
def __init__(
|
||||
self,
|
||||
ip_address: str,
|
||||
username: str = "lume",
|
||||
password: str = "lume",
|
||||
api_key: Optional[str] = None,
|
||||
vm_name: Optional[str] = None,
|
||||
logger_name: str = "computer.interface.generic",
|
||||
):
|
||||
super().__init__(ip_address, username, password, api_key, vm_name)
|
||||
self._ws = None
|
||||
self._reconnect_task = None
|
||||
@@ -38,7 +52,7 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
|
||||
async def _handle_delay(self, delay: Optional[float] = None):
|
||||
"""Handle delay between commands using async sleep.
|
||||
|
||||
|
||||
Args:
|
||||
delay: Optional delay in seconds. If None, uses self.delay.
|
||||
"""
|
||||
@@ -51,18 +65,18 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
@property
|
||||
def ws_uri(self) -> str:
|
||||
"""Get the WebSocket URI using the current IP address.
|
||||
|
||||
|
||||
Returns:
|
||||
WebSocket URI for the Computer API Server
|
||||
"""
|
||||
protocol = "wss" if self.api_key else "ws"
|
||||
port = "8443" if self.api_key else "8000"
|
||||
return f"{protocol}://{self.ip_address}:{port}/ws"
|
||||
|
||||
|
||||
@property
|
||||
def rest_uri(self) -> str:
|
||||
"""Get the REST URI using the current IP address.
|
||||
|
||||
|
||||
Returns:
|
||||
REST URI for the Computer API Server
|
||||
"""
|
||||
@@ -71,23 +85,41 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
return f"{protocol}://{self.ip_address}:{port}/cmd"
|
||||
|
||||
# Mouse actions
|
||||
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left", delay: Optional[float] = None) -> None:
|
||||
async def mouse_down(
|
||||
self,
|
||||
x: Optional[int] = None,
|
||||
y: Optional[int] = None,
|
||||
button: str = "left",
|
||||
delay: Optional[float] = None,
|
||||
) -> None:
|
||||
await self._send_command("mouse_down", {"x": x, "y": y, "button": button})
|
||||
await self._handle_delay(delay)
|
||||
|
||||
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left", delay: Optional[float] = None) -> None:
|
||||
|
||||
async def mouse_up(
|
||||
self,
|
||||
x: Optional[int] = None,
|
||||
y: Optional[int] = None,
|
||||
button: str = "left",
|
||||
delay: Optional[float] = None,
|
||||
) -> None:
|
||||
await self._send_command("mouse_up", {"x": x, "y": y, "button": button})
|
||||
await self._handle_delay(delay)
|
||||
|
||||
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None) -> None:
|
||||
|
||||
async def left_click(
|
||||
self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None
|
||||
) -> None:
|
||||
await self._send_command("left_click", {"x": x, "y": y})
|
||||
await self._handle_delay(delay)
|
||||
|
||||
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None) -> None:
|
||||
async def right_click(
|
||||
self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None
|
||||
) -> None:
|
||||
await self._send_command("right_click", {"x": x, "y": y})
|
||||
await self._handle_delay(delay)
|
||||
|
||||
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None) -> None:
|
||||
async def double_click(
|
||||
self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None
|
||||
) -> None:
|
||||
await self._send_command("double_click", {"x": x, "y": y})
|
||||
await self._handle_delay(delay)
|
||||
|
||||
@@ -95,37 +127,40 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
await self._send_command("move_cursor", {"x": x, "y": y})
|
||||
await self._handle_delay(delay)
|
||||
|
||||
async def drag_to(self, x: int, y: int, button: "MouseButton" = "left", duration: float = 0.5, delay: Optional[float] = None) -> None:
|
||||
async def drag_to(
|
||||
self,
|
||||
x: int,
|
||||
y: int,
|
||||
button: "MouseButton" = "left",
|
||||
duration: float = 0.5,
|
||||
delay: Optional[float] = None,
|
||||
) -> None:
|
||||
await self._send_command(
|
||||
"drag_to", {"x": x, "y": y, "button": button, "duration": duration}
|
||||
)
|
||||
await self._handle_delay(delay)
|
||||
|
||||
async def drag(self, path: List[Tuple[int, int]], button: "MouseButton" = "left", duration: float = 0.5, delay: Optional[float] = None) -> None:
|
||||
await self._send_command(
|
||||
"drag", {"path": path, "button": button, "duration": duration}
|
||||
)
|
||||
async def drag(
|
||||
self,
|
||||
path: List[Tuple[int, int]],
|
||||
button: "MouseButton" = "left",
|
||||
duration: float = 0.5,
|
||||
delay: Optional[float] = None,
|
||||
) -> None:
|
||||
await self._send_command("drag", {"path": path, "button": button, "duration": duration})
|
||||
await self._handle_delay(delay)
|
||||
|
||||
# Keyboard Actions
|
||||
async def key_down(self, key: "KeyType", delay: Optional[float] = None) -> None:
|
||||
await self._send_command("key_down", {"key": key})
|
||||
await self._handle_delay(delay)
|
||||
|
||||
|
||||
async def key_up(self, key: "KeyType", delay: Optional[float] = None) -> None:
|
||||
await self._send_command("key_up", {"key": key})
|
||||
await self._handle_delay(delay)
|
||||
|
||||
|
||||
async def type_text(self, text: str, delay: Optional[float] = None) -> None:
|
||||
# Temporary fix for https://github.com/trycua/cua/issues/165
|
||||
# Check if text contains Unicode characters
|
||||
if any(ord(char) > 127 for char in text):
|
||||
# For Unicode text, use clipboard and paste
|
||||
await self.set_clipboard(text)
|
||||
await self.hotkey(Key.COMMAND, 'v')
|
||||
else:
|
||||
# For ASCII text, use the regular typing method
|
||||
await self._send_command("type_text", {"text": text})
|
||||
await self._send_command("type_text", {"text": text})
|
||||
await self._handle_delay(delay)
|
||||
|
||||
async def press(self, key: "KeyType", delay: Optional[float] = None) -> None:
|
||||
@@ -203,10 +238,12 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
elif isinstance(key, str):
|
||||
# Try to convert to enum if it matches a known key
|
||||
key_or_enum = Key.from_string(key)
|
||||
actual_keys.append(key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum)
|
||||
actual_keys.append(
|
||||
key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.")
|
||||
|
||||
|
||||
await self._send_command("hotkey", {"keys": actual_keys})
|
||||
await self._handle_delay(delay)
|
||||
|
||||
@@ -214,11 +251,11 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
async def scroll(self, x: int, y: int, delay: Optional[float] = None) -> None:
|
||||
await self._send_command("scroll", {"x": x, "y": y})
|
||||
await self._handle_delay(delay)
|
||||
|
||||
|
||||
async def scroll_down(self, clicks: int = 1, delay: Optional[float] = None) -> None:
|
||||
await self._send_command("scroll_down", {"clicks": clicks})
|
||||
await self._handle_delay(delay)
|
||||
|
||||
|
||||
async def scroll_up(self, clicks: int = 1, delay: Optional[float] = None) -> None:
|
||||
await self._send_command("scroll_up", {"clicks": clicks})
|
||||
await self._handle_delay(delay)
|
||||
@@ -302,27 +339,32 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
await self._send_command("set_clipboard", {"text": text})
|
||||
|
||||
# File Operations
|
||||
async def _write_bytes_chunked(self, path: str, content: bytes, append: bool = False, chunk_size: int = 1024 * 1024) -> None:
|
||||
async def _write_bytes_chunked(
|
||||
self, path: str, content: bytes, append: bool = False, chunk_size: int = 1024 * 1024
|
||||
) -> None:
|
||||
"""Write large files in chunks to avoid memory issues."""
|
||||
total_size = len(content)
|
||||
current_offset = 0
|
||||
|
||||
|
||||
while current_offset < total_size:
|
||||
chunk_end = min(current_offset + chunk_size, total_size)
|
||||
chunk_data = content[current_offset:chunk_end]
|
||||
|
||||
|
||||
# First chunk uses the original append flag, subsequent chunks always append
|
||||
chunk_append = append if current_offset == 0 else True
|
||||
|
||||
result = await self._send_command("write_bytes", {
|
||||
"path": path,
|
||||
"content_b64": encode_base64_image(chunk_data),
|
||||
"append": chunk_append
|
||||
})
|
||||
|
||||
|
||||
result = await self._send_command(
|
||||
"write_bytes",
|
||||
{
|
||||
"path": path,
|
||||
"content_b64": encode_base64_image(chunk_data),
|
||||
"append": chunk_append,
|
||||
},
|
||||
)
|
||||
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to write file chunk"))
|
||||
|
||||
|
||||
current_offset = chunk_end
|
||||
|
||||
async def write_bytes(self, path: str, content: bytes, append: bool = False) -> None:
|
||||
@@ -330,36 +372,39 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
if len(content) > 5 * 1024 * 1024: # 5MB threshold
|
||||
await self._write_bytes_chunked(path, content, append)
|
||||
return
|
||||
|
||||
result = await self._send_command("write_bytes", {"path": path, "content_b64": encode_base64_image(content), "append": append})
|
||||
|
||||
result = await self._send_command(
|
||||
"write_bytes",
|
||||
{"path": path, "content_b64": encode_base64_image(content), "append": append},
|
||||
)
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to write file"))
|
||||
|
||||
async def _read_bytes_chunked(self, path: str, offset: int, total_length: int, chunk_size: int = 1024 * 1024) -> bytes:
|
||||
async def _read_bytes_chunked(
|
||||
self, path: str, offset: int, total_length: int, chunk_size: int = 1024 * 1024
|
||||
) -> bytes:
|
||||
"""Read large files in chunks to avoid memory issues."""
|
||||
chunks = []
|
||||
current_offset = offset
|
||||
remaining = total_length
|
||||
|
||||
|
||||
while remaining > 0:
|
||||
read_size = min(chunk_size, remaining)
|
||||
result = await self._send_command("read_bytes", {
|
||||
"path": path,
|
||||
"offset": current_offset,
|
||||
"length": read_size
|
||||
})
|
||||
|
||||
result = await self._send_command(
|
||||
"read_bytes", {"path": path, "offset": current_offset, "length": read_size}
|
||||
)
|
||||
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to read file chunk"))
|
||||
|
||||
|
||||
content_b64 = result.get("content_b64", "")
|
||||
chunk_data = decode_base64_image(content_b64)
|
||||
chunks.append(chunk_data)
|
||||
|
||||
|
||||
current_offset += read_size
|
||||
remaining -= read_size
|
||||
|
||||
return b''.join(chunks)
|
||||
|
||||
return b"".join(chunks)
|
||||
|
||||
async def read_bytes(self, path: str, offset: int = 0, length: Optional[int] = None) -> bytes:
|
||||
# For large files, use chunked reading
|
||||
@@ -368,34 +413,36 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
file_size = await self.get_file_size(path)
|
||||
# If file is larger than 5MB, read in chunks
|
||||
if file_size > 5 * 1024 * 1024: # 5MB threshold
|
||||
return await self._read_bytes_chunked(path, offset, file_size - offset if offset > 0 else file_size)
|
||||
|
||||
result = await self._send_command("read_bytes", {
|
||||
"path": path,
|
||||
"offset": offset,
|
||||
"length": length
|
||||
})
|
||||
return await self._read_bytes_chunked(
|
||||
path, offset, file_size - offset if offset > 0 else file_size
|
||||
)
|
||||
|
||||
result = await self._send_command(
|
||||
"read_bytes", {"path": path, "offset": offset, "length": length}
|
||||
)
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to read file"))
|
||||
content_b64 = result.get("content_b64", "")
|
||||
return decode_base64_image(content_b64)
|
||||
|
||||
async def read_text(self, path: str, encoding: str = 'utf-8') -> str:
|
||||
async def read_text(self, path: str, encoding: str = "utf-8") -> str:
|
||||
"""Read text from a file with specified encoding.
|
||||
|
||||
|
||||
Args:
|
||||
path: Path to the file to read
|
||||
encoding: Text encoding to use (default: 'utf-8')
|
||||
|
||||
|
||||
Returns:
|
||||
str: The decoded text content of the file
|
||||
"""
|
||||
content_bytes = await self.read_bytes(path)
|
||||
return content_bytes.decode(encoding)
|
||||
|
||||
async def write_text(self, path: str, content: str, encoding: str = 'utf-8', append: bool = False) -> None:
|
||||
async def write_text(
|
||||
self, path: str, content: str, encoding: str = "utf-8", append: bool = False
|
||||
) -> None:
|
||||
"""Write text to a file with specified encoding.
|
||||
|
||||
|
||||
Args:
|
||||
path: Path to the file to write
|
||||
content: Text content to write
|
||||
@@ -448,7 +495,7 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
return CommandResult(
|
||||
stdout=result.get("stdout", ""),
|
||||
stderr=result.get("stderr", ""),
|
||||
returncode=result.get("return_code", 0)
|
||||
returncode=result.get("return_code", 0),
|
||||
)
|
||||
|
||||
# Accessibility Actions
|
||||
@@ -458,7 +505,7 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to get accessibility tree"))
|
||||
return result
|
||||
|
||||
|
||||
async def get_active_window_bounds(self) -> Dict[str, int]:
|
||||
"""Get the bounds of the currently active window."""
|
||||
result = await self._send_command("get_active_window_bounds")
|
||||
@@ -564,33 +611,30 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
timeout=120,
|
||||
)
|
||||
self.logger.info("WebSocket connection established")
|
||||
|
||||
|
||||
# If api_key and vm_name are provided, perform authentication handshake
|
||||
if self.api_key and self.vm_name:
|
||||
self.logger.info("Performing authentication handshake...")
|
||||
auth_message = {
|
||||
"command": "authenticate",
|
||||
"params": {
|
||||
"api_key": self.api_key,
|
||||
"container_name": self.vm_name
|
||||
}
|
||||
"params": {"api_key": self.api_key, "container_name": self.vm_name},
|
||||
}
|
||||
await self._ws.send(json.dumps(auth_message))
|
||||
|
||||
|
||||
# Wait for authentication response
|
||||
async with self._recv_lock:
|
||||
auth_response = await asyncio.wait_for(self._ws.recv(), timeout=10)
|
||||
auth_result = json.loads(auth_response)
|
||||
|
||||
|
||||
if not auth_result.get("success"):
|
||||
error_msg = auth_result.get("error", "Authentication failed")
|
||||
self.logger.error(f"Authentication failed: {error_msg}")
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
raise ConnectionError(f"Authentication failed: {error_msg}")
|
||||
|
||||
|
||||
self.logger.info("Authentication successful")
|
||||
|
||||
|
||||
self._reconnect_delay = 1 # Reset reconnect delay on successful connection
|
||||
self._last_ping = time.time()
|
||||
retry_count = 0 # Reset retry count on successful connection
|
||||
@@ -600,7 +644,7 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
# Only log the first error at WARNING level, then every Nth attempt
|
||||
if retry_count == 1:
|
||||
self.logger.warning(
|
||||
f"Computer API Server not ready yet. Will retry automatically."
|
||||
"Computer API Server not ready yet. Will retry automatically."
|
||||
)
|
||||
elif retry_count % log_interval == 0:
|
||||
self.logger.warning(
|
||||
@@ -648,7 +692,7 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
# Only log connection lost warnings at most once every min_warning_interval seconds
|
||||
if current_time - last_warning_time >= min_warning_interval:
|
||||
self.logger.warning(
|
||||
f"Computer API Server connection lost. Will retry automatically."
|
||||
"Computer API Server connection lost. Will retry automatically."
|
||||
)
|
||||
last_warning_time = current_time
|
||||
else:
|
||||
@@ -661,7 +705,7 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
except:
|
||||
pass
|
||||
self._ws = None
|
||||
|
||||
|
||||
async def _ensure_connection(self):
|
||||
"""Ensure WebSocket connection is established."""
|
||||
if self._reconnect_task is None or self._reconnect_task.done():
|
||||
@@ -730,32 +774,30 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
|
||||
raise last_error if last_error else RuntimeError("Failed to send command")
|
||||
|
||||
async def _send_command_rest(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
async def _send_command_rest(
|
||||
self, command: str, params: Optional[Dict] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Send command through REST API without retries or connection management."""
|
||||
try:
|
||||
# Prepare the request payload
|
||||
payload = {"command": command, "params": params or {}}
|
||||
|
||||
|
||||
# Prepare headers
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["X-API-Key"] = self.api_key
|
||||
if self.vm_name:
|
||||
headers["X-Container-Name"] = self.vm_name
|
||||
|
||||
|
||||
# Send the request
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
self.rest_uri,
|
||||
json=payload,
|
||||
headers=headers
|
||||
) as response:
|
||||
async with session.post(self.rest_uri, json=payload, headers=headers) as response:
|
||||
# Get the response text
|
||||
response_text = await response.text()
|
||||
|
||||
|
||||
# Trim whitespace
|
||||
response_text = response_text.strip()
|
||||
|
||||
|
||||
# Check if it starts with "data: "
|
||||
if response_text.startswith("data: "):
|
||||
# Extract everything after "data: "
|
||||
@@ -766,38 +808,39 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Server returned malformed response",
|
||||
"message": response_text
|
||||
"message": response_text,
|
||||
}
|
||||
else:
|
||||
# Return error response
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Server returned malformed response",
|
||||
"message": response_text
|
||||
"message": response_text,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Request failed",
|
||||
"message": str(e)
|
||||
}
|
||||
return {"success": False, "error": "Request failed", "message": str(e)}
|
||||
|
||||
async def _send_command(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
"""Send command using REST API with WebSocket fallback."""
|
||||
# Try REST API first
|
||||
result = await self._send_command_rest(command, params)
|
||||
|
||||
|
||||
# If REST failed with "Request failed", try WebSocket as fallback
|
||||
if not result.get("success", True) and (result.get("error") == "Request failed" or result.get("error") == "Server returned malformed response"):
|
||||
self.logger.warning(f"REST API failed for command '{command}', trying WebSocket fallback")
|
||||
if not result.get("success", True) and (
|
||||
result.get("error") == "Request failed"
|
||||
or result.get("error") == "Server returned malformed response"
|
||||
):
|
||||
self.logger.warning(
|
||||
f"REST API failed for command '{command}', trying WebSocket fallback"
|
||||
)
|
||||
try:
|
||||
return await self._send_command_ws(command, params)
|
||||
except Exception as e:
|
||||
self.logger.error(f"WebSocket fallback also failed: {e}")
|
||||
# Return the original REST error
|
||||
return result
|
||||
|
||||
|
||||
return result
|
||||
|
||||
async def wait_for_ready(self, timeout: int = 60, interval: float = 1.0):
|
||||
@@ -808,7 +851,9 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
result = await self._send_command_rest("version", {})
|
||||
assert result.get("success", True)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"REST API failed for command 'version', trying WebSocket fallback: {e}")
|
||||
self.logger.debug(
|
||||
f"REST API failed for command 'version', trying WebSocket fallback: {e}"
|
||||
)
|
||||
try:
|
||||
await self._wait_for_ready_ws(timeout, interval)
|
||||
return
|
||||
@@ -957,7 +1002,7 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
# if self._ws:
|
||||
# asyncio.create_task(self._ws.close())
|
||||
# self._ws = None
|
||||
|
||||
|
||||
def force_close(self):
|
||||
"""Force close the WebSocket connection.
|
||||
|
||||
@@ -970,4 +1015,3 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
if self._ws:
|
||||
asyncio.create_task(self._ws.close())
|
||||
self._ws = None
|
||||
|
||||
|
||||
@@ -1,8 +1,19 @@
|
||||
from typing import Optional
|
||||
|
||||
from .generic import GenericComputerInterface
|
||||
|
||||
|
||||
class LinuxComputerInterface(GenericComputerInterface):
|
||||
"""Interface for Linux."""
|
||||
|
||||
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
|
||||
super().__init__(ip_address, username, password, api_key, vm_name, "computer.interface.linux")
|
||||
def __init__(
|
||||
self,
|
||||
ip_address: str,
|
||||
username: str = "lume",
|
||||
password: str = "lume",
|
||||
api_key: Optional[str] = None,
|
||||
vm_name: Optional[str] = None,
|
||||
):
|
||||
super().__init__(
|
||||
ip_address, username, password, api_key, vm_name, "computer.interface.linux"
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user