From e90997c4ff5ced5072bd6ba70a8a05b3ad63ffb3 Mon Sep 17 00:00:00 2001 From: Dillon DuPont Date: Thu, 28 Aug 2025 13:18:17 -0400 Subject: [PATCH] Added screenshot_dir and lazy loading of MLX --- .../agent/agent/adapters/mlxvlm_adapter.py | 5 +- libs/python/agent/agent/agent.py | 8 +- .../agent/agent/callbacks/trajectory_saver.py | 88 +++++++++++++++++-- .../agent/agent/integrations/hud/__init__.py | 6 +- 4 files changed, 95 insertions(+), 12 deletions(-) diff --git a/libs/python/agent/agent/adapters/mlxvlm_adapter.py b/libs/python/agent/agent/adapters/mlxvlm_adapter.py index c38f4ad6..8255725b 100644 --- a/libs/python/agent/agent/adapters/mlxvlm_adapter.py +++ b/libs/python/agent/agent/adapters/mlxvlm_adapter.py @@ -78,8 +78,6 @@ class MLXVLMAdapter(CustomLLM): **kwargs: Additional arguments """ super().__init__() - if not MLX_AVAILABLE: - raise ImportError("MLX VLM dependencies not available. Please install mlx-vlm.") self.models = {} # Cache for loaded models self.processors = {} # Cache for loaded processors @@ -95,6 +93,9 @@ class MLXVLMAdapter(CustomLLM): Returns: Tuple of (model, processor, config) """ + if not MLX_AVAILABLE: + raise ImportError("MLX VLM dependencies not available. Please install mlx-vlm.") + if model_name not in self.models: # Load model and processor model_obj, processor = load( diff --git a/libs/python/agent/agent/agent.py b/libs/python/agent/agent/agent.py index 361a3549..f7ba6418 100644 --- a/libs/python/agent/agent/agent.py +++ b/libs/python/agent/agent/agent.py @@ -3,6 +3,7 @@ ComputerAgent - Main agent class that selects and runs agent loops """ import asyncio +from pathlib import Path from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Callable, Set, Tuple from litellm.responses.utils import Usage @@ -162,7 +163,7 @@ class ComputerAgent: only_n_most_recent_images: Optional[int] = None, callbacks: Optional[List[Any]] = None, verbosity: Optional[int] = None, - trajectory_dir: Optional[str] = None, + trajectory_dir: Optional[str | Path | dict] = None, max_retries: Optional[int] = 3, screenshot_delay: Optional[float | int] = 0.5, use_prompt_caching: Optional[bool] = False, @@ -223,7 +224,10 @@ class ComputerAgent: # Add trajectory saver callback if trajectory_dir is set if self.trajectory_dir: - self.callbacks.append(TrajectorySaverCallback(self.trajectory_dir)) + if isinstance(self.trajectory_dir, dict): + self.callbacks.append(TrajectorySaverCallback(**self.trajectory_dir)) + elif isinstance(self.trajectory_dir, (str, Path)): + self.callbacks.append(TrajectorySaverCallback(str(self.trajectory_dir))) # Add budget manager if max_trajectory_budget is set if max_trajectory_budget: diff --git a/libs/python/agent/agent/callbacks/trajectory_saver.py b/libs/python/agent/agent/callbacks/trajectory_saver.py index 53e4c189..a65722aa 100644 --- a/libs/python/agent/agent/callbacks/trajectory_saver.py +++ b/libs/python/agent/agent/callbacks/trajectory_saver.py @@ -11,6 +11,8 @@ from pathlib import Path from typing import List, Dict, Any, Optional, Union, override from PIL import Image, ImageDraw import io +from copy import deepcopy + from .base import AsyncCallbackHandler def sanitize_image_urls(data: Any) -> Any: @@ -43,6 +45,64 @@ def sanitize_image_urls(data: Any) -> Any: return data +def extract_computer_call_outputs(items: List[Dict[str, Any]], screenshot_dir: Optional[Path]) -> List[Dict[str, Any]]: + """ + Save any base64-encoded screenshots from computer_call_output entries to files and + replace their image_url with the saved file path when a call_id is present. + + Only operates if screenshot_dir is provided and exists; otherwise returns items unchanged. + + Args: + items: List of message/result dicts potentially containing computer_call_output entries + screenshot_dir: Directory to write screenshots into + + Returns: + A new list with updated image_url fields when applicable. + """ + if not items: + return items + if not screenshot_dir or not screenshot_dir.exists(): + return items + + updated: List[Dict[str, Any]] = [] + for item in items: + # work on a shallow copy; deep copy nested 'output' if we modify it + msg = dict(item) + try: + if msg.get("type") == "computer_call_output": + call_id = msg.get("call_id") + output = msg.get("output", {}) + image_url = output.get("image_url") + if call_id and isinstance(image_url, str) and image_url.startswith("data:"): + # derive extension from MIME type e.g. data:image/png;base64, + try: + ext = image_url.split(";", 1)[0].split("/")[-1] + if not ext: + ext = "png" + except Exception: + ext = "png" + out_path = screenshot_dir / f"{call_id}.{ext}" + # write file if it doesn't exist + if not out_path.exists(): + try: + b64_payload = image_url.split(",", 1)[1] + img_bytes = base64.b64decode(b64_payload) + out_path.parent.mkdir(parents=True, exist_ok=True) + with open(out_path, "wb") as f: + f.write(img_bytes) + except Exception: + # if anything fails, skip modifying this message + pass + # update image_url to file path + new_output = dict(output) + new_output["image_url"] = str(out_path) + msg["output"] = new_output + except Exception: + # do not block on malformed entries; keep original + pass + updated.append(msg) + return updated + class TrajectorySaverCallback(AsyncCallbackHandler): """ Callback handler that saves agent trajectories to disk. @@ -51,7 +111,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler): within the trajectory gets its own folder with screenshots and responses. """ - def __init__(self, trajectory_dir: str, reset_on_run: bool = True): + def __init__(self, trajectory_dir: str, reset_on_run: bool = True, screenshot_dir: Optional[str] = None): """ Initialize trajectory saver. @@ -67,10 +127,12 @@ class TrajectorySaverCallback(AsyncCallbackHandler): self.model: Optional[str] = None self.total_usage: Dict[str, Any] = {} self.reset_on_run = reset_on_run + # Optional directory to store extracted screenshots from metadata/new_items + self.screenshot_dir: Optional[Path] = Path(screenshot_dir) if screenshot_dir else None # Ensure trajectory directory exists self.trajectory_dir.mkdir(parents=True, exist_ok=True) - + def _get_turn_dir(self) -> Path: """Get the directory for the current turn.""" if not self.trajectory_id: @@ -139,12 +201,21 @@ class TrajectorySaverCallback(AsyncCallbackHandler): trajectory_path = self.trajectory_dir / self.trajectory_id trajectory_path.mkdir(parents=True, exist_ok=True) - # Save trajectory metadata + # Save trajectory metadata (optionally extract screenshots to screenshot_dir) + kwargs_to_save = kwargs.copy() + try: + if "messages" in kwargs_to_save: + kwargs_to_save["messages"] = extract_computer_call_outputs( + kwargs_to_save["messages"], self.screenshot_dir + ) + except Exception: + # If extraction fails, fall back to original messages + pass metadata = { "trajectory_id": self.trajectory_id, "created_at": str(uuid.uuid1().time), "status": "running", - "kwargs": kwargs, + "kwargs": kwargs_to_save, } with open(trajectory_path / "metadata.json", "w") as f: @@ -171,11 +242,18 @@ class TrajectorySaverCallback(AsyncCallbackHandler): metadata = {} # Update metadata with completion info + # Optionally extract screenshots from new_items before persisting + new_items_to_save = new_items + try: + new_items_to_save = extract_computer_call_outputs(new_items, self.screenshot_dir) + except Exception: + pass + metadata.update({ "status": "completed", "completed_at": str(uuid.uuid1().time), "total_usage": self.total_usage, - "new_items": new_items, + "new_items": new_items_to_save, "total_turns": self.current_turn }) diff --git a/libs/python/agent/agent/integrations/hud/__init__.py b/libs/python/agent/agent/integrations/hud/__init__.py index 21695026..0da87bfa 100644 --- a/libs/python/agent/agent/integrations/hud/__init__.py +++ b/libs/python/agent/agent/integrations/hud/__init__.py @@ -41,7 +41,7 @@ class ProxyOperatorAgent(OperatorAgent): *, model: str | None = None, allowed_tools: list[str] | None = None, - trajectory_dir: str | None = None, + trajectory_dir: str | dict | None = None, # === ComputerAgent kwargs === tools: list[Any] | None = None, custom_loop: Any | None = None, @@ -109,7 +109,7 @@ async def run_single_task( only_n_most_recent_images: int | None = None, callbacks: list[Any] | None = None, verbosity: int | None = None, - trajectory_dir: str | None = None, + trajectory_dir: str | dict | None = None, max_retries: int | None = 3, screenshot_delay: float | int = 0.5, use_prompt_caching: bool | None = False, @@ -167,7 +167,7 @@ async def run_full_dataset( max_concurrent: int = 30, max_steps: int = 50, split: str = "train", - trajectory_dir: str | None = None, + trajectory_dir: str | dict | None = None, # === ComputerAgent kwargs === tools: list[Any] | None = None, custom_loop: Any | None = None,