diff --git a/libs/agent/agent/core/base.py b/libs/agent/agent/core/base.py
index fb91d855..702be207 100644
--- a/libs/agent/agent/core/base.py
+++ b/libs/agent/agent/core/base.py
@@ -5,10 +5,12 @@ import asyncio
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Dict, List, Optional
+from agent.providers.omni.parser import ParseResult
from computer import Computer
from .messages import StandardMessageManager, ImageRetentionConfig
from .types import AgentResponse
from .experiment import ExperimentManager
+from .callbacks import CallbackManager, CallbackHandler
logger = logging.getLogger(__name__)
@@ -27,6 +29,7 @@ class BaseLoop(ABC):
base_dir: Optional[str] = "trajectories",
save_trajectory: bool = True,
only_n_most_recent_images: Optional[int] = 2,
+ callback_handlers: Optional[List[CallbackHandler]] = None,
**kwargs,
):
"""Initialize base agent loop.
@@ -75,6 +78,9 @@ class BaseLoop(ABC):
# Initialize basic tracking
self.turn_count = 0
+
+ # Initialize callback manager
+ self.callback_manager = CallbackManager(handlers=callback_handlers or [])
async def initialize(self) -> None:
"""Initialize both the API client and computer interface with retries."""
@@ -187,3 +193,17 @@ class BaseLoop(ABC):
"""
if self.experiment_manager:
self.experiment_manager.save_screenshot(img_base64, action_type)
+
+ ###########################################
+ # EVENT HOOKS / CALLBACKS
+ ###########################################
+
+ async def handle_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
+ """Process a screenshot through callback managers
+
+ Args:
+ screenshot_base64: Base64 encoded screenshot
+ action_type: Type of action that triggered the screenshot
+ """
+ if hasattr(self, 'callback_manager'):
+ await self.callback_manager.on_screenshot(screenshot_base64, action_type, parsed_screen)
diff --git a/libs/agent/agent/core/callbacks.py b/libs/agent/agent/core/callbacks.py
index 70eca5ad..59cd0e5a 100644
--- a/libs/agent/agent/core/callbacks.py
+++ b/libs/agent/agent/core/callbacks.py
@@ -6,6 +6,8 @@ from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, List, Optional, Protocol
+from agent.providers.omni.parser import ParseResult
+
logger = logging.getLogger(__name__)
class ContentCallback(Protocol):
@@ -20,6 +22,10 @@ class APICallback(Protocol):
"""Protocol for API callbacks."""
def __call__(self, request: Any, response: Any, error: Optional[Exception] = None) -> None: ...
+class ScreenshotCallback(Protocol):
+ """Protocol for screenshot callbacks."""
+ def __call__(self, screenshot_base64: str, action_type: str = "") -> Optional[str]: ...
+
class BaseCallbackManager(ABC):
"""Base class for callback managers."""
@@ -110,7 +116,20 @@ class CallbackManager:
"""
for handler in self.handlers:
await handler.on_error(error, **kwargs)
-
+
+ async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
+ """Called when a screenshot is taken.
+
+ Args:
+ screenshot_base64: Base64 encoded screenshot
+ action_type: Type of action that triggered the screenshot
+ parsed_screen: Optional output from parsing the screenshot
+
+ Returns:
+ Modified screenshot or original if no modifications
+ """
+ for handler in self.handlers:
+ await handler.on_screenshot(screenshot_base64, action_type, parsed_screen)
class CallbackHandler(ABC):
"""Base class for callback handlers."""
@@ -144,4 +163,40 @@ class CallbackHandler(ABC):
error: Exception that occurred
**kwargs: Additional data
"""
- pass
\ No newline at end of file
+ pass
+
+ @abstractmethod
+ async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
+ """Called when a screenshot is taken.
+
+ Args:
+ screenshot_base64: Base64 encoded screenshot
+ action_type: Type of action that triggered the screenshot
+
+ Returns:
+ Optional modified screenshot
+ """
+ pass
+
+class DefaultCallbackHandler(CallbackHandler):
+ """Default implementation of CallbackHandler with no-op methods.
+
+ This class implements all abstract methods from CallbackHandler,
+ allowing subclasses to override only the methods they need.
+ """
+
+ async def on_action_start(self, action: str, **kwargs) -> None:
+ """Default no-op implementation."""
+ pass
+
+ async def on_action_end(self, action: str, success: bool, **kwargs) -> None:
+ """Default no-op implementation."""
+ pass
+
+ async def on_error(self, error: Exception, **kwargs) -> None:
+ """Default no-op implementation."""
+ pass
+
+ async def on_screenshot(self, screenshot_base64: str, action_type: str = "") -> None:
+ """Default no-op implementation."""
+ pass
\ No newline at end of file
diff --git a/libs/agent/agent/providers/omni/clients/oaicompat.py b/libs/agent/agent/providers/omni/clients/oaicompat.py
index bddc95e2..6a95896a 100644
--- a/libs/agent/agent/providers/omni/clients/oaicompat.py
+++ b/libs/agent/agent/providers/omni/clients/oaicompat.py
@@ -45,8 +45,8 @@ class OAICompatClient(BaseOmniClient):
max_tokens: Maximum tokens to generate
temperature: Generation temperature
"""
- super().__init__(api_key="EMPTY", model=model)
- self.api_key = "EMPTY" # Local endpoints typically don't require an API key
+ super().__init__(api_key=api_key or "EMPTY", model=model)
+ self.api_key = api_key or "EMPTY" # Local endpoints typically don't require an API key
self.model = model
self.provider_base_url = (
provider_base_url or "http://localhost:8000/v1"
@@ -146,10 +146,18 @@ class OAICompatClient(BaseOmniClient):
base_url = self.provider_base_url or "http://localhost:8000/v1"
# Check if the base URL already includes the chat/completions endpoint
+
endpoint_url = base_url
if not endpoint_url.endswith("/chat/completions"):
+ # If URL is RunPod format, make it OpenAI compatible
+ if endpoint_url.startswith("https://api.runpod.ai/v2/"):
+ # Extract RunPod endpoint ID
+ parts = endpoint_url.split("/")
+ if len(parts) >= 5:
+ runpod_id = parts[4]
+ endpoint_url = f"https://api.runpod.ai/v2/{runpod_id}/openai/v1/chat/completions"
# If the URL ends with /v1, append /chat/completions
- if endpoint_url.endswith("/v1"):
+ elif endpoint_url.endswith("/v1"):
endpoint_url = f"{endpoint_url}/chat/completions"
# If the URL doesn't end with /v1, make sure it has a proper structure
elif not endpoint_url.endswith("/"):
diff --git a/libs/agent/agent/providers/omni/loop.py b/libs/agent/agent/providers/omni/loop.py
index 7fb80654..b53c120c 100644
--- a/libs/agent/agent/providers/omni/loop.py
+++ b/libs/agent/agent/providers/omni/loop.py
@@ -147,7 +147,7 @@ class OmniLoop(BaseLoop):
)
elif self.provider == LLMProvider.OAICOMPAT:
self.client = OAICompatClient(
- api_key="EMPTY", # Local endpoints typically don't require an API key
+ api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
model=self.model,
provider_base_url=self.provider_base_url,
)
@@ -183,7 +183,7 @@ class OmniLoop(BaseLoop):
)
elif self.provider == LLMProvider.OAICOMPAT:
self.client = OAICompatClient(
- api_key="EMPTY", # Local endpoints typically don't require an API key
+ api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
model=self.model,
provider_base_url=self.provider_base_url,
)
@@ -548,6 +548,10 @@ class OmniLoop(BaseLoop):
img_data = parsed_screen.annotated_image_base64
if "," in img_data:
img_data = img_data.split(",")[1]
+
+ # Process screenshot through hooks and save if needed
+ await self.handle_screenshot(img_data, action_type="state", parsed_screen=parsed_screen)
+
# Save with a generic "state" action type to indicate this is the current screen state
self._save_screenshot(img_data, action_type="state")
except Exception as e:
@@ -663,6 +667,8 @@ class OmniLoop(BaseLoop):
response=response,
messages=self.message_manager.messages,
model=self.model,
+ parsed_screen=parsed_screen,
+ parser=self.parser
)
# Yield the response to the caller
diff --git a/libs/agent/agent/providers/openai/loop.py b/libs/agent/agent/providers/openai/loop.py
index cfb9a443..8e507a1b 100644
--- a/libs/agent/agent/providers/openai/loop.py
+++ b/libs/agent/agent/providers/openai/loop.py
@@ -194,8 +194,13 @@ class OpenAILoop(BaseLoop):
# Convert to base64 if needed
if isinstance(screenshot, bytes):
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
+ elif isinstance(screenshot, (bytearray, memoryview)):
+ screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
else:
- screenshot_base64 = screenshot
+ screenshot_base64 = str(screenshot)
+
+ # Emit screenshot callbacks
+ await self.handle_screenshot(screenshot_base64, action_type="initial_state")
# Save screenshot if requested
if self.save_trajectory:
@@ -204,8 +209,6 @@ class OpenAILoop(BaseLoop):
logger.warning(
"Converting non-string screenshot_base64 to string for _save_screenshot"
)
- if isinstance(screenshot_base64, (bytearray, memoryview)):
- screenshot_base64 = base64.b64encode(screenshot_base64).decode("utf-8")
self._save_screenshot(screenshot_base64, action_type="state")
logger.info("Screenshot saved to trajectory")
@@ -336,8 +339,14 @@ class OpenAILoop(BaseLoop):
screenshot = await self.computer.interface.screenshot()
if isinstance(screenshot, bytes):
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
+ elif isinstance(screenshot, (bytearray, memoryview)):
+ screenshot_base64 = base64.b64encode(bytes(screenshot)).decode("utf-8")
else:
- screenshot_base64 = screenshot
+ screenshot_base64 = str(screenshot)
+
+ # Process screenshot through hooks
+ action_type = f"after_{action.get('type', 'action')}"
+ await self.handle_screenshot(screenshot_base64, action_type=action_type)
# Create computer_call_output
computer_call_output = {
diff --git a/libs/agent/agent/ui/gradio/app.py b/libs/agent/agent/ui/gradio/app.py
index 44027317..a3a017bd 100644
--- a/libs/agent/agent/ui/gradio/app.py
+++ b/libs/agent/agent/ui/gradio/app.py
@@ -32,9 +32,12 @@ import asyncio
import logging
from typing import Dict, List, Optional, AsyncGenerator, Any, Tuple, Union
import gradio as gr
+from gradio.components.chatbot import MetadataDict
# Import from agent package
from agent.core.types import AgentResponse
+from agent.core.callbacks import DefaultCallbackHandler
+from agent.providers.omni.parser import ParseResult
from computer import Computer
from agent import ComputerAgent, AgentLoop, LLM, LLMProvider
@@ -43,6 +46,44 @@ from agent import ComputerAgent, AgentLoop, LLM, LLMProvider
global_agent = None
global_computer = None
+# We'll use asyncio.run() instead of a persistent event loop
+
+
+# Custom Screenshot Handler for Gradio chat
+class GradioChatScreenshotHandler(DefaultCallbackHandler):
+ """Custom handler that adds screenshots to the Gradio chatbot and updates annotated image."""
+
+ def __init__(self, chatbot_history: List[gr.ChatMessage]):
+ """Initialize with reference to chat history and annotated image component.
+
+ Args:
+ chatbot_history: Reference to the Gradio chatbot history list
+ annotated_image: Reference to the annotated image component
+ """
+ self.chatbot_history = chatbot_history
+ print("GradioChatScreenshotHandler initialized")
+
+ async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
+ """Add screenshot to chatbot when a screenshot is taken and update the annotated image.
+
+ Args:
+ screenshot_base64: Base64 encoded screenshot
+ action_type: Type of action that triggered the screenshot
+
+ Returns:
+ Original screenshot (does not modify it)
+ """
+ # Create a markdown image element for the screenshot
+ image_markdown = f""
+
+ # Simply append the screenshot as a new message
+ if self.chatbot_history is not None:
+ self.chatbot_history.append(gr.ChatMessage(
+ role="assistant",
+ content=image_markdown,
+ metadata={"title": f"🖥️ Screenshot - {action_type}", "status": "done"}
+ ))
+
# Map model names to specific provider model names
MODEL_MAPPINGS = {
"openai": {
@@ -177,17 +218,18 @@ def get_ollama_models() -> List[str]:
return []
-def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> str:
+def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> Tuple[str, MetadataDict]:
"""Extract synthesized text from the agent result."""
synthesized_text = ""
+ metadata = MetadataDict()
if "output" in result and result["output"]:
for output in result["output"]:
if output.get("type") == "reasoning":
+ metadata["title"] = "🧠 Reasoning"
content = output.get("content", "")
if content:
synthesized_text += f"{content}\n"
-
elif output.get("type") == "message":
# Handle message type outputs - can contain rich content
content = output.get("content", [])
@@ -200,7 +242,7 @@ def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> st
if text_value:
synthesized_text += f"{text_value}\n"
- elif output.get("type") == "computer_call":
+ elif output.get("type") == "computer_call":
action = output.get("action", {})
action_type = action.get("type", "")
@@ -223,8 +265,11 @@ def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> st
synthesized_text += f"Pressed key: {key}\n"
else:
synthesized_text += f"Performed {action_type} action.\n"
-
- return synthesized_text.strip()
+
+ metadata["status"] = "done"
+ metadata["title"] = f"🛠️ {synthesized_text.strip().splitlines()[-1]}"
+
+ return synthesized_text.strip(), metadata
def create_computer_instance(verbosity: int = logging.INFO) -> Computer:
@@ -247,6 +292,7 @@ def create_agent(
verbosity: int = logging.INFO,
use_ollama: bool = False,
use_oaicompat: bool = False,
+ provider_base_url: Optional[str] = None,
) -> ComputerAgent:
"""Create or update the global agent with the specified parameters."""
global global_agent
@@ -270,72 +316,55 @@ def create_agent(
elif provider == LLMProvider.ANTHROPIC:
api_key = os.environ.get("ANTHROPIC_API_KEY", "")
- # Create LLM model object with appropriate parameters
- provider_base_url = "http://localhost:1234/v1" if use_oaicompat else None
+ # Use provided provider_base_url if available, otherwise use default
+ default_base_url = "http://localhost:1234/v1" if use_oaicompat else None
+ custom_base_url = provider_base_url or default_base_url
if use_oaicompat:
# Special handling for OAICOMPAT - use OAICOMPAT provider with custom base URL
print(
- f"DEBUG - Creating OAICOMPAT agent with model: {model_name}, URL: {provider_base_url}"
+ f"DEBUG - Creating OAICOMPAT agent with model: {model_name}, URL: {custom_base_url}"
)
llm = LLM(
provider=LLMProvider.OAICOMPAT, # Set to OAICOMPAT instead of using original provider
name=model_name,
- provider_base_url=provider_base_url,
+ provider_base_url=custom_base_url,
)
print(f"DEBUG - LLM provider is now: {llm.provider}, base URL: {llm.provider_base_url}")
# Note: Don't pass use_oaicompat to the agent, as it doesn't accept this parameter
elif provider == LLMProvider.OAICOMPAT:
# This path is unlikely to be taken with our current approach
- llm = LLM(provider=provider, name=model_name, provider_base_url=provider_base_url)
+ llm = LLM(provider=provider, name=model_name, provider_base_url=custom_base_url)
else:
# For other providers, just use standard parameters
llm = LLM(provider=provider, name=model_name)
# Create or update the agent
- if global_agent is None:
- global_agent = ComputerAgent(
- computer=computer,
- loop=agent_loop,
- model=llm,
- api_key=api_key,
- save_trajectory=save_trajectory,
- only_n_most_recent_images=only_n_most_recent_images,
- verbosity=verbosity,
- **extra_config,
- )
- else:
- # Update the existing agent's parameters
- global_agent._loop = None # Force recreation of the loop
- global_agent.provider = provider
- global_agent.loop = agent_loop
- global_agent.model = llm
- global_agent.api_key = api_key
-
- # Explicitly update these settings to ensure they take effect
- global_agent.save_trajectory = save_trajectory
- global_agent.only_n_most_recent_images = only_n_most_recent_images
-
- # Update Ollama settings if applicable
- if use_ollama:
- global_agent.use_ollama = True
- global_agent.ollama_model = model_name
- else:
- global_agent.use_ollama = False
- global_agent.ollama_model = None
-
- # Log the updated settings
- logging.info(
- f"Updated agent settings: save_trajectory={save_trajectory}, recent_images={only_n_most_recent_images}"
- )
+ global_agent = ComputerAgent(
+ computer=computer,
+ loop=agent_loop,
+ model=llm,
+ api_key=api_key,
+ save_trajectory=save_trajectory,
+ only_n_most_recent_images=only_n_most_recent_images,
+ verbosity=verbosity,
+ **extra_config,
+ )
return global_agent
-def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> str:
+def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> Tuple[str, MetadataDict]:
"""Process agent results for the Gradio UI."""
# Extract text content
text_obj = result.get("text", {})
+ metadata = result.get("metadata", {})
+
+ # Create a properly typed MetadataDict
+ metadata_dict = MetadataDict()
+ metadata_dict["title"] = metadata.get("title", "")
+ metadata_dict["status"] = "done"
+ metadata = metadata_dict
# For OpenAI's Computer-Use Agent, text field is an object with format property
if (
@@ -344,8 +373,11 @@ def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> str:
and "format" in text_obj
and not text_obj.get("value", "")
):
- content = extract_synthesized_text(result)
+ content, metadata = extract_synthesized_text(result)
else:
+ if not text_obj:
+ text_obj = result
+
# For other types of results, try to get text directly
if isinstance(text_obj, dict):
if "value" in text_obj:
@@ -378,180 +410,7 @@ def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> str:
if not isinstance(content, str):
content = str(content) if content else ""
- return content
-
-
-def respond(
- message: str,
- history: List[Tuple[str, str]],
- model_choice, # Accept Gradio Dropdown component
- agent_loop, # Accept Gradio Dropdown component
- save_trajectory, # Accept Gradio Checkbox component
- recent_images, # Accept Gradio Slider component
- openai_api_key: Optional[str] = None,
- anthropic_api_key: Optional[str] = None,
-) -> str:
- """Process a message with the Computer-Use Agent and return the response."""
- import asyncio
-
- # Get actual values from Gradio components
- model_choice_value = model_choice.value if hasattr(model_choice, "value") else model_choice
- agent_loop_value = agent_loop.value if hasattr(agent_loop, "value") else agent_loop
- save_trajectory_value = (
- save_trajectory.value if hasattr(save_trajectory, "value") else save_trajectory
- )
- recent_images_value = int(
- recent_images.value if hasattr(recent_images, "value") else recent_images
- )
-
- # Debug logging
- print(f"DEBUG - Model choice object: {type(model_choice)}")
- print(f"DEBUG - Model choice value: {model_choice_value}")
- print(f"DEBUG - Agent loop value: {agent_loop_value}")
-
- # Create a new event loop for this function call
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- async def _async_respond():
- # Extract the loop type and model from the selection
- loop_provider = "OPENAI"
- if isinstance(model_choice_value, str):
- # This is the case for a custom text input from textbox
- if agent_loop_value == "OMNI":
- loop_provider = "OMNI"
- # Use the custom model name as is
- model_id = model_choice_value
- print(f"DEBUG - Using custom model: {model_id}")
- else:
- # Handle regular dropdown value as string
- if model_choice_value.startswith("OpenAI:"):
- loop_provider = "OPENAI"
- model_id = model_choice_value.replace("OpenAI: ", "").lower()
- elif model_choice_value.startswith("Anthropic:"):
- loop_provider = "ANTHROPIC"
- model_id = model_choice_value.replace("Anthropic: ", "").lower()
- elif model_choice_value.startswith("OMNI:"):
- loop_provider = "OMNI"
- if "GPT" in model_choice_value:
- model_id = model_choice_value.replace("OMNI: OpenAI ", "").lower()
- elif "Claude" in model_choice_value:
- model_id = model_choice_value.replace("OMNI: ", "").lower()
- elif "Ollama" in model_choice_value:
- loop_provider = "OMNI-OLLAMA"
- # Extract everything after "OMNI: Ollama " which is the full model name (e.g., phi3:latest)
- model_id = model_choice_value.replace("OMNI: Ollama ", "")
- print(f"DEBUG - Ollama model ID: {model_id}")
- else:
- model_id = "default"
- else:
- # Default case
- loop_provider = agent_loop_value
- model_id = "default"
- else:
- # Model choice is not a string (shouldn't happen, but handle anyway)
- loop_provider = agent_loop_value
- model_id = "default"
-
- print(f"DEBUG - Using loop provider: {loop_provider}, model_id: {model_id}")
-
- # Use the mapping function to get provider, model name and agent loop
- provider, model_name, agent_loop_type = get_provider_and_model(model_id, loop_provider)
- print(
- f"DEBUG - After mapping: provider={provider}, model_name={model_name}, agent_loop={agent_loop_type}"
- )
-
- # Special handling for OAICOMPAT to bypass provider-specific errors
- # Creates the agent with OPENAI provider but using custom model name and provider base URL
- is_oaicompat = str(provider) == "oaicompat"
-
- # Don't override the provider for OAICOMPAT - instead pass it through
- # if is_oaicompat:
- # provider = LLMProvider.OPENAI
-
- # Get API key based on provider
- if provider == LLMProvider.OPENAI:
- api_key = openai_api_key or os.environ.get("OPENAI_API_KEY", "")
- elif provider == LLMProvider.ANTHROPIC:
- api_key = anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY", "")
- else:
- api_key = ""
-
- # Check for settings changes if agent already exists
- settings_changed = False
- settings_message = ""
- if global_agent is not None:
- # Safely check if save_trajectory setting changed
- current_save_traj = getattr(global_agent, "save_trajectory", None)
- if current_save_traj is not None and current_save_traj != save_trajectory_value:
- settings_changed = True
- settings_message += f"Save trajectory set to: {save_trajectory_value}. "
-
- # Safely check if recent_images setting changed
- current_recent_images = getattr(global_agent, "only_n_most_recent_images", None)
- if current_recent_images is not None and current_recent_images != recent_images_value:
- settings_changed = True
- settings_message += f"Recent images set to: {recent_images_value}. "
-
- # Create or update the agent
- try:
- create_agent(
- provider=provider,
- agent_loop=agent_loop_type,
- model_name=model_name,
- api_key=api_key,
- save_trajectory=save_trajectory_value,
- only_n_most_recent_images=recent_images_value,
- use_ollama=loop_provider == "OMNI-OLLAMA",
- use_oaicompat=is_oaicompat,
- )
-
- if global_agent is None:
- return "Failed to create agent. Check API keys and configuration."
- except Exception as e:
- return f"Error creating agent: {str(e)}"
-
- # Notify about settings changes if needed
- if settings_changed:
- return f"Settings updated: {settings_message}"
-
- # Collect all responses
- response_text = []
-
- # Run the agent
- try:
- async for result in global_agent.run(message):
- # Process result
- content = process_agent_result(result)
-
- # Skip empty content
- if not content:
- continue
-
- # Add content to response list
- response_text.append(content)
-
- # Return the full response as a single string
- return "\n".join(response_text) if response_text else "Task completed."
-
- except Exception as e:
- import traceback
-
- traceback.print_exc()
- return f"Error: {str(e)}"
-
- # Run the async function and get the result
- try:
- result = loop.run_until_complete(_async_respond())
- loop.close()
- return result
- except Exception as e:
- loop.close()
- import traceback
-
- traceback.print_exc()
- return f"Error executing async operation: {str(e)}"
-
+ return content, metadata
def create_gradio_ui(
provider_name: str = "openai",
@@ -725,48 +584,68 @@ def create_gradio_ui(
"""
)
- # Configuration options
- agent_loop = gr.Dropdown(
- choices=["OPENAI", "ANTHROPIC", "OMNI"],
- label="Agent Loop",
- value=initial_loop,
- info="Select the agent loop provider",
- )
+ with gr.Accordion("Configuration", open=True):
+ # Configuration options
+ agent_loop = gr.Dropdown(
+ choices=["OPENAI", "ANTHROPIC", "OMNI"],
+ label="Agent Loop",
+ value=initial_loop,
+ info="Select the agent loop provider",
+ )
- # Create model selection dropdown with custom value support for OMNI
- model_choice = gr.Dropdown(
- choices=provider_to_models.get(initial_loop, ["No models available"]),
- label="LLM Provider and Model",
- value=initial_model,
- info="Select model or choose 'Custom model...' to enter a custom name",
- interactive=True,
- )
+ # Create model selection dropdown with custom value support for OMNI
+ model_choice = gr.Dropdown(
+ choices=provider_to_models.get(initial_loop, ["No models available"]),
+ label="LLM Provider and Model",
+ value=initial_model,
+ info="Select model or choose 'Custom model...' to enter a custom name",
+ interactive=True,
+ )
- # Add custom model textbox (only visible when "Custom model..." is selected)
- custom_model = gr.Textbox(
- label="Custom Model Name",
- placeholder="Enter custom model name (e.g., Qwen2.5-VL-7B-Instruct)",
- value="Qwen2.5-VL-7B-Instruct", # Default value
- visible=False, # Initially hidden
- interactive=True,
- )
+ # Add custom model textbox (only visible when "Custom model..." is selected)
+ custom_model = gr.Textbox(
+ label="Custom Model Name",
+ placeholder="Enter custom model name (e.g., Qwen2.5-VL-7B-Instruct)",
+ value="Qwen2.5-VL-7B-Instruct", # Default value
+ visible=False, # Initially hidden
+ interactive=True,
+ )
+
+ # Add custom provider base URL textbox (only visible when "Custom model..." is selected)
+ provider_base_url = gr.Textbox(
+ label="Provider Base URL",
+ placeholder="Enter provider base URL (e.g., http://localhost:1234/v1)",
+ value="http://localhost:1234/v1", # Default value
+ visible=False, # Initially hidden
+ interactive=True,
+ )
+
+ # Add custom API key textbox (only visible when "Custom model..." is selected)
+ provider_api_key = gr.Textbox(
+ label="Provider API Key",
+ placeholder="Enter provider API key (if required)",
+ value="", # Default empty value
+ visible=False, # Initially hidden
+ interactive=True,
+ type="password", # Hide the API key
+ )
- save_trajectory = gr.Checkbox(
- label="Save Trajectory",
- value=True,
- info="Save the agent's trajectory for debugging",
- interactive=True,
- )
+ save_trajectory = gr.Checkbox(
+ label="Save Trajectory",
+ value=True,
+ info="Save the agent's trajectory for debugging",
+ interactive=True,
+ )
- recent_images = gr.Slider(
- label="Recent Images",
- minimum=1,
- maximum=10,
- value=3,
- step=1,
- info="Number of recent images to keep in context",
- interactive=True,
- )
+ recent_images = gr.Slider(
+ label="Recent Images",
+ minimum=1,
+ maximum=10,
+ value=3,
+ step=1,
+ info="Number of recent images to keep in context",
+ interactive=True,
+ )
# Right column for chat interface
with gr.Column(scale=2):
@@ -775,7 +654,7 @@ def create_gradio_ui(
"Ask me to perform tasks in a virtual macOS environment.
Built with github.com/trycua/cua."
)
- chatbot = gr.Chatbot()
+ chatbot_history = gr.Chatbot(type='messages')
msg = gr.Textbox(
placeholder="Ask me to perform tasks in a virtual macOS environment"
)
@@ -787,24 +666,27 @@ def create_gradio_ui(
# Function to handle chat submission
def chat_submit(message, history):
# Add user message to history
- history = history + [(message, None)]
+ history.append(gr.ChatMessage(role="user", content=message))
return "", history
# Function to process agent response after user input
- def process_response(
+ async def process_response(
history,
model_choice_value,
custom_model_value,
agent_loop_choice,
save_traj,
recent_imgs,
+ custom_url_value=None,
+ custom_api_key=None,
):
if not history:
- return history
+ yield history
+ return
# Get the last user message
- last_user_message = history[-1][0]
-
+ last_user_message = history[-1]['content']
+
# Use custom model value if "Custom model..." is selected
model_to_use = (
custom_model_value
@@ -812,38 +694,94 @@ def create_gradio_ui(
else model_choice_value
)
- # Process with agent
- response = respond(
- last_user_message,
- history[:-1], # History without the last message
- model_to_use,
- agent_loop_choice,
- save_traj,
- recent_imgs,
- openai_api_key,
- anthropic_api_key,
- )
-
- # Update the last assistant message
- history[-1] = (last_user_message, response)
- return history
-
+ try:
+ # Get the model, agent loop, and provider
+ provider, model_name, agent_loop_type = get_provider_and_model(
+ model_to_use, agent_loop_choice
+ )
+
+ # Special handling for OAICOMPAT
+ is_oaicompat = str(provider) == "oaicompat"
+
+ # Get API key based on provider
+ if model_choice_value == "Custom model..." and custom_api_key:
+ # Use custom API key if provided for custom model
+ api_key = custom_api_key
+ print(f"DEBUG - Using custom API key for model: {model_name}")
+ elif provider == LLMProvider.OPENAI:
+ api_key = openai_api_key or os.environ.get("OPENAI_API_KEY", "")
+ elif provider == LLMProvider.ANTHROPIC:
+ api_key = anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY", "")
+ else:
+ api_key = ""
+
+ # Create or update the agent
+ create_agent(
+ provider=provider,
+ agent_loop=agent_loop_type,
+ model_name=model_name,
+ api_key=api_key,
+ save_trajectory=save_traj,
+ only_n_most_recent_images=recent_imgs,
+ use_ollama=agent_loop_choice == "OMNI-OLLAMA",
+ use_oaicompat=is_oaicompat,
+ provider_base_url=custom_url_value if is_oaicompat and model_choice_value == "Custom model..." else None,
+ )
+
+ if global_agent is None:
+ # Add initial empty assistant message
+ history.append(gr.ChatMessage(role="assistant", content="Failed to create agent. Check API keys and configuration."))
+ yield history
+ return
+
+ # Add the screenshot handler to the agent's loop if available
+ if global_agent and hasattr(global_agent, "_loop"):
+ print("DEBUG - Adding screenshot handler to agent loop")
+
+ # Create the screenshot handler with references to UI components
+ screenshot_handler = GradioChatScreenshotHandler(
+ history
+ )
+
+ # Add the handler to the callback manager if it exists
+ if hasattr(global_agent._loop, "callback_manager"):
+ global_agent._loop.callback_manager.add_handler(screenshot_handler)
+ print(f"DEBUG - Screenshot handler added to callback manager with history: {id(history)}")
+
+ # Stream responses from the agent
+ async for result in global_agent.run(last_user_message):
+ # Process result
+ content, metadata = process_agent_result(result)
+
+ # Skip empty content
+ if content or metadata.get("title"):
+ history.append(gr.ChatMessage(role="assistant", content=content, metadata=metadata))
+ yield history
+ except Exception as e:
+ import traceback
+ traceback.print_exc()
+ # Update with error message
+ history.append(gr.ChatMessage(role="assistant", content=f"Error: {str(e)}"))
+ yield history
+
# Connect the components
- msg.submit(chat_submit, [msg, chatbot], [msg, chatbot]).then(
+ msg.submit(chat_submit, [msg, chatbot_history], [msg, chatbot_history]).then(
process_response,
[
- chatbot,
+ chatbot_history,
model_choice,
custom_model,
agent_loop,
save_trajectory,
recent_images,
+ provider_base_url,
+ provider_api_key,
],
- [chatbot],
+ [chatbot_history],
)
# Clear button functionality
- clear.click(lambda: None, None, chatbot, queue=False)
+ clear.click(lambda: None, None, chatbot_history, queue=False)
# Connect agent_loop changes to model selection
agent_loop.change(
@@ -853,14 +791,15 @@ def create_gradio_ui(
queue=False, # Process immediately without queueing
)
- # Show/hide custom model textbox based on dropdown selection
+ # Show/hide custom model, provider base URL, and API key textboxes based on dropdown selection
def update_custom_model_visibility(model_value):
- return gr.update(visible=model_value == "Custom model...")
+ is_custom = model_value == "Custom model..."
+ return gr.update(visible=is_custom), gr.update(visible=is_custom), gr.update(visible=is_custom)
model_choice.change(
fn=update_custom_model_visibility,
inputs=[model_choice],
- outputs=[custom_model],
+ outputs=[custom_model, provider_base_url, provider_api_key],
queue=False, # Process immediately without queueing
)
diff --git a/libs/som/som/ocr.py b/libs/som/som/ocr.py
index 6d10e85a..32f15bd1 100644
--- a/libs/som/som/ocr.py
+++ b/libs/som/som/ocr.py
@@ -17,17 +17,28 @@ class TimeoutException(Exception):
@contextmanager
def timeout(seconds: int):
- def timeout_handler(signum, frame):
- raise TimeoutException("OCR process timed out")
+ import threading
+
+ # Check if we're in the main thread
+ if threading.current_thread() is threading.main_thread():
+ def timeout_handler(signum, frame):
+ raise TimeoutException("OCR process timed out")
- original_handler = signal.signal(signal.SIGALRM, timeout_handler)
- signal.alarm(seconds)
+ original_handler = signal.signal(signal.SIGALRM, timeout_handler)
+ signal.alarm(seconds)
- try:
- yield
- finally:
- signal.alarm(0)
- signal.signal(signal.SIGALRM, original_handler)
+ try:
+ yield
+ finally:
+ signal.alarm(0)
+ signal.signal(signal.SIGALRM, original_handler)
+ else:
+ # In a non-main thread, we can't use signal
+ logger.warning("Timeout function called from non-main thread; signal-based timeout disabled")
+ try:
+ yield
+ finally:
+ pass
class OCRProcessor: