Add omniparser predict_click

This commit is contained in:
Dillon DuPont
2025-08-04 18:11:21 -04:00
parent f87b8eaea5
commit 3e7bc0aa79
2 changed files with 81 additions and 9 deletions

View File

@@ -115,7 +115,7 @@ def get_last_computer_call_image(messages: List[Dict[str, Any]]) -> Optional[str
return None
@register_agent(r".*\+.*", priority=10)
@register_agent(r".*\+.*", priority=1)
class ComposedGroundedConfig:
"""
Composed-grounded agent configuration that uses both grounding and thinking models.

View File

@@ -249,13 +249,13 @@ async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[
return [item]
@register_agent(models=r"omniparser\+.*|omni\+.*")
class OmniparsrConfig(AsyncAgentConfig):
@register_agent(models=r"omniparser\+.*|omni\+.*", priority=2)
class OmniparserConfig(AsyncAgentConfig):
"""Omniparser agent configuration implementing AsyncAgentConfig protocol."""
async def predict_step(
self,
messages: Messages,
messages: List[Dict[str, Any]],
model: str,
tools: Optional[List[Dict[str, Any]]] = None,
max_retries: Optional[int] = None,
@@ -284,7 +284,7 @@ class OmniparsrConfig(AsyncAgentConfig):
openai_tools, id2xy = _prepare_tools_for_omniparser(tools)
# Find last computer_call_output
last_computer_call_output = get_last_computer_call_output(messages)
last_computer_call_output = get_last_computer_call_output(messages) # type: ignore
if last_computer_call_output:
image_url = last_computer_call_output.get("output", {}).get("image_url", "")
image_data = image_url.split(",")[-1]
@@ -301,7 +301,7 @@ class OmniparsrConfig(AsyncAgentConfig):
for message in messages:
if not isinstance(message, dict):
message = message.__dict__
new_messages += await replace_computer_call_with_function(message, id2xy)
new_messages += await replace_computer_call_with_function(message, id2xy) # type: ignore
messages = new_messages
# Prepare API call kwargs
@@ -331,7 +331,7 @@ class OmniparsrConfig(AsyncAgentConfig):
# Extract usage information
usage = {
**response.usage.model_dump(), # type: ignore
"response_cost": response._hidden_params.get("response_cost", 0.0),
"response_cost": response._hidden_params.get("response_cost", 0.0), # type: ignore
}
if _on_usage:
await _on_usage(usage)
@@ -339,7 +339,7 @@ class OmniparsrConfig(AsyncAgentConfig):
# handle som function calls -> xy computer calls
new_output = []
for i in range(len(response.output)): # type: ignore
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy)
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy) # type: ignore
return {
"output": new_output,
@@ -353,7 +353,79 @@ class OmniparsrConfig(AsyncAgentConfig):
instruction: str,
**kwargs
) -> Optional[Tuple[float, float]]:
"""Omniparser does not support click prediction."""
"""
Predict click coordinates using OmniParser and LLM.
Uses OmniParser to annotate the image with element IDs, then uses LLM
to identify the correct element ID based on the instruction.
"""
if not OMNIPARSER_AVAILABLE:
return None
# Parse the image with OmniParser to get annotated image and elements
parser = get_parser()
result = parser.parse(image_b64)
# Extract the LLM model from composed model string
llm_model = model.split('+')[-1]
# Create system prompt for element ID prediction
SYSTEM_PROMPT = f'''
You are an expert UI element locator. Given a GUI image annotated with numerical IDs over each interactable element, along with a user's element description, provide the ID of the specified element.
The image shows UI elements with numbered overlays. Each number corresponds to a clickable/interactable element.
Output only the element ID as a single integer.
'''.strip()
# Prepare messages for LLM
messages = [
{
"role": "system",
"content": SYSTEM_PROMPT
},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{result.annotated_image_base64}"
}
},
{
"type": "text",
"text": f"Find the element: {instruction}"
}
]
}
]
# Call LLM to predict element ID
response = await litellm.acompletion(
model=llm_model,
messages=messages,
max_tokens=10,
temperature=0.1
)
# Extract element ID from response
response_text = response.choices[0].message.content.strip() # type: ignore
# Try to parse the element ID
try:
element_id = int(response_text)
# Find the element with this ID and return its center coordinates
for element in result.elements:
if element.id == element_id:
center_x = (element.bbox.x1 + element.bbox.x2) / 2
center_y = (element.bbox.y1 + element.bbox.y2) / 2
return (center_x, center_y)
except ValueError:
# If we can't parse the ID, return None
pass
return None
def get_capabilities(self) -> List[AgentCapability]: