mirror of
https://github.com/trycua/computer.git
synced 2026-01-03 03:49:58 -06:00
Add computer server SSL
This commit is contained in:
@@ -27,6 +27,16 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
|
||||
default="info",
|
||||
help="Logging level (default: info)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssl-keyfile",
|
||||
type=str,
|
||||
help="Path to SSL private key file (enables HTTPS)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssl-certfile",
|
||||
type=str,
|
||||
help="Path to SSL certificate file (enables HTTPS)",
|
||||
)
|
||||
|
||||
return parser.parse_args(args)
|
||||
|
||||
@@ -43,7 +53,21 @@ def main() -> None:
|
||||
|
||||
# Create and start the server
|
||||
logger.info(f"Starting CUA Computer API server on {args.host}:{args.port}...")
|
||||
server = Server(host=args.host, port=args.port, log_level=args.log_level)
|
||||
|
||||
# Handle SSL configuration
|
||||
ssl_args = {}
|
||||
if args.ssl_keyfile and args.ssl_certfile:
|
||||
ssl_args = {
|
||||
"ssl_keyfile": args.ssl_keyfile,
|
||||
"ssl_certfile": args.ssl_certfile,
|
||||
}
|
||||
logger.info("HTTPS mode enabled with SSL certificates")
|
||||
elif args.ssl_keyfile or args.ssl_certfile:
|
||||
logger.warning("Both --ssl-keyfile and --ssl-certfile are required for HTTPS. Running in HTTP mode.")
|
||||
else:
|
||||
logger.info("HTTP mode (no SSL certificates provided)")
|
||||
|
||||
server = Server(host=args.host, port=args.port, log_level=args.log_level, **ssl_args)
|
||||
|
||||
try:
|
||||
server.start()
|
||||
|
||||
@@ -32,7 +32,8 @@ class Server:
|
||||
await server.stop() # Stop the server
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 8000, log_level: str = "info"):
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 8000, log_level: str = "info",
|
||||
ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None):
|
||||
"""
|
||||
Initialize the server.
|
||||
|
||||
@@ -40,10 +41,14 @@ class Server:
|
||||
host: Host to bind the server to
|
||||
port: Port to bind the server to
|
||||
log_level: Logging level (debug, info, warning, error, critical)
|
||||
ssl_keyfile: Path to SSL private key file (for HTTPS)
|
||||
ssl_certfile: Path to SSL certificate file (for HTTPS)
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.log_level = log_level
|
||||
self.ssl_keyfile = ssl_keyfile
|
||||
self.ssl_certfile = ssl_certfile
|
||||
self.app = fastapi_app
|
||||
self._server_task: Optional[asyncio.Task] = None
|
||||
self._should_exit = asyncio.Event()
|
||||
@@ -52,7 +57,14 @@ class Server:
|
||||
"""
|
||||
Start the server synchronously. This will block until the server is stopped.
|
||||
"""
|
||||
uvicorn.run(self.app, host=self.host, port=self.port, log_level=self.log_level)
|
||||
uvicorn.run(
|
||||
self.app,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
log_level=self.log_level,
|
||||
ssl_keyfile=self.ssl_keyfile,
|
||||
ssl_certfile=self.ssl_certfile
|
||||
)
|
||||
|
||||
async def start_async(self) -> None:
|
||||
"""
|
||||
@@ -60,7 +72,12 @@ class Server:
|
||||
will run in the background.
|
||||
"""
|
||||
server_config = uvicorn.Config(
|
||||
self.app, host=self.host, port=self.port, log_level=self.log_level
|
||||
self.app,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
log_level=self.log_level,
|
||||
ssl_keyfile=self.ssl_keyfile,
|
||||
ssl_certfile=self.ssl_certfile
|
||||
)
|
||||
|
||||
self._should_exit.clear()
|
||||
@@ -72,7 +89,8 @@ class Server:
|
||||
# Wait a short time to ensure the server starts
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
logger.info(f"Server started at http://{self.host}:{self.port}")
|
||||
protocol = "https" if self.ssl_certfile else "http"
|
||||
logger.info(f"Server started at {protocol}://{self.host}:{self.port}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user