From 3e7bc0aa7941994bf0a87bf3303c9a427df3935f Mon Sep 17 00:00:00 2001 From: Dillon DuPont Date: Mon, 4 Aug 2025 18:11:21 -0400 Subject: [PATCH] Add omniparser predict_click --- .../agent/agent/loops/composed_grounded.py | 2 +- libs/python/agent/agent/loops/omniparser.py | 88 +++++++++++++++++-- 2 files changed, 81 insertions(+), 9 deletions(-) diff --git a/libs/python/agent/agent/loops/composed_grounded.py b/libs/python/agent/agent/loops/composed_grounded.py index 31b29372..1371ff3f 100644 --- a/libs/python/agent/agent/loops/composed_grounded.py +++ b/libs/python/agent/agent/loops/composed_grounded.py @@ -115,7 +115,7 @@ def get_last_computer_call_image(messages: List[Dict[str, Any]]) -> Optional[str return None -@register_agent(r".*\+.*", priority=10) +@register_agent(r".*\+.*", priority=1) class ComposedGroundedConfig: """ Composed-grounded agent configuration that uses both grounding and thinking models. diff --git a/libs/python/agent/agent/loops/omniparser.py b/libs/python/agent/agent/loops/omniparser.py index 2cf2d2c7..d85d07de 100644 --- a/libs/python/agent/agent/loops/omniparser.py +++ b/libs/python/agent/agent/loops/omniparser.py @@ -249,13 +249,13 @@ async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[ return [item] -@register_agent(models=r"omniparser\+.*|omni\+.*") -class OmniparsrConfig(AsyncAgentConfig): +@register_agent(models=r"omniparser\+.*|omni\+.*", priority=2) +class OmniparserConfig(AsyncAgentConfig): """Omniparser agent configuration implementing AsyncAgentConfig protocol.""" async def predict_step( self, - messages: Messages, + messages: List[Dict[str, Any]], model: str, tools: Optional[List[Dict[str, Any]]] = None, max_retries: Optional[int] = None, @@ -284,7 +284,7 @@ class OmniparsrConfig(AsyncAgentConfig): openai_tools, id2xy = _prepare_tools_for_omniparser(tools) # Find last computer_call_output - last_computer_call_output = get_last_computer_call_output(messages) + last_computer_call_output = get_last_computer_call_output(messages) # type: ignore if last_computer_call_output: image_url = last_computer_call_output.get("output", {}).get("image_url", "") image_data = image_url.split(",")[-1] @@ -301,7 +301,7 @@ class OmniparsrConfig(AsyncAgentConfig): for message in messages: if not isinstance(message, dict): message = message.__dict__ - new_messages += await replace_computer_call_with_function(message, id2xy) + new_messages += await replace_computer_call_with_function(message, id2xy) # type: ignore messages = new_messages # Prepare API call kwargs @@ -331,7 +331,7 @@ class OmniparsrConfig(AsyncAgentConfig): # Extract usage information usage = { **response.usage.model_dump(), # type: ignore - "response_cost": response._hidden_params.get("response_cost", 0.0), + "response_cost": response._hidden_params.get("response_cost", 0.0), # type: ignore } if _on_usage: await _on_usage(usage) @@ -339,7 +339,7 @@ class OmniparsrConfig(AsyncAgentConfig): # handle som function calls -> xy computer calls new_output = [] for i in range(len(response.output)): # type: ignore - new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy) + new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy) # type: ignore return { "output": new_output, @@ -353,7 +353,79 @@ class OmniparsrConfig(AsyncAgentConfig): instruction: str, **kwargs ) -> Optional[Tuple[float, float]]: - """Omniparser does not support click prediction.""" + """ + Predict click coordinates using OmniParser and LLM. + + Uses OmniParser to annotate the image with element IDs, then uses LLM + to identify the correct element ID based on the instruction. + """ + if not OMNIPARSER_AVAILABLE: + return None + + # Parse the image with OmniParser to get annotated image and elements + parser = get_parser() + result = parser.parse(image_b64) + + # Extract the LLM model from composed model string + llm_model = model.split('+')[-1] + + # Create system prompt for element ID prediction + SYSTEM_PROMPT = f''' +You are an expert UI element locator. Given a GUI image annotated with numerical IDs over each interactable element, along with a user's element description, provide the ID of the specified element. + +The image shows UI elements with numbered overlays. Each number corresponds to a clickable/interactable element. + +Output only the element ID as a single integer. +'''.strip() + + # Prepare messages for LLM + messages = [ + { + "role": "system", + "content": SYSTEM_PROMPT + }, + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{result.annotated_image_base64}" + } + }, + { + "type": "text", + "text": f"Find the element: {instruction}" + } + ] + } + ] + + # Call LLM to predict element ID + response = await litellm.acompletion( + model=llm_model, + messages=messages, + max_tokens=10, + temperature=0.1 + ) + + # Extract element ID from response + response_text = response.choices[0].message.content.strip() # type: ignore + + # Try to parse the element ID + try: + element_id = int(response_text) + + # Find the element with this ID and return its center coordinates + for element in result.elements: + if element.id == element_id: + center_x = (element.bbox.x1 + element.bbox.x2) / 2 + center_y = (element.bbox.y1 + element.bbox.y2) / 2 + return (center_x, center_y) + except ValueError: + # If we can't parse the ID, return None + pass + return None def get_capabilities(self) -> List[AgentCapability]: