Merge pull request #360 from trycua/fix/tool-errors

[Agent] Fix tool error propagation
This commit is contained in:
ddupont
2025-08-19 15:54:31 -04:00
committed by GitHub
3 changed files with 168 additions and 88 deletions

View File

@@ -7,7 +7,13 @@ from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Calla
from litellm.responses.utils import Usage
from .types import Messages, AgentCapability
from .types import (
Messages,
AgentCapability,
ToolError,
IllegalArgumentError
)
from .responses import make_tool_error_item, replace_failed_computer_calls_with_function_calls
from .decorators import find_agent_config
import json
import litellm
@@ -30,6 +36,15 @@ from .computers import (
make_computer_handler
)
def assert_callable_with(f, *args, **kwargs):
"""Check if function can be called with given arguments."""
try:
inspect.signature(f).bind(*args, **kwargs)
return True
except TypeError as e:
sig = inspect.signature(f)
raise IllegalArgumentError(f"Expected {sig}, got args={args} kwargs={kwargs}") from e
def get_json(obj: Any, max_depth: int = 10) -> Any:
def custom_serializer(o: Any, depth: int = 0, seen: Optional[Set[int]] = None) -> Any:
if seen is None:
@@ -405,7 +420,8 @@ class ComputerAgent:
async def _handle_item(self, item: Any, computer: Optional[AsyncComputerHandler] = None, ignore_call_ids: Optional[List[str]] = None) -> List[Dict[str, Any]]:
"""Handle each item; may cause a computer action + screenshot."""
if ignore_call_ids and item.get("call_id") and item.get("call_id") in ignore_call_ids:
call_id = item.get("call_id")
if ignore_call_ids and call_id and call_id in ignore_call_ids:
return []
item_type = item.get("type", None)
@@ -419,96 +435,103 @@ class ComputerAgent:
# print(content_item.get("text"))
return []
if item_type == "computer_call":
await self._on_computer_call_start(item)
if not computer:
raise ValueError("Computer handler is required for computer calls")
try:
if item_type == "computer_call":
await self._on_computer_call_start(item)
if not computer:
raise ValueError("Computer handler is required for computer calls")
# Perform computer actions
action = item.get("action")
action_type = action.get("type")
if action_type is None:
print(f"Action type cannot be `None`: action={action}, action_type={action_type}")
return []
# Perform computer actions
action = item.get("action")
action_type = action.get("type")
if action_type is None:
print(f"Action type cannot be `None`: action={action}, action_type={action_type}")
return []
# Extract action arguments (all fields except 'type')
action_args = {k: v for k, v in action.items() if k != "type"}
# print(f"{action_type}({action_args})")
# Execute the computer action
computer_method = getattr(computer, action_type, None)
if computer_method:
assert_callable_with(computer_method, **action_args)
await computer_method(**action_args)
else:
print(f"Unknown computer action: {action_type}")
return []
# Take screenshot after action
if self.screenshot_delay and self.screenshot_delay > 0:
await asyncio.sleep(self.screenshot_delay)
screenshot_base64 = await computer.screenshot()
await self._on_screenshot(screenshot_base64, "screenshot_after")
# Handle safety checks
pending_checks = item.get("pending_safety_checks", [])
acknowledged_checks = []
for check in pending_checks:
check_message = check.get("message", str(check))
acknowledged_checks.append(check)
# TODO: implement a callback for safety checks
# if acknowledge_safety_check_callback(check_message, allow_always=True):
# acknowledged_checks.append(check)
# else:
# raise ValueError(f"Safety check failed: {check_message}")
# Create call output
call_output = {
"type": "computer_call_output",
"call_id": item.get("call_id"),
"acknowledged_safety_checks": acknowledged_checks,
"output": {
"type": "input_image",
"image_url": f"data:image/png;base64,{screenshot_base64}",
},
}
# # Additional URL safety checks for browser environments
# if await computer.get_environment() == "browser":
# current_url = await computer.get_current_url()
# call_output["output"]["current_url"] = current_url
# # TODO: implement a callback for URL safety checks
# # check_blocklisted_url(current_url)
result = [call_output]
await self._on_computer_call_end(item, result)
return result
# Extract action arguments (all fields except 'type')
action_args = {k: v for k, v in action.items() if k != "type"}
if item_type == "function_call":
await self._on_function_call_start(item)
# Perform function call
function = self._get_tool(item.get("name"))
if not function:
raise ValueError(f"Function {item.get("name")} not found")
# print(f"{action_type}({action_args})")
# Execute the computer action
computer_method = getattr(computer, action_type, None)
if computer_method:
await computer_method(**action_args)
else:
print(f"Unknown computer action: {action_type}")
return []
# Take screenshot after action
if self.screenshot_delay and self.screenshot_delay > 0:
await asyncio.sleep(self.screenshot_delay)
screenshot_base64 = await computer.screenshot()
await self._on_screenshot(screenshot_base64, "screenshot_after")
# Handle safety checks
pending_checks = item.get("pending_safety_checks", [])
acknowledged_checks = []
for check in pending_checks:
check_message = check.get("message", str(check))
acknowledged_checks.append(check)
# TODO: implement a callback for safety checks
# if acknowledge_safety_check_callback(check_message, allow_always=True):
# acknowledged_checks.append(check)
# else:
# raise ValueError(f"Safety check failed: {check_message}")
# Create call output
call_output = {
"type": "computer_call_output",
"call_id": item.get("call_id"),
"acknowledged_safety_checks": acknowledged_checks,
"output": {
"type": "input_image",
"image_url": f"data:image/png;base64,{screenshot_base64}",
},
}
# # Additional URL safety checks for browser environments
# if await computer.get_environment() == "browser":
# current_url = await computer.get_current_url()
# call_output["output"]["current_url"] = current_url
# # TODO: implement a callback for URL safety checks
# # check_blocklisted_url(current_url)
result = [call_output]
await self._on_computer_call_end(item, result)
return result
if item_type == "function_call":
await self._on_function_call_start(item)
# Perform function call
function = self._get_tool(item.get("name"))
if not function:
raise ValueError(f"Function {item.get("name")} not found")
args = json.loads(item.get("arguments"))
args = json.loads(item.get("arguments"))
# Execute function - use asyncio.to_thread for non-async functions
if inspect.iscoroutinefunction(function):
result = await function(**args)
else:
result = await asyncio.to_thread(function, **args)
# Create function call output
call_output = {
"type": "function_call_output",
"call_id": item.get("call_id"),
"output": str(result),
}
result = [call_output]
await self._on_function_call_end(item, result)
return result
# Validate arguments before execution
assert_callable_with(function, **args)
# Execute function - use asyncio.to_thread for non-async functions
if inspect.iscoroutinefunction(function):
result = await function(**args)
else:
result = await asyncio.to_thread(function, **args)
# Create function call output
call_output = {
"type": "function_call_output",
"call_id": item.get("call_id"),
"output": str(result),
}
result = [call_output]
await self._on_function_call_end(item, result)
return result
except ToolError as e:
return [make_tool_error_item(repr(e), call_id)]
return []
@@ -569,6 +592,7 @@ class ComputerAgent:
# - PII anonymization
# - Image retention policy
combined_messages = old_items + new_items
combined_messages = replace_failed_computer_calls_with_function_calls(combined_messages)
preprocessed_messages = await self._on_llm_start(combined_messages)
loop_kwargs = {

View File

@@ -252,6 +252,53 @@ def make_failed_tool_call_items(tool_name: str, tool_kwargs: Dict[str, Any], err
}
]
def make_tool_error_item(error_message: str, call_id: Optional[str] = None) -> Dict[str, Any]:
call_id = call_id if call_id else random_id()
return {
"type": "function_call_output",
"call_id": call_id,
"output": json.dumps({"error": error_message}),
}
def replace_failed_computer_calls_with_function_calls(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Replace computer_call items with function_call items if they share a call_id with a function_call_output.
This indicates the computer call failed and should be treated as a function call instead.
We do this because the computer_call_output items do not support text output.
Args:
messages: List of message items to process
"""
messages = messages.copy()
# Find all call_ids that have function_call_output items
failed_call_ids = set()
for msg in messages:
if msg.get("type") == "function_call_output":
call_id = msg.get("call_id")
if call_id:
failed_call_ids.add(call_id)
# Replace computer_call items that have matching call_ids
for i, msg in enumerate(messages):
if (msg.get("type") == "computer_call" and
msg.get("call_id") in failed_call_ids):
# Extract action from computer_call
action = msg.get("action", {})
call_id = msg.get("call_id")
# Create function_call replacement
messages[i] = {
"type": "function_call",
"id": msg.get("id", random_id()),
"call_id": call_id,
"name": "computer",
"arguments": json.dumps(action),
}
return messages
# Conversion functions between element descriptions and coordinates
def convert_computer_calls_desc2xy(responses_items: List[Dict[str, Any]], desc2xy: Dict[str, tuple]) -> List[Dict[str, Any]]:
"""

View File

@@ -16,6 +16,15 @@ Tools = Optional[Iterable[ToolParam]]
AgentResponse = ResponsesAPIResponse
AgentCapability = Literal["step", "click"]
# Exception types
class ToolError(RuntimeError):
"""Base exception for tool-related errors"""
pass
class IllegalArgumentError(ToolError):
"""Exception raised when function arguments are invalid"""
pass
# Agent config registration
class AgentConfigInfo(BaseModel):