mirror of
https://github.com/trycua/computer.git
synced 2026-01-03 03:49:58 -06:00
added extended kwargs, renamed callback to normalizer
This commit is contained in:
@@ -29,7 +29,7 @@ from .callbacks import (
|
||||
TrajectorySaverCallback,
|
||||
BudgetManagerCallback,
|
||||
TelemetryCallback,
|
||||
OperatorValidatorCallback
|
||||
OperatorNormalizerCallback
|
||||
)
|
||||
from .computers import (
|
||||
AsyncComputerHandler,
|
||||
@@ -202,8 +202,8 @@ class ComputerAgent:
|
||||
|
||||
# == Add built-in callbacks ==
|
||||
|
||||
# Prepend operator validator callback
|
||||
self.callbacks.insert(0, OperatorValidatorCallback())
|
||||
# Prepend operator normalizer callback
|
||||
self.callbacks.insert(0, OperatorNormalizerCallback())
|
||||
|
||||
# Add telemetry callback if telemetry_enabled is set
|
||||
if self.telemetry_enabled:
|
||||
|
||||
@@ -8,7 +8,7 @@ from .logging import LoggingCallback
|
||||
from .trajectory_saver import TrajectorySaverCallback
|
||||
from .budget_manager import BudgetManagerCallback
|
||||
from .telemetry import TelemetryCallback
|
||||
from .operator_validator import OperatorValidatorCallback
|
||||
from .operator_validator import OperatorNormalizerCallback
|
||||
|
||||
__all__ = [
|
||||
"AsyncCallbackHandler",
|
||||
@@ -17,5 +17,5 @@ __all__ = [
|
||||
"TrajectorySaverCallback",
|
||||
"BudgetManagerCallback",
|
||||
"TelemetryCallback",
|
||||
"OperatorValidatorCallback",
|
||||
"OperatorNormalizerCallback",
|
||||
]
|
||||
|
||||
@@ -4,6 +4,7 @@ OperatorValidatorCallback
|
||||
Ensures agent output actions conform to expected schemas by fixing common issues:
|
||||
- click: add default button='left' if missing
|
||||
- keypress: wrap keys string into a list
|
||||
- etc.
|
||||
|
||||
This runs in on_llm_end, which receives the output array (AgentMessage[] as dicts).
|
||||
"""
|
||||
@@ -14,14 +15,12 @@ from typing import Any, Dict, List
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
|
||||
class OperatorValidatorCallback(AsyncCallbackHandler):
|
||||
"""Validates and normalizes operator/computer actions in LLM outputs."""
|
||||
class OperatorNormalizerCallback(AsyncCallbackHandler):
|
||||
"""Normalizes common computer call hallucinations / errors in computer call syntax."""
|
||||
|
||||
async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
# Mutate in-place as requested, but still return the list for chaining
|
||||
for item in output or []:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("type") != "computer_call":
|
||||
continue
|
||||
action = item.get("action")
|
||||
@@ -56,8 +55,6 @@ class OperatorValidatorCallback(AsyncCallbackHandler):
|
||||
# replace the assistant message itself with a reasoning message with summary text.
|
||||
if isinstance(output, list):
|
||||
for i, item in enumerate(output):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
# AssistantMessage shape: { type: 'message', role: 'assistant', content: OutputContent[] }
|
||||
if item.get("type") == "message" and item.get("role") == "assistant":
|
||||
next_idx = i + 1
|
||||
|
||||
@@ -42,6 +42,17 @@ class ProxyOperatorAgent(OperatorAgent):
|
||||
model: str | None = None,
|
||||
allowed_tools: list[str] | None = None,
|
||||
trajectory_dir: str | None = None,
|
||||
# === ComputerAgent kwargs ===
|
||||
tools: list[Any] | None = None,
|
||||
custom_loop: Any | None = None,
|
||||
only_n_most_recent_images: int | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
verbosity: int | None = None,
|
||||
max_retries: int | None = 3,
|
||||
screenshot_delay: float | int = 0.5,
|
||||
use_prompt_caching: bool | None = False,
|
||||
max_trajectory_budget: float | dict | None = None,
|
||||
telemetry_enabled: bool | None = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
model = model or "computer-use-preview"
|
||||
@@ -52,10 +63,24 @@ class ProxyOperatorAgent(OperatorAgent):
|
||||
'environment': 'linux',
|
||||
'dimensions': (computer_settings.OPENAI_COMPUTER_WIDTH, computer_settings.OPENAI_COMPUTER_HEIGHT)
|
||||
}
|
||||
# Build tools ensuring the computer_shim is included
|
||||
agent_tools: list[Any] = [computer_shim]
|
||||
if tools:
|
||||
agent_tools.extend(tools)
|
||||
|
||||
computer_agent = BaseComputerAgent(
|
||||
model=model,
|
||||
tools=[computer_shim],
|
||||
trajectory_dir=trajectory_dir
|
||||
model=model,
|
||||
tools=agent_tools,
|
||||
custom_loop=custom_loop,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
callbacks=callbacks,
|
||||
verbosity=verbosity,
|
||||
trajectory_dir=trajectory_dir,
|
||||
max_retries=max_retries,
|
||||
screenshot_delay=screenshot_delay,
|
||||
use_prompt_caching=use_prompt_caching,
|
||||
max_trajectory_budget=max_trajectory_budget,
|
||||
telemetry_enabled=telemetry_enabled,
|
||||
)
|
||||
model_client = FakeAsyncOpenAI(computer_agent)
|
||||
|
||||
@@ -78,6 +103,18 @@ async def run_single_task(
|
||||
task_id: int = 0,
|
||||
model: str | None = None,
|
||||
allowed_tools: list[str] | None = None,
|
||||
# === ComputerAgent kwargs ===
|
||||
tools: list[Any] | None = None,
|
||||
custom_loop: Any | None = None,
|
||||
only_n_most_recent_images: int | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
verbosity: int | None = None,
|
||||
trajectory_dir: str | None = None,
|
||||
max_retries: int | None = 3,
|
||||
screenshot_delay: float | int = 0.5,
|
||||
use_prompt_caching: bool | None = False,
|
||||
max_trajectory_budget: float | dict | None = None,
|
||||
telemetry_enabled: bool | None = True,
|
||||
) -> None:
|
||||
"""Load one task from the dataset and execute it with Operator+CUA proxy."""
|
||||
|
||||
@@ -95,7 +132,22 @@ async def run_single_task(
|
||||
with trace(name=task_prompt):
|
||||
task = Task(**sample_task) # type: ignore[arg-type]
|
||||
|
||||
agent = ProxyOperatorAgent(model=model, allowed_tools=allowed_tools)
|
||||
agent = ProxyOperatorAgent(
|
||||
model=model,
|
||||
allowed_tools=allowed_tools,
|
||||
# === ComputerAgent kwargs passthrough ===
|
||||
tools=tools,
|
||||
custom_loop=custom_loop,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
callbacks=callbacks,
|
||||
verbosity=verbosity,
|
||||
trajectory_dir=trajectory_dir,
|
||||
max_retries=max_retries,
|
||||
screenshot_delay=screenshot_delay,
|
||||
use_prompt_caching=use_prompt_caching,
|
||||
max_trajectory_budget=max_trajectory_budget,
|
||||
telemetry_enabled=telemetry_enabled,
|
||||
)
|
||||
print(f"Running: {task_prompt}")
|
||||
result = await agent.run(task, max_steps=10)
|
||||
print(f"✅ Reward: {getattr(result, 'reward')}")
|
||||
@@ -116,6 +168,17 @@ async def run_full_dataset(
|
||||
max_steps: int = 50,
|
||||
split: str = "train",
|
||||
trajectory_dir: str | None = None,
|
||||
# === ComputerAgent kwargs ===
|
||||
tools: list[Any] | None = None,
|
||||
custom_loop: Any | None = None,
|
||||
only_n_most_recent_images: int | None = 5,
|
||||
callbacks: list[Any] | None = None,
|
||||
verbosity: int | None = None,
|
||||
max_retries: int | None = 3,
|
||||
screenshot_delay: float | int = 0.5,
|
||||
use_prompt_caching: bool | None = False,
|
||||
max_trajectory_budget: float | dict | None = None,
|
||||
telemetry_enabled: bool | None = True,
|
||||
) -> list[Any]:
|
||||
"""Run evaluation across the entire dataset using hud.datasets.run_dataset."""
|
||||
|
||||
@@ -135,7 +198,22 @@ async def run_full_dataset(
|
||||
name=job_name,
|
||||
dataset=dataset,
|
||||
agent_class=ProxyOperatorAgent,
|
||||
agent_config={"model": model, "allowed_tools": allowed_tools, "trajectory_dir": trajectory_dir},
|
||||
agent_config={
|
||||
"model": model,
|
||||
"allowed_tools": allowed_tools,
|
||||
"trajectory_dir": trajectory_dir,
|
||||
# === ComputerAgent kwargs passthrough ===
|
||||
"tools": tools,
|
||||
"custom_loop": custom_loop,
|
||||
"only_n_most_recent_images": only_n_most_recent_images,
|
||||
"callbacks": callbacks,
|
||||
"verbosity": verbosity,
|
||||
"max_retries": max_retries,
|
||||
"screenshot_delay": screenshot_delay,
|
||||
"use_prompt_caching": use_prompt_caching,
|
||||
"max_trajectory_budget": max_trajectory_budget,
|
||||
"telemetry_enabled": telemetry_enabled,
|
||||
},
|
||||
max_concurrent=max_concurrent,
|
||||
metadata={"dataset": dataset_name},
|
||||
max_steps=max_steps,
|
||||
|
||||
Reference in New Issue
Block a user