Add a TLS test using self-signed certs.

This commit is contained in:
Sebastian Jeltsch
2025-01-29 23:01:42 +01:00
parent a478af9b31
commit 9f1a198441
8 changed files with 163 additions and 70 deletions

23
Cargo.lock generated
View File

@@ -4648,6 +4648,19 @@ dependencies = [
"crossbeam-utils",
]
[[package]]
name = "rcgen"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2"
dependencies = [
"pem",
"ring",
"rustls-pki-types",
"time",
"yasna",
]
[[package]]
name = "record_api_rs"
version = "0.1.0"
@@ -6567,6 +6580,7 @@ dependencies = [
"prost-reflect-build",
"quoted_printable",
"rand 0.8.5",
"rcgen",
"regex",
"reqwest",
"rusqlite",
@@ -7529,6 +7543,15 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049"
[[package]]
name = "yasna"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd"
dependencies = [
"time",
]
[[package]]
name = "yoke"
version = "0.7.5"

View File

@@ -6,7 +6,7 @@ use axum::{
use tracing_subscriber::{filter, prelude::*};
use trailbase::{AppState, DataDir, Server, ServerOptions, User};
type BoxError = Box<dyn std::error::Error>;
type BoxError = Box<dyn std::error::Error + Send + Sync>;
pub async fn handler(State(_state): State<AppState>, user: Option<User>) -> Response {
Html(format!(
@@ -35,6 +35,7 @@ async fn main() -> Result<(), BoxError> {
disable_auth_ui: false,
cors_allowed_origins: vec![],
js_runtime_threads: None,
..Default::default()
},
Some(custom_routes),
|state: AppState| async move {

View File

@@ -20,7 +20,7 @@ use trailbase_cli::{
AdminSubCommands, DefaultCommandLineArgs, JsonSchemaModeArg, SubCommands, UserSubCommands,
};
type BoxError = Box<dyn std::error::Error>;
type BoxError = Box<dyn std::error::Error + Send + Sync>;
fn init_logger(dev: bool) {
// SWC is very spammy in in debug builds and complaints about source maps when compiling
@@ -88,6 +88,8 @@ async fn async_main() -> Result<(), BoxError> {
disable_auth_ui: cmd.disable_auth_ui,
cors_allowed_origins: cmd.cors_allowed_origins,
js_runtime_threads: cmd.js_runtime_threads,
tls_key: None,
tls_cert: None,
})
.await?;

View File

@@ -103,3 +103,4 @@ quoted_printable = "0.5.1"
schemars = "0.8.21"
temp-dir = "0.1.13"
tower = { version = "0.5.0", features = ["util"] }
rcgen = "0.13.2"

View File

@@ -14,10 +14,7 @@ use std::sync::Arc;
use tokio::signal;
use tokio::task::JoinSet;
use tokio_rustls::{
rustls::pki_types::{
pem::{Error as TlsPemError, PemObject},
CertificateDer, PrivateKeyDer,
},
rustls::pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer},
rustls::ServerConfig,
TlsAcceptor,
};
@@ -40,7 +37,7 @@ pub use init::{init_app_state, InitArgs, InitError};
/// A set of options to configure serving behaviors. Changing any of these options
/// requires a server restart, which makes them a natural fit for being exposed as command line
/// arguments.
#[derive(Debug, Default, Clone)]
#[derive(Debug, Default)]
pub struct ServerOptions {
/// Optional path to static assets that will be served at the HTTP root.
pub data_dir: DataDir,
@@ -69,6 +66,11 @@ pub struct ServerOptions {
/// Number of V8 worker threads. If set to None, default of num available cores will be used.
pub js_runtime_threads: Option<usize>,
/// TLS certificate path.
pub tls_cert: Option<CertificateDer<'static>>,
/// TLS key path.
pub tls_key: Option<PrivateKeyDer<'static>>,
}
pub struct Server {
@@ -77,6 +79,11 @@ pub struct Server {
// Routers.
main_router: (String, Router),
admin_router: Option<(String, Router)>,
/// TLS certificate path.
pub tls_cert: Option<CertificateDer<'static>>,
/// TLS key path.
pub tls_key: Option<PrivateKeyDer<'static>>,
}
impl Server {
@@ -138,6 +145,8 @@ impl Server {
state,
main_router,
admin_router,
tls_key: opts.tls_key,
tls_cert: opts.tls_cert,
})
}
@@ -149,38 +158,38 @@ impl Server {
return &self.main_router.1;
}
pub async fn serve(&self) -> Result<(), Box<dyn std::error::Error>> {
// This declares **where** tracing is being logged to, e.g. stderr, file, sqlite.
//
// NOTE: it's ok to fail. Just means someone else already initialize the tracing sub-system.
// {
// use tracing_subscriber::{filter, prelude::*};
// let _ = tracing_subscriber::registry()
// .with(
// logging::SqliteLogLayer::new(&self.state).with_filter(
// filter::Targets::new()
// .with_target("tower_http::trace::on_response", filter::LevelFilter::DEBUG)
// .with_target("tower_http::trace::on_request", filter::LevelFilter::DEBUG)
// .with_target("tower_http::trace::make_span", filter::LevelFilter::DEBUG)
// .with_default(filter::LevelFilter::INFO),
// ),
// )
// .try_init();
// }
pub async fn serve(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let _raii_tasks = scheduler::start_periodic_tasks(&self.state);
let mut set = JoinSet::new();
// NOTE: We panic if a key/cert that was explicitly specified cannot be loaded.
let data_dir = self.state.data_dir();
let tls_key = self.tls_key.as_ref().map_or_else(
|| {
std::fs::read(data_dir.secrets_path().join("certs").join("key.pem"))
.ok()
.and_then(|key| PrivateKeyDer::from_pem_slice(&key).ok())
},
|key| Some(key.clone_key()),
);
let tls_cert = self.tls_cert.clone().map_or_else(
|| {
std::fs::read(data_dir.secrets_path().join("certs").join("cert.pem"))
.ok()
.and_then(|cert| CertificateDer::from_pem_slice(&cert).ok())
},
Some,
);
let mut set = JoinSet::new();
{
let (addr, router) = self.main_router.clone();
let data_dir = self.state.data_dir().clone();
set.spawn(async move { Self::start_listen(&addr, router, data_dir).await });
let (tls_key, tls_cert) = (tls_key.as_ref().map(|k| k.clone_key()), tls_cert.clone());
set.spawn(async move { Self::start_listen(&addr, router, tls_key, tls_cert).await });
}
if let Some((addr, router)) = self.admin_router.clone() {
let data_dir = self.state.data_dir().clone();
set.spawn(async move { Self::start_listen(&addr, router, data_dir).await });
set.spawn(async move { Self::start_listen(&addr, router, tls_key, tls_cert).await });
}
log::info!(
@@ -197,15 +206,14 @@ impl Server {
return Ok(());
}
async fn start_listen(addr: &str, router: Router<()>, data_dir: DataDir) {
let key_path = data_dir.secrets_path().join("certs").join("key.pem");
let cert_path = data_dir.secrets_path().join("certs").join("cert.pem");
let key = tokio::fs::read(key_path).await;
let cert = tokio::fs::read(cert_path).await;
match (key, cert) {
(Ok(key), Ok(cert)) => {
async fn start_listen(
addr: &str,
router: Router<()>,
tls_key: Option<PrivateKeyDer<'static>>,
tls_cert: Option<CertificateDer<'static>>,
) {
match (tls_key, tls_cert) {
(Some(key), Some(cert)) => {
let tcp_listener = match tokio::net::TcpListener::bind(addr).await {
Ok(listener) => listener,
Err(err) => {
@@ -214,8 +222,10 @@ impl Server {
}
};
let server_config = rustls_server_config(key, cert)
.expect("Found TLS key and cert but failed to build valid server config.");
let server_config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert], key)
.expect("Failed to build server config");
let listener = serve::TlsListener {
listener: tcp_listener,
@@ -480,20 +490,6 @@ async fn shutdown_signal() {
}
}
fn rustls_server_config(key: Vec<u8>, cert: Vec<u8>) -> Result<ServerConfig, TlsPemError> {
let key = PrivateKeyDer::from_pem_slice(&key)?;
let certs = CertificateDer::from_pem_slice(&cert)?;
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![certs], key)
.expect("Failed to build server config");
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
return Ok(config);
}
#[derive(RustEmbed, Clone)]
#[folder = "js/admin/dist/"]
struct AdminAssets;

View File

@@ -1,17 +1,14 @@
use axum::http::StatusCode;
use axum_test::TestServer;
use std::rc::Rc;
use trailbase::{DataDir, Server, ServerOptions};
#[test]
fn test_admin_permissions() {
let runtime = Rc::new(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap(),
);
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let data_dir = temp_dir::TempDir::new().unwrap();
@@ -25,9 +22,11 @@ fn test_admin_permissions() {
disable_auth_ui: false,
cors_allowed_origins: vec![],
js_runtime_threads: None,
..Default::default()
})
.await
.unwrap();
let server = TestServer::new(app.router().clone()).unwrap();
assert_eq!(

View File

@@ -2,7 +2,6 @@ use axum::extract::{Json, State};
use axum::http::StatusCode;
use axum_test::multipart::MultipartForm;
use axum_test::TestServer;
use std::rc::Rc;
use tower_cookies::Cookie;
use tracing_subscriber::prelude::*;
use trailbase_sqlite::params;
@@ -17,12 +16,10 @@ use trailbase::{DataDir, Server, ServerOptions};
#[test]
fn integration_tests() {
let runtime = Rc::new(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap(),
);
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let _ = runtime.block_on(test_record_apis());
}
@@ -39,6 +36,7 @@ async fn test_record_apis() {
disable_auth_ui: false,
cors_allowed_origins: vec![],
js_runtime_threads: None,
..Default::default()
})
.await
.unwrap();

View File

@@ -0,0 +1,73 @@
use rcgen::{generate_simple_self_signed, CertifiedKey};
use tokio_rustls::rustls::pki_types::{pem::PemObject, PrivateKeyDer};
use trailbase::{DataDir, Server, ServerOptions};
#[test]
fn test_https_serving() {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let data_dir = temp_dir::TempDir::new().unwrap();
// Generate a certificate valid for "trailbase.io" and "localhost".
let subject_alt_names = vec!["trailbase.io".to_string(), "localhost".to_string()];
let CertifiedKey { cert, key_pair } = generate_simple_self_signed(subject_alt_names).unwrap();
let _ = runtime.block_on(async move {
let port = 4025;
let address = format!("127.0.0.1:{port}");
let tls_pem = key_pair.serialize_pem();
let tls_key = PrivateKeyDer::from_pem_slice(tls_pem.as_bytes()).unwrap();
let app = Server::init(ServerOptions {
data_dir: DataDir(data_dir.path().to_path_buf()),
address: address.to_string(),
admin_address: None,
public_dir: None,
dev: false,
disable_auth_ui: false,
cors_allowed_origins: vec![],
js_runtime_threads: None,
tls_key: Some(tls_key),
tls_cert: Some(cert.der().clone()),
..Default::default()
})
.await
.unwrap();
let _server = tokio::spawn(async move {
app.serve().await.unwrap();
});
let client = reqwest::ClientBuilder::new()
.add_root_certificate(reqwest::Certificate::from_pem(cert.pem().as_bytes()).unwrap())
.use_rustls_tls()
.min_tls_version(reqwest::tls::Version::TLS_1_3)
.build()
.unwrap();
'success: {
for _ in 0..100 {
let response = client
.get(&format!("https://localhost:{port}/api/healthcheck"))
.send()
.await;
log::debug!("{response:?}");
if let Ok(response) = response {
assert_eq!(response.text().await.unwrap(), "Ok");
break 'success;
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
panic!("Timed out");
}
});
}