mirror of
https://github.com/trycua/computer.git
synced 2026-01-04 04:19:57 -06:00
feat(helpers): add automatic dependency resolution for @sandboxed decorator
- Implement AST-based dependency analysis with closure support - Automatically detect and include helper functions, imports, and constants - Handle nested dependencies recursively - Add caching with function object as key
This commit is contained in:
@@ -2,10 +2,19 @@
|
||||
Helper functions and decorators for the Computer module.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import builtins
|
||||
import importlib.util
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from functools import wraps
|
||||
from typing import Any, Awaitable, Callable, Optional, TypeVar
|
||||
from inspect import getsource
|
||||
from textwrap import dedent
|
||||
from types import FunctionType, ModuleType
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Set, TypedDict, TypeVar
|
||||
|
||||
try:
|
||||
# Python 3.12+ has ParamSpec in typing
|
||||
@@ -17,9 +26,18 @@ except ImportError: # pragma: no cover
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class DependencyInfo(TypedDict):
|
||||
import_statements: List[str]
|
||||
definitions: List[tuple[str, Any]]
|
||||
|
||||
|
||||
# Global reference to the default computer instance
|
||||
_default_computer = None
|
||||
|
||||
# Global cache for function dependency analysis
|
||||
_function_dependency_map: Dict[FunctionType, DependencyInfo] = {}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -42,6 +60,9 @@ def sandboxed(
|
||||
"""
|
||||
Decorator that wraps a function to be executed remotely via computer.venv_exec
|
||||
|
||||
The function is automatically analyzed for dependencies (imports, helper functions,
|
||||
constants, etc.) and reconstructed with all necessary code in the remote sandbox.
|
||||
|
||||
Args:
|
||||
venv_name: Name of the virtual environment to execute in
|
||||
computer: The computer instance to use, or "default" to use the globally set default
|
||||
@@ -74,3 +95,396 @@ def sandboxed(
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _extract_import_statement(name: str, module: ModuleType) -> str:
|
||||
"""Extract the original import statement for a module."""
|
||||
module_name = module.__name__
|
||||
|
||||
if name == module_name.split(".")[0]:
|
||||
return f"import {module_name}"
|
||||
else:
|
||||
return f"import {module_name} as {name}"
|
||||
|
||||
|
||||
def _is_third_party_module(module_name: str) -> bool:
|
||||
"""Check if a module is a third-party module."""
|
||||
stdlib_modules = set(sys.stdlib_module_names) if hasattr(sys, "stdlib_module_names") else set()
|
||||
|
||||
if module_name in stdlib_modules:
|
||||
return False
|
||||
|
||||
try:
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
if spec is None:
|
||||
return False
|
||||
|
||||
if spec.origin and ("site-packages" in spec.origin or "dist-packages" in spec.origin):
|
||||
return True
|
||||
|
||||
return False
|
||||
except (ImportError, ModuleNotFoundError, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
def _is_project_import(module_name: str) -> bool:
|
||||
"""Check if a module is a project-level import."""
|
||||
if module_name.startswith("__relative_import_level_"):
|
||||
return True
|
||||
|
||||
if module_name in sys.modules:
|
||||
module = sys.modules[module_name]
|
||||
if hasattr(module, "__file__") and module.__file__:
|
||||
if "site-packages" not in module.__file__ and "dist-packages" not in module.__file__:
|
||||
cwd = os.getcwd()
|
||||
if module.__file__.startswith(cwd):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _categorize_module(module_name: str) -> str:
|
||||
"""Categorize a module as stdlib, third-party, or project."""
|
||||
if module_name.startswith("__relative_import_level_"):
|
||||
return "project"
|
||||
elif module_name in (
|
||||
set(sys.stdlib_module_names) if hasattr(sys, "stdlib_module_names") else set()
|
||||
):
|
||||
return "stdlib"
|
||||
elif _is_third_party_module(module_name):
|
||||
return "third_party"
|
||||
elif _is_project_import(module_name):
|
||||
return "project"
|
||||
else:
|
||||
return "unknown"
|
||||
|
||||
|
||||
class _DependencyVisitor(ast.NodeVisitor):
|
||||
"""AST visitor to extract imports and name references from a function."""
|
||||
|
||||
def __init__(self, function_name: str) -> None:
|
||||
self.function_name = function_name
|
||||
self.internal_imports: Set[str] = set()
|
||||
self.internal_import_statements: List[str] = []
|
||||
self.name_references: Set[str] = set()
|
||||
self.local_names: Set[str] = set()
|
||||
self.inside_function = False
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
||||
if node.name == self.function_name and not self.inside_function:
|
||||
self.inside_function = True
|
||||
|
||||
for arg in node.args.args + node.args.posonlyargs + node.args.kwonlyargs:
|
||||
self.local_names.add(arg.arg)
|
||||
if node.args.vararg:
|
||||
self.local_names.add(node.args.vararg.arg)
|
||||
if node.args.kwarg:
|
||||
self.local_names.add(node.args.kwarg.arg)
|
||||
|
||||
for child in node.body:
|
||||
self.visit(child)
|
||||
|
||||
self.inside_function = False
|
||||
else:
|
||||
if self.inside_function:
|
||||
self.local_names.add(node.name)
|
||||
for child in node.body:
|
||||
self.visit(child)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
|
||||
self.visit_FunctionDef(node) # type: ignore
|
||||
|
||||
def visit_Import(self, node: ast.Import) -> None:
|
||||
if self.inside_function:
|
||||
for alias in node.names:
|
||||
module_name = alias.name.split(".")[0]
|
||||
self.internal_imports.add(module_name)
|
||||
imported_as = alias.asname if alias.asname else alias.name.split(".")[0]
|
||||
self.local_names.add(imported_as)
|
||||
self.internal_import_statements.append(ast.unparse(node))
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
||||
if self.inside_function:
|
||||
if node.level == 0 and node.module:
|
||||
module_name = node.module.split(".")[0]
|
||||
self.internal_imports.add(module_name)
|
||||
elif node.level > 0:
|
||||
self.internal_imports.add(f"__relative_import_level_{node.level}__")
|
||||
|
||||
for alias in node.names:
|
||||
imported_as = alias.asname if alias.asname else alias.name
|
||||
self.local_names.add(imported_as)
|
||||
self.internal_import_statements.append(ast.unparse(node))
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Name(self, node: ast.Name) -> None:
|
||||
if self.inside_function:
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
self.name_references.add(node.id)
|
||||
elif isinstance(node.ctx, ast.Store):
|
||||
self.local_names.add(node.id)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
||||
if self.inside_function:
|
||||
self.local_names.add(node.name)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_For(self, node: ast.For) -> None:
|
||||
if self.inside_function and isinstance(node.target, ast.Name):
|
||||
self.local_names.add(node.target.id)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_comprehension(self, node: ast.comprehension) -> None:
|
||||
if self.inside_function and isinstance(node.target, ast.Name):
|
||||
self.local_names.add(node.target.id)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ExceptHandler(self, node: ast.ExceptHandler) -> None:
|
||||
if self.inside_function and node.name:
|
||||
self.local_names.add(node.name)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_With(self, node: ast.With) -> None:
|
||||
if self.inside_function:
|
||||
for item in node.items:
|
||||
if item.optional_vars and isinstance(item.optional_vars, ast.Name):
|
||||
self.local_names.add(item.optional_vars.id)
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
def _traverse_and_collect_dependencies(func: FunctionType) -> DependencyInfo:
|
||||
"""
|
||||
Traverse a function and collect its dependencies.
|
||||
|
||||
Returns a dict with:
|
||||
- import_statements: List of import statements needed
|
||||
- definitions: List of (name, obj) tuples for helper functions/classes/constants
|
||||
"""
|
||||
source = dedent(getsource(func))
|
||||
tree = ast.parse(source)
|
||||
|
||||
visitor = _DependencyVisitor(func.__name__)
|
||||
visitor.visit(tree)
|
||||
|
||||
builtin_names = set(dir(builtins))
|
||||
external_refs = (visitor.name_references - visitor.local_names) - builtin_names
|
||||
|
||||
import_statements = []
|
||||
definitions = []
|
||||
visited = set()
|
||||
|
||||
# Include all internal import statements
|
||||
import_statements.extend(visitor.internal_import_statements)
|
||||
|
||||
# Analyze external references recursively
|
||||
def analyze_object(obj: Any, name: str, depth: int = 0) -> None:
|
||||
if depth > 20:
|
||||
return
|
||||
|
||||
obj_id = id(obj)
|
||||
if obj_id in visited:
|
||||
return
|
||||
visited.add(obj_id)
|
||||
|
||||
# Handle modules
|
||||
if inspect.ismodule(obj):
|
||||
import_stmt = _extract_import_statement(name, obj)
|
||||
import_statements.append(import_stmt)
|
||||
return
|
||||
|
||||
# Handle functions and classes
|
||||
if (
|
||||
inspect.isfunction(obj)
|
||||
or inspect.isclass(obj)
|
||||
or inspect.isbuiltin(obj)
|
||||
or inspect.ismethod(obj)
|
||||
):
|
||||
obj_module = getattr(obj, "__module__", None)
|
||||
if obj_module:
|
||||
base_module = obj_module.split(".")[0]
|
||||
module_category = _categorize_module(base_module)
|
||||
|
||||
# If from stdlib/third-party, just add import
|
||||
if module_category in ("stdlib", "third_party"):
|
||||
obj_name = getattr(obj, "__name__", name)
|
||||
|
||||
# Check if object is accessible by 'name' (in globals or closures)
|
||||
is_accessible = False
|
||||
if name in func.__globals__ and func.__globals__[name] is obj:
|
||||
is_accessible = True
|
||||
elif func.__closure__ and hasattr(func, "__code__"):
|
||||
freevars = func.__code__.co_freevars
|
||||
for i, var_name in enumerate(freevars):
|
||||
if var_name == name and i < len(func.__closure__):
|
||||
try:
|
||||
if func.__closure__[i].cell_contents is obj:
|
||||
is_accessible = True
|
||||
break
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
if is_accessible and name == obj_name:
|
||||
# Direct import: from requests import get, from math import sqrt
|
||||
import_statements.append(f"from {base_module} import {name}")
|
||||
else:
|
||||
# Module import: import requests
|
||||
import_statements.append(f"import {base_module}")
|
||||
return
|
||||
|
||||
try:
|
||||
obj_tree = ast.parse(dedent(getsource(obj)))
|
||||
obj_visitor = _DependencyVisitor(obj.__name__)
|
||||
obj_visitor.visit(obj_tree)
|
||||
|
||||
obj_external_refs = obj_visitor.name_references - obj_visitor.local_names
|
||||
obj_external_refs = obj_external_refs - builtin_names
|
||||
|
||||
# Add internal imports from this object
|
||||
import_statements.extend(obj_visitor.internal_import_statements)
|
||||
|
||||
# Recursively analyze its dependencies
|
||||
obj_globals = getattr(obj, "__globals__", None)
|
||||
obj_closure = getattr(obj, "__closure__", None)
|
||||
obj_code = getattr(obj, "__code__", None)
|
||||
if obj_globals:
|
||||
for ref_name in obj_external_refs:
|
||||
ref_obj = None
|
||||
|
||||
# Check globals first
|
||||
if ref_name in obj_globals:
|
||||
ref_obj = obj_globals[ref_name]
|
||||
# Check closure variables using co_freevars
|
||||
elif obj_closure and obj_code:
|
||||
freevars = obj_code.co_freevars
|
||||
for i, var_name in enumerate(freevars):
|
||||
if var_name == ref_name and i < len(obj_closure):
|
||||
try:
|
||||
ref_obj = obj_closure[i].cell_contents
|
||||
break
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
if ref_obj is not None:
|
||||
analyze_object(ref_obj, ref_name, depth + 1)
|
||||
|
||||
# Add this object to definitions
|
||||
if not inspect.ismodule(obj):
|
||||
ref_module = getattr(obj, "__module__", None)
|
||||
if ref_module:
|
||||
ref_base_module = ref_module.split(".")[0]
|
||||
ref_category = _categorize_module(ref_base_module)
|
||||
if ref_category not in ("stdlib", "third_party"):
|
||||
definitions.append((name, obj))
|
||||
else:
|
||||
definitions.append((name, obj))
|
||||
|
||||
except (OSError, TypeError):
|
||||
pass
|
||||
return
|
||||
|
||||
if isinstance(obj, (int, float, str, bool, list, dict, tuple, set, frozenset, type(None))):
|
||||
definitions.append((name, obj))
|
||||
|
||||
# Analyze all external references
|
||||
for name in external_refs:
|
||||
obj = None
|
||||
|
||||
# First check globals
|
||||
if name in func.__globals__:
|
||||
obj = func.__globals__[name]
|
||||
# Then check closure variables (sibling functions in enclosing scope)
|
||||
elif func.__closure__ and func.__code__.co_freevars:
|
||||
# Match closure variable names with cell contents
|
||||
freevars = func.__code__.co_freevars
|
||||
for i, var_name in enumerate(freevars):
|
||||
if var_name == name and i < len(func.__closure__):
|
||||
try:
|
||||
obj = func.__closure__[i].cell_contents
|
||||
break
|
||||
except (ValueError, AttributeError):
|
||||
# Cell is empty or doesn't have contents
|
||||
pass
|
||||
|
||||
if obj is not None:
|
||||
analyze_object(obj, name)
|
||||
|
||||
# Remove duplicate import statements
|
||||
unique_imports = []
|
||||
seen = set()
|
||||
for stmt in import_statements:
|
||||
if stmt not in seen:
|
||||
seen.add(stmt)
|
||||
unique_imports.append(stmt)
|
||||
|
||||
# Remove duplicate definitions
|
||||
unique_definitions = []
|
||||
seen_names = set()
|
||||
for name, obj in definitions:
|
||||
if name not in seen_names:
|
||||
seen_names.add(name)
|
||||
unique_definitions.append((name, obj))
|
||||
|
||||
return {
|
||||
"import_statements": unique_imports,
|
||||
"definitions": unique_definitions,
|
||||
}
|
||||
|
||||
|
||||
def generate_source_code(func: FunctionType) -> str:
|
||||
"""
|
||||
Generate complete source code for a function with all dependencies.
|
||||
|
||||
Args:
|
||||
func: The function to generate source code for
|
||||
|
||||
Returns:
|
||||
Complete Python source code as a string
|
||||
"""
|
||||
|
||||
if func in _function_dependency_map:
|
||||
info = _function_dependency_map[func]
|
||||
else:
|
||||
info = _traverse_and_collect_dependencies(func)
|
||||
_function_dependency_map[func] = info
|
||||
|
||||
# Build source code
|
||||
parts = []
|
||||
|
||||
# 1. Add imports
|
||||
if info["import_statements"]:
|
||||
parts.append("\n".join(info["import_statements"]))
|
||||
|
||||
# 2. Add definitions
|
||||
for name, obj in info["definitions"]:
|
||||
try:
|
||||
if inspect.isfunction(obj):
|
||||
source = dedent(getsource(obj))
|
||||
tree = ast.parse(source)
|
||||
if tree.body and isinstance(tree.body[0], (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
tree.body[0].decorator_list = []
|
||||
source = ast.unparse(tree)
|
||||
parts.append(source)
|
||||
elif inspect.isclass(obj):
|
||||
source = dedent(getsource(obj))
|
||||
tree = ast.parse(source)
|
||||
if tree.body and isinstance(tree.body[0], ast.ClassDef):
|
||||
tree.body[0].decorator_list = []
|
||||
source = ast.unparse(tree)
|
||||
parts.append(source)
|
||||
else:
|
||||
parts.append(f"{name} = {repr(obj)}")
|
||||
except (OSError, TypeError):
|
||||
pass
|
||||
|
||||
# 3. Add main function (without decorators)
|
||||
func_source = dedent(getsource(func))
|
||||
tree = ast.parse(func_source)
|
||||
if tree.body and isinstance(tree.body[0], (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
tree.body[0].decorator_list = []
|
||||
func_source = ast.unparse(tree)
|
||||
parts.append(func_source)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
Reference in New Issue
Block a user