mirror of
https://github.com/trycua/computer.git
synced 2026-05-11 19:12:35 -05:00
Merge pull request #104 from ddupont808/feature/gradio-upgrade
[Agent] Improved Gradio UI
This commit is contained in:
@@ -5,10 +5,12 @@ import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from agent.providers.omni.parser import ParseResult
|
||||
from computer import Computer
|
||||
from .messages import StandardMessageManager, ImageRetentionConfig
|
||||
from .types import AgentResponse
|
||||
from .experiment import ExperimentManager
|
||||
from .callbacks import CallbackManager, CallbackHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -27,6 +29,7 @@ class BaseLoop(ABC):
|
||||
base_dir: Optional[str] = "trajectories",
|
||||
save_trajectory: bool = True,
|
||||
only_n_most_recent_images: Optional[int] = 2,
|
||||
callback_handlers: Optional[List[CallbackHandler]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize base agent loop.
|
||||
@@ -75,6 +78,9 @@ class BaseLoop(ABC):
|
||||
|
||||
# Initialize basic tracking
|
||||
self.turn_count = 0
|
||||
|
||||
# Initialize callback manager
|
||||
self.callback_manager = CallbackManager(handlers=callback_handlers or [])
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize both the API client and computer interface with retries."""
|
||||
@@ -187,3 +193,17 @@ class BaseLoop(ABC):
|
||||
"""
|
||||
if self.experiment_manager:
|
||||
self.experiment_manager.save_screenshot(img_base64, action_type)
|
||||
|
||||
###########################################
|
||||
# EVENT HOOKS / CALLBACKS
|
||||
###########################################
|
||||
|
||||
async def handle_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
|
||||
"""Process a screenshot through callback managers
|
||||
|
||||
Args:
|
||||
screenshot_base64: Base64 encoded screenshot
|
||||
action_type: Type of action that triggered the screenshot
|
||||
"""
|
||||
if hasattr(self, 'callback_manager'):
|
||||
await self.callback_manager.on_screenshot(screenshot_base64, action_type, parsed_screen)
|
||||
|
||||
@@ -6,6 +6,8 @@ from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Protocol
|
||||
|
||||
from agent.providers.omni.parser import ParseResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ContentCallback(Protocol):
|
||||
@@ -20,6 +22,10 @@ class APICallback(Protocol):
|
||||
"""Protocol for API callbacks."""
|
||||
def __call__(self, request: Any, response: Any, error: Optional[Exception] = None) -> None: ...
|
||||
|
||||
class ScreenshotCallback(Protocol):
|
||||
"""Protocol for screenshot callbacks."""
|
||||
def __call__(self, screenshot_base64: str, action_type: str = "") -> Optional[str]: ...
|
||||
|
||||
class BaseCallbackManager(ABC):
|
||||
"""Base class for callback managers."""
|
||||
|
||||
@@ -110,7 +116,20 @@ class CallbackManager:
|
||||
"""
|
||||
for handler in self.handlers:
|
||||
await handler.on_error(error, **kwargs)
|
||||
|
||||
|
||||
async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
|
||||
"""Called when a screenshot is taken.
|
||||
|
||||
Args:
|
||||
screenshot_base64: Base64 encoded screenshot
|
||||
action_type: Type of action that triggered the screenshot
|
||||
parsed_screen: Optional output from parsing the screenshot
|
||||
|
||||
Returns:
|
||||
Modified screenshot or original if no modifications
|
||||
"""
|
||||
for handler in self.handlers:
|
||||
await handler.on_screenshot(screenshot_base64, action_type, parsed_screen)
|
||||
|
||||
class CallbackHandler(ABC):
|
||||
"""Base class for callback handlers."""
|
||||
@@ -144,4 +163,40 @@ class CallbackHandler(ABC):
|
||||
error: Exception that occurred
|
||||
**kwargs: Additional data
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
|
||||
"""Called when a screenshot is taken.
|
||||
|
||||
Args:
|
||||
screenshot_base64: Base64 encoded screenshot
|
||||
action_type: Type of action that triggered the screenshot
|
||||
|
||||
Returns:
|
||||
Optional modified screenshot
|
||||
"""
|
||||
pass
|
||||
|
||||
class DefaultCallbackHandler(CallbackHandler):
|
||||
"""Default implementation of CallbackHandler with no-op methods.
|
||||
|
||||
This class implements all abstract methods from CallbackHandler,
|
||||
allowing subclasses to override only the methods they need.
|
||||
"""
|
||||
|
||||
async def on_action_start(self, action: str, **kwargs) -> None:
|
||||
"""Default no-op implementation."""
|
||||
pass
|
||||
|
||||
async def on_action_end(self, action: str, success: bool, **kwargs) -> None:
|
||||
"""Default no-op implementation."""
|
||||
pass
|
||||
|
||||
async def on_error(self, error: Exception, **kwargs) -> None:
|
||||
"""Default no-op implementation."""
|
||||
pass
|
||||
|
||||
async def on_screenshot(self, screenshot_base64: str, action_type: str = "") -> None:
|
||||
"""Default no-op implementation."""
|
||||
pass
|
||||
@@ -45,8 +45,8 @@ class OAICompatClient(BaseOmniClient):
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Generation temperature
|
||||
"""
|
||||
super().__init__(api_key="EMPTY", model=model)
|
||||
self.api_key = "EMPTY" # Local endpoints typically don't require an API key
|
||||
super().__init__(api_key=api_key or "EMPTY", model=model)
|
||||
self.api_key = api_key or "EMPTY" # Local endpoints typically don't require an API key
|
||||
self.model = model
|
||||
self.provider_base_url = (
|
||||
provider_base_url or "http://localhost:8000/v1"
|
||||
@@ -146,10 +146,18 @@ class OAICompatClient(BaseOmniClient):
|
||||
base_url = self.provider_base_url or "http://localhost:8000/v1"
|
||||
|
||||
# Check if the base URL already includes the chat/completions endpoint
|
||||
|
||||
endpoint_url = base_url
|
||||
if not endpoint_url.endswith("/chat/completions"):
|
||||
# If URL is RunPod format, make it OpenAI compatible
|
||||
if endpoint_url.startswith("https://api.runpod.ai/v2/"):
|
||||
# Extract RunPod endpoint ID
|
||||
parts = endpoint_url.split("/")
|
||||
if len(parts) >= 5:
|
||||
runpod_id = parts[4]
|
||||
endpoint_url = f"https://api.runpod.ai/v2/{runpod_id}/openai/v1/chat/completions"
|
||||
# If the URL ends with /v1, append /chat/completions
|
||||
if endpoint_url.endswith("/v1"):
|
||||
elif endpoint_url.endswith("/v1"):
|
||||
endpoint_url = f"{endpoint_url}/chat/completions"
|
||||
# If the URL doesn't end with /v1, make sure it has a proper structure
|
||||
elif not endpoint_url.endswith("/"):
|
||||
|
||||
@@ -147,7 +147,7 @@ class OmniLoop(BaseLoop):
|
||||
)
|
||||
elif self.provider == LLMProvider.OAICOMPAT:
|
||||
self.client = OAICompatClient(
|
||||
api_key="EMPTY", # Local endpoints typically don't require an API key
|
||||
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
|
||||
model=self.model,
|
||||
provider_base_url=self.provider_base_url,
|
||||
)
|
||||
@@ -183,7 +183,7 @@ class OmniLoop(BaseLoop):
|
||||
)
|
||||
elif self.provider == LLMProvider.OAICOMPAT:
|
||||
self.client = OAICompatClient(
|
||||
api_key="EMPTY", # Local endpoints typically don't require an API key
|
||||
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
|
||||
model=self.model,
|
||||
provider_base_url=self.provider_base_url,
|
||||
)
|
||||
@@ -548,6 +548,10 @@ class OmniLoop(BaseLoop):
|
||||
img_data = parsed_screen.annotated_image_base64
|
||||
if "," in img_data:
|
||||
img_data = img_data.split(",")[1]
|
||||
|
||||
# Process screenshot through hooks and save if needed
|
||||
await self.handle_screenshot(img_data, action_type="state", parsed_screen=parsed_screen)
|
||||
|
||||
# Save with a generic "state" action type to indicate this is the current screen state
|
||||
self._save_screenshot(img_data, action_type="state")
|
||||
except Exception as e:
|
||||
@@ -663,6 +667,8 @@ class OmniLoop(BaseLoop):
|
||||
response=response,
|
||||
messages=self.message_manager.messages,
|
||||
model=self.model,
|
||||
parsed_screen=parsed_screen,
|
||||
parser=self.parser
|
||||
)
|
||||
|
||||
# Yield the response to the caller
|
||||
|
||||
@@ -194,8 +194,13 @@ class OpenAILoop(BaseLoop):
|
||||
# Convert to base64 if needed
|
||||
if isinstance(screenshot, bytes):
|
||||
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
|
||||
elif isinstance(screenshot, (bytearray, memoryview)):
|
||||
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
|
||||
else:
|
||||
screenshot_base64 = screenshot
|
||||
screenshot_base64 = str(screenshot)
|
||||
|
||||
# Emit screenshot callbacks
|
||||
await self.handle_screenshot(screenshot_base64, action_type="initial_state")
|
||||
|
||||
# Save screenshot if requested
|
||||
if self.save_trajectory:
|
||||
@@ -204,8 +209,6 @@ class OpenAILoop(BaseLoop):
|
||||
logger.warning(
|
||||
"Converting non-string screenshot_base64 to string for _save_screenshot"
|
||||
)
|
||||
if isinstance(screenshot_base64, (bytearray, memoryview)):
|
||||
screenshot_base64 = base64.b64encode(screenshot_base64).decode("utf-8")
|
||||
self._save_screenshot(screenshot_base64, action_type="state")
|
||||
logger.info("Screenshot saved to trajectory")
|
||||
|
||||
@@ -336,8 +339,14 @@ class OpenAILoop(BaseLoop):
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
if isinstance(screenshot, bytes):
|
||||
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
|
||||
elif isinstance(screenshot, (bytearray, memoryview)):
|
||||
screenshot_base64 = base64.b64encode(bytes(screenshot)).decode("utf-8")
|
||||
else:
|
||||
screenshot_base64 = screenshot
|
||||
screenshot_base64 = str(screenshot)
|
||||
|
||||
# Process screenshot through hooks
|
||||
action_type = f"after_{action.get('type', 'action')}"
|
||||
await self.handle_screenshot(screenshot_base64, action_type=action_type)
|
||||
|
||||
# Create computer_call_output
|
||||
computer_call_output = {
|
||||
|
||||
+227
-288
@@ -32,9 +32,12 @@ import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Optional, AsyncGenerator, Any, Tuple, Union
|
||||
import gradio as gr
|
||||
from gradio.components.chatbot import MetadataDict
|
||||
|
||||
# Import from agent package
|
||||
from agent.core.types import AgentResponse
|
||||
from agent.core.callbacks import DefaultCallbackHandler
|
||||
from agent.providers.omni.parser import ParseResult
|
||||
from computer import Computer
|
||||
|
||||
from agent import ComputerAgent, AgentLoop, LLM, LLMProvider
|
||||
@@ -43,6 +46,44 @@ from agent import ComputerAgent, AgentLoop, LLM, LLMProvider
|
||||
global_agent = None
|
||||
global_computer = None
|
||||
|
||||
# We'll use asyncio.run() instead of a persistent event loop
|
||||
|
||||
|
||||
# Custom Screenshot Handler for Gradio chat
|
||||
class GradioChatScreenshotHandler(DefaultCallbackHandler):
|
||||
"""Custom handler that adds screenshots to the Gradio chatbot and updates annotated image."""
|
||||
|
||||
def __init__(self, chatbot_history: List[gr.ChatMessage]):
|
||||
"""Initialize with reference to chat history and annotated image component.
|
||||
|
||||
Args:
|
||||
chatbot_history: Reference to the Gradio chatbot history list
|
||||
annotated_image: Reference to the annotated image component
|
||||
"""
|
||||
self.chatbot_history = chatbot_history
|
||||
print("GradioChatScreenshotHandler initialized")
|
||||
|
||||
async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
|
||||
"""Add screenshot to chatbot when a screenshot is taken and update the annotated image.
|
||||
|
||||
Args:
|
||||
screenshot_base64: Base64 encoded screenshot
|
||||
action_type: Type of action that triggered the screenshot
|
||||
|
||||
Returns:
|
||||
Original screenshot (does not modify it)
|
||||
"""
|
||||
# Create a markdown image element for the screenshot
|
||||
image_markdown = f""
|
||||
|
||||
# Simply append the screenshot as a new message
|
||||
if self.chatbot_history is not None:
|
||||
self.chatbot_history.append(gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=image_markdown,
|
||||
metadata={"title": f"🖥️ Screenshot - {action_type}", "status": "done"}
|
||||
))
|
||||
|
||||
# Map model names to specific provider model names
|
||||
MODEL_MAPPINGS = {
|
||||
"openai": {
|
||||
@@ -177,17 +218,18 @@ def get_ollama_models() -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> str:
|
||||
def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> Tuple[str, MetadataDict]:
|
||||
"""Extract synthesized text from the agent result."""
|
||||
synthesized_text = ""
|
||||
metadata = MetadataDict()
|
||||
|
||||
if "output" in result and result["output"]:
|
||||
for output in result["output"]:
|
||||
if output.get("type") == "reasoning":
|
||||
metadata["title"] = "🧠 Reasoning"
|
||||
content = output.get("content", "")
|
||||
if content:
|
||||
synthesized_text += f"{content}\n"
|
||||
|
||||
elif output.get("type") == "message":
|
||||
# Handle message type outputs - can contain rich content
|
||||
content = output.get("content", [])
|
||||
@@ -200,7 +242,7 @@ def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> st
|
||||
if text_value:
|
||||
synthesized_text += f"{text_value}\n"
|
||||
|
||||
elif output.get("type") == "computer_call":
|
||||
elif output.get("type") == "computer_call":
|
||||
action = output.get("action", {})
|
||||
action_type = action.get("type", "")
|
||||
|
||||
@@ -223,8 +265,11 @@ def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> st
|
||||
synthesized_text += f"Pressed key: {key}\n"
|
||||
else:
|
||||
synthesized_text += f"Performed {action_type} action.\n"
|
||||
|
||||
return synthesized_text.strip()
|
||||
|
||||
metadata["status"] = "done"
|
||||
metadata["title"] = f"🛠️ {synthesized_text.strip().splitlines()[-1]}"
|
||||
|
||||
return synthesized_text.strip(), metadata
|
||||
|
||||
|
||||
def create_computer_instance(verbosity: int = logging.INFO) -> Computer:
|
||||
@@ -247,6 +292,7 @@ def create_agent(
|
||||
verbosity: int = logging.INFO,
|
||||
use_ollama: bool = False,
|
||||
use_oaicompat: bool = False,
|
||||
provider_base_url: Optional[str] = None,
|
||||
) -> ComputerAgent:
|
||||
"""Create or update the global agent with the specified parameters."""
|
||||
global global_agent
|
||||
@@ -270,72 +316,55 @@ def create_agent(
|
||||
elif provider == LLMProvider.ANTHROPIC:
|
||||
api_key = os.environ.get("ANTHROPIC_API_KEY", "")
|
||||
|
||||
# Create LLM model object with appropriate parameters
|
||||
provider_base_url = "http://localhost:1234/v1" if use_oaicompat else None
|
||||
# Use provided provider_base_url if available, otherwise use default
|
||||
default_base_url = "http://localhost:1234/v1" if use_oaicompat else None
|
||||
custom_base_url = provider_base_url or default_base_url
|
||||
|
||||
if use_oaicompat:
|
||||
# Special handling for OAICOMPAT - use OAICOMPAT provider with custom base URL
|
||||
print(
|
||||
f"DEBUG - Creating OAICOMPAT agent with model: {model_name}, URL: {provider_base_url}"
|
||||
f"DEBUG - Creating OAICOMPAT agent with model: {model_name}, URL: {custom_base_url}"
|
||||
)
|
||||
llm = LLM(
|
||||
provider=LLMProvider.OAICOMPAT, # Set to OAICOMPAT instead of using original provider
|
||||
name=model_name,
|
||||
provider_base_url=provider_base_url,
|
||||
provider_base_url=custom_base_url,
|
||||
)
|
||||
print(f"DEBUG - LLM provider is now: {llm.provider}, base URL: {llm.provider_base_url}")
|
||||
# Note: Don't pass use_oaicompat to the agent, as it doesn't accept this parameter
|
||||
elif provider == LLMProvider.OAICOMPAT:
|
||||
# This path is unlikely to be taken with our current approach
|
||||
llm = LLM(provider=provider, name=model_name, provider_base_url=provider_base_url)
|
||||
llm = LLM(provider=provider, name=model_name, provider_base_url=custom_base_url)
|
||||
else:
|
||||
# For other providers, just use standard parameters
|
||||
llm = LLM(provider=provider, name=model_name)
|
||||
|
||||
# Create or update the agent
|
||||
if global_agent is None:
|
||||
global_agent = ComputerAgent(
|
||||
computer=computer,
|
||||
loop=agent_loop,
|
||||
model=llm,
|
||||
api_key=api_key,
|
||||
save_trajectory=save_trajectory,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
verbosity=verbosity,
|
||||
**extra_config,
|
||||
)
|
||||
else:
|
||||
# Update the existing agent's parameters
|
||||
global_agent._loop = None # Force recreation of the loop
|
||||
global_agent.provider = provider
|
||||
global_agent.loop = agent_loop
|
||||
global_agent.model = llm
|
||||
global_agent.api_key = api_key
|
||||
|
||||
# Explicitly update these settings to ensure they take effect
|
||||
global_agent.save_trajectory = save_trajectory
|
||||
global_agent.only_n_most_recent_images = only_n_most_recent_images
|
||||
|
||||
# Update Ollama settings if applicable
|
||||
if use_ollama:
|
||||
global_agent.use_ollama = True
|
||||
global_agent.ollama_model = model_name
|
||||
else:
|
||||
global_agent.use_ollama = False
|
||||
global_agent.ollama_model = None
|
||||
|
||||
# Log the updated settings
|
||||
logging.info(
|
||||
f"Updated agent settings: save_trajectory={save_trajectory}, recent_images={only_n_most_recent_images}"
|
||||
)
|
||||
global_agent = ComputerAgent(
|
||||
computer=computer,
|
||||
loop=agent_loop,
|
||||
model=llm,
|
||||
api_key=api_key,
|
||||
save_trajectory=save_trajectory,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
verbosity=verbosity,
|
||||
**extra_config,
|
||||
)
|
||||
|
||||
return global_agent
|
||||
|
||||
|
||||
def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> str:
|
||||
def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> Tuple[str, MetadataDict]:
|
||||
"""Process agent results for the Gradio UI."""
|
||||
# Extract text content
|
||||
text_obj = result.get("text", {})
|
||||
metadata = result.get("metadata", {})
|
||||
|
||||
# Create a properly typed MetadataDict
|
||||
metadata_dict = MetadataDict()
|
||||
metadata_dict["title"] = metadata.get("title", "")
|
||||
metadata_dict["status"] = "done"
|
||||
metadata = metadata_dict
|
||||
|
||||
# For OpenAI's Computer-Use Agent, text field is an object with format property
|
||||
if (
|
||||
@@ -344,8 +373,11 @@ def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> str:
|
||||
and "format" in text_obj
|
||||
and not text_obj.get("value", "")
|
||||
):
|
||||
content = extract_synthesized_text(result)
|
||||
content, metadata = extract_synthesized_text(result)
|
||||
else:
|
||||
if not text_obj:
|
||||
text_obj = result
|
||||
|
||||
# For other types of results, try to get text directly
|
||||
if isinstance(text_obj, dict):
|
||||
if "value" in text_obj:
|
||||
@@ -378,180 +410,7 @@ def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> str:
|
||||
if not isinstance(content, str):
|
||||
content = str(content) if content else ""
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def respond(
|
||||
message: str,
|
||||
history: List[Tuple[str, str]],
|
||||
model_choice, # Accept Gradio Dropdown component
|
||||
agent_loop, # Accept Gradio Dropdown component
|
||||
save_trajectory, # Accept Gradio Checkbox component
|
||||
recent_images, # Accept Gradio Slider component
|
||||
openai_api_key: Optional[str] = None,
|
||||
anthropic_api_key: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Process a message with the Computer-Use Agent and return the response."""
|
||||
import asyncio
|
||||
|
||||
# Get actual values from Gradio components
|
||||
model_choice_value = model_choice.value if hasattr(model_choice, "value") else model_choice
|
||||
agent_loop_value = agent_loop.value if hasattr(agent_loop, "value") else agent_loop
|
||||
save_trajectory_value = (
|
||||
save_trajectory.value if hasattr(save_trajectory, "value") else save_trajectory
|
||||
)
|
||||
recent_images_value = int(
|
||||
recent_images.value if hasattr(recent_images, "value") else recent_images
|
||||
)
|
||||
|
||||
# Debug logging
|
||||
print(f"DEBUG - Model choice object: {type(model_choice)}")
|
||||
print(f"DEBUG - Model choice value: {model_choice_value}")
|
||||
print(f"DEBUG - Agent loop value: {agent_loop_value}")
|
||||
|
||||
# Create a new event loop for this function call
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
async def _async_respond():
|
||||
# Extract the loop type and model from the selection
|
||||
loop_provider = "OPENAI"
|
||||
if isinstance(model_choice_value, str):
|
||||
# This is the case for a custom text input from textbox
|
||||
if agent_loop_value == "OMNI":
|
||||
loop_provider = "OMNI"
|
||||
# Use the custom model name as is
|
||||
model_id = model_choice_value
|
||||
print(f"DEBUG - Using custom model: {model_id}")
|
||||
else:
|
||||
# Handle regular dropdown value as string
|
||||
if model_choice_value.startswith("OpenAI:"):
|
||||
loop_provider = "OPENAI"
|
||||
model_id = model_choice_value.replace("OpenAI: ", "").lower()
|
||||
elif model_choice_value.startswith("Anthropic:"):
|
||||
loop_provider = "ANTHROPIC"
|
||||
model_id = model_choice_value.replace("Anthropic: ", "").lower()
|
||||
elif model_choice_value.startswith("OMNI:"):
|
||||
loop_provider = "OMNI"
|
||||
if "GPT" in model_choice_value:
|
||||
model_id = model_choice_value.replace("OMNI: OpenAI ", "").lower()
|
||||
elif "Claude" in model_choice_value:
|
||||
model_id = model_choice_value.replace("OMNI: ", "").lower()
|
||||
elif "Ollama" in model_choice_value:
|
||||
loop_provider = "OMNI-OLLAMA"
|
||||
# Extract everything after "OMNI: Ollama " which is the full model name (e.g., phi3:latest)
|
||||
model_id = model_choice_value.replace("OMNI: Ollama ", "")
|
||||
print(f"DEBUG - Ollama model ID: {model_id}")
|
||||
else:
|
||||
model_id = "default"
|
||||
else:
|
||||
# Default case
|
||||
loop_provider = agent_loop_value
|
||||
model_id = "default"
|
||||
else:
|
||||
# Model choice is not a string (shouldn't happen, but handle anyway)
|
||||
loop_provider = agent_loop_value
|
||||
model_id = "default"
|
||||
|
||||
print(f"DEBUG - Using loop provider: {loop_provider}, model_id: {model_id}")
|
||||
|
||||
# Use the mapping function to get provider, model name and agent loop
|
||||
provider, model_name, agent_loop_type = get_provider_and_model(model_id, loop_provider)
|
||||
print(
|
||||
f"DEBUG - After mapping: provider={provider}, model_name={model_name}, agent_loop={agent_loop_type}"
|
||||
)
|
||||
|
||||
# Special handling for OAICOMPAT to bypass provider-specific errors
|
||||
# Creates the agent with OPENAI provider but using custom model name and provider base URL
|
||||
is_oaicompat = str(provider) == "oaicompat"
|
||||
|
||||
# Don't override the provider for OAICOMPAT - instead pass it through
|
||||
# if is_oaicompat:
|
||||
# provider = LLMProvider.OPENAI
|
||||
|
||||
# Get API key based on provider
|
||||
if provider == LLMProvider.OPENAI:
|
||||
api_key = openai_api_key or os.environ.get("OPENAI_API_KEY", "")
|
||||
elif provider == LLMProvider.ANTHROPIC:
|
||||
api_key = anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY", "")
|
||||
else:
|
||||
api_key = ""
|
||||
|
||||
# Check for settings changes if agent already exists
|
||||
settings_changed = False
|
||||
settings_message = ""
|
||||
if global_agent is not None:
|
||||
# Safely check if save_trajectory setting changed
|
||||
current_save_traj = getattr(global_agent, "save_trajectory", None)
|
||||
if current_save_traj is not None and current_save_traj != save_trajectory_value:
|
||||
settings_changed = True
|
||||
settings_message += f"Save trajectory set to: {save_trajectory_value}. "
|
||||
|
||||
# Safely check if recent_images setting changed
|
||||
current_recent_images = getattr(global_agent, "only_n_most_recent_images", None)
|
||||
if current_recent_images is not None and current_recent_images != recent_images_value:
|
||||
settings_changed = True
|
||||
settings_message += f"Recent images set to: {recent_images_value}. "
|
||||
|
||||
# Create or update the agent
|
||||
try:
|
||||
create_agent(
|
||||
provider=provider,
|
||||
agent_loop=agent_loop_type,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
save_trajectory=save_trajectory_value,
|
||||
only_n_most_recent_images=recent_images_value,
|
||||
use_ollama=loop_provider == "OMNI-OLLAMA",
|
||||
use_oaicompat=is_oaicompat,
|
||||
)
|
||||
|
||||
if global_agent is None:
|
||||
return "Failed to create agent. Check API keys and configuration."
|
||||
except Exception as e:
|
||||
return f"Error creating agent: {str(e)}"
|
||||
|
||||
# Notify about settings changes if needed
|
||||
if settings_changed:
|
||||
return f"Settings updated: {settings_message}"
|
||||
|
||||
# Collect all responses
|
||||
response_text = []
|
||||
|
||||
# Run the agent
|
||||
try:
|
||||
async for result in global_agent.run(message):
|
||||
# Process result
|
||||
content = process_agent_result(result)
|
||||
|
||||
# Skip empty content
|
||||
if not content:
|
||||
continue
|
||||
|
||||
# Add content to response list
|
||||
response_text.append(content)
|
||||
|
||||
# Return the full response as a single string
|
||||
return "\n".join(response_text) if response_text else "Task completed."
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
# Run the async function and get the result
|
||||
try:
|
||||
result = loop.run_until_complete(_async_respond())
|
||||
loop.close()
|
||||
return result
|
||||
except Exception as e:
|
||||
loop.close()
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return f"Error executing async operation: {str(e)}"
|
||||
|
||||
return content, metadata
|
||||
|
||||
def create_gradio_ui(
|
||||
provider_name: str = "openai",
|
||||
@@ -725,48 +584,68 @@ def create_gradio_ui(
|
||||
"""
|
||||
)
|
||||
|
||||
# Configuration options
|
||||
agent_loop = gr.Dropdown(
|
||||
choices=["OPENAI", "ANTHROPIC", "OMNI"],
|
||||
label="Agent Loop",
|
||||
value=initial_loop,
|
||||
info="Select the agent loop provider",
|
||||
)
|
||||
with gr.Accordion("Configuration", open=True):
|
||||
# Configuration options
|
||||
agent_loop = gr.Dropdown(
|
||||
choices=["OPENAI", "ANTHROPIC", "OMNI"],
|
||||
label="Agent Loop",
|
||||
value=initial_loop,
|
||||
info="Select the agent loop provider",
|
||||
)
|
||||
|
||||
# Create model selection dropdown with custom value support for OMNI
|
||||
model_choice = gr.Dropdown(
|
||||
choices=provider_to_models.get(initial_loop, ["No models available"]),
|
||||
label="LLM Provider and Model",
|
||||
value=initial_model,
|
||||
info="Select model or choose 'Custom model...' to enter a custom name",
|
||||
interactive=True,
|
||||
)
|
||||
# Create model selection dropdown with custom value support for OMNI
|
||||
model_choice = gr.Dropdown(
|
||||
choices=provider_to_models.get(initial_loop, ["No models available"]),
|
||||
label="LLM Provider and Model",
|
||||
value=initial_model,
|
||||
info="Select model or choose 'Custom model...' to enter a custom name",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
# Add custom model textbox (only visible when "Custom model..." is selected)
|
||||
custom_model = gr.Textbox(
|
||||
label="Custom Model Name",
|
||||
placeholder="Enter custom model name (e.g., Qwen2.5-VL-7B-Instruct)",
|
||||
value="Qwen2.5-VL-7B-Instruct", # Default value
|
||||
visible=False, # Initially hidden
|
||||
interactive=True,
|
||||
)
|
||||
# Add custom model textbox (only visible when "Custom model..." is selected)
|
||||
custom_model = gr.Textbox(
|
||||
label="Custom Model Name",
|
||||
placeholder="Enter custom model name (e.g., Qwen2.5-VL-7B-Instruct)",
|
||||
value="Qwen2.5-VL-7B-Instruct", # Default value
|
||||
visible=False, # Initially hidden
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
# Add custom provider base URL textbox (only visible when "Custom model..." is selected)
|
||||
provider_base_url = gr.Textbox(
|
||||
label="Provider Base URL",
|
||||
placeholder="Enter provider base URL (e.g., http://localhost:1234/v1)",
|
||||
value="http://localhost:1234/v1", # Default value
|
||||
visible=False, # Initially hidden
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
# Add custom API key textbox (only visible when "Custom model..." is selected)
|
||||
provider_api_key = gr.Textbox(
|
||||
label="Provider API Key",
|
||||
placeholder="Enter provider API key (if required)",
|
||||
value="", # Default empty value
|
||||
visible=False, # Initially hidden
|
||||
interactive=True,
|
||||
type="password", # Hide the API key
|
||||
)
|
||||
|
||||
save_trajectory = gr.Checkbox(
|
||||
label="Save Trajectory",
|
||||
value=True,
|
||||
info="Save the agent's trajectory for debugging",
|
||||
interactive=True,
|
||||
)
|
||||
save_trajectory = gr.Checkbox(
|
||||
label="Save Trajectory",
|
||||
value=True,
|
||||
info="Save the agent's trajectory for debugging",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
recent_images = gr.Slider(
|
||||
label="Recent Images",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
value=3,
|
||||
step=1,
|
||||
info="Number of recent images to keep in context",
|
||||
interactive=True,
|
||||
)
|
||||
recent_images = gr.Slider(
|
||||
label="Recent Images",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
value=3,
|
||||
step=1,
|
||||
info="Number of recent images to keep in context",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
# Right column for chat interface
|
||||
with gr.Column(scale=2):
|
||||
@@ -775,7 +654,7 @@ def create_gradio_ui(
|
||||
"Ask me to perform tasks in a virtual macOS environment.<br>Built with <a href='https://github.com/trycua/cua' target='_blank'>github.com/trycua/cua</a>."
|
||||
)
|
||||
|
||||
chatbot = gr.Chatbot()
|
||||
chatbot_history = gr.Chatbot(type='messages')
|
||||
msg = gr.Textbox(
|
||||
placeholder="Ask me to perform tasks in a virtual macOS environment"
|
||||
)
|
||||
@@ -787,24 +666,27 @@ def create_gradio_ui(
|
||||
# Function to handle chat submission
|
||||
def chat_submit(message, history):
|
||||
# Add user message to history
|
||||
history = history + [(message, None)]
|
||||
history.append(gr.ChatMessage(role="user", content=message))
|
||||
return "", history
|
||||
|
||||
# Function to process agent response after user input
|
||||
def process_response(
|
||||
async def process_response(
|
||||
history,
|
||||
model_choice_value,
|
||||
custom_model_value,
|
||||
agent_loop_choice,
|
||||
save_traj,
|
||||
recent_imgs,
|
||||
custom_url_value=None,
|
||||
custom_api_key=None,
|
||||
):
|
||||
if not history:
|
||||
return history
|
||||
yield history
|
||||
return
|
||||
|
||||
# Get the last user message
|
||||
last_user_message = history[-1][0]
|
||||
|
||||
last_user_message = history[-1]['content']
|
||||
|
||||
# Use custom model value if "Custom model..." is selected
|
||||
model_to_use = (
|
||||
custom_model_value
|
||||
@@ -812,38 +694,94 @@ def create_gradio_ui(
|
||||
else model_choice_value
|
||||
)
|
||||
|
||||
# Process with agent
|
||||
response = respond(
|
||||
last_user_message,
|
||||
history[:-1], # History without the last message
|
||||
model_to_use,
|
||||
agent_loop_choice,
|
||||
save_traj,
|
||||
recent_imgs,
|
||||
openai_api_key,
|
||||
anthropic_api_key,
|
||||
)
|
||||
|
||||
# Update the last assistant message
|
||||
history[-1] = (last_user_message, response)
|
||||
return history
|
||||
|
||||
try:
|
||||
# Get the model, agent loop, and provider
|
||||
provider, model_name, agent_loop_type = get_provider_and_model(
|
||||
model_to_use, agent_loop_choice
|
||||
)
|
||||
|
||||
# Special handling for OAICOMPAT
|
||||
is_oaicompat = str(provider) == "oaicompat"
|
||||
|
||||
# Get API key based on provider
|
||||
if model_choice_value == "Custom model..." and custom_api_key:
|
||||
# Use custom API key if provided for custom model
|
||||
api_key = custom_api_key
|
||||
print(f"DEBUG - Using custom API key for model: {model_name}")
|
||||
elif provider == LLMProvider.OPENAI:
|
||||
api_key = openai_api_key or os.environ.get("OPENAI_API_KEY", "")
|
||||
elif provider == LLMProvider.ANTHROPIC:
|
||||
api_key = anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY", "")
|
||||
else:
|
||||
api_key = ""
|
||||
|
||||
# Create or update the agent
|
||||
create_agent(
|
||||
provider=provider,
|
||||
agent_loop=agent_loop_type,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
save_trajectory=save_traj,
|
||||
only_n_most_recent_images=recent_imgs,
|
||||
use_ollama=agent_loop_choice == "OMNI-OLLAMA",
|
||||
use_oaicompat=is_oaicompat,
|
||||
provider_base_url=custom_url_value if is_oaicompat and model_choice_value == "Custom model..." else None,
|
||||
)
|
||||
|
||||
if global_agent is None:
|
||||
# Add initial empty assistant message
|
||||
history.append(gr.ChatMessage(role="assistant", content="Failed to create agent. Check API keys and configuration."))
|
||||
yield history
|
||||
return
|
||||
|
||||
# Add the screenshot handler to the agent's loop if available
|
||||
if global_agent and hasattr(global_agent, "_loop"):
|
||||
print("DEBUG - Adding screenshot handler to agent loop")
|
||||
|
||||
# Create the screenshot handler with references to UI components
|
||||
screenshot_handler = GradioChatScreenshotHandler(
|
||||
history
|
||||
)
|
||||
|
||||
# Add the handler to the callback manager if it exists
|
||||
if hasattr(global_agent._loop, "callback_manager"):
|
||||
global_agent._loop.callback_manager.add_handler(screenshot_handler)
|
||||
print(f"DEBUG - Screenshot handler added to callback manager with history: {id(history)}")
|
||||
|
||||
# Stream responses from the agent
|
||||
async for result in global_agent.run(last_user_message):
|
||||
# Process result
|
||||
content, metadata = process_agent_result(result)
|
||||
|
||||
# Skip empty content
|
||||
if content or metadata.get("title"):
|
||||
history.append(gr.ChatMessage(role="assistant", content=content, metadata=metadata))
|
||||
yield history
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
# Update with error message
|
||||
history.append(gr.ChatMessage(role="assistant", content=f"Error: {str(e)}"))
|
||||
yield history
|
||||
|
||||
# Connect the components
|
||||
msg.submit(chat_submit, [msg, chatbot], [msg, chatbot]).then(
|
||||
msg.submit(chat_submit, [msg, chatbot_history], [msg, chatbot_history]).then(
|
||||
process_response,
|
||||
[
|
||||
chatbot,
|
||||
chatbot_history,
|
||||
model_choice,
|
||||
custom_model,
|
||||
agent_loop,
|
||||
save_trajectory,
|
||||
recent_images,
|
||||
provider_base_url,
|
||||
provider_api_key,
|
||||
],
|
||||
[chatbot],
|
||||
[chatbot_history],
|
||||
)
|
||||
|
||||
# Clear button functionality
|
||||
clear.click(lambda: None, None, chatbot, queue=False)
|
||||
clear.click(lambda: None, None, chatbot_history, queue=False)
|
||||
|
||||
# Connect agent_loop changes to model selection
|
||||
agent_loop.change(
|
||||
@@ -853,14 +791,15 @@ def create_gradio_ui(
|
||||
queue=False, # Process immediately without queueing
|
||||
)
|
||||
|
||||
# Show/hide custom model textbox based on dropdown selection
|
||||
# Show/hide custom model, provider base URL, and API key textboxes based on dropdown selection
|
||||
def update_custom_model_visibility(model_value):
|
||||
return gr.update(visible=model_value == "Custom model...")
|
||||
is_custom = model_value == "Custom model..."
|
||||
return gr.update(visible=is_custom), gr.update(visible=is_custom), gr.update(visible=is_custom)
|
||||
|
||||
model_choice.change(
|
||||
fn=update_custom_model_visibility,
|
||||
inputs=[model_choice],
|
||||
outputs=[custom_model],
|
||||
outputs=[custom_model, provider_base_url, provider_api_key],
|
||||
queue=False, # Process immediately without queueing
|
||||
)
|
||||
|
||||
|
||||
+20
-9
@@ -17,17 +17,28 @@ class TimeoutException(Exception):
|
||||
|
||||
@contextmanager
|
||||
def timeout(seconds: int):
|
||||
def timeout_handler(signum, frame):
|
||||
raise TimeoutException("OCR process timed out")
|
||||
import threading
|
||||
|
||||
# Check if we're in the main thread
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
def timeout_handler(signum, frame):
|
||||
raise TimeoutException("OCR process timed out")
|
||||
|
||||
original_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(seconds)
|
||||
original_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(seconds)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.alarm(0)
|
||||
signal.signal(signal.SIGALRM, original_handler)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.alarm(0)
|
||||
signal.signal(signal.SIGALRM, original_handler)
|
||||
else:
|
||||
# In a non-main thread, we can't use signal
|
||||
logger.warning("Timeout function called from non-main thread; signal-based timeout disabled")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
class OCRProcessor:
|
||||
|
||||
Reference in New Issue
Block a user