Format codebase with uv run pre-commit run --all-files

This commit is contained in:
James Murdza
2025-10-22 11:11:02 -07:00
parent 759ff4703e
commit ddc5a5de91
234 changed files with 10127 additions and 8467 deletions

View File

@@ -8,10 +8,11 @@
</picture>
</div>
[![Python](https://img.shields.io/badge/Python-333333?logo=python&logoColor=white&labelColor=333333)](#)
[![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#)
[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85)
[![PyPI](https://img.shields.io/pypi/v/cua-computer-server?color=333333)](https://pypi.org/project/cua-computer-server/)
[![Python](https://img.shields.io/badge/Python-333333?logo=python&logoColor=white&labelColor=333333)](#)
[![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#)
[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85)
[![PyPI](https://img.shields.io/pypi/v/cua-computer-server?color=333333)](https://pypi.org/project/cua-computer-server/)
</h1>
</div>
@@ -42,4 +43,4 @@ Refer to this notebook for a step-by-step guide on how to use the Computer-Use S
- [Commands](https://trycua.com/docs/libraries/computer-server/Commands)
- [REST-API](https://trycua.com/docs/libraries/computer-server/REST-API)
- [WebSocket-API](https://trycua.com/docs/libraries/computer-server/WebSocket-API)
- [Index](https://trycua.com/docs/libraries/computer-server/index)
- [Index](https://trycua.com/docs/libraries/computer-server/index)

View File

@@ -4,6 +4,7 @@ This allows the server to be started with `python -m computer_server`.
"""
import sys
from .cli import main
if __name__ == "__main__":

View File

@@ -36,7 +36,7 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
help="Path to SSL private key file (enables HTTPS)",
)
parser.add_argument(
"--ssl-certfile",
"--ssl-certfile",
type=str,
help="Path to SSL certificate file (enables HTTPS)",
)
@@ -73,16 +73,18 @@ def main() -> None:
# Check if watchdog should be enabled
container_name = os.environ.get("CONTAINER_NAME")
enable_watchdog = (args.watchdog or bool(container_name)) and not sys.platform.startswith("win")
if container_name:
logger.info(f"Container environment detected (CONTAINER_NAME={container_name}), enabling watchdog")
logger.info(
f"Container environment detected (CONTAINER_NAME={container_name}), enabling watchdog"
)
elif args.watchdog:
logger.info("Watchdog explicitly enabled via --watchdog flag")
# Start watchdog if enabled
if enable_watchdog:
logger.info(f"Starting watchdog monitoring with {args.watchdog_interval}s interval")
def run_watchdog_thread():
"""Run watchdog in a separate thread."""
loop = asyncio.new_event_loop()
@@ -90,38 +92,32 @@ def main() -> None:
try:
# Create CLI args dict for watchdog
cli_args = {
'host': args.host,
'port': args.port,
'log_level': args.log_level,
'ssl_keyfile': args.ssl_keyfile,
'ssl_certfile': args.ssl_certfile
"host": args.host,
"port": args.port,
"log_level": args.log_level,
"ssl_keyfile": args.ssl_keyfile,
"ssl_certfile": args.ssl_certfile,
}
# Create watchdog with restart settings
from .watchdog import Watchdog
watchdog = Watchdog(
cli_args=cli_args,
ping_interval=args.watchdog_interval
)
watchdog = Watchdog(cli_args=cli_args, ping_interval=args.watchdog_interval)
watchdog.restart_enabled = not args.no_restart
loop.run_until_complete(watchdog.start_monitoring())
except Exception as e:
logger.error(f"Watchdog error: {e}")
finally:
loop.close()
# Start watchdog in background thread
watchdog_thread = threading.Thread(
target=run_watchdog_thread,
daemon=True,
name="watchdog"
)
watchdog_thread = threading.Thread(target=run_watchdog_thread, daemon=True, name="watchdog")
watchdog_thread.start()
# Create and start the server
logger.info(f"Starting CUA Computer API server on {args.host}:{args.port}...")
# Handle SSL configuration
ssl_args = {}
if args.ssl_keyfile and args.ssl_certfile:
@@ -131,10 +127,12 @@ def main() -> None:
}
logger.info("HTTPS mode enabled with SSL certificates")
elif args.ssl_keyfile or args.ssl_certfile:
logger.warning("Both --ssl-keyfile and --ssl-certfile are required for HTTPS. Running in HTTP mode.")
logger.warning(
"Both --ssl-keyfile and --ssl-certfile are required for HTTPS. Running in HTTP mode."
)
else:
logger.info("HTTP mode (no SSL certificates provided)")
server = Server(host=args.host, port=args.port, log_level=args.log_level, **ssl_args)
try:

View File

@@ -1,4 +1,5 @@
class BaseDioramaHandler:
"""Base Diorama handler for unsupported OSes."""
async def diorama_cmd(self, action: str, arguments: dict = None) -> dict:
return {"success": False, "error": "Diorama is not supported on this OS yet."}

View File

@@ -1,31 +1,38 @@
#!/usr/bin/env python3
"""Diorama: A virtual desktop manager for macOS"""
import os
import asyncio
import logging
import sys
import io
import logging
import os
import sys
from typing import Union
from PIL import Image, ImageDraw
from computer_server.diorama.draw import capture_all_apps, AppActivationContext, get_frontmost_and_active_app, get_all_windows, get_running_apps
from computer_server.diorama.diorama_computer import DioramaComputer
from computer_server.diorama.draw import (
AppActivationContext,
capture_all_apps,
get_all_windows,
get_frontmost_and_active_app,
get_running_apps,
)
from computer_server.handlers.macos import *
from PIL import Image, ImageDraw
# simple, nicely formatted logging
logger = logging.getLogger(__name__)
automation_handler = MacOSAutomationHandler()
class Diorama:
"""Virtual desktop manager that provides automation capabilities for macOS applications.
Manages application windows and provides an interface for taking screenshots,
mouse interactions, keyboard input, and coordinate transformations between
screenshot space and screen space.
"""
_scheduler_queue = None
_scheduler_task = None
_loop = None
@@ -34,10 +41,10 @@ class Diorama:
@classmethod
def create_from_apps(cls, *args) -> DioramaComputer:
"""Create a DioramaComputer instance from a list of application names.
Args:
*args: Variable number of application names to include in the desktop
Returns:
DioramaComputer: A computer interface for the specified applications
"""
@@ -46,10 +53,10 @@ class Diorama:
# Dictionary to store cursor positions for each unique app_list hash
_cursor_positions = {}
def __init__(self, app_list):
"""Initialize a Diorama instance for the specified applications.
Args:
app_list: List of application names to manage
"""
@@ -57,10 +64,10 @@ class Diorama:
self.interface = self.Interface(self)
self.computer = DioramaComputer(self)
self.focus_context = None
# Create a hash for this app_list to use as a key
self.app_list_hash = hash(tuple(sorted(app_list)))
# Initialize cursor position for this app_list if it doesn't exist
if self.app_list_hash not in Diorama._cursor_positions:
Diorama._cursor_positions[self.app_list_hash] = (0, 0)
@@ -68,7 +75,7 @@ class Diorama:
@classmethod
def _ensure_scheduler(cls):
"""Ensure the async scheduler loop is running.
Creates and starts the scheduler task if it hasn't been started yet.
"""
if not cls._scheduler_started:
@@ -81,7 +88,7 @@ class Diorama:
@classmethod
async def _scheduler_loop(cls):
"""Main scheduler loop that processes automation commands.
Continuously processes commands from the scheduler queue, handling
screenshots, mouse actions, keyboard input, and scrolling operations.
"""
@@ -91,31 +98,37 @@ class Diorama:
args = cmd.get("arguments", {})
future = cmd.get("future")
logger.info(f"Processing command: {action} | args={args}")
app_whitelist = args.get("app_list", [])
all_windows = get_all_windows()
running_apps = get_running_apps()
frontmost_app, active_app_to_use, active_app_pid = get_frontmost_and_active_app(all_windows, running_apps, app_whitelist)
frontmost_app, active_app_to_use, active_app_pid = get_frontmost_and_active_app(
all_windows, running_apps, app_whitelist
)
focus_context = AppActivationContext(active_app_pid, active_app_to_use, logger)
with focus_context:
try:
if action == "screenshot":
logger.info(f"Taking screenshot for apps: {app_whitelist}")
result, img = capture_all_apps(
app_whitelist=app_whitelist,
save_to_disk=False,
take_focus=False
app_whitelist=app_whitelist, save_to_disk=False, take_focus=False
)
logger.info("Screenshot complete.")
if future:
future.set_result((result, img))
# Mouse actions
elif action in ["left_click", "right_click", "double_click", "move_cursor", "drag_to"]:
elif action in [
"left_click",
"right_click",
"double_click",
"move_cursor",
"drag_to",
]:
x = args.get("x")
y = args.get("y")
duration = args.get("duration", 0.5)
if action == "left_click":
await automation_handler.left_click(x, y)
@@ -134,7 +147,7 @@ class Diorama:
y = args.get("y")
if x is not None and y is not None:
await automation_handler.move_cursor(x, y)
clicks = args.get("clicks", 1)
if action == "scroll_up":
await automation_handler.scroll_up(clicks)
@@ -171,31 +184,31 @@ class Diorama:
if future:
future.set_exception(e)
class Interface():
class Interface:
"""Interface for interacting with the virtual desktop.
Provides methods for taking screenshots, mouse interactions, keyboard input,
and coordinate transformations between screenshot and screen coordinates.
"""
def __init__(self, diorama):
"""Initialize the interface with a reference to the parent Diorama instance.
Args:
diorama: The parent Diorama instance
"""
self._diorama = diorama
self._scene_hitboxes = []
self._scene_size = None
async def _send_cmd(self, action, arguments=None):
"""Send a command to the scheduler queue.
Args:
action (str): The action to perform
arguments (dict, optional): Arguments for the action
Returns:
The result of the command execution
"""
@@ -203,11 +216,13 @@ class Diorama:
loop = asyncio.get_event_loop()
future = loop.create_future()
logger.info(f"Enqueuing {action} command for apps: {self._diorama.app_list}")
await Diorama._scheduler_queue.put({
"action": action,
"arguments": {"app_list": self._diorama.app_list, **(arguments or {})},
"future": future
})
await Diorama._scheduler_queue.put(
{
"action": action,
"arguments": {"app_list": self._diorama.app_list, **(arguments or {})},
"future": future,
}
)
try:
return await future
except asyncio.CancelledError:
@@ -216,21 +231,23 @@ class Diorama:
async def screenshot(self, as_bytes: bool = True) -> Union[str, Image.Image]:
"""Take a screenshot of the managed applications.
Args:
as_bytes (bool): If True, return base64-encoded bytes; if False, return PIL Image
Returns:
Union[str, Image.Image]: Base64-encoded PNG bytes or PIL Image object
"""
import base64
result, img = await self._send_cmd("screenshot")
self._scene_hitboxes = result.get("hitboxes", [])
self._scene_size = img.size
if as_bytes:
# PIL Image to bytes, then base64 encode for JSON
import io
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="PNG")
img_bytes = img_byte_arr.getvalue()
@@ -241,7 +258,7 @@ class Diorama:
async def left_click(self, x, y):
"""Perform a left mouse click at the specified coordinates.
Args:
x (int): X coordinate in screenshot space (or None to use last position)
y (int): Y coordinate in screenshot space (or None to use last position)
@@ -258,7 +275,7 @@ class Diorama:
async def right_click(self, x, y):
"""Perform a right mouse click at the specified coordinates.
Args:
x (int): X coordinate in screenshot space (or None to use last position)
y (int): Y coordinate in screenshot space (or None to use last position)
@@ -269,13 +286,13 @@ class Diorama:
x, y = x or last_pos[0], y or last_pos[1]
# Update cursor position for this app_list hash
Diorama._cursor_positions[app_list_hash] = (x, y)
sx, sy = await self.to_screen_coordinates(x, y)
await self._send_cmd("right_click", {"x": sx, "y": sy})
async def double_click(self, x, y):
"""Perform a double mouse click at the specified coordinates.
Args:
x (int): X coordinate in screenshot space (or None to use last position)
y (int): Y coordinate in screenshot space (or None to use last position)
@@ -286,13 +303,13 @@ class Diorama:
x, y = x or last_pos[0], y or last_pos[1]
# Update cursor position for this app_list hash
Diorama._cursor_positions[app_list_hash] = (x, y)
sx, sy = await self.to_screen_coordinates(x, y)
await self._send_cmd("double_click", {"x": sx, "y": sy})
async def move_cursor(self, x, y):
"""Move the mouse cursor to the specified coordinates.
Args:
x (int): X coordinate in screenshot space (or None to use last position)
y (int): Y coordinate in screenshot space (or None to use last position)
@@ -303,13 +320,13 @@ class Diorama:
x, y = x or last_pos[0], y or last_pos[1]
# Update cursor position for this app_list hash
Diorama._cursor_positions[app_list_hash] = (x, y)
sx, sy = await self.to_screen_coordinates(x, y)
await self._send_cmd("move_cursor", {"x": sx, "y": sy})
async def drag_to(self, x, y, duration=0.5):
"""Drag the mouse from current position to the specified coordinates.
Args:
x (int): X coordinate in screenshot space (or None to use last position)
y (int): Y coordinate in screenshot space (or None to use last position)
@@ -321,13 +338,13 @@ class Diorama:
x, y = x or last_pos[0], y or last_pos[1]
# Update cursor position for this app_list hash
Diorama._cursor_positions[app_list_hash] = (x, y)
sx, sy = await self.to_screen_coordinates(x, y)
await self._send_cmd("drag_to", {"x": sx, "y": sy, "duration": duration})
async def get_cursor_position(self):
"""Get the current cursor position in screen coordinates.
Returns:
tuple: (x, y) coordinates of the cursor in screen space
"""
@@ -335,7 +352,7 @@ class Diorama:
async def type_text(self, text):
"""Type the specified text using the keyboard.
Args:
text (str): The text to type
"""
@@ -343,7 +360,7 @@ class Diorama:
async def press_key(self, key):
"""Press a single key on the keyboard.
Args:
key (str): The key to press
"""
@@ -351,7 +368,7 @@ class Diorama:
async def hotkey(self, keys):
"""Press a combination of keys simultaneously.
Args:
keys (list): List of keys to press together
"""
@@ -359,7 +376,7 @@ class Diorama:
async def scroll_up(self, clicks: int = 1):
"""Scroll up at the current cursor position.
Args:
clicks (int): Number of scroll clicks to perform
"""
@@ -367,12 +384,12 @@ class Diorama:
app_list_hash = hash(tuple(sorted(self._diorama.app_list)))
last_pos = Diorama._cursor_positions.get(app_list_hash, (0, 0))
x, y = last_pos[0], last_pos[1]
await self._send_cmd("scroll_up", {"clicks": clicks, "x": x, "y": y})
async def scroll_down(self, clicks: int = 1):
"""Scroll down at the current cursor position.
Args:
clicks (int): Number of scroll clicks to perform
"""
@@ -380,18 +397,18 @@ class Diorama:
app_list_hash = hash(tuple(sorted(self._diorama.app_list)))
last_pos = Diorama._cursor_positions.get(app_list_hash, (0, 0))
x, y = last_pos[0], last_pos[1]
await self._send_cmd("scroll_down", {"clicks": clicks, "x": x, "y": y})
async def get_screen_size(self) -> dict[str, int]:
"""Get the size of the screenshot area.
Returns:
dict[str, int]: Dictionary with 'width' and 'height' keys
"""
if not self._scene_size:
await self.screenshot()
return { "width": self._scene_size[0], "height": self._scene_size[1] }
return {"width": self._scene_size[0], "height": self._scene_size[1]}
async def to_screen_coordinates(self, x: float, y: float) -> tuple[float, float]:
"""Convert screenshot coordinates to screen coordinates.
@@ -404,29 +421,29 @@ class Diorama:
tuple[float, float]: (x, y) absolute coordinates in screen space
"""
if not self._scene_hitboxes:
await self.screenshot() # get hitboxes
await self.screenshot() # get hitboxes
# Try all hitboxes
for h in self._scene_hitboxes[::-1]:
rect_from = h.get("hitbox")
rect_to = h.get("target")
if not rect_from or len(rect_from) != 4:
continue
# check if (x, y) is inside rect_from
x0, y0, x1, y1 = rect_from
if x0 <= x <= x1 and y0 <= y <= y1:
logger.info(f"Found hitbox: {h}")
# remap (x, y) to rect_to
tx0, ty0, tx1, ty1 = rect_to
# calculate offset from x0, y0
offset_x = x - x0
offset_y = y - y0
# remap offset to rect_to
tx = tx0 + offset_x
ty = ty0 + offset_y
return tx, ty
return x, y
@@ -441,34 +458,37 @@ class Diorama:
tuple[float, float]: (x, y) absolute coordinates in screenshot space
"""
if not self._scene_hitboxes:
await self.screenshot() # get hitboxes
await self.screenshot() # get hitboxes
# Try all hitboxes
for h in self._scene_hitboxes[::-1]:
rect_from = h.get("target")
rect_to = h.get("hitbox")
if not rect_from or len(rect_from) != 4:
continue
# check if (x, y) is inside rect_from
x0, y0, x1, y1 = rect_from
if x0 <= x <= x1 and y0 <= y <= y1:
# remap (x, y) to rect_to
tx0, ty0, tx1, ty1 = rect_to
# calculate offset from x0, y0
offset_x = x - x0
offset_y = y - y0
# remap offset to rect_to
tx = tx0 + offset_x
ty = ty0 + offset_y
return tx, ty
return x, y
import pyautogui
import time
import pyautogui
async def main():
"""Main function demonstrating Diorama usage with multiple desktops and mouse tracking."""
desktop1 = Diorama.create_from_apps(["Discord", "Notes"])
@@ -511,7 +531,7 @@ async def main():
# Draw on a copy of the screenshot
frame = base_img.copy()
frame_draw = ImageDraw.Draw(frame)
frame_draw.ellipse((sx-5, sy-5, sx+5, sy+5), fill="blue", outline="blue")
frame_draw.ellipse((sx - 5, sy - 5, sx + 5, sy + 5), fill="blue", outline="blue")
# Save the frame
frame.save("app_screenshots/desktop3_mouse.png")
print(f"Mouse at screen ({mouse_x}, {mouse_y}) -> screenshot ({sx:.1f}, {sy:.1f})")
@@ -520,15 +540,13 @@ async def main():
print("Stopped tracking.")
draw.text((rect[0], rect[1]), str(idx), fill="red")
canvas.save("app_screenshots/desktop3_hitboxes.png")
# move mouse in a square spiral around the screen
import math
import random
step = 20 # pixels per move
dot_radius = 10
width = screen_size["width"]
@@ -539,11 +557,12 @@ async def main():
await desktop3.interface.move_cursor(x, y)
img = await desktop3.interface.screenshot(as_bytes=False)
draw = ImageDraw.Draw(img)
draw.ellipse((x-dot_radius, y-dot_radius, x+dot_radius, y+dot_radius), fill="red")
draw.ellipse((x - dot_radius, y - dot_radius, x + dot_radius, y + dot_radius), fill="red")
img.save("current.png")
await asyncio.sleep(0.03)
x += step
y = math.sin(x / width * math.pi * 2) * 50 + 25
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,14 +1,16 @@
import asyncio
class DioramaComputer:
"""
A minimal Computer-like interface for Diorama, compatible with ComputerAgent.
Implements _initialized, run(), and __aenter__ for agent compatibility.
"""
def __init__(self, diorama):
"""
Initialize the DioramaComputer with a diorama instance.
Args:
diorama: The diorama instance to wrap with a computer-like interface.
"""
@@ -19,10 +21,10 @@ class DioramaComputer:
async def __aenter__(self):
"""
Async context manager entry method for compatibility with ComputerAgent.
Ensures an event loop is running and marks the instance as initialized.
Creates a new event loop if none is currently running.
Returns:
DioramaComputer: The initialized instance.
"""
@@ -37,10 +39,10 @@ class DioramaComputer:
async def run(self):
"""
Run method stub for compatibility with ComputerAgent interface.
Ensures the instance is initialized before returning. If not already
initialized, calls __aenter__ to perform initialization.
Returns:
DioramaComputer: The initialized instance.
"""

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,15 @@
import inspect
import platform
import sys
import platform
import inspect
from computer_server.diorama.diorama import Diorama
from computer_server.diorama.base import BaseDioramaHandler
from typing import Optional
from computer_server.diorama.base import BaseDioramaHandler
from computer_server.diorama.diorama import Diorama
class MacOSDioramaHandler(BaseDioramaHandler):
"""Handler for Diorama commands on macOS, using local diorama module."""
async def diorama_cmd(self, action: str, arguments: Optional[dict] = None) -> dict:
if platform.system().lower() != "darwin":
return {"success": False, "error": "Diorama is only supported on macOS."}
@@ -30,4 +32,5 @@ class MacOSDioramaHandler(BaseDioramaHandler):
return {"success": True, "result": result}
except Exception as e:
import traceback
return {"success": False, "error": str(e), "trace": traceback.format_exc()}

View File

@@ -8,31 +8,31 @@ like the menubar and dock, which are needed for proper screenshot composition.
import sys
import time
from typing import Dict, Any, Optional, Tuple
from typing import Any, Dict, Optional, Tuple
# Import Objective-C bridge libraries
try:
import AppKit
import Foundation
from AppKit import NSRunningApplication, NSWorkspace
from ApplicationServices import (
AXUIElementCreateSystemWide,
AXUIElementCreateApplication,
AXUIElementCopyAttributeValue,
AXUIElementCopyAttributeValues,
kAXChildrenAttribute,
kAXRoleAttribute,
kAXTitleAttribute,
kAXPositionAttribute,
kAXSizeAttribute,
kAXErrorSuccess,
AXValueGetType,
kAXValueCGSizeType,
kAXValueCGPointType,
AXUIElementCreateApplication,
AXUIElementCreateSystemWide,
AXUIElementGetTypeID,
AXValueGetType,
AXValueGetValue,
kAXChildrenAttribute,
kAXErrorSuccess,
kAXMenuBarAttribute,
kAXPositionAttribute,
kAXRoleAttribute,
kAXSizeAttribute,
kAXTitleAttribute,
kAXValueCGPointType,
kAXValueCGSizeType,
)
from AppKit import NSWorkspace, NSRunningApplication
import Foundation
except ImportError:
print("Error: This script requires PyObjC to be installed.")
print("Please install it with: pip install pyobjc")
@@ -74,13 +74,8 @@ def element_value(element, type):
def get_element_bounds(element):
"""Get the bounds of an accessibility element"""
bounds = {
"x": 0,
"y": 0,
"width": 0,
"height": 0
}
bounds = {"x": 0, "y": 0, "width": 0, "height": 0}
# Get position
position_value = element_attribute(element, kAXPositionAttribute)
if position_value:
@@ -88,7 +83,7 @@ def get_element_bounds(element):
if position_value:
bounds["x"] = position_value.x
bounds["y"] = position_value.y
# Get size
size_value = element_attribute(element, kAXSizeAttribute)
if size_value:
@@ -96,7 +91,7 @@ def get_element_bounds(element):
if size_value:
bounds["width"] = size_value.width
bounds["height"] = size_value.height
return bounds
@@ -111,13 +106,13 @@ def find_dock_process():
def get_menubar_bounds():
"""Get the bounds of the macOS menubar
Returns:
Dictionary with x, y, width, height of the menubar
"""
# Get the system-wide accessibility element
system_element = AXUIElementCreateSystemWide()
# Try to find the menubar
menubar = element_attribute(system_element, kAXMenuBarAttribute)
if menubar is None:
@@ -127,19 +122,19 @@ def get_menubar_bounds():
app_pid = frontmost_app.processIdentifier()
app_element = AXUIElementCreateApplication(app_pid)
menubar = element_attribute(app_element, kAXMenuBarAttribute)
if menubar is None:
print("Error: Could not get menubar")
# Return default menubar bounds as fallback
return {"x": 0, "y": 0, "width": 1800, "height": 24}
# Get menubar bounds
return get_element_bounds(menubar)
def get_dock_bounds():
"""Get the bounds of the macOS Dock
Returns:
Dictionary with x, y, width, height of the Dock
"""
@@ -148,19 +143,19 @@ def get_dock_bounds():
print("Error: Could not find Dock process")
# Return empty bounds as fallback
return {"x": 0, "y": 0, "width": 0, "height": 0}
# Create an accessibility element for the Dock
dock_element = AXUIElementCreateApplication(dock_pid)
if dock_element is None:
print(f"Error: Could not create accessibility element for Dock (PID {dock_pid})")
return {"x": 0, "y": 0, "width": 0, "height": 0}
# Get the Dock's children
children = element_attribute(dock_element, kAXChildrenAttribute)
if not children or len(children) == 0:
print("Error: Could not get Dock children")
return {"x": 0, "y": 0, "width": 0, "height": 0}
# Find the Dock's list (first child is usually the main dock list)
dock_list = None
for child in children:
@@ -168,28 +163,25 @@ def get_dock_bounds():
if role == "AXList":
dock_list = child
break
if dock_list is None:
print("Error: Could not find Dock list")
return {"x": 0, "y": 0, "width": 0, "height": 0}
# Get the bounds of the dock list
return get_element_bounds(dock_list)
def get_ui_element_bounds():
"""Get the bounds of important UI elements like menubar and dock
Returns:
Dictionary with menubar and dock bounds
"""
menubar_bounds = get_menubar_bounds()
dock_bounds = get_dock_bounds()
return {
"menubar": menubar_bounds,
"dock": dock_bounds
}
return {"menubar": menubar_bounds, "dock": dock_bounds}
if __name__ == "__main__":

View File

@@ -1,24 +1,26 @@
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
class BaseAccessibilityHandler(ABC):
"""Abstract base class for OS-specific accessibility handlers."""
@abstractmethod
async def get_accessibility_tree(self) -> Dict[str, Any]:
"""Get the accessibility tree of the current window."""
pass
@abstractmethod
async def find_element(self, role: Optional[str] = None,
title: Optional[str] = None,
value: Optional[str] = None) -> Dict[str, Any]:
async def find_element(
self, role: Optional[str] = None, title: Optional[str] = None, value: Optional[str] = None
) -> Dict[str, Any]:
"""Find an element in the accessibility tree by criteria."""
pass
class BaseFileHandler(ABC):
"""Abstract base class for OS-specific file handlers."""
@abstractmethod
async def file_exists(self, path: str) -> Dict[str, Any]:
"""Check if a file exists at the specified path."""
@@ -43,7 +45,7 @@ class BaseFileHandler(ABC):
async def write_text(self, path: str, content: str) -> Dict[str, Any]:
"""Write text content to a file."""
pass
@abstractmethod
async def write_bytes(self, path: str, content_b64: str) -> Dict[str, Any]:
"""Write binary content to a file. Sent over the websocket as a base64 string."""
@@ -65,9 +67,11 @@ class BaseFileHandler(ABC):
pass
@abstractmethod
async def read_bytes(self, path: str, offset: int = 0, length: Optional[int] = None) -> Dict[str, Any]:
async def read_bytes(
self, path: str, offset: int = 0, length: Optional[int] = None
) -> Dict[str, Any]:
"""Read the binary contents of a file. Sent over the websocket as a base64 string.
Args:
path: Path to the file
offset: Byte offset to start reading from (default: 0)
@@ -80,9 +84,10 @@ class BaseFileHandler(ABC):
"""Get the size of a file in bytes."""
pass
class BaseAutomationHandler(ABC):
"""Abstract base class for OS-specific automation handlers.
Categories:
- Mouse Actions: Methods for mouse control
- Keyboard Actions: Methods for keyboard input
@@ -90,18 +95,22 @@ class BaseAutomationHandler(ABC):
- Screen Actions: Methods for screen interaction
- Clipboard Actions: Methods for clipboard operations
"""
# Mouse Actions
@abstractmethod
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
async def mouse_down(
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
) -> Dict[str, Any]:
"""Perform a mouse down at the current or specified position."""
pass
@abstractmethod
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
async def mouse_up(
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
) -> Dict[str, Any]:
"""Perform a mouse up at the current or specified position."""
pass
@abstractmethod
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
"""Perform a left click at the current or specified position."""
@@ -113,7 +122,9 @@ class BaseAutomationHandler(ABC):
pass
@abstractmethod
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
async def double_click(
self, x: Optional[int] = None, y: Optional[int] = None
) -> Dict[str, Any]:
"""Perform a double click at the current or specified position."""
pass
@@ -123,9 +134,11 @@ class BaseAutomationHandler(ABC):
pass
@abstractmethod
async def drag_to(self, x: int, y: int, button: str = "left", duration: float = 0.5) -> Dict[str, Any]:
async def drag_to(
self, x: int, y: int, button: str = "left", duration: float = 0.5
) -> Dict[str, Any]:
"""Drag the cursor from current position to specified coordinates.
Args:
x: The x coordinate to drag to
y: The y coordinate to drag to
@@ -133,11 +146,13 @@ class BaseAutomationHandler(ABC):
duration: How long the drag should take in seconds
"""
pass
@abstractmethod
async def drag(self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5) -> Dict[str, Any]:
async def drag(
self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5
) -> Dict[str, Any]:
"""Drag the cursor from current position to specified coordinates.
Args:
path: A list of tuples of x and y coordinates to drag to
button: The mouse button to use ('left', 'middle', 'right')
@@ -150,12 +165,12 @@ class BaseAutomationHandler(ABC):
async def key_down(self, key: str) -> Dict[str, Any]:
"""Press and hold the specified key."""
pass
@abstractmethod
async def key_up(self, key: str) -> Dict[str, Any]:
"""Release the specified key."""
pass
@abstractmethod
async def type_text(self, text: str) -> Dict[str, Any]:
"""Type the specified text."""
@@ -176,7 +191,7 @@ class BaseAutomationHandler(ABC):
async def scroll(self, x: int, y: int) -> Dict[str, Any]:
"""Scroll the specified amount."""
pass
@abstractmethod
async def scroll_down(self, clicks: int = 1) -> Dict[str, Any]:
"""Scroll down by the specified number of clicks."""
@@ -212,9 +227,9 @@ class BaseAutomationHandler(ABC):
@abstractmethod
async def set_clipboard(self, text: str) -> Dict[str, Any]:
"""Set the clipboard content."""
pass
pass
@abstractmethod
async def run_command(self, command: str) -> Dict[str, Any]:
"""Run a command and return the output."""
pass
pass

View File

@@ -1,68 +1,89 @@
import platform
import subprocess
from typing import Tuple, Type
from .base import BaseAccessibilityHandler, BaseAutomationHandler, BaseFileHandler
from computer_server.diorama.base import BaseDioramaHandler
from .base import BaseAccessibilityHandler, BaseAutomationHandler, BaseFileHandler
# Conditionally import platform-specific handlers
system = platform.system().lower()
if system == 'darwin':
from .macos import MacOSAccessibilityHandler, MacOSAutomationHandler
if system == "darwin":
from computer_server.diorama.macos import MacOSDioramaHandler
elif system == 'linux':
from .macos import MacOSAccessibilityHandler, MacOSAutomationHandler
elif system == "linux":
from .linux import LinuxAccessibilityHandler, LinuxAutomationHandler
elif system == 'windows':
elif system == "windows":
from .windows import WindowsAccessibilityHandler, WindowsAutomationHandler
from .generic import GenericFileHandler
class HandlerFactory:
"""Factory for creating OS-specific handlers."""
@staticmethod
def _get_current_os() -> str:
"""Determine the current OS.
Returns:
str: The OS type ('darwin' for macOS, 'linux' for Linux, or 'windows' for Windows)
Raises:
RuntimeError: If unable to determine the current OS
"""
try:
# Use platform.system() as primary method
system = platform.system().lower()
if system in ['darwin', 'linux', 'windows']:
if system in ["darwin", "linux", "windows"]:
return system
# Fallback to uname if platform.system() doesn't return expected values (Unix-like systems only)
result = subprocess.run(['uname', '-s'], capture_output=True, text=True)
result = subprocess.run(["uname", "-s"], capture_output=True, text=True)
if result.returncode == 0:
return result.stdout.strip().lower()
raise RuntimeError(f"Unsupported OS: {system}")
except Exception as e:
raise RuntimeError(f"Failed to determine current OS: {str(e)}")
@staticmethod
def create_handlers() -> Tuple[BaseAccessibilityHandler, BaseAutomationHandler, BaseDioramaHandler, BaseFileHandler]:
def create_handlers() -> (
Tuple[BaseAccessibilityHandler, BaseAutomationHandler, BaseDioramaHandler, BaseFileHandler]
):
"""Create and return appropriate handlers for the current OS.
Returns:
Tuple[BaseAccessibilityHandler, BaseAutomationHandler, BaseDioramaHandler, BaseFileHandler]: A tuple containing
the appropriate accessibility, automation, diorama, and file handlers for the current OS.
Raises:
NotImplementedError: If the current OS is not supported
RuntimeError: If unable to determine the current OS
"""
os_type = HandlerFactory._get_current_os()
if os_type == 'darwin':
return MacOSAccessibilityHandler(), MacOSAutomationHandler(), MacOSDioramaHandler(), GenericFileHandler()
elif os_type == 'linux':
return LinuxAccessibilityHandler(), LinuxAutomationHandler(), BaseDioramaHandler(), GenericFileHandler()
elif os_type == 'windows':
return WindowsAccessibilityHandler(), WindowsAutomationHandler(), BaseDioramaHandler(), GenericFileHandler()
if os_type == "darwin":
return (
MacOSAccessibilityHandler(),
MacOSAutomationHandler(),
MacOSDioramaHandler(),
GenericFileHandler(),
)
elif os_type == "linux":
return (
LinuxAccessibilityHandler(),
LinuxAutomationHandler(),
BaseDioramaHandler(),
GenericFileHandler(),
)
elif os_type == "windows":
return (
WindowsAccessibilityHandler(),
WindowsAutomationHandler(),
BaseDioramaHandler(),
GenericFileHandler(),
)
else:
raise NotImplementedError(f"OS '{os_type}' is not supported")

View File

@@ -6,38 +6,41 @@ Includes:
"""
from pathlib import Path
from typing import Dict, Any, Optional
from .base import BaseFileHandler
import base64
from pathlib import Path
from typing import Any, Dict, Optional
from .base import BaseFileHandler
def resolve_path(path: str) -> Path:
"""Resolve a path to its absolute path. Expand ~ to the user's home directory.
Args:
path: The file or directory path to resolve
Returns:
Path: The resolved absolute path
"""
return Path(path).expanduser().resolve()
class GenericFileHandler(BaseFileHandler):
"""
Generic file handler that provides file system operations for all operating systems.
This class implements the BaseFileHandler interface and provides methods for
file and directory operations including reading, writing, creating, and deleting
files and directories.
"""
async def file_exists(self, path: str) -> Dict[str, Any]:
"""
Check if a file exists at the specified path.
Args:
path: The file path to check
Returns:
Dict containing 'success' boolean and either 'exists' boolean or 'error' string
"""
@@ -49,10 +52,10 @@ class GenericFileHandler(BaseFileHandler):
async def directory_exists(self, path: str) -> Dict[str, Any]:
"""
Check if a directory exists at the specified path.
Args:
path: The directory path to check
Returns:
Dict containing 'success' boolean and either 'exists' boolean or 'error' string
"""
@@ -64,25 +67,30 @@ class GenericFileHandler(BaseFileHandler):
async def list_dir(self, path: str) -> Dict[str, Any]:
"""
List all files and directories in the specified directory.
Args:
path: The directory path to list
Returns:
Dict containing 'success' boolean and either 'files' list of names or 'error' string
"""
try:
return {"success": True, "files": [p.name for p in resolve_path(path).iterdir() if p.is_file() or p.is_dir()]}
return {
"success": True,
"files": [
p.name for p in resolve_path(path).iterdir() if p.is_file() or p.is_dir()
],
}
except Exception as e:
return {"success": False, "error": str(e)}
async def read_text(self, path: str) -> Dict[str, Any]:
"""
Read the contents of a text file.
Args:
path: The file path to read from
Returns:
Dict containing 'success' boolean and either 'content' string or 'error' string
"""
@@ -94,11 +102,11 @@ class GenericFileHandler(BaseFileHandler):
async def write_text(self, path: str, content: str) -> Dict[str, Any]:
"""
Write text content to a file.
Args:
path: The file path to write to
content: The text content to write
Returns:
Dict containing 'success' boolean and optionally 'error' string
"""
@@ -108,60 +116,64 @@ class GenericFileHandler(BaseFileHandler):
except Exception as e:
return {"success": False, "error": str(e)}
async def write_bytes(self, path: str, content_b64: str, append: bool = False) -> Dict[str, Any]:
async def write_bytes(
self, path: str, content_b64: str, append: bool = False
) -> Dict[str, Any]:
"""
Write binary content to a file from base64 encoded string.
Args:
path: The file path to write to
content_b64: Base64 encoded binary content
append: If True, append to existing file; if False, overwrite
Returns:
Dict containing 'success' boolean and optionally 'error' string
"""
try:
mode = 'ab' if append else 'wb'
mode = "ab" if append else "wb"
with open(resolve_path(path), mode) as f:
f.write(base64.b64decode(content_b64))
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def read_bytes(self, path: str, offset: int = 0, length: Optional[int] = None) -> Dict[str, Any]:
async def read_bytes(
self, path: str, offset: int = 0, length: Optional[int] = None
) -> Dict[str, Any]:
"""
Read binary content from a file and return as base64 encoded string.
Args:
path: The file path to read from
offset: Byte offset to start reading from
length: Number of bytes to read; if None, read entire file from offset
Returns:
Dict containing 'success' boolean and either 'content_b64' string or 'error' string
"""
try:
file_path = resolve_path(path)
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
if offset > 0:
f.seek(offset)
if length is not None:
content = f.read(length)
else:
content = f.read()
return {"success": True, "content_b64": base64.b64encode(content).decode('utf-8')}
return {"success": True, "content_b64": base64.b64encode(content).decode("utf-8")}
except Exception as e:
return {"success": False, "error": str(e)}
async def get_file_size(self, path: str) -> Dict[str, Any]:
"""
Get the size of a file in bytes.
Args:
path: The file path to get size for
Returns:
Dict containing 'success' boolean and either 'size' integer or 'error' string
"""
@@ -175,10 +187,10 @@ class GenericFileHandler(BaseFileHandler):
async def delete_file(self, path: str) -> Dict[str, Any]:
"""
Delete a file at the specified path.
Args:
path: The file path to delete
Returns:
Dict containing 'success' boolean and optionally 'error' string
"""
@@ -191,13 +203,13 @@ class GenericFileHandler(BaseFileHandler):
async def create_dir(self, path: str) -> Dict[str, Any]:
"""
Create a directory at the specified path.
Creates parent directories if they don't exist and doesn't raise an error
if the directory already exists.
Args:
path: The directory path to create
Returns:
Dict containing 'success' boolean and optionally 'error' string
"""
@@ -210,10 +222,10 @@ class GenericFileHandler(BaseFileHandler):
async def delete_dir(self, path: str) -> Dict[str, Any]:
"""
Delete an empty directory at the specified path.
Args:
path: The directory path to delete
Returns:
Dict containing 'success' boolean and optionally 'error' string
"""

View File

@@ -7,14 +7,15 @@ To use GUI automation in a headless environment:
1. Install Xvfb: sudo apt-get install xvfb
2. Run with virtual display: xvfb-run python -m computer_server
"""
from typing import Dict, Any, List, Tuple, Optional
import logging
import subprocess
import asyncio
import base64
import os
import json
import logging
import os
import subprocess
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple
# Configure logger
logger = logging.getLogger(__name__)
@@ -23,30 +24,36 @@ logger = logging.getLogger(__name__)
# This allows the server to run in headless environments
try:
import pyautogui
pyautogui.FAILSAFE = False
logger.info("pyautogui successfully imported, GUI automation available")
except Exception as e:
logger.warning(f"pyautogui import failed: {str(e)}. GUI operations will be simulated.")
from pynput.mouse import Button, Controller as MouseController
from pynput.keyboard import Key, Controller as KeyboardController
from pynput.keyboard import Controller as KeyboardController
from pynput.keyboard import Key
from pynput.mouse import Button
from pynput.mouse import Controller as MouseController
from .base import BaseAccessibilityHandler, BaseAutomationHandler
class LinuxAccessibilityHandler(BaseAccessibilityHandler):
"""Linux implementation of accessibility handler."""
async def get_accessibility_tree(self) -> Dict[str, Any]:
"""Get the accessibility tree of the current window.
Returns:
Dict[str, Any]: A dictionary containing success status and a simulated tree structure
since Linux doesn't have equivalent accessibility API like macOS.
"""
# Linux doesn't have equivalent accessibility API like macOS
# Return a minimal dummy tree
logger.info("Getting accessibility tree (simulated, no accessibility API available on Linux)")
logger.info(
"Getting accessibility tree (simulated, no accessibility API available on Linux)"
)
return {
"success": True,
"tree": {
@@ -54,32 +61,31 @@ class LinuxAccessibilityHandler(BaseAccessibilityHandler):
"title": "Linux Window",
"position": {"x": 0, "y": 0},
"size": {"width": 1920, "height": 1080},
"children": []
}
"children": [],
},
}
async def find_element(self, role: Optional[str] = None,
title: Optional[str] = None,
value: Optional[str] = None) -> Dict[str, Any]:
async def find_element(
self, role: Optional[str] = None, title: Optional[str] = None, value: Optional[str] = None
) -> Dict[str, Any]:
"""Find an element in the accessibility tree by criteria.
Args:
role: The role of the element to find.
title: The title of the element to find.
value: The value of the element to find.
Returns:
Dict[str, Any]: A dictionary indicating that element search is not supported on Linux.
"""
logger.info(f"Finding element with role={role}, title={title}, value={value} (not supported on Linux)")
return {
"success": False,
"message": "Element search not supported on Linux"
}
logger.info(
f"Finding element with role={role}, title={title}, value={value} (not supported on Linux)"
)
return {"success": False, "message": "Element search not supported on Linux"}
def get_cursor_position(self) -> Tuple[int, int]:
"""Get the current cursor position.
Returns:
Tuple[int, int]: The x and y coordinates of the cursor position.
Returns (0, 0) if pyautogui is not available.
@@ -89,13 +95,13 @@ class LinuxAccessibilityHandler(BaseAccessibilityHandler):
return pos.x, pos.y
except Exception as e:
logger.warning(f"Failed to get cursor position with pyautogui: {e}")
logger.info("Getting cursor position (simulated)")
return 0, 0
def get_screen_size(self) -> Tuple[int, int]:
"""Get the screen size.
Returns:
Tuple[int, int]: The width and height of the screen in pixels.
Returns (1920, 1080) if pyautogui is not available.
@@ -105,24 +111,28 @@ class LinuxAccessibilityHandler(BaseAccessibilityHandler):
return size.width, size.height
except Exception as e:
logger.warning(f"Failed to get screen size with pyautogui: {e}")
logger.info("Getting screen size (simulated)")
return 1920, 1080
class LinuxAutomationHandler(BaseAutomationHandler):
"""Linux implementation of automation handler using pyautogui."""
keyboard = KeyboardController()
mouse = MouseController()
# Mouse Actions
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
async def mouse_down(
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
) -> Dict[str, Any]:
"""Press and hold a mouse button at the specified coordinates.
Args:
x: The x coordinate to move to before pressing. If None, uses current position.
y: The y coordinate to move to before pressing. If None, uses current position.
button: The mouse button to press ("left", "right", or "middle").
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -133,15 +143,17 @@ class LinuxAutomationHandler(BaseAutomationHandler):
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
async def mouse_up(
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
) -> Dict[str, Any]:
"""Release a mouse button at the specified coordinates.
Args:
x: The x coordinate to move to before releasing. If None, uses current position.
y: The y coordinate to move to before releasing. If None, uses current position.
button: The mouse button to release ("left", "right", or "middle").
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -152,14 +164,14 @@ class LinuxAutomationHandler(BaseAutomationHandler):
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def move_cursor(self, x: int, y: int) -> Dict[str, Any]:
"""Move the cursor to the specified coordinates.
Args:
x: The x coordinate to move to.
y: The y coordinate to move to.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -171,11 +183,11 @@ class LinuxAutomationHandler(BaseAutomationHandler):
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
"""Perform a left mouse click at the specified coordinates.
Args:
x: The x coordinate to click at. If None, clicks at current position.
y: The y coordinate to click at. If None, clicks at current position.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -189,11 +201,11 @@ class LinuxAutomationHandler(BaseAutomationHandler):
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
"""Perform a right mouse click at the specified coordinates.
Args:
x: The x coordinate to click at. If None, clicks at current position.
y: The y coordinate to click at. If None, clicks at current position.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -205,13 +217,15 @@ class LinuxAutomationHandler(BaseAutomationHandler):
except Exception as e:
return {"success": False, "error": str(e)}
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
async def double_click(
self, x: Optional[int] = None, y: Optional[int] = None
) -> Dict[str, Any]:
"""Perform a double click at the specified coordinates.
Args:
x: The x coordinate to double click at. If None, clicks at current position.
y: The y coordinate to double click at. If None, clicks at current position.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -223,14 +237,16 @@ class LinuxAutomationHandler(BaseAutomationHandler):
except Exception as e:
return {"success": False, "error": str(e)}
async def click(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
async def click(
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
) -> Dict[str, Any]:
"""Perform a mouse click with the specified button at the given coordinates.
Args:
x: The x coordinate to click at. If None, clicks at current position.
y: The y coordinate to click at. If None, clicks at current position.
button: The mouse button to click ("left", "right", or "middle").
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -242,15 +258,17 @@ class LinuxAutomationHandler(BaseAutomationHandler):
except Exception as e:
return {"success": False, "error": str(e)}
async def drag_to(self, x: int, y: int, button: str = "left", duration: float = 0.5) -> Dict[str, Any]:
async def drag_to(
self, x: int, y: int, button: str = "left", duration: float = 0.5
) -> Dict[str, Any]:
"""Drag from the current position to the specified coordinates.
Args:
x: The x coordinate to drag to.
y: The y coordinate to drag to.
button: The mouse button to use for dragging.
duration: The time in seconds to take for the drag operation.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -260,16 +278,18 @@ class LinuxAutomationHandler(BaseAutomationHandler):
except Exception as e:
return {"success": False, "error": str(e)}
async def drag(self, start_x: int, start_y: int, end_x: int, end_y: int, button: str = "left") -> Dict[str, Any]:
async def drag(
self, start_x: int, start_y: int, end_x: int, end_y: int, button: str = "left"
) -> Dict[str, Any]:
"""Drag from start coordinates to end coordinates.
Args:
start_x: The starting x coordinate.
start_y: The starting y coordinate.
end_x: The ending x coordinate.
end_y: The ending y coordinate.
button: The mouse button to use for dragging.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -280,14 +300,16 @@ class LinuxAutomationHandler(BaseAutomationHandler):
except Exception as e:
return {"success": False, "error": str(e)}
async def drag_path(self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5) -> Dict[str, Any]:
async def drag_path(
self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5
) -> Dict[str, Any]:
"""Drag along a path defined by a list of coordinates.
Args:
path: A list of (x, y) coordinate tuples defining the drag path.
button: The mouse button to use for dragging.
duration: The time in seconds to take for each segment of the drag.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -304,10 +326,10 @@ class LinuxAutomationHandler(BaseAutomationHandler):
# Keyboard Actions
async def key_down(self, key: str) -> Dict[str, Any]:
"""Press and hold a key.
Args:
key: The key to press down.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -316,13 +338,13 @@ class LinuxAutomationHandler(BaseAutomationHandler):
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def key_up(self, key: str) -> Dict[str, Any]:
"""Release a key.
Args:
key: The key to release.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -331,13 +353,13 @@ class LinuxAutomationHandler(BaseAutomationHandler):
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def type_text(self, text: str) -> Dict[str, Any]:
"""Type the specified text using the keyboard.
Args:
text: The text to type.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -350,10 +372,10 @@ class LinuxAutomationHandler(BaseAutomationHandler):
async def press_key(self, key: str) -> Dict[str, Any]:
"""Press and release a key.
Args:
key: The key to press.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -365,10 +387,10 @@ class LinuxAutomationHandler(BaseAutomationHandler):
async def hotkey(self, keys: List[str]) -> Dict[str, Any]:
"""Press a combination of keys simultaneously.
Args:
keys: A list of keys to press together as a hotkey combination.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -381,11 +403,11 @@ class LinuxAutomationHandler(BaseAutomationHandler):
# Scrolling Actions
async def scroll(self, x: int, y: int) -> Dict[str, Any]:
"""Scroll the mouse wheel.
Args:
x: The horizontal scroll amount.
y: The vertical scroll amount.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -394,13 +416,13 @@ class LinuxAutomationHandler(BaseAutomationHandler):
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def scroll_down(self, clicks: int = 1) -> Dict[str, Any]:
"""Scroll down by the specified number of clicks.
Args:
clicks: The number of scroll clicks to perform downward.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -412,10 +434,10 @@ class LinuxAutomationHandler(BaseAutomationHandler):
async def scroll_up(self, clicks: int = 1) -> Dict[str, Any]:
"""Scroll up by the specified number of clicks.
Args:
clicks: The number of scroll clicks to perform upward.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -428,13 +450,14 @@ class LinuxAutomationHandler(BaseAutomationHandler):
# Screen Actions
async def screenshot(self) -> Dict[str, Any]:
"""Take a screenshot of the current screen.
Returns:
Dict[str, Any]: A dictionary containing success status and base64-encoded image data,
or error message if failed.
"""
try:
from PIL import Image
screenshot = pyautogui.screenshot()
if not isinstance(screenshot, Image.Image):
return {"success": False, "error": "Failed to capture screenshot"}
@@ -448,7 +471,7 @@ class LinuxAutomationHandler(BaseAutomationHandler):
async def get_screen_size(self) -> Dict[str, Any]:
"""Get the size of the screen.
Returns:
Dict[str, Any]: A dictionary containing success status and screen dimensions,
or error message if failed.
@@ -461,7 +484,7 @@ class LinuxAutomationHandler(BaseAutomationHandler):
async def get_cursor_position(self) -> Dict[str, Any]:
"""Get the current position of the cursor.
Returns:
Dict[str, Any]: A dictionary containing success status and cursor coordinates,
or error message if failed.
@@ -475,13 +498,14 @@ class LinuxAutomationHandler(BaseAutomationHandler):
# Clipboard Actions
async def copy_to_clipboard(self) -> Dict[str, Any]:
"""Get the current content of the clipboard.
Returns:
Dict[str, Any]: A dictionary containing success status and clipboard content,
or error message if failed.
"""
try:
import pyperclip
content = pyperclip.paste()
return {"success": True, "content": content}
except Exception as e:
@@ -489,15 +513,16 @@ class LinuxAutomationHandler(BaseAutomationHandler):
async def set_clipboard(self, text: str) -> Dict[str, Any]:
"""Set the clipboard content to the specified text.
Args:
text: The text to copy to the clipboard.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
try:
import pyperclip
pyperclip.copy(text)
return {"success": True}
except Exception as e:
@@ -506,10 +531,10 @@ class LinuxAutomationHandler(BaseAutomationHandler):
# Command Execution
async def run_command(self, command: str) -> Dict[str, Any]:
"""Execute a shell command asynchronously.
Args:
command: The shell command to execute.
Returns:
Dict[str, Any]: A dictionary containing success status, stdout, stderr,
and return code, or error message if failed.
@@ -517,18 +542,16 @@ class LinuxAutomationHandler(BaseAutomationHandler):
try:
# Create subprocess
process = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
# Wait for the subprocess to finish
stdout, stderr = await process.communicate()
# Return decoded output
return {
"success": True,
"stdout": stdout.decode() if stdout else "",
"stderr": stderr.decode() if stderr else "",
"return_code": process.returncode
"success": True,
"stdout": stdout.decode() if stdout else "",
"stderr": stderr.decode() if stderr else "",
"return_code": process.returncode,
}
except Exception as e:
return {"success": False, "error": str(e)}

View File

@@ -1,54 +1,57 @@
import pyautogui
pyautogui.FAILSAFE = False
from pynput.mouse import Button, Controller as MouseController
from pynput.keyboard import Key, Controller as KeyboardController
import time
import asyncio
import base64
import copy
import json
import logging
import re
import time
from ctypes import POINTER, byref, c_void_p
from io import BytesIO
from typing import Optional, Dict, Any, List, Tuple
from ctypes import byref, c_void_p, POINTER
from AppKit import NSWorkspace # type: ignore
from typing import Any, Dict, List, Optional, Tuple
import AppKit
import Foundation
import objc
from AppKit import NSWorkspace # type: ignore
from ApplicationServices import AXUIElementCopyAttributeValue # type: ignore
from ApplicationServices import AXUIElementCopyAttributeValues # type: ignore
from ApplicationServices import AXUIElementCreateApplication # type: ignore
from ApplicationServices import AXUIElementCreateSystemWide # type: ignore
from ApplicationServices import AXUIElementGetTypeID # type: ignore
from ApplicationServices import AXValueGetType # type: ignore
from ApplicationServices import AXValueGetValue # type: ignore
from ApplicationServices import kAXChildrenAttribute # type: ignore
from ApplicationServices import kAXDescriptionAttribute # type: ignore
from ApplicationServices import kAXEnabledAttribute # type: ignore
from ApplicationServices import kAXErrorSuccess # type: ignore
from ApplicationServices import kAXFocusedApplicationAttribute # type: ignore
from ApplicationServices import kAXFocusedUIElementAttribute # type: ignore
from ApplicationServices import kAXFocusedWindowAttribute # type: ignore
from ApplicationServices import kAXMainWindowAttribute # type: ignore
from ApplicationServices import kAXPositionAttribute # type: ignore
from ApplicationServices import kAXRoleAttribute # type: ignore
from ApplicationServices import kAXRoleDescriptionAttribute # type: ignore
from ApplicationServices import kAXSelectedTextAttribute # type: ignore
from ApplicationServices import kAXSelectedTextRangeAttribute # type: ignore
from ApplicationServices import kAXSizeAttribute # type: ignore
from ApplicationServices import kAXTitleAttribute # type: ignore
from ApplicationServices import kAXValueAttribute # type: ignore
from ApplicationServices import kAXValueCFRangeType # type: ignore
from ApplicationServices import kAXValueCGPointType # type: ignore
from ApplicationServices import kAXValueCGSizeType # type: ignore
from ApplicationServices import kAXVisibleChildrenAttribute # type: ignore
from ApplicationServices import kAXWindowsAttribute # type: ignore
from pynput.keyboard import Controller as KeyboardController
from pynput.keyboard import Key
from pynput.mouse import Button
from pynput.mouse import Controller as MouseController
from Quartz.CoreGraphics import * # type: ignore
from Quartz.CoreGraphics import CGPoint, CGSize # type: ignore
import Foundation
from ApplicationServices import (
AXUIElementCreateSystemWide, # type: ignore
AXUIElementCreateApplication, # type: ignore
AXUIElementCopyAttributeValue, # type: ignore
AXUIElementCopyAttributeValues, # type: ignore
kAXFocusedWindowAttribute, # type: ignore
kAXWindowsAttribute, # type: ignore
kAXMainWindowAttribute, # type: ignore
kAXChildrenAttribute, # type: ignore
kAXRoleAttribute, # type: ignore
kAXTitleAttribute, # type: ignore
kAXValueAttribute, # type: ignore
kAXDescriptionAttribute, # type: ignore
kAXEnabledAttribute, # type: ignore
kAXPositionAttribute, # type: ignore
kAXSizeAttribute, # type: ignore
kAXErrorSuccess, # type: ignore
AXValueGetType, # type: ignore
kAXValueCGSizeType, # type: ignore
kAXValueCGPointType, # type: ignore
kAXValueCFRangeType, # type: ignore
AXUIElementGetTypeID, # type: ignore
AXValueGetValue, # type: ignore
kAXVisibleChildrenAttribute, # type: ignore
kAXRoleDescriptionAttribute, # type: ignore
kAXFocusedApplicationAttribute, # type: ignore
kAXFocusedUIElementAttribute, # type: ignore
kAXSelectedTextAttribute, # type: ignore
kAXSelectedTextRangeAttribute, # type: ignore
)
import objc
import re
import json
import copy
import asyncio
from .base import BaseAccessibilityHandler, BaseAutomationHandler
import logging
logger = logging.getLogger(__name__)
@@ -73,24 +76,26 @@ kCGWindowAlpha = "kCGWindowAlpha" # Window opacity
NSApplicationActivationOptions = {
"regular": 0, # Default activation
"bringing_all_windows_forward": 1 << 0, # NSApplicationActivateAllWindows
"ignoring_other_apps": 1 << 1 # NSApplicationActivateIgnoringOtherApps
"ignoring_other_apps": 1 << 1, # NSApplicationActivateIgnoringOtherApps
}
def CFAttributeToPyObject(attrValue):
"""Convert Core Foundation attribute values to Python objects.
Args:
attrValue: Core Foundation attribute value to convert
Returns:
Converted Python object or None if conversion fails
"""
def list_helper(list_value):
"""Helper function to convert CF arrays to Python lists.
Args:
list_value: Core Foundation array to convert
Returns:
Python list containing converted items
"""
@@ -101,10 +106,10 @@ def CFAttributeToPyObject(attrValue):
def number_helper(number_value):
"""Helper function to convert CF numbers to Python numbers.
Args:
number_value: Core Foundation number to convert
Returns:
Python int or float, or None if conversion fails
"""
@@ -123,10 +128,10 @@ def CFAttributeToPyObject(attrValue):
def axuielement_helper(element_value):
"""Helper function to handle AX UI elements.
Args:
element_value: Accessibility UI element to process
Returns:
The element value unchanged
"""
@@ -164,11 +169,11 @@ def CFAttributeToPyObject(attrValue):
def element_attribute(element, attribute):
"""Get an attribute value from an accessibility element.
Args:
element: The accessibility element
attribute: The attribute name to retrieve
Returns:
The attribute value or None if not found
"""
@@ -190,11 +195,11 @@ def element_attribute(element, attribute):
def element_value(element, type):
"""Extract a typed value from an accessibility element.
Args:
element: The accessibility element containing the value
type: The expected value type
Returns:
The extracted value or None if extraction fails
"""
@@ -206,10 +211,10 @@ def element_value(element, type):
class UIElement:
"""Represents a UI element in the accessibility tree with position, size, and hierarchy information."""
def __init__(self, element, offset_x=0, offset_y=0, max_depth=None, parents_visible_bbox=None):
"""Initialize a UIElement from an accessibility element.
Args:
element: The accessibility element to wrap
offset_x: X offset for position calculations
@@ -297,7 +302,7 @@ class UIElement:
def _set_bboxes(self, parents_visible_bbox):
"""Set bounding box and visible bounding box for the element.
Args:
parents_visible_bbox: Parent's visible bounding box for intersection calculation
"""
@@ -332,13 +337,13 @@ class UIElement:
def _get_children(self, element, start_position, offset_x, offset_y):
"""Get child elements from the accessibility element.
Args:
element: The parent accessibility element
start_position: Starting position for offset calculations
offset_x: X offset for child positioning
offset_y: Y offset for child positioning
Returns:
List of UIElement children
"""
@@ -371,7 +376,7 @@ class UIElement:
def component_hash(self):
"""Generate a hash identifier for this component based on its properties.
Returns:
MD5 hash string of component properties
"""
@@ -388,10 +393,10 @@ class UIElement:
def hash_from_string(self, string):
"""Generate MD5 hash from a string.
Args:
string: Input string to hash
Returns:
MD5 hash hexdigest or empty string if input is None/empty
"""
@@ -403,10 +408,10 @@ class UIElement:
def children_content_hash(self, children):
"""Generate a hash representing the content and structure of child elements.
Args:
children: List of child UIElement objects
Returns:
Combined hash of children content and structure
"""
@@ -426,16 +431,17 @@ class UIElement:
def to_dict(self):
"""Convert the UIElement to a dictionary representation.
Returns:
Dictionary containing all element properties and children
"""
def children_to_dict(children):
"""Convert list of children to dictionary format.
Args:
children: List of UIElement children to convert
Returns:
List of dictionaries representing the children
"""
@@ -464,7 +470,7 @@ class UIElement:
size = f"{self.size.width:.0f};{self.size.height:.0f}"
else:
size = ""
return {
"id": self.identifier,
"name": self.name,
@@ -482,36 +488,38 @@ class UIElement:
}
import Quartz
from AppKit import NSWorkspace, NSRunningApplication
from pathlib import Path
import Quartz
from AppKit import NSRunningApplication, NSWorkspace
def get_all_windows_zorder():
"""Get all windows in the system with their z-order information.
Returns:
List of window dictionaries sorted by z-index, containing window properties
like id, name, pid, owner, bounds, layer, and opacity
"""
window_list = Quartz.CGWindowListCopyWindowInfo(
Quartz.kCGWindowListOptionOnScreenOnly,
Quartz.kCGNullWindowID
Quartz.kCGWindowListOptionOnScreenOnly, Quartz.kCGNullWindowID
)
z_order = {window['kCGWindowNumber']: z_index for z_index, window in enumerate(window_list[::-1])}
z_order = {
window["kCGWindowNumber"]: z_index for z_index, window in enumerate(window_list[::-1])
}
window_list_all = Quartz.CGWindowListCopyWindowInfo(
Quartz.kCGWindowListOptionAll,
Quartz.kCGNullWindowID
Quartz.kCGWindowListOptionAll, Quartz.kCGNullWindowID
)
windows = []
for window in window_list_all:
window_id = window.get('kCGWindowNumber', 0)
window_name = window.get('kCGWindowName', '')
window_pid = window.get('kCGWindowOwnerPID', 0)
window_bounds = window.get('kCGWindowBounds', {})
window_owner = window.get('kCGWindowOwnerName', '')
window_is_on_screen = window.get('kCGWindowIsOnscreen', False)
layer = window.get('kCGWindowLayer', 0)
opacity = window.get('kCGWindowAlpha', 1.0)
window_id = window.get("kCGWindowNumber", 0)
window_name = window.get("kCGWindowName", "")
window_pid = window.get("kCGWindowOwnerPID", 0)
window_bounds = window.get("kCGWindowBounds", {})
window_owner = window.get("kCGWindowOwnerName", "")
window_is_on_screen = window.get("kCGWindowIsOnscreen", False)
layer = window.get("kCGWindowLayer", 0)
opacity = window.get("kCGWindowAlpha", 1.0)
z_index = z_order.get(window_id, -1)
if window_name == "Dock" and window_owner == "Dock":
role = "dock"
@@ -522,32 +530,35 @@ def get_all_windows_zorder():
else:
role = "app"
if window_bounds:
windows.append({
"id": window_id,
"name": window_name or "Unnamed Window",
"pid": window_pid,
"owner": window_owner,
"role": role,
"is_on_screen": window_is_on_screen,
"bounds": {
"x": window_bounds.get('X', 0),
"y": window_bounds.get('Y', 0),
"width": window_bounds.get('Width', 0),
"height": window_bounds.get('Height', 0)
},
"layer": layer,
"z_index": z_index,
"opacity": opacity
})
windows.append(
{
"id": window_id,
"name": window_name or "Unnamed Window",
"pid": window_pid,
"owner": window_owner,
"role": role,
"is_on_screen": window_is_on_screen,
"bounds": {
"x": window_bounds.get("X", 0),
"y": window_bounds.get("Y", 0),
"width": window_bounds.get("Width", 0),
"height": window_bounds.get("Height", 0),
},
"layer": layer,
"z_index": z_index,
"opacity": opacity,
}
)
windows = sorted(windows, key=lambda x: x["z_index"])
return windows
def get_app_info(app):
"""Extract information from an NSRunningApplication object.
Args:
app: NSRunningApplication instance
Returns:
Dictionary containing app name, bundle ID, PID, and status flags
"""
@@ -560,12 +571,13 @@ def get_app_info(app):
"terminated": app.isTerminated(),
}
def get_menubar_items(active_app_pid=None):
"""Get menubar items for the active application.
Args:
active_app_pid: Process ID of the active application, or None to use frontmost app
Returns:
List of menubar item dictionaries with title, bounds, index, and app_pid
"""
@@ -591,26 +603,24 @@ def get_menubar_items(active_app_pid=None):
position_value = element_attribute(item, kAXPositionAttribute)
if position_value:
position_value = element_value(position_value, kAXValueCGPointType)
bounds["x"] = getattr(position_value, 'x', 0)
bounds["y"] = getattr(position_value, 'y', 0)
bounds["x"] = getattr(position_value, "x", 0)
bounds["y"] = getattr(position_value, "y", 0)
size_value = element_attribute(item, kAXSizeAttribute)
if size_value:
size_value = element_value(size_value, kAXValueCGSizeType)
bounds["width"] = getattr(size_value, 'width', 0)
bounds["height"] = getattr(size_value, 'height', 0)
menubar_items.append({
"title": title,
"bounds": bounds,
"index": i,
"app_pid": active_app_pid
})
bounds["width"] = getattr(size_value, "width", 0)
bounds["height"] = getattr(size_value, "height", 0)
menubar_items.append(
{"title": title, "bounds": bounds, "index": i, "app_pid": active_app_pid}
)
return menubar_items
def get_dock_items():
"""Get all items in the macOS Dock.
Returns:
List of dock item dictionaries with title, description, bounds, index,
List of dock item dictionaries with title, description, bounds, index,
type, role, and subrole information
"""
dock_items = []
@@ -648,13 +658,13 @@ def get_dock_items():
position_value = element_attribute(item, kAXPositionAttribute)
if position_value:
position_value = element_value(position_value, kAXValueCGPointType)
bounds["x"] = getattr(position_value, 'x', 0)
bounds["y"] = getattr(position_value, 'y', 0)
bounds["x"] = getattr(position_value, "x", 0)
bounds["y"] = getattr(position_value, "y", 0)
size_value = element_attribute(item, kAXSizeAttribute)
if size_value:
size_value = element_value(size_value, kAXValueCGSizeType)
bounds["width"] = getattr(size_value, 'width', 0)
bounds["height"] = getattr(size_value, 'height', 0)
bounds["width"] = getattr(size_value, "width", 0)
bounds["height"] = getattr(size_value, "height", 0)
item_type = "unknown"
if subrole == "AXApplicationDockItem":
item_type = "application"
@@ -666,23 +676,26 @@ def get_dock_items():
item_type = "separator"
elif "trash" in title.lower():
item_type = "trash"
dock_items.append({
"title": title,
"description": description,
"bounds": bounds,
"index": i,
"type": item_type,
"role": role,
"subrole": subrole
})
dock_items.append(
{
"title": title,
"description": description,
"bounds": bounds,
"index": i,
"type": item_type,
"role": role,
"subrole": subrole,
}
)
return dock_items
class MacOSAccessibilityHandler(BaseAccessibilityHandler):
"""Handler for macOS accessibility features and UI element inspection."""
def get_desktop_state(self):
"""Get the current state of the desktop including windows, apps, menubar, and dock.
Returns:
Dictionary containing applications, windows, menubar_items, and dock_items
"""
@@ -696,7 +709,9 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
pid = app.processIdentifier()
try:
app_elem = AXUIElementCreateApplication(pid)
err, app_windows = AXUIElementCopyAttributeValue(app_elem, kAXWindowsAttribute, None)
err, app_windows = AXUIElementCopyAttributeValue(
app_elem, kAXWindowsAttribute, None
)
trees = []
if err == kAXErrorSuccess and app_windows:
for ax_win in app_windows:
@@ -713,31 +728,32 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
pid = win["pid"]
idx = pid_to_idx.get(pid, 0)
ax_trees = pid_to_ax_trees.get(pid, [])
win["children"] = ax_trees[idx]["children"] if idx < len(ax_trees) and "children" in ax_trees[idx] else []
win["children"] = (
ax_trees[idx]["children"]
if idx < len(ax_trees) and "children" in ax_trees[idx]
else []
)
pid_to_idx[pid] = idx + 1
pid_to_window_ids.setdefault(pid, []).append(win["id"])
for app in running_apps:
info = get_app_info(app)
app_pid = info["pid"]
applications.append({
"info": info,
"windows": pid_to_window_ids.get(app_pid, [])
})
applications.append({"info": info, "windows": pid_to_window_ids.get(app_pid, [])})
menubar_items = get_menubar_items()
dock_items = get_dock_items()
return {
"applications": applications,
"windows": windows,
"menubar_items": menubar_items,
"dock_items": dock_items
"dock_items": dock_items,
}
def get_application_windows(self, pid: int):
"""Get all windows for a specific application.
Args:
pid: Process ID of the application
Returns:
List of accessibility window elements or empty list if none found
"""
@@ -753,7 +769,7 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
def get_all_windows(self):
"""Get all visible windows in the system.
Returns:
List of window dictionaries with app information and window details
"""
@@ -791,7 +807,7 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
def get_running_apps(self):
"""Get all currently running applications.
Returns:
List of NSRunningApplication objects
"""
@@ -803,11 +819,11 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
def get_ax_attribute(self, element, attribute):
"""Get an accessibility attribute from an element.
Args:
element: The accessibility element
attribute: The attribute name to retrieve
Returns:
The attribute value or None if not found
"""
@@ -815,10 +831,10 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
def serialize_node(self, element):
"""Create a serializable dictionary representation of an accessibility element.
Args:
element: The accessibility element to serialize
Returns:
Dictionary containing element properties like role, title, value, position, and size
"""
@@ -851,16 +867,13 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
async def get_accessibility_tree(self) -> Dict[str, Any]:
"""Get the complete accessibility tree for the current desktop state.
Returns:
Dictionary containing success status and desktop state information
"""
"""
try:
desktop_state = self.get_desktop_state()
return {
"success": True,
**desktop_state
}
return {"success": True, **desktop_state}
except Exception as e:
return {"success": False, "error": str(e)}
@@ -869,12 +882,12 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
self, role: Optional[str] = None, title: Optional[str] = None, value: Optional[str] = None
) -> Dict[str, Any]:
"""Find an accessibility element matching the specified criteria.
Args:
role: The accessibility role to match (optional)
title: The title to match (optional)
value: The value to match (optional)
Returns:
Dictionary containing success status and the found element or error message
"""
@@ -883,10 +896,10 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
def match_element(element):
"""Check if an element matches the search criteria.
Args:
element: The accessibility element to check
Returns:
True if element matches all specified criteria, False otherwise
"""
@@ -900,10 +913,10 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
def search_tree(element):
"""Recursively search the accessibility tree for matching elements.
Args:
element: The accessibility element to search from
Returns:
Serialized element dictionary if match found, None otherwise
"""
@@ -924,58 +937,71 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
except Exception as e:
return {"success": False, "error": str(e)}
class MacOSAutomationHandler(BaseAutomationHandler):
"""Handler for macOS automation including mouse, keyboard, and screen operations."""
# Mouse Actions
mouse = MouseController()
keyboard = KeyboardController()
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
async def mouse_down(
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
) -> Dict[str, Any]:
"""Press and hold a mouse button at the specified coordinates.
Args:
x: X coordinate (optional, uses current position if None)
y: Y coordinate (optional, uses current position if None)
button: Mouse button to press ("left", "right", or "middle")
Returns:
Dictionary containing success status and error message if failed
"""
try:
if x is not None and y is not None:
self.mouse.position = (x, y)
self.mouse.press(Button.left if button == "left" else Button.right if button == "right" else Button.middle)
self.mouse.press(
Button.left
if button == "left"
else Button.right if button == "right" else Button.middle
)
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
async def mouse_up(
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
) -> Dict[str, Any]:
"""Release a mouse button at the specified coordinates.
Args:
x: X coordinate (optional, uses current position if None)
y: Y coordinate (optional, uses current position if None)
button: Mouse button to release ("left", "right", or "middle")
Returns:
Dictionary containing success status and error message if failed
"""
try:
if x is not None and y is not None:
self.mouse.position = (x, y)
self.mouse.release(Button.left if button == "left" else Button.right if button == "right" else Button.middle)
self.mouse.release(
Button.left
if button == "left"
else Button.right if button == "right" else Button.middle
)
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
"""Perform a left mouse click at the specified coordinates.
Args:
x: X coordinate (optional, uses current position if None)
y: Y coordinate (optional, uses current position if None)
Returns:
Dictionary containing success status and error message if failed
"""
@@ -989,11 +1015,11 @@ class MacOSAutomationHandler(BaseAutomationHandler):
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
"""Perform a right mouse click at the specified coordinates.
Args:
x: X coordinate (optional, uses current position if None)
y: Y coordinate (optional, uses current position if None)
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1009,11 +1035,11 @@ class MacOSAutomationHandler(BaseAutomationHandler):
self, x: Optional[int] = None, y: Optional[int] = None
) -> Dict[str, Any]:
"""Perform a double left mouse click at the specified coordinates.
Args:
x: X coordinate (optional, uses current position if None)
y: Y coordinate (optional, uses current position if None)
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1027,11 +1053,11 @@ class MacOSAutomationHandler(BaseAutomationHandler):
async def move_cursor(self, x: int, y: int) -> Dict[str, Any]:
"""Move the mouse cursor to the specified coordinates.
Args:
x: Target X coordinate
y: Target Y coordinate
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1045,18 +1071,22 @@ class MacOSAutomationHandler(BaseAutomationHandler):
self, x: int, y: int, button: str = "left", duration: float = 0.5
) -> Dict[str, Any]:
"""Drag from current position to target coordinates.
Args:
x: Target X coordinate
y: Target Y coordinate
button: Mouse button to use for dragging ("left", "right", or "middle")
duration: Duration of the drag operation in seconds
Returns:
Dictionary containing success status and error message if failed
"""
try:
btn = Button.left if button == "left" else Button.right if button == "right" else Button.middle
btn = (
Button.left
if button == "left"
else Button.right if button == "right" else Button.middle
)
# Press
self.mouse.press(btn)
# Move with sleep to simulate drag duration
@@ -1082,19 +1112,23 @@ class MacOSAutomationHandler(BaseAutomationHandler):
self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5
) -> Dict[str, Any]:
"""Drag the mouse along a specified path of coordinates.
Args:
path: List of (x, y) coordinate tuples defining the drag path
button: Mouse button to use for dragging ("left", "right", or "middle")
duration: Total duration of the drag operation in seconds
Returns:
Dictionary containing success status and error message if failed
"""
try:
if not path or len(path) < 2:
return {"success": False, "error": "Path must contain at least 2 points"}
btn = Button.left if button == "left" else Button.right if button == "right" else Button.middle
btn = (
Button.left
if button == "left"
else Button.right if button == "right" else Button.middle
)
# Move to the first point
self.mouse.position = path[0]
self.mouse.press(btn)
@@ -1114,10 +1148,10 @@ class MacOSAutomationHandler(BaseAutomationHandler):
# Keyboard Actions
async def key_down(self, key: str) -> Dict[str, Any]:
"""Press and hold a keyboard key.
Args:
key: Key name to press (using pyautogui key names)
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1127,13 +1161,13 @@ class MacOSAutomationHandler(BaseAutomationHandler):
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def key_up(self, key: str) -> Dict[str, Any]:
"""Release a keyboard key.
Args:
key: Key name to release (using pyautogui key names)
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1143,13 +1177,13 @@ class MacOSAutomationHandler(BaseAutomationHandler):
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def type_text(self, text: str) -> Dict[str, Any]:
"""Type text using the keyboard with Unicode support.
Args:
text: Text string to type
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1162,10 +1196,10 @@ class MacOSAutomationHandler(BaseAutomationHandler):
async def press_key(self, key: str) -> Dict[str, Any]:
"""Press and release a keyboard key.
Args:
key: Key name to press (using pyautogui key names)
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1178,10 +1212,10 @@ class MacOSAutomationHandler(BaseAutomationHandler):
async def hotkey(self, keys: List[str]) -> Dict[str, Any]:
"""Press a combination of keys simultaneously.
Args:
keys: List of key names to press together (using pyautogui key names)
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1195,11 +1229,11 @@ class MacOSAutomationHandler(BaseAutomationHandler):
# Scrolling Actions
async def scroll(self, x: int, y: int) -> Dict[str, Any]:
"""Scroll the mouse wheel in the specified direction.
Args:
x: Horizontal scroll amount
y: Vertical scroll amount (positive for up, negative for down)
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1208,13 +1242,13 @@ class MacOSAutomationHandler(BaseAutomationHandler):
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def scroll_down(self, clicks: int = 1) -> Dict[str, Any]:
"""Scroll down by the specified number of clicks.
Args:
clicks: Number of scroll clicks to perform
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1226,10 +1260,10 @@ class MacOSAutomationHandler(BaseAutomationHandler):
async def scroll_up(self, clicks: int = 1) -> Dict[str, Any]:
"""Scroll up by the specified number of clicks.
Args:
clicks: Number of scroll clicks to perform
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1242,7 +1276,7 @@ class MacOSAutomationHandler(BaseAutomationHandler):
# Screen Actions
async def screenshot(self) -> Dict[str, Any]:
"""Capture a screenshot of the current screen.
Returns:
Dictionary containing success status and base64-encoded image data or error message
"""
@@ -1263,7 +1297,7 @@ class MacOSAutomationHandler(BaseAutomationHandler):
async def get_screen_size(self) -> Dict[str, Any]:
"""Get the dimensions of the current screen.
Returns:
Dictionary containing success status and screen size or error message
"""
@@ -1275,7 +1309,7 @@ class MacOSAutomationHandler(BaseAutomationHandler):
async def get_cursor_position(self) -> Dict[str, Any]:
"""Get the current position of the mouse cursor.
Returns:
Dictionary containing success status and cursor position or error message
"""
@@ -1288,7 +1322,7 @@ class MacOSAutomationHandler(BaseAutomationHandler):
# Clipboard Actions
async def copy_to_clipboard(self) -> Dict[str, Any]:
"""Get the current content of the system clipboard.
Returns:
Dictionary containing success status and clipboard content or error message
"""
@@ -1302,10 +1336,10 @@ class MacOSAutomationHandler(BaseAutomationHandler):
async def set_clipboard(self, text: str) -> Dict[str, Any]:
"""Set the content of the system clipboard.
Args:
text: Text to copy to the clipboard
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1319,28 +1353,26 @@ class MacOSAutomationHandler(BaseAutomationHandler):
async def run_command(self, command: str) -> Dict[str, Any]:
"""Run a shell command and return its output.
Args:
command: Shell command to execute
Returns:
Dictionary containing success status, stdout, stderr, and return code
"""
try:
# Create subprocess
process = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
# Wait for the subprocess to finish
stdout, stderr = await process.communicate()
# Return decoded output
return {
"success": True,
"stdout": stdout.decode() if stdout else "",
"success": True,
"stdout": stdout.decode() if stdout else "",
"stderr": stderr.decode() if stderr else "",
"return_code": process.returncode
"return_code": process.returncode,
}
except Exception as e:
return {"success": False, "error": str(e)}

View File

@@ -4,15 +4,17 @@ Windows implementation of automation and accessibility handlers.
This implementation uses pyautogui for GUI automation and Windows-specific APIs
for accessibility and system operations.
"""
from typing import Dict, Any, List, Tuple, Optional
import logging
import subprocess
import asyncio
import base64
import logging
import os
import subprocess
from io import BytesIO
from pynput.mouse import Controller as MouseController
from typing import Any, Dict, List, Optional, Tuple
from pynput.keyboard import Controller as KeyboardController
from pynput.mouse import Controller as MouseController
# Configure logger
logger = logging.getLogger(__name__)
@@ -20,6 +22,7 @@ logger = logging.getLogger(__name__)
# Try to import pyautogui
try:
import pyautogui
pyautogui.FAILSAFE = False
logger.info("pyautogui successfully imported, GUI automation available")
except Exception as e:
@@ -28,58 +31,62 @@ except Exception as e:
# Try to import Windows-specific modules
try:
import win32gui
import win32con
import win32api
import win32con
import win32gui
logger.info("Windows API modules successfully imported")
WINDOWS_API_AVAILABLE = True
except Exception as e:
logger.error(f"Windows API modules import failed: {str(e)}. Some Windows-specific features will be unavailable.")
logger.error(
f"Windows API modules import failed: {str(e)}. Some Windows-specific features will be unavailable."
)
WINDOWS_API_AVAILABLE = False
from .base import BaseAccessibilityHandler, BaseAutomationHandler
class WindowsAccessibilityHandler(BaseAccessibilityHandler):
"""Windows implementation of accessibility handler."""
async def get_accessibility_tree(self) -> Dict[str, Any]:
"""Get the accessibility tree of the current window.
Returns:
Dict[str, Any]: A dictionary containing the success status and either
the accessibility tree or an error message.
Structure: {"success": bool, "tree": dict} or
Structure: {"success": bool, "tree": dict} or
{"success": bool, "error": str}
"""
if not WINDOWS_API_AVAILABLE:
return {"success": False, "error": "Windows API not available"}
try:
# Get the foreground window
hwnd = win32gui.GetForegroundWindow()
if not hwnd:
return {"success": False, "error": "No foreground window found"}
# Get window information
window_text = win32gui.GetWindowText(hwnd)
rect = win32gui.GetWindowRect(hwnd)
tree = {
"role": "Window",
"title": window_text,
"position": {"x": rect[0], "y": rect[1]},
"size": {"width": rect[2] - rect[0], "height": rect[3] - rect[1]},
"children": []
"children": [],
}
# Enumerate child windows
def enum_child_proc(hwnd_child, children_list):
"""Callback function to enumerate child windows and collect their information.
Args:
hwnd_child: Handle to the child window being enumerated.
children_list: List to append child window information to.
Returns:
bool: True to continue enumeration, False to stop.
"""
@@ -87,46 +94,49 @@ class WindowsAccessibilityHandler(BaseAccessibilityHandler):
child_text = win32gui.GetWindowText(hwnd_child)
child_rect = win32gui.GetWindowRect(hwnd_child)
child_class = win32gui.GetClassName(hwnd_child)
child_info = {
"role": child_class,
"title": child_text,
"position": {"x": child_rect[0], "y": child_rect[1]},
"size": {"width": child_rect[2] - child_rect[0], "height": child_rect[3] - child_rect[1]},
"children": []
"size": {
"width": child_rect[2] - child_rect[0],
"height": child_rect[3] - child_rect[1],
},
"children": [],
}
children_list.append(child_info)
except Exception as e:
logger.debug(f"Error getting child window info: {e}")
return True
win32gui.EnumChildWindows(hwnd, enum_child_proc, tree["children"])
return {"success": True, "tree": tree}
except Exception as e:
logger.error(f"Error getting accessibility tree: {e}")
return {"success": False, "error": str(e)}
async def find_element(self, role: Optional[str] = None,
title: Optional[str] = None,
value: Optional[str] = None) -> Dict[str, Any]:
async def find_element(
self, role: Optional[str] = None, title: Optional[str] = None, value: Optional[str] = None
) -> Dict[str, Any]:
"""Find an element in the accessibility tree by criteria.
Args:
role (Optional[str]): The role or class name of the element to find.
title (Optional[str]): The title or text of the element to find.
value (Optional[str]): The value of the element (not used in Windows implementation).
Returns:
Dict[str, Any]: A dictionary containing the success status and either
the found element or an error message.
Structure: {"success": bool, "element": dict} or
Structure: {"success": bool, "element": dict} or
{"success": bool, "error": str}
"""
if not WINDOWS_API_AVAILABLE:
return {"success": False, "error": "Windows API not available"}
try:
# Find window by title if specified
if title:
@@ -139,10 +149,10 @@ class WindowsAccessibilityHandler(BaseAccessibilityHandler):
"role": "Window",
"title": title,
"position": {"x": rect[0], "y": rect[1]},
"size": {"width": rect[2] - rect[0], "height": rect[3] - rect[1]}
}
"size": {"width": rect[2] - rect[0], "height": rect[3] - rect[1]},
},
}
# Find window by class name if role is specified
if role:
hwnd = win32gui.FindWindow(role, None)
@@ -155,37 +165,40 @@ class WindowsAccessibilityHandler(BaseAccessibilityHandler):
"role": role,
"title": window_text,
"position": {"x": rect[0], "y": rect[1]},
"size": {"width": rect[2] - rect[0], "height": rect[3] - rect[1]}
}
"size": {"width": rect[2] - rect[0], "height": rect[3] - rect[1]},
},
}
return {"success": False, "error": "Element not found"}
except Exception as e:
logger.error(f"Error finding element: {e}")
return {"success": False, "error": str(e)}
class WindowsAutomationHandler(BaseAutomationHandler):
"""Windows implementation of automation handler using pyautogui and Windows APIs."""
mouse = MouseController()
keyboard = KeyboardController()
# Mouse Actions
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
async def mouse_down(
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
) -> Dict[str, Any]:
"""Press and hold a mouse button at the specified coordinates.
Args:
x (Optional[int]): The x-coordinate to move to before pressing. If None, uses current position.
y (Optional[int]): The y-coordinate to move to before pressing. If None, uses current position.
button (str): The mouse button to press ("left", "right", or "middle").
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
if x is not None and y is not None:
pyautogui.moveTo(x, y)
@@ -193,21 +206,23 @@ class WindowsAutomationHandler(BaseAutomationHandler):
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
async def mouse_up(
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
) -> Dict[str, Any]:
"""Release a mouse button at the specified coordinates.
Args:
x (Optional[int]): The x-coordinate to move to before releasing. If None, uses current position.
y (Optional[int]): The y-coordinate to move to before releasing. If None, uses current position.
button (str): The mouse button to release ("left", "right", or "middle").
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
if x is not None and y is not None:
pyautogui.moveTo(x, y)
@@ -215,20 +230,20 @@ class WindowsAutomationHandler(BaseAutomationHandler):
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def move_cursor(self, x: int, y: int) -> Dict[str, Any]:
"""Move the mouse cursor to the specified coordinates.
Args:
x (int): The x-coordinate to move to.
y (int): The y-coordinate to move to.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
pyautogui.moveTo(x, y)
return {"success": True}
@@ -237,17 +252,17 @@ class WindowsAutomationHandler(BaseAutomationHandler):
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
"""Perform a left mouse click at the specified coordinates.
Args:
x (Optional[int]): The x-coordinate to click at. If None, clicks at current position.
y (Optional[int]): The y-coordinate to click at. If None, clicks at current position.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
if x is not None and y is not None:
pyautogui.moveTo(x, y)
@@ -258,17 +273,17 @@ class WindowsAutomationHandler(BaseAutomationHandler):
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
"""Perform a right mouse click at the specified coordinates.
Args:
x (Optional[int]): The x-coordinate to click at. If None, clicks at current position.
y (Optional[int]): The y-coordinate to click at. If None, clicks at current position.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
if x is not None and y is not None:
pyautogui.moveTo(x, y)
@@ -277,19 +292,21 @@ class WindowsAutomationHandler(BaseAutomationHandler):
except Exception as e:
return {"success": False, "error": str(e)}
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
async def double_click(
self, x: Optional[int] = None, y: Optional[int] = None
) -> Dict[str, Any]:
"""Perform a double left mouse click at the specified coordinates.
Args:
x (Optional[int]): The x-coordinate to double-click at. If None, clicks at current position.
y (Optional[int]): The y-coordinate to double-click at. If None, clicks at current position.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
if x is not None and y is not None:
pyautogui.moveTo(x, y)
@@ -298,52 +315,56 @@ class WindowsAutomationHandler(BaseAutomationHandler):
except Exception as e:
return {"success": False, "error": str(e)}
async def drag_to(self, x: int, y: int, button: str = "left", duration: float = 0.5) -> Dict[str, Any]:
async def drag_to(
self, x: int, y: int, button: str = "left", duration: float = 0.5
) -> Dict[str, Any]:
"""Drag from the current position to the specified coordinates.
Args:
x (int): The x-coordinate to drag to.
y (int): The y-coordinate to drag to.
button (str): The mouse button to use for dragging ("left", "right", or "middle").
duration (float): The time in seconds to take for the drag operation.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
pyautogui.dragTo(x, y, duration=duration, button=button)
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def drag(self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5) -> Dict[str, Any]:
async def drag(
self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5
) -> Dict[str, Any]:
"""Drag the mouse through a series of coordinates.
Args:
path (List[Tuple[int, int]]): A list of (x, y) coordinate tuples to drag through.
button (str): The mouse button to use for dragging ("left", "right", or "middle").
duration (float): The total time in seconds for the entire drag operation.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
if not path:
return {"success": False, "error": "Path is empty"}
# Move to first position
pyautogui.moveTo(*path[0])
# Drag through all positions
for x, y in path[1:]:
pyautogui.dragTo(x, y, duration=duration/len(path), button=button)
pyautogui.dragTo(x, y, duration=duration / len(path), button=button)
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
@@ -351,46 +372,46 @@ class WindowsAutomationHandler(BaseAutomationHandler):
# Keyboard Actions
async def key_down(self, key: str) -> Dict[str, Any]:
"""Press and hold a keyboard key.
Args:
key (str): The key to press down (e.g., 'ctrl', 'shift', 'a').
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
pyautogui.keyDown(key)
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def key_up(self, key: str) -> Dict[str, Any]:
"""Release a keyboard key.
Args:
key (str): The key to release (e.g., 'ctrl', 'shift', 'a').
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
pyautogui.keyUp(key)
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def type_text(self, text: str) -> Dict[str, Any]:
"""Type the specified text.
Args:
text (str): The text to type.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
@@ -403,16 +424,16 @@ class WindowsAutomationHandler(BaseAutomationHandler):
async def press_key(self, key: str) -> Dict[str, Any]:
"""Press and release a keyboard key.
Args:
key (str): The key to press (e.g., 'enter', 'space', 'tab').
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
pyautogui.press(key)
return {"success": True}
@@ -421,16 +442,16 @@ class WindowsAutomationHandler(BaseAutomationHandler):
async def hotkey(self, keys: List[str]) -> Dict[str, Any]:
"""Press a combination of keys simultaneously.
Args:
keys (List[str]): The keys to press together (e.g., ['ctrl', 'c'], ['alt', 'tab']).
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
pyautogui.hotkey(*keys)
return {"success": True}
@@ -440,35 +461,35 @@ class WindowsAutomationHandler(BaseAutomationHandler):
# Scrolling Actions
async def scroll(self, x: int, y: int) -> Dict[str, Any]:
"""Scroll vertically at the current cursor position.
Args:
x (int): Horizontal scroll amount (not used in pyautogui implementation).
y (int): Vertical scroll amount. Positive values scroll up, negative values scroll down.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
self.mouse.scroll(x, y)
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def scroll_down(self, clicks: int = 1) -> Dict[str, Any]:
"""Scroll down by the specified number of clicks.
Args:
clicks (int): The number of scroll clicks to perform downward.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
pyautogui.scroll(-clicks)
return {"success": True}
@@ -477,16 +498,16 @@ class WindowsAutomationHandler(BaseAutomationHandler):
async def scroll_up(self, clicks: int = 1) -> Dict[str, Any]:
"""Scroll up by the specified number of clicks.
Args:
clicks (int): The number of scroll clicks to perform upward.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
pyautogui.scroll(clicks)
return {"success": True}
@@ -496,22 +517,23 @@ class WindowsAutomationHandler(BaseAutomationHandler):
# Screen Actions
async def screenshot(self) -> Dict[str, Any]:
"""Capture a screenshot of the entire screen.
Returns:
Dict[str, Any]: A dictionary containing the success status and either
base64-encoded image data or an error message.
Structure: {"success": bool, "image_data": str} or
Structure: {"success": bool, "image_data": str} or
{"success": bool, "error": str}
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
from PIL import Image
screenshot = pyautogui.screenshot()
if not isinstance(screenshot, Image.Image):
return {"success": False, "error": "Failed to capture screenshot"}
buffered = BytesIO()
screenshot.save(buffered, format="PNG", optimize=True)
buffered.seek(0)
@@ -522,11 +544,11 @@ class WindowsAutomationHandler(BaseAutomationHandler):
async def get_screen_size(self) -> Dict[str, Any]:
"""Get the size of the screen in pixels.
Returns:
Dict[str, Any]: A dictionary containing the success status and either
screen size information or an error message.
Structure: {"success": bool, "size": {"width": int, "height": int}} or
Structure: {"success": bool, "size": {"width": int, "height": int}} or
{"success": bool, "error": str}
"""
try:
@@ -545,11 +567,11 @@ class WindowsAutomationHandler(BaseAutomationHandler):
async def get_cursor_position(self) -> Dict[str, Any]:
"""Get the current position of the mouse cursor.
Returns:
Dict[str, Any]: A dictionary containing the success status and either
cursor position or an error message.
Structure: {"success": bool, "position": {"x": int, "y": int}} or
Structure: {"success": bool, "position": {"x": int, "y": int}} or
{"success": bool, "error": str}
"""
try:
@@ -568,15 +590,16 @@ class WindowsAutomationHandler(BaseAutomationHandler):
# Clipboard Actions
async def copy_to_clipboard(self) -> Dict[str, Any]:
"""Get the current content of the clipboard.
Returns:
Dict[str, Any]: A dictionary containing the success status and either
clipboard content or an error message.
Structure: {"success": bool, "content": str} or
Structure: {"success": bool, "content": str} or
{"success": bool, "error": str}
"""
try:
import pyperclip
content = pyperclip.paste()
return {"success": True, "content": content}
except Exception as e:
@@ -584,15 +607,16 @@ class WindowsAutomationHandler(BaseAutomationHandler):
async def set_clipboard(self, text: str) -> Dict[str, Any]:
"""Set the clipboard content to the specified text.
Args:
text (str): The text to copy to the clipboard.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
try:
import pyperclip
pyperclip.copy(text)
return {"success": True}
except Exception as e:
@@ -601,31 +625,29 @@ class WindowsAutomationHandler(BaseAutomationHandler):
# Command Execution
async def run_command(self, command: str) -> Dict[str, Any]:
"""Execute a shell command asynchronously.
Args:
command (str): The shell command to execute.
Returns:
Dict[str, Any]: A dictionary containing the success status and either
command output or an error message.
Structure: {"success": bool, "stdout": str, "stderr": str, "return_code": int} or
Structure: {"success": bool, "stdout": str, "stderr": str, "return_code": int} or
{"success": bool, "error": str}
"""
try:
# Create subprocess
process = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
# Wait for the subprocess to finish
stdout, stderr = await process.communicate()
# Return decoded output
return {
"success": True,
"stdout": stdout.decode() if stdout else "",
"stderr": stderr.decode() if stderr else "",
"return_code": process.returncode
"success": True,
"stdout": stdout.decode() if stdout else "",
"stderr": stderr.decode() if stderr else "",
"return_code": process.returncode,
}
except Exception as e:
return {"success": False, "error": str(e)}

View File

@@ -1,27 +1,37 @@
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, HTTPException, Header
from fastapi.responses import StreamingResponse, JSONResponse
from typing import List, Dict, Any, Optional, Union, Literal, cast
import uvicorn
import logging
import asyncio
import json
import traceback
import inspect
from contextlib import redirect_stdout, redirect_stderr
from io import StringIO
from .handlers.factory import HandlerFactory
import os
import aiohttp
import hashlib
import time
import inspect
import json
import logging
import os
import platform
import time
import traceback
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
from typing import Any, Dict, List, Literal, Optional, Union, cast
import aiohttp
import uvicorn
from fastapi import (
FastAPI,
Header,
HTTPException,
Request,
WebSocket,
WebSocketDisconnect,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from .handlers.factory import HandlerFactory
# Authentication session TTL (in seconds). Override via env var CUA_AUTH_TTL_SECONDS. Default: 60s
AUTH_SESSION_TTL_SECONDS: int = int(os.environ.get("CUA_AUTH_TTL_SECONDS", "60"))
try:
from agent import ComputerAgent
HAS_AGENT = True
except ImportError:
HAS_AGENT = False
@@ -54,16 +64,20 @@ app.add_middleware(
protocol_version = 1
try:
from importlib.metadata import version
package_version = version("cua-computer-server")
except Exception:
# Fallback for cases where package is not installed or importlib.metadata is not available
try:
import pkg_resources
package_version = pkg_resources.get_distribution("cua-computer-server").version
except Exception:
package_version = "unknown"
accessibility_handler, automation_handler, diorama_handler, file_handler = HandlerFactory.create_handlers()
accessibility_handler, automation_handler, diorama_handler, file_handler = (
HandlerFactory.create_handlers()
)
handlers = {
"version": lambda: {"protocol": protocol_version, "package": package_version},
# App-Use commands
@@ -118,87 +132,91 @@ class AuthenticationManager:
def __init__(self):
self.sessions: Dict[str, Dict[str, Any]] = {}
self.container_name = os.environ.get("CONTAINER_NAME")
def _hash_credentials(self, container_name: str, api_key: str) -> str:
"""Create a hash of container name and API key for session identification"""
combined = f"{container_name}:{api_key}"
return hashlib.sha256(combined.encode()).hexdigest()
def _is_session_valid(self, session_data: Dict[str, Any]) -> bool:
"""Check if a session is still valid based on expiration time"""
if not session_data.get('valid', False):
if not session_data.get("valid", False):
return False
expires_at = session_data.get('expires_at', 0)
expires_at = session_data.get("expires_at", 0)
return time.time() < expires_at
async def auth(self, container_name: str, api_key: str) -> bool:
"""Authenticate container name and API key, using cached sessions when possible"""
# If no CONTAINER_NAME is set, always allow access (local development)
if not self.container_name:
logger.info("No CONTAINER_NAME set in environment. Allowing access (local development mode)")
logger.info(
"No CONTAINER_NAME set in environment. Allowing access (local development mode)"
)
return True
# Layer 1: VM Identity Verification
if container_name != self.container_name:
logger.warning(f"VM name mismatch. Expected: {self.container_name}, Got: {container_name}")
logger.warning(
f"VM name mismatch. Expected: {self.container_name}, Got: {container_name}"
)
return False
# Create hash for session lookup
session_hash = self._hash_credentials(container_name, api_key)
# Check if we have a valid cached session
if session_hash in self.sessions:
session_data = self.sessions[session_hash]
if self._is_session_valid(session_data):
logger.info(f"Using cached authentication for container: {container_name}")
return session_data['valid']
return session_data["valid"]
else:
# Remove expired session
del self.sessions[session_hash]
# No valid cached session, authenticate with API
logger.info(f"Authenticating with TryCUA API for container: {container_name}")
try:
async with aiohttp.ClientSession() as session:
headers = {
"Authorization": f"Bearer {api_key}"
}
headers = {"Authorization": f"Bearer {api_key}"}
async with session.get(
f"https://www.trycua.com/api/vm/auth?container_name={container_name}",
headers=headers,
) as resp:
is_valid = resp.status == 200 and bool((await resp.text()).strip())
# Cache the result with configurable expiration
self.sessions[session_hash] = {
'valid': is_valid,
'expires_at': time.time() + AUTH_SESSION_TTL_SECONDS
"valid": is_valid,
"expires_at": time.time() + AUTH_SESSION_TTL_SECONDS,
}
if is_valid:
logger.info(f"Authentication successful for container: {container_name}")
else:
logger.warning(f"Authentication failed for container: {container_name}. Status: {resp.status}")
logger.warning(
f"Authentication failed for container: {container_name}. Status: {resp.status}"
)
return is_valid
except aiohttp.ClientError as e:
logger.error(f"Failed to validate API key with TryCUA API: {str(e)}")
# Cache failed result to avoid repeated requests
self.sessions[session_hash] = {
'valid': False,
'expires_at': time.time() + AUTH_SESSION_TTL_SECONDS
"valid": False,
"expires_at": time.time() + AUTH_SESSION_TTL_SECONDS,
}
return False
except Exception as e:
logger.error(f"Unexpected error during authentication: {str(e)}")
# Cache failed result to avoid repeated requests
self.sessions[session_hash] = {
'valid': False,
'expires_at': time.time() + AUTH_SESSION_TTL_SECONDS
"valid": False,
"expires_at": time.time() + AUTH_SESSION_TTL_SECONDS,
}
return False
@@ -218,6 +236,7 @@ class ConnectionManager:
manager = ConnectionManager()
auth_manager = AuthenticationManager()
@app.get("/status")
async def status():
sys = platform.system().lower()
@@ -234,80 +253,67 @@ async def status():
features.append("agent")
return {"status": "ok", "os_type": os_type, "features": features}
@app.websocket("/ws", name="websocket_endpoint")
async def websocket_endpoint(websocket: WebSocket):
global handlers
# WebSocket message size is configured at the app or endpoint level, not on the instance
await manager.connect(websocket)
# Check if CONTAINER_NAME is set (indicating cloud provider)
server_container_name = os.environ.get("CONTAINER_NAME")
# If cloud provider, perform authentication handshake
if server_container_name:
try:
logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Waiting for authentication...")
logger.info(
f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Waiting for authentication..."
)
# Wait for authentication message
auth_data = await websocket.receive_json()
# Validate auth message format
if auth_data.get("command") != "authenticate":
await websocket.send_json({
"success": False,
"error": "First message must be authentication"
})
await websocket.send_json(
{"success": False, "error": "First message must be authentication"}
)
await websocket.close()
manager.disconnect(websocket)
return
# Extract credentials
client_api_key = auth_data.get("params", {}).get("api_key")
client_container_name = auth_data.get("params", {}).get("container_name")
# Validate credentials using AuthenticationManager
if not client_api_key:
await websocket.send_json({
"success": False,
"error": "API key required"
})
await websocket.send_json({"success": False, "error": "API key required"})
await websocket.close()
manager.disconnect(websocket)
return
if not client_container_name:
await websocket.send_json({
"success": False,
"error": "Container name required"
})
await websocket.send_json({"success": False, "error": "Container name required"})
await websocket.close()
manager.disconnect(websocket)
return
# Use AuthenticationManager for validation
is_authenticated = await auth_manager.auth(client_container_name, client_api_key)
if not is_authenticated:
await websocket.send_json({
"success": False,
"error": "Authentication failed"
})
await websocket.send_json({"success": False, "error": "Authentication failed"})
await websocket.close()
manager.disconnect(websocket)
return
logger.info(f"Authentication successful for VM: {client_container_name}")
await websocket.send_json({
"success": True,
"message": "Authentication successful"
})
await websocket.send_json({"success": True, "message": "Authentication successful"})
except Exception as e:
logger.error(f"Error during authentication handshake: {str(e)}")
await websocket.send_json({
"success": False,
"error": "Authentication failed"
})
await websocket.send_json({"success": False, "error": "Authentication failed"})
await websocket.close()
manager.disconnect(websocket)
return
@@ -330,7 +336,7 @@ async def websocket_endpoint(websocket: WebSocket):
handler_func = handlers[command]
sig = inspect.signature(handler_func)
filtered_params = {k: v for k, v in params.items() if k in sig.parameters}
# Handle both sync and async functions
if asyncio.iscoroutinefunction(handler_func):
result = await handler_func(**filtered_params)
@@ -367,20 +373,21 @@ async def websocket_endpoint(websocket: WebSocket):
pass
manager.disconnect(websocket)
@app.post("/cmd")
async def cmd_endpoint(
request: Request,
container_name: Optional[str] = Header(None, alias="X-Container-Name"),
api_key: Optional[str] = Header(None, alias="X-API-Key")
api_key: Optional[str] = Header(None, alias="X-API-Key"),
):
"""
Backup endpoint for when WebSocket connections fail.
Accepts commands via HTTP POST with streaming response.
Headers:
- X-Container-Name: Container name for cloud authentication
- X-API-Key: API key for cloud authentication
Body:
{
"command": "command_name",
@@ -388,7 +395,7 @@ async def cmd_endpoint(
}
"""
global handlers
# Parse request body
try:
body = await request.json()
@@ -396,32 +403,34 @@ async def cmd_endpoint(
params = body.get("params", {})
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON body: {str(e)}")
if not command:
raise HTTPException(status_code=400, detail="Command is required")
# Check if CONTAINER_NAME is set (indicating cloud provider)
server_container_name = os.environ.get("CONTAINER_NAME")
# If cloud provider, perform authentication
if server_container_name:
logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Performing authentication...")
logger.info(
f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Performing authentication..."
)
# Validate required headers
if not container_name:
raise HTTPException(status_code=401, detail="Container name required")
if not api_key:
raise HTTPException(status_code=401, detail="API key required")
# Validate with AuthenticationManager
is_authenticated = await auth_manager.auth(container_name, api_key)
if not is_authenticated:
raise HTTPException(status_code=401, detail="Authentication failed")
if command not in handlers:
raise HTTPException(status_code=400, detail=f"Unknown command: {command}")
async def generate_response():
"""Generate streaming response for the command execution"""
try:
@@ -429,35 +438,36 @@ async def cmd_endpoint(
handler_func = handlers[command]
sig = inspect.signature(handler_func)
filtered_params = {k: v for k, v in params.items() if k in sig.parameters}
# Handle both sync and async functions
if asyncio.iscoroutinefunction(handler_func):
result = await handler_func(**filtered_params)
else:
# Run sync functions in thread pool to avoid blocking event loop
result = await asyncio.to_thread(handler_func, **filtered_params)
# Stream the successful result
response_data = {"success": True, **result}
yield f"data: {json.dumps(response_data)}\n\n"
except Exception as cmd_error:
logger.error(f"Error executing command {command}: {str(cmd_error)}")
logger.error(traceback.format_exc())
# Stream the error result
error_data = {"success": False, "error": str(cmd_error)}
yield f"data: {json.dumps(error_data)}\n\n"
return StreamingResponse(
generate_response(),
media_type="text/plain",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
}
},
)
@app.post("/responses")
async def agent_response_endpoint(
request: Request,
@@ -480,11 +490,17 @@ async def agent_response_endpoint(
"""
if not HAS_AGENT:
raise HTTPException(status_code=501, detail="ComputerAgent not available")
# Authenticate via AuthenticationManager if running in cloud (CONTAINER_NAME set)
container_name = os.environ.get("CONTAINER_NAME")
if container_name:
is_public = os.environ.get("CUA_ENABLE_PUBLIC_PROXY", "").lower().strip() in ["1", "true", "yes", "y", "on"]
is_public = os.environ.get("CUA_ENABLE_PUBLIC_PROXY", "").lower().strip() in [
"1",
"true",
"yes",
"y",
"on",
]
if not is_public:
if not api_key:
raise HTTPException(status_code=401, detail="Missing AGENT PROXY auth headers")
@@ -511,10 +527,12 @@ async def agent_response_endpoint(
def __init__(self, overrides: Dict[str, str]):
self.overrides = overrides
self._original: Dict[str, Optional[str]] = {}
def __enter__(self):
for k, v in (self.overrides or {}).items():
self._original[k] = os.environ.get(k)
os.environ[k] = str(v)
def __exit__(self, exc_type, exc, tb):
for k, old in self._original.items():
if old is None:
@@ -598,9 +616,9 @@ async def agent_response_endpoint(
start = path[0]
await self._auto.mouse_down(start["x"], start["y"])
for pt in path[1:]:
await self._auto.move_cursor(pt["x"], pt["y"])
await self._auto.move_cursor(pt["x"], pt["y"])
end = path[-1]
await self._auto.mouse_up(end["x"], end["y"])
await self._auto.mouse_up(end["x"], end["y"])
async def get_current_url(self) -> str:
# Not available in this server context
@@ -667,7 +685,11 @@ async def agent_response_endpoint(
async for result in agent.run(messages):
total_output += result["output"]
# Try to collect usage if present
if isinstance(result, dict) and "usage" in result and isinstance(result["usage"], dict):
if (
isinstance(result, dict)
and "usage" in result
and isinstance(result["usage"], dict)
):
# Merge usage counters
for k, v in result["usage"].items():
if isinstance(v, (int, float)):
@@ -686,14 +708,14 @@ async def agent_response_endpoint(
logger.error(f"Error running agent: {str(e)}")
logger.error(traceback.format_exc())
error = str(e)
# Build response payload
payload = {
"model": model,
"error": error,
"output": total_output,
"usage": total_usage,
"status": "completed" if not error else "failed"
"status": "completed" if not error else "failed",
}
# CORS: allow any origin

View File

@@ -5,8 +5,9 @@ Provides a clean API for starting and stopping the server.
import asyncio
import logging
import uvicorn
from typing import Optional
import uvicorn
from fastapi import FastAPI
from .main import app as fastapi_app
@@ -32,8 +33,14 @@ class Server:
await server.stop() # Stop the server
"""
def __init__(self, host: str = "0.0.0.0", port: int = 8000, log_level: str = "info",
ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None):
def __init__(
self,
host: str = "0.0.0.0",
port: int = 8000,
log_level: str = "info",
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
):
"""
Initialize the server.
@@ -58,12 +65,12 @@ class Server:
Start the server synchronously. This will block until the server is stopped.
"""
uvicorn.run(
self.app,
host=self.host,
port=self.port,
self.app,
host=self.host,
port=self.port,
log_level=self.log_level,
ssl_keyfile=self.ssl_keyfile,
ssl_certfile=self.ssl_certfile
ssl_certfile=self.ssl_certfile,
)
async def start_async(self) -> None:
@@ -72,12 +79,12 @@ class Server:
will run in the background.
"""
server_config = uvicorn.Config(
self.app,
host=self.host,
port=self.port,
self.app,
host=self.host,
port=self.port,
log_level=self.log_level,
ssl_keyfile=self.ssl_keyfile,
ssl_certfile=self.ssl_certfile
ssl_certfile=self.ssl_certfile,
)
self._should_exit.clear()

View File

@@ -12,9 +12,10 @@ import platform
import subprocess
import sys
import time
import websockets
from typing import Optional
import websockets
logger = logging.getLogger(__name__)
@@ -45,62 +46,62 @@ class Watchdog:
"""Watchdog class to monitor server health via WebSocket connection.
Unix/Linux only - provides restart capabilities.
"""
def __init__(self, cli_args: Optional[dict] = None, ping_interval: int = 30):
"""
Initialize the watchdog.
Args:
cli_args: Dictionary of CLI arguments to replicate when restarting
ping_interval: Interval between ping checks in seconds
"""
# Check if running on Unix/Linux
if platform.system() not in ['Linux', 'Darwin']:
if platform.system() not in ["Linux", "Darwin"]:
raise RuntimeError("Watchdog is only supported on Unix/Linux systems")
# Store CLI arguments for restart
self.cli_args = cli_args or {}
self.host = self.cli_args.get('host', 'localhost')
self.port = self.cli_args.get('port', 8000)
self.host = self.cli_args.get("host", "localhost")
self.port = self.cli_args.get("port", 8000)
self.ping_interval = ping_interval
self.container_name = os.environ.get("CONTAINER_NAME")
self.running = False
self.restart_enabled = True
@property
def ws_uri(self) -> str:
"""Get the WebSocket URI using the current IP address.
Returns:
WebSocket URI for the Computer API Server
"""
ip_address = "localhost" if not self.container_name else f"{self.container_name}.containers.cloud.trycua.com"
ip_address = (
"localhost"
if not self.container_name
else f"{self.container_name}.containers.cloud.trycua.com"
)
protocol = "wss" if self.container_name else "ws"
port = "8443" if self.container_name else "8000"
return f"{protocol}://{ip_address}:{port}/ws"
async def ping(self) -> bool:
"""
Test connection to the WebSocket endpoint.
Returns:
True if connection successful, False otherwise
"""
try:
# Create a simple ping message
ping_message = {
"command": "get_screen_size",
"params": {}
}
ping_message = {"command": "get_screen_size", "params": {}}
# Try to connect to the WebSocket
async with websockets.connect(
self.ws_uri,
max_size=1024 * 1024 * 10 # 10MB limit to match server
self.ws_uri, max_size=1024 * 1024 * 10 # 10MB limit to match server
) as websocket:
# Send ping message
await websocket.send(json.dumps(ping_message))
# Wait for any response or just close
try:
response = await asyncio.wait_for(websocket.recv(), timeout=5)
@@ -111,30 +112,27 @@ class Watchdog:
except Exception as e:
logger.warning(f"Ping failed: {e}")
return False
def kill_processes_on_port(self, port: int) -> bool:
"""
Kill any processes using the specified port.
Args:
port: Port number to check and kill processes on
Returns:
True if processes were killed or none found, False on error
"""
try:
# Find processes using the port
result = subprocess.run(
["lsof", "-ti", f":{port}"],
capture_output=True,
text=True,
timeout=10
["lsof", "-ti", f":{port}"], capture_output=True, text=True, timeout=10
)
if result.returncode == 0 and result.stdout.strip():
pids = result.stdout.strip().split('\n')
pids = result.stdout.strip().split("\n")
logger.info(f"Found {len(pids)} processes using port {port}: {pids}")
# Kill each process
for pid in pids:
if pid.strip():
@@ -145,42 +143,42 @@ class Watchdog:
logger.warning(f"Timeout killing process {pid}")
except Exception as e:
logger.warning(f"Error killing process {pid}: {e}")
return True
else:
logger.debug(f"No processes found using port {port}")
return True
except subprocess.TimeoutExpired:
logger.error(f"Timeout finding processes on port {port}")
return False
except Exception as e:
logger.error(f"Error finding processes on port {port}: {e}")
return False
def restart_server(self) -> bool:
"""
Attempt to restart the server by killing existing processes and starting new one.
Returns:
True if restart was attempted, False on error
"""
if not self.restart_enabled:
logger.info("Server restart is disabled")
return False
try:
logger.info("Attempting to restart server...")
# Kill processes on the port
port_to_kill = 8443 if self.container_name else self.port
if not self.kill_processes_on_port(port_to_kill):
logger.error("Failed to kill processes on port, restart aborted")
return False
# Wait a moment for processes to die
time.sleep(2)
# Try to restart the server
# In container mode, we can't easily restart, so just log
if self.container_name:
@@ -190,50 +188,50 @@ class Watchdog:
else:
# For local mode, try to restart the CLI
logger.info("Attempting to restart local server...")
# Get the current Python executable and script
python_exe = sys.executable
# Try to find the CLI module
try:
# Build command with all original CLI arguments
cmd = [python_exe, "-m", "computer_server.cli"]
# Add all CLI arguments except watchdog-related ones
for key, value in self.cli_args.items():
if key in ['watchdog', 'watchdog_interval', 'no_restart']:
if key in ["watchdog", "watchdog_interval", "no_restart"]:
continue # Skip watchdog args to avoid recursive watchdog
# Convert underscores to hyphens for CLI args
arg_name = f"--{key.replace('_', '-')}"
if isinstance(value, bool):
if value: # Only add flag if True
cmd.append(arg_name)
else:
cmd.extend([arg_name, str(value)])
logger.info(f"Starting server with command: {' '.join(cmd)}")
# Start process in background
subprocess.Popen(
cmd,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
start_new_session=True
start_new_session=True,
)
logger.info("Server restart initiated")
return True
except Exception as e:
logger.error(f"Failed to restart server: {e}")
return False
except Exception as e:
logger.error(f"Error during server restart: {e}")
return False
async def start_monitoring(self) -> None:
"""Start the watchdog monitoring loop."""
self.running = True
@@ -241,14 +239,14 @@ class Watchdog:
logger.info(f"Ping interval: {self.ping_interval} seconds")
if self.container_name:
logger.info(f"Container mode detected: {self.container_name}")
consecutive_failures = 0
max_failures = 3
while self.running:
try:
success = await self.ping()
if success:
if consecutive_failures > 0:
logger.info("Server connection restored")
@@ -257,15 +255,17 @@ class Watchdog:
else:
consecutive_failures += 1
logger.warning(f"Ping failed ({consecutive_failures}/{max_failures})")
if consecutive_failures >= max_failures:
logger.error(f"Server appears to be down after {max_failures} consecutive failures")
logger.error(
f"Server appears to be down after {max_failures} consecutive failures"
)
# Attempt to restart the server
if self.restart_enabled:
logger.info("Attempting automatic server restart...")
restart_success = self.restart_server()
if restart_success:
logger.info("Server restart initiated, waiting before next ping...")
# Wait longer after restart attempt
@@ -275,17 +275,17 @@ class Watchdog:
logger.error("Server restart failed")
else:
logger.warning("Automatic restart is disabled")
# Wait for next ping interval
await asyncio.sleep(self.ping_interval)
except asyncio.CancelledError:
logger.info("Watchdog monitoring cancelled")
break
except Exception as e:
logger.error(f"Unexpected error in watchdog loop: {e}")
await asyncio.sleep(self.ping_interval)
def stop_monitoring(self) -> None:
"""Stop the watchdog monitoring."""
self.running = False
@@ -295,13 +295,13 @@ class Watchdog:
async def run_watchdog(cli_args: Optional[dict] = None, ping_interval: int = 30) -> None:
"""
Run the watchdog monitoring.
Args:
cli_args: Dictionary of CLI arguments to replicate when restarting
ping_interval: Interval between ping checks in seconds
"""
watchdog = Watchdog(cli_args=cli_args, ping_interval=ping_interval)
try:
await watchdog.start_monitoring()
except KeyboardInterrupt:
@@ -313,21 +313,18 @@ async def run_watchdog(cli_args: Optional[dict] = None, ping_interval: int = 30)
if __name__ == "__main__":
# For testing the watchdog standalone
import argparse
parser = argparse.ArgumentParser(description="Run Computer API server watchdog")
parser.add_argument("--host", default="localhost", help="Server host to monitor")
parser.add_argument("--port", type=int, default=8000, help="Server port to monitor")
parser.add_argument("--ping-interval", type=int, default=30, help="Ping interval in seconds")
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
cli_args = {
'host': args.host,
'port': args.port
}
cli_args = {"host": args.host, "port": args.port}
asyncio.run(run_watchdog(cli_args, args.ping_interval))

View File

@@ -10,6 +10,7 @@ Usage:
"""
import sys
from computer_server.cli import main
if __name__ == "__main__":

View File

@@ -6,18 +6,22 @@ This script tests both WebSocket (/ws) and REST (/cmd) connections to the Comput
and keeps it alive, allowing you to verify the server is running correctly.
"""
import argparse
import asyncio
import json
import websockets
import argparse
import sys
import aiohttp
import os
import sys
import aiohttp
import dotenv
import websockets
dotenv.load_dotenv()
async def test_websocket_connection(host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None):
async def test_websocket_connection(
host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None
):
"""Test WebSocket connection to the Computer Server."""
if container_name:
# Container mode: use WSS with container domain and port 8443
@@ -37,19 +41,16 @@ async def test_websocket_connection(host="localhost", port=8000, keep_alive=Fals
if not api_key:
print("Error: API key required for container connections")
return False
print("Sending authentication...")
auth_message = {
"command": "authenticate",
"params": {
"api_key": api_key,
"container_name": container_name
}
"params": {"api_key": api_key, "container_name": container_name},
}
await websocket.send(json.dumps(auth_message))
auth_response = await websocket.recv()
print(f"Authentication response: {auth_response}")
# Check if authentication was successful
auth_data = json.loads(auth_response)
if not auth_data.get("success", False):
@@ -90,7 +91,9 @@ async def test_websocket_connection(host="localhost", port=8000, keep_alive=Fals
return True
async def test_rest_connection(host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None):
async def test_rest_connection(
host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None
):
"""Test REST connection to the Computer Server."""
if container_name:
# Container mode: use HTTPS with container domain and port 8443
@@ -113,13 +116,11 @@ async def test_rest_connection(host="localhost", port=8000, keep_alive=False, co
return False
headers["X-Container-Name"] = container_name
headers["X-API-Key"] = api_key
print(f"Using container authentication headers")
print("Using container authentication headers")
# Test screenshot endpoint
async with session.post(
f"{base_url}/cmd",
json={"command": "screenshot", "params": {}},
headers=headers
f"{base_url}/cmd", json={"command": "screenshot", "params": {}}, headers=headers
) as response:
if response.status == 200:
text = await response.text()
@@ -133,7 +134,7 @@ async def test_rest_connection(host="localhost", port=8000, keep_alive=False, co
async with session.post(
f"{base_url}/cmd",
json={"command": "get_screen_size", "params": {}},
headers=headers
headers=headers,
) as response:
if response.status == 200:
text = await response.text()
@@ -151,7 +152,7 @@ async def test_rest_connection(host="localhost", port=8000, keep_alive=False, co
async with session.post(
f"{base_url}/cmd",
json={"command": "get_cursor_position", "params": {}},
headers=headers
headers=headers,
) as response:
if response.status == 200:
text = await response.text()
@@ -171,7 +172,9 @@ async def test_rest_connection(host="localhost", port=8000, keep_alive=False, co
return True
async def test_connection(host="localhost", port=8000, keep_alive=False, container_name=None, use_rest=False, api_key=None):
async def test_connection(
host="localhost", port=8000, keep_alive=False, container_name=None, use_rest=False, api_key=None
):
"""Test connection to the Computer Server using WebSocket or REST."""
if use_rest:
return await test_rest_connection(host, port, keep_alive, container_name, api_key)
@@ -183,40 +186,50 @@ def parse_args():
parser = argparse.ArgumentParser(description="Test connection to Computer Server")
parser.add_argument("--host", default="localhost", help="Host address (default: localhost)")
parser.add_argument("-p", "--port", type=int, default=8000, help="Port number (default: 8000)")
parser.add_argument("-c", "--container-name", help="Container name for cloud connection (uses WSS/HTTPS and port 8443)")
parser.add_argument("--api-key", help="API key for container authentication (can also use CUA_API_KEY env var)")
parser.add_argument(
"-c",
"--container-name",
help="Container name for cloud connection (uses WSS/HTTPS and port 8443)",
)
parser.add_argument(
"--api-key", help="API key for container authentication (can also use CUA_API_KEY env var)"
)
parser.add_argument("--keep-alive", action="store_true", help="Keep connection alive")
parser.add_argument("--rest", action="store_true", help="Use REST endpoint (/cmd) instead of WebSocket (/ws)")
parser.add_argument(
"--rest", action="store_true", help="Use REST endpoint (/cmd) instead of WebSocket (/ws)"
)
return parser.parse_args()
async def main():
args = parse_args()
# Convert hyphenated argument to underscore for function parameter
container_name = getattr(args, 'container_name', None)
container_name = getattr(args, "container_name", None)
# Get API key from argument or environment variable
api_key = getattr(args, 'api_key', None) or os.environ.get('CUA_API_KEY')
api_key = getattr(args, "api_key", None) or os.environ.get("CUA_API_KEY")
# Check if container name is provided but API key is missing
if container_name and not api_key:
print("Warning: Container name provided but no API key found.")
print("Please provide --api-key argument or set CUA_API_KEY environment variable.")
return 1
print(f"Testing {'REST' if args.rest else 'WebSocket'} connection...")
if container_name:
print(f"Container: {container_name}")
print(f"API Key: {'***' + api_key[-4:] if api_key and len(api_key) > 4 else 'Not provided'}")
print(
f"API Key: {'***' + api_key[-4:] if api_key and len(api_key) > 4 else 'Not provided'}"
)
success = await test_connection(
host=args.host,
port=args.port,
host=args.host,
port=args.port,
keep_alive=args.keep_alive,
container_name=container_name,
use_rest=args.rest,
api_key=api_key
api_key=api_key,
)
return 0 if success else 1