mirror of
https://github.com/trycua/computer.git
synced 2026-01-04 12:30:08 -06:00
improved trajectory saving with run_job
This commit is contained in:
@@ -51,12 +51,14 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
within the trajectory gets its own folder with screenshots and responses.
|
||||
"""
|
||||
|
||||
def __init__(self, trajectory_dir: str):
|
||||
def __init__(self, trajectory_dir: str, reset_on_run: bool = True):
|
||||
"""
|
||||
Initialize trajectory saver.
|
||||
|
||||
Args:
|
||||
trajectory_dir: Base directory to save trajectories
|
||||
reset_on_run: If True, reset trajectory_id/turn/artifact on each run.
|
||||
If False, continue using existing trajectory_id if set.
|
||||
"""
|
||||
self.trajectory_dir = Path(trajectory_dir)
|
||||
self.trajectory_id: Optional[str] = None
|
||||
@@ -64,6 +66,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
self.current_artifact: int = 0
|
||||
self.model: Optional[str] = None
|
||||
self.total_usage: Dict[str, Any] = {}
|
||||
self.reset_on_run = reset_on_run
|
||||
|
||||
# Ensure trajectory directory exists
|
||||
self.trajectory_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -113,32 +116,38 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Initialize trajectory tracking for a new run."""
|
||||
model = kwargs.get("model", "unknown")
|
||||
model_name_short = model.split("+")[-1].split("/")[-1].lower()[:16]
|
||||
if "+" in model:
|
||||
model_name_short = model.split("+")[0].lower()[:4] + "_" + model_name_short
|
||||
|
||||
# Only reset trajectory state if reset_on_run is True or no trajectory exists
|
||||
if self.reset_on_run or not self.trajectory_id:
|
||||
model_name_short = model.split("+")[-1].split("/")[-1].lower()[:16]
|
||||
if "+" in model:
|
||||
model_name_short = model.split("+")[0].lower()[:4] + "_" + model_name_short
|
||||
|
||||
# id format: yyyy-mm-dd_model_hhmmss_uuid[:4]
|
||||
now = datetime.now()
|
||||
self.trajectory_id = f"{now.strftime('%Y-%m-%d')}_{model_name_short}_{now.strftime('%H%M%S')}_{str(uuid.uuid4())[:4]}"
|
||||
self.current_turn = 0
|
||||
self.current_artifact = 0
|
||||
self.model = model
|
||||
self.total_usage = {}
|
||||
|
||||
# Create trajectory directory
|
||||
trajectory_path = self.trajectory_dir / self.trajectory_id
|
||||
trajectory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save trajectory metadata
|
||||
metadata = {
|
||||
"trajectory_id": self.trajectory_id,
|
||||
"created_at": str(uuid.uuid1().time),
|
||||
"status": "running",
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
with open(trajectory_path / "metadata.json", "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
# id format: yyyy-mm-dd_model_hhmmss_uuid[:4]
|
||||
now = datetime.now()
|
||||
self.trajectory_id = f"{now.strftime('%Y-%m-%d')}_{model_name_short}_{now.strftime('%H%M%S')}_{str(uuid.uuid4())[:4]}"
|
||||
self.current_turn = 0
|
||||
self.current_artifact = 0
|
||||
self.model = model
|
||||
self.total_usage = {}
|
||||
|
||||
# Create trajectory directory
|
||||
trajectory_path = self.trajectory_dir / self.trajectory_id
|
||||
trajectory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save trajectory metadata
|
||||
metadata = {
|
||||
"trajectory_id": self.trajectory_id,
|
||||
"created_at": str(uuid.uuid1().time),
|
||||
"status": "running",
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
with open(trajectory_path / "metadata.json", "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
else:
|
||||
# Continue with existing trajectory - just update model if needed
|
||||
self.model = model
|
||||
|
||||
@override
|
||||
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
||||
|
||||
@@ -7,6 +7,7 @@ from hud import run_job as hud_run_job
|
||||
from .agent import ComputerAgent
|
||||
from .adapter import ComputerAgentAdapter
|
||||
from .computer_handler import HUDComputerHandler
|
||||
from ..callbacks.trajectory_saver import TrajectorySaverCallback
|
||||
|
||||
|
||||
async def run_job(
|
||||
@@ -48,11 +49,21 @@ async def run_job(
|
||||
Returns:
|
||||
Job instance from HUD
|
||||
"""
|
||||
# Handle trajectory_dir by adding TrajectorySaverCallback
|
||||
trajectory_dir = agent_kwargs.pop("trajectory_dir", None)
|
||||
callbacks = agent_kwargs.get("callbacks", [])
|
||||
|
||||
if trajectory_dir:
|
||||
trajectory_callback = TrajectorySaverCallback(trajectory_dir, reset_on_run=False)
|
||||
callbacks = callbacks + [trajectory_callback]
|
||||
agent_kwargs["callbacks"] = callbacks
|
||||
|
||||
# combine verbose and verbosity kwargs
|
||||
if "verbose" in agent_kwargs:
|
||||
agent_kwargs["verbosity"] = logging.INFO
|
||||
del agent_kwargs["verbose"]
|
||||
verbose = True if agent_kwargs.get("verbosity", logging.WARNING) > logging.INFO else False
|
||||
|
||||
# run job
|
||||
return await hud_run_job(
|
||||
agent_cls=ComputerAgent,
|
||||
|
||||
Reference in New Issue
Block a user