improved trajectory saving with run_job

This commit is contained in:
Dillon DuPont
2025-08-08 19:01:08 -04:00
parent 6d42c5d939
commit b23cac9e8b
2 changed files with 46 additions and 26 deletions

View File

@@ -51,12 +51,14 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
within the trajectory gets its own folder with screenshots and responses.
"""
def __init__(self, trajectory_dir: str):
def __init__(self, trajectory_dir: str, reset_on_run: bool = True):
"""
Initialize trajectory saver.
Args:
trajectory_dir: Base directory to save trajectories
reset_on_run: If True, reset trajectory_id/turn/artifact on each run.
If False, continue using existing trajectory_id if set.
"""
self.trajectory_dir = Path(trajectory_dir)
self.trajectory_id: Optional[str] = None
@@ -64,6 +66,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
self.current_artifact: int = 0
self.model: Optional[str] = None
self.total_usage: Dict[str, Any] = {}
self.reset_on_run = reset_on_run
# Ensure trajectory directory exists
self.trajectory_dir.mkdir(parents=True, exist_ok=True)
@@ -113,32 +116,38 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
"""Initialize trajectory tracking for a new run."""
model = kwargs.get("model", "unknown")
model_name_short = model.split("+")[-1].split("/")[-1].lower()[:16]
if "+" in model:
model_name_short = model.split("+")[0].lower()[:4] + "_" + model_name_short
# Only reset trajectory state if reset_on_run is True or no trajectory exists
if self.reset_on_run or not self.trajectory_id:
model_name_short = model.split("+")[-1].split("/")[-1].lower()[:16]
if "+" in model:
model_name_short = model.split("+")[0].lower()[:4] + "_" + model_name_short
# id format: yyyy-mm-dd_model_hhmmss_uuid[:4]
now = datetime.now()
self.trajectory_id = f"{now.strftime('%Y-%m-%d')}_{model_name_short}_{now.strftime('%H%M%S')}_{str(uuid.uuid4())[:4]}"
self.current_turn = 0
self.current_artifact = 0
self.model = model
self.total_usage = {}
# Create trajectory directory
trajectory_path = self.trajectory_dir / self.trajectory_id
trajectory_path.mkdir(parents=True, exist_ok=True)
# Save trajectory metadata
metadata = {
"trajectory_id": self.trajectory_id,
"created_at": str(uuid.uuid1().time),
"status": "running",
"kwargs": kwargs,
}
with open(trajectory_path / "metadata.json", "w") as f:
json.dump(metadata, f, indent=2)
# id format: yyyy-mm-dd_model_hhmmss_uuid[:4]
now = datetime.now()
self.trajectory_id = f"{now.strftime('%Y-%m-%d')}_{model_name_short}_{now.strftime('%H%M%S')}_{str(uuid.uuid4())[:4]}"
self.current_turn = 0
self.current_artifact = 0
self.model = model
self.total_usage = {}
# Create trajectory directory
trajectory_path = self.trajectory_dir / self.trajectory_id
trajectory_path.mkdir(parents=True, exist_ok=True)
# Save trajectory metadata
metadata = {
"trajectory_id": self.trajectory_id,
"created_at": str(uuid.uuid1().time),
"status": "running",
"kwargs": kwargs,
}
with open(trajectory_path / "metadata.json", "w") as f:
json.dump(metadata, f, indent=2)
else:
# Continue with existing trajectory - just update model if needed
self.model = model
@override
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:

View File

@@ -7,6 +7,7 @@ from hud import run_job as hud_run_job
from .agent import ComputerAgent
from .adapter import ComputerAgentAdapter
from .computer_handler import HUDComputerHandler
from ..callbacks.trajectory_saver import TrajectorySaverCallback
async def run_job(
@@ -48,11 +49,21 @@ async def run_job(
Returns:
Job instance from HUD
"""
# Handle trajectory_dir by adding TrajectorySaverCallback
trajectory_dir = agent_kwargs.pop("trajectory_dir", None)
callbacks = agent_kwargs.get("callbacks", [])
if trajectory_dir:
trajectory_callback = TrajectorySaverCallback(trajectory_dir, reset_on_run=False)
callbacks = callbacks + [trajectory_callback]
agent_kwargs["callbacks"] = callbacks
# combine verbose and verbosity kwargs
if "verbose" in agent_kwargs:
agent_kwargs["verbosity"] = logging.INFO
del agent_kwargs["verbose"]
verbose = True if agent_kwargs.get("verbosity", logging.WARNING) > logging.INFO else False
# run job
return await hud_run_job(
agent_cls=ComputerAgent,