Merge pull request #270 from trycua/feat/computer/cloud

Add Cloud Provider Support
This commit is contained in:
f-trycua
2025-05-27 05:30:40 +02:00
committed by GitHub
12 changed files with 418 additions and 184 deletions

View File

@@ -326,4 +326,4 @@ Thank you to all our supporters!
<!-- markdownlint-restore -->
<!-- prettier-ignore-end -->
<!-- ALL-CONTRIBUTORS-LIST:END -->
<!-- ALL-CONTRIBUTORS-LIST:END -->

View File

@@ -42,11 +42,22 @@ async def main():
ephemeral=False,
)
# computer = Computer(
# os_type="linux",
# api_key=os.getenv("CUA_API_KEY"),
# name=os.getenv("CONTAINER_NAME"),
# provider_type=VMProviderType.CLOUD,
# )
try:
# Run the computer with default parameters
await computer.run()
await computer.interface.hotkey("command", "space")
screenshot = await computer.interface.screenshot()
with open(Path("~/cua/examples/screenshot.png").expanduser(), "wb") as f:
f.write(screenshot)
# await computer.interface.hotkey("command", "space")
# res = await computer.interface.run_command("touch ./Downloads/empty_file")
# print(f"Run command result: {res}")

View File

@@ -290,7 +290,7 @@ def get_provider_and_model(model_name: str, loop_provider: str) -> tuple:
model_name_to_use = cleaned_model_name
# agent_loop remains AgentLoop.OMNI
elif agent_loop == AgentLoop.UITARS:
# For UITARS, use MLXVLM provider for the MLX models, OAICOMPAT for custom
# For UITARS, use MLXVLM for mlx-community models, OAICOMPAT for custom
if model_name == "Custom model (OpenAI compatible API)":
provider = LLMProvider.OAICOMPAT
model_name_to_use = "tgi"
@@ -333,12 +333,25 @@ def get_ollama_models() -> List[str]:
logging.error(f"Error getting Ollama models: {e}")
return []
def create_computer_instance(verbosity: int = logging.INFO) -> Computer:
def create_computer_instance(
verbosity: int = logging.INFO,
os_type: str = "macos",
provider_type: str = "lume",
name: Optional[str] = None,
api_key: Optional[str] = None
) -> Computer:
"""Create or get the global Computer instance."""
global global_computer
if global_computer is None:
global_computer = Computer(verbosity=verbosity)
global_computer = Computer(
verbosity=verbosity,
os_type=os_type,
provider_type=provider_type,
name=name if name else "",
api_key=api_key
)
return global_computer
@@ -353,12 +366,22 @@ def create_agent(
verbosity: int = logging.INFO,
use_oaicompat: bool = False,
provider_base_url: Optional[str] = None,
computer_os: str = "macos",
computer_provider: str = "lume",
computer_name: Optional[str] = None,
computer_api_key: Optional[str] = None,
) -> ComputerAgent:
"""Create or update the global agent with the specified parameters."""
global global_agent
# Create the computer if not already done
computer = create_computer_instance(verbosity=verbosity)
computer = create_computer_instance(
verbosity=verbosity,
os_type=computer_os,
provider_type=computer_provider,
name=computer_name,
api_key=computer_api_key
)
# Get API key from environment if not provided
if api_key is None:
@@ -401,6 +424,7 @@ def create_agent(
return global_agent
def create_gradio_ui(
provider_name: str = "openai",
model_name: str = "gpt-4o",
@@ -439,6 +463,9 @@ def create_gradio_ui(
# Check if API keys are available
has_openai_key = bool(openai_api_key)
has_anthropic_key = bool(anthropic_api_key)
print("has_openai_key", has_openai_key)
print("has_anthropic_key", has_anthropic_key)
# Get Ollama models for OMNI
ollama_models = get_ollama_models()
@@ -473,7 +500,7 @@ def create_gradio_ui(
elif initial_loop == "ANTHROPIC":
initial_model = anthropic_models[0] if anthropic_models else "No models available"
else: # OMNI
initial_model = omni_models[0] if omni_models else "No models available"
initial_model = omni_models[0] if omni_models else "Custom model (OpenAI compatible API)"
if "Custom model (OpenAI compatible API)" in available_models_for_loop:
initial_model = (
"Custom model (OpenAI compatible API)" # Default to custom if available and no other default fits
@@ -494,7 +521,7 @@ def create_gradio_ui(
]
# Function to generate Python code based on configuration and tasks
def generate_python_code(agent_loop_choice, provider, model_name, tasks, provider_url, recent_images=3, save_trajectory=True):
def generate_python_code(agent_loop_choice, provider, model_name, tasks, provider_url, recent_images=3, save_trajectory=True, computer_os="macos", computer_provider="lume", container_name="", cua_cloud_api_key=""):
"""Generate Python code for the current configuration and tasks.
Args:
@@ -505,6 +532,10 @@ def create_gradio_ui(
provider_url: The provider base URL for OAICOMPAT providers
recent_images: Number of recent images to keep in context
save_trajectory: Whether to save the agent trajectory
computer_os: Operating system type for the computer
computer_provider: Provider type for the computer
container_name: Optional VM name
cua_cloud_api_key: Optional CUA Cloud API key
Returns:
Formatted Python code as a string
@@ -515,13 +546,29 @@ def create_gradio_ui(
if task and task.strip():
tasks_str += f' "{task}",\n'
# Create the Python code template
# Create the Python code template with computer configuration
computer_args = []
if computer_os != "macos":
computer_args.append(f'os_type="{computer_os}"')
if computer_provider != "lume":
computer_args.append(f'provider_type="{computer_provider}"')
if container_name:
computer_args.append(f'name="{container_name}"')
if cua_cloud_api_key:
computer_args.append(f'api_key="{cua_cloud_api_key}"')
computer_args_str = ", ".join(computer_args)
if computer_args_str:
computer_args_str = f"({computer_args_str})"
else:
computer_args_str = "()"
code = f'''import asyncio
from computer import Computer
from agent import ComputerAgent, LLM, AgentLoop, LLMProvider
async def main():
async with Computer() as macos_computer:
async with Computer{computer_args_str} as macos_computer:
agent = ComputerAgent(
computer=macos_computer,
loop=AgentLoop.{agent_loop_choice},
@@ -660,12 +707,49 @@ if __name__ == "__main__":
LLMProvider.OPENAI,
"gpt-4o",
[],
"https://openrouter.ai/api/v1"
"https://openrouter.ai/api/v1",
3, # recent_images default
True, # save_trajectory default
"macos",
"lume",
"",
""
),
interactive=False,
)
with gr.Accordion("Configuration", open=True):
with gr.Accordion("Computer Configuration", open=True):
# Computer configuration options
computer_os = gr.Radio(
choices=["macos", "linux"],
label="Operating System",
value="macos",
info="Select the operating system for the computer",
)
computer_provider = gr.Radio(
choices=["cloud", "lume"],
label="Provider",
value="lume",
info="Select the computer provider",
)
container_name = gr.Textbox(
label="Container Name",
placeholder="Enter container name (optional)",
value="",
info="Optional name for the container",
)
cua_cloud_api_key = gr.Textbox(
label="CUA Cloud API Key",
placeholder="Enter your CUA Cloud API key",
value="",
type="password",
info="Required for cloud provider",
)
with gr.Accordion("Agent Configuration", open=True):
# Configuration options
agent_loop = gr.Dropdown(
choices=["OPENAI", "ANTHROPIC", "OMNI", "UITARS"],
@@ -986,6 +1070,10 @@ if __name__ == "__main__":
custom_api_key=None,
openai_key_input=None,
anthropic_key_input=None,
computer_os="macos",
computer_provider="lume",
container_name="",
cua_cloud_api_key="",
):
if not history:
yield history
@@ -1092,6 +1180,10 @@ if __name__ == "__main__":
"provider_base_url": custom_url_value,
"save_trajectory": save_traj,
"recent_images": recent_imgs,
"computer_os": computer_os,
"computer_provider": computer_provider,
"container_name": container_name,
"cua_cloud_api_key": cua_cloud_api_key,
}
save_settings(current_settings)
# --- End Save Settings ---
@@ -1109,6 +1201,10 @@ if __name__ == "__main__":
use_oaicompat=is_oaicompat, # Set flag if custom model was selected
# Pass custom URL only if custom model was selected
provider_base_url=custom_url_value if is_oaicompat else None,
computer_os=computer_os,
computer_provider=computer_provider,
computer_name=container_name,
computer_api_key=cua_cloud_api_key,
verbosity=logging.DEBUG, # Added verbosity here
)
@@ -1235,6 +1331,10 @@ if __name__ == "__main__":
provider_api_key,
openai_api_key_input,
anthropic_api_key_input,
computer_os,
computer_provider,
container_name,
cua_cloud_api_key,
],
outputs=[chatbot_history],
queue=True,
@@ -1253,82 +1353,20 @@ if __name__ == "__main__":
# Function to update the code display based on configuration and chat history
def update_code_display(agent_loop, model_choice_val, custom_model_val, chat_history, provider_base_url, recent_images_val, save_trajectory_val):
def update_code_display(agent_loop, model_choice_val, custom_model_val, chat_history, provider_base_url, recent_images_val, save_trajectory_val, computer_os, computer_provider, container_name, cua_cloud_api_key):
# Extract messages from chat history
messages = []
if chat_history:
for msg in chat_history:
if msg.get("role") == "user":
if isinstance(msg, dict) and msg.get("role") == "user":
messages.append(msg.get("content", ""))
# Determine if this is a custom model selection and which type
is_custom_openai_api = model_choice_val == "Custom model (OpenAI compatible API)"
is_custom_ollama = model_choice_val == "Custom model (ollama)"
is_custom_model_selected = is_custom_openai_api or is_custom_ollama
# Determine provider and model based on current selection
provider, model_name, _ = get_provider_and_model(
model_choice_val or custom_model_val or "gpt-4o",
agent_loop
)
# Determine provider and model name based on agent loop
if agent_loop == "OPENAI":
# For OPENAI loop, always use OPENAI provider with computer-use-preview
provider = LLMProvider.OPENAI
model_name = "computer-use-preview"
elif agent_loop == "ANTHROPIC":
# For ANTHROPIC loop, always use ANTHROPIC provider
provider = LLMProvider.ANTHROPIC
# Extract model name from the UI string
if model_choice_val.startswith("Anthropic: Claude "):
# Extract the model name based on the UI string
model_parts = model_choice_val.replace("Anthropic: Claude ", "").split(" (")
version = model_parts[0] # e.g., "3.7 Sonnet"
date = model_parts[1].replace(")", "") if len(model_parts) > 1 else "" # e.g., "20250219"
# Format as claude-3-7-sonnet-20250219 or claude-3-5-sonnet-20240620
version = version.replace(".", "-").replace(" ", "-").lower()
model_name = f"claude-{version}-{date}"
else:
# Use the model_choice_val directly if it doesn't match the expected format
model_name = model_choice_val
elif agent_loop == "UITARS":
# For UITARS, use MLXVLM for mlx-community models, OAICOMPAT for custom
if model_choice_val == "Custom model (OpenAI compatible API)":
provider = LLMProvider.OAICOMPAT
model_name = custom_model_val
else:
provider = LLMProvider.MLXVLM
model_name = model_choice_val
elif agent_loop == "OMNI":
# For OMNI, provider can be OPENAI, ANTHROPIC, OLLAMA, or OAICOMPAT
if is_custom_openai_api:
provider = LLMProvider.OAICOMPAT
model_name = custom_model_val
elif is_custom_ollama:
provider = LLMProvider.OLLAMA
model_name = custom_model_val
elif model_choice_val.startswith("OMNI: OpenAI "):
provider = LLMProvider.OPENAI
# Extract model name from UI string (e.g., "OMNI: OpenAI GPT-4o" -> "gpt-4o")
model_name = model_choice_val.replace("OMNI: OpenAI ", "").lower().replace(" ", "-")
elif model_choice_val.startswith("OMNI: Claude "):
provider = LLMProvider.ANTHROPIC
# Extract model name from UI string (similar to ANTHROPIC loop case)
model_parts = model_choice_val.replace("OMNI: Claude ", "").split(" (")
version = model_parts[0] # e.g., "3.7 Sonnet"
date = model_parts[1].replace(")", "") if len(model_parts) > 1 else "" # e.g., "20250219"
# Format as claude-3-7-sonnet-20250219 or claude-3-5-sonnet-20240620
version = version.replace(".", "-").replace(" ", "-").lower()
model_name = f"claude-{version}-{date}"
elif model_choice_val.startswith("OMNI: Ollama "):
provider = LLMProvider.OLLAMA
# Extract model name from UI string (e.g., "OMNI: Ollama llama3" -> "llama3")
model_name = model_choice_val.replace("OMNI: Ollama ", "")
else:
# Fallback to get_provider_and_model for any other cases
provider, model_name, _ = get_provider_and_model(model_choice_val, agent_loop)
else:
# Fallback for any other agent loop
provider, model_name, _ = get_provider_and_model(model_choice_val, agent_loop)
# Generate and return the code
return generate_python_code(
agent_loop,
provider,
@@ -1336,38 +1374,62 @@ if __name__ == "__main__":
messages,
provider_base_url,
recent_images_val,
save_trajectory_val
save_trajectory_val,
computer_os,
computer_provider,
container_name,
cua_cloud_api_key
)
# Update code display when configuration changes
agent_loop.change(
update_code_display,
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory],
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key],
outputs=[code_display]
)
model_choice.change(
update_code_display,
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory],
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key],
outputs=[code_display]
)
custom_model.change(
update_code_display,
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory],
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key],
outputs=[code_display]
)
chatbot_history.change(
update_code_display,
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory],
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key],
outputs=[code_display]
)
recent_images.change(
update_code_display,
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory],
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key],
outputs=[code_display]
)
save_trajectory.change(
update_code_display,
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory],
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key],
outputs=[code_display]
)
computer_os.change(
update_code_display,
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key],
outputs=[code_display]
)
computer_provider.change(
update_code_display,
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key],
outputs=[code_display]
)
container_name.change(
update_code_display,
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key],
outputs=[code_display]
)
cua_cloud_api_key.change(
update_code_display,
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key],
outputs=[code_display]
)

View File

@@ -8,11 +8,11 @@ import traceback
from contextlib import redirect_stdout, redirect_stderr
from io import StringIO
from .handlers.factory import HandlerFactory
import os
import aiohttp
# Set up logging with more detail
logging.basicConfig(
level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configure WebSocket with larger message size
@@ -48,6 +48,112 @@ manager = ConnectionManager()
async def websocket_endpoint(websocket: WebSocket):
# WebSocket message size is configured at the app or endpoint level, not on the instance
await manager.connect(websocket)
# Check if CONTAINER_NAME is set (indicating cloud provider)
container_name = os.environ.get("CONTAINER_NAME")
# If cloud provider, perform authentication handshake
if container_name:
try:
logger.info(f"Cloud provider detected. CONTAINER_NAME: {container_name}. Waiting for authentication...")
# Wait for authentication message
auth_data = await websocket.receive_json()
# Validate auth message format
if auth_data.get("command") != "authenticate":
await websocket.send_json({
"success": False,
"error": "First message must be authentication"
})
await websocket.close()
manager.disconnect(websocket)
return
# Extract credentials
client_api_key = auth_data.get("params", {}).get("api_key")
client_container_name = auth_data.get("params", {}).get("container_name")
# Layer 1: VM Identity Verification
if client_container_name != container_name:
logger.warning(f"VM name mismatch. Expected: {container_name}, Got: {client_container_name}")
await websocket.send_json({
"success": False,
"error": "VM name mismatch"
})
await websocket.close()
manager.disconnect(websocket)
return
# Layer 2: API Key Validation with TryCUA API
if not client_api_key:
await websocket.send_json({
"success": False,
"error": "API key required"
})
await websocket.close()
manager.disconnect(websocket)
return
# Validate with TryCUA API
try:
async with aiohttp.ClientSession() as session:
headers = {
"Authorization": f"Bearer {client_api_key}"
}
async with session.get(
f"https://www.trycua.com/api/vm/auth?container_name={container_name}",
headers=headers,
) as resp:
if resp.status != 200:
error_msg = await resp.text()
logger.warning(f"API validation failed: {error_msg}")
await websocket.send_json({
"success": False,
"error": "Authentication failed"
})
await websocket.close()
manager.disconnect(websocket)
return
# If we get a 200 response with VNC URL, the VM exists and user has access
vnc_url = (await resp.text()).strip()
if not vnc_url:
logger.warning(f"No VNC URL returned for VM: {container_name}")
await websocket.send_json({
"success": False,
"error": "VM not found"
})
await websocket.close()
manager.disconnect(websocket)
return
logger.info(f"Authentication successful for VM: {container_name}")
await websocket.send_json({
"success": True,
"message": "Authenticated"
})
except Exception as e:
logger.error(f"Error validating with TryCUA API: {e}")
await websocket.send_json({
"success": False,
"error": "Authentication service unavailable"
})
await websocket.close()
manager.disconnect(websocket)
return
except Exception as e:
logger.error(f"Authentication error: {e}")
await websocket.send_json({
"success": False,
"error": "Authentication failed"
})
await websocket.close()
manager.disconnect(websocket)
return
# Map commands to appropriate handler methods
handlers = {

View File

@@ -17,7 +17,8 @@ dependencies = [
"uvicorn[standard]>=0.27.0",
"pydantic>=2.0.0",
"pyautogui>=0.9.54",
"pillow>=10.2.0"
"pillow>=10.2.0",
"aiohttp>=3.9.1"
]
[project.optional-dependencies]

View File

@@ -38,7 +38,8 @@ class Computer:
noVNC_port: Optional[int] = 8006,
host: str = os.environ.get("PYLUME_HOST", "localhost"),
storage: Optional[str] = None,
ephemeral: bool = False
ephemeral: bool = False,
api_key: Optional[str] = None
):
"""Initialize a new Computer instance.
@@ -77,6 +78,8 @@ class Computer:
self.os_type = os_type
self.provider_type = provider_type
self.ephemeral = ephemeral
self.api_key = api_key
# The default is currently to use non-ephemeral storage
if storage and ephemeral and storage != "ephemeral":
@@ -256,9 +259,7 @@ class Computer:
elif self.provider_type == VMProviderType.CLOUD:
self.config.vm_provider = VMProviderFactory.create_provider(
self.provider_type,
port=port,
host=host,
storage=storage,
api_key=self.api_key,
verbose=verbose,
)
else:
@@ -392,12 +393,25 @@ class Computer:
self.logger.info(f"Initializing interface for {self.os_type} at {ip_address}")
from .interface.base import BaseComputerInterface
self._interface = cast(
BaseComputerInterface,
InterfaceFactory.create_interface_for_os(
os=self.os_type, ip_address=ip_address # type: ignore[arg-type]
),
)
# Pass authentication credentials if using cloud provider
if self.provider_type == VMProviderType.CLOUD and self.api_key and self.config.name:
self._interface = cast(
BaseComputerInterface,
InterfaceFactory.create_interface_for_os(
os=self.os_type,
ip_address=ip_address,
api_key=self.api_key,
vm_name=self.config.name
),
)
else:
self._interface = cast(
BaseComputerInterface,
InterfaceFactory.create_interface_for_os(
os=self.os_type,
ip_address=ip_address
),
)
# Wait for the WebSocket interface to be ready
self.logger.info("Connecting to WebSocket interface...")
@@ -492,6 +506,11 @@ class Computer:
# Call the provider's get_ip method which will wait indefinitely
storage_param = "ephemeral" if self.ephemeral else self.storage
# Log the image being used
self.logger.info(f"Running VM using image: {self.image}")
# Call provider.get_ip with explicit image parameter
ip = await self.config.vm_provider.get_ip(
name=self.config.name,
storage=storage_param,

View File

@@ -8,17 +8,21 @@ from ..logger import Logger, LogLevel
class BaseComputerInterface(ABC):
"""Base class for computer control interfaces."""
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume"):
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
"""Initialize interface.
Args:
ip_address: IP address of the computer to control
username: Username for authentication
password: Password for authentication
api_key: Optional API key for cloud authentication
vm_name: Optional VM name for cloud authentication
"""
self.ip_address = ip_address
self.username = username
self.password = password
self.api_key = api_key
self.vm_name = vm_name
self.logger = Logger("cua.interface", LogLevel.NORMAL)
@abstractmethod

View File

@@ -1,6 +1,6 @@
"""Factory for creating computer interfaces."""
from typing import Literal
from typing import Literal, Optional
from .base import BaseComputerInterface
class InterfaceFactory:
@@ -9,13 +9,17 @@ class InterfaceFactory:
@staticmethod
def create_interface_for_os(
os: Literal['macos', 'linux'],
ip_address: str
ip_address: str,
api_key: Optional[str] = None,
vm_name: Optional[str] = None
) -> BaseComputerInterface:
"""Create an interface for the specified OS.
Args:
os: Operating system type ('macos' or 'linux')
ip_address: IP address of the computer to control
api_key: Optional API key for cloud authentication
vm_name: Optional VM name for cloud authentication
Returns:
BaseComputerInterface: The appropriate interface for the OS
@@ -28,8 +32,8 @@ class InterfaceFactory:
from .linux import LinuxComputerInterface
if os == 'macos':
return MacOSComputerInterface(ip_address)
return MacOSComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
elif os == 'linux':
return LinuxComputerInterface(ip_address)
return LinuxComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
else:
raise ValueError(f"Unsupported OS type: {os}")

View File

@@ -15,8 +15,8 @@ from .models import Key, KeyType
class LinuxComputerInterface(BaseComputerInterface):
"""Interface for Linux."""
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume"):
super().__init__(ip_address, username, password)
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
super().__init__(ip_address, username, password, api_key, vm_name)
self._ws = None
self._reconnect_task = None
self._closed = False
@@ -37,7 +37,8 @@ class LinuxComputerInterface(BaseComputerInterface):
Returns:
WebSocket URI for the Computer API Server
"""
return f"ws://{self.ip_address}:8000/ws"
protocol = "wss" if self.api_key else "ws"
return f"{protocol}://{self.ip_address}:8000/ws"
async def _keep_alive(self):
"""Keep the WebSocket connection alive with automatic reconnection."""
@@ -86,6 +87,32 @@ class LinuxComputerInterface(BaseComputerInterface):
timeout=30,
)
self.logger.info("WebSocket connection established")
# If api_key and vm_name are provided, perform authentication handshake
if self.api_key and self.vm_name:
self.logger.info("Performing authentication handshake...")
auth_message = {
"command": "authenticate",
"params": {
"api_key": self.api_key,
"vm_name": self.vm_name
}
}
await self._ws.send(json.dumps(auth_message))
# Wait for authentication response
auth_response = await asyncio.wait_for(self._ws.recv(), timeout=10)
auth_result = json.loads(auth_response)
if not auth_result.get("success"):
error_msg = auth_result.get("error", "Authentication failed")
self.logger.error(f"Authentication failed: {error_msg}")
await self._ws.close()
self._ws = None
raise ConnectionError(f"Authentication failed: {error_msg}")
self.logger.info("Authentication successful")
self._reconnect_delay = 1 # Reset reconnect delay on successful connection
self._last_ping = time.time()
retry_count = 0 # Reset retry count on successful connection

View File

@@ -13,10 +13,10 @@ from .models import Key, KeyType
class MacOSComputerInterface(BaseComputerInterface):
"""Interface for MacOS."""
"""Interface for macOS."""
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume"):
super().__init__(ip_address, username, password)
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
super().__init__(ip_address, username, password, api_key, vm_name)
self._ws = None
self._reconnect_task = None
self._closed = False
@@ -27,7 +27,7 @@ class MacOSComputerInterface(BaseComputerInterface):
self._max_reconnect_delay = 30 # Maximum delay between reconnection attempts
self._log_connection_attempts = True # Flag to control connection attempt logging
# Set logger name for MacOS interface
# Set logger name for macOS interface
self.logger = Logger("cua.interface.macos", LogLevel.NORMAL)
@property
@@ -37,7 +37,8 @@ class MacOSComputerInterface(BaseComputerInterface):
Returns:
WebSocket URI for the Computer API Server
"""
return f"ws://{self.ip_address}:8000/ws"
protocol = "wss" if self.api_key else "ws"
return f"{protocol}://{self.ip_address}:8000/ws"
async def _keep_alive(self):
"""Keep the WebSocket connection alive with automatic reconnection."""
@@ -86,6 +87,32 @@ class MacOSComputerInterface(BaseComputerInterface):
timeout=30,
)
self.logger.info("WebSocket connection established")
# If api_key and vm_name are provided, perform authentication handshake
if self.api_key and self.vm_name:
self.logger.info("Performing authentication handshake...")
auth_message = {
"command": "authenticate",
"params": {
"api_key": self.api_key,
"vm_name": self.vm_name
}
}
await self._ws.send(json.dumps(auth_message))
# Wait for authentication response
auth_response = await asyncio.wait_for(self._ws.recv(), timeout=10)
auth_result = json.loads(auth_response)
if not auth_result.get("success"):
error_msg = auth_result.get("error", "Authentication failed")
self.logger.error(f"Authentication failed: {error_msg}")
await self._ws.close()
self._ws = None
raise ConnectionError(f"Authentication failed: {error_msg}")
self.logger.info("Authentication successful")
self._reconnect_delay = 1 # Reset reconnect delay on successful connection
self._last_ping = time.time()
retry_count = 0 # Reset retry count on successful connection

View File

@@ -11,90 +11,65 @@ from ..base import BaseVMProvider, VMProviderType
# Setup logging
logger = logging.getLogger(__name__)
import asyncio
import aiohttp
from urllib.parse import urlparse
class CloudProvider(BaseVMProvider):
"""Cloud VM Provider stub implementation.
This is a placeholder for a future cloud VM provider implementation.
"""
"""Cloud VM Provider implementation."""
def __init__(
self,
host: str = "localhost",
port: int = 7777,
storage: Optional[str] = None,
self,
api_key: str,
verbose: bool = False,
**kwargs,
):
"""Initialize the Cloud provider.
"""
Args:
host: Host to use for API connections (default: localhost)
port: Port for the API server (default: 7777)
storage: Path to store VM data
api_key: API key for authentication
name: Name of the VM
verbose: Enable verbose logging
"""
self.host = host
self.port = port
self.storage = storage
assert api_key, "api_key required for CloudProvider"
self.api_key = api_key
self.verbose = verbose
logger.warning("CloudProvider is not yet implemented")
@property
def provider_type(self) -> VMProviderType:
"""Get the provider type."""
return VMProviderType.CLOUD
async def __aenter__(self):
"""Enter async context manager."""
logger.debug("Entering CloudProvider context")
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Exit async context manager."""
logger.debug("Exiting CloudProvider context")
pass
async def get_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
"""Get VM information by name."""
logger.warning("CloudProvider.get_vm is not implemented")
return {
"name": name,
"status": "unavailable",
"message": "CloudProvider is not implemented"
}
"""Get VM VNC URL by name using the cloud API."""
return {"name": name, "hostname": f"{name}.containers.cloud.trycua.com"}
async def list_vms(self) -> List[Dict[str, Any]]:
"""List all available VMs."""
logger.warning("CloudProvider.list_vms is not implemented")
return []
async def run_vm(self, image: str, name: str, run_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
"""Run a VM with the given options."""
logger.warning("CloudProvider.run_vm is not implemented")
return {
"name": name,
"status": "unavailable",
"message": "CloudProvider is not implemented"
}
return {"name": name, "status": "unavailable", "message": "CloudProvider is not implemented"}
async def stop_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
"""Stop a running VM."""
logger.warning("CloudProvider.stop_vm is not implemented")
return {
"name": name,
"status": "stopped",
"message": "CloudProvider is not implemented"
}
return {"name": name, "status": "stopped", "message": "CloudProvider is not implemented"}
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
"""Update VM configuration."""
logger.warning("CloudProvider.update_vm is not implemented")
return {
"name": name,
"status": "unchanged",
"message": "CloudProvider is not implemented"
}
async def get_ip(self, name: str, storage: Optional[str] = None, retry_delay: int = 2) -> str:
"""Get the IP address of a VM."""
logger.warning("CloudProvider.get_ip is not implemented")
raise NotImplementedError("CloudProvider.get_ip is not implemented")
return {"name": name, "status": "unchanged", "message": "CloudProvider is not implemented"}
async def get_ip(self, name: Optional[str] = None, storage: Optional[str] = None, retry_delay: int = 2) -> str:
"""
Return the VM's IP address as '{container_name}.containers.cloud.trycua.com'.
Uses the provided 'name' argument (the VM name requested by the caller),
falling back to self.name only if 'name' is None.
Retries up to 3 times with retry_delay seconds if hostname is not available.
"""
if name is None:
raise ValueError("VM name is required for CloudProvider.get_ip")
return f"{name}.containers.cloud.trycua.com"

View File

@@ -22,7 +22,8 @@ class VMProviderFactory:
image: Optional[str] = None,
verbose: bool = False,
ephemeral: bool = False,
noVNC_port: Optional[int] = None
noVNC_port: Optional[int] = None,
**kwargs,
) -> BaseVMProvider:
"""Create a VM provider of the specified type.
@@ -101,12 +102,9 @@ class VMProviderFactory:
elif provider_type == VMProviderType.CLOUD:
try:
from .cloud import CloudProvider
# Return the stub implementation of CloudProvider
return CloudProvider(
host=host,
port=port,
storage=storage,
verbose=verbose
verbose=verbose,
**kwargs,
)
except ImportError as e:
logger.error(f"Failed to import CloudProvider: {e}")