diff --git a/libs/python/agent/agent/cli.py b/libs/python/agent/agent/cli.py index c0434d02..0ea840d2 100644 --- a/libs/python/agent/agent/cli.py +++ b/libs/python/agent/agent/cli.py @@ -167,7 +167,7 @@ async def chat_loop(agent, model: str, container_name: str, initial_prompt: str # Process and display the output for item in result.get("output", []): - if item.get("type") == "message": + if item.get("type") == "message" and item.get("role") == "assistant": # Display agent text response content = item.get("content", []) for content_part in content: diff --git a/libs/python/agent/agent/loops/__init__.py b/libs/python/agent/agent/loops/__init__.py index 406f14ca..d9014259 100644 --- a/libs/python/agent/agent/loops/__init__.py +++ b/libs/python/agent/agent/loops/__init__.py @@ -13,6 +13,7 @@ from . import glm45v from . import opencua from . import internvl from . import holo +from . import moondream3 __all__ = [ "anthropic", @@ -25,4 +26,5 @@ __all__ = [ "opencua", "internvl", "holo", + "moondream3", ] \ No newline at end of file diff --git a/libs/python/agent/agent/loops/moondream3.py b/libs/python/agent/agent/loops/moondream3.py new file mode 100644 index 00000000..6c775c8f --- /dev/null +++ b/libs/python/agent/agent/loops/moondream3.py @@ -0,0 +1,415 @@ +""" +Moondream3+ composed-grounded agent loop implementation. +Grounding is handled by a local Moondream3 preview model via Transformers. +Thinking is delegated to the trailing LLM in the composed model string: "moondream3+". + +Differences from composed_grounded: +- Provides a singleton Moondream3 client outside the class. +- predict_click uses model.point(image, instruction, settings={"max_objects": 1}) and returns pixel coordinates. +- If the last image was a screenshot (or we take one), run model.detect(image, "all form ui") to get bboxes, then + run model.caption on each cropped bbox to label it. Overlay labels on the screenshot and emit via _on_screenshot. +- Add a user message listing all detected form UI names so the thinker can reference them. +- If the thinking model doesn't support vision, filter out image content before calling litellm. +""" + +from __future__ import annotations + +import uuid +import base64 +import io +from typing import Dict, List, Any, Optional, Tuple, Any + +from PIL import Image, ImageDraw, ImageFont +import torch +from transformers import AutoModelForCausalLM +import litellm + +from ..decorators import register_agent +from ..types import AgentCapability +from ..loops.base import AsyncAgentConfig +from ..responses import ( + convert_computer_calls_xy2desc, + convert_responses_items_to_completion_messages, + convert_completion_messages_to_responses_items, + convert_computer_calls_desc2xy, + get_all_element_descriptions, +) + +_MOONDREAM_SINGLETON = None + +def get_moondream_model() -> Any: + """Get a singleton instance of the Moondream3 preview model.""" + global _MOONDREAM_SINGLETON + if _MOONDREAM_SINGLETON is None: + _MOONDREAM_SINGLETON = AutoModelForCausalLM.from_pretrained( + "moondream/moondream3-preview", + trust_remote_code=True, + torch_dtype=torch.bfloat16, + device_map="cuda", + ) + return _MOONDREAM_SINGLETON + + +def _decode_image_b64(image_b64: str) -> Image.Image: + data = base64.b64decode(image_b64) + return Image.open(io.BytesIO(data)).convert("RGB") + + +def _image_to_b64(img: Image.Image) -> str: + buf = io.BytesIO() + img.save(buf, format="PNG") + return base64.b64encode(buf.getvalue()).decode("utf-8") + + +def _supports_vision(model: str) -> bool: + """Heuristic vision support detection for thinking model.""" + m = model.lower() + vision_markers = [ + "gpt-4o", + "gpt-4.1", + "o1", + "o3", + "claude-3", + "claude-3.5", + "sonnet", + "haiku", + "opus", + "gemini-1.5", + "llava", + ] + return any(v in m for v in vision_markers) + + +def _filter_images_from_completion_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + filtered: List[Dict[str, Any]] = [] + for msg in messages: + msg_copy = {**msg} + content = msg_copy.get("content") + if isinstance(content, list): + msg_copy["content"] = [c for c in content if c.get("type") != "image_url"] + filtered.append(msg_copy) + return filtered + +GROUNDED_COMPUTER_TOOL_SCHEMA = { + "type": "function", + "function": { + "name": "computer", + "description": ( + "Control a computer by taking screenshots and interacting with UI elements. " + "The screenshot action will include a list of detected form UI element names when available. " + "Use element descriptions to locate and interact with UI elements on the screen." + ), + "parameters": { + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": [ + "screenshot", + "click", + "double_click", + "drag", + "type", + "keypress", + "scroll", + "move", + "wait", + "get_current_url", + "get_dimensions", + "get_environment", + ], + "description": "The action to perform (required for all actions)", + }, + "element_description": { + "type": "string", + "description": "Description of the element to interact with (required for click/double_click/move/scroll)", + }, + "start_element_description": { + "type": "string", + "description": "Description of the element to start dragging from (required for drag)", + }, + "end_element_description": { + "type": "string", + "description": "Description of the element to drag to (required for drag)", + }, + "text": { + "type": "string", + "description": "The text to type (required for type)", + }, + "keys": { + "type": "array", + "items": {"type": "string"}, + "description": "Key(s) to press (required for keypress)", + }, + "button": { + "type": "string", + "enum": ["left", "right", "wheel", "back", "forward"], + "description": "The mouse button to use for click/double_click", + }, + "scroll_x": { + "type": "integer", + "description": "Horizontal scroll amount (required for scroll)", + }, + "scroll_y": { + "type": "integer", + "description": "Vertical scroll amount (required for scroll)", + }, + }, + "required": ["action"], + }, + }, +} + +@register_agent(r"moondream3\+.*", priority=2) +class Moondream3PlusConfig(AsyncAgentConfig): + def __init__(self): + self.desc2xy: Dict[str, Tuple[float, float]] = {} + + async def predict_step( + self, + messages: List[Dict[str, Any]], + model: str, + tools: Optional[List[Dict[str, Any]]] = None, + max_retries: Optional[int] = None, + stream: bool = False, + computer_handler=None, + use_prompt_caching: Optional[bool] = False, + _on_api_start=None, + _on_api_end=None, + _on_usage=None, + _on_screenshot=None, + **kwargs, + ) -> Dict[str, Any]: + # Parse composed model: moondream3+ + if "+" not in model: + raise ValueError(f"Composed model must be 'moondream3+', got: {model}") + _, thinking_model = model.split("+", 1) + + pre_output_items: List[Dict[str, Any]] = [] + + # Acquire last screenshot; if missing, take one + last_image_b64: Optional[str] = None + for message in reversed(messages): + if ( + isinstance(message, dict) + and message.get("type") == "computer_call_output" + and isinstance(message.get("output"), dict) + and message["output"].get("type") == "input_image" + ): + image_url = message["output"].get("image_url", "") + if image_url.startswith("data:image/png;base64,"): + last_image_b64 = image_url.split(",", 1)[1] + break + + if last_image_b64 is None and computer_handler is not None: + # Take a screenshot + screenshot_b64 = await computer_handler.screenshot() # type: ignore + if screenshot_b64: + call_id = uuid.uuid4().hex + pre_output_items += [ + { + "type": "message", + "role": "assistant", + "content": [ + {"type": "output_text", "text": "Taking a screenshot to analyze the current screen."} + ], + }, + {"type": "computer_call", "call_id": call_id, "status": "completed", "action": {"type": "screenshot"}}, + { + "type": "computer_call_output", + "call_id": call_id, + "output": {"type": "input_image", "image_url": f"data:image/png;base64,{screenshot_b64}"}, + }, + ] + last_image_b64 = screenshot_b64 + if _on_screenshot: + await _on_screenshot(screenshot_b64) + + # If we have a last screenshot, run Moondream detection and labeling + detected_names: List[str] = [] + if last_image_b64 is not None: + base_img = _decode_image_b64(last_image_b64) + W, H = base_img.width, base_img.height + model_md = get_moondream_model() + + try: + detect_result = model_md.detect(base_img, "all ui elements") + objects = detect_result.get("objects", []) if isinstance(detect_result, dict) else [] + except Exception: + objects = [] + + draw = ImageDraw.Draw(base_img) + try: + font = ImageFont.load_default() + except Exception: + font = None + + # For each object, crop and caption + for i, obj in enumerate(objects): + try: + x_min = max(0.0, min(1.0, float(obj.get("x_min", 0.0)))) + y_min = max(0.0, min(1.0, float(obj.get("y_min", 0.0)))) + x_max = max(0.0, min(1.0, float(obj.get("x_max", 0.0)))) + y_max = max(0.0, min(1.0, float(obj.get("y_max", 0.0)))) + left, top, right, bottom = int(x_min * W), int(y_min * H), int(x_max * W), int(y_max * H) + left, top = max(0, left), max(0, top) + right, bottom = min(W - 1, right), min(H - 1, bottom) + crop = base_img.crop((left, top, right, bottom)) + + # # Streaming caption + # caption_text = "" + # try: + # cap_result = model_md.caption(crop, length="short", stream=True, settings={"temperature": 0.3}) + # for chunk in cap_result.get("caption", []): + # caption_text += str(chunk) + # except Exception: + # caption_text = "" + + # Prompted captionn + result = model_md.query( + crop, "Caption this UI element in few words." + ) + caption_text = result["answer"] + + name = caption_text.strip() or f"element_{i+1}" + detected_names.append(name) + + # Draw bbox and label + draw.rectangle([left, top, right, bottom], outline=(255, 215, 0), width=2) + label = f"{i+1}. {name}" + text_xy = (left + 3, max(0, top - 12)) + if font: + draw.text(text_xy, label, fill=(255, 215, 0), font=font) + else: + draw.text(text_xy, label, fill=(255, 215, 0)) + except Exception: + continue + + # Emit annotated screenshot + annotated_b64 = _image_to_b64(base_img) + if _on_screenshot: + await _on_screenshot(annotated_b64, "annotated_form_ui") + + # Also push a user message listing all detected names + if detected_names: + names_text = "\n".join(f"- {n}" for n in detected_names) + pre_output_items.append( + { + "type": "message", + "role": "user", + "content": [ + {"type": "input_text", "text": "Detected form UI elements on screen:"}, + {"type": "input_text", "text": names_text}, + {"type": "input_text", "text": "Please continue with the next action needed to perform your task."} + ], + } + ) + + tool_schemas = [] + for schema in (tools or []): + if schema.get("type") == "computer": + tool_schemas.append(GROUNDED_COMPUTER_TOOL_SCHEMA) + else: + tool_schemas.append(schema) + + # Step 1: Convert computer calls from xy to descriptions + input_messages = messages + pre_output_items + messages_with_descriptions = convert_computer_calls_xy2desc(input_messages, self.desc2xy) + + # Step 2: Convert responses items to completion messages + completion_messages = convert_responses_items_to_completion_messages( + messages_with_descriptions, + allow_images_in_tool_results=False, + ) + + # Optionally filter images if model lacks vision + if not _supports_vision(thinking_model): + completion_messages = _filter_images_from_completion_messages(completion_messages) + + # Step 3: Call thinking model with litellm.acompletion + api_kwargs = { + "model": thinking_model, + "messages": completion_messages, + "tools": tool_schemas, + "max_retries": max_retries, + "stream": stream, + **kwargs, + } + if use_prompt_caching: + api_kwargs["use_prompt_caching"] = use_prompt_caching + + if _on_api_start: + await _on_api_start(api_kwargs) + + response = await litellm.acompletion(**api_kwargs) + + if _on_api_end: + await _on_api_end(api_kwargs, response) + + usage = { + **response.usage.model_dump(), # type: ignore + "response_cost": response._hidden_params.get("response_cost", 0.0), + } + if _on_usage: + await _on_usage(usage) + + # Step 4: Convert completion messages back to responses items format + response_dict = response.model_dump() # type: ignore + choice_messages = [choice["message"] for choice in response_dict["choices"]] + thinking_output_items: List[Dict[str, Any]] = [] + for choice_message in choice_messages: + thinking_output_items.extend( + convert_completion_messages_to_responses_items([choice_message]) + ) + + # Step 5: Use Moondream to get coordinates for each description + element_descriptions = get_all_element_descriptions(thinking_output_items) + if element_descriptions and last_image_b64: + for desc in element_descriptions: + for _ in range(3): # try 3 times + coords = await self.predict_click( + model=model, + image_b64=last_image_b64, + instruction=desc, + ) + if coords: + self.desc2xy[desc] = coords + break + + # Step 6: Convert computer calls from descriptions back to xy coordinates + final_output_items = convert_computer_calls_desc2xy(thinking_output_items, self.desc2xy) + + # Step 7: Return output and usage + return {"output": pre_output_items + final_output_items, "usage": usage} + + async def predict_click( + self, + model: str, + image_b64: str, + instruction: str, + **kwargs, + ) -> Optional[Tuple[float, float]]: + """Predict click coordinates using Moondream3's point API. + + Returns pixel coordinates (x, y) as floats. + """ + img = _decode_image_b64(image_b64) + W, H = img.width, img.height + model_md = get_moondream_model() + try: + result = model_md.point(img, instruction, settings={"max_objects": 1}) + except Exception: + return None + + try: + pt = (result or {}).get("points", [])[0] + x_norm = float(pt.get("x", 0.0)) + y_norm = float(pt.get("y", 0.0)) + x_px = max(0.0, min(float(W - 1), x_norm * W)) + y_px = max(0.0, min(float(H - 1), y_norm * H)) + return (x_px, y_px) + except Exception: + return None + + def get_capabilities(self) -> List[AgentCapability]: + return ["click", "step"]