added extended kwargs, renamed callback to normalizer

This commit is contained in:
Dillon DuPont
2025-08-27 20:49:31 -04:00
parent 799d9d3ba7
commit 95cefc50f0
5 changed files with 422 additions and 23 deletions

View File

@@ -29,7 +29,7 @@ from .callbacks import (
TrajectorySaverCallback,
BudgetManagerCallback,
TelemetryCallback,
OperatorValidatorCallback
OperatorNormalizerCallback
)
from .computers import (
AsyncComputerHandler,
@@ -202,8 +202,8 @@ class ComputerAgent:
# == Add built-in callbacks ==
# Prepend operator validator callback
self.callbacks.insert(0, OperatorValidatorCallback())
# Prepend operator normalizer callback
self.callbacks.insert(0, OperatorNormalizerCallback())
# Add telemetry callback if telemetry_enabled is set
if self.telemetry_enabled:

View File

@@ -8,7 +8,7 @@ from .logging import LoggingCallback
from .trajectory_saver import TrajectorySaverCallback
from .budget_manager import BudgetManagerCallback
from .telemetry import TelemetryCallback
from .operator_validator import OperatorValidatorCallback
from .operator_validator import OperatorNormalizerCallback
__all__ = [
"AsyncCallbackHandler",
@@ -17,5 +17,5 @@ __all__ = [
"TrajectorySaverCallback",
"BudgetManagerCallback",
"TelemetryCallback",
"OperatorValidatorCallback",
"OperatorNormalizerCallback",
]

View File

@@ -4,6 +4,7 @@ OperatorValidatorCallback
Ensures agent output actions conform to expected schemas by fixing common issues:
- click: add default button='left' if missing
- keypress: wrap keys string into a list
- etc.
This runs in on_llm_end, which receives the output array (AgentMessage[] as dicts).
"""
@@ -14,14 +15,12 @@ from typing import Any, Dict, List
from .base import AsyncCallbackHandler
class OperatorValidatorCallback(AsyncCallbackHandler):
"""Validates and normalizes operator/computer actions in LLM outputs."""
class OperatorNormalizerCallback(AsyncCallbackHandler):
"""Normalizes common computer call hallucinations / errors in computer call syntax."""
async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# Mutate in-place as requested, but still return the list for chaining
for item in output or []:
if not isinstance(item, dict):
continue
if item.get("type") != "computer_call":
continue
action = item.get("action")
@@ -56,8 +55,6 @@ class OperatorValidatorCallback(AsyncCallbackHandler):
# replace the assistant message itself with a reasoning message with summary text.
if isinstance(output, list):
for i, item in enumerate(output):
if not isinstance(item, dict):
continue
# AssistantMessage shape: { type: 'message', role: 'assistant', content: OutputContent[] }
if item.get("type") == "message" and item.get("role") == "assistant":
next_idx = i + 1

View File

@@ -42,6 +42,17 @@ class ProxyOperatorAgent(OperatorAgent):
model: str | None = None,
allowed_tools: list[str] | None = None,
trajectory_dir: str | None = None,
# === ComputerAgent kwargs ===
tools: list[Any] | None = None,
custom_loop: Any | None = None,
only_n_most_recent_images: int | None = None,
callbacks: list[Any] | None = None,
verbosity: int | None = None,
max_retries: int | None = 3,
screenshot_delay: float | int = 0.5,
use_prompt_caching: bool | None = False,
max_trajectory_budget: float | dict | None = None,
telemetry_enabled: bool | None = True,
**kwargs: Any,
) -> None:
model = model or "computer-use-preview"
@@ -52,10 +63,24 @@ class ProxyOperatorAgent(OperatorAgent):
'environment': 'linux',
'dimensions': (computer_settings.OPENAI_COMPUTER_WIDTH, computer_settings.OPENAI_COMPUTER_HEIGHT)
}
# Build tools ensuring the computer_shim is included
agent_tools: list[Any] = [computer_shim]
if tools:
agent_tools.extend(tools)
computer_agent = BaseComputerAgent(
model=model,
tools=[computer_shim],
trajectory_dir=trajectory_dir
model=model,
tools=agent_tools,
custom_loop=custom_loop,
only_n_most_recent_images=only_n_most_recent_images,
callbacks=callbacks,
verbosity=verbosity,
trajectory_dir=trajectory_dir,
max_retries=max_retries,
screenshot_delay=screenshot_delay,
use_prompt_caching=use_prompt_caching,
max_trajectory_budget=max_trajectory_budget,
telemetry_enabled=telemetry_enabled,
)
model_client = FakeAsyncOpenAI(computer_agent)
@@ -78,6 +103,18 @@ async def run_single_task(
task_id: int = 0,
model: str | None = None,
allowed_tools: list[str] | None = None,
# === ComputerAgent kwargs ===
tools: list[Any] | None = None,
custom_loop: Any | None = None,
only_n_most_recent_images: int | None = None,
callbacks: list[Any] | None = None,
verbosity: int | None = None,
trajectory_dir: str | None = None,
max_retries: int | None = 3,
screenshot_delay: float | int = 0.5,
use_prompt_caching: bool | None = False,
max_trajectory_budget: float | dict | None = None,
telemetry_enabled: bool | None = True,
) -> None:
"""Load one task from the dataset and execute it with Operator+CUA proxy."""
@@ -95,7 +132,22 @@ async def run_single_task(
with trace(name=task_prompt):
task = Task(**sample_task) # type: ignore[arg-type]
agent = ProxyOperatorAgent(model=model, allowed_tools=allowed_tools)
agent = ProxyOperatorAgent(
model=model,
allowed_tools=allowed_tools,
# === ComputerAgent kwargs passthrough ===
tools=tools,
custom_loop=custom_loop,
only_n_most_recent_images=only_n_most_recent_images,
callbacks=callbacks,
verbosity=verbosity,
trajectory_dir=trajectory_dir,
max_retries=max_retries,
screenshot_delay=screenshot_delay,
use_prompt_caching=use_prompt_caching,
max_trajectory_budget=max_trajectory_budget,
telemetry_enabled=telemetry_enabled,
)
print(f"Running: {task_prompt}")
result = await agent.run(task, max_steps=10)
print(f"✅ Reward: {getattr(result, 'reward')}")
@@ -116,6 +168,17 @@ async def run_full_dataset(
max_steps: int = 50,
split: str = "train",
trajectory_dir: str | None = None,
# === ComputerAgent kwargs ===
tools: list[Any] | None = None,
custom_loop: Any | None = None,
only_n_most_recent_images: int | None = 5,
callbacks: list[Any] | None = None,
verbosity: int | None = None,
max_retries: int | None = 3,
screenshot_delay: float | int = 0.5,
use_prompt_caching: bool | None = False,
max_trajectory_budget: float | dict | None = None,
telemetry_enabled: bool | None = True,
) -> list[Any]:
"""Run evaluation across the entire dataset using hud.datasets.run_dataset."""
@@ -135,7 +198,22 @@ async def run_full_dataset(
name=job_name,
dataset=dataset,
agent_class=ProxyOperatorAgent,
agent_config={"model": model, "allowed_tools": allowed_tools, "trajectory_dir": trajectory_dir},
agent_config={
"model": model,
"allowed_tools": allowed_tools,
"trajectory_dir": trajectory_dir,
# === ComputerAgent kwargs passthrough ===
"tools": tools,
"custom_loop": custom_loop,
"only_n_most_recent_images": only_n_most_recent_images,
"callbacks": callbacks,
"verbosity": verbosity,
"max_retries": max_retries,
"screenshot_delay": screenshot_delay,
"use_prompt_caching": use_prompt_caching,
"max_trajectory_budget": max_trajectory_budget,
"telemetry_enabled": telemetry_enabled,
},
max_concurrent=max_concurrent,
metadata={"dataset": dataset_name},
max_steps=max_steps,