Merge pull request #152 from trycua/feature/agent/drag-cursor-fix

[Agent] Improved drag tool and standardized UITARS response
This commit is contained in:
f-trycua
2025-05-01 01:31:32 +02:00
committed by GitHub
11 changed files with 302 additions and 186 deletions

View File

@@ -161,15 +161,17 @@ class ComputerTool(BaseComputerTool, BaseAnthropicTool):
self.logger.info(f"Moving cursor to ({x}, {y})")
await self.computer.interface.move_cursor(x, y)
elif action == "left_click_drag":
self.logger.info(f"Dragging from ({x}, {y})")
# First move to the position
await self.computer.interface.move_cursor(x, y)
# Then perform drag operation - check if drag_to exists or we need to use other methods
try:
await self.computer.interface.drag_to(x, y)
except Exception as e:
self.logger.error(f"Error during drag operation: {str(e)}")
raise ToolError(f"Failed to perform drag: {str(e)}")
# Get the start coordinate from kwargs
start_coordinate = kwargs.get("start_coordinate")
if not start_coordinate:
raise ToolError("start_coordinate is required for left_click_drag action")
start_x, start_y = start_coordinate
end_x, end_y = x, y
self.logger.info(f"Dragging from ({start_x}, {start_y}) to ({end_x}, {end_y})")
await self.computer.interface.move_cursor(start_x, start_y)
await self.computer.interface.drag_to(end_x, end_y)
# Wait briefly for any UI changes
await asyncio.sleep(0.5)

View File

@@ -44,6 +44,7 @@ Action = Literal[
"double_click",
"screenshot",
"scroll",
"drag",
]
@@ -165,6 +166,11 @@ class ComputerTool(BaseComputerTool, BaseOpenAITool):
scroll_x = kwargs.get("scroll_x", 0) // 50
scroll_y = kwargs.get("scroll_y", 0) // 50
return await self.handle_scroll(x, y, scroll_x, scroll_y)
elif type == "drag":
path = kwargs.get("path")
if not path or not isinstance(path, list) or len(path) < 2:
raise ToolError("path is required for drag action and must contain at least 2 points")
return await self.handle_drag(path)
elif type == "screenshot":
return await self.screenshot()
elif type == "wait":
@@ -302,6 +308,41 @@ class ComputerTool(BaseComputerTool, BaseOpenAITool):
self.logger.error(f"Error in handle_scroll: {str(e)}")
raise ToolError(f"Failed to scroll at ({x}, {y}): {str(e)}")
async def handle_drag(self, path: List[Dict[str, int]]) -> ToolResult:
"""Handle mouse drag operation using a path of coordinates.
Args:
path: List of coordinate points {"x": int, "y": int} defining the drag path
Returns:
ToolResult with the operation result and screenshot
"""
try:
# Convert from [{"x": x, "y": y}, ...] format to [(x, y), ...] format
points = [(p["x"], p["y"]) for p in path]
# Perform drag action
if len(points) == 2:
await self.computer.interface.move_cursor(points[0][0], points[0][1])
await self.computer.interface.drag_to(points[1][0], points[1][1])
else:
await self.computer.interface.drag(points, button="left")
# Wait for UI to update
await asyncio.sleep(0.5)
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
return ToolResult(
output=f"Dragged from ({path[0]['x']}, {path[0]['y']}) to ({path[-1]['x']}, {path[-1]['y']})",
base64_image=base64_screenshot,
)
except Exception as e:
self.logger.error(f"Error in handle_drag: {str(e)}")
raise ToolError(f"Failed to perform drag operation: {str(e)}")
async def screenshot(self) -> ToolResult:
"""Take a screenshot."""
try:

View File

@@ -190,25 +190,21 @@ class OAICompatClient(BaseUITarsClient):
response_text = await response.text()
logger.debug(f"Response content: {response_text}")
# if 503, then the endpoint is still warming up
if response.status == 503:
logger.error(f"Endpoint is still warming up, please try again later")
raise Exception(f"Endpoint is still warming up: {response_text}")
# Try to parse as JSON if the content type is appropriate
if "application/json" in response.headers.get('Content-Type', ''):
response_json = await response.json()
else:
raise Exception(f"Response is not JSON format")
# # Optionally try to parse it anyway
# try:
# import json
# response_json = json.loads(response_text)
# except json.JSONDecodeError as e:
# print(f"Failed to parse response as JSON: {e}")
if response.status != 200:
error_msg = response_json.get("error", {}).get(
"message", str(response_json)
)
logger.error(f"Error in API call: {error_msg}")
raise Exception(f"API error: {error_msg}")
logger.error(f"Error in API call: {response_text}")
raise Exception(f"API error: {response_text}")
return response_json
except Exception as e:

View File

@@ -17,7 +17,7 @@ from ...core.types import AgentResponse, LLMProvider
from ...core.visualization import VisualizationHelper
from computer import Computer
from .utils import add_box_token, parse_actions, parse_action_parameters
from .utils import add_box_token, parse_actions, parse_action_parameters, to_agent_response_format
from .tools.manager import ToolManager
from .tools.computer import ToolResult
from .prompts import COMPUTER_USE, SYSTEM_PROMPT, MAC_SPECIFIC_NOTES
@@ -507,41 +507,14 @@ class UITARSLoop(BaseLoop):
# Update whether an action screenshot was saved this turn
action_screenshot_saved = action_screenshot_saved or new_screenshot_saved
# Parse actions from the raw response
raw_response = response["choices"][0]["message"]["content"]
parsed_actions = parse_actions(raw_response)
# Extract thought content if available
thought = ""
if "Thought:" in raw_response:
thought_match = re.search(r"Thought: (.*?)(?=\s*Action:|$)", raw_response, re.DOTALL)
if thought_match:
thought = thought_match.group(1).strip()
agent_response = await to_agent_response_format(
response,
messages,
model=self.model,
)
yield agent_response
# Create standardized thought response format
thought_response = {
"role": "assistant",
"content": thought or raw_response,
"metadata": {
"title": "🧠 UI-TARS Thoughts"
}
}
# Create action response format
action_response = {
"role": "assistant",
"content": str(parsed_actions),
"metadata": {
"title": "🖱️ UI-TARS Actions",
}
}
# Yield both responses to the caller (thoughts first, then actions)
yield thought_response
if parsed_actions:
yield action_response
# Check if we should continue this conversation
running = should_continue
@@ -562,7 +535,8 @@ class UITARSLoop(BaseLoop):
logger.error(f"Maximum retry attempts reached. Last error was: {str(e)}")
yield {
"error": str(e),
"role": "assistant",
"content": f"Error: {str(e)}",
"metadata": {"title": "❌ Error"},
}

View File

@@ -4,9 +4,114 @@ import logging
import base64
import re
from typing import Any, Dict, List, Optional, Union, Tuple
from datetime import datetime
logger = logging.getLogger(__name__)
from ...core.types import AgentResponse
async def to_agent_response_format(
response: Dict[str, Any],
messages: List[Dict[str, Any]],
model: Optional[str] = None,
) -> AgentResponse:
"""Convert raw UI-TARS response to agent response format.
Args:
response: Raw UI-TARS response
messages: List of messages in standard format
model: Optional model name
Returns:
AgentResponse: Standardized agent response format
"""
# Create unique IDs for this response
response_id = f"resp_{datetime.now().strftime('%Y%m%d%H%M%S')}_{id(response)}"
reasoning_id = f"rs_{response_id}"
action_id = f"cu_{response_id}"
call_id = f"call_{response_id}"
# Parse actions from the raw response
content = response["choices"][0]["message"]["content"]
actions = parse_actions(content)
# Extract thought content if available
reasoning_text = ""
if "Thought:" in content:
thought_match = re.search(r"Thought: (.*?)(?=\s*Action:|$)", content, re.DOTALL)
if thought_match:
reasoning_text = thought_match.group(1).strip()
# Create output items
output_items = []
if reasoning_text:
output_items.append({
"type": "reasoning",
"id": reasoning_id,
"text": reasoning_text
})
if actions:
for i, action in enumerate(actions):
action_name, tool_args = parse_action_parameters(action)
if action_name == "finished":
output_items.append({
"type": "message",
"role": "assistant",
"content": [{
"type": "output_text",
"text": tool_args["content"]
}],
"id": f"action_{i}_{action_id}",
"status": "completed"
})
else:
if tool_args.get("action") == action_name:
del tool_args["action"]
output_items.append({
"type": "computer_call",
"id": f"{action}_{i}_{action_id}",
"call_id": f"call_{i}_{action_id}",
"action": { "type": action_name, **tool_args },
"pending_safety_checks": [],
"status": "completed"
})
# Create agent response
agent_response = AgentResponse(
id=response_id,
object="response",
created_at=int(datetime.now().timestamp()),
status="completed",
error=None,
incomplete_details=None,
instructions=None,
max_output_tokens=None,
model=model or response["model"],
output=output_items,
parallel_tool_calls=True,
previous_response_id=None,
reasoning={"effort": "medium"},
store=True,
temperature=0.0,
top_p=0.7,
text={"format": {"type": "text"}},
tool_choice="auto",
tools=[
{
"type": "computer_use_preview",
"display_height": 768,
"display_width": 1024,
"environment": "mac",
}
],
truncation="auto",
usage=response["usage"],
user=None,
metadata={},
response=response
)
return agent_response
def add_box_token(input_string: str) -> str:
"""Add box tokens to the coordinates in the model response.
@@ -74,7 +179,13 @@ def parse_action_parameters(action: str) -> Tuple[str, Dict[str, Any]]:
"""
# Handle "finished" action
if action.startswith("finished"):
return "finished", {}
# Parse content if it exists
content_match = re.search(r"content='([^']*)'", action)
if content_match:
content = content_match.group(1)
return "finished", {"content": content}
else:
return "finished", {}
# Parse action parameters
action_match = re.match(r'(\w+)\((.*)\)', action)

View File

@@ -35,6 +35,7 @@ from pathlib import Path
from typing import Dict, List, Optional, AsyncGenerator, Any, Tuple, Union
import gradio as gr
from gradio.components.chatbot import MetadataDict
from typing import cast
# Import from agent package
from agent.core.types import AgentResponse
@@ -322,63 +323,6 @@ def get_ollama_models() -> List[str]:
logging.error(f"Error getting Ollama models: {e}")
return []
def extract_synthesized_text(
result: Union[AgentResponse, Dict[str, Any]],
) -> Tuple[str, MetadataDict]:
"""Extract synthesized text from the agent result."""
synthesized_text = ""
metadata = MetadataDict()
if "output" in result and result["output"]:
for output in result["output"]:
if output.get("type") == "reasoning":
metadata["title"] = "🧠 Reasoning"
content = output.get("content", "")
if content:
synthesized_text += f"{content}\n"
elif output.get("type") == "message":
# Handle message type outputs - can contain rich content
content = output.get("content", [])
# Content is usually an array of content blocks
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "output_text":
text_value = block.get("text", "")
if text_value:
synthesized_text += f"{text_value}\n"
elif output.get("type") == "computer_call":
action = output.get("action", {})
action_type = action.get("type", "")
# Create a descriptive text about the action
if action_type == "click":
button = action.get("button", "")
x = action.get("x", "")
y = action.get("y", "")
synthesized_text += f"Clicked {button} at position ({x}, {y}).\n"
elif action_type == "type":
text = action.get("text", "")
synthesized_text += f"Typed: {text}.\n"
elif action_type == "keypress":
# Extract key correctly from either keys array or key field
if isinstance(action.get("keys"), list):
key = ", ".join(action.get("keys"))
else:
key = action.get("key", "")
synthesized_text += f"Pressed key: {key}\n"
else:
synthesized_text += f"Performed {action_type} action.\n"
metadata["status"] = "done"
metadata["title"] = f"🛠️ {synthesized_text.strip().splitlines()[-1]}"
return synthesized_text.strip(), metadata
def create_computer_instance(verbosity: int = logging.INFO) -> Computer:
"""Create or get the global Computer instance."""
global global_computer
@@ -447,66 +391,6 @@ def create_agent(
return global_agent
def process_agent_result(result: Union[AgentResponse, Dict[str, Any]]) -> Tuple[str, MetadataDict]:
"""Process agent results for the Gradio UI."""
# Extract text content
text_obj = result.get("text", {})
metadata = result.get("metadata", {})
# Create a properly typed MetadataDict
metadata_dict = MetadataDict()
metadata_dict["title"] = metadata.get("title", "")
metadata_dict["status"] = "done"
metadata = metadata_dict
# For OpenAI's Computer-Use Agent, text field is an object with format property
if (
text_obj
and isinstance(text_obj, dict)
and "format" in text_obj
and not text_obj.get("value", "")
):
content, metadata = extract_synthesized_text(result)
else:
if not text_obj:
text_obj = result
# For other types of results, try to get text directly
if isinstance(text_obj, dict):
if "value" in text_obj:
content = text_obj["value"]
elif "text" in text_obj:
content = text_obj["text"]
elif "content" in text_obj:
content = text_obj["content"]
else:
content = ""
else:
content = str(text_obj) if text_obj else ""
# If still no content but we have outputs, create a summary
if not content and "output" in result and result["output"]:
output = result["output"]
for out in output:
if out.get("type") == "reasoning":
content = out.get("content", "")
if content:
break
elif out.get("type") == "computer_call":
action = out.get("action", {})
action_type = action.get("type", "")
if action_type:
content = f"Performing action: {action_type}"
break
# Clean up the text - ensure content is a string
if not isinstance(content, str):
content = str(content) if content else ""
return content, metadata
def create_gradio_ui(
provider_name: str = "openai",
model_name: str = "gpt-4o",
@@ -907,17 +791,64 @@ def create_gradio_ui(
# Stream responses from the agent
async for result in global_agent.run(last_user_message):
# Process result
content, metadata = process_agent_result(result)
# Skip empty content
if content or metadata.get("title"):
history.append(
gr.ChatMessage(
role="assistant", content=content, metadata=metadata
print(f"DEBUG - Agent response ------- START")
from pprint import pprint
pprint(result)
print(f"DEBUG - Agent response ------- END")
def generate_gradio_messages():
if result.get("content"):
yield gr.ChatMessage(
role="assistant",
content=result.get("content", ""),
metadata=cast(MetadataDict, result.get("metadata", {}))
)
)
yield history
else:
outputs = result.get("output", [])
for output in outputs:
if output.get("type") == "message":
content = output.get("content", [])
for content_part in content:
if content_part.get("text"):
yield gr.ChatMessage(
role=output.get("role", "assistant"),
content=content_part.get("text", ""),
metadata=content_part.get("metadata", {})
)
elif output.get("type") == "reasoning":
# if it's openAI, we only have access to a summary of the reasoning
summary_content = output.get("summary", [])
if summary_content:
for summary_part in summary_content:
if summary_part.get("type") == "summary_text":
yield gr.ChatMessage(
role="assistant",
content=summary_part.get("text", "")
)
else:
summary_content = output.get("text", "")
if summary_content:
yield gr.ChatMessage(
role="assistant",
content=summary_content,
)
elif output.get("type") == "computer_call":
action = output.get("action", {})
action_type = action.get("type", "")
if action_type:
action_title = f"🛠️ Performing {action_type}"
if action.get("x") and action.get("y"):
action_title += f" at ({action['x']}, {action['y']})"
yield gr.ChatMessage(
role="assistant",
content=f"```json\n{json.dumps(action)}\n```",
metadata={"title": action_title}
)
for message in generate_gradio_messages():
history.append(message)
yield history
except Exception as e:
import traceback

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any
from typing import Optional, Dict, Any, List, Tuple
class BaseAccessibilityHandler(ABC):
"""Abstract base class for OS-specific accessibility handlers."""
@@ -59,6 +59,17 @@ class BaseAutomationHandler(ABC):
duration: How long the drag should take in seconds
"""
pass
@abstractmethod
async def drag(self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5) -> Dict[str, Any]:
"""Drag the cursor from current position to specified coordinates.
Args:
path: A list of tuples of x and y coordinates to drag to
button: The mouse button to use ('left', 'middle', 'right')
duration: How long the drag should take in seconds
"""
pass
# Keyboard Actions
@abstractmethod

View File

@@ -1,7 +1,7 @@
import pyautogui
import base64
from io import BytesIO
from typing import Optional, Dict, Any, List
from typing import Optional, Dict, Any, List, Tuple
from ctypes import byref, c_void_p, POINTER
from AppKit import NSWorkspace # type: ignore
import AppKit
@@ -563,6 +563,39 @@ class MacOSAutomationHandler(BaseAutomationHandler):
except Exception as e:
return {"success": False, "error": str(e)}
async def drag(
self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5
) -> Dict[str, Any]:
try:
if not path or len(path) < 2:
return {"success": False, "error": "Path must contain at least 2 points"}
# Move to the first point
start_x, start_y = path[0]
pyautogui.moveTo(start_x, start_y)
# Press the mouse button
pyautogui.mouseDown(button=button)
# Calculate time between points to distribute duration evenly
step_duration = duration / (len(path) - 1) if len(path) > 1 else duration
# Move through each subsequent point
for x, y in path[1:]:
pyautogui.moveTo(x, y, duration=step_duration)
# Release the mouse button
pyautogui.mouseUp(button=button)
return {"success": True}
except Exception as e:
# Make sure to release the mouse button if an error occurs
try:
pyautogui.mouseUp(button=button)
except:
pass
return {"success": False, "error": str(e)}
# Keyboard Actions
async def type_text(self, text: str) -> Dict[str, Any]:
try:

View File

@@ -65,6 +65,7 @@ async def websocket_endpoint(websocket: WebSocket):
"type_text": manager.automation_handler.type_text,
"press_key": manager.automation_handler.press_key,
"drag_to": manager.automation_handler.drag_to,
"drag": manager.automation_handler.drag,
"hotkey": manager.automation_handler.hotkey,
"get_cursor_position": manager.automation_handler.get_cursor_position,
"get_screen_size": manager.automation_handler.get_screen_size,

View File

@@ -79,6 +79,17 @@ class BaseComputerInterface(ABC):
"""
pass
@abstractmethod
async def drag(self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5) -> None:
"""Drag the cursor along a path of coordinates.
Args:
path: List of (x, y) coordinate tuples defining the drag path
button: The mouse button to use ('left', 'middle', 'right')
duration: Total time in seconds that the drag operation should take
"""
pass
# Keyboard Actions
@abstractmethod
async def type_text(self, text: str) -> None:

View File

@@ -328,6 +328,11 @@ class MacOSComputerInterface(BaseComputerInterface):
"drag_to", {"x": x, "y": y, "button": button, "duration": duration}
)
async def drag(self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5) -> None:
await self._send_command(
"drag", {"path": path, "button": button, "duration": duration}
)
# Keyboard Actions
async def type_text(self, text: str) -> None:
await self._send_command("type_text", {"text": text})