mirror of
https://github.com/trycua/computer.git
synced 2026-01-03 20:10:04 -06:00
Fix gradio model selection
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -239,4 +239,7 @@ shared
|
||||
trajectories/
|
||||
|
||||
# Installation ID Storage
|
||||
.storage/
|
||||
.storage/
|
||||
|
||||
# Gradio settings
|
||||
.gradio_settings.json
|
||||
@@ -59,7 +59,7 @@ If you want to use AI agents with virtualized environments:
|
||||
|
||||
2. Pull the latest macOS CUA image:
|
||||
```bash
|
||||
lume pull macos-sequoia-cua:latest --no-cache
|
||||
lume pull macos-sequoia-cua:latest
|
||||
```
|
||||
|
||||
3. Start Lume daemon service:
|
||||
|
||||
@@ -172,3 +172,26 @@ async for result in agent.run(task):
|
||||
print("\nTool Call Output:")
|
||||
print(output)
|
||||
```
|
||||
|
||||
### Gradio UI
|
||||
|
||||
You can also interact with the agent using a Gradio interface.
|
||||
|
||||
```python
|
||||
# Ensure environment variables (e.g., API keys) are loaded
|
||||
# You might need a helper function like load_dotenv_files() if using .env
|
||||
# from utils import load_dotenv_files
|
||||
# load_dotenv_files()
|
||||
|
||||
from agent.ui.gradio.app import create_gradio_ui
|
||||
|
||||
app = create_gradio_ui()
|
||||
app.launch(share=False)
|
||||
```
|
||||
|
||||
**Note on Settings Persistence:**
|
||||
|
||||
* The Gradio UI automatically saves your configuration (Agent Loop, Model Choice, Custom Base URL, Save Trajectory state, Recent Images count) to a file named `.gradio_settings.json` in the project's root directory when you successfully run a task.
|
||||
* This allows your preferences to persist between sessions.
|
||||
* API keys entered into the custom provider field are **not** saved in this file for security reasons. Manage API keys using environment variables (e.g., `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`) or a `.env` file.
|
||||
* It's recommended to add `.gradio_settings.json` to your `.gitignore` file.
|
||||
@@ -3,23 +3,33 @@ import httpx
|
||||
from anthropic.types.beta import BetaContentBlockParam
|
||||
from ..tools import ToolResult
|
||||
|
||||
|
||||
class APICallback(Protocol):
|
||||
"""Protocol for API callbacks."""
|
||||
def __call__(self, request: httpx.Request | None,
|
||||
response: httpx.Response | object | None,
|
||||
error: Exception | None) -> None: ...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
request: httpx.Request | None,
|
||||
response: httpx.Response | object | None,
|
||||
error: Exception | None,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class ContentCallback(Protocol):
|
||||
"""Protocol for content callbacks."""
|
||||
|
||||
def __call__(self, content: BetaContentBlockParam) -> None: ...
|
||||
|
||||
|
||||
class ToolCallback(Protocol):
|
||||
"""Protocol for tool callbacks."""
|
||||
|
||||
def __call__(self, result: ToolResult, tool_id: str) -> None: ...
|
||||
|
||||
|
||||
class CallbackManager:
|
||||
"""Manages various callbacks for the agent system."""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content_callback: ContentCallback,
|
||||
@@ -27,7 +37,7 @@ class CallbackManager:
|
||||
api_callback: APICallback,
|
||||
):
|
||||
"""Initialize the callback manager.
|
||||
|
||||
|
||||
Args:
|
||||
content_callback: Callback for content updates
|
||||
tool_callback: Callback for tool execution results
|
||||
@@ -36,20 +46,20 @@ class CallbackManager:
|
||||
self.content_callback = content_callback
|
||||
self.tool_callback = tool_callback
|
||||
self.api_callback = api_callback
|
||||
|
||||
|
||||
def on_content(self, content: BetaContentBlockParam) -> None:
|
||||
"""Handle content updates."""
|
||||
self.content_callback(content)
|
||||
|
||||
|
||||
def on_tool_result(self, result: ToolResult, tool_id: str) -> None:
|
||||
"""Handle tool execution results."""
|
||||
self.tool_callback(result, tool_id)
|
||||
|
||||
|
||||
def on_api_interaction(
|
||||
self,
|
||||
request: httpx.Request | None,
|
||||
response: httpx.Response | object | None,
|
||||
error: Exception | None
|
||||
error: Exception | None,
|
||||
) -> None:
|
||||
"""Handle API interactions."""
|
||||
self.api_callback(request, response, error)
|
||||
self.api_callback(request, response, error)
|
||||
|
||||
@@ -30,6 +30,8 @@ Requirements:
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, AsyncGenerator, Any, Tuple, Union
|
||||
import gradio as gr
|
||||
from gradio.components.chatbot import MetadataDict
|
||||
@@ -45,44 +47,86 @@ from agent import ComputerAgent, AgentLoop, LLM, LLMProvider
|
||||
# Global variables
|
||||
global_agent = None
|
||||
global_computer = None
|
||||
SETTINGS_FILE = Path(".gradio_settings.json")
|
||||
|
||||
# We'll use asyncio.run() instead of a persistent event loop
|
||||
|
||||
|
||||
# --- Settings Load/Save Functions ---
|
||||
def load_settings() -> Dict[str, Any]:
|
||||
"""Loads settings from the JSON file."""
|
||||
if SETTINGS_FILE.exists():
|
||||
try:
|
||||
with open(SETTINGS_FILE, "r") as f:
|
||||
settings = json.load(f)
|
||||
# Basic validation (can be expanded)
|
||||
if isinstance(settings, dict):
|
||||
print(f"DEBUG - Loaded settings from {SETTINGS_FILE}")
|
||||
return settings
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
print(f"Warning: Could not load settings from {SETTINGS_FILE}: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def save_settings(settings: Dict[str, Any]):
|
||||
"""Saves settings to the JSON file."""
|
||||
# Ensure sensitive keys are not saved
|
||||
settings.pop("provider_api_key", None)
|
||||
try:
|
||||
with open(SETTINGS_FILE, "w") as f:
|
||||
json.dump(settings, f, indent=4)
|
||||
print(f"DEBUG - Saved settings to {SETTINGS_FILE}")
|
||||
except IOError as e:
|
||||
print(f"Warning: Could not save settings to {SETTINGS_FILE}: {e}")
|
||||
|
||||
|
||||
# --- End Settings Load/Save ---
|
||||
|
||||
|
||||
# Custom Screenshot Handler for Gradio chat
|
||||
class GradioChatScreenshotHandler(DefaultCallbackHandler):
|
||||
"""Custom handler that adds screenshots to the Gradio chatbot and updates annotated image."""
|
||||
|
||||
|
||||
def __init__(self, chatbot_history: List[gr.ChatMessage]):
|
||||
"""Initialize with reference to chat history and annotated image component.
|
||||
|
||||
|
||||
Args:
|
||||
chatbot_history: Reference to the Gradio chatbot history list
|
||||
annotated_image: Reference to the annotated image component
|
||||
"""
|
||||
self.chatbot_history = chatbot_history
|
||||
print("GradioChatScreenshotHandler initialized")
|
||||
|
||||
async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
|
||||
|
||||
async def on_screenshot(
|
||||
self,
|
||||
screenshot_base64: str,
|
||||
action_type: str = "",
|
||||
parsed_screen: Optional[ParseResult] = None,
|
||||
) -> None:
|
||||
"""Add screenshot to chatbot when a screenshot is taken and update the annotated image.
|
||||
|
||||
|
||||
Args:
|
||||
screenshot_base64: Base64 encoded screenshot
|
||||
action_type: Type of action that triggered the screenshot
|
||||
|
||||
|
||||
Returns:
|
||||
Original screenshot (does not modify it)
|
||||
"""
|
||||
# Create a markdown image element for the screenshot
|
||||
image_markdown = f""
|
||||
|
||||
image_markdown = (
|
||||
f""
|
||||
)
|
||||
|
||||
# Simply append the screenshot as a new message
|
||||
if self.chatbot_history is not None:
|
||||
self.chatbot_history.append(gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=image_markdown,
|
||||
metadata={"title": f"🖥️ Screenshot - {action_type}", "status": "done"}
|
||||
))
|
||||
self.chatbot_history.append(
|
||||
gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=image_markdown,
|
||||
metadata={"title": f"🖥️ Screenshot - {action_type}", "status": "done"},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Map model names to specific provider model names
|
||||
MODEL_MAPPINGS = {
|
||||
@@ -94,6 +138,7 @@ MODEL_MAPPINGS = {
|
||||
"gpt-4o": "computer_use_preview",
|
||||
"gpt-4": "computer_use_preview",
|
||||
"gpt-4.5-preview": "computer_use_preview",
|
||||
"gpt-4o-mini": "gpt-4o-mini",
|
||||
},
|
||||
"anthropic": {
|
||||
# Default to newest model
|
||||
@@ -111,6 +156,7 @@ MODEL_MAPPINGS = {
|
||||
# OMNI works with any of these models
|
||||
"default": "gpt-4o",
|
||||
"gpt-4o": "gpt-4o",
|
||||
"gpt-4o-mini": "gpt-4o-mini",
|
||||
"gpt-4": "gpt-4",
|
||||
"gpt-4.5-preview": "gpt-4.5-preview",
|
||||
"claude-3-5-sonnet-20240620": "claude-3-5-sonnet-20240620",
|
||||
@@ -160,30 +206,82 @@ def get_provider_and_model(model_name: str, loop_provider: str) -> tuple:
|
||||
model_name.lower(), MODEL_MAPPINGS["anthropic"]["default"]
|
||||
)
|
||||
elif agent_loop == AgentLoop.OMNI:
|
||||
# For OMNI, select provider based on model name or loop_provider
|
||||
if loop_provider == "OMNI-OLLAMA":
|
||||
provider = LLMProvider.OLLAMA
|
||||
# Determine provider and clean model name based on the full string from UI
|
||||
cleaned_model_name = model_name # Default to using the name as-is (for custom)
|
||||
|
||||
# For Ollama models from the UI dropdown, we use the model name as is
|
||||
# No need to parse it - it's already the correct Ollama model name
|
||||
model_name_to_use = model_name
|
||||
elif "claude" in model_name.lower():
|
||||
provider = LLMProvider.ANTHROPIC
|
||||
model_name_to_use = MODEL_MAPPINGS["omni"].get(
|
||||
model_name.lower(), MODEL_MAPPINGS["omni"]["default"]
|
||||
)
|
||||
elif "gpt" in model_name.lower():
|
||||
provider = LLMProvider.OPENAI
|
||||
model_name_to_use = MODEL_MAPPINGS["omni"].get(
|
||||
model_name.lower(), MODEL_MAPPINGS["omni"]["default"]
|
||||
)
|
||||
else:
|
||||
# Handle custom model names - use the OAICOMPAT provider
|
||||
if model_name == "Custom model...":
|
||||
# Actual model name comes from custom_model_value via model_to_use.
|
||||
# Assume OAICOMPAT for custom models unless overridden by URL/key later?
|
||||
# get_provider_and_model determines the *initial* provider/model.
|
||||
# The custom URL/key in process_response ultimately dictates the OAICOMPAT setup.
|
||||
provider = LLMProvider.OAICOMPAT
|
||||
# Use the model name as is without mapping, or use default if empty
|
||||
model_name_to_use = (
|
||||
model_name if model_name.strip() else MODEL_MAPPINGS["oaicompat"]["default"]
|
||||
# We set cleaned_model_name below outside the checks based on model_to_use
|
||||
cleaned_model_name = "" # Placeholder, will be set by custom value later
|
||||
elif model_name.startswith("OMNI: Ollama "):
|
||||
provider = LLMProvider.OLLAMA
|
||||
# Extract the part after "OMNI: Ollama "
|
||||
cleaned_model_name = model_name.split("OMNI: Ollama ", 1)[1]
|
||||
elif model_name.startswith("OMNI: Claude "):
|
||||
provider = LLMProvider.ANTHROPIC
|
||||
# Extract the canonical model name based on the UI string
|
||||
# e.g., "OMNI: Claude 3.7 Sonnet (20250219)" -> "3.7 Sonnet" and "20250219"
|
||||
parts = model_name.split(" (")
|
||||
model_key_part = parts[0].replace("OMNI: Claude ", "")
|
||||
date_part = parts[1].replace(")", "") if len(parts) > 1 else ""
|
||||
|
||||
# Normalize the extracted key part for comparison
|
||||
# "3.7 Sonnet" -> "37sonnet"
|
||||
model_key_part_norm = model_key_part.lower().replace(".", "").replace(" ", "")
|
||||
|
||||
cleaned_model_name = MODEL_MAPPINGS["omni"]["default"] # Default if not found
|
||||
# Find the canonical name in the main Anthropic map
|
||||
for key_anthropic, val_anthropic in MODEL_MAPPINGS["anthropic"].items():
|
||||
# Normalize the canonical key for comparison
|
||||
# "claude-3-7-sonnet-20250219" -> "claude37sonnet20250219"
|
||||
key_anthropic_norm = key_anthropic.lower().replace("-", "")
|
||||
|
||||
# Check if the normalized canonical key starts with "claude" + normalized extracted part
|
||||
# AND contains the date part.
|
||||
if (
|
||||
key_anthropic_norm.startswith("claude" + model_key_part_norm)
|
||||
and date_part in key_anthropic_norm
|
||||
):
|
||||
cleaned_model_name = (
|
||||
val_anthropic # Use the canonical name like "claude-3-7-sonnet-20250219"
|
||||
)
|
||||
break
|
||||
elif model_name.startswith("OMNI: OpenAI "):
|
||||
provider = LLMProvider.OPENAI
|
||||
# Extract the model part, e.g., "GPT-4o mini"
|
||||
model_key_part = model_name.replace("OMNI: OpenAI ", "")
|
||||
# Normalize the extracted part: "gpt4omini"
|
||||
model_key_part_norm = model_key_part.lower().replace("-", "").replace(" ", "")
|
||||
|
||||
cleaned_model_name = MODEL_MAPPINGS["omni"]["default"] # Default if not found
|
||||
# Find the canonical name in the main OMNI map for OpenAI models
|
||||
for key_omni, val_omni in MODEL_MAPPINGS["omni"].items():
|
||||
# Normalize the omni map key: "gpt-4o-mini" -> "gpt4omini"
|
||||
key_omni_norm = key_omni.lower().replace("-", "").replace(" ", "")
|
||||
# Check if the normalized omni key matches the normalized extracted part
|
||||
if key_omni_norm == model_key_part_norm:
|
||||
cleaned_model_name = (
|
||||
val_omni # Use the value from the OMNI map (e.g., gpt-4o-mini)
|
||||
)
|
||||
break
|
||||
# Note: No fallback needed here as we explicitly check against omni keys
|
||||
|
||||
else: # Handles unexpected formats or the raw custom name if "Custom model..." selected
|
||||
# Should only happen if user selected "Custom model..."
|
||||
# Or if a model name format isn't caught above
|
||||
provider = LLMProvider.OAICOMPAT
|
||||
cleaned_model_name = (
|
||||
model_name.strip() if model_name.strip() else MODEL_MAPPINGS["oaicompat"]["default"]
|
||||
)
|
||||
|
||||
# Assign the determined model name
|
||||
model_name_to_use = cleaned_model_name
|
||||
# agent_loop remains AgentLoop.OMNI
|
||||
|
||||
else:
|
||||
# Default to OpenAI if unrecognized loop
|
||||
provider = LLMProvider.OPENAI
|
||||
@@ -218,7 +316,9 @@ def get_ollama_models() -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> Tuple[str, MetadataDict]:
|
||||
def extract_synthesized_text(
|
||||
result: Union[AgentResponse, Dict[str, Any]],
|
||||
) -> Tuple[str, MetadataDict]:
|
||||
"""Extract synthesized text from the agent result."""
|
||||
synthesized_text = ""
|
||||
metadata = MetadataDict()
|
||||
@@ -242,7 +342,7 @@ def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> Tu
|
||||
if text_value:
|
||||
synthesized_text += f"{text_value}\n"
|
||||
|
||||
elif output.get("type") == "computer_call":
|
||||
elif output.get("type") == "computer_call":
|
||||
action = output.get("action", {})
|
||||
action_type = action.get("type", "")
|
||||
|
||||
@@ -265,10 +365,10 @@ def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> Tu
|
||||
synthesized_text += f"Pressed key: {key}\n"
|
||||
else:
|
||||
synthesized_text += f"Performed {action_type} action.\n"
|
||||
|
||||
|
||||
metadata["status"] = "done"
|
||||
metadata["title"] = f"🛠️ {synthesized_text.strip().splitlines()[-1]}"
|
||||
|
||||
|
||||
return synthesized_text.strip(), metadata
|
||||
|
||||
|
||||
@@ -290,7 +390,6 @@ def create_agent(
|
||||
save_trajectory: bool = True,
|
||||
only_n_most_recent_images: int = 3,
|
||||
verbosity: int = logging.INFO,
|
||||
use_ollama: bool = False,
|
||||
use_oaicompat: bool = False,
|
||||
provider_base_url: Optional[str] = None,
|
||||
) -> ComputerAgent:
|
||||
@@ -300,15 +399,6 @@ def create_agent(
|
||||
# Create the computer if not already done
|
||||
computer = create_computer_instance(verbosity=verbosity)
|
||||
|
||||
# Extra configuration to pass to the agent
|
||||
extra_config = {}
|
||||
|
||||
# For Ollama models, we'll pass use_ollama and the model_name directly
|
||||
if use_ollama:
|
||||
extra_config["use_ollama"] = True
|
||||
extra_config["ollama_model"] = model_name
|
||||
print(f"DEBUG - Using Ollama with model: {model_name}")
|
||||
|
||||
# Get API key from environment if not provided
|
||||
if api_key is None:
|
||||
if provider == LLMProvider.OPENAI:
|
||||
@@ -322,9 +412,7 @@ def create_agent(
|
||||
|
||||
if use_oaicompat:
|
||||
# Special handling for OAICOMPAT - use OAICOMPAT provider with custom base URL
|
||||
print(
|
||||
f"DEBUG - Creating OAICOMPAT agent with model: {model_name}, URL: {custom_base_url}"
|
||||
)
|
||||
print(f"DEBUG - Creating OAICOMPAT agent with model: {model_name}, URL: {custom_base_url}")
|
||||
llm = LLM(
|
||||
provider=LLMProvider.OAICOMPAT, # Set to OAICOMPAT instead of using original provider
|
||||
name=model_name,
|
||||
@@ -348,7 +436,6 @@ def create_agent(
|
||||
save_trajectory=save_trajectory,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
verbosity=verbosity,
|
||||
**extra_config,
|
||||
)
|
||||
|
||||
return global_agent
|
||||
@@ -359,7 +446,7 @@ def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> Tuple[
|
||||
# Extract text content
|
||||
text_obj = result.get("text", {})
|
||||
metadata = result.get("metadata", {})
|
||||
|
||||
|
||||
# Create a properly typed MetadataDict
|
||||
metadata_dict = MetadataDict()
|
||||
metadata_dict["title"] = metadata.get("title", "")
|
||||
@@ -412,6 +499,7 @@ def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> Tuple[
|
||||
|
||||
return content, metadata
|
||||
|
||||
|
||||
def create_gradio_ui(
|
||||
provider_name: str = "openai",
|
||||
model_name: str = "gpt-4o",
|
||||
@@ -425,6 +513,10 @@ def create_gradio_ui(
|
||||
Returns:
|
||||
A Gradio Blocks application
|
||||
"""
|
||||
# --- Load Settings ---
|
||||
saved_settings = load_settings()
|
||||
# --- End Load Settings ---
|
||||
|
||||
# Check for API keys
|
||||
openai_api_key = os.environ.get("OPENAI_API_KEY", "")
|
||||
anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY", "")
|
||||
@@ -438,6 +530,7 @@ def create_gradio_ui(
|
||||
openai_models = ["OpenAI: Computer-Use Preview"]
|
||||
omni_models += [
|
||||
"OMNI: OpenAI GPT-4o",
|
||||
"OMNI: OpenAI GPT-4o mini",
|
||||
"OMNI: OpenAI GPT-4.5-preview",
|
||||
]
|
||||
|
||||
@@ -460,21 +553,33 @@ def create_gradio_ui(
|
||||
"OMNI": omni_models + ["Custom model..."], # Add custom model option
|
||||
}
|
||||
|
||||
# Get initial agent loop and model based on provided parameters
|
||||
if provider_name.lower() == "openai":
|
||||
initial_loop = "OPENAI"
|
||||
initial_model = "OpenAI: Computer-Use Preview" if openai_models else "No models available"
|
||||
elif provider_name.lower() == "anthropic":
|
||||
initial_loop = "ANTHROPIC"
|
||||
initial_model = anthropic_models[0] if anthropic_models else "No models available"
|
||||
# --- Apply Saved Settings (override defaults if available) ---
|
||||
initial_loop = saved_settings.get("agent_loop", "OMNI")
|
||||
# Ensure the saved model is actually available in the choices for the loaded loop
|
||||
available_models_for_loop = provider_to_models.get(initial_loop, [])
|
||||
saved_model_choice = saved_settings.get("model_choice")
|
||||
if saved_model_choice and saved_model_choice in available_models_for_loop:
|
||||
initial_model = saved_model_choice
|
||||
else:
|
||||
initial_loop = "OMNI"
|
||||
if model_name == "gpt-4o" and "OMNI: OpenAI GPT-4o" in omni_models:
|
||||
initial_model = "OMNI: OpenAI GPT-4o"
|
||||
elif "claude" in model_name.lower() and omni_models:
|
||||
initial_model = next((m for m in omni_models if "Claude" in m), omni_models[0])
|
||||
else:
|
||||
# If saved model isn't valid for the loop, reset to default for that loop
|
||||
if initial_loop == "OPENAI":
|
||||
initial_model = (
|
||||
"OpenAI: Computer-Use Preview" if openai_models else "No models available"
|
||||
)
|
||||
elif initial_loop == "ANTHROPIC":
|
||||
initial_model = anthropic_models[0] if anthropic_models else "No models available"
|
||||
else: # OMNI
|
||||
initial_model = omni_models[0] if omni_models else "No models available"
|
||||
if "Custom model..." in available_models_for_loop:
|
||||
initial_model = (
|
||||
"Custom model..." # Default to custom if available and no other default fits
|
||||
)
|
||||
|
||||
initial_custom_model = saved_settings.get("custom_model", "Qwen2.5-VL-7B-Instruct")
|
||||
initial_provider_base_url = saved_settings.get("provider_base_url", "http://localhost:1234/v1")
|
||||
initial_save_trajectory = saved_settings.get("save_trajectory", True)
|
||||
initial_recent_images = saved_settings.get("recent_images", 3)
|
||||
# --- End Apply Saved Settings ---
|
||||
|
||||
# Example prompts
|
||||
example_messages = [
|
||||
@@ -567,7 +672,7 @@ def create_gradio_ui(
|
||||
### 3. Pull the pre-built macOS image
|
||||
|
||||
```bash
|
||||
lume pull macos-sequoia-cua:latest --no-cache
|
||||
lume pull macos-sequoia-cua:latest
|
||||
```
|
||||
|
||||
Initial download requires 80GB storage, but reduces to ~30GB after first run due to macOS's sparse file system.
|
||||
@@ -606,33 +711,33 @@ def create_gradio_ui(
|
||||
custom_model = gr.Textbox(
|
||||
label="Custom Model Name",
|
||||
placeholder="Enter custom model name (e.g., Qwen2.5-VL-7B-Instruct)",
|
||||
value="Qwen2.5-VL-7B-Instruct", # Default value
|
||||
visible=False, # Initially hidden
|
||||
value=initial_custom_model,
|
||||
visible=(initial_model == "Custom model..."),
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
|
||||
# Add custom provider base URL textbox (only visible when "Custom model..." is selected)
|
||||
provider_base_url = gr.Textbox(
|
||||
label="Provider Base URL",
|
||||
placeholder="Enter provider base URL (e.g., http://localhost:1234/v1)",
|
||||
value="http://localhost:1234/v1", # Default value
|
||||
visible=False, # Initially hidden
|
||||
value=initial_provider_base_url,
|
||||
visible=(initial_model == "Custom model..."),
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
|
||||
# Add custom API key textbox (only visible when "Custom model..." is selected)
|
||||
provider_api_key = gr.Textbox(
|
||||
label="Provider API Key",
|
||||
placeholder="Enter provider API key (if required)",
|
||||
value="", # Default empty value
|
||||
visible=False, # Initially hidden
|
||||
value="",
|
||||
visible=(initial_model == "Custom model..."),
|
||||
interactive=True,
|
||||
type="password", # Hide the API key
|
||||
type="password",
|
||||
)
|
||||
|
||||
save_trajectory = gr.Checkbox(
|
||||
label="Save Trajectory",
|
||||
value=True,
|
||||
value=initial_save_trajectory,
|
||||
info="Save the agent's trajectory for debugging",
|
||||
interactive=True,
|
||||
)
|
||||
@@ -641,7 +746,7 @@ def create_gradio_ui(
|
||||
label="Recent Images",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
value=3,
|
||||
value=initial_recent_images,
|
||||
step=1,
|
||||
info="Number of recent images to keep in context",
|
||||
interactive=True,
|
||||
@@ -654,7 +759,7 @@ def create_gradio_ui(
|
||||
"Ask me to perform tasks in a virtual macOS environment.<br>Built with <a href='https://github.com/trycua/cua' target='_blank'>github.com/trycua/cua</a>."
|
||||
)
|
||||
|
||||
chatbot_history = gr.Chatbot(type='messages')
|
||||
chatbot_history = gr.Chatbot(type="messages")
|
||||
msg = gr.Textbox(
|
||||
placeholder="Ask me to perform tasks in a virtual macOS environment"
|
||||
)
|
||||
@@ -685,85 +790,132 @@ def create_gradio_ui(
|
||||
return
|
||||
|
||||
# Get the last user message
|
||||
last_user_message = history[-1]['content']
|
||||
|
||||
# Use custom model value if "Custom model..." is selected
|
||||
model_to_use = (
|
||||
last_user_message = history[-1]["content"]
|
||||
|
||||
# Determine the model name string to analyze: custom or from dropdown
|
||||
model_string_to_analyze = (
|
||||
custom_model_value
|
||||
if model_choice_value == "Custom model..."
|
||||
else model_choice_value
|
||||
else model_choice_value # Use the full UI string initially
|
||||
)
|
||||
|
||||
# Determine if this is a custom model selection
|
||||
is_custom_model_selected = model_choice_value == "Custom model..."
|
||||
|
||||
try:
|
||||
# Get the model, agent loop, and provider
|
||||
provider, model_name, agent_loop_type = get_provider_and_model(
|
||||
model_to_use, agent_loop_choice
|
||||
# Get the provider, *cleaned* model name, and agent loop type
|
||||
provider, cleaned_model_name_from_func, agent_loop_type = (
|
||||
get_provider_and_model(model_string_to_analyze, agent_loop_choice)
|
||||
)
|
||||
|
||||
# Special handling for OAICOMPAT
|
||||
is_oaicompat = str(provider) == "oaicompat"
|
||||
|
||||
# Get API key based on provider
|
||||
if model_choice_value == "Custom model..." and custom_api_key:
|
||||
|
||||
# Determine the final model name to send to the agent
|
||||
# If custom selected, use the custom text box value, otherwise use the cleaned name
|
||||
final_model_name_to_send = (
|
||||
custom_model_value
|
||||
if is_custom_model_selected
|
||||
else cleaned_model_name_from_func
|
||||
)
|
||||
|
||||
# Determine if OAICOMPAT should be used (only if custom model explicitly selected)
|
||||
is_oaicompat = is_custom_model_selected
|
||||
|
||||
# Get API key based on provider determined by get_provider_and_model
|
||||
if is_oaicompat and custom_api_key:
|
||||
# Use custom API key if provided for custom model
|
||||
api_key = custom_api_key
|
||||
print(f"DEBUG - Using custom API key for model: {model_name}")
|
||||
print(
|
||||
f"DEBUG - Using custom API key for model: {final_model_name_to_send}"
|
||||
)
|
||||
elif provider == LLMProvider.OPENAI:
|
||||
api_key = openai_api_key or os.environ.get("OPENAI_API_KEY", "")
|
||||
elif provider == LLMProvider.ANTHROPIC:
|
||||
api_key = anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY", "")
|
||||
else:
|
||||
# For Ollama or default OAICOMPAT (without custom key), no key needed/expected
|
||||
api_key = ""
|
||||
|
||||
|
||||
# --- Save Settings Before Running Agent ---
|
||||
current_settings = {
|
||||
"agent_loop": agent_loop_choice,
|
||||
"model_choice": model_choice_value,
|
||||
"custom_model": custom_model_value,
|
||||
"provider_base_url": custom_url_value,
|
||||
"save_trajectory": save_traj,
|
||||
"recent_images": recent_imgs,
|
||||
}
|
||||
save_settings(current_settings)
|
||||
# --- End Save Settings ---
|
||||
|
||||
# Create or update the agent
|
||||
create_agent(
|
||||
provider=provider,
|
||||
# Provider determined by get_provider_and_model unless custom model selected
|
||||
provider=LLMProvider.OAICOMPAT if is_oaicompat else provider,
|
||||
agent_loop=agent_loop_type,
|
||||
model_name=model_name,
|
||||
# Pass the FINAL determined model name (cleaned or custom)
|
||||
model_name=final_model_name_to_send,
|
||||
api_key=api_key,
|
||||
save_trajectory=save_traj,
|
||||
only_n_most_recent_images=recent_imgs,
|
||||
use_ollama=agent_loop_choice == "OMNI-OLLAMA",
|
||||
use_oaicompat=is_oaicompat,
|
||||
provider_base_url=custom_url_value if is_oaicompat and model_choice_value == "Custom model..." else None,
|
||||
use_oaicompat=is_oaicompat, # Set flag if custom model was selected
|
||||
# Pass custom URL only if custom model was selected
|
||||
provider_base_url=custom_url_value if is_oaicompat else None,
|
||||
verbosity=logging.DEBUG, # Added verbosity here
|
||||
)
|
||||
|
||||
|
||||
if global_agent is None:
|
||||
# Add initial empty assistant message
|
||||
history.append(gr.ChatMessage(role="assistant", content="Failed to create agent. Check API keys and configuration."))
|
||||
history.append(
|
||||
gr.ChatMessage(
|
||||
role="assistant",
|
||||
content="Failed to create agent. Check API keys and configuration.",
|
||||
)
|
||||
)
|
||||
yield history
|
||||
return
|
||||
|
||||
|
||||
# Add the screenshot handler to the agent's loop if available
|
||||
if global_agent and hasattr(global_agent, "_loop"):
|
||||
print("DEBUG - Adding screenshot handler to agent loop")
|
||||
|
||||
|
||||
# Create the screenshot handler with references to UI components
|
||||
screenshot_handler = GradioChatScreenshotHandler(
|
||||
history
|
||||
)
|
||||
|
||||
# Add the handler to the callback manager if it exists
|
||||
if hasattr(global_agent._loop, "callback_manager"):
|
||||
screenshot_handler = GradioChatScreenshotHandler(history)
|
||||
|
||||
# Add the handler to the callback manager if it exists AND is not None
|
||||
if (
|
||||
hasattr(global_agent._loop, "callback_manager")
|
||||
and global_agent._loop.callback_manager is not None
|
||||
):
|
||||
global_agent._loop.callback_manager.add_handler(screenshot_handler)
|
||||
print(f"DEBUG - Screenshot handler added to callback manager with history: {id(history)}")
|
||||
|
||||
print(
|
||||
f"DEBUG - Screenshot handler added to callback manager with history: {id(history)}"
|
||||
)
|
||||
else:
|
||||
# Optional: Log a warning if the callback manager is missing/None for a specific loop
|
||||
print(
|
||||
f"WARNING - Callback manager not found or is None for loop type: {type(global_agent._loop)}. Screenshot handler not added."
|
||||
)
|
||||
|
||||
# Stream responses from the agent
|
||||
async for result in global_agent.run(last_user_message):
|
||||
# Process result
|
||||
content, metadata = process_agent_result(result)
|
||||
|
||||
|
||||
# Skip empty content
|
||||
if content or metadata.get("title"):
|
||||
history.append(gr.ChatMessage(role="assistant", content=content, metadata=metadata))
|
||||
history.append(
|
||||
gr.ChatMessage(
|
||||
role="assistant", content=content, metadata=metadata
|
||||
)
|
||||
)
|
||||
yield history
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
# Update with error message
|
||||
history.append(gr.ChatMessage(role="assistant", content=f"Error: {str(e)}"))
|
||||
yield history
|
||||
|
||||
|
||||
# Connect the components
|
||||
msg.submit(chat_submit, [msg, chatbot_history], [msg, chatbot_history]).then(
|
||||
process_response,
|
||||
@@ -794,7 +946,11 @@ def create_gradio_ui(
|
||||
# Show/hide custom model, provider base URL, and API key textboxes based on dropdown selection
|
||||
def update_custom_model_visibility(model_value):
|
||||
is_custom = model_value == "Custom model..."
|
||||
return gr.update(visible=is_custom), gr.update(visible=is_custom), gr.update(visible=is_custom)
|
||||
return (
|
||||
gr.update(visible=is_custom),
|
||||
gr.update(visible=is_custom),
|
||||
gr.update(visible=is_custom),
|
||||
)
|
||||
|
||||
model_choice.change(
|
||||
fn=update_custom_model_visibility,
|
||||
|
||||
@@ -95,7 +95,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!lume pull macos-sequoia-cua:latest --no-cache"
|
||||
"!lume pull macos-sequoia-cua:latest"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -107,13 +107,6 @@
|
||||
"VMs are stored in `~/.lume`, and locally cached images are stored in `~/.lume/cache`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can remove the `--no-cache` flag to also save the image to your local cache during pull (requires double the storage space). This is useful if you plan to use the same image multiple times to create other VMs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
||||
Reference in New Issue
Block a user