mirror of
https://github.com/trycua/computer.git
synced 2025-12-31 10:29:59 -06:00
Added screenshot_dir and lazy loading of MLX
This commit is contained in:
@@ -78,8 +78,6 @@ class MLXVLMAdapter(CustomLLM):
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
super().__init__()
|
||||
if not MLX_AVAILABLE:
|
||||
raise ImportError("MLX VLM dependencies not available. Please install mlx-vlm.")
|
||||
|
||||
self.models = {} # Cache for loaded models
|
||||
self.processors = {} # Cache for loaded processors
|
||||
@@ -95,6 +93,9 @@ class MLXVLMAdapter(CustomLLM):
|
||||
Returns:
|
||||
Tuple of (model, processor, config)
|
||||
"""
|
||||
if not MLX_AVAILABLE:
|
||||
raise ImportError("MLX VLM dependencies not available. Please install mlx-vlm.")
|
||||
|
||||
if model_name not in self.models:
|
||||
# Load model and processor
|
||||
model_obj, processor = load(
|
||||
|
||||
@@ -3,6 +3,7 @@ ComputerAgent - Main agent class that selects and runs agent loops
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Callable, Set, Tuple
|
||||
|
||||
from litellm.responses.utils import Usage
|
||||
@@ -162,7 +163,7 @@ class ComputerAgent:
|
||||
only_n_most_recent_images: Optional[int] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
verbosity: Optional[int] = None,
|
||||
trajectory_dir: Optional[str] = None,
|
||||
trajectory_dir: Optional[str | Path | dict] = None,
|
||||
max_retries: Optional[int] = 3,
|
||||
screenshot_delay: Optional[float | int] = 0.5,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
@@ -223,7 +224,10 @@ class ComputerAgent:
|
||||
|
||||
# Add trajectory saver callback if trajectory_dir is set
|
||||
if self.trajectory_dir:
|
||||
self.callbacks.append(TrajectorySaverCallback(self.trajectory_dir))
|
||||
if isinstance(self.trajectory_dir, dict):
|
||||
self.callbacks.append(TrajectorySaverCallback(**self.trajectory_dir))
|
||||
elif isinstance(self.trajectory_dir, (str, Path)):
|
||||
self.callbacks.append(TrajectorySaverCallback(str(self.trajectory_dir)))
|
||||
|
||||
# Add budget manager if max_trajectory_budget is set
|
||||
if max_trajectory_budget:
|
||||
|
||||
@@ -11,6 +11,8 @@ from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Union, override
|
||||
from PIL import Image, ImageDraw
|
||||
import io
|
||||
from copy import deepcopy
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
def sanitize_image_urls(data: Any) -> Any:
|
||||
@@ -43,6 +45,64 @@ def sanitize_image_urls(data: Any) -> Any:
|
||||
return data
|
||||
|
||||
|
||||
def extract_computer_call_outputs(items: List[Dict[str, Any]], screenshot_dir: Optional[Path]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Save any base64-encoded screenshots from computer_call_output entries to files and
|
||||
replace their image_url with the saved file path when a call_id is present.
|
||||
|
||||
Only operates if screenshot_dir is provided and exists; otherwise returns items unchanged.
|
||||
|
||||
Args:
|
||||
items: List of message/result dicts potentially containing computer_call_output entries
|
||||
screenshot_dir: Directory to write screenshots into
|
||||
|
||||
Returns:
|
||||
A new list with updated image_url fields when applicable.
|
||||
"""
|
||||
if not items:
|
||||
return items
|
||||
if not screenshot_dir or not screenshot_dir.exists():
|
||||
return items
|
||||
|
||||
updated: List[Dict[str, Any]] = []
|
||||
for item in items:
|
||||
# work on a shallow copy; deep copy nested 'output' if we modify it
|
||||
msg = dict(item)
|
||||
try:
|
||||
if msg.get("type") == "computer_call_output":
|
||||
call_id = msg.get("call_id")
|
||||
output = msg.get("output", {})
|
||||
image_url = output.get("image_url")
|
||||
if call_id and isinstance(image_url, str) and image_url.startswith("data:"):
|
||||
# derive extension from MIME type e.g. data:image/png;base64,
|
||||
try:
|
||||
ext = image_url.split(";", 1)[0].split("/")[-1]
|
||||
if not ext:
|
||||
ext = "png"
|
||||
except Exception:
|
||||
ext = "png"
|
||||
out_path = screenshot_dir / f"{call_id}.{ext}"
|
||||
# write file if it doesn't exist
|
||||
if not out_path.exists():
|
||||
try:
|
||||
b64_payload = image_url.split(",", 1)[1]
|
||||
img_bytes = base64.b64decode(b64_payload)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(out_path, "wb") as f:
|
||||
f.write(img_bytes)
|
||||
except Exception:
|
||||
# if anything fails, skip modifying this message
|
||||
pass
|
||||
# update image_url to file path
|
||||
new_output = dict(output)
|
||||
new_output["image_url"] = str(out_path)
|
||||
msg["output"] = new_output
|
||||
except Exception:
|
||||
# do not block on malformed entries; keep original
|
||||
pass
|
||||
updated.append(msg)
|
||||
return updated
|
||||
|
||||
class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
Callback handler that saves agent trajectories to disk.
|
||||
@@ -51,7 +111,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
within the trajectory gets its own folder with screenshots and responses.
|
||||
"""
|
||||
|
||||
def __init__(self, trajectory_dir: str, reset_on_run: bool = True):
|
||||
def __init__(self, trajectory_dir: str, reset_on_run: bool = True, screenshot_dir: Optional[str] = None):
|
||||
"""
|
||||
Initialize trajectory saver.
|
||||
|
||||
@@ -67,10 +127,12 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
self.model: Optional[str] = None
|
||||
self.total_usage: Dict[str, Any] = {}
|
||||
self.reset_on_run = reset_on_run
|
||||
# Optional directory to store extracted screenshots from metadata/new_items
|
||||
self.screenshot_dir: Optional[Path] = Path(screenshot_dir) if screenshot_dir else None
|
||||
|
||||
# Ensure trajectory directory exists
|
||||
self.trajectory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def _get_turn_dir(self) -> Path:
|
||||
"""Get the directory for the current turn."""
|
||||
if not self.trajectory_id:
|
||||
@@ -139,12 +201,21 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
trajectory_path = self.trajectory_dir / self.trajectory_id
|
||||
trajectory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save trajectory metadata
|
||||
# Save trajectory metadata (optionally extract screenshots to screenshot_dir)
|
||||
kwargs_to_save = kwargs.copy()
|
||||
try:
|
||||
if "messages" in kwargs_to_save:
|
||||
kwargs_to_save["messages"] = extract_computer_call_outputs(
|
||||
kwargs_to_save["messages"], self.screenshot_dir
|
||||
)
|
||||
except Exception:
|
||||
# If extraction fails, fall back to original messages
|
||||
pass
|
||||
metadata = {
|
||||
"trajectory_id": self.trajectory_id,
|
||||
"created_at": str(uuid.uuid1().time),
|
||||
"status": "running",
|
||||
"kwargs": kwargs,
|
||||
"kwargs": kwargs_to_save,
|
||||
}
|
||||
|
||||
with open(trajectory_path / "metadata.json", "w") as f:
|
||||
@@ -171,11 +242,18 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
metadata = {}
|
||||
|
||||
# Update metadata with completion info
|
||||
# Optionally extract screenshots from new_items before persisting
|
||||
new_items_to_save = new_items
|
||||
try:
|
||||
new_items_to_save = extract_computer_call_outputs(new_items, self.screenshot_dir)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
metadata.update({
|
||||
"status": "completed",
|
||||
"completed_at": str(uuid.uuid1().time),
|
||||
"total_usage": self.total_usage,
|
||||
"new_items": new_items,
|
||||
"new_items": new_items_to_save,
|
||||
"total_turns": self.current_turn
|
||||
})
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class ProxyOperatorAgent(OperatorAgent):
|
||||
*,
|
||||
model: str | None = None,
|
||||
allowed_tools: list[str] | None = None,
|
||||
trajectory_dir: str | None = None,
|
||||
trajectory_dir: str | dict | None = None,
|
||||
# === ComputerAgent kwargs ===
|
||||
tools: list[Any] | None = None,
|
||||
custom_loop: Any | None = None,
|
||||
@@ -109,7 +109,7 @@ async def run_single_task(
|
||||
only_n_most_recent_images: int | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
verbosity: int | None = None,
|
||||
trajectory_dir: str | None = None,
|
||||
trajectory_dir: str | dict | None = None,
|
||||
max_retries: int | None = 3,
|
||||
screenshot_delay: float | int = 0.5,
|
||||
use_prompt_caching: bool | None = False,
|
||||
@@ -167,7 +167,7 @@ async def run_full_dataset(
|
||||
max_concurrent: int = 30,
|
||||
max_steps: int = 50,
|
||||
split: str = "train",
|
||||
trajectory_dir: str | None = None,
|
||||
trajectory_dir: str | dict | None = None,
|
||||
# === ComputerAgent kwargs ===
|
||||
tools: list[Any] | None = None,
|
||||
custom_loop: Any | None = None,
|
||||
|
||||
Reference in New Issue
Block a user