mirror of
https://github.com/trycua/computer.git
synced 2026-01-01 19:10:30 -06:00
Add omniparser predict_click
This commit is contained in:
@@ -115,7 +115,7 @@ def get_last_computer_call_image(messages: List[Dict[str, Any]]) -> Optional[str
|
||||
return None
|
||||
|
||||
|
||||
@register_agent(r".*\+.*", priority=10)
|
||||
@register_agent(r".*\+.*", priority=1)
|
||||
class ComposedGroundedConfig:
|
||||
"""
|
||||
Composed-grounded agent configuration that uses both grounding and thinking models.
|
||||
|
||||
@@ -249,13 +249,13 @@ async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[
|
||||
return [item]
|
||||
|
||||
|
||||
@register_agent(models=r"omniparser\+.*|omni\+.*")
|
||||
class OmniparsrConfig(AsyncAgentConfig):
|
||||
@register_agent(models=r"omniparser\+.*|omni\+.*", priority=2)
|
||||
class OmniparserConfig(AsyncAgentConfig):
|
||||
"""Omniparser agent configuration implementing AsyncAgentConfig protocol."""
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: Messages,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
@@ -284,7 +284,7 @@ class OmniparsrConfig(AsyncAgentConfig):
|
||||
openai_tools, id2xy = _prepare_tools_for_omniparser(tools)
|
||||
|
||||
# Find last computer_call_output
|
||||
last_computer_call_output = get_last_computer_call_output(messages)
|
||||
last_computer_call_output = get_last_computer_call_output(messages) # type: ignore
|
||||
if last_computer_call_output:
|
||||
image_url = last_computer_call_output.get("output", {}).get("image_url", "")
|
||||
image_data = image_url.split(",")[-1]
|
||||
@@ -301,7 +301,7 @@ class OmniparsrConfig(AsyncAgentConfig):
|
||||
for message in messages:
|
||||
if not isinstance(message, dict):
|
||||
message = message.__dict__
|
||||
new_messages += await replace_computer_call_with_function(message, id2xy)
|
||||
new_messages += await replace_computer_call_with_function(message, id2xy) # type: ignore
|
||||
messages = new_messages
|
||||
|
||||
# Prepare API call kwargs
|
||||
@@ -331,7 +331,7 @@ class OmniparsrConfig(AsyncAgentConfig):
|
||||
# Extract usage information
|
||||
usage = {
|
||||
**response.usage.model_dump(), # type: ignore
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0), # type: ignore
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(usage)
|
||||
@@ -339,7 +339,7 @@ class OmniparsrConfig(AsyncAgentConfig):
|
||||
# handle som function calls -> xy computer calls
|
||||
new_output = []
|
||||
for i in range(len(response.output)): # type: ignore
|
||||
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy)
|
||||
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy) # type: ignore
|
||||
|
||||
return {
|
||||
"output": new_output,
|
||||
@@ -353,7 +353,79 @@ class OmniparsrConfig(AsyncAgentConfig):
|
||||
instruction: str,
|
||||
**kwargs
|
||||
) -> Optional[Tuple[float, float]]:
|
||||
"""Omniparser does not support click prediction."""
|
||||
"""
|
||||
Predict click coordinates using OmniParser and LLM.
|
||||
|
||||
Uses OmniParser to annotate the image with element IDs, then uses LLM
|
||||
to identify the correct element ID based on the instruction.
|
||||
"""
|
||||
if not OMNIPARSER_AVAILABLE:
|
||||
return None
|
||||
|
||||
# Parse the image with OmniParser to get annotated image and elements
|
||||
parser = get_parser()
|
||||
result = parser.parse(image_b64)
|
||||
|
||||
# Extract the LLM model from composed model string
|
||||
llm_model = model.split('+')[-1]
|
||||
|
||||
# Create system prompt for element ID prediction
|
||||
SYSTEM_PROMPT = f'''
|
||||
You are an expert UI element locator. Given a GUI image annotated with numerical IDs over each interactable element, along with a user's element description, provide the ID of the specified element.
|
||||
|
||||
The image shows UI elements with numbered overlays. Each number corresponds to a clickable/interactable element.
|
||||
|
||||
Output only the element ID as a single integer.
|
||||
'''.strip()
|
||||
|
||||
# Prepare messages for LLM
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": SYSTEM_PROMPT
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{result.annotated_image_base64}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Find the element: {instruction}"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Call LLM to predict element ID
|
||||
response = await litellm.acompletion(
|
||||
model=llm_model,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.1
|
||||
)
|
||||
|
||||
# Extract element ID from response
|
||||
response_text = response.choices[0].message.content.strip() # type: ignore
|
||||
|
||||
# Try to parse the element ID
|
||||
try:
|
||||
element_id = int(response_text)
|
||||
|
||||
# Find the element with this ID and return its center coordinates
|
||||
for element in result.elements:
|
||||
if element.id == element_id:
|
||||
center_x = (element.bbox.x1 + element.bbox.x2) / 2
|
||||
center_y = (element.bbox.y1 + element.bbox.y2) / 2
|
||||
return (center_x, center_y)
|
||||
except ValueError:
|
||||
# If we can't parse the ID, return None
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
|
||||
Reference in New Issue
Block a user