mirror of
https://github.com/trycua/computer.git
synced 2025-12-31 10:29:59 -06:00
Merge upstream/main to resolve conflicts with trycua/cua
This commit is contained in:
10
libs/python/agent/.bumpversion.cfg
Normal file
10
libs/python/agent/.bumpversion.cfg
Normal file
@@ -0,0 +1,10 @@
|
||||
[bumpversion]
|
||||
current_version = 0.4.35
|
||||
commit = True
|
||||
tag = True
|
||||
tag_name = agent-v{new_version}
|
||||
message = Bump cua-agent to v{new_version}
|
||||
|
||||
[bumpversion:file:pyproject.toml]
|
||||
search = version = "{current_version}"
|
||||
replace = version = "{new_version}"
|
||||
@@ -8,10 +8,11 @@
|
||||
</picture>
|
||||
</div>
|
||||
|
||||
[](#)
|
||||
[](#)
|
||||
[](https://discord.com/invite/mVnXXpdE85)
|
||||
[](https://pypi.org/project/cua-computer/)
|
||||
[](#)
|
||||
[](#)
|
||||
[](https://discord.com/invite/mVnXXpdE85)
|
||||
[](https://pypi.org/project/cua-computer/)
|
||||
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
@@ -47,7 +48,7 @@ async def main():
|
||||
name=os.getenv("CUA_CONTAINER_NAME"),
|
||||
api_key=os.getenv("CUA_API_KEY")
|
||||
) as computer:
|
||||
|
||||
|
||||
# Create agent
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
@@ -56,10 +57,10 @@ async def main():
|
||||
trajectory_dir="trajectories",
|
||||
max_trajectory_budget=5.0 # $5 budget limit
|
||||
)
|
||||
|
||||
|
||||
# Run agent
|
||||
messages = [{"role": "user", "content": "Take a screenshot and tell me what you see"}]
|
||||
|
||||
|
||||
async for result in agent.run(messages):
|
||||
for item in result["output"]:
|
||||
if item["type"] == "message":
|
||||
@@ -84,4 +85,4 @@ if __name__ == "__main__":
|
||||
|
||||
## License
|
||||
|
||||
MIT License - see LICENSE file for details.
|
||||
MIT License - see LICENSE file for details.
|
||||
|
||||
@@ -5,19 +5,13 @@ agent - Decorator-based Computer Use Agent with liteLLM integration
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from .decorators import register_agent
|
||||
from .agent import ComputerAgent
|
||||
from .types import Messages, AgentResponse
|
||||
|
||||
# Import loops to register them
|
||||
from . import loops
|
||||
from .agent import ComputerAgent
|
||||
from .decorators import register_agent
|
||||
from .types import AgentResponse, Messages
|
||||
|
||||
__all__ = [
|
||||
"register_agent",
|
||||
"ComputerAgent",
|
||||
"Messages",
|
||||
"AgentResponse"
|
||||
]
|
||||
__all__ = ["register_agent", "ComputerAgent", "Messages", "AgentResponse"]
|
||||
|
||||
__version__ = "0.4.0"
|
||||
|
||||
|
||||
@@ -5,8 +5,9 @@ Usage:
|
||||
python -m agent.cli <model_string>
|
||||
"""
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from .cli import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -2,27 +2,30 @@ import asyncio
|
||||
import functools
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Iterator, AsyncIterator, Dict, List, Any, Optional
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||
|
||||
from litellm import acompletion, completion
|
||||
from litellm.llms.custom_llm import CustomLLM
|
||||
from litellm import completion, acompletion
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||
|
||||
# Try to import HuggingFace dependencies
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoModelForImageTextToText, AutoProcessor
|
||||
|
||||
HF_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_AVAILABLE = False
|
||||
|
||||
from .models import load_model as load_model_handler
|
||||
|
||||
|
||||
class HuggingFaceLocalAdapter(CustomLLM):
|
||||
"""HuggingFace Local Adapter for running vision-language models locally."""
|
||||
|
||||
|
||||
def __init__(self, device: str = "auto", trust_remote_code: bool = False, **kwargs):
|
||||
"""Initialize the adapter.
|
||||
|
||||
|
||||
Args:
|
||||
device: Device to load model on ("auto", "cuda", "cpu", etc.)
|
||||
trust_remote_code: Whether to trust remote code
|
||||
@@ -34,129 +37,120 @@ class HuggingFaceLocalAdapter(CustomLLM):
|
||||
# Cache for model handlers keyed by model_name
|
||||
self._handlers: Dict[str, Any] = {}
|
||||
self._executor = ThreadPoolExecutor(max_workers=1) # Single thread pool
|
||||
|
||||
|
||||
def _get_handler(self, model_name: str):
|
||||
"""Get or create a model handler for the given model name."""
|
||||
if model_name not in self._handlers:
|
||||
self._handlers[model_name] = load_model_handler(model_name=model_name, device=self.device, trust_remote_code=self.trust_remote_code)
|
||||
self._handlers[model_name] = load_model_handler(
|
||||
model_name=model_name, device=self.device, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
return self._handlers[model_name]
|
||||
|
||||
|
||||
def _convert_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Convert OpenAI format messages to HuggingFace format.
|
||||
|
||||
|
||||
Args:
|
||||
messages: Messages in OpenAI format
|
||||
|
||||
|
||||
Returns:
|
||||
Messages in HuggingFace format
|
||||
"""
|
||||
converted_messages = []
|
||||
|
||||
|
||||
for message in messages:
|
||||
converted_message = {
|
||||
"role": message["role"],
|
||||
"content": []
|
||||
}
|
||||
|
||||
converted_message = {"role": message["role"], "content": []}
|
||||
|
||||
content = message.get("content", [])
|
||||
if isinstance(content, str):
|
||||
# Simple text content
|
||||
converted_message["content"].append({
|
||||
"type": "text",
|
||||
"text": content
|
||||
})
|
||||
converted_message["content"].append({"type": "text", "text": content})
|
||||
elif isinstance(content, list):
|
||||
# Multi-modal content
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
converted_message["content"].append({
|
||||
"type": "text",
|
||||
"text": item.get("text", "")
|
||||
})
|
||||
converted_message["content"].append(
|
||||
{"type": "text", "text": item.get("text", "")}
|
||||
)
|
||||
elif item.get("type") == "image_url":
|
||||
# Convert image_url format to image format
|
||||
image_url = item.get("image_url", {}).get("url", "")
|
||||
converted_message["content"].append({
|
||||
"type": "image",
|
||||
"image": image_url
|
||||
})
|
||||
|
||||
converted_message["content"].append({"type": "image", "image": image_url})
|
||||
|
||||
converted_messages.append(converted_message)
|
||||
|
||||
|
||||
return converted_messages
|
||||
|
||||
|
||||
def _generate(self, **kwargs) -> str:
|
||||
"""Generate response using the local HuggingFace model.
|
||||
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing messages and model info
|
||||
|
||||
|
||||
Returns:
|
||||
Generated text response
|
||||
"""
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError(
|
||||
"HuggingFace transformers dependencies not found. "
|
||||
"Please install with: pip install \"cua-agent[uitars-hf]\""
|
||||
'Please install with: pip install "cua-agent[uitars-hf]"'
|
||||
)
|
||||
|
||||
|
||||
# Extract messages and model from kwargs
|
||||
messages = kwargs.get('messages', [])
|
||||
model_name = kwargs.get('model', 'ByteDance-Seed/UI-TARS-1.5-7B')
|
||||
max_new_tokens = kwargs.get('max_tokens', 128)
|
||||
|
||||
messages = kwargs.get("messages", [])
|
||||
model_name = kwargs.get("model", "ByteDance-Seed/UI-TARS-1.5-7B")
|
||||
max_new_tokens = kwargs.get("max_tokens", 128)
|
||||
|
||||
# Warn about ignored kwargs
|
||||
ignored_kwargs = set(kwargs.keys()) - {'messages', 'model', 'max_tokens'}
|
||||
ignored_kwargs = set(kwargs.keys()) - {"messages", "model", "max_tokens"}
|
||||
if ignored_kwargs:
|
||||
warnings.warn(f"Ignoring unsupported kwargs: {ignored_kwargs}")
|
||||
|
||||
|
||||
# Convert messages to HuggingFace format
|
||||
hf_messages = self._convert_messages(messages)
|
||||
|
||||
|
||||
# Delegate to model handler
|
||||
handler = self._get_handler(model_name)
|
||||
generated_text = handler.generate(hf_messages, max_new_tokens=max_new_tokens)
|
||||
return generated_text
|
||||
|
||||
|
||||
def completion(self, *args, **kwargs) -> ModelResponse:
|
||||
"""Synchronous completion method.
|
||||
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated text
|
||||
"""
|
||||
generated_text = self._generate(**kwargs)
|
||||
|
||||
|
||||
return completion(
|
||||
model=f"huggingface-local/{kwargs['model']}",
|
||||
mock_response=generated_text,
|
||||
)
|
||||
|
||||
|
||||
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
||||
"""Asynchronous completion method.
|
||||
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated text
|
||||
"""
|
||||
# Run _generate in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
generated_text = await loop.run_in_executor(
|
||||
self._executor,
|
||||
functools.partial(self._generate, **kwargs)
|
||||
self._executor, functools.partial(self._generate, **kwargs)
|
||||
)
|
||||
|
||||
|
||||
return await acompletion(
|
||||
model=f"huggingface-local/{kwargs['model']}",
|
||||
mock_response=generated_text,
|
||||
)
|
||||
|
||||
|
||||
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||
"""Synchronous streaming method.
|
||||
|
||||
|
||||
Returns:
|
||||
Iterator of GenericStreamingChunk
|
||||
"""
|
||||
generated_text = self._generate(**kwargs)
|
||||
|
||||
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
@@ -165,22 +159,21 @@ class HuggingFaceLocalAdapter(CustomLLM):
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
yield generic_streaming_chunk
|
||||
|
||||
|
||||
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
||||
"""Asynchronous streaming method.
|
||||
|
||||
|
||||
Returns:
|
||||
AsyncIterator of GenericStreamingChunk
|
||||
"""
|
||||
# Run _generate in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
generated_text = await loop.run_in_executor(
|
||||
self._executor,
|
||||
functools.partial(self._generate, **kwargs)
|
||||
self._executor, functools.partial(self._generate, **kwargs)
|
||||
)
|
||||
|
||||
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
@@ -189,5 +182,5 @@ class HuggingFaceLocalAdapter(CustomLLM):
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
yield generic_streaming_chunk
|
||||
|
||||
yield generic_streaming_chunk
|
||||
|
||||
@@ -1,22 +1,23 @@
|
||||
import os
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List
|
||||
|
||||
import requests
|
||||
from typing import List, Dict, Any, Iterator, AsyncIterator
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||
from litellm import acompletion, completion
|
||||
from litellm.llms.custom_llm import CustomLLM
|
||||
from litellm import completion, acompletion
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||
|
||||
|
||||
class HumanAdapter(CustomLLM):
|
||||
"""Human Adapter for human-in-the-loop completions.
|
||||
|
||||
|
||||
This adapter sends completion requests to a human completion server
|
||||
where humans can review and respond to AI requests.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, base_url: str | None = None, timeout: float = 300.0, **kwargs):
|
||||
"""Initialize the human adapter.
|
||||
|
||||
|
||||
Args:
|
||||
base_url: Base URL for the human completion server.
|
||||
Defaults to HUMAN_BASE_URL environment variable or http://localhost:8002
|
||||
@@ -24,60 +25,58 @@ class HumanAdapter(CustomLLM):
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
super().__init__()
|
||||
self.base_url = base_url or os.getenv('HUMAN_BASE_URL', 'http://localhost:8002')
|
||||
self.base_url = base_url or os.getenv("HUMAN_BASE_URL", "http://localhost:8002")
|
||||
self.timeout = timeout
|
||||
|
||||
|
||||
# Ensure base_url doesn't end with slash
|
||||
self.base_url = self.base_url.rstrip('/')
|
||||
|
||||
self.base_url = self.base_url.rstrip("/")
|
||||
|
||||
def _queue_completion(self, messages: List[Dict[str, Any]], model: str) -> str:
|
||||
"""Queue a completion request and return the call ID.
|
||||
|
||||
|
||||
Args:
|
||||
messages: Messages in OpenAI format
|
||||
model: Model name
|
||||
|
||||
|
||||
Returns:
|
||||
Call ID for tracking the request
|
||||
|
||||
|
||||
Raises:
|
||||
Exception: If queueing fails
|
||||
"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/queue",
|
||||
json={"messages": messages, "model": model},
|
||||
timeout=10
|
||||
f"{self.base_url}/queue", json={"messages": messages, "model": model}, timeout=10
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["id"]
|
||||
except requests.RequestException as e:
|
||||
raise Exception(f"Failed to queue completion request: {e}")
|
||||
|
||||
|
||||
def _wait_for_completion(self, call_id: str) -> Dict[str, Any]:
|
||||
"""Wait for human to complete the call.
|
||||
|
||||
|
||||
Args:
|
||||
call_id: ID of the queued completion call
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing response and/or tool_calls
|
||||
|
||||
|
||||
Raises:
|
||||
TimeoutError: If timeout is exceeded
|
||||
Exception: If completion fails
|
||||
"""
|
||||
import time
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Check status
|
||||
status_response = requests.get(f"{self.base_url}/status/{call_id}")
|
||||
status_response.raise_for_status()
|
||||
status_data = status_response.json()
|
||||
|
||||
|
||||
if status_data["status"] == "completed":
|
||||
result = {}
|
||||
if "response" in status_data and status_data["response"]:
|
||||
@@ -88,38 +87,41 @@ class HumanAdapter(CustomLLM):
|
||||
elif status_data["status"] == "failed":
|
||||
error_msg = status_data.get("error", "Unknown error")
|
||||
raise Exception(f"Completion failed: {error_msg}")
|
||||
|
||||
|
||||
# Check timeout
|
||||
if time.time() - start_time > self.timeout:
|
||||
raise TimeoutError(f"Timeout waiting for human response after {self.timeout} seconds")
|
||||
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for human response after {self.timeout} seconds"
|
||||
)
|
||||
|
||||
# Wait before checking again
|
||||
time.sleep(1.0)
|
||||
|
||||
|
||||
except requests.RequestException as e:
|
||||
if time.time() - start_time > self.timeout:
|
||||
raise TimeoutError(f"Timeout waiting for human response: {e}")
|
||||
# Continue trying if we haven't timed out
|
||||
time.sleep(1.0)
|
||||
|
||||
|
||||
async def _async_wait_for_completion(self, call_id: str) -> Dict[str, Any]:
|
||||
"""Async version of wait_for_completion.
|
||||
|
||||
|
||||
Args:
|
||||
call_id: ID of the queued completion call
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing response and/or tool_calls
|
||||
|
||||
|
||||
Raises:
|
||||
TimeoutError: If timeout is exceeded
|
||||
Exception: If completion fails
|
||||
"""
|
||||
import aiohttp
|
||||
import time
|
||||
|
||||
|
||||
import aiohttp
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
while True:
|
||||
try:
|
||||
@@ -127,7 +129,7 @@ class HumanAdapter(CustomLLM):
|
||||
async with session.get(f"{self.base_url}/status/{call_id}") as response:
|
||||
response.raise_for_status()
|
||||
status_data = await response.json()
|
||||
|
||||
|
||||
if status_data["status"] == "completed":
|
||||
result = {}
|
||||
if "response" in status_data and status_data["response"]:
|
||||
@@ -138,166 +140,158 @@ class HumanAdapter(CustomLLM):
|
||||
elif status_data["status"] == "failed":
|
||||
error_msg = status_data.get("error", "Unknown error")
|
||||
raise Exception(f"Completion failed: {error_msg}")
|
||||
|
||||
|
||||
# Check timeout
|
||||
if time.time() - start_time > self.timeout:
|
||||
raise TimeoutError(f"Timeout waiting for human response after {self.timeout} seconds")
|
||||
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for human response after {self.timeout} seconds"
|
||||
)
|
||||
|
||||
# Wait before checking again
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if time.time() - start_time > self.timeout:
|
||||
raise TimeoutError(f"Timeout waiting for human response: {e}")
|
||||
# Continue trying if we haven't timed out
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
|
||||
def _generate_response(self, messages: List[Dict[str, Any]], model: str) -> Dict[str, Any]:
|
||||
"""Generate a human response for the given messages.
|
||||
|
||||
|
||||
Args:
|
||||
messages: Messages in OpenAI format
|
||||
model: Model name
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing response and/or tool_calls
|
||||
"""
|
||||
# Queue the completion request
|
||||
call_id = self._queue_completion(messages, model)
|
||||
|
||||
|
||||
# Wait for human response
|
||||
response = self._wait_for_completion(call_id)
|
||||
|
||||
|
||||
return response
|
||||
|
||||
async def _async_generate_response(self, messages: List[Dict[str, Any]], model: str) -> Dict[str, Any]:
|
||||
|
||||
async def _async_generate_response(
|
||||
self, messages: List[Dict[str, Any]], model: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Async version of _generate_response.
|
||||
|
||||
|
||||
Args:
|
||||
messages: Messages in OpenAI format
|
||||
model: Model name
|
||||
|
||||
|
||||
Returns:
|
||||
Dict containing response and/or tool_calls
|
||||
"""
|
||||
# Queue the completion request (sync operation)
|
||||
call_id = self._queue_completion(messages, model)
|
||||
|
||||
|
||||
# Wait for human response (async)
|
||||
response = await self._async_wait_for_completion(call_id)
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def completion(self, *args, **kwargs) -> ModelResponse:
|
||||
"""Synchronous completion method.
|
||||
|
||||
|
||||
Returns:
|
||||
ModelResponse with human-generated text or tool calls
|
||||
"""
|
||||
messages = kwargs.get('messages', [])
|
||||
model = kwargs.get('model', 'human')
|
||||
|
||||
messages = kwargs.get("messages", [])
|
||||
model = kwargs.get("model", "human")
|
||||
|
||||
# Generate human response
|
||||
human_response_data = self._generate_response(messages, model)
|
||||
|
||||
|
||||
# Create ModelResponse with proper structure
|
||||
from litellm.types.utils import ModelResponse, Choices, Message
|
||||
import uuid
|
||||
import time
|
||||
|
||||
import uuid
|
||||
|
||||
from litellm.types.utils import Choices, Message, ModelResponse
|
||||
|
||||
# Create message content based on response type
|
||||
if "tool_calls" in human_response_data and human_response_data["tool_calls"]:
|
||||
# Tool calls response
|
||||
message = Message(
|
||||
role="assistant",
|
||||
content=human_response_data.get("response", ""),
|
||||
tool_calls=human_response_data["tool_calls"]
|
||||
tool_calls=human_response_data["tool_calls"],
|
||||
)
|
||||
else:
|
||||
# Text response
|
||||
message = Message(
|
||||
role="assistant",
|
||||
content=human_response_data.get("response", "")
|
||||
)
|
||||
|
||||
choice = Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=message
|
||||
)
|
||||
|
||||
message = Message(role="assistant", content=human_response_data.get("response", ""))
|
||||
|
||||
choice = Choices(finish_reason="stop", index=0, message=message)
|
||||
|
||||
result = ModelResponse(
|
||||
id=f"human-{uuid.uuid4()}",
|
||||
choices=[choice],
|
||||
created=int(time.time()),
|
||||
model=f"human/{model}",
|
||||
object="chat.completion"
|
||||
object="chat.completion",
|
||||
)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
||||
"""Asynchronous completion method.
|
||||
|
||||
|
||||
Returns:
|
||||
ModelResponse with human-generated text or tool calls
|
||||
"""
|
||||
messages = kwargs.get('messages', [])
|
||||
model = kwargs.get('model', 'human')
|
||||
|
||||
messages = kwargs.get("messages", [])
|
||||
model = kwargs.get("model", "human")
|
||||
|
||||
# Generate human response
|
||||
human_response_data = await self._async_generate_response(messages, model)
|
||||
|
||||
|
||||
# Create ModelResponse with proper structure
|
||||
from litellm.types.utils import ModelResponse, Choices, Message
|
||||
import uuid
|
||||
import time
|
||||
|
||||
import uuid
|
||||
|
||||
from litellm.types.utils import Choices, Message, ModelResponse
|
||||
|
||||
# Create message content based on response type
|
||||
if "tool_calls" in human_response_data and human_response_data["tool_calls"]:
|
||||
# Tool calls response
|
||||
message = Message(
|
||||
role="assistant",
|
||||
content=human_response_data.get("response", ""),
|
||||
tool_calls=human_response_data["tool_calls"]
|
||||
tool_calls=human_response_data["tool_calls"],
|
||||
)
|
||||
else:
|
||||
# Text response
|
||||
message = Message(
|
||||
role="assistant",
|
||||
content=human_response_data.get("response", "")
|
||||
)
|
||||
|
||||
choice = Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=message
|
||||
)
|
||||
|
||||
message = Message(role="assistant", content=human_response_data.get("response", ""))
|
||||
|
||||
choice = Choices(finish_reason="stop", index=0, message=message)
|
||||
|
||||
result = ModelResponse(
|
||||
id=f"human-{uuid.uuid4()}",
|
||||
choices=[choice],
|
||||
created=int(time.time()),
|
||||
model=f"human/{model}",
|
||||
object="chat.completion"
|
||||
object="chat.completion",
|
||||
)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||
"""Synchronous streaming method.
|
||||
|
||||
|
||||
Yields:
|
||||
Streaming chunks with human-generated text or tool calls
|
||||
"""
|
||||
messages = kwargs.get('messages', [])
|
||||
model = kwargs.get('model', 'human')
|
||||
|
||||
messages = kwargs.get("messages", [])
|
||||
model = kwargs.get("model", "human")
|
||||
|
||||
# Generate human response
|
||||
human_response_data = self._generate_response(messages, model)
|
||||
|
||||
|
||||
import time
|
||||
|
||||
|
||||
# Handle tool calls vs text response
|
||||
if "tool_calls" in human_response_data and human_response_data["tool_calls"]:
|
||||
# Stream tool calls as a single chunk
|
||||
@@ -319,22 +313,26 @@ class HumanAdapter(CustomLLM):
|
||||
"is_finished": True,
|
||||
"text": response_text,
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": len(response_text.split()), "prompt_tokens": 0, "total_tokens": len(response_text.split())},
|
||||
"usage": {
|
||||
"completion_tokens": len(response_text.split()),
|
||||
"prompt_tokens": 0,
|
||||
"total_tokens": len(response_text.split()),
|
||||
},
|
||||
}
|
||||
yield generic_chunk
|
||||
|
||||
|
||||
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
||||
"""Asynchronous streaming method.
|
||||
|
||||
|
||||
Yields:
|
||||
Streaming chunks with human-generated text or tool calls
|
||||
"""
|
||||
messages = kwargs.get('messages', [])
|
||||
model = kwargs.get('model', 'human')
|
||||
|
||||
messages = kwargs.get("messages", [])
|
||||
model = kwargs.get("model", "human")
|
||||
|
||||
# Generate human response
|
||||
human_response = await self._async_generate_response(messages, model)
|
||||
|
||||
|
||||
# Return as single streaming chunk
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
@@ -342,7 +340,11 @@ class HumanAdapter(CustomLLM):
|
||||
"is_finished": True,
|
||||
"text": human_response,
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": len(human_response.split()), "prompt_tokens": 0, "total_tokens": len(human_response.split())},
|
||||
"usage": {
|
||||
"completion_tokens": len(human_response.split()),
|
||||
"prompt_tokens": 0,
|
||||
"total_tokens": len(human_response.split()),
|
||||
},
|
||||
}
|
||||
|
||||
yield generic_streaming_chunk
|
||||
|
||||
yield generic_streaming_chunk
|
||||
|
||||
@@ -1,24 +1,26 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import warnings
|
||||
import io
|
||||
import base64
|
||||
import functools
|
||||
import io
|
||||
import math
|
||||
import re
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Iterator, AsyncIterator, Dict, List, Any, Optional, Tuple, cast
|
||||
from PIL import Image
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple, cast
|
||||
|
||||
from litellm import acompletion, completion
|
||||
from litellm.llms.custom_llm import CustomLLM
|
||||
from litellm import completion, acompletion
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||
from PIL import Image
|
||||
|
||||
# Try to import MLX dependencies
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_vlm import load, generate
|
||||
from mlx_vlm import generate, load
|
||||
from mlx_vlm.prompt_utils import apply_chat_template
|
||||
from mlx_vlm.utils import load_config
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
MLX_AVAILABLE = True
|
||||
except ImportError:
|
||||
MLX_AVAILABLE = False
|
||||
@@ -29,20 +31,28 @@ MIN_PIXELS = 100 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
|
||||
|
||||
def round_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
|
||||
def ceil_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
|
||||
def floor_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
|
||||
def smart_resize(
|
||||
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||||
height: int,
|
||||
width: int,
|
||||
factor: int = IMAGE_FACTOR,
|
||||
min_pixels: int = MIN_PIXELS,
|
||||
max_pixels: int = MAX_PIXELS,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
@@ -70,61 +80,62 @@ def smart_resize(
|
||||
|
||||
class MLXVLMAdapter(CustomLLM):
|
||||
"""MLX VLM Adapter for running vision-language models locally using MLX."""
|
||||
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the adapter.
|
||||
|
||||
|
||||
Args:
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.models = {} # Cache for loaded models
|
||||
self.processors = {} # Cache for loaded processors
|
||||
self.configs = {} # Cache for loaded configs
|
||||
self._executor = ThreadPoolExecutor(max_workers=1) # Single thread pool
|
||||
|
||||
|
||||
def _load_model_and_processor(self, model_name: str):
|
||||
"""Load model and processor if not already cached.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to load
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (model, processor, config)
|
||||
"""
|
||||
if not MLX_AVAILABLE:
|
||||
raise ImportError("MLX VLM dependencies not available. Please install mlx-vlm.")
|
||||
|
||||
|
||||
if model_name not in self.models:
|
||||
# Load model and processor
|
||||
model_obj, processor = load(
|
||||
model_name,
|
||||
processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
|
||||
model_name, processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
|
||||
)
|
||||
config = load_config(model_name)
|
||||
|
||||
|
||||
# Cache them
|
||||
self.models[model_name] = model_obj
|
||||
self.processors[model_name] = processor
|
||||
self.configs[model_name] = config
|
||||
|
||||
|
||||
return self.models[model_name], self.processors[model_name], self.configs[model_name]
|
||||
|
||||
def _process_coordinates(self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]) -> str:
|
||||
|
||||
def _process_coordinates(
|
||||
self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]
|
||||
) -> str:
|
||||
"""Process coordinates in box tokens based on image resizing using smart_resize approach.
|
||||
|
||||
|
||||
Args:
|
||||
text: Text containing box tokens
|
||||
original_size: Original image size (width, height)
|
||||
model_size: Model processed image size (width, height)
|
||||
|
||||
|
||||
Returns:
|
||||
Text with processed coordinates
|
||||
"""
|
||||
# Find all box tokens
|
||||
box_pattern = r"<\|box_start\|>\((\d+),\s*(\d+)\)<\|box_end\|>"
|
||||
|
||||
|
||||
def process_coords(match):
|
||||
model_x, model_y = int(match.group(1)), int(match.group(2))
|
||||
# Scale coordinates from model space to original image space
|
||||
@@ -132,15 +143,20 @@ class MLXVLMAdapter(CustomLLM):
|
||||
new_x = int(model_x * original_size[0] / model_size[0]) # Width
|
||||
new_y = int(model_y * original_size[1] / model_size[1]) # Height
|
||||
return f"<|box_start|>({new_x},{new_y})<|box_end|>"
|
||||
|
||||
|
||||
return re.sub(box_pattern, process_coords, text)
|
||||
|
||||
def _convert_messages(self, messages: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Image.Image], Dict[int, Tuple[int, int]], Dict[int, Tuple[int, int]]]:
|
||||
|
||||
def _convert_messages(self, messages: List[Dict[str, Any]]) -> Tuple[
|
||||
List[Dict[str, Any]],
|
||||
List[Image.Image],
|
||||
Dict[int, Tuple[int, int]],
|
||||
Dict[int, Tuple[int, int]],
|
||||
]:
|
||||
"""Convert OpenAI format messages to MLX VLM format and extract images.
|
||||
|
||||
|
||||
Args:
|
||||
messages: Messages in OpenAI format
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (processed_messages, images, original_sizes, model_sizes)
|
||||
"""
|
||||
@@ -149,13 +165,10 @@ class MLXVLMAdapter(CustomLLM):
|
||||
original_sizes = {} # Track original sizes of images for coordinate mapping
|
||||
model_sizes = {} # Track model processed sizes
|
||||
image_index = 0
|
||||
|
||||
|
||||
for message in messages:
|
||||
processed_message = {
|
||||
"role": message["role"],
|
||||
"content": []
|
||||
}
|
||||
|
||||
processed_message = {"role": message["role"], "content": []}
|
||||
|
||||
content = message.get("content", [])
|
||||
if isinstance(content, str):
|
||||
# Simple text content
|
||||
@@ -165,164 +178,163 @@ class MLXVLMAdapter(CustomLLM):
|
||||
processed_content = []
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
processed_content.append({
|
||||
"type": "text",
|
||||
"text": item.get("text", "")
|
||||
})
|
||||
processed_content.append({"type": "text", "text": item.get("text", "")})
|
||||
elif item.get("type") == "image_url":
|
||||
image_url = item.get("image_url", {}).get("url", "")
|
||||
pil_image = None
|
||||
|
||||
|
||||
if image_url.startswith("data:image/"):
|
||||
# Extract base64 data
|
||||
base64_data = image_url.split(',')[1]
|
||||
base64_data = image_url.split(",")[1]
|
||||
# Convert base64 to PIL Image
|
||||
image_data = base64.b64decode(base64_data)
|
||||
pil_image = Image.open(io.BytesIO(image_data))
|
||||
else:
|
||||
# Handle file path or URL
|
||||
pil_image = Image.open(image_url)
|
||||
|
||||
|
||||
# Store original image size for coordinate mapping
|
||||
original_size = pil_image.size
|
||||
original_sizes[image_index] = original_size
|
||||
|
||||
|
||||
# Use smart_resize to determine model size
|
||||
# Note: smart_resize expects (height, width) but PIL gives (width, height)
|
||||
height, width = original_size[1], original_size[0]
|
||||
new_height, new_width = smart_resize(height, width)
|
||||
# Store model size in (width, height) format for consistent coordinate processing
|
||||
model_sizes[image_index] = (new_width, new_height)
|
||||
|
||||
|
||||
# Resize the image using the calculated dimensions from smart_resize
|
||||
resized_image = pil_image.resize((new_width, new_height))
|
||||
images.append(resized_image)
|
||||
|
||||
|
||||
# Add image placeholder to content
|
||||
processed_content.append({
|
||||
"type": "image"
|
||||
})
|
||||
|
||||
processed_content.append({"type": "image"})
|
||||
|
||||
image_index += 1
|
||||
|
||||
|
||||
processed_message["content"] = processed_content
|
||||
|
||||
|
||||
processed_messages.append(processed_message)
|
||||
|
||||
|
||||
return processed_messages, images, original_sizes, model_sizes
|
||||
|
||||
|
||||
def _generate(self, **kwargs) -> str:
|
||||
"""Generate response using the local MLX VLM model.
|
||||
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing messages and model info
|
||||
|
||||
|
||||
Returns:
|
||||
Generated text response
|
||||
"""
|
||||
messages = kwargs.get('messages', [])
|
||||
model_name = kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')
|
||||
max_tokens = kwargs.get('max_tokens', 128)
|
||||
|
||||
messages = kwargs.get("messages", [])
|
||||
model_name = kwargs.get("model", "mlx-community/UI-TARS-1.5-7B-4bit")
|
||||
max_tokens = kwargs.get("max_tokens", 128)
|
||||
|
||||
# Warn about ignored kwargs
|
||||
ignored_kwargs = set(kwargs.keys()) - {'messages', 'model', 'max_tokens'}
|
||||
ignored_kwargs = set(kwargs.keys()) - {"messages", "model", "max_tokens"}
|
||||
if ignored_kwargs:
|
||||
warnings.warn(f"Ignoring unsupported kwargs: {ignored_kwargs}")
|
||||
|
||||
|
||||
# Load model and processor
|
||||
model, processor, config = self._load_model_and_processor(model_name)
|
||||
|
||||
|
||||
# Convert messages and extract images
|
||||
processed_messages, images, original_sizes, model_sizes = self._convert_messages(messages)
|
||||
|
||||
|
||||
# Process user text input with box coordinates after image processing
|
||||
# Swap original_size and model_size arguments for inverse transformation
|
||||
for msg_idx, msg in enumerate(processed_messages):
|
||||
if msg.get("role") == "user" and isinstance(msg.get("content"), str):
|
||||
content = msg.get("content", "")
|
||||
if "<|box_start|>" in content and original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
|
||||
if (
|
||||
"<|box_start|>" in content
|
||||
and original_sizes
|
||||
and model_sizes
|
||||
and 0 in original_sizes
|
||||
and 0 in model_sizes
|
||||
):
|
||||
orig_size = original_sizes[0]
|
||||
model_size = model_sizes[0]
|
||||
# Swap arguments to perform inverse transformation for user input
|
||||
processed_messages[msg_idx]["content"] = self._process_coordinates(content, model_size, orig_size)
|
||||
|
||||
processed_messages[msg_idx]["content"] = self._process_coordinates(
|
||||
content, model_size, orig_size
|
||||
)
|
||||
|
||||
try:
|
||||
# Format prompt according to model requirements using the processor directly
|
||||
prompt = processor.apply_chat_template(
|
||||
processed_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
return_tensors='pt'
|
||||
processed_messages, tokenize=False, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
tokenizer = cast(PreTrainedTokenizer, processor)
|
||||
|
||||
|
||||
# Generate response
|
||||
text_content, usage = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
str(prompt),
|
||||
images, # type: ignore
|
||||
model,
|
||||
tokenizer,
|
||||
str(prompt),
|
||||
images, # type: ignore
|
||||
verbose=False,
|
||||
max_tokens=max_tokens
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error generating response: {str(e)}") from e
|
||||
|
||||
|
||||
# Process coordinates in the response back to original image space
|
||||
if original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
|
||||
# Get original image size and model size (using the first image)
|
||||
orig_size = original_sizes[0]
|
||||
model_size = model_sizes[0]
|
||||
|
||||
|
||||
# Check if output contains box tokens that need processing
|
||||
if "<|box_start|>" in text_content:
|
||||
# Process coordinates from model space back to original image space
|
||||
text_content = self._process_coordinates(text_content, orig_size, model_size)
|
||||
|
||||
|
||||
return text_content
|
||||
|
||||
|
||||
def completion(self, *args, **kwargs) -> ModelResponse:
|
||||
"""Synchronous completion method.
|
||||
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated text
|
||||
"""
|
||||
generated_text = self._generate(**kwargs)
|
||||
|
||||
|
||||
result = completion(
|
||||
model=f"mlx/{kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')}",
|
||||
mock_response=generated_text,
|
||||
)
|
||||
return cast(ModelResponse, result)
|
||||
|
||||
|
||||
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
||||
"""Asynchronous completion method.
|
||||
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated text
|
||||
"""
|
||||
# Run _generate in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
generated_text = await loop.run_in_executor(
|
||||
self._executor,
|
||||
functools.partial(self._generate, **kwargs)
|
||||
self._executor, functools.partial(self._generate, **kwargs)
|
||||
)
|
||||
|
||||
|
||||
result = await acompletion(
|
||||
model=f"mlx/{kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')}",
|
||||
mock_response=generated_text,
|
||||
)
|
||||
return cast(ModelResponse, result)
|
||||
|
||||
|
||||
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||
"""Synchronous streaming method.
|
||||
|
||||
|
||||
Returns:
|
||||
Iterator of GenericStreamingChunk
|
||||
"""
|
||||
generated_text = self._generate(**kwargs)
|
||||
|
||||
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
@@ -331,22 +343,21 @@ class MLXVLMAdapter(CustomLLM):
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
yield generic_streaming_chunk
|
||||
|
||||
|
||||
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
||||
"""Asynchronous streaming method.
|
||||
|
||||
|
||||
Returns:
|
||||
AsyncIterator of GenericStreamingChunk
|
||||
"""
|
||||
# Run _generate in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
generated_text = await loop.run_in_executor(
|
||||
self._executor,
|
||||
functools.partial(self._generate, **kwargs)
|
||||
self._executor, functools.partial(self._generate, **kwargs)
|
||||
)
|
||||
|
||||
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
@@ -355,5 +366,5 @@ class MLXVLMAdapter(CustomLLM):
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
yield generic_streaming_chunk
|
||||
|
||||
yield generic_streaming_chunk
|
||||
|
||||
@@ -2,32 +2,40 @@ from typing import Optional
|
||||
|
||||
try:
|
||||
from transformers import AutoConfig
|
||||
|
||||
HF_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_AVAILABLE = False
|
||||
|
||||
from .generic import GenericHFModel
|
||||
from .internvl import InternVLModel
|
||||
from .opencua import OpenCUAModel
|
||||
from .qwen2_5_vl import Qwen2_5_VLModel
|
||||
from .internvl import InternVLModel
|
||||
|
||||
|
||||
def load_model(model_name: str, device: str = "auto", trust_remote_code: bool = False):
|
||||
"""Factory function to load and return the right model handler instance.
|
||||
|
||||
|
||||
- If the underlying transformers config class matches OpenCUA, return OpenCUAModel
|
||||
- Otherwise, return GenericHFModel
|
||||
"""
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError(
|
||||
"HuggingFace transformers dependencies not found. Install with: pip install \"cua-agent[uitars-hf]\""
|
||||
'HuggingFace transformers dependencies not found. Install with: pip install "cua-agent[uitars-hf]"'
|
||||
)
|
||||
cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
||||
cls = cfg.__class__.__name__
|
||||
print(f"cls: {cls}")
|
||||
if "OpenCUA" in cls:
|
||||
return OpenCUAModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
|
||||
return OpenCUAModel(
|
||||
model_name=model_name, device=device, trust_remote_code=trust_remote_code
|
||||
)
|
||||
elif "Qwen2_5_VL" in cls:
|
||||
return Qwen2_5_VLModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
|
||||
return Qwen2_5_VLModel(
|
||||
model_name=model_name, device=device, trust_remote_code=trust_remote_code
|
||||
)
|
||||
elif "InternVL" in cls:
|
||||
return InternVLModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
|
||||
return InternVLModel(
|
||||
model_name=model_name, device=device, trust_remote_code=trust_remote_code
|
||||
)
|
||||
return GenericHFModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# Hugging Face imports are local to avoid hard dependency at module import
|
||||
try:
|
||||
import torch # type: ignore
|
||||
from transformers import AutoModel, AutoProcessor # type: ignore
|
||||
|
||||
HF_AVAILABLE = True
|
||||
except Exception:
|
||||
HF_AVAILABLE = False
|
||||
@@ -14,10 +15,12 @@ class GenericHFModel:
|
||||
Loads an AutoModelForImageTextToText and AutoProcessor and generates text.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
|
||||
def __init__(
|
||||
self, model_name: str, device: str = "auto", trust_remote_code: bool = False
|
||||
) -> None:
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError(
|
||||
"HuggingFace transformers dependencies not found. Install with: pip install \"cua-agent[uitars-hf]\""
|
||||
'HuggingFace transformers dependencies not found. Install with: pip install "cua-agent[uitars-hf]"'
|
||||
)
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
@@ -64,7 +67,7 @@ class GenericHFModel:
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
# Trim prompt tokens from output
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
# Decode
|
||||
output_text = self.processor.batch_decode(
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# Hugging Face imports are local to avoid hard dependency at module import
|
||||
try:
|
||||
import torch # type: ignore
|
||||
from transformers import AutoModel, AutoTokenizer # type: ignore
|
||||
# Attempt to import InternVL's model dependencies
|
||||
import einops as _ # type: ignore
|
||||
import timm as _ # type: ignore
|
||||
from PIL import Image # type: ignore
|
||||
import torchvision.transforms as T # type: ignore
|
||||
from torchvision.transforms.functional import InterpolationMode # type: ignore
|
||||
import base64 # type: ignore
|
||||
from io import BytesIO # type: ignore
|
||||
|
||||
# Attempt to import InternVL's model dependencies
|
||||
import einops as _ # type: ignore
|
||||
import requests # type: ignore
|
||||
import timm as _ # type: ignore
|
||||
import torch # type: ignore
|
||||
import torchvision.transforms as T # type: ignore
|
||||
from PIL import Image # type: ignore
|
||||
from torchvision.transforms.functional import InterpolationMode # type: ignore
|
||||
from transformers import AutoModel, AutoTokenizer # type: ignore
|
||||
|
||||
HF_AVAILABLE = True
|
||||
except Exception:
|
||||
HF_AVAILABLE = False
|
||||
@@ -25,10 +28,12 @@ class InternVLModel:
|
||||
Provides preprocessing to support multi-turn conversations with multiple images.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
|
||||
def __init__(
|
||||
self, model_name: str, device: str = "auto", trust_remote_code: bool = False
|
||||
) -> None:
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError(
|
||||
"InternVL dependencies not found. Install with: pip install \"cua-agent[internvl-hf]\""
|
||||
'InternVL dependencies not found. Install with: pip install "cua-agent[internvl-hf]"'
|
||||
)
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
@@ -60,16 +65,25 @@ class InternVLModel:
|
||||
|
||||
def _build_transform(self, input_size: int) -> T.Compose:
|
||||
MEAN, STD = self.IMAGENET_MEAN, self.IMAGENET_STD
|
||||
transform = T.Compose([
|
||||
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
||||
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=MEAN, std=STD)
|
||||
])
|
||||
transform = T.Compose(
|
||||
[
|
||||
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
|
||||
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=MEAN, std=STD),
|
||||
]
|
||||
)
|
||||
return transform
|
||||
|
||||
def _find_closest_aspect_ratio(self, aspect_ratio: float, target_ratios: List[tuple], width: int, height: int, image_size: int):
|
||||
best_ratio_diff = float('inf')
|
||||
def _find_closest_aspect_ratio(
|
||||
self,
|
||||
aspect_ratio: float,
|
||||
target_ratios: List[tuple],
|
||||
width: int,
|
||||
height: int,
|
||||
image_size: int,
|
||||
):
|
||||
best_ratio_diff = float("inf")
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
@@ -83,17 +97,29 @@ class InternVLModel:
|
||||
best_ratio = ratio
|
||||
return best_ratio
|
||||
|
||||
def _dynamic_preprocess(self, image: Image.Image, min_num: int = 1, max_num: int = 12, image_size: int = 448, use_thumbnail: bool = True) -> List[Image.Image]:
|
||||
def _dynamic_preprocess(
|
||||
self,
|
||||
image: Image.Image,
|
||||
min_num: int = 1,
|
||||
max_num: int = 12,
|
||||
image_size: int = 448,
|
||||
use_thumbnail: bool = True,
|
||||
) -> List[Image.Image]:
|
||||
orig_width, orig_height = image.size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
target_ratios = set(
|
||||
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
||||
i * j <= max_num and i * j >= min_num)
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
)
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
target_aspect_ratio = self._find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
||||
)
|
||||
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
@@ -106,7 +132,7 @@ class InternVLModel:
|
||||
(i % (target_width // image_size)) * image_size,
|
||||
(i // (target_width // image_size)) * image_size,
|
||||
((i % (target_width // image_size)) + 1) * image_size,
|
||||
((i // (target_width // image_size)) + 1) * image_size
|
||||
((i // (target_width // image_size)) + 1) * image_size,
|
||||
)
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
@@ -122,20 +148,24 @@ class InternVLModel:
|
||||
# data URL base64
|
||||
header, b64data = src.split(",", 1)
|
||||
img_bytes = base64.b64decode(b64data)
|
||||
return Image.open(BytesIO(img_bytes)).convert('RGB')
|
||||
return Image.open(BytesIO(img_bytes)).convert("RGB")
|
||||
if src.startswith("http://") or src.startswith("https://"):
|
||||
resp = requests.get(src, timeout=10)
|
||||
resp.raise_for_status()
|
||||
return Image.open(BytesIO(resp.content)).convert('RGB')
|
||||
return Image.open(BytesIO(resp.content)).convert("RGB")
|
||||
# Assume local file path
|
||||
return Image.open(src).convert('RGB')
|
||||
return Image.open(src).convert("RGB")
|
||||
|
||||
def _images_to_pixel_values(self, images: List[Image.Image], input_size: int = 448, max_num: int = 12):
|
||||
def _images_to_pixel_values(
|
||||
self, images: List[Image.Image], input_size: int = 448, max_num: int = 12
|
||||
):
|
||||
transform = self._build_transform(input_size=input_size)
|
||||
pixel_values_list = []
|
||||
num_patches_list: List[int] = []
|
||||
for img in images:
|
||||
tiles = self._dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
||||
tiles = self._dynamic_preprocess(
|
||||
img, image_size=input_size, use_thumbnail=True, max_num=max_num
|
||||
)
|
||||
pv = [transform(tile) for tile in tiles]
|
||||
pv = torch.stack(pv)
|
||||
num_patches_list.append(pv.shape[0])
|
||||
@@ -191,7 +221,9 @@ class InternVLModel:
|
||||
last_user_text_parts = parts_text or last_user_text_parts
|
||||
elif role == "assistant":
|
||||
# Only keep text content for history
|
||||
parts_text = [item.get("text", "") for item in content_items if item.get("type") == "text"]
|
||||
parts_text = [
|
||||
item.get("text", "") for item in content_items if item.get("type") == "text"
|
||||
]
|
||||
text = "\n".join(parts_text).strip()
|
||||
if text:
|
||||
context_lines.append(f"Assistant: {text}")
|
||||
@@ -200,7 +232,9 @@ class InternVLModel:
|
||||
pixel_values = None
|
||||
num_patches_list: List[int] = []
|
||||
if all_images:
|
||||
pixel_values, num_patches_list = self._images_to_pixel_values(all_images, input_size=448, max_num=12)
|
||||
pixel_values, num_patches_list = self._images_to_pixel_values(
|
||||
all_images, input_size=448, max_num=12
|
||||
)
|
||||
if pixel_values is not None:
|
||||
# Convert dtype/device as in docs
|
||||
pixel_values = pixel_values.to(torch.bfloat16)
|
||||
@@ -246,7 +280,9 @@ class InternVLModel:
|
||||
num_patches_list=num_patches_list,
|
||||
)
|
||||
else:
|
||||
response = self.model.chat(self.tokenizer, pixel_values, question, generation_config)
|
||||
response = self.model.chat(
|
||||
self.tokenizer, pixel_values, question, generation_config
|
||||
)
|
||||
except Exception as e:
|
||||
# Fallback: return empty string to avoid crashing the adapter
|
||||
return ""
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
from typing import List, Dict, Any
|
||||
import re
|
||||
import base64
|
||||
import re
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
import blobfile as _ # assert blobfile is installed
|
||||
import torch # type: ignore
|
||||
from transformers import AutoTokenizer, AutoModel, AutoImageProcessor # type: ignore
|
||||
from PIL import Image # type: ignore
|
||||
import blobfile as _ # assert blobfile is installed
|
||||
from transformers import ( # type: ignore
|
||||
AutoImageProcessor,
|
||||
AutoModel,
|
||||
AutoTokenizer,
|
||||
)
|
||||
|
||||
OPENCUA_AVAILABLE = True
|
||||
except Exception:
|
||||
OPENCUA_AVAILABLE = False
|
||||
@@ -16,10 +21,12 @@ except Exception:
|
||||
class OpenCUAModel:
|
||||
"""OpenCUA model handler using AutoTokenizer, AutoModel and AutoImageProcessor."""
|
||||
|
||||
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
|
||||
def __init__(
|
||||
self, model_name: str, device: str = "auto", trust_remote_code: bool = False
|
||||
) -> None:
|
||||
if not OPENCUA_AVAILABLE:
|
||||
raise ImportError(
|
||||
"OpenCUA requirements not found. Install with: pip install \"cua-agent[opencua-hf]\""
|
||||
'OpenCUA requirements not found. Install with: pip install "cua-agent[opencua-hf]"'
|
||||
)
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
@@ -56,7 +63,11 @@ class OpenCUAModel:
|
||||
return ""
|
||||
|
||||
def generate(self, messages: List[Dict[str, Any]], max_new_tokens: int = 512) -> str:
|
||||
assert self.model is not None and self.tokenizer is not None and self.image_processor is not None
|
||||
assert (
|
||||
self.model is not None
|
||||
and self.tokenizer is not None
|
||||
and self.image_processor is not None
|
||||
)
|
||||
|
||||
# Tokenize text side using chat template
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
@@ -74,7 +85,11 @@ class OpenCUAModel:
|
||||
pixel_values = torch.tensor(image_info["pixel_values"]).to(
|
||||
dtype=torch.bfloat16, device=self.model.device
|
||||
)
|
||||
grid_thws = torch.tensor(image_info["image_grid_thw"]) if "image_grid_thw" in image_info else None
|
||||
grid_thws = (
|
||||
torch.tensor(image_info["image_grid_thw"])
|
||||
if "image_grid_thw" in image_info
|
||||
else None
|
||||
)
|
||||
|
||||
gen_kwargs: Dict[str, Any] = {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# Hugging Face imports are local to avoid hard dependency at module import
|
||||
try:
|
||||
import torch # type: ignore
|
||||
from transformers import AutoModelForImageTextToText, AutoProcessor # type: ignore
|
||||
|
||||
HF_AVAILABLE = True
|
||||
except Exception:
|
||||
HF_AVAILABLE = False
|
||||
@@ -14,10 +15,12 @@ class Qwen2_5_VLModel:
|
||||
Loads an AutoModelForImageTextToText and AutoProcessor and generates text.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
|
||||
def __init__(
|
||||
self, model_name: str, device: str = "auto", trust_remote_code: bool = False
|
||||
) -> None:
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError(
|
||||
"HuggingFace transformers dependencies not found. Install with: pip install \"cua-agent[uitars-hf]\""
|
||||
'HuggingFace transformers dependencies not found. Install with: pip install "cua-agent[uitars-hf]"'
|
||||
)
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
@@ -64,7 +67,7 @@ class Qwen2_5_VLModel:
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
# Trim prompt tokens from output
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
# Decode
|
||||
output_text = self.processor.batch_decode(
|
||||
|
||||
@@ -3,76 +3,83 @@ ComputerAgent - Main agent class that selects and runs agent loops
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Callable, Set, Tuple
|
||||
|
||||
from litellm.responses.utils import Usage
|
||||
|
||||
from .types import (
|
||||
Messages,
|
||||
AgentCapability,
|
||||
ToolError,
|
||||
IllegalArgumentError
|
||||
)
|
||||
from .responses import make_tool_error_item, replace_failed_computer_calls_with_function_calls
|
||||
from .decorators import find_agent_config
|
||||
import inspect
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import litellm
|
||||
import litellm.utils
|
||||
import inspect
|
||||
from litellm.responses.utils import Usage
|
||||
|
||||
from .adapters import (
|
||||
HuggingFaceLocalAdapter,
|
||||
HumanAdapter,
|
||||
MLXVLMAdapter,
|
||||
)
|
||||
from .callbacks import (
|
||||
ImageRetentionCallback,
|
||||
LoggingCallback,
|
||||
TrajectorySaverCallback,
|
||||
BudgetManagerCallback,
|
||||
TelemetryCallback,
|
||||
ImageRetentionCallback,
|
||||
LoggingCallback,
|
||||
OperatorNormalizerCallback,
|
||||
PromptInstructionsCallback,
|
||||
TelemetryCallback,
|
||||
TrajectorySaverCallback,
|
||||
)
|
||||
from .computers import (
|
||||
AsyncComputerHandler,
|
||||
is_agent_computer,
|
||||
make_computer_handler
|
||||
from .computers import AsyncComputerHandler, is_agent_computer, make_computer_handler
|
||||
from .decorators import find_agent_config
|
||||
from .responses import (
|
||||
make_tool_error_item,
|
||||
replace_failed_computer_calls_with_function_calls,
|
||||
)
|
||||
from .types import AgentCapability, IllegalArgumentError, Messages, ToolError
|
||||
|
||||
|
||||
def assert_callable_with(f, *args, **kwargs):
|
||||
"""Check if function can be called with given arguments."""
|
||||
try:
|
||||
inspect.signature(f).bind(*args, **kwargs)
|
||||
return True
|
||||
except TypeError as e:
|
||||
sig = inspect.signature(f)
|
||||
raise IllegalArgumentError(f"Expected {sig}, got args={args} kwargs={kwargs}") from e
|
||||
"""Check if function can be called with given arguments."""
|
||||
try:
|
||||
inspect.signature(f).bind(*args, **kwargs)
|
||||
return True
|
||||
except TypeError as e:
|
||||
sig = inspect.signature(f)
|
||||
raise IllegalArgumentError(f"Expected {sig}, got args={args} kwargs={kwargs}") from e
|
||||
|
||||
|
||||
def get_json(obj: Any, max_depth: int = 10) -> Any:
|
||||
def custom_serializer(o: Any, depth: int = 0, seen: Optional[Set[int]] = None) -> Any:
|
||||
if seen is None:
|
||||
seen = set()
|
||||
|
||||
|
||||
# Use model_dump() if available
|
||||
if hasattr(o, 'model_dump'):
|
||||
if hasattr(o, "model_dump"):
|
||||
return o.model_dump()
|
||||
|
||||
|
||||
# Check depth limit
|
||||
if depth > max_depth:
|
||||
return f"<max_depth_exceeded:{max_depth}>"
|
||||
|
||||
|
||||
# Check for circular references using object id
|
||||
obj_id = id(o)
|
||||
if obj_id in seen:
|
||||
return f"<circular_reference:{type(o).__name__}>"
|
||||
|
||||
|
||||
# Handle Computer objects
|
||||
if hasattr(o, '__class__') and 'computer' in getattr(o, '__class__').__name__.lower():
|
||||
if hasattr(o, "__class__") and "computer" in o.__class__.__name__.lower():
|
||||
return f"<computer:{o.__class__.__name__}>"
|
||||
|
||||
# Handle objects with __dict__
|
||||
if hasattr(o, '__dict__'):
|
||||
if hasattr(o, "__dict__"):
|
||||
seen.add(obj_id)
|
||||
try:
|
||||
result = {}
|
||||
@@ -84,7 +91,7 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
|
||||
return result
|
||||
finally:
|
||||
seen.discard(obj_id)
|
||||
|
||||
|
||||
# Handle common types that might contain nested objects
|
||||
elif isinstance(o, dict):
|
||||
seen.add(obj_id)
|
||||
@@ -96,7 +103,7 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
|
||||
}
|
||||
finally:
|
||||
seen.discard(obj_id)
|
||||
|
||||
|
||||
elif isinstance(o, (list, tuple, set)):
|
||||
seen.add(obj_id)
|
||||
try:
|
||||
@@ -107,32 +114,33 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
|
||||
]
|
||||
finally:
|
||||
seen.discard(obj_id)
|
||||
|
||||
|
||||
# For basic types that json.dumps can handle
|
||||
elif isinstance(o, (str, int, float, bool)) or o is None:
|
||||
return o
|
||||
|
||||
|
||||
# Fallback to string representation
|
||||
else:
|
||||
return str(o)
|
||||
|
||||
|
||||
def remove_nones(obj: Any) -> Any:
|
||||
if isinstance(obj, dict):
|
||||
return {k: remove_nones(v) for k, v in obj.items() if v is not None}
|
||||
elif isinstance(obj, list):
|
||||
return [remove_nones(item) for item in obj if item is not None]
|
||||
return obj
|
||||
|
||||
|
||||
# Serialize with circular reference and depth protection
|
||||
serialized = custom_serializer(obj)
|
||||
|
||||
|
||||
# Convert to JSON string and back to ensure JSON compatibility
|
||||
json_str = json.dumps(serialized)
|
||||
parsed = json.loads(json_str)
|
||||
|
||||
|
||||
# Final cleanup of any remaining None values
|
||||
return remove_nones(parsed)
|
||||
|
||||
|
||||
def sanitize_message(msg: Any) -> Any:
|
||||
"""Return a copy of the message with image_url omitted for computer_call_output messages."""
|
||||
if msg.get("type") == "computer_call_output":
|
||||
@@ -143,19 +151,24 @@ def sanitize_message(msg: Any) -> Any:
|
||||
return sanitized
|
||||
return msg
|
||||
|
||||
|
||||
def get_output_call_ids(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
call_ids = []
|
||||
for message in messages:
|
||||
if message.get("type") == "computer_call_output" or message.get("type") == "function_call_output":
|
||||
if (
|
||||
message.get("type") == "computer_call_output"
|
||||
or message.get("type") == "function_call_output"
|
||||
):
|
||||
call_ids.append(message.get("call_id"))
|
||||
return call_ids
|
||||
|
||||
|
||||
class ComputerAgent:
|
||||
"""
|
||||
Main agent class that automatically selects the appropriate agent loop
|
||||
based on the model and executes tool calls.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
@@ -172,11 +185,11 @@ class ComputerAgent:
|
||||
max_trajectory_budget: Optional[float | dict] = None,
|
||||
telemetry_enabled: Optional[bool] = True,
|
||||
trust_remote_code: Optional[bool] = False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize ComputerAgent.
|
||||
|
||||
|
||||
Args:
|
||||
model: Model name (e.g., "claude-3-5-sonnet-20241022", "computer-use-preview", "omni+vertex_ai/gemini-pro")
|
||||
tools: List of tools (computer objects, decorated functions, etc.)
|
||||
@@ -193,11 +206,11 @@ class ComputerAgent:
|
||||
telemetry_enabled: If set, adds TelemetryCallback to track anonymized usage data. Enabled by default.
|
||||
trust_remote_code: If set, trust remote code when loading local models. Disabled by default.
|
||||
**kwargs: Additional arguments passed to the agent loop
|
||||
"""
|
||||
"""
|
||||
# If the loop is "human/human", we need to prefix a grounding model fallback
|
||||
if model in ["human/human", "human"]:
|
||||
model = "openai/computer-use-preview+human/human"
|
||||
|
||||
|
||||
self.model = model
|
||||
self.tools = tools or []
|
||||
self.custom_loop = custom_loop
|
||||
@@ -236,34 +249,33 @@ class ComputerAgent:
|
||||
# Add image retention callback if only_n_most_recent_images is set
|
||||
if self.only_n_most_recent_images:
|
||||
self.callbacks.append(ImageRetentionCallback(self.only_n_most_recent_images))
|
||||
|
||||
|
||||
# Add trajectory saver callback if trajectory_dir is set
|
||||
if self.trajectory_dir:
|
||||
if isinstance(self.trajectory_dir, dict):
|
||||
self.callbacks.append(TrajectorySaverCallback(**self.trajectory_dir))
|
||||
elif isinstance(self.trajectory_dir, (str, Path)):
|
||||
self.callbacks.append(TrajectorySaverCallback(str(self.trajectory_dir)))
|
||||
|
||||
|
||||
# Add budget manager if max_trajectory_budget is set
|
||||
if max_trajectory_budget:
|
||||
if isinstance(max_trajectory_budget, dict):
|
||||
self.callbacks.append(BudgetManagerCallback(**max_trajectory_budget))
|
||||
else:
|
||||
self.callbacks.append(BudgetManagerCallback(max_trajectory_budget))
|
||||
|
||||
|
||||
# == Enable local model providers w/ LiteLLM ==
|
||||
|
||||
# Register local model providers
|
||||
hf_adapter = HuggingFaceLocalAdapter(
|
||||
device="auto",
|
||||
trust_remote_code=self.trust_remote_code or False
|
||||
device="auto", trust_remote_code=self.trust_remote_code or False
|
||||
)
|
||||
human_adapter = HumanAdapter()
|
||||
mlx_adapter = MLXVLMAdapter()
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "huggingface-local", "custom_handler": hf_adapter},
|
||||
{"provider": "human", "custom_handler": human_adapter},
|
||||
{"provider": "mlx", "custom_handler": mlx_adapter}
|
||||
{"provider": "mlx", "custom_handler": mlx_adapter},
|
||||
]
|
||||
litellm.suppress_debug_info = True
|
||||
|
||||
@@ -280,16 +292,16 @@ class ComputerAgent:
|
||||
# Instantiate the agent config class
|
||||
self.agent_loop = config_info.agent_class()
|
||||
self.agent_config_info = config_info
|
||||
|
||||
|
||||
self.tool_schemas = []
|
||||
self.computer_handler = None
|
||||
|
||||
|
||||
async def _initialize_computers(self):
|
||||
"""Initialize computer objects"""
|
||||
if not self.tool_schemas:
|
||||
# Process tools and create tool schemas
|
||||
self.tool_schemas = self._process_tools()
|
||||
|
||||
|
||||
# Find computer tool and create interface adapter
|
||||
computer_handler = None
|
||||
for schema in self.tool_schemas:
|
||||
@@ -297,7 +309,7 @@ class ComputerAgent:
|
||||
computer_handler = await make_computer_handler(schema["computer"])
|
||||
break
|
||||
self.computer_handler = computer_handler
|
||||
|
||||
|
||||
def _process_input(self, input: Messages) -> List[Dict[str, Any]]:
|
||||
"""Process input messages and create schemas for the agent loop"""
|
||||
if isinstance(input, str):
|
||||
@@ -307,69 +319,73 @@ class ComputerAgent:
|
||||
def _process_tools(self) -> List[Dict[str, Any]]:
|
||||
"""Process tools and create schemas for the agent loop"""
|
||||
schemas = []
|
||||
|
||||
|
||||
for tool in self.tools:
|
||||
# Check if it's a computer object (has interface attribute)
|
||||
if is_agent_computer(tool):
|
||||
# This is a computer tool - will be handled by agent loop
|
||||
schemas.append({
|
||||
"type": "computer",
|
||||
"computer": tool
|
||||
})
|
||||
schemas.append({"type": "computer", "computer": tool})
|
||||
elif callable(tool):
|
||||
# Use litellm.utils.function_to_dict to extract schema from docstring
|
||||
try:
|
||||
function_schema = litellm.utils.function_to_dict(tool)
|
||||
schemas.append({
|
||||
"type": "function",
|
||||
"function": function_schema
|
||||
})
|
||||
schemas.append({"type": "function", "function": function_schema})
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process tool {tool}: {e}")
|
||||
else:
|
||||
print(f"Warning: Unknown tool type: {tool}")
|
||||
|
||||
|
||||
return schemas
|
||||
|
||||
|
||||
def _get_tool(self, name: str) -> Optional[Callable]:
|
||||
"""Get a tool by name"""
|
||||
for tool in self.tools:
|
||||
if hasattr(tool, '__name__') and tool.__name__ == name:
|
||||
if hasattr(tool, "__name__") and tool.__name__ == name:
|
||||
return tool
|
||||
elif hasattr(tool, 'func') and tool.func.__name__ == name:
|
||||
elif hasattr(tool, "func") and tool.func.__name__ == name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# AGENT RUN LOOP LIFECYCLE HOOKS
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def _on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Initialize run tracking by calling callbacks."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_run_start'):
|
||||
if hasattr(callback, "on_run_start"):
|
||||
await callback.on_run_start(kwargs, old_items)
|
||||
|
||||
async def _on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
||||
|
||||
async def _on_run_end(
|
||||
self,
|
||||
kwargs: Dict[str, Any],
|
||||
old_items: List[Dict[str, Any]],
|
||||
new_items: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Finalize run tracking by calling callbacks."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_run_end'):
|
||||
if hasattr(callback, "on_run_end"):
|
||||
await callback.on_run_end(kwargs, old_items, new_items)
|
||||
|
||||
async def _on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
|
||||
|
||||
async def _on_run_continue(
|
||||
self,
|
||||
kwargs: Dict[str, Any],
|
||||
old_items: List[Dict[str, Any]],
|
||||
new_items: List[Dict[str, Any]],
|
||||
) -> bool:
|
||||
"""Check if run should continue by calling callbacks."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_run_continue'):
|
||||
if hasattr(callback, "on_run_continue"):
|
||||
should_continue = await callback.on_run_continue(kwargs, old_items, new_items)
|
||||
if not should_continue:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def _on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Prepare messages for the LLM call by applying callbacks."""
|
||||
result = messages
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_llm_start'):
|
||||
if hasattr(callback, "on_llm_start"):
|
||||
result = await callback.on_llm_start(result)
|
||||
return result
|
||||
|
||||
@@ -377,82 +393,91 @@ class ComputerAgent:
|
||||
"""Postprocess messages after the LLM call by applying callbacks."""
|
||||
result = messages
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_llm_end'):
|
||||
if hasattr(callback, "on_llm_end"):
|
||||
result = await callback.on_llm_end(result)
|
||||
return result
|
||||
|
||||
async def _on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
|
||||
"""Called when responses are received."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_responses'):
|
||||
if hasattr(callback, "on_responses"):
|
||||
await callback.on_responses(get_json(kwargs), get_json(responses))
|
||||
|
||||
|
||||
async def _on_computer_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a computer call is about to start."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_computer_call_start'):
|
||||
if hasattr(callback, "on_computer_call_start"):
|
||||
await callback.on_computer_call_start(get_json(item))
|
||||
|
||||
async def _on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
||||
|
||||
async def _on_computer_call_end(
|
||||
self, item: Dict[str, Any], result: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Called when a computer call has completed."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_computer_call_end'):
|
||||
if hasattr(callback, "on_computer_call_end"):
|
||||
await callback.on_computer_call_end(get_json(item), get_json(result))
|
||||
|
||||
|
||||
async def _on_function_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a function call is about to start."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_function_call_start'):
|
||||
if hasattr(callback, "on_function_call_start"):
|
||||
await callback.on_function_call_start(get_json(item))
|
||||
|
||||
async def _on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
||||
|
||||
async def _on_function_call_end(
|
||||
self, item: Dict[str, Any], result: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Called when a function call has completed."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_function_call_end'):
|
||||
if hasattr(callback, "on_function_call_end"):
|
||||
await callback.on_function_call_end(get_json(item), get_json(result))
|
||||
|
||||
|
||||
async def _on_text(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a text message is encountered."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_text'):
|
||||
if hasattr(callback, "on_text"):
|
||||
await callback.on_text(get_json(item))
|
||||
|
||||
|
||||
async def _on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
||||
"""Called when an LLM API call is about to start."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_api_start'):
|
||||
if hasattr(callback, "on_api_start"):
|
||||
await callback.on_api_start(get_json(kwargs))
|
||||
|
||||
|
||||
async def _on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
||||
"""Called when an LLM API call has completed."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_api_end'):
|
||||
if hasattr(callback, "on_api_end"):
|
||||
await callback.on_api_end(get_json(kwargs), get_json(result))
|
||||
|
||||
async def _on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Called when usage information is received."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_usage'):
|
||||
if hasattr(callback, "on_usage"):
|
||||
await callback.on_usage(get_json(usage))
|
||||
|
||||
async def _on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
|
||||
"""Called when a screenshot is taken."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_screenshot'):
|
||||
if hasattr(callback, "on_screenshot"):
|
||||
await callback.on_screenshot(screenshot, name)
|
||||
|
||||
# ============================================================================
|
||||
# AGENT OUTPUT PROCESSING
|
||||
# ============================================================================
|
||||
|
||||
async def _handle_item(self, item: Any, computer: Optional[AsyncComputerHandler] = None, ignore_call_ids: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
||||
|
||||
async def _handle_item(
|
||||
self,
|
||||
item: Any,
|
||||
computer: Optional[AsyncComputerHandler] = None,
|
||||
ignore_call_ids: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Handle each item; may cause a computer action + screenshot."""
|
||||
call_id = item.get("call_id")
|
||||
if ignore_call_ids and call_id and call_id in ignore_call_ids:
|
||||
return []
|
||||
|
||||
|
||||
item_type = item.get("type", None)
|
||||
|
||||
|
||||
if item_type == "message":
|
||||
await self._on_text(item)
|
||||
# # Print messages
|
||||
@@ -461,7 +486,7 @@ class ComputerAgent:
|
||||
# if content_item.get("text"):
|
||||
# print(content_item.get("text"))
|
||||
return []
|
||||
|
||||
|
||||
try:
|
||||
if item_type == "computer_call":
|
||||
await self._on_computer_call_start(item)
|
||||
@@ -472,14 +497,16 @@ class ComputerAgent:
|
||||
action = item.get("action")
|
||||
action_type = action.get("type")
|
||||
if action_type is None:
|
||||
print(f"Action type cannot be `None`: action={action}, action_type={action_type}")
|
||||
print(
|
||||
f"Action type cannot be `None`: action={action}, action_type={action_type}"
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
# Extract action arguments (all fields except 'type')
|
||||
action_args = {k: v for k, v in action.items() if k != "type"}
|
||||
|
||||
|
||||
# print(f"{action_type}({action_args})")
|
||||
|
||||
|
||||
# Execute the computer action
|
||||
computer_method = getattr(computer, action_type, None)
|
||||
if computer_method:
|
||||
@@ -487,13 +514,13 @@ class ComputerAgent:
|
||||
await computer_method(**action_args)
|
||||
else:
|
||||
raise ToolError(f"Unknown computer action: {action_type}")
|
||||
|
||||
|
||||
# Take screenshot after action
|
||||
if self.screenshot_delay and self.screenshot_delay > 0:
|
||||
await asyncio.sleep(self.screenshot_delay)
|
||||
screenshot_base64 = await computer.screenshot()
|
||||
await self._on_screenshot(screenshot_base64, "screenshot_after")
|
||||
|
||||
|
||||
# Handle safety checks
|
||||
pending_checks = item.get("pending_safety_checks", [])
|
||||
acknowledged_checks = []
|
||||
@@ -505,7 +532,7 @@ class ComputerAgent:
|
||||
# acknowledged_checks.append(check)
|
||||
# else:
|
||||
# raise ValueError(f"Safety check failed: {check_message}")
|
||||
|
||||
|
||||
# Create call output
|
||||
call_output = {
|
||||
"type": "computer_call_output",
|
||||
@@ -516,25 +543,25 @@ class ComputerAgent:
|
||||
"image_url": f"data:image/png;base64,{screenshot_base64}",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# # Additional URL safety checks for browser environments
|
||||
# if await computer.get_environment() == "browser":
|
||||
# current_url = await computer.get_current_url()
|
||||
# call_output["output"]["current_url"] = current_url
|
||||
# # TODO: implement a callback for URL safety checks
|
||||
# # check_blocklisted_url(current_url)
|
||||
|
||||
|
||||
result = [call_output]
|
||||
await self._on_computer_call_end(item, result)
|
||||
return result
|
||||
|
||||
|
||||
if item_type == "function_call":
|
||||
await self._on_function_call_start(item)
|
||||
# Perform function call
|
||||
function = self._get_tool(item.get("name"))
|
||||
if not function:
|
||||
raise ToolError(f"Function {item.get("name")} not found")
|
||||
|
||||
raise ToolError(f"Function {item.get('name')} not found")
|
||||
|
||||
args = json.loads(item.get("arguments"))
|
||||
|
||||
# Validate arguments before execution
|
||||
@@ -545,14 +572,14 @@ class ComputerAgent:
|
||||
result = await function(**args)
|
||||
else:
|
||||
result = await asyncio.to_thread(function, **args)
|
||||
|
||||
|
||||
# Create function call output
|
||||
call_output = {
|
||||
"type": "function_call_output",
|
||||
"call_id": item.get("call_id"),
|
||||
"output": str(result),
|
||||
}
|
||||
|
||||
|
||||
result = [call_output]
|
||||
await self._on_function_call_end(item, result)
|
||||
return result
|
||||
@@ -564,36 +591,35 @@ class ComputerAgent:
|
||||
# ============================================================================
|
||||
# MAIN AGENT LOOP
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def run(
|
||||
self,
|
||||
messages: Messages,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
self, messages: Messages, stream: bool = False, **kwargs
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
Run the agent with the given messages using Computer protocol handler pattern.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
stream: Whether to stream the response
|
||||
**kwargs: Additional arguments
|
||||
|
||||
|
||||
Returns:
|
||||
AsyncGenerator that yields response chunks
|
||||
"""
|
||||
if not self.agent_config_info:
|
||||
raise ValueError("Agent configuration not found")
|
||||
|
||||
|
||||
capabilities = self.get_capabilities()
|
||||
if "step" not in capabilities:
|
||||
raise ValueError(f"Agent loop {self.agent_config_info.agent_class.__name__} does not support step predictions")
|
||||
raise ValueError(
|
||||
f"Agent loop {self.agent_config_info.agent_class.__name__} does not support step predictions"
|
||||
)
|
||||
|
||||
await self._initialize_computers()
|
||||
|
||||
|
||||
# Merge kwargs
|
||||
merged_kwargs = {**self.kwargs, **kwargs}
|
||||
|
||||
|
||||
old_items = self._process_input(messages)
|
||||
new_items = []
|
||||
|
||||
@@ -603,7 +629,7 @@ class ComputerAgent:
|
||||
"stream": stream,
|
||||
"model": self.model,
|
||||
"agent_loop": self.agent_config_info.agent_class.__name__,
|
||||
**merged_kwargs
|
||||
**merged_kwargs,
|
||||
}
|
||||
await self._on_run_start(run_kwargs, old_items)
|
||||
|
||||
@@ -620,7 +646,7 @@ class ComputerAgent:
|
||||
combined_messages = old_items + new_items
|
||||
combined_messages = replace_failed_computer_calls_with_function_calls(combined_messages)
|
||||
preprocessed_messages = await self._on_llm_start(combined_messages)
|
||||
|
||||
|
||||
loop_kwargs = {
|
||||
"messages": preprocessed_messages,
|
||||
"model": self.model,
|
||||
@@ -629,7 +655,7 @@ class ComputerAgent:
|
||||
"computer_handler": self.computer_handler,
|
||||
"max_retries": self.max_retries,
|
||||
"use_prompt_caching": self.use_prompt_caching,
|
||||
**merged_kwargs
|
||||
**merged_kwargs,
|
||||
}
|
||||
|
||||
# Run agent loop iteration
|
||||
@@ -641,13 +667,13 @@ class ComputerAgent:
|
||||
_on_screenshot=self._on_screenshot,
|
||||
)
|
||||
result = get_json(result)
|
||||
|
||||
|
||||
# Lifecycle hook: Postprocess messages after the LLM call
|
||||
# Use cases:
|
||||
# - PII deanonymization (if you want tool calls to see PII)
|
||||
result["output"] = await self._on_llm_end(result.get("output", []))
|
||||
await self._on_responses(loop_kwargs, result)
|
||||
|
||||
|
||||
# Yield agent response
|
||||
yield result
|
||||
|
||||
@@ -659,7 +685,9 @@ class ComputerAgent:
|
||||
|
||||
# Handle computer actions
|
||||
for item in result.get("output"):
|
||||
partial_items = await self._handle_item(item, self.computer_handler, ignore_call_ids=output_call_ids)
|
||||
partial_items = await self._handle_item(
|
||||
item, self.computer_handler, ignore_call_ids=output_call_ids
|
||||
)
|
||||
new_items += partial_items
|
||||
|
||||
# Yield partial response
|
||||
@@ -669,54 +697,52 @@ class ComputerAgent:
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
await self._on_run_end(loop_kwargs, old_items, new_items)
|
||||
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
instruction: str,
|
||||
image_b64: Optional[str] = None
|
||||
self, instruction: str, image_b64: Optional[str] = None
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates based on image and instruction.
|
||||
|
||||
|
||||
Args:
|
||||
instruction: Instruction for where to click
|
||||
image_b64: Base64 encoded image (optional, will take screenshot if not provided)
|
||||
|
||||
|
||||
Returns:
|
||||
None or tuple with (x, y) coordinates
|
||||
"""
|
||||
if not self.agent_config_info:
|
||||
raise ValueError("Agent configuration not found")
|
||||
|
||||
|
||||
capabilities = self.get_capabilities()
|
||||
if "click" not in capabilities:
|
||||
raise ValueError(f"Agent loop {self.agent_config_info.agent_class.__name__} does not support click predictions")
|
||||
if hasattr(self.agent_loop, 'predict_click'):
|
||||
raise ValueError(
|
||||
f"Agent loop {self.agent_config_info.agent_class.__name__} does not support click predictions"
|
||||
)
|
||||
if hasattr(self.agent_loop, "predict_click"):
|
||||
if not image_b64:
|
||||
if not self.computer_handler:
|
||||
raise ValueError("Computer tool or image_b64 is required for predict_click")
|
||||
image_b64 = await self.computer_handler.screenshot()
|
||||
return await self.agent_loop.predict_click(
|
||||
model=self.model,
|
||||
image_b64=image_b64,
|
||||
instruction=instruction
|
||||
model=self.model, image_b64=image_b64, instruction=instruction
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""
|
||||
Get list of capabilities supported by the current agent config.
|
||||
|
||||
|
||||
Returns:
|
||||
List of capability strings (e.g., ["step", "click"])
|
||||
"""
|
||||
if not self.agent_config_info:
|
||||
raise ValueError("Agent configuration not found")
|
||||
|
||||
if hasattr(self.agent_loop, 'get_capabilities'):
|
||||
|
||||
if hasattr(self.agent_loop, "get_capabilities"):
|
||||
return self.agent_loop.get_capabilities()
|
||||
return ["step"] # Default capability
|
||||
return ["step"] # Default capability
|
||||
|
||||
@@ -3,17 +3,17 @@ Callback system for ComputerAgent preprocessing and postprocessing hooks.
|
||||
"""
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
from .budget_manager import BudgetManagerCallback
|
||||
from .image_retention import ImageRetentionCallback
|
||||
from .logging import LoggingCallback
|
||||
from .trajectory_saver import TrajectorySaverCallback
|
||||
from .budget_manager import BudgetManagerCallback
|
||||
from .telemetry import TelemetryCallback
|
||||
from .operator_validator import OperatorNormalizerCallback
|
||||
from .prompt_instructions import PromptInstructionsCallback
|
||||
from .telemetry import TelemetryCallback
|
||||
from .trajectory_saver import TrajectorySaverCallback
|
||||
|
||||
__all__ = [
|
||||
"AsyncCallbackHandler",
|
||||
"ImageRetentionCallback",
|
||||
"ImageRetentionCallback",
|
||||
"LoggingCallback",
|
||||
"TrajectorySaverCallback",
|
||||
"BudgetManagerCallback",
|
||||
|
||||
@@ -3,7 +3,7 @@ Base callback handler interface for ComputerAgent preprocessing and postprocessi
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
class AsyncCallbackHandler(ABC):
|
||||
@@ -16,42 +16,52 @@ class AsyncCallbackHandler(ABC):
|
||||
"""Called at the start of an agent run loop."""
|
||||
pass
|
||||
|
||||
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
||||
async def on_run_end(
|
||||
self,
|
||||
kwargs: Dict[str, Any],
|
||||
old_items: List[Dict[str, Any]],
|
||||
new_items: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Called at the end of an agent run loop."""
|
||||
pass
|
||||
|
||||
async def on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
|
||||
|
||||
async def on_run_continue(
|
||||
self,
|
||||
kwargs: Dict[str, Any],
|
||||
old_items: List[Dict[str, Any]],
|
||||
new_items: List[Dict[str, Any]],
|
||||
) -> bool:
|
||||
"""Called during agent run loop to determine if execution should continue.
|
||||
|
||||
|
||||
Args:
|
||||
kwargs: Run arguments
|
||||
old_items: Original messages
|
||||
new_items: New messages generated during run
|
||||
|
||||
|
||||
Returns:
|
||||
True to continue execution, False to stop
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Called before messages are sent to the agent loop.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries to preprocess
|
||||
|
||||
|
||||
Returns:
|
||||
List of preprocessed message dictionaries
|
||||
"""
|
||||
return messages
|
||||
|
||||
|
||||
async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Called after the agent loop returns output.
|
||||
|
||||
|
||||
Args:
|
||||
output: List of output message dictionaries to postprocess
|
||||
|
||||
|
||||
Returns:
|
||||
List of postprocessed output dictionaries
|
||||
"""
|
||||
@@ -60,63 +70,67 @@ class AsyncCallbackHandler(ABC):
|
||||
async def on_computer_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when a computer call is about to start.
|
||||
|
||||
|
||||
Args:
|
||||
item: The computer call item dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
||||
|
||||
async def on_computer_call_end(
|
||||
self, item: Dict[str, Any], result: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""
|
||||
Called when a computer call has completed.
|
||||
|
||||
|
||||
Args:
|
||||
item: The computer call item dictionary
|
||||
result: The result of the computer call
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def on_function_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when a function call is about to start.
|
||||
|
||||
|
||||
Args:
|
||||
item: The function call item dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
||||
|
||||
async def on_function_call_end(
|
||||
self, item: Dict[str, Any], result: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""
|
||||
Called when a function call has completed.
|
||||
|
||||
|
||||
Args:
|
||||
item: The function call item dictionary
|
||||
result: The result of the function call
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def on_text(self, item: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when a text message is encountered.
|
||||
|
||||
|
||||
Args:
|
||||
item: The message item dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when an API call is about to start.
|
||||
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs being passed to the API call
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
||||
"""
|
||||
Called when an API call has completed.
|
||||
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs that were passed to the API call
|
||||
result: The result of the API call
|
||||
@@ -126,7 +140,7 @@ class AsyncCallbackHandler(ABC):
|
||||
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when usage information is received.
|
||||
|
||||
|
||||
Args:
|
||||
usage: The usage information
|
||||
"""
|
||||
@@ -135,7 +149,7 @@ class AsyncCallbackHandler(ABC):
|
||||
async def on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
|
||||
"""
|
||||
Called when a screenshot is taken.
|
||||
|
||||
|
||||
Args:
|
||||
screenshot: The screenshot image
|
||||
name: The name of the screenshot
|
||||
@@ -145,9 +159,9 @@ class AsyncCallbackHandler(ABC):
|
||||
async def on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when responses are received.
|
||||
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs being passed to the agent loop
|
||||
responses: The responses received
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
|
||||
class BudgetExceededError(Exception):
|
||||
"""Exception raised when budget is exceeded."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BudgetManagerCallback(AsyncCallbackHandler):
|
||||
"""Budget manager callback that tracks usage costs and can stop execution when budget is exceeded."""
|
||||
|
||||
def __init__(self, max_budget: float, reset_after_each_run: bool = True, raise_error: bool = False):
|
||||
|
||||
def __init__(
|
||||
self, max_budget: float, reset_after_each_run: bool = True, raise_error: bool = False
|
||||
):
|
||||
"""
|
||||
Initialize BudgetManagerCallback.
|
||||
|
||||
|
||||
Args:
|
||||
max_budget: Maximum budget allowed
|
||||
reset_after_each_run: Whether to reset budget after each run
|
||||
@@ -21,24 +27,30 @@ class BudgetManagerCallback(AsyncCallbackHandler):
|
||||
self.reset_after_each_run = reset_after_each_run
|
||||
self.raise_error = raise_error
|
||||
self.total_cost = 0.0
|
||||
|
||||
|
||||
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Reset budget if configured to do so."""
|
||||
if self.reset_after_each_run:
|
||||
self.total_cost = 0.0
|
||||
|
||||
|
||||
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Track usage costs."""
|
||||
if "response_cost" in usage:
|
||||
self.total_cost += usage["response_cost"]
|
||||
|
||||
async def on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
|
||||
|
||||
async def on_run_continue(
|
||||
self,
|
||||
kwargs: Dict[str, Any],
|
||||
old_items: List[Dict[str, Any]],
|
||||
new_items: List[Dict[str, Any]],
|
||||
) -> bool:
|
||||
"""Check if budget allows continuation."""
|
||||
if self.total_cost >= self.max_budget:
|
||||
if self.raise_error:
|
||||
raise BudgetExceededError(f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}")
|
||||
raise BudgetExceededError(
|
||||
f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}"
|
||||
)
|
||||
else:
|
||||
print(f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}")
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
Image retention callback handler that limits the number of recent images in message history.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
|
||||
@@ -11,40 +12,40 @@ class ImageRetentionCallback(AsyncCallbackHandler):
|
||||
Callback handler that applies image retention policy to limit the number
|
||||
of recent images in message history to prevent context window overflow.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, only_n_most_recent_images: Optional[int] = None):
|
||||
"""
|
||||
Initialize the image retention callback.
|
||||
|
||||
|
||||
Args:
|
||||
only_n_most_recent_images: If set, only keep the N most recent images in message history
|
||||
"""
|
||||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
|
||||
|
||||
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Apply image retention policy to messages before sending to agent loop.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
|
||||
Returns:
|
||||
List of messages with image retention policy applied
|
||||
"""
|
||||
if self.only_n_most_recent_images is None:
|
||||
return messages
|
||||
|
||||
|
||||
return self._apply_image_retention(messages)
|
||||
|
||||
|
||||
def _apply_image_retention(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Apply image retention policy to keep only the N most recent images.
|
||||
|
||||
|
||||
Removes computer_call_output items with image_url and their corresponding computer_call items,
|
||||
keeping only the most recent N image pairs based on only_n_most_recent_images setting.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
|
||||
Returns:
|
||||
Filtered list of messages with image retention applied
|
||||
"""
|
||||
@@ -78,7 +79,11 @@ class ImageRetentionCallback(AsyncCallbackHandler):
|
||||
# Remove the immediately preceding computer_call with matching call_id (if present)
|
||||
call_id = messages[idx].get("call_id")
|
||||
prev_idx = idx - 1
|
||||
if prev_idx >= 0 and messages[prev_idx].get("type") == "computer_call" and messages[prev_idx].get("call_id") == call_id:
|
||||
if (
|
||||
prev_idx >= 0
|
||||
and messages[prev_idx].get("type") == "computer_call"
|
||||
and messages[prev_idx].get("call_id") == call_id
|
||||
):
|
||||
to_remove.add(prev_idx)
|
||||
# Check a single reasoning immediately before that computer_call
|
||||
r_idx = prev_idx - 1
|
||||
@@ -87,4 +92,4 @@ class ImageRetentionCallback(AsyncCallbackHandler):
|
||||
|
||||
# Construct filtered list
|
||||
filtered = [m for i, m in enumerate(messages) if i not in to_remove]
|
||||
return filtered
|
||||
return filtered
|
||||
|
||||
@@ -4,17 +4,18 @@ Logging callback for ComputerAgent that provides configurable logging of agent l
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
|
||||
def sanitize_image_urls(data: Any) -> Any:
|
||||
"""
|
||||
Recursively search for 'image_url' keys and set their values to '[omitted]'.
|
||||
|
||||
|
||||
Args:
|
||||
data: Any data structure (dict, list, or primitive type)
|
||||
|
||||
|
||||
Returns:
|
||||
A deep copy of the data with all 'image_url' values replaced with '[omitted]'
|
||||
"""
|
||||
@@ -28,11 +29,11 @@ def sanitize_image_urls(data: Any) -> Any:
|
||||
# Recursively sanitize the value
|
||||
sanitized[key] = sanitize_image_urls(value)
|
||||
return sanitized
|
||||
|
||||
|
||||
elif isinstance(data, list):
|
||||
# Recursively sanitize each item in the list
|
||||
return [sanitize_image_urls(item) for item in data]
|
||||
|
||||
|
||||
else:
|
||||
# For primitive types (str, int, bool, None, etc.), return as-is
|
||||
return data
|
||||
@@ -41,37 +42,36 @@ def sanitize_image_urls(data: Any) -> Any:
|
||||
class LoggingCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
Callback handler that logs agent lifecycle events with configurable verbosity.
|
||||
|
||||
|
||||
Logging levels:
|
||||
- DEBUG: All events including API calls, message preprocessing, and detailed outputs
|
||||
- INFO: Major lifecycle events (start/end, messages, outputs)
|
||||
- INFO: Major lifecycle events (start/end, messages, outputs)
|
||||
- WARNING: Only warnings and errors
|
||||
- ERROR: Only errors
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, logger: Optional[logging.Logger] = None, level: int = logging.INFO):
|
||||
"""
|
||||
Initialize the logging callback.
|
||||
|
||||
|
||||
Args:
|
||||
logger: Logger instance to use. If None, creates a logger named 'agent.ComputerAgent'
|
||||
level: Logging level (logging.DEBUG, logging.INFO, etc.)
|
||||
"""
|
||||
self.logger = logger or logging.getLogger('agent.ComputerAgent')
|
||||
self.logger = logger or logging.getLogger("agent.ComputerAgent")
|
||||
self.level = level
|
||||
|
||||
|
||||
# Set up logger if it doesn't have handlers
|
||||
if not self.logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
self.logger.addHandler(handler)
|
||||
self.logger.setLevel(level)
|
||||
|
||||
|
||||
def _update_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Update total usage statistics."""
|
||||
|
||||
def add_dicts(target: Dict[str, Any], source: Dict[str, Any]) -> None:
|
||||
for key, value in source.items():
|
||||
if isinstance(value, dict):
|
||||
@@ -82,18 +82,25 @@ class LoggingCallback(AsyncCallbackHandler):
|
||||
if key not in target:
|
||||
target[key] = 0
|
||||
target[key] += value
|
||||
|
||||
add_dicts(self.total_usage, usage)
|
||||
|
||||
|
||||
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Called before the run starts."""
|
||||
self.total_usage = {}
|
||||
|
||||
|
||||
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Called when usage information is received."""
|
||||
self._update_usage(usage)
|
||||
|
||||
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
||||
async def on_run_end(
|
||||
self,
|
||||
kwargs: Dict[str, Any],
|
||||
old_items: List[Dict[str, Any]],
|
||||
new_items: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Called after the run ends."""
|
||||
|
||||
def format_dict(d, indent=0):
|
||||
lines = []
|
||||
prefix = f" - {' ' * indent}"
|
||||
@@ -106,10 +113,10 @@ class LoggingCallback(AsyncCallbackHandler):
|
||||
else:
|
||||
lines.append(f"{prefix}{key}: {value}")
|
||||
return lines
|
||||
|
||||
|
||||
formatted_output = "\n".join(format_dict(self.total_usage))
|
||||
self.logger.info(f"Total usage:\n{formatted_output}")
|
||||
|
||||
|
||||
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Called before LLM processing starts."""
|
||||
if self.logger.isEnabledFor(logging.INFO):
|
||||
@@ -118,27 +125,27 @@ class LoggingCallback(AsyncCallbackHandler):
|
||||
sanitized_messages = [sanitize_image_urls(msg) for msg in messages]
|
||||
self.logger.debug(f"LLM input messages: {json.dumps(sanitized_messages, indent=2)}")
|
||||
return messages
|
||||
|
||||
|
||||
async def on_llm_end(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Called after LLM processing ends."""
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
sanitized_messages = [sanitize_image_urls(msg) for msg in messages]
|
||||
self.logger.debug(f"LLM output: {json.dumps(sanitized_messages, indent=2)}")
|
||||
return messages
|
||||
|
||||
|
||||
async def on_computer_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a computer call starts."""
|
||||
action = item.get("action", {})
|
||||
action_type = action.get("type", "unknown")
|
||||
action_args = {k: v for k, v in action.items() if k != "type"}
|
||||
|
||||
|
||||
# INFO level logging for the action
|
||||
self.logger.info(f"Computer: {action_type}({action_args})")
|
||||
|
||||
|
||||
# DEBUG level logging for full details
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
self.logger.debug(f"Computer call started: {json.dumps(action, indent=2)}")
|
||||
|
||||
|
||||
async def on_computer_call_end(self, item: Dict[str, Any], result: Any) -> None:
|
||||
"""Called when a computer call ends."""
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
@@ -147,48 +154,52 @@ class LoggingCallback(AsyncCallbackHandler):
|
||||
if result:
|
||||
sanitized_result = sanitize_image_urls(result)
|
||||
self.logger.debug(f"Computer call result: {json.dumps(sanitized_result, indent=2)}")
|
||||
|
||||
|
||||
async def on_function_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a function call starts."""
|
||||
name = item.get("name", "unknown")
|
||||
arguments = item.get("arguments", "{}")
|
||||
|
||||
|
||||
# INFO level logging for the function call
|
||||
self.logger.info(f"Function: {name}({arguments})")
|
||||
|
||||
|
||||
# DEBUG level logging for full details
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
self.logger.debug(f"Function call started: {name}")
|
||||
|
||||
|
||||
async def on_function_call_end(self, item: Dict[str, Any], result: Any) -> None:
|
||||
"""Called when a function call ends."""
|
||||
# INFO level logging for function output (similar to function_call_output)
|
||||
if result:
|
||||
# Handle both list and direct result formats
|
||||
if isinstance(result, list) and len(result) > 0:
|
||||
output = result[0].get("output", str(result)) if isinstance(result[0], dict) else str(result[0])
|
||||
output = (
|
||||
result[0].get("output", str(result))
|
||||
if isinstance(result[0], dict)
|
||||
else str(result[0])
|
||||
)
|
||||
else:
|
||||
output = str(result)
|
||||
|
||||
|
||||
# Truncate long outputs
|
||||
if len(output) > 100:
|
||||
output = output[:100] + "..."
|
||||
|
||||
|
||||
self.logger.info(f"Output: {output}")
|
||||
|
||||
|
||||
# DEBUG level logging for full details
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
name = item.get("name", "unknown")
|
||||
self.logger.debug(f"Function call completed: {name}")
|
||||
if result:
|
||||
self.logger.debug(f"Function call result: {json.dumps(result, indent=2)}")
|
||||
|
||||
|
||||
async def on_text(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a text message is encountered."""
|
||||
# Get the role to determine if it's Agent or User
|
||||
role = item.get("role", "unknown")
|
||||
content_items = item.get("content", [])
|
||||
|
||||
|
||||
# Process content items to build display text
|
||||
text_parts = []
|
||||
for content_item in content_items:
|
||||
@@ -206,10 +217,10 @@ class LoggingCallback(AsyncCallbackHandler):
|
||||
else:
|
||||
# Non-text content, show as [type]
|
||||
text_parts.append(f"[{content_type}]")
|
||||
|
||||
|
||||
# Join all text parts
|
||||
display_text = ''.join(text_parts) if text_parts else "[empty]"
|
||||
|
||||
display_text = "".join(text_parts) if text_parts else "[empty]"
|
||||
|
||||
# Log with appropriate level and format
|
||||
if role == "assistant":
|
||||
self.logger.info(f"Agent: {display_text}")
|
||||
@@ -219,7 +230,7 @@ class LoggingCallback(AsyncCallbackHandler):
|
||||
# Fallback for unknown roles, use debug level
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
self.logger.debug(f"Text message ({role}): {display_text}")
|
||||
|
||||
|
||||
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
||||
"""Called when an API call is about to start."""
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
@@ -232,16 +243,18 @@ class LoggingCallback(AsyncCallbackHandler):
|
||||
elif "input" in kwargs:
|
||||
sanitized_input = sanitize_image_urls(kwargs["input"])
|
||||
self.logger.debug(f"API call input: {json.dumps(sanitized_input, indent=2)}")
|
||||
|
||||
|
||||
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
||||
"""Called when an API call has completed."""
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
model = kwargs.get("model", "unknown")
|
||||
self.logger.debug(f"API call completed for model: {model}")
|
||||
self.logger.debug(f"API call result: {json.dumps(sanitize_image_urls(result), indent=2)}")
|
||||
self.logger.debug(
|
||||
f"API call result: {json.dumps(sanitize_image_urls(result), indent=2)}"
|
||||
)
|
||||
|
||||
async def on_screenshot(self, item: Union[str, bytes], name: str = "screenshot") -> None:
|
||||
"""Called when a screenshot is taken."""
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
image_size = len(item) / 1024
|
||||
self.logger.debug(f"Screenshot captured: {name} {image_size:.2f} KB")
|
||||
self.logger.debug(f"Screenshot captured: {name} {image_size:.2f} KB")
|
||||
|
||||
@@ -9,6 +9,7 @@ Ensures agent output actions conform to expected schemas by fixing common issues
|
||||
This runs in on_llm_end, which receives the output array (AgentMessage[] as dicts).
|
||||
The purpose is to avoid spending another LLM call to fix broken computer call syntax when possible.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
@@ -48,6 +49,7 @@ class OperatorNormalizerCallback(AsyncCallbackHandler):
|
||||
action["type"] = "type"
|
||||
|
||||
action_type = action.get("type")
|
||||
|
||||
def _keep_keys(action: Dict[str, Any], keys_to_keep: List[str]):
|
||||
"""Keep only the provided keys on action; delete everything else.
|
||||
Always ensures required 'type' is present if listed in keys_to_keep.
|
||||
@@ -55,6 +57,7 @@ class OperatorNormalizerCallback(AsyncCallbackHandler):
|
||||
for key in list(action.keys()):
|
||||
if key not in keys_to_keep:
|
||||
del action[key]
|
||||
|
||||
# rename "coordinate" to "x", "y"
|
||||
if "coordinate" in action:
|
||||
action["x"] = action["coordinate"][0]
|
||||
@@ -100,7 +103,6 @@ class OperatorNormalizerCallback(AsyncCallbackHandler):
|
||||
keep = required_keys_by_type.get(action_type or "")
|
||||
if keep:
|
||||
_keep_keys(action, keep)
|
||||
|
||||
|
||||
# # Second pass: if an assistant message is immediately followed by a computer_call,
|
||||
# # replace the assistant message itself with a reasoning message with summary text.
|
||||
|
||||
@@ -2,38 +2,41 @@
|
||||
PII anonymization callback handler using Microsoft Presidio for text and image redaction.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from .base import AsyncCallbackHandler
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
try:
|
||||
# TODO: Add Presidio dependencies
|
||||
from PIL import Image
|
||||
|
||||
PRESIDIO_AVAILABLE = True
|
||||
except ImportError:
|
||||
PRESIDIO_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PIIAnonymizationCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
Callback handler that anonymizes PII in text and images using Microsoft Presidio.
|
||||
|
||||
|
||||
This handler:
|
||||
1. Anonymizes PII in messages before sending to the agent loop
|
||||
2. Deanonymizes PII in tool calls and message outputs after the agent loop
|
||||
3. Redacts PII from images in computer_call_output messages
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# TODO: Any extra kwargs if needed
|
||||
):
|
||||
"""
|
||||
Initialize the PII anonymization callback.
|
||||
|
||||
|
||||
Args:
|
||||
anonymize_text: Whether to anonymize text content
|
||||
anonymize_images: Whether to redact images
|
||||
@@ -46,16 +49,16 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
|
||||
"Presidio is not available. Install with: "
|
||||
"pip install cua-agent[pii-anonymization]"
|
||||
)
|
||||
|
||||
|
||||
# TODO: Implement __init__
|
||||
|
||||
|
||||
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Anonymize PII in messages before sending to agent loop.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
|
||||
Returns:
|
||||
List of messages with PII anonymized
|
||||
"""
|
||||
@@ -63,16 +66,16 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
|
||||
for msg in messages:
|
||||
anonymized_msg = await self._anonymize_message(msg)
|
||||
anonymized_messages.append(anonymized_msg)
|
||||
|
||||
|
||||
return anonymized_messages
|
||||
|
||||
|
||||
async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Deanonymize PII in tool calls and message outputs after agent loop.
|
||||
|
||||
|
||||
Args:
|
||||
output: List of output dictionaries
|
||||
|
||||
|
||||
Returns:
|
||||
List of output with PII deanonymized for tool calls
|
||||
"""
|
||||
@@ -84,13 +87,13 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
|
||||
deanonymized_output.append(deanonymized_item)
|
||||
else:
|
||||
deanonymized_output.append(item)
|
||||
|
||||
|
||||
return deanonymized_output
|
||||
|
||||
|
||||
async def _anonymize_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# TODO: Implement _anonymize_message
|
||||
return message
|
||||
|
||||
|
||||
async def _deanonymize_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# TODO: Implement _deanonymize_item
|
||||
return item
|
||||
|
||||
@@ -2,17 +2,17 @@
|
||||
Telemetry callback handler for Computer-Use Agent (cua-agent)
|
||||
"""
|
||||
|
||||
import platform
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
from core.telemetry import (
|
||||
record_event,
|
||||
is_telemetry_enabled,
|
||||
record_event,
|
||||
)
|
||||
|
||||
import platform
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
SYSTEM_INFO = {
|
||||
"os": platform.system().lower(),
|
||||
@@ -20,32 +20,29 @@ SYSTEM_INFO = {
|
||||
"python_version": platform.python_version(),
|
||||
}
|
||||
|
||||
|
||||
class TelemetryCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
Telemetry callback handler for Computer-Use Agent (cua-agent)
|
||||
|
||||
|
||||
Tracks agent usage, performance metrics, and optionally trajectory data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent,
|
||||
log_trajectory: bool = False
|
||||
):
|
||||
|
||||
def __init__(self, agent, log_trajectory: bool = False):
|
||||
"""
|
||||
Initialize telemetry callback.
|
||||
|
||||
|
||||
Args:
|
||||
agent: The ComputerAgent instance
|
||||
log_trajectory: Whether to log full trajectory items (opt-in)
|
||||
"""
|
||||
self.agent = agent
|
||||
self.log_trajectory = log_trajectory
|
||||
|
||||
|
||||
# Generate session/run IDs
|
||||
self.session_id = str(uuid.uuid4())
|
||||
self.run_id = None
|
||||
|
||||
|
||||
# Track timing and metrics
|
||||
self.run_start_time = None
|
||||
self.step_count = 0
|
||||
@@ -54,126 +51,133 @@ class TelemetryCallback(AsyncCallbackHandler):
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"response_cost": 0.0
|
||||
"response_cost": 0.0,
|
||||
}
|
||||
|
||||
|
||||
# Record agent initialization
|
||||
if is_telemetry_enabled():
|
||||
self._record_agent_initialization()
|
||||
|
||||
|
||||
def _record_agent_initialization(self) -> None:
|
||||
"""Record agent type/model and session initialization."""
|
||||
agent_info = {
|
||||
"session_id": self.session_id,
|
||||
"agent_type": self.agent.agent_loop.__name__ if hasattr(self.agent, 'agent_loop') else 'unknown',
|
||||
"model": getattr(self.agent, 'model', 'unknown'),
|
||||
**SYSTEM_INFO
|
||||
"agent_type": (
|
||||
self.agent.agent_loop.__name__ if hasattr(self.agent, "agent_loop") else "unknown"
|
||||
),
|
||||
"model": getattr(self.agent, "model", "unknown"),
|
||||
**SYSTEM_INFO,
|
||||
}
|
||||
|
||||
|
||||
record_event("agent_session_start", agent_info)
|
||||
|
||||
|
||||
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Called at the start of an agent run loop."""
|
||||
if not is_telemetry_enabled():
|
||||
return
|
||||
|
||||
|
||||
self.run_id = str(uuid.uuid4())
|
||||
self.run_start_time = time.time()
|
||||
self.step_count = 0
|
||||
|
||||
|
||||
# Calculate input context size
|
||||
input_context_size = self._calculate_context_size(old_items)
|
||||
|
||||
|
||||
run_data = {
|
||||
"session_id": self.session_id,
|
||||
"run_id": self.run_id,
|
||||
"start_time": self.run_start_time,
|
||||
"input_context_size": input_context_size,
|
||||
"num_existing_messages": len(old_items)
|
||||
"num_existing_messages": len(old_items),
|
||||
}
|
||||
|
||||
|
||||
# Log trajectory if opted in
|
||||
if self.log_trajectory:
|
||||
trajectory = self._extract_trajectory(old_items)
|
||||
if trajectory:
|
||||
run_data["uploaded_trajectory"] = trajectory
|
||||
|
||||
|
||||
record_event("agent_run_start", run_data)
|
||||
|
||||
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
||||
|
||||
async def on_run_end(
|
||||
self,
|
||||
kwargs: Dict[str, Any],
|
||||
old_items: List[Dict[str, Any]],
|
||||
new_items: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Called at the end of an agent run loop."""
|
||||
if not is_telemetry_enabled() or not self.run_start_time:
|
||||
return
|
||||
|
||||
|
||||
run_duration = time.time() - self.run_start_time
|
||||
|
||||
|
||||
run_data = {
|
||||
"session_id": self.session_id,
|
||||
"run_id": self.run_id,
|
||||
"end_time": time.time(),
|
||||
"duration_seconds": run_duration,
|
||||
"num_steps": self.step_count,
|
||||
"total_usage": self.total_usage.copy()
|
||||
"total_usage": self.total_usage.copy(),
|
||||
}
|
||||
|
||||
|
||||
# Log trajectory if opted in
|
||||
if self.log_trajectory:
|
||||
trajectory = self._extract_trajectory(new_items)
|
||||
if trajectory:
|
||||
run_data["uploaded_trajectory"] = trajectory
|
||||
|
||||
|
||||
record_event("agent_run_end", run_data)
|
||||
|
||||
|
||||
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Called when usage information is received."""
|
||||
if not is_telemetry_enabled():
|
||||
return
|
||||
|
||||
|
||||
# Accumulate usage stats
|
||||
self.total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
|
||||
self.total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
|
||||
self.total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
|
||||
self.total_usage["total_tokens"] += usage.get("total_tokens", 0)
|
||||
self.total_usage["response_cost"] += usage.get("response_cost", 0.0)
|
||||
|
||||
|
||||
# Record individual usage event
|
||||
usage_data = {
|
||||
"session_id": self.session_id,
|
||||
"run_id": self.run_id,
|
||||
"step": self.step_count,
|
||||
**usage
|
||||
**usage,
|
||||
}
|
||||
|
||||
|
||||
record_event("agent_usage", usage_data)
|
||||
|
||||
|
||||
async def on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
|
||||
"""Called when responses are received."""
|
||||
if not is_telemetry_enabled():
|
||||
return
|
||||
|
||||
|
||||
self.step_count += 1
|
||||
step_duration = None
|
||||
|
||||
|
||||
if self.step_start_time:
|
||||
step_duration = time.time() - self.step_start_time
|
||||
|
||||
|
||||
self.step_start_time = time.time()
|
||||
|
||||
|
||||
step_data = {
|
||||
"session_id": self.session_id,
|
||||
"run_id": self.run_id,
|
||||
"step": self.step_count,
|
||||
"timestamp": self.step_start_time
|
||||
"timestamp": self.step_start_time,
|
||||
}
|
||||
|
||||
|
||||
if step_duration is not None:
|
||||
step_data["duration_seconds"] = step_duration
|
||||
|
||||
|
||||
record_event("agent_step", step_data)
|
||||
|
||||
|
||||
def _calculate_context_size(self, items: List[Dict[str, Any]]) -> int:
|
||||
"""Calculate approximate context size in tokens/characters."""
|
||||
total_size = 0
|
||||
|
||||
|
||||
for item in items:
|
||||
if item.get("type") == "message" and "content" in item:
|
||||
content = item["content"]
|
||||
@@ -185,25 +189,27 @@ class TelemetryCallback(AsyncCallbackHandler):
|
||||
total_size += len(part["text"])
|
||||
elif "content" in item and isinstance(item["content"], str):
|
||||
total_size += len(item["content"])
|
||||
|
||||
|
||||
return total_size
|
||||
|
||||
|
||||
def _extract_trajectory(self, items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Extract trajectory items that should be logged."""
|
||||
trajectory = []
|
||||
|
||||
|
||||
for item in items:
|
||||
# Include user messages, assistant messages, reasoning, computer calls, and computer outputs
|
||||
if (
|
||||
item.get("role") == "user" or # User inputs
|
||||
(item.get("type") == "message" and item.get("role") == "assistant") or # Model outputs
|
||||
item.get("type") == "reasoning" or # Reasoning traces
|
||||
item.get("type") == "computer_call" or # Computer actions
|
||||
item.get("type") == "computer_call_output" # Computer outputs
|
||||
item.get("role") == "user" # User inputs
|
||||
or (
|
||||
item.get("type") == "message" and item.get("role") == "assistant"
|
||||
) # Model outputs
|
||||
or item.get("type") == "reasoning" # Reasoning traces
|
||||
or item.get("type") == "computer_call" # Computer actions
|
||||
or item.get("type") == "computer_call_output" # Computer outputs
|
||||
):
|
||||
# Create a copy of the item with timestamp
|
||||
trajectory_item = item.copy()
|
||||
trajectory_item["logged_at"] = time.time()
|
||||
trajectory.append(trajectory_item)
|
||||
|
||||
return trajectory
|
||||
|
||||
return trajectory
|
||||
|
||||
@@ -2,26 +2,28 @@
|
||||
Trajectory saving callback handler for ComputerAgent.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Union, override
|
||||
from PIL import Image, ImageDraw
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union, override
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
|
||||
def sanitize_image_urls(data: Any) -> Any:
|
||||
"""
|
||||
Recursively search for 'image_url' keys and set their values to '[omitted]'.
|
||||
|
||||
|
||||
Args:
|
||||
data: Any data structure (dict, list, or primitive type)
|
||||
|
||||
|
||||
Returns:
|
||||
A deep copy of the data with all 'image_url' values replaced with '[omitted]'
|
||||
"""
|
||||
@@ -35,17 +37,19 @@ def sanitize_image_urls(data: Any) -> Any:
|
||||
# Recursively sanitize the value
|
||||
sanitized[key] = sanitize_image_urls(value)
|
||||
return sanitized
|
||||
|
||||
|
||||
elif isinstance(data, list):
|
||||
# Recursively sanitize each item in the list
|
||||
return [sanitize_image_urls(item) for item in data]
|
||||
|
||||
|
||||
else:
|
||||
# For primitive types (str, int, bool, None, etc.), return as-is
|
||||
return data
|
||||
|
||||
|
||||
def extract_computer_call_outputs(items: List[Dict[str, Any]], screenshot_dir: Optional[Path]) -> List[Dict[str, Any]]:
|
||||
def extract_computer_call_outputs(
|
||||
items: List[Dict[str, Any]], screenshot_dir: Optional[Path]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Save any base64-encoded screenshots from computer_call_output entries to files and
|
||||
replace their image_url with the saved file path when a call_id is present.
|
||||
@@ -103,18 +107,21 @@ def extract_computer_call_outputs(items: List[Dict[str, Any]], screenshot_dir: O
|
||||
updated.append(msg)
|
||||
return updated
|
||||
|
||||
|
||||
class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
Callback handler that saves agent trajectories to disk.
|
||||
|
||||
|
||||
Saves each run as a separate trajectory with unique ID, and each turn
|
||||
within the trajectory gets its own folder with screenshots and responses.
|
||||
"""
|
||||
|
||||
def __init__(self, trajectory_dir: str, reset_on_run: bool = True, screenshot_dir: Optional[str] = None):
|
||||
|
||||
def __init__(
|
||||
self, trajectory_dir: str, reset_on_run: bool = True, screenshot_dir: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize trajectory saver.
|
||||
|
||||
|
||||
Args:
|
||||
trajectory_dir: Base directory to save trajectories
|
||||
reset_on_run: If True, reset trajectory_id/turn/artifact on each run.
|
||||
@@ -129,7 +136,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
self.reset_on_run = reset_on_run
|
||||
# Optional directory to store extracted screenshots from metadata/new_items
|
||||
self.screenshot_dir: Optional[Path] = Path(screenshot_dir) if screenshot_dir else None
|
||||
|
||||
|
||||
# Ensure trajectory directory exists
|
||||
self.trajectory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -137,7 +144,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
"""Get the directory for the current turn."""
|
||||
if not self.trajectory_id:
|
||||
raise ValueError("Trajectory not initialized - call _on_run_start first")
|
||||
|
||||
|
||||
# format: trajectory_id/turn_000
|
||||
turn_dir = self.trajectory_dir / self.trajectory_id / f"turn_{self.current_turn:03d}"
|
||||
turn_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -166,6 +173,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
|
||||
def _update_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Update total usage statistics."""
|
||||
|
||||
def add_dicts(target: Dict[str, Any], source: Dict[str, Any]) -> None:
|
||||
for key, value in source.items():
|
||||
if isinstance(value, dict):
|
||||
@@ -176,20 +184,21 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
if key not in target:
|
||||
target[key] = 0
|
||||
target[key] += value
|
||||
|
||||
add_dicts(self.total_usage, usage)
|
||||
|
||||
|
||||
@override
|
||||
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Initialize trajectory tracking for a new run."""
|
||||
model = kwargs.get("model", "unknown")
|
||||
|
||||
|
||||
# Only reset trajectory state if reset_on_run is True or no trajectory exists
|
||||
if self.reset_on_run or not self.trajectory_id:
|
||||
model_name_short = model.split("+")[-1].split("/")[-1].lower()[:16]
|
||||
if "+" in model:
|
||||
model_name_short = model.split("+")[0].lower()[:4] + "_" + model_name_short
|
||||
# strip non-alphanumeric characters from model_name_short
|
||||
model_name_short = ''.join(c for c in model_name_short if c.isalnum() or c == '_')
|
||||
model_name_short = "".join(c for c in model_name_short if c.isalnum() or c == "_")
|
||||
|
||||
# id format: yyyy-mm-dd_model_hhmmss_uuid[:4]
|
||||
now = datetime.now()
|
||||
@@ -198,11 +207,11 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
self.current_artifact = 0
|
||||
self.model = model
|
||||
self.total_usage = {}
|
||||
|
||||
|
||||
# Create trajectory directory
|
||||
trajectory_path = self.trajectory_dir / self.trajectory_id
|
||||
trajectory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Save trajectory metadata (optionally extract screenshots to screenshot_dir)
|
||||
kwargs_to_save = kwargs.copy()
|
||||
try:
|
||||
@@ -219,7 +228,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
"status": "running",
|
||||
"kwargs": kwargs_to_save,
|
||||
}
|
||||
|
||||
|
||||
with open(trajectory_path / "metadata.json", "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
else:
|
||||
@@ -227,22 +236,27 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
self.model = model
|
||||
|
||||
@override
|
||||
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
||||
async def on_run_end(
|
||||
self,
|
||||
kwargs: Dict[str, Any],
|
||||
old_items: List[Dict[str, Any]],
|
||||
new_items: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Finalize run tracking by updating metadata with completion status, usage, and new items."""
|
||||
if not self.trajectory_id:
|
||||
return
|
||||
|
||||
|
||||
# Update metadata with completion status, total usage, and new items
|
||||
trajectory_path = self.trajectory_dir / self.trajectory_id
|
||||
metadata_path = trajectory_path / "metadata.json"
|
||||
|
||||
|
||||
# Read existing metadata
|
||||
if metadata_path.exists():
|
||||
with open(metadata_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
metadata = {}
|
||||
|
||||
|
||||
# Update metadata with completion info
|
||||
# Optionally extract screenshots from new_items before persisting
|
||||
new_items_to_save = new_items
|
||||
@@ -251,32 +265,34 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
metadata.update({
|
||||
"status": "completed",
|
||||
"completed_at": str(uuid.uuid1().time),
|
||||
"total_usage": self.total_usage,
|
||||
"new_items": new_items_to_save,
|
||||
"total_turns": self.current_turn
|
||||
})
|
||||
|
||||
metadata.update(
|
||||
{
|
||||
"status": "completed",
|
||||
"completed_at": str(uuid.uuid1().time),
|
||||
"total_usage": self.total_usage,
|
||||
"new_items": new_items_to_save,
|
||||
"total_turns": self.current_turn,
|
||||
}
|
||||
)
|
||||
|
||||
# Save updated metadata
|
||||
with open(metadata_path, "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
@override
|
||||
|
||||
@override
|
||||
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
||||
if not self.trajectory_id:
|
||||
return
|
||||
|
||||
self._save_artifact("api_start", { "kwargs": kwargs })
|
||||
|
||||
|
||||
self._save_artifact("api_start", {"kwargs": kwargs})
|
||||
|
||||
@override
|
||||
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
||||
"""Save API call result."""
|
||||
if not self.trajectory_id:
|
||||
return
|
||||
|
||||
self._save_artifact("api_result", { "kwargs": kwargs, "result": result })
|
||||
|
||||
self._save_artifact("api_result", {"kwargs": kwargs, "result": result})
|
||||
|
||||
@override
|
||||
async def on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
|
||||
@@ -295,77 +311,83 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
"""Save responses to the current turn directory and update usage statistics."""
|
||||
if not self.trajectory_id:
|
||||
return
|
||||
|
||||
|
||||
# Save responses
|
||||
turn_dir = self._get_turn_dir()
|
||||
response_data = {
|
||||
"timestamp": str(uuid.uuid1().time),
|
||||
"model": self.model,
|
||||
"kwargs": kwargs,
|
||||
"response": responses
|
||||
"response": responses,
|
||||
}
|
||||
|
||||
|
||||
self._save_artifact("agent_response", response_data)
|
||||
|
||||
|
||||
# Increment turn counter
|
||||
self.current_turn += 1
|
||||
|
||||
def _draw_crosshair_on_image(self, image_bytes: bytes, x: int, y: int) -> bytes:
|
||||
"""
|
||||
Draw a red dot and crosshair at the specified coordinates on the image.
|
||||
|
||||
|
||||
Args:
|
||||
image_bytes: The original image as bytes
|
||||
x: X coordinate for the crosshair
|
||||
y: Y coordinate for the crosshair
|
||||
|
||||
|
||||
Returns:
|
||||
Modified image as bytes with red dot and crosshair
|
||||
"""
|
||||
# Open the image
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
|
||||
# Draw crosshair lines (red, 2px thick)
|
||||
crosshair_size = 20
|
||||
line_width = 2
|
||||
color = "red"
|
||||
|
||||
|
||||
# Horizontal line
|
||||
draw.line([(x - crosshair_size, y), (x + crosshair_size, y)], fill=color, width=line_width)
|
||||
# Vertical line
|
||||
draw.line([(x, y - crosshair_size), (x, y + crosshair_size)], fill=color, width=line_width)
|
||||
|
||||
|
||||
# Draw center dot (filled circle)
|
||||
dot_radius = 3
|
||||
draw.ellipse([(x - dot_radius, y - dot_radius), (x + dot_radius, y + dot_radius)], fill=color)
|
||||
|
||||
draw.ellipse(
|
||||
[(x - dot_radius, y - dot_radius), (x + dot_radius, y + dot_radius)], fill=color
|
||||
)
|
||||
|
||||
# Convert back to bytes
|
||||
output = io.BytesIO()
|
||||
image.save(output, format='PNG')
|
||||
image.save(output, format="PNG")
|
||||
return output.getvalue()
|
||||
|
||||
@override
|
||||
async def on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
||||
async def on_computer_call_end(
|
||||
self, item: Dict[str, Any], result: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""
|
||||
Called when a computer call has completed.
|
||||
Saves screenshots and computer call output.
|
||||
"""
|
||||
if not self.trajectory_id:
|
||||
return
|
||||
|
||||
self._save_artifact("computer_call_result", { "item": item, "result": result })
|
||||
|
||||
|
||||
self._save_artifact("computer_call_result", {"item": item, "result": result})
|
||||
|
||||
# Check if action has x/y coordinates and there's a screenshot in the result
|
||||
action = item.get("action", {})
|
||||
if "x" in action and "y" in action:
|
||||
# Look for screenshot in the result
|
||||
for result_item in result:
|
||||
if (result_item.get("type") == "computer_call_output" and
|
||||
result_item.get("output", {}).get("type") == "input_image"):
|
||||
|
||||
if (
|
||||
result_item.get("type") == "computer_call_output"
|
||||
and result_item.get("output", {}).get("type") == "input_image"
|
||||
):
|
||||
|
||||
image_url = result_item["output"]["image_url"]
|
||||
|
||||
|
||||
# Extract base64 image data
|
||||
if image_url.startswith("data:image/"):
|
||||
# Format: data:image/png;base64,<base64_data>
|
||||
@@ -373,26 +395,24 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
else:
|
||||
# Assume it's just base64 data
|
||||
base64_data = image_url
|
||||
|
||||
|
||||
try:
|
||||
# Decode the image
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
|
||||
|
||||
# Draw crosshair at the action coordinates
|
||||
annotated_image = self._draw_crosshair_on_image(
|
||||
image_bytes,
|
||||
int(action["x"]),
|
||||
int(action["y"])
|
||||
image_bytes, int(action["x"]), int(action["y"])
|
||||
)
|
||||
|
||||
|
||||
# Save as screenshot_action
|
||||
self._save_artifact("screenshot_action", annotated_image)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# If annotation fails, just log and continue
|
||||
print(f"Failed to annotate screenshot: {e}")
|
||||
|
||||
|
||||
break # Only process the first screenshot found
|
||||
|
||||
# Increment turn counter
|
||||
self.current_turn += 1
|
||||
self.current_turn += 1
|
||||
|
||||
@@ -3,7 +3,7 @@ CLI chat interface for agent - Computer Use Agent
|
||||
|
||||
Usage:
|
||||
python -m agent.cli <model_string>
|
||||
|
||||
|
||||
Examples:
|
||||
python -m agent.cli openai/computer-use-preview
|
||||
python -m agent.cli anthropic/claude-3-5-sonnet-20241022
|
||||
@@ -11,19 +11,22 @@ Examples:
|
||||
"""
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from typing import List, Dict, Any
|
||||
import dotenv
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import dotenv
|
||||
|
||||
try:
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
PIL_AVAILABLE = True
|
||||
except Exception:
|
||||
PIL_AVAILABLE = False
|
||||
@@ -31,36 +34,44 @@ try:
|
||||
except ImportError:
|
||||
if __name__ == "__main__":
|
||||
raise ImportError(
|
||||
"CLI dependencies not found. "
|
||||
"Please install with: pip install \"cua-agent[cli]\""
|
||||
"CLI dependencies not found. " 'Please install with: pip install "cua-agent[cli]"'
|
||||
)
|
||||
|
||||
# Load environment variables
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
# Color codes for terminal output
|
||||
class Colors:
|
||||
RESET = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
DIM = '\033[2m'
|
||||
|
||||
# Text colors
|
||||
RED = '\033[31m'
|
||||
GREEN = '\033[32m'
|
||||
YELLOW = '\033[33m'
|
||||
BLUE = '\033[34m'
|
||||
MAGENTA = '\033[35m'
|
||||
CYAN = '\033[36m'
|
||||
WHITE = '\033[37m'
|
||||
GRAY = '\033[90m'
|
||||
|
||||
# Background colors
|
||||
BG_RED = '\033[41m'
|
||||
BG_GREEN = '\033[42m'
|
||||
BG_YELLOW = '\033[43m'
|
||||
BG_BLUE = '\033[44m'
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
DIM = "\033[2m"
|
||||
|
||||
def print_colored(text: str, color: str = "", bold: bool = False, dim: bool = False, end: str = "\n", right: str = ""):
|
||||
# Text colors
|
||||
RED = "\033[31m"
|
||||
GREEN = "\033[32m"
|
||||
YELLOW = "\033[33m"
|
||||
BLUE = "\033[34m"
|
||||
MAGENTA = "\033[35m"
|
||||
CYAN = "\033[36m"
|
||||
WHITE = "\033[37m"
|
||||
GRAY = "\033[90m"
|
||||
|
||||
# Background colors
|
||||
BG_RED = "\033[41m"
|
||||
BG_GREEN = "\033[42m"
|
||||
BG_YELLOW = "\033[43m"
|
||||
BG_BLUE = "\033[44m"
|
||||
|
||||
|
||||
def print_colored(
|
||||
text: str,
|
||||
color: str = "",
|
||||
bold: bool = False,
|
||||
dim: bool = False,
|
||||
end: str = "\n",
|
||||
right: str = "",
|
||||
):
|
||||
"""Print colored text to terminal with optional right-aligned text."""
|
||||
prefix = ""
|
||||
if bold:
|
||||
@@ -69,24 +80,25 @@ def print_colored(text: str, color: str = "", bold: bool = False, dim: bool = Fa
|
||||
prefix += Colors.DIM
|
||||
if color:
|
||||
prefix += color
|
||||
|
||||
|
||||
if right:
|
||||
# Get terminal width (default to 80 if unable to determine)
|
||||
try:
|
||||
import shutil
|
||||
|
||||
terminal_width = shutil.get_terminal_size().columns
|
||||
except:
|
||||
terminal_width = 80
|
||||
|
||||
# Add right margin
|
||||
terminal_width -= 1
|
||||
|
||||
|
||||
# Calculate padding needed
|
||||
# Account for ANSI escape codes not taking visual space
|
||||
visible_left_len = len(text)
|
||||
visible_right_len = len(right)
|
||||
padding = terminal_width - visible_left_len - visible_right_len
|
||||
|
||||
|
||||
if padding > 0:
|
||||
output = f"{prefix}{text}{' ' * padding}{right}{Colors.RESET}"
|
||||
else:
|
||||
@@ -94,7 +106,7 @@ def print_colored(text: str, color: str = "", bold: bool = False, dim: bool = Fa
|
||||
output = f"{prefix}{text} {right}{Colors.RESET}"
|
||||
else:
|
||||
output = f"{prefix}{text}{Colors.RESET}"
|
||||
|
||||
|
||||
print(output, end=end)
|
||||
|
||||
|
||||
@@ -113,29 +125,34 @@ def print_action(action_type: str, details: Dict[str, Any], total_cost: float):
|
||||
args_str = f"('{details['text']}')"
|
||||
elif action_type == "scroll" and "x" in details and "y" in details:
|
||||
args_str = f"({details['x']}, {details['y']})"
|
||||
|
||||
|
||||
if total_cost > 0:
|
||||
print_colored(f"🛠️ {action_type}{args_str}", dim=True, right=f"💸 ${total_cost:.2f}")
|
||||
else:
|
||||
print_colored(f"🛠️ {action_type}{args_str}", dim=True)
|
||||
|
||||
|
||||
def print_welcome(model: str, agent_loop: str, container_name: str):
|
||||
"""Print welcome message."""
|
||||
print_colored(f"Connected to {container_name} ({model}, {agent_loop})")
|
||||
print_colored("Type 'exit' to quit.", dim=True)
|
||||
|
||||
|
||||
async def ainput(prompt: str = ""):
|
||||
return await asyncio.to_thread(input, prompt)
|
||||
|
||||
async def chat_loop(agent, model: str, container_name: str, initial_prompt: str = "", show_usage: bool = True):
|
||||
|
||||
async def chat_loop(
|
||||
agent, model: str, container_name: str, initial_prompt: str = "", show_usage: bool = True
|
||||
):
|
||||
"""Main chat loop with the agent."""
|
||||
print_welcome(model, agent.agent_config_info.agent_class.__name__, container_name)
|
||||
|
||||
|
||||
history = []
|
||||
|
||||
|
||||
if initial_prompt:
|
||||
history.append({"role": "user", "content": initial_prompt})
|
||||
|
||||
|
||||
total_cost = 0
|
||||
|
||||
while True:
|
||||
@@ -143,31 +160,31 @@ async def chat_loop(agent, model: str, container_name: str, initial_prompt: str
|
||||
# Get user input with prompt
|
||||
print_colored("> ", end="")
|
||||
user_input = await ainput()
|
||||
|
||||
if user_input.lower() in ['exit', 'quit', 'q']:
|
||||
|
||||
if user_input.lower() in ["exit", "quit", "q"]:
|
||||
print_colored("\n👋 Goodbye!")
|
||||
break
|
||||
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
|
||||
# Add user message to history
|
||||
history.append({"role": "user", "content": user_input})
|
||||
|
||||
|
||||
# Stream responses from the agent with spinner
|
||||
with yaspin(text="Thinking...", spinner="line", attrs=["dark"]) as spinner:
|
||||
spinner.hide()
|
||||
|
||||
|
||||
async for result in agent.run(history):
|
||||
# Add agent responses to history
|
||||
history.extend(result.get("output", []))
|
||||
|
||||
if show_usage:
|
||||
total_cost += result.get("usage", {}).get("response_cost", 0)
|
||||
|
||||
|
||||
# Process and display the output
|
||||
for item in result.get("output", []):
|
||||
if item.get("type") == "message":
|
||||
if item.get("type") == "message" and item.get("role") == "assistant":
|
||||
# Display agent text response
|
||||
content = item.get("content", [])
|
||||
for content_part in content:
|
||||
@@ -176,7 +193,7 @@ async def chat_loop(agent, model: str, container_name: str, initial_prompt: str
|
||||
if text:
|
||||
spinner.hide()
|
||||
print_colored(text)
|
||||
|
||||
|
||||
elif item.get("type") == "computer_call":
|
||||
# Display computer action
|
||||
action = item.get("action", {})
|
||||
@@ -186,7 +203,7 @@ async def chat_loop(agent, model: str, container_name: str, initial_prompt: str
|
||||
print_action(action_type, action, total_cost)
|
||||
spinner.text = f"Performing {action_type}..."
|
||||
spinner.show()
|
||||
|
||||
|
||||
elif item.get("type") == "function_call":
|
||||
# Display function call
|
||||
function_name = item.get("name", "")
|
||||
@@ -194,18 +211,18 @@ async def chat_loop(agent, model: str, container_name: str, initial_prompt: str
|
||||
print_colored(f"🔧 Calling function: {function_name}", dim=True)
|
||||
spinner.text = f"Calling {function_name}..."
|
||||
spinner.show()
|
||||
|
||||
|
||||
elif item.get("type") == "function_call_output":
|
||||
# Display function output (dimmed)
|
||||
output = item.get("output", "")
|
||||
if output and len(output.strip()) > 0:
|
||||
spinner.hide()
|
||||
print_colored(f"📤 {output}", dim=True)
|
||||
|
||||
|
||||
spinner.hide()
|
||||
if show_usage and total_cost > 0:
|
||||
print_colored(f"Total cost: ${total_cost:.2f}", dim=True)
|
||||
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main CLI function."""
|
||||
@@ -218,104 +235,103 @@ Examples:
|
||||
python -m agent.cli anthropic/claude-3-5-sonnet-20241022
|
||||
python -m agent.cli omniparser+anthropic/claude-3-5-sonnet-20241022
|
||||
python -m agent.cli huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"model",
|
||||
help="Model string (e.g., 'openai/computer-use-preview', 'anthropic/claude-3-5-sonnet-20241022')"
|
||||
help="Model string (e.g., 'openai/computer-use-preview', 'anthropic/claude-3-5-sonnet-20241022')",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
choices=["cloud", "lume", "winsandbox", "docker"],
|
||||
default="cloud",
|
||||
help="Computer provider to use: cloud (default), lume, winsandbox, or docker",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--images",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of recent images to keep in context (default: 3)"
|
||||
help="Number of recent images to keep in context (default: 3)",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument("--trajectory", action="store_true", help="Save trajectory for debugging")
|
||||
|
||||
parser.add_argument("--budget", type=float, help="Maximum budget for the session (in dollars)")
|
||||
|
||||
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
|
||||
|
||||
parser.add_argument(
|
||||
"--trajectory",
|
||||
action="store_true",
|
||||
help="Save trajectory for debugging"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--budget",
|
||||
type=float,
|
||||
help="Maximum budget for the session (in dollars)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Enable verbose logging"
|
||||
"-p",
|
||||
"--prompt",
|
||||
type=str,
|
||||
help="Initial prompt to send to the agent. Leave blank for interactive mode.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-p", "--prompt",
|
||||
type=str,
|
||||
help="Initial prompt to send to the agent. Leave blank for interactive mode."
|
||||
"--prompt-file",
|
||||
type=Path,
|
||||
help="Path to a UTF-8 text file whose contents will be used as the initial prompt. If provided, overrides --prompt.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--predict-click",
|
||||
dest="predict_click",
|
||||
type=str,
|
||||
help="Instruction for click prediction. If set, runs predict_click, draws crosshair on a fresh screenshot, saves and opens it."
|
||||
help="Instruction for click prediction. If set, runs predict_click, draws crosshair on a fresh screenshot, saves and opens it.",
|
||||
)
|
||||
|
||||
parser.add_argument("-c", "--cache", action="store_true", help="Tell the API to enable caching")
|
||||
|
||||
parser.add_argument(
|
||||
"-u", "--usage", action="store_true", help="Show total cost of the agent runs"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-c", "--cache",
|
||||
action="store_true",
|
||||
help="Tell the API to enable caching"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-u", "--usage",
|
||||
action="store_true",
|
||||
help="Show total cost of the agent runs"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-r", "--max-retries",
|
||||
"-r",
|
||||
"--max-retries",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Maximum number of retries for the LLM API calls"
|
||||
help="Maximum number of retries for the LLM API calls",
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Check for required environment variables
|
||||
container_name = os.getenv("CUA_CONTAINER_NAME")
|
||||
cua_api_key = os.getenv("CUA_API_KEY")
|
||||
|
||||
# Prompt for missing environment variables
|
||||
|
||||
# Prompt for missing environment variables (container name always required)
|
||||
if not container_name:
|
||||
print_colored("CUA_CONTAINER_NAME not set.", dim=True)
|
||||
print_colored("You can get a CUA container at https://www.trycua.com/", dim=True)
|
||||
container_name = input("Enter your CUA container name: ").strip()
|
||||
if not container_name:
|
||||
print_colored("❌ Container name is required.")
|
||||
sys.exit(1)
|
||||
|
||||
if not cua_api_key:
|
||||
if args.provider == "cloud":
|
||||
print_colored("CUA_CONTAINER_NAME not set.", dim=True)
|
||||
print_colored("You can get a CUA container at https://www.trycua.com/", dim=True)
|
||||
container_name = input("Enter your CUA container name: ").strip()
|
||||
if not container_name:
|
||||
print_colored("❌ Container name is required.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
container_name = "cli-sandbox"
|
||||
|
||||
# Only require API key for cloud provider
|
||||
if args.provider == "cloud" and not cua_api_key:
|
||||
print_colored("CUA_API_KEY not set.", dim=True)
|
||||
cua_api_key = input("Enter your CUA API key: ").strip()
|
||||
if not cua_api_key:
|
||||
print_colored("❌ API key is required.")
|
||||
print_colored("❌ API key is required for cloud provider.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# Check for provider-specific API keys based on model
|
||||
provider_api_keys = {
|
||||
"openai/": "OPENAI_API_KEY",
|
||||
"anthropic/": "ANTHROPIC_API_KEY",
|
||||
"omniparser+": "OPENAI_API_KEY",
|
||||
"omniparser+": "ANTHROPIC_API_KEY",
|
||||
}
|
||||
|
||||
|
||||
# Find matching provider and check for API key
|
||||
for prefix, env_var in provider_api_keys.items():
|
||||
if args.model.startswith(prefix):
|
||||
if prefix in args.model:
|
||||
if not os.getenv(env_var):
|
||||
print_colored(f"{env_var} not set.", dim=True)
|
||||
api_key = input(f"Enter your {env_var.replace('_', ' ').title()}: ").strip()
|
||||
@@ -325,7 +341,7 @@ Examples:
|
||||
# Set the environment variable for the session
|
||||
os.environ[env_var] = api_key
|
||||
break
|
||||
|
||||
|
||||
# Import here to avoid import errors if dependencies are missing
|
||||
try:
|
||||
from agent import ComputerAgent
|
||||
@@ -334,46 +350,62 @@ Examples:
|
||||
print_colored(f"❌ Import error: {e}", Colors.RED, bold=True)
|
||||
print_colored("Make sure agent and computer libraries are installed.", Colors.YELLOW)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# Resolve provider -> os_type, provider_type, api key requirement
|
||||
provider_map = {
|
||||
"cloud": ("linux", "cloud", True),
|
||||
"lume": ("macos", "lume", False),
|
||||
"winsandbox": ("windows", "winsandbox", False),
|
||||
"docker": ("linux", "docker", False),
|
||||
}
|
||||
os_type, provider_type, needs_api_key = provider_map[args.provider]
|
||||
|
||||
computer_kwargs = {
|
||||
"os_type": os_type,
|
||||
"provider_type": provider_type,
|
||||
"name": container_name,
|
||||
}
|
||||
if needs_api_key:
|
||||
computer_kwargs["api_key"] = cua_api_key # type: ignore
|
||||
|
||||
# Create computer instance
|
||||
async with Computer(
|
||||
os_type="linux",
|
||||
provider_type="cloud",
|
||||
name=container_name,
|
||||
api_key=cua_api_key
|
||||
) as computer:
|
||||
|
||||
async with Computer(**computer_kwargs) as computer: # type: ignore
|
||||
|
||||
# Create agent
|
||||
agent_kwargs = {
|
||||
"model": args.model,
|
||||
"tools": [computer],
|
||||
"trust_remote_code": True, # needed for some local models (e.g., InternVL, OpenCUA)
|
||||
"trust_remote_code": True, # needed for some local models (e.g., InternVL, OpenCUA)
|
||||
"verbosity": 20 if args.verbose else 30, # DEBUG vs WARNING
|
||||
"max_retries": args.max_retries
|
||||
"max_retries": args.max_retries,
|
||||
}
|
||||
|
||||
if args.images > 0:
|
||||
agent_kwargs["only_n_most_recent_images"] = args.images
|
||||
|
||||
|
||||
if args.trajectory:
|
||||
agent_kwargs["trajectory_dir"] = "trajectories"
|
||||
|
||||
|
||||
if args.budget:
|
||||
agent_kwargs["max_trajectory_budget"] = {
|
||||
"max_budget": args.budget,
|
||||
"raise_error": True,
|
||||
"reset_after_each_run": False
|
||||
"reset_after_each_run": False,
|
||||
}
|
||||
|
||||
if args.cache:
|
||||
agent_kwargs["use_prompt_caching"] = True
|
||||
|
||||
|
||||
agent = ComputerAgent(**agent_kwargs)
|
||||
|
||||
|
||||
# If predict-click mode is requested, run once and exit
|
||||
if args.predict_click:
|
||||
if not PIL_AVAILABLE:
|
||||
print_colored("❌ Pillow (PIL) is required for --predict-click visualization. Install with: pip install pillow", Colors.RED, bold=True)
|
||||
print_colored(
|
||||
"❌ Pillow (PIL) is required for --predict-click visualization. Install with: pip install pillow",
|
||||
Colors.RED,
|
||||
bold=True,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
instruction = args.predict_click
|
||||
@@ -408,6 +440,7 @@ Examples:
|
||||
|
||||
try:
|
||||
from io import BytesIO
|
||||
|
||||
with Image.open(BytesIO(img_bytes)) as img:
|
||||
img = img.convert("RGB")
|
||||
draw = ImageDraw.Draw(img)
|
||||
@@ -430,9 +463,9 @@ Examples:
|
||||
if system == "windows":
|
||||
os.startfile(str(out_path)) # type: ignore[attr-defined]
|
||||
elif system == "darwin":
|
||||
os.system(f"open \"{out_path}\"")
|
||||
os.system(f'open "{out_path}"')
|
||||
else:
|
||||
os.system(f"xdg-open \"{out_path}\"")
|
||||
os.system(f'xdg-open "{out_path}"')
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
@@ -442,13 +475,21 @@ Examples:
|
||||
# Done
|
||||
sys.exit(0)
|
||||
|
||||
# Start chat loop (default interactive mode)
|
||||
await chat_loop(agent, args.model, container_name, args.prompt, args.usage)
|
||||
# Resolve initial prompt from --prompt-file or --prompt
|
||||
initial_prompt = args.prompt or ""
|
||||
if args.prompt_file:
|
||||
try:
|
||||
initial_prompt = args.prompt_file.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
print_colored(f"❌ Failed to read --prompt-file: {e}", Colors.RED, bold=True)
|
||||
sys.exit(1)
|
||||
|
||||
# Start chat loop (default interactive mode)
|
||||
await chat_loop(agent, args.model, container_name, initial_prompt, args.usage)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except (KeyboardInterrupt, EOFError) as _:
|
||||
print_colored("\n\n👋 Goodbye!")
|
||||
print_colored("\n\n👋 Goodbye!")
|
||||
|
||||
@@ -6,27 +6,32 @@ computer interface types, supporting both the ComputerHandler protocol and the
|
||||
Computer library interface.
|
||||
"""
|
||||
|
||||
from computer import Computer as cuaComputer
|
||||
|
||||
from .base import AsyncComputerHandler
|
||||
from .cua import cuaComputerHandler
|
||||
from .custom import CustomComputerHandler
|
||||
from computer import Computer as cuaComputer
|
||||
|
||||
|
||||
def is_agent_computer(computer):
|
||||
"""Check if the given computer is a ComputerHandler or CUA Computer."""
|
||||
return isinstance(computer, AsyncComputerHandler) or \
|
||||
isinstance(computer, cuaComputer) or \
|
||||
(isinstance(computer, dict)) #and "screenshot" in computer)
|
||||
return (
|
||||
isinstance(computer, AsyncComputerHandler)
|
||||
or isinstance(computer, cuaComputer)
|
||||
or (isinstance(computer, dict))
|
||||
) # and "screenshot" in computer)
|
||||
|
||||
|
||||
async def make_computer_handler(computer):
|
||||
"""
|
||||
Create a computer handler from a computer interface.
|
||||
|
||||
|
||||
Args:
|
||||
computer: Either a ComputerHandler instance, Computer instance, or dict of functions
|
||||
|
||||
|
||||
Returns:
|
||||
ComputerHandler: A computer handler instance
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If the computer type is not supported
|
||||
"""
|
||||
@@ -38,4 +43,4 @@ async def make_computer_handler(computer):
|
||||
return computer_handler
|
||||
if isinstance(computer, dict):
|
||||
return CustomComputerHandler(computer)
|
||||
raise ValueError(f"Unsupported computer type: {type(computer)}")
|
||||
raise ValueError(f"Unsupported computer type: {type(computer)}")
|
||||
|
||||
@@ -2,23 +2,32 @@
|
||||
Base computer interface protocol for agent interactions.
|
||||
"""
|
||||
|
||||
from typing import Protocol, Literal, List, Dict, Any, Union, Optional, runtime_checkable
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AsyncComputerHandler(Protocol):
|
||||
"""Protocol defining the interface for computer interactions."""
|
||||
|
||||
# ==== Computer-Use-Preview Action Space ====
|
||||
|
||||
# ==== Computer-Use-Preview Action Space ====
|
||||
|
||||
async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]:
|
||||
"""Get the current environment type."""
|
||||
...
|
||||
|
||||
|
||||
async def get_dimensions(self) -> tuple[int, int]:
|
||||
"""Get screen dimensions as (width, height)."""
|
||||
...
|
||||
|
||||
|
||||
async def screenshot(self, text: Optional[str] = None) -> str:
|
||||
"""Take a screenshot and return as base64 string.
|
||||
|
||||
@@ -26,49 +35,49 @@ class AsyncComputerHandler(Protocol):
|
||||
text: Optional descriptive text (for compatibility with GPT-4o models, ignored)
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
async def click(self, x: int, y: int, button: str = "left") -> None:
|
||||
"""Click at coordinates with specified button."""
|
||||
...
|
||||
|
||||
|
||||
async def double_click(self, x: int, y: int) -> None:
|
||||
"""Double click at coordinates."""
|
||||
...
|
||||
|
||||
|
||||
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
|
||||
"""Scroll at coordinates with specified scroll amounts."""
|
||||
...
|
||||
|
||||
|
||||
async def type(self, text: str) -> None:
|
||||
"""Type text."""
|
||||
...
|
||||
|
||||
|
||||
async def wait(self, ms: int = 1000) -> None:
|
||||
"""Wait for specified milliseconds."""
|
||||
...
|
||||
|
||||
|
||||
async def move(self, x: int, y: int) -> None:
|
||||
"""Move cursor to coordinates."""
|
||||
...
|
||||
|
||||
|
||||
async def keypress(self, keys: Union[List[str], str]) -> None:
|
||||
"""Press key combination."""
|
||||
...
|
||||
|
||||
|
||||
async def drag(self, path: List[Dict[str, int]]) -> None:
|
||||
"""Drag along specified path."""
|
||||
...
|
||||
|
||||
|
||||
async def get_current_url(self) -> str:
|
||||
"""Get current URL (for browser environments)."""
|
||||
...
|
||||
|
||||
# ==== Anthropic Action Space ====
|
||||
|
||||
# ==== Anthropic Action Space ====
|
||||
|
||||
async def left_mouse_down(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
"""Left mouse down at coordinates."""
|
||||
...
|
||||
|
||||
|
||||
async def left_mouse_up(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
"""Left mouse up at coordinates."""
|
||||
...
|
||||
|
||||
@@ -3,24 +3,27 @@ Computer handler implementation for OpenAI computer-use-preview protocol.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Dict, List, Any, Literal, Union, Optional
|
||||
from .base import AsyncComputerHandler
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from computer import Computer
|
||||
|
||||
from .base import AsyncComputerHandler
|
||||
|
||||
|
||||
class cuaComputerHandler(AsyncComputerHandler):
|
||||
"""Computer handler that implements the Computer protocol using the computer interface."""
|
||||
|
||||
|
||||
def __init__(self, cua_computer: Computer):
|
||||
"""Initialize with a computer interface (from tool schema)."""
|
||||
self.cua_computer = cua_computer
|
||||
self.interface = None
|
||||
|
||||
async def _initialize(self):
|
||||
if hasattr(self.cua_computer, '_initialized') and not self.cua_computer._initialized:
|
||||
if hasattr(self.cua_computer, "_initialized") and not self.cua_computer._initialized:
|
||||
await self.cua_computer.run()
|
||||
self.interface = self.cua_computer.interface
|
||||
|
||||
# ==== Computer-Use-Preview Action Space ====
|
||||
|
||||
# ==== Computer-Use-Preview Action Space ====
|
||||
|
||||
async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]:
|
||||
"""Get the current environment type."""
|
||||
@@ -32,7 +35,7 @@ class cuaComputerHandler(AsyncComputerHandler):
|
||||
assert self.interface is not None
|
||||
screen_size = await self.interface.get_screen_size()
|
||||
return screen_size["width"], screen_size["height"]
|
||||
|
||||
|
||||
async def screenshot(self, text: Optional[str] = None) -> str:
|
||||
"""Take a screenshot and return as base64 string.
|
||||
|
||||
@@ -41,8 +44,8 @@ class cuaComputerHandler(AsyncComputerHandler):
|
||||
"""
|
||||
assert self.interface is not None
|
||||
screenshot_bytes = await self.interface.screenshot()
|
||||
return base64.b64encode(screenshot_bytes).decode('utf-8')
|
||||
|
||||
return base64.b64encode(screenshot_bytes).decode("utf-8")
|
||||
|
||||
async def click(self, x: int, y: int, button: str = "left") -> None:
|
||||
"""Click at coordinates with specified button."""
|
||||
assert self.interface is not None
|
||||
@@ -53,34 +56,35 @@ class cuaComputerHandler(AsyncComputerHandler):
|
||||
else:
|
||||
# Default to left click for unknown buttons
|
||||
await self.interface.left_click(x, y)
|
||||
|
||||
|
||||
async def double_click(self, x: int, y: int) -> None:
|
||||
"""Double click at coordinates."""
|
||||
assert self.interface is not None
|
||||
await self.interface.double_click(x, y)
|
||||
|
||||
|
||||
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
|
||||
"""Scroll at coordinates with specified scroll amounts."""
|
||||
assert self.interface is not None
|
||||
await self.interface.move_cursor(x, y)
|
||||
await self.interface.scroll(scroll_x, scroll_y)
|
||||
|
||||
|
||||
async def type(self, text: str) -> None:
|
||||
"""Type text."""
|
||||
assert self.interface is not None
|
||||
await self.interface.type_text(text)
|
||||
|
||||
|
||||
async def wait(self, ms: int = 1000) -> None:
|
||||
"""Wait for specified milliseconds."""
|
||||
assert self.interface is not None
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(ms / 1000.0)
|
||||
|
||||
|
||||
async def move(self, x: int, y: int) -> None:
|
||||
"""Move cursor to coordinates."""
|
||||
assert self.interface is not None
|
||||
await self.interface.move_cursor(x, y)
|
||||
|
||||
|
||||
async def keypress(self, keys: Union[List[str], str]) -> None:
|
||||
"""Press key combination."""
|
||||
assert self.interface is not None
|
||||
@@ -91,38 +95,38 @@ class cuaComputerHandler(AsyncComputerHandler):
|
||||
else:
|
||||
# Handle key combinations
|
||||
await self.interface.hotkey(*keys)
|
||||
|
||||
|
||||
async def drag(self, path: List[Dict[str, int]]) -> None:
|
||||
"""Drag along specified path."""
|
||||
assert self.interface is not None
|
||||
if not path:
|
||||
return
|
||||
|
||||
|
||||
# Start drag from first point
|
||||
start = path[0]
|
||||
await self.interface.mouse_down(start["x"], start["y"])
|
||||
|
||||
|
||||
# Move through path
|
||||
for point in path[1:]:
|
||||
await self.interface.move_cursor(point["x"], point["y"])
|
||||
|
||||
|
||||
# End drag at last point
|
||||
end = path[-1]
|
||||
await self.interface.mouse_up(end["x"], end["y"])
|
||||
|
||||
|
||||
async def get_current_url(self) -> str:
|
||||
"""Get current URL (for browser environments)."""
|
||||
# This would need to be implemented based on the specific browser interface
|
||||
# For now, return empty string
|
||||
return ""
|
||||
|
||||
# ==== Anthropic Computer Action Space ====
|
||||
# ==== Anthropic Computer Action Space ====
|
||||
async def left_mouse_down(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
"""Left mouse down at coordinates."""
|
||||
assert self.interface is not None
|
||||
await self.interface.mouse_down(x, y, button="left")
|
||||
|
||||
|
||||
async def left_mouse_up(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
"""Left mouse up at coordinates."""
|
||||
assert self.interface is not None
|
||||
await self.interface.mouse_up(x, y, button="left")
|
||||
await self.interface.mouse_up(x, y, button="left")
|
||||
|
||||
@@ -3,47 +3,49 @@ Custom computer handler implementation that accepts a dictionary of functions.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Dict, List, Any, Literal, Union, Optional, Callable
|
||||
from PIL import Image
|
||||
import io
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from .base import AsyncComputerHandler
|
||||
|
||||
|
||||
class CustomComputerHandler(AsyncComputerHandler):
|
||||
"""Computer handler that implements the Computer protocol using a dictionary of custom functions."""
|
||||
|
||||
|
||||
def __init__(self, functions: Dict[str, Callable]):
|
||||
"""
|
||||
Initialize with a dictionary of functions.
|
||||
|
||||
|
||||
Args:
|
||||
functions: Dictionary where keys are method names and values are callable functions.
|
||||
Only 'screenshot' is required, all others are optional.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If required 'screenshot' function is not provided.
|
||||
"""
|
||||
if 'screenshot' not in functions:
|
||||
if "screenshot" not in functions:
|
||||
raise ValueError("'screenshot' function is required in functions dictionary")
|
||||
|
||||
|
||||
self.functions = functions
|
||||
self._last_screenshot_size: Optional[tuple[int, int]] = None
|
||||
|
||||
|
||||
async def _call_function(self, func, *args, **kwargs):
|
||||
"""
|
||||
Call a function, handling both async and sync functions.
|
||||
|
||||
|
||||
Args:
|
||||
func: The function to call
|
||||
*args: Positional arguments to pass to the function
|
||||
**kwargs: Keyword arguments to pass to the function
|
||||
|
||||
|
||||
Returns:
|
||||
The result of the function call
|
||||
"""
|
||||
import asyncio
|
||||
import inspect
|
||||
|
||||
|
||||
if callable(func):
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*args, **kwargs)
|
||||
@@ -51,14 +53,14 @@ class CustomComputerHandler(AsyncComputerHandler):
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
return func
|
||||
|
||||
|
||||
async def _get_value(self, attribute: str):
|
||||
"""
|
||||
Get value for an attribute, checking both 'get_{attribute}' and '{attribute}' keys.
|
||||
|
||||
|
||||
Args:
|
||||
attribute: The attribute name to look for
|
||||
|
||||
|
||||
Returns:
|
||||
The value from the functions dict, called if callable, returned directly if not
|
||||
"""
|
||||
@@ -66,20 +68,20 @@ class CustomComputerHandler(AsyncComputerHandler):
|
||||
get_key = f"get_{attribute}"
|
||||
if get_key in self.functions:
|
||||
return await self._call_function(self.functions[get_key])
|
||||
|
||||
# Check for '{attribute}'
|
||||
|
||||
# Check for '{attribute}'
|
||||
if attribute in self.functions:
|
||||
return await self._call_function(self.functions[attribute])
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _to_b64_str(self, img: Union[bytes, Image.Image, str]) -> str:
|
||||
"""
|
||||
Convert image to base64 string.
|
||||
|
||||
|
||||
Args:
|
||||
img: Image as bytes, PIL Image, or base64 string
|
||||
|
||||
|
||||
Returns:
|
||||
str: Base64 encoded image string
|
||||
"""
|
||||
@@ -88,47 +90,47 @@ class CustomComputerHandler(AsyncComputerHandler):
|
||||
return img
|
||||
elif isinstance(img, bytes):
|
||||
# Raw bytes
|
||||
return base64.b64encode(img).decode('utf-8')
|
||||
return base64.b64encode(img).decode("utf-8")
|
||||
elif isinstance(img, Image.Image):
|
||||
# PIL Image
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='PNG')
|
||||
return base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
img.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
else:
|
||||
raise ValueError(f"Unsupported image type: {type(img)}")
|
||||
|
||||
# ==== Computer-Use-Preview Action Space ====
|
||||
|
||||
# ==== Computer-Use-Preview Action Space ====
|
||||
|
||||
async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]:
|
||||
"""Get the current environment type."""
|
||||
result = await self._get_value('environment')
|
||||
result = await self._get_value("environment")
|
||||
if result is None:
|
||||
return "linux"
|
||||
assert result in ["windows", "mac", "linux", "browser"]
|
||||
return result # type: ignore
|
||||
return result # type: ignore
|
||||
|
||||
async def get_dimensions(self) -> tuple[int, int]:
|
||||
"""Get screen dimensions as (width, height)."""
|
||||
result = await self._get_value('dimensions')
|
||||
result = await self._get_value("dimensions")
|
||||
if result is not None:
|
||||
return result # type: ignore
|
||||
|
||||
return result # type: ignore
|
||||
|
||||
# Fallback: use last screenshot size if available
|
||||
if not self._last_screenshot_size:
|
||||
await self.screenshot()
|
||||
assert self._last_screenshot_size is not None, "Failed to get screenshot size"
|
||||
|
||||
|
||||
return self._last_screenshot_size
|
||||
|
||||
|
||||
async def screenshot(self, text: Optional[str] = None) -> str:
|
||||
"""Take a screenshot and return as base64 string.
|
||||
|
||||
Args:
|
||||
text: Optional descriptive text (for compatibility with GPT-4o models, ignored)
|
||||
"""
|
||||
result = await self._call_function(self.functions['screenshot'])
|
||||
b64_str = self._to_b64_str(result) # type: ignore
|
||||
|
||||
result = await self._call_function(self.functions["screenshot"])
|
||||
b64_str = self._to_b64_str(result) # type: ignore
|
||||
|
||||
# Try to extract dimensions for fallback use
|
||||
try:
|
||||
if isinstance(result, Image.Image):
|
||||
@@ -140,74 +142,75 @@ class CustomComputerHandler(AsyncComputerHandler):
|
||||
except Exception:
|
||||
# If we can't get dimensions, that's okay
|
||||
pass
|
||||
|
||||
|
||||
return b64_str
|
||||
|
||||
|
||||
async def click(self, x: int, y: int, button: str = "left") -> None:
|
||||
"""Click at coordinates with specified button."""
|
||||
if 'click' in self.functions:
|
||||
await self._call_function(self.functions['click'], x, y, button)
|
||||
if "click" in self.functions:
|
||||
await self._call_function(self.functions["click"], x, y, button)
|
||||
# No-op if not implemented
|
||||
|
||||
|
||||
async def double_click(self, x: int, y: int) -> None:
|
||||
"""Double click at coordinates."""
|
||||
if 'double_click' in self.functions:
|
||||
await self._call_function(self.functions['double_click'], x, y)
|
||||
if "double_click" in self.functions:
|
||||
await self._call_function(self.functions["double_click"], x, y)
|
||||
# No-op if not implemented
|
||||
|
||||
|
||||
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
|
||||
"""Scroll at coordinates with specified scroll amounts."""
|
||||
if 'scroll' in self.functions:
|
||||
await self._call_function(self.functions['scroll'], x, y, scroll_x, scroll_y)
|
||||
if "scroll" in self.functions:
|
||||
await self._call_function(self.functions["scroll"], x, y, scroll_x, scroll_y)
|
||||
# No-op if not implemented
|
||||
|
||||
|
||||
async def type(self, text: str) -> None:
|
||||
"""Type text."""
|
||||
if 'type' in self.functions:
|
||||
await self._call_function(self.functions['type'], text)
|
||||
if "type" in self.functions:
|
||||
await self._call_function(self.functions["type"], text)
|
||||
# No-op if not implemented
|
||||
|
||||
|
||||
async def wait(self, ms: int = 1000) -> None:
|
||||
"""Wait for specified milliseconds."""
|
||||
if 'wait' in self.functions:
|
||||
await self._call_function(self.functions['wait'], ms)
|
||||
if "wait" in self.functions:
|
||||
await self._call_function(self.functions["wait"], ms)
|
||||
else:
|
||||
# Default implementation
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(ms / 1000.0)
|
||||
|
||||
|
||||
async def move(self, x: int, y: int) -> None:
|
||||
"""Move cursor to coordinates."""
|
||||
if 'move' in self.functions:
|
||||
await self._call_function(self.functions['move'], x, y)
|
||||
if "move" in self.functions:
|
||||
await self._call_function(self.functions["move"], x, y)
|
||||
# No-op if not implemented
|
||||
|
||||
|
||||
async def keypress(self, keys: Union[List[str], str]) -> None:
|
||||
"""Press key combination."""
|
||||
if 'keypress' in self.functions:
|
||||
await self._call_function(self.functions['keypress'], keys)
|
||||
if "keypress" in self.functions:
|
||||
await self._call_function(self.functions["keypress"], keys)
|
||||
# No-op if not implemented
|
||||
|
||||
|
||||
async def drag(self, path: List[Dict[str, int]]) -> None:
|
||||
"""Drag along specified path."""
|
||||
if 'drag' in self.functions:
|
||||
await self._call_function(self.functions['drag'], path)
|
||||
if "drag" in self.functions:
|
||||
await self._call_function(self.functions["drag"], path)
|
||||
# No-op if not implemented
|
||||
|
||||
|
||||
async def get_current_url(self) -> str:
|
||||
"""Get current URL (for browser environments)."""
|
||||
if 'get_current_url' in self.functions:
|
||||
return await self._get_value('current_url') # type: ignore
|
||||
if "get_current_url" in self.functions:
|
||||
return await self._get_value("current_url") # type: ignore
|
||||
return "" # Default fallback
|
||||
|
||||
|
||||
async def left_mouse_down(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
"""Left mouse down at coordinates."""
|
||||
if 'left_mouse_down' in self.functions:
|
||||
await self._call_function(self.functions['left_mouse_down'], x, y)
|
||||
if "left_mouse_down" in self.functions:
|
||||
await self._call_function(self.functions["left_mouse_down"], x, y)
|
||||
# No-op if not implemented
|
||||
|
||||
|
||||
async def left_mouse_up(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
"""Left mouse up at coordinates."""
|
||||
if 'left_mouse_up' in self.functions:
|
||||
await self._call_function(self.functions['left_mouse_up'], x, y)
|
||||
if "left_mouse_up" in self.functions:
|
||||
await self._call_function(self.functions["left_mouse_up"], x, y)
|
||||
# No-op if not implemented
|
||||
|
||||
@@ -3,47 +3,56 @@ Decorators for agent - agent_loop decorator
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from .types import AgentConfigInfo
|
||||
|
||||
# Global registry
|
||||
_agent_configs: List[AgentConfigInfo] = []
|
||||
|
||||
|
||||
def register_agent(models: str, priority: int = 0):
|
||||
"""
|
||||
Decorator to register an AsyncAgentConfig class.
|
||||
|
||||
|
||||
Args:
|
||||
models: Regex pattern to match supported models
|
||||
priority: Priority for agent selection (higher = more priority)
|
||||
"""
|
||||
|
||||
def decorator(agent_class: type):
|
||||
# Validate that the class implements AsyncAgentConfig protocol
|
||||
if not hasattr(agent_class, 'predict_step'):
|
||||
raise ValueError(f"Agent class {agent_class.__name__} must implement predict_step method")
|
||||
if not hasattr(agent_class, 'predict_click'):
|
||||
raise ValueError(f"Agent class {agent_class.__name__} must implement predict_click method")
|
||||
if not hasattr(agent_class, 'get_capabilities'):
|
||||
raise ValueError(f"Agent class {agent_class.__name__} must implement get_capabilities method")
|
||||
|
||||
if not hasattr(agent_class, "predict_step"):
|
||||
raise ValueError(
|
||||
f"Agent class {agent_class.__name__} must implement predict_step method"
|
||||
)
|
||||
if not hasattr(agent_class, "predict_click"):
|
||||
raise ValueError(
|
||||
f"Agent class {agent_class.__name__} must implement predict_click method"
|
||||
)
|
||||
if not hasattr(agent_class, "get_capabilities"):
|
||||
raise ValueError(
|
||||
f"Agent class {agent_class.__name__} must implement get_capabilities method"
|
||||
)
|
||||
|
||||
# Register the agent config
|
||||
config_info = AgentConfigInfo(
|
||||
agent_class=agent_class,
|
||||
models_regex=models,
|
||||
priority=priority
|
||||
agent_class=agent_class, models_regex=models, priority=priority
|
||||
)
|
||||
_agent_configs.append(config_info)
|
||||
|
||||
|
||||
# Sort by priority (highest first)
|
||||
_agent_configs.sort(key=lambda x: x.priority, reverse=True)
|
||||
|
||||
|
||||
return agent_class
|
||||
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_agent_configs() -> List[AgentConfigInfo]:
|
||||
"""Get all registered agent configs"""
|
||||
return _agent_configs.copy()
|
||||
|
||||
|
||||
def find_agent_config(model: str) -> Optional[AgentConfigInfo]:
|
||||
"""Find the best matching agent config for a model"""
|
||||
for config_info in _agent_configs:
|
||||
|
||||
@@ -12,7 +12,7 @@ Components:
|
||||
Usage:
|
||||
# Run the server and UI
|
||||
python -m agent.human_tool
|
||||
|
||||
|
||||
# Or run components separately
|
||||
python -m agent.human_tool.server # API server only
|
||||
python -m agent.human_tool.ui # UI only
|
||||
@@ -21,9 +21,4 @@ Usage:
|
||||
from .server import CompletionQueue, completion_queue
|
||||
from .ui import HumanCompletionUI, create_ui
|
||||
|
||||
__all__ = [
|
||||
"CompletionQueue",
|
||||
"completion_queue",
|
||||
"HumanCompletionUI",
|
||||
"create_ui"
|
||||
]
|
||||
__all__ = ["CompletionQueue", "completion_queue", "HumanCompletionUI", "create_ui"]
|
||||
|
||||
@@ -8,6 +8,7 @@ with a Gradio UI for human interaction.
|
||||
|
||||
import gradio as gr
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .server import app as fastapi_app
|
||||
from .ui import create_ui
|
||||
|
||||
@@ -18,6 +19,7 @@ gradio_demo = create_ui()
|
||||
CUSTOM_PATH = "/gradio"
|
||||
app = gr.mount_gradio_app(fastapi_app, gradio_demo, path=CUSTOM_PATH)
|
||||
|
||||
|
||||
# Add a redirect from root to Gradio UI
|
||||
@fastapi_app.get("/")
|
||||
async def redirect_to_ui():
|
||||
@@ -25,14 +27,16 @@ async def redirect_to_ui():
|
||||
return {
|
||||
"message": "Human Completion Server is running",
|
||||
"ui_url": "/gradio",
|
||||
"api_docs": "/docs"
|
||||
"api_docs": "/docs",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
print("🚀 Starting Human-in-the-Loop Completion Server...")
|
||||
print("📊 API Server: http://localhost:8002")
|
||||
print("🎨 Gradio UI: http://localhost:8002/gradio")
|
||||
print("📚 API Docs: http://localhost:8002/docs")
|
||||
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8002)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any, Optional
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
@@ -49,7 +49,7 @@ class CompletionQueue:
|
||||
self._queue: Dict[str, CompletionCall] = {}
|
||||
self._pending_order: List[str] = []
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def add_completion(self, messages: List[Dict[str, Any]], model: str) -> str:
|
||||
"""Add a completion call to the queue."""
|
||||
async with self._lock:
|
||||
@@ -59,42 +59,47 @@ class CompletionQueue:
|
||||
messages=messages,
|
||||
model=model,
|
||||
status=CompletionStatus.PENDING,
|
||||
created_at=datetime.now()
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
self._queue[call_id] = completion_call
|
||||
self._pending_order.append(call_id)
|
||||
return call_id
|
||||
|
||||
|
||||
async def get_pending_calls(self) -> List[Dict[str, Any]]:
|
||||
"""Get all pending completion calls."""
|
||||
async with self._lock:
|
||||
pending_calls = []
|
||||
for call_id in self._pending_order:
|
||||
if call_id in self._queue and self._queue[call_id].status == CompletionStatus.PENDING:
|
||||
if (
|
||||
call_id in self._queue
|
||||
and self._queue[call_id].status == CompletionStatus.PENDING
|
||||
):
|
||||
call = self._queue[call_id]
|
||||
pending_calls.append({
|
||||
"id": call.id,
|
||||
"model": call.model,
|
||||
"created_at": call.created_at.isoformat(),
|
||||
"messages": call.messages
|
||||
})
|
||||
pending_calls.append(
|
||||
{
|
||||
"id": call.id,
|
||||
"model": call.model,
|
||||
"created_at": call.created_at.isoformat(),
|
||||
"messages": call.messages,
|
||||
}
|
||||
)
|
||||
return pending_calls
|
||||
|
||||
|
||||
async def get_call_status(self, call_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get the status of a specific completion call."""
|
||||
async with self._lock:
|
||||
if call_id not in self._queue:
|
||||
return None
|
||||
|
||||
|
||||
call = self._queue[call_id]
|
||||
result = {
|
||||
"id": call.id,
|
||||
"status": call.status.value,
|
||||
"created_at": call.created_at.isoformat(),
|
||||
"model": call.model,
|
||||
"messages": call.messages
|
||||
"messages": call.messages,
|
||||
}
|
||||
|
||||
|
||||
if call.completed_at:
|
||||
result["completed_at"] = call.completed_at.isoformat()
|
||||
if call.response:
|
||||
@@ -103,69 +108,74 @@ class CompletionQueue:
|
||||
result["tool_calls"] = call.tool_calls
|
||||
if call.error:
|
||||
result["error"] = call.error
|
||||
|
||||
|
||||
return result
|
||||
|
||||
async def complete_call(self, call_id: str, response: Optional[str] = None, tool_calls: Optional[List[Dict[str, Any]]] = None) -> bool:
|
||||
|
||||
async def complete_call(
|
||||
self,
|
||||
call_id: str,
|
||||
response: Optional[str] = None,
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> bool:
|
||||
"""Mark a completion call as completed with a response or tool calls."""
|
||||
async with self._lock:
|
||||
if call_id not in self._queue:
|
||||
return False
|
||||
|
||||
|
||||
call = self._queue[call_id]
|
||||
if call.status != CompletionStatus.PENDING:
|
||||
return False
|
||||
|
||||
|
||||
call.status = CompletionStatus.COMPLETED
|
||||
call.completed_at = datetime.now()
|
||||
call.response = response
|
||||
call.tool_calls = tool_calls
|
||||
|
||||
|
||||
# Remove from pending order
|
||||
if call_id in self._pending_order:
|
||||
self._pending_order.remove(call_id)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def fail_call(self, call_id: str, error: str) -> bool:
|
||||
"""Mark a completion call as failed with an error."""
|
||||
async with self._lock:
|
||||
if call_id not in self._queue:
|
||||
return False
|
||||
|
||||
|
||||
call = self._queue[call_id]
|
||||
if call.status != CompletionStatus.PENDING:
|
||||
return False
|
||||
|
||||
|
||||
call.status = CompletionStatus.FAILED
|
||||
call.completed_at = datetime.now()
|
||||
call.error = error
|
||||
|
||||
|
||||
# Remove from pending order
|
||||
if call_id in self._pending_order:
|
||||
self._pending_order.remove(call_id)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def wait_for_completion(self, call_id: str, timeout: float = 300.0) -> Optional[str]:
|
||||
"""Wait for a completion call to be completed and return the response."""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
|
||||
while True:
|
||||
status = await self.get_call_status(call_id)
|
||||
if not status:
|
||||
return None
|
||||
|
||||
|
||||
if status["status"] == CompletionStatus.COMPLETED.value:
|
||||
return status.get("response")
|
||||
elif status["status"] == CompletionStatus.FAILED.value:
|
||||
raise Exception(f"Completion failed: {status.get('error', 'Unknown error')}")
|
||||
|
||||
|
||||
# Check timeout
|
||||
if asyncio.get_event_loop().time() - start_time > timeout:
|
||||
await self.fail_call(call_id, "Timeout waiting for human response")
|
||||
raise TimeoutError("Timeout waiting for human response")
|
||||
|
||||
|
||||
# Wait a bit before checking again
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
@@ -204,9 +214,7 @@ async def get_status(call_id: str):
|
||||
async def complete_call(call_id: str, response: CompletionResponse):
|
||||
"""Complete a call with a human response."""
|
||||
success = await completion_queue.complete_call(
|
||||
call_id,
|
||||
response=response.response,
|
||||
tool_calls=response.tool_calls
|
||||
call_id, response=response.response, tool_calls=response.tool_calls
|
||||
)
|
||||
if success:
|
||||
return {"status": "success", "message": "Call completed"}
|
||||
@@ -219,7 +227,9 @@ async def fail_call(call_id: str, error: Dict[str, str]):
|
||||
"""Mark a call as failed."""
|
||||
success = await completion_queue.fail_call(call_id, error.get("error", "Unknown error"))
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Completion call not found or already completed")
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Completion call not found or already completed"
|
||||
)
|
||||
return {"status": "failed"}
|
||||
|
||||
|
||||
@@ -231,4 +241,5 @@ async def root():
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8002)
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
import gradio as gr
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
import requests
|
||||
from .server import completion_queue
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import gradio as gr
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from .server import completion_queue
|
||||
|
||||
|
||||
class HumanCompletionUI:
|
||||
def __init__(self, server_url: str = "http://localhost:8002"):
|
||||
self.server_url = server_url
|
||||
@@ -20,7 +23,7 @@ class HumanCompletionUI:
|
||||
self.current_button: str = "left"
|
||||
self.current_scroll_x: int = 0
|
||||
self.current_scroll_y: int = -120
|
||||
|
||||
|
||||
def format_messages_for_chatbot(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Format messages for display in gr.Chatbot with type='messages'."""
|
||||
formatted = []
|
||||
@@ -28,7 +31,7 @@ class HumanCompletionUI:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
tool_calls = msg.get("tool_calls", [])
|
||||
|
||||
|
||||
# Handle different content formats
|
||||
if isinstance(content, list):
|
||||
# Multi-modal content - can include text and images
|
||||
@@ -55,7 +58,7 @@ class HumanCompletionUI:
|
||||
else:
|
||||
# For URL images, create gr.Image with URL
|
||||
formatted_content.append(gr.Image(value=image_url))
|
||||
|
||||
|
||||
# Determine final content format
|
||||
if len(formatted_content) == 1:
|
||||
content = formatted_content[0]
|
||||
@@ -63,28 +66,28 @@ class HumanCompletionUI:
|
||||
content = formatted_content
|
||||
else:
|
||||
content = "[Empty content]"
|
||||
|
||||
|
||||
# Ensure role is valid for Gradio Chatbot
|
||||
if role not in ["user", "assistant"]:
|
||||
role = "assistant" if role == "system" else "user"
|
||||
|
||||
|
||||
# Invert roles for better display in human UI context
|
||||
# (what the AI says becomes "user", what human should respond becomes "assistant")
|
||||
if role == "user":
|
||||
role = "assistant"
|
||||
else:
|
||||
role = "user"
|
||||
|
||||
|
||||
# Add the main message if it has content
|
||||
if content and str(content).strip():
|
||||
formatted.append({"role": role, "content": content})
|
||||
|
||||
|
||||
# Handle tool calls - create separate messages for each tool call
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call.get("function", {}).get("name", "unknown")
|
||||
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
|
||||
|
||||
|
||||
try:
|
||||
# Parse arguments to format them nicely
|
||||
arguments = json.loads(arguments_str)
|
||||
@@ -92,18 +95,20 @@ class HumanCompletionUI:
|
||||
except json.JSONDecodeError:
|
||||
# If parsing fails, use the raw string
|
||||
formatted_args = arguments_str
|
||||
|
||||
|
||||
# Create a formatted message for the tool call
|
||||
tool_call_content = f"```json\n{formatted_args}\n```"
|
||||
|
||||
formatted.append({
|
||||
"role": role,
|
||||
"content": tool_call_content,
|
||||
"metadata": {"title": f"🛠️ Used {function_name}"}
|
||||
})
|
||||
|
||||
|
||||
formatted.append(
|
||||
{
|
||||
"role": role,
|
||||
"content": tool_call_content,
|
||||
"metadata": {"title": f"🛠️ Used {function_name}"},
|
||||
}
|
||||
)
|
||||
|
||||
return formatted
|
||||
|
||||
|
||||
def get_pending_calls(self) -> List[Dict[str, Any]]:
|
||||
"""Get pending calls from the server."""
|
||||
try:
|
||||
@@ -113,38 +118,39 @@ class HumanCompletionUI:
|
||||
except Exception as e:
|
||||
print(f"Error fetching pending calls: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def complete_call_with_response(self, call_id: str, response: str) -> bool:
|
||||
"""Complete a call with a text response."""
|
||||
try:
|
||||
response_data = {"response": response}
|
||||
response_obj = requests.post(
|
||||
f"{self.server_url}/complete/{call_id}",
|
||||
json=response_data,
|
||||
timeout=10
|
||||
f"{self.server_url}/complete/{call_id}", json=response_data, timeout=10
|
||||
)
|
||||
response_obj.raise_for_status()
|
||||
return True
|
||||
except requests.RequestException as e:
|
||||
print(f"Error completing call: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def complete_call_with_tool_calls(self, call_id: str, tool_calls: List[Dict[str, Any]]) -> bool:
|
||||
"""Complete a call with tool calls."""
|
||||
try:
|
||||
response_data = {"tool_calls": tool_calls}
|
||||
response_obj = requests.post(
|
||||
f"{self.server_url}/complete/{call_id}",
|
||||
json=response_data,
|
||||
timeout=10
|
||||
f"{self.server_url}/complete/{call_id}", json=response_data, timeout=10
|
||||
)
|
||||
response_obj.raise_for_status()
|
||||
return True
|
||||
except requests.RequestException as e:
|
||||
print(f"Error completing call: {e}")
|
||||
return False
|
||||
|
||||
def complete_call(self, call_id: str, response: Optional[str] = None, tool_calls: Optional[List[Dict[str, Any]]] = None) -> bool:
|
||||
|
||||
def complete_call(
|
||||
self,
|
||||
call_id: str,
|
||||
response: Optional[str] = None,
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> bool:
|
||||
"""Complete a call with either a response or tool calls."""
|
||||
try:
|
||||
response_data = {}
|
||||
@@ -152,25 +158,23 @@ class HumanCompletionUI:
|
||||
response_data["response"] = response
|
||||
if tool_calls:
|
||||
response_data["tool_calls"] = tool_calls
|
||||
|
||||
|
||||
response_obj = requests.post(
|
||||
f"{self.server_url}/complete/{call_id}",
|
||||
json=response_data,
|
||||
timeout=10
|
||||
f"{self.server_url}/complete/{call_id}", json=response_data, timeout=10
|
||||
)
|
||||
response_obj.raise_for_status()
|
||||
return True
|
||||
except requests.RequestException as e:
|
||||
print(f"Error completing call: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_last_image_from_messages(self, messages: List[Dict[str, Any]]) -> Optional[Any]:
|
||||
"""Extract the last image from the messages for display above conversation."""
|
||||
last_image = None
|
||||
|
||||
|
||||
for msg in reversed(messages): # Start from the last message
|
||||
content = msg.get("content", "")
|
||||
|
||||
|
||||
if isinstance(content, list):
|
||||
for item in reversed(content): # Get the last image in the message
|
||||
if item.get("type") == "image_url":
|
||||
@@ -189,13 +193,13 @@ class HumanCompletionUI:
|
||||
else:
|
||||
# For URL images, return the URL
|
||||
return image_url
|
||||
|
||||
|
||||
return last_image
|
||||
|
||||
|
||||
def refresh_pending_calls(self):
|
||||
"""Refresh the list of pending calls."""
|
||||
pending_calls = self.get_pending_calls()
|
||||
|
||||
|
||||
if not pending_calls:
|
||||
return (
|
||||
gr.update(choices=["latest"], value="latest"), # dropdown
|
||||
@@ -205,27 +209,27 @@ class HumanCompletionUI:
|
||||
gr.update(visible=False), # click_actions_group hidden
|
||||
gr.update(visible=False), # actions_group hidden
|
||||
)
|
||||
|
||||
|
||||
# Sort pending calls by created_at to get oldest first
|
||||
sorted_calls = sorted(pending_calls, key=lambda x: x.get("created_at", ""))
|
||||
|
||||
|
||||
# Create choices for dropdown
|
||||
choices = [("latest", "latest")] # Add "latest" option first
|
||||
|
||||
|
||||
for call in sorted_calls:
|
||||
call_id = call["id"]
|
||||
model = call.get("model", "unknown")
|
||||
created_at = call.get("created_at", "")
|
||||
# Format timestamp
|
||||
try:
|
||||
dt = datetime.fromisoformat(created_at.replace('Z', '+00:00'))
|
||||
dt = datetime.fromisoformat(created_at.replace("Z", "+00:00"))
|
||||
time_str = dt.strftime("%H:%M:%S")
|
||||
except:
|
||||
time_str = created_at
|
||||
|
||||
|
||||
choice_label = f"{call_id[:8]}... ({model}) - {time_str}"
|
||||
choices.append((choice_label, call_id))
|
||||
|
||||
|
||||
# Default to "latest" which shows the oldest pending conversation
|
||||
selected_call_id = "latest"
|
||||
if selected_call_id == "latest" and sorted_calls:
|
||||
@@ -239,7 +243,7 @@ class HumanCompletionUI:
|
||||
conversation = []
|
||||
self.current_call_id = None
|
||||
self.last_image = None
|
||||
|
||||
|
||||
return (
|
||||
gr.update(choices=choices, value="latest"),
|
||||
gr.update(value=self.last_image),
|
||||
@@ -248,7 +252,7 @@ class HumanCompletionUI:
|
||||
gr.update(visible=True), # click_actions_group visible when there is a call
|
||||
gr.update(visible=True), # actions_group visible when there is a call
|
||||
)
|
||||
|
||||
|
||||
def on_call_selected(self, selected_choice):
|
||||
"""Handle when a call is selected from the dropdown."""
|
||||
if not selected_choice:
|
||||
@@ -259,7 +263,7 @@ class HumanCompletionUI:
|
||||
gr.update(visible=False), # click_actions_group hidden
|
||||
gr.update(visible=False), # actions_group hidden
|
||||
)
|
||||
|
||||
|
||||
pending_calls = self.get_pending_calls()
|
||||
if not pending_calls:
|
||||
return (
|
||||
@@ -269,7 +273,7 @@ class HumanCompletionUI:
|
||||
gr.update(visible=False), # click_actions_group hidden
|
||||
gr.update(visible=False), # actions_group hidden
|
||||
)
|
||||
|
||||
|
||||
# Handle "latest" option
|
||||
if selected_choice == "latest":
|
||||
# Sort calls by created_at to get oldest first
|
||||
@@ -284,17 +288,17 @@ class HumanCompletionUI:
|
||||
if call_id_short in selected_choice:
|
||||
call_id = call["id"]
|
||||
break
|
||||
|
||||
|
||||
if not call_id:
|
||||
return (
|
||||
gr.update(value=None), # no image
|
||||
gr.update(value=[]), # empty chatbot
|
||||
gr.update(interactive=False)
|
||||
gr.update(interactive=False),
|
||||
)
|
||||
|
||||
|
||||
# Find the selected call
|
||||
selected_call = next((c for c in pending_calls if c["id"] == call_id), None)
|
||||
|
||||
|
||||
if not selected_call:
|
||||
return (
|
||||
gr.update(value=None), # no image
|
||||
@@ -303,12 +307,12 @@ class HumanCompletionUI:
|
||||
gr.update(visible=False), # click_actions_group hidden
|
||||
gr.update(visible=False), # actions_group hidden
|
||||
)
|
||||
|
||||
|
||||
conversation = self.format_messages_for_chatbot(selected_call.get("messages", []))
|
||||
self.current_call_id = call_id
|
||||
# Get the last image from messages
|
||||
self.last_image = self.get_last_image_from_messages(selected_call.get("messages", []))
|
||||
|
||||
|
||||
return (
|
||||
gr.update(value=self.last_image),
|
||||
gr.update(value=conversation),
|
||||
@@ -316,110 +320,111 @@ class HumanCompletionUI:
|
||||
gr.update(visible=True), # click_actions_group visible
|
||||
gr.update(visible=True), # actions_group visible
|
||||
)
|
||||
|
||||
|
||||
def submit_response(self, response_text: str):
|
||||
"""Submit a text response to the current call."""
|
||||
if not self.current_call_id:
|
||||
return (
|
||||
gr.update(value=response_text), # keep response text
|
||||
gr.update(value="❌ No call selected") # status
|
||||
gr.update(value="❌ No call selected"), # status
|
||||
)
|
||||
|
||||
|
||||
if not response_text.strip():
|
||||
return (
|
||||
gr.update(value=response_text), # keep response text
|
||||
gr.update(value="❌ Response cannot be empty") # status
|
||||
gr.update(value="❌ Response cannot be empty"), # status
|
||||
)
|
||||
|
||||
|
||||
success = self.complete_call_with_response(self.current_call_id, response_text)
|
||||
|
||||
|
||||
if success:
|
||||
status_msg = "✅ Response submitted successfully!"
|
||||
return (
|
||||
gr.update(value=""), # clear response text
|
||||
gr.update(value=status_msg) # status
|
||||
gr.update(value=status_msg), # status
|
||||
)
|
||||
else:
|
||||
return (
|
||||
gr.update(value=response_text), # keep response text
|
||||
gr.update(value="❌ Failed to submit response") # status
|
||||
gr.update(value="❌ Failed to submit response"), # status
|
||||
)
|
||||
|
||||
|
||||
def submit_action(self, action_type: str, **kwargs) -> str:
|
||||
"""Submit a computer action as a tool call."""
|
||||
if not self.current_call_id:
|
||||
return "❌ No call selected"
|
||||
|
||||
|
||||
import uuid
|
||||
|
||||
|
||||
# Create tool call structure
|
||||
action_data = {"type": action_type, **kwargs}
|
||||
tool_call = {
|
||||
"id": f"call_{uuid.uuid4().hex[:24]}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "computer",
|
||||
"arguments": json.dumps(action_data)
|
||||
}
|
||||
"function": {"name": "computer", "arguments": json.dumps(action_data)},
|
||||
}
|
||||
|
||||
|
||||
success = self.complete_call_with_tool_calls(self.current_call_id, [tool_call])
|
||||
|
||||
|
||||
if success:
|
||||
return f"✅ {action_type.capitalize()} action submitted as tool call"
|
||||
else:
|
||||
return f"❌ Failed to submit {action_type} action"
|
||||
|
||||
def submit_click_action(self, x: int, y: int, action_type: str = "click", button: str = "left") -> str:
|
||||
|
||||
def submit_click_action(
|
||||
self, x: int, y: int, action_type: str = "click", button: str = "left"
|
||||
) -> str:
|
||||
"""Submit a coordinate-based action."""
|
||||
if action_type == "click":
|
||||
return self.submit_action(action_type, x=x, y=y, button=button)
|
||||
else:
|
||||
return self.submit_action(action_type, x=x, y=y)
|
||||
|
||||
|
||||
def submit_type_action(self, text: str) -> str:
|
||||
"""Submit a type action."""
|
||||
return self.submit_action("type", text=text)
|
||||
|
||||
|
||||
def submit_hotkey_action(self, keys: str) -> str:
|
||||
"""Submit a hotkey action."""
|
||||
return self.submit_action("keypress", keys=keys)
|
||||
|
||||
|
||||
def submit_wait_action(self) -> str:
|
||||
"""Submit a wait action with no kwargs."""
|
||||
return self.submit_action("wait")
|
||||
|
||||
def submit_description_click(self, description: str, action_type: str = "click", button: str = "left") -> str:
|
||||
|
||||
def submit_description_click(
|
||||
self, description: str, action_type: str = "click", button: str = "left"
|
||||
) -> str:
|
||||
"""Submit a description-based action."""
|
||||
if action_type == "click":
|
||||
return self.submit_action(action_type, element_description=description, button=button)
|
||||
else:
|
||||
return self.submit_action(action_type, element_description=description)
|
||||
|
||||
|
||||
def wait_for_pending_calls(self, max_seconds: float = 10.0, check_interval: float = 0.2):
|
||||
"""Wait for pending calls to appear or until max_seconds elapsed.
|
||||
|
||||
|
||||
This method loops and checks for pending calls at regular intervals,
|
||||
returning as soon as a pending call is found or the maximum wait time is reached.
|
||||
|
||||
|
||||
Args:
|
||||
max_seconds: Maximum number of seconds to wait
|
||||
check_interval: How often to check for pending calls (in seconds)
|
||||
"""
|
||||
import time
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
while time.time() - start_time < max_seconds:
|
||||
# Check if there are any pending calls
|
||||
pending_calls = self.get_pending_calls()
|
||||
if pending_calls:
|
||||
# Found pending calls, return immediately
|
||||
return self.refresh_pending_calls()
|
||||
|
||||
|
||||
# Wait before checking again
|
||||
time.sleep(check_interval)
|
||||
|
||||
|
||||
# Max wait time reached, return current state
|
||||
return self.refresh_pending_calls()
|
||||
|
||||
@@ -427,79 +432,73 @@ class HumanCompletionUI:
|
||||
def create_ui():
|
||||
"""Create the Gradio interface."""
|
||||
ui_handler = HumanCompletionUI()
|
||||
|
||||
|
||||
with gr.Blocks(title="Human-in-the-Loop Agent Tool", fill_width=True) as demo:
|
||||
gr.Markdown("# 🤖 Human-in-the-Loop Agent Tool")
|
||||
gr.Markdown("Review AI conversation requests and provide human responses.")
|
||||
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
with gr.Group():
|
||||
screenshot_image = gr.Image(
|
||||
label="Interactive Screenshot",
|
||||
interactive=False,
|
||||
height=600
|
||||
label="Interactive Screenshot", interactive=False, height=600
|
||||
)
|
||||
|
||||
|
||||
# Action type selection for image clicks (wrapped for visibility control)
|
||||
with gr.Group(visible=False) as click_actions_group:
|
||||
with gr.Row():
|
||||
action_type_radio = gr.Dropdown(
|
||||
label="Interactive Action",
|
||||
choices=["click", "double_click", "move", "left_mouse_up", "left_mouse_down", "scroll"],
|
||||
choices=[
|
||||
"click",
|
||||
"double_click",
|
||||
"move",
|
||||
"left_mouse_up",
|
||||
"left_mouse_down",
|
||||
"scroll",
|
||||
],
|
||||
value="click",
|
||||
scale=2
|
||||
scale=2,
|
||||
)
|
||||
action_button_radio = gr.Dropdown(
|
||||
label="Button",
|
||||
choices=["left", "right", "wheel", "back", "forward"],
|
||||
value="left",
|
||||
visible=True,
|
||||
scale=1
|
||||
scale=1,
|
||||
)
|
||||
scroll_x_input = gr.Number(
|
||||
label="scroll_x",
|
||||
value=0,
|
||||
visible=False,
|
||||
scale=1
|
||||
label="scroll_x", value=0, visible=False, scale=1
|
||||
)
|
||||
scroll_y_input = gr.Number(
|
||||
label="scroll_y",
|
||||
value=-120,
|
||||
visible=False,
|
||||
scale=1
|
||||
label="scroll_y", value=-120, visible=False, scale=1
|
||||
)
|
||||
|
||||
|
||||
conversation_chatbot = gr.Chatbot(
|
||||
label="Conversation",
|
||||
type="messages",
|
||||
height=500,
|
||||
show_copy_button=True
|
||||
label="Conversation", type="messages", height=500, show_copy_button=True
|
||||
)
|
||||
|
||||
|
||||
with gr.Column(scale=1):
|
||||
with gr.Group():
|
||||
call_dropdown = gr.Dropdown(
|
||||
label="Select a pending conversation request",
|
||||
choices=["latest"],
|
||||
interactive=True,
|
||||
value="latest"
|
||||
value="latest",
|
||||
)
|
||||
refresh_btn = gr.Button("🔄 Refresh", variant="secondary")
|
||||
status_display = gr.Textbox(
|
||||
label="Status",
|
||||
interactive=False,
|
||||
value="Ready to receive requests..."
|
||||
label="Status", interactive=False, value="Ready to receive requests..."
|
||||
)
|
||||
|
||||
with gr.Group():
|
||||
response_text = gr.Textbox(
|
||||
label="Message",
|
||||
lines=3,
|
||||
placeholder="Enter your message here..."
|
||||
label="Message", lines=3, placeholder="Enter your message here..."
|
||||
)
|
||||
submit_btn = gr.Button("📤 Submit Message", variant="primary", interactive=False)
|
||||
|
||||
submit_btn = gr.Button(
|
||||
"📤 Submit Message", variant="primary", interactive=False
|
||||
)
|
||||
|
||||
# Action Accordions (wrapped for visibility control)
|
||||
with gr.Group(visible=False) as actions_group:
|
||||
with gr.Tabs():
|
||||
@@ -507,58 +506,73 @@ def create_ui():
|
||||
with gr.Group():
|
||||
description_text = gr.Textbox(
|
||||
label="Element Description",
|
||||
placeholder="e.g., 'Privacy and security option in left sidebar'"
|
||||
placeholder="e.g., 'Privacy and security option in left sidebar'",
|
||||
)
|
||||
with gr.Row():
|
||||
description_action_type = gr.Dropdown(
|
||||
label="Action",
|
||||
choices=["click", "double_click", "move", "left_mouse_up", "left_mouse_down"],
|
||||
value="click"
|
||||
choices=[
|
||||
"click",
|
||||
"double_click",
|
||||
"move",
|
||||
"left_mouse_up",
|
||||
"left_mouse_down",
|
||||
],
|
||||
value="click",
|
||||
)
|
||||
description_button = gr.Dropdown(
|
||||
label="Button",
|
||||
choices=["left", "right", "wheel", "back", "forward"],
|
||||
value="left"
|
||||
value="left",
|
||||
)
|
||||
description_submit_btn = gr.Button("Submit Click Action")
|
||||
|
||||
|
||||
with gr.Tab("📝 Type Action"):
|
||||
with gr.Group():
|
||||
type_text = gr.Textbox(
|
||||
label="Text to Type",
|
||||
placeholder="Enter text to type..."
|
||||
label="Text to Type", placeholder="Enter text to type..."
|
||||
)
|
||||
type_submit_btn = gr.Button("Submit Type")
|
||||
|
||||
|
||||
with gr.Tab("⌨️ Keypress Action"):
|
||||
with gr.Group():
|
||||
keypress_text = gr.Textbox(
|
||||
label="Keys",
|
||||
placeholder="e.g., ctrl+c, alt+tab"
|
||||
label="Keys", placeholder="e.g., ctrl+c, alt+tab"
|
||||
)
|
||||
keypress_submit_btn = gr.Button("Submit Keypress")
|
||||
|
||||
|
||||
with gr.Tab("🧰 Misc Actions"):
|
||||
with gr.Group():
|
||||
misc_action_dropdown = gr.Dropdown(
|
||||
label="Action",
|
||||
choices=["wait"],
|
||||
value="wait"
|
||||
label="Action", choices=["wait"], value="wait"
|
||||
)
|
||||
misc_submit_btn = gr.Button("Submit Action")
|
||||
|
||||
|
||||
# Event handlers
|
||||
refresh_btn.click(
|
||||
fn=ui_handler.refresh_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
outputs=[
|
||||
call_dropdown,
|
||||
screenshot_image,
|
||||
conversation_chatbot,
|
||||
submit_btn,
|
||||
click_actions_group,
|
||||
actions_group,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
call_dropdown.change(
|
||||
fn=ui_handler.on_call_selected,
|
||||
inputs=[call_dropdown],
|
||||
outputs=[screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
outputs=[
|
||||
screenshot_image,
|
||||
conversation_chatbot,
|
||||
submit_btn,
|
||||
click_actions_group,
|
||||
actions_group,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def handle_image_click(evt: gr.SelectData):
|
||||
if evt.index is not None:
|
||||
x, y = evt.index
|
||||
@@ -568,31 +582,44 @@ def create_ui():
|
||||
sx_i = int(ui_handler.current_scroll_x or 0)
|
||||
sy_i = int(ui_handler.current_scroll_y or 0)
|
||||
# Submit a scroll action with x,y position and scroll deltas
|
||||
result = ui_handler.submit_action("scroll", x=x, y=y, scroll_x=sx_i, scroll_y=sy_i)
|
||||
result = ui_handler.submit_action(
|
||||
"scroll", x=x, y=y, scroll_x=sx_i, scroll_y=sy_i
|
||||
)
|
||||
else:
|
||||
result = ui_handler.submit_click_action(x, y, action_type, button)
|
||||
ui_handler.wait_for_pending_calls()
|
||||
return result
|
||||
return "No coordinates selected"
|
||||
|
||||
screenshot_image.select(
|
||||
fn=handle_image_click,
|
||||
outputs=[status_display]
|
||||
).then(
|
||||
screenshot_image.select(fn=handle_image_click, outputs=[status_display]).then(
|
||||
fn=ui_handler.wait_for_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
outputs=[
|
||||
call_dropdown,
|
||||
screenshot_image,
|
||||
conversation_chatbot,
|
||||
submit_btn,
|
||||
click_actions_group,
|
||||
actions_group,
|
||||
],
|
||||
)
|
||||
|
||||
# Response submission
|
||||
submit_btn.click(
|
||||
fn=ui_handler.submit_response,
|
||||
inputs=[response_text],
|
||||
outputs=[response_text, status_display]
|
||||
outputs=[response_text, status_display],
|
||||
).then(
|
||||
fn=ui_handler.refresh_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
outputs=[
|
||||
call_dropdown,
|
||||
screenshot_image,
|
||||
conversation_chatbot,
|
||||
submit_btn,
|
||||
click_actions_group,
|
||||
actions_group,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# Toggle visibility of controls based on action type
|
||||
def toggle_action_controls(action_type):
|
||||
# Button visible only for click
|
||||
@@ -603,59 +630,63 @@ def create_ui():
|
||||
# Update state
|
||||
ui_handler.current_action_type = action_type or "click"
|
||||
return button_vis, scroll_x_vis, scroll_y_vis
|
||||
|
||||
|
||||
action_type_radio.change(
|
||||
fn=toggle_action_controls,
|
||||
inputs=[action_type_radio],
|
||||
outputs=[action_button_radio, scroll_x_input, scroll_y_input]
|
||||
outputs=[action_button_radio, scroll_x_input, scroll_y_input],
|
||||
)
|
||||
|
||||
# Keep other control values in ui_handler state
|
||||
def on_button_change(val):
|
||||
ui_handler.current_button = (val or "left")
|
||||
action_button_radio.change(
|
||||
fn=on_button_change,
|
||||
inputs=[action_button_radio]
|
||||
)
|
||||
ui_handler.current_button = val or "left"
|
||||
|
||||
action_button_radio.change(fn=on_button_change, inputs=[action_button_radio])
|
||||
|
||||
def on_scroll_x_change(val):
|
||||
try:
|
||||
ui_handler.current_scroll_x = int(val) if val is not None else 0
|
||||
except Exception:
|
||||
ui_handler.current_scroll_x = 0
|
||||
scroll_x_input.change(
|
||||
fn=on_scroll_x_change,
|
||||
inputs=[scroll_x_input]
|
||||
)
|
||||
|
||||
scroll_x_input.change(fn=on_scroll_x_change, inputs=[scroll_x_input])
|
||||
|
||||
def on_scroll_y_change(val):
|
||||
try:
|
||||
ui_handler.current_scroll_y = int(val) if val is not None else 0
|
||||
except Exception:
|
||||
ui_handler.current_scroll_y = 0
|
||||
scroll_y_input.change(
|
||||
fn=on_scroll_y_change,
|
||||
inputs=[scroll_y_input]
|
||||
)
|
||||
|
||||
|
||||
scroll_y_input.change(fn=on_scroll_y_change, inputs=[scroll_y_input])
|
||||
|
||||
type_submit_btn.click(
|
||||
fn=ui_handler.submit_type_action,
|
||||
inputs=[type_text],
|
||||
outputs=[status_display]
|
||||
fn=ui_handler.submit_type_action, inputs=[type_text], outputs=[status_display]
|
||||
).then(
|
||||
fn=ui_handler.wait_for_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
outputs=[
|
||||
call_dropdown,
|
||||
screenshot_image,
|
||||
conversation_chatbot,
|
||||
submit_btn,
|
||||
click_actions_group,
|
||||
actions_group,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
keypress_submit_btn.click(
|
||||
fn=ui_handler.submit_hotkey_action,
|
||||
inputs=[keypress_text],
|
||||
outputs=[status_display]
|
||||
fn=ui_handler.submit_hotkey_action, inputs=[keypress_text], outputs=[status_display]
|
||||
).then(
|
||||
fn=ui_handler.wait_for_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
outputs=[
|
||||
call_dropdown,
|
||||
screenshot_image,
|
||||
conversation_chatbot,
|
||||
submit_btn,
|
||||
click_actions_group,
|
||||
actions_group,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def handle_description_submit(description, action_type, button):
|
||||
if description:
|
||||
result = ui_handler.submit_description_click(description, action_type, button)
|
||||
@@ -666,12 +697,19 @@ def create_ui():
|
||||
description_submit_btn.click(
|
||||
fn=handle_description_submit,
|
||||
inputs=[description_text, description_action_type, description_button],
|
||||
outputs=[status_display]
|
||||
outputs=[status_display],
|
||||
).then(
|
||||
fn=ui_handler.wait_for_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
outputs=[
|
||||
call_dropdown,
|
||||
screenshot_image,
|
||||
conversation_chatbot,
|
||||
submit_btn,
|
||||
click_actions_group,
|
||||
actions_group,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# Misc action handler
|
||||
def handle_misc_submit(selected_action):
|
||||
if selected_action == "wait":
|
||||
@@ -681,20 +719,32 @@ def create_ui():
|
||||
return f"Unsupported misc action: {selected_action}"
|
||||
|
||||
misc_submit_btn.click(
|
||||
fn=handle_misc_submit,
|
||||
inputs=[misc_action_dropdown],
|
||||
outputs=[status_display]
|
||||
fn=handle_misc_submit, inputs=[misc_action_dropdown], outputs=[status_display]
|
||||
).then(
|
||||
fn=ui_handler.wait_for_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
outputs=[
|
||||
call_dropdown,
|
||||
screenshot_image,
|
||||
conversation_chatbot,
|
||||
submit_btn,
|
||||
click_actions_group,
|
||||
actions_group,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# Load initial data
|
||||
demo.load(
|
||||
fn=ui_handler.refresh_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
outputs=[
|
||||
call_dropdown,
|
||||
screenshot_image,
|
||||
conversation_chatbot,
|
||||
submit_btn,
|
||||
click_actions_group,
|
||||
actions_group,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
|
||||
@@ -8,21 +8,22 @@ Exports:
|
||||
- run_full_dataset(dataset, ...)
|
||||
- MCPComputerAgent
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from agent.computers import is_agent_computer
|
||||
from datasets import load_dataset, Dataset
|
||||
from hud.datasets import Task, run_dataset
|
||||
from datasets import Dataset, load_dataset
|
||||
from hud import trace
|
||||
from hud.datasets import Task, run_dataset
|
||||
|
||||
from .agent import MCPComputerAgent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Single-task runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_single_task(
|
||||
dataset: str | Dataset | list[dict[str, Any]],
|
||||
*,
|
||||
@@ -47,24 +48,20 @@ async def run_single_task(
|
||||
|
||||
# Load dataset and pick a sample
|
||||
if isinstance(dataset, str):
|
||||
dataset = load_dataset(dataset, split="train") # type: ignore[arg-type]
|
||||
dataset = load_dataset(dataset, split="train") # type: ignore[arg-type]
|
||||
elif isinstance(dataset, list):
|
||||
dataset = dataset
|
||||
else:
|
||||
dataset = dataset["train"]
|
||||
|
||||
|
||||
sample_task = dataset[task_id] # type: ignore[index]
|
||||
task_prompt = sample_task.get("prompt", f"Task {sample_task.get('id', 0)}") # type: ignore[attr-defined]
|
||||
|
||||
# Filter any existing Computer tools
|
||||
# The eval framework will add its own Computer tool per task
|
||||
if tools:
|
||||
tools = [
|
||||
tool
|
||||
for tool in tools
|
||||
if not is_agent_computer(tool)
|
||||
]
|
||||
|
||||
tools = [tool for tool in tools if not is_agent_computer(tool)]
|
||||
|
||||
with trace(name=task_prompt):
|
||||
task = Task(**sample_task) # type: ignore[arg-type]
|
||||
|
||||
@@ -87,13 +84,14 @@ async def run_single_task(
|
||||
)
|
||||
print(f"Running: {task_prompt}")
|
||||
result = await agent.run(task, max_steps=10)
|
||||
print(f"✅ Reward: {getattr(result, 'reward')}")
|
||||
print(f"✅ Reward: {result.reward}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full-dataset runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_full_dataset(
|
||||
dataset: str | Dataset | list[dict[str, Any]],
|
||||
*,
|
||||
@@ -121,9 +119,9 @@ async def run_full_dataset(
|
||||
|
||||
# Run with our MCP-based agent class.
|
||||
if isinstance(dataset, str):
|
||||
dataset_name = dataset.split('/')[-1]
|
||||
dataset_name = dataset.split("/")[-1]
|
||||
job_name = job_name or f"Evaluation {dataset_name}"
|
||||
dataset = load_dataset(dataset, split=split) # type: ignore[arg-type]
|
||||
dataset = load_dataset(dataset, split=split) # type: ignore[arg-type]
|
||||
else:
|
||||
dataset_name = "custom"
|
||||
job_name = job_name or f"Evaluation {time.strftime('%H:%M %Y-%m-%d')}"
|
||||
@@ -131,12 +129,8 @@ async def run_full_dataset(
|
||||
# Filter any existing Computer tools
|
||||
# The eval framework will add its own Computer tool per task
|
||||
if tools:
|
||||
tools = [
|
||||
tool
|
||||
for tool in tools
|
||||
if not is_agent_computer(tool)
|
||||
]
|
||||
|
||||
tools = [tool for tool in tools if not is_agent_computer(tool)]
|
||||
|
||||
# Execute evaluation
|
||||
return await run_dataset(
|
||||
name=job_name,
|
||||
@@ -170,4 +164,4 @@ __all__ = [
|
||||
"run_single_task",
|
||||
"run_full_dataset",
|
||||
"MCPComputerAgent",
|
||||
]
|
||||
]
|
||||
|
||||
@@ -9,26 +9,26 @@ Key differences from the OpenAI OperatorAgent variant:
|
||||
- Planning is executed via `ComputerAgent.run(messages)`.
|
||||
- The first yielded result per step is returned as the agent response.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Optional
|
||||
|
||||
import hud
|
||||
import mcp.types as types
|
||||
from agent.agent import ComputerAgent as BaseComputerAgent
|
||||
from agent.callbacks import PromptInstructionsCallback
|
||||
from agent.callbacks.trajectory_saver import TrajectorySaverCallback
|
||||
from agent.computers import is_agent_computer
|
||||
from agent.responses import make_failed_tool_call_items
|
||||
from hud.agents import MCPAgent
|
||||
from hud.tools.computer.settings import computer_settings
|
||||
from hud.types import AgentResponse, MCPToolCall, MCPToolResult, Trace
|
||||
|
||||
from agent.responses import make_failed_tool_call_items
|
||||
from agent.computers import is_agent_computer
|
||||
from PIL import Image
|
||||
import mcp.types as types
|
||||
import hud
|
||||
import uuid
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class MCPComputerAgent(MCPAgent):
|
||||
@@ -114,8 +114,10 @@ class MCPComputerAgent(MCPAgent):
|
||||
self.last_screenshot_b64 = None
|
||||
|
||||
buffer = io.BytesIO()
|
||||
Image.new('RGB', (self.metadata["display_width"], self.metadata["display_height"])).save(buffer, format='PNG')
|
||||
self.last_screenshot_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
Image.new("RGB", (self.metadata["display_width"], self.metadata["display_height"])).save(
|
||||
buffer, format="PNG"
|
||||
)
|
||||
self.last_screenshot_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
# Ensure a computer shim is present so width/height/environment are known
|
||||
computer_shim = {
|
||||
@@ -128,12 +130,8 @@ class MCPComputerAgent(MCPAgent):
|
||||
}
|
||||
agent_tools: list[Any] = [computer_shim]
|
||||
if tools:
|
||||
agent_tools.extend([
|
||||
tool
|
||||
for tool in tools
|
||||
if not is_agent_computer(tool)
|
||||
])
|
||||
|
||||
agent_tools.extend([tool for tool in tools if not is_agent_computer(tool)])
|
||||
|
||||
agent_kwargs = {
|
||||
"model": self.model,
|
||||
"trajectory_dir": trajectory_dir,
|
||||
@@ -150,9 +148,7 @@ class MCPComputerAgent(MCPAgent):
|
||||
"telemetry_enabled": telemetry_enabled,
|
||||
}
|
||||
|
||||
self.computer_agent = BaseComputerAgent(
|
||||
**agent_kwargs
|
||||
)
|
||||
self.computer_agent = BaseComputerAgent(**agent_kwargs)
|
||||
|
||||
async def get_system_messages(self) -> list[Any]:
|
||||
"""Create initial messages.
|
||||
@@ -161,9 +157,7 @@ class MCPComputerAgent(MCPAgent):
|
||||
"""
|
||||
return []
|
||||
|
||||
async def format_blocks(
|
||||
self, blocks: list[types.ContentBlock]
|
||||
) -> list[dict[str, Any]]:
|
||||
async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Format blocks for OpenAI input format.
|
||||
|
||||
@@ -200,42 +194,49 @@ class MCPComputerAgent(MCPAgent):
|
||||
|
||||
# Call the ComputerAgent LLM API
|
||||
async for result in self.computer_agent.run(messages): # type: ignore[arg-type]
|
||||
items = result['output']
|
||||
items = result["output"]
|
||||
if not items or tool_calls:
|
||||
break
|
||||
|
||||
for item in items:
|
||||
if item['type'] in ['reasoning', 'message', 'computer_call', 'function_call', 'function_call_output']:
|
||||
if item["type"] in [
|
||||
"reasoning",
|
||||
"message",
|
||||
"computer_call",
|
||||
"function_call",
|
||||
"function_call_output",
|
||||
]:
|
||||
agent_result.append(item)
|
||||
|
||||
|
||||
# Add messages to output text
|
||||
if item['type'] == 'reasoning':
|
||||
if item["type"] == "reasoning":
|
||||
output_text.extend(
|
||||
f"Reasoning: {summary['text']}"
|
||||
for summary in item['summary']
|
||||
f"Reasoning: {summary['text']}" for summary in item["summary"]
|
||||
)
|
||||
elif item['type'] == 'message':
|
||||
if isinstance(item['content'], list):
|
||||
elif item["type"] == "message":
|
||||
if isinstance(item["content"], list):
|
||||
output_text.extend(
|
||||
item['text']
|
||||
for item in item['content']
|
||||
if item['type'] == 'output_text'
|
||||
item["text"]
|
||||
for item in item["content"]
|
||||
if item["type"] == "output_text"
|
||||
)
|
||||
elif isinstance(item['content'], str):
|
||||
output_text.append(item['content'])
|
||||
|
||||
elif isinstance(item["content"], str):
|
||||
output_text.append(item["content"])
|
||||
|
||||
# If we get a tool call, we're not done
|
||||
if item['type'] == 'computer_call':
|
||||
if item["type"] == "computer_call":
|
||||
id = item["call_id"]
|
||||
tool_calls.append(MCPToolCall(
|
||||
name="openai_computer",
|
||||
arguments=item["action"],
|
||||
id=id,
|
||||
))
|
||||
tool_calls.append(
|
||||
MCPToolCall(
|
||||
name="openai_computer",
|
||||
arguments=item["action"],
|
||||
id=id,
|
||||
)
|
||||
)
|
||||
is_done = False
|
||||
self.tool_call_inputs[id] = agent_result
|
||||
break
|
||||
|
||||
|
||||
# if we have tool calls, we should exit the loop
|
||||
if tool_calls:
|
||||
break
|
||||
@@ -247,7 +248,7 @@ class MCPComputerAgent(MCPAgent):
|
||||
tool_calls=tool_calls,
|
||||
done=is_done,
|
||||
)
|
||||
|
||||
|
||||
def _log_image(self, image_b64: str):
|
||||
callbacks = self.computer_agent.callbacks
|
||||
for callback in callbacks:
|
||||
@@ -257,9 +258,7 @@ class MCPComputerAgent(MCPAgent):
|
||||
callback._save_artifact("screenshot_after", image_bytes)
|
||||
|
||||
async def format_tool_results(
|
||||
self,
|
||||
tool_calls: list[MCPToolCall],
|
||||
tool_results: list[MCPToolResult]
|
||||
self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Extract latest screenshot from tool results in dict form.
|
||||
|
||||
@@ -274,45 +273,60 @@ class MCPComputerAgent(MCPAgent):
|
||||
previous_output = self.previous_output.copy() or []
|
||||
|
||||
# First we need to remove any pending computer_calls from the end of previous_output
|
||||
while previous_output and previous_output[-1]['type'] == 'computer_call':
|
||||
while previous_output and previous_output[-1]["type"] == "computer_call":
|
||||
previous_output.pop()
|
||||
messages.extend(previous_output)
|
||||
|
||||
# If the call is a 'response', don't add the result
|
||||
if call.name == 'response':
|
||||
if call.name == "response":
|
||||
continue
|
||||
# Otherwise, if we have a result, we should add it to the messages
|
||||
content = [
|
||||
{ "type": "input_text", "text": content.text } if isinstance(content, types.TextContent)
|
||||
else { "type": "input_image", "image_url": f"data:image/png;base64,{content.data}" } if isinstance(content, types.ImageContent)
|
||||
else { "type": "input_text", "text": "" }
|
||||
(
|
||||
{"type": "input_text", "text": content.text}
|
||||
if isinstance(content, types.TextContent)
|
||||
else (
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{content.data}",
|
||||
}
|
||||
if isinstance(content, types.ImageContent)
|
||||
else {"type": "input_text", "text": ""}
|
||||
)
|
||||
)
|
||||
for content in result.content
|
||||
]
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": content,
|
||||
})
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
|
||||
# Add the assistant's computer call
|
||||
messages.extend(self.tool_call_inputs[call.id])
|
||||
|
||||
|
||||
if result.isError:
|
||||
error_text = "".join([
|
||||
content.text
|
||||
for content in result.content
|
||||
if isinstance(content, types.TextContent)
|
||||
])
|
||||
error_text = "".join(
|
||||
[
|
||||
content.text
|
||||
for content in result.content
|
||||
if isinstance(content, types.TextContent)
|
||||
]
|
||||
)
|
||||
|
||||
# Replace computer call with failed tool call
|
||||
messages.pop()
|
||||
messages.extend(make_failed_tool_call_items(
|
||||
tool_name=call.name,
|
||||
tool_kwargs=call.arguments or {},
|
||||
error_message=error_text,
|
||||
call_id=call.id,
|
||||
))
|
||||
messages.extend(
|
||||
make_failed_tool_call_items(
|
||||
tool_name=call.name,
|
||||
tool_kwargs=call.arguments or {},
|
||||
error_message=error_text,
|
||||
call_id=call.id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Get the latest screenshot
|
||||
screenshots = [
|
||||
@@ -325,23 +339,27 @@ class MCPComputerAgent(MCPAgent):
|
||||
if screenshots:
|
||||
self._log_image(screenshots[0])
|
||||
self.last_screenshot_b64 = screenshots[0]
|
||||
messages.append({
|
||||
"type": "computer_call_output",
|
||||
"call_id": call.id,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshots[0]}"
|
||||
},
|
||||
})
|
||||
messages.append(
|
||||
{
|
||||
"type": "computer_call_output",
|
||||
"call_id": call.id,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshots[0]}",
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Otherwise, replace computer call with failed tool call
|
||||
messages.pop()
|
||||
messages.extend(make_failed_tool_call_items(
|
||||
tool_name=call.name,
|
||||
tool_kwargs=call.arguments or {},
|
||||
error_message="No screenshots returned.",
|
||||
call_id=call.id,
|
||||
))
|
||||
messages.extend(
|
||||
make_failed_tool_call_items(
|
||||
tool_name=call.name,
|
||||
tool_kwargs=call.arguments or {},
|
||||
error_message="No screenshots returned.",
|
||||
call_id=call.id,
|
||||
)
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@@ -7,30 +7,33 @@ OpenAI-like response blocks. We intentionally only support a single-step call
|
||||
by consuming the first yielded result from `ComputerAgent.run()`.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.agent import ComputerAgent as BaseComputerAgent
|
||||
from agent.callbacks import PromptInstructionsCallback
|
||||
from hud.tools.computer.settings import computer_settings
|
||||
from PIL import Image
|
||||
from hud.agents import OperatorAgent
|
||||
from hud.tools.computer.settings import computer_settings
|
||||
|
||||
# OpenAI Responses typed models (required)
|
||||
from openai.types.responses import (
|
||||
Response,
|
||||
ResponseComputerToolCall,
|
||||
ResponseInputParam,
|
||||
ResponseOutputItem,
|
||||
ResponseComputerToolCall,
|
||||
ResponseOutputMessage,
|
||||
ResponseOutputText,
|
||||
ResponseReasoningItem,
|
||||
ResponseUsage,
|
||||
)
|
||||
from PIL import Image
|
||||
|
||||
def _map_agent_output_to_openai_blocks(output_items: List[Dict[str, Any]]) -> List[ResponseOutputItem]:
|
||||
|
||||
def _map_agent_output_to_openai_blocks(
|
||||
output_items: List[Dict[str, Any]],
|
||||
) -> List[ResponseOutputItem]:
|
||||
"""Map our agent output items to OpenAI ResponseOutputItem typed models.
|
||||
|
||||
Only a subset is supported: computer_call, assistant message (text), and reasoning.
|
||||
@@ -40,14 +43,16 @@ def _map_agent_output_to_openai_blocks(output_items: List[Dict[str, Any]]) -> Li
|
||||
for item in output_items or []:
|
||||
t = item.get("type")
|
||||
if t == "computer_call":
|
||||
comp = ResponseComputerToolCall.model_validate({
|
||||
"id": item.get("id") or f"cu_{uuid.uuid4().hex}",
|
||||
"type": "computer_call",
|
||||
"call_id": item["call_id"],
|
||||
"action": item["action"],
|
||||
"pending_safety_checks": item.get("pending_safety_checks", []),
|
||||
"status": "completed",
|
||||
})
|
||||
comp = ResponseComputerToolCall.model_validate(
|
||||
{
|
||||
"id": item.get("id") or f"cu_{uuid.uuid4().hex}",
|
||||
"type": "computer_call",
|
||||
"call_id": item["call_id"],
|
||||
"action": item["action"],
|
||||
"pending_safety_checks": item.get("pending_safety_checks", []),
|
||||
"status": "completed",
|
||||
}
|
||||
)
|
||||
blocks.append(comp)
|
||||
# we will exit early here as the responses api only supports a single step
|
||||
break
|
||||
@@ -55,31 +60,38 @@ def _map_agent_output_to_openai_blocks(output_items: List[Dict[str, Any]]) -> Li
|
||||
content_blocks: List[ResponseOutputText] = []
|
||||
for c in item.get("content", []) or []:
|
||||
content_blocks.append(
|
||||
ResponseOutputText.model_validate({
|
||||
"type": "output_text",
|
||||
"text": c["text"],
|
||||
"annotations": [],
|
||||
})
|
||||
ResponseOutputText.model_validate(
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": c["text"],
|
||||
"annotations": [],
|
||||
}
|
||||
)
|
||||
)
|
||||
if content_blocks:
|
||||
msg = ResponseOutputMessage.model_validate({
|
||||
"id": item.get("id") or f"msg_{uuid.uuid4()}",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": [ct.model_dump() for ct in content_blocks],
|
||||
})
|
||||
msg = ResponseOutputMessage.model_validate(
|
||||
{
|
||||
"id": item.get("id") or f"msg_{uuid.uuid4()}",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": [ct.model_dump() for ct in content_blocks],
|
||||
}
|
||||
)
|
||||
blocks.append(msg)
|
||||
elif t == "reasoning":
|
||||
reasoning = ResponseReasoningItem.model_validate({
|
||||
"id": item.get("id") or f"rsn_{uuid.uuid4()}",
|
||||
"type": "reasoning",
|
||||
"summary": item["summary"],
|
||||
})
|
||||
reasoning = ResponseReasoningItem.model_validate(
|
||||
{
|
||||
"id": item.get("id") or f"rsn_{uuid.uuid4()}",
|
||||
"type": "reasoning",
|
||||
"summary": item["summary"],
|
||||
}
|
||||
)
|
||||
blocks.append(reasoning)
|
||||
# Unhandled types are ignored
|
||||
return blocks
|
||||
|
||||
|
||||
def _to_plain_dict_list(items: Any) -> List[Dict[str, Any]]:
|
||||
out: List[Dict[str, Any]] = []
|
||||
for it in list(items):
|
||||
@@ -92,6 +104,7 @@ def _to_plain_dict_list(items: Any) -> List[Dict[str, Any]]:
|
||||
out.append(dict(it)) # may raise if not mapping
|
||||
return out
|
||||
|
||||
|
||||
class FakeAsyncOpenAI:
|
||||
"""Minimal fake OpenAI client with only `responses.create` implemented.
|
||||
|
||||
@@ -132,10 +145,12 @@ class FakeAsyncOpenAI:
|
||||
# Pre-pend instructions message
|
||||
effective_input = full_input
|
||||
if instructions:
|
||||
effective_input = [{
|
||||
"role": "user",
|
||||
"content": instructions,
|
||||
}] + full_input
|
||||
effective_input = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": instructions,
|
||||
}
|
||||
] + full_input
|
||||
|
||||
# Run a single iteration of the ComputerAgent
|
||||
agent_result: Optional[Dict[str, Any]] = None
|
||||
@@ -152,32 +167,43 @@ class FakeAsyncOpenAI:
|
||||
blocks_to_cache = full_input + output
|
||||
for b in blocks_to_cache:
|
||||
bid = getattr(b, "id", None) or f"tmp-{hash(repr(b))}"
|
||||
self.blocks_cache[bid] = b # type: ignore[assignment]
|
||||
self.blocks_cache[bid] = b # type: ignore[assignment]
|
||||
block_ids.append(bid)
|
||||
response_id = agent_result.get("id") or f"fake-{int(time.time()*1000)}"
|
||||
self.context_cache[response_id] = block_ids
|
||||
|
||||
try:
|
||||
return Response.model_validate({
|
||||
"id": response_id,
|
||||
"created_at": time.time(),
|
||||
"object": "response",
|
||||
"model": model,
|
||||
"output": output,
|
||||
"parallel_tool_calls": False,
|
||||
"tool_choice": "auto",
|
||||
"tools": [],
|
||||
"previous_response_id": previous_response_id,
|
||||
"usage": ResponseUsage.model_validate({
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
"input_tokens_details": usage.get("input_tokens_details", { "cached_tokens": 0 }),
|
||||
"output_tokens_details": usage.get("output_tokens_details", { "reasoning_tokens": 0 }),
|
||||
}),
|
||||
})
|
||||
return Response.model_validate(
|
||||
{
|
||||
"id": response_id,
|
||||
"created_at": time.time(),
|
||||
"object": "response",
|
||||
"model": model,
|
||||
"output": output,
|
||||
"parallel_tool_calls": False,
|
||||
"tool_choice": "auto",
|
||||
"tools": [],
|
||||
"previous_response_id": previous_response_id,
|
||||
"usage": ResponseUsage.model_validate(
|
||||
{
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
"input_tokens_details": usage.get(
|
||||
"input_tokens_details", {"cached_tokens": 0}
|
||||
),
|
||||
"output_tokens_details": usage.get(
|
||||
"output_tokens_details", {"reasoning_tokens": 0}
|
||||
),
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error while validating agent response (attempt {attempt + 1}/{max_retries}): ", e)
|
||||
print(
|
||||
f"Error while validating agent response (attempt {attempt + 1}/{max_retries}): ",
|
||||
e,
|
||||
)
|
||||
if attempt == max_retries - 1:
|
||||
print(traceback.format_exc())
|
||||
raise e
|
||||
@@ -221,9 +247,15 @@ class ProxyOperatorAgent(OperatorAgent):
|
||||
allowed_tools = allowed_tools or ["openai_computer"]
|
||||
|
||||
computer_shim = {
|
||||
'screenshot': lambda: Image.new('RGB', (computer_settings.OPENAI_COMPUTER_WIDTH, computer_settings.OPENAI_COMPUTER_HEIGHT)),
|
||||
'environment': 'linux',
|
||||
'dimensions': (computer_settings.OPENAI_COMPUTER_WIDTH, computer_settings.OPENAI_COMPUTER_HEIGHT)
|
||||
"screenshot": lambda: Image.new(
|
||||
"RGB",
|
||||
(computer_settings.OPENAI_COMPUTER_WIDTH, computer_settings.OPENAI_COMPUTER_HEIGHT),
|
||||
),
|
||||
"environment": "linux",
|
||||
"dimensions": (
|
||||
computer_settings.OPENAI_COMPUTER_WIDTH,
|
||||
computer_settings.OPENAI_COMPUTER_HEIGHT,
|
||||
),
|
||||
}
|
||||
# Build tools ensuring the computer_shim is included
|
||||
agent_tools: list[Any] = [computer_shim]
|
||||
@@ -258,6 +290,7 @@ class ProxyOperatorAgent(OperatorAgent):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FakeAsyncOpenAI",
|
||||
"ProxyOperatorAgent",
|
||||
|
||||
@@ -3,26 +3,34 @@ Agent loops for agent
|
||||
"""
|
||||
|
||||
# Import the loops to register them
|
||||
from . import anthropic
|
||||
from . import openai
|
||||
from . import uitars
|
||||
from . import omniparser
|
||||
from . import gta1
|
||||
from . import composed_grounded
|
||||
from . import glm45v
|
||||
from . import opencua
|
||||
from . import internvl
|
||||
from . import holo
|
||||
from . import (
|
||||
anthropic,
|
||||
composed_grounded,
|
||||
gemini,
|
||||
glm45v,
|
||||
gta1,
|
||||
holo,
|
||||
internvl,
|
||||
moondream3,
|
||||
omniparser,
|
||||
openai,
|
||||
opencua,
|
||||
qwen,
|
||||
uitars,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"anthropic",
|
||||
"openai",
|
||||
"uitars",
|
||||
"omniparser",
|
||||
"gta1",
|
||||
"composed_grounded",
|
||||
"glm45v",
|
||||
"anthropic",
|
||||
"openai",
|
||||
"uitars",
|
||||
"omniparser",
|
||||
"gta1",
|
||||
"composed_grounded",
|
||||
"glm45v",
|
||||
"opencua",
|
||||
"internvl",
|
||||
"holo",
|
||||
]
|
||||
"moondream3",
|
||||
"gemini",
|
||||
"qwen",
|
||||
]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,13 +2,15 @@
|
||||
Base protocol for async agent configurations
|
||||
"""
|
||||
|
||||
from typing import Protocol, List, Dict, Any, Optional, Tuple, Union
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
|
||||
|
||||
from ..types import AgentCapability
|
||||
|
||||
|
||||
class AsyncAgentConfig(Protocol):
|
||||
"""Protocol defining the interface for async agent configurations."""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def predict_step(
|
||||
self,
|
||||
@@ -22,11 +24,11 @@ class AsyncAgentConfig(Protocol):
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Predict the next step based on input items.
|
||||
|
||||
|
||||
Args:
|
||||
messages: Input items following Responses format (message, function_call, computer_call)
|
||||
model: Model name to use
|
||||
@@ -39,37 +41,34 @@ class AsyncAgentConfig(Protocol):
|
||||
_on_usage: Callback for usage tracking
|
||||
_on_screenshot: Callback for screenshot events
|
||||
**kwargs: Additional arguments
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with "output" (output items) and "usage" array
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str
|
||||
self, model: str, image_b64: str, instruction: str
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates based on image and instruction.
|
||||
|
||||
|
||||
Args:
|
||||
model: Model name to use
|
||||
image_b64: Base64 encoded image
|
||||
instruction: Instruction for where to click
|
||||
|
||||
|
||||
Returns:
|
||||
None or tuple with (x, y) coordinates
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""
|
||||
Get list of capabilities supported by this agent config.
|
||||
|
||||
|
||||
Returns:
|
||||
List of capability strings (e.g., ["step", "click"])
|
||||
"""
|
||||
|
||||
@@ -3,122 +3,117 @@ Composed-grounded agent loop implementation that combines grounding and thinking
|
||||
Uses a two-stage approach: grounding model for element detection, thinking model for reasoning.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
import json
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import litellm
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from PIL import Image
|
||||
|
||||
from ..agent import find_agent_config
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..responses import (
|
||||
convert_computer_calls_xy2desc,
|
||||
convert_responses_items_to_completion_messages,
|
||||
convert_completion_messages_to_responses_items,
|
||||
convert_computer_calls_desc2xy,
|
||||
get_all_element_descriptions
|
||||
convert_computer_calls_xy2desc,
|
||||
convert_responses_items_to_completion_messages,
|
||||
get_all_element_descriptions,
|
||||
)
|
||||
from ..agent import find_agent_config
|
||||
from ..types import AgentCapability, AgentResponse, Messages, Tools
|
||||
|
||||
GROUNDED_COMPUTER_TOOL_SCHEMA = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "computer",
|
||||
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool uses element descriptions to locate and interact with UI elements on the screen (e.g., 'red submit button', 'search text field', 'hamburger menu icon', 'close button in top right corner').",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"screenshot",
|
||||
"click",
|
||||
"double_click",
|
||||
"drag",
|
||||
"type",
|
||||
"keypress",
|
||||
"scroll",
|
||||
"move",
|
||||
"wait",
|
||||
"get_current_url",
|
||||
"get_dimensions",
|
||||
"get_environment"
|
||||
],
|
||||
"description": "The action to perform (required for all actions)"
|
||||
},
|
||||
"element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to interact with (required for click, double_click, move, scroll actions)"
|
||||
},
|
||||
"start_element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to start dragging from (required for drag action)"
|
||||
},
|
||||
"end_element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to drag to (required for drag action)"
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to type (required for type action)"
|
||||
},
|
||||
"keys": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "computer",
|
||||
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool uses element descriptions to locate and interact with UI elements on the screen (e.g., 'red submit button', 'search text field', 'hamburger menu icon', 'close button in top right corner').",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"screenshot",
|
||||
"click",
|
||||
"double_click",
|
||||
"drag",
|
||||
"type",
|
||||
"keypress",
|
||||
"scroll",
|
||||
"move",
|
||||
"wait",
|
||||
"get_current_url",
|
||||
"get_dimensions",
|
||||
"get_environment",
|
||||
],
|
||||
"description": "The action to perform (required for all actions)",
|
||||
},
|
||||
"element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to interact with (required for click, double_click, move, scroll actions)",
|
||||
},
|
||||
"start_element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to start dragging from (required for drag action)",
|
||||
},
|
||||
"end_element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to drag to (required for drag action)",
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to type (required for type action)",
|
||||
},
|
||||
"keys": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Key(s) to press (required for keypress action)",
|
||||
},
|
||||
"button": {
|
||||
"type": "string",
|
||||
"enum": ["left", "right", "wheel", "back", "forward"],
|
||||
"description": "The mouse button to use for click action (required for click and double_click action)",
|
||||
},
|
||||
"scroll_x": {
|
||||
"type": "integer",
|
||||
"description": "Horizontal scroll amount for scroll action (required for scroll action)",
|
||||
},
|
||||
"scroll_y": {
|
||||
"type": "integer",
|
||||
"description": "Vertical scroll amount for scroll action (required for scroll action)",
|
||||
},
|
||||
},
|
||||
"description": "Key(s) to press (required for keypress action)"
|
||||
"required": ["action"],
|
||||
},
|
||||
"button": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"left",
|
||||
"right",
|
||||
"wheel",
|
||||
"back",
|
||||
"forward"
|
||||
],
|
||||
"description": "The mouse button to use for click action (required for click and double_click action)",
|
||||
},
|
||||
"scroll_x": {
|
||||
"type": "integer",
|
||||
"description": "Horizontal scroll amount for scroll action (required for scroll action)",
|
||||
},
|
||||
"scroll_y": {
|
||||
"type": "integer",
|
||||
"description": "Vertical scroll amount for scroll action (required for scroll action)",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"action"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _prepare_tools_for_grounded(tool_schemas: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Prepare tools for grounded API format"""
|
||||
grounded_tools = []
|
||||
|
||||
|
||||
for schema in tool_schemas:
|
||||
if schema["type"] == "computer":
|
||||
grounded_tools.append(GROUNDED_COMPUTER_TOOL_SCHEMA)
|
||||
else:
|
||||
grounded_tools.append(schema)
|
||||
|
||||
|
||||
return grounded_tools
|
||||
|
||||
|
||||
def get_last_computer_call_image(messages: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""Get the last computer call output image from messages."""
|
||||
for message in reversed(messages):
|
||||
if (isinstance(message, dict) and
|
||||
message.get("type") == "computer_call_output" and
|
||||
isinstance(message.get("output"), dict) and
|
||||
message["output"].get("type") == "input_image"):
|
||||
if (
|
||||
isinstance(message, dict)
|
||||
and message.get("type") == "computer_call_output"
|
||||
and isinstance(message.get("output"), dict)
|
||||
and message["output"].get("type") == "input_image"
|
||||
):
|
||||
image_url = message["output"].get("image_url", "")
|
||||
if image_url.startswith("data:image/png;base64,"):
|
||||
return image_url.split(",", 1)[1]
|
||||
@@ -129,14 +124,14 @@ def get_last_computer_call_image(messages: List[Dict[str, Any]]) -> Optional[str
|
||||
class ComposedGroundedConfig(AsyncAgentConfig):
|
||||
"""
|
||||
Composed-grounded agent configuration that uses both grounding and thinking models.
|
||||
|
||||
|
||||
The model parameter should be in format: "grounding_model+thinking_model"
|
||||
e.g., "huggingface-local/HelloKKMe/GTA1-7B+gemini/gemini-1.5-pro"
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.desc2xy: Dict[str, Tuple[float, float]] = {}
|
||||
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
@@ -150,11 +145,11 @@ class ComposedGroundedConfig(AsyncAgentConfig):
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Composed-grounded predict step implementation.
|
||||
|
||||
|
||||
Process:
|
||||
0. Store last computer call image, if none then take a screenshot
|
||||
1. Convert computer calls from xy to descriptions
|
||||
@@ -167,18 +162,20 @@ class ComposedGroundedConfig(AsyncAgentConfig):
|
||||
"""
|
||||
# Parse the composed model
|
||||
if "+" not in model:
|
||||
raise ValueError(f"Composed model must be in format 'grounding_model+thinking_model', got: {model}")
|
||||
raise ValueError(
|
||||
f"Composed model must be in format 'grounding_model+thinking_model', got: {model}"
|
||||
)
|
||||
grounding_model, thinking_model = model.split("+", 1)
|
||||
|
||||
|
||||
pre_output_items = []
|
||||
|
||||
|
||||
# Step 0: Store last computer call image, if none then take a screenshot
|
||||
last_image_b64 = get_last_computer_call_image(messages)
|
||||
if last_image_b64 is None:
|
||||
# Take a screenshot
|
||||
screenshot_b64 = await computer_handler.screenshot() # type: ignore
|
||||
screenshot_b64 = await computer_handler.screenshot() # type: ignore
|
||||
if screenshot_b64:
|
||||
|
||||
|
||||
call_id = uuid.uuid4().hex
|
||||
pre_output_items += [
|
||||
{
|
||||
@@ -187,45 +184,42 @@ class ComposedGroundedConfig(AsyncAgentConfig):
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": "Taking a screenshot to see the current computer screen."
|
||||
"text": "Taking a screenshot to see the current computer screen.",
|
||||
}
|
||||
]
|
||||
],
|
||||
},
|
||||
{
|
||||
"action": {
|
||||
"type": "screenshot"
|
||||
},
|
||||
"action": {"type": "screenshot"},
|
||||
"call_id": call_id,
|
||||
"status": "completed",
|
||||
"type": "computer_call"
|
||||
"type": "computer_call",
|
||||
},
|
||||
{
|
||||
"type": "computer_call_output",
|
||||
"call_id": call_id,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshot_b64}"
|
||||
}
|
||||
"image_url": f"data:image/png;base64,{screenshot_b64}",
|
||||
},
|
||||
},
|
||||
]
|
||||
last_image_b64 = screenshot_b64
|
||||
|
||||
|
||||
# Call screenshot callback if provided
|
||||
if _on_screenshot:
|
||||
await _on_screenshot(screenshot_b64)
|
||||
|
||||
tool_schemas = _prepare_tools_for_grounded(tools) # type: ignore
|
||||
|
||||
tool_schemas = _prepare_tools_for_grounded(tools) # type: ignore
|
||||
|
||||
# Step 1: Convert computer calls from xy to descriptions
|
||||
input_messages = messages + pre_output_items
|
||||
messages_with_descriptions = convert_computer_calls_xy2desc(input_messages, self.desc2xy)
|
||||
|
||||
|
||||
# Step 2: Convert responses items to completion messages
|
||||
completion_messages = convert_responses_items_to_completion_messages(
|
||||
messages_with_descriptions,
|
||||
allow_images_in_tool_results=False
|
||||
messages_with_descriptions, allow_images_in_tool_results=False
|
||||
)
|
||||
|
||||
|
||||
# Step 3: Call thinking model with litellm.acompletion
|
||||
api_kwargs = {
|
||||
"model": thinking_model,
|
||||
@@ -233,98 +227,90 @@ class ComposedGroundedConfig(AsyncAgentConfig):
|
||||
"tools": tool_schemas,
|
||||
"max_retries": max_retries,
|
||||
"stream": stream,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if use_prompt_caching:
|
||||
api_kwargs["use_prompt_caching"] = use_prompt_caching
|
||||
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
|
||||
# Make the completion call
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
|
||||
# Extract usage information
|
||||
usage = {
|
||||
**response.usage.model_dump(), # type: ignore
|
||||
**response.usage.model_dump(), # type: ignore
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(usage)
|
||||
|
||||
|
||||
# Step 4: Convert completion messages back to responses items format
|
||||
response_dict = response.model_dump() # type: ignore
|
||||
response_dict = response.model_dump() # type: ignore
|
||||
choice_messages = [choice["message"] for choice in response_dict["choices"]]
|
||||
thinking_output_items = []
|
||||
|
||||
|
||||
for choice_message in choice_messages:
|
||||
thinking_output_items.extend(convert_completion_messages_to_responses_items([choice_message]))
|
||||
|
||||
thinking_output_items.extend(
|
||||
convert_completion_messages_to_responses_items([choice_message])
|
||||
)
|
||||
|
||||
# Step 5: Get all element descriptions and populate desc2xy mapping
|
||||
element_descriptions = get_all_element_descriptions(thinking_output_items)
|
||||
|
||||
|
||||
if element_descriptions and last_image_b64:
|
||||
# Use grounding model to predict coordinates for each description
|
||||
grounding_agent_conf = find_agent_config(grounding_model)
|
||||
if grounding_agent_conf:
|
||||
grounding_agent = grounding_agent_conf.agent_class()
|
||||
|
||||
|
||||
for desc in element_descriptions:
|
||||
for _ in range(3): # try 3 times
|
||||
for _ in range(3): # try 3 times
|
||||
coords = await grounding_agent.predict_click(
|
||||
model=grounding_model,
|
||||
image_b64=last_image_b64,
|
||||
instruction=desc
|
||||
model=grounding_model, image_b64=last_image_b64, instruction=desc
|
||||
)
|
||||
if coords:
|
||||
self.desc2xy[desc] = coords
|
||||
break
|
||||
|
||||
|
||||
# Step 6: Convert computer calls from descriptions back to xy coordinates
|
||||
final_output_items = convert_computer_calls_desc2xy(thinking_output_items, self.desc2xy)
|
||||
|
||||
|
||||
# Step 7: Return output and usage
|
||||
return {
|
||||
"output": pre_output_items + final_output_items,
|
||||
"usage": usage
|
||||
}
|
||||
|
||||
return {"output": pre_output_items + final_output_items, "usage": usage}
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str,
|
||||
**kwargs
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates using the grounding model.
|
||||
|
||||
|
||||
For composed models, uses only the grounding model part for click prediction.
|
||||
"""
|
||||
# Parse the composed model to get grounding model
|
||||
if "+" not in model:
|
||||
raise ValueError(f"Composed model must be in format 'grounding_model+thinking_model', got: {model}")
|
||||
raise ValueError(
|
||||
f"Composed model must be in format 'grounding_model+thinking_model', got: {model}"
|
||||
)
|
||||
grounding_model, thinking_model = model.split("+", 1)
|
||||
|
||||
|
||||
# Find and use the grounding agent
|
||||
grounding_agent_conf = find_agent_config(grounding_model)
|
||||
if grounding_agent_conf:
|
||||
grounding_agent = grounding_agent_conf.agent_class()
|
||||
return await grounding_agent.predict_click(
|
||||
model=grounding_model,
|
||||
image_b64=image_b64,
|
||||
instruction=instruction,
|
||||
**kwargs
|
||||
model=grounding_model, image_b64=image_b64, instruction=instruction, **kwargs
|
||||
)
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""Return the capabilities supported by this agent."""
|
||||
return ["click", "step"]
|
||||
|
||||
410
libs/python/agent/agent/loops/gemini.py
Normal file
410
libs/python/agent/agent/loops/gemini.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""
|
||||
Gemini 2.5 Computer Use agent loop
|
||||
|
||||
Maps internal Agent SDK message format to Google's Gemini Computer Use API and back.
|
||||
|
||||
Key features:
|
||||
- Lazy import of google.genai
|
||||
- Configure Computer Use tool with excluded browser-specific predefined functions
|
||||
- Optional custom function declarations hook for computer-call specific functions
|
||||
- Convert Gemini function_call parts into internal computer_call actions
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..types import AgentCapability
|
||||
|
||||
|
||||
def _lazy_import_genai():
|
||||
"""Import google.genai lazily to avoid hard dependency unless used."""
|
||||
try:
|
||||
from google import genai # type: ignore
|
||||
from google.genai import types # type: ignore
|
||||
|
||||
return genai, types
|
||||
except Exception as e: # pragma: no cover
|
||||
raise RuntimeError(
|
||||
"google.genai is required for the Gemini Computer Use loop. Install the Google Gemini SDK."
|
||||
) from e
|
||||
|
||||
|
||||
def _data_url_to_bytes(data_url: str) -> Tuple[bytes, str]:
|
||||
"""Convert a data URL to raw bytes and mime type."""
|
||||
if not data_url.startswith("data:"):
|
||||
# Assume it's base64 png payload
|
||||
try:
|
||||
return base64.b64decode(data_url), "image/png"
|
||||
except Exception:
|
||||
return b"", "application/octet-stream"
|
||||
header, b64 = data_url.split(",", 1)
|
||||
mime = "image/png"
|
||||
if ";" in header:
|
||||
mime = header.split(";")[0].split(":", 1)[1] or "image/png"
|
||||
return base64.b64decode(b64), mime
|
||||
|
||||
|
||||
def _bytes_image_size(img_bytes: bytes) -> Tuple[int, int]:
|
||||
try:
|
||||
img = Image.open(io.BytesIO(img_bytes))
|
||||
return img.size
|
||||
except Exception:
|
||||
return (1024, 768)
|
||||
|
||||
|
||||
def _find_last_user_text(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
texts: List[str] = []
|
||||
for msg in reversed(messages):
|
||||
if msg.get("type") in (None, "message") and msg.get("role") == "user":
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
return [content]
|
||||
elif isinstance(content, list):
|
||||
for c in content:
|
||||
if c.get("type") in ("input_text", "output_text") and c.get("text"):
|
||||
texts.append(c["text"]) # newest first
|
||||
if texts:
|
||||
return list(reversed(texts))
|
||||
return []
|
||||
|
||||
|
||||
def _find_last_screenshot(messages: List[Dict[str, Any]]) -> Optional[bytes]:
|
||||
for msg in reversed(messages):
|
||||
if msg.get("type") == "computer_call_output":
|
||||
out = msg.get("output", {})
|
||||
if isinstance(out, dict) and out.get("type") in ("input_image", "computer_screenshot"):
|
||||
image_url = out.get("image_url", "")
|
||||
if image_url:
|
||||
data, _ = _data_url_to_bytes(image_url)
|
||||
return data
|
||||
return None
|
||||
|
||||
|
||||
def _denormalize(v: int, size: int) -> int:
|
||||
# Gemini returns 0-999 normalized
|
||||
try:
|
||||
return max(0, min(size - 1, int(round(v / 1000 * size))))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def _map_gemini_fc_to_computer_call(
|
||||
fc: Dict[str, Any],
|
||||
screen_w: int,
|
||||
screen_h: int,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
name = fc.get("name")
|
||||
args = fc.get("args", {}) or {}
|
||||
|
||||
action: Dict[str, Any] = {}
|
||||
if name == "click_at":
|
||||
x = _denormalize(int(args.get("x", 0)), screen_w)
|
||||
y = _denormalize(int(args.get("y", 0)), screen_h)
|
||||
action = {"type": "click", "x": x, "y": y, "button": "left"}
|
||||
elif name == "type_text_at":
|
||||
x = _denormalize(int(args.get("x", 0)), screen_w)
|
||||
y = _denormalize(int(args.get("y", 0)), screen_h)
|
||||
text = args.get("text", "")
|
||||
if args.get("press_enter") == True:
|
||||
text += "\n"
|
||||
action = {"type": "type", "x": x, "y": y, "text": text}
|
||||
elif name == "hover_at":
|
||||
x = _denormalize(int(args.get("x", 0)), screen_w)
|
||||
y = _denormalize(int(args.get("y", 0)), screen_h)
|
||||
action = {"type": "move", "x": x, "y": y}
|
||||
elif name == "key_combination":
|
||||
keys = str(args.get("keys", ""))
|
||||
action = {"type": "keypress", "keys": keys}
|
||||
elif name == "scroll_document":
|
||||
direction = args.get("direction", "down")
|
||||
magnitude = 800
|
||||
dx, dy = 0, 0
|
||||
if direction == "down":
|
||||
dy = magnitude
|
||||
elif direction == "up":
|
||||
dy = -magnitude
|
||||
elif direction == "right":
|
||||
dx = magnitude
|
||||
elif direction == "left":
|
||||
dx = -magnitude
|
||||
action = {
|
||||
"type": "scroll",
|
||||
"scroll_x": dx,
|
||||
"scroll_y": dy,
|
||||
"x": int(screen_w / 2),
|
||||
"y": int(screen_h / 2),
|
||||
}
|
||||
elif name == "scroll_at":
|
||||
x = _denormalize(int(args.get("x", 500)), screen_w)
|
||||
y = _denormalize(int(args.get("y", 500)), screen_h)
|
||||
direction = args.get("direction", "down")
|
||||
magnitude = int(args.get("magnitude", 800))
|
||||
dx, dy = 0, 0
|
||||
if direction == "down":
|
||||
dy = magnitude
|
||||
elif direction == "up":
|
||||
dy = -magnitude
|
||||
elif direction == "right":
|
||||
dx = magnitude
|
||||
elif direction == "left":
|
||||
dx = -magnitude
|
||||
action = {"type": "scroll", "scroll_x": dx, "scroll_y": dy, "x": x, "y": y}
|
||||
elif name == "drag_and_drop":
|
||||
x = _denormalize(int(args.get("x", 0)), screen_w)
|
||||
y = _denormalize(int(args.get("y", 0)), screen_h)
|
||||
dx = _denormalize(int(args.get("destination_x", x)), screen_w)
|
||||
dy = _denormalize(int(args.get("destination_y", y)), screen_h)
|
||||
action = {
|
||||
"type": "drag",
|
||||
"start_x": x,
|
||||
"start_y": y,
|
||||
"end_x": dx,
|
||||
"end_y": dy,
|
||||
"button": "left",
|
||||
}
|
||||
elif name == "wait_5_seconds":
|
||||
action = {"type": "wait"}
|
||||
else:
|
||||
# Unsupported / excluded browser-specific or custom function; ignore
|
||||
return None
|
||||
|
||||
return {
|
||||
"type": "computer_call",
|
||||
"call_id": uuid.uuid4().hex,
|
||||
"status": "completed",
|
||||
"action": action,
|
||||
}
|
||||
|
||||
|
||||
@register_agent(models=r"^gemini-2\.5-computer-use-preview-10-2025$")
|
||||
class GeminiComputerUseConfig(AsyncAgentConfig):
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
genai, types = _lazy_import_genai()
|
||||
|
||||
client = genai.Client()
|
||||
|
||||
# Build excluded predefined functions for browser-specific behavior
|
||||
excluded = [
|
||||
"open_web_browser",
|
||||
"search",
|
||||
"navigate",
|
||||
"go_forward",
|
||||
"go_back",
|
||||
"scroll_document",
|
||||
]
|
||||
# Optional custom functions: can be extended by host code via `tools` parameter later if desired
|
||||
CUSTOM_FUNCTION_DECLARATIONS: List[Any] = []
|
||||
|
||||
# Compose tools config
|
||||
generate_content_config = types.GenerateContentConfig(
|
||||
tools=[
|
||||
types.Tool(
|
||||
computer_use=types.ComputerUse(
|
||||
environment=types.Environment.ENVIRONMENT_BROWSER,
|
||||
excluded_predefined_functions=excluded,
|
||||
)
|
||||
),
|
||||
# types.Tool(function_declarations=CUSTOM_FUNCTION_DECLARATIONS), # enable when custom functions needed
|
||||
]
|
||||
)
|
||||
|
||||
# Prepare contents: last user text + latest screenshot
|
||||
user_texts = _find_last_user_text(messages)
|
||||
screenshot_bytes = _find_last_screenshot(messages)
|
||||
|
||||
parts: List[Any] = []
|
||||
for t in user_texts:
|
||||
parts.append(types.Part(text=t))
|
||||
|
||||
screen_w, screen_h = 1024, 768
|
||||
if screenshot_bytes:
|
||||
screen_w, screen_h = _bytes_image_size(screenshot_bytes)
|
||||
parts.append(types.Part.from_bytes(data=screenshot_bytes, mime_type="image/png"))
|
||||
|
||||
# If we don't have any content, at least pass an empty user part to prompt reasoning
|
||||
if not parts:
|
||||
parts = [types.Part(text="Proceed to the next action.")]
|
||||
|
||||
contents = [types.Content(role="user", parts=parts)]
|
||||
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"contents": contents,
|
||||
"config": generate_content_config,
|
||||
}
|
||||
|
||||
if _on_api_start:
|
||||
await _on_api_start(
|
||||
{
|
||||
"model": api_kwargs["model"],
|
||||
# "contents": api_kwargs["contents"], # Disabled for now
|
||||
"config": api_kwargs["config"],
|
||||
}
|
||||
)
|
||||
|
||||
response = client.models.generate_content(**api_kwargs)
|
||||
|
||||
if _on_api_end:
|
||||
await _on_api_end(
|
||||
{
|
||||
"model": api_kwargs["model"],
|
||||
# "contents": api_kwargs["contents"], # Disabled for now
|
||||
"config": api_kwargs["config"],
|
||||
},
|
||||
response,
|
||||
)
|
||||
|
||||
# Usage (Gemini SDK may not always provide token usage; populate when available)
|
||||
usage: Dict[str, Any] = {}
|
||||
try:
|
||||
# Some SDKs expose response.usage; if available, copy
|
||||
if getattr(response, "usage_metadata", None):
|
||||
md = response.usage_metadata
|
||||
usage = {
|
||||
"prompt_tokens": getattr(md, "prompt_token_count", None) or 0,
|
||||
"completion_tokens": getattr(md, "candidates_token_count", None) or 0,
|
||||
"total_tokens": getattr(md, "total_token_count", None) or 0,
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if _on_usage and usage:
|
||||
await _on_usage(usage)
|
||||
|
||||
# Parse output into internal items
|
||||
output_items: List[Dict[str, Any]] = []
|
||||
|
||||
candidate = response.candidates[0]
|
||||
# Text parts from the model (assistant message)
|
||||
text_parts: List[str] = []
|
||||
function_calls: List[Dict[str, Any]] = []
|
||||
for p in candidate.content.parts:
|
||||
if getattr(p, "text", None):
|
||||
text_parts.append(p.text)
|
||||
if getattr(p, "function_call", None):
|
||||
# p.function_call has name and args
|
||||
fc = {
|
||||
"name": getattr(p.function_call, "name", None),
|
||||
"args": dict(getattr(p.function_call, "args", {}) or {}),
|
||||
}
|
||||
function_calls.append(fc)
|
||||
|
||||
if text_parts:
|
||||
output_items.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "\n".join(text_parts)}],
|
||||
}
|
||||
)
|
||||
|
||||
# Map function calls to internal computer_call actions
|
||||
for fc in function_calls:
|
||||
item = _map_gemini_fc_to_computer_call(fc, screen_w, screen_h)
|
||||
if item is not None:
|
||||
output_items.append(item)
|
||||
|
||||
return {"output": output_items, "usage": usage}
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str,
|
||||
**kwargs,
|
||||
) -> Optional[Tuple[float, float]]:
|
||||
"""Ask Gemini CUA to output a single click action for the given instruction.
|
||||
|
||||
Excludes all predefined tools except `click_at` and sends the screenshot.
|
||||
Returns pixel (x, y) if a click is proposed, else None.
|
||||
"""
|
||||
genai, types = _lazy_import_genai()
|
||||
|
||||
client = genai.Client()
|
||||
|
||||
# Exclude all but click_at
|
||||
exclude_all_but_click = [
|
||||
"open_web_browser",
|
||||
"wait_5_seconds",
|
||||
"go_back",
|
||||
"go_forward",
|
||||
"search",
|
||||
"navigate",
|
||||
"hover_at",
|
||||
"type_text_at",
|
||||
"key_combination",
|
||||
"scroll_document",
|
||||
"scroll_at",
|
||||
"drag_and_drop",
|
||||
]
|
||||
|
||||
config = types.GenerateContentConfig(
|
||||
tools=[
|
||||
types.Tool(
|
||||
computer_use=types.ComputerUse(
|
||||
environment=types.Environment.ENVIRONMENT_BROWSER,
|
||||
excluded_predefined_functions=exclude_all_but_click,
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Prepare prompt parts
|
||||
try:
|
||||
img_bytes = base64.b64decode(image_b64)
|
||||
except Exception:
|
||||
img_bytes = b""
|
||||
|
||||
w, h = _bytes_image_size(img_bytes) if img_bytes else (1024, 768)
|
||||
|
||||
parts: List[Any] = [types.Part(text=f"Click {instruction}.")]
|
||||
if img_bytes:
|
||||
parts.append(types.Part.from_bytes(data=img_bytes, mime_type="image/png"))
|
||||
|
||||
contents = [types.Content(role="user", parts=parts)]
|
||||
|
||||
response = client.models.generate_content(
|
||||
model=model,
|
||||
contents=contents,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Parse first click_at
|
||||
try:
|
||||
candidate = response.candidates[0]
|
||||
for p in candidate.content.parts:
|
||||
fc = getattr(p, "function_call", None)
|
||||
if fc and getattr(fc, "name", None) == "click_at":
|
||||
args = dict(getattr(fc, "args", {}) or {})
|
||||
x = _denormalize(int(args.get("x", 0)), w)
|
||||
y = _denormalize(int(args.get("y", 0)), h)
|
||||
return float(x), float(y)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
return ["click", "step"]
|
||||
@@ -4,33 +4,36 @@ Supports vision-language models for computer control with bounding box parsing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from litellm.responses.litellm_completion_transformation.transformation import (
|
||||
LiteLLMCompletionResponsesConfig,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.responses.litellm_completion_transformation.transformation import LiteLLMCompletionResponsesConfig
|
||||
from PIL import Image
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..responses import (
|
||||
convert_responses_items_to_completion_messages,
|
||||
convert_completion_messages_to_responses_items,
|
||||
make_reasoning_item,
|
||||
make_output_text_item,
|
||||
convert_responses_items_to_completion_messages,
|
||||
make_click_item,
|
||||
make_double_click_item,
|
||||
make_drag_item,
|
||||
make_input_image_item,
|
||||
make_keypress_item,
|
||||
make_output_text_item,
|
||||
make_reasoning_item,
|
||||
make_scroll_item,
|
||||
make_type_item,
|
||||
make_wait_item,
|
||||
make_input_image_item
|
||||
)
|
||||
from ..types import AgentCapability, AgentResponse, Messages, Tools
|
||||
|
||||
# GLM-4.5V specific constants
|
||||
GLM_ACTION_SPACE = """
|
||||
@@ -251,16 +254,18 @@ Call rule: `FAIL()`
|
||||
}
|
||||
}"""
|
||||
|
||||
|
||||
def encode_image_to_base64(image_path: str) -> str:
|
||||
"""Encode image file to base64 string with data URI."""
|
||||
with open(image_path, "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
return f"data:image/png;base64,{encoded_string}"
|
||||
|
||||
|
||||
def parse_glm_response(response: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse GLM-4.5V response to extract action and memory.
|
||||
|
||||
|
||||
The special tokens <|begin_of_box|> and <|end_of_box|> mark bounding boxes.
|
||||
Coordinates are normalized values between 0 and 1000.
|
||||
"""
|
||||
@@ -274,26 +279,23 @@ def parse_glm_response(response: str) -> Dict[str, Any]:
|
||||
action_pattern = r"[\w_]+\([^)]*\)"
|
||||
matches = re.findall(action_pattern, response)
|
||||
action = matches[0] if matches else None
|
||||
|
||||
|
||||
# Extract memory section
|
||||
memory_pattern = r"Memory:(.*?)$"
|
||||
memory_match = re.search(memory_pattern, response, re.DOTALL)
|
||||
memory = memory_match.group(1).strip() if memory_match else "[]"
|
||||
|
||||
|
||||
# Extract action text (everything before Memory:)
|
||||
action_text_pattern = r'^(.*?)Memory:'
|
||||
action_text_pattern = r"^(.*?)Memory:"
|
||||
action_text_match = re.search(action_text_pattern, response, re.DOTALL)
|
||||
action_text = action_text_match.group(1).strip() if action_text_match else response
|
||||
|
||||
|
||||
# Clean up action text by removing special tokens
|
||||
if action_text:
|
||||
action_text = action_text.replace("<|begin_of_box|>", "").replace("<|end_of_box|>", "")
|
||||
|
||||
return {
|
||||
"action": action,
|
||||
"action_text": action_text,
|
||||
"memory": memory
|
||||
}
|
||||
|
||||
return {"action": action, "action_text": action_text, "memory": memory}
|
||||
|
||||
|
||||
def get_last_image_from_messages(messages: Messages) -> Optional[str]:
|
||||
"""Extract the last image from messages for processing."""
|
||||
@@ -314,23 +316,28 @@ def get_last_image_from_messages(messages: Messages) -> Optional[str]:
|
||||
image_url_obj = item.get("image_url", {})
|
||||
if isinstance(image_url_obj, dict):
|
||||
image_url = image_url_obj.get("url", "")
|
||||
if isinstance(image_url, str) and image_url.startswith("data:image/"):
|
||||
if isinstance(image_url, str) and image_url.startswith(
|
||||
"data:image/"
|
||||
):
|
||||
return image_url.split(",", 1)[1]
|
||||
return None
|
||||
|
||||
def convert_responses_items_to_glm45v_pc_prompt(messages: Messages, task: str, memory: str = "") -> List[Dict[str, Any]]:
|
||||
|
||||
def convert_responses_items_to_glm45v_pc_prompt(
|
||||
messages: Messages, task: str, memory: str = ""
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert responses items to GLM-4.5V PC prompt format with historical actions.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of message items from the conversation
|
||||
task: The task description
|
||||
memory: Current memory state
|
||||
|
||||
|
||||
Returns:
|
||||
List of content items for the prompt (text and image_url items)
|
||||
"""
|
||||
action_space = GLM_ACTION_SPACE
|
||||
|
||||
|
||||
# Template head
|
||||
head_text = f"""You are a GUI Agent, and your primary task is to respond accurately to user requests or questions. In addition to directly answering the user's queries, you can also use tools or perform GUI operations directly until you fulfill the user's request or provide a correct answer. You should carefully read and understand the images and questions provided by the user, and engage in thinking and reflection when appropriate. The coordinates involved are all represented in thousandths (0-999).
|
||||
|
||||
@@ -345,7 +352,7 @@ Ubuntu
|
||||
|
||||
# Historical Actions and Current Memory
|
||||
History:"""
|
||||
|
||||
|
||||
# Template tail
|
||||
tail_text = f"""
|
||||
Memory:
|
||||
@@ -363,18 +370,18 @@ Memory:
|
||||
|
||||
Current Screenshot:
|
||||
"""
|
||||
|
||||
|
||||
# Build history from messages
|
||||
history = []
|
||||
history_images = []
|
||||
|
||||
|
||||
# Group messages into steps
|
||||
current_step = []
|
||||
step_num = 0
|
||||
|
||||
|
||||
for message in messages:
|
||||
msg_type = message.get("type")
|
||||
|
||||
|
||||
if msg_type == "reasoning":
|
||||
current_step.append(message)
|
||||
elif msg_type == "message" and message.get("role") == "assistant":
|
||||
@@ -386,7 +393,7 @@ Current Screenshot:
|
||||
# End of step - process it
|
||||
if current_step:
|
||||
step_num += 1
|
||||
|
||||
|
||||
# Extract bot thought from message content
|
||||
bot_thought = ""
|
||||
for item in current_step:
|
||||
@@ -397,14 +404,14 @@ Current Screenshot:
|
||||
bot_thought = content_item.get("text", "")
|
||||
break
|
||||
break
|
||||
|
||||
|
||||
# Extract action from computer_call
|
||||
action_text = ""
|
||||
for item in current_step:
|
||||
if item.get("type") == "computer_call":
|
||||
action = item.get("action", {})
|
||||
action_type = action.get("type", "")
|
||||
|
||||
|
||||
if action_type == "click":
|
||||
x, y = action.get("x", 0), action.get("y", 0)
|
||||
# Convert to 0-999 range (assuming screen dimensions)
|
||||
@@ -436,7 +443,7 @@ Current Screenshot:
|
||||
elif action_type == "wait":
|
||||
action_text = "WAIT()"
|
||||
break
|
||||
|
||||
|
||||
# Extract screenshot from computer_call_output
|
||||
screenshot_url = None
|
||||
for item in current_step:
|
||||
@@ -445,34 +452,34 @@ Current Screenshot:
|
||||
if output.get("type") == "input_image":
|
||||
screenshot_url = output.get("image_url", "")
|
||||
break
|
||||
|
||||
|
||||
# Store step info
|
||||
step_info = {
|
||||
"step_num": step_num,
|
||||
"bot_thought": bot_thought,
|
||||
"action_text": action_text,
|
||||
"screenshot_url": screenshot_url
|
||||
"screenshot_url": screenshot_url,
|
||||
}
|
||||
history.append(step_info)
|
||||
|
||||
|
||||
# Store screenshot for last 4 steps
|
||||
if screenshot_url:
|
||||
history_images.append(screenshot_url)
|
||||
|
||||
|
||||
current_step = []
|
||||
|
||||
|
||||
# Build content array with head, history, and tail
|
||||
content = []
|
||||
current_text = head_text
|
||||
|
||||
|
||||
total_history_steps = len(history)
|
||||
history_image_count = min(4, len(history_images)) # Last 4 images
|
||||
|
||||
|
||||
for step_idx, step_info in enumerate(history):
|
||||
step_num = step_info["step_num"]
|
||||
bot_thought = step_info["bot_thought"]
|
||||
action_text = step_info["action_text"]
|
||||
|
||||
|
||||
if step_idx < total_history_steps - history_image_count:
|
||||
# For steps beyond the last 4, use text placeholder
|
||||
current_text += f"\nstep {step_num}: Screenshot:(Omitted in context.) Thought: {bot_thought}\nAction: {action_text}"
|
||||
@@ -480,20 +487,21 @@ Current Screenshot:
|
||||
# For the last 4 steps, insert images
|
||||
current_text += f"\nstep {step_num}: Screenshot:"
|
||||
content.append({"type": "text", "text": current_text})
|
||||
|
||||
|
||||
# Add image
|
||||
img_idx = step_idx - (total_history_steps - history_image_count)
|
||||
if img_idx < len(history_images):
|
||||
content.append({"type": "image_url", "image_url": {"url": history_images[img_idx]}})
|
||||
|
||||
|
||||
current_text = f" Thought: {bot_thought}\nAction: {action_text}"
|
||||
|
||||
|
||||
# Add tail
|
||||
current_text += tail_text
|
||||
content.append({"type": "text", "text": current_text})
|
||||
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def model_dump(obj) -> Dict[str, Any]:
|
||||
if isinstance(obj, dict):
|
||||
return {k: model_dump(v) for k, v in obj.items()}
|
||||
@@ -502,58 +510,61 @@ def model_dump(obj) -> Dict[str, Any]:
|
||||
else:
|
||||
return obj
|
||||
|
||||
def convert_glm_completion_to_responses_items(response: ModelResponse, image_width: int, image_height: int) -> List[Dict[str, Any]]:
|
||||
|
||||
def convert_glm_completion_to_responses_items(
|
||||
response: ModelResponse, image_width: int, image_height: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert GLM-4.5V completion response to responses items format.
|
||||
|
||||
|
||||
Args:
|
||||
response: LiteLLM ModelResponse from GLM-4.5V
|
||||
image_width: Original image width for coordinate scaling
|
||||
image_height: Original image height for coordinate scaling
|
||||
|
||||
|
||||
Returns:
|
||||
List of response items in the proper format
|
||||
"""
|
||||
import uuid
|
||||
|
||||
|
||||
response_items = []
|
||||
|
||||
|
||||
if not response.choices or not response.choices[0].message:
|
||||
return response_items
|
||||
|
||||
|
||||
message = response.choices[0].message
|
||||
content = message.content or ""
|
||||
reasoning_content = getattr(message, 'reasoning_content', None)
|
||||
|
||||
reasoning_content = getattr(message, "reasoning_content", None)
|
||||
|
||||
# Add reasoning item if present
|
||||
if reasoning_content:
|
||||
reasoning_item = model_dump(make_reasoning_item(reasoning_content))
|
||||
response_items.append(reasoning_item)
|
||||
|
||||
|
||||
# Parse the content to extract action and text
|
||||
parsed_response = parse_glm_response(content)
|
||||
action = parsed_response.get("action", "")
|
||||
action_text = parsed_response.get("action_text", "")
|
||||
|
||||
|
||||
# Add message item with text content (excluding action and memory)
|
||||
if action_text:
|
||||
# Remove action from action_text if it's there
|
||||
clean_text = action_text
|
||||
if action and action in clean_text:
|
||||
clean_text = clean_text.replace(action, "").strip()
|
||||
|
||||
|
||||
# Remove memory section
|
||||
memory_pattern = r"Memory:\s*\[.*?\]\s*$"
|
||||
clean_text = re.sub(memory_pattern, "", clean_text, flags=re.DOTALL).strip()
|
||||
|
||||
|
||||
if clean_text:
|
||||
message_item = model_dump(make_output_text_item(clean_text))
|
||||
response_items.append(message_item)
|
||||
|
||||
|
||||
# Convert action to computer call if present
|
||||
if action:
|
||||
call_id = f"call_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
# Parse different action types and create appropriate computer calls
|
||||
if action.startswith("left_click"):
|
||||
coord_match = re.search(r"start_box='?\[(\d+),\s*(\d+)\]'?", action)
|
||||
@@ -566,7 +577,7 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
|
||||
computer_call["call_id"] = call_id
|
||||
computer_call["status"] = "completed"
|
||||
response_items.append(computer_call)
|
||||
|
||||
|
||||
elif action.startswith("right_click"):
|
||||
coord_match = re.search(r"start_box='?\[(\d+),\s*(\d+)\]'?", action)
|
||||
if coord_match:
|
||||
@@ -577,7 +588,7 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
|
||||
computer_call["call_id"] = call_id
|
||||
computer_call["status"] = "completed"
|
||||
response_items.append(computer_call)
|
||||
|
||||
|
||||
elif action.startswith("left_double_click"):
|
||||
coord_match = re.search(r"start_box='?\[(\d+),\s*(\d+)\]'?", action)
|
||||
if coord_match:
|
||||
@@ -588,7 +599,7 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
|
||||
computer_call["call_id"] = call_id
|
||||
computer_call["status"] = "completed"
|
||||
response_items.append(computer_call)
|
||||
|
||||
|
||||
elif action.startswith("left_drag"):
|
||||
start_match = re.search(r"start_box='?\[(\d+),\s*(\d+)\]'?", action)
|
||||
end_match = re.search(r"end_box='?\[(\d+),\s*(\d+)\]'?", action)
|
||||
@@ -605,18 +616,18 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
|
||||
computer_call["call_id"] = call_id
|
||||
computer_call["status"] = "completed"
|
||||
response_items.append(computer_call)
|
||||
|
||||
|
||||
elif action.startswith("key"):
|
||||
key_match = re.search(r"keys='([^']+)'", action)
|
||||
if key_match:
|
||||
keys = key_match.group(1)
|
||||
# Split keys by '+' for key combinations, or use as single key
|
||||
key_list = keys.split('+') if '+' in keys else [keys]
|
||||
key_list = keys.split("+") if "+" in keys else [keys]
|
||||
computer_call = model_dump(make_keypress_item(key_list))
|
||||
computer_call["call_id"] = call_id
|
||||
computer_call["status"] = "completed"
|
||||
response_items.append(computer_call)
|
||||
|
||||
|
||||
elif action.startswith("type"):
|
||||
content_match = re.search(r"content='([^']*)'", action)
|
||||
if content_match:
|
||||
@@ -625,7 +636,7 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
|
||||
computer_call["call_id"] = call_id
|
||||
computer_call["status"] = "completed"
|
||||
response_items.append(computer_call)
|
||||
|
||||
|
||||
elif action.startswith("scroll"):
|
||||
coord_match = re.search(r"start_box='?\[(\d+),\s*(\d+)\]'?", action)
|
||||
direction_match = re.search(r"direction='([^']+)'", action)
|
||||
@@ -648,15 +659,16 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
|
||||
computer_call["call_id"] = call_id
|
||||
computer_call["status"] = "completed"
|
||||
response_items.append(computer_call)
|
||||
|
||||
|
||||
elif action == "WAIT()":
|
||||
computer_call = model_dump(make_wait_item())
|
||||
computer_call["call_id"] = call_id
|
||||
computer_call["status"] = "completed"
|
||||
response_items.append(computer_call)
|
||||
|
||||
|
||||
return response_items
|
||||
|
||||
|
||||
@register_agent(models=r"(?i).*GLM-4\.5V.*")
|
||||
class Glm4vConfig(AsyncAgentConfig):
|
||||
"""GLM-4.5V agent configuration using liteLLM."""
|
||||
@@ -674,11 +686,11 @@ class Glm4vConfig(AsyncAgentConfig):
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Predict the next step using GLM-4.5V model.
|
||||
|
||||
|
||||
Args:
|
||||
messages: Input messages following Responses format
|
||||
model: Model name to use
|
||||
@@ -691,7 +703,7 @@ class Glm4vConfig(AsyncAgentConfig):
|
||||
_on_api_end: Callback for API end
|
||||
_on_usage: Callback for usage tracking
|
||||
_on_screenshot: Callback for screenshot events
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with "output" and "usage" keys
|
||||
"""
|
||||
@@ -708,7 +720,7 @@ class Glm4vConfig(AsyncAgentConfig):
|
||||
user_instruction = item.get("text", "")
|
||||
break
|
||||
break
|
||||
|
||||
|
||||
# Get the last image for processing
|
||||
last_image_b64 = get_last_image_from_messages(messages)
|
||||
if not last_image_b64 and computer_handler:
|
||||
@@ -718,35 +730,28 @@ class Glm4vConfig(AsyncAgentConfig):
|
||||
last_image_b64 = screenshot_b64
|
||||
if _on_screenshot:
|
||||
await _on_screenshot(screenshot_b64)
|
||||
|
||||
|
||||
if not last_image_b64:
|
||||
raise ValueError("No image available for GLM-4.5V processing")
|
||||
|
||||
|
||||
# Convert responses items to GLM-4.5V PC prompt format with historical actions
|
||||
prompt_content = convert_responses_items_to_glm45v_pc_prompt(
|
||||
messages=messages,
|
||||
task=user_instruction,
|
||||
memory="[]" # Initialize with empty memory for now
|
||||
memory="[]", # Initialize with empty memory for now
|
||||
)
|
||||
|
||||
|
||||
# Add the current screenshot to the end
|
||||
prompt_content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{last_image_b64}"}
|
||||
})
|
||||
|
||||
prompt_content.append(
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{last_image_b64}"}}
|
||||
)
|
||||
|
||||
# Prepare messages for liteLLM
|
||||
litellm_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful GUI agent assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt_content
|
||||
}
|
||||
{"role": "system", "content": "You are a helpful GUI agent assistant."},
|
||||
{"role": "user", "content": prompt_content},
|
||||
]
|
||||
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
@@ -757,20 +762,20 @@ class Glm4vConfig(AsyncAgentConfig):
|
||||
# "skip_special_tokens": False,
|
||||
# }
|
||||
}
|
||||
|
||||
|
||||
# Add API callbacks
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
|
||||
# Call liteLLM
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
|
||||
# Get image dimensions for coordinate scaling
|
||||
image_width, image_height = 1920, 1080 # Default dimensions
|
||||
|
||||
|
||||
# Try to get actual dimensions from the image
|
||||
try:
|
||||
image_data = base64.b64decode(last_image_b64)
|
||||
@@ -778,41 +783,38 @@ class Glm4vConfig(AsyncAgentConfig):
|
||||
image_width, image_height = image.size
|
||||
except Exception:
|
||||
pass # Use default dimensions
|
||||
|
||||
|
||||
# Convert GLM completion response to responses items
|
||||
response_items = convert_glm_completion_to_responses_items(response, image_width, image_height)
|
||||
|
||||
response_items = convert_glm_completion_to_responses_items(
|
||||
response, image_width, image_height
|
||||
)
|
||||
|
||||
# Extract usage information
|
||||
response_usage = {
|
||||
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(response.usage).model_dump(),
|
||||
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(
|
||||
response.usage
|
||||
).model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(response_usage)
|
||||
|
||||
|
||||
# Create agent response
|
||||
agent_response = {
|
||||
"output": response_items,
|
||||
"usage": response_usage
|
||||
}
|
||||
|
||||
agent_response = {"output": response_items, "usage": response_usage}
|
||||
|
||||
return agent_response
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str,
|
||||
**kwargs
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates using GLM-4.5V model.
|
||||
|
||||
|
||||
Args:
|
||||
model: Model name to use
|
||||
image_b64: Base64 encoded image
|
||||
instruction: Instruction for where to click
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple with (x, y) coordinates or None
|
||||
"""
|
||||
@@ -824,22 +826,22 @@ Respond with a single click action in this format:
|
||||
left_click(start_box='[x,y]')
|
||||
|
||||
Where x,y are coordinates normalized to 0-999 range."""
|
||||
|
||||
|
||||
# Prepare messages for liteLLM
|
||||
litellm_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful GUI agent assistant."
|
||||
},
|
||||
{"role": "system", "content": "You are a helpful GUI agent assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": click_prompt},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}}
|
||||
]
|
||||
}
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
@@ -848,21 +850,21 @@ Where x,y are coordinates normalized to 0-999 range."""
|
||||
"temperature": 0.001,
|
||||
"extra_body": {
|
||||
"skip_special_tokens": False,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Call liteLLM
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
|
||||
# Extract response content
|
||||
response_content = response.choices[0].message.content.strip()
|
||||
print(response)
|
||||
|
||||
|
||||
# Parse response for click coordinates
|
||||
# Look for coordinates in the response, handling special tokens
|
||||
coord_pattern = r"<\|begin_of_box\|>.*?left_click\(start_box='?\[(\d+),(\d+)\]'?\).*?<\|end_of_box\|>"
|
||||
match = re.search(coord_pattern, response_content)
|
||||
|
||||
|
||||
if not match:
|
||||
# Fallback: look for coordinates without special tokens
|
||||
coord_pattern = r"left_click\(start_box='?\[(\d+),(\d+)\]'?\)"
|
||||
@@ -870,7 +872,7 @@ Where x,y are coordinates normalized to 0-999 range."""
|
||||
|
||||
if match:
|
||||
x, y = int(match.group(1)), int(match.group(2))
|
||||
|
||||
|
||||
# Get actual image dimensions for scaling
|
||||
try:
|
||||
image_data = base64.b64decode(image_b64)
|
||||
@@ -879,15 +881,15 @@ Where x,y are coordinates normalized to 0-999 range."""
|
||||
except Exception:
|
||||
# Use default dimensions
|
||||
image_width, image_height = 1920, 1080
|
||||
|
||||
|
||||
# Convert from 0-999 normalized coordinates to actual pixel coordinates
|
||||
actual_x = int((x / 999.0) * image_width)
|
||||
actual_y = int((y / 999.0) * image_height)
|
||||
|
||||
|
||||
return (actual_x, actual_y)
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Log error and return None
|
||||
print(f"Error in predict_click: {e}")
|
||||
@@ -896,7 +898,7 @@ Where x,y are coordinates normalized to 0-999 range."""
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""
|
||||
Get list of capabilities supported by this agent config.
|
||||
|
||||
|
||||
Returns:
|
||||
List of capability strings
|
||||
"""
|
||||
|
||||
@@ -5,75 +5,80 @@ Code: https://github.com/Yan98/GTA1
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import base64
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
|
||||
from io import BytesIO
|
||||
import uuid
|
||||
from PIL import Image
|
||||
import litellm
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import litellm
|
||||
from PIL import Image
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..types import AgentCapability, AgentResponse, Messages, Tools
|
||||
|
||||
SYSTEM_PROMPT = '''
|
||||
SYSTEM_PROMPT = """
|
||||
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point.
|
||||
|
||||
Output the coordinate pair exactly:
|
||||
(x,y)
|
||||
'''.strip()
|
||||
""".strip()
|
||||
|
||||
|
||||
def extract_coordinates(raw_string: str) -> Tuple[float, float]:
|
||||
"""Extract coordinates from model output."""
|
||||
try:
|
||||
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
|
||||
return tuple(map(float, matches[0])) # type: ignore
|
||||
return tuple(map(float, matches[0])) # type: ignore
|
||||
except:
|
||||
return (0.0, 0.0)
|
||||
|
||||
def smart_resize(height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 8847360) -> Tuple[int, int]:
|
||||
|
||||
def smart_resize(
|
||||
height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 8847360
|
||||
) -> Tuple[int, int]:
|
||||
"""Smart resize function similar to qwen_vl_utils."""
|
||||
# Calculate the total pixels
|
||||
total_pixels = height * width
|
||||
|
||||
|
||||
# If already within bounds, return original dimensions
|
||||
if min_pixels <= total_pixels <= max_pixels:
|
||||
# Round to nearest factor
|
||||
new_height = (height // factor) * factor
|
||||
new_width = (width // factor) * factor
|
||||
return new_height, new_width
|
||||
|
||||
|
||||
# Calculate scaling factor
|
||||
if total_pixels > max_pixels:
|
||||
scale = (max_pixels / total_pixels) ** 0.5
|
||||
else:
|
||||
scale = (min_pixels / total_pixels) ** 0.5
|
||||
|
||||
|
||||
# Apply scaling
|
||||
new_height = int(height * scale)
|
||||
new_width = int(width * scale)
|
||||
|
||||
|
||||
# Round to nearest factor
|
||||
new_height = (new_height // factor) * factor
|
||||
new_width = (new_width // factor) * factor
|
||||
|
||||
|
||||
# Ensure minimum size
|
||||
new_height = max(new_height, factor)
|
||||
new_width = max(new_width, factor)
|
||||
|
||||
|
||||
return new_height, new_width
|
||||
|
||||
|
||||
@register_agent(models=r".*GTA1.*")
|
||||
class GTA1Config(AsyncAgentConfig):
|
||||
"""GTA1 agent configuration implementing AsyncAgentConfig protocol for click prediction."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.current_model = None
|
||||
self.last_screenshot_b64 = None
|
||||
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
@@ -87,25 +92,21 @@ class GTA1Config(AsyncAgentConfig):
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str,
|
||||
**kwargs
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
) -> Optional[Tuple[float, float]]:
|
||||
"""
|
||||
Predict click coordinates using GTA1 model via litellm.acompletion.
|
||||
|
||||
|
||||
Args:
|
||||
model: The GTA1 model name
|
||||
image_b64: Base64 encoded image
|
||||
instruction: Instruction for where to click
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
@@ -113,66 +114,62 @@ class GTA1Config(AsyncAgentConfig):
|
||||
image_data = base64.b64decode(image_b64)
|
||||
image = Image.open(BytesIO(image_data))
|
||||
width, height = image.width, image.height
|
||||
|
||||
|
||||
# Smart resize the image (similar to qwen_vl_utils)
|
||||
resized_height, resized_width = smart_resize(
|
||||
height, width,
|
||||
height,
|
||||
width,
|
||||
factor=28, # Default factor for Qwen models
|
||||
min_pixels=3136,
|
||||
max_pixels=4096 * 2160
|
||||
max_pixels=4096 * 2160,
|
||||
)
|
||||
resized_image = image.resize((resized_width, resized_height))
|
||||
scale_x, scale_y = width / resized_width, height / resized_height
|
||||
|
||||
|
||||
# Convert resized image back to base64
|
||||
buffered = BytesIO()
|
||||
resized_image.save(buffered, format="PNG")
|
||||
resized_image_b64 = base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
|
||||
# Prepare system and user messages
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": SYSTEM_PROMPT.format(height=resized_height, width=resized_width)
|
||||
"content": SYSTEM_PROMPT.format(height=resized_height, width=resized_width),
|
||||
}
|
||||
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{resized_image_b64}"
|
||||
}
|
||||
"image_url": {"url": f"data:image/png;base64,{resized_image_b64}"},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": instruction
|
||||
}
|
||||
]
|
||||
{"type": "text", "text": instruction},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": [system_message, user_message],
|
||||
"max_tokens": 2056,
|
||||
"temperature": 0.0,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
# Use liteLLM acompletion
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
|
||||
# Extract response text
|
||||
output_text = response.choices[0].message.content # type: ignore
|
||||
|
||||
output_text = response.choices[0].message.content # type: ignore
|
||||
|
||||
# Extract and rescale coordinates
|
||||
pred_x, pred_y = extract_coordinates(output_text) # type: ignore
|
||||
pred_x, pred_y = extract_coordinates(output_text) # type: ignore
|
||||
pred_x *= scale_x
|
||||
pred_y *= scale_y
|
||||
|
||||
|
||||
return (math.floor(pred_x), math.floor(pred_y))
|
||||
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""Return the capabilities supported by this agent."""
|
||||
return ["click"]
|
||||
|
||||
@@ -21,8 +21,8 @@ import litellm
|
||||
from PIL import Image
|
||||
|
||||
from ..decorators import register_agent
|
||||
from .base import AsyncAgentConfig
|
||||
from ..types import AgentCapability
|
||||
from .base import AsyncAgentConfig
|
||||
|
||||
|
||||
def _strip_hf_prefix(model: str) -> str:
|
||||
@@ -53,7 +53,9 @@ def _maybe_smart_resize(image: Image.Image, model: str) -> Tuple[Image.Image, Tu
|
||||
if image_processor is None:
|
||||
return image, (orig_w, orig_h)
|
||||
|
||||
factor = getattr(image_processor, "patch_size", 14) * getattr(image_processor, "merge_size", 1)
|
||||
factor = getattr(image_processor, "patch_size", 14) * getattr(
|
||||
image_processor, "merge_size", 1
|
||||
)
|
||||
min_pixels = getattr(image_processor, "min_pixels", 256 * 256)
|
||||
max_pixels = getattr(image_processor, "max_pixels", 1536 * 1536)
|
||||
|
||||
|
||||
@@ -18,13 +18,12 @@ import re
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from PIL import Image
|
||||
import litellm
|
||||
from PIL import Image
|
||||
|
||||
from ..decorators import register_agent
|
||||
from .composed_grounded import ComposedGroundedConfig
|
||||
from ..types import AgentCapability
|
||||
|
||||
from .composed_grounded import ComposedGroundedConfig
|
||||
|
||||
# Regex patterns for extracting coordinates
|
||||
# Accept optional whitespace and optional decimal fractions
|
||||
@@ -91,7 +90,7 @@ class InternVLConfig(ComposedGroundedConfig):
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Fallback to a self-composed model"""
|
||||
return await super().predict_step(
|
||||
@@ -105,15 +104,11 @@ class InternVLConfig(ComposedGroundedConfig):
|
||||
_on_api_end=_on_api_end,
|
||||
_on_usage=_on_usage,
|
||||
_on_screenshot=_on_screenshot,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str,
|
||||
**kwargs
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates using InternVL via litellm.acompletion.
|
||||
|
||||
493
libs/python/agent/agent/loops/moondream3.py
Normal file
493
libs/python/agent/agent/loops/moondream3.py
Normal file
@@ -0,0 +1,493 @@
|
||||
"""
|
||||
Moondream3+ composed-grounded agent loop implementation.
|
||||
Grounding is handled by a local Moondream3 preview model via Transformers.
|
||||
Thinking is delegated to the trailing LLM in the composed model string: "moondream3+<thinking_model>".
|
||||
|
||||
Differences from composed_grounded:
|
||||
- Provides a singleton Moondream3 client outside the class.
|
||||
- predict_click uses model.point(image, instruction, settings={"max_objects": 1}) and returns pixel coordinates.
|
||||
- If the last image was a screenshot (or we take one), run model.detect(image, "all form ui") to get bboxes, then
|
||||
run model.caption on each cropped bbox to label it. Overlay labels on the screenshot and emit via _on_screenshot.
|
||||
- Add a user message listing all detected form UI names so the thinker can reference them.
|
||||
- If the thinking model doesn't support vision, filter out image content before calling litellm.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..responses import (
|
||||
convert_completion_messages_to_responses_items,
|
||||
convert_computer_calls_desc2xy,
|
||||
convert_computer_calls_xy2desc,
|
||||
convert_responses_items_to_completion_messages,
|
||||
get_all_element_descriptions,
|
||||
)
|
||||
from ..types import AgentCapability
|
||||
|
||||
_MOONDREAM_SINGLETON = None
|
||||
|
||||
|
||||
def get_moondream_model() -> Any:
|
||||
"""Get a singleton instance of the Moondream3 preview model."""
|
||||
global _MOONDREAM_SINGLETON
|
||||
if _MOONDREAM_SINGLETON is None:
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
_MOONDREAM_SINGLETON = AutoModelForCausalLM.from_pretrained(
|
||||
"moondream/moondream3-preview",
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda",
|
||||
)
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"moondream3 requires torch and transformers. Install with: pip install cua-agent[moondream3]"
|
||||
) from e
|
||||
return _MOONDREAM_SINGLETON
|
||||
|
||||
|
||||
def _decode_image_b64(image_b64: str) -> Image.Image:
|
||||
data = base64.b64decode(image_b64)
|
||||
return Image.open(io.BytesIO(data)).convert("RGB")
|
||||
|
||||
|
||||
def _image_to_b64(img: Image.Image) -> str:
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
return base64.b64encode(buf.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def _supports_vision(model: str) -> bool:
|
||||
"""Heuristic vision support detection for thinking model."""
|
||||
m = model.lower()
|
||||
vision_markers = [
|
||||
"gpt-4o",
|
||||
"gpt-4.1",
|
||||
"o1",
|
||||
"o3",
|
||||
"claude-3",
|
||||
"claude-3.5",
|
||||
"sonnet",
|
||||
"haiku",
|
||||
"opus",
|
||||
"gemini-1.5",
|
||||
"llava",
|
||||
]
|
||||
return any(v in m for v in vision_markers)
|
||||
|
||||
|
||||
def _filter_images_from_completion_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
filtered: List[Dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
msg_copy = {**msg}
|
||||
content = msg_copy.get("content")
|
||||
if isinstance(content, list):
|
||||
msg_copy["content"] = [c for c in content if c.get("type") != "image_url"]
|
||||
filtered.append(msg_copy)
|
||||
return filtered
|
||||
|
||||
|
||||
def _annotate_detect_and_label_ui(base_img: Image.Image, model_md) -> Tuple[str, List[str]]:
|
||||
"""Detect UI elements with Moondream, caption each, draw labels with backgrounds.
|
||||
|
||||
Args:
|
||||
base_img: PIL image of the screenshot (RGB or RGBA). Will be copied/converted internally.
|
||||
model_md: Moondream model instance with .detect() and .query() methods.
|
||||
|
||||
Returns:
|
||||
A tuple of (annotated_image_base64_png, detected_names)
|
||||
"""
|
||||
# Ensure RGBA for semi-transparent fills
|
||||
if base_img.mode != "RGBA":
|
||||
base_img = base_img.convert("RGBA")
|
||||
W, H = base_img.width, base_img.height
|
||||
|
||||
# Detect objects
|
||||
try:
|
||||
detect_result = model_md.detect(base_img, "all ui elements")
|
||||
objects = detect_result.get("objects", []) if isinstance(detect_result, dict) else []
|
||||
except Exception:
|
||||
objects = []
|
||||
|
||||
draw = ImageDraw.Draw(base_img)
|
||||
try:
|
||||
font = ImageFont.load_default()
|
||||
except Exception:
|
||||
font = None
|
||||
|
||||
detected_names: List[str] = []
|
||||
|
||||
for i, obj in enumerate(objects):
|
||||
try:
|
||||
# Clamp normalized coords and crop
|
||||
x_min = max(0.0, min(1.0, float(obj.get("x_min", 0.0))))
|
||||
y_min = max(0.0, min(1.0, float(obj.get("y_min", 0.0))))
|
||||
x_max = max(0.0, min(1.0, float(obj.get("x_max", 0.0))))
|
||||
y_max = max(0.0, min(1.0, float(obj.get("y_max", 0.0))))
|
||||
left, top, right, bottom = (
|
||||
int(x_min * W),
|
||||
int(y_min * H),
|
||||
int(x_max * W),
|
||||
int(y_max * H),
|
||||
)
|
||||
left, top = max(0, left), max(0, top)
|
||||
right, bottom = min(W - 1, right), min(H - 1, bottom)
|
||||
crop = base_img.crop((left, top, right, bottom))
|
||||
|
||||
# Prompted short caption
|
||||
try:
|
||||
result = model_md.query(crop, "Caption this UI element in few words.")
|
||||
caption_text = (result or {}).get("answer", "")
|
||||
except Exception:
|
||||
caption_text = ""
|
||||
|
||||
name = (caption_text or "").strip() or f"element_{i+1}"
|
||||
detected_names.append(name)
|
||||
|
||||
# Draw bbox
|
||||
draw.rectangle([left, top, right, bottom], outline=(255, 215, 0, 255), width=2)
|
||||
|
||||
# Label background with padding and rounded corners
|
||||
label = f"{i+1}. {name}"
|
||||
padding = 3
|
||||
if font:
|
||||
text_bbox = draw.textbbox((0, 0), label, font=font)
|
||||
else:
|
||||
text_bbox = draw.textbbox((0, 0), label)
|
||||
text_w = text_bbox[2] - text_bbox[0]
|
||||
text_h = text_bbox[3] - text_bbox[1]
|
||||
|
||||
tx = left + 3
|
||||
ty = top - (text_h + 2 * padding + 4)
|
||||
if ty < 0:
|
||||
ty = top + 3
|
||||
|
||||
bg_left = tx - padding
|
||||
bg_top = ty - padding
|
||||
bg_right = tx + text_w + padding
|
||||
bg_bottom = ty + text_h + padding
|
||||
try:
|
||||
draw.rounded_rectangle(
|
||||
[bg_left, bg_top, bg_right, bg_bottom],
|
||||
radius=4,
|
||||
fill=(0, 0, 0, 160),
|
||||
outline=(255, 215, 0, 200),
|
||||
width=1,
|
||||
)
|
||||
except Exception:
|
||||
draw.rectangle(
|
||||
[bg_left, bg_top, bg_right, bg_bottom],
|
||||
fill=(0, 0, 0, 160),
|
||||
outline=(255, 215, 0, 200),
|
||||
width=1,
|
||||
)
|
||||
|
||||
text_fill = (255, 255, 255, 255)
|
||||
if font:
|
||||
draw.text((tx, ty), label, fill=text_fill, font=font)
|
||||
else:
|
||||
draw.text((tx, ty), label, fill=text_fill)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Encode PNG base64
|
||||
annotated = base_img
|
||||
if annotated.mode not in ("RGBA", "RGB"):
|
||||
annotated = annotated.convert("RGBA")
|
||||
annotated_b64 = _image_to_b64(annotated)
|
||||
return annotated_b64, detected_names
|
||||
|
||||
|
||||
GROUNDED_COMPUTER_TOOL_SCHEMA = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "computer",
|
||||
"description": (
|
||||
"Control a computer by taking screenshots and interacting with UI elements. "
|
||||
"The screenshot action will include a list of detected form UI element names when available. "
|
||||
"Use element descriptions to locate and interact with UI elements on the screen."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"screenshot",
|
||||
"click",
|
||||
"double_click",
|
||||
"drag",
|
||||
"type",
|
||||
"keypress",
|
||||
"scroll",
|
||||
"move",
|
||||
"wait",
|
||||
"get_current_url",
|
||||
"get_dimensions",
|
||||
"get_environment",
|
||||
],
|
||||
"description": "The action to perform (required for all actions)",
|
||||
},
|
||||
"element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to interact with (required for click/double_click/move/scroll)",
|
||||
},
|
||||
"start_element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to start dragging from (required for drag)",
|
||||
},
|
||||
"end_element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to drag to (required for drag)",
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to type (required for type)",
|
||||
},
|
||||
"keys": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Key(s) to press (required for keypress)",
|
||||
},
|
||||
"button": {
|
||||
"type": "string",
|
||||
"enum": ["left", "right", "wheel", "back", "forward"],
|
||||
"description": "The mouse button to use for click/double_click",
|
||||
},
|
||||
"scroll_x": {
|
||||
"type": "integer",
|
||||
"description": "Horizontal scroll amount (required for scroll)",
|
||||
},
|
||||
"scroll_y": {
|
||||
"type": "integer",
|
||||
"description": "Vertical scroll amount (required for scroll)",
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@register_agent(r"moondream3\+.*", priority=2)
|
||||
class Moondream3PlusConfig(AsyncAgentConfig):
|
||||
def __init__(self):
|
||||
self.desc2xy: Dict[str, Tuple[float, float]] = {}
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
# Parse composed model: moondream3+<thinking_model>
|
||||
if "+" not in model:
|
||||
raise ValueError(f"Composed model must be 'moondream3+<thinking_model>', got: {model}")
|
||||
_, thinking_model = model.split("+", 1)
|
||||
|
||||
pre_output_items: List[Dict[str, Any]] = []
|
||||
|
||||
# Acquire last screenshot; if missing, take one
|
||||
last_image_b64: Optional[str] = None
|
||||
for message in reversed(messages):
|
||||
if (
|
||||
isinstance(message, dict)
|
||||
and message.get("type") == "computer_call_output"
|
||||
and isinstance(message.get("output"), dict)
|
||||
and message["output"].get("type") == "input_image"
|
||||
):
|
||||
image_url = message["output"].get("image_url", "")
|
||||
if image_url.startswith("data:image/png;base64,"):
|
||||
last_image_b64 = image_url.split(",", 1)[1]
|
||||
break
|
||||
|
||||
if last_image_b64 is None and computer_handler is not None:
|
||||
# Take a screenshot
|
||||
screenshot_b64 = await computer_handler.screenshot() # type: ignore
|
||||
if screenshot_b64:
|
||||
call_id = uuid.uuid4().hex
|
||||
pre_output_items += [
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": "Taking a screenshot to analyze the current screen.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"type": "computer_call",
|
||||
"call_id": call_id,
|
||||
"status": "completed",
|
||||
"action": {"type": "screenshot"},
|
||||
},
|
||||
{
|
||||
"type": "computer_call_output",
|
||||
"call_id": call_id,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshot_b64}",
|
||||
},
|
||||
},
|
||||
]
|
||||
last_image_b64 = screenshot_b64
|
||||
if _on_screenshot:
|
||||
await _on_screenshot(screenshot_b64)
|
||||
|
||||
# If we have a last screenshot, run Moondream detection and labeling
|
||||
detected_names: List[str] = []
|
||||
if last_image_b64 is not None:
|
||||
base_img = _decode_image_b64(last_image_b64)
|
||||
model_md = get_moondream_model()
|
||||
annotated_b64, detected_names = _annotate_detect_and_label_ui(base_img, model_md)
|
||||
if _on_screenshot:
|
||||
await _on_screenshot(annotated_b64, "annotated_form_ui")
|
||||
|
||||
# Also push a user message listing all detected names
|
||||
if detected_names:
|
||||
names_text = "\n".join(f"- {n}" for n in detected_names)
|
||||
pre_output_items.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "input_text", "text": "Detected form UI elements on screen:"},
|
||||
{"type": "input_text", "text": names_text},
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "Please continue with the next action needed to perform your task.",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
tool_schemas = []
|
||||
for schema in tools or []:
|
||||
if schema.get("type") == "computer":
|
||||
tool_schemas.append(GROUNDED_COMPUTER_TOOL_SCHEMA)
|
||||
else:
|
||||
tool_schemas.append(schema)
|
||||
|
||||
# Step 1: Convert computer calls from xy to descriptions
|
||||
input_messages = messages + pre_output_items
|
||||
messages_with_descriptions = convert_computer_calls_xy2desc(input_messages, self.desc2xy)
|
||||
|
||||
# Step 2: Convert responses items to completion messages
|
||||
completion_messages = convert_responses_items_to_completion_messages(
|
||||
messages_with_descriptions,
|
||||
allow_images_in_tool_results=False,
|
||||
)
|
||||
|
||||
# Optionally filter images if model lacks vision
|
||||
if not _supports_vision(thinking_model):
|
||||
completion_messages = _filter_images_from_completion_messages(completion_messages)
|
||||
|
||||
# Step 3: Call thinking model with litellm.acompletion
|
||||
api_kwargs = {
|
||||
"model": thinking_model,
|
||||
"messages": completion_messages,
|
||||
"tools": tool_schemas,
|
||||
"max_retries": max_retries,
|
||||
"stream": stream,
|
||||
**kwargs,
|
||||
}
|
||||
if use_prompt_caching:
|
||||
api_kwargs["use_prompt_caching"] = use_prompt_caching
|
||||
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
usage = {
|
||||
**response.usage.model_dump(), # type: ignore
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(usage)
|
||||
|
||||
# Step 4: Convert completion messages back to responses items format
|
||||
response_dict = response.model_dump() # type: ignore
|
||||
choice_messages = [choice["message"] for choice in response_dict["choices"]]
|
||||
thinking_output_items: List[Dict[str, Any]] = []
|
||||
for choice_message in choice_messages:
|
||||
thinking_output_items.extend(
|
||||
convert_completion_messages_to_responses_items([choice_message])
|
||||
)
|
||||
|
||||
# Step 5: Use Moondream to get coordinates for each description
|
||||
element_descriptions = get_all_element_descriptions(thinking_output_items)
|
||||
if element_descriptions and last_image_b64:
|
||||
for desc in element_descriptions:
|
||||
for _ in range(3): # try 3 times
|
||||
coords = await self.predict_click(
|
||||
model=model,
|
||||
image_b64=last_image_b64,
|
||||
instruction=desc,
|
||||
)
|
||||
if coords:
|
||||
self.desc2xy[desc] = coords
|
||||
break
|
||||
|
||||
# Step 6: Convert computer calls from descriptions back to xy coordinates
|
||||
final_output_items = convert_computer_calls_desc2xy(thinking_output_items, self.desc2xy)
|
||||
|
||||
# Step 7: Return output and usage
|
||||
return {"output": pre_output_items + final_output_items, "usage": usage}
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str,
|
||||
**kwargs,
|
||||
) -> Optional[Tuple[float, float]]:
|
||||
"""Predict click coordinates using Moondream3's point API.
|
||||
|
||||
Returns pixel coordinates (x, y) as floats.
|
||||
"""
|
||||
img = _decode_image_b64(image_b64)
|
||||
W, H = img.width, img.height
|
||||
model_md = get_moondream_model()
|
||||
try:
|
||||
result = model_md.point(img, instruction, settings={"max_objects": 1})
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
try:
|
||||
pt = (result or {}).get("points", [])[0]
|
||||
x_norm = float(pt.get("x", 0.0))
|
||||
y_norm = float(pt.get("y", 0.0))
|
||||
x_px = max(0.0, min(float(W - 1), x_norm * W))
|
||||
y_px = max(0.0, min(float(H - 1), y_norm * H))
|
||||
return (x_px, y_px)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
return ["click", "step"]
|
||||
@@ -5,100 +5,102 @@ Code: https://github.com/microsoft/OmniParser
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
|
||||
import litellm
|
||||
import inspect
|
||||
import base64
|
||||
import inspect
|
||||
import json
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import litellm
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..types import AgentCapability, AgentResponse, Messages, Tools
|
||||
|
||||
SOM_TOOL_SCHEMA = {
|
||||
"type": "function",
|
||||
"name": "computer",
|
||||
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool shows screenshots with numbered elements overlaid on them. Each UI element has been assigned a unique ID number that you can see in the image. Use the element's ID number to interact with any element instead of pixel coordinates.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"screenshot",
|
||||
"click",
|
||||
"double_click",
|
||||
"drag",
|
||||
"type",
|
||||
"keypress",
|
||||
"scroll",
|
||||
"move",
|
||||
"wait",
|
||||
"get_current_url",
|
||||
"get_dimensions",
|
||||
"get_environment"
|
||||
],
|
||||
"description": "The action to perform"
|
||||
},
|
||||
"element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)"
|
||||
},
|
||||
"start_element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to start dragging from (required for drag action)"
|
||||
},
|
||||
"end_element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to drag to (required for drag action)"
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to type (required for type action)"
|
||||
},
|
||||
"keys": {
|
||||
"type": "string",
|
||||
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')"
|
||||
},
|
||||
"button": {
|
||||
"type": "string",
|
||||
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
|
||||
},
|
||||
"scroll_x": {
|
||||
"type": "integer",
|
||||
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
|
||||
},
|
||||
"scroll_y": {
|
||||
"type": "integer",
|
||||
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
|
||||
},
|
||||
"type": "function",
|
||||
"name": "computer",
|
||||
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool shows screenshots with numbered elements overlaid on them. Each UI element has been assigned a unique ID number that you can see in the image. Use the element's ID number to interact with any element instead of pixel coordinates.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"screenshot",
|
||||
"click",
|
||||
"double_click",
|
||||
"drag",
|
||||
"type",
|
||||
"keypress",
|
||||
"scroll",
|
||||
"move",
|
||||
"wait",
|
||||
"get_current_url",
|
||||
"get_dimensions",
|
||||
"get_environment",
|
||||
],
|
||||
"description": "The action to perform",
|
||||
},
|
||||
"element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)",
|
||||
},
|
||||
"start_element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to start dragging from (required for drag action)",
|
||||
},
|
||||
"end_element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to drag to (required for drag action)",
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to type (required for type action)",
|
||||
},
|
||||
"keys": {
|
||||
"type": "string",
|
||||
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')",
|
||||
},
|
||||
"button": {
|
||||
"type": "string",
|
||||
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
|
||||
},
|
||||
"scroll_x": {
|
||||
"type": "integer",
|
||||
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
|
||||
},
|
||||
"scroll_y": {
|
||||
"type": "integer",
|
||||
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
},
|
||||
"required": [
|
||||
"action"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
OMNIPARSER_AVAILABLE = False
|
||||
try:
|
||||
from som import OmniParser
|
||||
|
||||
OMNIPARSER_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
OMNIPARSER_SINGLETON = None
|
||||
|
||||
|
||||
def get_parser():
|
||||
global OMNIPARSER_SINGLETON
|
||||
if OMNIPARSER_SINGLETON is None:
|
||||
OMNIPARSER_SINGLETON = OmniParser()
|
||||
return OMNIPARSER_SINGLETON
|
||||
|
||||
|
||||
|
||||
def get_last_computer_call_output(messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""Get the last computer_call_output message from a messages list.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of messages to search through
|
||||
|
||||
|
||||
Returns:
|
||||
The last computer_call_output message dict, or None if not found
|
||||
"""
|
||||
@@ -107,11 +109,12 @@ def get_last_computer_call_output(messages: List[Dict[str, Any]]) -> Optional[Di
|
||||
return message
|
||||
return None
|
||||
|
||||
|
||||
def _prepare_tools_for_omniparser(tool_schemas: List[Dict[str, Any]]) -> Tuple[Tools, dict]:
|
||||
"""Prepare tools for OpenAI API format"""
|
||||
omniparser_tools = []
|
||||
id2xy = dict()
|
||||
|
||||
|
||||
for schema in tool_schemas:
|
||||
if schema["type"] == "computer":
|
||||
omniparser_tools.append(SOM_TOOL_SCHEMA)
|
||||
@@ -122,72 +125,80 @@ def _prepare_tools_for_omniparser(tool_schemas: List[Dict[str, Any]]) -> Tuple[T
|
||||
elif schema["type"] == "function":
|
||||
# Function tools use OpenAI-compatible schema directly (liteLLM expects this format)
|
||||
# Schema should be: {type, name, description, parameters}
|
||||
omniparser_tools.append({ "type": "function", **schema["function"] })
|
||||
|
||||
omniparser_tools.append({"type": "function", **schema["function"]})
|
||||
|
||||
return omniparser_tools, id2xy
|
||||
|
||||
async def replace_function_with_computer_call(item: Dict[str, Any], id2xy: Dict[int, Tuple[float, float]]):
|
||||
item_type = item.get("type")
|
||||
|
||||
def _get_xy(element_id: Optional[int]) -> Union[Tuple[float, float], Tuple[None, None]]:
|
||||
if element_id is None:
|
||||
return (None, None)
|
||||
return id2xy.get(element_id, (None, None))
|
||||
async def replace_function_with_computer_call(
|
||||
item: Dict[str, Any], id2xy: Dict[int, Tuple[float, float]]
|
||||
):
|
||||
item_type = item.get("type")
|
||||
|
||||
if item_type == "function_call":
|
||||
fn_name = item.get("name")
|
||||
fn_args = json.loads(item.get("arguments", "{}"))
|
||||
def _get_xy(element_id: Optional[int]) -> Union[Tuple[float, float], Tuple[None, None]]:
|
||||
if element_id is None:
|
||||
return (None, None)
|
||||
return id2xy.get(element_id, (None, None))
|
||||
|
||||
item_id = item.get("id")
|
||||
call_id = item.get("call_id")
|
||||
|
||||
if fn_name == "computer":
|
||||
action = fn_args.get("action")
|
||||
element_id = fn_args.get("element_id")
|
||||
start_element_id = fn_args.get("start_element_id")
|
||||
end_element_id = fn_args.get("end_element_id")
|
||||
text = fn_args.get("text")
|
||||
keys = fn_args.get("keys")
|
||||
button = fn_args.get("button")
|
||||
scroll_x = fn_args.get("scroll_x")
|
||||
scroll_y = fn_args.get("scroll_y")
|
||||
if item_type == "function_call":
|
||||
fn_name = item.get("name")
|
||||
fn_args = json.loads(item.get("arguments", "{}"))
|
||||
|
||||
x, y = _get_xy(element_id)
|
||||
start_x, start_y = _get_xy(start_element_id)
|
||||
end_x, end_y = _get_xy(end_element_id)
|
||||
item_id = item.get("id")
|
||||
call_id = item.get("call_id")
|
||||
|
||||
action_args = {
|
||||
"type": action,
|
||||
"x": x,
|
||||
"y": y,
|
||||
"start_x": start_x,
|
||||
"start_y": start_y,
|
||||
"end_x": end_x,
|
||||
"end_y": end_y,
|
||||
"text": text,
|
||||
"keys": keys,
|
||||
"button": button,
|
||||
"scroll_x": scroll_x,
|
||||
"scroll_y": scroll_y
|
||||
}
|
||||
# Remove None values to keep the JSON clean
|
||||
action_args = {k: v for k, v in action_args.items() if v is not None}
|
||||
if fn_name == "computer":
|
||||
action = fn_args.get("action")
|
||||
element_id = fn_args.get("element_id")
|
||||
start_element_id = fn_args.get("start_element_id")
|
||||
end_element_id = fn_args.get("end_element_id")
|
||||
text = fn_args.get("text")
|
||||
keys = fn_args.get("keys")
|
||||
button = fn_args.get("button")
|
||||
scroll_x = fn_args.get("scroll_x")
|
||||
scroll_y = fn_args.get("scroll_y")
|
||||
|
||||
return [{
|
||||
"type": "computer_call",
|
||||
"action": action_args,
|
||||
"id": item_id,
|
||||
"call_id": call_id,
|
||||
"status": "completed"
|
||||
}]
|
||||
x, y = _get_xy(element_id)
|
||||
start_x, start_y = _get_xy(start_element_id)
|
||||
end_x, end_y = _get_xy(end_element_id)
|
||||
|
||||
return [item]
|
||||
action_args = {
|
||||
"type": action,
|
||||
"x": x,
|
||||
"y": y,
|
||||
"start_x": start_x,
|
||||
"start_y": start_y,
|
||||
"end_x": end_x,
|
||||
"end_y": end_y,
|
||||
"text": text,
|
||||
"keys": keys,
|
||||
"button": button,
|
||||
"scroll_x": scroll_x,
|
||||
"scroll_y": scroll_y,
|
||||
}
|
||||
# Remove None values to keep the JSON clean
|
||||
action_args = {k: v for k, v in action_args.items() if v is not None}
|
||||
|
||||
async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[Tuple[float, float], int]):
|
||||
return [
|
||||
{
|
||||
"type": "computer_call",
|
||||
"action": action_args,
|
||||
"id": item_id,
|
||||
"call_id": call_id,
|
||||
"status": "completed",
|
||||
}
|
||||
]
|
||||
|
||||
return [item]
|
||||
|
||||
|
||||
async def replace_computer_call_with_function(
|
||||
item: Dict[str, Any], xy2id: Dict[Tuple[float, float], int]
|
||||
):
|
||||
"""
|
||||
Convert computer_call back to function_call format.
|
||||
Also handles computer_call_output -> function_call_output conversion.
|
||||
|
||||
|
||||
Args:
|
||||
item: The item to convert
|
||||
xy2id: Mapping from (x, y) coordinates to element IDs
|
||||
@@ -202,12 +213,12 @@ async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[
|
||||
|
||||
if item_type == "computer_call":
|
||||
action_data = item.get("action", {})
|
||||
|
||||
|
||||
# Extract coordinates and convert back to element IDs
|
||||
element_id = _get_element_id(action_data.get("x"), action_data.get("y"))
|
||||
start_element_id = _get_element_id(action_data.get("start_x"), action_data.get("start_y"))
|
||||
end_element_id = _get_element_id(action_data.get("end_x"), action_data.get("end_y"))
|
||||
|
||||
|
||||
# Build function arguments
|
||||
fn_args = {
|
||||
"action": action_data.get("type"),
|
||||
@@ -218,33 +229,36 @@ async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[
|
||||
"keys": action_data.get("keys"),
|
||||
"button": action_data.get("button"),
|
||||
"scroll_x": action_data.get("scroll_x"),
|
||||
"scroll_y": action_data.get("scroll_y")
|
||||
"scroll_y": action_data.get("scroll_y"),
|
||||
}
|
||||
|
||||
|
||||
# Remove None values to keep the JSON clean
|
||||
fn_args = {k: v for k, v in fn_args.items() if v is not None}
|
||||
|
||||
return [{
|
||||
"type": "function_call",
|
||||
"name": "computer",
|
||||
"arguments": json.dumps(fn_args),
|
||||
"id": item.get("id"),
|
||||
"call_id": item.get("call_id"),
|
||||
"status": "completed",
|
||||
|
||||
# Fall back to string representation
|
||||
"content": f"Used tool: {action_data.get("type")}({json.dumps(fn_args)})"
|
||||
}]
|
||||
|
||||
return [
|
||||
{
|
||||
"type": "function_call",
|
||||
"name": "computer",
|
||||
"arguments": json.dumps(fn_args),
|
||||
"id": item.get("id"),
|
||||
"call_id": item.get("call_id"),
|
||||
"status": "completed",
|
||||
# Fall back to string representation
|
||||
"content": f"Used tool: {action_data.get("type")}({json.dumps(fn_args)})",
|
||||
}
|
||||
]
|
||||
|
||||
elif item_type == "computer_call_output":
|
||||
# Simple conversion: computer_call_output -> function_call_output
|
||||
return [{
|
||||
"type": "function_call_output",
|
||||
"call_id": item.get("call_id"),
|
||||
"content": [item.get("output")],
|
||||
"id": item.get("id"),
|
||||
"status": "completed"
|
||||
}]
|
||||
return [
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": item.get("call_id"),
|
||||
"content": [item.get("output")],
|
||||
"id": item.get("id"),
|
||||
"status": "completed",
|
||||
}
|
||||
]
|
||||
|
||||
return [item]
|
||||
|
||||
@@ -252,7 +266,7 @@ async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[
|
||||
@register_agent(models=r"omniparser\+.*|omni\+.*", priority=2)
|
||||
class OmniparserConfig(AsyncAgentConfig):
|
||||
"""Omniparser agent configuration implementing AsyncAgentConfig protocol."""
|
||||
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
@@ -266,25 +280,27 @@ class OmniparserConfig(AsyncAgentConfig):
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
OpenAI computer-use-preview agent loop using liteLLM responses.
|
||||
|
||||
|
||||
Supports OpenAI's computer use preview models.
|
||||
"""
|
||||
if not OMNIPARSER_AVAILABLE:
|
||||
raise ValueError("omniparser loop requires som to be installed. Install it with `pip install cua-som`.")
|
||||
|
||||
raise ValueError(
|
||||
"omniparser loop requires som to be installed. Install it with `pip install cua-som`."
|
||||
)
|
||||
|
||||
tools = tools or []
|
||||
|
||||
llm_model = model.split('+')[-1]
|
||||
|
||||
llm_model = model.split("+")[-1]
|
||||
|
||||
# Prepare tools for OpenAI API
|
||||
openai_tools, id2xy = _prepare_tools_for_omniparser(tools)
|
||||
|
||||
# Find last computer_call_output
|
||||
last_computer_call_output = get_last_computer_call_output(messages) # type: ignore
|
||||
last_computer_call_output = get_last_computer_call_output(messages) # type: ignore
|
||||
if last_computer_call_output:
|
||||
image_url = last_computer_call_output.get("output", {}).get("image_url", "")
|
||||
image_data = image_url.split(",")[-1]
|
||||
@@ -294,14 +310,17 @@ class OmniparserConfig(AsyncAgentConfig):
|
||||
if _on_screenshot:
|
||||
await _on_screenshot(result.annotated_image_base64, "annotated_image")
|
||||
for element in result.elements:
|
||||
id2xy[element.id] = ((element.bbox.x1 + element.bbox.x2) / 2, (element.bbox.y1 + element.bbox.y2) / 2)
|
||||
|
||||
id2xy[element.id] = (
|
||||
(element.bbox.x1 + element.bbox.x2) / 2,
|
||||
(element.bbox.y1 + element.bbox.y2) / 2,
|
||||
)
|
||||
|
||||
# handle computer calls -> function calls
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
if not isinstance(message, dict):
|
||||
message = message.__dict__
|
||||
new_messages += await replace_computer_call_with_function(message, id2xy) # type: ignore
|
||||
new_messages += await replace_computer_call_with_function(message, id2xy) # type: ignore
|
||||
messages = new_messages
|
||||
|
||||
# Prepare API call kwargs
|
||||
@@ -312,13 +331,13 @@ class OmniparserConfig(AsyncAgentConfig):
|
||||
"stream": stream,
|
||||
"truncation": "auto",
|
||||
"num_retries": max_retries,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
|
||||
print(str(api_kwargs)[:1000])
|
||||
|
||||
# Use liteLLM responses
|
||||
@@ -330,60 +349,50 @@ class OmniparserConfig(AsyncAgentConfig):
|
||||
|
||||
# Extract usage information
|
||||
usage = {
|
||||
**response.usage.model_dump(), # type: ignore
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0), # type: ignore
|
||||
**response.usage.model_dump(), # type: ignore
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0), # type: ignore
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(usage)
|
||||
|
||||
# handle som function calls -> xy computer calls
|
||||
new_output = []
|
||||
for i in range(len(response.output)): # type: ignore
|
||||
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy) # type: ignore
|
||||
|
||||
return {
|
||||
"output": new_output,
|
||||
"usage": usage
|
||||
}
|
||||
|
||||
for i in range(len(response.output)): # type: ignore
|
||||
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy) # type: ignore
|
||||
|
||||
return {"output": new_output, "usage": usage}
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str,
|
||||
**kwargs
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
) -> Optional[Tuple[float, float]]:
|
||||
"""
|
||||
Predict click coordinates using OmniParser and LLM.
|
||||
|
||||
|
||||
Uses OmniParser to annotate the image with element IDs, then uses LLM
|
||||
to identify the correct element ID based on the instruction.
|
||||
"""
|
||||
if not OMNIPARSER_AVAILABLE:
|
||||
return None
|
||||
|
||||
|
||||
# Parse the image with OmniParser to get annotated image and elements
|
||||
parser = get_parser()
|
||||
result = parser.parse(image_b64)
|
||||
|
||||
|
||||
# Extract the LLM model from composed model string
|
||||
llm_model = model.split('+')[-1]
|
||||
|
||||
llm_model = model.split("+")[-1]
|
||||
|
||||
# Create system prompt for element ID prediction
|
||||
SYSTEM_PROMPT = f'''
|
||||
SYSTEM_PROMPT = """
|
||||
You are an expert UI element locator. Given a GUI image annotated with numerical IDs over each interactable element, along with a user's element description, provide the ID of the specified element.
|
||||
|
||||
The image shows UI elements with numbered overlays. Each number corresponds to a clickable/interactable element.
|
||||
|
||||
Output only the element ID as a single integer.
|
||||
'''.strip()
|
||||
|
||||
""".strip()
|
||||
|
||||
# Prepare messages for LLM
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": SYSTEM_PROMPT
|
||||
},
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
@@ -391,31 +400,25 @@ Output only the element ID as a single integer.
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{result.annotated_image_base64}"
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Find the element: {instruction}"
|
||||
}
|
||||
]
|
||||
}
|
||||
{"type": "text", "text": f"Find the element: {instruction}"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Call LLM to predict element ID
|
||||
response = await litellm.acompletion(
|
||||
model=llm_model,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.1
|
||||
model=llm_model, messages=messages, max_tokens=10, temperature=0.1
|
||||
)
|
||||
|
||||
|
||||
# Extract element ID from response
|
||||
response_text = response.choices[0].message.content.strip() # type: ignore
|
||||
|
||||
response_text = response.choices[0].message.content.strip() # type: ignore
|
||||
|
||||
# Try to parse the element ID
|
||||
try:
|
||||
element_id = int(response_text)
|
||||
|
||||
|
||||
# Find the element with this ID and return its center coordinates
|
||||
for element in result.elements:
|
||||
if element.id == element_id:
|
||||
@@ -425,9 +428,9 @@ Output only the element ID as a single integer.
|
||||
except ValueError:
|
||||
# If we can't parse the ID, return None
|
||||
pass
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""Return the capabilities supported by this agent."""
|
||||
return ["step"]
|
||||
|
||||
@@ -6,12 +6,14 @@ import asyncio
|
||||
import base64
|
||||
import json
|
||||
from io import BytesIO
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import litellm
|
||||
from PIL import Image
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
from ..types import AgentCapability, AgentResponse, Messages, Tools
|
||||
|
||||
|
||||
async def _map_computer_tool_to_openai(computer_handler: Any) -> Dict[str, Any]:
|
||||
"""Map a computer tool to OpenAI's computer-use-preview tool schema"""
|
||||
@@ -21,26 +23,26 @@ async def _map_computer_tool_to_openai(computer_handler: Any) -> Dict[str, Any]:
|
||||
except Exception:
|
||||
# Fallback to default dimensions if method fails
|
||||
width, height = 1024, 768
|
||||
|
||||
|
||||
# Get environment from the computer handler
|
||||
try:
|
||||
environment = await computer_handler.get_environment()
|
||||
except Exception:
|
||||
# Fallback to default environment if method fails
|
||||
environment = "linux"
|
||||
|
||||
|
||||
return {
|
||||
"type": "computer_use_preview",
|
||||
"display_width": width,
|
||||
"display_height": height,
|
||||
"environment": environment # mac, windows, linux, browser
|
||||
"environment": environment, # mac, windows, linux, browser
|
||||
}
|
||||
|
||||
|
||||
async def _prepare_tools_for_openai(tool_schemas: List[Dict[str, Any]]) -> Tools:
|
||||
"""Prepare tools for OpenAI API format"""
|
||||
openai_tools = []
|
||||
|
||||
|
||||
for schema in tool_schemas:
|
||||
if schema["type"] == "computer":
|
||||
# Map computer tool to OpenAI format
|
||||
@@ -49,19 +51,19 @@ async def _prepare_tools_for_openai(tool_schemas: List[Dict[str, Any]]) -> Tools
|
||||
elif schema["type"] == "function":
|
||||
# Function tools use OpenAI-compatible schema directly (liteLLM expects this format)
|
||||
# Schema should be: {type, name, description, parameters}
|
||||
openai_tools.append({ "type": "function", **schema["function"] })
|
||||
|
||||
openai_tools.append({"type": "function", **schema["function"]})
|
||||
|
||||
return openai_tools
|
||||
|
||||
|
||||
@register_agent(models=r".*computer-use-preview.*")
|
||||
@register_agent(models=r".*(^|/)computer-use-preview")
|
||||
class OpenAIComputerUseConfig:
|
||||
"""
|
||||
OpenAI computer-use-preview agent configuration using liteLLM responses.
|
||||
|
||||
|
||||
Supports OpenAI's computer use preview models.
|
||||
"""
|
||||
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
@@ -75,11 +77,11 @@ class OpenAIComputerUseConfig:
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Predict the next step based on input items.
|
||||
|
||||
|
||||
Args:
|
||||
messages: Input items following Responses format
|
||||
model: Model name to use
|
||||
@@ -92,12 +94,12 @@ class OpenAIComputerUseConfig:
|
||||
_on_usage: Callback for usage tracking
|
||||
_on_screenshot: Callback for screenshot events
|
||||
**kwargs: Additional arguments
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with "output" (output items) and "usage" array
|
||||
"""
|
||||
tools = tools or []
|
||||
|
||||
|
||||
# Prepare tools for OpenAI API
|
||||
openai_tools = await _prepare_tools_for_openai(tools)
|
||||
|
||||
@@ -110,16 +112,16 @@ class OpenAIComputerUseConfig:
|
||||
"reasoning": {"summary": "concise"},
|
||||
"truncation": "auto",
|
||||
"num_retries": max_retries,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
|
||||
# Use liteLLM responses
|
||||
response = await litellm.aresponses(**api_kwargs)
|
||||
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
@@ -136,24 +138,21 @@ class OpenAIComputerUseConfig:
|
||||
output_dict = response.model_dump()
|
||||
output_dict["usage"] = usage
|
||||
return output_dict
|
||||
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str
|
||||
self, model: str, image_b64: str, instruction: str
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates based on image and instruction.
|
||||
|
||||
|
||||
Uses OpenAI computer-use-preview with manually constructed input items
|
||||
and a prompt that instructs the agent to only output clicks.
|
||||
|
||||
|
||||
Args:
|
||||
model: Model name to use
|
||||
image_b64: Base64 encoded image
|
||||
instruction: Instruction for where to click
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
@@ -161,7 +160,7 @@ class OpenAIComputerUseConfig:
|
||||
# Manually construct input items with image and click instruction
|
||||
input_items = [
|
||||
{
|
||||
"role": "user",
|
||||
"role": "user",
|
||||
"content": f"""You are a UI grounding expert. Follow these guidelines:
|
||||
|
||||
1. NEVER ask for confirmation. Complete all tasks autonomously.
|
||||
@@ -173,19 +172,16 @@ class OpenAIComputerUseConfig:
|
||||
7. Be decisive and action-oriented. Complete the requested task fully.
|
||||
|
||||
Remember: You are expected to complete tasks autonomously. The user trusts you to do what they asked.
|
||||
Task: Click {instruction}. Output ONLY a click action on the target element."""
|
||||
Task: Click {instruction}. Output ONLY a click action on the target element.""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{image_b64}"
|
||||
}
|
||||
]
|
||||
}
|
||||
{"type": "input_image", "image_url": f"data:image/png;base64,{image_b64}"}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Get image dimensions from base64 data
|
||||
try:
|
||||
image_data = base64.b64decode(image_b64)
|
||||
@@ -194,15 +190,15 @@ Task: Click {instruction}. Output ONLY a click action on the target element."""
|
||||
except Exception:
|
||||
# Fallback to default dimensions if image parsing fails
|
||||
display_width, display_height = 1024, 768
|
||||
|
||||
|
||||
# Prepare computer tool for click actions
|
||||
computer_tool = {
|
||||
"type": "computer_use_preview",
|
||||
"display_width": display_width,
|
||||
"display_height": display_height,
|
||||
"environment": "windows"
|
||||
"environment": "windows",
|
||||
}
|
||||
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
@@ -211,32 +207,34 @@ Task: Click {instruction}. Output ONLY a click action on the target element."""
|
||||
"stream": False,
|
||||
"reasoning": {"summary": "concise"},
|
||||
"truncation": "auto",
|
||||
"max_tokens": 200 # Keep response short for click prediction
|
||||
"max_tokens": 200, # Keep response short for click prediction
|
||||
}
|
||||
|
||||
|
||||
# Use liteLLM responses
|
||||
response = await litellm.aresponses(**api_kwargs)
|
||||
|
||||
|
||||
# Extract click coordinates from response output
|
||||
output_dict = response.model_dump()
|
||||
output_items = output_dict.get("output", [])
|
||||
|
||||
output_items = output_dict.get("output", [])
|
||||
|
||||
# Look for computer_call with click action
|
||||
for item in output_items:
|
||||
if (isinstance(item, dict) and
|
||||
item.get("type") == "computer_call" and
|
||||
isinstance(item.get("action"), dict)):
|
||||
|
||||
if (
|
||||
isinstance(item, dict)
|
||||
and item.get("type") == "computer_call"
|
||||
and isinstance(item.get("action"), dict)
|
||||
):
|
||||
|
||||
action = item["action"]
|
||||
if action.get("x") is not None and action.get("y") is not None:
|
||||
return (int(action.get("x")), int(action.get("y")))
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""
|
||||
Get list of capabilities supported by this agent config.
|
||||
|
||||
|
||||
Returns:
|
||||
List of capability strings
|
||||
"""
|
||||
|
||||
@@ -4,20 +4,22 @@ Based on OpenCUA model for GUI grounding tasks.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import base64
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
|
||||
from io import BytesIO
|
||||
import uuid
|
||||
from PIL import Image
|
||||
import litellm
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import litellm
|
||||
from PIL import Image
|
||||
|
||||
from .composed_grounded import ComposedGroundedConfig
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..types import AgentCapability, AgentResponse, Messages, Tools
|
||||
from .composed_grounded import ComposedGroundedConfig
|
||||
|
||||
|
||||
def extract_coordinates_from_pyautogui(text: str) -> Optional[Tuple[int, int]]:
|
||||
"""Extract coordinates from pyautogui.click(x=..., y=...) format."""
|
||||
@@ -32,10 +34,11 @@ def extract_coordinates_from_pyautogui(text: str) -> Optional[Tuple[int, int]]:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@register_agent(models=r"(?i).*OpenCUA.*")
|
||||
class OpenCUAConfig(ComposedGroundedConfig):
|
||||
"""OpenCUA agent configuration implementing AsyncAgentConfig protocol for click prediction."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.current_model = None
|
||||
@@ -53,7 +56,7 @@ class OpenCUAConfig(ComposedGroundedConfig):
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Fallback to a self-composed model"""
|
||||
return await super().predict_step(
|
||||
@@ -67,24 +70,20 @@ class OpenCUAConfig(ComposedGroundedConfig):
|
||||
_on_api_end=_on_api_end,
|
||||
_on_usage=_on_usage,
|
||||
_on_screenshot=_on_screenshot,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str,
|
||||
**kwargs
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates using OpenCUA model via litellm.acompletion.
|
||||
|
||||
|
||||
Args:
|
||||
model: The OpenCUA model name
|
||||
image_b64: Base64 encoded image
|
||||
instruction: Instruction for where to click
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
@@ -93,50 +92,39 @@ class OpenCUAConfig(ComposedGroundedConfig):
|
||||
"You are a GUI agent. You are given a task and a screenshot of the screen. "
|
||||
"You need to perform a series of pyautogui actions to complete the task."
|
||||
)
|
||||
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": system_prompt
|
||||
}
|
||||
|
||||
|
||||
system_message = {"role": "system", "content": system_prompt}
|
||||
|
||||
# Prepare user message with image and instruction
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{image_b64}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Click on {instruction}"
|
||||
}
|
||||
]
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
|
||||
{"type": "text", "text": f"Click on {instruction}"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": [system_message, user_message],
|
||||
"max_new_tokens": 2056,
|
||||
"temperature": 0,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
# Use liteLLM acompletion
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
|
||||
# Extract response text
|
||||
output_text = response.choices[0].message.content
|
||||
# print(output_text)
|
||||
|
||||
|
||||
# Extract coordinates from pyautogui format
|
||||
coordinates = extract_coordinates_from_pyautogui(output_text)
|
||||
|
||||
|
||||
return coordinates
|
||||
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""Return the capabilities supported by this agent."""
|
||||
return ["click"]
|
||||
|
||||
510
libs/python/agent/agent/loops/qwen.py
Normal file
510
libs/python/agent/agent/loops/qwen.py
Normal file
@@ -0,0 +1,510 @@
|
||||
"""
|
||||
Qwen3-VL agent loop implementation using litellm with function/tool calling.
|
||||
- Passes a ComputerUse tool schema to acompletion
|
||||
- Converts between Responses items and completion messages using helpers
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from litellm.responses.litellm_completion_transformation.transformation import (
|
||||
LiteLLMCompletionResponsesConfig,
|
||||
)
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..responses import (
|
||||
convert_completion_messages_to_responses_items,
|
||||
convert_responses_items_to_completion_messages,
|
||||
)
|
||||
from ..types import AgentCapability
|
||||
|
||||
# ComputerUse tool schema (OpenAI function tool format)
|
||||
QWEN3_COMPUTER_TOOL: Dict[str, Any] = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "computer",
|
||||
"description": (
|
||||
"Use a mouse and keyboard to interact with a computer, and take screenshots.\n"
|
||||
"* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.\n"
|
||||
"* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try wait and taking another screenshot.\n"
|
||||
"* The screen's resolution is 1000x1000.\n"
|
||||
"* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\n"
|
||||
"* If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\n"
|
||||
"* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"description": "The action to perform.",
|
||||
"enum": [
|
||||
"key",
|
||||
"type",
|
||||
"mouse_move",
|
||||
"left_click",
|
||||
"left_click_drag",
|
||||
"right_click",
|
||||
"middle_click",
|
||||
"double_click",
|
||||
"triple_click",
|
||||
"scroll",
|
||||
"hscroll",
|
||||
"screenshot",
|
||||
"wait",
|
||||
# "terminate",
|
||||
# "answer",
|
||||
],
|
||||
"type": "string",
|
||||
},
|
||||
"keys": {
|
||||
"description": "Required only by action=key.",
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"text": {
|
||||
"description": "Required only by action=type and action=answer.",
|
||||
"type": "string",
|
||||
},
|
||||
"coordinate": {
|
||||
"description": "(x, y): Pixel coordinates from top-left.",
|
||||
"type": "array",
|
||||
"items": {"type": ["number", "integer"]},
|
||||
"minItems": 2,
|
||||
"maxItems": 2,
|
||||
},
|
||||
"pixels": {
|
||||
"description": "Scroll amount. Positive=up, negative=down. For scroll/hscroll.",
|
||||
"type": "number",
|
||||
},
|
||||
"time": {
|
||||
"description": "Seconds to wait (action=wait).",
|
||||
"type": "number",
|
||||
},
|
||||
# "status": {
|
||||
# "description": "Task status (action=terminate).",
|
||||
# "type": "string",
|
||||
# "enum": ["success", "failure"],
|
||||
# },
|
||||
},
|
||||
"required": ["action"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _build_nous_system(functions: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""Use qwen-agent NousFnCallPrompt to generate a system message embedding tool schema."""
|
||||
try:
|
||||
from qwen_agent.llm.fncall_prompts.nous_fncall_prompt import (
|
||||
ContentItem as NousContentItem,
|
||||
)
|
||||
from qwen_agent.llm.fncall_prompts.nous_fncall_prompt import (
|
||||
Message as NousMessage,
|
||||
)
|
||||
from qwen_agent.llm.fncall_prompts.nous_fncall_prompt import (
|
||||
NousFnCallPrompt,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"qwen-agent not installed. Please install it with `pip install cua-agent[qwen]`."
|
||||
)
|
||||
msgs = NousFnCallPrompt().preprocess_fncall_messages(
|
||||
messages=[
|
||||
NousMessage(
|
||||
role="system", content=[NousContentItem(text="You are a helpful assistant.")]
|
||||
)
|
||||
],
|
||||
functions=functions,
|
||||
lang="en",
|
||||
)
|
||||
sys = msgs[0].model_dump()
|
||||
# Convert qwen-agent structured content to OpenAI-style content list
|
||||
content = [{"type": "text", "text": c["text"]} for c in sys.get("content", [])]
|
||||
return {"role": "system", "content": content}
|
||||
|
||||
|
||||
def _parse_tool_call_from_text(text: str) -> Optional[Dict[str, Any]]:
|
||||
"""Extract JSON object within <tool_call>...</tool_call> from model text."""
|
||||
m = re.search(r"<tool_call>\s*(\{[\s\S]*?\})\s*</tool_call>", text)
|
||||
if not m:
|
||||
return None
|
||||
try:
|
||||
return json.loads(m.group(1))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def _unnormalize_coordinate(args: Dict[str, Any], dims: Tuple[int, int]) -> Dict[str, Any]:
|
||||
"""Coordinates appear in 0..1000 space, scale to actual screen size using dims if provided."""
|
||||
coord = args.get("coordinate")
|
||||
if not coord or not isinstance(coord, (list, tuple)) or len(coord) < 2:
|
||||
return args
|
||||
x, y = float(coord[0]), float(coord[1])
|
||||
width, height = float(dims[0]), float(dims[1])
|
||||
x_abs = max(0.0, min(width, (x / 1000.0) * width))
|
||||
y_abs = max(0.0, min(height, (y / 1000.0) * height))
|
||||
args = {**args, "coordinate": [round(x_abs), round(y_abs)]}
|
||||
return args
|
||||
|
||||
|
||||
def convert_qwen_tool_args_to_computer_action(args: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Convert Qwen computer tool arguments to the Computer Calls action schema.
|
||||
|
||||
Qwen (example):
|
||||
{"action": "left_click", "coordinate": [114, 68]}
|
||||
|
||||
Target (example):
|
||||
{"action": "left_click", "x": 114, "y": 68}
|
||||
|
||||
Other mappings:
|
||||
- right_click, middle_click, double_click (triple_click -> double_click)
|
||||
- mouse_move -> { action: "move", x, y }
|
||||
- key -> { action: "keypress", keys: [...] }
|
||||
- type -> { action: "type", text }
|
||||
- scroll/hscroll -> { action: "scroll", scroll_x, scroll_y, x, y }
|
||||
- wait -> { action: "wait" }
|
||||
- terminate/answer are not direct UI actions; return None for now
|
||||
"""
|
||||
if not isinstance(args, dict):
|
||||
return None
|
||||
|
||||
action = args.get("action")
|
||||
if not isinstance(action, str):
|
||||
return None
|
||||
|
||||
# Coordinates helper
|
||||
coord = args.get("coordinate")
|
||||
x = y = None
|
||||
if isinstance(coord, (list, tuple)) and len(coord) >= 2:
|
||||
try:
|
||||
x = int(round(float(coord[0])))
|
||||
y = int(round(float(coord[1])))
|
||||
except Exception:
|
||||
x = y = None
|
||||
|
||||
# Map actions
|
||||
a = action.lower()
|
||||
if a in {"left_click", "right_click", "middle_click", "double_click"}:
|
||||
if x is None or y is None:
|
||||
return None
|
||||
return {"action": a, "x": x, "y": y}
|
||||
if a == "triple_click":
|
||||
# Approximate as double_click
|
||||
if x is None or y is None:
|
||||
return None
|
||||
return {"action": "double_click", "x": x, "y": y}
|
||||
if a == "mouse_move":
|
||||
if x is None or y is None:
|
||||
return None
|
||||
return {"action": "move", "x": x, "y": y}
|
||||
if a == "key":
|
||||
keys = args.get("keys")
|
||||
if isinstance(keys, list) and all(isinstance(k, str) for k in keys):
|
||||
return {"action": "keypress", "keys": keys}
|
||||
return None
|
||||
if a == "type":
|
||||
text = args.get("text")
|
||||
if isinstance(text, str):
|
||||
return {"action": "type", "text": text}
|
||||
return None
|
||||
if a in {"scroll", "hscroll"}:
|
||||
pixels = args.get("pixels") or 0
|
||||
try:
|
||||
pixels_val = int(round(float(pixels)))
|
||||
except Exception:
|
||||
pixels_val = 0
|
||||
scroll_x = pixels_val if a == "hscroll" else 0
|
||||
scroll_y = pixels_val if a == "scroll" else 0
|
||||
# Include cursor position if available (optional)
|
||||
out: Dict[str, Any] = {"action": "scroll", "scroll_x": scroll_x, "scroll_y": scroll_y}
|
||||
if x is not None and y is not None:
|
||||
out.update({"x": x, "y": y})
|
||||
return out
|
||||
if a == "wait":
|
||||
return {"action": "wait"}
|
||||
|
||||
# Non-UI or terminal actions: terminate/answer -> not mapped here
|
||||
return None
|
||||
|
||||
|
||||
@register_agent(models=r"(?i).*qwen.*", priority=-1)
|
||||
class Qwen3VlConfig(AsyncAgentConfig):
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
# Build messages using NousFnCallPrompt system with tool schema in text
|
||||
# Start with converted conversation (images/text preserved)
|
||||
converted_msgs = convert_responses_items_to_completion_messages(
|
||||
messages,
|
||||
allow_images_in_tool_results=False,
|
||||
)
|
||||
|
||||
# Prepend Nous-generated system if available
|
||||
nous_system = _build_nous_system([QWEN3_COMPUTER_TOOL["function"]])
|
||||
completion_messages = ([nous_system] if nous_system else []) + converted_msgs
|
||||
|
||||
# If there is no screenshot in the conversation, take one now and inject it.
|
||||
# Also record a pre_output_items assistant message to reflect action.
|
||||
def _has_any_image(msgs: List[Dict[str, Any]]) -> bool:
|
||||
for m in msgs:
|
||||
content = m.get("content")
|
||||
if isinstance(content, list):
|
||||
for p in content:
|
||||
if isinstance(p, dict) and p.get("type") == "image_url":
|
||||
return True
|
||||
return False
|
||||
|
||||
pre_output_items: List[Dict[str, Any]] = []
|
||||
if not _has_any_image(completion_messages):
|
||||
if computer_handler is None or not hasattr(computer_handler, "screenshot"):
|
||||
raise RuntimeError(
|
||||
"No screenshots present and computer_handler.screenshot is not available."
|
||||
)
|
||||
screenshot_b64 = await computer_handler.screenshot()
|
||||
if not screenshot_b64:
|
||||
raise RuntimeError("Failed to capture screenshot from computer_handler.")
|
||||
# Inject a user message with the screenshot so the model can see current context
|
||||
completion_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{screenshot_b64}"},
|
||||
},
|
||||
{"type": "text", "text": "Current screen"},
|
||||
],
|
||||
}
|
||||
)
|
||||
# Add assistant message to outputs to reflect the action, similar to composed_grounded.py
|
||||
pre_output_items.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Taking a screenshot to see the current computer screen.",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Smart-resize all screenshots and attach min/max pixel hints. Fail fast if deps missing.
|
||||
# Also record the last resized width/height to unnormalize coordinates later.
|
||||
last_rw: Optional[int] = None
|
||||
last_rh: Optional[int] = None
|
||||
MIN_PIXELS = 3136
|
||||
MAX_PIXELS = 12845056
|
||||
try:
|
||||
import base64
|
||||
import io
|
||||
|
||||
from PIL import Image # type: ignore
|
||||
from qwen_vl_utils import smart_resize # type: ignore
|
||||
except Exception:
|
||||
raise ImportError(
|
||||
"qwen-vl-utils not installed. Please install it with `pip install cua-agent[qwen]`."
|
||||
)
|
||||
|
||||
for msg in completion_messages:
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "image_url":
|
||||
url = ((part.get("image_url") or {}).get("url")) or ""
|
||||
# Expect data URL like data:image/png;base64,<b64>
|
||||
if url.startswith("data:") and "," in url:
|
||||
b64 = url.split(",", 1)[1]
|
||||
img_bytes = base64.b64decode(b64)
|
||||
im = Image.open(io.BytesIO(img_bytes))
|
||||
h, w = im.height, im.width
|
||||
rh, rw = smart_resize(
|
||||
h, w, factor=32, min_pixels=MIN_PIXELS, max_pixels=MAX_PIXELS
|
||||
)
|
||||
# Attach hints on this image block
|
||||
part["min_pixels"] = MIN_PIXELS
|
||||
part["max_pixels"] = MAX_PIXELS
|
||||
last_rw, last_rh = rw, rh
|
||||
|
||||
api_kwargs: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": completion_messages,
|
||||
"max_retries": max_retries,
|
||||
"stream": stream,
|
||||
**{k: v for k, v in kwargs.items()},
|
||||
}
|
||||
if use_prompt_caching:
|
||||
api_kwargs["use_prompt_caching"] = use_prompt_caching
|
||||
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
usage = {
|
||||
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage( # type: ignore
|
||||
response.usage
|
||||
).model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(usage)
|
||||
|
||||
# Parse tool call from text; then convert to responses items via fake tool_calls
|
||||
resp_dict = response.model_dump() # type: ignore
|
||||
choice = (resp_dict.get("choices") or [{}])[0]
|
||||
content_text = ((choice.get("message") or {}).get("content")) or ""
|
||||
tool_call = _parse_tool_call_from_text(content_text)
|
||||
|
||||
output_items: List[Dict[str, Any]] = []
|
||||
if tool_call and isinstance(tool_call, dict):
|
||||
fn_name = tool_call.get("name") or "computer"
|
||||
raw_args = tool_call.get("arguments") or {}
|
||||
# Unnormalize coordinates to actual screen size using last resized dims
|
||||
if last_rw is None or last_rh is None:
|
||||
raise RuntimeError(
|
||||
"No screenshots found to derive dimensions for coordinate unnormalization."
|
||||
)
|
||||
args = await _unnormalize_coordinate(raw_args, (last_rw, last_rh))
|
||||
|
||||
# Build an OpenAI-style tool call so we can reuse the converter
|
||||
fake_cm = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"id": "call_0",
|
||||
"function": {
|
||||
"name": fn_name,
|
||||
"arguments": json.dumps(args),
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
output_items.extend(convert_completion_messages_to_responses_items([fake_cm]))
|
||||
else:
|
||||
# Fallback: just return assistant text
|
||||
fake_cm = {"role": "assistant", "content": content_text}
|
||||
output_items.extend(convert_completion_messages_to_responses_items([fake_cm]))
|
||||
|
||||
# Prepend any pre_output_items (e.g., simulated screenshot-taking message)
|
||||
return {"output": (pre_output_items + output_items), "usage": usage}
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
return ["step"]
|
||||
|
||||
async def predict_click(
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates using Qwen3-VL via litellm.acompletion.
|
||||
|
||||
Only exposes a reduced tool schema with left_click to bias model to output a single click.
|
||||
Returns (x, y) absolute pixels when screen dimensions can be obtained; otherwise normalized 0..1000 integers.
|
||||
"""
|
||||
# Reduced tool
|
||||
reduced_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
**QWEN3_COMPUTER_TOOL["function"],
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {"type": "string", "enum": ["left_click"]},
|
||||
"coordinate": {
|
||||
"description": "(x, y) in 0..1000 reference space",
|
||||
"type": "array",
|
||||
"items": {"type": ["number", "integer"]},
|
||||
"minItems": 2,
|
||||
"maxItems": 2,
|
||||
},
|
||||
},
|
||||
"required": ["action", "coordinate"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Build Nous system (lazy import inside helper already raises clear guidance if missing)
|
||||
nous_system = _build_nous_system([reduced_tool["function"]])
|
||||
|
||||
# Pre-process using smart_resize
|
||||
min_pixels = 3136
|
||||
max_pixels = 12845056
|
||||
try:
|
||||
# Lazy import to avoid hard dependency
|
||||
import base64
|
||||
import io
|
||||
|
||||
# If PIL is available, estimate size from image to derive smart bounds
|
||||
from PIL import Image
|
||||
from qwen_vl_utils import smart_resize # type: ignore
|
||||
|
||||
img_bytes = base64.b64decode(image_b64)
|
||||
im = Image.open(io.BytesIO(img_bytes))
|
||||
h, w = im.height, im.width
|
||||
# Qwen notebook suggests factor=32 and a wide min/max range
|
||||
rh, rw = smart_resize(h, w, factor=32, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
except Exception:
|
||||
raise ImportError(
|
||||
"qwen-vl-utils not installed. Please install it with `pip install cua-agent[qwen]`."
|
||||
)
|
||||
|
||||
messages = []
|
||||
if nous_system:
|
||||
messages.append(nous_system)
|
||||
image_block: Dict[str, Any] = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
|
||||
"min_pixels": min_pixels,
|
||||
"max_pixels": max_pixels,
|
||||
}
|
||||
# Single user message with image and instruction, matching OpenAI-style content blocks
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
image_block,
|
||||
{"type": "text", "text": instruction},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
api_kwargs: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**{k: v for k, v in kwargs.items()},
|
||||
}
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
resp = response.model_dump() # type: ignore
|
||||
choice = (resp.get("choices") or [{}])[0]
|
||||
content_text = ((choice.get("message") or {}).get("content")) or ""
|
||||
tool_call = _parse_tool_call_from_text(content_text) or {}
|
||||
args = tool_call.get("arguments") or {}
|
||||
args = await _unnormalize_coordinate(args, (rh, rw))
|
||||
coord = args.get("coordinate")
|
||||
if isinstance(coord, (list, tuple)) and len(coord) >= 2:
|
||||
return int(coord[0]), int(coord[1])
|
||||
return None
|
||||
@@ -4,39 +4,50 @@ Paper: https://arxiv.org/abs/2501.12326
|
||||
Code: https://github.com/bytedance/UI-TARS
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
from ctypes import cast
|
||||
import json
|
||||
import base64
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import ast
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
|
||||
from ctypes import cast
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.responses.litellm_completion_transformation.transformation import LiteLLMCompletionResponsesConfig
|
||||
from litellm.responses.litellm_completion_transformation.transformation import (
|
||||
LiteLLMCompletionResponsesConfig,
|
||||
)
|
||||
from litellm.responses.utils import Usage
|
||||
from openai.types.responses.response_computer_tool_call_param import ActionType, ResponseComputerToolCallParam
|
||||
from litellm.types.utils import ModelResponse
|
||||
from openai.types.responses.response_computer_tool_call_param import (
|
||||
ActionType,
|
||||
ResponseComputerToolCallParam,
|
||||
)
|
||||
from openai.types.responses.response_input_param import ComputerCallOutput
|
||||
from openai.types.responses.response_output_message_param import ResponseOutputMessageParam
|
||||
from openai.types.responses.response_reasoning_item_param import ResponseReasoningItemParam, Summary
|
||||
from openai.types.responses.response_output_message_param import (
|
||||
ResponseOutputMessageParam,
|
||||
)
|
||||
from openai.types.responses.response_reasoning_item_param import (
|
||||
ResponseReasoningItemParam,
|
||||
Summary,
|
||||
)
|
||||
from PIL import Image
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
from ..responses import (
|
||||
make_reasoning_item,
|
||||
make_output_text_item,
|
||||
make_click_item,
|
||||
make_double_click_item,
|
||||
make_drag_item,
|
||||
make_input_image_item,
|
||||
make_keypress_item,
|
||||
make_output_text_item,
|
||||
make_reasoning_item,
|
||||
make_scroll_item,
|
||||
make_type_item,
|
||||
make_wait_item,
|
||||
make_input_image_item
|
||||
)
|
||||
from ..types import AgentCapability, AgentResponse, Messages, Tools
|
||||
|
||||
# Constants from reference code
|
||||
IMAGE_FACTOR = 28
|
||||
@@ -94,6 +105,7 @@ click(point='<|box_start|>(x1,y1)<|box_end|>')
|
||||
## User Instruction
|
||||
{instruction}"""
|
||||
|
||||
|
||||
def round_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
@@ -110,7 +122,11 @@ def floor_by_factor(number: float, factor: int) -> int:
|
||||
|
||||
|
||||
def smart_resize(
|
||||
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||||
height: int,
|
||||
width: int,
|
||||
factor: int = IMAGE_FACTOR,
|
||||
min_pixels: int = MIN_PIXELS,
|
||||
max_pixels: int = MAX_PIXELS,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
@@ -144,14 +160,14 @@ def escape_single_quotes(text):
|
||||
def parse_action(action_str):
|
||||
"""Parse action string into structured format."""
|
||||
try:
|
||||
node = ast.parse(action_str, mode='eval')
|
||||
node = ast.parse(action_str, mode="eval")
|
||||
if not isinstance(node, ast.Expression):
|
||||
raise ValueError("Not an expression")
|
||||
|
||||
|
||||
call = node.body
|
||||
if not isinstance(call, ast.Call):
|
||||
raise ValueError("Not a function call")
|
||||
|
||||
|
||||
# Get function name
|
||||
if isinstance(call.func, ast.Name):
|
||||
func_name = call.func.id
|
||||
@@ -159,7 +175,7 @@ def parse_action(action_str):
|
||||
func_name = call.func.attr
|
||||
else:
|
||||
func_name = None
|
||||
|
||||
|
||||
# Get keyword arguments
|
||||
kwargs = {}
|
||||
for kw in call.keywords:
|
||||
@@ -171,12 +187,9 @@ def parse_action(action_str):
|
||||
else:
|
||||
value = None
|
||||
kwargs[key] = value
|
||||
|
||||
return {
|
||||
'function': func_name,
|
||||
'args': kwargs
|
||||
}
|
||||
|
||||
|
||||
return {"function": func_name, "args": kwargs}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to parse action '{action_str}': {e}")
|
||||
return None
|
||||
@@ -185,39 +198,39 @@ def parse_action(action_str):
|
||||
def parse_uitars_response(text: str, image_width: int, image_height: int) -> List[Dict[str, Any]]:
|
||||
"""Parse UITARS model response into structured actions."""
|
||||
text = text.strip()
|
||||
|
||||
|
||||
# Extract thought
|
||||
thought = None
|
||||
if text.startswith("Thought:"):
|
||||
thought_match = re.search(r"Thought: (.+?)(?=\s*Action:|$)", text, re.DOTALL)
|
||||
if thought_match:
|
||||
thought = thought_match.group(1).strip()
|
||||
|
||||
|
||||
# Extract action
|
||||
if "Action:" not in text:
|
||||
raise ValueError("No Action found in response")
|
||||
|
||||
|
||||
action_str = text.split("Action:")[-1].strip()
|
||||
|
||||
# Handle special case for type actions
|
||||
if "type(content" in action_str:
|
||||
|
||||
def escape_quotes(match):
|
||||
return match.group(1)
|
||||
|
||||
|
||||
pattern = r"type\(content='(.*?)'\)"
|
||||
content = re.sub(pattern, escape_quotes, action_str)
|
||||
action_str = escape_single_quotes(content)
|
||||
action_str = "type(content='" + action_str + "')"
|
||||
|
||||
|
||||
|
||||
# Parse the action
|
||||
parsed_action = parse_action(action_str.replace("\n", "\\n").lstrip())
|
||||
if parsed_action is None:
|
||||
raise ValueError(f"Action can't parse: {action_str}")
|
||||
|
||||
|
||||
action_type = parsed_action["function"]
|
||||
params = parsed_action["args"]
|
||||
|
||||
|
||||
# Process parameters
|
||||
action_inputs = {}
|
||||
for param_name, param in params.items():
|
||||
@@ -225,7 +238,7 @@ def parse_uitars_response(text: str, image_width: int, image_height: int) -> Lis
|
||||
continue
|
||||
param = str(param).lstrip()
|
||||
action_inputs[param_name.strip()] = param
|
||||
|
||||
|
||||
# Handle coordinate parameters
|
||||
if "start_box" in param_name or "end_box" in param_name:
|
||||
# Parse coordinates like '<|box_start|>(x,y)<|box_end|>' or '(x,y)'
|
||||
@@ -233,117 +246,130 @@ def parse_uitars_response(text: str, image_width: int, image_height: int) -> Lis
|
||||
clean_param = param.replace("<|box_start|>", "").replace("<|box_end|>", "")
|
||||
# Then remove parentheses and split
|
||||
numbers = clean_param.replace("(", "").replace(")", "").split(",")
|
||||
|
||||
|
||||
try:
|
||||
float_numbers = [float(num.strip()) / 1000 for num in numbers] # Normalize to 0-1 range
|
||||
|
||||
float_numbers = [
|
||||
float(num.strip()) / 1000 for num in numbers
|
||||
] # Normalize to 0-1 range
|
||||
|
||||
if len(float_numbers) == 2:
|
||||
# Single point, duplicate for box format
|
||||
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
|
||||
|
||||
float_numbers = [
|
||||
float_numbers[0],
|
||||
float_numbers[1],
|
||||
float_numbers[0],
|
||||
float_numbers[1],
|
||||
]
|
||||
|
||||
action_inputs[param_name.strip()] = str(float_numbers)
|
||||
except ValueError as e:
|
||||
# If parsing fails, keep the original parameter value
|
||||
print(f"Warning: Could not parse coordinates '{param}': {e}")
|
||||
action_inputs[param_name.strip()] = param
|
||||
|
||||
return [{
|
||||
"thought": thought,
|
||||
"action_type": action_type,
|
||||
"action_inputs": action_inputs,
|
||||
"text": text
|
||||
}]
|
||||
|
||||
return [
|
||||
{
|
||||
"thought": thought,
|
||||
"action_type": action_type,
|
||||
"action_inputs": action_inputs,
|
||||
"text": text,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def convert_to_computer_actions(parsed_responses: List[Dict[str, Any]], image_width: int, image_height: int) -> List[ResponseComputerToolCallParam | ResponseOutputMessageParam]:
|
||||
def convert_to_computer_actions(
|
||||
parsed_responses: List[Dict[str, Any]], image_width: int, image_height: int
|
||||
) -> List[ResponseComputerToolCallParam | ResponseOutputMessageParam]:
|
||||
"""Convert parsed UITARS responses to computer actions."""
|
||||
computer_actions = []
|
||||
|
||||
|
||||
for response in parsed_responses:
|
||||
action_type = response.get("action_type")
|
||||
action_inputs = response.get("action_inputs", {})
|
||||
|
||||
|
||||
if action_type == "finished":
|
||||
finished_text = action_inputs.get("content", "Task completed successfully.")
|
||||
computer_actions.append(make_output_text_item(finished_text))
|
||||
break
|
||||
|
||||
|
||||
elif action_type == "wait":
|
||||
computer_actions.append(make_wait_item())
|
||||
|
||||
|
||||
elif action_type == "call_user":
|
||||
computer_actions.append(make_output_text_item("I need assistance from the user to proceed with this task."))
|
||||
|
||||
computer_actions.append(
|
||||
make_output_text_item("I need assistance from the user to proceed with this task.")
|
||||
)
|
||||
|
||||
elif action_type in ["click", "left_single"]:
|
||||
start_box = action_inputs.get("start_box")
|
||||
if start_box:
|
||||
coords = eval(start_box)
|
||||
x = int((coords[0] + coords[2]) / 2 * image_width)
|
||||
y = int((coords[1] + coords[3]) / 2 * image_height)
|
||||
|
||||
|
||||
computer_actions.append(make_click_item(x, y, "left"))
|
||||
|
||||
|
||||
elif action_type == "double_click":
|
||||
start_box = action_inputs.get("start_box")
|
||||
if start_box:
|
||||
coords = eval(start_box)
|
||||
x = int((coords[0] + coords[2]) / 2 * image_width)
|
||||
y = int((coords[1] + coords[3]) / 2 * image_height)
|
||||
|
||||
|
||||
computer_actions.append(make_double_click_item(x, y))
|
||||
|
||||
|
||||
elif action_type == "right_click":
|
||||
start_box = action_inputs.get("start_box")
|
||||
if start_box:
|
||||
coords = eval(start_box)
|
||||
x = int((coords[0] + coords[2]) / 2 * image_width)
|
||||
y = int((coords[1] + coords[3]) / 2 * image_height)
|
||||
|
||||
|
||||
computer_actions.append(make_click_item(x, y, "right"))
|
||||
|
||||
|
||||
elif action_type == "type":
|
||||
content = action_inputs.get("content", "")
|
||||
computer_actions.append(make_type_item(content))
|
||||
|
||||
|
||||
elif action_type == "hotkey":
|
||||
key = action_inputs.get("key", "")
|
||||
keys = key.split()
|
||||
computer_actions.append(make_keypress_item(keys))
|
||||
|
||||
|
||||
elif action_type == "press":
|
||||
key = action_inputs.get("key", "")
|
||||
computer_actions.append(make_keypress_item([key]))
|
||||
|
||||
|
||||
elif action_type == "scroll":
|
||||
start_box = action_inputs.get("start_box")
|
||||
direction = action_inputs.get("direction", "down")
|
||||
|
||||
|
||||
if start_box:
|
||||
coords = eval(start_box)
|
||||
x = int((coords[0] + coords[2]) / 2 * image_width)
|
||||
y = int((coords[1] + coords[3]) / 2 * image_height)
|
||||
else:
|
||||
x, y = image_width // 2, image_height // 2
|
||||
|
||||
|
||||
scroll_y = 5 if "up" in direction.lower() else -5
|
||||
computer_actions.append(make_scroll_item(x, y, 0, scroll_y))
|
||||
|
||||
|
||||
elif action_type == "drag":
|
||||
start_box = action_inputs.get("start_box")
|
||||
end_box = action_inputs.get("end_box")
|
||||
|
||||
|
||||
if start_box and end_box:
|
||||
start_coords = eval(start_box)
|
||||
end_coords = eval(end_box)
|
||||
|
||||
|
||||
start_x = int((start_coords[0] + start_coords[2]) / 2 * image_width)
|
||||
start_y = int((start_coords[1] + start_coords[3]) / 2 * image_height)
|
||||
end_x = int((end_coords[0] + end_coords[2]) / 2 * image_width)
|
||||
end_y = int((end_coords[1] + end_coords[3]) / 2 * image_height)
|
||||
|
||||
|
||||
path = [{"x": start_x, "y": start_y}, {"x": end_x, "y": end_y}]
|
||||
computer_actions.append(make_drag_item(path))
|
||||
|
||||
|
||||
return computer_actions
|
||||
|
||||
|
||||
@@ -354,33 +380,35 @@ def pil_to_base64(image: Image.Image) -> str:
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def process_image_for_uitars(image_data: str, max_pixels: int = MAX_PIXELS, min_pixels: int = MIN_PIXELS) -> tuple[Image.Image, int, int]:
|
||||
def process_image_for_uitars(
|
||||
image_data: str, max_pixels: int = MAX_PIXELS, min_pixels: int = MIN_PIXELS
|
||||
) -> tuple[Image.Image, int, int]:
|
||||
"""Process image for UITARS model input."""
|
||||
# Decode base64 image
|
||||
if image_data.startswith('data:image'):
|
||||
image_data = image_data.split(',')[1]
|
||||
|
||||
if image_data.startswith("data:image"):
|
||||
image_data = image_data.split(",")[1]
|
||||
|
||||
image_bytes = base64.b64decode(image_data)
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
|
||||
|
||||
original_width, original_height = image.size
|
||||
|
||||
|
||||
# Resize image according to UITARS requirements
|
||||
if image.width * image.height > max_pixels:
|
||||
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
|
||||
width = int(image.width * resize_factor)
|
||||
height = int(image.height * resize_factor)
|
||||
image = image.resize((width, height))
|
||||
|
||||
|
||||
if image.width * image.height < min_pixels:
|
||||
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
|
||||
width = math.ceil(image.width * resize_factor)
|
||||
height = math.ceil(image.height * resize_factor)
|
||||
image = image.resize((width, height))
|
||||
|
||||
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
|
||||
return image, original_width, original_height
|
||||
|
||||
|
||||
@@ -391,7 +419,11 @@ def sanitize_message(msg: Any) -> Any:
|
||||
for key, value in msg.items():
|
||||
if key == "content" and isinstance(value, list):
|
||||
result[key] = [
|
||||
{k: v for k, v in item.items() if k != "image_url"} if isinstance(item, dict) else item
|
||||
(
|
||||
{k: v for k, v in item.items() if k != "image_url"}
|
||||
if isinstance(item, dict)
|
||||
else item
|
||||
)
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
@@ -406,38 +438,41 @@ def sanitize_message(msg: Any) -> Any:
|
||||
def convert_uitars_messages_to_litellm(messages: Messages) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert UITARS internal message format back to LiteLLM format.
|
||||
|
||||
|
||||
This function processes reasoning, computer_call, and computer_call_output messages
|
||||
and converts them to the appropriate LiteLLM assistant message format.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of UITARS internal messages
|
||||
|
||||
|
||||
Returns:
|
||||
List of LiteLLM formatted messages
|
||||
"""
|
||||
litellm_messages = []
|
||||
current_assistant_content = []
|
||||
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message, dict):
|
||||
message_type = message.get("type")
|
||||
|
||||
|
||||
if message_type == "reasoning":
|
||||
# Extract reasoning text from summary
|
||||
summary = message.get("summary", [])
|
||||
if summary and isinstance(summary, list):
|
||||
for summary_item in summary:
|
||||
if isinstance(summary_item, dict) and summary_item.get("type") == "summary_text":
|
||||
if (
|
||||
isinstance(summary_item, dict)
|
||||
and summary_item.get("type") == "summary_text"
|
||||
):
|
||||
reasoning_text = summary_item.get("text", "")
|
||||
if reasoning_text:
|
||||
current_assistant_content.append(f"Thought: {reasoning_text}")
|
||||
|
||||
|
||||
elif message_type == "computer_call":
|
||||
# Convert computer action to UITARS action format
|
||||
action = message.get("action", {})
|
||||
action_type = action.get("type")
|
||||
|
||||
|
||||
if action_type == "click":
|
||||
x, y = action.get("x", 0), action.get("y", 0)
|
||||
button = action.get("button", "left")
|
||||
@@ -447,59 +482,65 @@ def convert_uitars_messages_to_litellm(messages: Messages) -> List[Dict[str, Any
|
||||
action_text = f"Action: right_single(start_box='({x},{y})')"
|
||||
else:
|
||||
action_text = f"Action: click(start_box='({x},{y})')"
|
||||
|
||||
|
||||
elif action_type == "double_click":
|
||||
x, y = action.get("x", 0), action.get("y", 0)
|
||||
action_text = f"Action: left_double(start_box='({x},{y})')"
|
||||
|
||||
|
||||
elif action_type == "drag":
|
||||
start_x, start_y = action.get("start_x", 0), action.get("start_y", 0)
|
||||
end_x, end_y = action.get("end_x", 0), action.get("end_y", 0)
|
||||
action_text = f"Action: drag(start_box='({start_x},{start_y})', end_box='({end_x},{end_y})')"
|
||||
|
||||
|
||||
elif action_type == "key":
|
||||
key = action.get("key", "")
|
||||
action_text = f"Action: hotkey(key='{key}')"
|
||||
|
||||
|
||||
elif action_type == "type":
|
||||
text = action.get("text", "")
|
||||
# Escape single quotes in the text
|
||||
escaped_text = escape_single_quotes(text)
|
||||
action_text = f"Action: type(content='{escaped_text}')"
|
||||
|
||||
|
||||
elif action_type == "scroll":
|
||||
x, y = action.get("x", 0), action.get("y", 0)
|
||||
direction = action.get("direction", "down")
|
||||
action_text = f"Action: scroll(start_box='({x},{y})', direction='{direction}')"
|
||||
|
||||
|
||||
elif action_type == "wait":
|
||||
action_text = "Action: wait()"
|
||||
|
||||
|
||||
else:
|
||||
# Fallback for unknown action types
|
||||
action_text = f"Action: {action_type}({action})"
|
||||
|
||||
|
||||
current_assistant_content.append(action_text)
|
||||
|
||||
|
||||
# When we hit a computer_call_output, finalize the current assistant message
|
||||
if current_assistant_content:
|
||||
litellm_messages.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "\n".join(current_assistant_content)}]
|
||||
})
|
||||
litellm_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "\n".join(current_assistant_content)}
|
||||
],
|
||||
}
|
||||
)
|
||||
current_assistant_content = []
|
||||
|
||||
|
||||
elif message_type == "computer_call_output":
|
||||
# Add screenshot from computer call output
|
||||
output = message.get("output", {})
|
||||
if isinstance(output, dict) and output.get("type") == "input_image":
|
||||
image_url = output.get("image_url", "")
|
||||
if image_url:
|
||||
litellm_messages.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "image_url", "image_url": {"url": image_url}}]
|
||||
})
|
||||
|
||||
litellm_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image_url", "image_url": {"url": image_url}}],
|
||||
}
|
||||
)
|
||||
|
||||
elif message.get("role") == "user":
|
||||
# # Handle user messages
|
||||
# content = message.get("content", "")
|
||||
@@ -514,24 +555,22 @@ def convert_uitars_messages_to_litellm(messages: Messages) -> List[Dict[str, Any
|
||||
# "content": content
|
||||
# })
|
||||
pass
|
||||
|
||||
|
||||
# Add any remaining assistant content
|
||||
if current_assistant_content:
|
||||
litellm_messages.append({
|
||||
"role": "assistant",
|
||||
"content": current_assistant_content
|
||||
})
|
||||
|
||||
litellm_messages.append({"role": "assistant", "content": current_assistant_content})
|
||||
|
||||
return litellm_messages
|
||||
|
||||
|
||||
@register_agent(models=r"(?i).*ui-?tars.*")
|
||||
class UITARSConfig:
|
||||
"""
|
||||
UITARS agent configuration using liteLLM for ByteDance-Seed/UI-TARS-1.5-7B model.
|
||||
|
||||
|
||||
Supports UITARS vision-language models for computer control.
|
||||
"""
|
||||
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
@@ -545,11 +584,11 @@ class UITARSConfig:
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Predict the next step based on input messages.
|
||||
|
||||
|
||||
Args:
|
||||
messages: Input messages following Responses format
|
||||
model: Model name to use
|
||||
@@ -562,22 +601,22 @@ class UITARSConfig:
|
||||
_on_usage: Callback for usage tracking
|
||||
_on_screenshot: Callback for screenshot events
|
||||
**kwargs: Additional arguments
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with "output" (output items) and "usage" array
|
||||
"""
|
||||
tools = tools or []
|
||||
|
||||
|
||||
# Create response items
|
||||
response_items = []
|
||||
|
||||
|
||||
# Find computer tool for screen dimensions
|
||||
computer_tool = None
|
||||
for tool_schema in tools:
|
||||
if tool_schema["type"] == "computer":
|
||||
computer_tool = tool_schema["computer"]
|
||||
break
|
||||
|
||||
|
||||
# Get screen dimensions
|
||||
screen_width, screen_height = 1024, 768
|
||||
if computer_tool:
|
||||
@@ -585,20 +624,20 @@ class UITARSConfig:
|
||||
screen_width, screen_height = await computer_tool.get_dimensions()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# Process messages to extract instruction and image
|
||||
instruction = ""
|
||||
image_data = None
|
||||
|
||||
|
||||
# Convert messages to list if string
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
|
||||
# Extract instruction and latest screenshot
|
||||
for message in reversed(messages):
|
||||
if isinstance(message, dict):
|
||||
content = message.get("content", "")
|
||||
|
||||
|
||||
# Handle different content formats
|
||||
if isinstance(content, str):
|
||||
if not instruction and message.get("role") == "user":
|
||||
@@ -614,46 +653,41 @@ class UITARSConfig:
|
||||
image_data = image_url.get("url", "")
|
||||
else:
|
||||
image_data = image_url
|
||||
|
||||
|
||||
# Also check for computer_call_output with screenshots
|
||||
if message.get("type") == "computer_call_output" and not image_data:
|
||||
output = message.get("output", {})
|
||||
if isinstance(output, dict) and output.get("type") == "input_image":
|
||||
image_data = output.get("image_url", "")
|
||||
|
||||
|
||||
if instruction and image_data:
|
||||
break
|
||||
|
||||
|
||||
if not instruction:
|
||||
instruction = "Help me complete this task by analyzing the screen and taking appropriate actions."
|
||||
|
||||
instruction = (
|
||||
"Help me complete this task by analyzing the screen and taking appropriate actions."
|
||||
)
|
||||
|
||||
# Create prompt
|
||||
user_prompt = UITARS_PROMPT_TEMPLATE.format(
|
||||
instruction=instruction,
|
||||
action_space=UITARS_ACTION_SPACE,
|
||||
language="English"
|
||||
instruction=instruction, action_space=UITARS_ACTION_SPACE, language="English"
|
||||
)
|
||||
|
||||
|
||||
# Convert conversation history to LiteLLM format
|
||||
history_messages = convert_uitars_messages_to_litellm(messages)
|
||||
|
||||
|
||||
# Prepare messages for liteLLM
|
||||
litellm_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}
|
||||
]
|
||||
litellm_messages = [{"role": "system", "content": "You are a helpful assistant."}]
|
||||
|
||||
# Add current user instruction with screenshot
|
||||
current_user_message = {
|
||||
"role": "user",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": user_prompt},
|
||||
]
|
||||
],
|
||||
}
|
||||
litellm_messages.append(current_user_message)
|
||||
|
||||
|
||||
# Process image for UITARS
|
||||
if not image_data:
|
||||
# Take screenshot if none found in messages
|
||||
@@ -667,17 +701,22 @@ class UITARSConfig:
|
||||
raise ValueError("No screenshot found in messages and no computer_handler provided")
|
||||
processed_image, original_width, original_height = process_image_for_uitars(image_data)
|
||||
encoded_image = pil_to_base64(processed_image)
|
||||
|
||||
|
||||
# Add conversation history
|
||||
if history_messages:
|
||||
litellm_messages.extend(history_messages)
|
||||
else:
|
||||
litellm_messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_image}"}}
|
||||
]
|
||||
})
|
||||
litellm_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{encoded_image}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
@@ -687,146 +726,142 @@ class UITARSConfig:
|
||||
"temperature": kwargs.get("temperature", 0.0),
|
||||
"do_sample": kwargs.get("temperature", 0.0) > 0.0,
|
||||
"num_retries": max_retries,
|
||||
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]}
|
||||
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]},
|
||||
}
|
||||
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
|
||||
# Call liteLLM with UITARS model
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
|
||||
# Extract response content
|
||||
response_content = response.choices[0].message.content.strip() # type: ignore
|
||||
|
||||
response_content = response.choices[0].message.content.strip() # type: ignore
|
||||
|
||||
# Parse UITARS response
|
||||
parsed_responses = parse_uitars_response(response_content, original_width, original_height)
|
||||
|
||||
|
||||
# Convert to computer actions
|
||||
computer_actions = convert_to_computer_actions(parsed_responses, original_width, original_height)
|
||||
|
||||
computer_actions = convert_to_computer_actions(
|
||||
parsed_responses, original_width, original_height
|
||||
)
|
||||
|
||||
# Add computer actions to response items
|
||||
thought = parsed_responses[0].get("thought", "")
|
||||
if thought:
|
||||
response_items.append(make_reasoning_item(thought))
|
||||
response_items.extend(computer_actions)
|
||||
|
||||
|
||||
# Extract usage information
|
||||
response_usage = {
|
||||
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(response.usage).model_dump(),
|
||||
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(
|
||||
response.usage
|
||||
).model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(response_usage)
|
||||
|
||||
# Create agent response
|
||||
agent_response = {
|
||||
"output": response_items,
|
||||
"usage": response_usage
|
||||
}
|
||||
|
||||
agent_response = {"output": response_items, "usage": response_usage}
|
||||
|
||||
return agent_response
|
||||
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str
|
||||
self, model: str, image_b64: str, instruction: str
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates based on image and instruction.
|
||||
|
||||
|
||||
UITARS supports click prediction through its action parsing.
|
||||
|
||||
|
||||
Args:
|
||||
model: Model name to use
|
||||
image_b64: Base64 encoded image
|
||||
instruction: Instruction for where to click
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple with (x, y) coordinates or None
|
||||
"""
|
||||
try:
|
||||
# Create prompt using grounding template
|
||||
user_prompt = GROUNDING_UITARS_PROMPT_TEMPLATE.format(
|
||||
instruction=instruction
|
||||
)
|
||||
|
||||
user_prompt = GROUNDING_UITARS_PROMPT_TEMPLATE.format(instruction=instruction)
|
||||
|
||||
# Process image for UITARS
|
||||
processed_image, original_width, original_height = process_image_for_uitars(image_b64)
|
||||
encoded_image = pil_to_base64(processed_image)
|
||||
|
||||
|
||||
# Prepare messages for liteLLM
|
||||
litellm_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": user_prompt},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_image}"}}
|
||||
]
|
||||
}
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{encoded_image}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": litellm_messages,
|
||||
"max_tokens": 2056,
|
||||
"temperature": 0.0,
|
||||
"do_sample": False
|
||||
"do_sample": False,
|
||||
}
|
||||
|
||||
|
||||
# Call liteLLM with UITARS model
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
|
||||
# Extract response content
|
||||
response_content = response.choices[0].message.content.strip() # type: ignore
|
||||
|
||||
response_content = response.choices[0].message.content.strip() # type: ignore
|
||||
|
||||
print(response_content)
|
||||
|
||||
# Parse the response to extract click coordinates
|
||||
# Look for click action with coordinates (with special tokens)
|
||||
click_pattern = r"click\(point='<\|box_start\|>\((\d+),(\d+)\)<\|box_end\|>'\)"
|
||||
match = re.search(click_pattern, response_content)
|
||||
|
||||
|
||||
# Fallback: Look for simpler format without special tokens
|
||||
if not match:
|
||||
# Pattern for: click(start_box='(x,y)') or click(point='(x,y)')
|
||||
fallback_pattern = r"click\((?:start_box|point)='\((\d+),(\d+)\)'\)"
|
||||
match = re.search(fallback_pattern, response_content)
|
||||
|
||||
|
||||
if match:
|
||||
x, y = int(match.group(1)), int(match.group(2))
|
||||
# Scale coordinates back to original image dimensions
|
||||
scale_x = original_width / processed_image.width
|
||||
scale_y = original_height / processed_image.height
|
||||
|
||||
|
||||
scaled_x = int(x * scale_x)
|
||||
scaled_y = int(y * scale_y)
|
||||
|
||||
|
||||
return (scaled_x, scaled_y)
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Log error and return None
|
||||
print(f"Error in predict_click: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""
|
||||
Get list of capabilities supported by this agent config.
|
||||
|
||||
|
||||
Returns:
|
||||
List of capability strings
|
||||
"""
|
||||
return ["step", "click"]
|
||||
return ["step", "click"]
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
"""
|
||||
Example usage of the proxy server and client requests.
|
||||
"""
|
||||
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import aiohttp
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
async def test_http_endpoint():
|
||||
"""Test the HTTP /responses endpoint."""
|
||||
|
||||
|
||||
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
assert isinstance(anthropic_api_key, str), "ANTHROPIC_API_KEY environment variable must be set"
|
||||
|
||||
@@ -21,11 +24,9 @@ async def test_http_endpoint():
|
||||
simple_request = {
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"input": "Tell me a three sentence bedtime story about a unicorn.",
|
||||
"env": {
|
||||
"ANTHROPIC_API_KEY": anthropic_api_key
|
||||
}
|
||||
"env": {"ANTHROPIC_API_KEY": anthropic_api_key},
|
||||
}
|
||||
|
||||
|
||||
# Example 2: Multi-modal request with image
|
||||
multimodal_request = {
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
@@ -36,70 +37,72 @@ async def test_http_endpoint():
|
||||
{"type": "input_text", "text": "what is in this image?"},
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
}
|
||||
]
|
||||
"image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"env": {
|
||||
"ANTHROPIC_API_KEY": anthropic_api_key
|
||||
}
|
||||
"env": {"ANTHROPIC_API_KEY": anthropic_api_key},
|
||||
}
|
||||
|
||||
|
||||
# Example 3: Request with custom agent and computer kwargs
|
||||
custom_request = {
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"input": "Take a screenshot and tell me what you see",
|
||||
"env": {
|
||||
"ANTHROPIC_API_KEY": anthropic_api_key
|
||||
}
|
||||
"env": {"ANTHROPIC_API_KEY": anthropic_api_key},
|
||||
}
|
||||
|
||||
|
||||
# Test requests
|
||||
base_url = "https://m-linux-96lcxd2c2k.containers.cloud.trycua.com:8443"
|
||||
# base_url = "http://localhost:8000"
|
||||
api_key = os.getenv("CUA_API_KEY")
|
||||
assert isinstance(api_key, str), "CUA_API_KEY environment variable must be set"
|
||||
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for i, request_data in enumerate([
|
||||
simple_request,
|
||||
# multimodal_request,
|
||||
custom_request
|
||||
], 1):
|
||||
for i, request_data in enumerate(
|
||||
[
|
||||
simple_request,
|
||||
# multimodal_request,
|
||||
custom_request,
|
||||
],
|
||||
1,
|
||||
):
|
||||
print(f"\n--- Test {i} ---")
|
||||
print(f"Request: {json.dumps(request_data, indent=2)}")
|
||||
|
||||
|
||||
try:
|
||||
print(f"Sending request to {base_url}/responses")
|
||||
async with session.post(
|
||||
f"{base_url}/responses",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json", "X-API-Key": api_key}
|
||||
headers={"Content-Type": "application/json", "X-API-Key": api_key},
|
||||
) as response:
|
||||
result = await response.json()
|
||||
print(f"Status: {response.status}")
|
||||
print(f"Response: {json.dumps(result, indent=2)}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
|
||||
def curl_examples():
|
||||
"""Print curl command examples."""
|
||||
|
||||
|
||||
print("=== CURL Examples ===\n")
|
||||
|
||||
|
||||
print("1. Simple text request:")
|
||||
print("""curl http://localhost:8000/responses \\
|
||||
print(
|
||||
"""curl http://localhost:8000/responses \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"input": "Tell me a three sentence bedtime story about a unicorn."
|
||||
}'""")
|
||||
|
||||
}'"""
|
||||
)
|
||||
|
||||
print("\n2. Multi-modal request with image:")
|
||||
print("""curl http://localhost:8000/responses \\
|
||||
print(
|
||||
"""curl http://localhost:8000/responses \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
@@ -115,10 +118,12 @@ def curl_examples():
|
||||
]
|
||||
}
|
||||
]
|
||||
}'""")
|
||||
|
||||
}'"""
|
||||
)
|
||||
|
||||
print("\n3. Request with custom configuration:")
|
||||
print("""curl http://localhost:8000/responses \\
|
||||
print(
|
||||
"""curl http://localhost:8000/responses \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
@@ -131,50 +136,49 @@ def curl_examples():
|
||||
"os_type": "linux",
|
||||
"provider_type": "cloud"
|
||||
}
|
||||
}'""")
|
||||
}'"""
|
||||
)
|
||||
|
||||
|
||||
async def test_p2p_client():
|
||||
"""Example P2P client using peerjs-python."""
|
||||
try:
|
||||
from peerjs import Peer, PeerOptions, ConnectionEventType
|
||||
from aiortc import RTCConfiguration, RTCIceServer
|
||||
|
||||
from peerjs import ConnectionEventType, Peer, PeerOptions
|
||||
|
||||
# Set up client peer
|
||||
options = PeerOptions(
|
||||
host="0.peerjs.com",
|
||||
port=443,
|
||||
secure=True,
|
||||
config=RTCConfiguration(
|
||||
iceServers=[RTCIceServer(urls="stun:stun.l.google.com:19302")]
|
||||
)
|
||||
config=RTCConfiguration(iceServers=[RTCIceServer(urls="stun:stun.l.google.com:19302")]),
|
||||
)
|
||||
|
||||
|
||||
client_peer = Peer(id="test-client", peer_options=options)
|
||||
await client_peer.start()
|
||||
|
||||
|
||||
# Connect to proxy server
|
||||
connection = client_peer.connect("computer-agent-proxy")
|
||||
|
||||
|
||||
@connection.on(ConnectionEventType.Open)
|
||||
async def connection_open():
|
||||
print("Connected to proxy server")
|
||||
|
||||
|
||||
# Send a test request
|
||||
request = {
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"input": "Hello from P2P client!"
|
||||
"input": "Hello from P2P client!",
|
||||
}
|
||||
await connection.send(json.dumps(request))
|
||||
|
||||
|
||||
@connection.on(ConnectionEventType.Data)
|
||||
async def connection_data(data):
|
||||
print(f"Received response: {data}")
|
||||
await client_peer.destroy()
|
||||
|
||||
|
||||
# Wait for connection
|
||||
await asyncio.sleep(10)
|
||||
|
||||
|
||||
except ImportError:
|
||||
print("P2P dependencies not available. Install peerjs-python for P2P testing.")
|
||||
except Exception as e:
|
||||
@@ -183,7 +187,7 @@ async def test_p2p_client():
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "curl":
|
||||
curl_examples()
|
||||
elif len(sys.argv) > 1 and sys.argv[1] == "p2p":
|
||||
|
||||
@@ -7,24 +7,25 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Any, List, Union, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from computer import Computer
|
||||
|
||||
from ..agent import ComputerAgent
|
||||
from computer import Computer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResponsesHandler:
|
||||
"""Handler for /responses endpoint that processes agent requests."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.computer = None
|
||||
self.agent = None
|
||||
# Simple in-memory caches
|
||||
self._computer_cache: Dict[str, Any] = {}
|
||||
self._agent_cache: Dict[str, Any] = {}
|
||||
|
||||
|
||||
async def setup_computer_agent(
|
||||
self,
|
||||
model: str,
|
||||
@@ -75,7 +76,9 @@ class ResponsesHandler:
|
||||
computer = Computer(**default_c_config)
|
||||
await computer.__aenter__()
|
||||
self._computer_cache[comp_key] = computer
|
||||
logger.info(f"Computer created and cached with key={comp_key} config={default_c_config}")
|
||||
logger.info(
|
||||
f"Computer created and cached with key={comp_key} config={default_c_config}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Reusing cached computer for key={comp_key}")
|
||||
|
||||
@@ -115,14 +118,14 @@ class ResponsesHandler:
|
||||
|
||||
# Bind current agent reference
|
||||
self.agent = agent
|
||||
|
||||
|
||||
async def process_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a /responses request and return the result.
|
||||
|
||||
|
||||
Args:
|
||||
request_data: Dictionary containing model, input, and optional kwargs
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with the agent's response
|
||||
"""
|
||||
@@ -133,12 +136,12 @@ class ResponsesHandler:
|
||||
agent_kwargs = request_data.get("agent_kwargs", {})
|
||||
computer_kwargs = request_data.get("computer_kwargs", {})
|
||||
env_overrides = request_data.get("env", {}) or {}
|
||||
|
||||
|
||||
if not model:
|
||||
raise ValueError("Model is required")
|
||||
if not input_data:
|
||||
raise ValueError("Input is required")
|
||||
|
||||
|
||||
# Apply env overrides for the duration of this request
|
||||
with self._env_overrides(env_overrides):
|
||||
# Set up (and possibly reuse) computer and agent via caches
|
||||
@@ -155,28 +158,22 @@ class ResponsesHandler:
|
||||
# Run agent and get first result
|
||||
async for result in agent.run(messages):
|
||||
# Return the first result and break
|
||||
return {
|
||||
"success": True,
|
||||
"result": result,
|
||||
"model": model
|
||||
}
|
||||
|
||||
return {"success": True, "result": result, "model": model}
|
||||
|
||||
# If no results were yielded
|
||||
return {
|
||||
"success": False,
|
||||
"error": "No results from agent",
|
||||
"model": model
|
||||
}
|
||||
|
||||
return {"success": False, "error": "No results from agent", "model": model}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing request: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"model": request_data.get("model", "unknown")
|
||||
"model": request_data.get("model", "unknown"),
|
||||
}
|
||||
|
||||
def _convert_input_to_messages(self, input_data: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
|
||||
|
||||
def _convert_input_to_messages(
|
||||
self, input_data: Union[str, List[Dict[str, Any]]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert input data to messages format."""
|
||||
if isinstance(input_data, str):
|
||||
# Simple string input
|
||||
@@ -192,22 +189,18 @@ class ResponsesHandler:
|
||||
if part.get("type") == "input_text":
|
||||
content_parts.append({"type": "text", "text": part["text"]})
|
||||
elif part.get("type") == "input_image":
|
||||
content_parts.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": part["image_url"]}
|
||||
})
|
||||
content_parts.append(
|
||||
{"type": "image_url", "image_url": {"url": part["image_url"]}}
|
||||
)
|
||||
else:
|
||||
content_parts.append(part)
|
||||
messages.append({
|
||||
"role": msg["role"],
|
||||
"content": content_parts
|
||||
})
|
||||
messages.append({"role": msg["role"], "content": content_parts})
|
||||
else:
|
||||
messages.append(msg)
|
||||
return messages
|
||||
else:
|
||||
raise ValueError("Input must be string or list of messages")
|
||||
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources."""
|
||||
if self.computer:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,37 +2,43 @@
|
||||
Type definitions for agent
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Optional, Callable, Protocol, Literal
|
||||
from pydantic import BaseModel
|
||||
import re
|
||||
from litellm import ResponseInputParam, ResponsesAPIResponse, ToolParam
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Protocol
|
||||
|
||||
from litellm import ResponseInputParam, ResponsesAPIResponse, ToolParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Agent input types
|
||||
Messages = str | ResponseInputParam | List[Dict[str, Any]]
|
||||
Tools = Optional[Iterable[ToolParam]]
|
||||
|
||||
# Agent output types
|
||||
AgentResponse = ResponsesAPIResponse
|
||||
AgentResponse = ResponsesAPIResponse
|
||||
AgentCapability = Literal["step", "click"]
|
||||
|
||||
|
||||
# Exception types
|
||||
class ToolError(RuntimeError):
|
||||
"""Base exception for tool-related errors"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class IllegalArgumentError(ToolError):
|
||||
"""Exception raised when function arguments are invalid"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Agent config registration
|
||||
class AgentConfigInfo(BaseModel):
|
||||
"""Information about a registered agent config"""
|
||||
|
||||
agent_class: type
|
||||
models_regex: str
|
||||
priority: int = 0
|
||||
|
||||
|
||||
def matches_model(self, model: str) -> bool:
|
||||
"""Check if this agent config matches the given model"""
|
||||
return bool(re.match(self.models_regex, model))
|
||||
|
||||
@@ -2,6 +2,6 @@
|
||||
UI components for agent
|
||||
"""
|
||||
|
||||
from .gradio import launch_ui, create_gradio_ui
|
||||
from .gradio import create_gradio_ui, launch_ui
|
||||
|
||||
__all__ = ["launch_ui", "create_gradio_ui"]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .gradio import launch_ui
|
||||
|
||||
if __name__ == "__main__":
|
||||
launch_ui()
|
||||
launch_ui()
|
||||
|
||||
@@ -18,21 +18,21 @@ Requirements:
|
||||
- OpenAI or Anthropic API key
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, AsyncGenerator, Any, Tuple, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components.chatbot import MetadataDict
|
||||
from typing import cast
|
||||
|
||||
# Import from agent package
|
||||
from agent import ComputerAgent
|
||||
from agent.types import Messages, AgentResponse
|
||||
from agent.types import AgentResponse, Messages
|
||||
from computer import Computer
|
||||
from gradio.components.chatbot import MetadataDict
|
||||
|
||||
# Global variables
|
||||
global_agent = None
|
||||
@@ -42,11 +42,13 @@ SETTINGS_FILE = Path(".gradio_settings.json")
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
import dotenv
|
||||
|
||||
if dotenv.load_dotenv():
|
||||
print(f"DEBUG - Loaded environment variables from {dotenv.find_dotenv()}")
|
||||
else:
|
||||
print("DEBUG - No .env file found")
|
||||
|
||||
|
||||
# --- Settings Load/Save Functions ---
|
||||
def load_settings() -> Dict[str, Any]:
|
||||
"""Loads settings from the JSON file."""
|
||||
@@ -84,7 +86,7 @@ def save_settings(settings: Dict[str, Any]):
|
||||
# async def on_screenshot(self, screenshot_base64: str, action_type: str = "") -> None:
|
||||
# """Add screenshot to chatbot when a screenshot is taken."""
|
||||
# image_markdown = f""
|
||||
|
||||
|
||||
# if self.chatbot_history is not None:
|
||||
# self.chatbot_history.append(
|
||||
# gr.ChatMessage(
|
||||
@@ -141,7 +143,7 @@ def get_model_string(model_name: str, loop_provider: str) -> str:
|
||||
ollama_model = model_name.split("OMNI: Ollama ", 1)[1]
|
||||
return f"omniparser+ollama_chat/{ollama_model}"
|
||||
return "omniparser+ollama_chat/llama3"
|
||||
|
||||
|
||||
# Map based on loop provider
|
||||
mapping = MODEL_MAPPINGS.get(loop_provider.lower(), MODEL_MAPPINGS["openai"])
|
||||
return mapping.get(model_name, mapping["default"])
|
||||
@@ -151,6 +153,7 @@ def get_ollama_models() -> List[str]:
|
||||
"""Get available models from Ollama if installed."""
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
result = subprocess.run(["ollama", "list"], capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
lines = result.stdout.strip().split("\n")
|
||||
@@ -174,16 +177,14 @@ def create_computer_instance(
|
||||
os_type: str = "macos",
|
||||
provider_type: str = "lume",
|
||||
name: Optional[str] = None,
|
||||
api_key: Optional[str] = None
|
||||
api_key: Optional[str] = None,
|
||||
) -> Computer:
|
||||
"""Create or get the global Computer instance."""
|
||||
global global_computer
|
||||
if global_computer is None:
|
||||
if provider_type == "localhost":
|
||||
global_computer = Computer(
|
||||
verbosity=verbosity,
|
||||
os_type=os_type,
|
||||
use_host_computer_server=True
|
||||
verbosity=verbosity, os_type=os_type, use_host_computer_server=True
|
||||
)
|
||||
else:
|
||||
global_computer = Computer(
|
||||
@@ -191,7 +192,7 @@ def create_computer_instance(
|
||||
os_type=os_type,
|
||||
provider_type=provider_type,
|
||||
name=name if name else "",
|
||||
api_key=api_key
|
||||
api_key=api_key,
|
||||
)
|
||||
return global_computer
|
||||
|
||||
@@ -217,7 +218,7 @@ def create_agent(
|
||||
os_type=computer_os,
|
||||
provider_type=computer_provider,
|
||||
name=computer_name,
|
||||
api_key=computer_api_key
|
||||
api_key=computer_api_key,
|
||||
)
|
||||
|
||||
# Handle custom models
|
||||
@@ -233,12 +234,15 @@ def create_agent(
|
||||
"only_n_most_recent_images": only_n_most_recent_images,
|
||||
"verbosity": verbosity,
|
||||
}
|
||||
|
||||
|
||||
if save_trajectory:
|
||||
agent_kwargs["trajectory_dir"] = "trajectories"
|
||||
|
||||
|
||||
if max_trajectory_budget:
|
||||
agent_kwargs["max_trajectory_budget"] = {"max_budget": max_trajectory_budget, "raise_error": True}
|
||||
agent_kwargs["max_trajectory_budget"] = {
|
||||
"max_budget": max_trajectory_budget,
|
||||
"raise_error": True,
|
||||
}
|
||||
|
||||
global_agent = ComputerAgent(**agent_kwargs)
|
||||
return global_agent
|
||||
@@ -247,7 +251,8 @@ def create_agent(
|
||||
def launch_ui():
|
||||
"""Standalone function to launch the Gradio app."""
|
||||
from agent.ui.gradio.ui_components import create_gradio_ui
|
||||
print(f"Starting Gradio app for CUA Agent...")
|
||||
|
||||
print("Starting Gradio app for CUA Agent...")
|
||||
demo = create_gradio_ui()
|
||||
demo.launch(share=False, inbrowser=True)
|
||||
|
||||
|
||||
@@ -2,19 +2,25 @@
|
||||
UI Components for the Gradio interface
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any, cast
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components.chatbot import MetadataDict
|
||||
|
||||
from .app import (
|
||||
load_settings, save_settings, create_agent, get_model_string,
|
||||
get_ollama_models, global_agent, global_computer
|
||||
create_agent,
|
||||
get_model_string,
|
||||
get_ollama_models,
|
||||
global_agent,
|
||||
global_computer,
|
||||
load_settings,
|
||||
save_settings,
|
||||
)
|
||||
|
||||
# Global messages array to maintain conversation history
|
||||
@@ -23,15 +29,15 @@ global_messages = []
|
||||
|
||||
def create_gradio_ui() -> gr.Blocks:
|
||||
"""Create a Gradio UI for the Computer-Use Agent."""
|
||||
|
||||
|
||||
# Load settings
|
||||
saved_settings = load_settings()
|
||||
|
||||
|
||||
# Check for API keys
|
||||
openai_api_key = os.environ.get("OPENAI_API_KEY", "")
|
||||
anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY", "")
|
||||
cua_api_key = os.environ.get("CUA_API_KEY", "")
|
||||
|
||||
|
||||
# Model choices
|
||||
openai_models = ["OpenAI: Computer-Use Preview"]
|
||||
anthropic_models = [
|
||||
@@ -43,10 +49,10 @@ def create_gradio_ui() -> gr.Blocks:
|
||||
omni_models = [
|
||||
"OMNI: OpenAI GPT-4o",
|
||||
"OMNI: OpenAI GPT-4o mini",
|
||||
"OMNI: Claude 3.7 Sonnet (20250219)",
|
||||
"OMNI: Claude 3.5 Sonnet (20241022)"
|
||||
"OMNI: Claude 3.7 Sonnet (20250219)",
|
||||
"OMNI: Claude 3.5 Sonnet (20241022)",
|
||||
]
|
||||
|
||||
|
||||
# Check if API keys are available
|
||||
has_openai_key = bool(openai_api_key)
|
||||
has_anthropic_key = bool(anthropic_api_key)
|
||||
@@ -59,15 +65,20 @@ def create_gradio_ui() -> gr.Blocks:
|
||||
|
||||
# Detect platform
|
||||
is_mac = platform.system().lower() == "darwin"
|
||||
|
||||
|
||||
# Format model choices
|
||||
provider_to_models = {
|
||||
"OPENAI": openai_models,
|
||||
"ANTHROPIC": anthropic_models,
|
||||
"OMNI": omni_models + ["Custom model (OpenAI compatible API)", "Custom model (ollama)"],
|
||||
"UITARS": ([
|
||||
"huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B",
|
||||
] if is_mac else []) + ["Custom model (OpenAI compatible API)"],
|
||||
"UITARS": (
|
||||
[
|
||||
"huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B",
|
||||
]
|
||||
if is_mac
|
||||
else []
|
||||
)
|
||||
+ ["Custom model (OpenAI compatible API)"],
|
||||
}
|
||||
|
||||
# Apply saved settings
|
||||
@@ -82,7 +93,9 @@ def create_gradio_ui() -> gr.Blocks:
|
||||
elif initial_loop == "ANTHROPIC":
|
||||
initial_model = anthropic_models[0] if anthropic_models else "No models available"
|
||||
else: # OMNI
|
||||
initial_model = omni_models[0] if omni_models else "Custom model (OpenAI compatible API)"
|
||||
initial_model = (
|
||||
omni_models[0] if omni_models else "Custom model (OpenAI compatible API)"
|
||||
)
|
||||
|
||||
initial_custom_model = saved_settings.get("custom_model", "Qwen2.5-VL-7B-Instruct")
|
||||
initial_provider_base_url = saved_settings.get("provider_base_url", "http://localhost:1234/v1")
|
||||
@@ -96,16 +109,27 @@ def create_gradio_ui() -> gr.Blocks:
|
||||
"Open Safari, search for 'macOS automation tools', and save the first three results as bookmarks",
|
||||
"Configure SSH keys and set up a connection to a remote server",
|
||||
]
|
||||
|
||||
def generate_python_code(agent_loop_choice, model_name, tasks, recent_images=3, save_trajectory=True, computer_os="linux", computer_provider="cloud", container_name="", cua_cloud_api_key="", max_budget=None):
|
||||
|
||||
def generate_python_code(
|
||||
agent_loop_choice,
|
||||
model_name,
|
||||
tasks,
|
||||
recent_images=3,
|
||||
save_trajectory=True,
|
||||
computer_os="linux",
|
||||
computer_provider="cloud",
|
||||
container_name="",
|
||||
cua_cloud_api_key="",
|
||||
max_budget=None,
|
||||
):
|
||||
"""Generate Python code for the current configuration and tasks."""
|
||||
tasks_str = ""
|
||||
for task in tasks:
|
||||
if task and task.strip():
|
||||
tasks_str += f' "{task}",\n'
|
||||
|
||||
|
||||
model_string = get_model_string(model_name, agent_loop_choice)
|
||||
|
||||
|
||||
computer_args = []
|
||||
if computer_os != "macos":
|
||||
computer_args.append(f'os_type="{computer_os}"')
|
||||
@@ -115,14 +139,14 @@ def create_gradio_ui() -> gr.Blocks:
|
||||
computer_args.append(f'name="{container_name}"')
|
||||
if cua_cloud_api_key:
|
||||
computer_args.append(f'api_key="{cua_cloud_api_key}"')
|
||||
|
||||
|
||||
computer_args_str = ", ".join(computer_args)
|
||||
if computer_args_str:
|
||||
computer_args_str = f"({computer_args_str})"
|
||||
else:
|
||||
computer_args_str = "()"
|
||||
|
||||
code = f'''import asyncio
|
||||
|
||||
code = f"""import asyncio
|
||||
from computer import Computer
|
||||
from agent import ComputerAgent
|
||||
|
||||
@@ -131,22 +155,22 @@ async def main():
|
||||
agent = ComputerAgent(
|
||||
model="{model_string}",
|
||||
tools=[computer],
|
||||
only_n_most_recent_images={recent_images},'''
|
||||
|
||||
only_n_most_recent_images={recent_images},"""
|
||||
|
||||
if save_trajectory:
|
||||
code += '''
|
||||
trajectory_dir="trajectories",'''
|
||||
|
||||
code += """
|
||||
trajectory_dir="trajectories","""
|
||||
|
||||
if max_budget:
|
||||
code += f'''
|
||||
max_trajectory_budget={{"max_budget": {max_budget}, "raise_error": True}},'''
|
||||
|
||||
code += '''
|
||||
code += f"""
|
||||
max_trajectory_budget={{"max_budget": {max_budget}, "raise_error": True}},"""
|
||||
|
||||
code += """
|
||||
)
|
||||
'''
|
||||
|
||||
"""
|
||||
|
||||
if tasks_str:
|
||||
code += f'''
|
||||
code += f"""
|
||||
# Prompts for the computer-use agent
|
||||
tasks = [
|
||||
{tasks_str.rstrip()}
|
||||
@@ -158,23 +182,23 @@ async def main():
|
||||
async for result in agent.run(messages):
|
||||
for item in result["output"]:
|
||||
if item["type"] == "message":
|
||||
print(item["content"][0]["text"])'''
|
||||
print(item["content"][0]["text"])"""
|
||||
else:
|
||||
code += f'''
|
||||
code += """
|
||||
# Execute a single task
|
||||
task = "Search for information about CUA on GitHub"
|
||||
print(f"Executing task: {{task}}")
|
||||
messages = [{{"role": "user", "content": task}}]
|
||||
print(f"Executing task: {task}")
|
||||
messages = [{"role": "user", "content": task}]
|
||||
async for result in agent.run(messages):
|
||||
for item in result["output"]:
|
||||
if item["type"] == "message":
|
||||
print(item["content"][0]["text"])'''
|
||||
print(item["content"][0]["text"])"""
|
||||
|
||||
code += '''
|
||||
code += """
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())'''
|
||||
|
||||
asyncio.run(main())"""
|
||||
|
||||
return code
|
||||
|
||||
# Create the Gradio interface
|
||||
@@ -199,11 +223,11 @@ if __name__ == "__main__":
|
||||
value=generate_python_code(initial_loop, "gpt-4o", []),
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
|
||||
with gr.Accordion("Computer Configuration", open=True):
|
||||
is_windows = platform.system().lower() == "windows"
|
||||
is_mac = platform.system().lower() == "darwin"
|
||||
|
||||
|
||||
providers = ["cloud", "localhost", "docker"]
|
||||
if is_mac:
|
||||
providers += ["lume"]
|
||||
@@ -227,30 +251,30 @@ if __name__ == "__main__":
|
||||
value=computer_choices[0],
|
||||
info="Select the operating system for the computer",
|
||||
)
|
||||
|
||||
|
||||
computer_provider = gr.Radio(
|
||||
choices=providers,
|
||||
label="Provider",
|
||||
value="lume" if is_mac else "cloud",
|
||||
info="Select the computer provider",
|
||||
)
|
||||
|
||||
|
||||
container_name = gr.Textbox(
|
||||
label="Container Name",
|
||||
placeholder="Enter container name (optional)",
|
||||
value=os.environ.get("CUA_CONTAINER_NAME", ""),
|
||||
info="Optional name for the container",
|
||||
)
|
||||
|
||||
|
||||
cua_cloud_api_key = gr.Textbox(
|
||||
label="CUA Cloud API Key",
|
||||
placeholder="Enter your CUA Cloud API key",
|
||||
value=os.environ.get("CUA_API_KEY", ""),
|
||||
type="password",
|
||||
info="Required for cloud provider",
|
||||
visible=(not has_cua_key)
|
||||
visible=(not has_cua_key),
|
||||
)
|
||||
|
||||
|
||||
with gr.Accordion("Agent Configuration", open=True):
|
||||
agent_loop = gr.Dropdown(
|
||||
choices=["OPENAI", "ANTHROPIC", "OMNI", "UITARS"],
|
||||
@@ -267,90 +291,113 @@ if __name__ == "__main__":
|
||||
value=openai_models[0] if openai_models else "No models available",
|
||||
info="Select OpenAI model",
|
||||
interactive=True,
|
||||
visible=(initial_loop == "OPENAI")
|
||||
visible=(initial_loop == "OPENAI"),
|
||||
)
|
||||
|
||||
|
||||
anthropic_model_choice = gr.Dropdown(
|
||||
choices=anthropic_models,
|
||||
label="Anthropic Model",
|
||||
value=anthropic_models[0] if anthropic_models else "No models available",
|
||||
value=(
|
||||
anthropic_models[0] if anthropic_models else "No models available"
|
||||
),
|
||||
info="Select Anthropic model",
|
||||
interactive=True,
|
||||
visible=(initial_loop == "ANTHROPIC")
|
||||
visible=(initial_loop == "ANTHROPIC"),
|
||||
)
|
||||
|
||||
|
||||
omni_model_choice = gr.Dropdown(
|
||||
choices=omni_models + ["Custom model (OpenAI compatible API)", "Custom model (ollama)"],
|
||||
choices=omni_models
|
||||
+ ["Custom model (OpenAI compatible API)", "Custom model (ollama)"],
|
||||
label="OMNI Model",
|
||||
value=omni_models[0] if omni_models else "Custom model (OpenAI compatible API)",
|
||||
value=(
|
||||
omni_models[0]
|
||||
if omni_models
|
||||
else "Custom model (OpenAI compatible API)"
|
||||
),
|
||||
info="Select OMNI model or choose a custom model option",
|
||||
interactive=True,
|
||||
visible=(initial_loop == "OMNI")
|
||||
visible=(initial_loop == "OMNI"),
|
||||
)
|
||||
|
||||
|
||||
uitars_model_choice = gr.Dropdown(
|
||||
choices=provider_to_models.get("UITARS", ["No models available"]),
|
||||
label="UITARS Model",
|
||||
value=provider_to_models.get("UITARS", ["No models available"])[0] if provider_to_models.get("UITARS") else "No models available",
|
||||
value=(
|
||||
provider_to_models.get("UITARS", ["No models available"])[0]
|
||||
if provider_to_models.get("UITARS")
|
||||
else "No models available"
|
||||
),
|
||||
info="Select UITARS model",
|
||||
interactive=True,
|
||||
visible=(initial_loop == "UITARS")
|
||||
visible=(initial_loop == "UITARS"),
|
||||
)
|
||||
|
||||
|
||||
model_choice = gr.Textbox(visible=False)
|
||||
|
||||
# API key inputs
|
||||
with gr.Group(visible=not has_openai_key and (initial_loop == "OPENAI" or initial_loop == "OMNI")) as openai_key_group:
|
||||
with gr.Group(
|
||||
visible=not has_openai_key
|
||||
and (initial_loop == "OPENAI" or initial_loop == "OMNI")
|
||||
) as openai_key_group:
|
||||
openai_api_key_input = gr.Textbox(
|
||||
label="OpenAI API Key",
|
||||
placeholder="Enter your OpenAI API key",
|
||||
value=os.environ.get("OPENAI_API_KEY", ""),
|
||||
interactive=True,
|
||||
type="password",
|
||||
info="Required for OpenAI models"
|
||||
info="Required for OpenAI models",
|
||||
)
|
||||
|
||||
with gr.Group(visible=not has_anthropic_key and (initial_loop == "ANTHROPIC" or initial_loop == "OMNI")) as anthropic_key_group:
|
||||
|
||||
with gr.Group(
|
||||
visible=not has_anthropic_key
|
||||
and (initial_loop == "ANTHROPIC" or initial_loop == "OMNI")
|
||||
) as anthropic_key_group:
|
||||
anthropic_api_key_input = gr.Textbox(
|
||||
label="Anthropic API Key",
|
||||
placeholder="Enter your Anthropic API key",
|
||||
value=os.environ.get("ANTHROPIC_API_KEY", ""),
|
||||
interactive=True,
|
||||
type="password",
|
||||
info="Required for Anthropic models"
|
||||
info="Required for Anthropic models",
|
||||
)
|
||||
|
||||
|
||||
# API key handlers
|
||||
def set_openai_api_key(key):
|
||||
if key and key.strip():
|
||||
os.environ["OPENAI_API_KEY"] = key.strip()
|
||||
print(f"DEBUG - Set OpenAI API key environment variable")
|
||||
print("DEBUG - Set OpenAI API key environment variable")
|
||||
return key
|
||||
|
||||
|
||||
def set_anthropic_api_key(key):
|
||||
if key and key.strip():
|
||||
os.environ["ANTHROPIC_API_KEY"] = key.strip()
|
||||
print(f"DEBUG - Set Anthropic API key environment variable")
|
||||
print("DEBUG - Set Anthropic API key environment variable")
|
||||
return key
|
||||
|
||||
|
||||
openai_api_key_input.change(
|
||||
fn=set_openai_api_key,
|
||||
inputs=[openai_api_key_input],
|
||||
outputs=[openai_api_key_input],
|
||||
queue=False
|
||||
queue=False,
|
||||
)
|
||||
|
||||
|
||||
anthropic_api_key_input.change(
|
||||
fn=set_anthropic_api_key,
|
||||
inputs=[anthropic_api_key_input],
|
||||
outputs=[anthropic_api_key_input],
|
||||
queue=False
|
||||
queue=False,
|
||||
)
|
||||
|
||||
# UI update function
|
||||
def update_ui(loop=None, openai_model=None, anthropic_model=None, omni_model=None, uitars_model=None):
|
||||
def update_ui(
|
||||
loop=None,
|
||||
openai_model=None,
|
||||
anthropic_model=None,
|
||||
omni_model=None,
|
||||
uitars_model=None,
|
||||
):
|
||||
loop = loop or agent_loop.value
|
||||
|
||||
|
||||
model_value = None
|
||||
if loop == "OPENAI" and openai_model:
|
||||
model_value = openai_model
|
||||
@@ -360,21 +407,37 @@ if __name__ == "__main__":
|
||||
model_value = omni_model
|
||||
elif loop == "UITARS" and uitars_model:
|
||||
model_value = uitars_model
|
||||
|
||||
openai_visible = (loop == "OPENAI")
|
||||
anthropic_visible = (loop == "ANTHROPIC")
|
||||
omni_visible = (loop == "OMNI")
|
||||
uitars_visible = (loop == "UITARS")
|
||||
|
||||
show_openai_key = not has_openai_key and (loop == "OPENAI" or (loop == "OMNI" and model_value and "OpenAI" in model_value and "Custom" not in model_value))
|
||||
show_anthropic_key = not has_anthropic_key and (loop == "ANTHROPIC" or (loop == "OMNI" and model_value and "Claude" in model_value and "Custom" not in model_value))
|
||||
|
||||
|
||||
openai_visible = loop == "OPENAI"
|
||||
anthropic_visible = loop == "ANTHROPIC"
|
||||
omni_visible = loop == "OMNI"
|
||||
uitars_visible = loop == "UITARS"
|
||||
|
||||
show_openai_key = not has_openai_key and (
|
||||
loop == "OPENAI"
|
||||
or (
|
||||
loop == "OMNI"
|
||||
and model_value
|
||||
and "OpenAI" in model_value
|
||||
and "Custom" not in model_value
|
||||
)
|
||||
)
|
||||
show_anthropic_key = not has_anthropic_key and (
|
||||
loop == "ANTHROPIC"
|
||||
or (
|
||||
loop == "OMNI"
|
||||
and model_value
|
||||
and "Claude" in model_value
|
||||
and "Custom" not in model_value
|
||||
)
|
||||
)
|
||||
|
||||
is_custom_openai_api = model_value == "Custom model (OpenAI compatible API)"
|
||||
is_custom_ollama = model_value == "Custom model (ollama)"
|
||||
is_any_custom = is_custom_openai_api or is_custom_ollama
|
||||
|
||||
|
||||
model_choice_value = model_value if model_value else ""
|
||||
|
||||
|
||||
return [
|
||||
gr.update(visible=openai_visible),
|
||||
gr.update(visible=anthropic_visible),
|
||||
@@ -385,15 +448,18 @@ if __name__ == "__main__":
|
||||
gr.update(visible=is_any_custom),
|
||||
gr.update(visible=is_custom_openai_api),
|
||||
gr.update(visible=is_custom_openai_api),
|
||||
gr.update(value=model_choice_value)
|
||||
gr.update(value=model_choice_value),
|
||||
]
|
||||
|
||||
|
||||
# Custom model inputs
|
||||
custom_model = gr.Textbox(
|
||||
label="Custom Model Name",
|
||||
placeholder="Enter custom model name (e.g., Qwen2.5-VL-7B-Instruct or llama3)",
|
||||
value=initial_custom_model,
|
||||
visible=(initial_model == "Custom model (OpenAI compatible API)" or initial_model == "Custom model (ollama)"),
|
||||
visible=(
|
||||
initial_model == "Custom model (OpenAI compatible API)"
|
||||
or initial_model == "Custom model (ollama)"
|
||||
),
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
@@ -413,36 +479,56 @@ if __name__ == "__main__":
|
||||
interactive=True,
|
||||
type="password",
|
||||
)
|
||||
|
||||
|
||||
# Provider visibility update function
|
||||
def update_provider_visibility(provider):
|
||||
"""Update visibility of container name and API key based on selected provider."""
|
||||
is_localhost = provider == "localhost"
|
||||
return [
|
||||
gr.update(visible=not is_localhost), # container_name
|
||||
gr.update(visible=not is_localhost and not has_cua_key) # cua_cloud_api_key
|
||||
gr.update(
|
||||
visible=not is_localhost and not has_cua_key
|
||||
), # cua_cloud_api_key
|
||||
]
|
||||
|
||||
|
||||
# Connect provider change event
|
||||
computer_provider.change(
|
||||
fn=update_provider_visibility,
|
||||
inputs=[computer_provider],
|
||||
outputs=[container_name, cua_cloud_api_key],
|
||||
queue=False
|
||||
queue=False,
|
||||
)
|
||||
|
||||
|
||||
# Connect UI update events
|
||||
for dropdown in [agent_loop, omni_model_choice, uitars_model_choice, openai_model_choice, anthropic_model_choice]:
|
||||
for dropdown in [
|
||||
agent_loop,
|
||||
omni_model_choice,
|
||||
uitars_model_choice,
|
||||
openai_model_choice,
|
||||
anthropic_model_choice,
|
||||
]:
|
||||
dropdown.change(
|
||||
fn=update_ui,
|
||||
inputs=[agent_loop, openai_model_choice, anthropic_model_choice, omni_model_choice, uitars_model_choice],
|
||||
outputs=[
|
||||
openai_model_choice, anthropic_model_choice, omni_model_choice, uitars_model_choice,
|
||||
openai_key_group, anthropic_key_group,
|
||||
custom_model, provider_base_url, provider_api_key,
|
||||
model_choice
|
||||
inputs=[
|
||||
agent_loop,
|
||||
openai_model_choice,
|
||||
anthropic_model_choice,
|
||||
omni_model_choice,
|
||||
uitars_model_choice,
|
||||
],
|
||||
queue=False
|
||||
outputs=[
|
||||
openai_model_choice,
|
||||
anthropic_model_choice,
|
||||
omni_model_choice,
|
||||
uitars_model_choice,
|
||||
openai_key_group,
|
||||
anthropic_key_group,
|
||||
custom_model,
|
||||
provider_base_url,
|
||||
provider_api_key,
|
||||
model_choice,
|
||||
],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
save_trajectory = gr.Checkbox(
|
||||
@@ -461,7 +547,7 @@ if __name__ == "__main__":
|
||||
info="Number of recent images to keep in context",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
|
||||
max_budget = gr.Number(
|
||||
label="Max Budget ($)",
|
||||
value=lambda: None,
|
||||
@@ -479,9 +565,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
chatbot_history = gr.Chatbot(type="messages")
|
||||
msg = gr.Textbox(
|
||||
placeholder="Ask me to perform tasks in a virtual environment"
|
||||
)
|
||||
msg = gr.Textbox(placeholder="Ask me to perform tasks in a virtual environment")
|
||||
clear = gr.Button("Clear")
|
||||
cancel_button = gr.Button("Cancel", variant="stop")
|
||||
|
||||
@@ -498,11 +582,23 @@ if __name__ == "__main__":
|
||||
global global_agent
|
||||
if global_agent:
|
||||
print("DEBUG - Cancelling agent task")
|
||||
history.append(gr.ChatMessage(role="assistant", content="Task cancelled by user", metadata={"title": "❌ Cancelled"}))
|
||||
history.append(
|
||||
gr.ChatMessage(
|
||||
role="assistant",
|
||||
content="Task cancelled by user",
|
||||
metadata={"title": "❌ Cancelled"},
|
||||
)
|
||||
)
|
||||
else:
|
||||
history.append(gr.ChatMessage(role="assistant", content="No active agent task to cancel", metadata={"title": "ℹ️ Info"}))
|
||||
history.append(
|
||||
gr.ChatMessage(
|
||||
role="assistant",
|
||||
content="No active agent task to cancel",
|
||||
metadata={"title": "ℹ️ Info"},
|
||||
)
|
||||
)
|
||||
return history
|
||||
|
||||
|
||||
# Process response function
|
||||
async def process_response(
|
||||
history,
|
||||
@@ -542,10 +638,13 @@ if __name__ == "__main__":
|
||||
model_choice_value = uitars_model_value
|
||||
else:
|
||||
model_choice_value = "No models available"
|
||||
|
||||
|
||||
# Determine if this is a custom model selection
|
||||
is_custom_model_selected = model_choice_value in ["Custom model (OpenAI compatible API)", "Custom model (ollama)"]
|
||||
|
||||
is_custom_model_selected = model_choice_value in [
|
||||
"Custom model (OpenAI compatible API)",
|
||||
"Custom model (ollama)",
|
||||
]
|
||||
|
||||
# Determine the model name string to analyze
|
||||
if is_custom_model_selected:
|
||||
model_string_to_analyze = custom_model_value
|
||||
@@ -583,13 +682,19 @@ if __name__ == "__main__":
|
||||
model_string=model_string,
|
||||
save_trajectory=save_traj,
|
||||
only_n_most_recent_images=recent_imgs,
|
||||
custom_model_name=custom_model_value if is_custom_model_selected else None,
|
||||
custom_model_name=(
|
||||
custom_model_value if is_custom_model_selected else None
|
||||
),
|
||||
computer_os=computer_os,
|
||||
computer_provider=computer_provider,
|
||||
computer_name=container_name,
|
||||
computer_api_key=cua_cloud_api_key,
|
||||
verbosity=logging.DEBUG,
|
||||
max_trajectory_budget=max_budget_value if max_budget_value and max_budget_value > 0 else None,
|
||||
max_trajectory_budget=(
|
||||
max_budget_value
|
||||
if max_budget_value and max_budget_value > 0
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
if global_agent is None:
|
||||
@@ -605,7 +710,7 @@ if __name__ == "__main__":
|
||||
# Add user message to global history
|
||||
global global_messages
|
||||
global_messages.append({"role": "user", "content": last_user_message})
|
||||
|
||||
|
||||
# Stream responses from the agent
|
||||
async for result in global_agent.run(global_messages):
|
||||
global_messages += result.get("output", [])
|
||||
@@ -613,18 +718,20 @@ if __name__ == "__main__":
|
||||
# from pprint import pprint
|
||||
# pprint(result)
|
||||
# print(f"DEBUG - Agent response ------- END")
|
||||
|
||||
|
||||
# Process the result output
|
||||
for item in result.get("output", []):
|
||||
if item.get("type") == "message":
|
||||
content = item.get("content", [])
|
||||
for content_part in content:
|
||||
if content_part.get("text"):
|
||||
history.append(gr.ChatMessage(
|
||||
role=item.get("role", "assistant"),
|
||||
content=content_part.get("text", ""),
|
||||
metadata=content_part.get("metadata", {})
|
||||
))
|
||||
history.append(
|
||||
gr.ChatMessage(
|
||||
role=item.get("role", "assistant"),
|
||||
content=content_part.get("text", ""),
|
||||
metadata=content_part.get("metadata", {}),
|
||||
)
|
||||
)
|
||||
elif item.get("type") == "computer_call":
|
||||
action = item.get("action", {})
|
||||
action_type = action.get("type", "")
|
||||
@@ -632,43 +739,52 @@ if __name__ == "__main__":
|
||||
action_title = f"🛠️ Performing {action_type}"
|
||||
if action.get("x") and action.get("y"):
|
||||
action_title += f" at ({action['x']}, {action['y']})"
|
||||
history.append(gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=f"```json\n{json.dumps(action)}\n```",
|
||||
metadata={"title": action_title}
|
||||
))
|
||||
history.append(
|
||||
gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=f"```json\n{json.dumps(action)}\n```",
|
||||
metadata={"title": action_title},
|
||||
)
|
||||
)
|
||||
elif item.get("type") == "function_call":
|
||||
function_name = item.get("name", "")
|
||||
arguments = item.get("arguments", "{}")
|
||||
history.append(gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=f"🔧 Calling function: {function_name}\n```json\n{arguments}\n```",
|
||||
metadata={"title": f"Function Call: {function_name}"}
|
||||
))
|
||||
history.append(
|
||||
gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=f"🔧 Calling function: {function_name}\n```json\n{arguments}\n```",
|
||||
metadata={"title": f"Function Call: {function_name}"},
|
||||
)
|
||||
)
|
||||
elif item.get("type") == "function_call_output":
|
||||
output = item.get("output", "")
|
||||
history.append(gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=f"📤 Function output:\n```\n{output}\n```",
|
||||
metadata={"title": "Function Output"}
|
||||
))
|
||||
history.append(
|
||||
gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=f"📤 Function output:\n```\n{output}\n```",
|
||||
metadata={"title": "Function Output"},
|
||||
)
|
||||
)
|
||||
elif item.get("type") == "computer_call_output":
|
||||
output = item.get("output", {}).get("image_url", "")
|
||||
image_markdown = f""
|
||||
history.append(gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=image_markdown,
|
||||
metadata={"title": "🖥️ Computer Output"}
|
||||
))
|
||||
|
||||
history.append(
|
||||
gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=image_markdown,
|
||||
metadata={"title": "🖥️ Computer Output"},
|
||||
)
|
||||
)
|
||||
|
||||
yield history
|
||||
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
history.append(gr.ChatMessage(role="assistant", content=f"Error: {str(e)}"))
|
||||
yield history
|
||||
|
||||
|
||||
# Connect the submit button
|
||||
submit_event = msg.submit(
|
||||
fn=chat_submit,
|
||||
@@ -706,44 +822,77 @@ if __name__ == "__main__":
|
||||
global global_messages
|
||||
global_messages.clear()
|
||||
return None
|
||||
|
||||
|
||||
clear.click(clear_chat, None, chatbot_history, queue=False)
|
||||
|
||||
|
||||
# Connect cancel button
|
||||
cancel_button.click(
|
||||
cancel_agent_task,
|
||||
[chatbot_history],
|
||||
[chatbot_history],
|
||||
queue=False
|
||||
cancel_agent_task, [chatbot_history], [chatbot_history], queue=False
|
||||
)
|
||||
|
||||
# Code display update function
|
||||
def update_code_display(agent_loop, model_choice_val, custom_model_val, chat_history, recent_images_val, save_trajectory_val, computer_os, computer_provider, container_name, cua_cloud_api_key, max_budget_val):
|
||||
def update_code_display(
|
||||
agent_loop,
|
||||
model_choice_val,
|
||||
custom_model_val,
|
||||
chat_history,
|
||||
recent_images_val,
|
||||
save_trajectory_val,
|
||||
computer_os,
|
||||
computer_provider,
|
||||
container_name,
|
||||
cua_cloud_api_key,
|
||||
max_budget_val,
|
||||
):
|
||||
messages = []
|
||||
if chat_history:
|
||||
for msg in chat_history:
|
||||
if isinstance(msg, dict) and msg.get("role") == "user":
|
||||
messages.append(msg.get("content", ""))
|
||||
|
||||
|
||||
return generate_python_code(
|
||||
agent_loop,
|
||||
model_choice_val or custom_model_val or "gpt-4o",
|
||||
messages,
|
||||
agent_loop,
|
||||
model_choice_val or custom_model_val or "gpt-4o",
|
||||
messages,
|
||||
recent_images_val,
|
||||
save_trajectory_val,
|
||||
computer_os,
|
||||
computer_provider,
|
||||
container_name,
|
||||
cua_cloud_api_key,
|
||||
max_budget_val
|
||||
max_budget_val,
|
||||
)
|
||||
|
||||
|
||||
# Update code display when configuration changes
|
||||
for component in [agent_loop, model_choice, custom_model, chatbot_history, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key, max_budget]:
|
||||
for component in [
|
||||
agent_loop,
|
||||
model_choice,
|
||||
custom_model,
|
||||
chatbot_history,
|
||||
recent_images,
|
||||
save_trajectory,
|
||||
computer_os,
|
||||
computer_provider,
|
||||
container_name,
|
||||
cua_cloud_api_key,
|
||||
max_budget,
|
||||
]:
|
||||
component.change(
|
||||
update_code_display,
|
||||
inputs=[agent_loop, model_choice, custom_model, chatbot_history, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key, max_budget],
|
||||
outputs=[code_display]
|
||||
inputs=[
|
||||
agent_loop,
|
||||
model_choice,
|
||||
custom_model,
|
||||
chatbot_history,
|
||||
recent_images,
|
||||
save_trajectory,
|
||||
computer_os,
|
||||
computer_provider,
|
||||
container_name,
|
||||
cua_cloud_api_key,
|
||||
max_budget,
|
||||
],
|
||||
outputs=[code_display],
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
@@ -5,26 +5,30 @@ This directory contains benchmarks designed to test agent providers in the Compu
|
||||
## Overview
|
||||
|
||||
The benchmark system evaluates models on GUI grounding tasks, specifically click prediction accuracy. It supports both:
|
||||
|
||||
- **Computer Agent SDK providers** (using model strings like `"huggingface-local/HelloKKMe/GTA1-7B"`)
|
||||
- **Reference agent implementations** (custom model classes implementing the `ModelProtocol`)
|
||||
|
||||
## Available Benchmarks
|
||||
|
||||
### 1. ScreenSpot-v2 (`ss-v2.py`)
|
||||
|
||||
- **Dataset**: ScreenSpot-v2 (click-only GUI grounding)
|
||||
- **Format**: Standard resolution screenshots
|
||||
- **Task**: Predict click coordinates given an instruction and image
|
||||
- **Metrics**: Accuracy, Error Rate, Timing, VRAM usage
|
||||
|
||||
### 2. ScreenSpot-Pro (`ss-pro.py`)
|
||||
### 2. ScreenSpot-Pro (`ss-pro.py`)
|
||||
|
||||
- **Dataset**: ScreenSpot-Pro (high-resolution click-only GUI grounding)
|
||||
- **Format**: High-resolution screenshots
|
||||
- **Task**: Predict click coordinates given an instruction and image
|
||||
- **Metrics**: Accuracy, Error Rate, Timing, VRAM usage
|
||||
|
||||
### 3. Interactive Testing (`interactive.py`)
|
||||
|
||||
- **Real-time testing**: Take screenshots and visualize model predictions
|
||||
- **Commands**:
|
||||
- **Commands**:
|
||||
- Type instruction → test all models on last screenshot
|
||||
- `screenshot` → take screenshot
|
||||
- `models` → list available models
|
||||
@@ -34,14 +38,16 @@ The benchmark system evaluates models on GUI grounding tasks, specifically click
|
||||
## Running Benchmarks
|
||||
|
||||
### 1. Configure Models
|
||||
|
||||
Edit `utils.py` to specify which models you want to test in `get_available_models()`.
|
||||
|
||||
### 2. Run Benchmark
|
||||
|
||||
```bash
|
||||
# ScreenSpot-v2 benchmark
|
||||
python ss-v2.py --samples 50
|
||||
|
||||
# ScreenSpot-Pro benchmark
|
||||
# ScreenSpot-Pro benchmark
|
||||
python ss-pro.py --samples 50
|
||||
|
||||
# Interactive testing
|
||||
@@ -51,6 +57,7 @@ python interactive.py
|
||||
## Output
|
||||
|
||||
### Console Output
|
||||
|
||||
```
|
||||
Model Results:
|
||||
Accuracy: 85.50% (171/200)
|
||||
@@ -59,10 +66,11 @@ Model Results:
|
||||
```
|
||||
|
||||
### Generated Files
|
||||
|
||||
- **Markdown Report**: `*_results.md` with detailed results tables
|
||||
- **Visualizations**: `output/` directory with prediction visualizations
|
||||
- **Interactive Output**: `interactive_output/` for interactive session results
|
||||
|
||||
## Contributing
|
||||
|
||||
To add a new reference model, follow the instructions in [contrib.md](contrib.md).
|
||||
To add a new reference model, follow the instructions in [contrib.md](contrib.md).
|
||||
|
||||
@@ -17,29 +17,29 @@ class YourModelName(ModelProtocol):
|
||||
def __init__(self, model_path: str):
|
||||
self.model_path = model_path
|
||||
self._model = None
|
||||
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self.model_path
|
||||
|
||||
|
||||
async def load_model(self) -> None:
|
||||
"""Load the model into memory."""
|
||||
# Your model loading logic here
|
||||
pass
|
||||
|
||||
|
||||
async def unload_model(self) -> None:
|
||||
"""Unload the model from memory."""
|
||||
# Your model cleanup logic here
|
||||
pass
|
||||
|
||||
|
||||
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates for the given image and instruction.
|
||||
|
||||
|
||||
Args:
|
||||
image: PIL Image to analyze
|
||||
instruction: Text instruction describing what to click
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
@@ -56,7 +56,7 @@ def get_available_models() -> List[Union[str, ModelProtocol]]:
|
||||
models = [
|
||||
# Computer Agent SDK providers
|
||||
"huggingface-local/HelloKKMe/GTA1-7B",
|
||||
|
||||
|
||||
# Reference implementations
|
||||
GTA1Model("HelloKKMe/GTA1-7B"),
|
||||
YourModelName("path/to/your/model"), # Add your model here
|
||||
@@ -79,6 +79,7 @@ This will help you verify that your model loads correctly and produces reasonabl
|
||||
Here's a complete example of adding a hypothetical "MyVisionModel":
|
||||
|
||||
1. **Create `models/my_vision_model.py`:**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
@@ -91,11 +92,11 @@ class MyVisionModel(ModelProtocol):
|
||||
self.model_path = model_path
|
||||
self.model = None
|
||||
self.processor = None
|
||||
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return f"MyVisionModel({self.model_path})"
|
||||
|
||||
|
||||
async def load_model(self) -> None:
|
||||
"""Load the model and processor."""
|
||||
self.processor = AutoProcessor.from_pretrained(self.model_path)
|
||||
@@ -104,7 +105,7 @@ class MyVisionModel(ModelProtocol):
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
|
||||
async def unload_model(self) -> None:
|
||||
"""Clean up model resources."""
|
||||
del self.model
|
||||
@@ -112,7 +113,7 @@ class MyVisionModel(ModelProtocol):
|
||||
self.model = None
|
||||
self.processor = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
|
||||
"""Predict click coordinates."""
|
||||
try:
|
||||
@@ -122,19 +123,19 @@ class MyVisionModel(ModelProtocol):
|
||||
images=image,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
|
||||
# Run inference
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
|
||||
# Extract coordinates (model-specific logic)
|
||||
x, y = self._extract_coordinates(outputs)
|
||||
return (int(x), int(y))
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Prediction failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _extract_coordinates(self, outputs):
|
||||
"""Extract x, y coordinates from model outputs."""
|
||||
# Your model-specific coordinate extraction logic
|
||||
@@ -142,6 +143,7 @@ class MyVisionModel(ModelProtocol):
|
||||
```
|
||||
|
||||
2. **Update `models/__init__.py`:**
|
||||
|
||||
```python
|
||||
from .gta1 import GTA1Model
|
||||
from .my_vision_model import MyVisionModel
|
||||
@@ -150,6 +152,7 @@ __all__ = ["GTA1Model", "MyVisionModel"]
|
||||
```
|
||||
|
||||
3. **Update `utils.py`:**
|
||||
|
||||
```python
|
||||
from models import GTA1Model, MyVisionModel
|
||||
|
||||
|
||||
@@ -9,60 +9,56 @@ Models are loaded/unloaded one at a time to avoid memory issues.
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from utils import (
|
||||
ModelWrapper,
|
||||
take_screenshot,
|
||||
get_available_models,
|
||||
save_prediction_visualization,
|
||||
get_available_models
|
||||
take_screenshot,
|
||||
)
|
||||
|
||||
|
||||
async def predict_with_all_models(image, instruction: str, models) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Predict click coordinates with all models sequentially.
|
||||
|
||||
|
||||
Args:
|
||||
image: PIL Image to analyze
|
||||
instruction: Instruction text
|
||||
models: List of model instances
|
||||
|
||||
|
||||
Returns:
|
||||
List of prediction results
|
||||
"""
|
||||
predictions = []
|
||||
|
||||
|
||||
for model in models:
|
||||
model_wrapper = ModelWrapper(model)
|
||||
print(f"\n🔄 Loading {model_wrapper.model_name}...")
|
||||
|
||||
|
||||
try:
|
||||
# Load model
|
||||
await model_wrapper.load_model()
|
||||
|
||||
|
||||
# Predict
|
||||
coords = await model_wrapper.predict_click(image, instruction)
|
||||
|
||||
predictions.append({
|
||||
'model_name': model_wrapper.model_name,
|
||||
'coords': coords,
|
||||
'error': None
|
||||
})
|
||||
|
||||
|
||||
predictions.append(
|
||||
{"model_name": model_wrapper.model_name, "coords": coords, "error": None}
|
||||
)
|
||||
|
||||
if coords:
|
||||
print(f"✅ {model_wrapper.model_name}: ({coords[0]}, {coords[1]})")
|
||||
else:
|
||||
print(f"❌ {model_wrapper.model_name}: No prediction")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ {model_wrapper.model_name}: ERROR - {str(e)}")
|
||||
predictions.append({
|
||||
'model_name': model_wrapper.model_name,
|
||||
'coords': None,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
predictions.append(
|
||||
{"model_name": model_wrapper.model_name, "coords": None, "error": str(e)}
|
||||
)
|
||||
|
||||
finally:
|
||||
# Always unload model to free memory
|
||||
try:
|
||||
@@ -70,7 +66,7 @@ async def predict_with_all_models(image, instruction: str, models) -> List[Dict[
|
||||
print(f"🗑️ Unloaded {model_wrapper.model_name}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error unloading {model_wrapper.model_name}: {e}")
|
||||
|
||||
|
||||
return predictions
|
||||
|
||||
|
||||
@@ -103,87 +99,91 @@ async def main():
|
||||
Main interactive loop.
|
||||
"""
|
||||
print_header()
|
||||
|
||||
|
||||
# Get available models
|
||||
models = get_available_models()
|
||||
print_models(models)
|
||||
|
||||
|
||||
# Create output directory for visualizations
|
||||
output_dir = "interactive_output"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
|
||||
session_count = 0
|
||||
last_screenshot = None
|
||||
screenshot_timestamp = None
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Get user input
|
||||
print(f"\n{'='*40}")
|
||||
user_input = input("🎯 Enter instruction (or command): ").strip()
|
||||
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
|
||||
# Handle commands
|
||||
if user_input.lower() in ['quit', 'exit', 'q']:
|
||||
if user_input.lower() in ["quit", "exit", "q"]:
|
||||
print("👋 Goodbye!")
|
||||
break
|
||||
|
||||
elif user_input.lower() == 'models':
|
||||
|
||||
elif user_input.lower() == "models":
|
||||
print_models(models)
|
||||
continue
|
||||
|
||||
elif user_input.lower() == 'screenshot':
|
||||
|
||||
elif user_input.lower() == "screenshot":
|
||||
print("📸 Taking screenshot...")
|
||||
try:
|
||||
last_screenshot = take_screenshot()
|
||||
screenshot_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
screenshot_path = os.path.join(output_dir, f"screenshot_{screenshot_timestamp}.png")
|
||||
screenshot_path = os.path.join(
|
||||
output_dir, f"screenshot_{screenshot_timestamp}.png"
|
||||
)
|
||||
last_screenshot.save(screenshot_path)
|
||||
print(f"✅ Screenshot captured and saved to: {screenshot_path}")
|
||||
print(f"📝 Ready for instructions! Screenshot size: {last_screenshot.size}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error taking screenshot: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# Handle instruction input
|
||||
if last_screenshot is None:
|
||||
print("⚠️ No screenshot available! Please take a screenshot first using 'screenshot' command.")
|
||||
print(
|
||||
"⚠️ No screenshot available! Please take a screenshot first using 'screenshot' command."
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
session_count += 1
|
||||
print(f"\n🎯 Session {session_count}: '{user_input}'")
|
||||
print(f"📷 Using screenshot from: {screenshot_timestamp}")
|
||||
|
||||
|
||||
# Predict with all models using last screenshot
|
||||
print(f"\n🤖 Testing {len(models)} models on screenshot...")
|
||||
predictions = await predict_with_all_models(last_screenshot, user_input, models)
|
||||
|
||||
|
||||
# Display results summary
|
||||
print(f"\n📊 Results Summary:")
|
||||
print("\n📊 Results Summary:")
|
||||
print("-" * 50)
|
||||
for pred in predictions:
|
||||
if pred['coords']:
|
||||
if pred["coords"]:
|
||||
print(f"✅ {pred['model_name']}: ({pred['coords'][0]}, {pred['coords'][1]})")
|
||||
elif pred['error']:
|
||||
elif pred["error"]:
|
||||
print(f"❌ {pred['model_name']}: ERROR - {pred['error']}")
|
||||
else:
|
||||
print(f"❌ {pred['model_name']}: No prediction")
|
||||
|
||||
|
||||
# Save visualization
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
vis_filename = f"session_{session_count:03d}_{timestamp}.png"
|
||||
vis_path = os.path.join(output_dir, vis_filename)
|
||||
|
||||
|
||||
try:
|
||||
save_prediction_visualization(last_screenshot, user_input, predictions, vis_path)
|
||||
print(f"\n💾 Visualization saved to: {vis_path}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error saving visualization: {e}")
|
||||
|
||||
|
||||
print(f"\n✨ Session {session_count} completed!")
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n👋 Interrupted by user. Goodbye!")
|
||||
break
|
||||
|
||||
@@ -2,34 +2,37 @@
|
||||
Base protocol for benchmark models.
|
||||
"""
|
||||
|
||||
from typing import Protocol, Optional, Tuple
|
||||
from typing import Optional, Protocol, Tuple
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ModelProtocol(Protocol):
|
||||
"""Protocol for benchmark models that can predict click coordinates."""
|
||||
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Return the name of the model."""
|
||||
...
|
||||
|
||||
|
||||
async def load_model(self) -> None:
|
||||
"""Load the model into memory."""
|
||||
...
|
||||
|
||||
|
||||
async def unload_model(self) -> None:
|
||||
"""Unload the model from memory."""
|
||||
...
|
||||
|
||||
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
|
||||
|
||||
async def predict_click(
|
||||
self, image: Image.Image, instruction: str
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates for the given image and instruction.
|
||||
|
||||
|
||||
Args:
|
||||
image: PIL Image to analyze
|
||||
instruction: Text instruction describing what to click
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
|
||||
@@ -2,54 +2,51 @@
|
||||
GTA1 model implementation for benchmarking.
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from PIL import Image
|
||||
import torch
|
||||
import re
|
||||
import gc
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from qwen_vl_utils import process_vision_info, smart_resize
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||||
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
from .base import ModelProtocol
|
||||
|
||||
|
||||
class GTA1Model:
|
||||
"""Ground truth GTA1 model implementation."""
|
||||
|
||||
|
||||
def __init__(self, model_path: str = "HelloKKMe/GTA1-7B"):
|
||||
self.model_path = model_path
|
||||
self.model = None
|
||||
self.processor = None
|
||||
self.max_new_tokens = 32
|
||||
|
||||
self.system_prompt = '''
|
||||
|
||||
self.system_prompt = """
|
||||
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point.
|
||||
|
||||
Output the coordinate pair exactly:
|
||||
(x,y)
|
||||
'''.strip()
|
||||
|
||||
""".strip()
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Return the name of the model."""
|
||||
return f"GTA1-{self.model_path.split('/')[-1]}"
|
||||
|
||||
|
||||
async def load_model(self) -> None:
|
||||
"""Load the model into memory."""
|
||||
if self.model is None:
|
||||
print(f"Loading GTA1 model: {self.model_path}")
|
||||
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto"
|
||||
self.model_path, torch_dtype=torch.bfloat16, device_map="auto"
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
self.model_path,
|
||||
min_pixels=3136,
|
||||
max_pixels=4096 * 2160
|
||||
self.model_path, min_pixels=3136, max_pixels=4096 * 2160
|
||||
)
|
||||
print("GTA1 model loaded successfully")
|
||||
|
||||
|
||||
async def unload_model(self) -> None:
|
||||
"""Unload the model from memory."""
|
||||
if self.model is not None:
|
||||
@@ -62,23 +59,25 @@ Output the coordinate pair exactly:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
print("GTA1 model unloaded")
|
||||
|
||||
|
||||
def _extract_coordinates(self, raw_string: str) -> Tuple[int, int]:
|
||||
"""Extract coordinates from model output."""
|
||||
try:
|
||||
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
|
||||
return tuple(map(int, map(float, matches[0]))) # type: ignore
|
||||
return tuple(map(int, map(float, matches[0]))) # type: ignore
|
||||
except:
|
||||
return (0, 0)
|
||||
|
||||
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
|
||||
|
||||
async def predict_click(
|
||||
self, image: Image.Image, instruction: str
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates for the given image and instruction.
|
||||
|
||||
|
||||
Args:
|
||||
image: PIL Image to analyze
|
||||
instruction: Text instruction describing what to click
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
@@ -87,76 +86,73 @@ Output the coordinate pair exactly:
|
||||
|
||||
assert self.processor is not None
|
||||
assert self.model is not None
|
||||
|
||||
|
||||
try:
|
||||
width, height = image.width, image.height
|
||||
|
||||
|
||||
# Resize image according to processor requirements
|
||||
resized_height, resized_width = smart_resize(
|
||||
image.height,
|
||||
image.width,
|
||||
factor=self.processor.image_processor.patch_size * self.processor.image_processor.merge_size,
|
||||
factor=self.processor.image_processor.patch_size
|
||||
* self.processor.image_processor.merge_size,
|
||||
min_pixels=self.processor.image_processor.min_pixels,
|
||||
max_pixels=self.processor.image_processor.max_pixels,
|
||||
)
|
||||
resized_image = image.resize((resized_width, resized_height))
|
||||
scale_x, scale_y = width / resized_width, height / resized_height
|
||||
|
||||
|
||||
# Prepare messages
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": self.system_prompt.format(height=resized_height, width=resized_width)
|
||||
"content": self.system_prompt.format(height=resized_height, width=resized_width),
|
||||
}
|
||||
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": resized_image},
|
||||
{"type": "text", "text": instruction}
|
||||
]
|
||||
{"type": "text", "text": instruction},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# Process inputs
|
||||
image_inputs, video_inputs = process_vision_info([system_message, user_message]) # type: ignore
|
||||
image_inputs, video_inputs = process_vision_info([system_message, user_message]) # type: ignore
|
||||
text = self.processor.apply_chat_template(
|
||||
[system_message, user_message],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
[system_message, user_message], tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to(self.model.device)
|
||||
|
||||
|
||||
# Generate prediction
|
||||
output_ids = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
do_sample=False,
|
||||
temperature=1.0,
|
||||
use_cache=True
|
||||
**inputs,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
do_sample=False,
|
||||
temperature=1.0,
|
||||
use_cache=True,
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):]
|
||||
output_ids[len(input_ids) :]
|
||||
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
|
||||
]
|
||||
output_text = self.processor.batch_decode(
|
||||
generated_ids,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)[0]
|
||||
|
||||
|
||||
# Extract and rescale coordinates
|
||||
pred_x, pred_y = self._extract_coordinates(output_text)
|
||||
pred_x = int(pred_x * scale_x)
|
||||
pred_y = int(pred_y * scale_y)
|
||||
|
||||
|
||||
return (pred_x, pred_y)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in GTA1 prediction: {e}")
|
||||
return None
|
||||
|
||||
@@ -15,103 +15,106 @@ from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils import (
|
||||
ModelWrapper,
|
||||
is_click_in_bbox,
|
||||
save_results_to_markdown,
|
||||
save_visualizations,
|
||||
ModelWrapper,
|
||||
get_available_models,
|
||||
get_gpu_memory
|
||||
get_gpu_memory,
|
||||
is_click_in_bbox,
|
||||
save_results_to_markdown,
|
||||
save_visualizations,
|
||||
)
|
||||
|
||||
|
||||
async def evaluate_model(model_wrapper: ModelWrapper, dataset, max_samples: Optional[int] = None) -> dict:
|
||||
async def evaluate_model(
|
||||
model_wrapper: ModelWrapper, dataset, max_samples: Optional[int] = None
|
||||
) -> dict:
|
||||
"""
|
||||
Evaluate a model on the ScreenSpot-Pro dataset.
|
||||
|
||||
|
||||
Args:
|
||||
model_wrapper: ModelWrapper instance
|
||||
dataset: ScreenSpot-Pro dataset (list of samples)
|
||||
max_samples: Maximum number of samples to evaluate (None for all)
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with evaluation results
|
||||
"""
|
||||
print(f"\nEvaluating model: {model_wrapper.model_name}")
|
||||
|
||||
|
||||
# Load model
|
||||
await model_wrapper.load_model()
|
||||
|
||||
|
||||
total_samples = len(dataset)
|
||||
if max_samples is not None:
|
||||
total_samples = min(max_samples, total_samples)
|
||||
|
||||
|
||||
correct_predictions = 0
|
||||
error_predictions = 0
|
||||
results = []
|
||||
|
||||
|
||||
for i in tqdm(range(total_samples), desc=f"Evaluating {model_wrapper.model_name}"):
|
||||
sample = dataset[i]
|
||||
|
||||
|
||||
# Extract sample data
|
||||
image = sample['image']
|
||||
instruction = sample['instruction']
|
||||
bbox = sample['bbox'] # [x1, y1, x2, y2]
|
||||
sample_id = sample['img_filename']
|
||||
|
||||
image = sample["image"]
|
||||
instruction = sample["instruction"]
|
||||
bbox = sample["bbox"] # [x1, y1, x2, y2]
|
||||
sample_id = sample["img_filename"]
|
||||
|
||||
# Predict click coordinates with timing
|
||||
start_time = time.time()
|
||||
click_coords = await model_wrapper.predict_click(image, instruction)
|
||||
prediction_time = time.time() - start_time
|
||||
|
||||
|
||||
# Check if prediction is correct
|
||||
is_correct = is_click_in_bbox(click_coords, bbox)
|
||||
|
||||
|
||||
if is_correct:
|
||||
correct_predictions += 1
|
||||
|
||||
results.append({
|
||||
'id': sample_id,
|
||||
'instruction': instruction,
|
||||
'bbox': bbox,
|
||||
'predicted_coords': click_coords,
|
||||
'is_correct': is_correct,
|
||||
'failed': False,
|
||||
'prediction_time': prediction_time
|
||||
})
|
||||
|
||||
|
||||
results.append(
|
||||
{
|
||||
"id": sample_id,
|
||||
"instruction": instruction,
|
||||
"bbox": bbox,
|
||||
"predicted_coords": click_coords,
|
||||
"is_correct": is_correct,
|
||||
"failed": False,
|
||||
"prediction_time": prediction_time,
|
||||
}
|
||||
)
|
||||
|
||||
# Unload model
|
||||
await model_wrapper.unload_model()
|
||||
|
||||
|
||||
# Calculate metrics
|
||||
accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0
|
||||
error_rate = error_predictions / total_samples if total_samples > 0 else 0.0
|
||||
|
||||
|
||||
# Calculate timing statistics
|
||||
successful_times = [r['prediction_time'] for r in results if not r['failed']]
|
||||
successful_times = [r["prediction_time"] for r in results if not r["failed"]]
|
||||
avg_prediction_time = sum(successful_times) / len(successful_times) if successful_times else 0.0
|
||||
median_prediction_time = statistics.median(successful_times) if successful_times else 0.0
|
||||
min_prediction_time = min(successful_times) if successful_times else 0.0
|
||||
max_prediction_time = max(successful_times) if successful_times else 0.0
|
||||
|
||||
|
||||
# Get VRAM statistics
|
||||
vram_stats = model_wrapper.get_vram_stats()
|
||||
|
||||
|
||||
return {
|
||||
'model_name': model_wrapper.model_name,
|
||||
'total_samples': total_samples,
|
||||
'correct_predictions': correct_predictions,
|
||||
'failed_predictions': error_predictions,
|
||||
'accuracy': accuracy,
|
||||
'failure_rate': error_rate,
|
||||
'avg_prediction_time': avg_prediction_time,
|
||||
'median_prediction_time': median_prediction_time,
|
||||
'min_prediction_time': min_prediction_time,
|
||||
'max_prediction_time': max_prediction_time,
|
||||
'vram_max_mb': vram_stats['max_mb'],
|
||||
'vram_avg_mb': vram_stats['avg_mb'],
|
||||
'results': results
|
||||
"model_name": model_wrapper.model_name,
|
||||
"total_samples": total_samples,
|
||||
"correct_predictions": correct_predictions,
|
||||
"failed_predictions": error_predictions,
|
||||
"accuracy": accuracy,
|
||||
"failure_rate": error_rate,
|
||||
"avg_prediction_time": avg_prediction_time,
|
||||
"median_prediction_time": median_prediction_time,
|
||||
"min_prediction_time": min_prediction_time,
|
||||
"max_prediction_time": max_prediction_time,
|
||||
"vram_max_mb": vram_stats["max_mb"],
|
||||
"vram_avg_mb": vram_stats["avg_mb"],
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
@@ -120,42 +123,44 @@ async def main():
|
||||
Main function to run the benchmark.
|
||||
"""
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description='ScreenSpot-Pro Benchmark Script')
|
||||
parser.add_argument('--samples', type=int, default=300,
|
||||
help='Number of samples to evaluate (default: 300)')
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help='Random seed for shuffling (default: 42)')
|
||||
parser = argparse.ArgumentParser(description="ScreenSpot-Pro Benchmark Script")
|
||||
parser.add_argument(
|
||||
"--samples", type=int, default=300, help="Number of samples to evaluate (default: 300)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=42, help="Random seed for shuffling (default: 42)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Set random seed
|
||||
random.seed(args.seed)
|
||||
|
||||
|
||||
# Load dataset
|
||||
print("Loading ScreenSpot-Pro dataset...")
|
||||
ds = load_dataset("lmms-lab/ScreenSpot-Pro")
|
||||
dataset = ds['train'] # type: ignore
|
||||
dataset = ds["train"] # type: ignore
|
||||
# Convert to list to support indexing
|
||||
dataset_list = list(dataset)
|
||||
print(f"Dataset loaded: {len(dataset_list)} samples")
|
||||
|
||||
|
||||
# Shuffle dataset with seed
|
||||
random.shuffle(dataset_list)
|
||||
print(f"Dataset shuffled with seed {args.seed}")
|
||||
|
||||
|
||||
# Get available models
|
||||
models = get_available_models()
|
||||
|
||||
|
||||
# Evaluation settings
|
||||
max_samples = args.samples # Use command line argument
|
||||
|
||||
|
||||
# Run evaluations
|
||||
all_results = []
|
||||
|
||||
|
||||
for model in models:
|
||||
model_wrapper = ModelWrapper(model)
|
||||
result = await evaluate_model(model_wrapper, dataset_list, max_samples)
|
||||
all_results.append(result)
|
||||
|
||||
|
||||
# Print summary
|
||||
print(f"\n{result['model_name']} Results:")
|
||||
print(f" Accuracy: {result['accuracy']*100:.2f}%")
|
||||
@@ -164,15 +169,17 @@ async def main():
|
||||
print(f" Error Rate: {result['failure_rate']*100:.2f}%")
|
||||
print(f" Avg Time: {result['avg_prediction_time']:.2f}s")
|
||||
print(f" Median Time: {result['median_prediction_time']:.2f}s")
|
||||
print(f" Time Range: {result['min_prediction_time']:.2f}s - {result['max_prediction_time']:.2f}s")
|
||||
print(
|
||||
f" Time Range: {result['min_prediction_time']:.2f}s - {result['max_prediction_time']:.2f}s"
|
||||
)
|
||||
print(f" VRAM Max: {result['vram_max_mb']:.1f}MB")
|
||||
print(f" VRAM Avg: {result['vram_avg_mb']:.1f}MB")
|
||||
|
||||
|
||||
# Print GPU memory info
|
||||
gpu_memory = get_gpu_memory()
|
||||
if gpu_memory and gpu_memory[0] > 0:
|
||||
print(f" GPU Free Memory: {gpu_memory[0]:.1f}MB")
|
||||
|
||||
|
||||
# Save results
|
||||
if all_results:
|
||||
save_results_to_markdown(all_results)
|
||||
@@ -183,4 +190,4 @@ async def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -15,36 +15,37 @@ from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils import (
|
||||
ModelWrapper,
|
||||
is_click_in_bbox,
|
||||
save_results_to_markdown,
|
||||
save_visualizations,
|
||||
ModelWrapper,
|
||||
get_available_models,
|
||||
get_gpu_memory
|
||||
get_gpu_memory,
|
||||
is_click_in_bbox,
|
||||
save_results_to_markdown,
|
||||
save_visualizations,
|
||||
)
|
||||
|
||||
|
||||
async def evaluate_model(model_wrapper: ModelWrapper, samples, max_samples: Optional[int] = None) -> dict:
|
||||
async def evaluate_model(
|
||||
model_wrapper: ModelWrapper, samples, max_samples: Optional[int] = None
|
||||
) -> dict:
|
||||
"""
|
||||
Evaluate a model on any iterable of samples.
|
||||
|
||||
|
||||
Args:
|
||||
model_wrapper: ModelWrapper instance
|
||||
samples: Iterable of dicts with keys: image, bbox, instruction
|
||||
max_samples: Maximum number of samples to evaluate (None for all)
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with evaluation results
|
||||
"""
|
||||
print(f"\nEvaluating model: {model_wrapper.model_name}")
|
||||
|
||||
|
||||
# Load model
|
||||
await model_wrapper.load_model()
|
||||
|
||||
|
||||
# Convert to list if needed and limit samples
|
||||
if hasattr(samples, '__len__'):
|
||||
if hasattr(samples, "__len__"):
|
||||
total_samples = len(samples)
|
||||
if max_samples is not None:
|
||||
total_samples = min(max_samples, total_samples)
|
||||
@@ -55,69 +56,71 @@ async def evaluate_model(model_wrapper: ModelWrapper, samples, max_samples: Opti
|
||||
if max_samples is not None:
|
||||
sample_list = sample_list[:max_samples]
|
||||
total_samples = len(sample_list)
|
||||
|
||||
|
||||
correct_predictions = 0
|
||||
error_predictions = 0
|
||||
results = []
|
||||
|
||||
|
||||
for i, sample in enumerate(tqdm(sample_list, desc=f"Evaluating {model_wrapper.model_name}")):
|
||||
# Extract required data (only these 3 keys matter)
|
||||
image = sample['image']
|
||||
instruction = sample['instruction']
|
||||
bbox = sample['bbox'] # [x1, y1, x2, y2]
|
||||
|
||||
image = sample["image"]
|
||||
instruction = sample["instruction"]
|
||||
bbox = sample["bbox"] # [x1, y1, x2, y2]
|
||||
|
||||
# Predict click coordinates with timing
|
||||
start_time = time.time()
|
||||
click_coords = await model_wrapper.predict_click(image, instruction)
|
||||
prediction_time = time.time() - start_time
|
||||
|
||||
|
||||
# Check if prediction is correct
|
||||
is_correct = is_click_in_bbox(click_coords, bbox)
|
||||
|
||||
|
||||
if is_correct:
|
||||
correct_predictions += 1
|
||||
|
||||
results.append({
|
||||
'sample_idx': i,
|
||||
'instruction': instruction,
|
||||
'bbox': bbox,
|
||||
'predicted_coords': click_coords,
|
||||
'is_correct': is_correct,
|
||||
'failed': False,
|
||||
'prediction_time': prediction_time
|
||||
})
|
||||
|
||||
|
||||
results.append(
|
||||
{
|
||||
"sample_idx": i,
|
||||
"instruction": instruction,
|
||||
"bbox": bbox,
|
||||
"predicted_coords": click_coords,
|
||||
"is_correct": is_correct,
|
||||
"failed": False,
|
||||
"prediction_time": prediction_time,
|
||||
}
|
||||
)
|
||||
|
||||
# Unload model
|
||||
await model_wrapper.unload_model()
|
||||
|
||||
|
||||
# Calculate metrics
|
||||
accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0
|
||||
error_rate = error_predictions / total_samples if total_samples > 0 else 0.0
|
||||
|
||||
|
||||
# Calculate timing statistics
|
||||
successful_times = [r['prediction_time'] for r in results if not r['failed']]
|
||||
successful_times = [r["prediction_time"] for r in results if not r["failed"]]
|
||||
avg_prediction_time = sum(successful_times) / len(successful_times) if successful_times else 0.0
|
||||
median_prediction_time = statistics.median(successful_times) if successful_times else 0.0
|
||||
min_prediction_time = min(successful_times) if successful_times else 0.0
|
||||
max_prediction_time = max(successful_times) if successful_times else 0.0
|
||||
|
||||
|
||||
# Get VRAM statistics
|
||||
vram_stats = model_wrapper.get_vram_stats()
|
||||
|
||||
|
||||
return {
|
||||
'model_name': model_wrapper.model_name,
|
||||
'total_samples': total_samples,
|
||||
'correct_predictions': correct_predictions,
|
||||
'failed_predictions': error_predictions,
|
||||
'accuracy': accuracy,
|
||||
'failure_rate': error_rate,
|
||||
'avg_prediction_time': avg_prediction_time,
|
||||
'median_prediction_time': median_prediction_time,
|
||||
'min_prediction_time': min_prediction_time,
|
||||
'max_prediction_time': max_prediction_time,
|
||||
'vram_max_mb': vram_stats['max_mb'],
|
||||
'vram_avg_mb': vram_stats['avg_mb'],
|
||||
'results': results
|
||||
"model_name": model_wrapper.model_name,
|
||||
"total_samples": total_samples,
|
||||
"correct_predictions": correct_predictions,
|
||||
"failed_predictions": error_predictions,
|
||||
"accuracy": accuracy,
|
||||
"failure_rate": error_rate,
|
||||
"avg_prediction_time": avg_prediction_time,
|
||||
"median_prediction_time": median_prediction_time,
|
||||
"min_prediction_time": min_prediction_time,
|
||||
"max_prediction_time": max_prediction_time,
|
||||
"vram_max_mb": vram_stats["max_mb"],
|
||||
"vram_avg_mb": vram_stats["avg_mb"],
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
@@ -126,56 +129,60 @@ async def main():
|
||||
Main function to run the benchmark.
|
||||
"""
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description='ScreenSpot-v2 Benchmark Script')
|
||||
parser.add_argument('--samples', type=int, default=500,
|
||||
help='Number of samples to evaluate (default: 500)')
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help='Random seed for shuffling (default: 42)')
|
||||
parser = argparse.ArgumentParser(description="ScreenSpot-v2 Benchmark Script")
|
||||
parser.add_argument(
|
||||
"--samples", type=int, default=500, help="Number of samples to evaluate (default: 500)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=42, help="Random seed for shuffling (default: 42)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Set random seed
|
||||
random.seed(args.seed)
|
||||
|
||||
|
||||
# Load dataset
|
||||
print("Loading ScreenSpot-v2 dataset...")
|
||||
ds = load_dataset("lmms-lab/ScreenSpot-v2")
|
||||
dataset = ds['train'] # type: ignore
|
||||
dataset = ds["train"] # type: ignore
|
||||
# Convert to simple list of dicts with only required keys
|
||||
samples = []
|
||||
for item in dataset:
|
||||
# Convert dataset item to dict if needed
|
||||
item_dict = dict(item) if hasattr(item, 'keys') else item
|
||||
|
||||
item_dict = dict(item) if hasattr(item, "keys") else item
|
||||
|
||||
# Convert ScreenSpot-v2 bbox format [x, y, w, h] to [x1, y1, x2, y2]
|
||||
bbox_xywh = item_dict['bbox'] # type: ignore
|
||||
bbox_xywh = item_dict["bbox"] # type: ignore
|
||||
x, y, w, h = bbox_xywh
|
||||
bbox_xyxy = [x, y, x + w, y + h]
|
||||
|
||||
samples.append({
|
||||
'image': item_dict['image'], # type: ignore
|
||||
'instruction': item_dict['instruction'], # type: ignore
|
||||
'bbox': bbox_xyxy
|
||||
})
|
||||
|
||||
samples.append(
|
||||
{
|
||||
"image": item_dict["image"], # type: ignore
|
||||
"instruction": item_dict["instruction"], # type: ignore
|
||||
"bbox": bbox_xyxy,
|
||||
}
|
||||
)
|
||||
print(f"Dataset loaded: {len(samples)} samples")
|
||||
|
||||
|
||||
# Shuffle samples with seed
|
||||
random.shuffle(samples)
|
||||
print(f"Samples shuffled with seed {args.seed}")
|
||||
|
||||
|
||||
# Get available models
|
||||
models = get_available_models()
|
||||
|
||||
|
||||
# Evaluation settings
|
||||
max_samples = args.samples # Use command line argument
|
||||
|
||||
|
||||
# Run evaluations
|
||||
all_results = []
|
||||
|
||||
|
||||
for model in models:
|
||||
model_wrapper = ModelWrapper(model)
|
||||
result = await evaluate_model(model_wrapper, samples, max_samples)
|
||||
all_results.append(result)
|
||||
|
||||
|
||||
# Print summary
|
||||
print(f"\n{result['model_name']} Results:")
|
||||
print(f" Accuracy: {result['accuracy']*100:.2f}%")
|
||||
@@ -184,18 +191,22 @@ async def main():
|
||||
print(f" Error Rate: {result['failure_rate']*100:.2f}%")
|
||||
print(f" Avg Time: {result['avg_prediction_time']:.2f}s")
|
||||
print(f" Median Time: {result['median_prediction_time']:.2f}s")
|
||||
print(f" Time Range: {result['min_prediction_time']:.2f}s - {result['max_prediction_time']:.2f}s")
|
||||
print(
|
||||
f" Time Range: {result['min_prediction_time']:.2f}s - {result['max_prediction_time']:.2f}s"
|
||||
)
|
||||
print(f" VRAM Max: {result['vram_max_mb']:.1f}MB")
|
||||
print(f" VRAM Avg: {result['vram_avg_mb']:.1f}MB")
|
||||
|
||||
|
||||
# Print GPU memory info
|
||||
gpu_memory = get_gpu_memory()
|
||||
if gpu_memory and gpu_memory[0] > 0:
|
||||
print(f" GPU Free Memory: {gpu_memory[0]:.1f}MB")
|
||||
|
||||
|
||||
# Save results
|
||||
if all_results:
|
||||
save_results_to_markdown(all_results, "screenspot_v2_results.md", title="ScreenSpot-v2 Benchmark Results")
|
||||
save_results_to_markdown(
|
||||
all_results, "screenspot_v2_results.md", title="ScreenSpot-v2 Benchmark Results"
|
||||
)
|
||||
save_visualizations(all_results, samples)
|
||||
print("\nBenchmark completed successfully!")
|
||||
else:
|
||||
@@ -203,4 +214,4 @@ async def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -4,38 +4,40 @@ Shared utilities for ScreenSpot-Pro benchmarking and interactive testing.
|
||||
"""
|
||||
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import subprocess as sp
|
||||
import statistics
|
||||
import subprocess as sp
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import List, Union, Tuple, Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image, ImageDraw
|
||||
from tqdm import tqdm
|
||||
import gc
|
||||
import torch
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
from agent.agent import ComputerAgent
|
||||
from models.base import ModelProtocol
|
||||
|
||||
|
||||
def get_gpu_memory() -> List[int]:
|
||||
"""
|
||||
Get GPU memory usage using nvidia-smi.
|
||||
|
||||
|
||||
Returns:
|
||||
List of free memory values in MB for each GPU
|
||||
"""
|
||||
try:
|
||||
command = "nvidia-smi --query-gpu=memory.free --format=csv"
|
||||
memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
|
||||
memory_free_info = sp.check_output(command.split()).decode("ascii").split("\n")[:-1][1:]
|
||||
memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
|
||||
return memory_free_values
|
||||
except (sp.CalledProcessError, FileNotFoundError, IndexError):
|
||||
@@ -51,39 +53,34 @@ def get_gpu_memory() -> List[int]:
|
||||
def get_vram_usage() -> dict:
|
||||
"""
|
||||
Get current VRAM usage statistics.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with VRAM usage info (in MB)
|
||||
"""
|
||||
if torch.cuda.is_available():
|
||||
device = torch.cuda.current_device()
|
||||
allocated = torch.cuda.memory_allocated(device) / 1024 / 1024 # Convert to MB
|
||||
reserved = torch.cuda.memory_reserved(device) / 1024 / 1024 # Convert to MB
|
||||
reserved = torch.cuda.memory_reserved(device) / 1024 / 1024 # Convert to MB
|
||||
total = torch.cuda.get_device_properties(device).total_memory / 1024 / 1024
|
||||
return {
|
||||
'allocated_mb': allocated,
|
||||
'reserved_mb': reserved,
|
||||
'total_mb': total,
|
||||
'free_mb': total - reserved
|
||||
"allocated_mb": allocated,
|
||||
"reserved_mb": reserved,
|
||||
"total_mb": total,
|
||||
"free_mb": total - reserved,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'allocated_mb': 0.0,
|
||||
'reserved_mb': 0.0,
|
||||
'total_mb': 0.0,
|
||||
'free_mb': 0.0
|
||||
}
|
||||
return {"allocated_mb": 0.0, "reserved_mb": 0.0, "total_mb": 0.0, "free_mb": 0.0}
|
||||
|
||||
|
||||
def get_available_models() -> List[Union[str, ModelProtocol]]:
|
||||
"""
|
||||
Get list of available models for testing.
|
||||
|
||||
|
||||
Returns:
|
||||
List of model strings and model classes
|
||||
"""
|
||||
local_provider = "huggingface-local/" # Options: huggingface-local/ or mlx/
|
||||
|
||||
|
||||
# from models.gta1 import GTA1Model
|
||||
|
||||
models = [
|
||||
@@ -94,42 +91,41 @@ def get_available_models() -> List[Union[str, ModelProtocol]]:
|
||||
# f"{local_provider}HelloKKMe/GTA1-32B",
|
||||
"openai/computer-use-preview+openai/gpt-4o-mini",
|
||||
"anthropic/claude-opus-4-20250514+openai/gpt-4o-mini",
|
||||
|
||||
# === Reference model classes ===
|
||||
# GTA1Model("HelloKKMe/GTA1-7B"),
|
||||
# GTA1Model("HelloKKMe/GTA1-32B"),
|
||||
# GTA1Model("HelloKKMe/GTA1-32B"),
|
||||
]
|
||||
|
||||
|
||||
return models
|
||||
|
||||
|
||||
def is_click_in_bbox(click_coords: Optional[Tuple[int, int]], bbox: List[int]) -> bool:
|
||||
"""
|
||||
Check if click coordinates are within the bounding box.
|
||||
|
||||
|
||||
Args:
|
||||
click_coords: (x, y) coordinates or None
|
||||
bbox: [x1, y1, x2, y2] bounding box
|
||||
|
||||
|
||||
Returns:
|
||||
True if click is within bbox, False otherwise
|
||||
"""
|
||||
if click_coords is None:
|
||||
return False
|
||||
|
||||
|
||||
x, y = click_coords
|
||||
x1, y1, x2, y2 = bbox
|
||||
|
||||
|
||||
return x1 <= x <= x2 and y1 <= y <= y2
|
||||
|
||||
|
||||
def image_to_base64(image: Image.Image) -> str:
|
||||
"""
|
||||
Convert PIL Image to base64 string.
|
||||
|
||||
|
||||
Args:
|
||||
image: PIL Image
|
||||
|
||||
|
||||
Returns:
|
||||
Base64 encoded image string
|
||||
"""
|
||||
@@ -142,213 +138,252 @@ class ModelWrapper:
|
||||
"""
|
||||
Wrapper to provide unified interface for both ComputerAgent and custom models.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, model: Union[str, ModelProtocol]):
|
||||
self.model = model
|
||||
self.is_computer_agent = isinstance(model, str)
|
||||
self.agent: Optional[ComputerAgent] = None
|
||||
self.vram_usage_history: List[float] = [] # Track VRAM usage over time
|
||||
|
||||
|
||||
if self.is_computer_agent:
|
||||
self.model_name = str(model)
|
||||
else:
|
||||
self.model_name = f"{model.__class__.__name__}('{getattr(model, 'model_name', 'unknown')}')"
|
||||
|
||||
self.model_name = (
|
||||
f"{model.__class__.__name__}('{getattr(model, 'model_name', 'unknown')}')"
|
||||
)
|
||||
|
||||
async def load_model(self) -> None:
|
||||
"""Load the model."""
|
||||
if self.is_computer_agent:
|
||||
self.agent = ComputerAgent(model=str(self.model))
|
||||
else:
|
||||
await self.model.load_model() # type: ignore
|
||||
|
||||
await self.model.load_model() # type: ignore
|
||||
|
||||
# Record initial VRAM usage after loading
|
||||
vram_info = get_vram_usage()
|
||||
self.vram_usage_history.append(vram_info['allocated_mb'])
|
||||
|
||||
self.vram_usage_history.append(vram_info["allocated_mb"])
|
||||
|
||||
async def unload_model(self) -> None:
|
||||
"""Unload the model."""
|
||||
if not self.is_computer_agent:
|
||||
await self.model.unload_model() # type: ignore
|
||||
await self.model.unload_model() # type: ignore
|
||||
else:
|
||||
del self.agent
|
||||
self.agent = None
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
# Record VRAM usage after unloading
|
||||
vram_info = get_vram_usage()
|
||||
self.vram_usage_history.append(vram_info['allocated_mb'])
|
||||
|
||||
self.vram_usage_history.append(vram_info["allocated_mb"])
|
||||
|
||||
def get_vram_stats(self) -> dict:
|
||||
"""Get VRAM usage statistics for this model."""
|
||||
if not self.vram_usage_history:
|
||||
return {'max_mb': 0.0, 'avg_mb': 0.0}
|
||||
|
||||
return {"max_mb": 0.0, "avg_mb": 0.0}
|
||||
|
||||
return {
|
||||
'max_mb': max(self.vram_usage_history),
|
||||
'avg_mb': sum(self.vram_usage_history) / len(self.vram_usage_history)
|
||||
"max_mb": max(self.vram_usage_history),
|
||||
"avg_mb": sum(self.vram_usage_history) / len(self.vram_usage_history),
|
||||
}
|
||||
|
||||
|
||||
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
|
||||
async def predict_click(
|
||||
self, image: Image.Image, instruction: str
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""Predict click coordinates."""
|
||||
# Record VRAM usage before prediction
|
||||
vram_info = get_vram_usage()
|
||||
self.vram_usage_history.append(vram_info['allocated_mb'])
|
||||
|
||||
self.vram_usage_history.append(vram_info["allocated_mb"])
|
||||
|
||||
if self.is_computer_agent:
|
||||
if self.agent is None:
|
||||
await self.load_model()
|
||||
|
||||
|
||||
if self.agent is not None:
|
||||
image_b64 = image_to_base64(image)
|
||||
result = await self.agent.predict_click(instruction=instruction, image_b64=image_b64)
|
||||
|
||||
result = await self.agent.predict_click(
|
||||
instruction=instruction, image_b64=image_b64
|
||||
)
|
||||
|
||||
# Record VRAM usage after prediction
|
||||
vram_info = get_vram_usage()
|
||||
self.vram_usage_history.append(vram_info['allocated_mb'])
|
||||
|
||||
self.vram_usage_history.append(vram_info["allocated_mb"])
|
||||
|
||||
return result
|
||||
return None
|
||||
else:
|
||||
result = await self.model.predict_click(image, instruction) # type: ignore
|
||||
|
||||
result = await self.model.predict_click(image, instruction) # type: ignore
|
||||
|
||||
# Record VRAM usage after prediction
|
||||
vram_info = get_vram_usage()
|
||||
self.vram_usage_history.append(vram_info['allocated_mb'])
|
||||
|
||||
self.vram_usage_history.append(vram_info["allocated_mb"])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def save_results_to_markdown(all_results: List[dict],output_file: str = "screenspot_pro_results.md", title: str = "ScreenSpot-Pro Benchmark Results") -> None:
|
||||
def save_results_to_markdown(
|
||||
all_results: List[dict],
|
||||
output_file: str = "screenspot_pro_results.md",
|
||||
title: str = "ScreenSpot-Pro Benchmark Results",
|
||||
) -> None:
|
||||
"""
|
||||
Save evaluation results to a markdown table.
|
||||
|
||||
|
||||
Args:
|
||||
all_results: List of evaluation results for each model
|
||||
output_file: Output markdown file path
|
||||
"""
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
f.write(f"# {title}\n\n")
|
||||
f.write(f"**Evaluation Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
||||
|
||||
|
||||
# Summary table
|
||||
f.write("## Summary\n\n")
|
||||
f.write("| Model | Total Samples | Correct | Errors | Accuracy | Error Rate | Avg Time (s) | Median Time (s) | Time Range (s) | VRAM Max (GB) | VRAM Avg (GB) |\n")
|
||||
f.write("|-------|---------------|---------|--------|----------|------------|--------------|-----------------|----------------|---------------|---------------|\n")
|
||||
|
||||
f.write(
|
||||
"| Model | Total Samples | Correct | Errors | Accuracy | Error Rate | Avg Time (s) | Median Time (s) | Time Range (s) | VRAM Max (GB) | VRAM Avg (GB) |\n"
|
||||
)
|
||||
f.write(
|
||||
"|-------|---------------|---------|--------|----------|------------|--------------|-----------------|----------------|---------------|---------------|\n"
|
||||
)
|
||||
|
||||
for result in all_results:
|
||||
model_name = result['model_name']
|
||||
total = result['total_samples']
|
||||
correct = result['correct_predictions']
|
||||
errors = result['failed_predictions']
|
||||
accuracy = result['accuracy'] * 100
|
||||
error_rate = result['failure_rate'] * 100
|
||||
avg_time = result.get('avg_prediction_time', 0.0)
|
||||
median_time = result.get('median_prediction_time', 0.0)
|
||||
min_time = result.get('min_prediction_time', 0.0)
|
||||
max_time = result.get('max_prediction_time', 0.0)
|
||||
model_name = result["model_name"]
|
||||
total = result["total_samples"]
|
||||
correct = result["correct_predictions"]
|
||||
errors = result["failed_predictions"]
|
||||
accuracy = result["accuracy"] * 100
|
||||
error_rate = result["failure_rate"] * 100
|
||||
avg_time = result.get("avg_prediction_time", 0.0)
|
||||
median_time = result.get("median_prediction_time", 0.0)
|
||||
min_time = result.get("min_prediction_time", 0.0)
|
||||
max_time = result.get("max_prediction_time", 0.0)
|
||||
time_range = f"{min_time:.2f} - {max_time:.2f}"
|
||||
vram_max = result.get('vram_max_mb', 0.0) / 1024
|
||||
vram_avg = result.get('vram_avg_mb', 0.0) / 1024
|
||||
|
||||
f.write(f"| {model_name} | {total} | {correct} | {errors} | {accuracy:.2f}% | {error_rate:.2f}% | {avg_time:.2f} | {median_time:.2f} | {time_range} | {vram_max:.1f} | {vram_avg:.1f} |\n")
|
||||
|
||||
vram_max = result.get("vram_max_mb", 0.0) / 1024
|
||||
vram_avg = result.get("vram_avg_mb", 0.0) / 1024
|
||||
|
||||
f.write(
|
||||
f"| {model_name} | {total} | {correct} | {errors} | {accuracy:.2f}% | {error_rate:.2f}% | {avg_time:.2f} | {median_time:.2f} | {time_range} | {vram_max:.1f} | {vram_avg:.1f} |\n"
|
||||
)
|
||||
|
||||
# Detailed results for each model
|
||||
for result in all_results:
|
||||
f.write(f"\n## {result['model_name']} - Detailed Results\n\n")
|
||||
f.write("| Sample Index | Instruction | BBox | Predicted | Correct | Error | Time (s) |\n")
|
||||
f.write(
|
||||
"| Sample Index | Instruction | BBox | Predicted | Correct | Error | Time (s) |\n"
|
||||
)
|
||||
f.write("|-----------|-------------|------|-----------|---------|-------|----------|\n")
|
||||
|
||||
for sample_result in result['results'][:10]: # Show first 10 samples
|
||||
sample_idx = sample_result['sample_idx']
|
||||
instruction = sample_result['instruction'][:50] + "..." if len(sample_result['instruction']) > 50 else sample_result['instruction']
|
||||
bbox = str(sample_result['bbox'])
|
||||
predicted = str(sample_result['predicted_coords']) if sample_result['predicted_coords'] else "None"
|
||||
correct = "PASS" if sample_result['is_correct'] else "FAIL"
|
||||
error = "YES" if sample_result['failed'] else "NO"
|
||||
pred_time = sample_result.get('prediction_time', 0.0)
|
||||
|
||||
f.write(f"| {sample_idx} | {instruction} | {bbox} | {predicted} | {correct} | {error} | {pred_time:.2f} |\n")
|
||||
|
||||
if len(result['results']) > 10:
|
||||
|
||||
for sample_result in result["results"][:10]: # Show first 10 samples
|
||||
sample_idx = sample_result["sample_idx"]
|
||||
instruction = (
|
||||
sample_result["instruction"][:50] + "..."
|
||||
if len(sample_result["instruction"]) > 50
|
||||
else sample_result["instruction"]
|
||||
)
|
||||
bbox = str(sample_result["bbox"])
|
||||
predicted = (
|
||||
str(sample_result["predicted_coords"])
|
||||
if sample_result["predicted_coords"]
|
||||
else "None"
|
||||
)
|
||||
correct = "PASS" if sample_result["is_correct"] else "FAIL"
|
||||
error = "YES" if sample_result["failed"] else "NO"
|
||||
pred_time = sample_result.get("prediction_time", 0.0)
|
||||
|
||||
f.write(
|
||||
f"| {sample_idx} | {instruction} | {bbox} | {predicted} | {correct} | {error} | {pred_time:.2f} |\n"
|
||||
)
|
||||
|
||||
if len(result["results"]) > 10:
|
||||
f.write(f"\n*Showing first 10 of {len(result['results'])} samples*\n")
|
||||
|
||||
|
||||
print(f"\nResults saved to: {output_file}")
|
||||
|
||||
|
||||
def save_visualizations(all_results: List[dict], samples, output_dir: str = "output") -> None:
|
||||
"""
|
||||
Save visualizations of predicted coordinates vs bboxes to an output folder.
|
||||
|
||||
|
||||
Args:
|
||||
all_results: List of evaluation results for each model
|
||||
samples: List of sample dicts with image, bbox, instruction keys
|
||||
output_dir: Output directory path
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
|
||||
for result in all_results:
|
||||
model_name = result['model_name'].replace('/', '_').replace('\\', '_')
|
||||
model_name = result["model_name"].replace("/", "_").replace("\\", "_")
|
||||
model_dir = os.path.join(output_dir, model_name)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
|
||||
print(f"Saving visualizations for {result['model_name']}...")
|
||||
|
||||
|
||||
# Save first 10 samples for visualization
|
||||
for i, sample_result in enumerate(tqdm(result['results'][:10], desc=f"Saving {model_name} visualizations")):
|
||||
for i, sample_result in enumerate(
|
||||
tqdm(result["results"][:10], desc=f"Saving {model_name} visualizations")
|
||||
):
|
||||
# Get sample data using index
|
||||
sample_idx = sample_result['sample_idx']
|
||||
|
||||
sample_idx = sample_result["sample_idx"]
|
||||
|
||||
if sample_idx < len(samples):
|
||||
sample = samples[sample_idx]
|
||||
image = sample['image'].copy() # Make a copy to avoid modifying original
|
||||
image = sample["image"].copy() # Make a copy to avoid modifying original
|
||||
else:
|
||||
print(f"Warning: Could not find sample at index {sample_idx}")
|
||||
continue
|
||||
|
||||
bbox = sample_result['bbox']
|
||||
predicted_coords = sample_result['predicted_coords']
|
||||
is_correct = sample_result['is_correct']
|
||||
|
||||
|
||||
bbox = sample_result["bbox"]
|
||||
predicted_coords = sample_result["predicted_coords"]
|
||||
is_correct = sample_result["is_correct"]
|
||||
|
||||
# Draw on image
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
|
||||
# Draw bounding box (ground truth) in green
|
||||
x1, y1, x2, y2 = bbox
|
||||
draw.rectangle([x1, y1, x2, y2], outline="green", width=3)
|
||||
draw.text((x1, y1-20), "Ground Truth", fill="green")
|
||||
|
||||
draw.text((x1, y1 - 20), "Ground Truth", fill="green")
|
||||
|
||||
# Draw predicted click in red or blue
|
||||
if predicted_coords is not None:
|
||||
px, py = predicted_coords
|
||||
color = "blue" if is_correct else "red"
|
||||
# Draw crosshair
|
||||
crosshair_size = 15
|
||||
draw.line([(px-crosshair_size, py), (px+crosshair_size, py)], fill=color, width=3)
|
||||
draw.line([(px, py-crosshair_size), (px, py+crosshair_size)], fill=color, width=3)
|
||||
draw.text((px+10, py-20), f"Predicted ({px},{py})", fill=color)
|
||||
|
||||
draw.line(
|
||||
[(px - crosshair_size, py), (px + crosshair_size, py)], fill=color, width=3
|
||||
)
|
||||
draw.line(
|
||||
[(px, py - crosshair_size), (px, py + crosshair_size)], fill=color, width=3
|
||||
)
|
||||
draw.text((px + 10, py - 20), f"Predicted ({px},{py})", fill=color)
|
||||
|
||||
# Add status text
|
||||
status = "CORRECT" if is_correct else "INCORRECT"
|
||||
status_color = "blue" if is_correct else "red"
|
||||
draw.text((10, 10), f"Status: {status}", fill=status_color)
|
||||
draw.text((10, 30), f"Instruction: {sample_result['instruction'][:50]}...", fill="black")
|
||||
|
||||
draw.text(
|
||||
(10, 30), f"Instruction: {sample_result['instruction'][:50]}...", fill="black"
|
||||
)
|
||||
|
||||
# Save image
|
||||
filename = f"sample_{i+1:02d}_idx{sample_idx}_{status.lower()}.png"
|
||||
filepath = os.path.join(model_dir, filename)
|
||||
image.save(filepath)
|
||||
|
||||
|
||||
print(f"Visualizations saved to: {model_dir}")
|
||||
|
||||
|
||||
def save_prediction_visualization(image: Image.Image, instruction: str, predictions: List[dict],
|
||||
output_file: str = "interactive_prediction.png") -> None:
|
||||
def save_prediction_visualization(
|
||||
image: Image.Image,
|
||||
instruction: str,
|
||||
predictions: List[dict],
|
||||
output_file: str = "interactive_prediction.png",
|
||||
) -> None:
|
||||
"""
|
||||
Save visualization of multiple model predictions on a single image.
|
||||
|
||||
|
||||
Args:
|
||||
image: PIL Image to visualize
|
||||
instruction: Instruction text
|
||||
@@ -358,32 +393,32 @@ def save_prediction_visualization(image: Image.Image, instruction: str, predicti
|
||||
# Create a copy of the image
|
||||
vis_image = image.copy()
|
||||
draw = ImageDraw.Draw(vis_image)
|
||||
|
||||
|
||||
# Colors for different models
|
||||
colors = ["red", "blue", "orange", "purple", "brown", "pink", "gray", "olive"]
|
||||
|
||||
|
||||
# Draw predictions
|
||||
for i, pred in enumerate(predictions):
|
||||
color = colors[i % len(colors)]
|
||||
model_name = pred['model_name']
|
||||
coords = pred.get('coords')
|
||||
error = pred.get('error')
|
||||
|
||||
model_name = pred["model_name"]
|
||||
coords = pred.get("coords")
|
||||
error = pred.get("error")
|
||||
|
||||
if coords is not None:
|
||||
px, py = coords
|
||||
# Draw crosshair
|
||||
crosshair_size = 20
|
||||
draw.line([(px-crosshair_size, py), (px+crosshair_size, py)], fill=color, width=4)
|
||||
draw.line([(px, py-crosshair_size), (px, py+crosshair_size)], fill=color, width=4)
|
||||
draw.line([(px - crosshair_size, py), (px + crosshair_size, py)], fill=color, width=4)
|
||||
draw.line([(px, py - crosshair_size), (px, py + crosshair_size)], fill=color, width=4)
|
||||
# Draw model name
|
||||
draw.text((px+15, py+15), f"{model_name}: ({px},{py})", fill=color)
|
||||
draw.text((px + 15, py + 15), f"{model_name}: ({px},{py})", fill=color)
|
||||
else:
|
||||
# Draw error text
|
||||
draw.text((10, 50 + i*20), f"{model_name}: ERROR - {error}", fill=color)
|
||||
|
||||
draw.text((10, 50 + i * 20), f"{model_name}: ERROR - {error}", fill=color)
|
||||
|
||||
# Add instruction at the top
|
||||
draw.text((10, 10), f"Instruction: {instruction}", fill="black")
|
||||
|
||||
|
||||
# Save image
|
||||
vis_image.save(output_file)
|
||||
print(f"Prediction visualization saved to: {output_file}")
|
||||
@@ -392,12 +427,13 @@ def save_prediction_visualization(image: Image.Image, instruction: str, predicti
|
||||
def take_screenshot() -> Image.Image:
|
||||
"""
|
||||
Take a screenshot of the current screen.
|
||||
|
||||
|
||||
Returns:
|
||||
PIL Image of the screenshot
|
||||
"""
|
||||
try:
|
||||
import pyautogui
|
||||
|
||||
screenshot = pyautogui.screenshot()
|
||||
return screenshot
|
||||
except ImportError:
|
||||
@@ -406,4 +442,3 @@ def take_screenshot() -> Image.Image:
|
||||
except Exception as e:
|
||||
print(f"Error taking screenshot: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -9,58 +9,61 @@ from agent import ComputerAgent
|
||||
from computer import Computer
|
||||
from computer.helpers import sandboxed
|
||||
|
||||
|
||||
@sandboxed()
|
||||
def read_file(location: str) -> str:
|
||||
"""Read contents of a file
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
location : str
|
||||
Path to the file to read
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Contents of the file or error message
|
||||
"""
|
||||
try:
|
||||
with open(location, 'r') as f:
|
||||
with open(location, "r") as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
return f"Error reading file: {str(e)}"
|
||||
|
||||
|
||||
def save_note(content: str, filename: str = "note.txt") -> str:
|
||||
"""Save content to a note file
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
content : str
|
||||
Content to save to the file
|
||||
filename : str, optional
|
||||
Name of the file to save to (default is "note.txt")
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Success or error message
|
||||
"""
|
||||
try:
|
||||
with open(filename, 'w') as f:
|
||||
with open(filename, "w") as f:
|
||||
f.write(content)
|
||||
return f"Saved note to {filename}"
|
||||
except Exception as e:
|
||||
return f"Error saving note: {str(e)}"
|
||||
|
||||
|
||||
def calculate(a: int, b: int) -> int:
|
||||
"""Calculate the sum of two integers
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : int
|
||||
First integer
|
||||
b : int
|
||||
Second integer
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
@@ -68,15 +71,18 @@ def calculate(a: int, b: int) -> int:
|
||||
"""
|
||||
return a + b
|
||||
|
||||
|
||||
async def main():
|
||||
"""Example usage of ComputerAgent with different models"""
|
||||
|
||||
|
||||
# Example 1: Using Claude with computer and custom tools
|
||||
print("=== Example 1: Claude with Computer ===")
|
||||
|
||||
import os
|
||||
import dotenv
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
assert os.getenv("CUA_CONTAINER_NAME") is not None, "CUA_CONTAINER_NAME is not set"
|
||||
@@ -86,38 +92,37 @@ async def main():
|
||||
os_type="linux",
|
||||
provider_type="cloud",
|
||||
name=os.getenv("CUA_CONTAINER_NAME") or "",
|
||||
api_key=os.getenv("CUA_API_KEY") or ""
|
||||
api_key=os.getenv("CUA_API_KEY") or "",
|
||||
) as computer:
|
||||
agent = ComputerAgent(
|
||||
# Supported models:
|
||||
|
||||
# == OpenAI CUA (computer-use-preview) ==
|
||||
model="openai/computer-use-preview",
|
||||
|
||||
# == Anthropic CUA (Claude > 3.5) ==
|
||||
# model="anthropic/claude-opus-4-20250514",
|
||||
# model="anthropic/claude-opus-4-20250514",
|
||||
# model="anthropic/claude-sonnet-4-20250514",
|
||||
# model="anthropic/claude-3-7-sonnet-20250219",
|
||||
# model="anthropic/claude-3-5-sonnet-20241022",
|
||||
|
||||
# == UI-TARS ==
|
||||
# model="huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B",
|
||||
# TODO: add local mlx provider
|
||||
# model="mlx-community/UI-TARS-1.5-7B-6bit",
|
||||
# model="ollama_chat/0000/ui-tars-1.5-7b",
|
||||
|
||||
# == Omniparser + Any LLM ==
|
||||
# model="omniparser+..."
|
||||
# model="omniparser+anthropic/claude-opus-4-20250514",
|
||||
|
||||
tools=[computer],
|
||||
only_n_most_recent_images=3,
|
||||
verbosity=logging.INFO,
|
||||
trajectory_dir="trajectories",
|
||||
use_prompt_caching=True,
|
||||
max_trajectory_budget={ "max_budget": 1.0, "raise_error": True, "reset_after_each_run": False },
|
||||
max_trajectory_budget={
|
||||
"max_budget": 1.0,
|
||||
"raise_error": True,
|
||||
"reset_after_each_run": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
history = []
|
||||
while True:
|
||||
user_input = input("> ")
|
||||
@@ -143,5 +148,6 @@ async def main():
|
||||
# elif item["type"] == "function_call_output":
|
||||
# print("===>", item["output"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
|
||||
|
||||
[project]
|
||||
name = "cua-agent"
|
||||
version = "0.4.0"
|
||||
version = "0.4.35"
|
||||
description = "CUA (Computer Use) Agent for AI-driven computer interaction"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
@@ -29,6 +29,11 @@ requires-python = ">=3.12"
|
||||
[project.optional-dependencies]
|
||||
openai = []
|
||||
anthropic = []
|
||||
qwen = [
|
||||
"qwen-vl-utils",
|
||||
"qwen-agent",
|
||||
"Pillow>=10.0.0",
|
||||
]
|
||||
omni = [
|
||||
"cua-som>=0.1.0,<0.2.0",
|
||||
]
|
||||
@@ -49,7 +54,7 @@ glm45v-hf = [
|
||||
opencua-hf = [
|
||||
"accelerate",
|
||||
"torch",
|
||||
"transformers==4.53.0",
|
||||
"transformers>=4.53.0",
|
||||
"tiktoken>=0.11.0",
|
||||
"blobfile>=3.0.0"
|
||||
]
|
||||
@@ -60,6 +65,11 @@ internvl-hf = [
|
||||
"einops",
|
||||
"timm"
|
||||
]
|
||||
moondream3 = [
|
||||
"accelerate",
|
||||
"torch",
|
||||
"transformers>=4.55.0"
|
||||
]
|
||||
ui = [
|
||||
"gradio>=5.23.3",
|
||||
"python-dotenv>=1.0.1",
|
||||
@@ -68,7 +78,10 @@ cli = [
|
||||
"yaspin>=3.1.0",
|
||||
]
|
||||
hud = [
|
||||
"hud-python==0.4.26",
|
||||
"hud-python==0.4.52",
|
||||
]
|
||||
gemini = [
|
||||
"google-genai>=1.41.0",
|
||||
]
|
||||
all = [
|
||||
# uitars requirements
|
||||
@@ -88,7 +101,13 @@ all = [
|
||||
# cli requirements
|
||||
"yaspin>=3.1.0",
|
||||
# hud requirements
|
||||
"hud-python==0.4.26",
|
||||
"hud-python==0.4.52",
|
||||
# gemini requirements
|
||||
"google-genai>=1.41.0",
|
||||
# qwen requirements
|
||||
"qwen-vl-utils",
|
||||
"qwen-agent",
|
||||
"Pillow>=10.0.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
@@ -98,4 +117,4 @@ constraint-dependencies = ["fastrtc>0.43.0", "mlx-audio>0.2.3"]
|
||||
distribution = true
|
||||
|
||||
[tool.pdm.build]
|
||||
includes = ["agent/"]
|
||||
includes = ["agent/"]
|
||||
Reference in New Issue
Block a user