diff --git a/libs/agent/agent/core/base.py b/libs/agent/agent/core/base.py index fb91d855..702be207 100644 --- a/libs/agent/agent/core/base.py +++ b/libs/agent/agent/core/base.py @@ -5,10 +5,12 @@ import asyncio from abc import ABC, abstractmethod from typing import Any, AsyncGenerator, Dict, List, Optional +from agent.providers.omni.parser import ParseResult from computer import Computer from .messages import StandardMessageManager, ImageRetentionConfig from .types import AgentResponse from .experiment import ExperimentManager +from .callbacks import CallbackManager, CallbackHandler logger = logging.getLogger(__name__) @@ -27,6 +29,7 @@ class BaseLoop(ABC): base_dir: Optional[str] = "trajectories", save_trajectory: bool = True, only_n_most_recent_images: Optional[int] = 2, + callback_handlers: Optional[List[CallbackHandler]] = None, **kwargs, ): """Initialize base agent loop. @@ -75,6 +78,9 @@ class BaseLoop(ABC): # Initialize basic tracking self.turn_count = 0 + + # Initialize callback manager + self.callback_manager = CallbackManager(handlers=callback_handlers or []) async def initialize(self) -> None: """Initialize both the API client and computer interface with retries.""" @@ -187,3 +193,17 @@ class BaseLoop(ABC): """ if self.experiment_manager: self.experiment_manager.save_screenshot(img_base64, action_type) + + ########################################### + # EVENT HOOKS / CALLBACKS + ########################################### + + async def handle_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None: + """Process a screenshot through callback managers + + Args: + screenshot_base64: Base64 encoded screenshot + action_type: Type of action that triggered the screenshot + """ + if hasattr(self, 'callback_manager'): + await self.callback_manager.on_screenshot(screenshot_base64, action_type, parsed_screen) diff --git a/libs/agent/agent/core/callbacks.py b/libs/agent/agent/core/callbacks.py index 70eca5ad..59cd0e5a 100644 --- a/libs/agent/agent/core/callbacks.py +++ b/libs/agent/agent/core/callbacks.py @@ -6,6 +6,8 @@ from abc import ABC, abstractmethod from datetime import datetime from typing import Any, Dict, List, Optional, Protocol +from agent.providers.omni.parser import ParseResult + logger = logging.getLogger(__name__) class ContentCallback(Protocol): @@ -20,6 +22,10 @@ class APICallback(Protocol): """Protocol for API callbacks.""" def __call__(self, request: Any, response: Any, error: Optional[Exception] = None) -> None: ... +class ScreenshotCallback(Protocol): + """Protocol for screenshot callbacks.""" + def __call__(self, screenshot_base64: str, action_type: str = "") -> Optional[str]: ... + class BaseCallbackManager(ABC): """Base class for callback managers.""" @@ -110,7 +116,20 @@ class CallbackManager: """ for handler in self.handlers: await handler.on_error(error, **kwargs) - + + async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None: + """Called when a screenshot is taken. + + Args: + screenshot_base64: Base64 encoded screenshot + action_type: Type of action that triggered the screenshot + parsed_screen: Optional output from parsing the screenshot + + Returns: + Modified screenshot or original if no modifications + """ + for handler in self.handlers: + await handler.on_screenshot(screenshot_base64, action_type, parsed_screen) class CallbackHandler(ABC): """Base class for callback handlers.""" @@ -144,4 +163,40 @@ class CallbackHandler(ABC): error: Exception that occurred **kwargs: Additional data """ - pass \ No newline at end of file + pass + + @abstractmethod + async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None: + """Called when a screenshot is taken. + + Args: + screenshot_base64: Base64 encoded screenshot + action_type: Type of action that triggered the screenshot + + Returns: + Optional modified screenshot + """ + pass + +class DefaultCallbackHandler(CallbackHandler): + """Default implementation of CallbackHandler with no-op methods. + + This class implements all abstract methods from CallbackHandler, + allowing subclasses to override only the methods they need. + """ + + async def on_action_start(self, action: str, **kwargs) -> None: + """Default no-op implementation.""" + pass + + async def on_action_end(self, action: str, success: bool, **kwargs) -> None: + """Default no-op implementation.""" + pass + + async def on_error(self, error: Exception, **kwargs) -> None: + """Default no-op implementation.""" + pass + + async def on_screenshot(self, screenshot_base64: str, action_type: str = "") -> None: + """Default no-op implementation.""" + pass \ No newline at end of file diff --git a/libs/agent/agent/providers/omni/clients/oaicompat.py b/libs/agent/agent/providers/omni/clients/oaicompat.py index bddc95e2..6a95896a 100644 --- a/libs/agent/agent/providers/omni/clients/oaicompat.py +++ b/libs/agent/agent/providers/omni/clients/oaicompat.py @@ -45,8 +45,8 @@ class OAICompatClient(BaseOmniClient): max_tokens: Maximum tokens to generate temperature: Generation temperature """ - super().__init__(api_key="EMPTY", model=model) - self.api_key = "EMPTY" # Local endpoints typically don't require an API key + super().__init__(api_key=api_key or "EMPTY", model=model) + self.api_key = api_key or "EMPTY" # Local endpoints typically don't require an API key self.model = model self.provider_base_url = ( provider_base_url or "http://localhost:8000/v1" @@ -146,10 +146,18 @@ class OAICompatClient(BaseOmniClient): base_url = self.provider_base_url or "http://localhost:8000/v1" # Check if the base URL already includes the chat/completions endpoint + endpoint_url = base_url if not endpoint_url.endswith("/chat/completions"): + # If URL is RunPod format, make it OpenAI compatible + if endpoint_url.startswith("https://api.runpod.ai/v2/"): + # Extract RunPod endpoint ID + parts = endpoint_url.split("/") + if len(parts) >= 5: + runpod_id = parts[4] + endpoint_url = f"https://api.runpod.ai/v2/{runpod_id}/openai/v1/chat/completions" # If the URL ends with /v1, append /chat/completions - if endpoint_url.endswith("/v1"): + elif endpoint_url.endswith("/v1"): endpoint_url = f"{endpoint_url}/chat/completions" # If the URL doesn't end with /v1, make sure it has a proper structure elif not endpoint_url.endswith("/"): diff --git a/libs/agent/agent/providers/omni/loop.py b/libs/agent/agent/providers/omni/loop.py index 7fb80654..b53c120c 100644 --- a/libs/agent/agent/providers/omni/loop.py +++ b/libs/agent/agent/providers/omni/loop.py @@ -147,7 +147,7 @@ class OmniLoop(BaseLoop): ) elif self.provider == LLMProvider.OAICOMPAT: self.client = OAICompatClient( - api_key="EMPTY", # Local endpoints typically don't require an API key + api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key model=self.model, provider_base_url=self.provider_base_url, ) @@ -183,7 +183,7 @@ class OmniLoop(BaseLoop): ) elif self.provider == LLMProvider.OAICOMPAT: self.client = OAICompatClient( - api_key="EMPTY", # Local endpoints typically don't require an API key + api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key model=self.model, provider_base_url=self.provider_base_url, ) @@ -548,6 +548,10 @@ class OmniLoop(BaseLoop): img_data = parsed_screen.annotated_image_base64 if "," in img_data: img_data = img_data.split(",")[1] + + # Process screenshot through hooks and save if needed + await self.handle_screenshot(img_data, action_type="state", parsed_screen=parsed_screen) + # Save with a generic "state" action type to indicate this is the current screen state self._save_screenshot(img_data, action_type="state") except Exception as e: @@ -663,6 +667,8 @@ class OmniLoop(BaseLoop): response=response, messages=self.message_manager.messages, model=self.model, + parsed_screen=parsed_screen, + parser=self.parser ) # Yield the response to the caller diff --git a/libs/agent/agent/providers/openai/loop.py b/libs/agent/agent/providers/openai/loop.py index cfb9a443..8e507a1b 100644 --- a/libs/agent/agent/providers/openai/loop.py +++ b/libs/agent/agent/providers/openai/loop.py @@ -194,8 +194,13 @@ class OpenAILoop(BaseLoop): # Convert to base64 if needed if isinstance(screenshot, bytes): screenshot_base64 = base64.b64encode(screenshot).decode("utf-8") + elif isinstance(screenshot, (bytearray, memoryview)): + screenshot_base64 = base64.b64encode(screenshot).decode("utf-8") else: - screenshot_base64 = screenshot + screenshot_base64 = str(screenshot) + + # Emit screenshot callbacks + await self.handle_screenshot(screenshot_base64, action_type="initial_state") # Save screenshot if requested if self.save_trajectory: @@ -204,8 +209,6 @@ class OpenAILoop(BaseLoop): logger.warning( "Converting non-string screenshot_base64 to string for _save_screenshot" ) - if isinstance(screenshot_base64, (bytearray, memoryview)): - screenshot_base64 = base64.b64encode(screenshot_base64).decode("utf-8") self._save_screenshot(screenshot_base64, action_type="state") logger.info("Screenshot saved to trajectory") @@ -336,8 +339,14 @@ class OpenAILoop(BaseLoop): screenshot = await self.computer.interface.screenshot() if isinstance(screenshot, bytes): screenshot_base64 = base64.b64encode(screenshot).decode("utf-8") + elif isinstance(screenshot, (bytearray, memoryview)): + screenshot_base64 = base64.b64encode(bytes(screenshot)).decode("utf-8") else: - screenshot_base64 = screenshot + screenshot_base64 = str(screenshot) + + # Process screenshot through hooks + action_type = f"after_{action.get('type', 'action')}" + await self.handle_screenshot(screenshot_base64, action_type=action_type) # Create computer_call_output computer_call_output = { diff --git a/libs/agent/agent/ui/gradio/app.py b/libs/agent/agent/ui/gradio/app.py index 44027317..a3a017bd 100644 --- a/libs/agent/agent/ui/gradio/app.py +++ b/libs/agent/agent/ui/gradio/app.py @@ -32,9 +32,12 @@ import asyncio import logging from typing import Dict, List, Optional, AsyncGenerator, Any, Tuple, Union import gradio as gr +from gradio.components.chatbot import MetadataDict # Import from agent package from agent.core.types import AgentResponse +from agent.core.callbacks import DefaultCallbackHandler +from agent.providers.omni.parser import ParseResult from computer import Computer from agent import ComputerAgent, AgentLoop, LLM, LLMProvider @@ -43,6 +46,44 @@ from agent import ComputerAgent, AgentLoop, LLM, LLMProvider global_agent = None global_computer = None +# We'll use asyncio.run() instead of a persistent event loop + + +# Custom Screenshot Handler for Gradio chat +class GradioChatScreenshotHandler(DefaultCallbackHandler): + """Custom handler that adds screenshots to the Gradio chatbot and updates annotated image.""" + + def __init__(self, chatbot_history: List[gr.ChatMessage]): + """Initialize with reference to chat history and annotated image component. + + Args: + chatbot_history: Reference to the Gradio chatbot history list + annotated_image: Reference to the annotated image component + """ + self.chatbot_history = chatbot_history + print("GradioChatScreenshotHandler initialized") + + async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None: + """Add screenshot to chatbot when a screenshot is taken and update the annotated image. + + Args: + screenshot_base64: Base64 encoded screenshot + action_type: Type of action that triggered the screenshot + + Returns: + Original screenshot (does not modify it) + """ + # Create a markdown image element for the screenshot + image_markdown = f"![Screenshot after {action_type}](data:image/png;base64,{screenshot_base64})" + + # Simply append the screenshot as a new message + if self.chatbot_history is not None: + self.chatbot_history.append(gr.ChatMessage( + role="assistant", + content=image_markdown, + metadata={"title": f"🖥️ Screenshot - {action_type}", "status": "done"} + )) + # Map model names to specific provider model names MODEL_MAPPINGS = { "openai": { @@ -177,17 +218,18 @@ def get_ollama_models() -> List[str]: return [] -def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> str: +def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> Tuple[str, MetadataDict]: """Extract synthesized text from the agent result.""" synthesized_text = "" + metadata = MetadataDict() if "output" in result and result["output"]: for output in result["output"]: if output.get("type") == "reasoning": + metadata["title"] = "🧠 Reasoning" content = output.get("content", "") if content: synthesized_text += f"{content}\n" - elif output.get("type") == "message": # Handle message type outputs - can contain rich content content = output.get("content", []) @@ -200,7 +242,7 @@ def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> st if text_value: synthesized_text += f"{text_value}\n" - elif output.get("type") == "computer_call": + elif output.get("type") == "computer_call": action = output.get("action", {}) action_type = action.get("type", "") @@ -223,8 +265,11 @@ def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> st synthesized_text += f"Pressed key: {key}\n" else: synthesized_text += f"Performed {action_type} action.\n" - - return synthesized_text.strip() + + metadata["status"] = "done" + metadata["title"] = f"🛠️ {synthesized_text.strip().splitlines()[-1]}" + + return synthesized_text.strip(), metadata def create_computer_instance(verbosity: int = logging.INFO) -> Computer: @@ -247,6 +292,7 @@ def create_agent( verbosity: int = logging.INFO, use_ollama: bool = False, use_oaicompat: bool = False, + provider_base_url: Optional[str] = None, ) -> ComputerAgent: """Create or update the global agent with the specified parameters.""" global global_agent @@ -270,72 +316,55 @@ def create_agent( elif provider == LLMProvider.ANTHROPIC: api_key = os.environ.get("ANTHROPIC_API_KEY", "") - # Create LLM model object with appropriate parameters - provider_base_url = "http://localhost:1234/v1" if use_oaicompat else None + # Use provided provider_base_url if available, otherwise use default + default_base_url = "http://localhost:1234/v1" if use_oaicompat else None + custom_base_url = provider_base_url or default_base_url if use_oaicompat: # Special handling for OAICOMPAT - use OAICOMPAT provider with custom base URL print( - f"DEBUG - Creating OAICOMPAT agent with model: {model_name}, URL: {provider_base_url}" + f"DEBUG - Creating OAICOMPAT agent with model: {model_name}, URL: {custom_base_url}" ) llm = LLM( provider=LLMProvider.OAICOMPAT, # Set to OAICOMPAT instead of using original provider name=model_name, - provider_base_url=provider_base_url, + provider_base_url=custom_base_url, ) print(f"DEBUG - LLM provider is now: {llm.provider}, base URL: {llm.provider_base_url}") # Note: Don't pass use_oaicompat to the agent, as it doesn't accept this parameter elif provider == LLMProvider.OAICOMPAT: # This path is unlikely to be taken with our current approach - llm = LLM(provider=provider, name=model_name, provider_base_url=provider_base_url) + llm = LLM(provider=provider, name=model_name, provider_base_url=custom_base_url) else: # For other providers, just use standard parameters llm = LLM(provider=provider, name=model_name) # Create or update the agent - if global_agent is None: - global_agent = ComputerAgent( - computer=computer, - loop=agent_loop, - model=llm, - api_key=api_key, - save_trajectory=save_trajectory, - only_n_most_recent_images=only_n_most_recent_images, - verbosity=verbosity, - **extra_config, - ) - else: - # Update the existing agent's parameters - global_agent._loop = None # Force recreation of the loop - global_agent.provider = provider - global_agent.loop = agent_loop - global_agent.model = llm - global_agent.api_key = api_key - - # Explicitly update these settings to ensure they take effect - global_agent.save_trajectory = save_trajectory - global_agent.only_n_most_recent_images = only_n_most_recent_images - - # Update Ollama settings if applicable - if use_ollama: - global_agent.use_ollama = True - global_agent.ollama_model = model_name - else: - global_agent.use_ollama = False - global_agent.ollama_model = None - - # Log the updated settings - logging.info( - f"Updated agent settings: save_trajectory={save_trajectory}, recent_images={only_n_most_recent_images}" - ) + global_agent = ComputerAgent( + computer=computer, + loop=agent_loop, + model=llm, + api_key=api_key, + save_trajectory=save_trajectory, + only_n_most_recent_images=only_n_most_recent_images, + verbosity=verbosity, + **extra_config, + ) return global_agent -def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> str: +def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> Tuple[str, MetadataDict]: """Process agent results for the Gradio UI.""" # Extract text content text_obj = result.get("text", {}) + metadata = result.get("metadata", {}) + + # Create a properly typed MetadataDict + metadata_dict = MetadataDict() + metadata_dict["title"] = metadata.get("title", "") + metadata_dict["status"] = "done" + metadata = metadata_dict # For OpenAI's Computer-Use Agent, text field is an object with format property if ( @@ -344,8 +373,11 @@ def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> str: and "format" in text_obj and not text_obj.get("value", "") ): - content = extract_synthesized_text(result) + content, metadata = extract_synthesized_text(result) else: + if not text_obj: + text_obj = result + # For other types of results, try to get text directly if isinstance(text_obj, dict): if "value" in text_obj: @@ -378,180 +410,7 @@ def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> str: if not isinstance(content, str): content = str(content) if content else "" - return content - - -def respond( - message: str, - history: List[Tuple[str, str]], - model_choice, # Accept Gradio Dropdown component - agent_loop, # Accept Gradio Dropdown component - save_trajectory, # Accept Gradio Checkbox component - recent_images, # Accept Gradio Slider component - openai_api_key: Optional[str] = None, - anthropic_api_key: Optional[str] = None, -) -> str: - """Process a message with the Computer-Use Agent and return the response.""" - import asyncio - - # Get actual values from Gradio components - model_choice_value = model_choice.value if hasattr(model_choice, "value") else model_choice - agent_loop_value = agent_loop.value if hasattr(agent_loop, "value") else agent_loop - save_trajectory_value = ( - save_trajectory.value if hasattr(save_trajectory, "value") else save_trajectory - ) - recent_images_value = int( - recent_images.value if hasattr(recent_images, "value") else recent_images - ) - - # Debug logging - print(f"DEBUG - Model choice object: {type(model_choice)}") - print(f"DEBUG - Model choice value: {model_choice_value}") - print(f"DEBUG - Agent loop value: {agent_loop_value}") - - # Create a new event loop for this function call - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - async def _async_respond(): - # Extract the loop type and model from the selection - loop_provider = "OPENAI" - if isinstance(model_choice_value, str): - # This is the case for a custom text input from textbox - if agent_loop_value == "OMNI": - loop_provider = "OMNI" - # Use the custom model name as is - model_id = model_choice_value - print(f"DEBUG - Using custom model: {model_id}") - else: - # Handle regular dropdown value as string - if model_choice_value.startswith("OpenAI:"): - loop_provider = "OPENAI" - model_id = model_choice_value.replace("OpenAI: ", "").lower() - elif model_choice_value.startswith("Anthropic:"): - loop_provider = "ANTHROPIC" - model_id = model_choice_value.replace("Anthropic: ", "").lower() - elif model_choice_value.startswith("OMNI:"): - loop_provider = "OMNI" - if "GPT" in model_choice_value: - model_id = model_choice_value.replace("OMNI: OpenAI ", "").lower() - elif "Claude" in model_choice_value: - model_id = model_choice_value.replace("OMNI: ", "").lower() - elif "Ollama" in model_choice_value: - loop_provider = "OMNI-OLLAMA" - # Extract everything after "OMNI: Ollama " which is the full model name (e.g., phi3:latest) - model_id = model_choice_value.replace("OMNI: Ollama ", "") - print(f"DEBUG - Ollama model ID: {model_id}") - else: - model_id = "default" - else: - # Default case - loop_provider = agent_loop_value - model_id = "default" - else: - # Model choice is not a string (shouldn't happen, but handle anyway) - loop_provider = agent_loop_value - model_id = "default" - - print(f"DEBUG - Using loop provider: {loop_provider}, model_id: {model_id}") - - # Use the mapping function to get provider, model name and agent loop - provider, model_name, agent_loop_type = get_provider_and_model(model_id, loop_provider) - print( - f"DEBUG - After mapping: provider={provider}, model_name={model_name}, agent_loop={agent_loop_type}" - ) - - # Special handling for OAICOMPAT to bypass provider-specific errors - # Creates the agent with OPENAI provider but using custom model name and provider base URL - is_oaicompat = str(provider) == "oaicompat" - - # Don't override the provider for OAICOMPAT - instead pass it through - # if is_oaicompat: - # provider = LLMProvider.OPENAI - - # Get API key based on provider - if provider == LLMProvider.OPENAI: - api_key = openai_api_key or os.environ.get("OPENAI_API_KEY", "") - elif provider == LLMProvider.ANTHROPIC: - api_key = anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY", "") - else: - api_key = "" - - # Check for settings changes if agent already exists - settings_changed = False - settings_message = "" - if global_agent is not None: - # Safely check if save_trajectory setting changed - current_save_traj = getattr(global_agent, "save_trajectory", None) - if current_save_traj is not None and current_save_traj != save_trajectory_value: - settings_changed = True - settings_message += f"Save trajectory set to: {save_trajectory_value}. " - - # Safely check if recent_images setting changed - current_recent_images = getattr(global_agent, "only_n_most_recent_images", None) - if current_recent_images is not None and current_recent_images != recent_images_value: - settings_changed = True - settings_message += f"Recent images set to: {recent_images_value}. " - - # Create or update the agent - try: - create_agent( - provider=provider, - agent_loop=agent_loop_type, - model_name=model_name, - api_key=api_key, - save_trajectory=save_trajectory_value, - only_n_most_recent_images=recent_images_value, - use_ollama=loop_provider == "OMNI-OLLAMA", - use_oaicompat=is_oaicompat, - ) - - if global_agent is None: - return "Failed to create agent. Check API keys and configuration." - except Exception as e: - return f"Error creating agent: {str(e)}" - - # Notify about settings changes if needed - if settings_changed: - return f"Settings updated: {settings_message}" - - # Collect all responses - response_text = [] - - # Run the agent - try: - async for result in global_agent.run(message): - # Process result - content = process_agent_result(result) - - # Skip empty content - if not content: - continue - - # Add content to response list - response_text.append(content) - - # Return the full response as a single string - return "\n".join(response_text) if response_text else "Task completed." - - except Exception as e: - import traceback - - traceback.print_exc() - return f"Error: {str(e)}" - - # Run the async function and get the result - try: - result = loop.run_until_complete(_async_respond()) - loop.close() - return result - except Exception as e: - loop.close() - import traceback - - traceback.print_exc() - return f"Error executing async operation: {str(e)}" - + return content, metadata def create_gradio_ui( provider_name: str = "openai", @@ -725,48 +584,68 @@ def create_gradio_ui( """ ) - # Configuration options - agent_loop = gr.Dropdown( - choices=["OPENAI", "ANTHROPIC", "OMNI"], - label="Agent Loop", - value=initial_loop, - info="Select the agent loop provider", - ) + with gr.Accordion("Configuration", open=True): + # Configuration options + agent_loop = gr.Dropdown( + choices=["OPENAI", "ANTHROPIC", "OMNI"], + label="Agent Loop", + value=initial_loop, + info="Select the agent loop provider", + ) - # Create model selection dropdown with custom value support for OMNI - model_choice = gr.Dropdown( - choices=provider_to_models.get(initial_loop, ["No models available"]), - label="LLM Provider and Model", - value=initial_model, - info="Select model or choose 'Custom model...' to enter a custom name", - interactive=True, - ) + # Create model selection dropdown with custom value support for OMNI + model_choice = gr.Dropdown( + choices=provider_to_models.get(initial_loop, ["No models available"]), + label="LLM Provider and Model", + value=initial_model, + info="Select model or choose 'Custom model...' to enter a custom name", + interactive=True, + ) - # Add custom model textbox (only visible when "Custom model..." is selected) - custom_model = gr.Textbox( - label="Custom Model Name", - placeholder="Enter custom model name (e.g., Qwen2.5-VL-7B-Instruct)", - value="Qwen2.5-VL-7B-Instruct", # Default value - visible=False, # Initially hidden - interactive=True, - ) + # Add custom model textbox (only visible when "Custom model..." is selected) + custom_model = gr.Textbox( + label="Custom Model Name", + placeholder="Enter custom model name (e.g., Qwen2.5-VL-7B-Instruct)", + value="Qwen2.5-VL-7B-Instruct", # Default value + visible=False, # Initially hidden + interactive=True, + ) + + # Add custom provider base URL textbox (only visible when "Custom model..." is selected) + provider_base_url = gr.Textbox( + label="Provider Base URL", + placeholder="Enter provider base URL (e.g., http://localhost:1234/v1)", + value="http://localhost:1234/v1", # Default value + visible=False, # Initially hidden + interactive=True, + ) + + # Add custom API key textbox (only visible when "Custom model..." is selected) + provider_api_key = gr.Textbox( + label="Provider API Key", + placeholder="Enter provider API key (if required)", + value="", # Default empty value + visible=False, # Initially hidden + interactive=True, + type="password", # Hide the API key + ) - save_trajectory = gr.Checkbox( - label="Save Trajectory", - value=True, - info="Save the agent's trajectory for debugging", - interactive=True, - ) + save_trajectory = gr.Checkbox( + label="Save Trajectory", + value=True, + info="Save the agent's trajectory for debugging", + interactive=True, + ) - recent_images = gr.Slider( - label="Recent Images", - minimum=1, - maximum=10, - value=3, - step=1, - info="Number of recent images to keep in context", - interactive=True, - ) + recent_images = gr.Slider( + label="Recent Images", + minimum=1, + maximum=10, + value=3, + step=1, + info="Number of recent images to keep in context", + interactive=True, + ) # Right column for chat interface with gr.Column(scale=2): @@ -775,7 +654,7 @@ def create_gradio_ui( "Ask me to perform tasks in a virtual macOS environment.
Built with github.com/trycua/cua." ) - chatbot = gr.Chatbot() + chatbot_history = gr.Chatbot(type='messages') msg = gr.Textbox( placeholder="Ask me to perform tasks in a virtual macOS environment" ) @@ -787,24 +666,27 @@ def create_gradio_ui( # Function to handle chat submission def chat_submit(message, history): # Add user message to history - history = history + [(message, None)] + history.append(gr.ChatMessage(role="user", content=message)) return "", history # Function to process agent response after user input - def process_response( + async def process_response( history, model_choice_value, custom_model_value, agent_loop_choice, save_traj, recent_imgs, + custom_url_value=None, + custom_api_key=None, ): if not history: - return history + yield history + return # Get the last user message - last_user_message = history[-1][0] - + last_user_message = history[-1]['content'] + # Use custom model value if "Custom model..." is selected model_to_use = ( custom_model_value @@ -812,38 +694,94 @@ def create_gradio_ui( else model_choice_value ) - # Process with agent - response = respond( - last_user_message, - history[:-1], # History without the last message - model_to_use, - agent_loop_choice, - save_traj, - recent_imgs, - openai_api_key, - anthropic_api_key, - ) - - # Update the last assistant message - history[-1] = (last_user_message, response) - return history - + try: + # Get the model, agent loop, and provider + provider, model_name, agent_loop_type = get_provider_and_model( + model_to_use, agent_loop_choice + ) + + # Special handling for OAICOMPAT + is_oaicompat = str(provider) == "oaicompat" + + # Get API key based on provider + if model_choice_value == "Custom model..." and custom_api_key: + # Use custom API key if provided for custom model + api_key = custom_api_key + print(f"DEBUG - Using custom API key for model: {model_name}") + elif provider == LLMProvider.OPENAI: + api_key = openai_api_key or os.environ.get("OPENAI_API_KEY", "") + elif provider == LLMProvider.ANTHROPIC: + api_key = anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY", "") + else: + api_key = "" + + # Create or update the agent + create_agent( + provider=provider, + agent_loop=agent_loop_type, + model_name=model_name, + api_key=api_key, + save_trajectory=save_traj, + only_n_most_recent_images=recent_imgs, + use_ollama=agent_loop_choice == "OMNI-OLLAMA", + use_oaicompat=is_oaicompat, + provider_base_url=custom_url_value if is_oaicompat and model_choice_value == "Custom model..." else None, + ) + + if global_agent is None: + # Add initial empty assistant message + history.append(gr.ChatMessage(role="assistant", content="Failed to create agent. Check API keys and configuration.")) + yield history + return + + # Add the screenshot handler to the agent's loop if available + if global_agent and hasattr(global_agent, "_loop"): + print("DEBUG - Adding screenshot handler to agent loop") + + # Create the screenshot handler with references to UI components + screenshot_handler = GradioChatScreenshotHandler( + history + ) + + # Add the handler to the callback manager if it exists + if hasattr(global_agent._loop, "callback_manager"): + global_agent._loop.callback_manager.add_handler(screenshot_handler) + print(f"DEBUG - Screenshot handler added to callback manager with history: {id(history)}") + + # Stream responses from the agent + async for result in global_agent.run(last_user_message): + # Process result + content, metadata = process_agent_result(result) + + # Skip empty content + if content or metadata.get("title"): + history.append(gr.ChatMessage(role="assistant", content=content, metadata=metadata)) + yield history + except Exception as e: + import traceback + traceback.print_exc() + # Update with error message + history.append(gr.ChatMessage(role="assistant", content=f"Error: {str(e)}")) + yield history + # Connect the components - msg.submit(chat_submit, [msg, chatbot], [msg, chatbot]).then( + msg.submit(chat_submit, [msg, chatbot_history], [msg, chatbot_history]).then( process_response, [ - chatbot, + chatbot_history, model_choice, custom_model, agent_loop, save_trajectory, recent_images, + provider_base_url, + provider_api_key, ], - [chatbot], + [chatbot_history], ) # Clear button functionality - clear.click(lambda: None, None, chatbot, queue=False) + clear.click(lambda: None, None, chatbot_history, queue=False) # Connect agent_loop changes to model selection agent_loop.change( @@ -853,14 +791,15 @@ def create_gradio_ui( queue=False, # Process immediately without queueing ) - # Show/hide custom model textbox based on dropdown selection + # Show/hide custom model, provider base URL, and API key textboxes based on dropdown selection def update_custom_model_visibility(model_value): - return gr.update(visible=model_value == "Custom model...") + is_custom = model_value == "Custom model..." + return gr.update(visible=is_custom), gr.update(visible=is_custom), gr.update(visible=is_custom) model_choice.change( fn=update_custom_model_visibility, inputs=[model_choice], - outputs=[custom_model], + outputs=[custom_model, provider_base_url, provider_api_key], queue=False, # Process immediately without queueing ) diff --git a/libs/som/som/ocr.py b/libs/som/som/ocr.py index 6d10e85a..32f15bd1 100644 --- a/libs/som/som/ocr.py +++ b/libs/som/som/ocr.py @@ -17,17 +17,28 @@ class TimeoutException(Exception): @contextmanager def timeout(seconds: int): - def timeout_handler(signum, frame): - raise TimeoutException("OCR process timed out") + import threading + + # Check if we're in the main thread + if threading.current_thread() is threading.main_thread(): + def timeout_handler(signum, frame): + raise TimeoutException("OCR process timed out") - original_handler = signal.signal(signal.SIGALRM, timeout_handler) - signal.alarm(seconds) + original_handler = signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(seconds) - try: - yield - finally: - signal.alarm(0) - signal.signal(signal.SIGALRM, original_handler) + try: + yield + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, original_handler) + else: + # In a non-main thread, we can't use signal + logger.warning("Timeout function called from non-main thread; signal-based timeout disabled") + try: + yield + finally: + pass class OCRProcessor: