mirror of
https://github.com/trycua/computer.git
synced 2026-02-19 12:59:34 -06:00
Normalize common LLM output errors
This commit is contained in:
90
libs/python/agent/agent/callbacks/operator_validator.py
Normal file
90
libs/python/agent/agent/callbacks/operator_validator.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
OperatorValidatorCallback
|
||||
|
||||
Ensures agent output actions conform to expected schemas by fixing common issues:
|
||||
- click: add default button='left' if missing
|
||||
- keypress: wrap keys string into a list
|
||||
|
||||
This runs in on_llm_end, which receives the output array (AgentMessage[] as dicts).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
|
||||
class OperatorValidatorCallback(AsyncCallbackHandler):
|
||||
"""Validates and normalizes operator/computer actions in LLM outputs."""
|
||||
|
||||
async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
# Mutate in-place as requested, but still return the list for chaining
|
||||
for item in output or []:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("type") != "computer_call":
|
||||
continue
|
||||
action = item.get("action")
|
||||
if not isinstance(action, dict):
|
||||
continue
|
||||
action_type = action.get("type")
|
||||
def _remove_keys(action: Dict[str, Any], keys: List[str]):
|
||||
for key in keys:
|
||||
if key in action:
|
||||
del action[key]
|
||||
for mouse_btn in ["left", "right", "wheel", "back", "forward"]:
|
||||
if f"{mouse_btn}_click" in action:
|
||||
action["type"] = "click"
|
||||
action["button"] = mouse_btn
|
||||
if action_type == "click":
|
||||
# Add default button if missing
|
||||
if "button" not in action or action.get("button") is None:
|
||||
action["button"] = "left"
|
||||
if "coordinate" in action:
|
||||
action["x"] = action["coordinate"][0]
|
||||
action["y"] = action["coordinate"][1]
|
||||
del action["coordinate"]
|
||||
if action_type in ["type", "keypress", "screenshot", "wait"]:
|
||||
_remove_keys(action, ["coordinate", "x", "y"])
|
||||
elif action_type == "keypress":
|
||||
keys = action.get("keys")
|
||||
if isinstance(keys, str):
|
||||
action["keys"] = keys.replace("-", "+").split("+") if len(keys) > 1 else [keys]
|
||||
|
||||
|
||||
# Second pass: if an assistant message is immediately followed by a computer_call,
|
||||
# replace the assistant message itself with a reasoning message with summary text.
|
||||
if isinstance(output, list):
|
||||
for i, item in enumerate(output):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
# AssistantMessage shape: { type: 'message', role: 'assistant', content: OutputContent[] }
|
||||
if item.get("type") == "message" and item.get("role") == "assistant":
|
||||
next_idx = i + 1
|
||||
if next_idx >= len(output):
|
||||
continue
|
||||
next_item = output[next_idx]
|
||||
if not isinstance(next_item, dict):
|
||||
continue
|
||||
if next_item.get("type") != "computer_call":
|
||||
continue
|
||||
contents = item.get("content") or []
|
||||
# Extract text from OutputContent[]
|
||||
text_parts: List[str] = []
|
||||
if isinstance(contents, list):
|
||||
for c in contents:
|
||||
if isinstance(c, dict) and c.get("type") == "output_text" and isinstance(c.get("text"), str):
|
||||
text_parts.append(c["text"])
|
||||
text_content = "\n".join(text_parts).strip()
|
||||
# Replace assistant message with reasoning message
|
||||
output[i] = {
|
||||
"type": "reasoning",
|
||||
"summary": [
|
||||
{
|
||||
"type": "summary_text",
|
||||
"text": text_content,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
return output
|
||||
Reference in New Issue
Block a user