Merge branch 'main' into feat/fara-browser-use

This commit is contained in:
ddupont
2025-12-17 14:38:48 -08:00
committed by GitHub
131 changed files with 7061 additions and 3643 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.5.1
current_version = 0.5.2
commit = True
tag = True
tag_name = agent-v{new_version}

View File

@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
[project]
name = "cua-agent"
version = "0.5.1"
version = "0.5.2"
description = "CUA (Computer Use) Agent for AI-driven computer interaction"
readme = "README.md"
authors = [

View File

@@ -0,0 +1,26 @@
# CUA Bench UI
Lightweight webUI window controller for CUA bench environments using pywebview
## Usage
```python
from bench_ui import launch_window, get_element_rect, execute_javascript
# Launch a window with inline HTML content
pid = launch_window(html="<html><body><h1>Hello</h1></body></html>")
# Get element rect in screen space
rect = get_element_rect(pid, "h1", space="screen")
print(rect)
# Execute arbitrary JavaScript
text = execute_javascript(pid, "document.querySelector('h1')?.textContent")
print(text)
```
## Installation
```bash
pip install cua-bench-ui
```

View File

@@ -0,0 +1,3 @@
from .api import execute_javascript, get_element_rect, launch_window
__all__ = ["launch_window", "get_element_rect", "execute_javascript"]

View File

@@ -0,0 +1,181 @@
import json
import os
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from typing import Any, Dict, Optional
from urllib import request
from urllib.error import HTTPError, URLError
import psutil
# Map child PID -> listening port
_pid_to_port: Dict[int, int] = {}
def _post_json(url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
data = json.dumps(payload).encode("utf-8")
req = request.Request(
url, data=data, headers={"Content-Type": "application/json"}, method="POST"
)
try:
with request.urlopen(req, timeout=5) as resp:
text = resp.read().decode("utf-8")
return json.loads(text)
except HTTPError as e:
try:
body = (e.read() or b"").decode("utf-8", errors="ignore")
return json.loads(body)
except Exception:
return {"error": "http_error", "status": getattr(e, "code", None)}
except URLError as e:
return {"error": "url_error", "reason": str(e.reason)}
def _detect_port_for_pid(pid: int) -> int:
"""Detect a listening local TCP port for the given PID using psutil.
Fails fast if psutil is unavailable or if no suitable port is found.
"""
if psutil is None:
raise RuntimeError("psutil is required for PID->port detection. Please install psutil.")
# Scan system-wide connections and filter by PID
for c in psutil.net_connections(kind="tcp"):
if getattr(c, "pid", None) != pid:
continue
laddr = getattr(c, "laddr", None)
status = str(getattr(c, "status", ""))
if not laddr or not isinstance(laddr, tuple) or len(laddr) < 2:
continue
lip, lport = laddr[0], int(laddr[1])
if status.upper() != "LISTEN":
continue
if lip in ("127.0.0.1", "::1", "0.0.0.0", "::"):
return lport
raise RuntimeError(f"Could not detect listening port for pid {pid}")
def launch_window(
url: Optional[str] = None,
*,
html: Optional[str] = None,
folder: Optional[str] = None,
title: str = "Window",
x: Optional[int] = None,
y: Optional[int] = None,
width: int = 600,
height: int = 400,
icon: Optional[str] = None,
use_inner_size: bool = False,
title_bar_style: str = "default",
) -> int:
"""Create a pywebview window in a child process and return its PID.
Preferred input is a URL via the positional `url` parameter.
To load inline HTML instead, pass `html=...`.
To serve a static folder, pass `folder=...` (path to directory).
Spawns `python -m bench_ui.child` with a JSON config passed via a temp file.
The child prints a single JSON line: {"pid": <pid>, "port": <port>}.
We cache pid->port for subsequent control calls like get_element_rect.
"""
if not url and not html and not folder:
raise ValueError("launch_window requires either a url, html, or folder")
config = {
"url": url,
"html": html,
"folder": folder,
"title": title,
"x": x,
"y": y,
"width": width,
"height": height,
"icon": icon,
"use_inner_size": use_inner_size,
"title_bar_style": title_bar_style,
}
with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f:
json.dump(config, f)
cfg_path = f.name
try:
# Launch child process
proc = subprocess.Popen(
[sys.executable, "-m", "bench_ui.child", cfg_path],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
assert proc.stdout is not None
# Read first line with startup info
line = proc.stdout.readline().strip()
info = json.loads(line)
pid = int(info["pid"]) if "pid" in info else proc.pid
port = int(info["port"]) # required
_pid_to_port[pid] = port
return pid
finally:
try:
os.unlink(cfg_path)
except Exception:
pass
def get_element_rect(pid: int, selector: str, *, space: str = "window"):
"""Ask the child process to compute element client rect via injected JS.
Returns a dict like {"x": float, "y": float, "width": float, "height": float} or None if not found.
"""
if pid not in _pid_to_port:
_pid_to_port[pid] = _detect_port_for_pid(pid)
port = _pid_to_port[pid]
url = f"http://127.0.0.1:{port}/rect"
last: Dict[str, Any] = {}
for _ in range(30): # ~3s total
resp = _post_json(url, {"selector": selector, "space": space})
last = resp or {}
rect = last.get("rect") if isinstance(last, dict) else None
err = last.get("error") if isinstance(last, dict) else None
if rect is not None:
return rect
if err in ("window_not_ready", "invalid_json"):
time.sleep(0.1)
continue
# If other transient errors, brief retry
if err:
time.sleep(0.1)
continue
time.sleep(0.1)
raise RuntimeError(f"Failed to get element rect: {last}")
def execute_javascript(pid: int, javascript: str):
"""Execute arbitrary JavaScript in the window and return its result.
Retries briefly while the window is still becoming ready.
"""
if pid not in _pid_to_port:
_pid_to_port[pid] = _detect_port_for_pid(pid)
port = _pid_to_port[pid]
url = f"http://127.0.0.1:{port}/eval"
last: Dict[str, Any] = {}
for _ in range(30): # ~3s total
resp = _post_json(url, {"javascript": javascript})
last = resp or {}
if isinstance(last, dict):
if "result" in last:
return last["result"]
if last.get("error") in ("window_not_ready", "invalid_json"):
time.sleep(0.1)
continue
if last.get("error"):
time.sleep(0.1)
continue
time.sleep(0.1)
raise RuntimeError(f"Failed to execute JavaScript: {last}")

View File

@@ -0,0 +1,221 @@
import asyncio
import json
import os
import random
import socket
import sys
import threading
from pathlib import Path
from typing import Optional
import webview
from aiohttp import web
def _get_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def _start_http_server(
window: webview.Window,
port: int,
ready_event: threading.Event,
html_content: str | None = None,
folder_path: str | None = None,
):
async def rect_handler(request: web.Request):
try:
data = await request.json()
except Exception:
return web.json_response({"error": "invalid_json"}, status=400)
selector = data.get("selector")
space = data.get("space", "window")
if not isinstance(selector, str):
return web.json_response({"error": "selector_required"}, status=400)
# Ensure window content is loaded
if not ready_event.is_set():
# give it a short chance to finish loading
ready_event.wait(timeout=2.0)
if not ready_event.is_set():
return web.json_response({"error": "window_not_ready"}, status=409)
# Safely embed selector into JS
selector_js = json.dumps(selector)
if space == "screen":
# Compute approximate screen coordinates using window metrics
js = (
"(function(){"
f"const s = {selector_js};"
"const el = document.querySelector(s);"
"if(!el){return null;}"
"const r = el.getBoundingClientRect();"
"const sx = (window.screenX ?? window.screenLeft ?? 0);"
"const syRaw = (window.screenY ?? window.screenTop ?? 0);"
"const frameH = (window.outerHeight - window.innerHeight) || 0;"
"const sy = syRaw + frameH;"
"return {x:sx + r.left, y:sy + r.top, width:r.width, height:r.height};"
"})()"
)
else:
js = (
"(function(){"
f"const s = {selector_js};"
"const el = document.querySelector(s);"
"if(!el){return null;}"
"const r = el.getBoundingClientRect();"
"return {x:r.left,y:r.top,width:r.width,height:r.height};"
"})()"
)
try:
# Evaluate JS on the target window; this call is thread-safe in pywebview
result = window.evaluate_js(js)
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
return web.json_response({"rect": result})
async def eval_handler(request: web.Request):
try:
data = await request.json()
except Exception:
return web.json_response({"error": "invalid_json"}, status=400)
code = data.get("javascript") or data.get("code")
if not isinstance(code, str):
return web.json_response({"error": "javascript_required"}, status=400)
if not ready_event.is_set():
ready_event.wait(timeout=2.0)
if not ready_event.is_set():
return web.json_response({"error": "window_not_ready"}, status=409)
try:
result = window.evaluate_js(code)
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
return web.json_response({"result": result})
async def index_handler(request: web.Request):
if html_content is None:
return web.json_response({"status": "ok", "message": "bench-ui control server"})
return web.Response(text=html_content, content_type="text/html")
app = web.Application()
# If serving a folder, add static file routes
if folder_path:
app.router.add_static("/", folder_path, show_index=True)
else:
app.router.add_get("/", index_handler)
app.router.add_post("/rect", rect_handler)
app.router.add_post("/eval", eval_handler)
loop = asyncio.new_event_loop()
def run_loop():
asyncio.set_event_loop(loop)
runner = web.AppRunner(app)
loop.run_until_complete(runner.setup())
site = web.TCPSite(runner, "127.0.0.1", port)
loop.run_until_complete(site.start())
loop.run_forever()
t = threading.Thread(target=run_loop, daemon=True)
t.start()
def main():
if len(sys.argv) < 2:
print("Usage: python -m bench_ui.child <config.json>", file=sys.stderr)
sys.exit(2)
cfg_path = Path(sys.argv[1])
cfg = json.loads(cfg_path.read_text(encoding="utf-8"))
html: Optional[str] = cfg.get("html") or ""
url: Optional[str] = cfg.get("url")
folder: Optional[str] = cfg.get("folder")
title: str = cfg.get("title", "Window")
x: Optional[int] = cfg.get("x")
y: Optional[int] = cfg.get("y")
width: int = int(cfg.get("width", 600))
height: int = int(cfg.get("height", 400))
icon: Optional[str] = cfg.get("icon")
use_inner_size: bool = bool(cfg.get("use_inner_size", False))
title_bar_style: str = cfg.get("title_bar_style", "default")
# Choose port early so we can point the window to it when serving inline HTML or folder
port = _get_free_port()
# Create window
if url:
window = webview.create_window(
title,
url=url,
width=width,
height=height,
x=x,
y=y,
confirm_close=False,
text_select=True,
background_color="#FFFFFF",
)
html_for_server = None
folder_for_server = None
elif folder:
# Serve static folder at control server root and point window to index.html
resolved_url = f"http://127.0.0.1:{port}/index.html"
window = webview.create_window(
title,
url=resolved_url,
width=width,
height=height,
x=x,
y=y,
confirm_close=False,
text_select=True,
background_color="#FFFFFF",
)
html_for_server = None
folder_for_server = folder
else:
# Serve inline HTML at control server root and point window to it
resolved_url = f"http://127.0.0.1:{port}/"
window = webview.create_window(
title,
url=resolved_url,
width=width,
height=height,
x=x,
y=y,
confirm_close=False,
text_select=True,
background_color="#FFFFFF",
)
html_for_server = html
folder_for_server = None
# Track when the page is loaded so JS execution succeeds
window_ready = threading.Event()
def _on_loaded():
window_ready.set()
window.events.loaded += _on_loaded # type: ignore[attr-defined]
# Start HTTP server for control (and optionally serve inline HTML or static folder)
_start_http_server(
window, port, window_ready, html_content=html_for_server, folder_path=folder_for_server
)
# Print startup info for parent to read
print(json.dumps({"pid": os.getpid(), "port": port}), flush=True)
# Start GUI (blocking)
webview.start(debug=os.environ.get("CUA_BENCH_UI_DEBUG", "false").lower() in ("true", "1"))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,40 @@
from __future__ import annotations
import time
from bench_ui import launch_window, get_element_rect, execute_javascript
from pathlib import Path
import os
def main():
os.environ["CUA_BENCH_UI_DEBUG"] = "1"
# Get the path to the gui folder
gui_folder = Path(__file__).parent / "gui"
# Launch a window serving the static folder
pid = launch_window(
folder=str(gui_folder),
title="Static Folder Example",
width=800,
height=600,
)
print(f"Launched window with PID: {pid}")
print(f"Serving folder: {gui_folder}")
# Give the window a moment to render
time.sleep(1.5)
# Query the client rect of the button element
rect = get_element_rect(pid, "#testButton", space="window")
print("Button rect (window space):", rect)
# Check if button has been clicked
clicked = execute_javascript(pid, "document.getElementById('testButton').disabled")
print("Button clicked:", clicked)
# Get the page title
title = execute_javascript(pid, "document.title")
print("Page title:", title)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,42 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Static Folder Example</title>
<link rel="stylesheet" href="styles.css">
</head>
<body>
<div class="container">
<h1>Static Folder Example</h1>
<p>This page is served from a static folder using bench-ui!</p>
<div class="image-container">
<img src="logo.svg" alt="Example SVG Logo" class="logo">
</div>
<div class="info">
<p>This example demonstrates:</p>
<ul>
<li>Serving a static folder with bench-ui</li>
<li>Loading external CSS files (styles.css)</li>
<li>Loading SVG images (logo.svg)</li>
</ul>
</div>
<button id="testButton" class="btn">Click Me!</button>
<p id="status"></p>
</div>
<script>
document.getElementById('testButton').addEventListener('click', function () {
document.getElementById('status').textContent = 'Button clicked! ✓';
this.disabled = true;
this.textContent = 'Clicked!';
});
</script>
</body>
</html>

View File

@@ -0,0 +1,24 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 200 200">
<defs>
<linearGradient id="grad1" x1="0%" y1="0%" x2="100%" y2="100%">
<stop offset="0%" style="stop-color:#667eea;stop-opacity:1" />
<stop offset="100%" style="stop-color:#764ba2;stop-opacity:1" />
</linearGradient>
</defs>
<!-- Background circle -->
<circle cx="100" cy="100" r="95" fill="url(#grad1)" />
<!-- Window icon -->
<rect x="50" y="50" width="100" height="100" rx="8" fill="white" opacity="0.9" />
<!-- Window panes -->
<line x1="100" y1="50" x2="100" y2="150" stroke="url(#grad1)" stroke-width="4" />
<line x1="50" y1="100" x2="150" y2="100" stroke="url(#grad1)" stroke-width="4" />
<!-- Decorative dots -->
<circle cx="75" cy="75" r="8" fill="url(#grad1)" />
<circle cx="125" cy="75" r="8" fill="url(#grad1)" />
<circle cx="75" cy="125" r="8" fill="url(#grad1)" />
<circle cx="125" cy="125" r="8" fill="url(#grad1)" />
</svg>

After

Width:  |  Height:  |  Size: 963 B

View File

@@ -0,0 +1,92 @@
body {
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
margin: 0;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
display: flex;
justify-content: center;
align-items: center;
}
.container {
background: white;
border-radius: 12px;
padding: 40px;
box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
max-width: 600px;
width: 100%;
}
h1 {
color: #333;
margin-top: 0;
font-size: 2em;
}
p {
color: #666;
line-height: 1.6;
}
.image-container {
display: flex;
justify-content: center;
margin: 30px 0;
}
.logo {
width: 150px;
height: 150px;
}
.info {
background: #f8f9fa;
border-left: 4px solid #667eea;
padding: 20px;
margin: 20px 0;
border-radius: 4px;
}
.info ul {
margin: 10px 0;
padding-left: 20px;
}
.info li {
color: #555;
margin: 8px 0;
}
.btn {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
padding: 12px 30px;
font-size: 16px;
border-radius: 6px;
cursor: pointer;
transition: transform 0.2s, box-shadow 0.2s;
font-weight: 600;
}
.btn:hover:not(:disabled) {
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
}
.btn:active:not(:disabled) {
transform: translateY(0);
}
.btn:disabled {
opacity: 0.6;
cursor: not-allowed;
}
#status {
margin-top: 15px;
font-weight: 600;
color: #28a745;
font-size: 18px;
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 743 KiB

View File

@@ -0,0 +1,80 @@
from __future__ import annotations
import os
import time
from pathlib import Path
from bench_ui import execute_javascript, get_element_rect, launch_window
HTML = """
<!doctype html>
<html>
<head>
<meta charset="utf-8" />
<title>Bench UI Example</title>
<style>
body { font-family: system-ui, sans-serif; margin: 24px; }
#target { width: 220px; height: 120px; background: #4f46e5; color: white; display: flex; align-items: center; justify-content: center; border-radius: 8px; }
</style>
</head>
<body>
<h1>Bench UI Example</h1>
<div id="target">Hello from pywebview</div>
<h1>Click the button</h1>
<button id="submit" class="btn" data-instruction="the button">Submit</button>
<script>
window.__submitted = false;
document.getElementById('submit').addEventListener('click', function() {
window.__submitted = true;
this.textContent = 'Submitted!';
this.disabled = true;
});
</script>
</body>
</html>
"""
def main():
os.environ["CUA_BENCH_UI_DEBUG"] = "1"
# Launch a window with inline HTML content
pid = launch_window(
html=HTML,
title="Bench UI Example",
width=800,
height=600,
)
print(f"Launched window with PID: {pid}")
# Give the window a brief moment to render
time.sleep(1.0)
# Query the client rect of an element via CSS selector in SCREEN space
rect = get_element_rect(pid, "#target", space="screen")
print("Element rect (screen space):", rect)
# Take a screenshot and overlay the bbox
try:
from PIL import ImageDraw, ImageGrab
img = ImageGrab.grab() # full screen
draw = ImageDraw.Draw(img)
x, y, w, h = rect["x"], rect["y"], rect["width"], rect["height"]
box = (x, y, x + w, y + h)
draw.rectangle(box, outline=(255, 0, 0), width=3)
out_path = Path(__file__).parent / "output_overlay.png"
img.save(out_path)
print(f"Saved overlay screenshot to: {out_path}")
except Exception as e:
print(f"Failed to capture/annotate screenshot: {e}")
# Execute arbitrary JavaScript
text = execute_javascript(pid, "window.__submitted")
print("text:", text)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,25 @@
[build-system]
requires = ["pdm-backend"]
build-backend = "pdm.backend"
[project]
name = "cua-bench-ui"
version = "0.7.0"
description = "Lightweight webUI window controller for CUA bench using pywebview"
readme = "README.md"
authors = [
{ name = "TryCua", email = "gh@trycua.com" }
]
dependencies = [
"pywebview>=5.3",
"aiohttp>=3.9.0",
"psutil>=5.9",
]
requires-python = ">=3.12"
[tool.pdm]
distribution = true
[tool.pdm.build]
includes = ["bench_ui/"]
source-includes = ["README.md"]

View File

@@ -0,0 +1,50 @@
import time
import psutil
import pytest
from bench_ui import execute_javascript, launch_window
from bench_ui.api import _pid_to_port
HTML = """
<!doctype html>
<html>
<head>
<meta charset="utf-8" />
<title>Bench UI Test</title>
</head>
<body>
<div id="t">hello-world</div>
</body>
</html>
"""
def test_execute_js_after_clearing_port_mapping():
# Skip if pywebview backend is unavailable on this machine
pywebview = pytest.importorskip("webview")
pid = launch_window(html=HTML, title="Bench UI Test", width=400, height=300)
try:
# Give a brief moment for window to render and server to start
time.sleep(1.0)
# Sanity: mapping should exist initially
assert pid in _pid_to_port
# Clear the cached mapping to simulate a fresh process lookup
del _pid_to_port[pid]
# Now execute JS; this should succeed by detecting the port via psutil
result = execute_javascript(pid, "document.querySelector('#t')?.textContent")
assert result == "hello-world"
finally:
# Best-effort cleanup of the child process
try:
p = psutil.Process(pid)
p.terminate()
try:
p.wait(timeout=3)
except psutil.TimeoutExpired:
p.kill()
except Exception:
pass

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.1.30
current_version = 0.1.31
commit = True
tag = True
tag_name = computer-server-v{new_version}

View File

@@ -40,7 +40,7 @@ Refer to this notebook for a step-by-step guide on how to use the Computer-Use S
## Docs
- [Commands](https://cua.ai/docs/libraries/computer-server/Commands)
- [REST-API](https://cua.ai/docs/libraries/computer-server/REST-API)
- [WebSocket-API](https://cua.ai/docs/libraries/computer-server/WebSocket-API)
- [Index](https://cua.ai/docs/libraries/computer-server)
- [Commands](https://cua.ai/docs/computer-sdk/computer-server/Commands)
- [REST-API](https://cua.ai/docs/computer-sdk/computer-server/REST-API)
- [WebSocket-API](https://cua.ai/docs/computer-sdk/computer-server/WebSocket-API)
- [Index](https://cua.ai/docs/computer-sdk/computer-server)

View File

@@ -24,8 +24,8 @@ from fastapi import (
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from .handlers.factory import HandlerFactory
from .browser import get_browser_manager
from .handlers.factory import HandlerFactory
# Authentication session TTL (in seconds). Override via env var CUA_AUTH_TTL_SECONDS. Default: 60s
AUTH_SESSION_TTL_SECONDS: int = int(os.environ.get("CUA_AUTH_TTL_SECONDS", "60"))
@@ -805,7 +805,7 @@ async def playwright_exec_endpoint(
try:
browser_manager = get_browser_manager()
result = await browser_manager.execute_command(command, params)
if result.get("success"):
return JSONResponse(content=result)
else:

View File

@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
[project]
name = "cua-computer-server"
version = "0.1.30"
version = "0.1.31"
description = "Server component for the Computer-Use Interface (CUI) framework powering Cua"
authors = [

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.4.17
current_version = 0.4.18
commit = True
tag = True
tag_name = computer-v{new_version}

View File

@@ -7,7 +7,28 @@ import platform
import re
import time
import traceback
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
Literal,
Optional,
TypeVar,
Union,
cast,
)
try:
from typing import ParamSpec
except Exception: # pragma: no cover
from typing_extensions import ParamSpec # type: ignore
P = ParamSpec("P")
R = TypeVar("R")
from core.telemetry import is_telemetry_enabled, record_event
from PIL import Image
@@ -66,8 +87,9 @@ class Computer:
verbosity: Union[int, LogLevel] = logging.INFO,
telemetry_enabled: bool = True,
provider_type: Union[str, VMProviderType] = VMProviderType.LUME,
port: Optional[int] = 7777,
provider_port: Optional[int] = 7777,
noVNC_port: Optional[int] = 8006,
api_port: Optional[int] = None,
host: str = os.environ.get("PYLUME_HOST", "localhost"),
storage: Optional[str] = None,
ephemeral: bool = False,
@@ -118,14 +140,19 @@ class Computer:
# Store original parameters
self.image = image
self.port = port
self.provider_port = provider_port
self.noVNC_port = noVNC_port
self.api_port = api_port
self.host = host
self.os_type = os_type
self.provider_type = provider_type
self.ephemeral = ephemeral
self.api_key = api_key if self.provider_type == VMProviderType.CLOUD else None
# Set default API port if not specified
if self.api_port is None:
self.api_port = 8443 if self.api_key else 8000
self.api_key = api_key
self.experiments = experiments or []
if "app-use" in self.experiments:
@@ -273,7 +300,7 @@ class Computer:
interface = cast(
BaseComputerInterface,
InterfaceFactory.create_interface_for_os(
os=self.os_type, ip_address=ip_address # type: ignore[arg-type]
os=self.os_type, ip_address=ip_address, api_port=self.api_port # type: ignore[arg-type]
),
)
self._interface = interface
@@ -300,7 +327,7 @@ class Computer:
storage = "ephemeral" if self.ephemeral else self.storage
verbose = self.verbosity >= LogLevel.DEBUG
ephemeral = self.ephemeral
port = self.port if self.port is not None else 7777
port = self.provider_port if self.provider_port is not None else 7777
host = self.host if self.host else "localhost"
image = self.image
shared_path = self.shared_path
@@ -365,6 +392,7 @@ class Computer:
verbose=verbose,
ephemeral=ephemeral,
noVNC_port=noVNC_port,
api_port=self.api_port,
)
else:
raise ValueError(f"Unsupported provider type: {self.provider_type}")
@@ -513,13 +541,14 @@ class Computer:
ip_address=ip_address,
api_key=self.api_key,
vm_name=self.config.name,
api_port=self.api_port,
),
)
else:
interface = cast(
BaseComputerInterface,
InterfaceFactory.create_interface_for_os(
os=self.os_type, ip_address=ip_address
os=self.os_type, ip_address=ip_address, api_port=self.api_port
),
)
@@ -533,15 +562,13 @@ class Computer:
# Use a single timeout for the entire connection process
# The VM should already be ready at this point, so we're just establishing the connection
await self._interface.wait_for_ready(timeout=30)
self.logger.info("WebSocket interface connected successfully")
self.logger.info("Sandbox interface connected successfully")
except TimeoutError as e:
self.logger.error(f"Failed to connect to WebSocket interface at {ip_address}")
port = getattr(self._interface, "_api_port", 8000) # Default to 8000 if not set
self.logger.error(f"Failed to connect to sandbox interface at {ip_address}:{port}")
raise TimeoutError(
f"Could not connect to WebSocket interface at {ip_address}:8000/ws: {str(e)}"
f"Could not connect to sandbox interface at {ip_address}:{port}: {str(e)}"
)
# self.logger.warning(
# f"Could not connect to WebSocket interface at {ip_address}:8000/ws: {str(e)}, expect missing functionality"
# )
# Create an event to keep the VM running in background if needed
if not self.use_host_computer_server:
@@ -688,6 +715,7 @@ class Computer:
ip_address=ip_address,
api_key=self.api_key,
vm_name=self.config.name,
api_port=self.api_port,
),
)
else:
@@ -696,6 +724,7 @@ class Computer:
InterfaceFactory.create_interface_for_os(
os=self.os_type,
ip_address=ip_address,
api_port=self.api_port,
),
)
@@ -1013,7 +1042,7 @@ class Computer:
else:
# POSIX (macOS/Linux)
venv_path = f"$HOME/.venvs/{venv_name}"
create_cmd = f'mkdir -p "$HOME/.venvs" && python3 -m venv "{venv_path}"'
create_cmd = f'mkdir -p "$HOME/.venvs" && python -m venv "{venv_path}"'
# Check if venv exists, if not create it
check_cmd = f'test -d "{venv_path}" || ({create_cmd})'
_ = await self.interface.run_command(check_cmd)
@@ -1024,7 +1053,25 @@ class Computer:
if requirements_str
else "echo No requirements to install"
)
return await self.interface.run_command(install_cmd)
return await self.interface.run_command(install_cmd)
async def pip_install(self, requirements: list[str]):
"""Install packages using the system Python/pip (no venv).
Args:
requirements: List of package requirements to install globally/user site.
Returns:
Tuple of (stdout, stderr) from the installation command
"""
requirements = requirements or []
if not requirements:
return await self.interface.run_command("echo No requirements to install")
# Use python -m pip for cross-platform consistency
reqs = " ".join(requirements)
install_cmd = f"python -m pip install {reqs}"
return await self.interface.run_command(install_cmd)
async def venv_cmd(self, venv_name: str, command: str):
"""Execute a shell command in a virtual environment.
@@ -1074,19 +1121,11 @@ class Computer:
The result of the function execution, or raises any exception that occurred
"""
import base64
import inspect
import json
import textwrap
try:
# Get function source code using inspect.getsource
source = inspect.getsource(python_func)
# Remove common leading whitespace (dedent)
func_source = textwrap.dedent(source).strip()
# Remove decorators
while func_source.lstrip().startswith("@"):
func_source = func_source.split("\n", 1)[1].strip()
func_source = helpers.generate_source_code(python_func)
# Get function name for execution
func_name = python_func.__name__
@@ -1101,19 +1140,23 @@ class Computer:
raise Exception(f"Failed to reconstruct function source: {e}")
# Create Python code that will define and execute the function
args_b64 = base64.b64encode(args_json.encode("utf-8")).decode("ascii")
kwargs_b64 = base64.b64encode(kwargs_json.encode("utf-8")).decode("ascii")
python_code = f'''
import json
import traceback
import base64
try:
# Define the function from source
{textwrap.indent(func_source, " ")}
# Deserialize args and kwargs from JSON
args_json = """{args_json}"""
kwargs_json = """{kwargs_json}"""
args = json.loads(args_json)
kwargs = json.loads(kwargs_json)
# Deserialize args and kwargs from base64 JSON
_args_b64 = """{args_b64}"""
_kwargs_b64 = """{kwargs_b64}"""
args = json.loads(base64.b64decode(_args_b64).decode('utf-8'))
kwargs = json.loads(base64.b64decode(_kwargs_b64).decode('utf-8'))
# Execute the function
result = {func_name}(*args, **kwargs)
@@ -1177,10 +1220,21 @@ print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
if output_payload["success"]:
return output_payload["result"]
else:
import builtins
# Recreate and raise the original exception
error_info = output_payload["error"]
error_class = eval(error_info["type"])
raise error_class(error_info["message"])
error_info = output_payload.get("error", {}) or {}
err_type = error_info.get("type") or "Exception"
err_msg = error_info.get("message") or ""
err_tb = error_info.get("traceback") or ""
exc_cls = getattr(builtins, err_type, None)
if isinstance(exc_cls, type) and issubclass(exc_cls, BaseException):
# Built-in exception: rethrow with remote traceback appended
raise exc_cls(f"{err_msg}\n\nRemote traceback:\n{err_tb}")
else:
# Non built-in: raise a safe local error carrying full remote context
raise RuntimeError(f"{err_type}: {err_msg}\n\nRemote traceback:\n{err_tb}")
else:
raise Exception("Invalid output format: markers found but no content between them")
else:
@@ -1188,3 +1242,345 @@ print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
raise Exception(
f"No output payload found. stdout: {result.stdout}, stderr: {result.stderr}"
)
async def venv_exec_background(
self, venv_name: str, python_func, *args, requirements: Optional[List[str]] = None, **kwargs
) -> int:
"""Run the Python function in the venv in the background and return the PID.
Uses a short launcher Python that spawns a detached child and exits immediately.
"""
import base64
import json
import textwrap
import time as _time
try:
func_source = helpers.generate_source_code(python_func)
func_name = python_func.__name__
args_json = json.dumps(args, default=str)
kwargs_json = json.dumps(kwargs, default=str)
except OSError as e:
raise Exception(f"Cannot retrieve source code for function {python_func.__name__}: {e}")
except Exception as e:
raise Exception(f"Failed to reconstruct function source: {e}")
reqs_list = requirements or []
reqs_json = json.dumps(reqs_list)
# Create Python code that will define and execute the function
args_b64 = base64.b64encode(args_json.encode("utf-8")).decode("ascii")
kwargs_b64 = base64.b64encode(kwargs_json.encode("utf-8")).decode("ascii")
payload_code = (
f'''
import json
import traceback
import base64
try:
# Define the function from source
{textwrap.indent(func_source, " ")}
# Deserialize args and kwargs from base64 JSON
_args_b64 = """{args_b64}"""
_kwargs_b64 = """{kwargs_b64}"""
args = json.loads(base64.b64decode(_args_b64).decode('utf-8'))
kwargs = json.loads(base64.b64decode(_kwargs_b64).decode('utf-8'))
# Ensure requirements inside the active venv
for pkg in json.loads('''
+ repr(reqs_json)
+ """):
if pkg:
import subprocess, sys
subprocess.run([sys.executable, '-m', 'pip', 'install', pkg], check=False)
_ = {func_name}(*args, **kwargs)
except Exception:
import sys
sys.stderr.write(traceback.format_exc())
"""
)
payload_b64 = base64.b64encode(payload_code.encode("utf-8")).decode("ascii")
if self.os_type == "windows":
# Launcher spawns detached child and prints its PID
launcher_code = f"""
import base64, subprocess, os, sys
DETACHED_PROCESS = 0x00000008
CREATE_NEW_PROCESS_GROUP = 0x00000200
creationflags = DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP
code = base64.b64decode("{payload_b64}").decode("utf-8")
p = subprocess.Popen(["python", "-c", code], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, creationflags=creationflags)
print(p.pid)
"""
launcher_b64 = base64.b64encode(launcher_code.encode("utf-8")).decode("ascii")
venv_path = f"%USERPROFILE%\\.venvs\\{venv_name}"
cmd = (
'cmd /c "'
f'call "{venv_path}\\Scripts\\activate.bat" && '
f"python -c \"import base64; exec(base64.b64decode('{launcher_b64}').decode('utf-8'))\""
'"'
)
result = await self.interface.run_command(cmd)
pid_str = (result.stdout or "").strip().splitlines()[-1].strip()
return int(pid_str)
else:
log = f"/tmp/cua_bg_{int(_time.time())}.log"
launcher_code = f"""
import base64, subprocess, os, sys
code = base64.b64decode("{payload_b64}").decode("utf-8")
with open("{log}", "ab", buffering=0) as f:
p = subprocess.Popen(["python", "-c", code], stdout=f, stderr=subprocess.STDOUT, preexec_fn=getattr(os, "setsid", None))
print(p.pid)
"""
launcher_b64 = base64.b64encode(launcher_code.encode("utf-8")).decode("ascii")
venv_path = f"$HOME/.venvs/{venv_name}"
shell = (
f'. "{venv_path}/bin/activate" && '
f"python -c \"import base64; exec(base64.b64decode('{launcher_b64}').decode('utf-8'))\""
)
result = await self.interface.run_command(shell)
pid_str = (result.stdout or "").strip().splitlines()[-1].strip()
return int(pid_str)
async def python_exec(self, python_func, *args, **kwargs):
"""Execute a Python function using the system Python (no venv).
Uses source extraction and base64 transport, mirroring venv_exec but
without virtual environment activation.
Returns the function result or raises a reconstructed exception with
remote traceback context appended.
"""
import base64
import json
import textwrap
try:
func_source = helpers.generate_source_code(python_func)
func_name = python_func.__name__
args_json = json.dumps(args, default=str)
kwargs_json = json.dumps(kwargs, default=str)
except OSError as e:
raise Exception(f"Cannot retrieve source code for function {python_func.__name__}: {e}")
except Exception as e:
raise Exception(f"Failed to reconstruct function source: {e}")
# Create Python code that will define and execute the function
args_b64 = base64.b64encode(args_json.encode("utf-8")).decode("ascii")
kwargs_b64 = base64.b64encode(kwargs_json.encode("utf-8")).decode("ascii")
python_code = f'''
import json
import traceback
import base64
try:
# Define the function from source
{textwrap.indent(func_source, " ")}
# Deserialize args and kwargs from base64 JSON
_args_b64 = """{args_b64}"""
_kwargs_b64 = """{kwargs_b64}"""
args = json.loads(base64.b64decode(_args_b64).decode('utf-8'))
kwargs = json.loads(base64.b64decode(_kwargs_b64).decode('utf-8'))
# Execute the function
result = {func_name}(*args, **kwargs)
# Create success output payload
output_payload = {{
"success": True,
"result": result,
"error": None
}}
except Exception as e:
# Create error output payload
output_payload = {{
"success": False,
"result": None,
"error": {{
"type": type(e).__name__,
"message": str(e),
"traceback": traceback.format_exc()
}}
}}
# Serialize the output payload as JSON
import json
output_json = json.dumps(output_payload, default=str)
# Print the JSON output with markers
print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
'''
encoded_code = base64.b64encode(python_code.encode("utf-8")).decode("ascii")
python_command = (
f"python -c \"import base64; exec(base64.b64decode('{encoded_code}').decode('utf-8'))\""
)
result = await self.interface.run_command(python_command)
start_marker = "<<<VENV_EXEC_START>>>"
end_marker = "<<<VENV_EXEC_END>>>"
print(result.stdout[: result.stdout.find(start_marker)])
if start_marker in result.stdout and end_marker in result.stdout:
start_idx = result.stdout.find(start_marker) + len(start_marker)
end_idx = result.stdout.find(end_marker)
if start_idx < end_idx:
output_json = result.stdout[start_idx:end_idx]
try:
output_payload = json.loads(output_json)
except Exception as e:
raise Exception(f"Failed to decode output payload: {e}")
if output_payload["success"]:
return output_payload["result"]
else:
import builtins
error_info = output_payload.get("error", {}) or {}
err_type = error_info.get("type") or "Exception"
err_msg = error_info.get("message") or ""
err_tb = error_info.get("traceback") or ""
exc_cls = getattr(builtins, err_type, None)
if isinstance(exc_cls, type) and issubclass(exc_cls, BaseException):
raise exc_cls(f"{err_msg}\n\nRemote traceback:\n{err_tb}")
else:
raise RuntimeError(f"{err_type}: {err_msg}\n\nRemote traceback:\n{err_tb}")
else:
raise Exception("Invalid output format: markers found but no content between them")
else:
raise Exception(
f"No output payload found. stdout: {result.stdout}, stderr: {result.stderr}"
)
async def python_exec_background(
self, python_func, *args, requirements: Optional[List[str]] = None, **kwargs
) -> int:
"""Run a Python function with the system interpreter in the background and return PID.
Uses a short launcher Python that spawns a detached child and exits immediately.
"""
import base64
import json
import textwrap
import time as _time
try:
func_source = helpers.generate_source_code(python_func)
func_name = python_func.__name__
args_json = json.dumps(args, default=str)
kwargs_json = json.dumps(kwargs, default=str)
except OSError as e:
raise Exception(f"Cannot retrieve source code for function {python_func.__name__}: {e}")
except Exception as e:
raise Exception(f"Failed to reconstruct function source: {e}")
# Create Python code that will define and execute the function
args_b64 = base64.b64encode(args_json.encode("utf-8")).decode("ascii")
kwargs_b64 = base64.b64encode(kwargs_json.encode("utf-8")).decode("ascii")
payload_code = f'''
import json
import traceback
import base64
try:
# Define the function from source
{textwrap.indent(func_source, " ")}
# Deserialize args and kwargs from base64 JSON
_args_b64 = """{args_b64}"""
_kwargs_b64 = """{kwargs_b64}"""
args = json.loads(base64.b64decode(_args_b64).decode('utf-8'))
kwargs = json.loads(base64.b64decode(_kwargs_b64).decode('utf-8'))
_ = {func_name}(*args, **kwargs)
except Exception:
import sys
sys.stderr.write(traceback.format_exc())
'''
payload_b64 = base64.b64encode(payload_code.encode("utf-8")).decode("ascii")
if self.os_type == "windows":
launcher_code = f"""
import base64, subprocess, os, sys
DETACHED_PROCESS = 0x00000008
CREATE_NEW_PROCESS_GROUP = 0x00000200
creationflags = DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP
code = base64.b64decode("{payload_b64}").decode("utf-8")
p = subprocess.Popen(["python", "-c", code], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, creationflags=creationflags)
print(p.pid)
"""
launcher_b64 = base64.b64encode(launcher_code.encode("utf-8")).decode("ascii")
cmd = f"python -c \"import base64; exec(base64.b64decode('{launcher_b64}').decode('utf-8'))\""
result = await self.interface.run_command(cmd)
pid_str = (result.stdout or "").strip().splitlines()[-1].strip()
return int(pid_str)
else:
log = f"/tmp/cua_bg_{int(_time.time())}.log"
launcher_code = f"""
import base64, subprocess, os, sys
code = base64.b64decode("{payload_b64}").decode("utf-8")
with open("{log}", "ab", buffering=0) as f:
p = subprocess.Popen(["python", "-c", code], stdout=f, stderr=subprocess.STDOUT, preexec_fn=getattr(os, "setsid", None))
print(p.pid)
"""
launcher_b64 = base64.b64encode(launcher_code.encode("utf-8")).decode("ascii")
cmd = f"python -c \"import base64; exec(base64.b64decode('{launcher_b64}').decode('utf-8'))\""
result = await self.interface.run_command(cmd)
pid_str = (result.stdout or "").strip().splitlines()[-1].strip()
return int(pid_str)
def python_command(
self,
requirements: Optional[List[str]] = None,
*,
venv_name: str = "default",
use_system_python: bool = False,
background: bool = False,
) -> Callable[[Callable[P, R]], Callable[P, Awaitable[R]]]:
"""Decorator to execute a Python function remotely in this Computer's venv.
This mirrors `computer.helpers.sandboxed()` but binds to this instance and
optionally ensures required packages are installed before execution.
Args:
requirements: Packages to install in the virtual environment.
venv_name: Name of the virtual environment to use.
use_system_python: If True, use the system Python/pip instead of a venv.
background: If True, run the function detached and return the child PID immediately.
Returns:
A decorator that turns a local function into an async callable which
runs remotely and returns the function's result.
"""
reqs = list(requirements or [])
def decorator(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
if use_system_python:
# For background, avoid blocking installs; install inside child process
if background:
return await self.python_exec_background(func, *args, requirements=reqs, **kwargs) # type: ignore[return-value]
# Foreground: install first, then execute
if reqs:
await self.pip_install(reqs)
return await self.python_exec(func, *args, **kwargs)
else:
# For background, avoid blocking installs; install inside child process under venv
if background:
return await self.venv_exec_background(venv_name, func, *args, requirements=reqs, **kwargs) # type: ignore[return-value]
# Foreground: ensure venv and install, then execute
await self.venv_install(venv_name, reqs)
return await self.venv_exec(venv_name, func, *args, **kwargs)
return wrapper
return decorator

View File

@@ -2,18 +2,46 @@
Helper functions and decorators for the Computer module.
"""
import ast
import asyncio
import builtins
import importlib.util
import inspect
import logging
import os
import sys
from functools import wraps
from typing import Any, Callable, Optional, TypeVar, cast
from inspect import getsource
from textwrap import dedent
from types import FunctionType, ModuleType
from typing import Any, Awaitable, Callable, Dict, List, Set, TypedDict, TypeVar
try:
# Python 3.12+ has ParamSpec in typing
from typing import ParamSpec
except ImportError: # pragma: no cover
# Fallback for environments without ParamSpec in typing
from typing_extensions import ParamSpec # type: ignore
P = ParamSpec("P")
R = TypeVar("R")
class DependencyInfo(TypedDict):
import_statements: List[str]
definitions: List[tuple[str, Any]]
# Global reference to the default computer instance
_default_computer = None
# Global cache for function dependency analysis
_function_dependency_map: Dict[FunctionType, DependencyInfo] = {}
logger = logging.getLogger(__name__)
def set_default_computer(computer):
def set_default_computer(computer: Any) -> None:
"""
Set the default computer instance to be used by the remote decorator.
@@ -24,19 +52,26 @@ def set_default_computer(computer):
_default_computer = computer
def sandboxed(venv_name: str = "default", computer: str = "default", max_retries: int = 3):
def sandboxed(
venv_name: str = "default",
computer: str = "default",
max_retries: int = 3,
) -> Callable[[Callable[P, R]], Callable[P, Awaitable[R]]]:
"""
Decorator that wraps a function to be executed remotely via computer.venv_exec
The function is automatically analyzed for dependencies (imports, helper functions,
constants, etc.) and reconstructed with all necessary code in the remote sandbox.
Args:
venv_name: Name of the virtual environment to execute in
computer: The computer instance to use, or "default" to use the globally set default
max_retries: Maximum number of retries for the remote execution
"""
def decorator(func):
def decorator(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
@wraps(func)
async def wrapper(*args, **kwargs):
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# Determine which computer instance to use
comp = computer if computer != "default" else _default_computer
@@ -54,6 +89,402 @@ def sandboxed(venv_name: str = "default", computer: str = "default", max_retries
if i == max_retries - 1:
raise e
# Should be unreachable because we either returned or raised
raise RuntimeError("sandboxed wrapper reached unreachable code path")
return wrapper
return decorator
def _extract_import_statement(name: str, module: ModuleType) -> str:
"""Extract the original import statement for a module."""
module_name = module.__name__
if name == module_name.split(".")[0]:
return f"import {module_name}"
else:
return f"import {module_name} as {name}"
def _is_third_party_module(module_name: str) -> bool:
"""Check if a module is a third-party module."""
stdlib_modules = set(sys.stdlib_module_names) if hasattr(sys, "stdlib_module_names") else set()
if module_name in stdlib_modules:
return False
try:
spec = importlib.util.find_spec(module_name)
if spec is None:
return False
if spec.origin and ("site-packages" in spec.origin or "dist-packages" in spec.origin):
return True
return False
except (ImportError, ModuleNotFoundError, ValueError):
return False
def _is_project_import(module_name: str) -> bool:
"""Check if a module is a project-level import."""
if module_name.startswith("__relative_import_level_"):
return True
if module_name in sys.modules:
module = sys.modules[module_name]
if hasattr(module, "__file__") and module.__file__:
if "site-packages" not in module.__file__ and "dist-packages" not in module.__file__:
cwd = os.getcwd()
if module.__file__.startswith(cwd):
return True
return False
def _categorize_module(module_name: str) -> str:
"""Categorize a module as stdlib, third-party, or project."""
if module_name.startswith("__relative_import_level_"):
return "project"
elif module_name in (
set(sys.stdlib_module_names) if hasattr(sys, "stdlib_module_names") else set()
):
return "stdlib"
elif _is_third_party_module(module_name):
return "third_party"
elif _is_project_import(module_name):
return "project"
else:
return "unknown"
class _DependencyVisitor(ast.NodeVisitor):
"""AST visitor to extract imports and name references from a function."""
def __init__(self, function_name: str) -> None:
self.function_name = function_name
self.internal_imports: Set[str] = set()
self.internal_import_statements: List[str] = []
self.name_references: Set[str] = set()
self.local_names: Set[str] = set()
self.inside_function = False
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
if node.name == self.function_name and not self.inside_function:
self.inside_function = True
for arg in node.args.args + node.args.posonlyargs + node.args.kwonlyargs:
self.local_names.add(arg.arg)
if node.args.vararg:
self.local_names.add(node.args.vararg.arg)
if node.args.kwarg:
self.local_names.add(node.args.kwarg.arg)
for child in node.body:
self.visit(child)
self.inside_function = False
else:
if self.inside_function:
self.local_names.add(node.name)
for child in node.body:
self.visit(child)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
self.visit_FunctionDef(node) # type: ignore
def visit_Import(self, node: ast.Import) -> None:
if self.inside_function:
for alias in node.names:
module_name = alias.name.split(".")[0]
self.internal_imports.add(module_name)
imported_as = alias.asname if alias.asname else alias.name.split(".")[0]
self.local_names.add(imported_as)
self.internal_import_statements.append(ast.unparse(node))
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if self.inside_function:
if node.level == 0 and node.module:
module_name = node.module.split(".")[0]
self.internal_imports.add(module_name)
elif node.level > 0:
self.internal_imports.add(f"__relative_import_level_{node.level}__")
for alias in node.names:
imported_as = alias.asname if alias.asname else alias.name
self.local_names.add(imported_as)
self.internal_import_statements.append(ast.unparse(node))
self.generic_visit(node)
def visit_Name(self, node: ast.Name) -> None:
if self.inside_function:
if isinstance(node.ctx, ast.Load):
self.name_references.add(node.id)
elif isinstance(node.ctx, ast.Store):
self.local_names.add(node.id)
self.generic_visit(node)
def visit_ClassDef(self, node: ast.ClassDef) -> None:
if self.inside_function:
self.local_names.add(node.name)
self.generic_visit(node)
def visit_For(self, node: ast.For) -> None:
if self.inside_function and isinstance(node.target, ast.Name):
self.local_names.add(node.target.id)
self.generic_visit(node)
def visit_comprehension(self, node: ast.comprehension) -> None:
if self.inside_function and isinstance(node.target, ast.Name):
self.local_names.add(node.target.id)
self.generic_visit(node)
def visit_ExceptHandler(self, node: ast.ExceptHandler) -> None:
if self.inside_function and node.name:
self.local_names.add(node.name)
self.generic_visit(node)
def visit_With(self, node: ast.With) -> None:
if self.inside_function:
for item in node.items:
if item.optional_vars and isinstance(item.optional_vars, ast.Name):
self.local_names.add(item.optional_vars.id)
self.generic_visit(node)
def _traverse_and_collect_dependencies(func: FunctionType) -> DependencyInfo:
"""
Traverse a function and collect its dependencies.
Returns a dict with:
- import_statements: List of import statements needed
- definitions: List of (name, obj) tuples for helper functions/classes/constants
"""
source = dedent(getsource(func))
tree = ast.parse(source)
visitor = _DependencyVisitor(func.__name__)
visitor.visit(tree)
builtin_names = set(dir(builtins))
external_refs = (visitor.name_references - visitor.local_names) - builtin_names
import_statements = []
definitions = []
visited = set()
# Include all internal import statements
import_statements.extend(visitor.internal_import_statements)
# Analyze external references recursively
def analyze_object(obj: Any, name: str, depth: int = 0) -> None:
if depth > 20:
return
obj_id = id(obj)
if obj_id in visited:
return
visited.add(obj_id)
# Handle modules
if inspect.ismodule(obj):
import_stmt = _extract_import_statement(name, obj)
import_statements.append(import_stmt)
return
# Handle functions and classes
if (
inspect.isfunction(obj)
or inspect.isclass(obj)
or inspect.isbuiltin(obj)
or inspect.ismethod(obj)
):
obj_module = getattr(obj, "__module__", None)
if obj_module:
base_module = obj_module.split(".")[0]
module_category = _categorize_module(base_module)
# If from stdlib/third-party, just add import
if module_category in ("stdlib", "third_party"):
obj_name = getattr(obj, "__name__", name)
# Check if object is accessible by 'name' (in globals or closures)
is_accessible = False
if name in func.__globals__ and func.__globals__[name] is obj:
is_accessible = True
elif func.__closure__ and hasattr(func, "__code__"):
freevars = func.__code__.co_freevars
for i, var_name in enumerate(freevars):
if var_name == name and i < len(func.__closure__):
try:
if func.__closure__[i].cell_contents is obj:
is_accessible = True
break
except (ValueError, AttributeError):
pass
if is_accessible and name == obj_name:
# Direct import: from requests import get, from math import sqrt
import_statements.append(f"from {base_module} import {name}")
else:
# Module import: import requests
import_statements.append(f"import {base_module}")
return
try:
obj_tree = ast.parse(dedent(getsource(obj)))
obj_visitor = _DependencyVisitor(obj.__name__)
obj_visitor.visit(obj_tree)
obj_external_refs = obj_visitor.name_references - obj_visitor.local_names
obj_external_refs = obj_external_refs - builtin_names
# Add internal imports from this object
import_statements.extend(obj_visitor.internal_import_statements)
# Recursively analyze its dependencies
obj_globals = getattr(obj, "__globals__", None)
obj_closure = getattr(obj, "__closure__", None)
obj_code = getattr(obj, "__code__", None)
if obj_globals:
for ref_name in obj_external_refs:
ref_obj = None
# Check globals first
if ref_name in obj_globals:
ref_obj = obj_globals[ref_name]
# Check closure variables using co_freevars
elif obj_closure and obj_code:
freevars = obj_code.co_freevars
for i, var_name in enumerate(freevars):
if var_name == ref_name and i < len(obj_closure):
try:
ref_obj = obj_closure[i].cell_contents
break
except (ValueError, AttributeError):
pass
if ref_obj is not None:
analyze_object(ref_obj, ref_name, depth + 1)
# Add this object to definitions
if not inspect.ismodule(obj):
ref_module = getattr(obj, "__module__", None)
if ref_module:
ref_base_module = ref_module.split(".")[0]
ref_category = _categorize_module(ref_base_module)
if ref_category not in ("stdlib", "third_party"):
definitions.append((name, obj))
else:
definitions.append((name, obj))
except (OSError, TypeError):
pass
return
if isinstance(obj, (int, float, str, bool, list, dict, tuple, set, frozenset, type(None))):
definitions.append((name, obj))
# Analyze all external references
for name in external_refs:
obj = None
# First check globals
if name in func.__globals__:
obj = func.__globals__[name]
# Then check closure variables (sibling functions in enclosing scope)
elif func.__closure__ and func.__code__.co_freevars:
# Match closure variable names with cell contents
freevars = func.__code__.co_freevars
for i, var_name in enumerate(freevars):
if var_name == name and i < len(func.__closure__):
try:
obj = func.__closure__[i].cell_contents
break
except (ValueError, AttributeError):
# Cell is empty or doesn't have contents
pass
if obj is not None:
analyze_object(obj, name)
# Remove duplicate import statements
unique_imports = []
seen = set()
for stmt in import_statements:
if stmt not in seen:
seen.add(stmt)
unique_imports.append(stmt)
# Remove duplicate definitions
unique_definitions = []
seen_names = set()
for name, obj in definitions:
if name not in seen_names:
seen_names.add(name)
unique_definitions.append((name, obj))
return {
"import_statements": unique_imports,
"definitions": unique_definitions,
}
def generate_source_code(func: FunctionType) -> str:
"""
Generate complete source code for a function with all dependencies.
Args:
func: The function to generate source code for
Returns:
Complete Python source code as a string
"""
if func in _function_dependency_map:
info = _function_dependency_map[func]
else:
info = _traverse_and_collect_dependencies(func)
_function_dependency_map[func] = info
# Build source code
parts = []
# 1. Add imports
if info["import_statements"]:
parts.append("\n".join(info["import_statements"]))
# 2. Add definitions
for name, obj in info["definitions"]:
try:
if inspect.isfunction(obj):
source = dedent(getsource(obj))
tree = ast.parse(source)
if tree.body and isinstance(tree.body[0], (ast.FunctionDef, ast.AsyncFunctionDef)):
tree.body[0].decorator_list = []
source = ast.unparse(tree)
parts.append(source)
elif inspect.isclass(obj):
source = dedent(getsource(obj))
tree = ast.parse(source)
if tree.body and isinstance(tree.body[0], ast.ClassDef):
tree.body[0].decorator_list = []
source = ast.unparse(tree)
parts.append(source)
else:
parts.append(f"{name} = {repr(obj)}")
except (OSError, TypeError):
pass
# 3. Add main function (without decorators)
func_source = dedent(getsource(func))
tree = ast.parse(func_source)
if tree.body and isinstance(tree.body[0], (ast.FunctionDef, ast.AsyncFunctionDef)):
tree.body[0].decorator_list = []
func_source = ast.unparse(tree)
parts.append(func_source)
return "\n\n".join(parts)

View File

@@ -12,6 +12,7 @@ class InterfaceFactory:
def create_interface_for_os(
os: Literal["macos", "linux", "windows"],
ip_address: str,
api_port: Optional[int] = None,
api_key: Optional[str] = None,
vm_name: Optional[str] = None,
) -> BaseComputerInterface:
@@ -20,6 +21,7 @@ class InterfaceFactory:
Args:
os: Operating system type ('macos', 'linux', or 'windows')
ip_address: IP address of the computer to control
api_port: Optional API port of the computer to control
api_key: Optional API key for cloud authentication
vm_name: Optional VM name for cloud authentication
@@ -35,10 +37,16 @@ class InterfaceFactory:
from .windows import WindowsComputerInterface
if os == "macos":
return MacOSComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
return MacOSComputerInterface(
ip_address, api_key=api_key, vm_name=vm_name, api_port=api_port
)
elif os == "linux":
return LinuxComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
return LinuxComputerInterface(
ip_address, api_key=api_key, vm_name=vm_name, api_port=api_port
)
elif os == "windows":
return WindowsComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
return WindowsComputerInterface(
ip_address, api_key=api_key, vm_name=vm_name, api_port=api_port
)
else:
raise ValueError(f"Unsupported OS type: {os}")

View File

@@ -30,6 +30,7 @@ class GenericComputerInterface(BaseComputerInterface):
api_key: Optional[str] = None,
vm_name: Optional[str] = None,
logger_name: str = "computer.interface.generic",
api_port: Optional[int] = None,
):
super().__init__(ip_address, username, password, api_key, vm_name)
self._ws = None
@@ -47,6 +48,9 @@ class GenericComputerInterface(BaseComputerInterface):
# Set logger name for the interface
self.logger = Logger(logger_name, LogLevel.NORMAL)
# Store custom ports
self._api_port = api_port
# Optional default delay time between commands (in seconds)
self.delay = 0.0
@@ -70,7 +74,12 @@ class GenericComputerInterface(BaseComputerInterface):
WebSocket URI for the Computer API Server
"""
protocol = "wss" if self.api_key else "ws"
port = "8443" if self.api_key else "8000"
# Use custom API port if provided, otherwise use defaults based on API key
port = (
str(self._api_port)
if self._api_port is not None
else ("8443" if self.api_key else "8000")
)
return f"{protocol}://{self.ip_address}:{port}/ws"
@property
@@ -81,7 +90,12 @@ class GenericComputerInterface(BaseComputerInterface):
REST URI for the Computer API Server
"""
protocol = "https" if self.api_key else "http"
port = "8443" if self.api_key else "8000"
# Use custom API port if provided, otherwise use defaults based on API key
port = (
str(self._api_port)
if self._api_port is not None
else ("8443" if self.api_key else "8000")
)
return f"{protocol}://{self.ip_address}:{port}/cmd"
# Mouse actions

View File

@@ -13,7 +13,8 @@ class LinuxComputerInterface(GenericComputerInterface):
password: str = "lume",
api_key: Optional[str] = None,
vm_name: Optional[str] = None,
api_port: Optional[int] = None,
):
super().__init__(
ip_address, username, password, api_key, vm_name, "computer.interface.linux"
ip_address, username, password, api_key, vm_name, "computer.interface.linux", api_port
)

View File

@@ -13,9 +13,10 @@ class MacOSComputerInterface(GenericComputerInterface):
password: str = "lume",
api_key: Optional[str] = None,
vm_name: Optional[str] = None,
api_port: Optional[int] = None,
):
super().__init__(
ip_address, username, password, api_key, vm_name, "computer.interface.macos"
ip_address, username, password, api_key, vm_name, "computer.interface.macos", api_port
)
async def diorama_cmd(self, action: str, arguments: Optional[dict] = None) -> dict:

View File

@@ -13,7 +13,8 @@ class WindowsComputerInterface(GenericComputerInterface):
password: str = "lume",
api_key: Optional[str] = None,
vm_name: Optional[str] = None,
api_port: Optional[int] = None,
):
super().__init__(
ip_address, username, password, api_key, vm_name, "computer.interface.windows"
ip_address, username, password, api_key, vm_name, "computer.interface.windows", api_port
)

View File

@@ -37,7 +37,6 @@ class DockerProvider(BaseVMProvider):
def __init__(
self,
port: Optional[int] = 8000,
host: str = "localhost",
storage: Optional[str] = None,
shared_path: Optional[str] = None,
@@ -45,11 +44,11 @@ class DockerProvider(BaseVMProvider):
verbose: bool = False,
ephemeral: bool = False,
vnc_port: Optional[int] = 6901,
api_port: Optional[int] = None,
):
"""Initialize the Docker VM Provider.
Args:
port: Currently unused (VM provider port)
host: Hostname for the API server (default: localhost)
storage: Path for persistent VM storage
shared_path: Path for shared folder between host and container
@@ -60,9 +59,10 @@ class DockerProvider(BaseVMProvider):
verbose: Enable verbose logging
ephemeral: Use ephemeral (temporary) storage
vnc_port: Port for VNC interface (default: 6901)
api_port: Port for API server (default: 8000)
"""
self.host = host
self.api_port = 8000
self.api_port = api_port if api_port is not None else 8000
self.vnc_port = vnc_port
self.ephemeral = ephemeral
@@ -296,6 +296,7 @@ class DockerProvider(BaseVMProvider):
if vnc_port:
cmd.extend(["-p", f"{vnc_port}:6901"]) # VNC port
if api_port:
# Map the API port to container port 8000 (computer-server default)
cmd.extend(["-p", f"{api_port}:8000"]) # computer-server API port
# Add volume mounts if storage is specified

View File

@@ -14,7 +14,7 @@ class VMProviderFactory:
@staticmethod
def create_provider(
provider_type: Union[str, VMProviderType],
port: int = 7777,
provider_port: int = 7777,
host: str = "localhost",
bin_path: Optional[str] = None,
storage: Optional[str] = None,
@@ -23,13 +23,14 @@ class VMProviderFactory:
verbose: bool = False,
ephemeral: bool = False,
noVNC_port: Optional[int] = None,
api_port: Optional[int] = None,
**kwargs,
) -> BaseVMProvider:
"""Create a VM provider of the specified type.
Args:
provider_type: Type of VM provider to create
port: Port for the API server
provider_port: Port for the provider's API server
host: Hostname for the API server
bin_path: Path to provider binary if needed
storage: Path for persistent VM storage
@@ -37,7 +38,8 @@ class VMProviderFactory:
image: VM image to use (for Lumier provider)
verbose: Enable verbose logging
ephemeral: Use ephemeral (temporary) storage
noVNC_port: Specific port for noVNC interface (for Lumier provider)
noVNC_port: Specific port for noVNC interface (for Lumier and Docker provider)
api_port: Specific port for Computer API server (for Docker provider)
Returns:
An instance of the requested VM provider
@@ -63,7 +65,11 @@ class VMProviderFactory:
"Please install it with 'pip install cua-computer[lume]'"
)
return LumeProvider(
port=port, host=host, storage=storage, verbose=verbose, ephemeral=ephemeral
provider_port=provider_port,
host=host,
storage=storage,
verbose=verbose,
ephemeral=ephemeral,
)
except ImportError as e:
logger.error(f"Failed to import LumeProvider: {e}")
@@ -81,7 +87,7 @@ class VMProviderFactory:
"Please install Docker for Apple Silicon and Lume CLI before using this provider."
)
return LumierProvider(
port=port,
provider_port=provider_port,
host=host,
storage=storage,
shared_path=shared_path,
@@ -121,7 +127,6 @@ class VMProviderFactory:
"Please install it with 'pip install -U git+https://github.com/karkason/pywinsandbox.git'"
)
return WinSandboxProvider(
port=port,
host=host,
storage=storage,
verbose=verbose,
@@ -144,7 +149,6 @@ class VMProviderFactory:
"Please install Docker and ensure it is running."
)
return DockerProvider(
port=port,
host=host,
storage=storage,
shared_path=shared_path,
@@ -152,6 +156,7 @@ class VMProviderFactory:
verbose=verbose,
ephemeral=ephemeral,
vnc_port=noVNC_port,
api_port=api_port,
)
except ImportError as e:
logger.error(f"Failed to import DockerProvider: {e}")

View File

@@ -38,7 +38,7 @@ class LumeProvider(BaseVMProvider):
def __init__(
self,
port: int = 7777,
provider_port: int = 7777,
host: str = "localhost",
storage: Optional[str] = None,
verbose: bool = False,
@@ -47,7 +47,7 @@ class LumeProvider(BaseVMProvider):
"""Initialize the Lume provider.
Args:
port: Port for the Lume API server (default: 7777)
provider_port: Port for the Lume API server (default: 7777)
host: Host to use for API connections (default: localhost)
storage: Path to store VM data
verbose: Enable verbose logging
@@ -59,7 +59,7 @@ class LumeProvider(BaseVMProvider):
)
self.host = host
self.port = port # Default port for Lume API
self.port = provider_port # Default port for Lume API
self.storage = storage
self.verbose = verbose
self.ephemeral = ephemeral # If True, VMs will be deleted after stopping

View File

@@ -39,7 +39,7 @@ class LumierProvider(BaseVMProvider):
def __init__(
self,
port: Optional[int] = 7777,
provider_port: Optional[int] = 7777,
host: str = "localhost",
storage: Optional[str] = None, # Can be a path or 'ephemeral'
shared_path: Optional[str] = None,
@@ -51,7 +51,7 @@ class LumierProvider(BaseVMProvider):
"""Initialize the Lumier VM Provider.
Args:
port: Port for the API server (default: 7777)
provider_port: Port for the API server (default: 7777)
host: Hostname for the API server (default: localhost)
storage: Path for persistent VM storage
shared_path: Path for shared folder between host and VM
@@ -61,8 +61,8 @@ class LumierProvider(BaseVMProvider):
noVNC_port: Specific port for noVNC interface (default: 8006)
"""
self.host = host
# Always ensure api_port has a valid value (7777 is the default)
self.api_port = 7777 if port is None else port
# Always ensure lume_port has a valid value (7777 is the default)
self.lume_port = 7777 if provider_port is None else provider_port
self.vnc_port = noVNC_port # User-specified noVNC port, will be set in run_vm if provided
self.ephemeral = ephemeral
@@ -198,7 +198,7 @@ class LumierProvider(BaseVMProvider):
vm_info = lume_api_get(
vm_name=name,
host=self.host,
port=self.api_port,
port=self.lume_port,
storage=storage if storage is not None else self.storage,
debug=self.verbose,
verbose=self.verbose,
@@ -320,7 +320,7 @@ class LumierProvider(BaseVMProvider):
logger.debug(f"Using specified noVNC_port: {self.vnc_port}")
# Set API URL using the API port
self._api_url = f"http://{self.host}:{self.api_port}"
self._api_url = f"http://{self.host}:{self.lume_port}"
# Parse memory setting
memory_mb = self._parse_memory(run_opts.get("memory", "8GB"))
@@ -671,7 +671,7 @@ class LumierProvider(BaseVMProvider):
# Container is running, check if API is responsive
try:
# First check the health endpoint
api_url = f"http://{self.host}:{self.api_port}/health"
api_url = f"http://{self.host}:{self.lume_port}/health"
logger.info(f"Checking API health at: {api_url}")
# Use longer timeout for API health check since it may still be initializing
@@ -685,7 +685,7 @@ class LumierProvider(BaseVMProvider):
else:
# API health check failed, now let's check if the VM status endpoint is responsive
# This covers cases where the health endpoint isn't implemented but the VM API is working
vm_api_url = f"http://{self.host}:{self.api_port}/lume/vms/{container_name}"
vm_api_url = f"http://{self.host}:{self.lume_port}/lume/vms/{container_name}"
if self.storage:
import urllib.parse
@@ -1026,7 +1026,7 @@ class LumierProvider(BaseVMProvider):
# Initialize the API URL with the default value if not already set
# This ensures get_vm can work before run_vm is called
if not hasattr(self, "_api_url") or not self._api_url:
self._api_url = f"http://{self.host}:{self.api_port}"
self._api_url = f"http://{self.host}:{self.lume_port}"
logger.info(f"Initialized default Lumier API URL: {self._api_url}")
return self

View File

@@ -29,7 +29,6 @@ class WinSandboxProvider(BaseVMProvider):
def __init__(
self,
port: int = 7777,
host: str = "localhost",
storage: Optional[str] = None,
verbose: bool = False,
@@ -41,7 +40,6 @@ class WinSandboxProvider(BaseVMProvider):
"""Initialize the Windows Sandbox provider.
Args:
port: Port for the computer server (default: 7777)
host: Host to use for connections (default: localhost)
storage: Storage path (ignored - Windows Sandbox is always ephemeral)
verbose: Enable verbose logging
@@ -56,7 +54,6 @@ class WinSandboxProvider(BaseVMProvider):
)
self.host = host
self.port = port
self.verbose = verbose
self.memory_mb = memory_mb
self.networking = networking

View File

@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
[project]
name = "cua-computer"
version = "0.4.17"
version = "0.4.18"
description = "Computer-Use Interface (CUI) framework powering Cua"
readme = "README.md"
authors = [

View File

@@ -0,0 +1,478 @@
from computer.helpers import generate_source_code
class TestSimpleCases:
"""Test simple dependency cases"""
def test_simple_function_no_dependencies(self):
"""Test simple function with no dependencies"""
def simple_func():
return 42
code = generate_source_code(simple_func)
assert "def simple_func():" in code
assert "return 42" in code
def test_function_with_parameters(self):
"""Test function with parameters"""
def add(a, b):
return a + b
code = generate_source_code(add)
assert "def add(a, b):" in code
assert "return a + b" in code
def test_function_with_docstring(self):
"""Test function with docstring"""
def documented():
"""This is a docstring"""
return True
code = generate_source_code(documented)
assert "def documented():" in code
assert "This is a docstring" in code
class TestImports:
"""Test import handling"""
def test_stdlib_import_inside_function(self):
"""Test function with stdlib import inside"""
def with_stdlib():
import math
return math.sqrt(16)
code = generate_source_code(with_stdlib)
assert "import math" in code
assert "def with_stdlib():" in code
assert "math.sqrt(16)" in code
def test_multiple_stdlib_imports(self):
"""Test function with multiple stdlib imports"""
def with_multiple():
import json
import math
return json.dumps({"value": math.pi})
code = generate_source_code(with_multiple)
assert "import json" in code
assert "import math" in code
def test_from_import_stdlib(self):
"""Test from X import Y style"""
def with_from_import():
from math import pi, sqrt
return sqrt(16) + pi
code = generate_source_code(with_from_import)
assert "from math import" in code
assert "sqrt, pi" in code or ("sqrt" in code and "pi" in code)
def test_third_party_global_import(self):
"""Test globally imported third-party module"""
from requests import get
def with_requests():
return get
code = generate_source_code(with_requests)
assert "from requests import get" in code
assert "def with_requests():" in code
class TestHelperFunctions:
"""Test helper function dependencies"""
def test_single_helper_function(self):
"""Test function with one helper"""
def helper(x):
return x * 2
def main():
return helper(21)
code = generate_source_code(main)
assert "def helper(x):" in code
assert "return x * 2" in code
assert "def main():" in code
assert "return helper(21)" in code
def test_nested_helper_functions(self):
"""Test function with nested helpers"""
def helper_a(x):
return x + 1
def helper_b(x):
return helper_a(x) * 2
def main():
return helper_b(10)
code = generate_source_code(main)
assert "def helper_a(x):" in code
assert "def helper_b(x):" in code
assert "def main():" in code
def test_helper_dependency_ordering(self):
"""Test that helpers are ordered correctly (dependencies first)"""
def level_3(x):
return x + 1
def level_2(x):
return level_3(x) * 2
def level_1(x):
return level_2(x) + 5
code = generate_source_code(level_1)
# Check ordering
pos_3 = code.find("def level_3")
pos_2 = code.find("def level_2")
pos_1 = code.find("def level_1")
assert pos_3 < pos_2 < pos_1, "Functions not in correct dependency order"
def test_helper_with_import(self):
"""Test helper function that uses imports"""
import math
def helper(x):
return math.sqrt(x)
def main():
return helper(16)
code = generate_source_code(main)
assert "import math" in code
assert "def helper(x):" in code
assert "def main():" in code
class TestGlobalConstants:
"""Test global constant handling"""
def test_simple_constant(self):
"""Test function using simple constant"""
MY_CONSTANT = 42
def with_constant():
return MY_CONSTANT * 2
code = generate_source_code(with_constant)
assert "MY_CONSTANT = 42" in code
assert "def with_constant():" in code
def test_string_constant(self):
"""Test function using string constant"""
API_URL = "https://api.example.com"
def with_string():
return API_URL
code = generate_source_code(with_string)
assert "API_URL" in code
assert "https://api.example.com" in code
def test_multiple_constants(self):
"""Test function using multiple constants"""
BASE_URL = "https://example.com"
TIMEOUT = 30
def with_multiple():
return f"{BASE_URL}?timeout={TIMEOUT}"
code = generate_source_code(with_multiple)
assert "BASE_URL" in code
assert "TIMEOUT" in code
class TestClassDefinitions:
"""Test class definition handling"""
def test_simple_class(self):
"""Test function using simple class"""
class Calculator:
def __init__(self, value):
self.value = value
def add(self, x):
return self.value + x
def with_class():
calc = Calculator(10)
return calc.add(5)
code = generate_source_code(with_class)
assert "class Calculator:" in code
assert "__init__" in code
assert "def add" in code
assert "def with_class():" in code
def test_class_with_methods(self):
"""Test class with multiple methods"""
class Math:
def add(self, a, b):
return a + b
def multiply(self, a, b):
return a * b
def use_math():
m = Math()
return m.add(2, 3) + m.multiply(4, 5)
code = generate_source_code(use_math)
assert "class Math:" in code
assert "def add" in code
assert "def multiply" in code
class TestDecoratorRemoval:
"""Test decorator removal"""
def test_simple_decorator_removed(self):
"""Test that decorators are removed"""
def my_decorator(f):
def wrapper(*args, **kwargs):
return f(*args, **kwargs)
return wrapper
@my_decorator
def decorated():
return 42
code = generate_source_code(decorated)
assert "@my_decorator" not in code
assert "def decorated():" in code
assert "return 42" in code
def test_multiple_decorators_removed(self):
"""Test that multiple decorators are removed"""
def decorator1(f):
return f
def decorator2(f):
return f
@decorator1
@decorator2
def multi_decorated():
return True
code = generate_source_code(multi_decorated)
assert "@decorator1" not in code
assert "@decorator2" not in code
assert "def multi_decorated():" in code
class TestComplexScenarios:
"""Test complex real-world scenarios"""
def test_api_client_pattern(self):
"""Test typical API client pattern"""
import json
BASE_URL = "https://api.example.com"
def make_request(endpoint):
return f"{BASE_URL}/{endpoint}"
def parse_response(response):
return json.loads(response)
def get_user(user_id):
url = make_request(f"users/{user_id}")
response = f'{{"id": {user_id}}}'
return parse_response(response)
code = generate_source_code(get_user)
assert "import json" in code
assert "BASE_URL" in code
assert "def make_request" in code
assert "def parse_response" in code
assert "def get_user" in code
def test_data_processing_pipeline(self):
"""Test data processing pipeline"""
def validate_data(data):
return [x for x in data if x > 0]
def transform_data(data):
return [x * 2 for x in data]
def process_pipeline(raw_data):
valid = validate_data(raw_data)
return transform_data(valid)
code = generate_source_code(process_pipeline)
assert "def validate_data" in code
assert "def transform_data" in code
assert "def process_pipeline" in code
# Check ordering
validate_pos = code.find("def validate_data")
transform_pos = code.find("def transform_data")
pipeline_pos = code.find("def process_pipeline")
assert validate_pos < pipeline_pos
assert transform_pos < pipeline_pos
class TestEdgeCases:
"""Test edge cases"""
def test_function_with_default_args(self):
"""Test function with default arguments"""
def with_defaults(a, b=10, c=20):
return a + b + c
code = generate_source_code(with_defaults)
assert "def with_defaults(a, b=10, c=20):" in code
def test_function_with_kwargs(self):
"""Test function with *args and **kwargs"""
def with_varargs(*args, **kwargs):
return sum(args)
code = generate_source_code(with_varargs)
assert "def with_varargs(*args, **kwargs):" in code
def test_async_function(self):
"""Test async function"""
async def async_func():
return 42
code = generate_source_code(async_func)
assert "async def async_func():" in code
assert "return 42" in code
def test_lambda_works_in_files(self):
"""Test that lambda functions work when defined in files"""
lambda_func = lambda x: x * 2
# Lambda functions defined in files CAN have their source extracted
code = generate_source_code(lambda_func)
# Should include the lambda
assert "lambda" in code.lower() or "lambda_func" in code
class TestCaching:
"""Test caching mechanism"""
def test_cache_returns_same_result(self):
"""Test that cache returns identical results"""
def cached_func():
return 42
code1 = generate_source_code(cached_func)
code2 = generate_source_code(cached_func)
assert code1 == code2
def test_different_functions_different_cache(self):
"""Test that different functions have different cache entries"""
def func1():
return 1
def func2():
return 2
code1 = generate_source_code(func1)
code2 = generate_source_code(func2)
assert code1 != code2
assert "return 1" in code1
assert "return 2" in code2
class TestBuiltinHandling:
"""Test that builtins are handled correctly"""
def test_builtin_not_included(self):
"""Test that builtin functions are not included as dependencies"""
def uses_builtins():
data = list(range(10))
return len(data)
code = generate_source_code(uses_builtins)
# Should not include definitions for list, range, len
assert "def list" not in code
assert "def range" not in code
assert "def len" not in code
assert "def uses_builtins():" in code
class TestImportStylePreservation:
"""Test that import styles are preserved"""
def test_from_import_preserved(self):
"""Test from X import Y is preserved"""
from math import sqrt
def use_sqrt():
return sqrt(16)
code = generate_source_code(use_sqrt)
assert "from math import sqrt" in code
def test_import_as_preserved(self):
"""Test import X as Y is preserved"""
import json as j
def use_json():
return j.dumps({"key": "value"})
code = generate_source_code(use_json)
# The import style should be preserved
assert "import json" in code

View File

@@ -129,12 +129,12 @@ See [desktop-extension/README.md](desktop-extension/README.md) for more details.
## Documentation
- Installation: https://cua.ai/docs/libraries/mcp-server/installation
- Configuration: https://cua.ai/docs/libraries/mcp-server/configuration
- Usage: https://cua.ai/docs/libraries/mcp-server/usage
- Tools: https://cua.ai/docs/libraries/mcp-server/tools
- Client Integrations: https://cua.ai/docs/libraries/mcp-server/client-integrations
- LLM Integrations: https://cua.ai/docs/libraries/mcp-server/llm-integrations
- Installation: https://cua.ai/docs/agent-sdk/mcp-server/installation
- Configuration: https://cua.ai/docs/agent-sdk/mcp-server/configuration
- Usage: https://cua.ai/docs/agent-sdk/mcp-server/usage
- Tools: https://cua.ai/docs/agent-sdk/mcp-server/tools
- Client Integrations: https://cua.ai/docs/agent-sdk/mcp-server/client-integrations
- LLM Integrations: https://cua.ai/docs/agent-sdk/mcp-server/llm-integrations
## Troubleshooting