diff --git a/libs/python/agent/agent/computers/custom.py b/libs/python/agent/agent/computers/custom.py index 19079bad..7ee027fd 100644 --- a/libs/python/agent/agent/computers/custom.py +++ b/libs/python/agent/agent/computers/custom.py @@ -29,6 +29,29 @@ class CustomComputerHandler(ComputerHandler): self.functions = functions self._last_screenshot_size: Optional[tuple[int, int]] = None + async def _call_function(self, func, *args, **kwargs): + """ + Call a function, handling both async and sync functions. + + Args: + func: The function to call + *args: Positional arguments to pass to the function + **kwargs: Keyword arguments to pass to the function + + Returns: + The result of the function call + """ + import asyncio + import inspect + + if callable(func): + if inspect.iscoroutinefunction(func): + return await func(*args, **kwargs) + else: + return func(*args, **kwargs) + else: + return func + async def _get_value(self, attribute: str): """ Get value for an attribute, checking both 'get_{attribute}' and '{attribute}' keys. @@ -42,13 +65,11 @@ class CustomComputerHandler(ComputerHandler): # Check for 'get_{attribute}' first get_key = f"get_{attribute}" if get_key in self.functions: - value = self.functions[get_key] - return await value() if callable(value) else value + return await self._call_function(self.functions[get_key]) # Check for '{attribute}' if attribute in self.functions: - value = self.functions[attribute] - return await value() if callable(value) else value + return await self._call_function(self.functions[attribute]) return None @@ -81,13 +102,16 @@ class CustomComputerHandler(ComputerHandler): async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]: """Get the current environment type.""" result = await self._get_value('environment') - return result if result is not None else "linux" + if result is None: + return "linux" + assert result in ["windows", "mac", "linux", "browser"] + return result # type: ignore async def get_dimensions(self) -> tuple[int, int]: """Get screen dimensions as (width, height).""" result = await self._get_value('dimensions') if result is not None: - return result + return result # type: ignore # Fallback: use last screenshot size if available if not self._last_screenshot_size: @@ -98,8 +122,8 @@ class CustomComputerHandler(ComputerHandler): async def screenshot(self) -> str: """Take a screenshot and return as base64 string.""" - result = await self.functions['screenshot']() - b64_str = self._to_b64_str(result) + result = await self._call_function(self.functions['screenshot']) + b64_str = self._to_b64_str(result) # type: ignore # Try to extract dimensions for fallback use try: @@ -118,31 +142,31 @@ class CustomComputerHandler(ComputerHandler): async def click(self, x: int, y: int, button: str = "left") -> None: """Click at coordinates with specified button.""" if 'click' in self.functions: - await self.functions['click'](x, y, button) + await self._call_function(self.functions['click'], x, y, button) # No-op if not implemented async def double_click(self, x: int, y: int) -> None: """Double click at coordinates.""" if 'double_click' in self.functions: - await self.functions['double_click'](x, y) + await self._call_function(self.functions['double_click'], x, y) # No-op if not implemented async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None: """Scroll at coordinates with specified scroll amounts.""" if 'scroll' in self.functions: - await self.functions['scroll'](x, y, scroll_x, scroll_y) + await self._call_function(self.functions['scroll'], x, y, scroll_x, scroll_y) # No-op if not implemented async def type(self, text: str) -> None: """Type text.""" if 'type' in self.functions: - await self.functions['type'](text) + await self._call_function(self.functions['type'], text) # No-op if not implemented async def wait(self, ms: int = 1000) -> None: """Wait for specified milliseconds.""" if 'wait' in self.functions: - await self.functions['wait'](ms) + await self._call_function(self.functions['wait'], ms) else: # Default implementation import asyncio @@ -151,35 +175,35 @@ class CustomComputerHandler(ComputerHandler): async def move(self, x: int, y: int) -> None: """Move cursor to coordinates.""" if 'move' in self.functions: - await self.functions['move'](x, y) + await self._call_function(self.functions['move'], x, y) # No-op if not implemented async def keypress(self, keys: Union[List[str], str]) -> None: """Press key combination.""" if 'keypress' in self.functions: - await self.functions['keypress'](keys) + await self._call_function(self.functions['keypress'], keys) # No-op if not implemented async def drag(self, path: List[Dict[str, int]]) -> None: """Drag along specified path.""" if 'drag' in self.functions: - await self.functions['drag'](path) + await self._call_function(self.functions['drag'], path) # No-op if not implemented async def get_current_url(self) -> str: """Get current URL (for browser environments).""" if 'get_current_url' in self.functions: - return await self.functions['get_current_url']() + return await self._get_value('current_url') # type: ignore return "" # Default fallback async def left_mouse_down(self, x: Optional[int] = None, y: Optional[int] = None) -> None: """Left mouse down at coordinates.""" if 'left_mouse_down' in self.functions: - await self.functions['left_mouse_down'](x, y) + await self._call_function(self.functions['left_mouse_down'], x, y) # No-op if not implemented async def left_mouse_up(self, x: Optional[int] = None, y: Optional[int] = None) -> None: """Left mouse up at coordinates.""" if 'left_mouse_up' in self.functions: - await self.functions['left_mouse_up'](x, y) + await self._call_function(self.functions['left_mouse_up'], x, y) # No-op if not implemented