mirror of
https://github.com/trycua/computer.git
synced 2026-02-18 12:28:51 -06:00
Merge pull request #270 from trycua/feat/computer/cloud
Add Cloud Provider Support
This commit is contained in:
@@ -326,4 +326,4 @@ Thank you to all our supporters!
|
||||
<!-- markdownlint-restore -->
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
<!-- ALL-CONTRIBUTORS-LIST:END -->
|
||||
<!-- ALL-CONTRIBUTORS-LIST:END -->
|
||||
@@ -42,11 +42,22 @@ async def main():
|
||||
ephemeral=False,
|
||||
)
|
||||
|
||||
# computer = Computer(
|
||||
# os_type="linux",
|
||||
# api_key=os.getenv("CUA_API_KEY"),
|
||||
# name=os.getenv("CONTAINER_NAME"),
|
||||
# provider_type=VMProviderType.CLOUD,
|
||||
# )
|
||||
|
||||
try:
|
||||
# Run the computer with default parameters
|
||||
await computer.run()
|
||||
|
||||
await computer.interface.hotkey("command", "space")
|
||||
screenshot = await computer.interface.screenshot()
|
||||
with open(Path("~/cua/examples/screenshot.png").expanduser(), "wb") as f:
|
||||
f.write(screenshot)
|
||||
|
||||
# await computer.interface.hotkey("command", "space")
|
||||
|
||||
# res = await computer.interface.run_command("touch ./Downloads/empty_file")
|
||||
# print(f"Run command result: {res}")
|
||||
|
||||
@@ -290,7 +290,7 @@ def get_provider_and_model(model_name: str, loop_provider: str) -> tuple:
|
||||
model_name_to_use = cleaned_model_name
|
||||
# agent_loop remains AgentLoop.OMNI
|
||||
elif agent_loop == AgentLoop.UITARS:
|
||||
# For UITARS, use MLXVLM provider for the MLX models, OAICOMPAT for custom
|
||||
# For UITARS, use MLXVLM for mlx-community models, OAICOMPAT for custom
|
||||
if model_name == "Custom model (OpenAI compatible API)":
|
||||
provider = LLMProvider.OAICOMPAT
|
||||
model_name_to_use = "tgi"
|
||||
@@ -333,12 +333,25 @@ def get_ollama_models() -> List[str]:
|
||||
logging.error(f"Error getting Ollama models: {e}")
|
||||
return []
|
||||
|
||||
def create_computer_instance(verbosity: int = logging.INFO) -> Computer:
|
||||
|
||||
def create_computer_instance(
|
||||
verbosity: int = logging.INFO,
|
||||
os_type: str = "macos",
|
||||
provider_type: str = "lume",
|
||||
name: Optional[str] = None,
|
||||
api_key: Optional[str] = None
|
||||
) -> Computer:
|
||||
"""Create or get the global Computer instance."""
|
||||
global global_computer
|
||||
|
||||
if global_computer is None:
|
||||
global_computer = Computer(verbosity=verbosity)
|
||||
global_computer = Computer(
|
||||
verbosity=verbosity,
|
||||
os_type=os_type,
|
||||
provider_type=provider_type,
|
||||
name=name if name else "",
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
return global_computer
|
||||
|
||||
@@ -353,12 +366,22 @@ def create_agent(
|
||||
verbosity: int = logging.INFO,
|
||||
use_oaicompat: bool = False,
|
||||
provider_base_url: Optional[str] = None,
|
||||
computer_os: str = "macos",
|
||||
computer_provider: str = "lume",
|
||||
computer_name: Optional[str] = None,
|
||||
computer_api_key: Optional[str] = None,
|
||||
) -> ComputerAgent:
|
||||
"""Create or update the global agent with the specified parameters."""
|
||||
global global_agent
|
||||
|
||||
# Create the computer if not already done
|
||||
computer = create_computer_instance(verbosity=verbosity)
|
||||
computer = create_computer_instance(
|
||||
verbosity=verbosity,
|
||||
os_type=computer_os,
|
||||
provider_type=computer_provider,
|
||||
name=computer_name,
|
||||
api_key=computer_api_key
|
||||
)
|
||||
|
||||
# Get API key from environment if not provided
|
||||
if api_key is None:
|
||||
@@ -401,6 +424,7 @@ def create_agent(
|
||||
|
||||
return global_agent
|
||||
|
||||
|
||||
def create_gradio_ui(
|
||||
provider_name: str = "openai",
|
||||
model_name: str = "gpt-4o",
|
||||
@@ -439,6 +463,9 @@ def create_gradio_ui(
|
||||
# Check if API keys are available
|
||||
has_openai_key = bool(openai_api_key)
|
||||
has_anthropic_key = bool(anthropic_api_key)
|
||||
|
||||
print("has_openai_key", has_openai_key)
|
||||
print("has_anthropic_key", has_anthropic_key)
|
||||
|
||||
# Get Ollama models for OMNI
|
||||
ollama_models = get_ollama_models()
|
||||
@@ -473,7 +500,7 @@ def create_gradio_ui(
|
||||
elif initial_loop == "ANTHROPIC":
|
||||
initial_model = anthropic_models[0] if anthropic_models else "No models available"
|
||||
else: # OMNI
|
||||
initial_model = omni_models[0] if omni_models else "No models available"
|
||||
initial_model = omni_models[0] if omni_models else "Custom model (OpenAI compatible API)"
|
||||
if "Custom model (OpenAI compatible API)" in available_models_for_loop:
|
||||
initial_model = (
|
||||
"Custom model (OpenAI compatible API)" # Default to custom if available and no other default fits
|
||||
@@ -494,7 +521,7 @@ def create_gradio_ui(
|
||||
]
|
||||
|
||||
# Function to generate Python code based on configuration and tasks
|
||||
def generate_python_code(agent_loop_choice, provider, model_name, tasks, provider_url, recent_images=3, save_trajectory=True):
|
||||
def generate_python_code(agent_loop_choice, provider, model_name, tasks, provider_url, recent_images=3, save_trajectory=True, computer_os="macos", computer_provider="lume", container_name="", cua_cloud_api_key=""):
|
||||
"""Generate Python code for the current configuration and tasks.
|
||||
|
||||
Args:
|
||||
@@ -505,6 +532,10 @@ def create_gradio_ui(
|
||||
provider_url: The provider base URL for OAICOMPAT providers
|
||||
recent_images: Number of recent images to keep in context
|
||||
save_trajectory: Whether to save the agent trajectory
|
||||
computer_os: Operating system type for the computer
|
||||
computer_provider: Provider type for the computer
|
||||
container_name: Optional VM name
|
||||
cua_cloud_api_key: Optional CUA Cloud API key
|
||||
|
||||
Returns:
|
||||
Formatted Python code as a string
|
||||
@@ -515,13 +546,29 @@ def create_gradio_ui(
|
||||
if task and task.strip():
|
||||
tasks_str += f' "{task}",\n'
|
||||
|
||||
# Create the Python code template
|
||||
# Create the Python code template with computer configuration
|
||||
computer_args = []
|
||||
if computer_os != "macos":
|
||||
computer_args.append(f'os_type="{computer_os}"')
|
||||
if computer_provider != "lume":
|
||||
computer_args.append(f'provider_type="{computer_provider}"')
|
||||
if container_name:
|
||||
computer_args.append(f'name="{container_name}"')
|
||||
if cua_cloud_api_key:
|
||||
computer_args.append(f'api_key="{cua_cloud_api_key}"')
|
||||
|
||||
computer_args_str = ", ".join(computer_args)
|
||||
if computer_args_str:
|
||||
computer_args_str = f"({computer_args_str})"
|
||||
else:
|
||||
computer_args_str = "()"
|
||||
|
||||
code = f'''import asyncio
|
||||
from computer import Computer
|
||||
from agent import ComputerAgent, LLM, AgentLoop, LLMProvider
|
||||
|
||||
async def main():
|
||||
async with Computer() as macos_computer:
|
||||
async with Computer{computer_args_str} as macos_computer:
|
||||
agent = ComputerAgent(
|
||||
computer=macos_computer,
|
||||
loop=AgentLoop.{agent_loop_choice},
|
||||
@@ -660,12 +707,49 @@ if __name__ == "__main__":
|
||||
LLMProvider.OPENAI,
|
||||
"gpt-4o",
|
||||
[],
|
||||
"https://openrouter.ai/api/v1"
|
||||
"https://openrouter.ai/api/v1",
|
||||
3, # recent_images default
|
||||
True, # save_trajectory default
|
||||
"macos",
|
||||
"lume",
|
||||
"",
|
||||
""
|
||||
),
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
with gr.Accordion("Configuration", open=True):
|
||||
with gr.Accordion("Computer Configuration", open=True):
|
||||
# Computer configuration options
|
||||
computer_os = gr.Radio(
|
||||
choices=["macos", "linux"],
|
||||
label="Operating System",
|
||||
value="macos",
|
||||
info="Select the operating system for the computer",
|
||||
)
|
||||
|
||||
computer_provider = gr.Radio(
|
||||
choices=["cloud", "lume"],
|
||||
label="Provider",
|
||||
value="lume",
|
||||
info="Select the computer provider",
|
||||
)
|
||||
|
||||
container_name = gr.Textbox(
|
||||
label="Container Name",
|
||||
placeholder="Enter container name (optional)",
|
||||
value="",
|
||||
info="Optional name for the container",
|
||||
)
|
||||
|
||||
cua_cloud_api_key = gr.Textbox(
|
||||
label="CUA Cloud API Key",
|
||||
placeholder="Enter your CUA Cloud API key",
|
||||
value="",
|
||||
type="password",
|
||||
info="Required for cloud provider",
|
||||
)
|
||||
|
||||
with gr.Accordion("Agent Configuration", open=True):
|
||||
# Configuration options
|
||||
agent_loop = gr.Dropdown(
|
||||
choices=["OPENAI", "ANTHROPIC", "OMNI", "UITARS"],
|
||||
@@ -986,6 +1070,10 @@ if __name__ == "__main__":
|
||||
custom_api_key=None,
|
||||
openai_key_input=None,
|
||||
anthropic_key_input=None,
|
||||
computer_os="macos",
|
||||
computer_provider="lume",
|
||||
container_name="",
|
||||
cua_cloud_api_key="",
|
||||
):
|
||||
if not history:
|
||||
yield history
|
||||
@@ -1092,6 +1180,10 @@ if __name__ == "__main__":
|
||||
"provider_base_url": custom_url_value,
|
||||
"save_trajectory": save_traj,
|
||||
"recent_images": recent_imgs,
|
||||
"computer_os": computer_os,
|
||||
"computer_provider": computer_provider,
|
||||
"container_name": container_name,
|
||||
"cua_cloud_api_key": cua_cloud_api_key,
|
||||
}
|
||||
save_settings(current_settings)
|
||||
# --- End Save Settings ---
|
||||
@@ -1109,6 +1201,10 @@ if __name__ == "__main__":
|
||||
use_oaicompat=is_oaicompat, # Set flag if custom model was selected
|
||||
# Pass custom URL only if custom model was selected
|
||||
provider_base_url=custom_url_value if is_oaicompat else None,
|
||||
computer_os=computer_os,
|
||||
computer_provider=computer_provider,
|
||||
computer_name=container_name,
|
||||
computer_api_key=cua_cloud_api_key,
|
||||
verbosity=logging.DEBUG, # Added verbosity here
|
||||
)
|
||||
|
||||
@@ -1235,6 +1331,10 @@ if __name__ == "__main__":
|
||||
provider_api_key,
|
||||
openai_api_key_input,
|
||||
anthropic_api_key_input,
|
||||
computer_os,
|
||||
computer_provider,
|
||||
container_name,
|
||||
cua_cloud_api_key,
|
||||
],
|
||||
outputs=[chatbot_history],
|
||||
queue=True,
|
||||
@@ -1253,82 +1353,20 @@ if __name__ == "__main__":
|
||||
|
||||
|
||||
# Function to update the code display based on configuration and chat history
|
||||
def update_code_display(agent_loop, model_choice_val, custom_model_val, chat_history, provider_base_url, recent_images_val, save_trajectory_val):
|
||||
def update_code_display(agent_loop, model_choice_val, custom_model_val, chat_history, provider_base_url, recent_images_val, save_trajectory_val, computer_os, computer_provider, container_name, cua_cloud_api_key):
|
||||
# Extract messages from chat history
|
||||
messages = []
|
||||
if chat_history:
|
||||
for msg in chat_history:
|
||||
if msg.get("role") == "user":
|
||||
if isinstance(msg, dict) and msg.get("role") == "user":
|
||||
messages.append(msg.get("content", ""))
|
||||
|
||||
# Determine if this is a custom model selection and which type
|
||||
is_custom_openai_api = model_choice_val == "Custom model (OpenAI compatible API)"
|
||||
is_custom_ollama = model_choice_val == "Custom model (ollama)"
|
||||
is_custom_model_selected = is_custom_openai_api or is_custom_ollama
|
||||
# Determine provider and model based on current selection
|
||||
provider, model_name, _ = get_provider_and_model(
|
||||
model_choice_val or custom_model_val or "gpt-4o",
|
||||
agent_loop
|
||||
)
|
||||
|
||||
# Determine provider and model name based on agent loop
|
||||
if agent_loop == "OPENAI":
|
||||
# For OPENAI loop, always use OPENAI provider with computer-use-preview
|
||||
provider = LLMProvider.OPENAI
|
||||
model_name = "computer-use-preview"
|
||||
elif agent_loop == "ANTHROPIC":
|
||||
# For ANTHROPIC loop, always use ANTHROPIC provider
|
||||
provider = LLMProvider.ANTHROPIC
|
||||
# Extract model name from the UI string
|
||||
if model_choice_val.startswith("Anthropic: Claude "):
|
||||
# Extract the model name based on the UI string
|
||||
model_parts = model_choice_val.replace("Anthropic: Claude ", "").split(" (")
|
||||
version = model_parts[0] # e.g., "3.7 Sonnet"
|
||||
date = model_parts[1].replace(")", "") if len(model_parts) > 1 else "" # e.g., "20250219"
|
||||
|
||||
# Format as claude-3-7-sonnet-20250219 or claude-3-5-sonnet-20240620
|
||||
version = version.replace(".", "-").replace(" ", "-").lower()
|
||||
model_name = f"claude-{version}-{date}"
|
||||
else:
|
||||
# Use the model_choice_val directly if it doesn't match the expected format
|
||||
model_name = model_choice_val
|
||||
elif agent_loop == "UITARS":
|
||||
# For UITARS, use MLXVLM for mlx-community models, OAICOMPAT for custom
|
||||
if model_choice_val == "Custom model (OpenAI compatible API)":
|
||||
provider = LLMProvider.OAICOMPAT
|
||||
model_name = custom_model_val
|
||||
else:
|
||||
provider = LLMProvider.MLXVLM
|
||||
model_name = model_choice_val
|
||||
elif agent_loop == "OMNI":
|
||||
# For OMNI, provider can be OPENAI, ANTHROPIC, OLLAMA, or OAICOMPAT
|
||||
if is_custom_openai_api:
|
||||
provider = LLMProvider.OAICOMPAT
|
||||
model_name = custom_model_val
|
||||
elif is_custom_ollama:
|
||||
provider = LLMProvider.OLLAMA
|
||||
model_name = custom_model_val
|
||||
elif model_choice_val.startswith("OMNI: OpenAI "):
|
||||
provider = LLMProvider.OPENAI
|
||||
# Extract model name from UI string (e.g., "OMNI: OpenAI GPT-4o" -> "gpt-4o")
|
||||
model_name = model_choice_val.replace("OMNI: OpenAI ", "").lower().replace(" ", "-")
|
||||
elif model_choice_val.startswith("OMNI: Claude "):
|
||||
provider = LLMProvider.ANTHROPIC
|
||||
# Extract model name from UI string (similar to ANTHROPIC loop case)
|
||||
model_parts = model_choice_val.replace("OMNI: Claude ", "").split(" (")
|
||||
version = model_parts[0] # e.g., "3.7 Sonnet"
|
||||
date = model_parts[1].replace(")", "") if len(model_parts) > 1 else "" # e.g., "20250219"
|
||||
|
||||
# Format as claude-3-7-sonnet-20250219 or claude-3-5-sonnet-20240620
|
||||
version = version.replace(".", "-").replace(" ", "-").lower()
|
||||
model_name = f"claude-{version}-{date}"
|
||||
elif model_choice_val.startswith("OMNI: Ollama "):
|
||||
provider = LLMProvider.OLLAMA
|
||||
# Extract model name from UI string (e.g., "OMNI: Ollama llama3" -> "llama3")
|
||||
model_name = model_choice_val.replace("OMNI: Ollama ", "")
|
||||
else:
|
||||
# Fallback to get_provider_and_model for any other cases
|
||||
provider, model_name, _ = get_provider_and_model(model_choice_val, agent_loop)
|
||||
else:
|
||||
# Fallback for any other agent loop
|
||||
provider, model_name, _ = get_provider_and_model(model_choice_val, agent_loop)
|
||||
|
||||
# Generate and return the code
|
||||
return generate_python_code(
|
||||
agent_loop,
|
||||
provider,
|
||||
@@ -1336,38 +1374,62 @@ if __name__ == "__main__":
|
||||
messages,
|
||||
provider_base_url,
|
||||
recent_images_val,
|
||||
save_trajectory_val
|
||||
save_trajectory_val,
|
||||
computer_os,
|
||||
computer_provider,
|
||||
container_name,
|
||||
cua_cloud_api_key
|
||||
)
|
||||
|
||||
# Update code display when configuration changes
|
||||
agent_loop.change(
|
||||
update_code_display,
|
||||
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory],
|
||||
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key],
|
||||
outputs=[code_display]
|
||||
)
|
||||
model_choice.change(
|
||||
update_code_display,
|
||||
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory],
|
||||
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key],
|
||||
outputs=[code_display]
|
||||
)
|
||||
custom_model.change(
|
||||
update_code_display,
|
||||
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory],
|
||||
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key],
|
||||
outputs=[code_display]
|
||||
)
|
||||
chatbot_history.change(
|
||||
update_code_display,
|
||||
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory],
|
||||
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key],
|
||||
outputs=[code_display]
|
||||
)
|
||||
recent_images.change(
|
||||
update_code_display,
|
||||
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory],
|
||||
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key],
|
||||
outputs=[code_display]
|
||||
)
|
||||
save_trajectory.change(
|
||||
update_code_display,
|
||||
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory],
|
||||
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key],
|
||||
outputs=[code_display]
|
||||
)
|
||||
computer_os.change(
|
||||
update_code_display,
|
||||
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key],
|
||||
outputs=[code_display]
|
||||
)
|
||||
computer_provider.change(
|
||||
update_code_display,
|
||||
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key],
|
||||
outputs=[code_display]
|
||||
)
|
||||
container_name.change(
|
||||
update_code_display,
|
||||
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key],
|
||||
outputs=[code_display]
|
||||
)
|
||||
cua_cloud_api_key.change(
|
||||
update_code_display,
|
||||
inputs=[agent_loop, model_choice, custom_model, chatbot_history, provider_base_url, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key],
|
||||
outputs=[code_display]
|
||||
)
|
||||
|
||||
|
||||
@@ -8,11 +8,11 @@ import traceback
|
||||
from contextlib import redirect_stdout, redirect_stderr
|
||||
from io import StringIO
|
||||
from .handlers.factory import HandlerFactory
|
||||
import os
|
||||
import aiohttp
|
||||
|
||||
# Set up logging with more detail
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configure WebSocket with larger message size
|
||||
@@ -48,6 +48,112 @@ manager = ConnectionManager()
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
# WebSocket message size is configured at the app or endpoint level, not on the instance
|
||||
await manager.connect(websocket)
|
||||
|
||||
# Check if CONTAINER_NAME is set (indicating cloud provider)
|
||||
container_name = os.environ.get("CONTAINER_NAME")
|
||||
|
||||
# If cloud provider, perform authentication handshake
|
||||
if container_name:
|
||||
try:
|
||||
logger.info(f"Cloud provider detected. CONTAINER_NAME: {container_name}. Waiting for authentication...")
|
||||
|
||||
# Wait for authentication message
|
||||
auth_data = await websocket.receive_json()
|
||||
|
||||
# Validate auth message format
|
||||
if auth_data.get("command") != "authenticate":
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "First message must be authentication"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
# Extract credentials
|
||||
client_api_key = auth_data.get("params", {}).get("api_key")
|
||||
client_container_name = auth_data.get("params", {}).get("container_name")
|
||||
|
||||
# Layer 1: VM Identity Verification
|
||||
if client_container_name != container_name:
|
||||
logger.warning(f"VM name mismatch. Expected: {container_name}, Got: {client_container_name}")
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "VM name mismatch"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
# Layer 2: API Key Validation with TryCUA API
|
||||
if not client_api_key:
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "API key required"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
# Validate with TryCUA API
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {client_api_key}"
|
||||
}
|
||||
|
||||
async with session.get(
|
||||
f"https://www.trycua.com/api/vm/auth?container_name={container_name}",
|
||||
headers=headers,
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
error_msg = await resp.text()
|
||||
logger.warning(f"API validation failed: {error_msg}")
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "Authentication failed"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
# If we get a 200 response with VNC URL, the VM exists and user has access
|
||||
vnc_url = (await resp.text()).strip()
|
||||
if not vnc_url:
|
||||
logger.warning(f"No VNC URL returned for VM: {container_name}")
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "VM not found"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
logger.info(f"Authentication successful for VM: {container_name}")
|
||||
await websocket.send_json({
|
||||
"success": True,
|
||||
"message": "Authenticated"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating with TryCUA API: {e}")
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "Authentication service unavailable"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication error: {e}")
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "Authentication failed"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
# Map commands to appropriate handler methods
|
||||
handlers = {
|
||||
|
||||
@@ -17,7 +17,8 @@ dependencies = [
|
||||
"uvicorn[standard]>=0.27.0",
|
||||
"pydantic>=2.0.0",
|
||||
"pyautogui>=0.9.54",
|
||||
"pillow>=10.2.0"
|
||||
"pillow>=10.2.0",
|
||||
"aiohttp>=3.9.1"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -38,7 +38,8 @@ class Computer:
|
||||
noVNC_port: Optional[int] = 8006,
|
||||
host: str = os.environ.get("PYLUME_HOST", "localhost"),
|
||||
storage: Optional[str] = None,
|
||||
ephemeral: bool = False
|
||||
ephemeral: bool = False,
|
||||
api_key: Optional[str] = None
|
||||
):
|
||||
"""Initialize a new Computer instance.
|
||||
|
||||
@@ -77,6 +78,8 @@ class Computer:
|
||||
self.os_type = os_type
|
||||
self.provider_type = provider_type
|
||||
self.ephemeral = ephemeral
|
||||
|
||||
self.api_key = api_key
|
||||
|
||||
# The default is currently to use non-ephemeral storage
|
||||
if storage and ephemeral and storage != "ephemeral":
|
||||
@@ -256,9 +259,7 @@ class Computer:
|
||||
elif self.provider_type == VMProviderType.CLOUD:
|
||||
self.config.vm_provider = VMProviderFactory.create_provider(
|
||||
self.provider_type,
|
||||
port=port,
|
||||
host=host,
|
||||
storage=storage,
|
||||
api_key=self.api_key,
|
||||
verbose=verbose,
|
||||
)
|
||||
else:
|
||||
@@ -392,12 +393,25 @@ class Computer:
|
||||
self.logger.info(f"Initializing interface for {self.os_type} at {ip_address}")
|
||||
from .interface.base import BaseComputerInterface
|
||||
|
||||
self._interface = cast(
|
||||
BaseComputerInterface,
|
||||
InterfaceFactory.create_interface_for_os(
|
||||
os=self.os_type, ip_address=ip_address # type: ignore[arg-type]
|
||||
),
|
||||
)
|
||||
# Pass authentication credentials if using cloud provider
|
||||
if self.provider_type == VMProviderType.CLOUD and self.api_key and self.config.name:
|
||||
self._interface = cast(
|
||||
BaseComputerInterface,
|
||||
InterfaceFactory.create_interface_for_os(
|
||||
os=self.os_type,
|
||||
ip_address=ip_address,
|
||||
api_key=self.api_key,
|
||||
vm_name=self.config.name
|
||||
),
|
||||
)
|
||||
else:
|
||||
self._interface = cast(
|
||||
BaseComputerInterface,
|
||||
InterfaceFactory.create_interface_for_os(
|
||||
os=self.os_type,
|
||||
ip_address=ip_address
|
||||
),
|
||||
)
|
||||
|
||||
# Wait for the WebSocket interface to be ready
|
||||
self.logger.info("Connecting to WebSocket interface...")
|
||||
@@ -492,6 +506,11 @@ class Computer:
|
||||
|
||||
# Call the provider's get_ip method which will wait indefinitely
|
||||
storage_param = "ephemeral" if self.ephemeral else self.storage
|
||||
|
||||
# Log the image being used
|
||||
self.logger.info(f"Running VM using image: {self.image}")
|
||||
|
||||
# Call provider.get_ip with explicit image parameter
|
||||
ip = await self.config.vm_provider.get_ip(
|
||||
name=self.config.name,
|
||||
storage=storage_param,
|
||||
|
||||
@@ -8,17 +8,21 @@ from ..logger import Logger, LogLevel
|
||||
class BaseComputerInterface(ABC):
|
||||
"""Base class for computer control interfaces."""
|
||||
|
||||
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume"):
|
||||
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
|
||||
"""Initialize interface.
|
||||
|
||||
Args:
|
||||
ip_address: IP address of the computer to control
|
||||
username: Username for authentication
|
||||
password: Password for authentication
|
||||
api_key: Optional API key for cloud authentication
|
||||
vm_name: Optional VM name for cloud authentication
|
||||
"""
|
||||
self.ip_address = ip_address
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.api_key = api_key
|
||||
self.vm_name = vm_name
|
||||
self.logger = Logger("cua.interface", LogLevel.NORMAL)
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Factory for creating computer interfaces."""
|
||||
|
||||
from typing import Literal
|
||||
from typing import Literal, Optional
|
||||
from .base import BaseComputerInterface
|
||||
|
||||
class InterfaceFactory:
|
||||
@@ -9,13 +9,17 @@ class InterfaceFactory:
|
||||
@staticmethod
|
||||
def create_interface_for_os(
|
||||
os: Literal['macos', 'linux'],
|
||||
ip_address: str
|
||||
ip_address: str,
|
||||
api_key: Optional[str] = None,
|
||||
vm_name: Optional[str] = None
|
||||
) -> BaseComputerInterface:
|
||||
"""Create an interface for the specified OS.
|
||||
|
||||
Args:
|
||||
os: Operating system type ('macos' or 'linux')
|
||||
ip_address: IP address of the computer to control
|
||||
api_key: Optional API key for cloud authentication
|
||||
vm_name: Optional VM name for cloud authentication
|
||||
|
||||
Returns:
|
||||
BaseComputerInterface: The appropriate interface for the OS
|
||||
@@ -28,8 +32,8 @@ class InterfaceFactory:
|
||||
from .linux import LinuxComputerInterface
|
||||
|
||||
if os == 'macos':
|
||||
return MacOSComputerInterface(ip_address)
|
||||
return MacOSComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
|
||||
elif os == 'linux':
|
||||
return LinuxComputerInterface(ip_address)
|
||||
return LinuxComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
|
||||
else:
|
||||
raise ValueError(f"Unsupported OS type: {os}")
|
||||
@@ -15,8 +15,8 @@ from .models import Key, KeyType
|
||||
class LinuxComputerInterface(BaseComputerInterface):
|
||||
"""Interface for Linux."""
|
||||
|
||||
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume"):
|
||||
super().__init__(ip_address, username, password)
|
||||
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
|
||||
super().__init__(ip_address, username, password, api_key, vm_name)
|
||||
self._ws = None
|
||||
self._reconnect_task = None
|
||||
self._closed = False
|
||||
@@ -37,7 +37,8 @@ class LinuxComputerInterface(BaseComputerInterface):
|
||||
Returns:
|
||||
WebSocket URI for the Computer API Server
|
||||
"""
|
||||
return f"ws://{self.ip_address}:8000/ws"
|
||||
protocol = "wss" if self.api_key else "ws"
|
||||
return f"{protocol}://{self.ip_address}:8000/ws"
|
||||
|
||||
async def _keep_alive(self):
|
||||
"""Keep the WebSocket connection alive with automatic reconnection."""
|
||||
@@ -86,6 +87,32 @@ class LinuxComputerInterface(BaseComputerInterface):
|
||||
timeout=30,
|
||||
)
|
||||
self.logger.info("WebSocket connection established")
|
||||
|
||||
# If api_key and vm_name are provided, perform authentication handshake
|
||||
if self.api_key and self.vm_name:
|
||||
self.logger.info("Performing authentication handshake...")
|
||||
auth_message = {
|
||||
"command": "authenticate",
|
||||
"params": {
|
||||
"api_key": self.api_key,
|
||||
"vm_name": self.vm_name
|
||||
}
|
||||
}
|
||||
await self._ws.send(json.dumps(auth_message))
|
||||
|
||||
# Wait for authentication response
|
||||
auth_response = await asyncio.wait_for(self._ws.recv(), timeout=10)
|
||||
auth_result = json.loads(auth_response)
|
||||
|
||||
if not auth_result.get("success"):
|
||||
error_msg = auth_result.get("error", "Authentication failed")
|
||||
self.logger.error(f"Authentication failed: {error_msg}")
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
raise ConnectionError(f"Authentication failed: {error_msg}")
|
||||
|
||||
self.logger.info("Authentication successful")
|
||||
|
||||
self._reconnect_delay = 1 # Reset reconnect delay on successful connection
|
||||
self._last_ping = time.time()
|
||||
retry_count = 0 # Reset retry count on successful connection
|
||||
|
||||
@@ -13,10 +13,10 @@ from .models import Key, KeyType
|
||||
|
||||
|
||||
class MacOSComputerInterface(BaseComputerInterface):
|
||||
"""Interface for MacOS."""
|
||||
"""Interface for macOS."""
|
||||
|
||||
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume"):
|
||||
super().__init__(ip_address, username, password)
|
||||
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
|
||||
super().__init__(ip_address, username, password, api_key, vm_name)
|
||||
self._ws = None
|
||||
self._reconnect_task = None
|
||||
self._closed = False
|
||||
@@ -27,7 +27,7 @@ class MacOSComputerInterface(BaseComputerInterface):
|
||||
self._max_reconnect_delay = 30 # Maximum delay between reconnection attempts
|
||||
self._log_connection_attempts = True # Flag to control connection attempt logging
|
||||
|
||||
# Set logger name for MacOS interface
|
||||
# Set logger name for macOS interface
|
||||
self.logger = Logger("cua.interface.macos", LogLevel.NORMAL)
|
||||
|
||||
@property
|
||||
@@ -37,7 +37,8 @@ class MacOSComputerInterface(BaseComputerInterface):
|
||||
Returns:
|
||||
WebSocket URI for the Computer API Server
|
||||
"""
|
||||
return f"ws://{self.ip_address}:8000/ws"
|
||||
protocol = "wss" if self.api_key else "ws"
|
||||
return f"{protocol}://{self.ip_address}:8000/ws"
|
||||
|
||||
async def _keep_alive(self):
|
||||
"""Keep the WebSocket connection alive with automatic reconnection."""
|
||||
@@ -86,6 +87,32 @@ class MacOSComputerInterface(BaseComputerInterface):
|
||||
timeout=30,
|
||||
)
|
||||
self.logger.info("WebSocket connection established")
|
||||
|
||||
# If api_key and vm_name are provided, perform authentication handshake
|
||||
if self.api_key and self.vm_name:
|
||||
self.logger.info("Performing authentication handshake...")
|
||||
auth_message = {
|
||||
"command": "authenticate",
|
||||
"params": {
|
||||
"api_key": self.api_key,
|
||||
"vm_name": self.vm_name
|
||||
}
|
||||
}
|
||||
await self._ws.send(json.dumps(auth_message))
|
||||
|
||||
# Wait for authentication response
|
||||
auth_response = await asyncio.wait_for(self._ws.recv(), timeout=10)
|
||||
auth_result = json.loads(auth_response)
|
||||
|
||||
if not auth_result.get("success"):
|
||||
error_msg = auth_result.get("error", "Authentication failed")
|
||||
self.logger.error(f"Authentication failed: {error_msg}")
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
raise ConnectionError(f"Authentication failed: {error_msg}")
|
||||
|
||||
self.logger.info("Authentication successful")
|
||||
|
||||
self._reconnect_delay = 1 # Reset reconnect delay on successful connection
|
||||
self._last_ping = time.time()
|
||||
retry_count = 0 # Reset retry count on successful connection
|
||||
|
||||
@@ -11,90 +11,65 @@ from ..base import BaseVMProvider, VMProviderType
|
||||
# Setup logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from urllib.parse import urlparse
|
||||
|
||||
class CloudProvider(BaseVMProvider):
|
||||
"""Cloud VM Provider stub implementation.
|
||||
|
||||
This is a placeholder for a future cloud VM provider implementation.
|
||||
"""
|
||||
|
||||
"""Cloud VM Provider implementation."""
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "localhost",
|
||||
port: int = 7777,
|
||||
storage: Optional[str] = None,
|
||||
self,
|
||||
api_key: str,
|
||||
verbose: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Cloud provider.
|
||||
|
||||
"""
|
||||
Args:
|
||||
host: Host to use for API connections (default: localhost)
|
||||
port: Port for the API server (default: 7777)
|
||||
storage: Path to store VM data
|
||||
api_key: API key for authentication
|
||||
name: Name of the VM
|
||||
verbose: Enable verbose logging
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.storage = storage
|
||||
assert api_key, "api_key required for CloudProvider"
|
||||
self.api_key = api_key
|
||||
self.verbose = verbose
|
||||
|
||||
logger.warning("CloudProvider is not yet implemented")
|
||||
|
||||
|
||||
@property
|
||||
def provider_type(self) -> VMProviderType:
|
||||
"""Get the provider type."""
|
||||
return VMProviderType.CLOUD
|
||||
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Enter async context manager."""
|
||||
logger.debug("Entering CloudProvider context")
|
||||
return self
|
||||
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit async context manager."""
|
||||
logger.debug("Exiting CloudProvider context")
|
||||
|
||||
pass
|
||||
|
||||
async def get_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get VM information by name."""
|
||||
logger.warning("CloudProvider.get_vm is not implemented")
|
||||
return {
|
||||
"name": name,
|
||||
"status": "unavailable",
|
||||
"message": "CloudProvider is not implemented"
|
||||
}
|
||||
|
||||
"""Get VM VNC URL by name using the cloud API."""
|
||||
return {"name": name, "hostname": f"{name}.containers.cloud.trycua.com"}
|
||||
|
||||
async def list_vms(self) -> List[Dict[str, Any]]:
|
||||
"""List all available VMs."""
|
||||
logger.warning("CloudProvider.list_vms is not implemented")
|
||||
return []
|
||||
|
||||
|
||||
async def run_vm(self, image: str, name: str, run_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Run a VM with the given options."""
|
||||
logger.warning("CloudProvider.run_vm is not implemented")
|
||||
return {
|
||||
"name": name,
|
||||
"status": "unavailable",
|
||||
"message": "CloudProvider is not implemented"
|
||||
}
|
||||
|
||||
return {"name": name, "status": "unavailable", "message": "CloudProvider is not implemented"}
|
||||
|
||||
async def stop_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Stop a running VM."""
|
||||
logger.warning("CloudProvider.stop_vm is not implemented")
|
||||
return {
|
||||
"name": name,
|
||||
"status": "stopped",
|
||||
"message": "CloudProvider is not implemented"
|
||||
}
|
||||
|
||||
return {"name": name, "status": "stopped", "message": "CloudProvider is not implemented"}
|
||||
|
||||
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Update VM configuration."""
|
||||
logger.warning("CloudProvider.update_vm is not implemented")
|
||||
return {
|
||||
"name": name,
|
||||
"status": "unchanged",
|
||||
"message": "CloudProvider is not implemented"
|
||||
}
|
||||
|
||||
async def get_ip(self, name: str, storage: Optional[str] = None, retry_delay: int = 2) -> str:
|
||||
"""Get the IP address of a VM."""
|
||||
logger.warning("CloudProvider.get_ip is not implemented")
|
||||
raise NotImplementedError("CloudProvider.get_ip is not implemented")
|
||||
return {"name": name, "status": "unchanged", "message": "CloudProvider is not implemented"}
|
||||
|
||||
async def get_ip(self, name: Optional[str] = None, storage: Optional[str] = None, retry_delay: int = 2) -> str:
|
||||
"""
|
||||
Return the VM's IP address as '{container_name}.containers.cloud.trycua.com'.
|
||||
Uses the provided 'name' argument (the VM name requested by the caller),
|
||||
falling back to self.name only if 'name' is None.
|
||||
Retries up to 3 times with retry_delay seconds if hostname is not available.
|
||||
"""
|
||||
if name is None:
|
||||
raise ValueError("VM name is required for CloudProvider.get_ip")
|
||||
return f"{name}.containers.cloud.trycua.com"
|
||||
|
||||
@@ -22,7 +22,8 @@ class VMProviderFactory:
|
||||
image: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
ephemeral: bool = False,
|
||||
noVNC_port: Optional[int] = None
|
||||
noVNC_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> BaseVMProvider:
|
||||
"""Create a VM provider of the specified type.
|
||||
|
||||
@@ -101,12 +102,9 @@ class VMProviderFactory:
|
||||
elif provider_type == VMProviderType.CLOUD:
|
||||
try:
|
||||
from .cloud import CloudProvider
|
||||
# Return the stub implementation of CloudProvider
|
||||
return CloudProvider(
|
||||
host=host,
|
||||
port=port,
|
||||
storage=storage,
|
||||
verbose=verbose
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import CloudProvider: {e}")
|
||||
|
||||
Reference in New Issue
Block a user