mirror of
https://github.com/btouchard/ackify.git
synced 2026-05-08 16:09:23 -05:00
feat: add silent OAuth login with auto-authentication
- Add ACKIFY_OAUTH_AUTO_LOGIN config flag (default: false) - Implement /api/auth/check endpoint for session validation - Add silent login flow with prompt=none OAuth parameter - Implement localStorage-based retry prevention (5min cooldown) - Add comprehensive OAuth flow debugging logs - Handle OAuth errors gracefully (login_required, interaction_required) - Update templates with silent login JavaScript - Add login button in header when not authenticated - Fix /health endpoint documentation (remove /healthz alias) - Extend tests to include autoLogin parameter
This commit is contained in:
@@ -72,19 +72,25 @@ func NewOAuthService(config Config) *OauthService {
|
||||
func (s *OauthService) GetUser(r *http.Request) (*models.User, error) {
|
||||
session, err := s.sessionStore.Get(r, sessionName)
|
||||
if err != nil {
|
||||
logger.Logger.Debug("GetUser: failed to get session", "error", err.Error())
|
||||
return nil, fmt.Errorf("failed to get session: %w", err)
|
||||
}
|
||||
|
||||
userJSON, ok := session.Values["user"].(string)
|
||||
if !ok || userJSON == "" {
|
||||
logger.Logger.Debug("GetUser: no user in session",
|
||||
"user_key_exists", ok,
|
||||
"user_json_empty", userJSON == "")
|
||||
return nil, models.ErrUnauthorized
|
||||
}
|
||||
|
||||
var user models.User
|
||||
if err := json.Unmarshal([]byte(userJSON), &user); err != nil {
|
||||
logger.Logger.Error("GetUser: failed to unmarshal user", "error", err.Error())
|
||||
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Debug("GetUser: user found", "email", user.Email)
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
@@ -93,9 +99,14 @@ func (s *OauthService) SetUser(w http.ResponseWriter, r *http.Request, user *mod
|
||||
|
||||
userJSON, err := json.Marshal(user)
|
||||
if err != nil {
|
||||
logger.Logger.Error("SetUser: failed to marshal user", "error", err.Error())
|
||||
return fmt.Errorf("failed to marshal user: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Debug("SetUser: saving user to session",
|
||||
"email", user.Email,
|
||||
"secure_cookies", s.secureCookies)
|
||||
|
||||
session.Values["user"] = string(userJSON)
|
||||
session.Options = &sessions.Options{
|
||||
Path: "/",
|
||||
@@ -105,9 +116,11 @@ func (s *OauthService) SetUser(w http.ResponseWriter, r *http.Request, user *mod
|
||||
}
|
||||
|
||||
if err := session.Save(r, w); err != nil {
|
||||
logger.Logger.Error("SetUser: failed to save session", "error", err.Error())
|
||||
return fmt.Errorf("failed to save session: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Debug("SetUser: session saved successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -148,26 +161,57 @@ func (s *OauthService) CreateAuthURL(w http.ResponseWriter, r *http.Request, nex
|
||||
token := base64.RawURLEncoding.EncodeToString(randPart)
|
||||
state := token + ":" + base64.RawURLEncoding.EncodeToString([]byte(nextURL))
|
||||
|
||||
logger.Logger.Debug("CreateAuthURL: generating OAuth state",
|
||||
"token_length", len(token),
|
||||
"next_url", nextURL)
|
||||
|
||||
session, _ := s.sessionStore.Get(r, sessionName)
|
||||
session.Values["oauth_state"] = token
|
||||
session.Options = &sessions.Options{Path: "/", HttpOnly: true, Secure: s.secureCookies, SameSite: http.SameSiteLaxMode}
|
||||
_ = session.Save(r, w)
|
||||
err := session.Save(r, w)
|
||||
if err != nil {
|
||||
logger.Logger.Error("CreateAuthURL: failed to save session", "error", err.Error())
|
||||
}
|
||||
|
||||
return s.oauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("prompt", "select_account"))
|
||||
// Check if silent login is requested
|
||||
promptParam := "select_account"
|
||||
if r.URL.Query().Get("silent") == "true" {
|
||||
promptParam = "none"
|
||||
logger.Logger.Debug("CreateAuthURL: using silent login (prompt=none)")
|
||||
}
|
||||
|
||||
authURL := s.oauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("prompt", promptParam))
|
||||
logger.Logger.Debug("CreateAuthURL: generated auth URL",
|
||||
"prompt", promptParam,
|
||||
"url_length", len(authURL))
|
||||
|
||||
return authURL
|
||||
}
|
||||
|
||||
// VerifyState Clear single-use state on success to prevent replay; compare in constant time to avoid timing leaks.
|
||||
func (s *OauthService) VerifyState(w http.ResponseWriter, r *http.Request, stateToken string) bool {
|
||||
session, _ := s.sessionStore.Get(r, sessionName)
|
||||
stored, _ := session.Values["oauth_state"].(string)
|
||||
|
||||
logger.Logger.Debug("VerifyState: validating OAuth state",
|
||||
"stored_length", len(stored),
|
||||
"token_length", len(stateToken),
|
||||
"stored_empty", stored == "",
|
||||
"token_empty", stateToken == "")
|
||||
|
||||
if stored == "" || stateToken == "" {
|
||||
logger.Logger.Warn("VerifyState: empty state tokens")
|
||||
return false
|
||||
}
|
||||
|
||||
if subtleConstantTimeCompare(stored, stateToken) {
|
||||
logger.Logger.Debug("VerifyState: state valid, clearing token")
|
||||
delete(session.Values, "oauth_state")
|
||||
_ = session.Save(r, w)
|
||||
return true
|
||||
}
|
||||
|
||||
logger.Logger.Warn("VerifyState: state mismatch")
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ type OAuthConfig struct {
|
||||
Scopes []string
|
||||
AllowedDomain string
|
||||
CookieSecret []byte
|
||||
AutoLogin bool
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
@@ -63,6 +64,7 @@ func Load() (*Config, error) {
|
||||
config.OAuth.ClientID = mustGetEnv("ACKIFY_OAUTH_CLIENT_ID")
|
||||
config.OAuth.ClientSecret = mustGetEnv("ACKIFY_OAUTH_CLIENT_SECRET")
|
||||
config.OAuth.AllowedDomain = os.Getenv("ACKIFY_OAUTH_ALLOWED_DOMAIN")
|
||||
config.OAuth.AutoLogin = strings.ToLower(getEnv("ACKIFY_OAUTH_AUTO_LOGIN", "false")) == "true"
|
||||
|
||||
provider := strings.ToLower(getEnv("ACKIFY_OAUTH_PROVIDER", ""))
|
||||
switch provider {
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -27,8 +29,16 @@ func (h *AuthHandlers) HandleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
next = h.baseURL + "/"
|
||||
}
|
||||
|
||||
logger.Logger.Debug("HandleLogin: starting OAuth flow",
|
||||
"next_url", next,
|
||||
"query_params", r.URL.Query().Encode())
|
||||
|
||||
// Persist CSRF state in session when generating auth URL
|
||||
authURL := h.authService.CreateAuthURL(w, r, next)
|
||||
|
||||
logger.Logger.Debug("HandleLogin: redirecting to OAuth provider",
|
||||
"auth_url", authURL)
|
||||
|
||||
http.Redirect(w, r, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
@@ -45,11 +55,74 @@ func (h *AuthHandlers) HandleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
||||
|
||||
func (h *AuthHandlers) HandleAuthCheck(w http.ResponseWriter, r *http.Request) {
|
||||
user, err := h.authService.GetUser(r)
|
||||
if err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"authenticated":false}`))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
response := map[string]interface{}{
|
||||
"authenticated": true,
|
||||
"user": map[string]string{
|
||||
"email": user.Email,
|
||||
"name": user.Name,
|
||||
},
|
||||
}
|
||||
|
||||
if jsonBytes, err := json.Marshal(response); err == nil {
|
||||
w.Write(jsonBytes)
|
||||
} else {
|
||||
w.Write([]byte(`{"authenticated":false}`))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AuthHandlers) HandleOAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
oauthError := r.URL.Query().Get("error")
|
||||
errorDescription := r.URL.Query().Get("error_description")
|
||||
|
||||
logger.Logger.Debug("HandleOAuthCallback: received callback",
|
||||
"code_present", code != "",
|
||||
"state_present", state != "",
|
||||
"error", oauthError,
|
||||
"query_params", r.URL.Query().Encode())
|
||||
|
||||
// Gérer les erreurs OAuth (ex: prompt=none sans session active)
|
||||
if oauthError != "" {
|
||||
logger.Logger.Debug("HandleOAuthCallback: OAuth error received",
|
||||
"error", oauthError,
|
||||
"description", errorDescription)
|
||||
|
||||
// Si c'est une erreur de silent login (prompt=none), rediriger silencieusement
|
||||
if oauthError == "login_required" || oauthError == "interaction_required" || oauthError == "consent_required" {
|
||||
// Extraire next_url du state
|
||||
parts := strings.SplitN(state, ":", 2)
|
||||
nextURL := "/"
|
||||
if len(parts) == 2 {
|
||||
if nb, err := base64.RawURLEncoding.DecodeString(parts[1]); err == nil {
|
||||
nextURL = string(nb)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Logger.Debug("HandleOAuthCallback: silent login failed, redirecting to original URL",
|
||||
"next_url", nextURL)
|
||||
http.Redirect(w, r, nextURL, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Pour d'autres erreurs, afficher un message
|
||||
http.Error(w, "OAuth error: "+oauthError, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if code == "" {
|
||||
logger.Logger.Warn("HandleOAuthCallback: missing authorization code")
|
||||
http.Error(w, "Missing authorization code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
@@ -60,7 +133,14 @@ func (h *AuthHandlers) HandleOAuthCallback(w http.ResponseWriter, r *http.Reques
|
||||
if len(parts) > 0 {
|
||||
token = parts[0]
|
||||
}
|
||||
|
||||
logger.Logger.Debug("HandleOAuthCallback: validating state",
|
||||
"token_length", len(token),
|
||||
"state_parts", len(parts))
|
||||
|
||||
if token == "" || !h.authService.VerifyState(w, r, token) {
|
||||
logger.Logger.Warn("HandleOAuthCallback: invalid OAuth state",
|
||||
"token_empty", token == "")
|
||||
http.Error(w, "Invalid OAuth state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
@@ -73,7 +153,12 @@ func (h *AuthHandlers) HandleOAuthCallback(w http.ResponseWriter, r *http.Reques
|
||||
return
|
||||
}
|
||||
|
||||
logger.Logger.Debug("HandleOAuthCallback: user authenticated",
|
||||
"user_email", user.Email,
|
||||
"next_url", nextURL)
|
||||
|
||||
if err := h.authService.SetUser(w, r, user); err != nil {
|
||||
logger.Logger.Error("HandleOAuthCallback: failed to set user session", "error", err.Error())
|
||||
http.Error(w, "Failed to set user session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -84,8 +169,14 @@ func (h *AuthHandlers) HandleOAuthCallback(w http.ResponseWriter, r *http.Reques
|
||||
|
||||
if parsedURL, err := url.Parse(nextURL); err != nil ||
|
||||
(parsedURL.Host != "" && parsedURL.Host != r.Host) {
|
||||
logger.Logger.Debug("HandleOAuthCallback: invalid nextURL, using /",
|
||||
"original_next", nextURL,
|
||||
"parse_error", err != nil)
|
||||
nextURL = "/"
|
||||
}
|
||||
|
||||
logger.Logger.Debug("HandleOAuthCallback: redirecting user",
|
||||
"final_next_url", nextURL)
|
||||
|
||||
http.Redirect(w, r, nextURL, http.StatusFound)
|
||||
}
|
||||
|
||||
@@ -19,7 +19,9 @@ import (
|
||||
type fakeAuthService struct {
|
||||
shouldFailSetUser bool
|
||||
shouldFailCallback bool
|
||||
shouldFailGetUser bool
|
||||
setUserError error
|
||||
getUserError error
|
||||
callbackUser *models.User
|
||||
callbackNextURL string
|
||||
callbackError error
|
||||
@@ -29,6 +31,7 @@ type fakeAuthService struct {
|
||||
|
||||
verifyStateResult bool
|
||||
lastVerifyToken string
|
||||
currentUser *models.User
|
||||
}
|
||||
|
||||
func newFakeAuthService() *fakeAuthService {
|
||||
@@ -40,15 +43,24 @@ func newFakeAuthService() *fakeAuthService {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeAuthService) SetUser(_ http.ResponseWriter, _ *http.Request, _ *models.User) error {
|
||||
func (f *fakeAuthService) GetUser(_ *http.Request) (*models.User, error) {
|
||||
if f.shouldFailGetUser {
|
||||
return nil, f.getUserError
|
||||
}
|
||||
return f.currentUser, nil
|
||||
}
|
||||
|
||||
func (f *fakeAuthService) SetUser(_ http.ResponseWriter, _ *http.Request, user *models.User) error {
|
||||
if f.shouldFailSetUser {
|
||||
return f.setUserError
|
||||
}
|
||||
f.currentUser = user
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeAuthService) Logout(_ http.ResponseWriter, _ *http.Request) {
|
||||
f.logoutCalled = true
|
||||
f.currentUser = nil
|
||||
}
|
||||
|
||||
func (f *fakeAuthService) GetLogoutURL() string {
|
||||
@@ -428,7 +440,8 @@ func TestSignatureHandlers_NewSignatureHandlers(t *testing.T) {
|
||||
organisation := "Organisation"
|
||||
adminEmails := []string{"admin@example.com"}
|
||||
|
||||
handlers := NewSignatureHandlers(signatureService, userService, tmpl, baseURL, organisation, adminEmails)
|
||||
autoLogin := false
|
||||
handlers := NewSignatureHandlers(signatureService, userService, tmpl, baseURL, organisation, adminEmails, autoLogin)
|
||||
|
||||
if handlers == nil {
|
||||
t.Error("NewSignatureHandlers should not return nil")
|
||||
@@ -442,6 +455,8 @@ func TestSignatureHandlers_NewSignatureHandlers(t *testing.T) {
|
||||
t.Error("BaseURL not set correctly")
|
||||
} else if handlers.organisation != organisation {
|
||||
t.Error("Organisation not set correctly")
|
||||
} else if handlers.autoLogin != autoLogin {
|
||||
t.Error("AutoLogin not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -449,7 +464,7 @@ func TestSignatureHandlers_HandleIndex(t *testing.T) {
|
||||
signatureService := newFakeSignatureService()
|
||||
userService := newFakeUserService()
|
||||
tmpl := createTestTemplate()
|
||||
handlers := NewSignatureHandlers(signatureService, userService, tmpl, "https://example.com", "Organisation", []string{})
|
||||
handlers := NewSignatureHandlers(signatureService, userService, tmpl, "https://example.com", "Organisation", []string{}, false)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
@@ -538,7 +553,7 @@ func TestSignatureHandlers_HandleSignGET(t *testing.T) {
|
||||
tt.setupSig(signatureService)
|
||||
|
||||
tmpl := createTestTemplate()
|
||||
handlers := NewSignatureHandlers(signatureService, userService, tmpl, "https://example.com", "Organisation", []string{})
|
||||
handlers := NewSignatureHandlers(signatureService, userService, tmpl, "https://example.com", "Organisation", []string{}, false)
|
||||
|
||||
req := httptest.NewRequest("GET", "/sign", nil)
|
||||
if tt.docParam != "" {
|
||||
@@ -638,7 +653,7 @@ func TestSignatureHandlers_HandleSignPOST(t *testing.T) {
|
||||
tt.setupSig(signatureService)
|
||||
|
||||
tmpl := createTestTemplate()
|
||||
handlers := NewSignatureHandlers(signatureService, userService, tmpl, "https://example.com", "Organisation", []string{})
|
||||
handlers := NewSignatureHandlers(signatureService, userService, tmpl, "https://example.com", "Organisation", []string{}, false)
|
||||
|
||||
form := url.Values{}
|
||||
for key, value := range tt.formData {
|
||||
@@ -702,7 +717,7 @@ func TestSignatureHandlers_HandleStatusJSON(t *testing.T) {
|
||||
tt.setupSig(signatureService)
|
||||
|
||||
tmpl := createTestTemplate()
|
||||
handlers := NewSignatureHandlers(signatureService, userService, tmpl, "https://example.com", "Organisation", []string{})
|
||||
handlers := NewSignatureHandlers(signatureService, userService, tmpl, "https://example.com", "Organisation", []string{}, false)
|
||||
|
||||
req := httptest.NewRequest("GET", "/status", nil)
|
||||
if tt.docParam != "" {
|
||||
@@ -769,7 +784,7 @@ func TestSignatureHandlers_HandleUserSignatures(t *testing.T) {
|
||||
tt.setupSig(signatureService)
|
||||
|
||||
tmpl := createTestTemplate()
|
||||
handlers := NewSignatureHandlers(signatureService, userService, tmpl, "https://example.com", "Organisation", []string{})
|
||||
handlers := NewSignatureHandlers(signatureService, userService, tmpl, "https://example.com", "Organisation", []string{}, false)
|
||||
|
||||
req := httptest.NewRequest("GET", "/signatures", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
)
|
||||
|
||||
type authService interface {
|
||||
GetUser(r *http.Request) (*models.User, error)
|
||||
SetUser(w http.ResponseWriter, r *http.Request, user *models.User) error
|
||||
Logout(w http.ResponseWriter, r *http.Request)
|
||||
GetLogoutURL() string
|
||||
|
||||
@@ -24,13 +24,28 @@ func NewAuthMiddleware(userService userService, baseURL string) *AuthMiddleware
|
||||
|
||||
func (m *AuthMiddleware) RequireAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := m.userService.GetUser(r)
|
||||
user, err := m.userService.GetUser(r)
|
||||
if err != nil {
|
||||
logger.Logger.Debug("RequireAuth: user not authenticated",
|
||||
"error", err.Error(),
|
||||
"path", r.URL.Path,
|
||||
"query", r.URL.Query().Encode())
|
||||
|
||||
nextURL := m.baseURL + r.URL.RequestURI()
|
||||
loginURL := buildLoginURL(nextURL)
|
||||
|
||||
logger.Logger.Debug("RequireAuth: redirecting to login",
|
||||
"next_url", nextURL,
|
||||
"login_url", loginURL)
|
||||
|
||||
http.Redirect(w, r, loginURL, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Logger.Debug("RequireAuth: user authenticated",
|
||||
"user_email", user.Email,
|
||||
"path", r.URL.Path)
|
||||
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,9 +32,10 @@ type SignatureHandlers struct {
|
||||
baseURL string
|
||||
organisation string
|
||||
adminEmails []string
|
||||
autoLogin bool
|
||||
}
|
||||
|
||||
func NewSignatureHandlers(signatureService signatureService, userService userService, tmpl *template.Template, baseURL, organisation string, adminEmails []string) *SignatureHandlers {
|
||||
func NewSignatureHandlers(signatureService signatureService, userService userService, tmpl *template.Template, baseURL, organisation string, adminEmails []string, autoLogin bool) *SignatureHandlers {
|
||||
return &SignatureHandlers{
|
||||
signatureService: signatureService,
|
||||
userService: userService,
|
||||
@@ -42,6 +43,7 @@ func NewSignatureHandlers(signatureService signatureService, userService userSer
|
||||
baseURL: baseURL,
|
||||
organisation: organisation,
|
||||
adminEmails: adminEmails,
|
||||
autoLogin: autoLogin,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,6 +61,7 @@ type PageData struct {
|
||||
IsAdmin bool
|
||||
Lang string
|
||||
T map[string]string
|
||||
AutoLogin bool
|
||||
ServiceInfo *struct {
|
||||
Name string
|
||||
Icon string
|
||||
@@ -283,6 +286,9 @@ func (h *SignatureHandlers) render(w http.ResponseWriter, r *http.Request, templ
|
||||
data.IsAdmin = admin.IsAdminUser(data.User, h.adminEmails)
|
||||
}
|
||||
|
||||
// Set AutoLogin from handler config
|
||||
data.AutoLogin = h.autoLogin
|
||||
|
||||
// Get language and translations from context
|
||||
ctx := r.Context()
|
||||
if data.Lang == "" {
|
||||
@@ -305,6 +311,7 @@ func (h *SignatureHandlers) render(w http.ResponseWriter, r *http.Request, templ
|
||||
"IsAdmin": data.IsAdmin,
|
||||
"Lang": data.Lang,
|
||||
"T": data.T,
|
||||
"AutoLogin": data.AutoLogin,
|
||||
}
|
||||
|
||||
if err := h.template.ExecuteTemplate(w, "base", templateData); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user