mirror of
https://github.com/trycua/computer.git
synced 2026-01-04 12:30:08 -06:00
991 lines
43 KiB
Python
991 lines
43 KiB
Python
"""Omni-specific agent loop implementation."""
|
|
|
|
import logging
|
|
from typing import Any, Dict, List, Optional, Tuple, AsyncGenerator
|
|
import json
|
|
import re
|
|
import os
|
|
import asyncio
|
|
from httpx import ConnectError, ReadTimeout
|
|
from typing import cast
|
|
|
|
from .parser import OmniParser, ParseResult
|
|
from ...core.base import BaseLoop
|
|
from ...core.visualization import VisualizationHelper
|
|
from ...core.messages import StandardMessageManager, ImageRetentionConfig
|
|
from .utils import to_openai_agent_response_format
|
|
from ...core.types import AgentResponse
|
|
from computer import Computer
|
|
from ...core.types import LLMProvider
|
|
from .clients.openai import OpenAIClient
|
|
from .clients.anthropic import AnthropicClient
|
|
from .clients.ollama import OllamaClient
|
|
from .clients.oaicompat import OAICompatClient
|
|
from .prompts import SYSTEM_PROMPT
|
|
from .api_handler import OmniAPIHandler
|
|
from .tools.manager import ToolManager
|
|
from .tools import ToolResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def extract_data(input_string: str, data_type: str) -> str:
|
|
"""Extract content from code blocks."""
|
|
pattern = f"```{data_type}" + r"(.*?)(```|$)"
|
|
matches = re.findall(pattern, input_string, re.DOTALL)
|
|
return matches[0][0].strip() if matches else input_string
|
|
|
|
|
|
class OmniLoop(BaseLoop):
|
|
"""Omni-specific implementation of the agent loop.
|
|
|
|
This class extends BaseLoop to provide support for multimodal models
|
|
from various providers (OpenAI, Anthropic, etc.) with UI parsing
|
|
and desktop automation capabilities.
|
|
"""
|
|
|
|
###########################################
|
|
# INITIALIZATION AND CONFIGURATION
|
|
###########################################
|
|
|
|
def __init__(
|
|
self,
|
|
parser: OmniParser,
|
|
provider: LLMProvider,
|
|
api_key: str,
|
|
model: str,
|
|
computer: Computer,
|
|
only_n_most_recent_images: Optional[int] = 2,
|
|
base_dir: Optional[str] = "trajectories",
|
|
max_retries: int = 3,
|
|
retry_delay: float = 1.0,
|
|
save_trajectory: bool = True,
|
|
provider_base_url: Optional[str] = None,
|
|
**kwargs,
|
|
):
|
|
"""Initialize the loop.
|
|
|
|
Args:
|
|
parser: Parser instance
|
|
provider: API provider
|
|
api_key: API key
|
|
model: Model name
|
|
computer: Computer instance
|
|
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
|
base_dir: Base directory for saving experiment data
|
|
max_retries: Maximum number of retries for API calls
|
|
retry_delay: Delay between retries in seconds
|
|
save_trajectory: Whether to save trajectory data
|
|
provider_base_url: Base URL for the API provider (used for OAICOMPAT)
|
|
"""
|
|
# Set parser and provider before initializing base class
|
|
self.parser = parser
|
|
self.provider = provider
|
|
self.provider_base_url = provider_base_url
|
|
|
|
# Initialize message manager with image retention config
|
|
self.message_manager = StandardMessageManager(
|
|
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
|
|
)
|
|
|
|
# Initialize base class (which will set up experiment manager)
|
|
super().__init__(
|
|
computer=computer,
|
|
model=model,
|
|
api_key=api_key,
|
|
max_retries=max_retries,
|
|
retry_delay=retry_delay,
|
|
base_dir=base_dir,
|
|
save_trajectory=save_trajectory,
|
|
only_n_most_recent_images=only_n_most_recent_images,
|
|
**kwargs,
|
|
)
|
|
|
|
# Set API client attributes
|
|
self.client = None
|
|
self.retry_count = 0
|
|
self.loop_task = None # Store the loop task for cancellation
|
|
|
|
# Initialize handlers
|
|
self.api_handler = OmniAPIHandler(loop=self)
|
|
self.viz_helper = VisualizationHelper(agent=self)
|
|
|
|
# Initialize tool manager
|
|
self.tool_manager = ToolManager(computer=computer, provider=provider)
|
|
|
|
logger.info("OmniLoop initialized with StandardMessageManager")
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize the loop by setting up tools and clients."""
|
|
# Initialize base class
|
|
await super().initialize()
|
|
|
|
# Initialize tool manager with error handling
|
|
try:
|
|
logger.info("Initializing tool manager...")
|
|
await self.tool_manager.initialize()
|
|
logger.info("Tool manager initialized successfully.")
|
|
except Exception as e:
|
|
logger.error(f"Error initializing tool manager: {str(e)}")
|
|
logger.warning("Will attempt to initialize tools on first use.")
|
|
|
|
# Initialize API clients based on provider
|
|
if self.provider == LLMProvider.ANTHROPIC:
|
|
self.client = AnthropicClient(
|
|
api_key=self.api_key,
|
|
model=self.model,
|
|
)
|
|
elif self.provider == LLMProvider.OPENAI:
|
|
self.client = OpenAIClient(
|
|
api_key=self.api_key,
|
|
model=self.model,
|
|
)
|
|
elif self.provider == LLMProvider.OLLAMA:
|
|
self.client = OllamaClient(
|
|
api_key=self.api_key,
|
|
model=self.model,
|
|
)
|
|
elif self.provider == LLMProvider.OAICOMPAT:
|
|
self.client = OAICompatClient(
|
|
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
|
|
model=self.model,
|
|
provider_base_url=self.provider_base_url,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported provider: {self.provider}")
|
|
|
|
###########################################
|
|
# CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD
|
|
###########################################
|
|
|
|
async def initialize_client(self) -> None:
|
|
"""Initialize the appropriate client based on provider.
|
|
|
|
Implements abstract method from BaseLoop to set up the specific
|
|
provider client (OpenAI, Anthropic, etc.).
|
|
"""
|
|
try:
|
|
logger.info(f"Initializing {self.provider} client with model {self.model}...")
|
|
|
|
if self.provider == LLMProvider.OPENAI:
|
|
self.client = OpenAIClient(api_key=self.api_key, model=self.model)
|
|
elif self.provider == LLMProvider.ANTHROPIC:
|
|
self.client = AnthropicClient(
|
|
api_key=self.api_key,
|
|
model=self.model,
|
|
max_retries=self.max_retries,
|
|
retry_delay=self.retry_delay,
|
|
)
|
|
elif self.provider == LLMProvider.OLLAMA:
|
|
self.client = OllamaClient(
|
|
api_key=self.api_key,
|
|
model=self.model,
|
|
)
|
|
elif self.provider == LLMProvider.OAICOMPAT:
|
|
self.client = OAICompatClient(
|
|
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
|
|
model=self.model,
|
|
provider_base_url=self.provider_base_url,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported provider: {self.provider}")
|
|
|
|
logger.info(f"Initialized {self.provider} client with model {self.model}")
|
|
except Exception as e:
|
|
logger.error(f"Error initializing client: {str(e)}")
|
|
self.client = None
|
|
raise RuntimeError(f"Failed to initialize client: {str(e)}")
|
|
|
|
###########################################
|
|
# API CALL HANDLING
|
|
###########################################
|
|
|
|
async def _make_api_call(self, messages: List[Dict[str, Any]], system_prompt: str) -> Any:
|
|
"""Make API call to provider with retry logic."""
|
|
# Create new turn directory for this API call
|
|
self._create_turn_dir()
|
|
|
|
request_data = None
|
|
last_error = None
|
|
|
|
for attempt in range(self.max_retries):
|
|
try:
|
|
# Ensure client is initialized
|
|
if self.client is None:
|
|
logger.info(
|
|
f"Client not initialized in _make_api_call (attempt {attempt+1}), initializing now..."
|
|
)
|
|
await self.initialize_client()
|
|
if self.client is None:
|
|
raise RuntimeError("Failed to initialize client")
|
|
|
|
# Get messages in standard format from the message manager
|
|
self.message_manager.messages = messages.copy()
|
|
prepared_messages = self.message_manager.get_messages()
|
|
|
|
# Special handling for Anthropic
|
|
if self.provider == LLMProvider.ANTHROPIC:
|
|
# Convert to Anthropic format
|
|
anthropic_messages, anthropic_system = self.message_manager.to_anthropic_format(
|
|
prepared_messages
|
|
)
|
|
|
|
# Filter out any empty/invalid messages
|
|
filtered_messages = [
|
|
msg
|
|
for msg in anthropic_messages
|
|
if msg.get("role") in ["user", "assistant"]
|
|
]
|
|
|
|
# Ensure there's at least one message for Anthropic
|
|
if not filtered_messages:
|
|
logger.warning(
|
|
"No valid messages found for Anthropic API call. Adding a default user message."
|
|
)
|
|
filtered_messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "Please help with this task."}
|
|
],
|
|
}
|
|
]
|
|
|
|
# Combine system prompts if needed
|
|
final_system_prompt = anthropic_system or system_prompt
|
|
|
|
# Log request
|
|
request_data = {
|
|
"messages": filtered_messages,
|
|
"max_tokens": self.max_tokens,
|
|
"system": final_system_prompt,
|
|
}
|
|
|
|
self._log_api_call("request", request_data)
|
|
|
|
# Make API call
|
|
response = await self.client.run_interleaved(
|
|
messages=filtered_messages,
|
|
system=final_system_prompt,
|
|
max_tokens=self.max_tokens,
|
|
)
|
|
else:
|
|
# For OpenAI and others, use standard format directly
|
|
# Log request
|
|
request_data = {
|
|
"messages": prepared_messages,
|
|
"max_tokens": self.max_tokens,
|
|
"system": system_prompt,
|
|
}
|
|
|
|
self._log_api_call("request", request_data)
|
|
|
|
# Make API call
|
|
response = await self.client.run_interleaved(
|
|
messages=prepared_messages,
|
|
system=system_prompt,
|
|
max_tokens=self.max_tokens,
|
|
)
|
|
|
|
# Log success response
|
|
self._log_api_call("response", request_data, response)
|
|
|
|
return response
|
|
|
|
except (ConnectError, ReadTimeout) as e:
|
|
last_error = e
|
|
logger.warning(
|
|
f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
|
|
)
|
|
if attempt < self.max_retries - 1:
|
|
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
|
|
# Reset client on connection errors to force re-initialization
|
|
self.client = None
|
|
continue
|
|
|
|
except RuntimeError as e:
|
|
# Handle client initialization errors specifically
|
|
last_error = e
|
|
self._log_api_call("error", request_data, error=e)
|
|
logger.error(
|
|
f"Client initialization error (attempt {attempt + 1}/{self.max_retries}): {str(e)}"
|
|
)
|
|
if attempt < self.max_retries - 1:
|
|
# Reset client to force re-initialization
|
|
self.client = None
|
|
await asyncio.sleep(self.retry_delay)
|
|
continue
|
|
|
|
except Exception as e:
|
|
# Log unexpected error
|
|
last_error = e
|
|
self._log_api_call("error", request_data, error=e)
|
|
logger.error(f"Unexpected error in API call: {str(e)}")
|
|
if attempt < self.max_retries - 1:
|
|
await asyncio.sleep(self.retry_delay)
|
|
continue
|
|
|
|
# If we get here, all retries failed
|
|
error_message = f"API call failed after {self.max_retries} attempts"
|
|
if last_error:
|
|
error_message += f": {str(last_error)}"
|
|
|
|
logger.error(error_message)
|
|
raise RuntimeError(error_message)
|
|
|
|
###########################################
|
|
# RESPONSE AND ACTION HANDLING
|
|
###########################################
|
|
|
|
async def _handle_response(
|
|
self, response: Any, messages: List[Dict[str, Any]], parsed_screen: ParseResult
|
|
) -> Tuple[bool, bool]:
|
|
"""Handle API response.
|
|
|
|
Args:
|
|
response: API response
|
|
messages: List of messages to update
|
|
parsed_screen: Current parsed screen information
|
|
|
|
Returns:
|
|
Tuple of (should_continue, action_screenshot_saved)
|
|
"""
|
|
action_screenshot_saved = False
|
|
|
|
# Helper function to safely add assistant messages using the message manager
|
|
def add_assistant_message(content):
|
|
if isinstance(content, str):
|
|
# Convert string to proper format
|
|
formatted_content = [{"type": "text", "text": content}]
|
|
self.message_manager.add_assistant_message(formatted_content)
|
|
logger.info("Added formatted text assistant message")
|
|
elif isinstance(content, list):
|
|
# Already in proper format
|
|
self.message_manager.add_assistant_message(content)
|
|
logger.info("Added structured assistant message")
|
|
else:
|
|
# Default case - convert to string
|
|
formatted_content = [{"type": "text", "text": str(content)}]
|
|
self.message_manager.add_assistant_message(formatted_content)
|
|
logger.info("Added converted assistant message")
|
|
|
|
try:
|
|
# Step 1: Normalize response to standard format based on provider
|
|
standard_content = []
|
|
raw_text = None
|
|
|
|
# Convert response to standardized content based on provider
|
|
if self.provider == LLMProvider.ANTHROPIC:
|
|
if hasattr(response, "content") and isinstance(response.content, list):
|
|
# Convert Anthropic response to standard format
|
|
for block in response.content:
|
|
if hasattr(block, "type"):
|
|
if block.type == "text":
|
|
standard_content.append({"type": "text", "text": block.text})
|
|
# Store raw text for JSON parsing
|
|
if raw_text is None:
|
|
raw_text = block.text
|
|
else:
|
|
raw_text += "\n" + block.text
|
|
else:
|
|
# Add other block types
|
|
block_dict = {}
|
|
for key, value in vars(block).items():
|
|
if not key.startswith("_"):
|
|
block_dict[key] = value
|
|
standard_content.append(block_dict)
|
|
else:
|
|
logger.warning("Invalid Anthropic response format")
|
|
return True, action_screenshot_saved
|
|
elif self.provider == LLMProvider.OLLAMA:
|
|
try:
|
|
raw_text = response["message"]["content"]
|
|
standard_content = [{"type": "text", "text": raw_text}]
|
|
except (KeyError, TypeError, IndexError) as e:
|
|
logger.error(f"Invalid response format: {str(e)}")
|
|
return True, action_screenshot_saved
|
|
elif self.provider == LLMProvider.OAICOMPAT:
|
|
try:
|
|
# OpenAI-compatible response format
|
|
raw_text = response["choices"][0]["message"]["content"]
|
|
standard_content = [{"type": "text", "text": raw_text}]
|
|
except (KeyError, TypeError, IndexError) as e:
|
|
logger.error(f"Invalid response format: {str(e)}")
|
|
return True, action_screenshot_saved
|
|
else:
|
|
# Assume OpenAI or compatible format
|
|
try:
|
|
raw_text = response["choices"][0]["message"]["content"]
|
|
standard_content = [{"type": "text", "text": raw_text}]
|
|
except (KeyError, TypeError, IndexError) as e:
|
|
logger.error(f"Invalid response format: {str(e)}")
|
|
return True, action_screenshot_saved
|
|
|
|
# Step 2: Add the normalized response to message history
|
|
add_assistant_message(standard_content)
|
|
|
|
# Step 3: Extract JSON from the content for action execution
|
|
parsed_content = None
|
|
|
|
# If we have raw text, try to extract JSON from it
|
|
if raw_text:
|
|
# Try different approaches to extract JSON
|
|
try:
|
|
# First try to parse the whole content as JSON
|
|
parsed_content = json.loads(raw_text)
|
|
logger.info("Successfully parsed whole content as JSON")
|
|
except json.JSONDecodeError:
|
|
try:
|
|
# Try to find JSON block
|
|
json_content = extract_data(raw_text, "json")
|
|
parsed_content = json.loads(json_content)
|
|
logger.info("Successfully parsed JSON from code block")
|
|
except (json.JSONDecodeError, IndexError):
|
|
try:
|
|
# Look for JSON object pattern
|
|
import re # Local import to ensure availability
|
|
|
|
json_pattern = r"\{[^}]+\}"
|
|
json_match = re.search(json_pattern, raw_text)
|
|
if json_match:
|
|
json_str = json_match.group(0)
|
|
parsed_content = json.loads(json_str)
|
|
logger.info("Successfully parsed JSON from text")
|
|
else:
|
|
logger.error(f"No JSON found in content")
|
|
return True, action_screenshot_saved
|
|
except json.JSONDecodeError as e:
|
|
# Try to sanitize the JSON string and retry
|
|
try:
|
|
# Remove or replace invalid control characters
|
|
import re # Local import to ensure availability
|
|
|
|
sanitized_text = re.sub(r"[\x00-\x1F\x7F]", "", raw_text)
|
|
# Try parsing again with sanitized text
|
|
parsed_content = json.loads(sanitized_text)
|
|
logger.info(
|
|
"Successfully parsed JSON after sanitizing control characters"
|
|
)
|
|
except json.JSONDecodeError:
|
|
logger.error(f"Failed to parse JSON from text: {str(e)}")
|
|
return True, action_screenshot_saved
|
|
|
|
# Step 4: Process the parsed content if available
|
|
if parsed_content:
|
|
# Clean up Box ID format
|
|
if "Box ID" in parsed_content and isinstance(parsed_content["Box ID"], str):
|
|
parsed_content["Box ID"] = parsed_content["Box ID"].replace("Box #", "")
|
|
|
|
# Add any explanatory text as reasoning if not present
|
|
if "Explanation" not in parsed_content and raw_text:
|
|
# Extract any text before the JSON as reasoning
|
|
text_before_json = raw_text.split("{")[0].strip()
|
|
if text_before_json:
|
|
parsed_content["Explanation"] = text_before_json
|
|
|
|
# Log the parsed content for debugging
|
|
logger.info(f"Parsed content: {json.dumps(parsed_content, indent=2)}")
|
|
|
|
# Step 5: Execute the action
|
|
try:
|
|
# Execute action using the common helper method
|
|
should_continue, action_screenshot_saved = (
|
|
await self._execute_action_with_tools(
|
|
parsed_content, cast(ParseResult, parsed_screen)
|
|
)
|
|
)
|
|
|
|
# Check if task is complete
|
|
if parsed_content.get("Action") == "None":
|
|
return False, action_screenshot_saved
|
|
return should_continue, action_screenshot_saved
|
|
except Exception as e:
|
|
logger.error(f"Error executing action: {str(e)}")
|
|
# Update the last assistant message with error
|
|
error_message = [{"type": "text", "text": f"Error executing action: {str(e)}"}]
|
|
# Replace the last assistant message with the error
|
|
self.message_manager.add_assistant_message(error_message)
|
|
return False, action_screenshot_saved
|
|
|
|
return True, action_screenshot_saved
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error handling response: {str(e)}")
|
|
# Add error message using the message manager
|
|
error_message = [{"type": "text", "text": f"Error: {str(e)}"}]
|
|
self.message_manager.add_assistant_message(error_message)
|
|
raise
|
|
|
|
###########################################
|
|
# SCREEN PARSING - IMPLEMENTING ABSTRACT METHOD
|
|
###########################################
|
|
|
|
async def _get_parsed_screen_som(self, save_screenshot: bool = True) -> ParseResult:
|
|
"""Get parsed screen information with Screen Object Model.
|
|
|
|
Extends the base class method to use the OmniParser to parse the screen
|
|
and extract UI elements.
|
|
|
|
Args:
|
|
save_screenshot: Whether to save the screenshot (set to False when screenshots will be saved elsewhere)
|
|
|
|
Returns:
|
|
ParseResult containing screen information and elements
|
|
"""
|
|
try:
|
|
# Use the parser's parse_screen method which handles the screenshot internally
|
|
parsed_screen = await self.parser.parse_screen(computer=self.computer)
|
|
|
|
# Log information about the parsed results
|
|
logger.info(
|
|
f"Parsed screen with {len(parsed_screen.elements) if parsed_screen.elements else 0} elements"
|
|
)
|
|
|
|
# Save screenshot if requested and if we have image data
|
|
if save_screenshot and self.save_trajectory and parsed_screen.annotated_image_base64:
|
|
try:
|
|
# Extract just the image data (remove data:image/png;base64, prefix)
|
|
img_data = parsed_screen.annotated_image_base64
|
|
if "," in img_data:
|
|
img_data = img_data.split(",")[1]
|
|
|
|
# Process screenshot through hooks and save if needed
|
|
await self.handle_screenshot(img_data, action_type="state", parsed_screen=parsed_screen)
|
|
|
|
# Save with a generic "state" action type to indicate this is the current screen state
|
|
self._save_screenshot(img_data, action_type="state")
|
|
except Exception as e:
|
|
logger.error(f"Error saving screenshot: {str(e)}")
|
|
|
|
return parsed_screen
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting parsed screen: {str(e)}")
|
|
raise
|
|
|
|
def _get_system_prompt(self) -> str:
|
|
"""Get the system prompt for the model."""
|
|
return SYSTEM_PROMPT
|
|
|
|
###########################################
|
|
# MAIN LOOP - IMPLEMENTING ABSTRACT METHOD
|
|
###########################################
|
|
|
|
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
|
|
"""Run the agent loop with provided messages.
|
|
|
|
Args:
|
|
messages: List of messages in standard OpenAI format
|
|
|
|
Yields:
|
|
Agent response format
|
|
"""
|
|
try:
|
|
logger.info(f"Starting OmniLoop run with {len(messages)} messages")
|
|
|
|
# Initialize the message manager with the provided messages
|
|
self.message_manager.messages = messages.copy()
|
|
|
|
# Create queue for response streaming
|
|
queue = asyncio.Queue()
|
|
|
|
# Start loop in background task
|
|
self.loop_task = asyncio.create_task(self._run_loop(queue, messages))
|
|
|
|
# Process and yield messages as they arrive
|
|
while True:
|
|
try:
|
|
item = await queue.get()
|
|
if item is None: # Stop signal
|
|
break
|
|
yield item
|
|
queue.task_done()
|
|
except Exception as e:
|
|
logger.error(f"Error processing queue item: {str(e)}")
|
|
continue
|
|
|
|
# Wait for loop to complete
|
|
await self.loop_task
|
|
|
|
# Send completion message
|
|
yield {
|
|
"role": "assistant",
|
|
"content": "Task completed successfully.",
|
|
"metadata": {"title": "✅ Complete"},
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in run method: {str(e)}")
|
|
yield {
|
|
"role": "assistant",
|
|
"content": f"Error: {str(e)}",
|
|
"metadata": {"title": "❌ Error"},
|
|
}
|
|
|
|
async def _run_loop(self, queue: asyncio.Queue, messages: List[Dict[str, Any]]) -> None:
|
|
"""Internal method to run the agent loop with provided messages.
|
|
|
|
Args:
|
|
queue: Queue to put responses into
|
|
messages: List of messages in standard OpenAI format
|
|
"""
|
|
# Continue running until explicitly told to stop
|
|
running = True
|
|
turn_created = False
|
|
# Track if an action-specific screenshot has been saved this turn
|
|
action_screenshot_saved = False
|
|
|
|
attempt = 0
|
|
max_attempts = 3
|
|
|
|
while running and attempt < max_attempts:
|
|
try:
|
|
# Create a new turn directory if it's not already created
|
|
if not turn_created:
|
|
self._create_turn_dir()
|
|
turn_created = True
|
|
|
|
# Ensure client is initialized
|
|
if self.client is None:
|
|
logger.info("Initializing client...")
|
|
await self.initialize_client()
|
|
if self.client is None:
|
|
raise RuntimeError("Failed to initialize client")
|
|
logger.info("Client initialized successfully")
|
|
|
|
# Get up-to-date screen information
|
|
parsed_screen = await self._get_parsed_screen_som()
|
|
|
|
# Process screen info and update messages in standard format
|
|
try:
|
|
# Get image from parsed screen
|
|
image = parsed_screen.annotated_image_base64 or None
|
|
if image:
|
|
# Save elements as JSON if we have a turn directory
|
|
if self.current_turn_dir and hasattr(parsed_screen, "elements"):
|
|
elements_path = os.path.join(self.current_turn_dir, "elements.json")
|
|
with open(elements_path, "w") as f:
|
|
# Convert elements to dicts for JSON serialization
|
|
elements_json = [
|
|
elem.model_dump() for elem in parsed_screen.elements
|
|
]
|
|
json.dump(elements_json, f, indent=2)
|
|
logger.info(f"Saved elements to {elements_path}")
|
|
|
|
# Remove data URL prefix if present
|
|
if "," in image:
|
|
image = image.split(",")[1]
|
|
|
|
# Add screenshot to message history using message manager
|
|
self.message_manager.add_user_message(
|
|
[
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {"url": f"data:image/png;base64,{image}"},
|
|
}
|
|
]
|
|
)
|
|
logger.info("Added screenshot to message history")
|
|
except Exception as e:
|
|
logger.error(f"Error processing screen info: {str(e)}")
|
|
raise
|
|
|
|
# Get system prompt
|
|
system_prompt = self._get_system_prompt()
|
|
|
|
# Make API call with retries using the APIHandler
|
|
response = await self.api_handler.make_api_call(
|
|
self.message_manager.messages, system_prompt
|
|
)
|
|
|
|
# Handle the response (may execute actions)
|
|
# Returns: (should_continue, action_screenshot_saved)
|
|
should_continue, new_screenshot_saved = await self._handle_response(
|
|
response, self.message_manager.messages, parsed_screen
|
|
)
|
|
|
|
# Update whether an action screenshot was saved this turn
|
|
action_screenshot_saved = action_screenshot_saved or new_screenshot_saved
|
|
|
|
# Create OpenAI-compatible response format using utility function
|
|
openai_compatible_response = await to_openai_agent_response_format(
|
|
response=response,
|
|
messages=self.message_manager.messages,
|
|
model=self.model,
|
|
parsed_screen=parsed_screen,
|
|
parser=self.parser
|
|
)
|
|
# Log standardized response for ease of parsing
|
|
self._log_api_call("agent_response", request=None, response=openai_compatible_response)
|
|
|
|
# Put the response in the queue
|
|
await queue.put(openai_compatible_response)
|
|
|
|
# Check if we should continue this conversation
|
|
running = should_continue
|
|
|
|
# Create a new turn directory if we're continuing
|
|
if running:
|
|
turn_created = False
|
|
|
|
# Reset attempt counter on success
|
|
attempt = 0
|
|
|
|
except Exception as e:
|
|
attempt += 1
|
|
error_msg = f"Error in _run_loop method (attempt {attempt}/{max_attempts}): {str(e)}"
|
|
logger.error(error_msg)
|
|
|
|
# If this is our last attempt, provide more info about the error
|
|
if attempt >= max_attempts:
|
|
logger.error(f"Maximum retry attempts reached. Last error was: {str(e)}")
|
|
|
|
await queue.put({
|
|
"role": "assistant",
|
|
"content": f"Error: {str(e)}",
|
|
"metadata": {"title": "❌ Error"},
|
|
})
|
|
|
|
# Create a brief delay before retrying
|
|
await asyncio.sleep(1)
|
|
finally:
|
|
# Signal that we're done
|
|
await queue.put(None)
|
|
|
|
async def cancel(self) -> None:
|
|
"""Cancel the currently running agent loop task.
|
|
|
|
This method stops the ongoing processing in the agent loop
|
|
by cancelling the loop_task if it exists and is running.
|
|
"""
|
|
if self.loop_task and not self.loop_task.done():
|
|
logger.info("Cancelling Omni loop task")
|
|
self.loop_task.cancel()
|
|
try:
|
|
# Wait for the task to be cancelled with a timeout
|
|
await asyncio.wait_for(self.loop_task, timeout=2.0)
|
|
except asyncio.TimeoutError:
|
|
logger.warning("Timeout while waiting for loop task to cancel")
|
|
except asyncio.CancelledError:
|
|
logger.info("Loop task cancelled successfully")
|
|
except Exception as e:
|
|
logger.error(f"Error while cancelling loop task: {str(e)}")
|
|
finally:
|
|
logger.info("Omni loop task cancelled")
|
|
else:
|
|
logger.info("No active Omni loop task to cancel")
|
|
|
|
async def process_model_response(self, response_text: str) -> Optional[Dict[str, Any]]:
|
|
"""Process model response to extract tool calls.
|
|
|
|
Args:
|
|
response_text: Model response text
|
|
|
|
Returns:
|
|
Extracted tool information, or None if no tool call was found
|
|
"""
|
|
try:
|
|
# Ensure tools are initialized before use
|
|
await self._ensure_tools_initialized()
|
|
|
|
# Look for tool use in the response
|
|
if "function_call" in response_text or "tool_use" in response_text:
|
|
# The extract_tool_call method should be implemented in the OmniAPIHandler
|
|
# For now, we'll just use a simple approach
|
|
# This will be replaced with the proper implementation
|
|
tool_info = None
|
|
if "function_call" in response_text:
|
|
# Extract function call params
|
|
try:
|
|
# Simple extraction - in real code this would be more robust
|
|
import json
|
|
import re
|
|
|
|
match = re.search(r'"function_call"\s*:\s*{([^}]+)}', response_text)
|
|
if match:
|
|
function_text = "{" + match.group(1) + "}"
|
|
tool_info = json.loads(function_text)
|
|
except Exception as e:
|
|
logger.error(f"Error extracting function call: {str(e)}")
|
|
|
|
if tool_info:
|
|
try:
|
|
# Execute the tool
|
|
result = await self.tool_manager.execute_tool(
|
|
name=tool_info.get("name"), tool_input=tool_info.get("arguments", {})
|
|
)
|
|
# Handle the result
|
|
return {"tool_result": result}
|
|
except Exception as e:
|
|
error_msg = (
|
|
f"Error executing tool '{tool_info.get('name', 'unknown')}': {str(e)}"
|
|
)
|
|
logger.error(error_msg)
|
|
return {"tool_result": ToolResult(error=error_msg)}
|
|
except Exception as e:
|
|
logger.error(f"Error processing tool call: {str(e)}")
|
|
|
|
return None
|
|
|
|
async def process_response_with_tools(
|
|
self, response_text: str, parsed_screen: Optional[ParseResult] = None
|
|
) -> Tuple[bool, str]:
|
|
"""Process model response and execute tools.
|
|
|
|
Args:
|
|
response_text: Model response text
|
|
parsed_screen: Current parsed screen information (optional)
|
|
|
|
Returns:
|
|
Tuple of (action_taken, observation)
|
|
"""
|
|
logger.info("Processing response with tools")
|
|
|
|
# Process the response to extract tool calls
|
|
tool_result = await self.process_model_response(response_text)
|
|
|
|
if tool_result and "tool_result" in tool_result:
|
|
# A tool was executed
|
|
result = tool_result["tool_result"]
|
|
if result.error:
|
|
return False, f"ERROR: {result.error}"
|
|
else:
|
|
return True, result.output or "Tool executed successfully"
|
|
|
|
# No action or tool call found
|
|
return False, "No action taken - no tool call detected in response"
|
|
|
|
###########################################
|
|
# UTILITY METHODS
|
|
###########################################
|
|
|
|
async def _ensure_tools_initialized(self) -> None:
|
|
"""Ensure the tool manager and tools are initialized before use."""
|
|
if not hasattr(self.tool_manager, "tools") or self.tool_manager.tools is None:
|
|
logger.info("Tools not initialized. Initializing now...")
|
|
await self.tool_manager.initialize()
|
|
logger.info("Tools initialized successfully.")
|
|
|
|
async def _execute_action_with_tools(
|
|
self, action_data: Dict[str, Any], parsed_screen: ParseResult
|
|
) -> Tuple[bool, bool]:
|
|
"""Execute an action using the tools-based approach.
|
|
|
|
Args:
|
|
action_data: Dictionary containing action details
|
|
parsed_screen: Current parsed screen information
|
|
|
|
Returns:
|
|
Tuple of (should_continue, action_screenshot_saved)
|
|
"""
|
|
action_screenshot_saved = False
|
|
action_type = None # Initialize for possible use in post-action screenshot
|
|
|
|
try:
|
|
# Extract the action
|
|
parsed_action = action_data.get("Action", "").lower()
|
|
|
|
# Only process if we have a valid action
|
|
if not parsed_action or parsed_action == "none":
|
|
return False, action_screenshot_saved
|
|
|
|
# Convert the parsed content to a format suitable for the tools system
|
|
tool_name = "computer" # Default to computer tool
|
|
tool_args = {"action": parsed_action}
|
|
|
|
# Add specific arguments based on action type
|
|
if parsed_action in ["left_click", "right_click", "double_click", "move_cursor"]:
|
|
# Calculate coordinates from Box ID using parser
|
|
try:
|
|
box_id = int(action_data["Box ID"])
|
|
x, y = await self.parser.calculate_click_coordinates(
|
|
box_id, cast(ParseResult, parsed_screen)
|
|
)
|
|
tool_args["x"] = x
|
|
tool_args["y"] = y
|
|
|
|
# Visualize action if screenshot is available
|
|
if parsed_screen and parsed_screen.annotated_image_base64:
|
|
img_data = parsed_screen.annotated_image_base64
|
|
# Remove data URL prefix if present
|
|
if img_data.startswith("data:image"):
|
|
img_data = img_data.split(",")[1]
|
|
# Save visualization for coordinate-based actions
|
|
self.viz_helper.visualize_action(x, y, img_data)
|
|
action_screenshot_saved = True
|
|
|
|
except (ValueError, KeyError) as e:
|
|
logger.error(f"Error processing Box ID: {str(e)}")
|
|
return False, action_screenshot_saved
|
|
|
|
elif parsed_action == "type_text":
|
|
tool_args["text"] = action_data.get("Value", "")
|
|
# For type_text, store the value in the action type for screenshot naming
|
|
action_type = f"type_{tool_args['text'][:20]}" # Truncate if too long
|
|
|
|
elif parsed_action == "press_key":
|
|
tool_args["key"] = action_data.get("Value", "")
|
|
action_type = f"press_{tool_args['key']}"
|
|
|
|
elif parsed_action == "hotkey":
|
|
value = action_data.get("Value", "")
|
|
if isinstance(value, list):
|
|
tool_args["keys"] = value
|
|
action_type = f"hotkey_{'_'.join(value)}"
|
|
else:
|
|
# Split string format like "command+space" into a list
|
|
keys = [k.strip() for k in value.lower().split("+")]
|
|
tool_args["keys"] = keys
|
|
action_type = f"hotkey_{value.replace('+', '_')}"
|
|
|
|
elif parsed_action in ["scroll_down", "scroll_up"]:
|
|
clicks = int(action_data.get("amount", 1))
|
|
tool_args["amount"] = clicks
|
|
action_type = f"scroll_{parsed_action.split('_')[1]}_{clicks}"
|
|
|
|
# Visualize scrolling if screenshot is available
|
|
if parsed_screen and parsed_screen.annotated_image_base64:
|
|
img_data = parsed_screen.annotated_image_base64
|
|
# Remove data URL prefix if present
|
|
if img_data.startswith("data:image"):
|
|
img_data = img_data.split(",")[1]
|
|
direction = "down" if parsed_action == "scroll_down" else "up"
|
|
# For scrolling, we save the visualization
|
|
self.viz_helper.visualize_scroll(direction, clicks, img_data)
|
|
action_screenshot_saved = True
|
|
|
|
# Ensure tools are initialized before use
|
|
await self._ensure_tools_initialized()
|
|
|
|
# Execute tool with prepared arguments
|
|
result = await self.tool_manager.execute_tool(name=tool_name, tool_input=tool_args)
|
|
|
|
# Take a new screenshot after the action if we haven't already saved one
|
|
if not action_screenshot_saved:
|
|
try:
|
|
# Get a new screenshot after the action
|
|
new_parsed_screen = await self._get_parsed_screen_som(save_screenshot=False)
|
|
if new_parsed_screen and new_parsed_screen.annotated_image_base64:
|
|
img_data = new_parsed_screen.annotated_image_base64
|
|
# Remove data URL prefix if present
|
|
if img_data.startswith("data:image"):
|
|
img_data = img_data.split(",")[1]
|
|
# Save with action type if defined, otherwise use the action name
|
|
if action_type:
|
|
self._save_screenshot(img_data, action_type=action_type)
|
|
else:
|
|
self._save_screenshot(img_data, action_type=parsed_action)
|
|
action_screenshot_saved = True
|
|
except Exception as screenshot_error:
|
|
logger.error(f"Error taking post-action screenshot: {str(screenshot_error)}")
|
|
|
|
# Continue the loop if the action is not "None"
|
|
return True, action_screenshot_saved
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error executing action: {str(e)}")
|
|
# Update the last assistant message with error
|
|
error_message = [{"type": "text", "text": f"Error executing action: {str(e)}"}]
|
|
# Replace the last assistant message with the error
|
|
self.message_manager.add_assistant_message(error_message)
|
|
return False, action_screenshot_saved
|