Merge pull request #104 from ddupont808/feature/gradio-upgrade

[Agent] Improved Gradio UI
This commit is contained in:
f-trycua
2025-04-14 03:54:41 +02:00
committed by GitHub
7 changed files with 356 additions and 308 deletions
+20
View File
@@ -5,10 +5,12 @@ import asyncio
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Dict, List, Optional
from agent.providers.omni.parser import ParseResult
from computer import Computer
from .messages import StandardMessageManager, ImageRetentionConfig
from .types import AgentResponse
from .experiment import ExperimentManager
from .callbacks import CallbackManager, CallbackHandler
logger = logging.getLogger(__name__)
@@ -27,6 +29,7 @@ class BaseLoop(ABC):
base_dir: Optional[str] = "trajectories",
save_trajectory: bool = True,
only_n_most_recent_images: Optional[int] = 2,
callback_handlers: Optional[List[CallbackHandler]] = None,
**kwargs,
):
"""Initialize base agent loop.
@@ -75,6 +78,9 @@ class BaseLoop(ABC):
# Initialize basic tracking
self.turn_count = 0
# Initialize callback manager
self.callback_manager = CallbackManager(handlers=callback_handlers or [])
async def initialize(self) -> None:
"""Initialize both the API client and computer interface with retries."""
@@ -187,3 +193,17 @@ class BaseLoop(ABC):
"""
if self.experiment_manager:
self.experiment_manager.save_screenshot(img_base64, action_type)
###########################################
# EVENT HOOKS / CALLBACKS
###########################################
async def handle_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
"""Process a screenshot through callback managers
Args:
screenshot_base64: Base64 encoded screenshot
action_type: Type of action that triggered the screenshot
"""
if hasattr(self, 'callback_manager'):
await self.callback_manager.on_screenshot(screenshot_base64, action_type, parsed_screen)
+57 -2
View File
@@ -6,6 +6,8 @@ from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, List, Optional, Protocol
from agent.providers.omni.parser import ParseResult
logger = logging.getLogger(__name__)
class ContentCallback(Protocol):
@@ -20,6 +22,10 @@ class APICallback(Protocol):
"""Protocol for API callbacks."""
def __call__(self, request: Any, response: Any, error: Optional[Exception] = None) -> None: ...
class ScreenshotCallback(Protocol):
"""Protocol for screenshot callbacks."""
def __call__(self, screenshot_base64: str, action_type: str = "") -> Optional[str]: ...
class BaseCallbackManager(ABC):
"""Base class for callback managers."""
@@ -110,7 +116,20 @@ class CallbackManager:
"""
for handler in self.handlers:
await handler.on_error(error, **kwargs)
async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
"""Called when a screenshot is taken.
Args:
screenshot_base64: Base64 encoded screenshot
action_type: Type of action that triggered the screenshot
parsed_screen: Optional output from parsing the screenshot
Returns:
Modified screenshot or original if no modifications
"""
for handler in self.handlers:
await handler.on_screenshot(screenshot_base64, action_type, parsed_screen)
class CallbackHandler(ABC):
"""Base class for callback handlers."""
@@ -144,4 +163,40 @@ class CallbackHandler(ABC):
error: Exception that occurred
**kwargs: Additional data
"""
pass
pass
@abstractmethod
async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
"""Called when a screenshot is taken.
Args:
screenshot_base64: Base64 encoded screenshot
action_type: Type of action that triggered the screenshot
Returns:
Optional modified screenshot
"""
pass
class DefaultCallbackHandler(CallbackHandler):
"""Default implementation of CallbackHandler with no-op methods.
This class implements all abstract methods from CallbackHandler,
allowing subclasses to override only the methods they need.
"""
async def on_action_start(self, action: str, **kwargs) -> None:
"""Default no-op implementation."""
pass
async def on_action_end(self, action: str, success: bool, **kwargs) -> None:
"""Default no-op implementation."""
pass
async def on_error(self, error: Exception, **kwargs) -> None:
"""Default no-op implementation."""
pass
async def on_screenshot(self, screenshot_base64: str, action_type: str = "") -> None:
"""Default no-op implementation."""
pass
@@ -45,8 +45,8 @@ class OAICompatClient(BaseOmniClient):
max_tokens: Maximum tokens to generate
temperature: Generation temperature
"""
super().__init__(api_key="EMPTY", model=model)
self.api_key = "EMPTY" # Local endpoints typically don't require an API key
super().__init__(api_key=api_key or "EMPTY", model=model)
self.api_key = api_key or "EMPTY" # Local endpoints typically don't require an API key
self.model = model
self.provider_base_url = (
provider_base_url or "http://localhost:8000/v1"
@@ -146,10 +146,18 @@ class OAICompatClient(BaseOmniClient):
base_url = self.provider_base_url or "http://localhost:8000/v1"
# Check if the base URL already includes the chat/completions endpoint
endpoint_url = base_url
if not endpoint_url.endswith("/chat/completions"):
# If URL is RunPod format, make it OpenAI compatible
if endpoint_url.startswith("https://api.runpod.ai/v2/"):
# Extract RunPod endpoint ID
parts = endpoint_url.split("/")
if len(parts) >= 5:
runpod_id = parts[4]
endpoint_url = f"https://api.runpod.ai/v2/{runpod_id}/openai/v1/chat/completions"
# If the URL ends with /v1, append /chat/completions
if endpoint_url.endswith("/v1"):
elif endpoint_url.endswith("/v1"):
endpoint_url = f"{endpoint_url}/chat/completions"
# If the URL doesn't end with /v1, make sure it has a proper structure
elif not endpoint_url.endswith("/"):
+8 -2
View File
@@ -147,7 +147,7 @@ class OmniLoop(BaseLoop):
)
elif self.provider == LLMProvider.OAICOMPAT:
self.client = OAICompatClient(
api_key="EMPTY", # Local endpoints typically don't require an API key
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
model=self.model,
provider_base_url=self.provider_base_url,
)
@@ -183,7 +183,7 @@ class OmniLoop(BaseLoop):
)
elif self.provider == LLMProvider.OAICOMPAT:
self.client = OAICompatClient(
api_key="EMPTY", # Local endpoints typically don't require an API key
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
model=self.model,
provider_base_url=self.provider_base_url,
)
@@ -548,6 +548,10 @@ class OmniLoop(BaseLoop):
img_data = parsed_screen.annotated_image_base64
if "," in img_data:
img_data = img_data.split(",")[1]
# Process screenshot through hooks and save if needed
await self.handle_screenshot(img_data, action_type="state", parsed_screen=parsed_screen)
# Save with a generic "state" action type to indicate this is the current screen state
self._save_screenshot(img_data, action_type="state")
except Exception as e:
@@ -663,6 +667,8 @@ class OmniLoop(BaseLoop):
response=response,
messages=self.message_manager.messages,
model=self.model,
parsed_screen=parsed_screen,
parser=self.parser
)
# Yield the response to the caller
+13 -4
View File
@@ -194,8 +194,13 @@ class OpenAILoop(BaseLoop):
# Convert to base64 if needed
if isinstance(screenshot, bytes):
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
elif isinstance(screenshot, (bytearray, memoryview)):
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
else:
screenshot_base64 = screenshot
screenshot_base64 = str(screenshot)
# Emit screenshot callbacks
await self.handle_screenshot(screenshot_base64, action_type="initial_state")
# Save screenshot if requested
if self.save_trajectory:
@@ -204,8 +209,6 @@ class OpenAILoop(BaseLoop):
logger.warning(
"Converting non-string screenshot_base64 to string for _save_screenshot"
)
if isinstance(screenshot_base64, (bytearray, memoryview)):
screenshot_base64 = base64.b64encode(screenshot_base64).decode("utf-8")
self._save_screenshot(screenshot_base64, action_type="state")
logger.info("Screenshot saved to trajectory")
@@ -336,8 +339,14 @@ class OpenAILoop(BaseLoop):
screenshot = await self.computer.interface.screenshot()
if isinstance(screenshot, bytes):
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
elif isinstance(screenshot, (bytearray, memoryview)):
screenshot_base64 = base64.b64encode(bytes(screenshot)).decode("utf-8")
else:
screenshot_base64 = screenshot
screenshot_base64 = str(screenshot)
# Process screenshot through hooks
action_type = f"after_{action.get('type', 'action')}"
await self.handle_screenshot(screenshot_base64, action_type=action_type)
# Create computer_call_output
computer_call_output = {
+227 -288
View File
@@ -32,9 +32,12 @@ import asyncio
import logging
from typing import Dict, List, Optional, AsyncGenerator, Any, Tuple, Union
import gradio as gr
from gradio.components.chatbot import MetadataDict
# Import from agent package
from agent.core.types import AgentResponse
from agent.core.callbacks import DefaultCallbackHandler
from agent.providers.omni.parser import ParseResult
from computer import Computer
from agent import ComputerAgent, AgentLoop, LLM, LLMProvider
@@ -43,6 +46,44 @@ from agent import ComputerAgent, AgentLoop, LLM, LLMProvider
global_agent = None
global_computer = None
# We'll use asyncio.run() instead of a persistent event loop
# Custom Screenshot Handler for Gradio chat
class GradioChatScreenshotHandler(DefaultCallbackHandler):
"""Custom handler that adds screenshots to the Gradio chatbot and updates annotated image."""
def __init__(self, chatbot_history: List[gr.ChatMessage]):
"""Initialize with reference to chat history and annotated image component.
Args:
chatbot_history: Reference to the Gradio chatbot history list
annotated_image: Reference to the annotated image component
"""
self.chatbot_history = chatbot_history
print("GradioChatScreenshotHandler initialized")
async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
"""Add screenshot to chatbot when a screenshot is taken and update the annotated image.
Args:
screenshot_base64: Base64 encoded screenshot
action_type: Type of action that triggered the screenshot
Returns:
Original screenshot (does not modify it)
"""
# Create a markdown image element for the screenshot
image_markdown = f"![Screenshot after {action_type}](data:image/png;base64,{screenshot_base64})"
# Simply append the screenshot as a new message
if self.chatbot_history is not None:
self.chatbot_history.append(gr.ChatMessage(
role="assistant",
content=image_markdown,
metadata={"title": f"🖥️ Screenshot - {action_type}", "status": "done"}
))
# Map model names to specific provider model names
MODEL_MAPPINGS = {
"openai": {
@@ -177,17 +218,18 @@ def get_ollama_models() -> List[str]:
return []
def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> str:
def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> Tuple[str, MetadataDict]:
"""Extract synthesized text from the agent result."""
synthesized_text = ""
metadata = MetadataDict()
if "output" in result and result["output"]:
for output in result["output"]:
if output.get("type") == "reasoning":
metadata["title"] = "🧠 Reasoning"
content = output.get("content", "")
if content:
synthesized_text += f"{content}\n"
elif output.get("type") == "message":
# Handle message type outputs - can contain rich content
content = output.get("content", [])
@@ -200,7 +242,7 @@ def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> st
if text_value:
synthesized_text += f"{text_value}\n"
elif output.get("type") == "computer_call":
elif output.get("type") == "computer_call":
action = output.get("action", {})
action_type = action.get("type", "")
@@ -223,8 +265,11 @@ def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> st
synthesized_text += f"Pressed key: {key}\n"
else:
synthesized_text += f"Performed {action_type} action.\n"
return synthesized_text.strip()
metadata["status"] = "done"
metadata["title"] = f"🛠️ {synthesized_text.strip().splitlines()[-1]}"
return synthesized_text.strip(), metadata
def create_computer_instance(verbosity: int = logging.INFO) -> Computer:
@@ -247,6 +292,7 @@ def create_agent(
verbosity: int = logging.INFO,
use_ollama: bool = False,
use_oaicompat: bool = False,
provider_base_url: Optional[str] = None,
) -> ComputerAgent:
"""Create or update the global agent with the specified parameters."""
global global_agent
@@ -270,72 +316,55 @@ def create_agent(
elif provider == LLMProvider.ANTHROPIC:
api_key = os.environ.get("ANTHROPIC_API_KEY", "")
# Create LLM model object with appropriate parameters
provider_base_url = "http://localhost:1234/v1" if use_oaicompat else None
# Use provided provider_base_url if available, otherwise use default
default_base_url = "http://localhost:1234/v1" if use_oaicompat else None
custom_base_url = provider_base_url or default_base_url
if use_oaicompat:
# Special handling for OAICOMPAT - use OAICOMPAT provider with custom base URL
print(
f"DEBUG - Creating OAICOMPAT agent with model: {model_name}, URL: {provider_base_url}"
f"DEBUG - Creating OAICOMPAT agent with model: {model_name}, URL: {custom_base_url}"
)
llm = LLM(
provider=LLMProvider.OAICOMPAT, # Set to OAICOMPAT instead of using original provider
name=model_name,
provider_base_url=provider_base_url,
provider_base_url=custom_base_url,
)
print(f"DEBUG - LLM provider is now: {llm.provider}, base URL: {llm.provider_base_url}")
# Note: Don't pass use_oaicompat to the agent, as it doesn't accept this parameter
elif provider == LLMProvider.OAICOMPAT:
# This path is unlikely to be taken with our current approach
llm = LLM(provider=provider, name=model_name, provider_base_url=provider_base_url)
llm = LLM(provider=provider, name=model_name, provider_base_url=custom_base_url)
else:
# For other providers, just use standard parameters
llm = LLM(provider=provider, name=model_name)
# Create or update the agent
if global_agent is None:
global_agent = ComputerAgent(
computer=computer,
loop=agent_loop,
model=llm,
api_key=api_key,
save_trajectory=save_trajectory,
only_n_most_recent_images=only_n_most_recent_images,
verbosity=verbosity,
**extra_config,
)
else:
# Update the existing agent's parameters
global_agent._loop = None # Force recreation of the loop
global_agent.provider = provider
global_agent.loop = agent_loop
global_agent.model = llm
global_agent.api_key = api_key
# Explicitly update these settings to ensure they take effect
global_agent.save_trajectory = save_trajectory
global_agent.only_n_most_recent_images = only_n_most_recent_images
# Update Ollama settings if applicable
if use_ollama:
global_agent.use_ollama = True
global_agent.ollama_model = model_name
else:
global_agent.use_ollama = False
global_agent.ollama_model = None
# Log the updated settings
logging.info(
f"Updated agent settings: save_trajectory={save_trajectory}, recent_images={only_n_most_recent_images}"
)
global_agent = ComputerAgent(
computer=computer,
loop=agent_loop,
model=llm,
api_key=api_key,
save_trajectory=save_trajectory,
only_n_most_recent_images=only_n_most_recent_images,
verbosity=verbosity,
**extra_config,
)
return global_agent
def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> str:
def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> Tuple[str, MetadataDict]:
"""Process agent results for the Gradio UI."""
# Extract text content
text_obj = result.get("text", {})
metadata = result.get("metadata", {})
# Create a properly typed MetadataDict
metadata_dict = MetadataDict()
metadata_dict["title"] = metadata.get("title", "")
metadata_dict["status"] = "done"
metadata = metadata_dict
# For OpenAI's Computer-Use Agent, text field is an object with format property
if (
@@ -344,8 +373,11 @@ def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> str:
and "format" in text_obj
and not text_obj.get("value", "")
):
content = extract_synthesized_text(result)
content, metadata = extract_synthesized_text(result)
else:
if not text_obj:
text_obj = result
# For other types of results, try to get text directly
if isinstance(text_obj, dict):
if "value" in text_obj:
@@ -378,180 +410,7 @@ def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> str:
if not isinstance(content, str):
content = str(content) if content else ""
return content
def respond(
message: str,
history: List[Tuple[str, str]],
model_choice, # Accept Gradio Dropdown component
agent_loop, # Accept Gradio Dropdown component
save_trajectory, # Accept Gradio Checkbox component
recent_images, # Accept Gradio Slider component
openai_api_key: Optional[str] = None,
anthropic_api_key: Optional[str] = None,
) -> str:
"""Process a message with the Computer-Use Agent and return the response."""
import asyncio
# Get actual values from Gradio components
model_choice_value = model_choice.value if hasattr(model_choice, "value") else model_choice
agent_loop_value = agent_loop.value if hasattr(agent_loop, "value") else agent_loop
save_trajectory_value = (
save_trajectory.value if hasattr(save_trajectory, "value") else save_trajectory
)
recent_images_value = int(
recent_images.value if hasattr(recent_images, "value") else recent_images
)
# Debug logging
print(f"DEBUG - Model choice object: {type(model_choice)}")
print(f"DEBUG - Model choice value: {model_choice_value}")
print(f"DEBUG - Agent loop value: {agent_loop_value}")
# Create a new event loop for this function call
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
async def _async_respond():
# Extract the loop type and model from the selection
loop_provider = "OPENAI"
if isinstance(model_choice_value, str):
# This is the case for a custom text input from textbox
if agent_loop_value == "OMNI":
loop_provider = "OMNI"
# Use the custom model name as is
model_id = model_choice_value
print(f"DEBUG - Using custom model: {model_id}")
else:
# Handle regular dropdown value as string
if model_choice_value.startswith("OpenAI:"):
loop_provider = "OPENAI"
model_id = model_choice_value.replace("OpenAI: ", "").lower()
elif model_choice_value.startswith("Anthropic:"):
loop_provider = "ANTHROPIC"
model_id = model_choice_value.replace("Anthropic: ", "").lower()
elif model_choice_value.startswith("OMNI:"):
loop_provider = "OMNI"
if "GPT" in model_choice_value:
model_id = model_choice_value.replace("OMNI: OpenAI ", "").lower()
elif "Claude" in model_choice_value:
model_id = model_choice_value.replace("OMNI: ", "").lower()
elif "Ollama" in model_choice_value:
loop_provider = "OMNI-OLLAMA"
# Extract everything after "OMNI: Ollama " which is the full model name (e.g., phi3:latest)
model_id = model_choice_value.replace("OMNI: Ollama ", "")
print(f"DEBUG - Ollama model ID: {model_id}")
else:
model_id = "default"
else:
# Default case
loop_provider = agent_loop_value
model_id = "default"
else:
# Model choice is not a string (shouldn't happen, but handle anyway)
loop_provider = agent_loop_value
model_id = "default"
print(f"DEBUG - Using loop provider: {loop_provider}, model_id: {model_id}")
# Use the mapping function to get provider, model name and agent loop
provider, model_name, agent_loop_type = get_provider_and_model(model_id, loop_provider)
print(
f"DEBUG - After mapping: provider={provider}, model_name={model_name}, agent_loop={agent_loop_type}"
)
# Special handling for OAICOMPAT to bypass provider-specific errors
# Creates the agent with OPENAI provider but using custom model name and provider base URL
is_oaicompat = str(provider) == "oaicompat"
# Don't override the provider for OAICOMPAT - instead pass it through
# if is_oaicompat:
# provider = LLMProvider.OPENAI
# Get API key based on provider
if provider == LLMProvider.OPENAI:
api_key = openai_api_key or os.environ.get("OPENAI_API_KEY", "")
elif provider == LLMProvider.ANTHROPIC:
api_key = anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY", "")
else:
api_key = ""
# Check for settings changes if agent already exists
settings_changed = False
settings_message = ""
if global_agent is not None:
# Safely check if save_trajectory setting changed
current_save_traj = getattr(global_agent, "save_trajectory", None)
if current_save_traj is not None and current_save_traj != save_trajectory_value:
settings_changed = True
settings_message += f"Save trajectory set to: {save_trajectory_value}. "
# Safely check if recent_images setting changed
current_recent_images = getattr(global_agent, "only_n_most_recent_images", None)
if current_recent_images is not None and current_recent_images != recent_images_value:
settings_changed = True
settings_message += f"Recent images set to: {recent_images_value}. "
# Create or update the agent
try:
create_agent(
provider=provider,
agent_loop=agent_loop_type,
model_name=model_name,
api_key=api_key,
save_trajectory=save_trajectory_value,
only_n_most_recent_images=recent_images_value,
use_ollama=loop_provider == "OMNI-OLLAMA",
use_oaicompat=is_oaicompat,
)
if global_agent is None:
return "Failed to create agent. Check API keys and configuration."
except Exception as e:
return f"Error creating agent: {str(e)}"
# Notify about settings changes if needed
if settings_changed:
return f"Settings updated: {settings_message}"
# Collect all responses
response_text = []
# Run the agent
try:
async for result in global_agent.run(message):
# Process result
content = process_agent_result(result)
# Skip empty content
if not content:
continue
# Add content to response list
response_text.append(content)
# Return the full response as a single string
return "\n".join(response_text) if response_text else "Task completed."
except Exception as e:
import traceback
traceback.print_exc()
return f"Error: {str(e)}"
# Run the async function and get the result
try:
result = loop.run_until_complete(_async_respond())
loop.close()
return result
except Exception as e:
loop.close()
import traceback
traceback.print_exc()
return f"Error executing async operation: {str(e)}"
return content, metadata
def create_gradio_ui(
provider_name: str = "openai",
@@ -725,48 +584,68 @@ def create_gradio_ui(
"""
)
# Configuration options
agent_loop = gr.Dropdown(
choices=["OPENAI", "ANTHROPIC", "OMNI"],
label="Agent Loop",
value=initial_loop,
info="Select the agent loop provider",
)
with gr.Accordion("Configuration", open=True):
# Configuration options
agent_loop = gr.Dropdown(
choices=["OPENAI", "ANTHROPIC", "OMNI"],
label="Agent Loop",
value=initial_loop,
info="Select the agent loop provider",
)
# Create model selection dropdown with custom value support for OMNI
model_choice = gr.Dropdown(
choices=provider_to_models.get(initial_loop, ["No models available"]),
label="LLM Provider and Model",
value=initial_model,
info="Select model or choose 'Custom model...' to enter a custom name",
interactive=True,
)
# Create model selection dropdown with custom value support for OMNI
model_choice = gr.Dropdown(
choices=provider_to_models.get(initial_loop, ["No models available"]),
label="LLM Provider and Model",
value=initial_model,
info="Select model or choose 'Custom model...' to enter a custom name",
interactive=True,
)
# Add custom model textbox (only visible when "Custom model..." is selected)
custom_model = gr.Textbox(
label="Custom Model Name",
placeholder="Enter custom model name (e.g., Qwen2.5-VL-7B-Instruct)",
value="Qwen2.5-VL-7B-Instruct", # Default value
visible=False, # Initially hidden
interactive=True,
)
# Add custom model textbox (only visible when "Custom model..." is selected)
custom_model = gr.Textbox(
label="Custom Model Name",
placeholder="Enter custom model name (e.g., Qwen2.5-VL-7B-Instruct)",
value="Qwen2.5-VL-7B-Instruct", # Default value
visible=False, # Initially hidden
interactive=True,
)
# Add custom provider base URL textbox (only visible when "Custom model..." is selected)
provider_base_url = gr.Textbox(
label="Provider Base URL",
placeholder="Enter provider base URL (e.g., http://localhost:1234/v1)",
value="http://localhost:1234/v1", # Default value
visible=False, # Initially hidden
interactive=True,
)
# Add custom API key textbox (only visible when "Custom model..." is selected)
provider_api_key = gr.Textbox(
label="Provider API Key",
placeholder="Enter provider API key (if required)",
value="", # Default empty value
visible=False, # Initially hidden
interactive=True,
type="password", # Hide the API key
)
save_trajectory = gr.Checkbox(
label="Save Trajectory",
value=True,
info="Save the agent's trajectory for debugging",
interactive=True,
)
save_trajectory = gr.Checkbox(
label="Save Trajectory",
value=True,
info="Save the agent's trajectory for debugging",
interactive=True,
)
recent_images = gr.Slider(
label="Recent Images",
minimum=1,
maximum=10,
value=3,
step=1,
info="Number of recent images to keep in context",
interactive=True,
)
recent_images = gr.Slider(
label="Recent Images",
minimum=1,
maximum=10,
value=3,
step=1,
info="Number of recent images to keep in context",
interactive=True,
)
# Right column for chat interface
with gr.Column(scale=2):
@@ -775,7 +654,7 @@ def create_gradio_ui(
"Ask me to perform tasks in a virtual macOS environment.<br>Built with <a href='https://github.com/trycua/cua' target='_blank'>github.com/trycua/cua</a>."
)
chatbot = gr.Chatbot()
chatbot_history = gr.Chatbot(type='messages')
msg = gr.Textbox(
placeholder="Ask me to perform tasks in a virtual macOS environment"
)
@@ -787,24 +666,27 @@ def create_gradio_ui(
# Function to handle chat submission
def chat_submit(message, history):
# Add user message to history
history = history + [(message, None)]
history.append(gr.ChatMessage(role="user", content=message))
return "", history
# Function to process agent response after user input
def process_response(
async def process_response(
history,
model_choice_value,
custom_model_value,
agent_loop_choice,
save_traj,
recent_imgs,
custom_url_value=None,
custom_api_key=None,
):
if not history:
return history
yield history
return
# Get the last user message
last_user_message = history[-1][0]
last_user_message = history[-1]['content']
# Use custom model value if "Custom model..." is selected
model_to_use = (
custom_model_value
@@ -812,38 +694,94 @@ def create_gradio_ui(
else model_choice_value
)
# Process with agent
response = respond(
last_user_message,
history[:-1], # History without the last message
model_to_use,
agent_loop_choice,
save_traj,
recent_imgs,
openai_api_key,
anthropic_api_key,
)
# Update the last assistant message
history[-1] = (last_user_message, response)
return history
try:
# Get the model, agent loop, and provider
provider, model_name, agent_loop_type = get_provider_and_model(
model_to_use, agent_loop_choice
)
# Special handling for OAICOMPAT
is_oaicompat = str(provider) == "oaicompat"
# Get API key based on provider
if model_choice_value == "Custom model..." and custom_api_key:
# Use custom API key if provided for custom model
api_key = custom_api_key
print(f"DEBUG - Using custom API key for model: {model_name}")
elif provider == LLMProvider.OPENAI:
api_key = openai_api_key or os.environ.get("OPENAI_API_KEY", "")
elif provider == LLMProvider.ANTHROPIC:
api_key = anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY", "")
else:
api_key = ""
# Create or update the agent
create_agent(
provider=provider,
agent_loop=agent_loop_type,
model_name=model_name,
api_key=api_key,
save_trajectory=save_traj,
only_n_most_recent_images=recent_imgs,
use_ollama=agent_loop_choice == "OMNI-OLLAMA",
use_oaicompat=is_oaicompat,
provider_base_url=custom_url_value if is_oaicompat and model_choice_value == "Custom model..." else None,
)
if global_agent is None:
# Add initial empty assistant message
history.append(gr.ChatMessage(role="assistant", content="Failed to create agent. Check API keys and configuration."))
yield history
return
# Add the screenshot handler to the agent's loop if available
if global_agent and hasattr(global_agent, "_loop"):
print("DEBUG - Adding screenshot handler to agent loop")
# Create the screenshot handler with references to UI components
screenshot_handler = GradioChatScreenshotHandler(
history
)
# Add the handler to the callback manager if it exists
if hasattr(global_agent._loop, "callback_manager"):
global_agent._loop.callback_manager.add_handler(screenshot_handler)
print(f"DEBUG - Screenshot handler added to callback manager with history: {id(history)}")
# Stream responses from the agent
async for result in global_agent.run(last_user_message):
# Process result
content, metadata = process_agent_result(result)
# Skip empty content
if content or metadata.get("title"):
history.append(gr.ChatMessage(role="assistant", content=content, metadata=metadata))
yield history
except Exception as e:
import traceback
traceback.print_exc()
# Update with error message
history.append(gr.ChatMessage(role="assistant", content=f"Error: {str(e)}"))
yield history
# Connect the components
msg.submit(chat_submit, [msg, chatbot], [msg, chatbot]).then(
msg.submit(chat_submit, [msg, chatbot_history], [msg, chatbot_history]).then(
process_response,
[
chatbot,
chatbot_history,
model_choice,
custom_model,
agent_loop,
save_trajectory,
recent_images,
provider_base_url,
provider_api_key,
],
[chatbot],
[chatbot_history],
)
# Clear button functionality
clear.click(lambda: None, None, chatbot, queue=False)
clear.click(lambda: None, None, chatbot_history, queue=False)
# Connect agent_loop changes to model selection
agent_loop.change(
@@ -853,14 +791,15 @@ def create_gradio_ui(
queue=False, # Process immediately without queueing
)
# Show/hide custom model textbox based on dropdown selection
# Show/hide custom model, provider base URL, and API key textboxes based on dropdown selection
def update_custom_model_visibility(model_value):
return gr.update(visible=model_value == "Custom model...")
is_custom = model_value == "Custom model..."
return gr.update(visible=is_custom), gr.update(visible=is_custom), gr.update(visible=is_custom)
model_choice.change(
fn=update_custom_model_visibility,
inputs=[model_choice],
outputs=[custom_model],
outputs=[custom_model, provider_base_url, provider_api_key],
queue=False, # Process immediately without queueing
)
+20 -9
View File
@@ -17,17 +17,28 @@ class TimeoutException(Exception):
@contextmanager
def timeout(seconds: int):
def timeout_handler(signum, frame):
raise TimeoutException("OCR process timed out")
import threading
# Check if we're in the main thread
if threading.current_thread() is threading.main_thread():
def timeout_handler(signum, frame):
raise TimeoutException("OCR process timed out")
original_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(seconds)
original_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(seconds)
try:
yield
finally:
signal.alarm(0)
signal.signal(signal.SIGALRM, original_handler)
try:
yield
finally:
signal.alarm(0)
signal.signal(signal.SIGALRM, original_handler)
else:
# In a non-main thread, we can't use signal
logger.warning("Timeout function called from non-main thread; signal-based timeout disabled")
try:
yield
finally:
pass
class OCRProcessor: