Added screenshot_dir and lazy loading of MLX

This commit is contained in:
Dillon DuPont
2025-08-28 13:18:17 -04:00
parent c4ce791a49
commit e90997c4ff
4 changed files with 95 additions and 12 deletions

View File

@@ -78,8 +78,6 @@ class MLXVLMAdapter(CustomLLM):
**kwargs: Additional arguments
"""
super().__init__()
if not MLX_AVAILABLE:
raise ImportError("MLX VLM dependencies not available. Please install mlx-vlm.")
self.models = {} # Cache for loaded models
self.processors = {} # Cache for loaded processors
@@ -95,6 +93,9 @@ class MLXVLMAdapter(CustomLLM):
Returns:
Tuple of (model, processor, config)
"""
if not MLX_AVAILABLE:
raise ImportError("MLX VLM dependencies not available. Please install mlx-vlm.")
if model_name not in self.models:
# Load model and processor
model_obj, processor = load(

View File

@@ -3,6 +3,7 @@ ComputerAgent - Main agent class that selects and runs agent loops
"""
import asyncio
from pathlib import Path
from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Callable, Set, Tuple
from litellm.responses.utils import Usage
@@ -162,7 +163,7 @@ class ComputerAgent:
only_n_most_recent_images: Optional[int] = None,
callbacks: Optional[List[Any]] = None,
verbosity: Optional[int] = None,
trajectory_dir: Optional[str] = None,
trajectory_dir: Optional[str | Path | dict] = None,
max_retries: Optional[int] = 3,
screenshot_delay: Optional[float | int] = 0.5,
use_prompt_caching: Optional[bool] = False,
@@ -223,7 +224,10 @@ class ComputerAgent:
# Add trajectory saver callback if trajectory_dir is set
if self.trajectory_dir:
self.callbacks.append(TrajectorySaverCallback(self.trajectory_dir))
if isinstance(self.trajectory_dir, dict):
self.callbacks.append(TrajectorySaverCallback(**self.trajectory_dir))
elif isinstance(self.trajectory_dir, (str, Path)):
self.callbacks.append(TrajectorySaverCallback(str(self.trajectory_dir)))
# Add budget manager if max_trajectory_budget is set
if max_trajectory_budget:

View File

@@ -11,6 +11,8 @@ from pathlib import Path
from typing import List, Dict, Any, Optional, Union, override
from PIL import Image, ImageDraw
import io
from copy import deepcopy
from .base import AsyncCallbackHandler
def sanitize_image_urls(data: Any) -> Any:
@@ -43,6 +45,64 @@ def sanitize_image_urls(data: Any) -> Any:
return data
def extract_computer_call_outputs(items: List[Dict[str, Any]], screenshot_dir: Optional[Path]) -> List[Dict[str, Any]]:
"""
Save any base64-encoded screenshots from computer_call_output entries to files and
replace their image_url with the saved file path when a call_id is present.
Only operates if screenshot_dir is provided and exists; otherwise returns items unchanged.
Args:
items: List of message/result dicts potentially containing computer_call_output entries
screenshot_dir: Directory to write screenshots into
Returns:
A new list with updated image_url fields when applicable.
"""
if not items:
return items
if not screenshot_dir or not screenshot_dir.exists():
return items
updated: List[Dict[str, Any]] = []
for item in items:
# work on a shallow copy; deep copy nested 'output' if we modify it
msg = dict(item)
try:
if msg.get("type") == "computer_call_output":
call_id = msg.get("call_id")
output = msg.get("output", {})
image_url = output.get("image_url")
if call_id and isinstance(image_url, str) and image_url.startswith("data:"):
# derive extension from MIME type e.g. data:image/png;base64,
try:
ext = image_url.split(";", 1)[0].split("/")[-1]
if not ext:
ext = "png"
except Exception:
ext = "png"
out_path = screenshot_dir / f"{call_id}.{ext}"
# write file if it doesn't exist
if not out_path.exists():
try:
b64_payload = image_url.split(",", 1)[1]
img_bytes = base64.b64decode(b64_payload)
out_path.parent.mkdir(parents=True, exist_ok=True)
with open(out_path, "wb") as f:
f.write(img_bytes)
except Exception:
# if anything fails, skip modifying this message
pass
# update image_url to file path
new_output = dict(output)
new_output["image_url"] = str(out_path)
msg["output"] = new_output
except Exception:
# do not block on malformed entries; keep original
pass
updated.append(msg)
return updated
class TrajectorySaverCallback(AsyncCallbackHandler):
"""
Callback handler that saves agent trajectories to disk.
@@ -51,7 +111,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
within the trajectory gets its own folder with screenshots and responses.
"""
def __init__(self, trajectory_dir: str, reset_on_run: bool = True):
def __init__(self, trajectory_dir: str, reset_on_run: bool = True, screenshot_dir: Optional[str] = None):
"""
Initialize trajectory saver.
@@ -67,10 +127,12 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
self.model: Optional[str] = None
self.total_usage: Dict[str, Any] = {}
self.reset_on_run = reset_on_run
# Optional directory to store extracted screenshots from metadata/new_items
self.screenshot_dir: Optional[Path] = Path(screenshot_dir) if screenshot_dir else None
# Ensure trajectory directory exists
self.trajectory_dir.mkdir(parents=True, exist_ok=True)
def _get_turn_dir(self) -> Path:
"""Get the directory for the current turn."""
if not self.trajectory_id:
@@ -139,12 +201,21 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
trajectory_path = self.trajectory_dir / self.trajectory_id
trajectory_path.mkdir(parents=True, exist_ok=True)
# Save trajectory metadata
# Save trajectory metadata (optionally extract screenshots to screenshot_dir)
kwargs_to_save = kwargs.copy()
try:
if "messages" in kwargs_to_save:
kwargs_to_save["messages"] = extract_computer_call_outputs(
kwargs_to_save["messages"], self.screenshot_dir
)
except Exception:
# If extraction fails, fall back to original messages
pass
metadata = {
"trajectory_id": self.trajectory_id,
"created_at": str(uuid.uuid1().time),
"status": "running",
"kwargs": kwargs,
"kwargs": kwargs_to_save,
}
with open(trajectory_path / "metadata.json", "w") as f:
@@ -171,11 +242,18 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
metadata = {}
# Update metadata with completion info
# Optionally extract screenshots from new_items before persisting
new_items_to_save = new_items
try:
new_items_to_save = extract_computer_call_outputs(new_items, self.screenshot_dir)
except Exception:
pass
metadata.update({
"status": "completed",
"completed_at": str(uuid.uuid1().time),
"total_usage": self.total_usage,
"new_items": new_items,
"new_items": new_items_to_save,
"total_turns": self.current_turn
})

View File

@@ -41,7 +41,7 @@ class ProxyOperatorAgent(OperatorAgent):
*,
model: str | None = None,
allowed_tools: list[str] | None = None,
trajectory_dir: str | None = None,
trajectory_dir: str | dict | None = None,
# === ComputerAgent kwargs ===
tools: list[Any] | None = None,
custom_loop: Any | None = None,
@@ -109,7 +109,7 @@ async def run_single_task(
only_n_most_recent_images: int | None = None,
callbacks: list[Any] | None = None,
verbosity: int | None = None,
trajectory_dir: str | None = None,
trajectory_dir: str | dict | None = None,
max_retries: int | None = 3,
screenshot_delay: float | int = 0.5,
use_prompt_caching: bool | None = False,
@@ -167,7 +167,7 @@ async def run_full_dataset(
max_concurrent: int = 30,
max_steps: int = 50,
split: str = "train",
trajectory_dir: str | None = None,
trajectory_dir: str | dict | None = None,
# === ComputerAgent kwargs ===
tools: list[Any] | None = None,
custom_loop: Any | None = None,