diff --git a/Cargo.lock b/Cargo.lock index e5e00f1a..6e5fe532 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4648,6 +4648,19 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "rcgen" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" +dependencies = [ + "pem", + "ring", + "rustls-pki-types", + "time", + "yasna", +] + [[package]] name = "record_api_rs" version = "0.1.0" @@ -6567,6 +6580,7 @@ dependencies = [ "prost-reflect-build", "quoted_printable", "rand 0.8.5", + "rcgen", "regex", "reqwest", "rusqlite", @@ -7529,6 +7543,15 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" +[[package]] +name = "yasna" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" +dependencies = [ + "time", +] + [[package]] name = "yoke" version = "0.7.5" diff --git a/examples/custom-binary/src/main.rs b/examples/custom-binary/src/main.rs index abb83742..bd384f0a 100644 --- a/examples/custom-binary/src/main.rs +++ b/examples/custom-binary/src/main.rs @@ -6,7 +6,7 @@ use axum::{ use tracing_subscriber::{filter, prelude::*}; use trailbase::{AppState, DataDir, Server, ServerOptions, User}; -type BoxError = Box; +type BoxError = Box; pub async fn handler(State(_state): State, user: Option) -> Response { Html(format!( @@ -35,6 +35,7 @@ async fn main() -> Result<(), BoxError> { disable_auth_ui: false, cors_allowed_origins: vec![], js_runtime_threads: None, + ..Default::default() }, Some(custom_routes), |state: AppState| async move { diff --git a/trailbase-cli/src/bin/trail.rs b/trailbase-cli/src/bin/trail.rs index 5ba3c2b1..e3042e10 100644 --- a/trailbase-cli/src/bin/trail.rs +++ b/trailbase-cli/src/bin/trail.rs @@ -20,7 +20,7 @@ use trailbase_cli::{ AdminSubCommands, DefaultCommandLineArgs, JsonSchemaModeArg, SubCommands, UserSubCommands, }; -type BoxError = Box; +type BoxError = Box; fn init_logger(dev: bool) { // SWC is very spammy in in debug builds and complaints about source maps when compiling @@ -88,6 +88,8 @@ async fn async_main() -> Result<(), BoxError> { disable_auth_ui: cmd.disable_auth_ui, cors_allowed_origins: cmd.cors_allowed_origins, js_runtime_threads: cmd.js_runtime_threads, + tls_key: None, + tls_cert: None, }) .await?; diff --git a/trailbase-core/Cargo.toml b/trailbase-core/Cargo.toml index 4bb85a37..a649a2e3 100644 --- a/trailbase-core/Cargo.toml +++ b/trailbase-core/Cargo.toml @@ -103,3 +103,4 @@ quoted_printable = "0.5.1" schemars = "0.8.21" temp-dir = "0.1.13" tower = { version = "0.5.0", features = ["util"] } +rcgen = "0.13.2" diff --git a/trailbase-core/src/server/mod.rs b/trailbase-core/src/server/mod.rs index a9f81ed1..d70127f8 100644 --- a/trailbase-core/src/server/mod.rs +++ b/trailbase-core/src/server/mod.rs @@ -14,10 +14,7 @@ use std::sync::Arc; use tokio::signal; use tokio::task::JoinSet; use tokio_rustls::{ - rustls::pki_types::{ - pem::{Error as TlsPemError, PemObject}, - CertificateDer, PrivateKeyDer, - }, + rustls::pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer}, rustls::ServerConfig, TlsAcceptor, }; @@ -40,7 +37,7 @@ pub use init::{init_app_state, InitArgs, InitError}; /// A set of options to configure serving behaviors. Changing any of these options /// requires a server restart, which makes them a natural fit for being exposed as command line /// arguments. -#[derive(Debug, Default, Clone)] +#[derive(Debug, Default)] pub struct ServerOptions { /// Optional path to static assets that will be served at the HTTP root. pub data_dir: DataDir, @@ -69,6 +66,11 @@ pub struct ServerOptions { /// Number of V8 worker threads. If set to None, default of num available cores will be used. pub js_runtime_threads: Option, + + /// TLS certificate path. + pub tls_cert: Option>, + /// TLS key path. + pub tls_key: Option>, } pub struct Server { @@ -77,6 +79,11 @@ pub struct Server { // Routers. main_router: (String, Router), admin_router: Option<(String, Router)>, + + /// TLS certificate path. + pub tls_cert: Option>, + /// TLS key path. + pub tls_key: Option>, } impl Server { @@ -138,6 +145,8 @@ impl Server { state, main_router, admin_router, + tls_key: opts.tls_key, + tls_cert: opts.tls_cert, }) } @@ -149,38 +158,38 @@ impl Server { return &self.main_router.1; } - pub async fn serve(&self) -> Result<(), Box> { - // This declares **where** tracing is being logged to, e.g. stderr, file, sqlite. - // - // NOTE: it's ok to fail. Just means someone else already initialize the tracing sub-system. - // { - // use tracing_subscriber::{filter, prelude::*}; - // let _ = tracing_subscriber::registry() - // .with( - // logging::SqliteLogLayer::new(&self.state).with_filter( - // filter::Targets::new() - // .with_target("tower_http::trace::on_response", filter::LevelFilter::DEBUG) - // .with_target("tower_http::trace::on_request", filter::LevelFilter::DEBUG) - // .with_target("tower_http::trace::make_span", filter::LevelFilter::DEBUG) - // .with_default(filter::LevelFilter::INFO), - // ), - // ) - // .try_init(); - // } - + pub async fn serve(&self) -> Result<(), Box> { let _raii_tasks = scheduler::start_periodic_tasks(&self.state); - let mut set = JoinSet::new(); + // NOTE: We panic if a key/cert that was explicitly specified cannot be loaded. + let data_dir = self.state.data_dir(); + let tls_key = self.tls_key.as_ref().map_or_else( + || { + std::fs::read(data_dir.secrets_path().join("certs").join("key.pem")) + .ok() + .and_then(|key| PrivateKeyDer::from_pem_slice(&key).ok()) + }, + |key| Some(key.clone_key()), + ); + let tls_cert = self.tls_cert.clone().map_or_else( + || { + std::fs::read(data_dir.secrets_path().join("certs").join("cert.pem")) + .ok() + .and_then(|cert| CertificateDer::from_pem_slice(&cert).ok()) + }, + Some, + ); + let mut set = JoinSet::new(); { let (addr, router) = self.main_router.clone(); - let data_dir = self.state.data_dir().clone(); - set.spawn(async move { Self::start_listen(&addr, router, data_dir).await }); + let (tls_key, tls_cert) = (tls_key.as_ref().map(|k| k.clone_key()), tls_cert.clone()); + + set.spawn(async move { Self::start_listen(&addr, router, tls_key, tls_cert).await }); } if let Some((addr, router)) = self.admin_router.clone() { - let data_dir = self.state.data_dir().clone(); - set.spawn(async move { Self::start_listen(&addr, router, data_dir).await }); + set.spawn(async move { Self::start_listen(&addr, router, tls_key, tls_cert).await }); } log::info!( @@ -197,15 +206,14 @@ impl Server { return Ok(()); } - async fn start_listen(addr: &str, router: Router<()>, data_dir: DataDir) { - let key_path = data_dir.secrets_path().join("certs").join("key.pem"); - let cert_path = data_dir.secrets_path().join("certs").join("cert.pem"); - - let key = tokio::fs::read(key_path).await; - let cert = tokio::fs::read(cert_path).await; - - match (key, cert) { - (Ok(key), Ok(cert)) => { + async fn start_listen( + addr: &str, + router: Router<()>, + tls_key: Option>, + tls_cert: Option>, + ) { + match (tls_key, tls_cert) { + (Some(key), Some(cert)) => { let tcp_listener = match tokio::net::TcpListener::bind(addr).await { Ok(listener) => listener, Err(err) => { @@ -214,8 +222,10 @@ impl Server { } }; - let server_config = rustls_server_config(key, cert) - .expect("Found TLS key and cert but failed to build valid server config."); + let server_config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(vec![cert], key) + .expect("Failed to build server config"); let listener = serve::TlsListener { listener: tcp_listener, @@ -480,20 +490,6 @@ async fn shutdown_signal() { } } -fn rustls_server_config(key: Vec, cert: Vec) -> Result { - let key = PrivateKeyDer::from_pem_slice(&key)?; - let certs = CertificateDer::from_pem_slice(&cert)?; - - let mut config = ServerConfig::builder() - .with_no_client_auth() - .with_single_cert(vec![certs], key) - .expect("Failed to build server config"); - - config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - - return Ok(config); -} - #[derive(RustEmbed, Clone)] #[folder = "js/admin/dist/"] struct AdminAssets; diff --git a/trailbase-core/tests/admin_permissions_test.rs b/trailbase-core/tests/admin_permissions_test.rs index 2900186d..d933b6d8 100644 --- a/trailbase-core/tests/admin_permissions_test.rs +++ b/trailbase-core/tests/admin_permissions_test.rs @@ -1,17 +1,14 @@ use axum::http::StatusCode; use axum_test::TestServer; -use std::rc::Rc; use trailbase::{DataDir, Server, ServerOptions}; #[test] fn test_admin_permissions() { - let runtime = Rc::new( - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap(), - ); + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); let data_dir = temp_dir::TempDir::new().unwrap(); @@ -25,9 +22,11 @@ fn test_admin_permissions() { disable_auth_ui: false, cors_allowed_origins: vec![], js_runtime_threads: None, + ..Default::default() }) .await .unwrap(); + let server = TestServer::new(app.router().clone()).unwrap(); assert_eq!( diff --git a/trailbase-core/tests/integration_test.rs b/trailbase-core/tests/integration_test.rs index 3ccf0542..7414fa1c 100644 --- a/trailbase-core/tests/integration_test.rs +++ b/trailbase-core/tests/integration_test.rs @@ -2,7 +2,6 @@ use axum::extract::{Json, State}; use axum::http::StatusCode; use axum_test::multipart::MultipartForm; use axum_test::TestServer; -use std::rc::Rc; use tower_cookies::Cookie; use tracing_subscriber::prelude::*; use trailbase_sqlite::params; @@ -17,12 +16,10 @@ use trailbase::{DataDir, Server, ServerOptions}; #[test] fn integration_tests() { - let runtime = Rc::new( - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap(), - ); + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); let _ = runtime.block_on(test_record_apis()); } @@ -39,6 +36,7 @@ async fn test_record_apis() { disable_auth_ui: false, cors_allowed_origins: vec![], js_runtime_threads: None, + ..Default::default() }) .await .unwrap(); diff --git a/trailbase-core/tests/tls_test.rs b/trailbase-core/tests/tls_test.rs new file mode 100644 index 00000000..be30dfe3 --- /dev/null +++ b/trailbase-core/tests/tls_test.rs @@ -0,0 +1,73 @@ +use rcgen::{generate_simple_self_signed, CertifiedKey}; +use tokio_rustls::rustls::pki_types::{pem::PemObject, PrivateKeyDer}; +use trailbase::{DataDir, Server, ServerOptions}; + +#[test] +fn test_https_serving() { + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + let data_dir = temp_dir::TempDir::new().unwrap(); + + // Generate a certificate valid for "trailbase.io" and "localhost". + let subject_alt_names = vec!["trailbase.io".to_string(), "localhost".to_string()]; + + let CertifiedKey { cert, key_pair } = generate_simple_self_signed(subject_alt_names).unwrap(); + + let _ = runtime.block_on(async move { + let port = 4025; + let address = format!("127.0.0.1:{port}"); + + let tls_pem = key_pair.serialize_pem(); + let tls_key = PrivateKeyDer::from_pem_slice(tls_pem.as_bytes()).unwrap(); + + let app = Server::init(ServerOptions { + data_dir: DataDir(data_dir.path().to_path_buf()), + address: address.to_string(), + admin_address: None, + public_dir: None, + dev: false, + disable_auth_ui: false, + cors_allowed_origins: vec![], + js_runtime_threads: None, + tls_key: Some(tls_key), + tls_cert: Some(cert.der().clone()), + ..Default::default() + }) + .await + .unwrap(); + + let _server = tokio::spawn(async move { + app.serve().await.unwrap(); + }); + + let client = reqwest::ClientBuilder::new() + .add_root_certificate(reqwest::Certificate::from_pem(cert.pem().as_bytes()).unwrap()) + .use_rustls_tls() + .min_tls_version(reqwest::tls::Version::TLS_1_3) + .build() + .unwrap(); + + 'success: { + for _ in 0..100 { + let response = client + .get(&format!("https://localhost:{port}/api/healthcheck")) + .send() + .await; + + log::debug!("{response:?}"); + + if let Ok(response) = response { + assert_eq!(response.text().await.unwrap(), "Ok"); + break 'success; + } + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + } + + panic!("Timed out"); + } + }); +}