mirror of
https://github.com/trycua/computer.git
synced 2026-02-20 05:19:38 -06:00
simplified event loop
This commit is contained in:
@@ -32,6 +32,7 @@ import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Optional, AsyncGenerator, Any, Tuple, Union
|
||||
import gradio as gr
|
||||
from gradio.components.chatbot import MetadataDict
|
||||
|
||||
# Import from agent package
|
||||
from agent.core.types import AgentResponse
|
||||
@@ -45,6 +46,9 @@ from agent import ComputerAgent, AgentLoop, LLM, LLMProvider
|
||||
global_agent = None
|
||||
global_computer = None
|
||||
|
||||
# We'll use asyncio.run() instead of a persistent event loop
|
||||
|
||||
|
||||
# Custom Screenshot Handler for Gradio chat
|
||||
class GradioChatScreenshotHandler(DefaultCallbackHandler):
|
||||
"""Custom handler that adds screenshots to the Gradio chatbot and updates annotated image."""
|
||||
@@ -57,9 +61,6 @@ class GradioChatScreenshotHandler(DefaultCallbackHandler):
|
||||
annotated_image: Reference to the annotated image component
|
||||
"""
|
||||
self.chatbot_history = chatbot_history
|
||||
self.latest_image = None
|
||||
self.latest_annotations = []
|
||||
logging.info("GradioChatScreenshotHandler initialized with chat history and annotated image")
|
||||
print("GradioChatScreenshotHandler initialized")
|
||||
|
||||
async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
|
||||
@@ -77,7 +78,11 @@ class GradioChatScreenshotHandler(DefaultCallbackHandler):
|
||||
|
||||
# Simply append the screenshot as a new message
|
||||
if self.chatbot_history is not None:
|
||||
self.chatbot_history.append(gr.ChatMessage(role="assistant", content=image_markdown))
|
||||
self.chatbot_history.append(gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=image_markdown,
|
||||
metadata={"title": f"🖥️ Screenshot - {action_type}", "status": "done"}
|
||||
))
|
||||
|
||||
# Map model names to specific provider model names
|
||||
MODEL_MAPPINGS = {
|
||||
@@ -213,17 +218,18 @@ def get_ollama_models() -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> str:
|
||||
def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> Tuple[str, MetadataDict]:
|
||||
"""Extract synthesized text from the agent result."""
|
||||
synthesized_text = ""
|
||||
metadata = MetadataDict()
|
||||
|
||||
if "output" in result and result["output"]:
|
||||
for output in result["output"]:
|
||||
if output.get("type") == "reasoning":
|
||||
metadata["title"] = "🧠 Reasoning"
|
||||
content = output.get("content", "")
|
||||
if content:
|
||||
synthesized_text += f"{content}\n"
|
||||
|
||||
elif output.get("type") == "message":
|
||||
# Handle message type outputs - can contain rich content
|
||||
content = output.get("content", [])
|
||||
@@ -237,6 +243,7 @@ def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> st
|
||||
synthesized_text += f"{text_value}\n"
|
||||
|
||||
elif output.get("type") == "computer_call":
|
||||
metadata["title"] = "🛠️ Used Computer"
|
||||
action = output.get("action", {})
|
||||
action_type = action.get("type", "")
|
||||
|
||||
@@ -259,8 +266,10 @@ def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> st
|
||||
synthesized_text += f"Pressed key: {key}\n"
|
||||
else:
|
||||
synthesized_text += f"Performed {action_type} action.\n"
|
||||
|
||||
return synthesized_text.strip()
|
||||
|
||||
metadata["title"] = f"🛠️ {synthesized_text.strip().splitlines()[-1]}"
|
||||
|
||||
return synthesized_text.strip(), metadata
|
||||
|
||||
|
||||
def create_computer_instance(verbosity: int = logging.INFO) -> Computer:
|
||||
@@ -345,10 +354,17 @@ def create_agent(
|
||||
return global_agent
|
||||
|
||||
|
||||
def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> str:
|
||||
def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> Tuple[str, MetadataDict]:
|
||||
"""Process agent results for the Gradio UI."""
|
||||
# Extract text content
|
||||
text_obj = result.get("text", {})
|
||||
metadata = result.get("metadata", {})
|
||||
|
||||
# Create a properly typed MetadataDict
|
||||
metadata_dict = MetadataDict()
|
||||
metadata_dict["title"] = metadata.get("title", "")
|
||||
metadata_dict["status"] = "done"
|
||||
metadata = metadata_dict
|
||||
|
||||
# For OpenAI's Computer-Use Agent, text field is an object with format property
|
||||
if (
|
||||
@@ -357,8 +373,11 @@ def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> str:
|
||||
and "format" in text_obj
|
||||
and not text_obj.get("value", "")
|
||||
):
|
||||
content = extract_synthesized_text(result)
|
||||
content, metadata = extract_synthesized_text(result)
|
||||
else:
|
||||
if not text_obj:
|
||||
text_obj = result
|
||||
|
||||
# For other types of results, try to get text directly
|
||||
if isinstance(text_obj, dict):
|
||||
if "value" in text_obj:
|
||||
@@ -391,7 +410,10 @@ def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> str:
|
||||
if not isinstance(content, str):
|
||||
content = str(content) if content else ""
|
||||
|
||||
return content
|
||||
print(content)
|
||||
print(metadata)
|
||||
|
||||
return content, metadata
|
||||
|
||||
def create_gradio_ui(
|
||||
provider_name: str = "openai",
|
||||
@@ -565,7 +587,7 @@ def create_gradio_ui(
|
||||
"""
|
||||
)
|
||||
|
||||
with gr.Accordion("Configuration", open=False):
|
||||
with gr.Accordion("Configuration", open=True):
|
||||
# Configuration options
|
||||
agent_loop = gr.Dropdown(
|
||||
choices=["OPENAI", "ANTHROPIC", "OMNI"],
|
||||
@@ -651,7 +673,7 @@ def create_gradio_ui(
|
||||
return "", history
|
||||
|
||||
# Function to process agent response after user input
|
||||
def process_response(
|
||||
async def process_response(
|
||||
history,
|
||||
model_choice_value,
|
||||
custom_model_value,
|
||||
@@ -674,106 +696,77 @@ def create_gradio_ui(
|
||||
if model_choice_value == "Custom model..."
|
||||
else model_choice_value
|
||||
)
|
||||
|
||||
# Create a new async event loop for this function call
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
async def _stream_agent_responses():
|
||||
try:
|
||||
# Get the model, agent loop, and provider
|
||||
provider, model_name, agent_loop_type = get_provider_and_model(
|
||||
model_to_use, agent_loop_choice
|
||||
)
|
||||
|
||||
# Special handling for OAICOMPAT
|
||||
is_oaicompat = str(provider) == "oaicompat"
|
||||
|
||||
# Get API key based on provider
|
||||
if model_choice_value == "Custom model..." and custom_api_key:
|
||||
# Use custom API key if provided for custom model
|
||||
api_key = custom_api_key
|
||||
print(f"DEBUG - Using custom API key for model: {model_name}")
|
||||
elif provider == LLMProvider.OPENAI:
|
||||
api_key = openai_api_key or os.environ.get("OPENAI_API_KEY", "")
|
||||
elif provider == LLMProvider.ANTHROPIC:
|
||||
api_key = anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY", "")
|
||||
else:
|
||||
api_key = ""
|
||||
|
||||
# Create or update the agent
|
||||
create_agent(
|
||||
provider=provider,
|
||||
agent_loop=agent_loop_type,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
save_trajectory=save_traj,
|
||||
only_n_most_recent_images=recent_imgs,
|
||||
use_ollama=agent_loop_choice == "OMNI-OLLAMA",
|
||||
use_oaicompat=is_oaicompat,
|
||||
provider_base_url=custom_url_value if is_oaicompat and model_choice_value == "Custom model..." else None,
|
||||
)
|
||||
|
||||
if global_agent is None:
|
||||
# Add initial empty assistant message
|
||||
history.append(gr.ChatMessage(role="assistant", content="Failed to create agent. Check API keys and configuration."))
|
||||
yield history
|
||||
return
|
||||
|
||||
# Add the screenshot handler to the agent's loop if available
|
||||
if global_agent and hasattr(global_agent, "_loop"):
|
||||
print("DEBUG - Adding screenshot handler to agent loop")
|
||||
|
||||
# Create the screenshot handler with references to UI components
|
||||
screenshot_handler = GradioChatScreenshotHandler(
|
||||
history
|
||||
)
|
||||
|
||||
# Add the handler to the callback manager if it exists
|
||||
if hasattr(global_agent._loop, "callback_manager"):
|
||||
global_agent._loop.callback_manager.add_handler(screenshot_handler)
|
||||
print(f"DEBUG - Screenshot handler added to callback manager with history: {id(history)}")
|
||||
|
||||
# Stream responses from the agent
|
||||
async for result in global_agent.run(last_user_message):
|
||||
# Process result
|
||||
content = process_agent_result(result)
|
||||
|
||||
# # Skip empty content
|
||||
if content:
|
||||
history.append(gr.ChatMessage(role="assistant", content=content))
|
||||
yield history
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
# Update with error message
|
||||
history.append(gr.ChatMessage(role="assistant", content=f"Error: {str(e)}"))
|
||||
yield history
|
||||
|
||||
# Create an async function to run the generator
|
||||
async def run_generator():
|
||||
async for update in _stream_agent_responses():
|
||||
yield update
|
||||
|
||||
# Run the wrapper function
|
||||
try:
|
||||
# Create a generator by running the async function
|
||||
generator = run_generator()
|
||||
# Push the first element to start the generator
|
||||
first_item = loop.run_until_complete(generator.__anext__())
|
||||
yield first_item
|
||||
|
||||
# Keep iterating until StopAsyncIteration
|
||||
while True:
|
||||
try:
|
||||
item = loop.run_until_complete(generator.__anext__())
|
||||
yield item
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
try:
|
||||
# Get the model, agent loop, and provider
|
||||
provider, model_name, agent_loop_type = get_provider_and_model(
|
||||
model_to_use, agent_loop_choice
|
||||
)
|
||||
|
||||
# Special handling for OAICOMPAT
|
||||
is_oaicompat = str(provider) == "oaicompat"
|
||||
|
||||
# Get API key based on provider
|
||||
if model_choice_value == "Custom model..." and custom_api_key:
|
||||
# Use custom API key if provided for custom model
|
||||
api_key = custom_api_key
|
||||
print(f"DEBUG - Using custom API key for model: {model_name}")
|
||||
elif provider == LLMProvider.OPENAI:
|
||||
api_key = openai_api_key or os.environ.get("OPENAI_API_KEY", "")
|
||||
elif provider == LLMProvider.ANTHROPIC:
|
||||
api_key = anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY", "")
|
||||
else:
|
||||
api_key = ""
|
||||
|
||||
# Create or update the agent
|
||||
create_agent(
|
||||
provider=provider,
|
||||
agent_loop=agent_loop_type,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
save_trajectory=save_traj,
|
||||
only_n_most_recent_images=recent_imgs,
|
||||
use_ollama=agent_loop_choice == "OMNI-OLLAMA",
|
||||
use_oaicompat=is_oaicompat,
|
||||
provider_base_url=custom_url_value if is_oaicompat and model_choice_value == "Custom model..." else None,
|
||||
)
|
||||
|
||||
if global_agent is None:
|
||||
# Add initial empty assistant message
|
||||
history.append(gr.ChatMessage(role="assistant", content="Failed to create agent. Check API keys and configuration."))
|
||||
yield history
|
||||
return
|
||||
|
||||
# Add the screenshot handler to the agent's loop if available
|
||||
if global_agent and hasattr(global_agent, "_loop"):
|
||||
print("DEBUG - Adding screenshot handler to agent loop")
|
||||
|
||||
# Create the screenshot handler with references to UI components
|
||||
screenshot_handler = GradioChatScreenshotHandler(
|
||||
history
|
||||
)
|
||||
|
||||
# Add the handler to the callback manager if it exists
|
||||
if hasattr(global_agent._loop, "callback_manager"):
|
||||
global_agent._loop.callback_manager.add_handler(screenshot_handler)
|
||||
print(f"DEBUG - Screenshot handler added to callback manager with history: {id(history)}")
|
||||
|
||||
# Stream responses from the agent
|
||||
async for result in global_agent.run(last_user_message):
|
||||
# Process result
|
||||
content, metadata = process_agent_result(result)
|
||||
|
||||
# Skip empty content
|
||||
if content or metadata.get("title"):
|
||||
history.append(gr.ChatMessage(role="assistant", content=content, metadata=metadata))
|
||||
yield history
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
# Update with error message
|
||||
history.append(gr.ChatMessage(role="assistant", content=f"Error: {str(e)}"))
|
||||
yield history
|
||||
|
||||
# Connect the components
|
||||
msg.submit(chat_submit, [msg, chatbot_history], [msg, chatbot_history]).then(
|
||||
process_response,
|
||||
|
||||
Reference in New Issue
Block a user