Add gemini CUA loop

This commit is contained in:
Dillon DuPont
2025-10-07 17:23:33 -04:00
parent 9f18f9eeaa
commit ef28a64b8f

View File

@@ -0,0 +1,391 @@
"""
Gemini 2.5 Computer Use agent loop
Maps internal Agent SDK message format to Google's Gemini Computer Use API and back.
Key features:
- Lazy import of google.genai
- Configure Computer Use tool with excluded browser-specific predefined functions
- Optional custom function declarations hook for computer-call specific functions
- Convert Gemini function_call parts into internal computer_call actions
"""
from __future__ import annotations
import base64
import io
import uuid
from typing import Any, Dict, List, Optional, Tuple
from PIL import Image
from ..decorators import register_agent
from ..loops.base import AsyncAgentConfig
from ..types import AgentCapability
def _lazy_import_genai():
"""Import google.genai lazily to avoid hard dependency unless used."""
try:
from google import genai # type: ignore
from google.genai import types # type: ignore
return genai, types
except Exception as e: # pragma: no cover
raise RuntimeError(
"google.genai is required for the Gemini Computer Use loop. Install the Google Gemini SDK."
) from e
def _data_url_to_bytes(data_url: str) -> Tuple[bytes, str]:
"""Convert a data URL to raw bytes and mime type."""
if not data_url.startswith("data:"):
# Assume it's base64 png payload
try:
return base64.b64decode(data_url), "image/png"
except Exception:
return b"", "application/octet-stream"
header, b64 = data_url.split(",", 1)
mime = "image/png"
if ";" in header:
mime = header.split(";")[0].split(":", 1)[1] or "image/png"
return base64.b64decode(b64), mime
def _bytes_image_size(img_bytes: bytes) -> Tuple[int, int]:
try:
img = Image.open(io.BytesIO(img_bytes))
return img.size
except Exception:
return (1024, 768)
def _find_last_user_text(messages: List[Dict[str, Any]]) -> List[str]:
texts: List[str] = []
for msg in reversed(messages):
if msg.get("type") in (None, "message") and msg.get("role") == "user":
content = msg.get("content")
if isinstance(content, str):
return [content]
elif isinstance(content, list):
for c in content:
if c.get("type") in ("input_text", "output_text") and c.get("text"):
texts.append(c["text"]) # newest first
if texts:
return list(reversed(texts))
return []
def _find_last_screenshot(messages: List[Dict[str, Any]]) -> Optional[bytes]:
for msg in reversed(messages):
if msg.get("type") == "computer_call_output":
out = msg.get("output", {})
if isinstance(out, dict) and out.get("type") in ("input_image", "computer_screenshot"):
image_url = out.get("image_url", "")
if image_url:
data, _ = _data_url_to_bytes(image_url)
return data
return None
def _denormalize(v: int, size: int) -> int:
# Gemini returns 0-999 normalized
try:
return max(0, min(size - 1, int(round(v / 1000 * size))))
except Exception:
return 0
def _map_gemini_fc_to_computer_call(
fc: Dict[str, Any],
screen_w: int,
screen_h: int,
) -> Optional[Dict[str, Any]]:
name = fc.get("name")
args = fc.get("args", {}) or {}
action: Dict[str, Any] = {}
if name == "click_at":
x = _denormalize(int(args.get("x", 0)), screen_w)
y = _denormalize(int(args.get("y", 0)), screen_h)
action = {"type": "click", "x": x, "y": y, "button": "left"}
elif name == "type_text_at":
x = _denormalize(int(args.get("x", 0)), screen_w)
y = _denormalize(int(args.get("y", 0)), screen_h)
text = args.get("text", "")
if args.get("press_enter") == True:
text += "\n"
action = {"type": "type", "x": x, "y": y, "text": text}
elif name == "hover_at":
x = _denormalize(int(args.get("x", 0)), screen_w)
y = _denormalize(int(args.get("y", 0)), screen_h)
action = {"type": "move", "x": x, "y": y}
elif name == "key_combination":
keys = str(args.get("keys", ""))
action = {"type": "keypress", "keys": keys}
elif name == "scroll_document":
direction = args.get("direction", "down")
magnitude = 800
dx, dy = 0, 0
if direction == "down":
dy = magnitude
elif direction == "up":
dy = -magnitude
elif direction == "right":
dx = magnitude
elif direction == "left":
dx = -magnitude
action = {"type": "scroll", "scroll_x": dx, "scroll_y": dy, "x": int(screen_w / 2), "y": int(screen_h / 2)}
elif name == "scroll_at":
x = _denormalize(int(args.get("x", 500)), screen_w)
y = _denormalize(int(args.get("y", 500)), screen_h)
direction = args.get("direction", "down")
magnitude = int(args.get("magnitude", 800))
dx, dy = 0, 0
if direction == "down":
dy = magnitude
elif direction == "up":
dy = -magnitude
elif direction == "right":
dx = magnitude
elif direction == "left":
dx = -magnitude
action = {"type": "scroll", "scroll_x": dx, "scroll_y": dy, "x": x, "y": y}
elif name == "drag_and_drop":
x = _denormalize(int(args.get("x", 0)), screen_w)
y = _denormalize(int(args.get("y", 0)), screen_h)
dx = _denormalize(int(args.get("destination_x", x)), screen_w)
dy = _denormalize(int(args.get("destination_y", y)), screen_h)
action = {"type": "drag", "start_x": x, "start_y": y, "end_x": dx, "end_y": dy, "button": "left"}
elif name == "wait_5_seconds":
action = {"type": "wait"}
else:
# Unsupported / excluded browser-specific or custom function; ignore
return None
return {
"type": "computer_call",
"call_id": uuid.uuid4().hex,
"status": "completed",
"action": action,
}
@register_agent(models=r"^gemini-2\.5-computer-use-preview-10-2025$")
class GeminiComputerUseConfig(AsyncAgentConfig):
async def predict_step(
self,
messages: List[Dict[str, Any]],
model: str,
tools: Optional[List[Dict[str, Any]]] = None,
max_retries: Optional[int] = None,
stream: bool = False,
computer_handler=None,
use_prompt_caching: Optional[bool] = False,
_on_api_start=None,
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs,
) -> Dict[str, Any]:
genai, types = _lazy_import_genai()
client = genai.Client()
# Build excluded predefined functions for browser-specific behavior
excluded = [
"open_web_browser",
"search",
"navigate",
"go_forward",
"go_back",
"scroll_document",
]
# Optional custom functions: can be extended by host code via `tools` parameter later if desired
CUSTOM_FUNCTION_DECLARATIONS: List[Any] = []
# Compose tools config
generate_content_config = types.GenerateContentConfig(
tools=[
types.Tool(
computer_use=types.ComputerUse(
environment=types.Environment.ENVIRONMENT_BROWSER,
excluded_predefined_functions=excluded,
)
),
# types.Tool(function_declarations=CUSTOM_FUNCTION_DECLARATIONS), # enable when custom functions needed
]
)
# Prepare contents: last user text + latest screenshot
user_texts = _find_last_user_text(messages)
screenshot_bytes = _find_last_screenshot(messages)
parts: List[Any] = []
for t in user_texts:
parts.append(types.Part(text=t))
screen_w, screen_h = 1024, 768
if screenshot_bytes:
screen_w, screen_h = _bytes_image_size(screenshot_bytes)
parts.append(types.Part.from_bytes(data=screenshot_bytes, mime_type="image/png"))
# If we don't have any content, at least pass an empty user part to prompt reasoning
if not parts:
parts = [types.Part(text="Proceed to the next action.")]
contents = [types.Content(role="user", parts=parts)]
api_kwargs = {
"model": model,
"contents": contents,
"config": generate_content_config,
}
if _on_api_start:
await _on_api_start({
"model": api_kwargs["model"],
# "contents": api_kwargs["contents"], # Disabled for now
"config": api_kwargs["config"],
})
response = client.models.generate_content(**api_kwargs)
if _on_api_end:
await _on_api_end({
"model": api_kwargs["model"],
# "contents": api_kwargs["contents"], # Disabled for now
"config": api_kwargs["config"],
}, response)
# Usage (Gemini SDK may not always provide token usage; populate when available)
usage: Dict[str, Any] = {}
try:
# Some SDKs expose response.usage; if available, copy
if getattr(response, "usage_metadata", None):
md = response.usage_metadata
usage = {
"prompt_tokens": getattr(md, "prompt_token_count", None) or 0,
"completion_tokens": getattr(md, "candidates_token_count", None) or 0,
"total_tokens": getattr(md, "total_token_count", None) or 0,
}
except Exception:
pass
if _on_usage and usage:
await _on_usage(usage)
# Parse output into internal items
output_items: List[Dict[str, Any]] = []
candidate = response.candidates[0]
# Text parts from the model (assistant message)
text_parts: List[str] = []
function_calls: List[Dict[str, Any]] = []
for p in candidate.content.parts:
if getattr(p, "text", None):
text_parts.append(p.text)
if getattr(p, "function_call", None):
# p.function_call has name and args
fc = {
"name": getattr(p.function_call, "name", None),
"args": dict(getattr(p.function_call, "args", {}) or {}),
}
function_calls.append(fc)
if text_parts:
output_items.append(
{
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": "\n".join(text_parts)}],
}
)
# Map function calls to internal computer_call actions
for fc in function_calls:
item = _map_gemini_fc_to_computer_call(fc, screen_w, screen_h)
if item is not None:
output_items.append(item)
return {"output": output_items, "usage": usage}
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str,
**kwargs,
) -> Optional[Tuple[float, float]]:
"""Ask Gemini CUA to output a single click action for the given instruction.
Excludes all predefined tools except `click_at` and sends the screenshot.
Returns pixel (x, y) if a click is proposed, else None.
"""
genai, types = _lazy_import_genai()
client = genai.Client()
# Exclude all but click_at
exclude_all_but_click = [
"open_web_browser",
"wait_5_seconds",
"go_back",
"go_forward",
"search",
"navigate",
"hover_at",
"type_text_at",
"key_combination",
"scroll_document",
"scroll_at",
"drag_and_drop",
]
config = types.GenerateContentConfig(
tools=[
types.Tool(
computer_use=types.ComputerUse(
environment=types.Environment.ENVIRONMENT_BROWSER,
excluded_predefined_functions=exclude_all_but_click,
)
)
]
)
# Prepare prompt parts
try:
img_bytes = base64.b64decode(image_b64)
except Exception:
img_bytes = b""
w, h = _bytes_image_size(img_bytes) if img_bytes else (1024, 768)
parts: List[Any] = [types.Part(text=f"Click {instruction}.")]
if img_bytes:
parts.append(types.Part.from_bytes(data=img_bytes, mime_type="image/png"))
contents = [types.Content(role="user", parts=parts)]
response = client.models.generate_content(
model=model,
contents=contents,
config=config,
)
# Parse first click_at
try:
candidate = response.candidates[0]
for p in candidate.content.parts:
fc = getattr(p, "function_call", None)
if fc and getattr(fc, "name", None) == "click_at":
args = dict(getattr(fc, "args", {}) or {})
x = _denormalize(int(args.get("x", 0)), w)
y = _denormalize(int(args.get("y", 0)), h)
return float(x), float(y)
except Exception:
return None
return None
def get_capabilities(self) -> List[AgentCapability]:
return ["click", "step"]