mirror of
https://github.com/trycua/lume.git
synced 2026-01-06 04:20:03 -06:00
Replaced agent loop func with agent config class
This commit is contained in:
@@ -3,12 +3,12 @@ ComputerAgent - Main agent class that selects and runs agent loops
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Callable, Set
|
||||
from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Callable, Set, Tuple
|
||||
|
||||
from litellm.responses.utils import Usage
|
||||
|
||||
from .types import Messages, Computer
|
||||
from .decorators import find_agent_loop
|
||||
from .types import Messages, Computer, AgentCapability
|
||||
from .decorators import find_agent_config
|
||||
from .computer_handler import OpenAIComputerHandler, acknowledge_safety_check_callback, check_blocklisted_url
|
||||
import json
|
||||
import litellm
|
||||
@@ -213,13 +213,14 @@ class ComputerAgent:
|
||||
# Find the appropriate agent loop
|
||||
if custom_loop:
|
||||
self.agent_loop = custom_loop
|
||||
self.agent_loop_info = None
|
||||
self.agent_config_info = None
|
||||
else:
|
||||
loop_info = find_agent_loop(model)
|
||||
if not loop_info:
|
||||
raise ValueError(f"No agent loop found for model: {model}")
|
||||
self.agent_loop = loop_info.func
|
||||
self.agent_loop_info = loop_info
|
||||
config_info = find_agent_config(model)
|
||||
if not config_info:
|
||||
raise ValueError(f"No agent config found for model: {model}")
|
||||
# Instantiate the agent config class
|
||||
self.agent_loop = config_info.agent_class()
|
||||
self.agent_config_info = config_info
|
||||
|
||||
self.tool_schemas = []
|
||||
self.computer_handler = None
|
||||
@@ -511,6 +512,9 @@ class ComputerAgent:
|
||||
Returns:
|
||||
AsyncGenerator that yields response chunks
|
||||
"""
|
||||
capabilities = self.get_capabilities()
|
||||
if "step" not in capabilities:
|
||||
raise ValueError(f"Agent loop {self.agent_loop.__name__} does not support step predictions")
|
||||
|
||||
await self._initialize_computers()
|
||||
|
||||
@@ -555,7 +559,7 @@ class ComputerAgent:
|
||||
}
|
||||
|
||||
# Run agent loop iteration
|
||||
result = await self.agent_loop(
|
||||
result = await self.agent_loop.predict_step(
|
||||
**loop_kwargs,
|
||||
_on_api_start=self._on_api_start,
|
||||
_on_api_end=self._on_api_end,
|
||||
@@ -591,4 +595,45 @@ class ComputerAgent:
|
||||
)
|
||||
}
|
||||
|
||||
await self._on_run_end(loop_kwargs, old_items, new_items)
|
||||
await self._on_run_end(loop_kwargs, old_items, new_items)
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
instruction: str,
|
||||
image_b64: Optional[str] = None
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates based on image and instruction.
|
||||
|
||||
Args:
|
||||
instruction: Instruction for where to click
|
||||
image_b64: Base64 encoded image (optional, will take screenshot if not provided)
|
||||
|
||||
Returns:
|
||||
None or tuple with (x, y) coordinates
|
||||
"""
|
||||
capabilities = self.get_capabilities()
|
||||
if "click" not in capabilities:
|
||||
raise ValueError(f"Agent loop {self.agent_loop.__name__} does not support click predictions")
|
||||
if hasattr(self.agent_loop, 'predict_click'):
|
||||
if not self.computer_handler:
|
||||
raise ValueError("Computer tool is required for predict_click")
|
||||
if not image_b64:
|
||||
image_b64 = await self.computer_handler.screenshot()
|
||||
return await self.agent_loop.predict_click(
|
||||
model=self.model,
|
||||
image_b64=image_b64,
|
||||
instruction=instruction
|
||||
)
|
||||
return None
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""
|
||||
Get list of capabilities supported by the current agent config.
|
||||
|
||||
Returns:
|
||||
List of capability strings (e.g., ["step", "click"])
|
||||
"""
|
||||
if hasattr(self.agent_loop, 'get_capabilities'):
|
||||
return self.agent_loop.get_capabilities()
|
||||
return ["step"] # Default capability
|
||||
@@ -260,7 +260,12 @@ Examples:
|
||||
help="Show total cost of the agent runs"
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"-r", "--max-retries",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Maximum number of retries for the LLM API calls"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -327,6 +332,7 @@ Examples:
|
||||
"model": args.model,
|
||||
"tools": [computer],
|
||||
"verbosity": 20 if args.verbose else 30, # DEBUG vs WARNING
|
||||
"max_retries": args.max_retries
|
||||
}
|
||||
|
||||
if args.images > 0:
|
||||
|
||||
@@ -7,84 +7,51 @@ import inspect
|
||||
from typing import Dict, List, Any, Callable, Optional
|
||||
from functools import wraps
|
||||
|
||||
from .types import AgentLoopInfo
|
||||
from .types import AgentConfigInfo
|
||||
from .loops.base import AsyncAgentConfig
|
||||
|
||||
# Global registry
|
||||
_agent_loops: List[AgentLoopInfo] = []
|
||||
_agent_configs: List[AgentConfigInfo] = []
|
||||
|
||||
def agent_loop(models: str, priority: int = 0):
|
||||
def register_agent(models: str, priority: int = 0):
|
||||
"""
|
||||
Decorator to register an agent loop function.
|
||||
Decorator to register an AsyncAgentConfig class.
|
||||
|
||||
Args:
|
||||
models: Regex pattern to match supported models
|
||||
priority: Priority for loop selection (higher = more priority)
|
||||
priority: Priority for agent selection (higher = more priority)
|
||||
"""
|
||||
def decorator(func: Callable):
|
||||
# Validate function signature
|
||||
sig = inspect.signature(func)
|
||||
required_params = {'messages', 'model'}
|
||||
func_params = set(sig.parameters.keys())
|
||||
def decorator(agent_class: type):
|
||||
# Validate that the class implements AsyncAgentConfig protocol
|
||||
if not hasattr(agent_class, 'predict_step'):
|
||||
raise ValueError(f"Agent class {agent_class.__name__} must implement predict_step method")
|
||||
if not hasattr(agent_class, 'predict_click'):
|
||||
raise ValueError(f"Agent class {agent_class.__name__} must implement predict_click method")
|
||||
if not hasattr(agent_class, 'get_capabilities'):
|
||||
raise ValueError(f"Agent class {agent_class.__name__} must implement get_capabilities method")
|
||||
|
||||
if not required_params.issubset(func_params):
|
||||
missing = required_params - func_params
|
||||
raise ValueError(f"Agent loop function must have parameters: {missing}")
|
||||
|
||||
# Register the loop
|
||||
loop_info = AgentLoopInfo(
|
||||
func=func,
|
||||
# Register the agent config
|
||||
config_info = AgentConfigInfo(
|
||||
agent_class=agent_class,
|
||||
models_regex=models,
|
||||
priority=priority
|
||||
)
|
||||
_agent_loops.append(loop_info)
|
||||
_agent_configs.append(config_info)
|
||||
|
||||
# Sort by priority (highest first)
|
||||
_agent_loops.sort(key=lambda x: x.priority, reverse=True)
|
||||
_agent_configs.sort(key=lambda x: x.priority, reverse=True)
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Wrap the function in an asyncio.Queue for cancellation support
|
||||
queue = asyncio.Queue()
|
||||
task = None
|
||||
|
||||
try:
|
||||
# Create a task that can be cancelled
|
||||
async def run_loop():
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
await queue.put(('result', result))
|
||||
except Exception as e:
|
||||
await queue.put(('error', e))
|
||||
|
||||
task = asyncio.create_task(run_loop())
|
||||
|
||||
# Wait for result or cancellation
|
||||
event_type, data = await queue.get()
|
||||
|
||||
if event_type == 'error':
|
||||
raise data
|
||||
return data
|
||||
|
||||
except asyncio.CancelledError:
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
return agent_class
|
||||
|
||||
return decorator
|
||||
|
||||
def get_agent_loops() -> List[AgentLoopInfo]:
|
||||
"""Get all registered agent loops"""
|
||||
return _agent_loops.copy()
|
||||
def get_agent_configs() -> List[AgentConfigInfo]:
|
||||
"""Get all registered agent configs"""
|
||||
return _agent_configs.copy()
|
||||
|
||||
def find_agent_loop(model: str) -> Optional[AgentLoopInfo]:
|
||||
"""Find the best matching agent loop for a model"""
|
||||
for loop_info in _agent_loops:
|
||||
if loop_info.matches_model(model):
|
||||
return loop_info
|
||||
def find_agent_config(model: str) -> Optional[AgentConfigInfo]:
|
||||
"""Find the best matching agent config for a model"""
|
||||
for config_info in _agent_configs:
|
||||
if config_info.matches_model(model):
|
||||
return config_info
|
||||
return None
|
||||
|
||||
@@ -4,12 +4,13 @@ Anthropic hosted tools agent loop implementation using liteLLM
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
|
||||
import litellm
|
||||
from litellm.responses.litellm_completion_transformation.transformation import LiteLLMCompletionResponsesConfig
|
||||
|
||||
from ..decorators import agent_loop
|
||||
from ..types import Messages, AgentResponse, Tools
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..responses import (
|
||||
make_reasoning_item,
|
||||
make_output_text_item,
|
||||
@@ -1284,84 +1285,100 @@ def _merge_consecutive_text(content_list: List[Dict[str, Any]]) -> List[Dict[str
|
||||
|
||||
return merged
|
||||
|
||||
@agent_loop(models=r".*claude-.*", priority=5)
|
||||
async def anthropic_hosted_tools_loop(
|
||||
messages: Messages,
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
) -> Union[AgentResponse, AsyncGenerator[Dict[str, Any], None]]:
|
||||
"""
|
||||
Anthropic hosted tools agent loop using liteLLM acompletion.
|
||||
@register_agent(models=r".*claude-.*", priority=5)
|
||||
class AnthropicHostedToolsConfig(AsyncAgentConfig):
|
||||
"""Anthropic hosted tools agent configuration implementing AsyncAgentConfig protocol."""
|
||||
|
||||
Supports Anthropic's computer use models with hosted tools.
|
||||
"""
|
||||
tools = tools or []
|
||||
|
||||
# Get tool configuration for this model
|
||||
tool_config = _get_tool_config_for_model(model)
|
||||
|
||||
# Prepare tools for Anthropic API
|
||||
anthropic_tools = _prepare_tools_for_anthropic(tools, model)
|
||||
|
||||
# Convert responses_items messages to completion format
|
||||
completion_messages = _convert_responses_items_to_completion_messages(messages)
|
||||
if use_prompt_caching:
|
||||
# First combine messages to reduce number of blocks
|
||||
completion_messages = _combine_completion_messages(completion_messages)
|
||||
# Then add cache control, anthropic requires explicit "cache_control" dicts
|
||||
completion_messages = _add_cache_control(completion_messages)
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": completion_messages,
|
||||
"tools": anthropic_tools if anthropic_tools else None,
|
||||
"stream": stream,
|
||||
"num_retries": max_retries,
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: Messages,
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# Add beta header for computer use
|
||||
if anthropic_tools:
|
||||
api_kwargs["headers"] = {
|
||||
"anthropic-beta": tool_config["beta_flag"]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Anthropic hosted tools agent loop using liteLLM acompletion.
|
||||
|
||||
Supports Anthropic's computer use models with hosted tools.
|
||||
"""
|
||||
tools = tools or []
|
||||
|
||||
# Get tool configuration for this model
|
||||
tool_config = _get_tool_config_for_model(model)
|
||||
|
||||
# Prepare tools for Anthropic API
|
||||
anthropic_tools = _prepare_tools_for_anthropic(tools, model)
|
||||
|
||||
# Convert responses_items messages to completion format
|
||||
completion_messages = _convert_responses_items_to_completion_messages(messages)
|
||||
if use_prompt_caching:
|
||||
# First combine messages to reduce number of blocks
|
||||
completion_messages = _combine_completion_messages(completion_messages)
|
||||
# Then add cache control, anthropic requires explicit "cache_control" dicts
|
||||
completion_messages = _add_cache_control(completion_messages)
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": completion_messages,
|
||||
"tools": anthropic_tools if anthropic_tools else None,
|
||||
"stream": stream,
|
||||
"num_retries": max_retries,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# Add beta header for computer use
|
||||
if anthropic_tools:
|
||||
api_kwargs["headers"] = {
|
||||
"anthropic-beta": tool_config["beta_flag"]
|
||||
}
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
# Use liteLLM acompletion
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
# Convert response to responses_items format
|
||||
responses_items = _convert_completion_to_responses_items(response)
|
||||
|
||||
# Extract usage information
|
||||
responses_usage = {
|
||||
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(response.usage).model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(responses_usage)
|
||||
|
||||
# Return in AsyncAgentConfig format
|
||||
return {
|
||||
"output": responses_items,
|
||||
"usage": responses_usage
|
||||
}
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str,
|
||||
**kwargs
|
||||
) -> Optional[Tuple[float, float]]:
|
||||
"""Anthropic hosted tools does not support click prediction."""
|
||||
return None
|
||||
|
||||
# Use liteLLM acompletion
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
# Convert response to responses_items format
|
||||
responses_items = _convert_completion_to_responses_items(response)
|
||||
|
||||
# Extract usage information
|
||||
responses_usage = {
|
||||
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(response.usage).model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(responses_usage)
|
||||
|
||||
# Create agent response
|
||||
agent_response = {
|
||||
"output": responses_items,
|
||||
"usage": responses_usage
|
||||
}
|
||||
|
||||
return agent_response
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""Return the capabilities supported by this agent."""
|
||||
return ["step"]
|
||||
|
||||
@@ -9,8 +9,9 @@ import litellm
|
||||
import inspect
|
||||
import base64
|
||||
|
||||
from ..decorators import agent_loop
|
||||
from ..types import Messages, AgentResponse, Tools
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
|
||||
SOM_TOOL_SCHEMA = {
|
||||
"type": "function",
|
||||
@@ -246,94 +247,114 @@ async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[
|
||||
return [item]
|
||||
|
||||
|
||||
@agent_loop(models=r"omniparser\+.*|omni\+.*", priority=10)
|
||||
async def omniparser_loop(
|
||||
messages: Messages,
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
) -> Union[AgentResponse, AsyncGenerator[Dict[str, Any], None]]:
|
||||
"""
|
||||
OpenAI computer-use-preview agent loop using liteLLM responses.
|
||||
@register_agent(models=r"omniparser\+.*|omni\+.*", priority=10)
|
||||
class OmniparsrConfig(AsyncAgentConfig):
|
||||
"""Omniparser agent configuration implementing AsyncAgentConfig protocol."""
|
||||
|
||||
Supports OpenAI's computer use preview models.
|
||||
"""
|
||||
if not OMNIPARSER_AVAILABLE:
|
||||
raise ValueError("omniparser loop requires som to be installed. Install it with `pip install cua-som`.")
|
||||
|
||||
tools = tools or []
|
||||
|
||||
llm_model = model.split('+')[-1]
|
||||
|
||||
# Prepare tools for OpenAI API
|
||||
openai_tools, id2xy = _prepare_tools_for_omniparser(tools)
|
||||
|
||||
# Find last computer_call_output
|
||||
last_computer_call_output = get_last_computer_call_output(messages)
|
||||
if last_computer_call_output:
|
||||
image_url = last_computer_call_output.get("output", {}).get("image_url", "")
|
||||
image_data = image_url.split(",")[-1]
|
||||
if image_data:
|
||||
parser = get_parser()
|
||||
result = parser.parse(image_data)
|
||||
if _on_screenshot:
|
||||
await _on_screenshot(result.annotated_image_base64, "annotated_image")
|
||||
for element in result.elements:
|
||||
id2xy[element.id] = ((element.bbox.x1 + element.bbox.x2) / 2, (element.bbox.y1 + element.bbox.y2) / 2)
|
||||
|
||||
# handle computer calls -> function calls
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
if not isinstance(message, dict):
|
||||
message = message.__dict__
|
||||
new_messages += await replace_computer_call_with_function(message, id2xy)
|
||||
messages = new_messages
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": llm_model,
|
||||
"input": messages,
|
||||
"tools": openai_tools if openai_tools else None,
|
||||
"stream": stream,
|
||||
"reasoning": {"summary": "concise"},
|
||||
"truncation": "auto",
|
||||
"num_retries": max_retries,
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: Messages,
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
}
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
OpenAI computer-use-preview agent loop using liteLLM responses.
|
||||
|
||||
Supports OpenAI's computer use preview models.
|
||||
"""
|
||||
if not OMNIPARSER_AVAILABLE:
|
||||
raise ValueError("omniparser loop requires som to be installed. Install it with `pip install cua-som`.")
|
||||
|
||||
tools = tools or []
|
||||
|
||||
llm_model = model.split('+')[-1]
|
||||
|
||||
# Prepare tools for OpenAI API
|
||||
openai_tools, id2xy = _prepare_tools_for_omniparser(tools)
|
||||
|
||||
# Find last computer_call_output
|
||||
last_computer_call_output = get_last_computer_call_output(messages)
|
||||
if last_computer_call_output:
|
||||
image_url = last_computer_call_output.get("output", {}).get("image_url", "")
|
||||
image_data = image_url.split(",")[-1]
|
||||
if image_data:
|
||||
parser = get_parser()
|
||||
result = parser.parse(image_data)
|
||||
if _on_screenshot:
|
||||
await _on_screenshot(result.annotated_image_base64, "annotated_image")
|
||||
for element in result.elements:
|
||||
id2xy[element.id] = ((element.bbox.x1 + element.bbox.x2) / 2, (element.bbox.y1 + element.bbox.y2) / 2)
|
||||
|
||||
# handle computer calls -> function calls
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
if not isinstance(message, dict):
|
||||
message = message.__dict__
|
||||
new_messages += await replace_computer_call_with_function(message, id2xy)
|
||||
messages = new_messages
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": llm_model,
|
||||
"input": messages,
|
||||
"tools": openai_tools if openai_tools else None,
|
||||
"stream": stream,
|
||||
"reasoning": {"summary": "concise"},
|
||||
"truncation": "auto",
|
||||
"num_retries": max_retries,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
print(str(api_kwargs)[:1000])
|
||||
|
||||
# Use liteLLM responses
|
||||
response = await litellm.aresponses(**api_kwargs)
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
# Extract usage information
|
||||
usage = {
|
||||
**response.usage.model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(usage)
|
||||
|
||||
# handle som function calls -> xy computer calls
|
||||
new_output = []
|
||||
for i in range(len(response.output)):
|
||||
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy)
|
||||
|
||||
return {
|
||||
"output": new_output,
|
||||
"usage": usage
|
||||
}
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str,
|
||||
**kwargs
|
||||
) -> Optional[Tuple[float, float]]:
|
||||
"""Omniparser does not support click prediction."""
|
||||
return None
|
||||
|
||||
print(str(api_kwargs)[:1000])
|
||||
|
||||
# Use liteLLM responses
|
||||
response = await litellm.aresponses(**api_kwargs)
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
# Extract usage information
|
||||
response.usage = {
|
||||
**response.usage.model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(response.usage)
|
||||
|
||||
# handle som function calls -> xy computer calls
|
||||
new_output = []
|
||||
for i in range(len(response.output)):
|
||||
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy)
|
||||
response.output = new_output
|
||||
|
||||
return response
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""Return the capabilities supported by this agent."""
|
||||
return ["step"]
|
||||
|
||||
@@ -4,11 +4,11 @@ OpenAI computer-use-preview agent loop implementation using liteLLM
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
|
||||
import litellm
|
||||
|
||||
from ..decorators import agent_loop
|
||||
from ..types import Messages, AgentResponse, Tools
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
|
||||
def _map_computer_tool_to_openai(computer_tool: Any) -> Dict[str, Any]:
|
||||
"""Map a computer tool to OpenAI's computer-use-preview tool schema"""
|
||||
@@ -36,60 +36,116 @@ def _prepare_tools_for_openai(tool_schemas: List[Dict[str, Any]]) -> Tools:
|
||||
return openai_tools
|
||||
|
||||
|
||||
@agent_loop(models=r".*computer-use-preview.*", priority=10)
|
||||
async def openai_computer_use_loop(
|
||||
messages: Messages,
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
) -> Union[AgentResponse, AsyncGenerator[Dict[str, Any], None]]:
|
||||
@register_agent(models=r".*computer-use-preview.*", priority=10)
|
||||
class OpenAIComputerUseConfig:
|
||||
"""
|
||||
OpenAI computer-use-preview agent loop using liteLLM responses.
|
||||
OpenAI computer-use-preview agent configuration using liteLLM responses.
|
||||
|
||||
Supports OpenAI's computer use preview models.
|
||||
"""
|
||||
tools = tools or []
|
||||
|
||||
# Prepare tools for OpenAI API
|
||||
openai_tools = _prepare_tools_for_openai(tools)
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"input": messages,
|
||||
"tools": openai_tools if openai_tools else None,
|
||||
"stream": stream,
|
||||
"reasoning": {"summary": "concise"},
|
||||
"truncation": "auto",
|
||||
"num_retries": max_retries,
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
# Use liteLLM responses
|
||||
response = await litellm.aresponses(**api_kwargs)
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Predict the next step based on input items.
|
||||
|
||||
Args:
|
||||
messages: Input items following Responses format
|
||||
model: Model name to use
|
||||
tools: Optional list of tool schemas
|
||||
max_retries: Maximum number of retries
|
||||
stream: Whether to stream responses
|
||||
computer_handler: Computer handler instance
|
||||
_on_api_start: Callback for API start
|
||||
_on_api_end: Callback for API end
|
||||
_on_usage: Callback for usage tracking
|
||||
_on_screenshot: Callback for screenshot events
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
Dictionary with "output" (output items) and "usage" array
|
||||
"""
|
||||
tools = tools or []
|
||||
|
||||
# Prepare tools for OpenAI API
|
||||
openai_tools = _prepare_tools_for_openai(tools)
|
||||
|
||||
# Extract usage information
|
||||
response.usage = {
|
||||
**response.usage.model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(response.usage)
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"input": messages,
|
||||
"tools": openai_tools if openai_tools else None,
|
||||
"stream": stream,
|
||||
"reasoning": {"summary": "concise"},
|
||||
"truncation": "auto",
|
||||
"num_retries": max_retries,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
# Use liteLLM responses
|
||||
response = await litellm.aresponses(**api_kwargs)
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
# Extract usage information
|
||||
usage = {
|
||||
**response.usage.model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(usage)
|
||||
|
||||
# Return in the expected format
|
||||
output_dict = response.model_dump()
|
||||
output_dict["usage"] = usage
|
||||
return output_dict
|
||||
|
||||
return response
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates based on image and instruction.
|
||||
|
||||
Note: OpenAI computer-use-preview doesn't support direct click prediction,
|
||||
so this returns None.
|
||||
|
||||
Args:
|
||||
model: Model name to use
|
||||
image_b64: Base64 encoded image
|
||||
instruction: Instruction for where to click
|
||||
|
||||
Returns:
|
||||
None (not supported by OpenAI computer-use-preview)
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""
|
||||
Get list of capabilities supported by this agent config.
|
||||
|
||||
Returns:
|
||||
List of capability strings
|
||||
"""
|
||||
return ["step"]
|
||||
|
||||
@@ -9,7 +9,7 @@ import base64
|
||||
import math
|
||||
import re
|
||||
import ast
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import litellm
|
||||
@@ -21,8 +21,8 @@ from openai.types.responses.response_input_param import ComputerCallOutput
|
||||
from openai.types.responses.response_output_message_param import ResponseOutputMessageParam
|
||||
from openai.types.responses.response_reasoning_item_param import ResponseReasoningItemParam, Summary
|
||||
|
||||
from ..decorators import agent_loop
|
||||
from ..types import Messages, AgentResponse, Tools
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
from ..responses import (
|
||||
make_reasoning_item,
|
||||
make_output_text_item,
|
||||
@@ -501,188 +501,298 @@ def convert_uitars_messages_to_litellm(messages: Messages) -> List[Dict[str, Any
|
||||
|
||||
return litellm_messages
|
||||
|
||||
@agent_loop(models=r"(?i).*ui-?tars.*", priority=10)
|
||||
async def uitars_loop(
|
||||
messages: Messages,
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
) -> Union[AgentResponse, AsyncGenerator[Dict[str, Any], None]]:
|
||||
@register_agent(models=r"(?i).*ui-?tars.*", priority=10)
|
||||
class UITARSConfig:
|
||||
"""
|
||||
UITARS agent loop using liteLLM for ByteDance-Seed/UI-TARS-1.5-7B model.
|
||||
UITARS agent configuration using liteLLM for ByteDance-Seed/UI-TARS-1.5-7B model.
|
||||
|
||||
Supports UITARS vision-language models for computer control.
|
||||
"""
|
||||
tools = tools or []
|
||||
|
||||
# Create response items
|
||||
response_items = []
|
||||
|
||||
# Find computer tool for screen dimensions
|
||||
computer_tool = None
|
||||
for tool_schema in tools:
|
||||
if tool_schema["type"] == "computer":
|
||||
computer_tool = tool_schema["computer"]
|
||||
break
|
||||
|
||||
# Get screen dimensions
|
||||
screen_width, screen_height = 1024, 768
|
||||
if computer_tool:
|
||||
try:
|
||||
screen_width, screen_height = await computer_tool.get_dimensions()
|
||||
except:
|
||||
pass
|
||||
|
||||
# Process messages to extract instruction and image
|
||||
instruction = ""
|
||||
image_data = None
|
||||
|
||||
# Convert messages to list if string
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
# Extract instruction and latest screenshot
|
||||
for message in reversed(messages):
|
||||
if isinstance(message, dict):
|
||||
content = message.get("content", "")
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: Messages,
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Predict the next step based on input messages.
|
||||
|
||||
Args:
|
||||
messages: Input messages following Responses format
|
||||
model: Model name to use
|
||||
tools: Optional list of tool schemas
|
||||
max_retries: Maximum number of retries
|
||||
stream: Whether to stream responses
|
||||
computer_handler: Computer handler instance
|
||||
_on_api_start: Callback for API start
|
||||
_on_api_end: Callback for API end
|
||||
_on_usage: Callback for usage tracking
|
||||
_on_screenshot: Callback for screenshot events
|
||||
**kwargs: Additional arguments
|
||||
|
||||
# Handle different content formats
|
||||
if isinstance(content, str):
|
||||
if not instruction and message.get("role") == "user":
|
||||
instruction = content
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text" and not instruction:
|
||||
instruction = item.get("text", "")
|
||||
elif item.get("type") == "image_url" and not image_data:
|
||||
image_url = item.get("image_url", {})
|
||||
if isinstance(image_url, dict):
|
||||
image_data = image_url.get("url", "")
|
||||
else:
|
||||
image_data = image_url
|
||||
Returns:
|
||||
Dictionary with "output" (output items) and "usage" array
|
||||
"""
|
||||
tools = tools or []
|
||||
|
||||
# Also check for computer_call_output with screenshots
|
||||
if message.get("type") == "computer_call_output" and not image_data:
|
||||
output = message.get("output", {})
|
||||
if isinstance(output, dict) and output.get("type") == "input_image":
|
||||
image_data = output.get("image_url", "")
|
||||
# Create response items
|
||||
response_items = []
|
||||
|
||||
if instruction and image_data:
|
||||
break
|
||||
|
||||
if not instruction:
|
||||
instruction = "Help me complete this task by analyzing the screen and taking appropriate actions."
|
||||
|
||||
# Create prompt
|
||||
user_prompt = UITARS_PROMPT_TEMPLATE.format(
|
||||
instruction=instruction,
|
||||
action_space=UITARS_ACTION_SPACE,
|
||||
language="English"
|
||||
)
|
||||
|
||||
# Convert conversation history to LiteLLM format
|
||||
history_messages = convert_uitars_messages_to_litellm(messages)
|
||||
|
||||
# Prepare messages for liteLLM
|
||||
litellm_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}
|
||||
]
|
||||
|
||||
# Add current user instruction with screenshot
|
||||
current_user_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": user_prompt},
|
||||
# Find computer tool for screen dimensions
|
||||
computer_tool = None
|
||||
for tool_schema in tools:
|
||||
if tool_schema["type"] == "computer":
|
||||
computer_tool = tool_schema["computer"]
|
||||
break
|
||||
|
||||
# Get screen dimensions
|
||||
screen_width, screen_height = 1024, 768
|
||||
if computer_tool:
|
||||
try:
|
||||
screen_width, screen_height = await computer_tool.get_dimensions()
|
||||
except:
|
||||
pass
|
||||
|
||||
# Process messages to extract instruction and image
|
||||
instruction = ""
|
||||
image_data = None
|
||||
|
||||
# Convert messages to list if string
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
# Extract instruction and latest screenshot
|
||||
for message in reversed(messages):
|
||||
if isinstance(message, dict):
|
||||
content = message.get("content", "")
|
||||
|
||||
# Handle different content formats
|
||||
if isinstance(content, str):
|
||||
if not instruction and message.get("role") == "user":
|
||||
instruction = content
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text" and not instruction:
|
||||
instruction = item.get("text", "")
|
||||
elif item.get("type") == "image_url" and not image_data:
|
||||
image_url = item.get("image_url", {})
|
||||
if isinstance(image_url, dict):
|
||||
image_data = image_url.get("url", "")
|
||||
else:
|
||||
image_data = image_url
|
||||
|
||||
# Also check for computer_call_output with screenshots
|
||||
if message.get("type") == "computer_call_output" and not image_data:
|
||||
output = message.get("output", {})
|
||||
if isinstance(output, dict) and output.get("type") == "input_image":
|
||||
image_data = output.get("image_url", "")
|
||||
|
||||
if instruction and image_data:
|
||||
break
|
||||
|
||||
if not instruction:
|
||||
instruction = "Help me complete this task by analyzing the screen and taking appropriate actions."
|
||||
|
||||
# Create prompt
|
||||
user_prompt = UITARS_PROMPT_TEMPLATE.format(
|
||||
instruction=instruction,
|
||||
action_space=UITARS_ACTION_SPACE,
|
||||
language="English"
|
||||
)
|
||||
|
||||
# Convert conversation history to LiteLLM format
|
||||
history_messages = convert_uitars_messages_to_litellm(messages)
|
||||
|
||||
# Prepare messages for liteLLM
|
||||
litellm_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}
|
||||
]
|
||||
}
|
||||
litellm_messages.append(current_user_message)
|
||||
|
||||
# Process image for UITARS
|
||||
if not image_data:
|
||||
# Take screenshot if none found in messages
|
||||
if computer_handler:
|
||||
image_data = await computer_handler.screenshot()
|
||||
await _on_screenshot(image_data, "screenshot_before")
|
||||
|
||||
# Add screenshot to output items so it can be retained in history
|
||||
response_items.append(make_input_image_item(image_data))
|
||||
else:
|
||||
raise ValueError("No screenshot found in messages and no computer_handler provided")
|
||||
processed_image, original_width, original_height = process_image_for_uitars(image_data)
|
||||
encoded_image = pil_to_base64(processed_image)
|
||||
|
||||
# Add conversation history
|
||||
if history_messages:
|
||||
litellm_messages.extend(history_messages)
|
||||
else:
|
||||
litellm_messages.append({
|
||||
"role": "user",
|
||||
# Add current user instruction with screenshot
|
||||
current_user_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_image}"}}
|
||||
{"type": "text", "text": user_prompt},
|
||||
]
|
||||
})
|
||||
}
|
||||
litellm_messages.append(current_user_message)
|
||||
|
||||
# Process image for UITARS
|
||||
if not image_data:
|
||||
# Take screenshot if none found in messages
|
||||
if computer_handler:
|
||||
image_data = await computer_handler.screenshot()
|
||||
await _on_screenshot(image_data, "screenshot_before")
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": litellm_messages,
|
||||
"max_tokens": kwargs.get("max_tokens", 500),
|
||||
"temperature": kwargs.get("temperature", 0.0),
|
||||
"do_sample": kwargs.get("temperature", 0.0) > 0.0,
|
||||
"num_retries": max_retries,
|
||||
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]}
|
||||
}
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
# Call liteLLM with UITARS model
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
# Extract response content
|
||||
response_content = response.choices[0].message.content.strip() # type: ignore
|
||||
|
||||
# Parse UITARS response
|
||||
parsed_responses = parse_uitars_response(response_content, original_width, original_height)
|
||||
|
||||
# Convert to computer actions
|
||||
computer_actions = convert_to_computer_actions(parsed_responses, original_width, original_height)
|
||||
|
||||
# Add computer actions to response items
|
||||
thought = parsed_responses[0].get("thought", "")
|
||||
if thought:
|
||||
response_items.append(make_reasoning_item(thought))
|
||||
response_items.extend(computer_actions)
|
||||
|
||||
# Extract usage information
|
||||
response_usage = {
|
||||
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(response.usage).model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(response_usage)
|
||||
# Add screenshot to output items so it can be retained in history
|
||||
response_items.append(make_input_image_item(image_data))
|
||||
else:
|
||||
raise ValueError("No screenshot found in messages and no computer_handler provided")
|
||||
processed_image, original_width, original_height = process_image_for_uitars(image_data)
|
||||
encoded_image = pil_to_base64(processed_image)
|
||||
|
||||
# Add conversation history
|
||||
if history_messages:
|
||||
litellm_messages.extend(history_messages)
|
||||
else:
|
||||
litellm_messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_image}"}}
|
||||
]
|
||||
})
|
||||
|
||||
# Create agent response
|
||||
agent_response = {
|
||||
"output": response_items,
|
||||
"usage": response_usage
|
||||
}
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": litellm_messages,
|
||||
"max_tokens": kwargs.get("max_tokens", 500),
|
||||
"temperature": kwargs.get("temperature", 0.0),
|
||||
"do_sample": kwargs.get("temperature", 0.0) > 0.0,
|
||||
"num_retries": max_retries,
|
||||
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]}
|
||||
}
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
# Call liteLLM with UITARS model
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
# Extract response content
|
||||
response_content = response.choices[0].message.content.strip() # type: ignore
|
||||
|
||||
# Parse UITARS response
|
||||
parsed_responses = parse_uitars_response(response_content, original_width, original_height)
|
||||
|
||||
# Convert to computer actions
|
||||
computer_actions = convert_to_computer_actions(parsed_responses, original_width, original_height)
|
||||
|
||||
# Add computer actions to response items
|
||||
thought = parsed_responses[0].get("thought", "")
|
||||
if thought:
|
||||
response_items.append(make_reasoning_item(thought))
|
||||
response_items.extend(computer_actions)
|
||||
|
||||
# Extract usage information
|
||||
response_usage = {
|
||||
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(response.usage).model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(response_usage)
|
||||
|
||||
# Create agent response
|
||||
agent_response = {
|
||||
"output": response_items,
|
||||
"usage": response_usage
|
||||
}
|
||||
|
||||
return agent_response
|
||||
|
||||
return agent_response
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates based on image and instruction.
|
||||
|
||||
UITARS supports click prediction through its action parsing.
|
||||
|
||||
Args:
|
||||
model: Model name to use
|
||||
image_b64: Base64 encoded image
|
||||
instruction: Instruction for where to click
|
||||
|
||||
Returns:
|
||||
Tuple with (x, y) coordinates or None
|
||||
"""
|
||||
try:
|
||||
# Create a simple click instruction for UITARS
|
||||
user_prompt = UITARS_PROMPT_TEMPLATE.format(
|
||||
instruction=f"Click on: {instruction}",
|
||||
action_space=UITARS_ACTION_SPACE,
|
||||
language="English"
|
||||
)
|
||||
|
||||
# Prepare messages for liteLLM
|
||||
litellm_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": user_prompt},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Call liteLLM with UITARS model
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=litellm_messages,
|
||||
max_tokens=100,
|
||||
temperature=0.0
|
||||
)
|
||||
|
||||
# Extract response content
|
||||
response_content = response.choices[0].message.content.strip() # type: ignore
|
||||
|
||||
# Parse UITARS response to extract click coordinates
|
||||
parsed_responses = parse_uitars_response(response_content, 1024, 768) # Default dimensions
|
||||
|
||||
if parsed_responses and len(parsed_responses) > 0:
|
||||
action_type = parsed_responses[0].get("action_type")
|
||||
if action_type == "click":
|
||||
action_inputs = parsed_responses[0].get("action_inputs", {})
|
||||
start_box = action_inputs.get("start_box")
|
||||
if start_box:
|
||||
# Parse coordinates from start_box
|
||||
try:
|
||||
coords = eval(start_box) # Parse the coordinate list
|
||||
if len(coords) >= 2:
|
||||
# Convert normalized coordinates back to pixel coordinates
|
||||
x = int(coords[0] * 1024)
|
||||
y = int(coords[1] * 768)
|
||||
return (x, y)
|
||||
except:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in UITARS predict_click: {e}")
|
||||
return None
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""
|
||||
Get list of capabilities supported by this agent config.
|
||||
|
||||
Returns:
|
||||
List of capability strings
|
||||
"""
|
||||
return ["step", "click"]
|
||||
@@ -14,16 +14,18 @@ Tools = Optional[Iterable[ToolParam]]
|
||||
|
||||
# Agent output types
|
||||
AgentResponse = ResponsesAPIResponse
|
||||
AgentCapability = Literal["step", "click"]
|
||||
|
||||
# Agent loop registration
|
||||
class AgentLoopInfo(BaseModel):
|
||||
"""Information about a registered agent loop"""
|
||||
func: Callable
|
||||
|
||||
# Agent config registration
|
||||
class AgentConfigInfo(BaseModel):
|
||||
"""Information about a registered agent config"""
|
||||
agent_class: type
|
||||
models_regex: str
|
||||
priority: int = 0
|
||||
|
||||
def matches_model(self, model: str) -> bool:
|
||||
"""Check if this loop matches the given model"""
|
||||
"""Check if this agent config matches the given model"""
|
||||
return bool(re.match(self.models_regex, model))
|
||||
|
||||
# Computer tool interface
|
||||
|
||||
Reference in New Issue
Block a user