mirror of
https://github.com/trycua/computer.git
synced 2026-05-12 03:21:18 -05:00
Merge pull request #660 from trycua/fix/sandboxed-dependency-resolution
[Computer] fix @sandboxed decorator to handle dependencies automatically
This commit is contained in:
@@ -1121,19 +1121,11 @@ class Computer:
|
||||
The result of the function execution, or raises any exception that occurred
|
||||
"""
|
||||
import base64
|
||||
import inspect
|
||||
import json
|
||||
import textwrap
|
||||
|
||||
try:
|
||||
# Get function source code using inspect.getsource
|
||||
source = inspect.getsource(python_func)
|
||||
# Remove common leading whitespace (dedent)
|
||||
func_source = textwrap.dedent(source).strip()
|
||||
|
||||
# Remove decorators
|
||||
while func_source.lstrip().startswith("@"):
|
||||
func_source = func_source.split("\n", 1)[1].strip()
|
||||
func_source = helpers.generate_source_code(python_func)
|
||||
|
||||
# Get function name for execution
|
||||
func_name = python_func.__name__
|
||||
@@ -1259,16 +1251,12 @@ print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
|
||||
Uses a short launcher Python that spawns a detached child and exits immediately.
|
||||
"""
|
||||
import base64
|
||||
import inspect
|
||||
import json
|
||||
import textwrap
|
||||
import time as _time
|
||||
|
||||
try:
|
||||
source = inspect.getsource(python_func)
|
||||
func_source = textwrap.dedent(source).strip()
|
||||
while func_source.lstrip().startswith("@"):
|
||||
func_source = func_source.split("\n", 1)[1].strip()
|
||||
func_source = helpers.generate_source_code(python_func)
|
||||
func_name = python_func.__name__
|
||||
args_json = json.dumps(args, default=str)
|
||||
kwargs_json = json.dumps(kwargs, default=str)
|
||||
@@ -1366,15 +1354,11 @@ print(p.pid)
|
||||
remote traceback context appended.
|
||||
"""
|
||||
import base64
|
||||
import inspect
|
||||
import json
|
||||
import textwrap
|
||||
|
||||
try:
|
||||
source = inspect.getsource(python_func)
|
||||
func_source = textwrap.dedent(source).strip()
|
||||
while func_source.lstrip().startswith("@"):
|
||||
func_source = func_source.split("\n", 1)[1].strip()
|
||||
func_source = helpers.generate_source_code(python_func)
|
||||
func_name = python_func.__name__
|
||||
args_json = json.dumps(args, default=str)
|
||||
kwargs_json = json.dumps(kwargs, default=str)
|
||||
@@ -1482,16 +1466,12 @@ print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
|
||||
Uses a short launcher Python that spawns a detached child and exits immediately.
|
||||
"""
|
||||
import base64
|
||||
import inspect
|
||||
import json
|
||||
import textwrap
|
||||
import time as _time
|
||||
|
||||
try:
|
||||
source = inspect.getsource(python_func)
|
||||
func_source = textwrap.dedent(source).strip()
|
||||
while func_source.lstrip().startswith("@"):
|
||||
func_source = func_source.split("\n", 1)[1].strip()
|
||||
func_source = helpers.generate_source_code(python_func)
|
||||
func_name = python_func.__name__
|
||||
args_json = json.dumps(args, default=str)
|
||||
kwargs_json = json.dumps(kwargs, default=str)
|
||||
|
||||
@@ -2,10 +2,19 @@
|
||||
Helper functions and decorators for the Computer module.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import builtins
|
||||
import importlib.util
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from functools import wraps
|
||||
from typing import Any, Awaitable, Callable, Optional, TypeVar
|
||||
from inspect import getsource
|
||||
from textwrap import dedent
|
||||
from types import FunctionType, ModuleType
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Set, TypedDict, TypeVar
|
||||
|
||||
try:
|
||||
# Python 3.12+ has ParamSpec in typing
|
||||
@@ -17,9 +26,18 @@ except ImportError: # pragma: no cover
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class DependencyInfo(TypedDict):
|
||||
import_statements: List[str]
|
||||
definitions: List[tuple[str, Any]]
|
||||
|
||||
|
||||
# Global reference to the default computer instance
|
||||
_default_computer = None
|
||||
|
||||
# Global cache for function dependency analysis
|
||||
_function_dependency_map: Dict[FunctionType, DependencyInfo] = {}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -42,6 +60,9 @@ def sandboxed(
|
||||
"""
|
||||
Decorator that wraps a function to be executed remotely via computer.venv_exec
|
||||
|
||||
The function is automatically analyzed for dependencies (imports, helper functions,
|
||||
constants, etc.) and reconstructed with all necessary code in the remote sandbox.
|
||||
|
||||
Args:
|
||||
venv_name: Name of the virtual environment to execute in
|
||||
computer: The computer instance to use, or "default" to use the globally set default
|
||||
@@ -74,3 +95,396 @@ def sandboxed(
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _extract_import_statement(name: str, module: ModuleType) -> str:
|
||||
"""Extract the original import statement for a module."""
|
||||
module_name = module.__name__
|
||||
|
||||
if name == module_name.split(".")[0]:
|
||||
return f"import {module_name}"
|
||||
else:
|
||||
return f"import {module_name} as {name}"
|
||||
|
||||
|
||||
def _is_third_party_module(module_name: str) -> bool:
|
||||
"""Check if a module is a third-party module."""
|
||||
stdlib_modules = set(sys.stdlib_module_names) if hasattr(sys, "stdlib_module_names") else set()
|
||||
|
||||
if module_name in stdlib_modules:
|
||||
return False
|
||||
|
||||
try:
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
if spec is None:
|
||||
return False
|
||||
|
||||
if spec.origin and ("site-packages" in spec.origin or "dist-packages" in spec.origin):
|
||||
return True
|
||||
|
||||
return False
|
||||
except (ImportError, ModuleNotFoundError, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
def _is_project_import(module_name: str) -> bool:
|
||||
"""Check if a module is a project-level import."""
|
||||
if module_name.startswith("__relative_import_level_"):
|
||||
return True
|
||||
|
||||
if module_name in sys.modules:
|
||||
module = sys.modules[module_name]
|
||||
if hasattr(module, "__file__") and module.__file__:
|
||||
if "site-packages" not in module.__file__ and "dist-packages" not in module.__file__:
|
||||
cwd = os.getcwd()
|
||||
if module.__file__.startswith(cwd):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _categorize_module(module_name: str) -> str:
|
||||
"""Categorize a module as stdlib, third-party, or project."""
|
||||
if module_name.startswith("__relative_import_level_"):
|
||||
return "project"
|
||||
elif module_name in (
|
||||
set(sys.stdlib_module_names) if hasattr(sys, "stdlib_module_names") else set()
|
||||
):
|
||||
return "stdlib"
|
||||
elif _is_third_party_module(module_name):
|
||||
return "third_party"
|
||||
elif _is_project_import(module_name):
|
||||
return "project"
|
||||
else:
|
||||
return "unknown"
|
||||
|
||||
|
||||
class _DependencyVisitor(ast.NodeVisitor):
|
||||
"""AST visitor to extract imports and name references from a function."""
|
||||
|
||||
def __init__(self, function_name: str) -> None:
|
||||
self.function_name = function_name
|
||||
self.internal_imports: Set[str] = set()
|
||||
self.internal_import_statements: List[str] = []
|
||||
self.name_references: Set[str] = set()
|
||||
self.local_names: Set[str] = set()
|
||||
self.inside_function = False
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
||||
if node.name == self.function_name and not self.inside_function:
|
||||
self.inside_function = True
|
||||
|
||||
for arg in node.args.args + node.args.posonlyargs + node.args.kwonlyargs:
|
||||
self.local_names.add(arg.arg)
|
||||
if node.args.vararg:
|
||||
self.local_names.add(node.args.vararg.arg)
|
||||
if node.args.kwarg:
|
||||
self.local_names.add(node.args.kwarg.arg)
|
||||
|
||||
for child in node.body:
|
||||
self.visit(child)
|
||||
|
||||
self.inside_function = False
|
||||
else:
|
||||
if self.inside_function:
|
||||
self.local_names.add(node.name)
|
||||
for child in node.body:
|
||||
self.visit(child)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
|
||||
self.visit_FunctionDef(node) # type: ignore
|
||||
|
||||
def visit_Import(self, node: ast.Import) -> None:
|
||||
if self.inside_function:
|
||||
for alias in node.names:
|
||||
module_name = alias.name.split(".")[0]
|
||||
self.internal_imports.add(module_name)
|
||||
imported_as = alias.asname if alias.asname else alias.name.split(".")[0]
|
||||
self.local_names.add(imported_as)
|
||||
self.internal_import_statements.append(ast.unparse(node))
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
||||
if self.inside_function:
|
||||
if node.level == 0 and node.module:
|
||||
module_name = node.module.split(".")[0]
|
||||
self.internal_imports.add(module_name)
|
||||
elif node.level > 0:
|
||||
self.internal_imports.add(f"__relative_import_level_{node.level}__")
|
||||
|
||||
for alias in node.names:
|
||||
imported_as = alias.asname if alias.asname else alias.name
|
||||
self.local_names.add(imported_as)
|
||||
self.internal_import_statements.append(ast.unparse(node))
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Name(self, node: ast.Name) -> None:
|
||||
if self.inside_function:
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
self.name_references.add(node.id)
|
||||
elif isinstance(node.ctx, ast.Store):
|
||||
self.local_names.add(node.id)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
||||
if self.inside_function:
|
||||
self.local_names.add(node.name)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_For(self, node: ast.For) -> None:
|
||||
if self.inside_function and isinstance(node.target, ast.Name):
|
||||
self.local_names.add(node.target.id)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_comprehension(self, node: ast.comprehension) -> None:
|
||||
if self.inside_function and isinstance(node.target, ast.Name):
|
||||
self.local_names.add(node.target.id)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ExceptHandler(self, node: ast.ExceptHandler) -> None:
|
||||
if self.inside_function and node.name:
|
||||
self.local_names.add(node.name)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_With(self, node: ast.With) -> None:
|
||||
if self.inside_function:
|
||||
for item in node.items:
|
||||
if item.optional_vars and isinstance(item.optional_vars, ast.Name):
|
||||
self.local_names.add(item.optional_vars.id)
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
def _traverse_and_collect_dependencies(func: FunctionType) -> DependencyInfo:
|
||||
"""
|
||||
Traverse a function and collect its dependencies.
|
||||
|
||||
Returns a dict with:
|
||||
- import_statements: List of import statements needed
|
||||
- definitions: List of (name, obj) tuples for helper functions/classes/constants
|
||||
"""
|
||||
source = dedent(getsource(func))
|
||||
tree = ast.parse(source)
|
||||
|
||||
visitor = _DependencyVisitor(func.__name__)
|
||||
visitor.visit(tree)
|
||||
|
||||
builtin_names = set(dir(builtins))
|
||||
external_refs = (visitor.name_references - visitor.local_names) - builtin_names
|
||||
|
||||
import_statements = []
|
||||
definitions = []
|
||||
visited = set()
|
||||
|
||||
# Include all internal import statements
|
||||
import_statements.extend(visitor.internal_import_statements)
|
||||
|
||||
# Analyze external references recursively
|
||||
def analyze_object(obj: Any, name: str, depth: int = 0) -> None:
|
||||
if depth > 20:
|
||||
return
|
||||
|
||||
obj_id = id(obj)
|
||||
if obj_id in visited:
|
||||
return
|
||||
visited.add(obj_id)
|
||||
|
||||
# Handle modules
|
||||
if inspect.ismodule(obj):
|
||||
import_stmt = _extract_import_statement(name, obj)
|
||||
import_statements.append(import_stmt)
|
||||
return
|
||||
|
||||
# Handle functions and classes
|
||||
if (
|
||||
inspect.isfunction(obj)
|
||||
or inspect.isclass(obj)
|
||||
or inspect.isbuiltin(obj)
|
||||
or inspect.ismethod(obj)
|
||||
):
|
||||
obj_module = getattr(obj, "__module__", None)
|
||||
if obj_module:
|
||||
base_module = obj_module.split(".")[0]
|
||||
module_category = _categorize_module(base_module)
|
||||
|
||||
# If from stdlib/third-party, just add import
|
||||
if module_category in ("stdlib", "third_party"):
|
||||
obj_name = getattr(obj, "__name__", name)
|
||||
|
||||
# Check if object is accessible by 'name' (in globals or closures)
|
||||
is_accessible = False
|
||||
if name in func.__globals__ and func.__globals__[name] is obj:
|
||||
is_accessible = True
|
||||
elif func.__closure__ and hasattr(func, "__code__"):
|
||||
freevars = func.__code__.co_freevars
|
||||
for i, var_name in enumerate(freevars):
|
||||
if var_name == name and i < len(func.__closure__):
|
||||
try:
|
||||
if func.__closure__[i].cell_contents is obj:
|
||||
is_accessible = True
|
||||
break
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
if is_accessible and name == obj_name:
|
||||
# Direct import: from requests import get, from math import sqrt
|
||||
import_statements.append(f"from {base_module} import {name}")
|
||||
else:
|
||||
# Module import: import requests
|
||||
import_statements.append(f"import {base_module}")
|
||||
return
|
||||
|
||||
try:
|
||||
obj_tree = ast.parse(dedent(getsource(obj)))
|
||||
obj_visitor = _DependencyVisitor(obj.__name__)
|
||||
obj_visitor.visit(obj_tree)
|
||||
|
||||
obj_external_refs = obj_visitor.name_references - obj_visitor.local_names
|
||||
obj_external_refs = obj_external_refs - builtin_names
|
||||
|
||||
# Add internal imports from this object
|
||||
import_statements.extend(obj_visitor.internal_import_statements)
|
||||
|
||||
# Recursively analyze its dependencies
|
||||
obj_globals = getattr(obj, "__globals__", None)
|
||||
obj_closure = getattr(obj, "__closure__", None)
|
||||
obj_code = getattr(obj, "__code__", None)
|
||||
if obj_globals:
|
||||
for ref_name in obj_external_refs:
|
||||
ref_obj = None
|
||||
|
||||
# Check globals first
|
||||
if ref_name in obj_globals:
|
||||
ref_obj = obj_globals[ref_name]
|
||||
# Check closure variables using co_freevars
|
||||
elif obj_closure and obj_code:
|
||||
freevars = obj_code.co_freevars
|
||||
for i, var_name in enumerate(freevars):
|
||||
if var_name == ref_name and i < len(obj_closure):
|
||||
try:
|
||||
ref_obj = obj_closure[i].cell_contents
|
||||
break
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
if ref_obj is not None:
|
||||
analyze_object(ref_obj, ref_name, depth + 1)
|
||||
|
||||
# Add this object to definitions
|
||||
if not inspect.ismodule(obj):
|
||||
ref_module = getattr(obj, "__module__", None)
|
||||
if ref_module:
|
||||
ref_base_module = ref_module.split(".")[0]
|
||||
ref_category = _categorize_module(ref_base_module)
|
||||
if ref_category not in ("stdlib", "third_party"):
|
||||
definitions.append((name, obj))
|
||||
else:
|
||||
definitions.append((name, obj))
|
||||
|
||||
except (OSError, TypeError):
|
||||
pass
|
||||
return
|
||||
|
||||
if isinstance(obj, (int, float, str, bool, list, dict, tuple, set, frozenset, type(None))):
|
||||
definitions.append((name, obj))
|
||||
|
||||
# Analyze all external references
|
||||
for name in external_refs:
|
||||
obj = None
|
||||
|
||||
# First check globals
|
||||
if name in func.__globals__:
|
||||
obj = func.__globals__[name]
|
||||
# Then check closure variables (sibling functions in enclosing scope)
|
||||
elif func.__closure__ and func.__code__.co_freevars:
|
||||
# Match closure variable names with cell contents
|
||||
freevars = func.__code__.co_freevars
|
||||
for i, var_name in enumerate(freevars):
|
||||
if var_name == name and i < len(func.__closure__):
|
||||
try:
|
||||
obj = func.__closure__[i].cell_contents
|
||||
break
|
||||
except (ValueError, AttributeError):
|
||||
# Cell is empty or doesn't have contents
|
||||
pass
|
||||
|
||||
if obj is not None:
|
||||
analyze_object(obj, name)
|
||||
|
||||
# Remove duplicate import statements
|
||||
unique_imports = []
|
||||
seen = set()
|
||||
for stmt in import_statements:
|
||||
if stmt not in seen:
|
||||
seen.add(stmt)
|
||||
unique_imports.append(stmt)
|
||||
|
||||
# Remove duplicate definitions
|
||||
unique_definitions = []
|
||||
seen_names = set()
|
||||
for name, obj in definitions:
|
||||
if name not in seen_names:
|
||||
seen_names.add(name)
|
||||
unique_definitions.append((name, obj))
|
||||
|
||||
return {
|
||||
"import_statements": unique_imports,
|
||||
"definitions": unique_definitions,
|
||||
}
|
||||
|
||||
|
||||
def generate_source_code(func: FunctionType) -> str:
|
||||
"""
|
||||
Generate complete source code for a function with all dependencies.
|
||||
|
||||
Args:
|
||||
func: The function to generate source code for
|
||||
|
||||
Returns:
|
||||
Complete Python source code as a string
|
||||
"""
|
||||
|
||||
if func in _function_dependency_map:
|
||||
info = _function_dependency_map[func]
|
||||
else:
|
||||
info = _traverse_and_collect_dependencies(func)
|
||||
_function_dependency_map[func] = info
|
||||
|
||||
# Build source code
|
||||
parts = []
|
||||
|
||||
# 1. Add imports
|
||||
if info["import_statements"]:
|
||||
parts.append("\n".join(info["import_statements"]))
|
||||
|
||||
# 2. Add definitions
|
||||
for name, obj in info["definitions"]:
|
||||
try:
|
||||
if inspect.isfunction(obj):
|
||||
source = dedent(getsource(obj))
|
||||
tree = ast.parse(source)
|
||||
if tree.body and isinstance(tree.body[0], (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
tree.body[0].decorator_list = []
|
||||
source = ast.unparse(tree)
|
||||
parts.append(source)
|
||||
elif inspect.isclass(obj):
|
||||
source = dedent(getsource(obj))
|
||||
tree = ast.parse(source)
|
||||
if tree.body and isinstance(tree.body[0], ast.ClassDef):
|
||||
tree.body[0].decorator_list = []
|
||||
source = ast.unparse(tree)
|
||||
parts.append(source)
|
||||
else:
|
||||
parts.append(f"{name} = {repr(obj)}")
|
||||
except (OSError, TypeError):
|
||||
pass
|
||||
|
||||
# 3. Add main function (without decorators)
|
||||
func_source = dedent(getsource(func))
|
||||
tree = ast.parse(func_source)
|
||||
if tree.body and isinstance(tree.body[0], (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
tree.body[0].decorator_list = []
|
||||
func_source = ast.unparse(tree)
|
||||
parts.append(func_source)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
@@ -0,0 +1,478 @@
|
||||
from computer.helpers import generate_source_code
|
||||
|
||||
|
||||
class TestSimpleCases:
|
||||
"""Test simple dependency cases"""
|
||||
|
||||
def test_simple_function_no_dependencies(self):
|
||||
"""Test simple function with no dependencies"""
|
||||
|
||||
def simple_func():
|
||||
return 42
|
||||
|
||||
code = generate_source_code(simple_func)
|
||||
|
||||
assert "def simple_func():" in code
|
||||
assert "return 42" in code
|
||||
|
||||
def test_function_with_parameters(self):
|
||||
"""Test function with parameters"""
|
||||
|
||||
def add(a, b):
|
||||
return a + b
|
||||
|
||||
code = generate_source_code(add)
|
||||
|
||||
assert "def add(a, b):" in code
|
||||
assert "return a + b" in code
|
||||
|
||||
def test_function_with_docstring(self):
|
||||
"""Test function with docstring"""
|
||||
|
||||
def documented():
|
||||
"""This is a docstring"""
|
||||
return True
|
||||
|
||||
code = generate_source_code(documented)
|
||||
|
||||
assert "def documented():" in code
|
||||
assert "This is a docstring" in code
|
||||
|
||||
|
||||
class TestImports:
|
||||
"""Test import handling"""
|
||||
|
||||
def test_stdlib_import_inside_function(self):
|
||||
"""Test function with stdlib import inside"""
|
||||
|
||||
def with_stdlib():
|
||||
import math
|
||||
|
||||
return math.sqrt(16)
|
||||
|
||||
code = generate_source_code(with_stdlib)
|
||||
|
||||
assert "import math" in code
|
||||
assert "def with_stdlib():" in code
|
||||
assert "math.sqrt(16)" in code
|
||||
|
||||
def test_multiple_stdlib_imports(self):
|
||||
"""Test function with multiple stdlib imports"""
|
||||
|
||||
def with_multiple():
|
||||
import json
|
||||
import math
|
||||
|
||||
return json.dumps({"value": math.pi})
|
||||
|
||||
code = generate_source_code(with_multiple)
|
||||
|
||||
assert "import json" in code
|
||||
assert "import math" in code
|
||||
|
||||
def test_from_import_stdlib(self):
|
||||
"""Test from X import Y style"""
|
||||
|
||||
def with_from_import():
|
||||
from math import pi, sqrt
|
||||
|
||||
return sqrt(16) + pi
|
||||
|
||||
code = generate_source_code(with_from_import)
|
||||
|
||||
assert "from math import" in code
|
||||
assert "sqrt, pi" in code or ("sqrt" in code and "pi" in code)
|
||||
|
||||
def test_third_party_global_import(self):
|
||||
"""Test globally imported third-party module"""
|
||||
from requests import get
|
||||
|
||||
def with_requests():
|
||||
return get
|
||||
|
||||
code = generate_source_code(with_requests)
|
||||
|
||||
assert "from requests import get" in code
|
||||
assert "def with_requests():" in code
|
||||
|
||||
|
||||
class TestHelperFunctions:
|
||||
"""Test helper function dependencies"""
|
||||
|
||||
def test_single_helper_function(self):
|
||||
"""Test function with one helper"""
|
||||
|
||||
def helper(x):
|
||||
return x * 2
|
||||
|
||||
def main():
|
||||
return helper(21)
|
||||
|
||||
code = generate_source_code(main)
|
||||
|
||||
assert "def helper(x):" in code
|
||||
assert "return x * 2" in code
|
||||
assert "def main():" in code
|
||||
assert "return helper(21)" in code
|
||||
|
||||
def test_nested_helper_functions(self):
|
||||
"""Test function with nested helpers"""
|
||||
|
||||
def helper_a(x):
|
||||
return x + 1
|
||||
|
||||
def helper_b(x):
|
||||
return helper_a(x) * 2
|
||||
|
||||
def main():
|
||||
return helper_b(10)
|
||||
|
||||
code = generate_source_code(main)
|
||||
|
||||
assert "def helper_a(x):" in code
|
||||
assert "def helper_b(x):" in code
|
||||
assert "def main():" in code
|
||||
|
||||
def test_helper_dependency_ordering(self):
|
||||
"""Test that helpers are ordered correctly (dependencies first)"""
|
||||
|
||||
def level_3(x):
|
||||
return x + 1
|
||||
|
||||
def level_2(x):
|
||||
return level_3(x) * 2
|
||||
|
||||
def level_1(x):
|
||||
return level_2(x) + 5
|
||||
|
||||
code = generate_source_code(level_1)
|
||||
|
||||
# Check ordering
|
||||
pos_3 = code.find("def level_3")
|
||||
pos_2 = code.find("def level_2")
|
||||
pos_1 = code.find("def level_1")
|
||||
|
||||
assert pos_3 < pos_2 < pos_1, "Functions not in correct dependency order"
|
||||
|
||||
def test_helper_with_import(self):
|
||||
"""Test helper function that uses imports"""
|
||||
import math
|
||||
|
||||
def helper(x):
|
||||
return math.sqrt(x)
|
||||
|
||||
def main():
|
||||
return helper(16)
|
||||
|
||||
code = generate_source_code(main)
|
||||
|
||||
assert "import math" in code
|
||||
assert "def helper(x):" in code
|
||||
assert "def main():" in code
|
||||
|
||||
|
||||
class TestGlobalConstants:
|
||||
"""Test global constant handling"""
|
||||
|
||||
def test_simple_constant(self):
|
||||
"""Test function using simple constant"""
|
||||
MY_CONSTANT = 42
|
||||
|
||||
def with_constant():
|
||||
return MY_CONSTANT * 2
|
||||
|
||||
code = generate_source_code(with_constant)
|
||||
|
||||
assert "MY_CONSTANT = 42" in code
|
||||
assert "def with_constant():" in code
|
||||
|
||||
def test_string_constant(self):
|
||||
"""Test function using string constant"""
|
||||
API_URL = "https://api.example.com"
|
||||
|
||||
def with_string():
|
||||
return API_URL
|
||||
|
||||
code = generate_source_code(with_string)
|
||||
|
||||
assert "API_URL" in code
|
||||
assert "https://api.example.com" in code
|
||||
|
||||
def test_multiple_constants(self):
|
||||
"""Test function using multiple constants"""
|
||||
BASE_URL = "https://example.com"
|
||||
TIMEOUT = 30
|
||||
|
||||
def with_multiple():
|
||||
return f"{BASE_URL}?timeout={TIMEOUT}"
|
||||
|
||||
code = generate_source_code(with_multiple)
|
||||
|
||||
assert "BASE_URL" in code
|
||||
assert "TIMEOUT" in code
|
||||
|
||||
|
||||
class TestClassDefinitions:
|
||||
"""Test class definition handling"""
|
||||
|
||||
def test_simple_class(self):
|
||||
"""Test function using simple class"""
|
||||
|
||||
class Calculator:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def add(self, x):
|
||||
return self.value + x
|
||||
|
||||
def with_class():
|
||||
calc = Calculator(10)
|
||||
return calc.add(5)
|
||||
|
||||
code = generate_source_code(with_class)
|
||||
|
||||
assert "class Calculator:" in code
|
||||
assert "__init__" in code
|
||||
assert "def add" in code
|
||||
assert "def with_class():" in code
|
||||
|
||||
def test_class_with_methods(self):
|
||||
"""Test class with multiple methods"""
|
||||
|
||||
class Math:
|
||||
def add(self, a, b):
|
||||
return a + b
|
||||
|
||||
def multiply(self, a, b):
|
||||
return a * b
|
||||
|
||||
def use_math():
|
||||
m = Math()
|
||||
return m.add(2, 3) + m.multiply(4, 5)
|
||||
|
||||
code = generate_source_code(use_math)
|
||||
|
||||
assert "class Math:" in code
|
||||
assert "def add" in code
|
||||
assert "def multiply" in code
|
||||
|
||||
|
||||
class TestDecoratorRemoval:
|
||||
"""Test decorator removal"""
|
||||
|
||||
def test_simple_decorator_removed(self):
|
||||
"""Test that decorators are removed"""
|
||||
|
||||
def my_decorator(f):
|
||||
def wrapper(*args, **kwargs):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@my_decorator
|
||||
def decorated():
|
||||
return 42
|
||||
|
||||
code = generate_source_code(decorated)
|
||||
|
||||
assert "@my_decorator" not in code
|
||||
assert "def decorated():" in code
|
||||
assert "return 42" in code
|
||||
|
||||
def test_multiple_decorators_removed(self):
|
||||
"""Test that multiple decorators are removed"""
|
||||
|
||||
def decorator1(f):
|
||||
return f
|
||||
|
||||
def decorator2(f):
|
||||
return f
|
||||
|
||||
@decorator1
|
||||
@decorator2
|
||||
def multi_decorated():
|
||||
return True
|
||||
|
||||
code = generate_source_code(multi_decorated)
|
||||
|
||||
assert "@decorator1" not in code
|
||||
assert "@decorator2" not in code
|
||||
assert "def multi_decorated():" in code
|
||||
|
||||
|
||||
class TestComplexScenarios:
|
||||
"""Test complex real-world scenarios"""
|
||||
|
||||
def test_api_client_pattern(self):
|
||||
"""Test typical API client pattern"""
|
||||
import json
|
||||
|
||||
BASE_URL = "https://api.example.com"
|
||||
|
||||
def make_request(endpoint):
|
||||
return f"{BASE_URL}/{endpoint}"
|
||||
|
||||
def parse_response(response):
|
||||
return json.loads(response)
|
||||
|
||||
def get_user(user_id):
|
||||
url = make_request(f"users/{user_id}")
|
||||
response = f'{{"id": {user_id}}}'
|
||||
return parse_response(response)
|
||||
|
||||
code = generate_source_code(get_user)
|
||||
|
||||
assert "import json" in code
|
||||
assert "BASE_URL" in code
|
||||
assert "def make_request" in code
|
||||
assert "def parse_response" in code
|
||||
assert "def get_user" in code
|
||||
|
||||
def test_data_processing_pipeline(self):
|
||||
"""Test data processing pipeline"""
|
||||
|
||||
def validate_data(data):
|
||||
return [x for x in data if x > 0]
|
||||
|
||||
def transform_data(data):
|
||||
return [x * 2 for x in data]
|
||||
|
||||
def process_pipeline(raw_data):
|
||||
valid = validate_data(raw_data)
|
||||
return transform_data(valid)
|
||||
|
||||
code = generate_source_code(process_pipeline)
|
||||
|
||||
assert "def validate_data" in code
|
||||
assert "def transform_data" in code
|
||||
assert "def process_pipeline" in code
|
||||
|
||||
# Check ordering
|
||||
validate_pos = code.find("def validate_data")
|
||||
transform_pos = code.find("def transform_data")
|
||||
pipeline_pos = code.find("def process_pipeline")
|
||||
|
||||
assert validate_pos < pipeline_pos
|
||||
assert transform_pos < pipeline_pos
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases"""
|
||||
|
||||
def test_function_with_default_args(self):
|
||||
"""Test function with default arguments"""
|
||||
|
||||
def with_defaults(a, b=10, c=20):
|
||||
return a + b + c
|
||||
|
||||
code = generate_source_code(with_defaults)
|
||||
|
||||
assert "def with_defaults(a, b=10, c=20):" in code
|
||||
|
||||
def test_function_with_kwargs(self):
|
||||
"""Test function with *args and **kwargs"""
|
||||
|
||||
def with_varargs(*args, **kwargs):
|
||||
return sum(args)
|
||||
|
||||
code = generate_source_code(with_varargs)
|
||||
|
||||
assert "def with_varargs(*args, **kwargs):" in code
|
||||
|
||||
def test_async_function(self):
|
||||
"""Test async function"""
|
||||
|
||||
async def async_func():
|
||||
return 42
|
||||
|
||||
code = generate_source_code(async_func)
|
||||
|
||||
assert "async def async_func():" in code
|
||||
assert "return 42" in code
|
||||
|
||||
def test_lambda_works_in_files(self):
|
||||
"""Test that lambda functions work when defined in files"""
|
||||
lambda_func = lambda x: x * 2
|
||||
|
||||
# Lambda functions defined in files CAN have their source extracted
|
||||
code = generate_source_code(lambda_func)
|
||||
|
||||
# Should include the lambda
|
||||
assert "lambda" in code.lower() or "lambda_func" in code
|
||||
|
||||
|
||||
class TestCaching:
|
||||
"""Test caching mechanism"""
|
||||
|
||||
def test_cache_returns_same_result(self):
|
||||
"""Test that cache returns identical results"""
|
||||
|
||||
def cached_func():
|
||||
return 42
|
||||
|
||||
code1 = generate_source_code(cached_func)
|
||||
code2 = generate_source_code(cached_func)
|
||||
|
||||
assert code1 == code2
|
||||
|
||||
def test_different_functions_different_cache(self):
|
||||
"""Test that different functions have different cache entries"""
|
||||
|
||||
def func1():
|
||||
return 1
|
||||
|
||||
def func2():
|
||||
return 2
|
||||
|
||||
code1 = generate_source_code(func1)
|
||||
code2 = generate_source_code(func2)
|
||||
|
||||
assert code1 != code2
|
||||
assert "return 1" in code1
|
||||
assert "return 2" in code2
|
||||
|
||||
|
||||
class TestBuiltinHandling:
|
||||
"""Test that builtins are handled correctly"""
|
||||
|
||||
def test_builtin_not_included(self):
|
||||
"""Test that builtin functions are not included as dependencies"""
|
||||
|
||||
def uses_builtins():
|
||||
data = list(range(10))
|
||||
return len(data)
|
||||
|
||||
code = generate_source_code(uses_builtins)
|
||||
|
||||
# Should not include definitions for list, range, len
|
||||
assert "def list" not in code
|
||||
assert "def range" not in code
|
||||
assert "def len" not in code
|
||||
assert "def uses_builtins():" in code
|
||||
|
||||
|
||||
class TestImportStylePreservation:
|
||||
"""Test that import styles are preserved"""
|
||||
|
||||
def test_from_import_preserved(self):
|
||||
"""Test from X import Y is preserved"""
|
||||
from math import sqrt
|
||||
|
||||
def use_sqrt():
|
||||
return sqrt(16)
|
||||
|
||||
code = generate_source_code(use_sqrt)
|
||||
|
||||
assert "from math import sqrt" in code
|
||||
|
||||
def test_import_as_preserved(self):
|
||||
"""Test import X as Y is preserved"""
|
||||
import json as j
|
||||
|
||||
def use_json():
|
||||
return j.dumps({"key": "value"})
|
||||
|
||||
code = generate_source_code(use_json)
|
||||
|
||||
# The import style should be preserved
|
||||
assert "import json" in code
|
||||
Reference in New Issue
Block a user