mirror of
https://github.com/trycua/computer.git
synced 2026-01-01 11:00:31 -06:00
Merge branch 'main' into feat/fara-browser-use
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.5.1
|
||||
current_version = 0.5.2
|
||||
commit = True
|
||||
tag = True
|
||||
tag_name = agent-v{new_version}
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
|
||||
|
||||
[project]
|
||||
name = "cua-agent"
|
||||
version = "0.5.1"
|
||||
version = "0.5.2"
|
||||
description = "CUA (Computer Use) Agent for AI-driven computer interaction"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
|
||||
26
libs/python/bench-ui/README.md
Normal file
26
libs/python/bench-ui/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# CUA Bench UI
|
||||
|
||||
Lightweight webUI window controller for CUA bench environments using pywebview
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from bench_ui import launch_window, get_element_rect, execute_javascript
|
||||
|
||||
# Launch a window with inline HTML content
|
||||
pid = launch_window(html="<html><body><h1>Hello</h1></body></html>")
|
||||
|
||||
# Get element rect in screen space
|
||||
rect = get_element_rect(pid, "h1", space="screen")
|
||||
print(rect)
|
||||
|
||||
# Execute arbitrary JavaScript
|
||||
text = execute_javascript(pid, "document.querySelector('h1')?.textContent")
|
||||
print(text)
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install cua-bench-ui
|
||||
```
|
||||
3
libs/python/bench-ui/bench_ui/__init__.py
Normal file
3
libs/python/bench-ui/bench_ui/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .api import execute_javascript, get_element_rect, launch_window
|
||||
|
||||
__all__ = ["launch_window", "get_element_rect", "execute_javascript"]
|
||||
181
libs/python/bench-ui/bench_ui/api.py
Normal file
181
libs/python/bench-ui/bench_ui/api.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib import request
|
||||
from urllib.error import HTTPError, URLError
|
||||
|
||||
import psutil
|
||||
|
||||
# Map child PID -> listening port
|
||||
_pid_to_port: Dict[int, int] = {}
|
||||
|
||||
|
||||
def _post_json(url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
req = request.Request(
|
||||
url, data=data, headers={"Content-Type": "application/json"}, method="POST"
|
||||
)
|
||||
try:
|
||||
with request.urlopen(req, timeout=5) as resp:
|
||||
text = resp.read().decode("utf-8")
|
||||
return json.loads(text)
|
||||
except HTTPError as e:
|
||||
try:
|
||||
body = (e.read() or b"").decode("utf-8", errors="ignore")
|
||||
return json.loads(body)
|
||||
except Exception:
|
||||
return {"error": "http_error", "status": getattr(e, "code", None)}
|
||||
except URLError as e:
|
||||
return {"error": "url_error", "reason": str(e.reason)}
|
||||
|
||||
|
||||
def _detect_port_for_pid(pid: int) -> int:
|
||||
"""Detect a listening local TCP port for the given PID using psutil.
|
||||
|
||||
Fails fast if psutil is unavailable or if no suitable port is found.
|
||||
"""
|
||||
if psutil is None:
|
||||
raise RuntimeError("psutil is required for PID->port detection. Please install psutil.")
|
||||
|
||||
# Scan system-wide connections and filter by PID
|
||||
for c in psutil.net_connections(kind="tcp"):
|
||||
if getattr(c, "pid", None) != pid:
|
||||
continue
|
||||
laddr = getattr(c, "laddr", None)
|
||||
status = str(getattr(c, "status", ""))
|
||||
if not laddr or not isinstance(laddr, tuple) or len(laddr) < 2:
|
||||
continue
|
||||
lip, lport = laddr[0], int(laddr[1])
|
||||
if status.upper() != "LISTEN":
|
||||
continue
|
||||
if lip in ("127.0.0.1", "::1", "0.0.0.0", "::"):
|
||||
return lport
|
||||
|
||||
raise RuntimeError(f"Could not detect listening port for pid {pid}")
|
||||
|
||||
|
||||
def launch_window(
|
||||
url: Optional[str] = None,
|
||||
*,
|
||||
html: Optional[str] = None,
|
||||
folder: Optional[str] = None,
|
||||
title: str = "Window",
|
||||
x: Optional[int] = None,
|
||||
y: Optional[int] = None,
|
||||
width: int = 600,
|
||||
height: int = 400,
|
||||
icon: Optional[str] = None,
|
||||
use_inner_size: bool = False,
|
||||
title_bar_style: str = "default",
|
||||
) -> int:
|
||||
"""Create a pywebview window in a child process and return its PID.
|
||||
|
||||
Preferred input is a URL via the positional `url` parameter.
|
||||
To load inline HTML instead, pass `html=...`.
|
||||
To serve a static folder, pass `folder=...` (path to directory).
|
||||
|
||||
Spawns `python -m bench_ui.child` with a JSON config passed via a temp file.
|
||||
The child prints a single JSON line: {"pid": <pid>, "port": <port>}.
|
||||
We cache pid->port for subsequent control calls like get_element_rect.
|
||||
"""
|
||||
if not url and not html and not folder:
|
||||
raise ValueError("launch_window requires either a url, html, or folder")
|
||||
|
||||
config = {
|
||||
"url": url,
|
||||
"html": html,
|
||||
"folder": folder,
|
||||
"title": title,
|
||||
"x": x,
|
||||
"y": y,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"icon": icon,
|
||||
"use_inner_size": use_inner_size,
|
||||
"title_bar_style": title_bar_style,
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f:
|
||||
json.dump(config, f)
|
||||
cfg_path = f.name
|
||||
|
||||
try:
|
||||
# Launch child process
|
||||
proc = subprocess.Popen(
|
||||
[sys.executable, "-m", "bench_ui.child", cfg_path],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
assert proc.stdout is not None
|
||||
# Read first line with startup info
|
||||
line = proc.stdout.readline().strip()
|
||||
info = json.loads(line)
|
||||
pid = int(info["pid"]) if "pid" in info else proc.pid
|
||||
port = int(info["port"]) # required
|
||||
_pid_to_port[pid] = port
|
||||
return pid
|
||||
finally:
|
||||
try:
|
||||
os.unlink(cfg_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_element_rect(pid: int, selector: str, *, space: str = "window"):
|
||||
"""Ask the child process to compute element client rect via injected JS.
|
||||
|
||||
Returns a dict like {"x": float, "y": float, "width": float, "height": float} or None if not found.
|
||||
"""
|
||||
if pid not in _pid_to_port:
|
||||
_pid_to_port[pid] = _detect_port_for_pid(pid)
|
||||
port = _pid_to_port[pid]
|
||||
url = f"http://127.0.0.1:{port}/rect"
|
||||
last: Dict[str, Any] = {}
|
||||
for _ in range(30): # ~3s total
|
||||
resp = _post_json(url, {"selector": selector, "space": space})
|
||||
last = resp or {}
|
||||
rect = last.get("rect") if isinstance(last, dict) else None
|
||||
err = last.get("error") if isinstance(last, dict) else None
|
||||
if rect is not None:
|
||||
return rect
|
||||
if err in ("window_not_ready", "invalid_json"):
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
# If other transient errors, brief retry
|
||||
if err:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
time.sleep(0.1)
|
||||
raise RuntimeError(f"Failed to get element rect: {last}")
|
||||
|
||||
|
||||
def execute_javascript(pid: int, javascript: str):
|
||||
"""Execute arbitrary JavaScript in the window and return its result.
|
||||
|
||||
Retries briefly while the window is still becoming ready.
|
||||
"""
|
||||
if pid not in _pid_to_port:
|
||||
_pid_to_port[pid] = _detect_port_for_pid(pid)
|
||||
port = _pid_to_port[pid]
|
||||
url = f"http://127.0.0.1:{port}/eval"
|
||||
last: Dict[str, Any] = {}
|
||||
for _ in range(30): # ~3s total
|
||||
resp = _post_json(url, {"javascript": javascript})
|
||||
last = resp or {}
|
||||
if isinstance(last, dict):
|
||||
if "result" in last:
|
||||
return last["result"]
|
||||
if last.get("error") in ("window_not_ready", "invalid_json"):
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
if last.get("error"):
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
time.sleep(0.1)
|
||||
raise RuntimeError(f"Failed to execute JavaScript: {last}")
|
||||
221
libs/python/bench-ui/bench_ui/child.py
Normal file
221
libs/python/bench-ui/bench_ui/child.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import webview
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
def _get_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def _start_http_server(
|
||||
window: webview.Window,
|
||||
port: int,
|
||||
ready_event: threading.Event,
|
||||
html_content: str | None = None,
|
||||
folder_path: str | None = None,
|
||||
):
|
||||
async def rect_handler(request: web.Request):
|
||||
try:
|
||||
data = await request.json()
|
||||
except Exception:
|
||||
return web.json_response({"error": "invalid_json"}, status=400)
|
||||
selector = data.get("selector")
|
||||
space = data.get("space", "window")
|
||||
if not isinstance(selector, str):
|
||||
return web.json_response({"error": "selector_required"}, status=400)
|
||||
|
||||
# Ensure window content is loaded
|
||||
if not ready_event.is_set():
|
||||
# give it a short chance to finish loading
|
||||
ready_event.wait(timeout=2.0)
|
||||
if not ready_event.is_set():
|
||||
return web.json_response({"error": "window_not_ready"}, status=409)
|
||||
|
||||
# Safely embed selector into JS
|
||||
selector_js = json.dumps(selector)
|
||||
if space == "screen":
|
||||
# Compute approximate screen coordinates using window metrics
|
||||
js = (
|
||||
"(function(){"
|
||||
f"const s = {selector_js};"
|
||||
"const el = document.querySelector(s);"
|
||||
"if(!el){return null;}"
|
||||
"const r = el.getBoundingClientRect();"
|
||||
"const sx = (window.screenX ?? window.screenLeft ?? 0);"
|
||||
"const syRaw = (window.screenY ?? window.screenTop ?? 0);"
|
||||
"const frameH = (window.outerHeight - window.innerHeight) || 0;"
|
||||
"const sy = syRaw + frameH;"
|
||||
"return {x:sx + r.left, y:sy + r.top, width:r.width, height:r.height};"
|
||||
"})()"
|
||||
)
|
||||
else:
|
||||
js = (
|
||||
"(function(){"
|
||||
f"const s = {selector_js};"
|
||||
"const el = document.querySelector(s);"
|
||||
"if(!el){return null;}"
|
||||
"const r = el.getBoundingClientRect();"
|
||||
"return {x:r.left,y:r.top,width:r.width,height:r.height};"
|
||||
"})()"
|
||||
)
|
||||
try:
|
||||
# Evaluate JS on the target window; this call is thread-safe in pywebview
|
||||
result = window.evaluate_js(js)
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
return web.json_response({"rect": result})
|
||||
|
||||
async def eval_handler(request: web.Request):
|
||||
try:
|
||||
data = await request.json()
|
||||
except Exception:
|
||||
return web.json_response({"error": "invalid_json"}, status=400)
|
||||
code = data.get("javascript") or data.get("code")
|
||||
if not isinstance(code, str):
|
||||
return web.json_response({"error": "javascript_required"}, status=400)
|
||||
|
||||
if not ready_event.is_set():
|
||||
ready_event.wait(timeout=2.0)
|
||||
if not ready_event.is_set():
|
||||
return web.json_response({"error": "window_not_ready"}, status=409)
|
||||
|
||||
try:
|
||||
result = window.evaluate_js(code)
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
return web.json_response({"result": result})
|
||||
|
||||
async def index_handler(request: web.Request):
|
||||
if html_content is None:
|
||||
return web.json_response({"status": "ok", "message": "bench-ui control server"})
|
||||
return web.Response(text=html_content, content_type="text/html")
|
||||
|
||||
app = web.Application()
|
||||
|
||||
# If serving a folder, add static file routes
|
||||
if folder_path:
|
||||
app.router.add_static("/", folder_path, show_index=True)
|
||||
else:
|
||||
app.router.add_get("/", index_handler)
|
||||
|
||||
app.router.add_post("/rect", rect_handler)
|
||||
app.router.add_post("/eval", eval_handler)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
def run_loop():
|
||||
asyncio.set_event_loop(loop)
|
||||
runner = web.AppRunner(app)
|
||||
loop.run_until_complete(runner.setup())
|
||||
site = web.TCPSite(runner, "127.0.0.1", port)
|
||||
loop.run_until_complete(site.start())
|
||||
loop.run_forever()
|
||||
|
||||
t = threading.Thread(target=run_loop, daemon=True)
|
||||
t.start()
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python -m bench_ui.child <config.json>", file=sys.stderr)
|
||||
sys.exit(2)
|
||||
|
||||
cfg_path = Path(sys.argv[1])
|
||||
cfg = json.loads(cfg_path.read_text(encoding="utf-8"))
|
||||
|
||||
html: Optional[str] = cfg.get("html") or ""
|
||||
url: Optional[str] = cfg.get("url")
|
||||
folder: Optional[str] = cfg.get("folder")
|
||||
title: str = cfg.get("title", "Window")
|
||||
x: Optional[int] = cfg.get("x")
|
||||
y: Optional[int] = cfg.get("y")
|
||||
width: int = int(cfg.get("width", 600))
|
||||
height: int = int(cfg.get("height", 400))
|
||||
icon: Optional[str] = cfg.get("icon")
|
||||
use_inner_size: bool = bool(cfg.get("use_inner_size", False))
|
||||
title_bar_style: str = cfg.get("title_bar_style", "default")
|
||||
|
||||
# Choose port early so we can point the window to it when serving inline HTML or folder
|
||||
port = _get_free_port()
|
||||
|
||||
# Create window
|
||||
if url:
|
||||
window = webview.create_window(
|
||||
title,
|
||||
url=url,
|
||||
width=width,
|
||||
height=height,
|
||||
x=x,
|
||||
y=y,
|
||||
confirm_close=False,
|
||||
text_select=True,
|
||||
background_color="#FFFFFF",
|
||||
)
|
||||
html_for_server = None
|
||||
folder_for_server = None
|
||||
elif folder:
|
||||
# Serve static folder at control server root and point window to index.html
|
||||
resolved_url = f"http://127.0.0.1:{port}/index.html"
|
||||
window = webview.create_window(
|
||||
title,
|
||||
url=resolved_url,
|
||||
width=width,
|
||||
height=height,
|
||||
x=x,
|
||||
y=y,
|
||||
confirm_close=False,
|
||||
text_select=True,
|
||||
background_color="#FFFFFF",
|
||||
)
|
||||
html_for_server = None
|
||||
folder_for_server = folder
|
||||
else:
|
||||
# Serve inline HTML at control server root and point window to it
|
||||
resolved_url = f"http://127.0.0.1:{port}/"
|
||||
window = webview.create_window(
|
||||
title,
|
||||
url=resolved_url,
|
||||
width=width,
|
||||
height=height,
|
||||
x=x,
|
||||
y=y,
|
||||
confirm_close=False,
|
||||
text_select=True,
|
||||
background_color="#FFFFFF",
|
||||
)
|
||||
html_for_server = html
|
||||
folder_for_server = None
|
||||
|
||||
# Track when the page is loaded so JS execution succeeds
|
||||
window_ready = threading.Event()
|
||||
|
||||
def _on_loaded():
|
||||
window_ready.set()
|
||||
|
||||
window.events.loaded += _on_loaded # type: ignore[attr-defined]
|
||||
|
||||
# Start HTTP server for control (and optionally serve inline HTML or static folder)
|
||||
_start_http_server(
|
||||
window, port, window_ready, html_content=html_for_server, folder_path=folder_for_server
|
||||
)
|
||||
|
||||
# Print startup info for parent to read
|
||||
print(json.dumps({"pid": os.getpid(), "port": port}), flush=True)
|
||||
|
||||
# Start GUI (blocking)
|
||||
webview.start(debug=os.environ.get("CUA_BENCH_UI_DEBUG", "false").lower() in ("true", "1"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
40
libs/python/bench-ui/examples/folder_example.py
Normal file
40
libs/python/bench-ui/examples/folder_example.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
import time
|
||||
from bench_ui import launch_window, get_element_rect, execute_javascript
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
def main():
|
||||
os.environ["CUA_BENCH_UI_DEBUG"] = "1"
|
||||
|
||||
# Get the path to the gui folder
|
||||
gui_folder = Path(__file__).parent / "gui"
|
||||
|
||||
# Launch a window serving the static folder
|
||||
pid = launch_window(
|
||||
folder=str(gui_folder),
|
||||
title="Static Folder Example",
|
||||
width=800,
|
||||
height=600,
|
||||
)
|
||||
print(f"Launched window with PID: {pid}")
|
||||
print(f"Serving folder: {gui_folder}")
|
||||
|
||||
# Give the window a moment to render
|
||||
time.sleep(1.5)
|
||||
|
||||
# Query the client rect of the button element
|
||||
rect = get_element_rect(pid, "#testButton", space="window")
|
||||
print("Button rect (window space):", rect)
|
||||
|
||||
# Check if button has been clicked
|
||||
clicked = execute_javascript(pid, "document.getElementById('testButton').disabled")
|
||||
print("Button clicked:", clicked)
|
||||
|
||||
# Get the page title
|
||||
title = execute_javascript(pid, "document.title")
|
||||
print("Page title:", title)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
42
libs/python/bench-ui/examples/gui/index.html
Normal file
42
libs/python/bench-ui/examples/gui/index.html
Normal file
@@ -0,0 +1,42 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Static Folder Example</title>
|
||||
<link rel="stylesheet" href="styles.css">
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>Static Folder Example</h1>
|
||||
<p>This page is served from a static folder using bench-ui!</p>
|
||||
|
||||
<div class="image-container">
|
||||
<img src="logo.svg" alt="Example SVG Logo" class="logo">
|
||||
</div>
|
||||
|
||||
<div class="info">
|
||||
<p>This example demonstrates:</p>
|
||||
<ul>
|
||||
<li>Serving a static folder with bench-ui</li>
|
||||
<li>Loading external CSS files (styles.css)</li>
|
||||
<li>Loading SVG images (logo.svg)</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<button id="testButton" class="btn">Click Me!</button>
|
||||
<p id="status"></p>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
document.getElementById('testButton').addEventListener('click', function () {
|
||||
document.getElementById('status').textContent = 'Button clicked! ✓';
|
||||
this.disabled = true;
|
||||
this.textContent = 'Clicked!';
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
24
libs/python/bench-ui/examples/gui/logo.svg
Normal file
24
libs/python/bench-ui/examples/gui/logo.svg
Normal file
@@ -0,0 +1,24 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 200 200">
|
||||
<defs>
|
||||
<linearGradient id="grad1" x1="0%" y1="0%" x2="100%" y2="100%">
|
||||
<stop offset="0%" style="stop-color:#667eea;stop-opacity:1" />
|
||||
<stop offset="100%" style="stop-color:#764ba2;stop-opacity:1" />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
|
||||
<!-- Background circle -->
|
||||
<circle cx="100" cy="100" r="95" fill="url(#grad1)" />
|
||||
|
||||
<!-- Window icon -->
|
||||
<rect x="50" y="50" width="100" height="100" rx="8" fill="white" opacity="0.9" />
|
||||
|
||||
<!-- Window panes -->
|
||||
<line x1="100" y1="50" x2="100" y2="150" stroke="url(#grad1)" stroke-width="4" />
|
||||
<line x1="50" y1="100" x2="150" y2="100" stroke="url(#grad1)" stroke-width="4" />
|
||||
|
||||
<!-- Decorative dots -->
|
||||
<circle cx="75" cy="75" r="8" fill="url(#grad1)" />
|
||||
<circle cx="125" cy="75" r="8" fill="url(#grad1)" />
|
||||
<circle cx="75" cy="125" r="8" fill="url(#grad1)" />
|
||||
<circle cx="125" cy="125" r="8" fill="url(#grad1)" />
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 963 B |
92
libs/python/bench-ui/examples/gui/styles.css
Normal file
92
libs/python/bench-ui/examples/gui/styles.css
Normal file
@@ -0,0 +1,92 @@
|
||||
body {
|
||||
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.container {
|
||||
background: white;
|
||||
border-radius: 12px;
|
||||
padding: 40px;
|
||||
box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
|
||||
max-width: 600px;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
h1 {
|
||||
color: #333;
|
||||
margin-top: 0;
|
||||
font-size: 2em;
|
||||
}
|
||||
|
||||
p {
|
||||
color: #666;
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.image-container {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
margin: 30px 0;
|
||||
}
|
||||
|
||||
.logo {
|
||||
width: 150px;
|
||||
height: 150px;
|
||||
}
|
||||
|
||||
.info {
|
||||
background: #f8f9fa;
|
||||
border-left: 4px solid #667eea;
|
||||
padding: 20px;
|
||||
margin: 20px 0;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.info ul {
|
||||
margin: 10px 0;
|
||||
padding-left: 20px;
|
||||
}
|
||||
|
||||
.info li {
|
||||
color: #555;
|
||||
margin: 8px 0;
|
||||
}
|
||||
|
||||
.btn {
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 12px 30px;
|
||||
font-size: 16px;
|
||||
border-radius: 6px;
|
||||
cursor: pointer;
|
||||
transition: transform 0.2s, box-shadow 0.2s;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.btn:hover:not(:disabled) {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
|
||||
.btn:active:not(:disabled) {
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
.btn:disabled {
|
||||
opacity: 0.6;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
#status {
|
||||
margin-top: 15px;
|
||||
font-weight: 600;
|
||||
color: #28a745;
|
||||
font-size: 18px;
|
||||
}
|
||||
BIN
libs/python/bench-ui/examples/output_overlay.png
Normal file
BIN
libs/python/bench-ui/examples/output_overlay.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 743 KiB |
80
libs/python/bench-ui/examples/simple_example.py
Normal file
80
libs/python/bench-ui/examples/simple_example.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from bench_ui import execute_javascript, get_element_rect, launch_window
|
||||
|
||||
HTML = """
|
||||
<!doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<title>Bench UI Example</title>
|
||||
<style>
|
||||
body { font-family: system-ui, sans-serif; margin: 24px; }
|
||||
#target { width: 220px; height: 120px; background: #4f46e5; color: white; display: flex; align-items: center; justify-content: center; border-radius: 8px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Bench UI Example</h1>
|
||||
<div id="target">Hello from pywebview</div>
|
||||
|
||||
|
||||
<h1>Click the button</h1>
|
||||
<button id="submit" class="btn" data-instruction="the button">Submit</button>
|
||||
<script>
|
||||
window.__submitted = false;
|
||||
document.getElementById('submit').addEventListener('click', function() {
|
||||
window.__submitted = true;
|
||||
this.textContent = 'Submitted!';
|
||||
this.disabled = true;
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def main():
|
||||
os.environ["CUA_BENCH_UI_DEBUG"] = "1"
|
||||
|
||||
# Launch a window with inline HTML content
|
||||
pid = launch_window(
|
||||
html=HTML,
|
||||
title="Bench UI Example",
|
||||
width=800,
|
||||
height=600,
|
||||
)
|
||||
print(f"Launched window with PID: {pid}")
|
||||
|
||||
# Give the window a brief moment to render
|
||||
time.sleep(1.0)
|
||||
|
||||
# Query the client rect of an element via CSS selector in SCREEN space
|
||||
rect = get_element_rect(pid, "#target", space="screen")
|
||||
print("Element rect (screen space):", rect)
|
||||
|
||||
# Take a screenshot and overlay the bbox
|
||||
try:
|
||||
from PIL import ImageDraw, ImageGrab
|
||||
|
||||
img = ImageGrab.grab() # full screen
|
||||
draw = ImageDraw.Draw(img)
|
||||
x, y, w, h = rect["x"], rect["y"], rect["width"], rect["height"]
|
||||
box = (x, y, x + w, y + h)
|
||||
draw.rectangle(box, outline=(255, 0, 0), width=3)
|
||||
out_path = Path(__file__).parent / "output_overlay.png"
|
||||
img.save(out_path)
|
||||
print(f"Saved overlay screenshot to: {out_path}")
|
||||
except Exception as e:
|
||||
print(f"Failed to capture/annotate screenshot: {e}")
|
||||
|
||||
# Execute arbitrary JavaScript
|
||||
text = execute_javascript(pid, "window.__submitted")
|
||||
print("text:", text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
25
libs/python/bench-ui/pyproject.toml
Normal file
25
libs/python/bench-ui/pyproject.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[build-system]
|
||||
requires = ["pdm-backend"]
|
||||
build-backend = "pdm.backend"
|
||||
|
||||
[project]
|
||||
name = "cua-bench-ui"
|
||||
version = "0.7.0"
|
||||
description = "Lightweight webUI window controller for CUA bench using pywebview"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "TryCua", email = "gh@trycua.com" }
|
||||
]
|
||||
dependencies = [
|
||||
"pywebview>=5.3",
|
||||
"aiohttp>=3.9.0",
|
||||
"psutil>=5.9",
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
|
||||
[tool.pdm]
|
||||
distribution = true
|
||||
|
||||
[tool.pdm.build]
|
||||
includes = ["bench_ui/"]
|
||||
source-includes = ["README.md"]
|
||||
50
libs/python/bench-ui/tests/test_port_detection.py
Normal file
50
libs/python/bench-ui/tests/test_port_detection.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import time
|
||||
|
||||
import psutil
|
||||
import pytest
|
||||
from bench_ui import execute_javascript, launch_window
|
||||
from bench_ui.api import _pid_to_port
|
||||
|
||||
HTML = """
|
||||
<!doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<title>Bench UI Test</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="t">hello-world</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def test_execute_js_after_clearing_port_mapping():
|
||||
# Skip if pywebview backend is unavailable on this machine
|
||||
pywebview = pytest.importorskip("webview")
|
||||
|
||||
pid = launch_window(html=HTML, title="Bench UI Test", width=400, height=300)
|
||||
try:
|
||||
# Give a brief moment for window to render and server to start
|
||||
time.sleep(1.0)
|
||||
|
||||
# Sanity: mapping should exist initially
|
||||
assert pid in _pid_to_port
|
||||
|
||||
# Clear the cached mapping to simulate a fresh process lookup
|
||||
del _pid_to_port[pid]
|
||||
|
||||
# Now execute JS; this should succeed by detecting the port via psutil
|
||||
result = execute_javascript(pid, "document.querySelector('#t')?.textContent")
|
||||
assert result == "hello-world"
|
||||
finally:
|
||||
# Best-effort cleanup of the child process
|
||||
try:
|
||||
p = psutil.Process(pid)
|
||||
p.terminate()
|
||||
try:
|
||||
p.wait(timeout=3)
|
||||
except psutil.TimeoutExpired:
|
||||
p.kill()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.1.30
|
||||
current_version = 0.1.31
|
||||
commit = True
|
||||
tag = True
|
||||
tag_name = computer-server-v{new_version}
|
||||
|
||||
@@ -40,7 +40,7 @@ Refer to this notebook for a step-by-step guide on how to use the Computer-Use S
|
||||
|
||||
## Docs
|
||||
|
||||
- [Commands](https://cua.ai/docs/libraries/computer-server/Commands)
|
||||
- [REST-API](https://cua.ai/docs/libraries/computer-server/REST-API)
|
||||
- [WebSocket-API](https://cua.ai/docs/libraries/computer-server/WebSocket-API)
|
||||
- [Index](https://cua.ai/docs/libraries/computer-server)
|
||||
- [Commands](https://cua.ai/docs/computer-sdk/computer-server/Commands)
|
||||
- [REST-API](https://cua.ai/docs/computer-sdk/computer-server/REST-API)
|
||||
- [WebSocket-API](https://cua.ai/docs/computer-sdk/computer-server/WebSocket-API)
|
||||
- [Index](https://cua.ai/docs/computer-sdk/computer-server)
|
||||
|
||||
@@ -24,8 +24,8 @@ from fastapi import (
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from .handlers.factory import HandlerFactory
|
||||
from .browser import get_browser_manager
|
||||
from .handlers.factory import HandlerFactory
|
||||
|
||||
# Authentication session TTL (in seconds). Override via env var CUA_AUTH_TTL_SECONDS. Default: 60s
|
||||
AUTH_SESSION_TTL_SECONDS: int = int(os.environ.get("CUA_AUTH_TTL_SECONDS", "60"))
|
||||
@@ -805,7 +805,7 @@ async def playwright_exec_endpoint(
|
||||
try:
|
||||
browser_manager = get_browser_manager()
|
||||
result = await browser_manager.execute_command(command, params)
|
||||
|
||||
|
||||
if result.get("success"):
|
||||
return JSONResponse(content=result)
|
||||
else:
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
|
||||
|
||||
[project]
|
||||
name = "cua-computer-server"
|
||||
version = "0.1.30"
|
||||
version = "0.1.31"
|
||||
|
||||
description = "Server component for the Computer-Use Interface (CUI) framework powering Cua"
|
||||
authors = [
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.4.17
|
||||
current_version = 0.4.18
|
||||
commit = True
|
||||
tag = True
|
||||
tag_name = computer-v{new_version}
|
||||
|
||||
@@ -7,7 +7,28 @@ import platform
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
try:
|
||||
from typing import ParamSpec
|
||||
except Exception: # pragma: no cover
|
||||
from typing_extensions import ParamSpec # type: ignore
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
from core.telemetry import is_telemetry_enabled, record_event
|
||||
from PIL import Image
|
||||
@@ -66,8 +87,9 @@ class Computer:
|
||||
verbosity: Union[int, LogLevel] = logging.INFO,
|
||||
telemetry_enabled: bool = True,
|
||||
provider_type: Union[str, VMProviderType] = VMProviderType.LUME,
|
||||
port: Optional[int] = 7777,
|
||||
provider_port: Optional[int] = 7777,
|
||||
noVNC_port: Optional[int] = 8006,
|
||||
api_port: Optional[int] = None,
|
||||
host: str = os.environ.get("PYLUME_HOST", "localhost"),
|
||||
storage: Optional[str] = None,
|
||||
ephemeral: bool = False,
|
||||
@@ -118,14 +140,19 @@ class Computer:
|
||||
|
||||
# Store original parameters
|
||||
self.image = image
|
||||
self.port = port
|
||||
self.provider_port = provider_port
|
||||
self.noVNC_port = noVNC_port
|
||||
self.api_port = api_port
|
||||
self.host = host
|
||||
self.os_type = os_type
|
||||
self.provider_type = provider_type
|
||||
self.ephemeral = ephemeral
|
||||
self.api_key = api_key if self.provider_type == VMProviderType.CLOUD else None
|
||||
|
||||
# Set default API port if not specified
|
||||
if self.api_port is None:
|
||||
self.api_port = 8443 if self.api_key else 8000
|
||||
|
||||
self.api_key = api_key
|
||||
self.experiments = experiments or []
|
||||
|
||||
if "app-use" in self.experiments:
|
||||
@@ -273,7 +300,7 @@ class Computer:
|
||||
interface = cast(
|
||||
BaseComputerInterface,
|
||||
InterfaceFactory.create_interface_for_os(
|
||||
os=self.os_type, ip_address=ip_address # type: ignore[arg-type]
|
||||
os=self.os_type, ip_address=ip_address, api_port=self.api_port # type: ignore[arg-type]
|
||||
),
|
||||
)
|
||||
self._interface = interface
|
||||
@@ -300,7 +327,7 @@ class Computer:
|
||||
storage = "ephemeral" if self.ephemeral else self.storage
|
||||
verbose = self.verbosity >= LogLevel.DEBUG
|
||||
ephemeral = self.ephemeral
|
||||
port = self.port if self.port is not None else 7777
|
||||
port = self.provider_port if self.provider_port is not None else 7777
|
||||
host = self.host if self.host else "localhost"
|
||||
image = self.image
|
||||
shared_path = self.shared_path
|
||||
@@ -365,6 +392,7 @@ class Computer:
|
||||
verbose=verbose,
|
||||
ephemeral=ephemeral,
|
||||
noVNC_port=noVNC_port,
|
||||
api_port=self.api_port,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider type: {self.provider_type}")
|
||||
@@ -513,13 +541,14 @@ class Computer:
|
||||
ip_address=ip_address,
|
||||
api_key=self.api_key,
|
||||
vm_name=self.config.name,
|
||||
api_port=self.api_port,
|
||||
),
|
||||
)
|
||||
else:
|
||||
interface = cast(
|
||||
BaseComputerInterface,
|
||||
InterfaceFactory.create_interface_for_os(
|
||||
os=self.os_type, ip_address=ip_address
|
||||
os=self.os_type, ip_address=ip_address, api_port=self.api_port
|
||||
),
|
||||
)
|
||||
|
||||
@@ -533,15 +562,13 @@ class Computer:
|
||||
# Use a single timeout for the entire connection process
|
||||
# The VM should already be ready at this point, so we're just establishing the connection
|
||||
await self._interface.wait_for_ready(timeout=30)
|
||||
self.logger.info("WebSocket interface connected successfully")
|
||||
self.logger.info("Sandbox interface connected successfully")
|
||||
except TimeoutError as e:
|
||||
self.logger.error(f"Failed to connect to WebSocket interface at {ip_address}")
|
||||
port = getattr(self._interface, "_api_port", 8000) # Default to 8000 if not set
|
||||
self.logger.error(f"Failed to connect to sandbox interface at {ip_address}:{port}")
|
||||
raise TimeoutError(
|
||||
f"Could not connect to WebSocket interface at {ip_address}:8000/ws: {str(e)}"
|
||||
f"Could not connect to sandbox interface at {ip_address}:{port}: {str(e)}"
|
||||
)
|
||||
# self.logger.warning(
|
||||
# f"Could not connect to WebSocket interface at {ip_address}:8000/ws: {str(e)}, expect missing functionality"
|
||||
# )
|
||||
|
||||
# Create an event to keep the VM running in background if needed
|
||||
if not self.use_host_computer_server:
|
||||
@@ -688,6 +715,7 @@ class Computer:
|
||||
ip_address=ip_address,
|
||||
api_key=self.api_key,
|
||||
vm_name=self.config.name,
|
||||
api_port=self.api_port,
|
||||
),
|
||||
)
|
||||
else:
|
||||
@@ -696,6 +724,7 @@ class Computer:
|
||||
InterfaceFactory.create_interface_for_os(
|
||||
os=self.os_type,
|
||||
ip_address=ip_address,
|
||||
api_port=self.api_port,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1013,7 +1042,7 @@ class Computer:
|
||||
else:
|
||||
# POSIX (macOS/Linux)
|
||||
venv_path = f"$HOME/.venvs/{venv_name}"
|
||||
create_cmd = f'mkdir -p "$HOME/.venvs" && python3 -m venv "{venv_path}"'
|
||||
create_cmd = f'mkdir -p "$HOME/.venvs" && python -m venv "{venv_path}"'
|
||||
# Check if venv exists, if not create it
|
||||
check_cmd = f'test -d "{venv_path}" || ({create_cmd})'
|
||||
_ = await self.interface.run_command(check_cmd)
|
||||
@@ -1024,7 +1053,25 @@ class Computer:
|
||||
if requirements_str
|
||||
else "echo No requirements to install"
|
||||
)
|
||||
return await self.interface.run_command(install_cmd)
|
||||
return await self.interface.run_command(install_cmd)
|
||||
|
||||
async def pip_install(self, requirements: list[str]):
|
||||
"""Install packages using the system Python/pip (no venv).
|
||||
|
||||
Args:
|
||||
requirements: List of package requirements to install globally/user site.
|
||||
|
||||
Returns:
|
||||
Tuple of (stdout, stderr) from the installation command
|
||||
"""
|
||||
requirements = requirements or []
|
||||
if not requirements:
|
||||
return await self.interface.run_command("echo No requirements to install")
|
||||
|
||||
# Use python -m pip for cross-platform consistency
|
||||
reqs = " ".join(requirements)
|
||||
install_cmd = f"python -m pip install {reqs}"
|
||||
return await self.interface.run_command(install_cmd)
|
||||
|
||||
async def venv_cmd(self, venv_name: str, command: str):
|
||||
"""Execute a shell command in a virtual environment.
|
||||
@@ -1074,19 +1121,11 @@ class Computer:
|
||||
The result of the function execution, or raises any exception that occurred
|
||||
"""
|
||||
import base64
|
||||
import inspect
|
||||
import json
|
||||
import textwrap
|
||||
|
||||
try:
|
||||
# Get function source code using inspect.getsource
|
||||
source = inspect.getsource(python_func)
|
||||
# Remove common leading whitespace (dedent)
|
||||
func_source = textwrap.dedent(source).strip()
|
||||
|
||||
# Remove decorators
|
||||
while func_source.lstrip().startswith("@"):
|
||||
func_source = func_source.split("\n", 1)[1].strip()
|
||||
func_source = helpers.generate_source_code(python_func)
|
||||
|
||||
# Get function name for execution
|
||||
func_name = python_func.__name__
|
||||
@@ -1101,19 +1140,23 @@ class Computer:
|
||||
raise Exception(f"Failed to reconstruct function source: {e}")
|
||||
|
||||
# Create Python code that will define and execute the function
|
||||
args_b64 = base64.b64encode(args_json.encode("utf-8")).decode("ascii")
|
||||
kwargs_b64 = base64.b64encode(kwargs_json.encode("utf-8")).decode("ascii")
|
||||
|
||||
python_code = f'''
|
||||
import json
|
||||
import traceback
|
||||
import base64
|
||||
|
||||
try:
|
||||
# Define the function from source
|
||||
{textwrap.indent(func_source, " ")}
|
||||
|
||||
# Deserialize args and kwargs from JSON
|
||||
args_json = """{args_json}"""
|
||||
kwargs_json = """{kwargs_json}"""
|
||||
args = json.loads(args_json)
|
||||
kwargs = json.loads(kwargs_json)
|
||||
# Deserialize args and kwargs from base64 JSON
|
||||
_args_b64 = """{args_b64}"""
|
||||
_kwargs_b64 = """{kwargs_b64}"""
|
||||
args = json.loads(base64.b64decode(_args_b64).decode('utf-8'))
|
||||
kwargs = json.loads(base64.b64decode(_kwargs_b64).decode('utf-8'))
|
||||
|
||||
# Execute the function
|
||||
result = {func_name}(*args, **kwargs)
|
||||
@@ -1177,10 +1220,21 @@ print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
|
||||
if output_payload["success"]:
|
||||
return output_payload["result"]
|
||||
else:
|
||||
import builtins
|
||||
|
||||
# Recreate and raise the original exception
|
||||
error_info = output_payload["error"]
|
||||
error_class = eval(error_info["type"])
|
||||
raise error_class(error_info["message"])
|
||||
error_info = output_payload.get("error", {}) or {}
|
||||
err_type = error_info.get("type") or "Exception"
|
||||
err_msg = error_info.get("message") or ""
|
||||
err_tb = error_info.get("traceback") or ""
|
||||
|
||||
exc_cls = getattr(builtins, err_type, None)
|
||||
if isinstance(exc_cls, type) and issubclass(exc_cls, BaseException):
|
||||
# Built-in exception: rethrow with remote traceback appended
|
||||
raise exc_cls(f"{err_msg}\n\nRemote traceback:\n{err_tb}")
|
||||
else:
|
||||
# Non built-in: raise a safe local error carrying full remote context
|
||||
raise RuntimeError(f"{err_type}: {err_msg}\n\nRemote traceback:\n{err_tb}")
|
||||
else:
|
||||
raise Exception("Invalid output format: markers found but no content between them")
|
||||
else:
|
||||
@@ -1188,3 +1242,345 @@ print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
|
||||
raise Exception(
|
||||
f"No output payload found. stdout: {result.stdout}, stderr: {result.stderr}"
|
||||
)
|
||||
|
||||
async def venv_exec_background(
|
||||
self, venv_name: str, python_func, *args, requirements: Optional[List[str]] = None, **kwargs
|
||||
) -> int:
|
||||
"""Run the Python function in the venv in the background and return the PID.
|
||||
|
||||
Uses a short launcher Python that spawns a detached child and exits immediately.
|
||||
"""
|
||||
import base64
|
||||
import json
|
||||
import textwrap
|
||||
import time as _time
|
||||
|
||||
try:
|
||||
func_source = helpers.generate_source_code(python_func)
|
||||
func_name = python_func.__name__
|
||||
args_json = json.dumps(args, default=str)
|
||||
kwargs_json = json.dumps(kwargs, default=str)
|
||||
except OSError as e:
|
||||
raise Exception(f"Cannot retrieve source code for function {python_func.__name__}: {e}")
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to reconstruct function source: {e}")
|
||||
|
||||
reqs_list = requirements or []
|
||||
reqs_json = json.dumps(reqs_list)
|
||||
|
||||
# Create Python code that will define and execute the function
|
||||
args_b64 = base64.b64encode(args_json.encode("utf-8")).decode("ascii")
|
||||
kwargs_b64 = base64.b64encode(kwargs_json.encode("utf-8")).decode("ascii")
|
||||
|
||||
payload_code = (
|
||||
f'''
|
||||
import json
|
||||
import traceback
|
||||
import base64
|
||||
|
||||
try:
|
||||
# Define the function from source
|
||||
{textwrap.indent(func_source, " ")}
|
||||
|
||||
# Deserialize args and kwargs from base64 JSON
|
||||
_args_b64 = """{args_b64}"""
|
||||
_kwargs_b64 = """{kwargs_b64}"""
|
||||
args = json.loads(base64.b64decode(_args_b64).decode('utf-8'))
|
||||
kwargs = json.loads(base64.b64decode(_kwargs_b64).decode('utf-8'))
|
||||
|
||||
# Ensure requirements inside the active venv
|
||||
for pkg in json.loads('''
|
||||
+ repr(reqs_json)
|
||||
+ """):
|
||||
if pkg:
|
||||
import subprocess, sys
|
||||
subprocess.run([sys.executable, '-m', 'pip', 'install', pkg], check=False)
|
||||
_ = {func_name}(*args, **kwargs)
|
||||
except Exception:
|
||||
import sys
|
||||
sys.stderr.write(traceback.format_exc())
|
||||
"""
|
||||
)
|
||||
payload_b64 = base64.b64encode(payload_code.encode("utf-8")).decode("ascii")
|
||||
|
||||
if self.os_type == "windows":
|
||||
# Launcher spawns detached child and prints its PID
|
||||
launcher_code = f"""
|
||||
import base64, subprocess, os, sys
|
||||
DETACHED_PROCESS = 0x00000008
|
||||
CREATE_NEW_PROCESS_GROUP = 0x00000200
|
||||
creationflags = DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP
|
||||
code = base64.b64decode("{payload_b64}").decode("utf-8")
|
||||
p = subprocess.Popen(["python", "-c", code], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, creationflags=creationflags)
|
||||
print(p.pid)
|
||||
"""
|
||||
launcher_b64 = base64.b64encode(launcher_code.encode("utf-8")).decode("ascii")
|
||||
venv_path = f"%USERPROFILE%\\.venvs\\{venv_name}"
|
||||
cmd = (
|
||||
'cmd /c "'
|
||||
f'call "{venv_path}\\Scripts\\activate.bat" && '
|
||||
f"python -c \"import base64; exec(base64.b64decode('{launcher_b64}').decode('utf-8'))\""
|
||||
'"'
|
||||
)
|
||||
result = await self.interface.run_command(cmd)
|
||||
pid_str = (result.stdout or "").strip().splitlines()[-1].strip()
|
||||
return int(pid_str)
|
||||
else:
|
||||
log = f"/tmp/cua_bg_{int(_time.time())}.log"
|
||||
launcher_code = f"""
|
||||
import base64, subprocess, os, sys
|
||||
code = base64.b64decode("{payload_b64}").decode("utf-8")
|
||||
with open("{log}", "ab", buffering=0) as f:
|
||||
p = subprocess.Popen(["python", "-c", code], stdout=f, stderr=subprocess.STDOUT, preexec_fn=getattr(os, "setsid", None))
|
||||
print(p.pid)
|
||||
"""
|
||||
launcher_b64 = base64.b64encode(launcher_code.encode("utf-8")).decode("ascii")
|
||||
venv_path = f"$HOME/.venvs/{venv_name}"
|
||||
shell = (
|
||||
f'. "{venv_path}/bin/activate" && '
|
||||
f"python -c \"import base64; exec(base64.b64decode('{launcher_b64}').decode('utf-8'))\""
|
||||
)
|
||||
result = await self.interface.run_command(shell)
|
||||
pid_str = (result.stdout or "").strip().splitlines()[-1].strip()
|
||||
return int(pid_str)
|
||||
|
||||
async def python_exec(self, python_func, *args, **kwargs):
|
||||
"""Execute a Python function using the system Python (no venv).
|
||||
|
||||
Uses source extraction and base64 transport, mirroring venv_exec but
|
||||
without virtual environment activation.
|
||||
|
||||
Returns the function result or raises a reconstructed exception with
|
||||
remote traceback context appended.
|
||||
"""
|
||||
import base64
|
||||
import json
|
||||
import textwrap
|
||||
|
||||
try:
|
||||
func_source = helpers.generate_source_code(python_func)
|
||||
func_name = python_func.__name__
|
||||
args_json = json.dumps(args, default=str)
|
||||
kwargs_json = json.dumps(kwargs, default=str)
|
||||
except OSError as e:
|
||||
raise Exception(f"Cannot retrieve source code for function {python_func.__name__}: {e}")
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to reconstruct function source: {e}")
|
||||
|
||||
# Create Python code that will define and execute the function
|
||||
args_b64 = base64.b64encode(args_json.encode("utf-8")).decode("ascii")
|
||||
kwargs_b64 = base64.b64encode(kwargs_json.encode("utf-8")).decode("ascii")
|
||||
|
||||
python_code = f'''
|
||||
import json
|
||||
import traceback
|
||||
import base64
|
||||
|
||||
try:
|
||||
# Define the function from source
|
||||
{textwrap.indent(func_source, " ")}
|
||||
|
||||
# Deserialize args and kwargs from base64 JSON
|
||||
_args_b64 = """{args_b64}"""
|
||||
_kwargs_b64 = """{kwargs_b64}"""
|
||||
args = json.loads(base64.b64decode(_args_b64).decode('utf-8'))
|
||||
kwargs = json.loads(base64.b64decode(_kwargs_b64).decode('utf-8'))
|
||||
|
||||
# Execute the function
|
||||
result = {func_name}(*args, **kwargs)
|
||||
|
||||
# Create success output payload
|
||||
output_payload = {{
|
||||
"success": True,
|
||||
"result": result,
|
||||
"error": None
|
||||
}}
|
||||
|
||||
except Exception as e:
|
||||
# Create error output payload
|
||||
output_payload = {{
|
||||
"success": False,
|
||||
"result": None,
|
||||
"error": {{
|
||||
"type": type(e).__name__,
|
||||
"message": str(e),
|
||||
"traceback": traceback.format_exc()
|
||||
}}
|
||||
}}
|
||||
|
||||
# Serialize the output payload as JSON
|
||||
import json
|
||||
output_json = json.dumps(output_payload, default=str)
|
||||
|
||||
# Print the JSON output with markers
|
||||
print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
|
||||
'''
|
||||
|
||||
encoded_code = base64.b64encode(python_code.encode("utf-8")).decode("ascii")
|
||||
python_command = (
|
||||
f"python -c \"import base64; exec(base64.b64decode('{encoded_code}').decode('utf-8'))\""
|
||||
)
|
||||
result = await self.interface.run_command(python_command)
|
||||
|
||||
start_marker = "<<<VENV_EXEC_START>>>"
|
||||
end_marker = "<<<VENV_EXEC_END>>>"
|
||||
|
||||
print(result.stdout[: result.stdout.find(start_marker)])
|
||||
|
||||
if start_marker in result.stdout and end_marker in result.stdout:
|
||||
start_idx = result.stdout.find(start_marker) + len(start_marker)
|
||||
end_idx = result.stdout.find(end_marker)
|
||||
if start_idx < end_idx:
|
||||
output_json = result.stdout[start_idx:end_idx]
|
||||
try:
|
||||
output_payload = json.loads(output_json)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to decode output payload: {e}")
|
||||
|
||||
if output_payload["success"]:
|
||||
return output_payload["result"]
|
||||
else:
|
||||
import builtins
|
||||
|
||||
error_info = output_payload.get("error", {}) or {}
|
||||
err_type = error_info.get("type") or "Exception"
|
||||
err_msg = error_info.get("message") or ""
|
||||
err_tb = error_info.get("traceback") or ""
|
||||
exc_cls = getattr(builtins, err_type, None)
|
||||
if isinstance(exc_cls, type) and issubclass(exc_cls, BaseException):
|
||||
raise exc_cls(f"{err_msg}\n\nRemote traceback:\n{err_tb}")
|
||||
else:
|
||||
raise RuntimeError(f"{err_type}: {err_msg}\n\nRemote traceback:\n{err_tb}")
|
||||
else:
|
||||
raise Exception("Invalid output format: markers found but no content between them")
|
||||
else:
|
||||
raise Exception(
|
||||
f"No output payload found. stdout: {result.stdout}, stderr: {result.stderr}"
|
||||
)
|
||||
|
||||
async def python_exec_background(
|
||||
self, python_func, *args, requirements: Optional[List[str]] = None, **kwargs
|
||||
) -> int:
|
||||
"""Run a Python function with the system interpreter in the background and return PID.
|
||||
|
||||
Uses a short launcher Python that spawns a detached child and exits immediately.
|
||||
"""
|
||||
import base64
|
||||
import json
|
||||
import textwrap
|
||||
import time as _time
|
||||
|
||||
try:
|
||||
func_source = helpers.generate_source_code(python_func)
|
||||
func_name = python_func.__name__
|
||||
args_json = json.dumps(args, default=str)
|
||||
kwargs_json = json.dumps(kwargs, default=str)
|
||||
except OSError as e:
|
||||
raise Exception(f"Cannot retrieve source code for function {python_func.__name__}: {e}")
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to reconstruct function source: {e}")
|
||||
|
||||
# Create Python code that will define and execute the function
|
||||
args_b64 = base64.b64encode(args_json.encode("utf-8")).decode("ascii")
|
||||
kwargs_b64 = base64.b64encode(kwargs_json.encode("utf-8")).decode("ascii")
|
||||
|
||||
payload_code = f'''
|
||||
import json
|
||||
import traceback
|
||||
import base64
|
||||
|
||||
try:
|
||||
# Define the function from source
|
||||
{textwrap.indent(func_source, " ")}
|
||||
|
||||
# Deserialize args and kwargs from base64 JSON
|
||||
_args_b64 = """{args_b64}"""
|
||||
_kwargs_b64 = """{kwargs_b64}"""
|
||||
args = json.loads(base64.b64decode(_args_b64).decode('utf-8'))
|
||||
kwargs = json.loads(base64.b64decode(_kwargs_b64).decode('utf-8'))
|
||||
|
||||
_ = {func_name}(*args, **kwargs)
|
||||
except Exception:
|
||||
import sys
|
||||
sys.stderr.write(traceback.format_exc())
|
||||
'''
|
||||
payload_b64 = base64.b64encode(payload_code.encode("utf-8")).decode("ascii")
|
||||
|
||||
if self.os_type == "windows":
|
||||
launcher_code = f"""
|
||||
import base64, subprocess, os, sys
|
||||
DETACHED_PROCESS = 0x00000008
|
||||
CREATE_NEW_PROCESS_GROUP = 0x00000200
|
||||
creationflags = DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP
|
||||
code = base64.b64decode("{payload_b64}").decode("utf-8")
|
||||
p = subprocess.Popen(["python", "-c", code], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, creationflags=creationflags)
|
||||
print(p.pid)
|
||||
"""
|
||||
launcher_b64 = base64.b64encode(launcher_code.encode("utf-8")).decode("ascii")
|
||||
cmd = f"python -c \"import base64; exec(base64.b64decode('{launcher_b64}').decode('utf-8'))\""
|
||||
result = await self.interface.run_command(cmd)
|
||||
pid_str = (result.stdout or "").strip().splitlines()[-1].strip()
|
||||
return int(pid_str)
|
||||
else:
|
||||
log = f"/tmp/cua_bg_{int(_time.time())}.log"
|
||||
launcher_code = f"""
|
||||
import base64, subprocess, os, sys
|
||||
code = base64.b64decode("{payload_b64}").decode("utf-8")
|
||||
with open("{log}", "ab", buffering=0) as f:
|
||||
p = subprocess.Popen(["python", "-c", code], stdout=f, stderr=subprocess.STDOUT, preexec_fn=getattr(os, "setsid", None))
|
||||
print(p.pid)
|
||||
"""
|
||||
launcher_b64 = base64.b64encode(launcher_code.encode("utf-8")).decode("ascii")
|
||||
cmd = f"python -c \"import base64; exec(base64.b64decode('{launcher_b64}').decode('utf-8'))\""
|
||||
result = await self.interface.run_command(cmd)
|
||||
pid_str = (result.stdout or "").strip().splitlines()[-1].strip()
|
||||
return int(pid_str)
|
||||
|
||||
def python_command(
|
||||
self,
|
||||
requirements: Optional[List[str]] = None,
|
||||
*,
|
||||
venv_name: str = "default",
|
||||
use_system_python: bool = False,
|
||||
background: bool = False,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, Awaitable[R]]]:
|
||||
"""Decorator to execute a Python function remotely in this Computer's venv.
|
||||
|
||||
This mirrors `computer.helpers.sandboxed()` but binds to this instance and
|
||||
optionally ensures required packages are installed before execution.
|
||||
|
||||
Args:
|
||||
requirements: Packages to install in the virtual environment.
|
||||
venv_name: Name of the virtual environment to use.
|
||||
use_system_python: If True, use the system Python/pip instead of a venv.
|
||||
background: If True, run the function detached and return the child PID immediately.
|
||||
|
||||
Returns:
|
||||
A decorator that turns a local function into an async callable which
|
||||
runs remotely and returns the function's result.
|
||||
"""
|
||||
|
||||
reqs = list(requirements or [])
|
||||
|
||||
def decorator(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
|
||||
@wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if use_system_python:
|
||||
# For background, avoid blocking installs; install inside child process
|
||||
if background:
|
||||
return await self.python_exec_background(func, *args, requirements=reqs, **kwargs) # type: ignore[return-value]
|
||||
# Foreground: install first, then execute
|
||||
if reqs:
|
||||
await self.pip_install(reqs)
|
||||
return await self.python_exec(func, *args, **kwargs)
|
||||
else:
|
||||
# For background, avoid blocking installs; install inside child process under venv
|
||||
if background:
|
||||
return await self.venv_exec_background(venv_name, func, *args, requirements=reqs, **kwargs) # type: ignore[return-value]
|
||||
# Foreground: ensure venv and install, then execute
|
||||
await self.venv_install(venv_name, reqs)
|
||||
return await self.venv_exec(venv_name, func, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -2,18 +2,46 @@
|
||||
Helper functions and decorators for the Computer module.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import builtins
|
||||
import importlib.util
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Optional, TypeVar, cast
|
||||
from inspect import getsource
|
||||
from textwrap import dedent
|
||||
from types import FunctionType, ModuleType
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Set, TypedDict, TypeVar
|
||||
|
||||
try:
|
||||
# Python 3.12+ has ParamSpec in typing
|
||||
from typing import ParamSpec
|
||||
except ImportError: # pragma: no cover
|
||||
# Fallback for environments without ParamSpec in typing
|
||||
from typing_extensions import ParamSpec # type: ignore
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class DependencyInfo(TypedDict):
|
||||
import_statements: List[str]
|
||||
definitions: List[tuple[str, Any]]
|
||||
|
||||
|
||||
# Global reference to the default computer instance
|
||||
_default_computer = None
|
||||
|
||||
# Global cache for function dependency analysis
|
||||
_function_dependency_map: Dict[FunctionType, DependencyInfo] = {}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_default_computer(computer):
|
||||
def set_default_computer(computer: Any) -> None:
|
||||
"""
|
||||
Set the default computer instance to be used by the remote decorator.
|
||||
|
||||
@@ -24,19 +52,26 @@ def set_default_computer(computer):
|
||||
_default_computer = computer
|
||||
|
||||
|
||||
def sandboxed(venv_name: str = "default", computer: str = "default", max_retries: int = 3):
|
||||
def sandboxed(
|
||||
venv_name: str = "default",
|
||||
computer: str = "default",
|
||||
max_retries: int = 3,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, Awaitable[R]]]:
|
||||
"""
|
||||
Decorator that wraps a function to be executed remotely via computer.venv_exec
|
||||
|
||||
The function is automatically analyzed for dependencies (imports, helper functions,
|
||||
constants, etc.) and reconstructed with all necessary code in the remote sandbox.
|
||||
|
||||
Args:
|
||||
venv_name: Name of the virtual environment to execute in
|
||||
computer: The computer instance to use, or "default" to use the globally set default
|
||||
max_retries: Maximum number of retries for the remote execution
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
def decorator(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# Determine which computer instance to use
|
||||
comp = computer if computer != "default" else _default_computer
|
||||
|
||||
@@ -54,6 +89,402 @@ def sandboxed(venv_name: str = "default", computer: str = "default", max_retries
|
||||
if i == max_retries - 1:
|
||||
raise e
|
||||
|
||||
# Should be unreachable because we either returned or raised
|
||||
raise RuntimeError("sandboxed wrapper reached unreachable code path")
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _extract_import_statement(name: str, module: ModuleType) -> str:
|
||||
"""Extract the original import statement for a module."""
|
||||
module_name = module.__name__
|
||||
|
||||
if name == module_name.split(".")[0]:
|
||||
return f"import {module_name}"
|
||||
else:
|
||||
return f"import {module_name} as {name}"
|
||||
|
||||
|
||||
def _is_third_party_module(module_name: str) -> bool:
|
||||
"""Check if a module is a third-party module."""
|
||||
stdlib_modules = set(sys.stdlib_module_names) if hasattr(sys, "stdlib_module_names") else set()
|
||||
|
||||
if module_name in stdlib_modules:
|
||||
return False
|
||||
|
||||
try:
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
if spec is None:
|
||||
return False
|
||||
|
||||
if spec.origin and ("site-packages" in spec.origin or "dist-packages" in spec.origin):
|
||||
return True
|
||||
|
||||
return False
|
||||
except (ImportError, ModuleNotFoundError, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
def _is_project_import(module_name: str) -> bool:
|
||||
"""Check if a module is a project-level import."""
|
||||
if module_name.startswith("__relative_import_level_"):
|
||||
return True
|
||||
|
||||
if module_name in sys.modules:
|
||||
module = sys.modules[module_name]
|
||||
if hasattr(module, "__file__") and module.__file__:
|
||||
if "site-packages" not in module.__file__ and "dist-packages" not in module.__file__:
|
||||
cwd = os.getcwd()
|
||||
if module.__file__.startswith(cwd):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _categorize_module(module_name: str) -> str:
|
||||
"""Categorize a module as stdlib, third-party, or project."""
|
||||
if module_name.startswith("__relative_import_level_"):
|
||||
return "project"
|
||||
elif module_name in (
|
||||
set(sys.stdlib_module_names) if hasattr(sys, "stdlib_module_names") else set()
|
||||
):
|
||||
return "stdlib"
|
||||
elif _is_third_party_module(module_name):
|
||||
return "third_party"
|
||||
elif _is_project_import(module_name):
|
||||
return "project"
|
||||
else:
|
||||
return "unknown"
|
||||
|
||||
|
||||
class _DependencyVisitor(ast.NodeVisitor):
|
||||
"""AST visitor to extract imports and name references from a function."""
|
||||
|
||||
def __init__(self, function_name: str) -> None:
|
||||
self.function_name = function_name
|
||||
self.internal_imports: Set[str] = set()
|
||||
self.internal_import_statements: List[str] = []
|
||||
self.name_references: Set[str] = set()
|
||||
self.local_names: Set[str] = set()
|
||||
self.inside_function = False
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
||||
if node.name == self.function_name and not self.inside_function:
|
||||
self.inside_function = True
|
||||
|
||||
for arg in node.args.args + node.args.posonlyargs + node.args.kwonlyargs:
|
||||
self.local_names.add(arg.arg)
|
||||
if node.args.vararg:
|
||||
self.local_names.add(node.args.vararg.arg)
|
||||
if node.args.kwarg:
|
||||
self.local_names.add(node.args.kwarg.arg)
|
||||
|
||||
for child in node.body:
|
||||
self.visit(child)
|
||||
|
||||
self.inside_function = False
|
||||
else:
|
||||
if self.inside_function:
|
||||
self.local_names.add(node.name)
|
||||
for child in node.body:
|
||||
self.visit(child)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
|
||||
self.visit_FunctionDef(node) # type: ignore
|
||||
|
||||
def visit_Import(self, node: ast.Import) -> None:
|
||||
if self.inside_function:
|
||||
for alias in node.names:
|
||||
module_name = alias.name.split(".")[0]
|
||||
self.internal_imports.add(module_name)
|
||||
imported_as = alias.asname if alias.asname else alias.name.split(".")[0]
|
||||
self.local_names.add(imported_as)
|
||||
self.internal_import_statements.append(ast.unparse(node))
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
||||
if self.inside_function:
|
||||
if node.level == 0 and node.module:
|
||||
module_name = node.module.split(".")[0]
|
||||
self.internal_imports.add(module_name)
|
||||
elif node.level > 0:
|
||||
self.internal_imports.add(f"__relative_import_level_{node.level}__")
|
||||
|
||||
for alias in node.names:
|
||||
imported_as = alias.asname if alias.asname else alias.name
|
||||
self.local_names.add(imported_as)
|
||||
self.internal_import_statements.append(ast.unparse(node))
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Name(self, node: ast.Name) -> None:
|
||||
if self.inside_function:
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
self.name_references.add(node.id)
|
||||
elif isinstance(node.ctx, ast.Store):
|
||||
self.local_names.add(node.id)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
||||
if self.inside_function:
|
||||
self.local_names.add(node.name)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_For(self, node: ast.For) -> None:
|
||||
if self.inside_function and isinstance(node.target, ast.Name):
|
||||
self.local_names.add(node.target.id)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_comprehension(self, node: ast.comprehension) -> None:
|
||||
if self.inside_function and isinstance(node.target, ast.Name):
|
||||
self.local_names.add(node.target.id)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ExceptHandler(self, node: ast.ExceptHandler) -> None:
|
||||
if self.inside_function and node.name:
|
||||
self.local_names.add(node.name)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_With(self, node: ast.With) -> None:
|
||||
if self.inside_function:
|
||||
for item in node.items:
|
||||
if item.optional_vars and isinstance(item.optional_vars, ast.Name):
|
||||
self.local_names.add(item.optional_vars.id)
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
def _traverse_and_collect_dependencies(func: FunctionType) -> DependencyInfo:
|
||||
"""
|
||||
Traverse a function and collect its dependencies.
|
||||
|
||||
Returns a dict with:
|
||||
- import_statements: List of import statements needed
|
||||
- definitions: List of (name, obj) tuples for helper functions/classes/constants
|
||||
"""
|
||||
source = dedent(getsource(func))
|
||||
tree = ast.parse(source)
|
||||
|
||||
visitor = _DependencyVisitor(func.__name__)
|
||||
visitor.visit(tree)
|
||||
|
||||
builtin_names = set(dir(builtins))
|
||||
external_refs = (visitor.name_references - visitor.local_names) - builtin_names
|
||||
|
||||
import_statements = []
|
||||
definitions = []
|
||||
visited = set()
|
||||
|
||||
# Include all internal import statements
|
||||
import_statements.extend(visitor.internal_import_statements)
|
||||
|
||||
# Analyze external references recursively
|
||||
def analyze_object(obj: Any, name: str, depth: int = 0) -> None:
|
||||
if depth > 20:
|
||||
return
|
||||
|
||||
obj_id = id(obj)
|
||||
if obj_id in visited:
|
||||
return
|
||||
visited.add(obj_id)
|
||||
|
||||
# Handle modules
|
||||
if inspect.ismodule(obj):
|
||||
import_stmt = _extract_import_statement(name, obj)
|
||||
import_statements.append(import_stmt)
|
||||
return
|
||||
|
||||
# Handle functions and classes
|
||||
if (
|
||||
inspect.isfunction(obj)
|
||||
or inspect.isclass(obj)
|
||||
or inspect.isbuiltin(obj)
|
||||
or inspect.ismethod(obj)
|
||||
):
|
||||
obj_module = getattr(obj, "__module__", None)
|
||||
if obj_module:
|
||||
base_module = obj_module.split(".")[0]
|
||||
module_category = _categorize_module(base_module)
|
||||
|
||||
# If from stdlib/third-party, just add import
|
||||
if module_category in ("stdlib", "third_party"):
|
||||
obj_name = getattr(obj, "__name__", name)
|
||||
|
||||
# Check if object is accessible by 'name' (in globals or closures)
|
||||
is_accessible = False
|
||||
if name in func.__globals__ and func.__globals__[name] is obj:
|
||||
is_accessible = True
|
||||
elif func.__closure__ and hasattr(func, "__code__"):
|
||||
freevars = func.__code__.co_freevars
|
||||
for i, var_name in enumerate(freevars):
|
||||
if var_name == name and i < len(func.__closure__):
|
||||
try:
|
||||
if func.__closure__[i].cell_contents is obj:
|
||||
is_accessible = True
|
||||
break
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
if is_accessible and name == obj_name:
|
||||
# Direct import: from requests import get, from math import sqrt
|
||||
import_statements.append(f"from {base_module} import {name}")
|
||||
else:
|
||||
# Module import: import requests
|
||||
import_statements.append(f"import {base_module}")
|
||||
return
|
||||
|
||||
try:
|
||||
obj_tree = ast.parse(dedent(getsource(obj)))
|
||||
obj_visitor = _DependencyVisitor(obj.__name__)
|
||||
obj_visitor.visit(obj_tree)
|
||||
|
||||
obj_external_refs = obj_visitor.name_references - obj_visitor.local_names
|
||||
obj_external_refs = obj_external_refs - builtin_names
|
||||
|
||||
# Add internal imports from this object
|
||||
import_statements.extend(obj_visitor.internal_import_statements)
|
||||
|
||||
# Recursively analyze its dependencies
|
||||
obj_globals = getattr(obj, "__globals__", None)
|
||||
obj_closure = getattr(obj, "__closure__", None)
|
||||
obj_code = getattr(obj, "__code__", None)
|
||||
if obj_globals:
|
||||
for ref_name in obj_external_refs:
|
||||
ref_obj = None
|
||||
|
||||
# Check globals first
|
||||
if ref_name in obj_globals:
|
||||
ref_obj = obj_globals[ref_name]
|
||||
# Check closure variables using co_freevars
|
||||
elif obj_closure and obj_code:
|
||||
freevars = obj_code.co_freevars
|
||||
for i, var_name in enumerate(freevars):
|
||||
if var_name == ref_name and i < len(obj_closure):
|
||||
try:
|
||||
ref_obj = obj_closure[i].cell_contents
|
||||
break
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
if ref_obj is not None:
|
||||
analyze_object(ref_obj, ref_name, depth + 1)
|
||||
|
||||
# Add this object to definitions
|
||||
if not inspect.ismodule(obj):
|
||||
ref_module = getattr(obj, "__module__", None)
|
||||
if ref_module:
|
||||
ref_base_module = ref_module.split(".")[0]
|
||||
ref_category = _categorize_module(ref_base_module)
|
||||
if ref_category not in ("stdlib", "third_party"):
|
||||
definitions.append((name, obj))
|
||||
else:
|
||||
definitions.append((name, obj))
|
||||
|
||||
except (OSError, TypeError):
|
||||
pass
|
||||
return
|
||||
|
||||
if isinstance(obj, (int, float, str, bool, list, dict, tuple, set, frozenset, type(None))):
|
||||
definitions.append((name, obj))
|
||||
|
||||
# Analyze all external references
|
||||
for name in external_refs:
|
||||
obj = None
|
||||
|
||||
# First check globals
|
||||
if name in func.__globals__:
|
||||
obj = func.__globals__[name]
|
||||
# Then check closure variables (sibling functions in enclosing scope)
|
||||
elif func.__closure__ and func.__code__.co_freevars:
|
||||
# Match closure variable names with cell contents
|
||||
freevars = func.__code__.co_freevars
|
||||
for i, var_name in enumerate(freevars):
|
||||
if var_name == name and i < len(func.__closure__):
|
||||
try:
|
||||
obj = func.__closure__[i].cell_contents
|
||||
break
|
||||
except (ValueError, AttributeError):
|
||||
# Cell is empty or doesn't have contents
|
||||
pass
|
||||
|
||||
if obj is not None:
|
||||
analyze_object(obj, name)
|
||||
|
||||
# Remove duplicate import statements
|
||||
unique_imports = []
|
||||
seen = set()
|
||||
for stmt in import_statements:
|
||||
if stmt not in seen:
|
||||
seen.add(stmt)
|
||||
unique_imports.append(stmt)
|
||||
|
||||
# Remove duplicate definitions
|
||||
unique_definitions = []
|
||||
seen_names = set()
|
||||
for name, obj in definitions:
|
||||
if name not in seen_names:
|
||||
seen_names.add(name)
|
||||
unique_definitions.append((name, obj))
|
||||
|
||||
return {
|
||||
"import_statements": unique_imports,
|
||||
"definitions": unique_definitions,
|
||||
}
|
||||
|
||||
|
||||
def generate_source_code(func: FunctionType) -> str:
|
||||
"""
|
||||
Generate complete source code for a function with all dependencies.
|
||||
|
||||
Args:
|
||||
func: The function to generate source code for
|
||||
|
||||
Returns:
|
||||
Complete Python source code as a string
|
||||
"""
|
||||
|
||||
if func in _function_dependency_map:
|
||||
info = _function_dependency_map[func]
|
||||
else:
|
||||
info = _traverse_and_collect_dependencies(func)
|
||||
_function_dependency_map[func] = info
|
||||
|
||||
# Build source code
|
||||
parts = []
|
||||
|
||||
# 1. Add imports
|
||||
if info["import_statements"]:
|
||||
parts.append("\n".join(info["import_statements"]))
|
||||
|
||||
# 2. Add definitions
|
||||
for name, obj in info["definitions"]:
|
||||
try:
|
||||
if inspect.isfunction(obj):
|
||||
source = dedent(getsource(obj))
|
||||
tree = ast.parse(source)
|
||||
if tree.body and isinstance(tree.body[0], (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
tree.body[0].decorator_list = []
|
||||
source = ast.unparse(tree)
|
||||
parts.append(source)
|
||||
elif inspect.isclass(obj):
|
||||
source = dedent(getsource(obj))
|
||||
tree = ast.parse(source)
|
||||
if tree.body and isinstance(tree.body[0], ast.ClassDef):
|
||||
tree.body[0].decorator_list = []
|
||||
source = ast.unparse(tree)
|
||||
parts.append(source)
|
||||
else:
|
||||
parts.append(f"{name} = {repr(obj)}")
|
||||
except (OSError, TypeError):
|
||||
pass
|
||||
|
||||
# 3. Add main function (without decorators)
|
||||
func_source = dedent(getsource(func))
|
||||
tree = ast.parse(func_source)
|
||||
if tree.body and isinstance(tree.body[0], (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
tree.body[0].decorator_list = []
|
||||
func_source = ast.unparse(tree)
|
||||
parts.append(func_source)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
@@ -12,6 +12,7 @@ class InterfaceFactory:
|
||||
def create_interface_for_os(
|
||||
os: Literal["macos", "linux", "windows"],
|
||||
ip_address: str,
|
||||
api_port: Optional[int] = None,
|
||||
api_key: Optional[str] = None,
|
||||
vm_name: Optional[str] = None,
|
||||
) -> BaseComputerInterface:
|
||||
@@ -20,6 +21,7 @@ class InterfaceFactory:
|
||||
Args:
|
||||
os: Operating system type ('macos', 'linux', or 'windows')
|
||||
ip_address: IP address of the computer to control
|
||||
api_port: Optional API port of the computer to control
|
||||
api_key: Optional API key for cloud authentication
|
||||
vm_name: Optional VM name for cloud authentication
|
||||
|
||||
@@ -35,10 +37,16 @@ class InterfaceFactory:
|
||||
from .windows import WindowsComputerInterface
|
||||
|
||||
if os == "macos":
|
||||
return MacOSComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
|
||||
return MacOSComputerInterface(
|
||||
ip_address, api_key=api_key, vm_name=vm_name, api_port=api_port
|
||||
)
|
||||
elif os == "linux":
|
||||
return LinuxComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
|
||||
return LinuxComputerInterface(
|
||||
ip_address, api_key=api_key, vm_name=vm_name, api_port=api_port
|
||||
)
|
||||
elif os == "windows":
|
||||
return WindowsComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
|
||||
return WindowsComputerInterface(
|
||||
ip_address, api_key=api_key, vm_name=vm_name, api_port=api_port
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported OS type: {os}")
|
||||
|
||||
@@ -30,6 +30,7 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
api_key: Optional[str] = None,
|
||||
vm_name: Optional[str] = None,
|
||||
logger_name: str = "computer.interface.generic",
|
||||
api_port: Optional[int] = None,
|
||||
):
|
||||
super().__init__(ip_address, username, password, api_key, vm_name)
|
||||
self._ws = None
|
||||
@@ -47,6 +48,9 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
# Set logger name for the interface
|
||||
self.logger = Logger(logger_name, LogLevel.NORMAL)
|
||||
|
||||
# Store custom ports
|
||||
self._api_port = api_port
|
||||
|
||||
# Optional default delay time between commands (in seconds)
|
||||
self.delay = 0.0
|
||||
|
||||
@@ -70,7 +74,12 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
WebSocket URI for the Computer API Server
|
||||
"""
|
||||
protocol = "wss" if self.api_key else "ws"
|
||||
port = "8443" if self.api_key else "8000"
|
||||
# Use custom API port if provided, otherwise use defaults based on API key
|
||||
port = (
|
||||
str(self._api_port)
|
||||
if self._api_port is not None
|
||||
else ("8443" if self.api_key else "8000")
|
||||
)
|
||||
return f"{protocol}://{self.ip_address}:{port}/ws"
|
||||
|
||||
@property
|
||||
@@ -81,7 +90,12 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
REST URI for the Computer API Server
|
||||
"""
|
||||
protocol = "https" if self.api_key else "http"
|
||||
port = "8443" if self.api_key else "8000"
|
||||
# Use custom API port if provided, otherwise use defaults based on API key
|
||||
port = (
|
||||
str(self._api_port)
|
||||
if self._api_port is not None
|
||||
else ("8443" if self.api_key else "8000")
|
||||
)
|
||||
return f"{protocol}://{self.ip_address}:{port}/cmd"
|
||||
|
||||
# Mouse actions
|
||||
|
||||
@@ -13,7 +13,8 @@ class LinuxComputerInterface(GenericComputerInterface):
|
||||
password: str = "lume",
|
||||
api_key: Optional[str] = None,
|
||||
vm_name: Optional[str] = None,
|
||||
api_port: Optional[int] = None,
|
||||
):
|
||||
super().__init__(
|
||||
ip_address, username, password, api_key, vm_name, "computer.interface.linux"
|
||||
ip_address, username, password, api_key, vm_name, "computer.interface.linux", api_port
|
||||
)
|
||||
|
||||
@@ -13,9 +13,10 @@ class MacOSComputerInterface(GenericComputerInterface):
|
||||
password: str = "lume",
|
||||
api_key: Optional[str] = None,
|
||||
vm_name: Optional[str] = None,
|
||||
api_port: Optional[int] = None,
|
||||
):
|
||||
super().__init__(
|
||||
ip_address, username, password, api_key, vm_name, "computer.interface.macos"
|
||||
ip_address, username, password, api_key, vm_name, "computer.interface.macos", api_port
|
||||
)
|
||||
|
||||
async def diorama_cmd(self, action: str, arguments: Optional[dict] = None) -> dict:
|
||||
|
||||
@@ -13,7 +13,8 @@ class WindowsComputerInterface(GenericComputerInterface):
|
||||
password: str = "lume",
|
||||
api_key: Optional[str] = None,
|
||||
vm_name: Optional[str] = None,
|
||||
api_port: Optional[int] = None,
|
||||
):
|
||||
super().__init__(
|
||||
ip_address, username, password, api_key, vm_name, "computer.interface.windows"
|
||||
ip_address, username, password, api_key, vm_name, "computer.interface.windows", api_port
|
||||
)
|
||||
|
||||
@@ -37,7 +37,6 @@ class DockerProvider(BaseVMProvider):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: Optional[int] = 8000,
|
||||
host: str = "localhost",
|
||||
storage: Optional[str] = None,
|
||||
shared_path: Optional[str] = None,
|
||||
@@ -45,11 +44,11 @@ class DockerProvider(BaseVMProvider):
|
||||
verbose: bool = False,
|
||||
ephemeral: bool = False,
|
||||
vnc_port: Optional[int] = 6901,
|
||||
api_port: Optional[int] = None,
|
||||
):
|
||||
"""Initialize the Docker VM Provider.
|
||||
|
||||
Args:
|
||||
port: Currently unused (VM provider port)
|
||||
host: Hostname for the API server (default: localhost)
|
||||
storage: Path for persistent VM storage
|
||||
shared_path: Path for shared folder between host and container
|
||||
@@ -60,9 +59,10 @@ class DockerProvider(BaseVMProvider):
|
||||
verbose: Enable verbose logging
|
||||
ephemeral: Use ephemeral (temporary) storage
|
||||
vnc_port: Port for VNC interface (default: 6901)
|
||||
api_port: Port for API server (default: 8000)
|
||||
"""
|
||||
self.host = host
|
||||
self.api_port = 8000
|
||||
self.api_port = api_port if api_port is not None else 8000
|
||||
self.vnc_port = vnc_port
|
||||
self.ephemeral = ephemeral
|
||||
|
||||
@@ -296,6 +296,7 @@ class DockerProvider(BaseVMProvider):
|
||||
if vnc_port:
|
||||
cmd.extend(["-p", f"{vnc_port}:6901"]) # VNC port
|
||||
if api_port:
|
||||
# Map the API port to container port 8000 (computer-server default)
|
||||
cmd.extend(["-p", f"{api_port}:8000"]) # computer-server API port
|
||||
|
||||
# Add volume mounts if storage is specified
|
||||
|
||||
@@ -14,7 +14,7 @@ class VMProviderFactory:
|
||||
@staticmethod
|
||||
def create_provider(
|
||||
provider_type: Union[str, VMProviderType],
|
||||
port: int = 7777,
|
||||
provider_port: int = 7777,
|
||||
host: str = "localhost",
|
||||
bin_path: Optional[str] = None,
|
||||
storage: Optional[str] = None,
|
||||
@@ -23,13 +23,14 @@ class VMProviderFactory:
|
||||
verbose: bool = False,
|
||||
ephemeral: bool = False,
|
||||
noVNC_port: Optional[int] = None,
|
||||
api_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> BaseVMProvider:
|
||||
"""Create a VM provider of the specified type.
|
||||
|
||||
Args:
|
||||
provider_type: Type of VM provider to create
|
||||
port: Port for the API server
|
||||
provider_port: Port for the provider's API server
|
||||
host: Hostname for the API server
|
||||
bin_path: Path to provider binary if needed
|
||||
storage: Path for persistent VM storage
|
||||
@@ -37,7 +38,8 @@ class VMProviderFactory:
|
||||
image: VM image to use (for Lumier provider)
|
||||
verbose: Enable verbose logging
|
||||
ephemeral: Use ephemeral (temporary) storage
|
||||
noVNC_port: Specific port for noVNC interface (for Lumier provider)
|
||||
noVNC_port: Specific port for noVNC interface (for Lumier and Docker provider)
|
||||
api_port: Specific port for Computer API server (for Docker provider)
|
||||
|
||||
Returns:
|
||||
An instance of the requested VM provider
|
||||
@@ -63,7 +65,11 @@ class VMProviderFactory:
|
||||
"Please install it with 'pip install cua-computer[lume]'"
|
||||
)
|
||||
return LumeProvider(
|
||||
port=port, host=host, storage=storage, verbose=verbose, ephemeral=ephemeral
|
||||
provider_port=provider_port,
|
||||
host=host,
|
||||
storage=storage,
|
||||
verbose=verbose,
|
||||
ephemeral=ephemeral,
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import LumeProvider: {e}")
|
||||
@@ -81,7 +87,7 @@ class VMProviderFactory:
|
||||
"Please install Docker for Apple Silicon and Lume CLI before using this provider."
|
||||
)
|
||||
return LumierProvider(
|
||||
port=port,
|
||||
provider_port=provider_port,
|
||||
host=host,
|
||||
storage=storage,
|
||||
shared_path=shared_path,
|
||||
@@ -121,7 +127,6 @@ class VMProviderFactory:
|
||||
"Please install it with 'pip install -U git+https://github.com/karkason/pywinsandbox.git'"
|
||||
)
|
||||
return WinSandboxProvider(
|
||||
port=port,
|
||||
host=host,
|
||||
storage=storage,
|
||||
verbose=verbose,
|
||||
@@ -144,7 +149,6 @@ class VMProviderFactory:
|
||||
"Please install Docker and ensure it is running."
|
||||
)
|
||||
return DockerProvider(
|
||||
port=port,
|
||||
host=host,
|
||||
storage=storage,
|
||||
shared_path=shared_path,
|
||||
@@ -152,6 +156,7 @@ class VMProviderFactory:
|
||||
verbose=verbose,
|
||||
ephemeral=ephemeral,
|
||||
vnc_port=noVNC_port,
|
||||
api_port=api_port,
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import DockerProvider: {e}")
|
||||
|
||||
@@ -38,7 +38,7 @@ class LumeProvider(BaseVMProvider):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: int = 7777,
|
||||
provider_port: int = 7777,
|
||||
host: str = "localhost",
|
||||
storage: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
@@ -47,7 +47,7 @@ class LumeProvider(BaseVMProvider):
|
||||
"""Initialize the Lume provider.
|
||||
|
||||
Args:
|
||||
port: Port for the Lume API server (default: 7777)
|
||||
provider_port: Port for the Lume API server (default: 7777)
|
||||
host: Host to use for API connections (default: localhost)
|
||||
storage: Path to store VM data
|
||||
verbose: Enable verbose logging
|
||||
@@ -59,7 +59,7 @@ class LumeProvider(BaseVMProvider):
|
||||
)
|
||||
|
||||
self.host = host
|
||||
self.port = port # Default port for Lume API
|
||||
self.port = provider_port # Default port for Lume API
|
||||
self.storage = storage
|
||||
self.verbose = verbose
|
||||
self.ephemeral = ephemeral # If True, VMs will be deleted after stopping
|
||||
|
||||
@@ -39,7 +39,7 @@ class LumierProvider(BaseVMProvider):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: Optional[int] = 7777,
|
||||
provider_port: Optional[int] = 7777,
|
||||
host: str = "localhost",
|
||||
storage: Optional[str] = None, # Can be a path or 'ephemeral'
|
||||
shared_path: Optional[str] = None,
|
||||
@@ -51,7 +51,7 @@ class LumierProvider(BaseVMProvider):
|
||||
"""Initialize the Lumier VM Provider.
|
||||
|
||||
Args:
|
||||
port: Port for the API server (default: 7777)
|
||||
provider_port: Port for the API server (default: 7777)
|
||||
host: Hostname for the API server (default: localhost)
|
||||
storage: Path for persistent VM storage
|
||||
shared_path: Path for shared folder between host and VM
|
||||
@@ -61,8 +61,8 @@ class LumierProvider(BaseVMProvider):
|
||||
noVNC_port: Specific port for noVNC interface (default: 8006)
|
||||
"""
|
||||
self.host = host
|
||||
# Always ensure api_port has a valid value (7777 is the default)
|
||||
self.api_port = 7777 if port is None else port
|
||||
# Always ensure lume_port has a valid value (7777 is the default)
|
||||
self.lume_port = 7777 if provider_port is None else provider_port
|
||||
self.vnc_port = noVNC_port # User-specified noVNC port, will be set in run_vm if provided
|
||||
self.ephemeral = ephemeral
|
||||
|
||||
@@ -198,7 +198,7 @@ class LumierProvider(BaseVMProvider):
|
||||
vm_info = lume_api_get(
|
||||
vm_name=name,
|
||||
host=self.host,
|
||||
port=self.api_port,
|
||||
port=self.lume_port,
|
||||
storage=storage if storage is not None else self.storage,
|
||||
debug=self.verbose,
|
||||
verbose=self.verbose,
|
||||
@@ -320,7 +320,7 @@ class LumierProvider(BaseVMProvider):
|
||||
logger.debug(f"Using specified noVNC_port: {self.vnc_port}")
|
||||
|
||||
# Set API URL using the API port
|
||||
self._api_url = f"http://{self.host}:{self.api_port}"
|
||||
self._api_url = f"http://{self.host}:{self.lume_port}"
|
||||
|
||||
# Parse memory setting
|
||||
memory_mb = self._parse_memory(run_opts.get("memory", "8GB"))
|
||||
@@ -671,7 +671,7 @@ class LumierProvider(BaseVMProvider):
|
||||
# Container is running, check if API is responsive
|
||||
try:
|
||||
# First check the health endpoint
|
||||
api_url = f"http://{self.host}:{self.api_port}/health"
|
||||
api_url = f"http://{self.host}:{self.lume_port}/health"
|
||||
logger.info(f"Checking API health at: {api_url}")
|
||||
|
||||
# Use longer timeout for API health check since it may still be initializing
|
||||
@@ -685,7 +685,7 @@ class LumierProvider(BaseVMProvider):
|
||||
else:
|
||||
# API health check failed, now let's check if the VM status endpoint is responsive
|
||||
# This covers cases where the health endpoint isn't implemented but the VM API is working
|
||||
vm_api_url = f"http://{self.host}:{self.api_port}/lume/vms/{container_name}"
|
||||
vm_api_url = f"http://{self.host}:{self.lume_port}/lume/vms/{container_name}"
|
||||
if self.storage:
|
||||
import urllib.parse
|
||||
|
||||
@@ -1026,7 +1026,7 @@ class LumierProvider(BaseVMProvider):
|
||||
# Initialize the API URL with the default value if not already set
|
||||
# This ensures get_vm can work before run_vm is called
|
||||
if not hasattr(self, "_api_url") or not self._api_url:
|
||||
self._api_url = f"http://{self.host}:{self.api_port}"
|
||||
self._api_url = f"http://{self.host}:{self.lume_port}"
|
||||
logger.info(f"Initialized default Lumier API URL: {self._api_url}")
|
||||
|
||||
return self
|
||||
|
||||
@@ -29,7 +29,6 @@ class WinSandboxProvider(BaseVMProvider):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: int = 7777,
|
||||
host: str = "localhost",
|
||||
storage: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
@@ -41,7 +40,6 @@ class WinSandboxProvider(BaseVMProvider):
|
||||
"""Initialize the Windows Sandbox provider.
|
||||
|
||||
Args:
|
||||
port: Port for the computer server (default: 7777)
|
||||
host: Host to use for connections (default: localhost)
|
||||
storage: Storage path (ignored - Windows Sandbox is always ephemeral)
|
||||
verbose: Enable verbose logging
|
||||
@@ -56,7 +54,6 @@ class WinSandboxProvider(BaseVMProvider):
|
||||
)
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.verbose = verbose
|
||||
self.memory_mb = memory_mb
|
||||
self.networking = networking
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
|
||||
|
||||
[project]
|
||||
name = "cua-computer"
|
||||
version = "0.4.17"
|
||||
version = "0.4.18"
|
||||
description = "Computer-Use Interface (CUI) framework powering Cua"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
|
||||
478
libs/python/computer/tests/test_helpers.py
Normal file
478
libs/python/computer/tests/test_helpers.py
Normal file
@@ -0,0 +1,478 @@
|
||||
from computer.helpers import generate_source_code
|
||||
|
||||
|
||||
class TestSimpleCases:
|
||||
"""Test simple dependency cases"""
|
||||
|
||||
def test_simple_function_no_dependencies(self):
|
||||
"""Test simple function with no dependencies"""
|
||||
|
||||
def simple_func():
|
||||
return 42
|
||||
|
||||
code = generate_source_code(simple_func)
|
||||
|
||||
assert "def simple_func():" in code
|
||||
assert "return 42" in code
|
||||
|
||||
def test_function_with_parameters(self):
|
||||
"""Test function with parameters"""
|
||||
|
||||
def add(a, b):
|
||||
return a + b
|
||||
|
||||
code = generate_source_code(add)
|
||||
|
||||
assert "def add(a, b):" in code
|
||||
assert "return a + b" in code
|
||||
|
||||
def test_function_with_docstring(self):
|
||||
"""Test function with docstring"""
|
||||
|
||||
def documented():
|
||||
"""This is a docstring"""
|
||||
return True
|
||||
|
||||
code = generate_source_code(documented)
|
||||
|
||||
assert "def documented():" in code
|
||||
assert "This is a docstring" in code
|
||||
|
||||
|
||||
class TestImports:
|
||||
"""Test import handling"""
|
||||
|
||||
def test_stdlib_import_inside_function(self):
|
||||
"""Test function with stdlib import inside"""
|
||||
|
||||
def with_stdlib():
|
||||
import math
|
||||
|
||||
return math.sqrt(16)
|
||||
|
||||
code = generate_source_code(with_stdlib)
|
||||
|
||||
assert "import math" in code
|
||||
assert "def with_stdlib():" in code
|
||||
assert "math.sqrt(16)" in code
|
||||
|
||||
def test_multiple_stdlib_imports(self):
|
||||
"""Test function with multiple stdlib imports"""
|
||||
|
||||
def with_multiple():
|
||||
import json
|
||||
import math
|
||||
|
||||
return json.dumps({"value": math.pi})
|
||||
|
||||
code = generate_source_code(with_multiple)
|
||||
|
||||
assert "import json" in code
|
||||
assert "import math" in code
|
||||
|
||||
def test_from_import_stdlib(self):
|
||||
"""Test from X import Y style"""
|
||||
|
||||
def with_from_import():
|
||||
from math import pi, sqrt
|
||||
|
||||
return sqrt(16) + pi
|
||||
|
||||
code = generate_source_code(with_from_import)
|
||||
|
||||
assert "from math import" in code
|
||||
assert "sqrt, pi" in code or ("sqrt" in code and "pi" in code)
|
||||
|
||||
def test_third_party_global_import(self):
|
||||
"""Test globally imported third-party module"""
|
||||
from requests import get
|
||||
|
||||
def with_requests():
|
||||
return get
|
||||
|
||||
code = generate_source_code(with_requests)
|
||||
|
||||
assert "from requests import get" in code
|
||||
assert "def with_requests():" in code
|
||||
|
||||
|
||||
class TestHelperFunctions:
|
||||
"""Test helper function dependencies"""
|
||||
|
||||
def test_single_helper_function(self):
|
||||
"""Test function with one helper"""
|
||||
|
||||
def helper(x):
|
||||
return x * 2
|
||||
|
||||
def main():
|
||||
return helper(21)
|
||||
|
||||
code = generate_source_code(main)
|
||||
|
||||
assert "def helper(x):" in code
|
||||
assert "return x * 2" in code
|
||||
assert "def main():" in code
|
||||
assert "return helper(21)" in code
|
||||
|
||||
def test_nested_helper_functions(self):
|
||||
"""Test function with nested helpers"""
|
||||
|
||||
def helper_a(x):
|
||||
return x + 1
|
||||
|
||||
def helper_b(x):
|
||||
return helper_a(x) * 2
|
||||
|
||||
def main():
|
||||
return helper_b(10)
|
||||
|
||||
code = generate_source_code(main)
|
||||
|
||||
assert "def helper_a(x):" in code
|
||||
assert "def helper_b(x):" in code
|
||||
assert "def main():" in code
|
||||
|
||||
def test_helper_dependency_ordering(self):
|
||||
"""Test that helpers are ordered correctly (dependencies first)"""
|
||||
|
||||
def level_3(x):
|
||||
return x + 1
|
||||
|
||||
def level_2(x):
|
||||
return level_3(x) * 2
|
||||
|
||||
def level_1(x):
|
||||
return level_2(x) + 5
|
||||
|
||||
code = generate_source_code(level_1)
|
||||
|
||||
# Check ordering
|
||||
pos_3 = code.find("def level_3")
|
||||
pos_2 = code.find("def level_2")
|
||||
pos_1 = code.find("def level_1")
|
||||
|
||||
assert pos_3 < pos_2 < pos_1, "Functions not in correct dependency order"
|
||||
|
||||
def test_helper_with_import(self):
|
||||
"""Test helper function that uses imports"""
|
||||
import math
|
||||
|
||||
def helper(x):
|
||||
return math.sqrt(x)
|
||||
|
||||
def main():
|
||||
return helper(16)
|
||||
|
||||
code = generate_source_code(main)
|
||||
|
||||
assert "import math" in code
|
||||
assert "def helper(x):" in code
|
||||
assert "def main():" in code
|
||||
|
||||
|
||||
class TestGlobalConstants:
|
||||
"""Test global constant handling"""
|
||||
|
||||
def test_simple_constant(self):
|
||||
"""Test function using simple constant"""
|
||||
MY_CONSTANT = 42
|
||||
|
||||
def with_constant():
|
||||
return MY_CONSTANT * 2
|
||||
|
||||
code = generate_source_code(with_constant)
|
||||
|
||||
assert "MY_CONSTANT = 42" in code
|
||||
assert "def with_constant():" in code
|
||||
|
||||
def test_string_constant(self):
|
||||
"""Test function using string constant"""
|
||||
API_URL = "https://api.example.com"
|
||||
|
||||
def with_string():
|
||||
return API_URL
|
||||
|
||||
code = generate_source_code(with_string)
|
||||
|
||||
assert "API_URL" in code
|
||||
assert "https://api.example.com" in code
|
||||
|
||||
def test_multiple_constants(self):
|
||||
"""Test function using multiple constants"""
|
||||
BASE_URL = "https://example.com"
|
||||
TIMEOUT = 30
|
||||
|
||||
def with_multiple():
|
||||
return f"{BASE_URL}?timeout={TIMEOUT}"
|
||||
|
||||
code = generate_source_code(with_multiple)
|
||||
|
||||
assert "BASE_URL" in code
|
||||
assert "TIMEOUT" in code
|
||||
|
||||
|
||||
class TestClassDefinitions:
|
||||
"""Test class definition handling"""
|
||||
|
||||
def test_simple_class(self):
|
||||
"""Test function using simple class"""
|
||||
|
||||
class Calculator:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def add(self, x):
|
||||
return self.value + x
|
||||
|
||||
def with_class():
|
||||
calc = Calculator(10)
|
||||
return calc.add(5)
|
||||
|
||||
code = generate_source_code(with_class)
|
||||
|
||||
assert "class Calculator:" in code
|
||||
assert "__init__" in code
|
||||
assert "def add" in code
|
||||
assert "def with_class():" in code
|
||||
|
||||
def test_class_with_methods(self):
|
||||
"""Test class with multiple methods"""
|
||||
|
||||
class Math:
|
||||
def add(self, a, b):
|
||||
return a + b
|
||||
|
||||
def multiply(self, a, b):
|
||||
return a * b
|
||||
|
||||
def use_math():
|
||||
m = Math()
|
||||
return m.add(2, 3) + m.multiply(4, 5)
|
||||
|
||||
code = generate_source_code(use_math)
|
||||
|
||||
assert "class Math:" in code
|
||||
assert "def add" in code
|
||||
assert "def multiply" in code
|
||||
|
||||
|
||||
class TestDecoratorRemoval:
|
||||
"""Test decorator removal"""
|
||||
|
||||
def test_simple_decorator_removed(self):
|
||||
"""Test that decorators are removed"""
|
||||
|
||||
def my_decorator(f):
|
||||
def wrapper(*args, **kwargs):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@my_decorator
|
||||
def decorated():
|
||||
return 42
|
||||
|
||||
code = generate_source_code(decorated)
|
||||
|
||||
assert "@my_decorator" not in code
|
||||
assert "def decorated():" in code
|
||||
assert "return 42" in code
|
||||
|
||||
def test_multiple_decorators_removed(self):
|
||||
"""Test that multiple decorators are removed"""
|
||||
|
||||
def decorator1(f):
|
||||
return f
|
||||
|
||||
def decorator2(f):
|
||||
return f
|
||||
|
||||
@decorator1
|
||||
@decorator2
|
||||
def multi_decorated():
|
||||
return True
|
||||
|
||||
code = generate_source_code(multi_decorated)
|
||||
|
||||
assert "@decorator1" not in code
|
||||
assert "@decorator2" not in code
|
||||
assert "def multi_decorated():" in code
|
||||
|
||||
|
||||
class TestComplexScenarios:
|
||||
"""Test complex real-world scenarios"""
|
||||
|
||||
def test_api_client_pattern(self):
|
||||
"""Test typical API client pattern"""
|
||||
import json
|
||||
|
||||
BASE_URL = "https://api.example.com"
|
||||
|
||||
def make_request(endpoint):
|
||||
return f"{BASE_URL}/{endpoint}"
|
||||
|
||||
def parse_response(response):
|
||||
return json.loads(response)
|
||||
|
||||
def get_user(user_id):
|
||||
url = make_request(f"users/{user_id}")
|
||||
response = f'{{"id": {user_id}}}'
|
||||
return parse_response(response)
|
||||
|
||||
code = generate_source_code(get_user)
|
||||
|
||||
assert "import json" in code
|
||||
assert "BASE_URL" in code
|
||||
assert "def make_request" in code
|
||||
assert "def parse_response" in code
|
||||
assert "def get_user" in code
|
||||
|
||||
def test_data_processing_pipeline(self):
|
||||
"""Test data processing pipeline"""
|
||||
|
||||
def validate_data(data):
|
||||
return [x for x in data if x > 0]
|
||||
|
||||
def transform_data(data):
|
||||
return [x * 2 for x in data]
|
||||
|
||||
def process_pipeline(raw_data):
|
||||
valid = validate_data(raw_data)
|
||||
return transform_data(valid)
|
||||
|
||||
code = generate_source_code(process_pipeline)
|
||||
|
||||
assert "def validate_data" in code
|
||||
assert "def transform_data" in code
|
||||
assert "def process_pipeline" in code
|
||||
|
||||
# Check ordering
|
||||
validate_pos = code.find("def validate_data")
|
||||
transform_pos = code.find("def transform_data")
|
||||
pipeline_pos = code.find("def process_pipeline")
|
||||
|
||||
assert validate_pos < pipeline_pos
|
||||
assert transform_pos < pipeline_pos
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases"""
|
||||
|
||||
def test_function_with_default_args(self):
|
||||
"""Test function with default arguments"""
|
||||
|
||||
def with_defaults(a, b=10, c=20):
|
||||
return a + b + c
|
||||
|
||||
code = generate_source_code(with_defaults)
|
||||
|
||||
assert "def with_defaults(a, b=10, c=20):" in code
|
||||
|
||||
def test_function_with_kwargs(self):
|
||||
"""Test function with *args and **kwargs"""
|
||||
|
||||
def with_varargs(*args, **kwargs):
|
||||
return sum(args)
|
||||
|
||||
code = generate_source_code(with_varargs)
|
||||
|
||||
assert "def with_varargs(*args, **kwargs):" in code
|
||||
|
||||
def test_async_function(self):
|
||||
"""Test async function"""
|
||||
|
||||
async def async_func():
|
||||
return 42
|
||||
|
||||
code = generate_source_code(async_func)
|
||||
|
||||
assert "async def async_func():" in code
|
||||
assert "return 42" in code
|
||||
|
||||
def test_lambda_works_in_files(self):
|
||||
"""Test that lambda functions work when defined in files"""
|
||||
lambda_func = lambda x: x * 2
|
||||
|
||||
# Lambda functions defined in files CAN have their source extracted
|
||||
code = generate_source_code(lambda_func)
|
||||
|
||||
# Should include the lambda
|
||||
assert "lambda" in code.lower() or "lambda_func" in code
|
||||
|
||||
|
||||
class TestCaching:
|
||||
"""Test caching mechanism"""
|
||||
|
||||
def test_cache_returns_same_result(self):
|
||||
"""Test that cache returns identical results"""
|
||||
|
||||
def cached_func():
|
||||
return 42
|
||||
|
||||
code1 = generate_source_code(cached_func)
|
||||
code2 = generate_source_code(cached_func)
|
||||
|
||||
assert code1 == code2
|
||||
|
||||
def test_different_functions_different_cache(self):
|
||||
"""Test that different functions have different cache entries"""
|
||||
|
||||
def func1():
|
||||
return 1
|
||||
|
||||
def func2():
|
||||
return 2
|
||||
|
||||
code1 = generate_source_code(func1)
|
||||
code2 = generate_source_code(func2)
|
||||
|
||||
assert code1 != code2
|
||||
assert "return 1" in code1
|
||||
assert "return 2" in code2
|
||||
|
||||
|
||||
class TestBuiltinHandling:
|
||||
"""Test that builtins are handled correctly"""
|
||||
|
||||
def test_builtin_not_included(self):
|
||||
"""Test that builtin functions are not included as dependencies"""
|
||||
|
||||
def uses_builtins():
|
||||
data = list(range(10))
|
||||
return len(data)
|
||||
|
||||
code = generate_source_code(uses_builtins)
|
||||
|
||||
# Should not include definitions for list, range, len
|
||||
assert "def list" not in code
|
||||
assert "def range" not in code
|
||||
assert "def len" not in code
|
||||
assert "def uses_builtins():" in code
|
||||
|
||||
|
||||
class TestImportStylePreservation:
|
||||
"""Test that import styles are preserved"""
|
||||
|
||||
def test_from_import_preserved(self):
|
||||
"""Test from X import Y is preserved"""
|
||||
from math import sqrt
|
||||
|
||||
def use_sqrt():
|
||||
return sqrt(16)
|
||||
|
||||
code = generate_source_code(use_sqrt)
|
||||
|
||||
assert "from math import sqrt" in code
|
||||
|
||||
def test_import_as_preserved(self):
|
||||
"""Test import X as Y is preserved"""
|
||||
import json as j
|
||||
|
||||
def use_json():
|
||||
return j.dumps({"key": "value"})
|
||||
|
||||
code = generate_source_code(use_json)
|
||||
|
||||
# The import style should be preserved
|
||||
assert "import json" in code
|
||||
@@ -129,12 +129,12 @@ See [desktop-extension/README.md](desktop-extension/README.md) for more details.
|
||||
|
||||
## Documentation
|
||||
|
||||
- Installation: https://cua.ai/docs/libraries/mcp-server/installation
|
||||
- Configuration: https://cua.ai/docs/libraries/mcp-server/configuration
|
||||
- Usage: https://cua.ai/docs/libraries/mcp-server/usage
|
||||
- Tools: https://cua.ai/docs/libraries/mcp-server/tools
|
||||
- Client Integrations: https://cua.ai/docs/libraries/mcp-server/client-integrations
|
||||
- LLM Integrations: https://cua.ai/docs/libraries/mcp-server/llm-integrations
|
||||
- Installation: https://cua.ai/docs/agent-sdk/mcp-server/installation
|
||||
- Configuration: https://cua.ai/docs/agent-sdk/mcp-server/configuration
|
||||
- Usage: https://cua.ai/docs/agent-sdk/mcp-server/usage
|
||||
- Tools: https://cua.ai/docs/agent-sdk/mcp-server/tools
|
||||
- Client Integrations: https://cua.ai/docs/agent-sdk/mcp-server/client-integrations
|
||||
- LLM Integrations: https://cua.ai/docs/agent-sdk/mcp-server/llm-integrations
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
|
||||
Reference in New Issue
Block a user