mirror of
https://github.com/trycua/computer.git
synced 2025-12-31 18:40:04 -06:00
186 lines
6.9 KiB
Python
186 lines
6.9 KiB
Python
"""
|
|
Custom computer handler implementation that accepts a dictionary of functions.
|
|
"""
|
|
|
|
import base64
|
|
from typing import Dict, List, Any, Literal, Union, Optional, Callable
|
|
from PIL import Image
|
|
import io
|
|
from .base import ComputerHandler
|
|
|
|
|
|
class CustomComputerHandler(ComputerHandler):
|
|
"""Computer handler that implements the Computer protocol using a dictionary of custom functions."""
|
|
|
|
def __init__(self, functions: Dict[str, Callable]):
|
|
"""
|
|
Initialize with a dictionary of functions.
|
|
|
|
Args:
|
|
functions: Dictionary where keys are method names and values are callable functions.
|
|
Only 'screenshot' is required, all others are optional.
|
|
|
|
Raises:
|
|
ValueError: If required 'screenshot' function is not provided.
|
|
"""
|
|
if 'screenshot' not in functions:
|
|
raise ValueError("'screenshot' function is required in functions dictionary")
|
|
|
|
self.functions = functions
|
|
self._last_screenshot_size: Optional[tuple[int, int]] = None
|
|
|
|
async def _get_value(self, attribute: str):
|
|
"""
|
|
Get value for an attribute, checking both 'get_{attribute}' and '{attribute}' keys.
|
|
|
|
Args:
|
|
attribute: The attribute name to look for
|
|
|
|
Returns:
|
|
The value from the functions dict, called if callable, returned directly if not
|
|
"""
|
|
# Check for 'get_{attribute}' first
|
|
get_key = f"get_{attribute}"
|
|
if get_key in self.functions:
|
|
value = self.functions[get_key]
|
|
return await value() if callable(value) else value
|
|
|
|
# Check for '{attribute}'
|
|
if attribute in self.functions:
|
|
value = self.functions[attribute]
|
|
return await value() if callable(value) else value
|
|
|
|
return None
|
|
|
|
def _to_b64_str(self, img: Union[bytes, Image.Image, str]) -> str:
|
|
"""
|
|
Convert image to base64 string.
|
|
|
|
Args:
|
|
img: Image as bytes, PIL Image, or base64 string
|
|
|
|
Returns:
|
|
str: Base64 encoded image string
|
|
"""
|
|
if isinstance(img, str):
|
|
# Already a base64 string
|
|
return img
|
|
elif isinstance(img, bytes):
|
|
# Raw bytes
|
|
return base64.b64encode(img).decode('utf-8')
|
|
elif isinstance(img, Image.Image):
|
|
# PIL Image
|
|
buffer = io.BytesIO()
|
|
img.save(buffer, format='PNG')
|
|
return base64.b64encode(buffer.getvalue()).decode('utf-8')
|
|
else:
|
|
raise ValueError(f"Unsupported image type: {type(img)}")
|
|
|
|
# ==== Computer-Use-Preview Action Space ====
|
|
|
|
async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]:
|
|
"""Get the current environment type."""
|
|
result = await self._get_value('environment')
|
|
return result if result is not None else "linux"
|
|
|
|
async def get_dimensions(self) -> tuple[int, int]:
|
|
"""Get screen dimensions as (width, height)."""
|
|
result = await self._get_value('dimensions')
|
|
if result is not None:
|
|
return result
|
|
|
|
# Fallback: use last screenshot size if available
|
|
if not self._last_screenshot_size:
|
|
await self.screenshot()
|
|
assert self._last_screenshot_size is not None, "Failed to get screenshot size"
|
|
|
|
return self._last_screenshot_size
|
|
|
|
async def screenshot(self) -> str:
|
|
"""Take a screenshot and return as base64 string."""
|
|
result = await self.functions['screenshot']()
|
|
b64_str = self._to_b64_str(result)
|
|
|
|
# Try to extract dimensions for fallback use
|
|
try:
|
|
if isinstance(result, Image.Image):
|
|
self._last_screenshot_size = result.size
|
|
elif isinstance(result, bytes):
|
|
# Try to decode bytes to get dimensions
|
|
img = Image.open(io.BytesIO(result))
|
|
self._last_screenshot_size = img.size
|
|
except Exception:
|
|
# If we can't get dimensions, that's okay
|
|
pass
|
|
|
|
return b64_str
|
|
|
|
async def click(self, x: int, y: int, button: str = "left") -> None:
|
|
"""Click at coordinates with specified button."""
|
|
if 'click' in self.functions:
|
|
await self.functions['click'](x, y, button)
|
|
# No-op if not implemented
|
|
|
|
async def double_click(self, x: int, y: int) -> None:
|
|
"""Double click at coordinates."""
|
|
if 'double_click' in self.functions:
|
|
await self.functions['double_click'](x, y)
|
|
# No-op if not implemented
|
|
|
|
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
|
|
"""Scroll at coordinates with specified scroll amounts."""
|
|
if 'scroll' in self.functions:
|
|
await self.functions['scroll'](x, y, scroll_x, scroll_y)
|
|
# No-op if not implemented
|
|
|
|
async def type(self, text: str) -> None:
|
|
"""Type text."""
|
|
if 'type' in self.functions:
|
|
await self.functions['type'](text)
|
|
# No-op if not implemented
|
|
|
|
async def wait(self, ms: int = 1000) -> None:
|
|
"""Wait for specified milliseconds."""
|
|
if 'wait' in self.functions:
|
|
await self.functions['wait'](ms)
|
|
else:
|
|
# Default implementation
|
|
import asyncio
|
|
await asyncio.sleep(ms / 1000.0)
|
|
|
|
async def move(self, x: int, y: int) -> None:
|
|
"""Move cursor to coordinates."""
|
|
if 'move' in self.functions:
|
|
await self.functions['move'](x, y)
|
|
# No-op if not implemented
|
|
|
|
async def keypress(self, keys: Union[List[str], str]) -> None:
|
|
"""Press key combination."""
|
|
if 'keypress' in self.functions:
|
|
await self.functions['keypress'](keys)
|
|
# No-op if not implemented
|
|
|
|
async def drag(self, path: List[Dict[str, int]]) -> None:
|
|
"""Drag along specified path."""
|
|
if 'drag' in self.functions:
|
|
await self.functions['drag'](path)
|
|
# No-op if not implemented
|
|
|
|
async def get_current_url(self) -> str:
|
|
"""Get current URL (for browser environments)."""
|
|
if 'get_current_url' in self.functions:
|
|
return await self.functions['get_current_url']()
|
|
return "" # Default fallback
|
|
|
|
async def left_mouse_down(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
|
"""Left mouse down at coordinates."""
|
|
if 'left_mouse_down' in self.functions:
|
|
await self.functions['left_mouse_down'](x, y)
|
|
# No-op if not implemented
|
|
|
|
async def left_mouse_up(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
|
"""Left mouse up at coordinates."""
|
|
if 'left_mouse_up' in self.functions:
|
|
await self.functions['left_mouse_up'](x, y)
|
|
# No-op if not implemented
|