Merge pull request #660 from trycua/fix/sandboxed-dependency-resolution

[Computer] fix @sandboxed decorator to handle dependencies automatically
This commit is contained in:
ddupont
2025-12-16 12:32:07 -08:00
committed by GitHub
3 changed files with 897 additions and 25 deletions
+4 -24
View File
@@ -1121,19 +1121,11 @@ class Computer:
The result of the function execution, or raises any exception that occurred
"""
import base64
import inspect
import json
import textwrap
try:
# Get function source code using inspect.getsource
source = inspect.getsource(python_func)
# Remove common leading whitespace (dedent)
func_source = textwrap.dedent(source).strip()
# Remove decorators
while func_source.lstrip().startswith("@"):
func_source = func_source.split("\n", 1)[1].strip()
func_source = helpers.generate_source_code(python_func)
# Get function name for execution
func_name = python_func.__name__
@@ -1259,16 +1251,12 @@ print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
Uses a short launcher Python that spawns a detached child and exits immediately.
"""
import base64
import inspect
import json
import textwrap
import time as _time
try:
source = inspect.getsource(python_func)
func_source = textwrap.dedent(source).strip()
while func_source.lstrip().startswith("@"):
func_source = func_source.split("\n", 1)[1].strip()
func_source = helpers.generate_source_code(python_func)
func_name = python_func.__name__
args_json = json.dumps(args, default=str)
kwargs_json = json.dumps(kwargs, default=str)
@@ -1366,15 +1354,11 @@ print(p.pid)
remote traceback context appended.
"""
import base64
import inspect
import json
import textwrap
try:
source = inspect.getsource(python_func)
func_source = textwrap.dedent(source).strip()
while func_source.lstrip().startswith("@"):
func_source = func_source.split("\n", 1)[1].strip()
func_source = helpers.generate_source_code(python_func)
func_name = python_func.__name__
args_json = json.dumps(args, default=str)
kwargs_json = json.dumps(kwargs, default=str)
@@ -1482,16 +1466,12 @@ print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
Uses a short launcher Python that spawns a detached child and exits immediately.
"""
import base64
import inspect
import json
import textwrap
import time as _time
try:
source = inspect.getsource(python_func)
func_source = textwrap.dedent(source).strip()
while func_source.lstrip().startswith("@"):
func_source = func_source.split("\n", 1)[1].strip()
func_source = helpers.generate_source_code(python_func)
func_name = python_func.__name__
args_json = json.dumps(args, default=str)
kwargs_json = json.dumps(kwargs, default=str)
+415 -1
View File
@@ -2,10 +2,19 @@
Helper functions and decorators for the Computer module.
"""
import ast
import asyncio
import builtins
import importlib.util
import inspect
import logging
import os
import sys
from functools import wraps
from typing import Any, Awaitable, Callable, Optional, TypeVar
from inspect import getsource
from textwrap import dedent
from types import FunctionType, ModuleType
from typing import Any, Awaitable, Callable, Dict, List, Set, TypedDict, TypeVar
try:
# Python 3.12+ has ParamSpec in typing
@@ -17,9 +26,18 @@ except ImportError: # pragma: no cover
P = ParamSpec("P")
R = TypeVar("R")
class DependencyInfo(TypedDict):
import_statements: List[str]
definitions: List[tuple[str, Any]]
# Global reference to the default computer instance
_default_computer = None
# Global cache for function dependency analysis
_function_dependency_map: Dict[FunctionType, DependencyInfo] = {}
logger = logging.getLogger(__name__)
@@ -42,6 +60,9 @@ def sandboxed(
"""
Decorator that wraps a function to be executed remotely via computer.venv_exec
The function is automatically analyzed for dependencies (imports, helper functions,
constants, etc.) and reconstructed with all necessary code in the remote sandbox.
Args:
venv_name: Name of the virtual environment to execute in
computer: The computer instance to use, or "default" to use the globally set default
@@ -74,3 +95,396 @@ def sandboxed(
return wrapper
return decorator
def _extract_import_statement(name: str, module: ModuleType) -> str:
"""Extract the original import statement for a module."""
module_name = module.__name__
if name == module_name.split(".")[0]:
return f"import {module_name}"
else:
return f"import {module_name} as {name}"
def _is_third_party_module(module_name: str) -> bool:
"""Check if a module is a third-party module."""
stdlib_modules = set(sys.stdlib_module_names) if hasattr(sys, "stdlib_module_names") else set()
if module_name in stdlib_modules:
return False
try:
spec = importlib.util.find_spec(module_name)
if spec is None:
return False
if spec.origin and ("site-packages" in spec.origin or "dist-packages" in spec.origin):
return True
return False
except (ImportError, ModuleNotFoundError, ValueError):
return False
def _is_project_import(module_name: str) -> bool:
"""Check if a module is a project-level import."""
if module_name.startswith("__relative_import_level_"):
return True
if module_name in sys.modules:
module = sys.modules[module_name]
if hasattr(module, "__file__") and module.__file__:
if "site-packages" not in module.__file__ and "dist-packages" not in module.__file__:
cwd = os.getcwd()
if module.__file__.startswith(cwd):
return True
return False
def _categorize_module(module_name: str) -> str:
"""Categorize a module as stdlib, third-party, or project."""
if module_name.startswith("__relative_import_level_"):
return "project"
elif module_name in (
set(sys.stdlib_module_names) if hasattr(sys, "stdlib_module_names") else set()
):
return "stdlib"
elif _is_third_party_module(module_name):
return "third_party"
elif _is_project_import(module_name):
return "project"
else:
return "unknown"
class _DependencyVisitor(ast.NodeVisitor):
"""AST visitor to extract imports and name references from a function."""
def __init__(self, function_name: str) -> None:
self.function_name = function_name
self.internal_imports: Set[str] = set()
self.internal_import_statements: List[str] = []
self.name_references: Set[str] = set()
self.local_names: Set[str] = set()
self.inside_function = False
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
if node.name == self.function_name and not self.inside_function:
self.inside_function = True
for arg in node.args.args + node.args.posonlyargs + node.args.kwonlyargs:
self.local_names.add(arg.arg)
if node.args.vararg:
self.local_names.add(node.args.vararg.arg)
if node.args.kwarg:
self.local_names.add(node.args.kwarg.arg)
for child in node.body:
self.visit(child)
self.inside_function = False
else:
if self.inside_function:
self.local_names.add(node.name)
for child in node.body:
self.visit(child)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
self.visit_FunctionDef(node) # type: ignore
def visit_Import(self, node: ast.Import) -> None:
if self.inside_function:
for alias in node.names:
module_name = alias.name.split(".")[0]
self.internal_imports.add(module_name)
imported_as = alias.asname if alias.asname else alias.name.split(".")[0]
self.local_names.add(imported_as)
self.internal_import_statements.append(ast.unparse(node))
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if self.inside_function:
if node.level == 0 and node.module:
module_name = node.module.split(".")[0]
self.internal_imports.add(module_name)
elif node.level > 0:
self.internal_imports.add(f"__relative_import_level_{node.level}__")
for alias in node.names:
imported_as = alias.asname if alias.asname else alias.name
self.local_names.add(imported_as)
self.internal_import_statements.append(ast.unparse(node))
self.generic_visit(node)
def visit_Name(self, node: ast.Name) -> None:
if self.inside_function:
if isinstance(node.ctx, ast.Load):
self.name_references.add(node.id)
elif isinstance(node.ctx, ast.Store):
self.local_names.add(node.id)
self.generic_visit(node)
def visit_ClassDef(self, node: ast.ClassDef) -> None:
if self.inside_function:
self.local_names.add(node.name)
self.generic_visit(node)
def visit_For(self, node: ast.For) -> None:
if self.inside_function and isinstance(node.target, ast.Name):
self.local_names.add(node.target.id)
self.generic_visit(node)
def visit_comprehension(self, node: ast.comprehension) -> None:
if self.inside_function and isinstance(node.target, ast.Name):
self.local_names.add(node.target.id)
self.generic_visit(node)
def visit_ExceptHandler(self, node: ast.ExceptHandler) -> None:
if self.inside_function and node.name:
self.local_names.add(node.name)
self.generic_visit(node)
def visit_With(self, node: ast.With) -> None:
if self.inside_function:
for item in node.items:
if item.optional_vars and isinstance(item.optional_vars, ast.Name):
self.local_names.add(item.optional_vars.id)
self.generic_visit(node)
def _traverse_and_collect_dependencies(func: FunctionType) -> DependencyInfo:
"""
Traverse a function and collect its dependencies.
Returns a dict with:
- import_statements: List of import statements needed
- definitions: List of (name, obj) tuples for helper functions/classes/constants
"""
source = dedent(getsource(func))
tree = ast.parse(source)
visitor = _DependencyVisitor(func.__name__)
visitor.visit(tree)
builtin_names = set(dir(builtins))
external_refs = (visitor.name_references - visitor.local_names) - builtin_names
import_statements = []
definitions = []
visited = set()
# Include all internal import statements
import_statements.extend(visitor.internal_import_statements)
# Analyze external references recursively
def analyze_object(obj: Any, name: str, depth: int = 0) -> None:
if depth > 20:
return
obj_id = id(obj)
if obj_id in visited:
return
visited.add(obj_id)
# Handle modules
if inspect.ismodule(obj):
import_stmt = _extract_import_statement(name, obj)
import_statements.append(import_stmt)
return
# Handle functions and classes
if (
inspect.isfunction(obj)
or inspect.isclass(obj)
or inspect.isbuiltin(obj)
or inspect.ismethod(obj)
):
obj_module = getattr(obj, "__module__", None)
if obj_module:
base_module = obj_module.split(".")[0]
module_category = _categorize_module(base_module)
# If from stdlib/third-party, just add import
if module_category in ("stdlib", "third_party"):
obj_name = getattr(obj, "__name__", name)
# Check if object is accessible by 'name' (in globals or closures)
is_accessible = False
if name in func.__globals__ and func.__globals__[name] is obj:
is_accessible = True
elif func.__closure__ and hasattr(func, "__code__"):
freevars = func.__code__.co_freevars
for i, var_name in enumerate(freevars):
if var_name == name and i < len(func.__closure__):
try:
if func.__closure__[i].cell_contents is obj:
is_accessible = True
break
except (ValueError, AttributeError):
pass
if is_accessible and name == obj_name:
# Direct import: from requests import get, from math import sqrt
import_statements.append(f"from {base_module} import {name}")
else:
# Module import: import requests
import_statements.append(f"import {base_module}")
return
try:
obj_tree = ast.parse(dedent(getsource(obj)))
obj_visitor = _DependencyVisitor(obj.__name__)
obj_visitor.visit(obj_tree)
obj_external_refs = obj_visitor.name_references - obj_visitor.local_names
obj_external_refs = obj_external_refs - builtin_names
# Add internal imports from this object
import_statements.extend(obj_visitor.internal_import_statements)
# Recursively analyze its dependencies
obj_globals = getattr(obj, "__globals__", None)
obj_closure = getattr(obj, "__closure__", None)
obj_code = getattr(obj, "__code__", None)
if obj_globals:
for ref_name in obj_external_refs:
ref_obj = None
# Check globals first
if ref_name in obj_globals:
ref_obj = obj_globals[ref_name]
# Check closure variables using co_freevars
elif obj_closure and obj_code:
freevars = obj_code.co_freevars
for i, var_name in enumerate(freevars):
if var_name == ref_name and i < len(obj_closure):
try:
ref_obj = obj_closure[i].cell_contents
break
except (ValueError, AttributeError):
pass
if ref_obj is not None:
analyze_object(ref_obj, ref_name, depth + 1)
# Add this object to definitions
if not inspect.ismodule(obj):
ref_module = getattr(obj, "__module__", None)
if ref_module:
ref_base_module = ref_module.split(".")[0]
ref_category = _categorize_module(ref_base_module)
if ref_category not in ("stdlib", "third_party"):
definitions.append((name, obj))
else:
definitions.append((name, obj))
except (OSError, TypeError):
pass
return
if isinstance(obj, (int, float, str, bool, list, dict, tuple, set, frozenset, type(None))):
definitions.append((name, obj))
# Analyze all external references
for name in external_refs:
obj = None
# First check globals
if name in func.__globals__:
obj = func.__globals__[name]
# Then check closure variables (sibling functions in enclosing scope)
elif func.__closure__ and func.__code__.co_freevars:
# Match closure variable names with cell contents
freevars = func.__code__.co_freevars
for i, var_name in enumerate(freevars):
if var_name == name and i < len(func.__closure__):
try:
obj = func.__closure__[i].cell_contents
break
except (ValueError, AttributeError):
# Cell is empty or doesn't have contents
pass
if obj is not None:
analyze_object(obj, name)
# Remove duplicate import statements
unique_imports = []
seen = set()
for stmt in import_statements:
if stmt not in seen:
seen.add(stmt)
unique_imports.append(stmt)
# Remove duplicate definitions
unique_definitions = []
seen_names = set()
for name, obj in definitions:
if name not in seen_names:
seen_names.add(name)
unique_definitions.append((name, obj))
return {
"import_statements": unique_imports,
"definitions": unique_definitions,
}
def generate_source_code(func: FunctionType) -> str:
"""
Generate complete source code for a function with all dependencies.
Args:
func: The function to generate source code for
Returns:
Complete Python source code as a string
"""
if func in _function_dependency_map:
info = _function_dependency_map[func]
else:
info = _traverse_and_collect_dependencies(func)
_function_dependency_map[func] = info
# Build source code
parts = []
# 1. Add imports
if info["import_statements"]:
parts.append("\n".join(info["import_statements"]))
# 2. Add definitions
for name, obj in info["definitions"]:
try:
if inspect.isfunction(obj):
source = dedent(getsource(obj))
tree = ast.parse(source)
if tree.body and isinstance(tree.body[0], (ast.FunctionDef, ast.AsyncFunctionDef)):
tree.body[0].decorator_list = []
source = ast.unparse(tree)
parts.append(source)
elif inspect.isclass(obj):
source = dedent(getsource(obj))
tree = ast.parse(source)
if tree.body and isinstance(tree.body[0], ast.ClassDef):
tree.body[0].decorator_list = []
source = ast.unparse(tree)
parts.append(source)
else:
parts.append(f"{name} = {repr(obj)}")
except (OSError, TypeError):
pass
# 3. Add main function (without decorators)
func_source = dedent(getsource(func))
tree = ast.parse(func_source)
if tree.body and isinstance(tree.body[0], (ast.FunctionDef, ast.AsyncFunctionDef)):
tree.body[0].decorator_list = []
func_source = ast.unparse(tree)
parts.append(func_source)
return "\n\n".join(parts)
+478
View File
@@ -0,0 +1,478 @@
from computer.helpers import generate_source_code
class TestSimpleCases:
"""Test simple dependency cases"""
def test_simple_function_no_dependencies(self):
"""Test simple function with no dependencies"""
def simple_func():
return 42
code = generate_source_code(simple_func)
assert "def simple_func():" in code
assert "return 42" in code
def test_function_with_parameters(self):
"""Test function with parameters"""
def add(a, b):
return a + b
code = generate_source_code(add)
assert "def add(a, b):" in code
assert "return a + b" in code
def test_function_with_docstring(self):
"""Test function with docstring"""
def documented():
"""This is a docstring"""
return True
code = generate_source_code(documented)
assert "def documented():" in code
assert "This is a docstring" in code
class TestImports:
"""Test import handling"""
def test_stdlib_import_inside_function(self):
"""Test function with stdlib import inside"""
def with_stdlib():
import math
return math.sqrt(16)
code = generate_source_code(with_stdlib)
assert "import math" in code
assert "def with_stdlib():" in code
assert "math.sqrt(16)" in code
def test_multiple_stdlib_imports(self):
"""Test function with multiple stdlib imports"""
def with_multiple():
import json
import math
return json.dumps({"value": math.pi})
code = generate_source_code(with_multiple)
assert "import json" in code
assert "import math" in code
def test_from_import_stdlib(self):
"""Test from X import Y style"""
def with_from_import():
from math import pi, sqrt
return sqrt(16) + pi
code = generate_source_code(with_from_import)
assert "from math import" in code
assert "sqrt, pi" in code or ("sqrt" in code and "pi" in code)
def test_third_party_global_import(self):
"""Test globally imported third-party module"""
from requests import get
def with_requests():
return get
code = generate_source_code(with_requests)
assert "from requests import get" in code
assert "def with_requests():" in code
class TestHelperFunctions:
"""Test helper function dependencies"""
def test_single_helper_function(self):
"""Test function with one helper"""
def helper(x):
return x * 2
def main():
return helper(21)
code = generate_source_code(main)
assert "def helper(x):" in code
assert "return x * 2" in code
assert "def main():" in code
assert "return helper(21)" in code
def test_nested_helper_functions(self):
"""Test function with nested helpers"""
def helper_a(x):
return x + 1
def helper_b(x):
return helper_a(x) * 2
def main():
return helper_b(10)
code = generate_source_code(main)
assert "def helper_a(x):" in code
assert "def helper_b(x):" in code
assert "def main():" in code
def test_helper_dependency_ordering(self):
"""Test that helpers are ordered correctly (dependencies first)"""
def level_3(x):
return x + 1
def level_2(x):
return level_3(x) * 2
def level_1(x):
return level_2(x) + 5
code = generate_source_code(level_1)
# Check ordering
pos_3 = code.find("def level_3")
pos_2 = code.find("def level_2")
pos_1 = code.find("def level_1")
assert pos_3 < pos_2 < pos_1, "Functions not in correct dependency order"
def test_helper_with_import(self):
"""Test helper function that uses imports"""
import math
def helper(x):
return math.sqrt(x)
def main():
return helper(16)
code = generate_source_code(main)
assert "import math" in code
assert "def helper(x):" in code
assert "def main():" in code
class TestGlobalConstants:
"""Test global constant handling"""
def test_simple_constant(self):
"""Test function using simple constant"""
MY_CONSTANT = 42
def with_constant():
return MY_CONSTANT * 2
code = generate_source_code(with_constant)
assert "MY_CONSTANT = 42" in code
assert "def with_constant():" in code
def test_string_constant(self):
"""Test function using string constant"""
API_URL = "https://api.example.com"
def with_string():
return API_URL
code = generate_source_code(with_string)
assert "API_URL" in code
assert "https://api.example.com" in code
def test_multiple_constants(self):
"""Test function using multiple constants"""
BASE_URL = "https://example.com"
TIMEOUT = 30
def with_multiple():
return f"{BASE_URL}?timeout={TIMEOUT}"
code = generate_source_code(with_multiple)
assert "BASE_URL" in code
assert "TIMEOUT" in code
class TestClassDefinitions:
"""Test class definition handling"""
def test_simple_class(self):
"""Test function using simple class"""
class Calculator:
def __init__(self, value):
self.value = value
def add(self, x):
return self.value + x
def with_class():
calc = Calculator(10)
return calc.add(5)
code = generate_source_code(with_class)
assert "class Calculator:" in code
assert "__init__" in code
assert "def add" in code
assert "def with_class():" in code
def test_class_with_methods(self):
"""Test class with multiple methods"""
class Math:
def add(self, a, b):
return a + b
def multiply(self, a, b):
return a * b
def use_math():
m = Math()
return m.add(2, 3) + m.multiply(4, 5)
code = generate_source_code(use_math)
assert "class Math:" in code
assert "def add" in code
assert "def multiply" in code
class TestDecoratorRemoval:
"""Test decorator removal"""
def test_simple_decorator_removed(self):
"""Test that decorators are removed"""
def my_decorator(f):
def wrapper(*args, **kwargs):
return f(*args, **kwargs)
return wrapper
@my_decorator
def decorated():
return 42
code = generate_source_code(decorated)
assert "@my_decorator" not in code
assert "def decorated():" in code
assert "return 42" in code
def test_multiple_decorators_removed(self):
"""Test that multiple decorators are removed"""
def decorator1(f):
return f
def decorator2(f):
return f
@decorator1
@decorator2
def multi_decorated():
return True
code = generate_source_code(multi_decorated)
assert "@decorator1" not in code
assert "@decorator2" not in code
assert "def multi_decorated():" in code
class TestComplexScenarios:
"""Test complex real-world scenarios"""
def test_api_client_pattern(self):
"""Test typical API client pattern"""
import json
BASE_URL = "https://api.example.com"
def make_request(endpoint):
return f"{BASE_URL}/{endpoint}"
def parse_response(response):
return json.loads(response)
def get_user(user_id):
url = make_request(f"users/{user_id}")
response = f'{{"id": {user_id}}}'
return parse_response(response)
code = generate_source_code(get_user)
assert "import json" in code
assert "BASE_URL" in code
assert "def make_request" in code
assert "def parse_response" in code
assert "def get_user" in code
def test_data_processing_pipeline(self):
"""Test data processing pipeline"""
def validate_data(data):
return [x for x in data if x > 0]
def transform_data(data):
return [x * 2 for x in data]
def process_pipeline(raw_data):
valid = validate_data(raw_data)
return transform_data(valid)
code = generate_source_code(process_pipeline)
assert "def validate_data" in code
assert "def transform_data" in code
assert "def process_pipeline" in code
# Check ordering
validate_pos = code.find("def validate_data")
transform_pos = code.find("def transform_data")
pipeline_pos = code.find("def process_pipeline")
assert validate_pos < pipeline_pos
assert transform_pos < pipeline_pos
class TestEdgeCases:
"""Test edge cases"""
def test_function_with_default_args(self):
"""Test function with default arguments"""
def with_defaults(a, b=10, c=20):
return a + b + c
code = generate_source_code(with_defaults)
assert "def with_defaults(a, b=10, c=20):" in code
def test_function_with_kwargs(self):
"""Test function with *args and **kwargs"""
def with_varargs(*args, **kwargs):
return sum(args)
code = generate_source_code(with_varargs)
assert "def with_varargs(*args, **kwargs):" in code
def test_async_function(self):
"""Test async function"""
async def async_func():
return 42
code = generate_source_code(async_func)
assert "async def async_func():" in code
assert "return 42" in code
def test_lambda_works_in_files(self):
"""Test that lambda functions work when defined in files"""
lambda_func = lambda x: x * 2
# Lambda functions defined in files CAN have their source extracted
code = generate_source_code(lambda_func)
# Should include the lambda
assert "lambda" in code.lower() or "lambda_func" in code
class TestCaching:
"""Test caching mechanism"""
def test_cache_returns_same_result(self):
"""Test that cache returns identical results"""
def cached_func():
return 42
code1 = generate_source_code(cached_func)
code2 = generate_source_code(cached_func)
assert code1 == code2
def test_different_functions_different_cache(self):
"""Test that different functions have different cache entries"""
def func1():
return 1
def func2():
return 2
code1 = generate_source_code(func1)
code2 = generate_source_code(func2)
assert code1 != code2
assert "return 1" in code1
assert "return 2" in code2
class TestBuiltinHandling:
"""Test that builtins are handled correctly"""
def test_builtin_not_included(self):
"""Test that builtin functions are not included as dependencies"""
def uses_builtins():
data = list(range(10))
return len(data)
code = generate_source_code(uses_builtins)
# Should not include definitions for list, range, len
assert "def list" not in code
assert "def range" not in code
assert "def len" not in code
assert "def uses_builtins():" in code
class TestImportStylePreservation:
"""Test that import styles are preserved"""
def test_from_import_preserved(self):
"""Test from X import Y is preserved"""
from math import sqrt
def use_sqrt():
return sqrt(16)
code = generate_source_code(use_sqrt)
assert "from math import sqrt" in code
def test_import_as_preserved(self):
"""Test import X as Y is preserved"""
import json as j
def use_json():
return j.dumps({"key": "value"})
code = generate_source_code(use_json)
# The import style should be preserved
assert "import json" in code