From b8a6a42baad16b9691904dab3c969304a0e7c73e Mon Sep 17 00:00:00 2001 From: "synacktra.work@gmail.com" Date: Wed, 17 Dec 2025 01:07:40 +0530 Subject: [PATCH 1/3] feat(helpers): add automatic dependency resolution for @sandboxed decorator - Implement AST-based dependency analysis with closure support - Automatically detect and include helper functions, imports, and constants - Handle nested dependencies recursively - Add caching with function object as key --- libs/python/computer/computer/helpers.py | 416 ++++++++++++++++++++++- 1 file changed, 415 insertions(+), 1 deletion(-) diff --git a/libs/python/computer/computer/helpers.py b/libs/python/computer/computer/helpers.py index 29231b56..fab882c7 100644 --- a/libs/python/computer/computer/helpers.py +++ b/libs/python/computer/computer/helpers.py @@ -2,10 +2,19 @@ Helper functions and decorators for the Computer module. """ +import ast import asyncio +import builtins +import importlib.util +import inspect import logging +import os +import sys from functools import wraps -from typing import Any, Awaitable, Callable, Optional, TypeVar +from inspect import getsource +from textwrap import dedent +from types import FunctionType, ModuleType +from typing import Any, Awaitable, Callable, Dict, List, Set, TypedDict, TypeVar try: # Python 3.12+ has ParamSpec in typing @@ -17,9 +26,18 @@ except ImportError: # pragma: no cover P = ParamSpec("P") R = TypeVar("R") + +class DependencyInfo(TypedDict): + import_statements: List[str] + definitions: List[tuple[str, Any]] + + # Global reference to the default computer instance _default_computer = None +# Global cache for function dependency analysis +_function_dependency_map: Dict[FunctionType, DependencyInfo] = {} + logger = logging.getLogger(__name__) @@ -42,6 +60,9 @@ def sandboxed( """ Decorator that wraps a function to be executed remotely via computer.venv_exec + The function is automatically analyzed for dependencies (imports, helper functions, + constants, etc.) and reconstructed with all necessary code in the remote sandbox. + Args: venv_name: Name of the virtual environment to execute in computer: The computer instance to use, or "default" to use the globally set default @@ -74,3 +95,396 @@ def sandboxed( return wrapper return decorator + + +def _extract_import_statement(name: str, module: ModuleType) -> str: + """Extract the original import statement for a module.""" + module_name = module.__name__ + + if name == module_name.split(".")[0]: + return f"import {module_name}" + else: + return f"import {module_name} as {name}" + + +def _is_third_party_module(module_name: str) -> bool: + """Check if a module is a third-party module.""" + stdlib_modules = set(sys.stdlib_module_names) if hasattr(sys, "stdlib_module_names") else set() + + if module_name in stdlib_modules: + return False + + try: + spec = importlib.util.find_spec(module_name) + if spec is None: + return False + + if spec.origin and ("site-packages" in spec.origin or "dist-packages" in spec.origin): + return True + + return False + except (ImportError, ModuleNotFoundError, ValueError): + return False + + +def _is_project_import(module_name: str) -> bool: + """Check if a module is a project-level import.""" + if module_name.startswith("__relative_import_level_"): + return True + + if module_name in sys.modules: + module = sys.modules[module_name] + if hasattr(module, "__file__") and module.__file__: + if "site-packages" not in module.__file__ and "dist-packages" not in module.__file__: + cwd = os.getcwd() + if module.__file__.startswith(cwd): + return True + + return False + + +def _categorize_module(module_name: str) -> str: + """Categorize a module as stdlib, third-party, or project.""" + if module_name.startswith("__relative_import_level_"): + return "project" + elif module_name in ( + set(sys.stdlib_module_names) if hasattr(sys, "stdlib_module_names") else set() + ): + return "stdlib" + elif _is_third_party_module(module_name): + return "third_party" + elif _is_project_import(module_name): + return "project" + else: + return "unknown" + + +class _DependencyVisitor(ast.NodeVisitor): + """AST visitor to extract imports and name references from a function.""" + + def __init__(self, function_name: str) -> None: + self.function_name = function_name + self.internal_imports: Set[str] = set() + self.internal_import_statements: List[str] = [] + self.name_references: Set[str] = set() + self.local_names: Set[str] = set() + self.inside_function = False + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + if node.name == self.function_name and not self.inside_function: + self.inside_function = True + + for arg in node.args.args + node.args.posonlyargs + node.args.kwonlyargs: + self.local_names.add(arg.arg) + if node.args.vararg: + self.local_names.add(node.args.vararg.arg) + if node.args.kwarg: + self.local_names.add(node.args.kwarg.arg) + + for child in node.body: + self.visit(child) + + self.inside_function = False + else: + if self.inside_function: + self.local_names.add(node.name) + for child in node.body: + self.visit(child) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + self.visit_FunctionDef(node) # type: ignore + + def visit_Import(self, node: ast.Import) -> None: + if self.inside_function: + for alias in node.names: + module_name = alias.name.split(".")[0] + self.internal_imports.add(module_name) + imported_as = alias.asname if alias.asname else alias.name.split(".")[0] + self.local_names.add(imported_as) + self.internal_import_statements.append(ast.unparse(node)) + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + if self.inside_function: + if node.level == 0 and node.module: + module_name = node.module.split(".")[0] + self.internal_imports.add(module_name) + elif node.level > 0: + self.internal_imports.add(f"__relative_import_level_{node.level}__") + + for alias in node.names: + imported_as = alias.asname if alias.asname else alias.name + self.local_names.add(imported_as) + self.internal_import_statements.append(ast.unparse(node)) + + self.generic_visit(node) + + def visit_Name(self, node: ast.Name) -> None: + if self.inside_function: + if isinstance(node.ctx, ast.Load): + self.name_references.add(node.id) + elif isinstance(node.ctx, ast.Store): + self.local_names.add(node.id) + self.generic_visit(node) + + def visit_ClassDef(self, node: ast.ClassDef) -> None: + if self.inside_function: + self.local_names.add(node.name) + self.generic_visit(node) + + def visit_For(self, node: ast.For) -> None: + if self.inside_function and isinstance(node.target, ast.Name): + self.local_names.add(node.target.id) + self.generic_visit(node) + + def visit_comprehension(self, node: ast.comprehension) -> None: + if self.inside_function and isinstance(node.target, ast.Name): + self.local_names.add(node.target.id) + self.generic_visit(node) + + def visit_ExceptHandler(self, node: ast.ExceptHandler) -> None: + if self.inside_function and node.name: + self.local_names.add(node.name) + self.generic_visit(node) + + def visit_With(self, node: ast.With) -> None: + if self.inside_function: + for item in node.items: + if item.optional_vars and isinstance(item.optional_vars, ast.Name): + self.local_names.add(item.optional_vars.id) + self.generic_visit(node) + + +def _traverse_and_collect_dependencies(func: FunctionType) -> DependencyInfo: + """ + Traverse a function and collect its dependencies. + + Returns a dict with: + - import_statements: List of import statements needed + - definitions: List of (name, obj) tuples for helper functions/classes/constants + """ + source = dedent(getsource(func)) + tree = ast.parse(source) + + visitor = _DependencyVisitor(func.__name__) + visitor.visit(tree) + + builtin_names = set(dir(builtins)) + external_refs = (visitor.name_references - visitor.local_names) - builtin_names + + import_statements = [] + definitions = [] + visited = set() + + # Include all internal import statements + import_statements.extend(visitor.internal_import_statements) + + # Analyze external references recursively + def analyze_object(obj: Any, name: str, depth: int = 0) -> None: + if depth > 20: + return + + obj_id = id(obj) + if obj_id in visited: + return + visited.add(obj_id) + + # Handle modules + if inspect.ismodule(obj): + import_stmt = _extract_import_statement(name, obj) + import_statements.append(import_stmt) + return + + # Handle functions and classes + if ( + inspect.isfunction(obj) + or inspect.isclass(obj) + or inspect.isbuiltin(obj) + or inspect.ismethod(obj) + ): + obj_module = getattr(obj, "__module__", None) + if obj_module: + base_module = obj_module.split(".")[0] + module_category = _categorize_module(base_module) + + # If from stdlib/third-party, just add import + if module_category in ("stdlib", "third_party"): + obj_name = getattr(obj, "__name__", name) + + # Check if object is accessible by 'name' (in globals or closures) + is_accessible = False + if name in func.__globals__ and func.__globals__[name] is obj: + is_accessible = True + elif func.__closure__ and hasattr(func, "__code__"): + freevars = func.__code__.co_freevars + for i, var_name in enumerate(freevars): + if var_name == name and i < len(func.__closure__): + try: + if func.__closure__[i].cell_contents is obj: + is_accessible = True + break + except (ValueError, AttributeError): + pass + + if is_accessible and name == obj_name: + # Direct import: from requests import get, from math import sqrt + import_statements.append(f"from {base_module} import {name}") + else: + # Module import: import requests + import_statements.append(f"import {base_module}") + return + + try: + obj_tree = ast.parse(dedent(getsource(obj))) + obj_visitor = _DependencyVisitor(obj.__name__) + obj_visitor.visit(obj_tree) + + obj_external_refs = obj_visitor.name_references - obj_visitor.local_names + obj_external_refs = obj_external_refs - builtin_names + + # Add internal imports from this object + import_statements.extend(obj_visitor.internal_import_statements) + + # Recursively analyze its dependencies + obj_globals = getattr(obj, "__globals__", None) + obj_closure = getattr(obj, "__closure__", None) + obj_code = getattr(obj, "__code__", None) + if obj_globals: + for ref_name in obj_external_refs: + ref_obj = None + + # Check globals first + if ref_name in obj_globals: + ref_obj = obj_globals[ref_name] + # Check closure variables using co_freevars + elif obj_closure and obj_code: + freevars = obj_code.co_freevars + for i, var_name in enumerate(freevars): + if var_name == ref_name and i < len(obj_closure): + try: + ref_obj = obj_closure[i].cell_contents + break + except (ValueError, AttributeError): + pass + + if ref_obj is not None: + analyze_object(ref_obj, ref_name, depth + 1) + + # Add this object to definitions + if not inspect.ismodule(obj): + ref_module = getattr(obj, "__module__", None) + if ref_module: + ref_base_module = ref_module.split(".")[0] + ref_category = _categorize_module(ref_base_module) + if ref_category not in ("stdlib", "third_party"): + definitions.append((name, obj)) + else: + definitions.append((name, obj)) + + except (OSError, TypeError): + pass + return + + if isinstance(obj, (int, float, str, bool, list, dict, tuple, set, frozenset, type(None))): + definitions.append((name, obj)) + + # Analyze all external references + for name in external_refs: + obj = None + + # First check globals + if name in func.__globals__: + obj = func.__globals__[name] + # Then check closure variables (sibling functions in enclosing scope) + elif func.__closure__ and func.__code__.co_freevars: + # Match closure variable names with cell contents + freevars = func.__code__.co_freevars + for i, var_name in enumerate(freevars): + if var_name == name and i < len(func.__closure__): + try: + obj = func.__closure__[i].cell_contents + break + except (ValueError, AttributeError): + # Cell is empty or doesn't have contents + pass + + if obj is not None: + analyze_object(obj, name) + + # Remove duplicate import statements + unique_imports = [] + seen = set() + for stmt in import_statements: + if stmt not in seen: + seen.add(stmt) + unique_imports.append(stmt) + + # Remove duplicate definitions + unique_definitions = [] + seen_names = set() + for name, obj in definitions: + if name not in seen_names: + seen_names.add(name) + unique_definitions.append((name, obj)) + + return { + "import_statements": unique_imports, + "definitions": unique_definitions, + } + + +def generate_source_code(func: FunctionType) -> str: + """ + Generate complete source code for a function with all dependencies. + + Args: + func: The function to generate source code for + + Returns: + Complete Python source code as a string + """ + + if func in _function_dependency_map: + info = _function_dependency_map[func] + else: + info = _traverse_and_collect_dependencies(func) + _function_dependency_map[func] = info + + # Build source code + parts = [] + + # 1. Add imports + if info["import_statements"]: + parts.append("\n".join(info["import_statements"])) + + # 2. Add definitions + for name, obj in info["definitions"]: + try: + if inspect.isfunction(obj): + source = dedent(getsource(obj)) + tree = ast.parse(source) + if tree.body and isinstance(tree.body[0], (ast.FunctionDef, ast.AsyncFunctionDef)): + tree.body[0].decorator_list = [] + source = ast.unparse(tree) + parts.append(source) + elif inspect.isclass(obj): + source = dedent(getsource(obj)) + tree = ast.parse(source) + if tree.body and isinstance(tree.body[0], ast.ClassDef): + tree.body[0].decorator_list = [] + source = ast.unparse(tree) + parts.append(source) + else: + parts.append(f"{name} = {repr(obj)}") + except (OSError, TypeError): + pass + + # 3. Add main function (without decorators) + func_source = dedent(getsource(func)) + tree = ast.parse(func_source) + if tree.body and isinstance(tree.body[0], (ast.FunctionDef, ast.AsyncFunctionDef)): + tree.body[0].decorator_list = [] + func_source = ast.unparse(tree) + parts.append(func_source) + + return "\n\n".join(parts) From e7be69d5cd1bc3c98388c446fc38344e4ecb686e Mon Sep 17 00:00:00 2001 From: "synacktra.work@gmail.com" Date: Wed, 17 Dec 2025 01:09:46 +0530 Subject: [PATCH 2/3] refactor(computer): update venv_exec to use generate source code utility --- libs/python/computer/computer/computer.py | 28 ++++------------------- 1 file changed, 4 insertions(+), 24 deletions(-) diff --git a/libs/python/computer/computer/computer.py b/libs/python/computer/computer/computer.py index b1f08cd2..8b08dace 100644 --- a/libs/python/computer/computer/computer.py +++ b/libs/python/computer/computer/computer.py @@ -1121,19 +1121,11 @@ class Computer: The result of the function execution, or raises any exception that occurred """ import base64 - import inspect import json import textwrap try: - # Get function source code using inspect.getsource - source = inspect.getsource(python_func) - # Remove common leading whitespace (dedent) - func_source = textwrap.dedent(source).strip() - - # Remove decorators - while func_source.lstrip().startswith("@"): - func_source = func_source.split("\n", 1)[1].strip() + func_source = helpers.generate_source_code(python_func) # Get function name for execution func_name = python_func.__name__ @@ -1259,16 +1251,12 @@ print(f"<<>>{{output_json}}<<>>") Uses a short launcher Python that spawns a detached child and exits immediately. """ import base64 - import inspect import json import textwrap import time as _time try: - source = inspect.getsource(python_func) - func_source = textwrap.dedent(source).strip() - while func_source.lstrip().startswith("@"): - func_source = func_source.split("\n", 1)[1].strip() + func_source = helpers.generate_source_code(python_func) func_name = python_func.__name__ args_json = json.dumps(args, default=str) kwargs_json = json.dumps(kwargs, default=str) @@ -1366,15 +1354,11 @@ print(p.pid) remote traceback context appended. """ import base64 - import inspect import json import textwrap try: - source = inspect.getsource(python_func) - func_source = textwrap.dedent(source).strip() - while func_source.lstrip().startswith("@"): - func_source = func_source.split("\n", 1)[1].strip() + func_source = helpers.generate_source_code(python_func) func_name = python_func.__name__ args_json = json.dumps(args, default=str) kwargs_json = json.dumps(kwargs, default=str) @@ -1482,16 +1466,12 @@ print(f"<<>>{{output_json}}<<>>") Uses a short launcher Python that spawns a detached child and exits immediately. """ import base64 - import inspect import json import textwrap import time as _time try: - source = inspect.getsource(python_func) - func_source = textwrap.dedent(source).strip() - while func_source.lstrip().startswith("@"): - func_source = func_source.split("\n", 1)[1].strip() + func_source = helpers.generate_source_code(python_func) func_name = python_func.__name__ args_json = json.dumps(args, default=str) kwargs_json = json.dumps(kwargs, default=str) From efcb370a554e4f4b2220efb13dd44c9f39c1d023 Mon Sep 17 00:00:00 2001 From: "synacktra.work@gmail.com" Date: Wed, 17 Dec 2025 01:12:38 +0530 Subject: [PATCH 3/3] test(helpers): add comprehensive teists for dependency collection - Test closure variable detection (helpers, constants, imports) - Test nested dependency resolution and ordering - Test class definitions and decorator removal - Test import style preservation - Test caching mechanism --- libs/python/computer/tests/test_helpers.py | 478 +++++++++++++++++++++ 1 file changed, 478 insertions(+) create mode 100644 libs/python/computer/tests/test_helpers.py diff --git a/libs/python/computer/tests/test_helpers.py b/libs/python/computer/tests/test_helpers.py new file mode 100644 index 00000000..1af455ef --- /dev/null +++ b/libs/python/computer/tests/test_helpers.py @@ -0,0 +1,478 @@ +from computer.helpers import generate_source_code + + +class TestSimpleCases: + """Test simple dependency cases""" + + def test_simple_function_no_dependencies(self): + """Test simple function with no dependencies""" + + def simple_func(): + return 42 + + code = generate_source_code(simple_func) + + assert "def simple_func():" in code + assert "return 42" in code + + def test_function_with_parameters(self): + """Test function with parameters""" + + def add(a, b): + return a + b + + code = generate_source_code(add) + + assert "def add(a, b):" in code + assert "return a + b" in code + + def test_function_with_docstring(self): + """Test function with docstring""" + + def documented(): + """This is a docstring""" + return True + + code = generate_source_code(documented) + + assert "def documented():" in code + assert "This is a docstring" in code + + +class TestImports: + """Test import handling""" + + def test_stdlib_import_inside_function(self): + """Test function with stdlib import inside""" + + def with_stdlib(): + import math + + return math.sqrt(16) + + code = generate_source_code(with_stdlib) + + assert "import math" in code + assert "def with_stdlib():" in code + assert "math.sqrt(16)" in code + + def test_multiple_stdlib_imports(self): + """Test function with multiple stdlib imports""" + + def with_multiple(): + import json + import math + + return json.dumps({"value": math.pi}) + + code = generate_source_code(with_multiple) + + assert "import json" in code + assert "import math" in code + + def test_from_import_stdlib(self): + """Test from X import Y style""" + + def with_from_import(): + from math import pi, sqrt + + return sqrt(16) + pi + + code = generate_source_code(with_from_import) + + assert "from math import" in code + assert "sqrt, pi" in code or ("sqrt" in code and "pi" in code) + + def test_third_party_global_import(self): + """Test globally imported third-party module""" + from requests import get + + def with_requests(): + return get + + code = generate_source_code(with_requests) + + assert "from requests import get" in code + assert "def with_requests():" in code + + +class TestHelperFunctions: + """Test helper function dependencies""" + + def test_single_helper_function(self): + """Test function with one helper""" + + def helper(x): + return x * 2 + + def main(): + return helper(21) + + code = generate_source_code(main) + + assert "def helper(x):" in code + assert "return x * 2" in code + assert "def main():" in code + assert "return helper(21)" in code + + def test_nested_helper_functions(self): + """Test function with nested helpers""" + + def helper_a(x): + return x + 1 + + def helper_b(x): + return helper_a(x) * 2 + + def main(): + return helper_b(10) + + code = generate_source_code(main) + + assert "def helper_a(x):" in code + assert "def helper_b(x):" in code + assert "def main():" in code + + def test_helper_dependency_ordering(self): + """Test that helpers are ordered correctly (dependencies first)""" + + def level_3(x): + return x + 1 + + def level_2(x): + return level_3(x) * 2 + + def level_1(x): + return level_2(x) + 5 + + code = generate_source_code(level_1) + + # Check ordering + pos_3 = code.find("def level_3") + pos_2 = code.find("def level_2") + pos_1 = code.find("def level_1") + + assert pos_3 < pos_2 < pos_1, "Functions not in correct dependency order" + + def test_helper_with_import(self): + """Test helper function that uses imports""" + import math + + def helper(x): + return math.sqrt(x) + + def main(): + return helper(16) + + code = generate_source_code(main) + + assert "import math" in code + assert "def helper(x):" in code + assert "def main():" in code + + +class TestGlobalConstants: + """Test global constant handling""" + + def test_simple_constant(self): + """Test function using simple constant""" + MY_CONSTANT = 42 + + def with_constant(): + return MY_CONSTANT * 2 + + code = generate_source_code(with_constant) + + assert "MY_CONSTANT = 42" in code + assert "def with_constant():" in code + + def test_string_constant(self): + """Test function using string constant""" + API_URL = "https://api.example.com" + + def with_string(): + return API_URL + + code = generate_source_code(with_string) + + assert "API_URL" in code + assert "https://api.example.com" in code + + def test_multiple_constants(self): + """Test function using multiple constants""" + BASE_URL = "https://example.com" + TIMEOUT = 30 + + def with_multiple(): + return f"{BASE_URL}?timeout={TIMEOUT}" + + code = generate_source_code(with_multiple) + + assert "BASE_URL" in code + assert "TIMEOUT" in code + + +class TestClassDefinitions: + """Test class definition handling""" + + def test_simple_class(self): + """Test function using simple class""" + + class Calculator: + def __init__(self, value): + self.value = value + + def add(self, x): + return self.value + x + + def with_class(): + calc = Calculator(10) + return calc.add(5) + + code = generate_source_code(with_class) + + assert "class Calculator:" in code + assert "__init__" in code + assert "def add" in code + assert "def with_class():" in code + + def test_class_with_methods(self): + """Test class with multiple methods""" + + class Math: + def add(self, a, b): + return a + b + + def multiply(self, a, b): + return a * b + + def use_math(): + m = Math() + return m.add(2, 3) + m.multiply(4, 5) + + code = generate_source_code(use_math) + + assert "class Math:" in code + assert "def add" in code + assert "def multiply" in code + + +class TestDecoratorRemoval: + """Test decorator removal""" + + def test_simple_decorator_removed(self): + """Test that decorators are removed""" + + def my_decorator(f): + def wrapper(*args, **kwargs): + return f(*args, **kwargs) + + return wrapper + + @my_decorator + def decorated(): + return 42 + + code = generate_source_code(decorated) + + assert "@my_decorator" not in code + assert "def decorated():" in code + assert "return 42" in code + + def test_multiple_decorators_removed(self): + """Test that multiple decorators are removed""" + + def decorator1(f): + return f + + def decorator2(f): + return f + + @decorator1 + @decorator2 + def multi_decorated(): + return True + + code = generate_source_code(multi_decorated) + + assert "@decorator1" not in code + assert "@decorator2" not in code + assert "def multi_decorated():" in code + + +class TestComplexScenarios: + """Test complex real-world scenarios""" + + def test_api_client_pattern(self): + """Test typical API client pattern""" + import json + + BASE_URL = "https://api.example.com" + + def make_request(endpoint): + return f"{BASE_URL}/{endpoint}" + + def parse_response(response): + return json.loads(response) + + def get_user(user_id): + url = make_request(f"users/{user_id}") + response = f'{{"id": {user_id}}}' + return parse_response(response) + + code = generate_source_code(get_user) + + assert "import json" in code + assert "BASE_URL" in code + assert "def make_request" in code + assert "def parse_response" in code + assert "def get_user" in code + + def test_data_processing_pipeline(self): + """Test data processing pipeline""" + + def validate_data(data): + return [x for x in data if x > 0] + + def transform_data(data): + return [x * 2 for x in data] + + def process_pipeline(raw_data): + valid = validate_data(raw_data) + return transform_data(valid) + + code = generate_source_code(process_pipeline) + + assert "def validate_data" in code + assert "def transform_data" in code + assert "def process_pipeline" in code + + # Check ordering + validate_pos = code.find("def validate_data") + transform_pos = code.find("def transform_data") + pipeline_pos = code.find("def process_pipeline") + + assert validate_pos < pipeline_pos + assert transform_pos < pipeline_pos + + +class TestEdgeCases: + """Test edge cases""" + + def test_function_with_default_args(self): + """Test function with default arguments""" + + def with_defaults(a, b=10, c=20): + return a + b + c + + code = generate_source_code(with_defaults) + + assert "def with_defaults(a, b=10, c=20):" in code + + def test_function_with_kwargs(self): + """Test function with *args and **kwargs""" + + def with_varargs(*args, **kwargs): + return sum(args) + + code = generate_source_code(with_varargs) + + assert "def with_varargs(*args, **kwargs):" in code + + def test_async_function(self): + """Test async function""" + + async def async_func(): + return 42 + + code = generate_source_code(async_func) + + assert "async def async_func():" in code + assert "return 42" in code + + def test_lambda_works_in_files(self): + """Test that lambda functions work when defined in files""" + lambda_func = lambda x: x * 2 + + # Lambda functions defined in files CAN have their source extracted + code = generate_source_code(lambda_func) + + # Should include the lambda + assert "lambda" in code.lower() or "lambda_func" in code + + +class TestCaching: + """Test caching mechanism""" + + def test_cache_returns_same_result(self): + """Test that cache returns identical results""" + + def cached_func(): + return 42 + + code1 = generate_source_code(cached_func) + code2 = generate_source_code(cached_func) + + assert code1 == code2 + + def test_different_functions_different_cache(self): + """Test that different functions have different cache entries""" + + def func1(): + return 1 + + def func2(): + return 2 + + code1 = generate_source_code(func1) + code2 = generate_source_code(func2) + + assert code1 != code2 + assert "return 1" in code1 + assert "return 2" in code2 + + +class TestBuiltinHandling: + """Test that builtins are handled correctly""" + + def test_builtin_not_included(self): + """Test that builtin functions are not included as dependencies""" + + def uses_builtins(): + data = list(range(10)) + return len(data) + + code = generate_source_code(uses_builtins) + + # Should not include definitions for list, range, len + assert "def list" not in code + assert "def range" not in code + assert "def len" not in code + assert "def uses_builtins():" in code + + +class TestImportStylePreservation: + """Test that import styles are preserved""" + + def test_from_import_preserved(self): + """Test from X import Y is preserved""" + from math import sqrt + + def use_sqrt(): + return sqrt(16) + + code = generate_source_code(use_sqrt) + + assert "from math import sqrt" in code + + def test_import_as_preserved(self): + """Test import X as Y is preserved""" + import json as j + + def use_json(): + return j.dumps({"key": "value"}) + + code = generate_source_code(use_json) + + # The import style should be preserved + assert "import json" in code