This commit is contained in:
Sebastian Jeltsch
2025-12-17 22:24:39 +01:00
parent 381b855b95
commit 7fd2b64938
6 changed files with 83 additions and 15 deletions
+4
View File
@@ -284,6 +284,10 @@ impl AppState {
return self.state.auth.value();
}
pub(crate) fn update_auth_options(&self, f: impl FnOnce(&AuthOptions) -> AuthOptions) {
return self.state.auth.update(move |o| Arc::new(f(o)));
}
pub fn site_url(&self) -> Arc<Option<url::Url>> {
return self.state.site_url.value();
}
+5 -3
View File
@@ -1,3 +1,5 @@
use std::sync::Arc;
use axum::{
extract::{Path, Query, State},
response::Redirect,
@@ -117,7 +119,7 @@ pub(crate) async fn callback_from_external_auth_provider(
async fn callback_from_oauth_provider_setting_token_cookies(
state: &AppState,
cookies: &Cookies,
provider: &OAuthProviderType,
provider: &Arc<OAuthProviderType>,
redirect: Option<String>,
auth_code: String,
server_pkce_code_verifier: String,
@@ -174,7 +176,7 @@ async fn callback_from_oauth_provider_setting_token_cookies(
async fn callback_from_oauth_provider_using_auth_code_flow(
state: &AppState,
cookies: &Cookies,
provider: &OAuthProviderType,
provider: &Arc<OAuthProviderType>,
redirect: Option<String>,
auth_code: String,
server_pkce_code_verifier: String,
@@ -234,7 +236,7 @@ async fn callback_from_oauth_provider_using_auth_code_flow(
async fn get_or_create_user(
state: &AppState,
provider: &OAuthProviderType,
provider: &Arc<OAuthProviderType>,
auth_code: String,
server_pkce_code_verifier: String,
) -> Result<DbUser, AuthError> {
+7 -6
View File
@@ -10,7 +10,7 @@ mod oidc;
pub(crate) mod test;
use std::collections::hash_map::HashMap;
use std::sync::LazyLock;
use std::sync::{Arc, LazyLock};
use thiserror::Error;
use crate::auth::oauth::OAuthProvider;
@@ -22,9 +22,10 @@ pub enum OAuthProviderError {
Missing(String),
}
pub type OAuthProviderType = Box<dyn OAuthProvider + Send + Sync>;
type OAuthFactoryType =
dyn Fn(&str, &OAuthProviderConfig) -> Result<OAuthProviderType, OAuthProviderError> + Send + Sync;
pub type OAuthProviderType = dyn OAuthProvider + Send + Sync;
type OAuthFactoryType = dyn Fn(&str, &OAuthProviderConfig) -> Result<Box<OAuthProviderType>, OAuthProviderError>
+ Send
+ Sync;
pub(crate) struct OAuthProviderFactory {
pub id: OAuthProviderId,
@@ -56,7 +57,7 @@ pub(crate) fn oauth_providers_static_registry() -> &'static [OAuthProviderFactor
pub(crate) fn build_oauth_providers_from_config(
config: AuthConfig,
) -> Result<HashMap<String, OAuthProviderType>, OAuthProviderError> {
) -> Result<HashMap<String, Arc<OAuthProviderType>>, OAuthProviderError> {
return config
.oauth_providers
.iter()
@@ -72,7 +73,7 @@ pub(crate) fn build_oauth_providers_from_config(
};
let provider = (entry.factory)(key, config)?;
return Ok((provider.name().to_string(), provider));
return Ok((provider.name().to_string(), provider.into()));
})
.collect();
}
+26 -5
View File
@@ -1,14 +1,32 @@
use log::*;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::auth::oauth::providers::{OAuthProviderType, build_oauth_providers_from_config};
use crate::auth::password::PasswordOptions;
use crate::config::proto::AuthConfig;
#[derive(Default)]
pub struct AuthOptions {
#[derive(Clone, Default)]
pub(crate) struct AuthOptions {
password_options: PasswordOptions,
oauth_providers: HashMap<String, OAuthProviderType>,
oauth_providers: HashMap<String, Arc<OAuthProviderType>>,
pub has_login_ui: bool,
pub has_register_ui: bool,
pub has_profile_ui: bool,
}
impl PartialEq for AuthOptions {
fn eq(&self, other: &Self) -> bool {
let p0: HashSet<&String> = self.oauth_providers.keys().collect();
let p1: HashSet<&String> = other.oauth_providers.keys().collect();
return p0 == p1
&& self.password_options == other.password_options
&& self.has_login_ui == other.has_login_ui
&& self.has_register_ui == other.has_register_ui
&& self.has_profile_ui == other.has_profile_ui;
}
}
#[derive(Default)]
@@ -35,6 +53,9 @@ impl AuthOptions {
error!("Failed to derive configured OAuth providers from config: {err}");
return Default::default();
}),
has_login_ui: false,
has_register_ui: false,
has_profile_ui: false,
};
}
@@ -42,7 +63,7 @@ impl AuthOptions {
return &self.password_options;
}
pub fn lookup_oauth_provider(&self, name: &str) -> Option<&OAuthProviderType> {
pub fn lookup_oauth_provider(&self, name: &str) -> Option<&Arc<OAuthProviderType>> {
if let Some(entry) = self.oauth_providers.get(name) {
return Some(entry);
}
+1 -1
View File
@@ -5,7 +5,7 @@ use std::sync::LazyLock;
use crate::auth::AuthError;
use crate::auth::user::DbUser;
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub struct PasswordOptions {
pub min_length: usize,
pub max_length: usize,
+40
View File
@@ -162,6 +162,9 @@ pub(crate) async fn install_routes_and_jobs(
for (method, path) in init_result.http_handlers {
debug!("Installing WASM route: {method:?}: {path}");
// TODO: Check for presence of login/register/profile auth UIs.
scan_for_auth_ui(state, method, &path);
let runtime = runtime.clone();
let registered_path = path.clone();
@@ -239,6 +242,43 @@ pub(crate) async fn install_routes_and_jobs(
return Ok(Some(router));
}
fn scan_for_auth_ui(
state: &AppState,
method: trailbase_wasm_runtime_host::HttpMethodType,
path: &str,
) {
use trailbase_wasm_runtime_host::HttpMethodType;
if method != HttpMethodType::Get {
return;
}
match path {
crate::auth::REGISTER_USER_UI => {
state.update_auth_options(|old| {
let mut new = old.clone();
new.has_register_ui = true;
return new;
});
}
crate::auth::LOGIN_UI => {
state.update_auth_options(|old| {
let mut new = old.clone();
new.has_login_ui = true;
return new;
});
}
crate::auth::PROFILE_UI => {
state.update_auth_options(|old| {
let mut new = old.clone();
new.has_profile_ui = true;
return new;
});
}
_ => {}
};
}
#[inline]
fn axum_method(method: trailbase_wasm_runtime_host::HttpMethodType) -> axum::routing::MethodFilter {
use trailbase_wasm_runtime_host::HttpMethodType;