simplified event loop

This commit is contained in:
Dillon DuPont
2025-04-13 20:01:38 -04:00
parent e20330c211
commit 43d0180309

View File

@@ -32,6 +32,7 @@ import asyncio
import logging
from typing import Dict, List, Optional, AsyncGenerator, Any, Tuple, Union
import gradio as gr
from gradio.components.chatbot import MetadataDict
# Import from agent package
from agent.core.types import AgentResponse
@@ -45,6 +46,9 @@ from agent import ComputerAgent, AgentLoop, LLM, LLMProvider
global_agent = None
global_computer = None
# We'll use asyncio.run() instead of a persistent event loop
# Custom Screenshot Handler for Gradio chat
class GradioChatScreenshotHandler(DefaultCallbackHandler):
"""Custom handler that adds screenshots to the Gradio chatbot and updates annotated image."""
@@ -57,9 +61,6 @@ class GradioChatScreenshotHandler(DefaultCallbackHandler):
annotated_image: Reference to the annotated image component
"""
self.chatbot_history = chatbot_history
self.latest_image = None
self.latest_annotations = []
logging.info("GradioChatScreenshotHandler initialized with chat history and annotated image")
print("GradioChatScreenshotHandler initialized")
async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[ParseResult] = None) -> None:
@@ -77,7 +78,11 @@ class GradioChatScreenshotHandler(DefaultCallbackHandler):
# Simply append the screenshot as a new message
if self.chatbot_history is not None:
self.chatbot_history.append(gr.ChatMessage(role="assistant", content=image_markdown))
self.chatbot_history.append(gr.ChatMessage(
role="assistant",
content=image_markdown,
metadata={"title": f"🖥️ Screenshot - {action_type}", "status": "done"}
))
# Map model names to specific provider model names
MODEL_MAPPINGS = {
@@ -213,17 +218,18 @@ def get_ollama_models() -> List[str]:
return []
def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> str:
def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> Tuple[str, MetadataDict]:
"""Extract synthesized text from the agent result."""
synthesized_text = ""
metadata = MetadataDict()
if "output" in result and result["output"]:
for output in result["output"]:
if output.get("type") == "reasoning":
metadata["title"] = "🧠 Reasoning"
content = output.get("content", "")
if content:
synthesized_text += f"{content}\n"
elif output.get("type") == "message":
# Handle message type outputs - can contain rich content
content = output.get("content", [])
@@ -237,6 +243,7 @@ def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> st
synthesized_text += f"{text_value}\n"
elif output.get("type") == "computer_call":
metadata["title"] = "🛠️ Used Computer"
action = output.get("action", {})
action_type = action.get("type", "")
@@ -259,8 +266,10 @@ def extract_synthesized_text(result: Union[AgentResponse, Dict[str, Any]]) -> st
synthesized_text += f"Pressed key: {key}\n"
else:
synthesized_text += f"Performed {action_type} action.\n"
return synthesized_text.strip()
metadata["title"] = f"🛠️ {synthesized_text.strip().splitlines()[-1]}"
return synthesized_text.strip(), metadata
def create_computer_instance(verbosity: int = logging.INFO) -> Computer:
@@ -345,10 +354,17 @@ def create_agent(
return global_agent
def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> str:
def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> Tuple[str, MetadataDict]:
"""Process agent results for the Gradio UI."""
# Extract text content
text_obj = result.get("text", {})
metadata = result.get("metadata", {})
# Create a properly typed MetadataDict
metadata_dict = MetadataDict()
metadata_dict["title"] = metadata.get("title", "")
metadata_dict["status"] = "done"
metadata = metadata_dict
# For OpenAI's Computer-Use Agent, text field is an object with format property
if (
@@ -357,8 +373,11 @@ def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> str:
and "format" in text_obj
and not text_obj.get("value", "")
):
content = extract_synthesized_text(result)
content, metadata = extract_synthesized_text(result)
else:
if not text_obj:
text_obj = result
# For other types of results, try to get text directly
if isinstance(text_obj, dict):
if "value" in text_obj:
@@ -391,7 +410,10 @@ def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> str:
if not isinstance(content, str):
content = str(content) if content else ""
return content
print(content)
print(metadata)
return content, metadata
def create_gradio_ui(
provider_name: str = "openai",
@@ -565,7 +587,7 @@ def create_gradio_ui(
"""
)
with gr.Accordion("Configuration", open=False):
with gr.Accordion("Configuration", open=True):
# Configuration options
agent_loop = gr.Dropdown(
choices=["OPENAI", "ANTHROPIC", "OMNI"],
@@ -651,7 +673,7 @@ def create_gradio_ui(
return "", history
# Function to process agent response after user input
def process_response(
async def process_response(
history,
model_choice_value,
custom_model_value,
@@ -674,106 +696,77 @@ def create_gradio_ui(
if model_choice_value == "Custom model..."
else model_choice_value
)
# Create a new async event loop for this function call
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
async def _stream_agent_responses():
try:
# Get the model, agent loop, and provider
provider, model_name, agent_loop_type = get_provider_and_model(
model_to_use, agent_loop_choice
)
# Special handling for OAICOMPAT
is_oaicompat = str(provider) == "oaicompat"
# Get API key based on provider
if model_choice_value == "Custom model..." and custom_api_key:
# Use custom API key if provided for custom model
api_key = custom_api_key
print(f"DEBUG - Using custom API key for model: {model_name}")
elif provider == LLMProvider.OPENAI:
api_key = openai_api_key or os.environ.get("OPENAI_API_KEY", "")
elif provider == LLMProvider.ANTHROPIC:
api_key = anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY", "")
else:
api_key = ""
# Create or update the agent
create_agent(
provider=provider,
agent_loop=agent_loop_type,
model_name=model_name,
api_key=api_key,
save_trajectory=save_traj,
only_n_most_recent_images=recent_imgs,
use_ollama=agent_loop_choice == "OMNI-OLLAMA",
use_oaicompat=is_oaicompat,
provider_base_url=custom_url_value if is_oaicompat and model_choice_value == "Custom model..." else None,
)
if global_agent is None:
# Add initial empty assistant message
history.append(gr.ChatMessage(role="assistant", content="Failed to create agent. Check API keys and configuration."))
yield history
return
# Add the screenshot handler to the agent's loop if available
if global_agent and hasattr(global_agent, "_loop"):
print("DEBUG - Adding screenshot handler to agent loop")
# Create the screenshot handler with references to UI components
screenshot_handler = GradioChatScreenshotHandler(
history
)
# Add the handler to the callback manager if it exists
if hasattr(global_agent._loop, "callback_manager"):
global_agent._loop.callback_manager.add_handler(screenshot_handler)
print(f"DEBUG - Screenshot handler added to callback manager with history: {id(history)}")
# Stream responses from the agent
async for result in global_agent.run(last_user_message):
# Process result
content = process_agent_result(result)
# # Skip empty content
if content:
history.append(gr.ChatMessage(role="assistant", content=content))
yield history
except Exception as e:
import traceback
traceback.print_exc()
# Update with error message
history.append(gr.ChatMessage(role="assistant", content=f"Error: {str(e)}"))
yield history
# Create an async function to run the generator
async def run_generator():
async for update in _stream_agent_responses():
yield update
# Run the wrapper function
try:
# Create a generator by running the async function
generator = run_generator()
# Push the first element to start the generator
first_item = loop.run_until_complete(generator.__anext__())
yield first_item
# Keep iterating until StopAsyncIteration
while True:
try:
item = loop.run_until_complete(generator.__anext__())
yield item
except StopAsyncIteration:
break
finally:
loop.close()
try:
# Get the model, agent loop, and provider
provider, model_name, agent_loop_type = get_provider_and_model(
model_to_use, agent_loop_choice
)
# Special handling for OAICOMPAT
is_oaicompat = str(provider) == "oaicompat"
# Get API key based on provider
if model_choice_value == "Custom model..." and custom_api_key:
# Use custom API key if provided for custom model
api_key = custom_api_key
print(f"DEBUG - Using custom API key for model: {model_name}")
elif provider == LLMProvider.OPENAI:
api_key = openai_api_key or os.environ.get("OPENAI_API_KEY", "")
elif provider == LLMProvider.ANTHROPIC:
api_key = anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY", "")
else:
api_key = ""
# Create or update the agent
create_agent(
provider=provider,
agent_loop=agent_loop_type,
model_name=model_name,
api_key=api_key,
save_trajectory=save_traj,
only_n_most_recent_images=recent_imgs,
use_ollama=agent_loop_choice == "OMNI-OLLAMA",
use_oaicompat=is_oaicompat,
provider_base_url=custom_url_value if is_oaicompat and model_choice_value == "Custom model..." else None,
)
if global_agent is None:
# Add initial empty assistant message
history.append(gr.ChatMessage(role="assistant", content="Failed to create agent. Check API keys and configuration."))
yield history
return
# Add the screenshot handler to the agent's loop if available
if global_agent and hasattr(global_agent, "_loop"):
print("DEBUG - Adding screenshot handler to agent loop")
# Create the screenshot handler with references to UI components
screenshot_handler = GradioChatScreenshotHandler(
history
)
# Add the handler to the callback manager if it exists
if hasattr(global_agent._loop, "callback_manager"):
global_agent._loop.callback_manager.add_handler(screenshot_handler)
print(f"DEBUG - Screenshot handler added to callback manager with history: {id(history)}")
# Stream responses from the agent
async for result in global_agent.run(last_user_message):
# Process result
content, metadata = process_agent_result(result)
# Skip empty content
if content or metadata.get("title"):
history.append(gr.ChatMessage(role="assistant", content=content, metadata=metadata))
yield history
except Exception as e:
import traceback
traceback.print_exc()
# Update with error message
history.append(gr.ChatMessage(role="assistant", content=f"Error: {str(e)}"))
yield history
# Connect the components
msg.submit(chat_submit, [msg, chatbot_history], [msg, chatbot_history]).then(
process_response,