diff --git a/src/oidc.rs b/src/oidc.rs index 702afe6..3ae5f31 100644 --- a/src/oidc.rs +++ b/src/oidc.rs @@ -1,10 +1,13 @@ use anyhow::{anyhow, Result}; use oauth2::{ basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId, - ClientSecret, CsrfToken, PkceCodeChallenge, RedirectUrl, Scope, TokenResponse, TokenUrl, + ClientSecret, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope, TokenResponse, TokenUrl, }; use reqwest::Client; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Mutex; +use std::time::{Duration, Instant}; use url::Url; use crate::config::Config; @@ -25,11 +28,16 @@ pub struct OidcUserInfo { pub preferred_username: Option, } +// Storage for PKCE verifiers (csrf_token -> (verifier, expiry)) +type PkceStore = Mutex>; + #[derive(Debug)] pub struct OidcClient { oauth_client: BasicClient, discovery: OidcDiscovery, http_client: Client, + is_public_client: bool, + pkce_store: PkceStore, } impl OidcClient { @@ -42,10 +50,11 @@ impl OidcClient { .oidc_client_id .as_ref() .ok_or_else(|| anyhow!("OIDC client ID not configured"))?; - let client_secret = config - .oidc_client_secret - .as_ref() - .ok_or_else(|| anyhow!("OIDC client secret not configured"))?; + + // Client secret is optional - if not provided, this is a public client + let client_secret_opt = config.oidc_client_secret.as_ref(); + let is_public_client = client_secret_opt.is_none(); + let issuer_url = config .oidc_issuer_url .as_ref() @@ -63,7 +72,7 @@ impl OidcClient { // Create OAuth2 client let oauth_client = BasicClient::new( ClientId::new(client_id.clone()), - Some(ClientSecret::new(client_secret.clone())), + client_secret_opt.map(|s| ClientSecret::new(s.clone())), AuthUrl::new(discovery.authorization_endpoint.clone())?, Some(TokenUrl::new(discovery.token_endpoint.clone())?), ) @@ -73,6 +82,8 @@ impl OidcClient { oauth_client, discovery, http_client, + is_public_client, + pkce_store: Mutex::new(HashMap::new()), }) } @@ -101,21 +112,61 @@ impl OidcClient { } pub fn get_authorization_url(&self) -> (Url, CsrfToken) { - let (pkce_challenge, _pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + // Clean up expired PKCE verifiers (older than 10 minutes) + self.cleanup_expired_verifiers(); - self.oauth_client + let mut auth_request = self.oauth_client .authorize_url(CsrfToken::new_random) .add_scope(Scope::new("openid".to_string())) .add_scope(Scope::new("email".to_string())) - .add_scope(Scope::new("profile".to_string())) - .set_pkce_challenge(pkce_challenge) - .url() + .add_scope(Scope::new("profile".to_string())); + + // For public clients (no client_secret), PKCE is required for security + // For confidential clients, PKCE is optional but we don't use it to avoid state management + if self.is_public_client { + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + auth_request = auth_request.set_pkce_challenge(pkce_challenge); + + // Store the verifier for later use in token exchange + let (url, csrf_token) = auth_request.url(); + let mut store = self.pkce_store.lock().unwrap(); + store.insert( + csrf_token.secret().clone(), + (pkce_verifier, Instant::now() + Duration::from_secs(600)), // 10 minute expiry + ); + (url, csrf_token) + } else { + // Confidential client - no PKCE needed + auth_request.url() + } } - pub async fn exchange_code(&self, code: &str) -> Result { - let token_result = self + fn cleanup_expired_verifiers(&self) { + let mut store = self.pkce_store.lock().unwrap(); + let now = Instant::now(); + store.retain(|_, (_, expiry)| *expiry > now); + } + + pub async fn exchange_code(&self, code: &str, state: Option<&str>) -> Result { + let mut token_request = self .oauth_client - .exchange_code(AuthorizationCode::new(code.to_string())) + .exchange_code(AuthorizationCode::new(code.to_string())); + + // For public clients, retrieve and use the PKCE verifier + if self.is_public_client { + if let Some(state_token) = state { + let mut store = self.pkce_store.lock().unwrap(); + if let Some((verifier, _)) = store.remove(state_token) { + token_request = token_request.set_pkce_verifier(verifier); + } else { + return Err(anyhow!("PKCE verifier not found for state token (expired or invalid)")); + } + } else { + return Err(anyhow!("State parameter required for public client PKCE flow")); + } + } + + let token_result = token_request .request_async(async_http_client) .await .map_err(|e| anyhow!("Failed to exchange authorization code: {}", e))?; diff --git a/src/routes/auth.rs b/src/routes/auth.rs index 9f41e84..2ea11f2 100644 --- a/src/routes/auth.rs +++ b/src/routes/auth.rs @@ -207,7 +207,7 @@ async fn oidc_callback( // Exchange authorization code for access token let access_token = oidc_client - .exchange_code(&code) + .exchange_code(&code, params.state.as_deref()) .await .map_err(|e| { tracing::error!("Failed to exchange code: {}", e);