mirror of
https://github.com/trycua/computer.git
synced 2026-01-03 12:00:00 -06:00
Merge pull request #360 from trycua/fix/tool-errors
[Agent] Fix tool error propagation
This commit is contained in:
@@ -7,7 +7,13 @@ from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Calla
|
||||
|
||||
from litellm.responses.utils import Usage
|
||||
|
||||
from .types import Messages, AgentCapability
|
||||
from .types import (
|
||||
Messages,
|
||||
AgentCapability,
|
||||
ToolError,
|
||||
IllegalArgumentError
|
||||
)
|
||||
from .responses import make_tool_error_item, replace_failed_computer_calls_with_function_calls
|
||||
from .decorators import find_agent_config
|
||||
import json
|
||||
import litellm
|
||||
@@ -30,6 +36,15 @@ from .computers import (
|
||||
make_computer_handler
|
||||
)
|
||||
|
||||
def assert_callable_with(f, *args, **kwargs):
|
||||
"""Check if function can be called with given arguments."""
|
||||
try:
|
||||
inspect.signature(f).bind(*args, **kwargs)
|
||||
return True
|
||||
except TypeError as e:
|
||||
sig = inspect.signature(f)
|
||||
raise IllegalArgumentError(f"Expected {sig}, got args={args} kwargs={kwargs}") from e
|
||||
|
||||
def get_json(obj: Any, max_depth: int = 10) -> Any:
|
||||
def custom_serializer(o: Any, depth: int = 0, seen: Optional[Set[int]] = None) -> Any:
|
||||
if seen is None:
|
||||
@@ -405,7 +420,8 @@ class ComputerAgent:
|
||||
|
||||
async def _handle_item(self, item: Any, computer: Optional[AsyncComputerHandler] = None, ignore_call_ids: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
||||
"""Handle each item; may cause a computer action + screenshot."""
|
||||
if ignore_call_ids and item.get("call_id") and item.get("call_id") in ignore_call_ids:
|
||||
call_id = item.get("call_id")
|
||||
if ignore_call_ids and call_id and call_id in ignore_call_ids:
|
||||
return []
|
||||
|
||||
item_type = item.get("type", None)
|
||||
@@ -419,96 +435,103 @@ class ComputerAgent:
|
||||
# print(content_item.get("text"))
|
||||
return []
|
||||
|
||||
if item_type == "computer_call":
|
||||
await self._on_computer_call_start(item)
|
||||
if not computer:
|
||||
raise ValueError("Computer handler is required for computer calls")
|
||||
try:
|
||||
if item_type == "computer_call":
|
||||
await self._on_computer_call_start(item)
|
||||
if not computer:
|
||||
raise ValueError("Computer handler is required for computer calls")
|
||||
|
||||
# Perform computer actions
|
||||
action = item.get("action")
|
||||
action_type = action.get("type")
|
||||
if action_type is None:
|
||||
print(f"Action type cannot be `None`: action={action}, action_type={action_type}")
|
||||
return []
|
||||
# Perform computer actions
|
||||
action = item.get("action")
|
||||
action_type = action.get("type")
|
||||
if action_type is None:
|
||||
print(f"Action type cannot be `None`: action={action}, action_type={action_type}")
|
||||
return []
|
||||
|
||||
# Extract action arguments (all fields except 'type')
|
||||
action_args = {k: v for k, v in action.items() if k != "type"}
|
||||
|
||||
# print(f"{action_type}({action_args})")
|
||||
|
||||
# Execute the computer action
|
||||
computer_method = getattr(computer, action_type, None)
|
||||
if computer_method:
|
||||
assert_callable_with(computer_method, **action_args)
|
||||
await computer_method(**action_args)
|
||||
else:
|
||||
print(f"Unknown computer action: {action_type}")
|
||||
return []
|
||||
|
||||
# Take screenshot after action
|
||||
if self.screenshot_delay and self.screenshot_delay > 0:
|
||||
await asyncio.sleep(self.screenshot_delay)
|
||||
screenshot_base64 = await computer.screenshot()
|
||||
await self._on_screenshot(screenshot_base64, "screenshot_after")
|
||||
|
||||
# Handle safety checks
|
||||
pending_checks = item.get("pending_safety_checks", [])
|
||||
acknowledged_checks = []
|
||||
for check in pending_checks:
|
||||
check_message = check.get("message", str(check))
|
||||
acknowledged_checks.append(check)
|
||||
# TODO: implement a callback for safety checks
|
||||
# if acknowledge_safety_check_callback(check_message, allow_always=True):
|
||||
# acknowledged_checks.append(check)
|
||||
# else:
|
||||
# raise ValueError(f"Safety check failed: {check_message}")
|
||||
|
||||
# Create call output
|
||||
call_output = {
|
||||
"type": "computer_call_output",
|
||||
"call_id": item.get("call_id"),
|
||||
"acknowledged_safety_checks": acknowledged_checks,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshot_base64}",
|
||||
},
|
||||
}
|
||||
|
||||
# # Additional URL safety checks for browser environments
|
||||
# if await computer.get_environment() == "browser":
|
||||
# current_url = await computer.get_current_url()
|
||||
# call_output["output"]["current_url"] = current_url
|
||||
# # TODO: implement a callback for URL safety checks
|
||||
# # check_blocklisted_url(current_url)
|
||||
|
||||
result = [call_output]
|
||||
await self._on_computer_call_end(item, result)
|
||||
return result
|
||||
|
||||
# Extract action arguments (all fields except 'type')
|
||||
action_args = {k: v for k, v in action.items() if k != "type"}
|
||||
if item_type == "function_call":
|
||||
await self._on_function_call_start(item)
|
||||
# Perform function call
|
||||
function = self._get_tool(item.get("name"))
|
||||
if not function:
|
||||
raise ValueError(f"Function {item.get("name")} not found")
|
||||
|
||||
# print(f"{action_type}({action_args})")
|
||||
|
||||
# Execute the computer action
|
||||
computer_method = getattr(computer, action_type, None)
|
||||
if computer_method:
|
||||
await computer_method(**action_args)
|
||||
else:
|
||||
print(f"Unknown computer action: {action_type}")
|
||||
return []
|
||||
|
||||
# Take screenshot after action
|
||||
if self.screenshot_delay and self.screenshot_delay > 0:
|
||||
await asyncio.sleep(self.screenshot_delay)
|
||||
screenshot_base64 = await computer.screenshot()
|
||||
await self._on_screenshot(screenshot_base64, "screenshot_after")
|
||||
|
||||
# Handle safety checks
|
||||
pending_checks = item.get("pending_safety_checks", [])
|
||||
acknowledged_checks = []
|
||||
for check in pending_checks:
|
||||
check_message = check.get("message", str(check))
|
||||
acknowledged_checks.append(check)
|
||||
# TODO: implement a callback for safety checks
|
||||
# if acknowledge_safety_check_callback(check_message, allow_always=True):
|
||||
# acknowledged_checks.append(check)
|
||||
# else:
|
||||
# raise ValueError(f"Safety check failed: {check_message}")
|
||||
|
||||
# Create call output
|
||||
call_output = {
|
||||
"type": "computer_call_output",
|
||||
"call_id": item.get("call_id"),
|
||||
"acknowledged_safety_checks": acknowledged_checks,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshot_base64}",
|
||||
},
|
||||
}
|
||||
|
||||
# # Additional URL safety checks for browser environments
|
||||
# if await computer.get_environment() == "browser":
|
||||
# current_url = await computer.get_current_url()
|
||||
# call_output["output"]["current_url"] = current_url
|
||||
# # TODO: implement a callback for URL safety checks
|
||||
# # check_blocklisted_url(current_url)
|
||||
|
||||
result = [call_output]
|
||||
await self._on_computer_call_end(item, result)
|
||||
return result
|
||||
|
||||
if item_type == "function_call":
|
||||
await self._on_function_call_start(item)
|
||||
# Perform function call
|
||||
function = self._get_tool(item.get("name"))
|
||||
if not function:
|
||||
raise ValueError(f"Function {item.get("name")} not found")
|
||||
|
||||
args = json.loads(item.get("arguments"))
|
||||
args = json.loads(item.get("arguments"))
|
||||
|
||||
# Execute function - use asyncio.to_thread for non-async functions
|
||||
if inspect.iscoroutinefunction(function):
|
||||
result = await function(**args)
|
||||
else:
|
||||
result = await asyncio.to_thread(function, **args)
|
||||
|
||||
# Create function call output
|
||||
call_output = {
|
||||
"type": "function_call_output",
|
||||
"call_id": item.get("call_id"),
|
||||
"output": str(result),
|
||||
}
|
||||
|
||||
result = [call_output]
|
||||
await self._on_function_call_end(item, result)
|
||||
return result
|
||||
# Validate arguments before execution
|
||||
assert_callable_with(function, **args)
|
||||
|
||||
# Execute function - use asyncio.to_thread for non-async functions
|
||||
if inspect.iscoroutinefunction(function):
|
||||
result = await function(**args)
|
||||
else:
|
||||
result = await asyncio.to_thread(function, **args)
|
||||
|
||||
# Create function call output
|
||||
call_output = {
|
||||
"type": "function_call_output",
|
||||
"call_id": item.get("call_id"),
|
||||
"output": str(result),
|
||||
}
|
||||
|
||||
result = [call_output]
|
||||
await self._on_function_call_end(item, result)
|
||||
return result
|
||||
except ToolError as e:
|
||||
return [make_tool_error_item(repr(e), call_id)]
|
||||
|
||||
return []
|
||||
|
||||
@@ -569,6 +592,7 @@ class ComputerAgent:
|
||||
# - PII anonymization
|
||||
# - Image retention policy
|
||||
combined_messages = old_items + new_items
|
||||
combined_messages = replace_failed_computer_calls_with_function_calls(combined_messages)
|
||||
preprocessed_messages = await self._on_llm_start(combined_messages)
|
||||
|
||||
loop_kwargs = {
|
||||
|
||||
@@ -252,6 +252,53 @@ def make_failed_tool_call_items(tool_name: str, tool_kwargs: Dict[str, Any], err
|
||||
}
|
||||
]
|
||||
|
||||
def make_tool_error_item(error_message: str, call_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
call_id = call_id if call_id else random_id()
|
||||
return {
|
||||
"type": "function_call_output",
|
||||
"call_id": call_id,
|
||||
"output": json.dumps({"error": error_message}),
|
||||
}
|
||||
|
||||
def replace_failed_computer_calls_with_function_calls(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Replace computer_call items with function_call items if they share a call_id with a function_call_output.
|
||||
This indicates the computer call failed and should be treated as a function call instead.
|
||||
We do this because the computer_call_output items do not support text output.
|
||||
|
||||
Args:
|
||||
messages: List of message items to process
|
||||
"""
|
||||
messages = messages.copy()
|
||||
|
||||
# Find all call_ids that have function_call_output items
|
||||
failed_call_ids = set()
|
||||
for msg in messages:
|
||||
if msg.get("type") == "function_call_output":
|
||||
call_id = msg.get("call_id")
|
||||
if call_id:
|
||||
failed_call_ids.add(call_id)
|
||||
|
||||
# Replace computer_call items that have matching call_ids
|
||||
for i, msg in enumerate(messages):
|
||||
if (msg.get("type") == "computer_call" and
|
||||
msg.get("call_id") in failed_call_ids):
|
||||
|
||||
# Extract action from computer_call
|
||||
action = msg.get("action", {})
|
||||
call_id = msg.get("call_id")
|
||||
|
||||
# Create function_call replacement
|
||||
messages[i] = {
|
||||
"type": "function_call",
|
||||
"id": msg.get("id", random_id()),
|
||||
"call_id": call_id,
|
||||
"name": "computer",
|
||||
"arguments": json.dumps(action),
|
||||
}
|
||||
|
||||
return messages
|
||||
|
||||
# Conversion functions between element descriptions and coordinates
|
||||
def convert_computer_calls_desc2xy(responses_items: List[Dict[str, Any]], desc2xy: Dict[str, tuple]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
|
||||
@@ -16,6 +16,15 @@ Tools = Optional[Iterable[ToolParam]]
|
||||
AgentResponse = ResponsesAPIResponse
|
||||
AgentCapability = Literal["step", "click"]
|
||||
|
||||
# Exception types
|
||||
class ToolError(RuntimeError):
|
||||
"""Base exception for tool-related errors"""
|
||||
pass
|
||||
|
||||
class IllegalArgumentError(ToolError):
|
||||
"""Exception raised when function arguments are invalid"""
|
||||
pass
|
||||
|
||||
|
||||
# Agent config registration
|
||||
class AgentConfigInfo(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user