mirror of
https://github.com/trycua/computer.git
synced 2026-01-05 12:59:58 -06:00
Merge branch 'main' into models/opencua
This commit is contained in:
@@ -29,16 +29,6 @@
|
||||
|
||||
```bash
|
||||
pip install "cua-agent[all]"
|
||||
|
||||
# or install specific providers
|
||||
pip install "cua-agent[openai]" # OpenAI computer-use-preview support
|
||||
pip install "cua-agent[anthropic]" # Anthropic Claude support
|
||||
pip install "cua-agent[omni]" # Omniparser + any LLM support
|
||||
pip install "cua-agent[uitars]" # UI-TARS
|
||||
pip install "cua-agent[uitars-mlx]" # UI-TARS + MLX support
|
||||
pip install "cua-agent[uitars-hf]" # UI-TARS + Huggingface support
|
||||
pip install "cua-agent[glm45v-hf]" # GLM-4.5V + Huggingface support
|
||||
pip install "cua-agent[ui]" # Gradio UI support
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
@@ -79,303 +69,18 @@ if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
## Docs
|
||||
|
||||
### Anthropic Claude (Computer Use API)
|
||||
```python
|
||||
model="anthropic/claude-3-5-sonnet-20241022"
|
||||
model="anthropic/claude-3-7-sonnet-20250219"
|
||||
model="anthropic/claude-opus-4-20250514"
|
||||
model="anthropic/claude-sonnet-4-20250514"
|
||||
```
|
||||
|
||||
### OpenAI Computer Use Preview
|
||||
```python
|
||||
model="openai/computer-use-preview"
|
||||
```
|
||||
|
||||
### UI-TARS (Local or Huggingface Inference)
|
||||
```python
|
||||
model="huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B"
|
||||
model="ollama_chat/0000/ui-tars-1.5-7b"
|
||||
```
|
||||
|
||||
### Omniparser + Any LLM
|
||||
```python
|
||||
model="omniparser+ollama_chat/mistral-small3.2"
|
||||
model="omniparser+vertex_ai/gemini-pro"
|
||||
model="omniparser+anthropic/claude-3-5-sonnet-20241022"
|
||||
model="omniparser+openai/gpt-4o"
|
||||
```
|
||||
|
||||
## Custom Tools
|
||||
|
||||
Define custom tools using decorated functions:
|
||||
|
||||
```python
|
||||
from computer.helpers import sandboxed
|
||||
|
||||
@sandboxed()
|
||||
def read_file(location: str) -> str:
|
||||
"""Read contents of a file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
location : str
|
||||
Path to the file to read
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Contents of the file or error message
|
||||
"""
|
||||
try:
|
||||
with open(location, 'r') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
return f"Error reading file: {str(e)}"
|
||||
|
||||
def calculate(a: int, b: int) -> int:
|
||||
"""Calculate the sum of two integers"""
|
||||
return a + b
|
||||
|
||||
# Use with agent
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
tools=[computer, read_file, calculate]
|
||||
)
|
||||
```
|
||||
|
||||
## Callbacks System
|
||||
|
||||
agent provides a comprehensive callback system for extending functionality:
|
||||
|
||||
### Built-in Callbacks
|
||||
|
||||
```python
|
||||
from agent.callbacks import (
|
||||
ImageRetentionCallback,
|
||||
TrajectorySaverCallback,
|
||||
BudgetManagerCallback,
|
||||
LoggingCallback
|
||||
)
|
||||
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
tools=[computer],
|
||||
callbacks=[
|
||||
ImageRetentionCallback(only_n_most_recent_images=3),
|
||||
TrajectorySaverCallback(trajectory_dir="trajectories"),
|
||||
BudgetManagerCallback(max_budget=10.0, raise_error=True),
|
||||
LoggingCallback(level=logging.INFO)
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Callbacks
|
||||
|
||||
```python
|
||||
from agent.callbacks.base import AsyncCallbackHandler
|
||||
|
||||
class CustomCallback(AsyncCallbackHandler):
|
||||
async def on_llm_start(self, messages):
|
||||
"""Preprocess messages before LLM call"""
|
||||
# Add custom preprocessing logic
|
||||
return messages
|
||||
|
||||
async def on_llm_end(self, messages):
|
||||
"""Postprocess messages after LLM call"""
|
||||
# Add custom postprocessing logic
|
||||
return messages
|
||||
|
||||
async def on_usage(self, usage):
|
||||
"""Track usage information"""
|
||||
print(f"Tokens used: {usage.total_tokens}")
|
||||
```
|
||||
|
||||
## Budget Management
|
||||
|
||||
Control costs with built-in budget management:
|
||||
|
||||
```python
|
||||
# Simple budget limit
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
max_trajectory_budget=5.0 # $5 limit
|
||||
)
|
||||
|
||||
# Advanced budget configuration
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
max_trajectory_budget={
|
||||
"max_budget": 10.0,
|
||||
"raise_error": True, # Raise error when exceeded
|
||||
"reset_after_each_run": False # Persistent across runs
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Trajectory Management
|
||||
|
||||
Save and replay agent conversations:
|
||||
|
||||
```python
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
trajectory_dir="trajectories", # Auto-save trajectories
|
||||
tools=[computer]
|
||||
)
|
||||
|
||||
# Trajectories are saved with:
|
||||
# - Complete conversation history
|
||||
# - Usage statistics and costs
|
||||
# - Timestamps and metadata
|
||||
# - Screenshots and computer actions
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### ComputerAgent Parameters
|
||||
|
||||
- `model`: Model identifier (required)
|
||||
- `tools`: List of computer objects and decorated functions
|
||||
- `callbacks`: List of callback handlers for extensibility
|
||||
- `only_n_most_recent_images`: Limit recent images to prevent context overflow
|
||||
- `verbosity`: Logging level (logging.INFO, logging.DEBUG, etc.)
|
||||
- `trajectory_dir`: Directory to save conversation trajectories
|
||||
- `max_retries`: Maximum API call retries (default: 3)
|
||||
- `screenshot_delay`: Delay between actions and screenshots (default: 0.5s)
|
||||
- `use_prompt_caching`: Enable prompt caching for supported models
|
||||
- `max_trajectory_budget`: Budget limit configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# Computer instance (cloud)
|
||||
export CUA_CONTAINER_NAME="your-container-name"
|
||||
export CUA_API_KEY="your-cua-api-key"
|
||||
|
||||
# LLM API keys
|
||||
export ANTHROPIC_API_KEY="your-anthropic-key"
|
||||
export OPENAI_API_KEY="your-openai-key"
|
||||
```
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Streaming Responses
|
||||
|
||||
```python
|
||||
async for result in agent.run(messages, stream=True):
|
||||
# Process streaming chunks
|
||||
for item in result["output"]:
|
||||
if item["type"] == "message":
|
||||
print(item["content"][0]["text"], end="", flush=True)
|
||||
elif item["type"] == "computer_call":
|
||||
action = item["action"]
|
||||
print(f"\n[Action: {action['type']}]")
|
||||
```
|
||||
|
||||
### Interactive Chat Loop
|
||||
|
||||
```python
|
||||
history = []
|
||||
while True:
|
||||
user_input = input("> ")
|
||||
if user_input.lower() in ['quit', 'exit']:
|
||||
break
|
||||
|
||||
history.append({"role": "user", "content": user_input})
|
||||
|
||||
async for result in agent.run(history):
|
||||
history += result["output"]
|
||||
|
||||
# Display assistant responses
|
||||
for item in result["output"]:
|
||||
if item["type"] == "message":
|
||||
print(item["content"][0]["text"])
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
|
||||
```python
|
||||
try:
|
||||
async for result in agent.run(messages):
|
||||
# Process results
|
||||
pass
|
||||
except BudgetExceededException:
|
||||
print("Budget limit exceeded")
|
||||
except Exception as e:
|
||||
print(f"Agent error: {e}")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### ComputerAgent.run()
|
||||
|
||||
```python
|
||||
async def run(
|
||||
self,
|
||||
messages: Messages,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
Run the agent with the given messages.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
stream: Whether to stream the response
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
AsyncGenerator that yields response chunks
|
||||
"""
|
||||
```
|
||||
|
||||
### Message Format
|
||||
|
||||
```python
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Take a screenshot and describe what you see"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'll take a screenshot for you."
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Response Format
|
||||
|
||||
```python
|
||||
{
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "I can see..."}]
|
||||
},
|
||||
{
|
||||
"type": "computer_call",
|
||||
"action": {"type": "screenshot"},
|
||||
"call_id": "call_123"
|
||||
},
|
||||
{
|
||||
"type": "computer_call_output",
|
||||
"call_id": "call_123",
|
||||
"output": {"image_url": "data:image/png;base64,..."}
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 150,
|
||||
"completion_tokens": 75,
|
||||
"total_tokens": 225,
|
||||
"response_cost": 0.01,
|
||||
}
|
||||
}
|
||||
```
|
||||
- [Agent Loops](https://trycua.com/docs/agent-sdk/agent-loops)
|
||||
- [Supported Agents](https://trycua.com/docs/agent-sdk/supported-agents)
|
||||
- [Supported Models](https://trycua.com/docs/agent-sdk/supported-models)
|
||||
- [Chat History](https://trycua.com/docs/agent-sdk/chat-history)
|
||||
- [Callbacks](https://trycua.com/docs/agent-sdk/callbacks)
|
||||
- [Custom Tools](https://trycua.com/docs/agent-sdk/custom-tools)
|
||||
- [Custom Computer Handlers](https://trycua.com/docs/agent-sdk/custom-computer-handlers)
|
||||
- [Prompt Caching](https://trycua.com/docs/agent-sdk/prompt-caching)
|
||||
- [Usage Tracking](https://trycua.com/docs/agent-sdk/usage-tracking)
|
||||
- [Benchmarks](https://trycua.com/docs/agent-sdk/benchmarks)
|
||||
|
||||
## License
|
||||
|
||||
|
||||
@@ -4,8 +4,10 @@ Adapters package for agent - Custom LLM adapters for LiteLLM
|
||||
|
||||
from .huggingfacelocal_adapter import HuggingFaceLocalAdapter
|
||||
from .human_adapter import HumanAdapter
|
||||
from .mlxvlm_adapter import MLXVLMAdapter
|
||||
|
||||
__all__ = [
|
||||
"HuggingFaceLocalAdapter",
|
||||
"HumanAdapter",
|
||||
"MLXVLMAdapter",
|
||||
]
|
||||
|
||||
359
libs/python/agent/agent/adapters/mlxvlm_adapter.py
Normal file
359
libs/python/agent/agent/adapters/mlxvlm_adapter.py
Normal file
@@ -0,0 +1,359 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import warnings
|
||||
import io
|
||||
import base64
|
||||
import math
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Iterator, AsyncIterator, Dict, List, Any, Optional, Tuple, cast
|
||||
from PIL import Image
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||
from litellm.llms.custom_llm import CustomLLM
|
||||
from litellm import completion, acompletion
|
||||
|
||||
# Try to import MLX dependencies
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_vlm import load, generate
|
||||
from mlx_vlm.prompt_utils import apply_chat_template
|
||||
from mlx_vlm.utils import load_config
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
MLX_AVAILABLE = True
|
||||
except ImportError:
|
||||
MLX_AVAILABLE = False
|
||||
|
||||
# Constants for smart_resize
|
||||
IMAGE_FACTOR = 28
|
||||
MIN_PIXELS = 100 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
|
||||
def round_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
def ceil_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
def floor_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
def smart_resize(
|
||||
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
"""
|
||||
if max(height, width) / min(height, width) > MAX_RATIO:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
return h_bar, w_bar
|
||||
|
||||
|
||||
class MLXVLMAdapter(CustomLLM):
|
||||
"""MLX VLM Adapter for running vision-language models locally using MLX."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the adapter.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.models = {} # Cache for loaded models
|
||||
self.processors = {} # Cache for loaded processors
|
||||
self.configs = {} # Cache for loaded configs
|
||||
self._executor = ThreadPoolExecutor(max_workers=1) # Single thread pool
|
||||
|
||||
def _load_model_and_processor(self, model_name: str):
|
||||
"""Load model and processor if not already cached.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to load
|
||||
|
||||
Returns:
|
||||
Tuple of (model, processor, config)
|
||||
"""
|
||||
if not MLX_AVAILABLE:
|
||||
raise ImportError("MLX VLM dependencies not available. Please install mlx-vlm.")
|
||||
|
||||
if model_name not in self.models:
|
||||
# Load model and processor
|
||||
model_obj, processor = load(
|
||||
model_name,
|
||||
processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
|
||||
)
|
||||
config = load_config(model_name)
|
||||
|
||||
# Cache them
|
||||
self.models[model_name] = model_obj
|
||||
self.processors[model_name] = processor
|
||||
self.configs[model_name] = config
|
||||
|
||||
return self.models[model_name], self.processors[model_name], self.configs[model_name]
|
||||
|
||||
def _process_coordinates(self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]) -> str:
|
||||
"""Process coordinates in box tokens based on image resizing using smart_resize approach.
|
||||
|
||||
Args:
|
||||
text: Text containing box tokens
|
||||
original_size: Original image size (width, height)
|
||||
model_size: Model processed image size (width, height)
|
||||
|
||||
Returns:
|
||||
Text with processed coordinates
|
||||
"""
|
||||
# Find all box tokens
|
||||
box_pattern = r"<\|box_start\|>\((\d+),\s*(\d+)\)<\|box_end\|>"
|
||||
|
||||
def process_coords(match):
|
||||
model_x, model_y = int(match.group(1)), int(match.group(2))
|
||||
# Scale coordinates from model space to original image space
|
||||
# Both original_size and model_size are in (width, height) format
|
||||
new_x = int(model_x * original_size[0] / model_size[0]) # Width
|
||||
new_y = int(model_y * original_size[1] / model_size[1]) # Height
|
||||
return f"<|box_start|>({new_x},{new_y})<|box_end|>"
|
||||
|
||||
return re.sub(box_pattern, process_coords, text)
|
||||
|
||||
def _convert_messages(self, messages: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Image.Image], Dict[int, Tuple[int, int]], Dict[int, Tuple[int, int]]]:
|
||||
"""Convert OpenAI format messages to MLX VLM format and extract images.
|
||||
|
||||
Args:
|
||||
messages: Messages in OpenAI format
|
||||
|
||||
Returns:
|
||||
Tuple of (processed_messages, images, original_sizes, model_sizes)
|
||||
"""
|
||||
processed_messages = []
|
||||
images = []
|
||||
original_sizes = {} # Track original sizes of images for coordinate mapping
|
||||
model_sizes = {} # Track model processed sizes
|
||||
image_index = 0
|
||||
|
||||
for message in messages:
|
||||
processed_message = {
|
||||
"role": message["role"],
|
||||
"content": []
|
||||
}
|
||||
|
||||
content = message.get("content", [])
|
||||
if isinstance(content, str):
|
||||
# Simple text content
|
||||
processed_message["content"] = content
|
||||
elif isinstance(content, list):
|
||||
# Multi-modal content
|
||||
processed_content = []
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
processed_content.append({
|
||||
"type": "text",
|
||||
"text": item.get("text", "")
|
||||
})
|
||||
elif item.get("type") == "image_url":
|
||||
image_url = item.get("image_url", {}).get("url", "")
|
||||
pil_image = None
|
||||
|
||||
if image_url.startswith("data:image/"):
|
||||
# Extract base64 data
|
||||
base64_data = image_url.split(',')[1]
|
||||
# Convert base64 to PIL Image
|
||||
image_data = base64.b64decode(base64_data)
|
||||
pil_image = Image.open(io.BytesIO(image_data))
|
||||
else:
|
||||
# Handle file path or URL
|
||||
pil_image = Image.open(image_url)
|
||||
|
||||
# Store original image size for coordinate mapping
|
||||
original_size = pil_image.size
|
||||
original_sizes[image_index] = original_size
|
||||
|
||||
# Use smart_resize to determine model size
|
||||
# Note: smart_resize expects (height, width) but PIL gives (width, height)
|
||||
height, width = original_size[1], original_size[0]
|
||||
new_height, new_width = smart_resize(height, width)
|
||||
# Store model size in (width, height) format for consistent coordinate processing
|
||||
model_sizes[image_index] = (new_width, new_height)
|
||||
|
||||
# Resize the image using the calculated dimensions from smart_resize
|
||||
resized_image = pil_image.resize((new_width, new_height))
|
||||
images.append(resized_image)
|
||||
|
||||
# Add image placeholder to content
|
||||
processed_content.append({
|
||||
"type": "image"
|
||||
})
|
||||
|
||||
image_index += 1
|
||||
|
||||
processed_message["content"] = processed_content
|
||||
|
||||
processed_messages.append(processed_message)
|
||||
|
||||
return processed_messages, images, original_sizes, model_sizes
|
||||
|
||||
def _generate(self, **kwargs) -> str:
|
||||
"""Generate response using the local MLX VLM model.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing messages and model info
|
||||
|
||||
Returns:
|
||||
Generated text response
|
||||
"""
|
||||
messages = kwargs.get('messages', [])
|
||||
model_name = kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')
|
||||
max_tokens = kwargs.get('max_tokens', 128)
|
||||
|
||||
# Warn about ignored kwargs
|
||||
ignored_kwargs = set(kwargs.keys()) - {'messages', 'model', 'max_tokens'}
|
||||
if ignored_kwargs:
|
||||
warnings.warn(f"Ignoring unsupported kwargs: {ignored_kwargs}")
|
||||
|
||||
# Load model and processor
|
||||
model, processor, config = self._load_model_and_processor(model_name)
|
||||
|
||||
# Convert messages and extract images
|
||||
processed_messages, images, original_sizes, model_sizes = self._convert_messages(messages)
|
||||
|
||||
# Process user text input with box coordinates after image processing
|
||||
# Swap original_size and model_size arguments for inverse transformation
|
||||
for msg_idx, msg in enumerate(processed_messages):
|
||||
if msg.get("role") == "user" and isinstance(msg.get("content"), str):
|
||||
content = msg.get("content", "")
|
||||
if "<|box_start|>" in content and original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
|
||||
orig_size = original_sizes[0]
|
||||
model_size = model_sizes[0]
|
||||
# Swap arguments to perform inverse transformation for user input
|
||||
processed_messages[msg_idx]["content"] = self._process_coordinates(content, model_size, orig_size)
|
||||
|
||||
try:
|
||||
# Format prompt according to model requirements using the processor directly
|
||||
prompt = processor.apply_chat_template(
|
||||
processed_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
tokenizer = cast(PreTrainedTokenizer, processor)
|
||||
|
||||
# Generate response
|
||||
text_content, usage = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
str(prompt),
|
||||
images, # type: ignore
|
||||
verbose=False,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error generating response: {str(e)}") from e
|
||||
|
||||
# Process coordinates in the response back to original image space
|
||||
if original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
|
||||
# Get original image size and model size (using the first image)
|
||||
orig_size = original_sizes[0]
|
||||
model_size = model_sizes[0]
|
||||
|
||||
# Check if output contains box tokens that need processing
|
||||
if "<|box_start|>" in text_content:
|
||||
# Process coordinates from model space back to original image space
|
||||
text_content = self._process_coordinates(text_content, orig_size, model_size)
|
||||
|
||||
return text_content
|
||||
|
||||
def completion(self, *args, **kwargs) -> ModelResponse:
|
||||
"""Synchronous completion method.
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated text
|
||||
"""
|
||||
generated_text = self._generate(**kwargs)
|
||||
|
||||
result = completion(
|
||||
model=f"mlx/{kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')}",
|
||||
mock_response=generated_text,
|
||||
)
|
||||
return cast(ModelResponse, result)
|
||||
|
||||
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
||||
"""Asynchronous completion method.
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated text
|
||||
"""
|
||||
# Run _generate in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
generated_text = await loop.run_in_executor(
|
||||
self._executor,
|
||||
functools.partial(self._generate, **kwargs)
|
||||
)
|
||||
|
||||
result = await acompletion(
|
||||
model=f"mlx/{kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')}",
|
||||
mock_response=generated_text,
|
||||
)
|
||||
return cast(ModelResponse, result)
|
||||
|
||||
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||
"""Synchronous streaming method.
|
||||
|
||||
Returns:
|
||||
Iterator of GenericStreamingChunk
|
||||
"""
|
||||
generated_text = self._generate(**kwargs)
|
||||
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"is_finished": True,
|
||||
"text": generated_text,
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
yield generic_streaming_chunk
|
||||
|
||||
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
||||
"""Asynchronous streaming method.
|
||||
|
||||
Returns:
|
||||
AsyncIterator of GenericStreamingChunk
|
||||
"""
|
||||
# Run _generate in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
generated_text = await loop.run_in_executor(
|
||||
self._executor,
|
||||
functools.partial(self._generate, **kwargs)
|
||||
)
|
||||
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"is_finished": True,
|
||||
"text": generated_text,
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
yield generic_streaming_chunk
|
||||
@@ -3,6 +3,7 @@ ComputerAgent - Main agent class that selects and runs agent loops
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Callable, Set, Tuple
|
||||
|
||||
from litellm.responses.utils import Usage
|
||||
@@ -22,6 +23,7 @@ import inspect
|
||||
from .adapters import (
|
||||
HuggingFaceLocalAdapter,
|
||||
HumanAdapter,
|
||||
MLXVLMAdapter,
|
||||
)
|
||||
from .callbacks import (
|
||||
ImageRetentionCallback,
|
||||
@@ -29,6 +31,7 @@ from .callbacks import (
|
||||
TrajectorySaverCallback,
|
||||
BudgetManagerCallback,
|
||||
TelemetryCallback,
|
||||
OperatorNormalizerCallback
|
||||
)
|
||||
from .computers import (
|
||||
AsyncComputerHandler,
|
||||
@@ -160,7 +163,7 @@ class ComputerAgent:
|
||||
only_n_most_recent_images: Optional[int] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
verbosity: Optional[int] = None,
|
||||
trajectory_dir: Optional[str] = None,
|
||||
trajectory_dir: Optional[str | Path | dict] = None,
|
||||
max_retries: Optional[int] = 3,
|
||||
screenshot_delay: Optional[float | int] = 0.5,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
@@ -187,7 +190,11 @@ class ComputerAgent:
|
||||
telemetry_enabled: If set, adds TelemetryCallback to track anonymized usage data. Enabled by default.
|
||||
trust_remote_code: If set, trust remote code when loading local models. Disabled by default.
|
||||
**kwargs: Additional arguments passed to the agent loop
|
||||
"""
|
||||
"""
|
||||
# If the loop is "human/human", we need to prefix a grounding model fallback
|
||||
if model in ["human/human", "human"]:
|
||||
model = "openai/computer-use-preview+human/human"
|
||||
|
||||
self.model = model
|
||||
self.tools = tools or []
|
||||
self.custom_loop = custom_loop
|
||||
@@ -204,6 +211,9 @@ class ComputerAgent:
|
||||
|
||||
# == Add built-in callbacks ==
|
||||
|
||||
# Prepend operator normalizer callback
|
||||
self.callbacks.insert(0, OperatorNormalizerCallback())
|
||||
|
||||
# Add telemetry callback if telemetry_enabled is set
|
||||
if self.telemetry_enabled:
|
||||
if isinstance(self.telemetry_enabled, bool):
|
||||
@@ -221,7 +231,10 @@ class ComputerAgent:
|
||||
|
||||
# Add trajectory saver callback if trajectory_dir is set
|
||||
if self.trajectory_dir:
|
||||
self.callbacks.append(TrajectorySaverCallback(self.trajectory_dir))
|
||||
if isinstance(self.trajectory_dir, dict):
|
||||
self.callbacks.append(TrajectorySaverCallback(**self.trajectory_dir))
|
||||
elif isinstance(self.trajectory_dir, (str, Path)):
|
||||
self.callbacks.append(TrajectorySaverCallback(str(self.trajectory_dir)))
|
||||
|
||||
# Add budget manager if max_trajectory_budget is set
|
||||
if max_trajectory_budget:
|
||||
@@ -238,9 +251,11 @@ class ComputerAgent:
|
||||
trust_remote_code=self.trust_remote_code or False
|
||||
)
|
||||
human_adapter = HumanAdapter()
|
||||
mlx_adapter = MLXVLMAdapter()
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "huggingface-local", "custom_handler": hf_adapter},
|
||||
{"provider": "human", "custom_handler": human_adapter}
|
||||
{"provider": "human", "custom_handler": human_adapter},
|
||||
{"provider": "mlx", "custom_handler": mlx_adapter}
|
||||
]
|
||||
litellm.suppress_debug_info = True
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from .logging import LoggingCallback
|
||||
from .trajectory_saver import TrajectorySaverCallback
|
||||
from .budget_manager import BudgetManagerCallback
|
||||
from .telemetry import TelemetryCallback
|
||||
from .operator_validator import OperatorNormalizerCallback
|
||||
|
||||
__all__ = [
|
||||
"AsyncCallbackHandler",
|
||||
@@ -16,4 +17,5 @@ __all__ = [
|
||||
"TrajectorySaverCallback",
|
||||
"BudgetManagerCallback",
|
||||
"TelemetryCallback",
|
||||
"OperatorNormalizerCallback",
|
||||
]
|
||||
|
||||
@@ -50,90 +50,41 @@ class ImageRetentionCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
if self.only_n_most_recent_images is None:
|
||||
return messages
|
||||
|
||||
# First pass: Assign call_id to reasoning items based on the next computer_call
|
||||
messages_with_call_ids = []
|
||||
for i, msg in enumerate(messages):
|
||||
msg_copy = msg.copy() if isinstance(msg, dict) else msg
|
||||
|
||||
# If this is a reasoning item without a call_id, find the next computer_call
|
||||
if (msg_copy.get("type") == "reasoning" and
|
||||
not msg_copy.get("call_id")):
|
||||
# Look ahead for the next computer_call
|
||||
for j in range(i + 1, len(messages)):
|
||||
next_msg = messages[j]
|
||||
if (next_msg.get("type") == "computer_call" and
|
||||
next_msg.get("call_id")):
|
||||
msg_copy["call_id"] = next_msg.get("call_id")
|
||||
break
|
||||
|
||||
messages_with_call_ids.append(msg_copy)
|
||||
|
||||
# Find all computer_call_output items with images and their call_ids
|
||||
image_call_ids = []
|
||||
for msg in reversed(messages_with_call_ids): # Process in reverse to get most recent first
|
||||
if (msg.get("type") == "computer_call_output" and
|
||||
isinstance(msg.get("output"), dict) and
|
||||
"image_url" in msg.get("output", {})):
|
||||
call_id = msg.get("call_id")
|
||||
if call_id and call_id not in image_call_ids:
|
||||
image_call_ids.append(call_id)
|
||||
if len(image_call_ids) >= self.only_n_most_recent_images:
|
||||
break
|
||||
|
||||
# Keep the most recent N image call_ids (reverse to get chronological order)
|
||||
keep_call_ids = set(image_call_ids[:self.only_n_most_recent_images])
|
||||
|
||||
# Filter messages: remove computer_call, computer_call_output, and reasoning for old images
|
||||
filtered_messages = []
|
||||
for msg in messages_with_call_ids:
|
||||
msg_type = msg.get("type")
|
||||
call_id = msg.get("call_id")
|
||||
|
||||
# Remove old computer_call items
|
||||
if msg_type == "computer_call" and call_id not in keep_call_ids:
|
||||
# Check if this call_id corresponds to an image call
|
||||
has_image_output = any(
|
||||
m.get("type") == "computer_call_output" and
|
||||
m.get("call_id") == call_id and
|
||||
isinstance(m.get("output"), dict) and
|
||||
"image_url" in m.get("output", {})
|
||||
for m in messages_with_call_ids
|
||||
)
|
||||
if has_image_output:
|
||||
continue # Skip this computer_call
|
||||
|
||||
# Remove old computer_call_output items with images
|
||||
if (msg_type == "computer_call_output" and
|
||||
call_id not in keep_call_ids and
|
||||
isinstance(msg.get("output"), dict) and
|
||||
"image_url" in msg.get("output", {})):
|
||||
continue # Skip this computer_call_output
|
||||
|
||||
# Remove old reasoning items that are paired with removed computer calls
|
||||
if (msg_type == "reasoning" and
|
||||
call_id and call_id not in keep_call_ids):
|
||||
# Check if this call_id corresponds to an image call that's being removed
|
||||
has_image_output = any(
|
||||
m.get("type") == "computer_call_output" and
|
||||
m.get("call_id") == call_id and
|
||||
isinstance(m.get("output"), dict) and
|
||||
"image_url" in m.get("output", {})
|
||||
for m in messages_with_call_ids
|
||||
)
|
||||
if has_image_output:
|
||||
continue # Skip this reasoning item
|
||||
|
||||
filtered_messages.append(msg)
|
||||
|
||||
# Clean up: Remove call_id from reasoning items before returning
|
||||
final_messages = []
|
||||
for msg in filtered_messages:
|
||||
if msg.get("type") == "reasoning" and "call_id" in msg:
|
||||
# Create a copy without call_id for reasoning items
|
||||
cleaned_msg = {k: v for k, v in msg.items() if k != "call_id"}
|
||||
final_messages.append(cleaned_msg)
|
||||
else:
|
||||
final_messages.append(msg)
|
||||
|
||||
return final_messages
|
||||
|
||||
# Gather indices of all computer_call_output messages that contain an image_url
|
||||
output_indices: List[int] = []
|
||||
for idx, msg in enumerate(messages):
|
||||
if msg.get("type") == "computer_call_output":
|
||||
out = msg.get("output")
|
||||
if isinstance(out, dict) and ("image_url" in out):
|
||||
output_indices.append(idx)
|
||||
|
||||
# Nothing to trim
|
||||
if len(output_indices) <= self.only_n_most_recent_images:
|
||||
return messages
|
||||
|
||||
# Determine which outputs to keep (most recent N)
|
||||
keep_output_indices = set(output_indices[-self.only_n_most_recent_images :])
|
||||
|
||||
# Build set of indices to remove in one pass
|
||||
to_remove: set[int] = set()
|
||||
|
||||
for idx in output_indices:
|
||||
if idx in keep_output_indices:
|
||||
continue # keep this screenshot and its context
|
||||
|
||||
to_remove.add(idx) # remove the computer_call_output itself
|
||||
|
||||
# Remove the immediately preceding computer_call with matching call_id (if present)
|
||||
call_id = messages[idx].get("call_id")
|
||||
prev_idx = idx - 1
|
||||
if prev_idx >= 0 and messages[prev_idx].get("type") == "computer_call" and messages[prev_idx].get("call_id") == call_id:
|
||||
to_remove.add(prev_idx)
|
||||
# Check a single reasoning immediately before that computer_call
|
||||
r_idx = prev_idx - 1
|
||||
if r_idx >= 0 and messages[r_idx].get("type") == "reasoning":
|
||||
to_remove.add(r_idx)
|
||||
|
||||
# Construct filtered list
|
||||
filtered = [m for i, m in enumerate(messages) if i not in to_remove]
|
||||
return filtered
|
||||
138
libs/python/agent/agent/callbacks/operator_validator.py
Normal file
138
libs/python/agent/agent/callbacks/operator_validator.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
OperatorValidatorCallback
|
||||
|
||||
Ensures agent output actions conform to expected schemas by fixing common issues:
|
||||
- click: add default button='left' if missing
|
||||
- keypress: wrap keys string into a list
|
||||
- etc.
|
||||
|
||||
This runs in on_llm_end, which receives the output array (AgentMessage[] as dicts).
|
||||
The purpose is to avoid spending another LLM call to fix broken computer call syntax when possible.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
|
||||
class OperatorNormalizerCallback(AsyncCallbackHandler):
|
||||
"""Normalizes common computer call hallucinations / errors in computer call syntax."""
|
||||
|
||||
async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
# Mutate in-place as requested, but still return the list for chaining
|
||||
for item in output or []:
|
||||
if item.get("type") != "computer_call":
|
||||
continue
|
||||
action = item.get("action")
|
||||
if not isinstance(action, dict):
|
||||
continue
|
||||
|
||||
# rename mouse click actions to "click"
|
||||
for mouse_btn in ["left", "right", "wheel", "back", "forward"]:
|
||||
if action.get("type", "") == f"{mouse_btn}_click":
|
||||
action["type"] = "click"
|
||||
action["button"] = mouse_btn
|
||||
# rename hotkey actions to "keypress"
|
||||
for alias in ["hotkey", "key", "press", "key_press"]:
|
||||
if action.get("type", "") == alias:
|
||||
action["type"] = "keypress"
|
||||
# assume click actions
|
||||
if "button" in action and "type" not in action:
|
||||
action["type"] = "click"
|
||||
if "click" in action and "type" not in action:
|
||||
action["type"] = "click"
|
||||
if ("scroll_x" in action or "scroll_y" in action) and "type" not in action:
|
||||
action["type"] = "scroll"
|
||||
if "text" in action and "type" not in action:
|
||||
action["type"] = "type"
|
||||
|
||||
action_type = action.get("type")
|
||||
def _keep_keys(action: Dict[str, Any], keys_to_keep: List[str]):
|
||||
"""Keep only the provided keys on action; delete everything else.
|
||||
Always ensures required 'type' is present if listed in keys_to_keep.
|
||||
"""
|
||||
for key in list(action.keys()):
|
||||
if key not in keys_to_keep:
|
||||
del action[key]
|
||||
# rename "coordinate" to "x", "y"
|
||||
if "coordinate" in action:
|
||||
action["x"] = action["coordinate"][0]
|
||||
action["y"] = action["coordinate"][1]
|
||||
del action["coordinate"]
|
||||
if action_type == "click":
|
||||
# convert "click" to "button"
|
||||
if "button" not in action and "click" in action:
|
||||
action["button"] = action["click"]
|
||||
del action["click"]
|
||||
# default button to "left"
|
||||
action["button"] = action.get("button", "left")
|
||||
# add default scroll x, y if missing
|
||||
if action_type == "scroll":
|
||||
action["scroll_x"] = action.get("scroll_x", 0)
|
||||
action["scroll_y"] = action.get("scroll_y", 0)
|
||||
# ensure keys arg is a list (normalize aliases first)
|
||||
if action_type == "keypress":
|
||||
keys = action.get("keys")
|
||||
for keys_alias in ["keypress", "key", "press", "key_press", "text"]:
|
||||
if keys_alias in action:
|
||||
action["keys"] = action[keys_alias]
|
||||
del action[keys_alias]
|
||||
keys = action.get("keys")
|
||||
if isinstance(keys, str):
|
||||
action["keys"] = keys.replace("-", "+").split("+") if len(keys) > 1 else [keys]
|
||||
required_keys_by_type = {
|
||||
# OpenAI actions
|
||||
"click": ["type", "button", "x", "y"],
|
||||
"double_click": ["type", "x", "y"],
|
||||
"drag": ["type", "path"],
|
||||
"keypress": ["type", "keys"],
|
||||
"move": ["type", "x", "y"],
|
||||
"screenshot": ["type"],
|
||||
"scroll": ["type", "scroll_x", "scroll_y", "x", "y"],
|
||||
"type": ["type", "text"],
|
||||
"wait": ["type"],
|
||||
# Anthropic actions
|
||||
"left_mouse_down": ["type", "x", "y"],
|
||||
"left_mouse_up": ["type", "x", "y"],
|
||||
"triple_click": ["type", "button", "x", "y"],
|
||||
}
|
||||
keep = required_keys_by_type.get(action_type or "")
|
||||
if keep:
|
||||
_keep_keys(action, keep)
|
||||
|
||||
|
||||
# # Second pass: if an assistant message is immediately followed by a computer_call,
|
||||
# # replace the assistant message itself with a reasoning message with summary text.
|
||||
# if isinstance(output, list):
|
||||
# for i, item in enumerate(output):
|
||||
# # AssistantMessage shape: { type: 'message', role: 'assistant', content: OutputContent[] }
|
||||
# if item.get("type") == "message" and item.get("role") == "assistant":
|
||||
# next_idx = i + 1
|
||||
# if next_idx >= len(output):
|
||||
# continue
|
||||
# next_item = output[next_idx]
|
||||
# if not isinstance(next_item, dict):
|
||||
# continue
|
||||
# if next_item.get("type") != "computer_call":
|
||||
# continue
|
||||
# contents = item.get("content") or []
|
||||
# # Extract text from OutputContent[]
|
||||
# text_parts: List[str] = []
|
||||
# if isinstance(contents, list):
|
||||
# for c in contents:
|
||||
# if isinstance(c, dict) and c.get("type") == "output_text" and isinstance(c.get("text"), str):
|
||||
# text_parts.append(c["text"])
|
||||
# text_content = "\n".join(text_parts).strip()
|
||||
# # Replace assistant message with reasoning message
|
||||
# output[i] = {
|
||||
# "type": "reasoning",
|
||||
# "summary": [
|
||||
# {
|
||||
# "type": "summary_text",
|
||||
# "text": text_content,
|
||||
# }
|
||||
# ],
|
||||
# }
|
||||
|
||||
return output
|
||||
@@ -11,6 +11,8 @@ from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Union, override
|
||||
from PIL import Image, ImageDraw
|
||||
import io
|
||||
from copy import deepcopy
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
def sanitize_image_urls(data: Any) -> Any:
|
||||
@@ -43,6 +45,64 @@ def sanitize_image_urls(data: Any) -> Any:
|
||||
return data
|
||||
|
||||
|
||||
def extract_computer_call_outputs(items: List[Dict[str, Any]], screenshot_dir: Optional[Path]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Save any base64-encoded screenshots from computer_call_output entries to files and
|
||||
replace their image_url with the saved file path when a call_id is present.
|
||||
|
||||
Only operates if screenshot_dir is provided and exists; otherwise returns items unchanged.
|
||||
|
||||
Args:
|
||||
items: List of message/result dicts potentially containing computer_call_output entries
|
||||
screenshot_dir: Directory to write screenshots into
|
||||
|
||||
Returns:
|
||||
A new list with updated image_url fields when applicable.
|
||||
"""
|
||||
if not items:
|
||||
return items
|
||||
if not screenshot_dir or not screenshot_dir.exists():
|
||||
return items
|
||||
|
||||
updated: List[Dict[str, Any]] = []
|
||||
for item in items:
|
||||
# work on a shallow copy; deep copy nested 'output' if we modify it
|
||||
msg = dict(item)
|
||||
try:
|
||||
if msg.get("type") == "computer_call_output":
|
||||
call_id = msg.get("call_id")
|
||||
output = msg.get("output", {})
|
||||
image_url = output.get("image_url")
|
||||
if call_id and isinstance(image_url, str) and image_url.startswith("data:"):
|
||||
# derive extension from MIME type e.g. data:image/png;base64,
|
||||
try:
|
||||
ext = image_url.split(";", 1)[0].split("/")[-1]
|
||||
if not ext:
|
||||
ext = "png"
|
||||
except Exception:
|
||||
ext = "png"
|
||||
out_path = screenshot_dir / f"{call_id}.{ext}"
|
||||
# write file if it doesn't exist
|
||||
if not out_path.exists():
|
||||
try:
|
||||
b64_payload = image_url.split(",", 1)[1]
|
||||
img_bytes = base64.b64decode(b64_payload)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(out_path, "wb") as f:
|
||||
f.write(img_bytes)
|
||||
except Exception:
|
||||
# if anything fails, skip modifying this message
|
||||
pass
|
||||
# update image_url to file path
|
||||
new_output = dict(output)
|
||||
new_output["image_url"] = str(out_path)
|
||||
msg["output"] = new_output
|
||||
except Exception:
|
||||
# do not block on malformed entries; keep original
|
||||
pass
|
||||
updated.append(msg)
|
||||
return updated
|
||||
|
||||
class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
Callback handler that saves agent trajectories to disk.
|
||||
@@ -51,7 +111,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
within the trajectory gets its own folder with screenshots and responses.
|
||||
"""
|
||||
|
||||
def __init__(self, trajectory_dir: str, reset_on_run: bool = True):
|
||||
def __init__(self, trajectory_dir: str, reset_on_run: bool = True, screenshot_dir: Optional[str] = None):
|
||||
"""
|
||||
Initialize trajectory saver.
|
||||
|
||||
@@ -67,10 +127,12 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
self.model: Optional[str] = None
|
||||
self.total_usage: Dict[str, Any] = {}
|
||||
self.reset_on_run = reset_on_run
|
||||
# Optional directory to store extracted screenshots from metadata/new_items
|
||||
self.screenshot_dir: Optional[Path] = Path(screenshot_dir) if screenshot_dir else None
|
||||
|
||||
# Ensure trajectory directory exists
|
||||
self.trajectory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def _get_turn_dir(self) -> Path:
|
||||
"""Get the directory for the current turn."""
|
||||
if not self.trajectory_id:
|
||||
@@ -94,6 +156,10 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
# format: turn_000/0000_name.json
|
||||
artifact_filename = f"{self.current_artifact:04d}_{name}"
|
||||
artifact_path = turn_dir / f"{artifact_filename}.json"
|
||||
# add created_at
|
||||
if isinstance(artifact, dict):
|
||||
artifact = artifact.copy()
|
||||
artifact["created_at"] = str(uuid.uuid1().time)
|
||||
with open(artifact_path, "w") as f:
|
||||
json.dump(sanitize_image_urls(artifact), f, indent=2)
|
||||
self.current_artifact += 1
|
||||
@@ -135,12 +201,21 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
trajectory_path = self.trajectory_dir / self.trajectory_id
|
||||
trajectory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save trajectory metadata
|
||||
# Save trajectory metadata (optionally extract screenshots to screenshot_dir)
|
||||
kwargs_to_save = kwargs.copy()
|
||||
try:
|
||||
if "messages" in kwargs_to_save:
|
||||
kwargs_to_save["messages"] = extract_computer_call_outputs(
|
||||
kwargs_to_save["messages"], self.screenshot_dir
|
||||
)
|
||||
except Exception:
|
||||
# If extraction fails, fall back to original messages
|
||||
pass
|
||||
metadata = {
|
||||
"trajectory_id": self.trajectory_id,
|
||||
"created_at": str(uuid.uuid1().time),
|
||||
"status": "running",
|
||||
"kwargs": kwargs,
|
||||
"kwargs": kwargs_to_save,
|
||||
}
|
||||
|
||||
with open(trajectory_path / "metadata.json", "w") as f:
|
||||
@@ -167,11 +242,18 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
metadata = {}
|
||||
|
||||
# Update metadata with completion info
|
||||
# Optionally extract screenshots from new_items before persisting
|
||||
new_items_to_save = new_items
|
||||
try:
|
||||
new_items_to_save = extract_computer_call_outputs(new_items, self.screenshot_dir)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
metadata.update({
|
||||
"status": "completed",
|
||||
"completed_at": str(uuid.uuid1().time),
|
||||
"total_usage": self.total_usage,
|
||||
"new_items": sanitize_image_urls(new_items),
|
||||
"new_items": new_items_to_save,
|
||||
"total_turns": self.current_turn
|
||||
})
|
||||
|
||||
|
||||
@@ -15,6 +15,11 @@ class HumanCompletionUI:
|
||||
self.current_call_id: Optional[str] = None
|
||||
self.refresh_interval = 2.0 # seconds
|
||||
self.last_image = None # Store the last image for display
|
||||
# Track current interactive action controls
|
||||
self.current_action_type: str = "click"
|
||||
self.current_button: str = "left"
|
||||
self.current_scroll_x: int = 0
|
||||
self.current_scroll_y: int = -120
|
||||
|
||||
def format_messages_for_chatbot(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Format messages for display in gr.Chatbot with type='messages'."""
|
||||
@@ -196,7 +201,9 @@ class HumanCompletionUI:
|
||||
gr.update(choices=["latest"], value="latest"), # dropdown
|
||||
gr.update(value=None), # image (no image)
|
||||
gr.update(value=[]), # chatbot (empty messages)
|
||||
gr.update(interactive=False) # submit button
|
||||
gr.update(interactive=False), # submit button
|
||||
gr.update(visible=False), # click_actions_group hidden
|
||||
gr.update(visible=False), # actions_group hidden
|
||||
)
|
||||
|
||||
# Sort pending calls by created_at to get oldest first
|
||||
@@ -237,7 +244,9 @@ class HumanCompletionUI:
|
||||
gr.update(choices=choices, value="latest"),
|
||||
gr.update(value=self.last_image),
|
||||
gr.update(value=conversation),
|
||||
gr.update(interactive=bool(choices))
|
||||
gr.update(interactive=bool(choices)),
|
||||
gr.update(visible=True), # click_actions_group visible when there is a call
|
||||
gr.update(visible=True), # actions_group visible when there is a call
|
||||
)
|
||||
|
||||
def on_call_selected(self, selected_choice):
|
||||
@@ -246,7 +255,9 @@ class HumanCompletionUI:
|
||||
return (
|
||||
gr.update(value=None), # no image
|
||||
gr.update(value=[]), # empty chatbot
|
||||
gr.update(interactive=False)
|
||||
gr.update(interactive=False),
|
||||
gr.update(visible=False), # click_actions_group hidden
|
||||
gr.update(visible=False), # actions_group hidden
|
||||
)
|
||||
|
||||
pending_calls = self.get_pending_calls()
|
||||
@@ -254,7 +265,9 @@ class HumanCompletionUI:
|
||||
return (
|
||||
gr.update(value=None), # no image
|
||||
gr.update(value=[]), # empty chatbot
|
||||
gr.update(interactive=False)
|
||||
gr.update(interactive=False),
|
||||
gr.update(visible=False), # click_actions_group hidden
|
||||
gr.update(visible=False), # actions_group hidden
|
||||
)
|
||||
|
||||
# Handle "latest" option
|
||||
@@ -286,7 +299,9 @@ class HumanCompletionUI:
|
||||
return (
|
||||
gr.update(value=None), # no image
|
||||
gr.update(value=[]), # empty chatbot
|
||||
gr.update(interactive=False)
|
||||
gr.update(interactive=False),
|
||||
gr.update(visible=False), # click_actions_group hidden
|
||||
gr.update(visible=False), # actions_group hidden
|
||||
)
|
||||
|
||||
conversation = self.format_messages_for_chatbot(selected_call.get("messages", []))
|
||||
@@ -297,7 +312,9 @@ class HumanCompletionUI:
|
||||
return (
|
||||
gr.update(value=self.last_image),
|
||||
gr.update(value=conversation),
|
||||
gr.update(interactive=True)
|
||||
gr.update(interactive=True),
|
||||
gr.update(visible=True), # click_actions_group visible
|
||||
gr.update(visible=True), # actions_group visible
|
||||
)
|
||||
|
||||
def submit_response(self, response_text: str):
|
||||
@@ -368,6 +385,10 @@ class HumanCompletionUI:
|
||||
"""Submit a hotkey action."""
|
||||
return self.submit_action("keypress", keys=keys)
|
||||
|
||||
def submit_wait_action(self) -> str:
|
||||
"""Submit a wait action with no kwargs."""
|
||||
return self.submit_action("wait")
|
||||
|
||||
def submit_description_click(self, description: str, action_type: str = "click", button: str = "left") -> str:
|
||||
"""Submit a description-based action."""
|
||||
if action_type == "click":
|
||||
@@ -407,7 +428,7 @@ def create_ui():
|
||||
"""Create the Gradio interface."""
|
||||
ui_handler = HumanCompletionUI()
|
||||
|
||||
with gr.Blocks(title="Human-in-the-Loop Agent Tool") as demo:
|
||||
with gr.Blocks(title="Human-in-the-Loop Agent Tool", fill_width=True) as demo:
|
||||
gr.Markdown("# 🤖 Human-in-the-Loop Agent Tool")
|
||||
gr.Markdown("Review AI conversation requests and provide human responses.")
|
||||
|
||||
@@ -415,29 +436,42 @@ def create_ui():
|
||||
with gr.Column(scale=2):
|
||||
with gr.Group():
|
||||
screenshot_image = gr.Image(
|
||||
label="Screenshot",
|
||||
label="Interactive Screenshot",
|
||||
interactive=False,
|
||||
height=600
|
||||
)
|
||||
|
||||
# Action type selection for image clicks
|
||||
with gr.Row():
|
||||
action_type_radio = gr.Radio(
|
||||
label="Action Type",
|
||||
choices=["click", "double_click", "move", "left_mouse_up", "left_mouse_down"],
|
||||
value="click",
|
||||
scale=2
|
||||
)
|
||||
action_button_radio = gr.Radio(
|
||||
label="Button (for click only)",
|
||||
choices=["left", "right", "wheel", "back", "forward"],
|
||||
value="left",
|
||||
visible=True,
|
||||
scale=1
|
||||
)
|
||||
# Action type selection for image clicks (wrapped for visibility control)
|
||||
with gr.Group(visible=False) as click_actions_group:
|
||||
with gr.Row():
|
||||
action_type_radio = gr.Dropdown(
|
||||
label="Interactive Action",
|
||||
choices=["click", "double_click", "move", "left_mouse_up", "left_mouse_down", "scroll"],
|
||||
value="click",
|
||||
scale=2
|
||||
)
|
||||
action_button_radio = gr.Dropdown(
|
||||
label="Button",
|
||||
choices=["left", "right", "wheel", "back", "forward"],
|
||||
value="left",
|
||||
visible=True,
|
||||
scale=1
|
||||
)
|
||||
scroll_x_input = gr.Number(
|
||||
label="scroll_x",
|
||||
value=0,
|
||||
visible=False,
|
||||
scale=1
|
||||
)
|
||||
scroll_y_input = gr.Number(
|
||||
label="scroll_y",
|
||||
value=-120,
|
||||
visible=False,
|
||||
scale=1
|
||||
)
|
||||
|
||||
conversation_chatbot = gr.Chatbot(
|
||||
label="Messages",
|
||||
label="Conversation",
|
||||
type="messages",
|
||||
height=500,
|
||||
show_copy_button=True
|
||||
@@ -446,99 +480,97 @@ def create_ui():
|
||||
with gr.Column(scale=1):
|
||||
with gr.Group():
|
||||
call_dropdown = gr.Dropdown(
|
||||
label="Select a pending call",
|
||||
label="Select a pending conversation request",
|
||||
choices=["latest"],
|
||||
interactive=True,
|
||||
value="latest"
|
||||
)
|
||||
refresh_btn = gr.Button("🔄 Refresh", variant="secondary")
|
||||
status_display = gr.Textbox(
|
||||
label="Status",
|
||||
interactive=False,
|
||||
value="Ready to receive requests..."
|
||||
)
|
||||
|
||||
with gr.Group():
|
||||
response_text = gr.Textbox(
|
||||
label="Response",
|
||||
label="Message",
|
||||
lines=3,
|
||||
placeholder="Enter your response here..."
|
||||
placeholder="Enter your message here..."
|
||||
)
|
||||
submit_btn = gr.Button("📤 Submit Response", variant="primary", interactive=False)
|
||||
submit_btn = gr.Button("📤 Submit Message", variant="primary", interactive=False)
|
||||
|
||||
# Action Accordions
|
||||
with gr.Accordion("🖱️ Click Actions", open=False):
|
||||
with gr.Group():
|
||||
with gr.Row():
|
||||
click_x = gr.Number(label="X", value=0, minimum=0)
|
||||
click_y = gr.Number(label="Y", value=0, minimum=0)
|
||||
with gr.Row():
|
||||
click_action_type = gr.Dropdown(
|
||||
label="Action Type",
|
||||
choices=["click", "double_click", "move", "left_mouse_up", "left_mouse_down"],
|
||||
value="click"
|
||||
)
|
||||
click_button = gr.Dropdown(
|
||||
label="Button (for click only)",
|
||||
choices=["left", "right", "wheel", "back", "forward"],
|
||||
value="left"
|
||||
)
|
||||
click_submit_btn = gr.Button("Submit Action")
|
||||
|
||||
with gr.Accordion("📝 Type Action", open=False):
|
||||
with gr.Group():
|
||||
type_text = gr.Textbox(
|
||||
label="Text to Type",
|
||||
placeholder="Enter text to type..."
|
||||
)
|
||||
type_submit_btn = gr.Button("Submit Type")
|
||||
|
||||
with gr.Accordion("⌨️ Keypress Action", open=False):
|
||||
with gr.Group():
|
||||
keypress_text = gr.Textbox(
|
||||
label="Keys",
|
||||
placeholder="e.g., ctrl+c, alt+tab"
|
||||
)
|
||||
keypress_submit_btn = gr.Button("Submit Keypress")
|
||||
|
||||
with gr.Accordion("🎯 Description Action", open=False):
|
||||
with gr.Group():
|
||||
description_text = gr.Textbox(
|
||||
label="Element Description",
|
||||
placeholder="e.g., 'Privacy and security option in left sidebar'"
|
||||
)
|
||||
with gr.Row():
|
||||
description_action_type = gr.Dropdown(
|
||||
label="Action Type",
|
||||
choices=["click", "double_click", "move", "left_mouse_up", "left_mouse_down"],
|
||||
value="click"
|
||||
)
|
||||
description_button = gr.Radio(
|
||||
label="Button (for click only)",
|
||||
choices=["left", "right", "wheel", "back", "forward"],
|
||||
value="left"
|
||||
)
|
||||
description_submit_btn = gr.Button("Submit Description Action")
|
||||
|
||||
status_display = gr.Textbox(
|
||||
label="Status",
|
||||
interactive=False,
|
||||
value="Ready to receive calls..."
|
||||
)
|
||||
# Action Accordions (wrapped for visibility control)
|
||||
with gr.Group(visible=False) as actions_group:
|
||||
with gr.Tabs():
|
||||
with gr.Tab("🖱️ Click Actions"):
|
||||
with gr.Group():
|
||||
description_text = gr.Textbox(
|
||||
label="Element Description",
|
||||
placeholder="e.g., 'Privacy and security option in left sidebar'"
|
||||
)
|
||||
with gr.Row():
|
||||
description_action_type = gr.Dropdown(
|
||||
label="Action",
|
||||
choices=["click", "double_click", "move", "left_mouse_up", "left_mouse_down"],
|
||||
value="click"
|
||||
)
|
||||
description_button = gr.Dropdown(
|
||||
label="Button",
|
||||
choices=["left", "right", "wheel", "back", "forward"],
|
||||
value="left"
|
||||
)
|
||||
description_submit_btn = gr.Button("Submit Click Action")
|
||||
|
||||
with gr.Tab("📝 Type Action"):
|
||||
with gr.Group():
|
||||
type_text = gr.Textbox(
|
||||
label="Text to Type",
|
||||
placeholder="Enter text to type..."
|
||||
)
|
||||
type_submit_btn = gr.Button("Submit Type")
|
||||
|
||||
with gr.Tab("⌨️ Keypress Action"):
|
||||
with gr.Group():
|
||||
keypress_text = gr.Textbox(
|
||||
label="Keys",
|
||||
placeholder="e.g., ctrl+c, alt+tab"
|
||||
)
|
||||
keypress_submit_btn = gr.Button("Submit Keypress")
|
||||
|
||||
with gr.Tab("🧰 Misc Actions"):
|
||||
with gr.Group():
|
||||
misc_action_dropdown = gr.Dropdown(
|
||||
label="Action",
|
||||
choices=["wait"],
|
||||
value="wait"
|
||||
)
|
||||
misc_submit_btn = gr.Button("Submit Action")
|
||||
|
||||
# Event handlers
|
||||
refresh_btn.click(
|
||||
fn=ui_handler.refresh_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn]
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
)
|
||||
|
||||
call_dropdown.change(
|
||||
fn=ui_handler.on_call_selected,
|
||||
inputs=[call_dropdown],
|
||||
outputs=[screenshot_image, conversation_chatbot, submit_btn]
|
||||
outputs=[screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
)
|
||||
|
||||
def handle_image_click(evt: gr.SelectData):
|
||||
if evt.index is not None:
|
||||
x, y = evt.index
|
||||
action_type = action_type_radio.value or "click"
|
||||
button = action_button_radio.value or "left"
|
||||
result = ui_handler.submit_click_action(x, y, action_type, button)
|
||||
action_type = ui_handler.current_action_type or "click"
|
||||
button = ui_handler.current_button or "left"
|
||||
if action_type == "scroll":
|
||||
sx_i = int(ui_handler.current_scroll_x or 0)
|
||||
sy_i = int(ui_handler.current_scroll_y or 0)
|
||||
# Submit a scroll action with x,y position and scroll deltas
|
||||
result = ui_handler.submit_action("scroll", x=x, y=y, scroll_x=sx_i, scroll_y=sy_i)
|
||||
else:
|
||||
result = ui_handler.submit_click_action(x, y, action_type, button)
|
||||
ui_handler.wait_for_pending_calls()
|
||||
return result
|
||||
return "No coordinates selected"
|
||||
@@ -548,7 +580,7 @@ def create_ui():
|
||||
outputs=[status_display]
|
||||
).then(
|
||||
fn=ui_handler.wait_for_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn]
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
)
|
||||
|
||||
# Response submission
|
||||
@@ -558,27 +590,52 @@ def create_ui():
|
||||
outputs=[response_text, status_display]
|
||||
).then(
|
||||
fn=ui_handler.refresh_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn]
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
)
|
||||
|
||||
# Toggle button radio visibility based on action type
|
||||
def toggle_button_visibility(action_type):
|
||||
return gr.update(visible=(action_type == "click"))
|
||||
# Toggle visibility of controls based on action type
|
||||
def toggle_action_controls(action_type):
|
||||
# Button visible only for click
|
||||
button_vis = gr.update(visible=(action_type == "click"))
|
||||
# Scroll inputs visible only for scroll
|
||||
scroll_x_vis = gr.update(visible=(action_type == "scroll"))
|
||||
scroll_y_vis = gr.update(visible=(action_type == "scroll"))
|
||||
# Update state
|
||||
ui_handler.current_action_type = action_type or "click"
|
||||
return button_vis, scroll_x_vis, scroll_y_vis
|
||||
|
||||
action_type_radio.change(
|
||||
fn=toggle_button_visibility,
|
||||
fn=toggle_action_controls,
|
||||
inputs=[action_type_radio],
|
||||
outputs=[action_button_radio]
|
||||
outputs=[action_button_radio, scroll_x_input, scroll_y_input]
|
||||
)
|
||||
|
||||
# Action accordion handlers
|
||||
click_submit_btn.click(
|
||||
fn=ui_handler.submit_click_action,
|
||||
inputs=[click_x, click_y, click_action_type, click_button],
|
||||
outputs=[status_display]
|
||||
).then(
|
||||
fn=ui_handler.wait_for_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn]
|
||||
# Keep other control values in ui_handler state
|
||||
def on_button_change(val):
|
||||
ui_handler.current_button = (val or "left")
|
||||
action_button_radio.change(
|
||||
fn=on_button_change,
|
||||
inputs=[action_button_radio]
|
||||
)
|
||||
|
||||
def on_scroll_x_change(val):
|
||||
try:
|
||||
ui_handler.current_scroll_x = int(val) if val is not None else 0
|
||||
except Exception:
|
||||
ui_handler.current_scroll_x = 0
|
||||
scroll_x_input.change(
|
||||
fn=on_scroll_x_change,
|
||||
inputs=[scroll_x_input]
|
||||
)
|
||||
|
||||
def on_scroll_y_change(val):
|
||||
try:
|
||||
ui_handler.current_scroll_y = int(val) if val is not None else 0
|
||||
except Exception:
|
||||
ui_handler.current_scroll_y = 0
|
||||
scroll_y_input.change(
|
||||
fn=on_scroll_y_change,
|
||||
inputs=[scroll_y_input]
|
||||
)
|
||||
|
||||
type_submit_btn.click(
|
||||
@@ -587,7 +644,7 @@ def create_ui():
|
||||
outputs=[status_display]
|
||||
).then(
|
||||
fn=ui_handler.wait_for_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn]
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
)
|
||||
|
||||
keypress_submit_btn.click(
|
||||
@@ -596,7 +653,7 @@ def create_ui():
|
||||
outputs=[status_display]
|
||||
).then(
|
||||
fn=ui_handler.wait_for_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn]
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
)
|
||||
|
||||
def handle_description_submit(description, action_type, button):
|
||||
@@ -612,13 +669,30 @@ def create_ui():
|
||||
outputs=[status_display]
|
||||
).then(
|
||||
fn=ui_handler.wait_for_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn]
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
)
|
||||
|
||||
# Misc action handler
|
||||
def handle_misc_submit(selected_action):
|
||||
if selected_action == "wait":
|
||||
result = ui_handler.submit_wait_action()
|
||||
ui_handler.wait_for_pending_calls()
|
||||
return result
|
||||
return f"Unsupported misc action: {selected_action}"
|
||||
|
||||
misc_submit_btn.click(
|
||||
fn=handle_misc_submit,
|
||||
inputs=[misc_action_dropdown],
|
||||
outputs=[status_display]
|
||||
).then(
|
||||
fn=ui_handler.wait_for_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
)
|
||||
|
||||
# Load initial data
|
||||
demo.load(
|
||||
fn=ui_handler.refresh_pending_calls,
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn]
|
||||
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
@@ -1,77 +1,228 @@
|
||||
"""HUD integration for ComputerAgent."""
|
||||
"""HUD integration: Generic HuggingFace dataset evaluation runner (CUA proxy).
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional, Dict
|
||||
from hud import run_job as hud_run_job
|
||||
This module exposes two helpers to evaluate HUD-compatible datasets using
|
||||
HUD's OperatorAgent, while proxying model calls through our ComputerAgent via
|
||||
`FakeAsyncOpenAI` (see `agent/integrations/hud/agent.py`).
|
||||
|
||||
from .agent import ComputerAgent
|
||||
from .adapter import ComputerAgentAdapter
|
||||
from .computer_handler import HUDComputerHandler
|
||||
Exports:
|
||||
- run_single_task(dataset_name, *, agent_type="cua-proxy", model=None, allowed_tools=None)
|
||||
- run_full_dataset(dataset_name, *, agent_type="cua-proxy", model=None, allowed_tools=None, max_concurrent=30, max_steps=50)
|
||||
"""
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from PIL import Image
|
||||
from datasets import load_dataset, Dataset
|
||||
from hud.agents import OperatorAgent
|
||||
from hud.datasets import Task, run_dataset
|
||||
from hud.tools.computer.settings import computer_settings
|
||||
from hud import trace
|
||||
|
||||
from agent.agent import ComputerAgent as BaseComputerAgent
|
||||
from .proxy import FakeAsyncOpenAI
|
||||
|
||||
|
||||
async def run_job(
|
||||
model: str,
|
||||
task_or_taskset: Any,
|
||||
job_name: str,
|
||||
# Job kwargs
|
||||
auto_reply_question: bool = False,
|
||||
adapter_cls: Any = None,
|
||||
adapter_kwargs: Optional[Dict[str, Any]] = None,
|
||||
max_steps_per_task: int = 20,
|
||||
run_parallel: bool = True,
|
||||
job_metadata: Optional[Dict[str, Any]] = None,
|
||||
show_progress: bool = True,
|
||||
max_concurrent_env_creations: Optional[int] = 30, # Limits gym.make calls
|
||||
max_concurrent_agent_predictions: Optional[int] = None, # No limit on LLM calls
|
||||
max_concurrent_tasks: Optional[int] = 30, # Limits overall task concurrency
|
||||
**agent_kwargs: Any
|
||||
) -> Any:
|
||||
# ---------------------------------------------------------------------------
|
||||
# Proxy OperatorAgent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ProxyOperatorAgent(OperatorAgent):
|
||||
"""OperatorAgent that proxies model calls through our ComputerAgent.
|
||||
|
||||
Accepts the same config keys we pass via hud.run_dataset `agent_config`:
|
||||
- model: str | None
|
||||
- allowed_tools: list[str] | None
|
||||
Additional kwargs are forwarded to OperatorAgent (if any are supported).
|
||||
"""
|
||||
Run a job using ComputerAgent with the specified model.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str | None = None,
|
||||
allowed_tools: list[str] | None = None,
|
||||
trajectory_dir: str | dict | None = None,
|
||||
# === ComputerAgent kwargs ===
|
||||
tools: list[Any] | None = None,
|
||||
custom_loop: Any | None = None,
|
||||
only_n_most_recent_images: int | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
verbosity: int | None = None,
|
||||
max_retries: int | None = 3,
|
||||
screenshot_delay: float | int = 0.5,
|
||||
use_prompt_caching: bool | None = False,
|
||||
max_trajectory_budget: float | dict | None = None,
|
||||
telemetry_enabled: bool | None = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
model = model or "computer-use-preview"
|
||||
allowed_tools = allowed_tools or ["openai_computer"]
|
||||
|
||||
computer_shim = {
|
||||
'screenshot': lambda: Image.new('RGB', (computer_settings.OPENAI_COMPUTER_WIDTH, computer_settings.OPENAI_COMPUTER_HEIGHT)),
|
||||
'environment': 'linux',
|
||||
'dimensions': (computer_settings.OPENAI_COMPUTER_WIDTH, computer_settings.OPENAI_COMPUTER_HEIGHT)
|
||||
}
|
||||
# Build tools ensuring the computer_shim is included
|
||||
agent_tools: list[Any] = [computer_shim]
|
||||
if tools:
|
||||
agent_tools.extend(tools)
|
||||
|
||||
computer_agent = BaseComputerAgent(
|
||||
model=model,
|
||||
tools=agent_tools,
|
||||
custom_loop=custom_loop,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
callbacks=callbacks,
|
||||
verbosity=verbosity,
|
||||
trajectory_dir=trajectory_dir,
|
||||
max_retries=max_retries,
|
||||
screenshot_delay=screenshot_delay,
|
||||
use_prompt_caching=use_prompt_caching,
|
||||
max_trajectory_budget=max_trajectory_budget,
|
||||
telemetry_enabled=telemetry_enabled,
|
||||
)
|
||||
model_client = FakeAsyncOpenAI(computer_agent)
|
||||
|
||||
super().__init__(
|
||||
model_client=model_client, # type: ignore[arg-type]
|
||||
model=model,
|
||||
allowed_tools=allowed_tools,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Single-task runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_single_task(
|
||||
dataset: str | Dataset | list[dict[str, Any]],
|
||||
*,
|
||||
task_id: int = 0,
|
||||
model: str | None = None,
|
||||
allowed_tools: list[str] | None = None,
|
||||
# === ComputerAgent kwargs ===
|
||||
tools: list[Any] | None = None,
|
||||
custom_loop: Any | None = None,
|
||||
only_n_most_recent_images: int | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
verbosity: int | None = None,
|
||||
trajectory_dir: str | dict | None = None,
|
||||
max_retries: int | None = 3,
|
||||
screenshot_delay: float | int = 0.5,
|
||||
use_prompt_caching: bool | None = False,
|
||||
max_trajectory_budget: float | dict | None = None,
|
||||
telemetry_enabled: bool | None = True,
|
||||
) -> None:
|
||||
"""Load one task from the dataset and execute it with Operator+CUA proxy."""
|
||||
|
||||
# Load dataset and pick a sample
|
||||
if isinstance(dataset, str):
|
||||
dataset = load_dataset(dataset, split="train") # type: ignore[arg-type]
|
||||
elif isinstance(dataset, list):
|
||||
dataset = dataset
|
||||
else:
|
||||
dataset = dataset["train"]
|
||||
|
||||
Args:
|
||||
model: Model string for ComputerAgent (e.g., "anthropic/claude-3-5-sonnet-20241022")
|
||||
task_or_taskset: Task or TaskSet to run
|
||||
job_name: Name for the job
|
||||
auto_reply_question: Whether to auto-reply to questions
|
||||
adapter_cls: Custom adapter class (defaults to ComputerAgentAdapter)
|
||||
adapter_kwargs: Additional kwargs for the adapter
|
||||
max_steps_per_task: Maximum steps per task
|
||||
run_parallel: Whether to run tasks in parallel
|
||||
job_metadata: Additional metadata for the job
|
||||
show_progress: Whether to show progress
|
||||
max_concurrent_env_creations: Max concurrent environment creations
|
||||
max_concurrent_agent_predictions: Max concurrent agent predictions
|
||||
max_concurrent_tasks: Max concurrent tasks
|
||||
**agent_kwargs: Additional kwargs to pass to ComputerAgent
|
||||
|
||||
Returns:
|
||||
Job instance from HUD
|
||||
"""
|
||||
# combine verbose and verbosity kwargs
|
||||
if "verbose" in agent_kwargs:
|
||||
agent_kwargs["verbosity"] = logging.INFO
|
||||
del agent_kwargs["verbose"]
|
||||
verbose = True if agent_kwargs.get("verbosity", logging.WARNING) > logging.INFO else False
|
||||
|
||||
# run job
|
||||
return await hud_run_job(
|
||||
agent_cls=ComputerAgent,
|
||||
agent_kwargs={"model": model, **agent_kwargs},
|
||||
task_or_taskset=task_or_taskset,
|
||||
job_name=job_name,
|
||||
auto_reply_question=auto_reply_question,
|
||||
adapter_cls=adapter_cls,
|
||||
adapter_kwargs=adapter_kwargs,
|
||||
max_steps_per_task=max_steps_per_task,
|
||||
run_parallel=run_parallel,
|
||||
job_metadata=job_metadata,
|
||||
show_progress=show_progress,
|
||||
verbose=verbose,
|
||||
max_concurrent_env_creations=max_concurrent_env_creations,
|
||||
max_concurrent_agent_predictions=max_concurrent_agent_predictions,
|
||||
max_concurrent_tasks=max_concurrent_tasks
|
||||
sample_task = dataset[task_id] # type: ignore[index]
|
||||
task_prompt = sample_task.get("prompt", f"Task {sample_task.get('id', 0)}") # type: ignore[attr-defined]
|
||||
|
||||
with trace(name=task_prompt):
|
||||
task = Task(**sample_task) # type: ignore[arg-type]
|
||||
|
||||
agent = ProxyOperatorAgent(
|
||||
model=model,
|
||||
allowed_tools=allowed_tools,
|
||||
# === ComputerAgent kwargs passthrough ===
|
||||
tools=tools,
|
||||
custom_loop=custom_loop,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
callbacks=callbacks,
|
||||
verbosity=verbosity,
|
||||
trajectory_dir=trajectory_dir,
|
||||
max_retries=max_retries,
|
||||
screenshot_delay=screenshot_delay,
|
||||
use_prompt_caching=use_prompt_caching,
|
||||
max_trajectory_budget=max_trajectory_budget,
|
||||
telemetry_enabled=telemetry_enabled,
|
||||
)
|
||||
print(f"Running: {task_prompt}")
|
||||
result = await agent.run(task, max_steps=10)
|
||||
print(f"✅ Reward: {getattr(result, 'reward')}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full-dataset runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_full_dataset(
|
||||
dataset: str | Dataset | list[dict[str, Any]],
|
||||
*,
|
||||
job_name: Optional[str] = None,
|
||||
model: str | None = None,
|
||||
allowed_tools: list[str] | None = None,
|
||||
max_concurrent: int = 30,
|
||||
max_steps: int = 50,
|
||||
split: str = "train",
|
||||
trajectory_dir: str | dict | None = None,
|
||||
# === ComputerAgent kwargs ===
|
||||
tools: list[Any] | None = None,
|
||||
custom_loop: Any | None = None,
|
||||
only_n_most_recent_images: int | None = 5,
|
||||
callbacks: list[Any] | None = None,
|
||||
verbosity: int | None = None,
|
||||
max_retries: int | None = 3,
|
||||
screenshot_delay: float | int = 0.5,
|
||||
use_prompt_caching: bool | None = False,
|
||||
max_trajectory_budget: float | dict | None = None,
|
||||
telemetry_enabled: bool | None = True,
|
||||
) -> list[Any]:
|
||||
"""Run evaluation across the entire dataset using hud.datasets.run_dataset."""
|
||||
|
||||
# We pass OperatorAgent as the class and provide a config that injects our
|
||||
# FakeAsyncOpenAI per agent instantiation.
|
||||
|
||||
if isinstance(dataset, str):
|
||||
dataset_name = dataset.split('/')[-1]
|
||||
job_name = job_name or f"Evaluation {dataset_name}"
|
||||
dataset = load_dataset(dataset, split=split) # type: ignore[arg-type]
|
||||
else:
|
||||
dataset_name = "custom"
|
||||
job_name = job_name or f"Evaluation {time.strftime('%H:%M %Y-%m-%d')}"
|
||||
|
||||
# Execute evaluation
|
||||
return await run_dataset(
|
||||
name=job_name,
|
||||
dataset=dataset,
|
||||
agent_class=ProxyOperatorAgent,
|
||||
agent_config={
|
||||
"model": model,
|
||||
"allowed_tools": allowed_tools,
|
||||
"trajectory_dir": trajectory_dir,
|
||||
# === ComputerAgent kwargs passthrough ===
|
||||
"tools": tools,
|
||||
"custom_loop": custom_loop,
|
||||
"only_n_most_recent_images": only_n_most_recent_images,
|
||||
"callbacks": callbacks,
|
||||
"verbosity": verbosity,
|
||||
"max_retries": max_retries,
|
||||
"screenshot_delay": screenshot_delay,
|
||||
"use_prompt_caching": use_prompt_caching,
|
||||
"max_trajectory_budget": max_trajectory_budget,
|
||||
"telemetry_enabled": telemetry_enabled,
|
||||
},
|
||||
max_concurrent=max_concurrent,
|
||||
metadata={"dataset": dataset_name},
|
||||
max_steps=max_steps,
|
||||
auto_respond=True,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["ComputerAgent", "ComputerAgentAdapter", "HUDComputerHandler", "run_job"]
|
||||
__all__ = [
|
||||
"run_single_task",
|
||||
"run_full_dataset",
|
||||
"ProxyOperatorAgent",
|
||||
]
|
||||
@@ -1,121 +0,0 @@
|
||||
"""HUD Adapter for ComputerAgent integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from hud.adapters.common import CLA, Adapter
|
||||
from hud.adapters.common.types import (
|
||||
CLAButton,
|
||||
CLAKey,
|
||||
ClickAction,
|
||||
CustomAction,
|
||||
DragAction,
|
||||
MoveAction,
|
||||
Point,
|
||||
PressAction,
|
||||
ResponseAction,
|
||||
ScreenshotFetch,
|
||||
ScrollAction,
|
||||
TypeAction,
|
||||
WaitAction,
|
||||
)
|
||||
|
||||
|
||||
class ComputerAgentAdapter(Adapter):
|
||||
"""Adapter for ComputerAgent to work with HUD."""
|
||||
|
||||
KEY_MAP: ClassVar[dict[str, CLAKey]] = {
|
||||
"return": "enter",
|
||||
"arrowup": "up",
|
||||
"arrowdown": "down",
|
||||
"arrowleft": "left",
|
||||
"arrowright": "right",
|
||||
"cmd": "ctrl",
|
||||
"super": "win",
|
||||
"meta": "win",
|
||||
}
|
||||
|
||||
BUTTON_MAP: ClassVar[dict[str, CLAButton]] = {
|
||||
"wheel": "middle",
|
||||
"middle": "middle",
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# ComputerAgent default dimensions (can be overridden)
|
||||
self.agent_width = 1024
|
||||
self.agent_height = 768
|
||||
|
||||
def _map_key(self, key: str) -> CLAKey:
|
||||
"""Map a key to its standardized form."""
|
||||
return self.KEY_MAP.get(key.lower(), key.lower()) # type: ignore
|
||||
|
||||
def convert(self, data: Any) -> CLA:
|
||||
"""Convert a ComputerAgent action to a HUD action."""
|
||||
try:
|
||||
action_type = data.get("type")
|
||||
|
||||
if action_type == "click":
|
||||
x, y = data.get("x", 0), data.get("y", 0)
|
||||
button = data.get("button", "left")
|
||||
button = self.BUTTON_MAP.get(button, button)
|
||||
if button is None:
|
||||
button = "left"
|
||||
converted_action = ClickAction(point=Point(x=x, y=y), button=button)
|
||||
|
||||
elif action_type == "double_click":
|
||||
x, y = data.get("x", 0), data.get("y", 0)
|
||||
converted_action = ClickAction(point=Point(x=x, y=y), button="left", pattern=[100])
|
||||
|
||||
elif action_type == "scroll":
|
||||
x, y = int(data.get("x", 0)), int(data.get("y", 0))
|
||||
scroll_x = int(data.get("scroll_x", 0))
|
||||
scroll_y = int(data.get("scroll_y", 0))
|
||||
converted_action = ScrollAction(
|
||||
point=Point(x=x, y=y), scroll=Point(x=scroll_x, y=scroll_y)
|
||||
)
|
||||
|
||||
elif action_type == "type":
|
||||
text = data.get("text", "")
|
||||
converted_action = TypeAction(text=text, enter_after=False)
|
||||
|
||||
elif action_type == "wait":
|
||||
ms = data.get("ms", 1000)
|
||||
converted_action = WaitAction(time=ms)
|
||||
|
||||
elif action_type == "move":
|
||||
x, y = data.get("x", 0), data.get("y", 0)
|
||||
converted_action = MoveAction(point=Point(x=x, y=y))
|
||||
|
||||
elif action_type == "keypress":
|
||||
keys = data.get("keys", [])
|
||||
if isinstance(keys, str):
|
||||
keys = [keys]
|
||||
converted_action = PressAction(keys=[self._map_key(k) for k in keys])
|
||||
|
||||
elif action_type == "drag":
|
||||
path = data.get("path", [])
|
||||
points = [Point(x=p.get("x", 0), y=p.get("y", 0)) for p in path]
|
||||
converted_action = DragAction(path=points)
|
||||
|
||||
elif action_type == "screenshot":
|
||||
converted_action = ScreenshotFetch()
|
||||
|
||||
elif action_type == "response":
|
||||
converted_action = ResponseAction(text=data.get("text", ""))
|
||||
|
||||
elif action_type == "custom":
|
||||
converted_action = CustomAction(action=data.get("action", ""))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported action type: {action_type}")
|
||||
|
||||
# Add reasoning and logs if available
|
||||
converted_action.reasoning = data.get("reasoning", "")
|
||||
converted_action.logs = data.get("logs", "")
|
||||
|
||||
return converted_action
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid action: {data}. Error: {e!s}") from e
|
||||
@@ -1,373 +0,0 @@
|
||||
"""HUD ComputerAgent wrapper for OSWorld benchmarking."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Literal, Optional, Union, List, Dict
|
||||
import asyncio
|
||||
|
||||
from agent import ComputerAgent as BaseComputerAgent
|
||||
from agent.responses import make_failed_tool_call_items
|
||||
from hud.adapters import Adapter
|
||||
from hud.agent.base import Agent
|
||||
from hud.utils.common import Observation
|
||||
from hud.adapters.common.types import LogType
|
||||
from hud.types import Gym
|
||||
|
||||
from .adapter import ComputerAgentAdapter
|
||||
from .computer_handler import HUDComputerHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BASE_SYSTEM_PROMPT = """
|
||||
You are an autonomous computer-using agent. Follow these guidelines:
|
||||
|
||||
1. Be decisive and complete tasks without asking for confirmation unless absolutely necessary.
|
||||
2. Use the computer tools to complete the task and do not stop until the task is complete.
|
||||
3. Do NOT ask questions like "Should I proceed?" or "Would you like me to continue?" - just proceed with the task.
|
||||
4. When you find what you're looking for (e.g., a file to upload), proceed with the action directly.
|
||||
5. Only stop when the task is fully complete or if you encounter an error that prevents completion.
|
||||
6. Trust that the user wants you to complete the entire task they've requested.
|
||||
7. You must say "Task completed" when the task is complete.
|
||||
|
||||
Remember: You have been given permission to complete the requested task autonomously.
|
||||
""".strip()
|
||||
|
||||
class ComputerAgent(Agent[BaseComputerAgent, dict[str, Any]]):
|
||||
"""
|
||||
A ComputerAgent wrapper for HUD integration.
|
||||
|
||||
This agent wraps the base ComputerAgent to work with HUD environments,
|
||||
providing the same interface as OperatorAgent but using ComputerAgent internally.
|
||||
"""
|
||||
|
||||
transfer_gyms: dict[Gym, Gym] = {"qa": "hud-browser"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "anthropic/claude-3-5-sonnet-20241022",
|
||||
environment: Literal["windows", "mac", "linux", "browser"] = "linux",
|
||||
adapter: Optional[Adapter] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Initialize the ComputerAgent for HUD.
|
||||
|
||||
Args:
|
||||
model: The model string for ComputerAgent (e.g., "anthropic/claude-3-5-sonnet-20241022")
|
||||
environment: The environment type (windows, mac, linux, browser)
|
||||
adapter: The adapter to use for preprocessing and postprocessing
|
||||
name: The name of the agent
|
||||
**kwargs: Additional arguments passed to ComputerAgent
|
||||
"""
|
||||
# Create adapter if not provided
|
||||
adapter = adapter or ComputerAgentAdapter()
|
||||
|
||||
if name is None:
|
||||
name = f"computeragent-{model.split('/')[-1]}"
|
||||
|
||||
# Initialize the base Agent class without client (we'll create it later)
|
||||
super().__init__(client=None, adapter=adapter, name=name)
|
||||
|
||||
self.model = model
|
||||
self.environment = environment
|
||||
self.kwargs = kwargs
|
||||
|
||||
# Default dimensions
|
||||
self.width = 1024
|
||||
self.height = 768
|
||||
|
||||
# Update dimensions if adapter is provided
|
||||
if self.adapter:
|
||||
self.width = self.adapter.agent_width
|
||||
self.height = self.adapter.agent_height
|
||||
|
||||
# Create HUD computer handler
|
||||
self.hud_computer = HUDComputerHandler(
|
||||
environment=environment,
|
||||
dimensions=(self.width, self.height)
|
||||
)
|
||||
|
||||
# Handle trajectory_dir by adding TrajectorySaverCallback
|
||||
trajectory_dir = kwargs.pop("trajectory_dir", None)
|
||||
callbacks = kwargs.get("callbacks", [])
|
||||
|
||||
if trajectory_dir:
|
||||
from agent.callbacks.trajectory_saver import TrajectorySaverCallback
|
||||
trajectory_callback = TrajectorySaverCallback(trajectory_dir, reset_on_run=False)
|
||||
callbacks = callbacks + [trajectory_callback]
|
||||
kwargs["callbacks"] = callbacks
|
||||
|
||||
# Initialize ComputerAgent with HUD computer handler
|
||||
self.computer_agent = BaseComputerAgent(
|
||||
model=model,
|
||||
tools=[self.hud_computer],
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Set the client to the computer_agent for compatibility
|
||||
self.client = self.computer_agent
|
||||
|
||||
# State tracking
|
||||
self.conversation_history: List[Dict[str, Any]] = []
|
||||
self.initial_prompt: Optional[str] = None
|
||||
|
||||
# System prompt for computer use tasks
|
||||
self.base_system_prompt = BASE_SYSTEM_PROMPT
|
||||
|
||||
async def fetch_response(self, observation: Observation) -> tuple[list[dict[str, Any]], bool]:
|
||||
"""
|
||||
Fetch a response from ComputerAgent based on the observation.
|
||||
|
||||
Args:
|
||||
observation: The preprocessed observation, attributes:
|
||||
screenshot: Base64 encoded PNG string of the screen
|
||||
text: Text observation, if available
|
||||
|
||||
Returns:
|
||||
tuple[list[dict[str, Any]], bool, list[LogType] | None]: A tuple containing the list of raw actions,
|
||||
boolean indicating if the agent believes the task is complete.
|
||||
"""
|
||||
try:
|
||||
# Update the computer handler with the current screenshot
|
||||
if observation.screenshot:
|
||||
self.hud_computer.update_screenshot(observation.screenshot)
|
||||
|
||||
# Set up action callback to capture actions
|
||||
captured_actions = []
|
||||
action_done = False
|
||||
|
||||
async def action_callback(action: Dict[str, Any]) -> None:
|
||||
"""Callback to capture actions from ComputerAgent."""
|
||||
nonlocal captured_actions, action_done
|
||||
captured_actions.append(action)
|
||||
|
||||
# Set the action callback
|
||||
self.hud_computer.set_action_callback(action_callback)
|
||||
|
||||
# Prepare the message for ComputerAgent
|
||||
if not self.conversation_history:
|
||||
# First interaction - use the observation text as initial prompt
|
||||
if observation.text:
|
||||
self.initial_prompt = observation.text
|
||||
message = f"{self.base_system_prompt}\n\nTask: {observation.text}"
|
||||
else:
|
||||
message = f"{self.base_system_prompt}\n\nPlease analyze the current screen and determine what action to take."
|
||||
|
||||
input_content = [
|
||||
{"type": "input_text", "text": message}
|
||||
]
|
||||
|
||||
# Add screenshot if present
|
||||
if observation.screenshot:
|
||||
input_content.append(
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{observation.screenshot}",
|
||||
}
|
||||
)
|
||||
|
||||
self.conversation_history.append({"role": "user", "content": input_content})
|
||||
else:
|
||||
# Subsequent interactions - check if last action was computer_call
|
||||
# If so, add computer_call_output with screenshot instead of user message
|
||||
last_computer_calls = []
|
||||
for msg in reversed(self.conversation_history):
|
||||
if msg.get("type") == "computer_call":
|
||||
call_id = msg.get("call_id")
|
||||
if call_id:
|
||||
# Check if this call_id already has a computer_call_output
|
||||
has_output = any(
|
||||
m.get("type") == "computer_call_output" and m.get("call_id") == call_id
|
||||
for m in self.conversation_history
|
||||
)
|
||||
if not has_output:
|
||||
last_computer_calls.append(call_id)
|
||||
|
||||
if last_computer_calls:
|
||||
if not observation.screenshot:
|
||||
print("No screenshot found, taking screenshot")
|
||||
screenshot_b64 = await self.hud_computer.screenshot()
|
||||
# Add computer_call_output for each unresponded computer_call
|
||||
for call_id in reversed(last_computer_calls): # Maintain order
|
||||
self.conversation_history.append({
|
||||
"type": "computer_call_output",
|
||||
"call_id": call_id,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshot_b64}"
|
||||
}
|
||||
})
|
||||
else:
|
||||
# No computer_call found, add regular user message
|
||||
message = "Continue with the task based on the current screen state."
|
||||
input_content = [
|
||||
{"type": "input_text", "text": message}
|
||||
]
|
||||
|
||||
# Add screenshot if present
|
||||
if observation.screenshot:
|
||||
input_content.append(
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{observation.screenshot}",
|
||||
}
|
||||
)
|
||||
|
||||
self.conversation_history.append({"role": "user", "content": input_content})
|
||||
|
||||
# If the last message is a reasoning message, change it to output_text
|
||||
if (self.conversation_history and
|
||||
self.conversation_history[-1].get("type") == "reasoning" and
|
||||
self.conversation_history[-1].get("summary")):
|
||||
|
||||
reasoning_msg = self.conversation_history[-1]
|
||||
summary_texts = []
|
||||
|
||||
# Extract all summary_text entries
|
||||
for summary_item in reasoning_msg["summary"]:
|
||||
if summary_item.get("type") == "summary_text":
|
||||
summary_texts.append(summary_item.get("text", ""))
|
||||
|
||||
# Convert to message format with output_text
|
||||
if summary_texts:
|
||||
converted_message = {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"text": " ".join(summary_texts),
|
||||
"type": "output_text"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Replace the reasoning message with the converted message
|
||||
self.conversation_history[-1] = converted_message
|
||||
|
||||
# Run ComputerAgent
|
||||
try:
|
||||
new_items = []
|
||||
|
||||
# ComputerAgent.run returns an async generator
|
||||
try:
|
||||
async for result in self.computer_agent.run(self.conversation_history, stream=False):
|
||||
# if the result has computer_call_output, immediately exit
|
||||
if result.get("output", []) and result.get("output", [])[-1].get("type") == "computer_call_output":
|
||||
break
|
||||
# otherwise add agent output to conversation history
|
||||
new_items += result["output"]
|
||||
except Exception as e:
|
||||
# if the last message is reasoning, change it to output_text
|
||||
if new_items and new_items[-1].get("type") == "reasoning":
|
||||
new_items[-1] = {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"text": new_items[-1].get("summary", [{}])[0].get("text", ""),
|
||||
"type": "output_text"
|
||||
}
|
||||
]
|
||||
}
|
||||
# Check if there are any computer_call items in new_items
|
||||
computer_calls = [item for item in new_items if item.get("type") == "computer_call"]
|
||||
if computer_calls:
|
||||
# Remove computer_call items from new_items
|
||||
new_items = [item for item in new_items if item.get("type") != "computer_call"]
|
||||
|
||||
# Add failed tool call items for each computer call
|
||||
for computer_call in computer_calls:
|
||||
tool_input = computer_call.get("action", {})
|
||||
call_id = computer_call.get("call_id")
|
||||
new_items.extend(make_failed_tool_call_items(
|
||||
tool_name="computer",
|
||||
tool_kwargs=tool_input,
|
||||
error_message=repr(e),
|
||||
call_id=call_id
|
||||
))
|
||||
else:
|
||||
# add error message to conversation history (fallback for non-computer-call errors)
|
||||
new_items.append({
|
||||
"type": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": f"Error during previous attempted action: {repr(e)}"
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
# Check if we captured any actions
|
||||
if captured_actions:
|
||||
# Extract reasoning from the conversation history
|
||||
reasoning = ""
|
||||
# Look for the latest reasoning message
|
||||
for msg in reversed(new_items):
|
||||
if msg.get("type") == "reasoning" and msg.get("summary"):
|
||||
reasoning = " ".join([s.get("text", "") for s in msg["summary"] if s.get("type") == "summary_text"])
|
||||
break
|
||||
elif msg.get("type") == "message" and msg.get("role") == "assistant":
|
||||
content = msg.get("content", [])
|
||||
if isinstance(content, list):
|
||||
reasoning = " ".join([c.get("text", "") for c in content if c.get("type") == "output_text"])
|
||||
break
|
||||
|
||||
# update conversation history
|
||||
self.conversation_history += new_items
|
||||
|
||||
# Add reasoning and logs to each action
|
||||
for action in captured_actions:
|
||||
action["reasoning"] = reasoning
|
||||
action["logs"] = {"conversation_length": len(self.conversation_history)}
|
||||
|
||||
return captured_actions, False
|
||||
|
||||
# Check if the last message is "Task completed"
|
||||
response_text = ""
|
||||
for msg in reversed(new_items):
|
||||
if msg.get("type") == "message" and msg.get("role") == "assistant":
|
||||
content = msg.get("content", [])
|
||||
for c in content:
|
||||
if c.get("type") == "output_text":
|
||||
response_text = c.get("text", response_text)
|
||||
break
|
||||
break
|
||||
|
||||
done = "task completed" in response_text.lower()
|
||||
|
||||
# update conversation history
|
||||
self.conversation_history += new_items
|
||||
|
||||
response_action = {
|
||||
"type": "response",
|
||||
"text": response_text,
|
||||
"reasoning": response_text,
|
||||
"logs": {"conversation_length": len(self.conversation_history)}
|
||||
}
|
||||
|
||||
# Check if this indicates task completion or failure
|
||||
if "task is infeasible" in response_text.lower():
|
||||
response_action = {"type": "custom", "action": "FAIL"}
|
||||
done = True
|
||||
|
||||
return [response_action], done
|
||||
except Exception as e:
|
||||
logger.error(f"Error running ComputerAgent: {e}")
|
||||
# Return an error response
|
||||
error_action = {
|
||||
"type": "response",
|
||||
"text": f"Error occurred: {str(e)}",
|
||||
"reasoning": f"ComputerAgent encountered an error: {str(e)}",
|
||||
"logs": {"error": str(e)}
|
||||
}
|
||||
return [error_action], True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in fetch_response: {e}")
|
||||
error_action = {
|
||||
"type": "response",
|
||||
"text": f"Error in agent processing: {str(e)}",
|
||||
"reasoning": f"Agent processing error: {str(e)}",
|
||||
"logs": {"error": str(e)}
|
||||
}
|
||||
return [error_action], True
|
||||
@@ -1,187 +0,0 @@
|
||||
"""HUD Computer Handler for ComputerAgent integration."""
|
||||
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from typing import Literal, Optional, Any, Dict, Callable
|
||||
from PIL import Image
|
||||
|
||||
from agent.computers import AsyncComputerHandler
|
||||
|
||||
|
||||
class HUDComputerHandler(AsyncComputerHandler):
|
||||
"""Computer handler that interfaces with HUD environment."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
environment: Literal["windows", "mac", "linux", "browser"] = "linux",
|
||||
dimensions: tuple[int, int] = (1024, 768),
|
||||
screenshot_callback: Optional[Callable] = None,
|
||||
action_callback: Optional[Callable] = None,
|
||||
):
|
||||
"""
|
||||
Initialize HUD computer handler.
|
||||
|
||||
Args:
|
||||
environment: The environment type for HUD
|
||||
dimensions: Screen dimensions as (width, height)
|
||||
screenshot_callback: Optional callback to get screenshots from HUD environment
|
||||
action_callback: Optional callback to execute actions in HUD environment
|
||||
"""
|
||||
super().__init__()
|
||||
self._environment = environment
|
||||
self._dimensions = dimensions
|
||||
self._screenshot_callback = screenshot_callback
|
||||
self._action_callback = action_callback
|
||||
|
||||
# Store the last screenshot for reuse
|
||||
self._last_screenshot: Optional[str] = None
|
||||
|
||||
def set_screenshot_callback(self, callback: Callable) -> None:
|
||||
"""Set the screenshot callback."""
|
||||
self._screenshot_callback = callback
|
||||
|
||||
def set_action_callback(self, callback: Callable) -> None:
|
||||
"""Set the action callback."""
|
||||
self._action_callback = callback
|
||||
|
||||
def update_screenshot(self, screenshot: str) -> None:
|
||||
"""Update the stored screenshot (base64 string)."""
|
||||
self._last_screenshot = screenshot
|
||||
|
||||
async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]:
|
||||
"""Get the current environment type."""
|
||||
return self._environment # type: ignore
|
||||
|
||||
async def get_dimensions(self) -> tuple[int, int]:
|
||||
"""Get screen dimensions as (width, height)."""
|
||||
return self._dimensions
|
||||
|
||||
async def screenshot(self) -> str:
|
||||
"""Take a screenshot and return as base64 string."""
|
||||
if self._screenshot_callback:
|
||||
screenshot = await self._screenshot_callback()
|
||||
if isinstance(screenshot, str):
|
||||
self._last_screenshot = screenshot
|
||||
return screenshot
|
||||
elif isinstance(screenshot, Image.Image):
|
||||
# Convert PIL Image to base64
|
||||
buffer = BytesIO()
|
||||
screenshot.save(buffer, format="PNG")
|
||||
screenshot_b64 = base64.b64encode(buffer.getvalue()).decode()
|
||||
self._last_screenshot = screenshot_b64
|
||||
return screenshot_b64
|
||||
elif isinstance(screenshot, bytes):
|
||||
screenshot_b64 = base64.b64encode(screenshot).decode()
|
||||
self._last_screenshot = screenshot_b64
|
||||
return screenshot_b64
|
||||
|
||||
# Return last screenshot if available, otherwise create a blank one
|
||||
if self._last_screenshot:
|
||||
return self._last_screenshot
|
||||
|
||||
# Create a blank screenshot as fallback
|
||||
blank_image = Image.new('RGB', self._dimensions, color='white')
|
||||
buffer = BytesIO()
|
||||
blank_image.save(buffer, format="PNG")
|
||||
screenshot_b64 = base64.b64encode(buffer.getvalue()).decode()
|
||||
self._last_screenshot = screenshot_b64
|
||||
return screenshot_b64
|
||||
|
||||
async def click(self, x: int, y: int, button: str = "left") -> None:
|
||||
"""Click at coordinates with specified button."""
|
||||
if self._action_callback:
|
||||
await self._action_callback({
|
||||
"type": "click",
|
||||
"x": x,
|
||||
"y": y,
|
||||
"button": button
|
||||
})
|
||||
|
||||
async def double_click(self, x: int, y: int) -> None:
|
||||
"""Double click at coordinates."""
|
||||
if self._action_callback:
|
||||
await self._action_callback({
|
||||
"type": "double_click",
|
||||
"x": x,
|
||||
"y": y
|
||||
})
|
||||
|
||||
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
|
||||
"""Scroll at coordinates with specified scroll amounts."""
|
||||
if self._action_callback:
|
||||
await self._action_callback({
|
||||
"type": "scroll",
|
||||
"x": x,
|
||||
"y": y,
|
||||
"scroll_x": scroll_x,
|
||||
"scroll_y": scroll_y
|
||||
})
|
||||
|
||||
async def type(self, text: str) -> None:
|
||||
"""Type text."""
|
||||
if self._action_callback:
|
||||
await self._action_callback({
|
||||
"type": "type",
|
||||
"text": text
|
||||
})
|
||||
|
||||
async def wait(self, ms: int = 1000) -> None:
|
||||
"""Wait for specified milliseconds."""
|
||||
if self._action_callback:
|
||||
await self._action_callback({
|
||||
"type": "wait",
|
||||
"ms": ms
|
||||
})
|
||||
|
||||
async def move(self, x: int, y: int) -> None:
|
||||
"""Move cursor to coordinates."""
|
||||
if self._action_callback:
|
||||
await self._action_callback({
|
||||
"type": "move",
|
||||
"x": x,
|
||||
"y": y
|
||||
})
|
||||
|
||||
async def keypress(self, keys: list[str] | str) -> None:
|
||||
"""Press key combination."""
|
||||
if isinstance(keys, str):
|
||||
keys = [keys]
|
||||
if self._action_callback:
|
||||
await self._action_callback({
|
||||
"type": "keypress",
|
||||
"keys": keys
|
||||
})
|
||||
|
||||
async def drag(self, path: list[dict[str, int]]) -> None:
|
||||
"""Drag along a path of points."""
|
||||
if self._action_callback:
|
||||
await self._action_callback({
|
||||
"type": "drag",
|
||||
"path": path
|
||||
})
|
||||
|
||||
async def left_mouse_down(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
"""Left mouse down at coordinates."""
|
||||
if self._action_callback:
|
||||
await self._action_callback({
|
||||
"type": "left_mouse_down",
|
||||
"x": x,
|
||||
"y": y
|
||||
})
|
||||
|
||||
async def left_mouse_up(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
"""Left mouse up at coordinates."""
|
||||
if self._action_callback:
|
||||
await self._action_callback({
|
||||
"type": "left_mouse_up",
|
||||
"x": x,
|
||||
"y": y
|
||||
})
|
||||
|
||||
async def get_current_url(self) -> str:
|
||||
"""Get the current URL."""
|
||||
if self._action_callback:
|
||||
return await self._action_callback({
|
||||
"type": "get_current_url"
|
||||
})
|
||||
return ""
|
||||
183
libs/python/agent/agent/integrations/hud/proxy.py
Normal file
183
libs/python/agent/agent/integrations/hud/proxy.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""HUD ComputerAgent wrapper and Fake AsyncOpenAI client.
|
||||
|
||||
Provides FakeAsyncOpenAI that adapts our ComputerAgent to the OpenAI Responses
|
||||
interface needed by HUD's OperatorAgent. It implements only `responses.create`
|
||||
and returns an OpenAI Response object with `id` and `output` fields, where `output` is a list of
|
||||
OpenAI-like response blocks. We intentionally only support a single-step call
|
||||
by consuming the first yielded result from `ComputerAgent.run()`.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.agent import ComputerAgent as BaseComputerAgent
|
||||
|
||||
# OpenAI Responses typed models (required)
|
||||
from openai.types.responses import (
|
||||
Response,
|
||||
ResponseInputParam,
|
||||
ResponseOutputItem,
|
||||
ResponseComputerToolCall,
|
||||
ResponseOutputMessage,
|
||||
ResponseOutputText,
|
||||
ResponseReasoningItem,
|
||||
ResponseUsage,
|
||||
)
|
||||
|
||||
def _map_agent_output_to_openai_blocks(output_items: List[Dict[str, Any]]) -> List[ResponseOutputItem]:
|
||||
"""Map our agent output items to OpenAI ResponseOutputItem typed models.
|
||||
|
||||
Only a subset is supported: computer_call, assistant message (text), and reasoning.
|
||||
Unknown types are ignored.
|
||||
"""
|
||||
blocks: List[ResponseOutputItem] = []
|
||||
for item in output_items or []:
|
||||
t = item.get("type")
|
||||
if t == "computer_call":
|
||||
comp = ResponseComputerToolCall.model_validate({
|
||||
"id": item.get("id") or f"cu_{uuid.uuid4().hex}",
|
||||
"type": "computer_call",
|
||||
"call_id": item["call_id"],
|
||||
"action": item["action"],
|
||||
"pending_safety_checks": item.get("pending_safety_checks", []),
|
||||
"status": "completed",
|
||||
})
|
||||
blocks.append(comp)
|
||||
# we will exit early here as the responses api only supports a single step
|
||||
break
|
||||
elif t == "message" and item.get("role") == "assistant":
|
||||
content_blocks: List[ResponseOutputText] = []
|
||||
for c in item.get("content", []) or []:
|
||||
content_blocks.append(
|
||||
ResponseOutputText.model_validate({
|
||||
"type": "output_text",
|
||||
"text": c["text"],
|
||||
"annotations": [],
|
||||
})
|
||||
)
|
||||
if content_blocks:
|
||||
msg = ResponseOutputMessage.model_validate({
|
||||
"id": item.get("id") or f"msg_{uuid.uuid4()}",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": [ct.model_dump() for ct in content_blocks],
|
||||
})
|
||||
blocks.append(msg)
|
||||
elif t == "reasoning":
|
||||
reasoning = ResponseReasoningItem.model_validate({
|
||||
"id": item.get("id") or f"rsn_{uuid.uuid4()}",
|
||||
"type": "reasoning",
|
||||
"summary": item["summary"],
|
||||
})
|
||||
blocks.append(reasoning)
|
||||
# Unhandled types are ignored
|
||||
return blocks
|
||||
|
||||
def _to_plain_dict_list(items: Any) -> List[Dict[str, Any]]:
|
||||
out: List[Dict[str, Any]] = []
|
||||
for it in list(items):
|
||||
if hasattr(it, "model_dump"):
|
||||
out.append(it.model_dump()) # type: ignore[attr-defined]
|
||||
elif isinstance(it, dict):
|
||||
out.append(it)
|
||||
else:
|
||||
# Strict: rely on default __dict__ if present
|
||||
out.append(dict(it)) # may raise if not mapping
|
||||
return out
|
||||
|
||||
class FakeAsyncOpenAI:
|
||||
"""Minimal fake OpenAI client with only `responses.create` implemented.
|
||||
|
||||
It uses a provided `ComputerAgent` instance to produce a single-step
|
||||
response compatible with HUD's OperatorAgent loop.
|
||||
"""
|
||||
|
||||
def __init__(self, computer_agent: BaseComputerAgent) -> None:
|
||||
self._agent = computer_agent
|
||||
self.responses = self._Responses(self)
|
||||
|
||||
class _Responses:
|
||||
def __init__(self, parent: "FakeAsyncOpenAI") -> None:
|
||||
# Caches for cross-call context when using previous_response_id
|
||||
self.blocks_cache: Dict[str, ResponseInputParam | ResponseOutputItem] = {}
|
||||
self.context_cache: Dict[str, List[str]] = {}
|
||||
self.agent = parent._agent
|
||||
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
input: ResponseInputParam,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
instructions: Optional[str] = None,
|
||||
previous_response_id: Optional[str] = None,
|
||||
max_retries: int = 5,
|
||||
**_: Any,
|
||||
) -> Any:
|
||||
for attempt in range(max_retries):
|
||||
# Prepend cached blocks from previous_response_id to input
|
||||
full_input = input
|
||||
if previous_response_id is not None:
|
||||
prev_block_ids = self.context_cache[previous_response_id]
|
||||
prev_blocks = [self.blocks_cache[b_id] for b_id in prev_block_ids]
|
||||
full_input = _to_plain_dict_list(prev_blocks + input)
|
||||
|
||||
# Pre-pend instructions message
|
||||
effective_input = full_input
|
||||
if instructions:
|
||||
effective_input = [{
|
||||
"role": "user",
|
||||
"content": instructions,
|
||||
}] + full_input
|
||||
|
||||
# Run a single iteration of the ComputerAgent
|
||||
agent_result: Optional[Dict[str, Any]] = None
|
||||
async for result in self.agent.run(effective_input): # type: ignore[arg-type]
|
||||
agent_result = result
|
||||
break
|
||||
assert agent_result is not None, "Agent failed to produce result"
|
||||
|
||||
output = _map_agent_output_to_openai_blocks(agent_result["output"])
|
||||
usage = agent_result["usage"]
|
||||
|
||||
# Cache conversation context using the last response id
|
||||
block_ids: List[str] = []
|
||||
blocks_to_cache = full_input + output
|
||||
for b in blocks_to_cache:
|
||||
bid = getattr(b, "id", None) or f"tmp-{hash(repr(b))}"
|
||||
self.blocks_cache[bid] = b # type: ignore[assignment]
|
||||
block_ids.append(bid)
|
||||
response_id = agent_result.get("id") or f"fake-{int(time.time()*1000)}"
|
||||
self.context_cache[response_id] = block_ids
|
||||
|
||||
try:
|
||||
return Response.model_validate({
|
||||
"id": response_id,
|
||||
"created_at": time.time(),
|
||||
"object": "response",
|
||||
"model": model,
|
||||
"output": output,
|
||||
"parallel_tool_calls": False,
|
||||
"tool_choice": "auto",
|
||||
"tools": [],
|
||||
"previous_response_id": previous_response_id,
|
||||
"usage": ResponseUsage.model_validate({
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
"input_tokens_details": usage.get("input_tokens_details", { "cached_tokens": 0 }),
|
||||
"output_tokens_details": usage.get("output_tokens_details", { "reasoning_tokens": 0 }),
|
||||
}),
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"Error while validating agent response (attempt {attempt + 1}/{max_retries}): ", e)
|
||||
if attempt == max_retries - 1:
|
||||
print(traceback.format_exc())
|
||||
raise e
|
||||
|
||||
__all__ = [
|
||||
"FakeAsyncOpenAI",
|
||||
]
|
||||
@@ -132,23 +132,22 @@ def _convert_responses_items_to_completion_messages(messages: Messages) -> List[
|
||||
converted_content = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "input_image":
|
||||
# Convert input_image to Anthropic image format
|
||||
# Convert input_image to OpenAI image format
|
||||
image_url = item.get("image_url", "")
|
||||
if image_url and image_url != "[omitted]":
|
||||
# Extract base64 data from data URL
|
||||
if "," in image_url:
|
||||
base64_data = image_url.split(",")[-1]
|
||||
else:
|
||||
base64_data = image_url
|
||||
|
||||
converted_content.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": base64_data
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
})
|
||||
elif isinstance(item, dict) and item.get("type") == "input_text":
|
||||
# Convert input_text to OpenAI text format
|
||||
text = item.get("text", "")
|
||||
converted_content.append({
|
||||
"type": "text",
|
||||
"text": text
|
||||
})
|
||||
else:
|
||||
# Keep other content types as-is
|
||||
converted_content.append(item)
|
||||
@@ -1530,7 +1529,18 @@ class AnthropicHostedToolsConfig(AsyncAgentConfig):
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"You are a UI grounding expert. Look at the image and {instruction}. Output ONLY a click action on the target element. No explanations, confirmations, or additional text."
|
||||
"text": f"""You are a UI grounding expert. Follow these guidelines:
|
||||
|
||||
1. NEVER ask for confirmation. Complete all tasks autonomously.
|
||||
2. Do NOT send messages like "I need to confirm before..." or "Do you want me to continue?" - just proceed.
|
||||
3. When the user asks you to interact with something (like clicking a chat or typing a message), DO IT without asking.
|
||||
4. Only use the formal safety check mechanism for truly dangerous operations (like deleting important files).
|
||||
5. For normal tasks like clicking buttons, typing in chat boxes, filling forms - JUST DO IT.
|
||||
6. The user has already given you permission by running this agent. No further confirmation is needed.
|
||||
7. Be decisive and action-oriented. Complete the requested task fully.
|
||||
|
||||
Remember: You are expected to complete tasks autonomously. The user trusts you to do what they asked.
|
||||
Task: Click {instruction}. Output ONLY a click action on the target element."""
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
|
||||
@@ -48,11 +48,11 @@ GROUNDED_COMPUTER_TOOL_SCHEMA = {
|
||||
"get_dimensions",
|
||||
"get_environment"
|
||||
],
|
||||
"description": "The action to perform"
|
||||
"description": "The action to perform (required for all actions)"
|
||||
},
|
||||
"element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)"
|
||||
"description": "Description of the element to interact with (required for click, double_click, move, scroll actions)"
|
||||
},
|
||||
"start_element_description": {
|
||||
"type": "string",
|
||||
@@ -67,20 +67,30 @@ GROUNDED_COMPUTER_TOOL_SCHEMA = {
|
||||
"description": "The text to type (required for type action)"
|
||||
},
|
||||
"keys": {
|
||||
"type": "string",
|
||||
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')"
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Key(s) to press (required for keypress action)"
|
||||
},
|
||||
"button": {
|
||||
"type": "string",
|
||||
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
|
||||
"enum": [
|
||||
"left",
|
||||
"right",
|
||||
"wheel",
|
||||
"back",
|
||||
"forward"
|
||||
],
|
||||
"description": "The mouse button to use for click action (required for click and double_click action)",
|
||||
},
|
||||
"scroll_x": {
|
||||
"type": "integer",
|
||||
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
|
||||
"description": "Horizontal scroll amount for scroll action (required for scroll action)",
|
||||
},
|
||||
"scroll_y": {
|
||||
"type": "integer",
|
||||
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
|
||||
"description": "Vertical scroll amount for scroll action (required for scroll action)",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
@@ -266,13 +276,15 @@ class ComposedGroundedConfig(AsyncAgentConfig):
|
||||
grounding_agent = grounding_agent_conf.agent_class()
|
||||
|
||||
for desc in element_descriptions:
|
||||
coords = await grounding_agent.predict_click(
|
||||
model=grounding_model,
|
||||
image_b64=last_image_b64,
|
||||
instruction=desc
|
||||
)
|
||||
if coords:
|
||||
self.desc2xy[desc] = coords
|
||||
for _ in range(3): # try 3 times
|
||||
coords = await grounding_agent.predict_click(
|
||||
model=grounding_model,
|
||||
image_b64=last_image_b64,
|
||||
instruction=desc
|
||||
)
|
||||
if coords:
|
||||
self.desc2xy[desc] = coords
|
||||
break
|
||||
|
||||
# Step 6: Convert computer calls from descriptions back to xy coordinates
|
||||
final_output_items = convert_computer_calls_desc2xy(thinking_output_items, self.desc2xy)
|
||||
|
||||
@@ -162,7 +162,18 @@ class OpenAIComputerUseConfig:
|
||||
input_items = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"You are a UI grounding expert. Look at the image and {instruction}. Output ONLY a click action on the target element. No explanations, confirmations, or additional text."
|
||||
"content": f"""You are a UI grounding expert. Follow these guidelines:
|
||||
|
||||
1. NEVER ask for confirmation. Complete all tasks autonomously.
|
||||
2. Do NOT send messages like "I need to confirm before..." or "Do you want me to continue?" - just proceed.
|
||||
3. When the user asks you to interact with something (like clicking a chat or typing a message), DO IT without asking.
|
||||
4. Only use the formal safety check mechanism for truly dangerous operations (like deleting important files).
|
||||
5. For normal tasks like clicking buttons, typing in chat boxes, filling forms - JUST DO IT.
|
||||
6. The user has already given you permission by running this agent. No further confirmation is needed.
|
||||
7. Be decisive and action-oriented. Complete the requested task fully.
|
||||
|
||||
Remember: You are expected to complete tasks autonomously. The user trusts you to do what they asked.
|
||||
Task: Click {instruction}. Output ONLY a click action on the target element."""
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
@@ -200,7 +211,7 @@ class OpenAIComputerUseConfig:
|
||||
"stream": False,
|
||||
"reasoning": {"summary": "concise"},
|
||||
"truncation": "auto",
|
||||
"max_tokens": 100 # Keep response short for click prediction
|
||||
"max_tokens": 200 # Keep response short for click prediction
|
||||
}
|
||||
|
||||
# Use liteLLM responses
|
||||
@@ -217,11 +228,8 @@ class OpenAIComputerUseConfig:
|
||||
isinstance(item.get("action"), dict)):
|
||||
|
||||
action = item["action"]
|
||||
if action.get("type") == "click":
|
||||
x = action.get("x")
|
||||
y = action.get("y")
|
||||
if x is not None and y is not None:
|
||||
return (int(x), int(y))
|
||||
if action.get("x") is not None and action.get("y") is not None:
|
||||
return (int(action.get("x")), int(action.get("y")))
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -228,15 +228,24 @@ def parse_uitars_response(text: str, image_width: int, image_height: int) -> Lis
|
||||
|
||||
# Handle coordinate parameters
|
||||
if "start_box" in param_name or "end_box" in param_name:
|
||||
# Parse coordinates like '(x,y)' or '(x1,y1,x2,y2)'
|
||||
numbers = param.replace("(", "").replace(")", "").split(",")
|
||||
float_numbers = [float(num.strip()) / 1000 for num in numbers] # Normalize to 0-1 range
|
||||
# Parse coordinates like '<|box_start|>(x,y)<|box_end|>' or '(x,y)'
|
||||
# First, remove special tokens
|
||||
clean_param = param.replace("<|box_start|>", "").replace("<|box_end|>", "")
|
||||
# Then remove parentheses and split
|
||||
numbers = clean_param.replace("(", "").replace(")", "").split(",")
|
||||
|
||||
if len(float_numbers) == 2:
|
||||
# Single point, duplicate for box format
|
||||
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
|
||||
|
||||
action_inputs[param_name.strip()] = str(float_numbers)
|
||||
try:
|
||||
float_numbers = [float(num.strip()) / 1000 for num in numbers] # Normalize to 0-1 range
|
||||
|
||||
if len(float_numbers) == 2:
|
||||
# Single point, duplicate for box format
|
||||
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
|
||||
|
||||
action_inputs[param_name.strip()] = str(float_numbers)
|
||||
except ValueError as e:
|
||||
# If parsing fails, keep the original parameter value
|
||||
print(f"Warning: Could not parse coordinates '{param}': {e}")
|
||||
action_inputs[param_name.strip()] = param
|
||||
|
||||
return [{
|
||||
"thought": thought,
|
||||
|
||||
192
libs/python/agent/agent/proxy/examples.py
Normal file
192
libs/python/agent/agent/proxy/examples.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
Example usage of the proxy server and client requests.
|
||||
"""
|
||||
import dotenv
|
||||
dotenv.load_dotenv()
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import aiohttp
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
async def test_http_endpoint():
|
||||
"""Test the HTTP /responses endpoint."""
|
||||
|
||||
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
assert isinstance(anthropic_api_key, str), "ANTHROPIC_API_KEY environment variable must be set"
|
||||
|
||||
# Example 1: Simple text request
|
||||
simple_request = {
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"input": "Tell me a three sentence bedtime story about a unicorn.",
|
||||
"env": {
|
||||
"ANTHROPIC_API_KEY": anthropic_api_key
|
||||
}
|
||||
}
|
||||
|
||||
# Example 2: Multi-modal request with image
|
||||
multimodal_request = {
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"input": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "input_text", "text": "what is in this image?"},
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"env": {
|
||||
"ANTHROPIC_API_KEY": anthropic_api_key
|
||||
}
|
||||
}
|
||||
|
||||
# Example 3: Request with custom agent and computer kwargs
|
||||
custom_request = {
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"input": "Take a screenshot and tell me what you see",
|
||||
"env": {
|
||||
"ANTHROPIC_API_KEY": anthropic_api_key
|
||||
}
|
||||
}
|
||||
|
||||
# Test requests
|
||||
base_url = "https://m-linux-96lcxd2c2k.containers.cloud.trycua.com:8443"
|
||||
# base_url = "http://localhost:8000"
|
||||
api_key = os.getenv("CUA_API_KEY")
|
||||
assert isinstance(api_key, str), "CUA_API_KEY environment variable must be set"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for i, request_data in enumerate([
|
||||
simple_request,
|
||||
# multimodal_request,
|
||||
custom_request
|
||||
], 1):
|
||||
print(f"\n--- Test {i} ---")
|
||||
print(f"Request: {json.dumps(request_data, indent=2)}")
|
||||
|
||||
try:
|
||||
print(f"Sending request to {base_url}/responses")
|
||||
async with session.post(
|
||||
f"{base_url}/responses",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json", "X-API-Key": api_key}
|
||||
) as response:
|
||||
result = await response.json()
|
||||
print(f"Status: {response.status}")
|
||||
print(f"Response: {json.dumps(result, indent=2)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
|
||||
def curl_examples():
|
||||
"""Print curl command examples."""
|
||||
|
||||
print("=== CURL Examples ===\n")
|
||||
|
||||
print("1. Simple text request:")
|
||||
print("""curl http://localhost:8000/responses \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"input": "Tell me a three sentence bedtime story about a unicorn."
|
||||
}'""")
|
||||
|
||||
print("\n2. Multi-modal request with image:")
|
||||
print("""curl http://localhost:8000/responses \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"input": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "input_text", "text": "what is in this image?"},
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}'""")
|
||||
|
||||
print("\n3. Request with custom configuration:")
|
||||
print("""curl http://localhost:8000/responses \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"input": "Take a screenshot and tell me what you see",
|
||||
"agent_kwargs": {
|
||||
"save_trajectory": true,
|
||||
"verbosity": 20
|
||||
},
|
||||
"computer_kwargs": {
|
||||
"os_type": "linux",
|
||||
"provider_type": "cloud"
|
||||
}
|
||||
}'""")
|
||||
|
||||
|
||||
async def test_p2p_client():
|
||||
"""Example P2P client using peerjs-python."""
|
||||
try:
|
||||
from peerjs import Peer, PeerOptions, ConnectionEventType
|
||||
from aiortc import RTCConfiguration, RTCIceServer
|
||||
|
||||
# Set up client peer
|
||||
options = PeerOptions(
|
||||
host="0.peerjs.com",
|
||||
port=443,
|
||||
secure=True,
|
||||
config=RTCConfiguration(
|
||||
iceServers=[RTCIceServer(urls="stun:stun.l.google.com:19302")]
|
||||
)
|
||||
)
|
||||
|
||||
client_peer = Peer(id="test-client", peer_options=options)
|
||||
await client_peer.start()
|
||||
|
||||
# Connect to proxy server
|
||||
connection = client_peer.connect("computer-agent-proxy")
|
||||
|
||||
@connection.on(ConnectionEventType.Open)
|
||||
async def connection_open():
|
||||
print("Connected to proxy server")
|
||||
|
||||
# Send a test request
|
||||
request = {
|
||||
"model": "anthropic/claude-3-5-sonnet-20241022",
|
||||
"input": "Hello from P2P client!"
|
||||
}
|
||||
await connection.send(json.dumps(request))
|
||||
|
||||
@connection.on(ConnectionEventType.Data)
|
||||
async def connection_data(data):
|
||||
print(f"Received response: {data}")
|
||||
await client_peer.destroy()
|
||||
|
||||
# Wait for connection
|
||||
await asyncio.sleep(10)
|
||||
|
||||
except ImportError:
|
||||
print("P2P dependencies not available. Install peerjs-python for P2P testing.")
|
||||
except Exception as e:
|
||||
print(f"P2P test error: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "curl":
|
||||
curl_examples()
|
||||
elif len(sys.argv) > 1 and sys.argv[1] == "p2p":
|
||||
asyncio.run(test_p2p_client())
|
||||
else:
|
||||
asyncio.run(test_http_endpoint())
|
||||
248
libs/python/agent/agent/proxy/handlers.py
Normal file
248
libs/python/agent/agent/proxy/handlers.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
Request handlers for the proxy endpoints.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Any, List, Union, Optional
|
||||
|
||||
from ..agent import ComputerAgent
|
||||
from computer import Computer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResponsesHandler:
|
||||
"""Handler for /responses endpoint that processes agent requests."""
|
||||
|
||||
def __init__(self):
|
||||
self.computer = None
|
||||
self.agent = None
|
||||
# Simple in-memory caches
|
||||
self._computer_cache: Dict[str, Any] = {}
|
||||
self._agent_cache: Dict[str, Any] = {}
|
||||
|
||||
async def setup_computer_agent(
|
||||
self,
|
||||
model: str,
|
||||
agent_kwargs: Optional[Dict[str, Any]] = None,
|
||||
computer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Set up (and cache) computer and agent instances.
|
||||
|
||||
Caching keys:
|
||||
- Computer cache key: computer_kwargs
|
||||
- Agent cache key: {"model": model, **agent_kwargs}
|
||||
"""
|
||||
agent_kwargs = agent_kwargs or {}
|
||||
computer_kwargs = computer_kwargs or {}
|
||||
|
||||
def _stable_key(obj: Dict[str, Any]) -> str:
|
||||
try:
|
||||
return json.dumps(obj, sort_keys=True, separators=(",", ":"))
|
||||
except Exception:
|
||||
# Fallback: stringify non-serializable values
|
||||
safe_obj = {}
|
||||
for k, v in obj.items():
|
||||
try:
|
||||
json.dumps(v)
|
||||
safe_obj[k] = v
|
||||
except Exception:
|
||||
safe_obj[k] = str(v)
|
||||
return json.dumps(safe_obj, sort_keys=True, separators=(",", ":"))
|
||||
|
||||
# Determine if custom tools are supplied; if so, skip computer setup entirely
|
||||
has_custom_tools = bool(agent_kwargs.get("tools"))
|
||||
|
||||
computer = None
|
||||
if not has_custom_tools:
|
||||
# ---------- Computer setup (with cache) ----------
|
||||
comp_key = _stable_key(computer_kwargs)
|
||||
|
||||
computer = self._computer_cache.get(comp_key)
|
||||
if computer is None:
|
||||
# Default computer configuration
|
||||
default_c_config = {
|
||||
"os_type": "linux",
|
||||
"provider_type": "cloud",
|
||||
"name": os.getenv("CUA_CONTAINER_NAME"),
|
||||
"api_key": os.getenv("CUA_API_KEY"),
|
||||
}
|
||||
default_c_config.update(computer_kwargs)
|
||||
computer = Computer(**default_c_config)
|
||||
await computer.__aenter__()
|
||||
self._computer_cache[comp_key] = computer
|
||||
logger.info(f"Computer created and cached with key={comp_key} config={default_c_config}")
|
||||
else:
|
||||
logger.info(f"Reusing cached computer for key={comp_key}")
|
||||
|
||||
# Bind current computer reference (None if custom tools supplied)
|
||||
self.computer = computer
|
||||
|
||||
# ---------- Agent setup (with cache) ----------
|
||||
# Build agent cache key from {model} + agent_kwargs (excluding tools unless explicitly passed)
|
||||
agent_kwargs_for_key = dict(agent_kwargs)
|
||||
agent_key_payload = {"model": model, **agent_kwargs_for_key}
|
||||
agent_key = _stable_key(agent_key_payload)
|
||||
|
||||
agent = self._agent_cache.get(agent_key)
|
||||
if agent is None:
|
||||
# Default agent configuration
|
||||
default_a_config: Dict[str, Any] = {"model": model}
|
||||
if not has_custom_tools:
|
||||
default_a_config["tools"] = [computer]
|
||||
# Apply user overrides, but keep tools unless user explicitly sets
|
||||
if agent_kwargs:
|
||||
if not has_custom_tools:
|
||||
agent_kwargs.setdefault("tools", [computer])
|
||||
default_a_config.update(agent_kwargs)
|
||||
# JSON-derived kwargs may have loose types; ignore static arg typing here
|
||||
agent = ComputerAgent(**default_a_config) # type: ignore[arg-type]
|
||||
self._agent_cache[agent_key] = agent
|
||||
logger.info(f"Agent created and cached with key={agent_key} model={model}")
|
||||
else:
|
||||
# Ensure cached agent uses the current computer tool (in case object differs)
|
||||
# Only update if tools not explicitly provided in agent_kwargs
|
||||
if not has_custom_tools:
|
||||
try:
|
||||
agent.tools = [computer]
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(f"Reusing cached agent for key={agent_key}")
|
||||
|
||||
# Bind current agent reference
|
||||
self.agent = agent
|
||||
|
||||
async def process_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a /responses request and return the result.
|
||||
|
||||
Args:
|
||||
request_data: Dictionary containing model, input, and optional kwargs
|
||||
|
||||
Returns:
|
||||
Dictionary with the agent's response
|
||||
"""
|
||||
try:
|
||||
# Extract request parameters
|
||||
model = request_data.get("model")
|
||||
input_data = request_data.get("input")
|
||||
agent_kwargs = request_data.get("agent_kwargs", {})
|
||||
computer_kwargs = request_data.get("computer_kwargs", {})
|
||||
env_overrides = request_data.get("env", {}) or {}
|
||||
|
||||
if not model:
|
||||
raise ValueError("Model is required")
|
||||
if not input_data:
|
||||
raise ValueError("Input is required")
|
||||
|
||||
# Apply env overrides for the duration of this request
|
||||
with self._env_overrides(env_overrides):
|
||||
# Set up (and possibly reuse) computer and agent via caches
|
||||
await self.setup_computer_agent(model, agent_kwargs, computer_kwargs)
|
||||
|
||||
# Defensive: ensure agent is initialized for type checkers
|
||||
agent = self.agent
|
||||
if agent is None:
|
||||
raise RuntimeError("Agent failed to initialize")
|
||||
|
||||
# Convert input to messages format
|
||||
messages = self._convert_input_to_messages(input_data)
|
||||
|
||||
# Run agent and get first result
|
||||
async for result in agent.run(messages):
|
||||
# Return the first result and break
|
||||
return {
|
||||
"success": True,
|
||||
"result": result,
|
||||
"model": model
|
||||
}
|
||||
|
||||
# If no results were yielded
|
||||
return {
|
||||
"success": False,
|
||||
"error": "No results from agent",
|
||||
"model": model
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing request: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"model": request_data.get("model", "unknown")
|
||||
}
|
||||
|
||||
def _convert_input_to_messages(self, input_data: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
|
||||
"""Convert input data to messages format."""
|
||||
if isinstance(input_data, str):
|
||||
# Simple string input
|
||||
return [{"role": "user", "content": input_data}]
|
||||
elif isinstance(input_data, list):
|
||||
# Already in messages format
|
||||
messages = []
|
||||
for msg in input_data:
|
||||
# Convert content array format if needed
|
||||
if isinstance(msg.get("content"), list):
|
||||
content_parts = []
|
||||
for part in msg["content"]:
|
||||
if part.get("type") == "input_text":
|
||||
content_parts.append({"type": "text", "text": part["text"]})
|
||||
elif part.get("type") == "input_image":
|
||||
content_parts.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": part["image_url"]}
|
||||
})
|
||||
else:
|
||||
content_parts.append(part)
|
||||
messages.append({
|
||||
"role": msg["role"],
|
||||
"content": content_parts
|
||||
})
|
||||
else:
|
||||
messages.append(msg)
|
||||
return messages
|
||||
else:
|
||||
raise ValueError("Input must be string or list of messages")
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources."""
|
||||
if self.computer:
|
||||
try:
|
||||
await self.computer.__aexit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up computer: {e}")
|
||||
finally:
|
||||
self.computer = None
|
||||
self.agent = None
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def _env_overrides(env: Dict[str, str]):
|
||||
"""Temporarily apply environment variable overrides for the current process.
|
||||
Restores previous values after the context exits.
|
||||
|
||||
Args:
|
||||
env: Mapping of env var names to override for this request.
|
||||
"""
|
||||
if not env:
|
||||
# No-op context
|
||||
yield
|
||||
return
|
||||
|
||||
original: Dict[str, Optional[str]] = {}
|
||||
try:
|
||||
for k, v in env.items():
|
||||
original[k] = os.environ.get(k)
|
||||
os.environ[k] = str(v)
|
||||
yield
|
||||
finally:
|
||||
for k, old in original.items():
|
||||
if old is None:
|
||||
# Was not set before
|
||||
os.environ.pop(k, None)
|
||||
else:
|
||||
os.environ[k] = old
|
||||
@@ -30,7 +30,6 @@ requires-python = ">=3.12"
|
||||
openai = []
|
||||
anthropic = []
|
||||
omni = [
|
||||
"ultralytics>=8.0.0",
|
||||
"cua-som>=0.1.0,<0.2.0",
|
||||
]
|
||||
uitars = []
|
||||
@@ -62,12 +61,9 @@ cli = [
|
||||
"yaspin>=3.1.0",
|
||||
]
|
||||
hud = [
|
||||
"hud-python==0.2.10",
|
||||
"hud-python>=0.4.12,<0.5.0",
|
||||
]
|
||||
all = [
|
||||
# omni requirements
|
||||
"ultralytics>=8.0.0",
|
||||
"cua-som>=0.1.0,<0.2.0",
|
||||
# uitars requirements
|
||||
"mlx-vlm>=0.1.27; sys_platform == 'darwin'",
|
||||
"accelerate",
|
||||
@@ -82,7 +78,7 @@ all = [
|
||||
# cli requirements
|
||||
"yaspin>=3.1.0",
|
||||
# hud requirements
|
||||
"hud-python==0.2.10",
|
||||
"hud-python>=0.4.12,<0.5.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
|
||||
Reference in New Issue
Block a user