Files
computer/libs/python/agent/agent/agent.py
2025-10-22 17:25:55 -07:00

763 lines
30 KiB
Python

"""
ComputerAgent - Main agent class that selects and runs agent loops
"""
import asyncio
import inspect
import json
from pathlib import Path
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Union,
cast,
)
import litellm
import litellm.utils
from litellm.responses.utils import Usage
from .adapters import (
HuggingFaceLocalAdapter,
HumanAdapter,
MLXVLMAdapter,
)
from .callbacks import (
BudgetManagerCallback,
ImageRetentionCallback,
LoggingCallback,
OperatorNormalizerCallback,
PromptInstructionsCallback,
TelemetryCallback,
TrajectorySaverCallback,
)
from .computers import AsyncComputerHandler, is_agent_computer, make_computer_handler
from .decorators import find_agent_config
from .responses import (
make_tool_error_item,
replace_failed_computer_calls_with_function_calls,
)
from .types import AgentCapability, IllegalArgumentError, Messages, ToolError
def assert_callable_with(f, *args, **kwargs):
"""Check if function can be called with given arguments."""
try:
inspect.signature(f).bind(*args, **kwargs)
return True
except TypeError as e:
sig = inspect.signature(f)
raise IllegalArgumentError(f"Expected {sig}, got args={args} kwargs={kwargs}") from e
def get_json(obj: Any, max_depth: int = 10) -> Any:
def custom_serializer(o: Any, depth: int = 0, seen: Optional[Set[int]] = None) -> Any:
if seen is None:
seen = set()
# Use model_dump() if available
if hasattr(o, "model_dump"):
return o.model_dump()
# Check depth limit
if depth > max_depth:
return f"<max_depth_exceeded:{max_depth}>"
# Check for circular references using object id
obj_id = id(o)
if obj_id in seen:
return f"<circular_reference:{type(o).__name__}>"
# Handle Computer objects
if hasattr(o, "__class__") and "computer" in o.__class__.__name__.lower():
return f"<computer:{o.__class__.__name__}>"
# Handle objects with __dict__
if hasattr(o, "__dict__"):
seen.add(obj_id)
try:
result = {}
for k, v in o.__dict__.items():
if v is not None:
# Recursively serialize with updated depth and seen set
serialized_value = custom_serializer(v, depth + 1, seen.copy())
result[k] = serialized_value
return result
finally:
seen.discard(obj_id)
# Handle common types that might contain nested objects
elif isinstance(o, dict):
seen.add(obj_id)
try:
return {
k: custom_serializer(v, depth + 1, seen.copy())
for k, v in o.items()
if v is not None
}
finally:
seen.discard(obj_id)
elif isinstance(o, (list, tuple, set)):
seen.add(obj_id)
try:
return [
custom_serializer(item, depth + 1, seen.copy())
for item in o
if item is not None
]
finally:
seen.discard(obj_id)
# For basic types that json.dumps can handle
elif isinstance(o, (str, int, float, bool)) or o is None:
return o
# Fallback to string representation
else:
return str(o)
def remove_nones(obj: Any) -> Any:
if isinstance(obj, dict):
return {k: remove_nones(v) for k, v in obj.items() if v is not None}
elif isinstance(obj, list):
return [remove_nones(item) for item in obj if item is not None]
return obj
# Serialize with circular reference and depth protection
serialized = custom_serializer(obj)
# Convert to JSON string and back to ensure JSON compatibility
json_str = json.dumps(serialized)
parsed = json.loads(json_str)
# Final cleanup of any remaining None values
return remove_nones(parsed)
def sanitize_message(msg: Any) -> Any:
"""Return a copy of the message with image_url omitted for computer_call_output messages."""
if msg.get("type") == "computer_call_output":
output = msg.get("output", {})
if isinstance(output, dict):
sanitized = msg.copy()
sanitized["output"] = {**output, "image_url": "[omitted]"}
return sanitized
return msg
def get_output_call_ids(messages: List[Dict[str, Any]]) -> List[str]:
call_ids = []
for message in messages:
if (
message.get("type") == "computer_call_output"
or message.get("type") == "function_call_output"
):
call_ids.append(message.get("call_id"))
return call_ids
class ComputerAgent:
"""
Main agent class that automatically selects the appropriate agent loop
based on the model and executes tool calls.
"""
def __init__(
self,
model: str,
tools: Optional[List[Any]] = None,
custom_loop: Optional[Callable] = None,
only_n_most_recent_images: Optional[int] = None,
callbacks: Optional[List[Any]] = None,
instructions: Optional[str] = None,
verbosity: Optional[int] = None,
trajectory_dir: Optional[str | Path | dict] = None,
max_retries: Optional[int] = 3,
screenshot_delay: Optional[float | int] = 0.5,
use_prompt_caching: Optional[bool] = False,
max_trajectory_budget: Optional[float | dict] = None,
telemetry_enabled: Optional[bool] = True,
trust_remote_code: Optional[bool] = False,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
**kwargs,
):
"""
Initialize ComputerAgent.
Args:
model: Model name (e.g., "claude-3-5-sonnet-20241022", "computer-use-preview", "omni+vertex_ai/gemini-pro")
tools: List of tools (computer objects, decorated functions, etc.)
custom_loop: Custom agent loop function to use instead of auto-selection
only_n_most_recent_images: If set, only keep the N most recent images in message history. Adds ImageRetentionCallback automatically.
callbacks: List of AsyncCallbackHandler instances for preprocessing/postprocessing
instructions: Optional system instructions to be passed to the model
verbosity: Logging level (logging.DEBUG, logging.INFO, etc.). If set, adds LoggingCallback automatically
trajectory_dir: If set, saves trajectory data (screenshots, responses) to this directory. Adds TrajectorySaverCallback automatically.
max_retries: Maximum number of retries for failed API calls
screenshot_delay: Delay before screenshots in seconds
use_prompt_caching: If set, use prompt caching to avoid reprocessing the same prompt. Intended for use with anthropic providers.
max_trajectory_budget: If set, adds BudgetManagerCallback to track usage costs and stop when budget is exceeded
telemetry_enabled: If set, adds TelemetryCallback to track anonymized usage data. Enabled by default.
trust_remote_code: If set, trust remote code when loading local models. Disabled by default.
**kwargs: Additional arguments passed to the agent loop
"""
# If the loop is "human/human", we need to prefix a grounding model fallback
if model in ["human/human", "human"]:
model = "openai/computer-use-preview+human/human"
self.model = model
self.tools = tools or []
self.custom_loop = custom_loop
self.only_n_most_recent_images = only_n_most_recent_images
self.callbacks = callbacks or []
self.instructions = instructions
self.verbosity = verbosity
self.trajectory_dir = trajectory_dir
self.max_retries = max_retries
self.screenshot_delay = screenshot_delay
self.use_prompt_caching = use_prompt_caching
self.telemetry_enabled = telemetry_enabled
self.kwargs = kwargs
self.trust_remote_code = trust_remote_code
self.api_key = api_key
self.api_base = api_base
# == Add built-in callbacks ==
# Prepend operator normalizer callback
self.callbacks.insert(0, OperatorNormalizerCallback())
# Add prompt instructions callback if provided
if self.instructions:
self.callbacks.append(PromptInstructionsCallback(self.instructions))
# Add telemetry callback if telemetry_enabled is set
if self.telemetry_enabled:
if isinstance(self.telemetry_enabled, bool):
self.callbacks.append(TelemetryCallback(self))
else:
self.callbacks.append(TelemetryCallback(self, **self.telemetry_enabled))
# Add logging callback if verbosity is set
if self.verbosity is not None:
self.callbacks.append(LoggingCallback(level=self.verbosity))
# Add image retention callback if only_n_most_recent_images is set
if self.only_n_most_recent_images:
self.callbacks.append(ImageRetentionCallback(self.only_n_most_recent_images))
# Add trajectory saver callback if trajectory_dir is set
if self.trajectory_dir:
if isinstance(self.trajectory_dir, dict):
self.callbacks.append(TrajectorySaverCallback(**self.trajectory_dir))
elif isinstance(self.trajectory_dir, (str, Path)):
self.callbacks.append(TrajectorySaverCallback(str(self.trajectory_dir)))
# Add budget manager if max_trajectory_budget is set
if max_trajectory_budget:
if isinstance(max_trajectory_budget, dict):
self.callbacks.append(BudgetManagerCallback(**max_trajectory_budget))
else:
self.callbacks.append(BudgetManagerCallback(max_trajectory_budget))
# == Enable local model providers w/ LiteLLM ==
# Register local model providers
hf_adapter = HuggingFaceLocalAdapter(
device="auto", trust_remote_code=self.trust_remote_code or False
)
human_adapter = HumanAdapter()
mlx_adapter = MLXVLMAdapter()
litellm.custom_provider_map = [
{"provider": "huggingface-local", "custom_handler": hf_adapter},
{"provider": "human", "custom_handler": human_adapter},
{"provider": "mlx", "custom_handler": mlx_adapter},
]
litellm.suppress_debug_info = True
# == Initialize computer agent ==
# Find the appropriate agent loop
if custom_loop:
self.agent_loop = custom_loop
self.agent_config_info = None
else:
config_info = find_agent_config(model)
if not config_info:
raise ValueError(f"No agent config found for model: {model}")
# Instantiate the agent config class
self.agent_loop = config_info.agent_class()
self.agent_config_info = config_info
self.tool_schemas = []
self.computer_handler = None
async def _initialize_computers(self):
"""Initialize computer objects"""
if not self.tool_schemas:
# Process tools and create tool schemas
self.tool_schemas = self._process_tools()
# Find computer tool and create interface adapter
computer_handler = None
for schema in self.tool_schemas:
if schema["type"] == "computer":
computer_handler = await make_computer_handler(schema["computer"])
break
self.computer_handler = computer_handler
def _process_input(self, input: Messages) -> List[Dict[str, Any]]:
"""Process input messages and create schemas for the agent loop"""
if isinstance(input, str):
return [{"role": "user", "content": input}]
return [get_json(msg) for msg in input]
def _process_tools(self) -> List[Dict[str, Any]]:
"""Process tools and create schemas for the agent loop"""
schemas = []
for tool in self.tools:
# Check if it's a computer object (has interface attribute)
if is_agent_computer(tool):
# This is a computer tool - will be handled by agent loop
schemas.append({"type": "computer", "computer": tool})
elif callable(tool):
# Use litellm.utils.function_to_dict to extract schema from docstring
try:
function_schema = litellm.utils.function_to_dict(tool)
schemas.append({"type": "function", "function": function_schema})
except Exception as e:
print(f"Warning: Could not process tool {tool}: {e}")
else:
print(f"Warning: Unknown tool type: {tool}")
return schemas
def _get_tool(self, name: str) -> Optional[Callable]:
"""Get a tool by name"""
for tool in self.tools:
if hasattr(tool, "__name__") and tool.__name__ == name:
return tool
elif hasattr(tool, "func") and tool.func.__name__ == name:
return tool
return None
# ============================================================================
# AGENT RUN LOOP LIFECYCLE HOOKS
# ============================================================================
async def _on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
"""Initialize run tracking by calling callbacks."""
for callback in self.callbacks:
if hasattr(callback, "on_run_start"):
await callback.on_run_start(kwargs, old_items)
async def _on_run_end(
self,
kwargs: Dict[str, Any],
old_items: List[Dict[str, Any]],
new_items: List[Dict[str, Any]],
) -> None:
"""Finalize run tracking by calling callbacks."""
for callback in self.callbacks:
if hasattr(callback, "on_run_end"):
await callback.on_run_end(kwargs, old_items, new_items)
async def _on_run_continue(
self,
kwargs: Dict[str, Any],
old_items: List[Dict[str, Any]],
new_items: List[Dict[str, Any]],
) -> bool:
"""Check if run should continue by calling callbacks."""
for callback in self.callbacks:
if hasattr(callback, "on_run_continue"):
should_continue = await callback.on_run_continue(kwargs, old_items, new_items)
if not should_continue:
return False
return True
async def _on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Prepare messages for the LLM call by applying callbacks."""
result = messages
for callback in self.callbacks:
if hasattr(callback, "on_llm_start"):
result = await callback.on_llm_start(result)
return result
async def _on_llm_end(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Postprocess messages after the LLM call by applying callbacks."""
result = messages
for callback in self.callbacks:
if hasattr(callback, "on_llm_end"):
result = await callback.on_llm_end(result)
return result
async def _on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
"""Called when responses are received."""
for callback in self.callbacks:
if hasattr(callback, "on_responses"):
await callback.on_responses(get_json(kwargs), get_json(responses))
async def _on_computer_call_start(self, item: Dict[str, Any]) -> None:
"""Called when a computer call is about to start."""
for callback in self.callbacks:
if hasattr(callback, "on_computer_call_start"):
await callback.on_computer_call_start(get_json(item))
async def _on_computer_call_end(
self, item: Dict[str, Any], result: List[Dict[str, Any]]
) -> None:
"""Called when a computer call has completed."""
for callback in self.callbacks:
if hasattr(callback, "on_computer_call_end"):
await callback.on_computer_call_end(get_json(item), get_json(result))
async def _on_function_call_start(self, item: Dict[str, Any]) -> None:
"""Called when a function call is about to start."""
for callback in self.callbacks:
if hasattr(callback, "on_function_call_start"):
await callback.on_function_call_start(get_json(item))
async def _on_function_call_end(
self, item: Dict[str, Any], result: List[Dict[str, Any]]
) -> None:
"""Called when a function call has completed."""
for callback in self.callbacks:
if hasattr(callback, "on_function_call_end"):
await callback.on_function_call_end(get_json(item), get_json(result))
async def _on_text(self, item: Dict[str, Any]) -> None:
"""Called when a text message is encountered."""
for callback in self.callbacks:
if hasattr(callback, "on_text"):
await callback.on_text(get_json(item))
async def _on_api_start(self, kwargs: Dict[str, Any]) -> None:
"""Called when an LLM API call is about to start."""
for callback in self.callbacks:
if hasattr(callback, "on_api_start"):
await callback.on_api_start(get_json(kwargs))
async def _on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
"""Called when an LLM API call has completed."""
for callback in self.callbacks:
if hasattr(callback, "on_api_end"):
await callback.on_api_end(get_json(kwargs), get_json(result))
async def _on_usage(self, usage: Dict[str, Any]) -> None:
"""Called when usage information is received."""
for callback in self.callbacks:
if hasattr(callback, "on_usage"):
await callback.on_usage(get_json(usage))
async def _on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
"""Called when a screenshot is taken."""
for callback in self.callbacks:
if hasattr(callback, "on_screenshot"):
await callback.on_screenshot(screenshot, name)
# ============================================================================
# AGENT OUTPUT PROCESSING
# ============================================================================
async def _handle_item(
self,
item: Any,
computer: Optional[AsyncComputerHandler] = None,
ignore_call_ids: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""Handle each item; may cause a computer action + screenshot."""
call_id = item.get("call_id")
if ignore_call_ids and call_id and call_id in ignore_call_ids:
return []
item_type = item.get("type", None)
if item_type == "message":
await self._on_text(item)
# # Print messages
# if item.get("content"):
# for content_item in item.get("content"):
# if content_item.get("text"):
# print(content_item.get("text"))
return []
try:
if item_type == "computer_call":
await self._on_computer_call_start(item)
if not computer:
raise ValueError("Computer handler is required for computer calls")
# Perform computer actions
action = item.get("action")
action_type = action.get("type")
if action_type is None:
print(
f"Action type cannot be `None`: action={action}, action_type={action_type}"
)
return []
# Extract action arguments (all fields except 'type')
action_args = {k: v for k, v in action.items() if k != "type"}
# print(f"{action_type}({action_args})")
# Execute the computer action
computer_method = getattr(computer, action_type, None)
if computer_method:
assert_callable_with(computer_method, **action_args)
await computer_method(**action_args)
else:
raise ToolError(f"Unknown computer action: {action_type}")
# Take screenshot after action
if self.screenshot_delay and self.screenshot_delay > 0:
await asyncio.sleep(self.screenshot_delay)
screenshot_base64 = await computer.screenshot()
await self._on_screenshot(screenshot_base64, "screenshot_after")
# Handle safety checks
pending_checks = item.get("pending_safety_checks", [])
acknowledged_checks = []
for check in pending_checks:
check_message = check.get("message", str(check))
acknowledged_checks.append(check)
# TODO: implement a callback for safety checks
# if acknowledge_safety_check_callback(check_message, allow_always=True):
# acknowledged_checks.append(check)
# else:
# raise ValueError(f"Safety check failed: {check_message}")
# Create call output
call_output = {
"type": "computer_call_output",
"call_id": item.get("call_id"),
"acknowledged_safety_checks": acknowledged_checks,
"output": {
"type": "input_image",
"image_url": f"data:image/png;base64,{screenshot_base64}",
},
}
# # Additional URL safety checks for browser environments
# if await computer.get_environment() == "browser":
# current_url = await computer.get_current_url()
# call_output["output"]["current_url"] = current_url
# # TODO: implement a callback for URL safety checks
# # check_blocklisted_url(current_url)
result = [call_output]
await self._on_computer_call_end(item, result)
return result
if item_type == "function_call":
await self._on_function_call_start(item)
# Perform function call
function = self._get_tool(item.get("name"))
if not function:
raise ToolError(f"Function {item.get('name')} not found")
args = json.loads(item.get("arguments"))
# Validate arguments before execution
assert_callable_with(function, **args)
# Execute function - use asyncio.to_thread for non-async functions
if inspect.iscoroutinefunction(function):
result = await function(**args)
else:
result = await asyncio.to_thread(function, **args)
# Create function call output
call_output = {
"type": "function_call_output",
"call_id": item.get("call_id"),
"output": str(result),
}
result = [call_output]
await self._on_function_call_end(item, result)
return result
except ToolError as e:
return [make_tool_error_item(repr(e), call_id)]
return []
# ============================================================================
# MAIN AGENT LOOP
# ============================================================================
async def run(
self, messages: Messages, stream: bool = False, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Run the agent with the given messages using Computer protocol handler pattern.
Args:
messages: List of message dictionaries
stream: Whether to stream the response
**kwargs: Additional arguments
Returns:
AsyncGenerator that yields response chunks
"""
if not self.agent_config_info:
raise ValueError("Agent configuration not found")
capabilities = self.get_capabilities()
if "step" not in capabilities:
raise ValueError(
f"Agent loop {self.agent_config_info.agent_class.__name__} does not support step predictions"
)
await self._initialize_computers()
# Merge kwargs and thread api credentials (run overrides constructor)
merged_kwargs = {**self.kwargs, **kwargs}
if (api_key is not None) or (self.api_key is not None):
merged_kwargs["api_key"] = api_key if api_key is not None else self.api_key
if (api_base is not None) or (self.api_base is not None):
merged_kwargs["api_base"] = api_base if api_base is not None else self.api_base
old_items = self._process_input(messages)
new_items = []
# Initialize run tracking
run_kwargs = {
"messages": messages,
"stream": stream,
"model": self.model,
"agent_loop": self.agent_config_info.agent_class.__name__,
**merged_kwargs,
}
await self._on_run_start(run_kwargs, old_items)
while new_items[-1].get("role") != "assistant" if new_items else True:
# Lifecycle hook: Check if we should continue based on callbacks (e.g., budget manager)
should_continue = await self._on_run_continue(run_kwargs, old_items, new_items)
if not should_continue:
break
# Lifecycle hook: Prepare messages for the LLM call
# Use cases:
# - PII anonymization
# - Image retention policy
combined_messages = old_items + new_items
combined_messages = replace_failed_computer_calls_with_function_calls(combined_messages)
preprocessed_messages = await self._on_llm_start(combined_messages)
loop_kwargs = {
"messages": preprocessed_messages,
"model": self.model,
"tools": self.tool_schemas,
"stream": False,
"computer_handler": self.computer_handler,
"max_retries": self.max_retries,
"use_prompt_caching": self.use_prompt_caching,
**merged_kwargs,
}
# Run agent loop iteration
result = await self.agent_loop.predict_step(
**loop_kwargs,
_on_api_start=self._on_api_start,
_on_api_end=self._on_api_end,
_on_usage=self._on_usage,
_on_screenshot=self._on_screenshot,
)
result = get_json(result)
# Lifecycle hook: Postprocess messages after the LLM call
# Use cases:
# - PII deanonymization (if you want tool calls to see PII)
result["output"] = await self._on_llm_end(result.get("output", []))
await self._on_responses(loop_kwargs, result)
# Yield agent response
yield result
# Add agent response to new_items
new_items += result.get("output")
# Get output call ids
output_call_ids = get_output_call_ids(result.get("output", []))
# Handle computer actions
for item in result.get("output"):
partial_items = await self._handle_item(
item, self.computer_handler, ignore_call_ids=output_call_ids
)
new_items += partial_items
# Yield partial response
yield {
"output": partial_items,
"usage": Usage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
),
}
await self._on_run_end(loop_kwargs, old_items, new_items)
async def predict_click(
self, instruction: str, image_b64: Optional[str] = None
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates based on image and instruction.
Args:
instruction: Instruction for where to click
image_b64: Base64 encoded image (optional, will take screenshot if not provided)
Returns:
None or tuple with (x, y) coordinates
"""
if not self.agent_config_info:
raise ValueError("Agent configuration not found")
capabilities = self.get_capabilities()
if "click" not in capabilities:
raise ValueError(
f"Agent loop {self.agent_config_info.agent_class.__name__} does not support click predictions"
)
if hasattr(self.agent_loop, "predict_click"):
if not image_b64:
if not self.computer_handler:
raise ValueError("Computer tool or image_b64 is required for predict_click")
image_b64 = await self.computer_handler.screenshot()
# Pass along api credentials if available
click_kwargs: Dict[str, Any] = {}
if self.api_key is not None:
click_kwargs["api_key"] = self.api_key
if self.api_base is not None:
click_kwargs["api_base"] = self.api_base
return await self.agent_loop.predict_click(
model=self.model, image_b64=image_b64, instruction=instruction, **click_kwargs
)
return None
def get_capabilities(self) -> List[AgentCapability]:
"""
Get list of capabilities supported by the current agent config.
Returns:
List of capability strings (e.g., ["step", "click"])
"""
if not self.agent_config_info:
raise ValueError("Agent configuration not found")
if hasattr(self.agent_loop, "get_capabilities"):
return self.agent_loop.get_capabilities()
return ["step"] # Default capability