feat: add magic link authentication

- Now can activate OIDC and/or MagicLink for user authentication.
- Add page to choose authentication method (if only OIDC is enabled, auto redirecting to login screen)
This commit is contained in:
Benjamin
2025-11-04 18:15:06 +01:00
parent f4b3430f06
commit 32b469f04e
57 changed files with 6895 additions and 729 deletions

View File

@@ -11,7 +11,20 @@ POSTGRES_USER=ackifyr
POSTGRES_PASSWORD=your_secure_password
POSTGRES_DB=ackify
# OAuth2 Configuration - Generic Provider
# ============================================================================
# Authentication Configuration
# ============================================================================
# At least ONE authentication method must be enabled (OAuth or MagicLink)
#
# AUTO-DETECTION:
# - OAuth is enabled if ACKIFY_OAUTH_CLIENT_ID and ACKIFY_OAUTH_CLIENT_SECRET are set
# - MagicLink is enabled if ACKIFY_MAIL_HOST is configured
#
# You can override auto-detection with these variables:
# ACKIFY_AUTH_OAUTH_ENABLED=true
# ACKIFY_AUTH_MAGICLINK_ENABLED=true
# OAuth2 Configuration (OPTIONAL - remove if using MagicLink only)
ACKIFY_OAUTH_CLIENT_ID=your_oauth_client_id
ACKIFY_OAUTH_CLIENT_SECRET=your_oauth_client_secret
ACKIFY_OAUTH_ALLOWED_DOMAIN=your-organization.com
@@ -35,6 +48,17 @@ ACKIFY_OAUTH_PROVIDER=google
# GitLab specific (if using gitlab as provider and self-hosted)
# ACKIFY_OAUTH_GITLAB_URL=https://gitlab.your-company.com
# Email Configuration for MagicLink Authentication (OPTIONAL - required for MagicLink)
# If configured, enables passwordless authentication via email
# ACKIFY_MAIL_HOST=smtp.example.com
# ACKIFY_MAIL_PORT=587
# ACKIFY_MAIL_USERNAME=your_smtp_username
# ACKIFY_MAIL_PASSWORD=your_smtp_password
# ACKIFY_MAIL_FROM=noreply@example.com
# ACKIFY_MAIL_FROM_NAME=Ackify
# ACKIFY_MAIL_TLS=true
# ACKIFY_MAIL_STARTTLS=true
# Security Configuration
ACKIFY_OAUTH_COOKIE_SECRET=your_base64_encoded_secret_key
ACKIFY_ED25519_PRIVATE_KEY=your_base64_encoded_ed25519_private_key

139
.github/workflows/e2e-tests.yml vendored Normal file
View File

@@ -0,0 +1,139 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
name: E2E Tests
on:
push:
branches: [main, develop]
pull_request:
branches: [main, develop]
jobs:
cypress-run:
runs-on: ubuntu-latest
services:
postgres:
image: postgres:16-alpine
env:
POSTGRES_USER: ackify
POSTGRES_PASSWORD: testpassword
POSTGRES_DB: ackify_test
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
mailhog:
image: mailhog/mailhog:latest
ports:
- 1025:1025
- 8025:8025
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: '1.24.5'
cache: true
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '22'
cache: 'npm'
cache-dependency-path: webapp/package-lock.json
- name: Install frontend dependencies
working-directory: webapp
run: npm ci
- name: Build frontend
working-directory: webapp
run: npm run build
- name: Install backend dependencies
run: go mod download
- name: Run database migrations
env:
ACKIFY_DB_DSN: "postgres://ackify:testpassword@localhost:5432/ackify_test?sslmode=disable"
run: |
go run ./backend/cmd/migrate/main.go up
- name: Generate Ed25519 keys
run: |
go run ./backend/cmd/community/keygen.go > /tmp/ed25519.key || true
if [ ! -f /tmp/ed25519.key ]; then
echo "Generating Ed25519 key for testing"
# Generate a test key if keygen doesn't exist
echo "test_private_key_base64_encoded_here" > /tmp/ed25519.key
fi
- name: Start Ackify server
env:
ACKIFY_DB_DSN: "postgres://ackify:testpassword@localhost:5432/ackify_test?sslmode=disable"
ACKIFY_BASE_URL: "http://localhost:8080"
ACKIFY_ORGANISATION: "Ackify Test"
ACKIFY_OAUTH_PROVIDER: "github"
ACKIFY_OAUTH_CLIENT_ID: "test_client_id"
ACKIFY_OAUTH_CLIENT_SECRET: "test_client_secret"
ACKIFY_OAUTH_COOKIE_SECRET: "dGVzdF9jb29raWVfc2VjcmV0X2Zvcl90ZXN0aW5nXzEyMzQ1Njc4OTA="
ACKIFY_ED25519_PRIVATE_KEY: "dGVzdF9wcml2YXRlX2tleV9mb3JfdGVzdGluZ19vbmx5XzEyMzQ1Njc4OTA="
ACKIFY_LISTEN_ADDR: ":8080"
ACKIFY_ADMIN_EMAILS: "admin@test.com"
ACKIFY_MAIL_HOST: "localhost"
ACKIFY_MAIL_PORT: "1025"
ACKIFY_MAIL_TLS: "false"
ACKIFY_MAIL_STARTTLS: "false"
ACKIFY_MAIL_FROM: "noreply@ackify.test"
ACKIFY_MAIL_FROM_NAME: "Ackify Test"
ACKIFY_LOG_LEVEL: "debug"
run: |
go build -o ackify ./backend/cmd/community
./ackify &
echo $! > /tmp/ackify.pid
# Wait for server to be ready
timeout 30 bash -c 'until curl -s http://localhost:8080/api/v1/health > /dev/null; do sleep 1; done'
- name: Run Cypress tests
uses: cypress-io/github-action@v6
with:
working-directory: webapp
install: false
wait-on: 'http://localhost:8080/api/v1/health'
wait-on-timeout: 60
browser: chrome
headless: true
env:
CYPRESS_baseUrl: http://localhost:8080
CYPRESS_mailhogUrl: http://localhost:8025
- name: Upload Cypress screenshots
uses: actions/upload-artifact@v4
if: failure()
with:
name: cypress-screenshots
path: webapp/cypress/screenshots
retention-days: 7
- name: Upload Cypress videos
uses: actions/upload-artifact@v4
if: failure()
with:
name: cypress-videos
path: webapp/cypress/videos
retention-days: 7
- name: Stop Ackify server
if: always()
run: |
if [ -f /tmp/ackify.pid ]; then
kill $(cat /tmp/ackify.pid) || true
fi

View File

@@ -29,7 +29,7 @@ Prove that collaborators have read and acknowledged important documents with **E
**Key Features**:
- ✅ Ed25519 cryptographic signatures
-OAuth2 authentication (Google, GitHub, GitLab, custom)
-**Flexible authentication**: OAuth2 (Google, GitHub, GitLab, custom) or MagicLink (passwordless email)
- ✅ One signature per user/document (database enforced)
- ✅ Immutable audit trail
- ✅ Expected signers tracking with email reminders
@@ -45,7 +45,9 @@ Prove that collaborators have read and acknowledged important documents with **E
### Prerequisites
- Docker & Docker Compose
- OAuth2 credentials (Google, GitHub, or GitLab)
- **At least ONE authentication method**:
- OAuth2 credentials (Google, GitHub, or GitLab), OR
- SMTP server for MagicLink (passwordless email authentication)
### Installation
@@ -111,15 +113,31 @@ POSTGRES_USER=ackifyr
POSTGRES_PASSWORD=your_secure_password
POSTGRES_DB=ackify
# OAuth2 (example with Google)
# Security (generate with: openssl rand -base64 32)
ACKIFY_OAUTH_COOKIE_SECRET=your_base64_secret
# ============================================================================
# Authentication (choose AT LEAST ONE method)
# ============================================================================
# Option 1: OAuth2 (Google, GitHub, GitLab, custom)
ACKIFY_OAUTH_PROVIDER=google
ACKIFY_OAUTH_CLIENT_ID=your_client_id
ACKIFY_OAUTH_CLIENT_SECRET=your_client_secret
# Security (generate with: openssl rand -base64 32)
ACKIFY_OAUTH_COOKIE_SECRET=your_base64_secret
# Option 2: MagicLink (passwordless email authentication)
# ACKIFY_MAIL_HOST=smtp.example.com
# ACKIFY_MAIL_PORT=587
# ACKIFY_MAIL_USERNAME=your_smtp_username
# ACKIFY_MAIL_PASSWORD=your_smtp_password
# ACKIFY_MAIL_FROM=noreply@example.com
```
**Auto-detection**:
- OAuth is enabled automatically if `ACKIFY_OAUTH_CLIENT_ID` and `ACKIFY_OAUTH_CLIENT_SECRET` are set
- MagicLink is enabled automatically if `ACKIFY_MAIL_HOST` is configured
- You can use **both methods simultaneously** for maximum flexibility
See [docs/en/configuration.md](docs/en/configuration.md) for all options.
---
@@ -175,7 +193,7 @@ See [docs/en/configuration.md](docs/en/configuration.md) for all options.
https://your-domain.com/?doc=security_policy_2025
```
User authenticates via OAuth2 and signs with one click.
User authenticates (OAuth2 or MagicLink) and signs with one click.
### Embed in Your Tools

View File

@@ -29,7 +29,7 @@ Prouvez que vos collaborateurs ont lu et pris connaissance de documents importan
**Fonctionnalités clés** :
- ✅ Signatures cryptographiques Ed25519
- ✅ Authentification OAuth2 (Google, GitHub, GitLab, custom)
-**Authentification flexible** : OAuth2 (Google, GitHub, GitLab, custom) ou MagicLink (email sans mot de passe)
- ✅ Une signature par utilisateur/document (contrainte base de données)
- ✅ Piste d'audit immutable
- ✅ Tracking signataires attendus avec rappels email
@@ -45,7 +45,9 @@ Prouvez que vos collaborateurs ont lu et pris connaissance de documents importan
### Prérequis
- Docker & Docker Compose
- Credentials OAuth2 (Google, GitHub, ou GitLab)
- **Au moins UNE méthode d'authentification** :
- Credentials OAuth2 (Google, GitHub, ou GitLab), OU
- Serveur SMTP pour MagicLink (authentification email sans mot de passe)
### Installation
@@ -111,15 +113,31 @@ POSTGRES_USER=ackifyr
POSTGRES_PASSWORD=votre_mot_de_passe_securise
POSTGRES_DB=ackify
# OAuth2 (exemple avec Google)
# Sécurité (générer avec: openssl rand -base64 32)
ACKIFY_OAUTH_COOKIE_SECRET=votre_secret_base64
# ============================================================================
# Authentification (choisir AU MOINS UNE méthode)
# ============================================================================
# Option 1 : OAuth2 (Google, GitHub, GitLab, custom)
ACKIFY_OAUTH_PROVIDER=google
ACKIFY_OAUTH_CLIENT_ID=votre_client_id
ACKIFY_OAUTH_CLIENT_SECRET=votre_client_secret
# Sécurité (générer avec: openssl rand -base64 32)
ACKIFY_OAUTH_COOKIE_SECRET=votre_secret_base64
# Option 2 : MagicLink (authentification email sans mot de passe)
# ACKIFY_MAIL_HOST=smtp.example.com
# ACKIFY_MAIL_PORT=587
# ACKIFY_MAIL_USERNAME=votre_utilisateur_smtp
# ACKIFY_MAIL_PASSWORD=votre_mot_de_passe_smtp
# ACKIFY_MAIL_FROM=noreply@example.com
```
**Auto-détection** :
- OAuth activé automatiquement si `ACKIFY_OAUTH_CLIENT_ID` et `ACKIFY_OAUTH_CLIENT_SECRET` sont définis
- MagicLink activé automatiquement si `ACKIFY_MAIL_HOST` est configuré
- Vous pouvez utiliser **les deux méthodes simultanément** pour une flexibilité maximale
Voir [docs/fr/configuration.md](docs/fr/configuration.md) pour toutes les options.
---
@@ -175,7 +193,7 @@ Voir [docs/fr/configuration.md](docs/fr/configuration.md) pour toutes les option
https://votre-domaine.com/?doc=politique_securite_2025
```
L'utilisateur s'authentifie via OAuth2 et signe en un clic.
L'utilisateur s'authentifie (OAuth2 ou MagicLink) et signe en un clic.
### Intégrer dans vos Outils

View File

@@ -0,0 +1,272 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package services
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"net/mail"
"net/url"
"strings"
"time"
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/email"
"github.com/btouchard/ackify-ce/backend/pkg/logger"
)
// MagicLinkRepository définit les opérations sur les tokens Magic Link
type MagicLinkRepository interface {
CreateToken(ctx context.Context, token *models.MagicLinkToken) error
GetByToken(ctx context.Context, token string) (*models.MagicLinkToken, error)
MarkAsUsed(ctx context.Context, token string, ip string, userAgent string) error
DeleteExpired(ctx context.Context) (int64, error)
LogAttempt(ctx context.Context, attempt *models.MagicLinkAuthAttempt) error
CountRecentAttempts(ctx context.Context, email string, since time.Time) (int, error)
CountRecentAttemptsByIP(ctx context.Context, ip string, since time.Time) (int, error)
}
// MagicLinkService gère l'authentification par Magic Link
type MagicLinkService struct {
repo MagicLinkRepository
emailSender email.Sender
baseURL string
appName string
allowedDomains []string // Domaines email autorisés (vide = tous)
tokenValidity time.Duration
}
// MagicLinkServiceConfig pour le service Magic Link
type MagicLinkServiceConfig struct {
Repository MagicLinkRepository
EmailSender email.Sender
BaseURL string
AppName string
AllowedDomains []string
TokenValidity time.Duration // Défaut: 15 minutes
}
func NewMagicLinkService(cfg MagicLinkServiceConfig) *MagicLinkService {
if cfg.TokenValidity == 0 {
cfg.TokenValidity = 15 * time.Minute
}
if cfg.AppName == "" {
cfg.AppName = "Ackify"
}
return &MagicLinkService{
repo: cfg.Repository,
emailSender: cfg.EmailSender,
baseURL: cfg.BaseURL,
appName: cfg.AppName,
allowedDomains: cfg.AllowedDomains,
tokenValidity: cfg.TokenValidity,
}
}
// RequestMagicLink génère et envoie un Magic Link par email
func (s *MagicLinkService) RequestMagicLink(
ctx context.Context,
emailAddr string,
redirectTo string,
ip string,
userAgent string,
) error {
// Normaliser l'email
emailAddr = strings.ToLower(strings.TrimSpace(emailAddr))
// Valider le format email
if _, err := mail.ParseAddress(emailAddr); err != nil {
s.logAttempt(ctx, emailAddr, false, "invalid_email_format", ip, userAgent)
return fmt.Errorf("invalid email format")
}
// Vérifier le domaine autorisé si configuré
if len(s.allowedDomains) > 0 {
allowed := false
for _, domain := range s.allowedDomains {
if strings.HasSuffix(emailAddr, "@"+domain) {
allowed = true
break
}
}
if !allowed {
s.logAttempt(ctx, emailAddr, false, "domain_not_allowed", ip, userAgent)
return fmt.Errorf("email domain not allowed")
}
}
// Rate limiting par email (max 3/heure)
since := time.Now().Add(-1 * time.Hour)
count, err := s.repo.CountRecentAttempts(ctx, emailAddr, since)
if err != nil {
logger.Logger.Error("Failed to check rate limit for email", "email", emailAddr, "error", err)
return fmt.Errorf("rate limit check failed")
}
if count >= 3 {
s.logAttempt(ctx, emailAddr, false, "rate_limit_exceeded_email", ip, userAgent)
// Ne pas révéler le rate limiting pour éviter l'énumération
logger.Logger.Warn("Magic Link rate limit exceeded", "email", emailAddr, "count", count)
// On retourne success pour ne pas révéler qu'on a bloqué
return nil
}
// Rate limiting par IP (max 10/heure)
countIP, err := s.repo.CountRecentAttemptsByIP(ctx, ip, since)
if err != nil {
logger.Logger.Error("Failed to check rate limit for IP", "ip", ip, "error", err)
return fmt.Errorf("rate limit check failed")
}
if countIP >= 10 {
s.logAttempt(ctx, emailAddr, false, "rate_limit_exceeded_ip", ip, userAgent)
logger.Logger.Warn("Magic Link IP rate limit exceeded", "ip", ip, "count", countIP)
return nil
}
// Générer un token cryptographiquement sécurisé
token, err := s.generateSecureToken()
if err != nil {
s.logAttempt(ctx, emailAddr, false, "token_generation_failed", ip, userAgent)
return fmt.Errorf("failed to generate token: %w", err)
}
// Créer le token en DB
magicToken := &models.MagicLinkToken{
Token: token,
Email: emailAddr,
ExpiresAt: time.Now().Add(s.tokenValidity),
RedirectTo: redirectTo,
CreatedByIP: ip,
CreatedByUserAgent: userAgent,
}
if err := s.repo.CreateToken(ctx, magicToken); err != nil {
s.logAttempt(ctx, emailAddr, false, "database_error", ip, userAgent)
return fmt.Errorf("failed to create token: %w", err)
}
// Construire le lien magique avec URL encoding du redirect
redirectEncoded := url.QueryEscape(redirectTo)
magicLink := fmt.Sprintf("%s/api/v1/auth/magic-link/verify?token=%s&redirect=%s", s.baseURL, token, redirectEncoded)
// Déterminer la locale (TODO: implémenter détection de langue préférée)
locale := "en" // Défaut
// Envoyer l'email
msg := email.Message{
To: []string{emailAddr},
Subject: "Your login link",
Template: "magic_link",
Locale: locale,
Data: map[string]interface{}{
"AppName": s.appName,
"Email": emailAddr,
"MagicLink": magicLink,
"ExpiresIn": int(s.tokenValidity.Minutes()),
"BaseURL": s.baseURL,
},
}
if err := s.emailSender.Send(ctx, msg); err != nil {
s.logAttempt(ctx, emailAddr, false, "email_send_failed", ip, userAgent)
return fmt.Errorf("failed to send email: %w", err)
}
// Log succès
s.logAttempt(ctx, emailAddr, true, "", ip, userAgent)
logger.Logger.Info("Magic Link sent successfully",
"email", emailAddr,
"expires_in", s.tokenValidity,
"ip", ip)
return nil
}
// VerifyMagicLink vérifie et consomme un token Magic Link
func (s *MagicLinkService) VerifyMagicLink(
ctx context.Context,
token string,
ip string,
userAgent string,
) (*models.MagicLinkToken, error) {
// Récupérer le token
magicToken, err := s.repo.GetByToken(ctx, token)
if err != nil {
logger.Logger.Warn("Magic Link token not found", "token_prefix", token[:min(8, len(token))])
return nil, fmt.Errorf("invalid token")
}
// Vérifier la validité
if !magicToken.IsValid() {
if magicToken.UsedAt != nil {
logger.Logger.Warn("Magic Link token already used",
"email", magicToken.Email,
"used_at", magicToken.UsedAt)
return nil, fmt.Errorf("token already used")
}
logger.Logger.Warn("Magic Link token expired",
"email", magicToken.Email,
"expires_at", magicToken.ExpiresAt)
return nil, fmt.Errorf("token expired")
}
// Marquer comme utilisé
if err := s.repo.MarkAsUsed(ctx, token, ip, userAgent); err != nil {
logger.Logger.Error("Failed to mark token as used", "error", err)
return nil, fmt.Errorf("failed to mark token as used: %w", err)
}
logger.Logger.Info("Magic Link verified successfully",
"email", magicToken.Email,
"ip", ip)
return magicToken, nil
}
// generateSecureToken génère un token cryptographiquement sécurisé
func (s *MagicLinkService) generateSecureToken() (string, error) {
bytes := make([]byte, 32) // 256 bits
if _, err := rand.Read(bytes); err != nil {
return "", err
}
// Base64 URL-safe encoding (sans padding)
return base64.RawURLEncoding.EncodeToString(bytes), nil
}
// logAttempt enregistre une tentative d'authentification
func (s *MagicLinkService) logAttempt(
ctx context.Context,
email string,
success bool,
failureReason string,
ip string,
userAgent string,
) {
attempt := &models.MagicLinkAuthAttempt{
Email: email,
Success: success,
FailureReason: failureReason,
IPAddress: ip,
UserAgent: userAgent,
}
if err := s.repo.LogAttempt(ctx, attempt); err != nil {
logger.Logger.Error("Failed to log Magic Link attempt", "error", err)
}
}
// CleanupExpiredTokens supprime les tokens expirés (à appeler périodiquement)
func (s *MagicLinkService) CleanupExpiredTokens(ctx context.Context) (int64, error) {
return s.repo.DeleteExpired(ctx)
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@@ -0,0 +1,41 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package models
import "time"
// MagicLinkToken représente un token de connexion Magic Link
type MagicLinkToken struct {
ID int64 `json:"id" db:"id"`
Token string `json:"token" db:"token"`
Email string `json:"email" db:"email"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
ExpiresAt time.Time `json:"expires_at" db:"expires_at"`
UsedAt *time.Time `json:"used_at,omitempty" db:"used_at"`
UsedByIP *string `json:"used_by_ip,omitempty" db:"used_by_ip"`
UsedByUserAgent *string `json:"used_by_user_agent,omitempty" db:"used_by_user_agent"`
RedirectTo string `json:"redirect_to" db:"redirect_to"` // URL destination après auth (ex: /?doc=xxx)
CreatedByIP string `json:"created_by_ip" db:"created_by_ip"`
CreatedByUserAgent string `json:"created_by_user_agent" db:"created_by_user_agent"`
}
// IsValid vérifie si le token est valide (non expiré, non utilisé)
func (t *MagicLinkToken) IsValid() bool {
if t.UsedAt != nil {
return false // Déjà utilisé
}
if time.Now().After(t.ExpiresAt) {
return false // Expiré
}
return true
}
// MagicLinkAuthAttempt représente une tentative d'authentification
type MagicLinkAuthAttempt struct {
ID int64 `json:"id" db:"id"`
Email string `json:"email" db:"email"`
Success bool `json:"success" db:"success"`
FailureReason string `json:"failure_reason,omitempty" db:"failure_reason"`
IPAddress string `json:"ip_address" db:"ip_address"`
UserAgent string `json:"user_agent,omitempty" db:"user_agent"`
AttemptedAt time.Time `json:"attempted_at" db:"attempted_at"`
}

View File

@@ -3,20 +3,10 @@ package auth
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
"golang.org/x/oauth2"
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
"github.com/btouchard/ackify-ce/backend/pkg/crypto"
"github.com/btouchard/ackify-ce/backend/pkg/logger"
)
@@ -31,16 +21,12 @@ type SessionRepository interface {
DeleteExpired(ctx context.Context, olderThan time.Duration) (int64, error)
}
// OauthService is a wrapper that composes SessionService and OAuthProvider
// SessionService is ALWAYS present (required for all auth methods)
// OAuthProvider is OPTIONAL (nil if OAuth is disabled)
type OauthService struct {
oauthConfig *oauth2.Config
sessionStore *sessions.CookieStore
userInfoURL string
logoutURL string
allowedDomain string
secureCookies bool
baseURL string
sessionRepo SessionRepository
encryptionKey []byte
SessionService *SessionService // ALWAYS present - manages user sessions
OAuthProvider *OAuthProvider // OPTIONAL - nil if OAuth disabled
}
type Config struct {
@@ -59,522 +45,92 @@ type Config struct {
}
func NewOAuthService(config Config) *OauthService {
oauthConfig := &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.BaseURL + "/api/v1/auth/callback",
Scopes: config.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: config.AuthURL,
TokenURL: config.TokenURL,
},
}
// Create SessionService (ALWAYS required)
sessionService := NewSessionService(SessionServiceConfig{
CookieSecret: config.CookieSecret,
SecureCookies: config.SecureCookies,
SessionRepo: config.SessionRepo,
})
sessionStore := sessions.NewCookieStore(config.CookieSecret)
// Configure session options globally on the store
sessionStore.Options = &sessions.Options{
Path: "/",
HttpOnly: true,
Secure: config.SecureCookies,
SameSite: http.SameSiteLaxMode,
MaxAge: 86400 * 30, // 30 days
}
logger.Logger.Info("OAuth session store configured",
"secure_cookies", config.SecureCookies,
"max_age_days", 30)
// Use CookieSecret as encryption key (must be 32 bytes for AES-256)
encryptionKey := config.CookieSecret
if len(encryptionKey) < 32 {
logger.Logger.Warn("Encryption key too short, padding to 32 bytes",
"original_length", len(encryptionKey))
// Pad with zeros (not ideal, but prevents crashes)
padded := make([]byte, 32)
copy(padded, encryptionKey)
encryptionKey = padded
} else if len(encryptionKey) > 32 {
// Truncate to 32 bytes for AES-256
encryptionKey = encryptionKey[:32]
// Create OAuthProvider (only if OAuth is configured)
// For now, we always create it for backward compatibility
// Later, this will be conditional based on config flags
var oauthProvider *OAuthProvider
if config.ClientID != "" && config.ClientSecret != "" {
oauthProvider = NewOAuthProvider(OAuthProviderConfig{
BaseURL: config.BaseURL,
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
AuthURL: config.AuthURL,
TokenURL: config.TokenURL,
UserInfoURL: config.UserInfoURL,
LogoutURL: config.LogoutURL,
Scopes: config.Scopes,
AllowedDomain: config.AllowedDomain,
SessionSvc: sessionService,
})
logger.Logger.Info("OAuth service configured with OAuth provider")
} else {
logger.Logger.Info("OAuth service configured WITHOUT OAuth provider (session-only mode)")
}
return &OauthService{
oauthConfig: oauthConfig,
sessionStore: sessionStore,
userInfoURL: config.UserInfoURL,
logoutURL: config.LogoutURL,
allowedDomain: config.AllowedDomain,
secureCookies: config.SecureCookies,
baseURL: config.BaseURL,
sessionRepo: config.SessionRepo,
encryptionKey: encryptionKey,
SessionService: sessionService,
OAuthProvider: oauthProvider,
}
}
// Session management methods - delegate to SessionService
func (s *OauthService) GetUser(r *http.Request) (*models.User, error) {
session, err := s.sessionStore.Get(r, sessionName)
if err != nil {
logger.Logger.Debug("GetUser: failed to get session", "error", err.Error())
return nil, fmt.Errorf("failed to get session: %w", err)
}
userJSON, ok := session.Values["user"].(string)
if !ok || userJSON == "" {
logger.Logger.Debug("GetUser: no user in session",
"user_key_exists", ok,
"user_json_empty", userJSON == "")
return nil, models.ErrUnauthorized
}
var user models.User
if err := json.Unmarshal([]byte(userJSON), &user); err != nil {
logger.Logger.Error("GetUser: failed to unmarshal user", "error", err.Error())
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
}
logger.Logger.Debug("GetUser: user found", "email", user.Email)
return &user, nil
return s.SessionService.GetUser(r)
}
func (s *OauthService) SetUser(w http.ResponseWriter, r *http.Request, user *models.User) error {
// Always create a fresh new session to ensure session ID is generated
// This fixes an issue where reusing an existing invalid session results in empty session.ID
session, err := s.sessionStore.New(r, sessionName)
if err != nil {
logger.Logger.Error("SetUser: failed to create new session", "error", err.Error())
return fmt.Errorf("failed to create new session: %w", err)
}
userJSON, err := json.Marshal(user)
if err != nil {
logger.Logger.Error("SetUser: failed to marshal user", "error", err.Error())
return fmt.Errorf("failed to marshal user: %w", err)
}
logger.Logger.Debug("SetUser: saving user to new session",
"email", user.Email,
"secure_cookies", s.secureCookies,
"session_is_new", session.IsNew)
session.Values["user"] = string(userJSON)
// Session options are already configured globally on the store
// No need to set them again here
if err := session.Save(r, w); err != nil {
logger.Logger.Error("SetUser: failed to save session",
"error", err.Error(),
"session_is_new", session.IsNew,
"session_id_length", len(session.ID))
return fmt.Errorf("failed to save session: %w", err)
}
logger.Logger.Info("SetUser: session saved successfully",
"email", user.Email,
"session_id_length", len(session.ID))
return nil
return s.SessionService.SetUser(w, r, user)
}
func (s *OauthService) Logout(w http.ResponseWriter, r *http.Request) {
session, _ := s.sessionStore.Get(r, sessionName)
session.Options.MaxAge = -1
_ = session.Save(r, w)
s.SessionService.Logout(w, r)
}
// GetLogoutURL returns the SSO logout URL if configured, otherwise returns empty string
// OAuth methods - delegate to OAuthProvider (nil-safe)
func (s *OauthService) GetLogoutURL() string {
if s.logoutURL == "" {
if s.OAuthProvider == nil {
return ""
}
// For most providers, add post_logout_redirect_uri or continue parameter
logoutURL := s.logoutURL
if s.baseURL != "" {
// Google and OIDC providers use post_logout_redirect_uri
// GitHub uses a simple redirect
// GitLab uses a redirect parameter
logoutURL += "?continue=" + s.baseURL
}
return logoutURL
}
func (s *OauthService) GetAuthURL(nextURL string) string {
state := base64.RawURLEncoding.EncodeToString(securecookie.GenerateRandomKey(20)) +
":" + base64.RawURLEncoding.EncodeToString([]byte(nextURL))
return s.oauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("prompt", "select_account"))
return s.OAuthProvider.GetLogoutURL()
}
func (s *OauthService) CreateAuthURL(w http.ResponseWriter, r *http.Request, nextURL string) string {
// Generate PKCE code verifier and challenge
codeVerifier, err := crypto.GenerateCodeVerifier()
if err != nil {
logger.Logger.Error("Failed to generate PKCE code verifier", "error", err.Error())
// Fallback to OAuth flow without PKCE for backward compatibility
return s.createAuthURLWithoutPKCE(w, r, nextURL)
if s.OAuthProvider == nil {
logger.Logger.Error("CreateAuthURL called but OAuth provider is nil")
return ""
}
codeChallenge := crypto.GenerateCodeChallenge(codeVerifier)
logger.Logger.Debug("Generated PKCE parameters for OAuth flow")
// Generate state token
randPart := securecookie.GenerateRandomKey(20)
token := base64.RawURLEncoding.EncodeToString(randPart)
state := token + ":" + base64.RawURLEncoding.EncodeToString([]byte(nextURL))
promptParam := "select_account"
isSilent := r.URL.Query().Get("silent") == "true"
if isSilent {
promptParam = "none"
}
logger.Logger.Info("Starting OAuth flow with PKCE",
"next_url", nextURL,
"silent", isSilent,
"state_token_length", len(token))
session, err := s.sessionStore.Get(r, sessionName)
if err != nil {
logger.Logger.Error("CreateAuthURL: failed to get session from store", "error", err.Error())
// Create a new empty session if Get fails
session, _ = s.sessionStore.New(r, sessionName)
}
// Store state and code_verifier in session
session.Values["oauth_state"] = token
session.Values["code_verifier"] = codeVerifier
err = session.Save(r, w)
if err != nil {
logger.Logger.Error("CreateAuthURL: failed to save session", "error", err.Error())
}
// Generate OAuth URL with PKCE parameters
authURL := s.oauthConfig.AuthCodeURL(state,
oauth2.SetAuthURLParam("prompt", promptParam),
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"))
logger.Logger.Debug("CreateAuthURL: generated auth URL with PKCE",
"prompt", promptParam,
"url_length", len(authURL))
return authURL
}
// createAuthURLWithoutPKCE is a fallback method for OAuth without PKCE
// Used for backward compatibility if PKCE generation fails
func (s *OauthService) createAuthURLWithoutPKCE(w http.ResponseWriter, r *http.Request, nextURL string) string {
randPart := securecookie.GenerateRandomKey(20)
token := base64.RawURLEncoding.EncodeToString(randPart)
state := token + ":" + base64.RawURLEncoding.EncodeToString([]byte(nextURL))
promptParam := "select_account"
isSilent := r.URL.Query().Get("silent") == "true"
if isSilent {
promptParam = "none"
}
logger.Logger.Warn("Starting OAuth flow WITHOUT PKCE (fallback mode)",
"next_url", nextURL,
"silent", isSilent)
session, err := s.sessionStore.Get(r, sessionName)
if err != nil {
session, _ = s.sessionStore.New(r, sessionName)
}
session.Values["oauth_state"] = token
err = session.Save(r, w)
if err != nil {
logger.Logger.Error("CreateAuthURL: failed to save session", "error", err.Error())
}
authURL := s.oauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("prompt", promptParam))
return authURL
return s.OAuthProvider.CreateAuthURL(w, r, nextURL)
}
func (s *OauthService) VerifyState(w http.ResponseWriter, r *http.Request, stateToken string) bool {
session, _ := s.sessionStore.Get(r, sessionName)
stored, _ := session.Values["oauth_state"].(string)
logger.Logger.Debug("VerifyState: validating OAuth state",
"stored_length", len(stored),
"token_length", len(stateToken),
"stored_empty", stored == "",
"token_empty", stateToken == "")
if stored == "" || stateToken == "" {
logger.Logger.Warn("VerifyState: empty state tokens")
if s.OAuthProvider == nil {
logger.Logger.Error("VerifyState called but OAuth provider is nil")
return false
}
if subtleConstantTimeCompare(stored, stateToken) {
logger.Logger.Debug("VerifyState: state valid, clearing token")
delete(session.Values, "oauth_state")
_ = session.Save(r, w)
return true
}
logger.Logger.Warn("VerifyState: state mismatch")
return false
}
func subtleConstantTimeCompare(a, b string) bool {
if len(a) != len(b) {
return false
}
var v byte
for i := 0; i < len(a); i++ {
v |= a[i] ^ b[i]
}
return v == 0
return s.OAuthProvider.VerifyState(w, r, stateToken)
}
func (s *OauthService) HandleCallback(ctx context.Context, w http.ResponseWriter, r *http.Request, code, state string) (*models.User, string, error) {
parts := strings.SplitN(state, ":", 2)
nextURL := "/"
if len(parts) == 2 {
if nb, err := base64.RawURLEncoding.DecodeString(parts[1]); err == nil {
nextURL = string(nb)
}
if s.OAuthProvider == nil {
logger.Logger.Error("HandleCallback called but OAuth provider is nil")
return nil, "/", models.ErrUnauthorized
}
logger.Logger.Debug("Processing OAuth callback",
"has_code", code != "",
"next_url", nextURL)
// Retrieve code_verifier from session for PKCE
session, _ := s.sessionStore.Get(r, sessionName)
codeVerifier, hasPKCE := session.Values["code_verifier"].(string)
// Clean up code_verifier immediately after retrieval
if hasPKCE {
delete(session.Values, "code_verifier")
_ = session.Save(r, w)
}
// Exchange authorization code for token (with or without PKCE)
var token *oauth2.Token
var err error
if hasPKCE && codeVerifier != "" {
logger.Logger.Info("OAuth token exchange with PKCE")
token, err = s.oauthConfig.Exchange(ctx, code,
oauth2.SetAuthURLParam("code_verifier", codeVerifier))
} else {
logger.Logger.Warn("OAuth token exchange without PKCE (legacy session or fallback)")
token, err = s.oauthConfig.Exchange(ctx, code)
}
if err != nil {
logger.Logger.Error("OAuth token exchange failed",
"error", err.Error(),
"with_pkce", hasPKCE)
return nil, nextURL, fmt.Errorf("oauth exchange failed: %w", err)
}
logger.Logger.Info("OAuth token exchange successful", "with_pkce", hasPKCE)
client := s.oauthConfig.Client(ctx, token)
resp, err := client.Get(s.userInfoURL)
if err != nil || resp.StatusCode != 200 {
statusCode := 0
if resp != nil {
statusCode = resp.StatusCode
}
logger.Logger.Error("User info request failed",
"error", err,
"status_code", statusCode)
return nil, nextURL, fmt.Errorf("userinfo request failed: %w", err)
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(resp.Body)
logger.Logger.Debug("User info retrieved successfully",
"status_code", resp.StatusCode)
user, err := s.parseUserInfo(resp)
if err != nil {
logger.Logger.Error("Failed to parse user info",
"error", err.Error())
return nil, nextURL, fmt.Errorf("failed to parse user info: %w", err)
}
if !s.IsAllowedDomain(user.Email) {
logger.Logger.Warn("User domain not allowed",
"user_email", user.Email,
"allowed_domain", s.allowedDomain)
return nil, nextURL, models.ErrDomainNotAllowed
}
logger.Logger.Info("OAuth callback successful",
"user_email", user.Email,
"user_name", user.Name)
// Store refresh token if available and repository is configured
if token.RefreshToken != "" && s.sessionRepo != nil && s.encryptionKey != nil {
if err := s.storeRefreshToken(ctx, w, r, token, user); err != nil {
// Log error but don't fail the authentication
logger.Logger.Error("Failed to store refresh token (non-fatal)",
"user_sub", user.Sub,
"error", err.Error())
}
}
return user, nextURL, nil
}
// storeRefreshToken encrypts and stores the OAuth refresh token
func (s *OauthService) storeRefreshToken(ctx context.Context, w http.ResponseWriter, r *http.Request, token *oauth2.Token, user *models.User) error {
// Encrypt refresh token
encryptedToken, err := crypto.EncryptToken(token.RefreshToken, s.encryptionKey)
if err != nil {
return fmt.Errorf("failed to encrypt refresh token: %w", err)
}
// Generate unique session ID for OAuth session tracking
sessionID := generateSessionID()
// Get client IP and user agent for security tracking
ipAddress := getClientIP(r)
userAgent := r.UserAgent()
// Create OAuth session
oauthSession := &models.OAuthSession{
SessionID: sessionID,
UserSub: user.Sub,
RefreshTokenEncrypted: encryptedToken,
AccessTokenExpiresAt: token.Expiry,
UserAgent: userAgent,
IPAddress: ipAddress,
}
// Save to database
if err := s.sessionRepo.Create(ctx, oauthSession); err != nil {
return fmt.Errorf("failed to create OAuth session: %w", err)
}
// Link OAuth session ID to user session
userSession, _ := s.sessionStore.Get(r, sessionName)
userSession.Values["oauth_session_id"] = sessionID
if err := userSession.Save(r, w); err != nil {
logger.Logger.Error("Failed to link OAuth session to user session",
"session_id", sessionID,
"error", err.Error())
// Don't return error, session is already created in DB
}
logger.Logger.Info("Stored encrypted refresh token",
"user_sub", user.Sub,
"session_id", sessionID,
"expires_at", token.Expiry)
return nil
}
// generateSessionID generates a unique session ID for OAuth sessions
func generateSessionID() string {
nonce, _ := crypto.GenerateNonce()
return nonce
}
// getClientIP extracts the client IP address from the request
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header (if behind proxy)
forwarded := r.Header.Get("X-Forwarded-For")
if forwarded != "" {
// Take the first IP in the list
parts := strings.Split(forwarded, ",")
if len(parts) > 0 {
return strings.TrimSpace(parts[0])
}
}
// Check X-Real-IP header
realIP := r.Header.Get("X-Real-IP")
if realIP != "" {
return realIP
}
// Fallback to RemoteAddr
return r.RemoteAddr
return s.OAuthProvider.HandleCallback(ctx, w, r, code, state)
}
func (s *OauthService) IsAllowedDomain(email string) bool {
if s.allowedDomain == "" {
if s.OAuthProvider == nil {
// If no OAuth provider, allow all domains (used for MagicLink)
return true
}
return strings.HasSuffix(
strings.ToLower(email),
"@"+strings.ToLower(s.allowedDomain),
)
}
func (s *OauthService) parseUserInfo(resp *http.Response) (*models.User, error) {
var rawUser map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&rawUser); err != nil {
return nil, fmt.Errorf("failed to decode user info: %w", err)
}
// Reduce PII in standard logs; log only keys at debug level
if rawUser != nil {
keys := make([]string, 0, len(rawUser))
for k := range rawUser {
keys = append(keys, k)
}
logger.Logger.Debug("OAuth user info received", "keys", keys)
}
user := &models.User{}
if sub, ok := rawUser["sub"].(string); ok {
user.Sub = sub
} else if id, ok := rawUser["id"]; ok {
user.Sub = fmt.Sprintf("%v", id)
} else {
return nil, fmt.Errorf("missing user ID in response")
}
if email, ok := rawUser["email"].(string); ok {
user.Email = email
} else {
return nil, fmt.Errorf("missing email in user info response")
}
var name string
if fullName, ok := rawUser["name"].(string); ok && fullName != "" {
name = fullName
} else if firstName, ok := rawUser["given_name"].(string); ok {
if lastName, ok := rawUser["family_name"].(string); ok {
name = firstName + " " + lastName
} else {
name = firstName
}
} else if cn, ok := rawUser["cn"].(string); ok && cn != "" {
name = cn
} else if displayName, ok := rawUser["display_name"].(string); ok && displayName != "" {
name = displayName
} else if preferredName, ok := rawUser["preferred_username"].(string); ok && preferredName != "" {
name = preferredName
}
user.Name = name
logger.Logger.Debug("Extracted OAuth user identifiers",
"sub", user.Sub,
"email_present", user.Email != "",
"name_present", user.Name != "")
if !user.IsValid() {
return nil, fmt.Errorf("invalid user data extracted: sub=%s, email=%s", user.Sub, user.Email)
}
return user, nil
return s.OAuthProvider.IsAllowedDomain(email)
}

View File

@@ -0,0 +1,391 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package auth
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/gorilla/securecookie"
"golang.org/x/oauth2"
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
"github.com/btouchard/ackify-ce/backend/pkg/crypto"
"github.com/btouchard/ackify-ce/backend/pkg/logger"
)
// OAuthProvider handles OAuth2 authentication flow
// This component is optional and can be nil if OAuth is disabled
type OAuthProvider struct {
oauthConfig *oauth2.Config
userInfoURL string
logoutURL string
allowedDomain string
baseURL string
sessionSvc *SessionService // Reference to session service for state management
}
// OAuthProviderConfig holds configuration for the OAuth provider
type OAuthProviderConfig struct {
BaseURL string
ClientID string
ClientSecret string
AuthURL string
TokenURL string
UserInfoURL string
LogoutURL string
Scopes []string
AllowedDomain string
SessionSvc *SessionService
}
// NewOAuthProvider creates a new OAuth provider
func NewOAuthProvider(config OAuthProviderConfig) *OAuthProvider {
oauthConfig := &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.BaseURL + "/api/v1/auth/callback",
Scopes: config.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: config.AuthURL,
TokenURL: config.TokenURL,
},
}
logger.Logger.Info("OAuth provider configured",
"client_id", config.ClientID,
"auth_url", config.AuthURL,
"redirect_url", oauthConfig.RedirectURL)
return &OAuthProvider{
oauthConfig: oauthConfig,
userInfoURL: config.UserInfoURL,
logoutURL: config.LogoutURL,
allowedDomain: config.AllowedDomain,
baseURL: config.BaseURL,
sessionSvc: config.SessionSvc,
}
}
// GetLogoutURL returns the SSO logout URL if configured
func (p *OAuthProvider) GetLogoutURL() string {
if p.logoutURL == "" {
return ""
}
logoutURL := p.logoutURL
if p.baseURL != "" {
logoutURL += "?continue=" + p.baseURL
}
return logoutURL
}
// CreateAuthURL creates an OAuth authorization URL with PKCE
func (p *OAuthProvider) CreateAuthURL(w http.ResponseWriter, r *http.Request, nextURL string) string {
// Generate PKCE code verifier and challenge
codeVerifier, err := crypto.GenerateCodeVerifier()
if err != nil {
logger.Logger.Error("Failed to generate PKCE code verifier", "error", err.Error())
// Fallback to OAuth flow without PKCE for backward compatibility
return p.createAuthURLWithoutPKCE(w, r, nextURL)
}
codeChallenge := crypto.GenerateCodeChallenge(codeVerifier)
logger.Logger.Debug("Generated PKCE parameters for OAuth flow")
// Generate state token
randPart := securecookie.GenerateRandomKey(20)
token := base64.RawURLEncoding.EncodeToString(randPart)
state := token + ":" + base64.RawURLEncoding.EncodeToString([]byte(nextURL))
promptParam := "select_account"
isSilent := r.URL.Query().Get("silent") == "true"
if isSilent {
promptParam = "none"
}
logger.Logger.Info("Starting OAuth flow with PKCE",
"next_url", nextURL,
"silent", isSilent,
"state_token_length", len(token))
session, err := p.sessionSvc.GetSession(r)
if err != nil {
logger.Logger.Error("CreateAuthURL: failed to get session from store", "error", err.Error())
// Create a new empty session if Get fails
session, _ = p.sessionSvc.GetNewSession(r)
}
// Store state and code_verifier in session
session.Values["oauth_state"] = token
session.Values["code_verifier"] = codeVerifier
err = session.Save(r, w)
if err != nil {
logger.Logger.Error("CreateAuthURL: failed to save session", "error", err.Error())
}
// Generate OAuth URL with PKCE parameters
authURL := p.oauthConfig.AuthCodeURL(state,
oauth2.SetAuthURLParam("prompt", promptParam),
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"))
logger.Logger.Debug("CreateAuthURL: generated auth URL with PKCE",
"prompt", promptParam,
"url_length", len(authURL))
return authURL
}
// createAuthURLWithoutPKCE is a fallback method for OAuth without PKCE
// Used for backward compatibility if PKCE generation fails
func (p *OAuthProvider) createAuthURLWithoutPKCE(w http.ResponseWriter, r *http.Request, nextURL string) string {
randPart := securecookie.GenerateRandomKey(20)
token := base64.RawURLEncoding.EncodeToString(randPart)
state := token + ":" + base64.RawURLEncoding.EncodeToString([]byte(nextURL))
promptParam := "select_account"
isSilent := r.URL.Query().Get("silent") == "true"
if isSilent {
promptParam = "none"
}
logger.Logger.Warn("Starting OAuth flow WITHOUT PKCE (fallback mode)",
"next_url", nextURL,
"silent", isSilent)
session, err := p.sessionSvc.GetSession(r)
if err != nil {
session, _ = p.sessionSvc.GetNewSession(r)
}
session.Values["oauth_state"] = token
err = session.Save(r, w)
if err != nil {
logger.Logger.Error("CreateAuthURL: failed to save session", "error", err.Error())
}
authURL := p.oauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("prompt", promptParam))
return authURL
}
// VerifyState validates the OAuth state token for CSRF protection
func (p *OAuthProvider) VerifyState(w http.ResponseWriter, r *http.Request, stateToken string) bool {
session, _ := p.sessionSvc.GetSession(r)
stored, _ := session.Values["oauth_state"].(string)
logger.Logger.Debug("VerifyState: validating OAuth state",
"stored_length", len(stored),
"token_length", len(stateToken),
"stored_empty", stored == "",
"token_empty", stateToken == "")
if stored == "" || stateToken == "" {
logger.Logger.Warn("VerifyState: empty state tokens")
return false
}
if subtleConstantTimeCompare(stored, stateToken) {
logger.Logger.Debug("VerifyState: state valid, clearing token")
delete(session.Values, "oauth_state")
_ = session.Save(r, w)
return true
}
logger.Logger.Warn("VerifyState: state mismatch")
return false
}
// subtleConstantTimeCompare performs a timing-safe string comparison
func subtleConstantTimeCompare(a, b string) bool {
if len(a) != len(b) {
return false
}
var v byte
for i := 0; i < len(a); i++ {
v |= a[i] ^ b[i]
}
return v == 0
}
// HandleCallback processes the OAuth callback and returns the authenticated user
func (p *OAuthProvider) HandleCallback(ctx context.Context, w http.ResponseWriter, r *http.Request, code, state string) (*models.User, string, error) {
parts := strings.SplitN(state, ":", 2)
nextURL := "/"
if len(parts) == 2 {
if nb, err := base64.RawURLEncoding.DecodeString(parts[1]); err == nil {
nextURL = string(nb)
}
}
logger.Logger.Debug("Processing OAuth callback",
"has_code", code != "",
"next_url", nextURL)
// Retrieve code_verifier from session for PKCE
session, _ := p.sessionSvc.GetSession(r)
codeVerifier, hasPKCE := session.Values["code_verifier"].(string)
// Clean up code_verifier immediately after retrieval
if hasPKCE {
delete(session.Values, "code_verifier")
_ = session.Save(r, w)
}
// Exchange authorization code for token (with or without PKCE)
var token *oauth2.Token
var err error
if hasPKCE && codeVerifier != "" {
logger.Logger.Info("OAuth token exchange with PKCE")
token, err = p.oauthConfig.Exchange(ctx, code,
oauth2.SetAuthURLParam("code_verifier", codeVerifier))
} else {
logger.Logger.Warn("OAuth token exchange without PKCE (legacy session or fallback)")
token, err = p.oauthConfig.Exchange(ctx, code)
}
if err != nil {
logger.Logger.Error("OAuth token exchange failed",
"error", err.Error(),
"with_pkce", hasPKCE)
return nil, nextURL, fmt.Errorf("oauth exchange failed: %w", err)
}
logger.Logger.Info("OAuth token exchange successful", "with_pkce", hasPKCE)
client := p.oauthConfig.Client(ctx, token)
resp, err := client.Get(p.userInfoURL)
if err != nil || resp.StatusCode != 200 {
statusCode := 0
if resp != nil {
statusCode = resp.StatusCode
}
logger.Logger.Error("User info request failed",
"error", err,
"status_code", statusCode)
return nil, nextURL, fmt.Errorf("userinfo request failed: %w", err)
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(resp.Body)
logger.Logger.Debug("User info retrieved successfully",
"status_code", resp.StatusCode)
user, err := p.parseUserInfo(resp)
if err != nil {
logger.Logger.Error("Failed to parse user info",
"error", err.Error())
return nil, nextURL, fmt.Errorf("failed to parse user info: %w", err)
}
if !p.IsAllowedDomain(user.Email) {
logger.Logger.Warn("User domain not allowed",
"user_email", user.Email,
"allowed_domain", p.allowedDomain)
return nil, nextURL, models.ErrDomainNotAllowed
}
logger.Logger.Info("OAuth callback successful",
"user_email", user.Email,
"user_name", user.Name)
// Store refresh token if available
if token.RefreshToken != "" && p.sessionSvc.sessionRepo != nil {
if err := p.sessionSvc.StoreRefreshToken(ctx, w, r, token, user); err != nil {
// Log error but don't fail the authentication
logger.Logger.Error("Failed to store refresh token (non-fatal)",
"user_sub", user.Sub,
"error", err.Error())
}
}
return user, nextURL, nil
}
// IsAllowedDomain checks if the user's email domain is allowed
func (p *OAuthProvider) IsAllowedDomain(email string) bool {
if p.allowedDomain == "" {
return true
}
domain := strings.ToLower(p.allowedDomain)
// If domain already has @ prefix, don't add another one
if !strings.HasPrefix(domain, "@") {
domain = "@" + domain
}
return strings.HasSuffix(strings.ToLower(email), domain)
}
// parseUserInfo extracts user information from the OAuth provider response
func (p *OAuthProvider) parseUserInfo(resp *http.Response) (*models.User, error) {
var rawUser map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&rawUser); err != nil {
return nil, fmt.Errorf("failed to decode user info: %w", err)
}
// Reduce PII in standard logs; log only keys at debug level
if rawUser != nil {
keys := make([]string, 0, len(rawUser))
for k := range rawUser {
keys = append(keys, k)
}
logger.Logger.Debug("OAuth user info received", "keys", keys)
}
user := &models.User{}
if sub, ok := rawUser["sub"].(string); ok {
user.Sub = sub
} else if id, ok := rawUser["id"]; ok {
user.Sub = fmt.Sprintf("%v", id)
} else {
return nil, fmt.Errorf("missing user ID in response")
}
if email, ok := rawUser["email"].(string); ok {
user.Email = email
} else {
return nil, fmt.Errorf("missing email in user info response")
}
var name string
if fullName, ok := rawUser["name"].(string); ok && fullName != "" {
name = fullName
} else if firstName, ok := rawUser["given_name"].(string); ok {
if lastName, ok := rawUser["family_name"].(string); ok {
name = firstName + " " + lastName
} else {
name = firstName
}
} else if cn, ok := rawUser["cn"].(string); ok && cn != "" {
name = cn
} else if displayName, ok := rawUser["display_name"].(string); ok && displayName != "" {
name = displayName
} else if preferredName, ok := rawUser["preferred_username"].(string); ok && preferredName != "" {
name = preferredName
}
user.Name = name
logger.Logger.Debug("Extracted OAuth user identifiers",
"sub", user.Sub,
"email_present", user.Email != "",
"name_present", user.Name != "")
if !user.IsValid() {
return nil, fmt.Errorf("invalid user data extracted: sub=%s, email=%s", user.Sub, user.Email)
}
return user, nil
}

View File

@@ -0,0 +1,513 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package auth
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
)
func TestNewOAuthProvider(t *testing.T) {
sessionSvc := NewSessionService(SessionServiceConfig{
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
SecureCookies: false,
})
tests := []struct {
name string
config OAuthProviderConfig
}{
{
name: "complete config",
config: OAuthProviderConfig{
BaseURL: "https://ackify.example.com",
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
AuthURL: "https://provider.com/auth",
TokenURL: "https://provider.com/token",
UserInfoURL: "https://provider.com/userinfo",
LogoutURL: "https://provider.com/logout",
Scopes: []string{"openid", "email", "profile"},
AllowedDomain: "@example.com",
SessionSvc: sessionSvc,
},
},
{
name: "minimal config",
config: OAuthProviderConfig{
BaseURL: "http://localhost:8080",
ClientID: "minimal-client",
ClientSecret: "minimal-secret",
AuthURL: "https://auth.com/oauth",
TokenURL: "https://auth.com/token",
UserInfoURL: "https://api.com/user",
Scopes: []string{"user"},
SessionSvc: sessionSvc,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider := NewOAuthProvider(tt.config)
if provider == nil {
t.Fatal("NewOAuthProvider() returned nil")
}
// Test that OAuth config is properly initialized
if provider.oauthConfig == nil {
t.Error("OAuth config should not be nil")
}
if provider.oauthConfig.ClientID != tt.config.ClientID {
t.Errorf("ClientID = %v, expected %v", provider.oauthConfig.ClientID, tt.config.ClientID)
}
if provider.oauthConfig.ClientSecret != tt.config.ClientSecret {
t.Errorf("ClientSecret = %v, expected %v", provider.oauthConfig.ClientSecret, tt.config.ClientSecret)
}
expectedRedirectURL := tt.config.BaseURL + "/api/v1/auth/callback"
if provider.oauthConfig.RedirectURL != expectedRedirectURL {
t.Errorf("RedirectURL = %v, expected %v", provider.oauthConfig.RedirectURL, expectedRedirectURL)
}
if len(provider.oauthConfig.Scopes) != len(tt.config.Scopes) {
t.Errorf("Scopes length = %v, expected %v", len(provider.oauthConfig.Scopes), len(tt.config.Scopes))
}
if provider.oauthConfig.Endpoint.AuthURL != tt.config.AuthURL {
t.Errorf("AuthURL = %v, expected %v", provider.oauthConfig.Endpoint.AuthURL, tt.config.AuthURL)
}
if provider.oauthConfig.Endpoint.TokenURL != tt.config.TokenURL {
t.Errorf("TokenURL = %v, expected %v", provider.oauthConfig.Endpoint.TokenURL, tt.config.TokenURL)
}
// Test provider fields
if provider.userInfoURL != tt.config.UserInfoURL {
t.Errorf("userInfoURL = %v, expected %v", provider.userInfoURL, tt.config.UserInfoURL)
}
if provider.logoutURL != tt.config.LogoutURL {
t.Errorf("logoutURL = %v, expected %v", provider.logoutURL, tt.config.LogoutURL)
}
if provider.allowedDomain != tt.config.AllowedDomain {
t.Errorf("allowedDomain = %v, expected %v", provider.allowedDomain, tt.config.AllowedDomain)
}
if provider.sessionSvc == nil {
t.Error("sessionSvc should not be nil")
}
})
}
}
func TestOAuthProvider_GetLogoutURL(t *testing.T) {
sessionSvc := NewSessionService(SessionServiceConfig{
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
SecureCookies: false,
})
tests := []struct {
name string
logoutURL string
baseURL string
expectedURL string
}{
{
name: "with logout URL and base URL",
logoutURL: "https://provider.com/logout",
baseURL: "https://ackify.example.com",
expectedURL: "https://provider.com/logout?continue=https://ackify.example.com",
},
{
name: "with logout URL only",
logoutURL: "https://provider.com/logout",
baseURL: "",
expectedURL: "https://provider.com/logout",
},
{
name: "without logout URL",
logoutURL: "",
baseURL: "https://ackify.example.com",
expectedURL: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider := &OAuthProvider{
logoutURL: tt.logoutURL,
baseURL: tt.baseURL,
sessionSvc: sessionSvc,
}
result := provider.GetLogoutURL()
if result != tt.expectedURL {
t.Errorf("GetLogoutURL() = %v, expected %v", result, tt.expectedURL)
}
})
}
}
func TestOAuthProvider_IsAllowedDomain(t *testing.T) {
sessionSvc := NewSessionService(SessionServiceConfig{
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
SecureCookies: false,
})
tests := []struct {
name string
allowedDomain string
email string
expected bool
}{
{
name: "allowed domain match",
allowedDomain: "@example.com",
email: "user@example.com",
expected: true,
},
{
name: "allowed domain mismatch",
allowedDomain: "@example.com",
email: "user@other.com",
expected: false,
},
{
name: "no restriction",
allowedDomain: "",
email: "user@any.com",
expected: true,
},
{
name: "case insensitive match",
allowedDomain: "@Example.Com",
email: "user@EXAMPLE.com",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider := &OAuthProvider{
allowedDomain: tt.allowedDomain,
sessionSvc: sessionSvc,
}
result := provider.IsAllowedDomain(tt.email)
if result != tt.expected {
t.Errorf("IsAllowedDomain(%v) = %v, expected %v", tt.email, result, tt.expected)
}
})
}
}
func TestOAuthProvider_CreateAuthURL(t *testing.T) {
sessionSvc := NewSessionService(SessionServiceConfig{
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
SecureCookies: false,
})
config := OAuthProviderConfig{
BaseURL: "https://ackify.example.com",
ClientID: "test-client",
ClientSecret: "test-secret",
AuthURL: "https://provider.com/auth",
TokenURL: "https://provider.com/token",
UserInfoURL: "https://provider.com/userinfo",
Scopes: []string{"openid", "email"},
SessionSvc: sessionSvc,
}
provider := NewOAuthProvider(config)
t.Run("creates auth URL with PKCE", func(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
authURL := provider.CreateAuthURL(rec, req, "/dashboard")
if authURL == "" {
t.Fatal("CreateAuthURL() returned empty string")
}
// Parse the URL and check parameters
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
query := parsedURL.Query()
// Check for required OAuth parameters
if query.Get("client_id") == "" {
t.Error("client_id parameter missing")
}
if query.Get("redirect_uri") == "" {
t.Error("redirect_uri parameter missing")
}
if query.Get("response_type") != "code" {
t.Error("response_type should be 'code'")
}
if query.Get("state") == "" {
t.Error("state parameter missing")
}
// Check for PKCE parameters
if query.Get("code_challenge") == "" {
t.Error("code_challenge parameter missing (PKCE should be enabled)")
}
if query.Get("code_challenge_method") != "S256" {
t.Error("code_challenge_method should be 'S256'")
}
})
t.Run("creates auth URL with silent flag", func(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/?silent=true", nil)
authURL := provider.CreateAuthURL(rec, req, "/dashboard")
parsedURL, _ := url.Parse(authURL)
query := parsedURL.Query()
if query.Get("prompt") != "none" {
t.Errorf("prompt parameter = %v, expected 'none' for silent auth", query.Get("prompt"))
}
})
}
func TestOAuthProvider_VerifyState(t *testing.T) {
sessionSvc := NewSessionService(SessionServiceConfig{
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
SecureCookies: false,
})
config := OAuthProviderConfig{
BaseURL: "https://ackify.example.com",
ClientID: "test-client",
ClientSecret: "test-secret",
AuthURL: "https://provider.com/auth",
TokenURL: "https://provider.com/token",
UserInfoURL: "https://provider.com/userinfo",
Scopes: []string{"openid"},
SessionSvc: sessionSvc,
}
provider := NewOAuthProvider(config)
t.Run("valid state", func(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
// Create auth URL to generate state
_ = provider.CreateAuthURL(rec, req, "/")
// Get session with state
req2 := httptest.NewRequest("GET", "/", nil)
for _, cookie := range rec.Result().Cookies() {
req2.AddCookie(cookie)
}
session, _ := sessionSvc.GetSession(req2)
storedState, ok := session.Values["oauth_state"].(string)
if !ok {
t.Fatal("State not stored in session")
}
// Verify state
rec2 := httptest.NewRecorder()
valid := provider.VerifyState(rec2, req2, storedState)
if !valid {
t.Error("VerifyState() should return true for valid state")
}
})
t.Run("invalid state", func(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
valid := provider.VerifyState(rec, req, "invalid-state-token")
if valid {
t.Error("VerifyState() should return false for invalid state")
}
})
t.Run("empty state", func(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
valid := provider.VerifyState(rec, req, "")
if valid {
t.Error("VerifyState() should return false for empty state")
}
})
}
func TestOAuthProvider_parseUserInfo(t *testing.T) {
sessionSvc := NewSessionService(SessionServiceConfig{
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
SecureCookies: false,
})
provider := &OAuthProvider{
sessionSvc: sessionSvc,
}
tests := []struct {
name string
responseObj map[string]interface{}
wantErr bool
checkUser func(*testing.T, *models.User)
}{
{
name: "complete user info with sub",
responseObj: map[string]interface{}{
"sub": "12345",
"email": "user@example.com",
"name": "Test User",
},
wantErr: false,
checkUser: func(t *testing.T, user *models.User) {
if user.Sub != "12345" {
t.Errorf("Sub = %v, expected 12345", user.Sub)
}
if user.Email != "user@example.com" {
t.Errorf("Email = %v, expected user@example.com", user.Email)
}
if user.Name != "Test User" {
t.Errorf("Name = %v, expected Test User", user.Name)
}
},
},
{
name: "user info with id instead of sub",
responseObj: map[string]interface{}{
"id": 67890,
"email": "user@example.com",
},
wantErr: false,
checkUser: func(t *testing.T, user *models.User) {
if user.Sub != "67890" {
t.Errorf("Sub = %v, expected 67890", user.Sub)
}
},
},
{
name: "user info with given_name and family_name",
responseObj: map[string]interface{}{
"sub": "12345",
"email": "user@example.com",
"given_name": "John",
"family_name": "Doe",
},
wantErr: false,
checkUser: func(t *testing.T, user *models.User) {
if user.Name != "John Doe" {
t.Errorf("Name = %v, expected 'John Doe'", user.Name)
}
},
},
{
name: "missing email",
responseObj: map[string]interface{}{
"sub": "12345",
"name": "Test User",
},
wantErr: true,
},
{
name: "missing sub and id",
responseObj: map[string]interface{}{
"email": "user@example.com",
"name": "Test User",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mock HTTP response
jsonData, _ := json.Marshal(tt.responseObj)
resp := &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(jsonData)),
}
user, err := provider.parseUserInfo(resp)
if tt.wantErr {
if err == nil {
t.Error("parseUserInfo() should return error")
}
return
}
if err != nil {
t.Fatalf("parseUserInfo() unexpected error: %v", err)
}
if user == nil {
t.Fatal("parseUserInfo() returned nil user")
}
if tt.checkUser != nil {
tt.checkUser(t, user)
}
})
}
}
func TestSubtleConstantTimeCompare(t *testing.T) {
tests := []struct {
name string
a string
b string
expected bool
}{
{
name: "equal strings",
a: "secret123",
b: "secret123",
expected: true,
},
{
name: "different strings",
a: "secret123",
b: "secret456",
expected: false,
},
{
name: "different lengths",
a: "short",
b: "longer string",
expected: false,
},
{
name: "empty strings",
a: "",
b: "",
expected: true,
},
{
name: "one empty",
a: "string",
b: "",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := subtleConstantTimeCompare(tt.a, tt.b)
if result != tt.expected {
t.Errorf("subtleConstantTimeCompare(%v, %v) = %v, expected %v", tt.a, tt.b, result, tt.expected)
}
})
}
}

View File

@@ -58,46 +58,60 @@ func TestNewOAuthService(t *testing.T) {
t.Fatal("NewOAuthService() returned nil")
}
// Test that OAuth config is properly initialized
if service.oauthConfig == nil {
t.Error("OAuth config should not be nil")
}
if service.oauthConfig.ClientID != tt.config.ClientID {
t.Errorf("ClientID = %v, expected %v", service.oauthConfig.ClientID, tt.config.ClientID)
}
if service.oauthConfig.ClientSecret != tt.config.ClientSecret {
t.Errorf("ClientSecret = %v, expected %v", service.oauthConfig.ClientSecret, tt.config.ClientSecret)
// Test that SessionService is always present
if service.SessionService == nil {
t.Fatal("SessionService should not be nil")
}
expectedRedirectURL := tt.config.BaseURL + "/api/v1/auth/callback"
if service.oauthConfig.RedirectURL != expectedRedirectURL {
t.Errorf("RedirectURL = %v, expected %v", service.oauthConfig.RedirectURL, expectedRedirectURL)
// Test that OAuthProvider is created when OAuth credentials are provided
if tt.config.ClientID != "" && tt.config.ClientSecret != "" {
if service.OAuthProvider == nil {
t.Error("OAuthProvider should not be nil when credentials are provided")
}
// Test OAuth provider config
if service.OAuthProvider.oauthConfig == nil {
t.Error("OAuth config should not be nil")
}
if service.OAuthProvider.oauthConfig.ClientID != tt.config.ClientID {
t.Errorf("ClientID = %v, expected %v", service.OAuthProvider.oauthConfig.ClientID, tt.config.ClientID)
}
if service.OAuthProvider.oauthConfig.ClientSecret != tt.config.ClientSecret {
t.Errorf("ClientSecret = %v, expected %v", service.OAuthProvider.oauthConfig.ClientSecret, tt.config.ClientSecret)
}
expectedRedirectURL := tt.config.BaseURL + "/api/v1/auth/callback"
if service.OAuthProvider.oauthConfig.RedirectURL != expectedRedirectURL {
t.Errorf("RedirectURL = %v, expected %v", service.OAuthProvider.oauthConfig.RedirectURL, expectedRedirectURL)
}
if len(service.OAuthProvider.oauthConfig.Scopes) != len(tt.config.Scopes) {
t.Errorf("Scopes length = %v, expected %v", len(service.OAuthProvider.oauthConfig.Scopes), len(tt.config.Scopes))
}
if service.OAuthProvider.oauthConfig.Endpoint.AuthURL != tt.config.AuthURL {
t.Errorf("AuthURL = %v, expected %v", service.OAuthProvider.oauthConfig.Endpoint.AuthURL, tt.config.AuthURL)
}
if service.OAuthProvider.oauthConfig.Endpoint.TokenURL != tt.config.TokenURL {
t.Errorf("TokenURL = %v, expected %v", service.OAuthProvider.oauthConfig.Endpoint.TokenURL, tt.config.TokenURL)
}
// Test OAuth provider fields
if service.OAuthProvider.userInfoURL != tt.config.UserInfoURL {
t.Errorf("userInfoURL = %v, expected %v", service.OAuthProvider.userInfoURL, tt.config.UserInfoURL)
}
if service.OAuthProvider.allowedDomain != tt.config.AllowedDomain {
t.Errorf("allowedDomain = %v, expected %v", service.OAuthProvider.allowedDomain, tt.config.AllowedDomain)
}
}
if len(service.oauthConfig.Scopes) != len(tt.config.Scopes) {
t.Errorf("Scopes length = %v, expected %v", len(service.oauthConfig.Scopes), len(tt.config.Scopes))
}
if service.oauthConfig.Endpoint.AuthURL != tt.config.AuthURL {
t.Errorf("AuthURL = %v, expected %v", service.oauthConfig.Endpoint.AuthURL, tt.config.AuthURL)
}
if service.oauthConfig.Endpoint.TokenURL != tt.config.TokenURL {
t.Errorf("TokenURL = %v, expected %v", service.oauthConfig.Endpoint.TokenURL, tt.config.TokenURL)
}
// Test service fields
if service.userInfoURL != tt.config.UserInfoURL {
t.Errorf("userInfoURL = %v, expected %v", service.userInfoURL, tt.config.UserInfoURL)
}
if service.allowedDomain != tt.config.AllowedDomain {
t.Errorf("allowedDomain = %v, expected %v", service.allowedDomain, tt.config.AllowedDomain)
}
if service.secureCookies != tt.config.SecureCookies {
t.Errorf("secureCookies = %v, expected %v", service.secureCookies, tt.config.SecureCookies)
// Test session service fields
if service.SessionService.secureCookies != tt.config.SecureCookies {
t.Errorf("secureCookies = %v, expected %v", service.SessionService.secureCookies, tt.config.SecureCookies)
}
// Test session store
if service.sessionStore == nil {
if service.SessionService.sessionStore == nil {
t.Error("Session store should not be nil")
}
})
@@ -143,7 +157,7 @@ func TestOauthService_GetUser(t *testing.T) {
{
name: "invalid JSON in session",
setupSession: func(w *httptest.ResponseRecorder, r *http.Request) {
session, _ := service.sessionStore.Get(r, sessionName)
session, _ := service.SessionService.sessionStore.Get(r, sessionName)
session.Values["user"] = "invalid-json"
session.Save(r, w)
},
@@ -152,7 +166,7 @@ func TestOauthService_GetUser(t *testing.T) {
{
name: "empty user value in session",
setupSession: func(w *httptest.ResponseRecorder, r *http.Request) {
session, _ := service.sessionStore.Get(r, sessionName)
session, _ := service.SessionService.sessionStore.Get(r, sessionName)
session.Values["user"] = ""
session.Save(r, w)
},
@@ -302,8 +316,8 @@ func TestOauthService_SetUser(t *testing.T) {
if sessionCookie.HttpOnly != true {
t.Error("Cookie should be HttpOnly")
}
if sessionCookie.Secure != tt.service.secureCookies {
t.Errorf("Cookie Secure = %v, expected %v", sessionCookie.Secure, tt.service.secureCookies)
if sessionCookie.Secure != tt.service.SessionService.secureCookies {
t.Errorf("Cookie Secure = %v, expected %v", sessionCookie.Secure, tt.service.SessionService.secureCookies)
}
if sessionCookie.SameSite != http.SameSiteLaxMode {
t.Errorf("Cookie SameSite = %v, expected %v", sessionCookie.SameSite, http.SameSiteLaxMode)
@@ -393,7 +407,9 @@ func TestOauthService_GetAuthURL(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
authURL := service.GetAuthURL(tt.nextURL)
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
authURL := service.CreateAuthURL(w, r, tt.nextURL)
if authURL == "" {
t.Error("Auth URL should not be empty")
@@ -496,9 +512,20 @@ func TestOauthService_IsAllowedDomain(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
service := &OauthService{
allowedDomain: tt.allowedDomain,
}
// Create service with custom allowed domain
service := NewOAuthService(Config{
BaseURL: "http://localhost:8080",
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
UserInfoURL: "https://www.googleapis.com/oauth2/v1/userinfo",
Scopes: []string{"openid", "email", "profile"},
AllowedDomain: tt.allowedDomain,
CookieSecret: []byte("test-cookie-secret-32-bytes-long!"),
SecureCookies: false,
SessionRepo: &mockSessionRepository{},
})
result := service.IsAllowedDomain(tt.email)
if result != tt.expected {
@@ -671,7 +698,7 @@ func TestOauthService_parseUserInfo(t *testing.T) {
Body: io.NopCloser(bytes.NewReader(jsonBody)),
}
user, err := service.parseUserInfo(resp)
user, err := service.OAuthProvider.parseUserInfo(resp)
if tt.expectError {
if err == nil {
@@ -711,7 +738,7 @@ func TestOauthService_parseUserInfo_InvalidJSON(t *testing.T) {
Body: io.NopCloser(strings.NewReader("invalid json")),
}
_, err := service.parseUserInfo(resp)
_, err := service.OAuthProvider.parseUserInfo(resp)
if err == nil {
t.Error("Expected error for invalid JSON")
}
@@ -909,7 +936,7 @@ func TestOauthService_VerifyState_Success(t *testing.T) {
r := httptest.NewRequest("GET", "/", nil)
// First, create a session with an oauth_state
session, _ := service.sessionStore.Get(r, sessionName)
session, _ := service.SessionService.sessionStore.Get(r, sessionName)
session.Values["oauth_state"] = "test-state-token-123"
_ = session.Save(r, w)
@@ -936,7 +963,7 @@ func TestOauthService_VerifyState_Mismatch(t *testing.T) {
r := httptest.NewRequest("GET", "/", nil)
// Set state in session
session, _ := service.sessionStore.Get(r, sessionName)
session, _ := service.SessionService.sessionStore.Get(r, sessionName)
session.Values["oauth_state"] = "correct-state"
_ = session.Save(r, w)
@@ -977,7 +1004,7 @@ func TestOauthService_VerifyState_EmptyToken(t *testing.T) {
r := httptest.NewRequest("GET", "/", nil)
// Set state in session
session, _ := service.sessionStore.Get(r, sessionName)
session, _ := service.SessionService.sessionStore.Get(r, sessionName)
session.Values["oauth_state"] = "some-state"
_ = session.Save(r, w)
@@ -1094,7 +1121,7 @@ func BenchmarkVerifyState(b *testing.B) {
r := httptest.NewRequest("GET", "/", nil)
// Setup session
session, _ := service.sessionStore.Get(r, sessionName)
session, _ := service.SessionService.sessionStore.Get(r, sessionName)
session.Values["oauth_state"] = "test-state-token"
_ = session.Save(r, w)

View File

@@ -0,0 +1,251 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package auth
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/gorilla/sessions"
"golang.org/x/oauth2"
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
"github.com/btouchard/ackify-ce/backend/pkg/crypto"
"github.com/btouchard/ackify-ce/backend/pkg/logger"
)
// SessionService manages user sessions independently of authentication method
// This service is always required, regardless of whether OAuth or MagicLink is used
type SessionService struct {
sessionStore *sessions.CookieStore
sessionRepo SessionRepository
encryptionKey []byte
secureCookies bool
}
// SessionServiceConfig holds configuration for the session service
type SessionServiceConfig struct {
CookieSecret []byte
SecureCookies bool
SessionRepo SessionRepository
}
// NewSessionService creates a new session service
func NewSessionService(config SessionServiceConfig) *SessionService {
sessionStore := sessions.NewCookieStore(config.CookieSecret)
// Configure session options globally on the store
sessionStore.Options = &sessions.Options{
Path: "/",
HttpOnly: true,
Secure: config.SecureCookies,
SameSite: http.SameSiteLaxMode,
MaxAge: 86400 * 30, // 30 days
}
logger.Logger.Info("Session store configured",
"secure_cookies", config.SecureCookies,
"max_age_days", 30)
// Use CookieSecret as encryption key (must be 32 bytes for AES-256)
encryptionKey := config.CookieSecret
if len(encryptionKey) < 32 {
logger.Logger.Warn("Encryption key too short, padding to 32 bytes",
"original_length", len(encryptionKey))
// Pad with zeros (not ideal, but prevents crashes)
padded := make([]byte, 32)
copy(padded, encryptionKey)
encryptionKey = padded
} else if len(encryptionKey) > 32 {
// Truncate to 32 bytes for AES-256
encryptionKey = encryptionKey[:32]
}
return &SessionService{
sessionStore: sessionStore,
sessionRepo: config.SessionRepo,
encryptionKey: encryptionKey,
secureCookies: config.SecureCookies,
}
}
// GetUser retrieves the authenticated user from the session
func (s *SessionService) GetUser(r *http.Request) (*models.User, error) {
session, err := s.sessionStore.Get(r, sessionName)
if err != nil {
logger.Logger.Debug("GetUser: failed to get session", "error", err.Error())
return nil, fmt.Errorf("failed to get session: %w", err)
}
userJSON, ok := session.Values["user"].(string)
if !ok || userJSON == "" {
logger.Logger.Debug("GetUser: no user in session",
"user_key_exists", ok,
"user_json_empty", userJSON == "")
return nil, models.ErrUnauthorized
}
var user models.User
if err := json.Unmarshal([]byte(userJSON), &user); err != nil {
logger.Logger.Error("GetUser: failed to unmarshal user", "error", err.Error())
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
}
logger.Logger.Debug("GetUser: user found", "email", user.Email)
return &user, nil
}
// SetUser stores a user in the session (works for both OAuth and MagicLink)
func (s *SessionService) SetUser(w http.ResponseWriter, r *http.Request, user *models.User) error {
// Always create a fresh new session to ensure session ID is generated
// This fixes an issue where reusing an existing invalid session results in empty session.ID
session, err := s.sessionStore.New(r, sessionName)
if err != nil {
logger.Logger.Error("SetUser: failed to create new session", "error", err.Error())
return fmt.Errorf("failed to create new session: %w", err)
}
userJSON, err := json.Marshal(user)
if err != nil {
logger.Logger.Error("SetUser: failed to marshal user", "error", err.Error())
return fmt.Errorf("failed to marshal user: %w", err)
}
logger.Logger.Debug("SetUser: saving user to new session",
"email", user.Email,
"secure_cookies", s.secureCookies,
"session_is_new", session.IsNew)
session.Values["user"] = string(userJSON)
// Session options are already configured globally on the store
// No need to set them again here
if err := session.Save(r, w); err != nil {
logger.Logger.Error("SetUser: failed to save session",
"error", err.Error(),
"session_is_new", session.IsNew,
"session_id_length", len(session.ID))
return fmt.Errorf("failed to save session: %w", err)
}
logger.Logger.Info("SetUser: session saved successfully",
"email", user.Email,
"session_id_length", len(session.ID))
return nil
}
// Logout clears the user session
func (s *SessionService) Logout(w http.ResponseWriter, r *http.Request) {
session, _ := s.sessionStore.Get(r, sessionName)
// Clear all session values first (important for cookie-based sessions)
for key := range session.Values {
delete(session.Values, key)
}
// Set MaxAge to -1 to expire the cookie
session.Options.MaxAge = -1
// Save the cleared session
_ = session.Save(r, w)
logger.Logger.Debug("Logout: session cleared")
}
// GetSession returns the raw session (useful for storing additional data like OAuth state)
func (s *SessionService) GetSession(r *http.Request) (*sessions.Session, error) {
return s.sessionStore.Get(r, sessionName)
}
// GetNewSession creates a new session
func (s *SessionService) GetNewSession(r *http.Request) (*sessions.Session, error) {
return s.sessionStore.New(r, sessionName)
}
// StoreRefreshToken encrypts and stores the OAuth refresh token
// This is called by OAuthProvider after successful authentication
func (s *SessionService) StoreRefreshToken(ctx context.Context, w http.ResponseWriter, r *http.Request, token *oauth2.Token, user *models.User) error {
if s.sessionRepo == nil {
return fmt.Errorf("session repository not configured")
}
if s.encryptionKey == nil {
return fmt.Errorf("encryption key not configured")
}
// Encrypt refresh token
encryptedToken, err := crypto.EncryptToken(token.RefreshToken, s.encryptionKey)
if err != nil {
return fmt.Errorf("failed to encrypt refresh token: %w", err)
}
// Generate unique session ID for OAuth session tracking
sessionID := generateSessionID()
// Get client IP and user agent for security tracking
ipAddress := getClientIP(r)
userAgent := r.UserAgent()
// Create OAuth session
oauthSession := &models.OAuthSession{
SessionID: sessionID,
UserSub: user.Sub,
RefreshTokenEncrypted: encryptedToken,
AccessTokenExpiresAt: token.Expiry,
UserAgent: userAgent,
IPAddress: ipAddress,
}
// Save to database
if err := s.sessionRepo.Create(ctx, oauthSession); err != nil {
return fmt.Errorf("failed to create OAuth session: %w", err)
}
// Link OAuth session ID to user session
userSession, _ := s.sessionStore.Get(r, sessionName)
userSession.Values["oauth_session_id"] = sessionID
if err := userSession.Save(r, w); err != nil {
logger.Logger.Error("Failed to link OAuth session to user session",
"session_id", sessionID,
"error", err.Error())
// Don't return error, session is already created in DB
}
logger.Logger.Info("Stored encrypted refresh token",
"user_sub", user.Sub,
"session_id", sessionID,
"expires_at", token.Expiry)
return nil
}
// generateSessionID generates a unique session ID for OAuth sessions
func generateSessionID() string {
nonce, _ := crypto.GenerateNonce()
return nonce
}
// getClientIP extracts the client IP address from the request
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header (if behind proxy)
forwarded := r.Header.Get("X-Forwarded-For")
if forwarded != "" {
// Take the first IP in the list
parts := strings.Split(forwarded, ",")
if len(parts) > 0 {
return strings.TrimSpace(parts[0])
}
}
// Check X-Real-IP header
realIP := r.Header.Get("X-Real-IP")
if realIP != "" {
return realIP
}
// Fallback to RemoteAddr
return r.RemoteAddr
}

View File

@@ -0,0 +1,504 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package auth
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"golang.org/x/oauth2"
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
)
var errSessionNotFound = errors.New("session not found")
// mockSessionRepository implements SessionRepository for testing
type mockSessionRepository struct {
sessions map[string]*models.OAuthSession
}
func newMockSessionRepository() *mockSessionRepository {
return &mockSessionRepository{
sessions: make(map[string]*models.OAuthSession),
}
}
func (m *mockSessionRepository) Create(ctx context.Context, session *models.OAuthSession) error {
m.sessions[session.SessionID] = session
return nil
}
func (m *mockSessionRepository) GetBySessionID(ctx context.Context, sessionID string) (*models.OAuthSession, error) {
session, ok := m.sessions[sessionID]
if !ok {
return nil, errSessionNotFound
}
return session, nil
}
func (m *mockSessionRepository) UpdateRefreshToken(ctx context.Context, sessionID string, encryptedToken []byte, expiresAt time.Time) error {
if session, ok := m.sessions[sessionID]; ok {
session.RefreshTokenEncrypted = encryptedToken
session.AccessTokenExpiresAt = expiresAt
return nil
}
return errSessionNotFound
}
func (m *mockSessionRepository) DeleteBySessionID(ctx context.Context, sessionID string) error {
delete(m.sessions, sessionID)
return nil
}
func (m *mockSessionRepository) DeleteExpired(ctx context.Context, olderThan time.Duration) (int64, error) {
return 0, nil
}
func TestNewSessionService(t *testing.T) {
tests := []struct {
name string
config SessionServiceConfig
}{
{
name: "complete config",
config: SessionServiceConfig{
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
SecureCookies: true,
SessionRepo: newMockSessionRepository(),
},
},
{
name: "minimal config",
config: SessionServiceConfig{
CookieSecret: []byte("test-secret"),
SecureCookies: false,
},
},
{
name: "short encryption key",
config: SessionServiceConfig{
CookieSecret: []byte("short"),
SecureCookies: false,
},
},
{
name: "long encryption key",
config: SessionServiceConfig{
CookieSecret: []byte("this-is-a-very-long-encryption-key-that-exceeds-32-bytes"),
SecureCookies: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
service := NewSessionService(tt.config)
if service == nil {
t.Fatal("NewSessionService() returned nil")
}
if service.sessionStore == nil {
t.Error("sessionStore should not be nil")
}
if service.secureCookies != tt.config.SecureCookies {
t.Errorf("secureCookies = %v, expected %v", service.secureCookies, tt.config.SecureCookies)
}
// Verify encryption key length is always 32 bytes
if len(service.encryptionKey) != 32 {
t.Errorf("encryption key length = %d, expected 32", len(service.encryptionKey))
}
// Verify session store options
if service.sessionStore.Options == nil {
t.Error("sessionStore.Options should not be nil")
} else {
opts := service.sessionStore.Options
if opts.Path != "/" {
t.Errorf("session path = %v, expected /", opts.Path)
}
if opts.HttpOnly != true {
t.Error("session should be HttpOnly")
}
if opts.Secure != tt.config.SecureCookies {
t.Errorf("session Secure = %v, expected %v", opts.Secure, tt.config.SecureCookies)
}
if opts.SameSite != http.SameSiteLaxMode {
t.Errorf("session SameSite = %v, expected Lax", opts.SameSite)
}
expectedMaxAge := 86400 * 30 // 30 days
if opts.MaxAge != expectedMaxAge {
t.Errorf("session MaxAge = %d, expected %d", opts.MaxAge, expectedMaxAge)
}
}
})
}
}
func TestSessionService_SetUser_GetUser(t *testing.T) {
config := SessionServiceConfig{
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
SecureCookies: false, // Use false for testing (no HTTPS)
}
service := NewSessionService(config)
testUser := &models.User{
Sub: "test-user-123",
Email: "test@example.com",
Name: "Test User",
}
// Test SetUser
t.Run("SetUser", func(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
err := service.SetUser(rec, req, testUser)
if err != nil {
t.Fatalf("SetUser() failed: %v", err)
}
// Check that cookie was set
cookies := rec.Result().Cookies()
if len(cookies) == 0 {
t.Fatal("No cookies were set")
}
foundSessionCookie := false
for _, cookie := range cookies {
if cookie.Name == sessionName {
foundSessionCookie = true
if cookie.Path != "/" {
t.Errorf("Cookie path = %v, expected /", cookie.Path)
}
if cookie.HttpOnly != true {
t.Error("Cookie should be HttpOnly")
}
}
}
if !foundSessionCookie {
t.Errorf("Session cookie %s not found", sessionName)
}
})
// Test GetUser with valid session
t.Run("GetUser with valid session", func(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
// First set the user
err := service.SetUser(rec, req, testUser)
if err != nil {
t.Fatalf("SetUser() failed: %v", err)
}
// Create a new request with the session cookie
req2 := httptest.NewRequest("GET", "/", nil)
for _, cookie := range rec.Result().Cookies() {
req2.AddCookie(cookie)
}
// Now get the user
user, err := service.GetUser(req2)
if err != nil {
t.Fatalf("GetUser() failed: %v", err)
}
if user == nil {
t.Fatal("GetUser() returned nil user")
}
if user.Sub != testUser.Sub {
t.Errorf("user.Sub = %v, expected %v", user.Sub, testUser.Sub)
}
if user.Email != testUser.Email {
t.Errorf("user.Email = %v, expected %v", user.Email, testUser.Email)
}
if user.Name != testUser.Name {
t.Errorf("user.Name = %v, expected %v", user.Name, testUser.Name)
}
})
// Test GetUser without session
t.Run("GetUser without session", func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
user, err := service.GetUser(req)
if err == nil {
t.Error("GetUser() should return error without session")
}
if user != nil {
t.Error("GetUser() should return nil user without session")
}
})
}
func TestSessionService_Logout(t *testing.T) {
config := SessionServiceConfig{
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
SecureCookies: false,
}
service := NewSessionService(config)
testUser := &models.User{
Sub: "test-user-123",
Email: "test@example.com",
Name: "Test User",
}
// Create session
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
err := service.SetUser(rec, req, testUser)
if err != nil {
t.Fatalf("SetUser() failed: %v", err)
}
// Create request with session cookie
req2 := httptest.NewRequest("GET", "/", nil)
for _, cookie := range rec.Result().Cookies() {
req2.AddCookie(cookie)
}
// Verify user exists before logout
user, err := service.GetUser(req2)
if err != nil {
t.Fatalf("GetUser() before logout failed: %v", err)
}
if user == nil {
t.Fatal("User should exist before logout")
}
// Logout
rec2 := httptest.NewRecorder()
service.Logout(rec2, req2)
// Create request with expired session cookie
req3 := httptest.NewRequest("GET", "/", nil)
for _, cookie := range rec2.Result().Cookies() {
req3.AddCookie(cookie)
}
// Verify user is gone after logout
user, err = service.GetUser(req3)
if err == nil {
t.Error("GetUser() after logout should return error")
}
if user != nil {
t.Error("User should be nil after logout")
}
}
func TestSessionService_GetSession(t *testing.T) {
config := SessionServiceConfig{
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
SecureCookies: false,
}
service := NewSessionService(config)
req := httptest.NewRequest("GET", "/", nil)
session, err := service.GetSession(req)
if err != nil {
t.Fatalf("GetSession() failed: %v", err)
}
if session == nil {
t.Fatal("GetSession() returned nil session")
}
// Test that we can store arbitrary data
session.Values["test_key"] = "test_value"
if val, ok := session.Values["test_key"].(string); !ok || val != "test_value" {
t.Error("Failed to store and retrieve value from session")
}
}
func TestSessionService_GetNewSession(t *testing.T) {
config := SessionServiceConfig{
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
SecureCookies: false,
}
service := NewSessionService(config)
req := httptest.NewRequest("GET", "/", nil)
session, err := service.GetNewSession(req)
if err != nil {
t.Fatalf("GetNewSession() failed: %v", err)
}
if session == nil {
t.Fatal("GetNewSession() returned nil session")
}
if !session.IsNew {
t.Error("GetNewSession() should return a new session")
}
}
func TestSessionService_StoreRefreshToken(t *testing.T) {
mockRepo := newMockSessionRepository()
config := SessionServiceConfig{
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
SecureCookies: false,
SessionRepo: mockRepo,
}
service := NewSessionService(config)
testUser := &models.User{
Sub: "test-user-123",
Email: "test@example.com",
Name: "Test User",
}
testToken := &oauth2.Token{
AccessToken: "access-token-123",
RefreshToken: "refresh-token-456",
Expiry: time.Now().Add(1 * time.Hour),
}
t.Run("successful storage", func(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
err := service.StoreRefreshToken(context.Background(), rec, req, testToken, testUser)
if err != nil {
t.Fatalf("StoreRefreshToken() failed: %v", err)
}
// Verify session was created in repository
if len(mockRepo.sessions) == 0 {
t.Error("No sessions were created in repository")
}
// Verify session contains encrypted refresh token
var session *models.OAuthSession
for _, s := range mockRepo.sessions {
if s.UserSub == testUser.Sub {
session = s
break
}
}
if session == nil {
t.Fatal("Session not found in repository")
}
if len(session.RefreshTokenEncrypted) == 0 {
t.Error("Refresh token was not encrypted and stored")
}
if session.SessionID == "" {
t.Error("Session ID should not be empty")
}
if session.AccessTokenExpiresAt.IsZero() {
t.Error("Access token expiry should be set")
}
})
t.Run("without repository", func(t *testing.T) {
configNoRepo := SessionServiceConfig{
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
SecureCookies: false,
SessionRepo: nil,
}
serviceNoRepo := NewSessionService(configNoRepo)
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
err := serviceNoRepo.StoreRefreshToken(context.Background(), rec, req, testToken, testUser)
if err == nil {
t.Error("StoreRefreshToken() should fail without repository")
}
})
}
func TestGetClientIP(t *testing.T) {
tests := []struct {
name string
remoteAddr string
xForwardedFor string
xRealIP string
expectedIP string
}{
{
name: "from RemoteAddr",
remoteAddr: "192.168.1.100:12345",
expectedIP: "192.168.1.100:12345",
},
{
name: "from X-Real-IP",
remoteAddr: "192.168.1.100:12345",
xRealIP: "203.0.113.45",
expectedIP: "203.0.113.45",
},
{
name: "from X-Forwarded-For single",
remoteAddr: "192.168.1.100:12345",
xForwardedFor: "203.0.113.45",
expectedIP: "203.0.113.45",
},
{
name: "from X-Forwarded-For multiple",
remoteAddr: "192.168.1.100:12345",
xForwardedFor: "203.0.113.45, 198.51.100.67, 192.0.2.123",
expectedIP: "203.0.113.45",
},
{
name: "X-Forwarded-For takes precedence over X-Real-IP",
remoteAddr: "192.168.1.100:12345",
xForwardedFor: "203.0.113.45",
xRealIP: "198.51.100.67",
expectedIP: "203.0.113.45",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.RemoteAddr = tt.remoteAddr
if tt.xForwardedFor != "" {
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
}
if tt.xRealIP != "" {
req.Header.Set("X-Real-IP", tt.xRealIP)
}
ip := getClientIP(req)
if ip != tt.expectedIP {
t.Errorf("getClientIP() = %v, expected %v", ip, tt.expectedIP)
}
})
}
}
func TestGenerateSessionID(t *testing.T) {
// Generate multiple session IDs and verify they're unique
ids := make(map[string]bool)
for i := 0; i < 100; i++ {
id := generateSessionID()
if id == "" {
t.Error("generateSessionID() returned empty string")
}
if ids[id] {
t.Errorf("generateSessionID() generated duplicate ID: %s", id)
}
ids[id] = true
}
if len(ids) != 100 {
t.Errorf("Expected 100 unique IDs, got %d", len(ids))
}
}

View File

@@ -0,0 +1,406 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package auth
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
)
// mockSessionRepo implements SessionRepository for testing
type mockSessionRepoForWorker struct {
mu sync.Mutex
deleteExpiredFn func(ctx context.Context, olderThan time.Duration) (int64, error)
callCount int
}
func (m *mockSessionRepoForWorker) DeleteExpired(ctx context.Context, olderThan time.Duration) (int64, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.callCount++
if m.deleteExpiredFn != nil {
return m.deleteExpiredFn(ctx, olderThan)
}
return 0, nil
}
func (m *mockSessionRepoForWorker) GetCallCount() int {
m.mu.Lock()
defer m.mu.Unlock()
return m.callCount
}
// Implement other SessionRepository methods (not used by worker)
func (m *mockSessionRepoForWorker) Create(ctx context.Context, session *models.OAuthSession) error {
return nil
}
func (m *mockSessionRepoForWorker) GetBySessionID(ctx context.Context, sessionID string) (*models.OAuthSession, error) {
return nil, nil
}
func (m *mockSessionRepoForWorker) UpdateRefreshToken(ctx context.Context, sessionID string, encryptedToken []byte, expiresAt time.Time) error {
return nil
}
func (m *mockSessionRepoForWorker) DeleteBySessionID(ctx context.Context, sessionID string) error {
return nil
}
func TestDefaultSessionWorkerConfig(t *testing.T) {
config := DefaultSessionWorkerConfig()
expectedInterval := 24 * time.Hour
if config.CleanupInterval != expectedInterval {
t.Errorf("CleanupInterval = %v, expected %v", config.CleanupInterval, expectedInterval)
}
expectedAge := 37 * 24 * time.Hour
if config.CleanupAge != expectedAge {
t.Errorf("CleanupAge = %v, expected %v", config.CleanupAge, expectedAge)
}
}
func TestNewSessionWorker(t *testing.T) {
repo := &mockSessionRepoForWorker{}
tests := []struct {
name string
config SessionWorkerConfig
expectedInterval time.Duration
expectedAge time.Duration
}{
{
name: "custom config",
config: SessionWorkerConfig{
CleanupInterval: 1 * time.Hour,
CleanupAge: 48 * time.Hour,
},
expectedInterval: 1 * time.Hour,
expectedAge: 48 * time.Hour,
},
{
name: "zero values should use defaults",
config: SessionWorkerConfig{
CleanupInterval: 0,
CleanupAge: 0,
},
expectedInterval: 24 * time.Hour,
expectedAge: 37 * 24 * time.Hour,
},
{
name: "negative values should use defaults",
config: SessionWorkerConfig{
CleanupInterval: -1 * time.Hour,
CleanupAge: -1 * time.Hour,
},
expectedInterval: 24 * time.Hour,
expectedAge: 37 * 24 * time.Hour,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
worker := NewSessionWorker(repo, tt.config)
if worker == nil {
t.Fatal("NewSessionWorker() returned nil")
}
if worker.cleanupInterval != tt.expectedInterval {
t.Errorf("cleanupInterval = %v, expected %v", worker.cleanupInterval, tt.expectedInterval)
}
if worker.cleanupAge != tt.expectedAge {
t.Errorf("cleanupAge = %v, expected %v", worker.cleanupAge, tt.expectedAge)
}
if worker.sessionRepo == nil {
t.Error("sessionRepo not set correctly")
}
if worker.started {
t.Error("Worker should not be started initially")
}
if worker.ctx == nil {
t.Error("Context should be initialized")
}
if worker.cancel == nil {
t.Error("Cancel function should be initialized")
}
if worker.stopChan == nil {
t.Error("Stop channel should be initialized")
}
})
}
}
func TestSessionWorker_StartStop(t *testing.T) {
repo := &mockSessionRepoForWorker{}
config := SessionWorkerConfig{
CleanupInterval: 100 * time.Millisecond,
CleanupAge: 1 * time.Hour,
}
worker := NewSessionWorker(repo, config)
// Test starting
err := worker.Start()
if err != nil {
t.Fatalf("Start() failed: %v", err)
}
if !worker.started {
t.Error("Worker should be marked as started")
}
// Test starting again should fail
err = worker.Start()
if err == nil {
t.Error("Starting already started worker should return error")
}
// Wait a bit for cleanup to run
time.Sleep(150 * time.Millisecond)
// Verify cleanup was called at least once
if repo.GetCallCount() < 1 {
t.Error("Cleanup should have been called at least once")
}
// Test stopping
err = worker.Stop()
if err != nil {
t.Errorf("Stop() failed: %v", err)
}
if worker.started {
t.Error("Worker should be marked as stopped")
}
// Test stopping again should fail
err = worker.Stop()
if err == nil {
t.Error("Stopping already stopped worker should return error")
}
}
func TestSessionWorker_CleanupSuccess(t *testing.T) {
deletedCount := int64(0)
repo := &mockSessionRepoForWorker{
deleteExpiredFn: func(ctx context.Context, olderThan time.Duration) (int64, error) {
deletedCount++
return deletedCount, nil
},
}
config := SessionWorkerConfig{
CleanupInterval: 50 * time.Millisecond,
CleanupAge: 24 * time.Hour,
}
worker := NewSessionWorker(repo, config)
err := worker.Start()
if err != nil {
t.Fatalf("Start() failed: %v", err)
}
// Wait for multiple cleanup cycles
time.Sleep(120 * time.Millisecond)
err = worker.Stop()
if err != nil {
t.Errorf("Stop() failed: %v", err)
}
// Should have been called at least twice (immediate + at least one tick)
if repo.GetCallCount() < 2 {
t.Errorf("Cleanup called %d times, expected at least 2", repo.GetCallCount())
}
}
func TestSessionWorker_CleanupError(t *testing.T) {
testError := errors.New("database error")
repo := &mockSessionRepoForWorker{
deleteExpiredFn: func(ctx context.Context, olderThan time.Duration) (int64, error) {
return 0, testError
},
}
config := SessionWorkerConfig{
CleanupInterval: 50 * time.Millisecond,
CleanupAge: 24 * time.Hour,
}
worker := NewSessionWorker(repo, config)
err := worker.Start()
if err != nil {
t.Fatalf("Start() failed: %v", err)
}
// Wait for cleanup to run
time.Sleep(100 * time.Millisecond)
err = worker.Stop()
if err != nil {
t.Errorf("Stop() failed: %v", err)
}
// Worker should continue despite errors
if repo.GetCallCount() < 1 {
t.Error("Cleanup should have been attempted despite errors")
}
}
func TestSessionWorker_ImmediateCleanupOnStart(t *testing.T) {
repo := &mockSessionRepoForWorker{
deleteExpiredFn: func(ctx context.Context, olderThan time.Duration) (int64, error) {
return 5, nil
},
}
// Use a long interval so we only get the immediate cleanup
config := SessionWorkerConfig{
CleanupInterval: 1 * time.Hour,
CleanupAge: 24 * time.Hour,
}
worker := NewSessionWorker(repo, config)
err := worker.Start()
if err != nil {
t.Fatalf("Start() failed: %v", err)
}
// Give it a moment to run the immediate cleanup
time.Sleep(50 * time.Millisecond)
err = worker.Stop()
if err != nil {
t.Errorf("Stop() failed: %v", err)
}
// Should have been called exactly once (immediate cleanup only)
if repo.GetCallCount() != 1 {
t.Errorf("Cleanup called %d times, expected exactly 1 (immediate cleanup)", repo.GetCallCount())
}
}
func TestSessionWorker_GracefulShutdown(t *testing.T) {
// Create a repo that takes time to cleanup
cleanupRunning := false
var mu sync.Mutex
repo := &mockSessionRepoForWorker{
deleteExpiredFn: func(ctx context.Context, olderThan time.Duration) (int64, error) {
mu.Lock()
cleanupRunning = true
mu.Unlock()
time.Sleep(50 * time.Millisecond)
mu.Lock()
cleanupRunning = false
mu.Unlock()
return 1, nil
},
}
config := SessionWorkerConfig{
CleanupInterval: 200 * time.Millisecond,
CleanupAge: 1 * time.Hour,
}
worker := NewSessionWorker(repo, config)
err := worker.Start()
if err != nil {
t.Fatalf("Start() failed: %v", err)
}
// Wait for immediate cleanup to start
time.Sleep(10 * time.Millisecond)
// Verify cleanup is running
mu.Lock()
if !cleanupRunning {
mu.Unlock()
t.Skip("Cleanup not running when expected, test timing issue")
return
}
mu.Unlock()
// Stop should wait for ongoing cleanup
start := time.Now()
err = worker.Stop()
duration := time.Since(start)
if err != nil {
t.Errorf("Stop() failed: %v", err)
}
// Cleanup should have finished
mu.Lock()
stillRunning := cleanupRunning
mu.Unlock()
if stillRunning {
t.Error("Cleanup still running after Stop()")
}
// Stop should have waited at least some time for cleanup
if duration < 5*time.Millisecond {
t.Logf("Stop() returned very quickly (%v), but cleanup finished cleanly", duration)
}
if duration > 10*time.Second {
t.Error("Stop() took too long, might be hanging")
}
}
func TestSessionWorker_ContextCancellation(t *testing.T) {
cleanupCalled := false
repo := &mockSessionRepoForWorker{
deleteExpiredFn: func(ctx context.Context, olderThan time.Duration) (int64, error) {
cleanupCalled = true
// Check if context is cancelled during cleanup
select {
case <-ctx.Done():
return 0, ctx.Err()
default:
return 0, nil
}
},
}
config := SessionWorkerConfig{
CleanupInterval: 1 * time.Hour, // Long interval
CleanupAge: 1 * time.Hour,
}
worker := NewSessionWorker(repo, config)
err := worker.Start()
if err != nil {
t.Fatalf("Start() failed: %v", err)
}
// Wait for immediate cleanup
time.Sleep(50 * time.Millisecond)
if !cleanupCalled {
t.Error("Cleanup should have been called")
}
err = worker.Stop()
if err != nil {
t.Errorf("Stop() failed: %v", err)
}
}

View File

@@ -13,6 +13,7 @@ import (
type Config struct {
App AppConfig
Database DatabaseConfig
Auth AuthConfig
OAuth OAuthConfig
Server ServerConfig
Logger LoggerConfig
@@ -20,6 +21,11 @@ type Config struct {
Checksum ChecksumConfig
}
type AuthConfig struct {
OAuthEnabled bool
MagicLinkEnabled bool
}
type AppConfig struct {
BaseURL string
Organisation string
@@ -88,39 +94,54 @@ func Load() (*Config, error) {
config.Database.DSN = mustGetEnv("ACKIFY_DB_DSN")
config.OAuth.ClientID = mustGetEnv("ACKIFY_OAUTH_CLIENT_ID")
config.OAuth.ClientSecret = mustGetEnv("ACKIFY_OAUTH_CLIENT_SECRET")
config.OAuth.AllowedDomain = os.Getenv("ACKIFY_OAUTH_ALLOWED_DOMAIN")
config.OAuth.AutoLogin = strings.ToLower(getEnv("ACKIFY_OAUTH_AUTO_LOGIN", "false")) == "true"
// OAuth configuration - now OPTIONAL
config.OAuth.ClientID = getEnv("ACKIFY_OAUTH_CLIENT_ID", "")
config.OAuth.ClientSecret = getEnv("ACKIFY_OAUTH_CLIENT_SECRET", "")
config.OAuth.AllowedDomain = getEnv("ACKIFY_OAUTH_ALLOWED_DOMAIN", "")
config.OAuth.AutoLogin = getEnvBool("ACKIFY_OAUTH_AUTO_LOGIN", false)
provider := strings.ToLower(getEnv("ACKIFY_OAUTH_PROVIDER", ""))
switch provider {
case "google":
config.OAuth.AuthURL = "https://accounts.google.com/o/oauth2/auth"
config.OAuth.TokenURL = "https://oauth2.googleapis.com/token"
config.OAuth.UserInfoURL = "https://openidconnect.googleapis.com/v1/userinfo"
config.OAuth.LogoutURL = "https://accounts.google.com/Logout"
config.OAuth.Scopes = []string{"openid", "email", "profile"}
case "github":
config.OAuth.AuthURL = "https://github.com/login/oauth/authorize"
config.OAuth.TokenURL = "https://github.com/login/oauth/access_token"
config.OAuth.UserInfoURL = "https://api.github.com/user"
config.OAuth.LogoutURL = "https://github.com/logout"
config.OAuth.Scopes = []string{"user:email", "read:user"}
case "gitlab":
gitlabURL := getEnv("ACKIFY_OAUTH_GITLAB_URL", "https://gitlab.com")
config.OAuth.AuthURL = fmt.Sprintf("%s/oauth/authorize", gitlabURL)
config.OAuth.TokenURL = fmt.Sprintf("%s/oauth/token", gitlabURL)
config.OAuth.UserInfoURL = fmt.Sprintf("%s/api/v4/user", gitlabURL)
config.OAuth.LogoutURL = fmt.Sprintf("%s/users/sign_out", gitlabURL)
config.OAuth.Scopes = []string{"read_user", "profile"}
default:
config.OAuth.AuthURL = mustGetEnv("ACKIFY_OAUTH_AUTH_URL")
config.OAuth.TokenURL = mustGetEnv("ACKIFY_OAUTH_TOKEN_URL")
config.OAuth.UserInfoURL = mustGetEnv("ACKIFY_OAUTH_USERINFO_URL")
config.OAuth.LogoutURL = getEnv("ACKIFY_OAUTH_LOGOUT_URL", "")
scopesStr := getEnv("ACKIFY_OAUTH_SCOPES", "openid,email,profile")
config.OAuth.Scopes = strings.Split(scopesStr, ",")
// Auto-detect OAuth enabled: true if ClientID and ClientSecret are provided
oauthConfigured := config.OAuth.ClientID != "" && config.OAuth.ClientSecret != ""
// Allow manual override via environment variable
if oauthEnabledStr := getEnv("ACKIFY_AUTH_OAUTH_ENABLED", ""); oauthEnabledStr != "" {
config.Auth.OAuthEnabled = getEnvBool("ACKIFY_AUTH_OAUTH_ENABLED", false)
} else {
config.Auth.OAuthEnabled = oauthConfigured
}
// Only configure OAuth URLs if OAuth is enabled
if config.Auth.OAuthEnabled {
provider := strings.ToLower(getEnv("ACKIFY_OAUTH_PROVIDER", ""))
switch provider {
case "google":
config.OAuth.AuthURL = "https://accounts.google.com/o/oauth2/auth"
config.OAuth.TokenURL = "https://oauth2.googleapis.com/token"
config.OAuth.UserInfoURL = "https://openidconnect.googleapis.com/v1/userinfo"
config.OAuth.LogoutURL = "https://accounts.google.com/Logout"
config.OAuth.Scopes = []string{"openid", "email", "profile"}
case "github":
config.OAuth.AuthURL = "https://github.com/login/oauth/authorize"
config.OAuth.TokenURL = "https://github.com/login/oauth/access_token"
config.OAuth.UserInfoURL = "https://api.github.com/user"
config.OAuth.LogoutURL = "https://github.com/logout"
config.OAuth.Scopes = []string{"user:email", "read:user"}
case "gitlab":
gitlabURL := getEnv("ACKIFY_OAUTH_GITLAB_URL", "https://gitlab.com")
config.OAuth.AuthURL = fmt.Sprintf("%s/oauth/authorize", gitlabURL)
config.OAuth.TokenURL = fmt.Sprintf("%s/oauth/token", gitlabURL)
config.OAuth.UserInfoURL = fmt.Sprintf("%s/api/v4/user", gitlabURL)
config.OAuth.LogoutURL = fmt.Sprintf("%s/users/sign_out", gitlabURL)
config.OAuth.Scopes = []string{"read_user", "profile"}
default:
// Custom OAuth provider - require URLs
config.OAuth.AuthURL = mustGetEnv("ACKIFY_OAUTH_AUTH_URL")
config.OAuth.TokenURL = mustGetEnv("ACKIFY_OAUTH_TOKEN_URL")
config.OAuth.UserInfoURL = mustGetEnv("ACKIFY_OAUTH_USERINFO_URL")
config.OAuth.LogoutURL = getEnv("ACKIFY_OAUTH_LOGOUT_URL", "")
scopesStr := getEnv("ACKIFY_OAUTH_SCOPES", "openid,email,profile")
config.OAuth.Scopes = strings.Split(scopesStr, ",")
}
}
cookieSecret, err := parseCookieSecret()
@@ -180,6 +201,21 @@ func Load() (*Config, error) {
}
}
// Auto-detect MagicLink enabled: true if SMTP is configured
magicLinkConfigured := mailHost != ""
// Allow manual override via environment variable
if magicLinkEnabledStr := getEnv("ACKIFY_AUTH_MAGICLINK_ENABLED", ""); magicLinkEnabledStr != "" {
config.Auth.MagicLinkEnabled = getEnvBool("ACKIFY_AUTH_MAGICLINK_ENABLED", false)
} else {
config.Auth.MagicLinkEnabled = magicLinkConfigured
}
// Validation: At least one authentication method must be enabled
if !config.Auth.OAuthEnabled && !config.Auth.MagicLinkEnabled {
return nil, fmt.Errorf("at least one authentication method must be enabled: set ACKIFY_OAUTH_CLIENT_ID/CLIENT_SECRET for OAuth or ACKIFY_MAIL_HOST for MagicLink")
}
return config, nil
}

View File

@@ -556,8 +556,6 @@ func TestLoad_MissingRequiredEnvironmentVariables(t *testing.T) {
"ACKIFY_BASE_URL",
"ACKIFY_ORGANISATION",
"ACKIFY_DB_DSN",
"ACKIFY_OAUTH_CLIENT_ID",
"ACKIFY_OAUTH_CLIENT_SECRET",
}
for _, missingVar := range requiredVars {
@@ -1203,3 +1201,232 @@ func TestGetEnvBool(t *testing.T) {
})
}
}
func TestConfig_AuthValidation(t *testing.T) {
// Save original env vars
origOAuthClientID := os.Getenv("ACKIFY_OAUTH_CLIENT_ID")
origOAuthClientSecret := os.Getenv("ACKIFY_OAUTH_CLIENT_SECRET")
origMailHost := os.Getenv("ACKIFY_MAIL_HOST")
origAuthOAuthEnabled := os.Getenv("ACKIFY_AUTH_OAUTH_ENABLED")
origAuthMagicLinkEnabled := os.Getenv("ACKIFY_AUTH_MAGICLINK_ENABLED")
origBaseURL := os.Getenv("ACKIFY_BASE_URL")
origOrg := os.Getenv("ACKIFY_ORGANISATION")
origDBDSN := os.Getenv("ACKIFY_DB_DSN")
origCookieSecret := os.Getenv("ACKIFY_OAUTH_COOKIE_SECRET")
// Cleanup function
defer func() {
os.Setenv("ACKIFY_OAUTH_CLIENT_ID", origOAuthClientID)
os.Setenv("ACKIFY_OAUTH_CLIENT_SECRET", origOAuthClientSecret)
os.Setenv("ACKIFY_MAIL_HOST", origMailHost)
os.Setenv("ACKIFY_AUTH_OAUTH_ENABLED", origAuthOAuthEnabled)
os.Setenv("ACKIFY_AUTH_MAGICLINK_ENABLED", origAuthMagicLinkEnabled)
os.Setenv("ACKIFY_BASE_URL", origBaseURL)
os.Setenv("ACKIFY_ORGANISATION", origOrg)
os.Setenv("ACKIFY_DB_DSN", origDBDSN)
os.Setenv("ACKIFY_OAUTH_COOKIE_SECRET", origCookieSecret)
}()
tests := []struct {
name string
envVars map[string]string
expectError bool
errorContains string
checkAuth func(*testing.T, *Config)
}{
{
name: "OAuth only (auto-detected)",
envVars: map[string]string{
"ACKIFY_BASE_URL": "http://localhost:8080",
"ACKIFY_ORGANISATION": "Test Org",
"ACKIFY_DB_DSN": "postgres://localhost/test",
"ACKIFY_OAUTH_COOKIE_SECRET": base64.StdEncoding.EncodeToString([]byte("test-secret-32-bytes-long!!!!!!")),
"ACKIFY_OAUTH_CLIENT_ID": "test-client-id",
"ACKIFY_OAUTH_CLIENT_SECRET": "test-secret",
"ACKIFY_OAUTH_PROVIDER": "google",
},
expectError: false,
checkAuth: func(t *testing.T, cfg *Config) {
if !cfg.Auth.OAuthEnabled {
t.Error("OAuth should be enabled")
}
if cfg.Auth.MagicLinkEnabled {
t.Error("MagicLink should be disabled")
}
},
},
{
name: "MagicLink only (auto-detected)",
envVars: map[string]string{
"ACKIFY_BASE_URL": "http://localhost:8080",
"ACKIFY_ORGANISATION": "Test Org",
"ACKIFY_DB_DSN": "postgres://localhost/test",
"ACKIFY_OAUTH_COOKIE_SECRET": base64.StdEncoding.EncodeToString([]byte("test-secret-32-bytes-long!!!!!!")),
"ACKIFY_MAIL_HOST": "smtp.example.com",
},
expectError: false,
checkAuth: func(t *testing.T, cfg *Config) {
if cfg.Auth.OAuthEnabled {
t.Error("OAuth should be disabled")
}
if !cfg.Auth.MagicLinkEnabled {
t.Error("MagicLink should be enabled")
}
},
},
{
name: "Both OAuth and MagicLink enabled",
envVars: map[string]string{
"ACKIFY_BASE_URL": "http://localhost:8080",
"ACKIFY_ORGANISATION": "Test Org",
"ACKIFY_DB_DSN": "postgres://localhost/test",
"ACKIFY_OAUTH_COOKIE_SECRET": base64.StdEncoding.EncodeToString([]byte("test-secret-32-bytes-long!!!!!!")),
"ACKIFY_OAUTH_CLIENT_ID": "test-client-id",
"ACKIFY_OAUTH_CLIENT_SECRET": "test-secret",
"ACKIFY_OAUTH_PROVIDER": "google",
"ACKIFY_MAIL_HOST": "smtp.example.com",
},
expectError: false,
checkAuth: func(t *testing.T, cfg *Config) {
if !cfg.Auth.OAuthEnabled {
t.Error("OAuth should be enabled")
}
if !cfg.Auth.MagicLinkEnabled {
t.Error("MagicLink should be enabled")
}
},
},
{
name: "No authentication method (should fail)",
envVars: map[string]string{
"ACKIFY_BASE_URL": "http://localhost:8080",
"ACKIFY_ORGANISATION": "Test Org",
"ACKIFY_DB_DSN": "postgres://localhost/test",
"ACKIFY_OAUTH_COOKIE_SECRET": base64.StdEncoding.EncodeToString([]byte("test-secret-32-bytes-long!!!!!!")),
},
expectError: true,
errorContains: "at least one authentication method must be enabled",
},
{
name: "Manual override - OAuth enabled despite missing client ID",
envVars: map[string]string{
"ACKIFY_BASE_URL": "http://localhost:8080",
"ACKIFY_ORGANISATION": "Test Org",
"ACKIFY_DB_DSN": "postgres://localhost/test",
"ACKIFY_OAUTH_COOKIE_SECRET": base64.StdEncoding.EncodeToString([]byte("test-secret-32-bytes-long!!!!!!")),
"ACKIFY_AUTH_OAUTH_ENABLED": "true",
"ACKIFY_OAUTH_PROVIDER": "google",
},
expectError: false,
checkAuth: func(t *testing.T, cfg *Config) {
if !cfg.Auth.OAuthEnabled {
t.Error("OAuth should be force-enabled via ACKIFY_AUTH_OAUTH_ENABLED")
}
},
},
{
name: "Manual override - disable OAuth even with credentials",
envVars: map[string]string{
"ACKIFY_BASE_URL": "http://localhost:8080",
"ACKIFY_ORGANISATION": "Test Org",
"ACKIFY_DB_DSN": "postgres://localhost/test",
"ACKIFY_OAUTH_COOKIE_SECRET": base64.StdEncoding.EncodeToString([]byte("test-secret-32-bytes-long!!!!!!")),
"ACKIFY_OAUTH_CLIENT_ID": "test-client-id",
"ACKIFY_OAUTH_CLIENT_SECRET": "test-secret",
"ACKIFY_MAIL_HOST": "smtp.example.com",
"ACKIFY_AUTH_OAUTH_ENABLED": "false",
},
expectError: false,
checkAuth: func(t *testing.T, cfg *Config) {
if cfg.Auth.OAuthEnabled {
t.Error("OAuth should be disabled via ACKIFY_AUTH_OAUTH_ENABLED=false")
}
if !cfg.Auth.MagicLinkEnabled {
t.Error("MagicLink should still be enabled")
}
},
},
{
name: "Manual override - disable MagicLink even with SMTP configured",
envVars: map[string]string{
"ACKIFY_BASE_URL": "http://localhost:8080",
"ACKIFY_ORGANISATION": "Test Org",
"ACKIFY_DB_DSN": "postgres://localhost/test",
"ACKIFY_OAUTH_COOKIE_SECRET": base64.StdEncoding.EncodeToString([]byte("test-secret-32-bytes-long!!!!!!")),
"ACKIFY_OAUTH_CLIENT_ID": "test-client-id",
"ACKIFY_OAUTH_CLIENT_SECRET": "test-secret",
"ACKIFY_OAUTH_PROVIDER": "google",
"ACKIFY_MAIL_HOST": "smtp.example.com",
"ACKIFY_AUTH_MAGICLINK_ENABLED": "false",
},
expectError: false,
checkAuth: func(t *testing.T, cfg *Config) {
if !cfg.Auth.OAuthEnabled {
t.Error("OAuth should still be enabled")
}
if cfg.Auth.MagicLinkEnabled {
t.Error("MagicLink should be disabled via ACKIFY_AUTH_MAGICLINK_ENABLED=false")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Clear all auth-related env vars
os.Unsetenv("ACKIFY_OAUTH_CLIENT_ID")
os.Unsetenv("ACKIFY_OAUTH_CLIENT_SECRET")
os.Unsetenv("ACKIFY_OAUTH_PROVIDER")
os.Unsetenv("ACKIFY_MAIL_HOST")
os.Unsetenv("ACKIFY_AUTH_OAUTH_ENABLED")
os.Unsetenv("ACKIFY_AUTH_MAGICLINK_ENABLED")
os.Unsetenv("ACKIFY_BASE_URL")
os.Unsetenv("ACKIFY_ORGANISATION")
os.Unsetenv("ACKIFY_DB_DSN")
os.Unsetenv("ACKIFY_OAUTH_COOKIE_SECRET")
// Set test env vars
for k, v := range tt.envVars {
os.Setenv(k, v)
}
// Try to load config
cfg, err := Load()
// Check error expectation
if tt.expectError {
if err == nil {
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
return
}
if tt.errorContains != "" && !contains(err.Error(), tt.errorContains) {
t.Errorf("Expected error containing '%s', got: %v", tt.errorContains, err)
}
return
}
// Should not error
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// Run auth check if provided
if tt.checkAuth != nil {
tt.checkAuth(t, cfg)
}
})
}
}
// Helper function to check if string contains substring
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || containsMiddle(s, substr)))
}
func containsMiddle(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

View File

@@ -0,0 +1,163 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package database
import (
"context"
"database/sql"
"time"
"github.com/btouchard/ackify-ce/backend/internal/application/services"
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
)
type magicLinkRepo struct {
db *sql.DB
}
func NewMagicLinkRepository(db *sql.DB) services.MagicLinkRepository {
return &magicLinkRepo{db: db}
}
func (r *magicLinkRepo) CreateToken(ctx context.Context, token *models.MagicLinkToken) error {
query := `
INSERT INTO magic_link_tokens
(token, email, expires_at, redirect_to, created_by_ip, created_by_user_agent)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, created_at
`
return r.db.QueryRowContext(ctx, query,
token.Token,
token.Email,
token.ExpiresAt,
token.RedirectTo,
token.CreatedByIP,
token.CreatedByUserAgent,
).Scan(&token.ID, &token.CreatedAt)
}
func (r *magicLinkRepo) GetByToken(ctx context.Context, token string) (*models.MagicLinkToken, error) {
query := `
SELECT id, token, email, created_at, expires_at, used_at, used_by_ip,
used_by_user_agent, redirect_to, created_by_ip, created_by_user_agent
FROM magic_link_tokens
WHERE token = $1
`
var t models.MagicLinkToken
var usedAt sql.NullTime
var usedByIP, usedByUserAgent sql.NullString
err := r.db.QueryRowContext(ctx, query, token).Scan(
&t.ID,
&t.Token,
&t.Email,
&t.CreatedAt,
&t.ExpiresAt,
&usedAt,
&usedByIP,
&usedByUserAgent,
&t.RedirectTo,
&t.CreatedByIP,
&t.CreatedByUserAgent,
)
if err == sql.ErrNoRows {
return nil, err
}
if err != nil {
return nil, err
}
if usedAt.Valid {
t.UsedAt = &usedAt.Time
}
if usedByIP.Valid {
t.UsedByIP = &usedByIP.String
}
if usedByUserAgent.Valid {
t.UsedByUserAgent = &usedByUserAgent.String
}
return &t, nil
}
func (r *magicLinkRepo) MarkAsUsed(ctx context.Context, token string, ip string, userAgent string) error {
query := `
UPDATE magic_link_tokens
SET used_at = now(),
used_by_ip = $2,
used_by_user_agent = $3
WHERE token = $1 AND used_at IS NULL
`
result, err := r.db.ExecContext(ctx, query, token, ip, userAgent)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return sql.ErrNoRows
}
return nil
}
func (r *magicLinkRepo) DeleteExpired(ctx context.Context) (int64, error) {
query := `
DELETE FROM magic_link_tokens
WHERE expires_at < now() OR (created_at < now() - INTERVAL '7 days' AND used_at IS NULL)
`
result, err := r.db.ExecContext(ctx, query)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
func (r *magicLinkRepo) LogAttempt(ctx context.Context, attempt *models.MagicLinkAuthAttempt) error {
query := `
INSERT INTO magic_link_auth_attempts
(email, success, failure_reason, ip_address, user_agent)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, attempted_at
`
return r.db.QueryRowContext(ctx, query,
attempt.Email,
attempt.Success,
attempt.FailureReason,
attempt.IPAddress,
attempt.UserAgent,
).Scan(&attempt.ID, &attempt.AttemptedAt)
}
func (r *magicLinkRepo) CountRecentAttempts(ctx context.Context, email string, since time.Time) (int, error) {
var count int
query := `
SELECT COUNT(*)
FROM magic_link_auth_attempts
WHERE email = $1 AND attempted_at > $2
`
err := r.db.QueryRowContext(ctx, query, email, since).Scan(&count)
return count, err
}
func (r *magicLinkRepo) CountRecentAttemptsByIP(ctx context.Context, ip string, since time.Time) (int, error) {
var count int
query := `
SELECT COUNT(*)
FROM magic_link_auth_attempts
WHERE ip_address = $1 AND attempted_at > $2
`
err := r.db.QueryRowContext(ctx, query, ip, since).Scan(&count)
return count, err
}

View File

@@ -0,0 +1,314 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
//go:build integration
package database
import (
"context"
"database/sql"
"testing"
"time"
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
)
func TestMagicLinkRepository_CreateToken(t *testing.T) {
testDB := SetupTestDB(t)
repo := NewMagicLinkRepository(testDB.DB)
ctx := context.Background()
token := &models.MagicLinkToken{
Token: "test-token-123",
Email: "test@example.com",
ExpiresAt: time.Now().Add(15 * time.Minute),
RedirectTo: "/dashboard",
CreatedByIP: "192.168.1.1",
CreatedByUserAgent: "Mozilla/5.0",
}
err := repo.CreateToken(ctx, token)
if err != nil {
t.Fatalf("Failed to create token: %v", err)
}
if token.ID == 0 {
t.Error("Expected token ID to be set")
}
if token.CreatedAt.IsZero() {
t.Error("Expected created_at to be set")
}
}
func TestMagicLinkRepository_GetByToken(t *testing.T) {
testDB := SetupTestDB(t)
repo := NewMagicLinkRepository(testDB.DB)
ctx := context.Background()
// Créer un token
original := &models.MagicLinkToken{
Token: "test-token-456",
Email: "user@example.com",
ExpiresAt: time.Now().Add(15 * time.Minute),
RedirectTo: "/",
CreatedByIP: "10.0.0.1",
CreatedByUserAgent: "Chrome",
}
err := repo.CreateToken(ctx, original)
if err != nil {
t.Fatalf("Failed to create token: %v", err)
}
// Récupérer le token
retrieved, err := repo.GetByToken(ctx, "test-token-456")
if err != nil {
t.Fatalf("Failed to get token: %v", err)
}
if retrieved.Email != original.Email {
t.Errorf("Expected email %s, got %s", original.Email, retrieved.Email)
}
if retrieved.UsedAt != nil {
t.Error("Expected token to not be used")
}
if !retrieved.IsValid() {
t.Error("Expected token to be valid")
}
}
func TestMagicLinkRepository_MarkAsUsed(t *testing.T) {
testDB := SetupTestDB(t)
repo := NewMagicLinkRepository(testDB.DB)
ctx := context.Background()
// Créer un token
token := &models.MagicLinkToken{
Token: "test-token-789",
Email: "mark@example.com",
ExpiresAt: time.Now().Add(15 * time.Minute),
RedirectTo: "/",
CreatedByIP: "10.0.0.2",
CreatedByUserAgent: "Firefox",
}
err := repo.CreateToken(ctx, token)
if err != nil {
t.Fatalf("Failed to create token: %v", err)
}
// Marquer comme utilisé
err = repo.MarkAsUsed(ctx, token.Token, "10.0.0.3", "Safari")
if err != nil {
t.Fatalf("Failed to mark token as used: %v", err)
}
// Vérifier que c'est bien marqué
retrieved, err := repo.GetByToken(ctx, token.Token)
if err != nil {
t.Fatalf("Failed to get token: %v", err)
}
if retrieved.UsedAt == nil {
t.Error("Expected token to be marked as used")
}
if retrieved.IsValid() {
t.Error("Expected token to be invalid after use")
}
// Tenter de marquer à nouveau (devrait échouer)
err = repo.MarkAsUsed(ctx, token.Token, "10.0.0.4", "Edge")
if err != sql.ErrNoRows {
t.Errorf("Expected ErrNoRows when marking already used token, got %v", err)
}
}
func TestMagicLinkRepository_DeleteExpired(t *testing.T) {
testDB := SetupTestDB(t)
repo := NewMagicLinkRepository(testDB.DB)
ctx := context.Background()
// Créer un token expiré
expiredToken := &models.MagicLinkToken{
Token: "expired-token",
Email: "expired@example.com",
ExpiresAt: time.Now().Add(-1 * time.Hour), // Expiré
RedirectTo: "/",
CreatedByIP: "10.0.0.1",
CreatedByUserAgent: "Test",
}
err := repo.CreateToken(ctx, expiredToken)
if err != nil {
t.Fatalf("Failed to create expired token: %v", err)
}
// Créer un token valide
validToken := &models.MagicLinkToken{
Token: "valid-token",
Email: "valid@example.com",
ExpiresAt: time.Now().Add(15 * time.Minute),
RedirectTo: "/",
CreatedByIP: "10.0.0.2",
CreatedByUserAgent: "Test",
}
err = repo.CreateToken(ctx, validToken)
if err != nil {
t.Fatalf("Failed to create valid token: %v", err)
}
// Supprimer les tokens expirés
deleted, err := repo.DeleteExpired(ctx)
if err != nil {
t.Fatalf("Failed to delete expired tokens: %v", err)
}
if deleted == 0 {
t.Error("Expected at least one token to be deleted")
}
// Vérifier que le token expiré est supprimé
_, err = repo.GetByToken(ctx, "expired-token")
if err != sql.ErrNoRows {
t.Error("Expected expired token to be deleted")
}
// Vérifier que le token valide existe toujours
_, err = repo.GetByToken(ctx, "valid-token")
if err != nil {
t.Error("Expected valid token to still exist")
}
}
func TestMagicLinkRepository_RateLimit(t *testing.T) {
testDB := SetupTestDB(t)
repo := NewMagicLinkRepository(testDB.DB)
ctx := context.Background()
email := "ratelimit@example.com"
ip := "192.168.1.100"
// Créer 5 tentatives
for i := 0; i < 5; i++ {
attempt := &models.MagicLinkAuthAttempt{
Email: email,
Success: true,
IPAddress: ip,
UserAgent: "Test",
}
err := repo.LogAttempt(ctx, attempt)
if err != nil {
t.Fatalf("Failed to log attempt: %v", err)
}
}
// Compter les tentatives récentes (dernière heure)
since := time.Now().Add(-1 * time.Hour)
count, err := repo.CountRecentAttempts(ctx, email, since)
if err != nil {
t.Fatalf("Failed to count attempts: %v", err)
}
if count != 5 {
t.Errorf("Expected 5 attempts, got %d", count)
}
// Compter par IP
countIP, err := repo.CountRecentAttemptsByIP(ctx, ip, since)
if err != nil {
t.Fatalf("Failed to count attempts by IP: %v", err)
}
if countIP != 5 {
t.Errorf("Expected 5 attempts by IP, got %d", countIP)
}
// Compter les tentatives anciennes (devrait être 0)
oldSince := time.Now().Add(-2 * time.Hour)
oldCount, err := repo.CountRecentAttempts(ctx, email, oldSince)
if err != nil {
t.Fatalf("Failed to count old attempts: %v", err)
}
if oldCount != 5 {
t.Errorf("Expected 5 old attempts, got %d", oldCount)
}
}
func TestMagicLinkRepository_LogAttempt(t *testing.T) {
testDB := SetupTestDB(t)
repo := NewMagicLinkRepository(testDB.DB)
ctx := context.Background()
attempt := &models.MagicLinkAuthAttempt{
Email: "test@example.com",
Success: true,
FailureReason: "",
IPAddress: "192.168.1.1",
UserAgent: "Mozilla/5.0",
}
err := repo.LogAttempt(ctx, attempt)
if err != nil {
t.Fatalf("Failed to log attempt: %v", err)
}
if attempt.ID == 0 {
t.Error("Expected attempt ID to be set")
}
if attempt.AttemptedAt.IsZero() {
t.Error("Expected attempted_at to be set")
}
}
func TestMagicLinkRepository_TokenExpiration(t *testing.T) {
testDB := SetupTestDB(t)
repo := NewMagicLinkRepository(testDB.DB)
ctx := context.Background()
// Créer un token qui expire dans 1 seconde
token := &models.MagicLinkToken{
Token: "expiring-token",
Email: "expiring@example.com",
ExpiresAt: time.Now().Add(1 * time.Second),
RedirectTo: "/",
CreatedByIP: "10.0.0.1",
CreatedByUserAgent: "Test",
}
err := repo.CreateToken(ctx, token)
if err != nil {
t.Fatalf("Failed to create token: %v", err)
}
// Vérifier que le token est valide
retrieved, err := repo.GetByToken(ctx, "expiring-token")
if err != nil {
t.Fatalf("Failed to get token: %v", err)
}
if !retrieved.IsValid() {
t.Error("Expected token to be valid")
}
// Attendre 2 secondes
time.Sleep(2 * time.Second)
// Récupérer à nouveau le token
retrieved, err = repo.GetByToken(ctx, "expiring-token")
if err != nil {
t.Fatalf("Failed to get token: %v", err)
}
// Vérifier que le token est maintenant invalide
if retrieved.IsValid() {
t.Error("Expected token to be invalid after expiration")
}
}

View File

@@ -0,0 +1,65 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package workers
import (
"context"
"time"
"github.com/btouchard/ackify-ce/backend/internal/application/services"
"github.com/btouchard/ackify-ce/backend/pkg/logger"
)
// MagicLinkCleanupWorker nettoie périodiquement les tokens expirés
type MagicLinkCleanupWorker struct {
service *services.MagicLinkService
interval time.Duration
stopChan chan struct{}
}
func NewMagicLinkCleanupWorker(service *services.MagicLinkService, interval time.Duration) *MagicLinkCleanupWorker {
if interval == 0 {
interval = 1 * time.Hour // Défaut: toutes les heures
}
return &MagicLinkCleanupWorker{
service: service,
interval: interval,
stopChan: make(chan struct{}),
}
}
func (w *MagicLinkCleanupWorker) Start(ctx context.Context) {
ticker := time.NewTicker(w.interval)
defer ticker.Stop()
logger.Logger.Info("Magic Link cleanup worker started", "interval", w.interval)
for {
select {
case <-ticker.C:
w.cleanup(ctx)
case <-w.stopChan:
logger.Logger.Info("Magic Link cleanup worker stopped")
return
case <-ctx.Done():
logger.Logger.Info("Magic Link cleanup worker context cancelled")
return
}
}
}
func (w *MagicLinkCleanupWorker) Stop() {
close(w.stopChan)
}
func (w *MagicLinkCleanupWorker) cleanup(ctx context.Context) {
deleted, err := w.service.CleanupExpiredTokens(ctx)
if err != nil {
logger.Logger.Error("Failed to cleanup expired magic link tokens", "error", err)
return
}
if deleted > 0 {
logger.Logger.Info("Cleaned up expired magic link tokens", "count", deleted)
}
}

View File

@@ -8,6 +8,7 @@ import (
"net/url"
"strings"
"github.com/btouchard/ackify-ce/backend/internal/application/services"
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/auth"
"github.com/btouchard/ackify-ce/backend/internal/presentation/api/shared"
"github.com/btouchard/ackify-ce/backend/internal/presentation/handlers"
@@ -16,17 +17,23 @@ import (
// Handler handles authentication API requests
type Handler struct {
authService *auth.OauthService
middleware *shared.Middleware
baseURL string
authService *auth.OauthService
magicLinkService *services.MagicLinkService
middleware *shared.Middleware
baseURL string
oauthEnabled bool
magicLinkEnabled bool
}
// NewHandler creates a new auth handler
func NewHandler(authService *auth.OauthService, middleware *shared.Middleware, baseURL string) *Handler {
func NewHandler(authService *auth.OauthService, magicLinkService *services.MagicLinkService, middleware *shared.Middleware, baseURL string, oauthEnabled bool, magicLinkEnabled bool) *Handler {
return &Handler{
authService: authService,
middleware: middleware,
baseURL: baseURL,
authService: authService,
magicLinkService: magicLinkService,
middleware: middleware,
baseURL: baseURL,
oauthEnabled: oauthEnabled,
magicLinkEnabled: magicLinkEnabled,
}
}
@@ -54,6 +61,15 @@ func (h *Handler) HandleGetCSRFToken(w http.ResponseWriter, r *http.Request) {
})
}
// HandleGetAuthConfig handles GET /api/v1/auth/config
// Returns available authentication methods
func (h *Handler) HandleGetAuthConfig(w http.ResponseWriter, r *http.Request) {
shared.WriteJSON(w, http.StatusOK, map[string]bool{
"oauth": h.oauthEnabled,
"magiclink": h.magicLinkEnabled,
})
}
// HandleStartOAuth handles POST /api/v1/auth/start
func (h *Handler) HandleStartOAuth(w http.ResponseWriter, r *http.Request) {
var req struct {

View File

@@ -98,7 +98,7 @@ func TestNewHandler(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
handler := NewHandler(tt.authService, tt.middleware, tt.baseURL)
handler := NewHandler(tt.authService, nil, tt.middleware, tt.baseURL, true, true)
assert.NotNil(t, handler)
assert.NotNil(t, handler.authService)
@@ -116,7 +116,7 @@ func TestHandler_HandleAuthCheck_Authenticated(t *testing.T) {
t.Parallel()
authService := createTestAuthService()
handler := NewHandler(authService, createTestMiddleware(), testBaseURL)
handler := NewHandler(authService, nil, createTestMiddleware(), testBaseURL, true, true)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/check", nil)
rec := httptest.NewRecorder()
@@ -190,7 +190,7 @@ func TestHandler_HandleAuthCheck_NotAuthenticated(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
handler := NewHandler(createTestAuthService(), nil, createTestMiddleware(), testBaseURL, true, true)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/check", nil)
req = tt.setupFunc(req)
@@ -219,7 +219,7 @@ func TestHandler_HandleAuthCheck_NotAuthenticated(t *testing.T) {
func TestHandler_HandleAuthCheck_ResponseFormat(t *testing.T) {
t.Parallel()
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
handler := NewHandler(createTestAuthService(), nil, createTestMiddleware(), testBaseURL, true, true)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/check", nil)
rec := httptest.NewRecorder()
@@ -258,7 +258,7 @@ func TestHandler_HandleLogout_WithSSO(t *testing.T) {
t.Parallel()
authService := createTestAuthService()
handler := NewHandler(authService, createTestMiddleware(), testBaseURL)
handler := NewHandler(authService, nil, createTestMiddleware(), testBaseURL, true, true)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/logout", nil)
rec := httptest.NewRecorder()
@@ -315,7 +315,7 @@ func TestHandler_HandleLogout_WithoutSSO(t *testing.T) {
SecureCookies: false,
})
handler := NewHandler(authService, createTestMiddleware(), testBaseURL)
handler := NewHandler(authService, nil, createTestMiddleware(), testBaseURL, true, true)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/logout", nil)
rec := httptest.NewRecorder()
@@ -344,7 +344,7 @@ func TestHandler_HandleLogout_ClearsSession(t *testing.T) {
t.Parallel()
authService := createTestAuthService()
handler := NewHandler(authService, createTestMiddleware(), testBaseURL)
handler := NewHandler(authService, nil, createTestMiddleware(), testBaseURL, true, true)
// Set user in session
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/logout", nil)
@@ -413,7 +413,7 @@ func TestHandler_HandleStartOAuth_WithRedirect(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
handler := NewHandler(createTestAuthService(), nil, createTestMiddleware(), testBaseURL, true, true)
body, err := json.Marshal(tt.requestBody)
require.NoError(t, err)
@@ -455,7 +455,7 @@ func TestHandler_HandleStartOAuth_WithRedirect(t *testing.T) {
func TestHandler_HandleStartOAuth_NoBody(t *testing.T) {
t.Parallel()
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
handler := NewHandler(createTestAuthService(), nil, createTestMiddleware(), testBaseURL, true, true)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/start", nil)
rec := httptest.NewRecorder()
@@ -483,7 +483,7 @@ func TestHandler_HandleStartOAuth_NoBody(t *testing.T) {
func TestHandler_HandleStartOAuth_InvalidJSON(t *testing.T) {
t.Parallel()
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
handler := NewHandler(createTestAuthService(), nil, createTestMiddleware(), testBaseURL, true, true)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/start", bytes.NewReader([]byte("invalid-json")))
req.Header.Set("Content-Type", "application/json")
@@ -510,7 +510,7 @@ func TestHandler_HandleStartOAuth_InvalidJSON(t *testing.T) {
func TestHandler_HandleStartOAuth_ResponseFormat(t *testing.T) {
t.Parallel()
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
handler := NewHandler(createTestAuthService(), nil, createTestMiddleware(), testBaseURL, true, true)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/start", nil)
rec := httptest.NewRecorder()
@@ -549,7 +549,7 @@ func TestHandler_HandleStartOAuth_ResponseFormat(t *testing.T) {
func TestHandler_HandleGetCSRFToken_Success(t *testing.T) {
t.Parallel()
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
handler := NewHandler(createTestAuthService(), nil, createTestMiddleware(), testBaseURL, true, true)
req := httptest.NewRequest(http.MethodGet, "/api/v1/csrf", nil)
rec := httptest.NewRecorder()
@@ -595,7 +595,7 @@ func TestHandler_HandleGetCSRFToken_Success(t *testing.T) {
func TestHandler_HandleGetCSRFToken_ResponseFormat(t *testing.T) {
t.Parallel()
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
handler := NewHandler(createTestAuthService(), nil, createTestMiddleware(), testBaseURL, true, true)
req := httptest.NewRequest(http.MethodGet, "/api/v1/csrf", nil)
rec := httptest.NewRecorder()
@@ -635,7 +635,7 @@ func TestHandler_HandleAuthCheck_Concurrent(t *testing.T) {
t.Parallel()
authService := createTestAuthService()
handler := NewHandler(authService, createTestMiddleware(), testBaseURL)
handler := NewHandler(authService, nil, createTestMiddleware(), testBaseURL, true, true)
const numRequests = 100
done := make(chan bool, numRequests)
@@ -695,7 +695,7 @@ func TestHandler_HandleAuthCheck_Concurrent(t *testing.T) {
func TestHandler_HandleLogout_Concurrent(t *testing.T) {
t.Parallel()
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
handler := NewHandler(createTestAuthService(), nil, createTestMiddleware(), testBaseURL, true, true)
const numRequests = 100
done := make(chan bool, numRequests)
@@ -745,7 +745,7 @@ func TestHandler_HandleLogout_Concurrent(t *testing.T) {
func TestHandler_HandleStartOAuth_Concurrent(t *testing.T) {
t.Parallel()
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
handler := NewHandler(createTestAuthService(), nil, createTestMiddleware(), testBaseURL, true, true)
const numRequests = 100
done := make(chan bool, numRequests)
@@ -802,7 +802,7 @@ func TestHandler_HandleStartOAuth_Concurrent(t *testing.T) {
// ============================================================================
func BenchmarkHandler_HandleAuthCheck(b *testing.B) {
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
handler := NewHandler(createTestAuthService(), nil, createTestMiddleware(), testBaseURL, true, true)
b.ResetTimer()
@@ -815,7 +815,7 @@ func BenchmarkHandler_HandleAuthCheck(b *testing.B) {
}
func BenchmarkHandler_HandleAuthCheck_Parallel(b *testing.B) {
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
handler := NewHandler(createTestAuthService(), nil, createTestMiddleware(), testBaseURL, true, true)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
@@ -828,7 +828,7 @@ func BenchmarkHandler_HandleAuthCheck_Parallel(b *testing.B) {
}
func BenchmarkHandler_HandleLogout(b *testing.B) {
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
handler := NewHandler(createTestAuthService(), nil, createTestMiddleware(), testBaseURL, true, true)
b.ResetTimer()
@@ -841,7 +841,7 @@ func BenchmarkHandler_HandleLogout(b *testing.B) {
}
func BenchmarkHandler_HandleLogout_Parallel(b *testing.B) {
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
handler := NewHandler(createTestAuthService(), nil, createTestMiddleware(), testBaseURL, true, true)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
@@ -854,7 +854,7 @@ func BenchmarkHandler_HandleLogout_Parallel(b *testing.B) {
}
func BenchmarkHandler_HandleStartOAuth(b *testing.B) {
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
handler := NewHandler(createTestAuthService(), nil, createTestMiddleware(), testBaseURL, true, true)
b.ResetTimer()
@@ -867,7 +867,7 @@ func BenchmarkHandler_HandleStartOAuth(b *testing.B) {
}
func BenchmarkHandler_HandleStartOAuth_Parallel(b *testing.B) {
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
handler := NewHandler(createTestAuthService(), nil, createTestMiddleware(), testBaseURL, true, true)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
@@ -880,7 +880,7 @@ func BenchmarkHandler_HandleStartOAuth_Parallel(b *testing.B) {
}
func BenchmarkHandler_HandleGetCSRFToken(b *testing.B) {
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
handler := NewHandler(createTestAuthService(), nil, createTestMiddleware(), testBaseURL, true, true)
b.ResetTimer()
@@ -893,7 +893,7 @@ func BenchmarkHandler_HandleGetCSRFToken(b *testing.B) {
}
func BenchmarkHandler_HandleGetCSRFToken_Parallel(b *testing.B) {
handler := NewHandler(createTestAuthService(), createTestMiddleware(), testBaseURL)
handler := NewHandler(createTestAuthService(), nil, createTestMiddleware(), testBaseURL, true, true)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {

View File

@@ -0,0 +1,93 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package auth
import (
"encoding/json"
"net/http"
"github.com/btouchard/ackify-ce/backend/internal/domain/models"
"github.com/btouchard/ackify-ce/backend/internal/presentation/api/shared"
"github.com/btouchard/ackify-ce/backend/pkg/logger"
)
// HandleRequestMagicLink handles POST /api/v1/auth/magic-link/request
func (h *Handler) HandleRequestMagicLink(w http.ResponseWriter, r *http.Request) {
var req struct {
Email string `json:"email"`
RedirectTo string `json:"redirectTo"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
shared.WriteError(w, http.StatusBadRequest, "invalid_request", "Invalid request body", nil)
return
}
if req.Email == "" {
shared.WriteError(w, http.StatusBadRequest, "missing_email", "Email is required", nil)
return
}
if req.RedirectTo == "" {
req.RedirectTo = "/"
}
// Extraire IP et User-Agent
ip := shared.GetClientIP(r)
userAgent := r.UserAgent()
// Demander le Magic Link
err := h.magicLinkService.RequestMagicLink(r.Context(), req.Email, req.RedirectTo, ip, userAgent)
// IMPORTANT: Ne jamais révéler si l'email existe ou non (protection contre énumération)
// Toujours retourner succès, même en cas d'erreur de rate limiting
if err != nil {
logger.Logger.Error("Magic Link request failed", "email", req.Email, "error", err)
// On log l'erreur mais on retourne succès au client
}
shared.WriteJSON(w, http.StatusOK, map[string]string{
"message": "If the email is valid, a magic link has been sent",
})
}
// HandleVerifyMagicLink handles GET /api/v1/auth/magic-link/verify
func (h *Handler) HandleVerifyMagicLink(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
if token == "" {
http.Error(w, "Missing token", http.StatusBadRequest)
return
}
// Extraire IP et User-Agent
ip := shared.GetClientIP(r)
userAgent := r.UserAgent()
// Vérifier le token
magicToken, err := h.magicLinkService.VerifyMagicLink(r.Context(), token, ip, userAgent)
if err != nil {
logger.Logger.Warn("Magic Link verification failed", "error", err, "ip", ip)
http.Error(w, "Invalid or expired token", http.StatusBadRequest)
return
}
// Créer une session utilisateur
user := &models.User{
Sub: magicToken.Email, // Utiliser email comme sub
Email: magicToken.Email,
Name: magicToken.Email, // Par défaut, nom = email
}
// Sauvegarder dans la session (réutiliser la logique OAuth existante)
if err := h.authService.SetUser(w, r, user); err != nil {
logger.Logger.Error("Failed to create session after Magic Link", "error", err)
http.Error(w, "Failed to create session", http.StatusInternalServerError)
return
}
logger.Logger.Info("User authenticated via Magic Link",
"email", magicToken.Email,
"redirect_to", magicToken.RedirectTo)
// Rediriger vers la destination demandée
http.Redirect(w, r, magicToken.RedirectTo, http.StatusFound)
}

View File

@@ -26,6 +26,7 @@ import (
// RouterConfig holds configuration for the API router
type RouterConfig struct {
AuthService *auth.OauthService
MagicLinkService *services.MagicLinkService
SignatureService *services.SignatureService
DocumentService *services.DocumentService
DocumentRepository *database.DocumentRepository
@@ -37,6 +38,8 @@ type RouterConfig struct {
BaseURL string
AdminEmails []string
AutoLogin bool
OAuthEnabled bool
MagicLinkEnabled bool
}
// NewRouter creates and configures the API v1 router
@@ -63,7 +66,7 @@ func NewRouter(cfg RouterConfig) *chi.Mux {
// Initialize handlers
healthHandler := health.NewHandler()
authHandler := apiAuth.NewHandler(cfg.AuthService, apiMiddleware, cfg.BaseURL)
authHandler := apiAuth.NewHandler(cfg.AuthService, cfg.MagicLinkService, apiMiddleware, cfg.BaseURL, cfg.OAuthEnabled, cfg.MagicLinkEnabled)
usersHandler := users.NewHandler(cfg.AdminEmails)
documentsHandler := documents.NewHandlerWithPublisher(cfg.SignatureService, cfg.DocumentService, cfg.WebhookPublisher)
signaturesHandler := signatures.NewHandlerWithDeps(cfg.SignatureService, cfg.ExpectedSignerRepository, cfg.WebhookPublisher)
@@ -78,15 +81,30 @@ func NewRouter(cfg RouterConfig) *chi.Mux {
// Auth endpoints
r.Route("/auth", func(r chi.Router) {
r.Use(authRateLimit.Middleware)
// Public endpoint to expose available authentication methods
r.Get("/config", authHandler.HandleGetAuthConfig)
r.Post("/start", authHandler.HandleStartOAuth)
r.Get("/callback", authHandler.HandleOAuthCallback)
r.Get("/logout", authHandler.HandleLogout)
// Apply rate limiting to auth endpoints (except /config which should be fast)
r.Group(func(r chi.Router) {
r.Use(authRateLimit.Middleware)
if cfg.AutoLogin {
r.Get("/check", authHandler.HandleAuthCheck)
}
// OAuth endpoints (conditional)
if cfg.OAuthEnabled {
r.Post("/start", authHandler.HandleStartOAuth)
r.Get("/callback", authHandler.HandleOAuthCallback)
r.Get("/logout", authHandler.HandleLogout)
if cfg.AutoLogin {
r.Get("/check", authHandler.HandleAuthCheck)
}
}
// Magic Link endpoints (conditional)
if cfg.MagicLinkEnabled {
r.Post("/magic-link/request", authHandler.HandleRequestMagicLink)
r.Get("/magic-link/verify", authHandler.HandleVerifyMagicLink)
}
})
})
// Public document endpoints

View File

@@ -0,0 +1,33 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package shared
import (
"net/http"
"strings"
)
// GetClientIP extracts the real client IP address from the request
// It checks X-Forwarded-For, X-Real-IP, and falls back to RemoteAddr
func GetClientIP(r *http.Request) string {
// Check X-Forwarded-For header (proxy/load balancer)
if forwardedFor := r.Header.Get("X-Forwarded-For"); forwardedFor != "" {
// X-Forwarded-For can contain multiple IPs, take the first one
ips := strings.Split(forwardedFor, ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
// Check X-Real-IP header
if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
return strings.TrimSpace(realIP)
}
// Fall back to RemoteAddr
ip := r.RemoteAddr
// Remove port if present
if idx := strings.LastIndex(ip, ":"); idx != -1 {
ip = ip[:idx]
}
return ip
}

View File

@@ -13,5 +13,17 @@
"email.reminder.explanation": "Ihre kryptographische Bestätigung liefert einen überprüfbaren Nachweis, dass Sie dieses Dokument gelesen und zur Kenntnis genommen haben.",
"email.reminder.contact": "Bei Fragen wenden Sie sich bitte an Ihren Administrator.",
"email.reminder.regards": "Mit freundlichen Grüßen,",
"email.reminder.team": "Das {{.Organisation}}-Team"
"email.reminder.team": "Das {{.Organisation}}-Team",
"email.magic_link.subject": "Ihr Anmeldelink",
"email.magic_link.title": "🔐 Ihr Anmeldelink",
"email.magic_link.greeting": "Hallo,",
"email.magic_link.intro": "Sie haben die Anmeldung bei {{.Organisation}} mit der E-Mail-Adresse {{.Email}} angefordert.",
"email.magic_link.instructions": "Klicken Sie auf die Schaltfläche unten, um sich sofort anzumelden:",
"email.magic_link.cta_button": "🚀 Jetzt anmelden",
"email.magic_link.warning_title": "Achtung:",
"email.magic_link.warning_text": "Dieser Link läuft in {{.ExpiresIn}} Minuten ab und kann nur einmal verwendet werden.",
"email.magic_link.not_requested": "Wenn Sie diesen Link nicht angefordert haben, können Sie diese E-Mail sicher ignorieren.",
"email.magic_link.button_not_working": "Wenn die Schaltfläche nicht funktioniert, kopieren Sie diesen Link in Ihren Browser:",
"email.magic_link.footer": "Diese E-Mail wurde von {{.Organisation}} gesendet {{.BaseURL}}"
}

View File

@@ -13,5 +13,17 @@
"email.reminder.explanation": "Your cryptographic confirmation will provide verifiable proof that you have read and acknowledged this document.",
"email.reminder.contact": "If you have any questions, please contact your administrator.",
"email.reminder.regards": "Best regards,",
"email.reminder.team": "The {{.Organisation}} team"
"email.reminder.team": "The {{.Organisation}} team",
"email.magic_link.subject": "Your login link",
"email.magic_link.title": "🔐 Your login link",
"email.magic_link.greeting": "Hello,",
"email.magic_link.intro": "You have requested to log in to {{.Organisation}} with the email address {{.Email}}.",
"email.magic_link.instructions": "Click the button below to log in instantly:",
"email.magic_link.cta_button": "🚀 Log in now",
"email.magic_link.warning_title": "Attention:",
"email.magic_link.warning_text": "This link expires in {{.ExpiresIn}} minutes and can only be used once.",
"email.magic_link.not_requested": "If you did not request this link, you can safely ignore this email.",
"email.magic_link.button_not_working": "If the button doesn't work, copy and paste this link into your browser:",
"email.magic_link.footer": "This email was sent by {{.Organisation}} {{.BaseURL}}"
}

View File

@@ -13,5 +13,17 @@
"email.reminder.explanation": "Su confirmación criptográfica proporcionará una prueba verificable de que ha leído y reconocido este documento.",
"email.reminder.contact": "Si tiene alguna pregunta, póngase en contacto con su administrador.",
"email.reminder.regards": "Saludos cordiales,",
"email.reminder.team": "El equipo de {{.Organisation}}"
"email.reminder.team": "El equipo de {{.Organisation}}",
"email.magic_link.subject": "Su enlace de inicio de sesión",
"email.magic_link.title": "🔐 Su enlace de inicio de sesión",
"email.magic_link.greeting": "Hola,",
"email.magic_link.intro": "Ha solicitado iniciar sesión en {{.Organisation}} con la dirección de correo electrónico {{.Email}}.",
"email.magic_link.instructions": "Haga clic en el botón de abajo para iniciar sesión instantáneamente:",
"email.magic_link.cta_button": "🚀 Iniciar sesión ahora",
"email.magic_link.warning_title": "Atención:",
"email.magic_link.warning_text": "Este enlace caduca en {{.ExpiresIn}} minutos y solo se puede usar una vez.",
"email.magic_link.not_requested": "Si no solicitó este enlace, puede ignorar este correo electrónico de forma segura.",
"email.magic_link.button_not_working": "Si el botón no funciona, copie y pegue este enlace en su navegador:",
"email.magic_link.footer": "Este correo electrónico fue enviado por {{.Organisation}} {{.BaseURL}}"
}

View File

@@ -13,5 +13,17 @@
"email.reminder.explanation": "Votre confirmation cryptographique fournira une preuve vérifiable que vous avez lu et pris connaissance de ce document.",
"email.reminder.contact": "Si vous avez des questions, veuillez contacter votre administrateur.",
"email.reminder.regards": "Cordialement,",
"email.reminder.team": "L'équipe {{.Organisation}}"
"email.reminder.team": "L'équipe {{.Organisation}}",
"email.magic_link.subject": "Votre lien de connexion",
"email.magic_link.title": "🔐 Votre lien de connexion",
"email.magic_link.greeting": "Bonjour,",
"email.magic_link.intro": "Vous avez demandé à vous connecter à {{.Organisation}} avec l'adresse email {{.Email}}.",
"email.magic_link.instructions": "Cliquez sur le bouton ci-dessous pour vous connecter instantanément :",
"email.magic_link.cta_button": "🚀 Se connecter maintenant",
"email.magic_link.warning_title": "Attention :",
"email.magic_link.warning_text": "Ce lien expire dans {{.ExpiresIn}} minutes et ne peut être utilisé qu'une seule fois.",
"email.magic_link.not_requested": "Si vous n'avez pas demandé ce lien, vous pouvez ignorer cet email en toute sécurité.",
"email.magic_link.button_not_working": "Si le bouton ne fonctionne pas, copiez et collez ce lien dans votre navigateur :",
"email.magic_link.footer": "Cet email a été envoyé par {{.Organisation}} {{.BaseURL}}"
}

View File

@@ -13,5 +13,17 @@
"email.reminder.explanation": "La tua conferma crittografica fornirà una prova verificabile che hai letto e preso atto di questo documento.",
"email.reminder.contact": "Se hai domande, contatta il tuo amministratore.",
"email.reminder.regards": "Cordiali saluti,",
"email.reminder.team": "Il team {{.Organisation}}"
"email.reminder.team": "Il team {{.Organisation}}",
"email.magic_link.subject": "Il tuo link di accesso",
"email.magic_link.title": "🔐 Il tuo link di accesso",
"email.magic_link.greeting": "Ciao,",
"email.magic_link.intro": "Hai richiesto l'accesso a {{.Organisation}} con l'indirizzo email {{.Email}}.",
"email.magic_link.instructions": "Fai clic sul pulsante qui sotto per accedere immediatamente:",
"email.magic_link.cta_button": "🚀 Accedi ora",
"email.magic_link.warning_title": "Attenzione:",
"email.magic_link.warning_text": "Questo link scade tra {{.ExpiresIn}} minuti e può essere utilizzato solo una volta.",
"email.magic_link.not_requested": "Se non hai richiesto questo link, puoi ignorare questa email in tutta sicurezza.",
"email.magic_link.button_not_working": "Se il pulsante non funziona, copia e incolla questo link nel tuo browser:",
"email.magic_link.footer": "Questa email è stata inviata da {{.Organisation}} {{.BaseURL}}"
}

View File

@@ -0,0 +1,11 @@
-- Rollback migration 0012: Magic Link Authentication
DROP INDEX IF EXISTS idx_magic_link_attempts_email_time;
DROP INDEX IF EXISTS idx_magic_link_attempts_ip_time;
DROP TABLE IF EXISTS magic_link_auth_attempts;
DROP INDEX IF EXISTS idx_magic_link_tokens_cleanup;
DROP INDEX IF EXISTS idx_magic_link_tokens_expires;
DROP INDEX IF EXISTS idx_magic_link_tokens_email;
DROP INDEX IF EXISTS idx_magic_link_tokens_token;
DROP TABLE IF EXISTS magic_link_tokens;

View File

@@ -0,0 +1,47 @@
-- Migration 0012: Magic Link Authentication
-- Adds tables for passwordless email authentication
-- Table pour stocker les tokens Magic Link
CREATE TABLE magic_link_tokens (
id BIGSERIAL PRIMARY KEY,
token TEXT NOT NULL UNIQUE,
email TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
expires_at TIMESTAMPTZ NOT NULL,
used_at TIMESTAMPTZ,
used_by_ip INET,
used_by_user_agent TEXT,
redirect_to TEXT NOT NULL DEFAULT '/',
created_by_ip INET NOT NULL,
created_by_user_agent TEXT
);
-- Index pour requêtes fréquentes
CREATE INDEX idx_magic_link_tokens_token ON magic_link_tokens(token) WHERE used_at IS NULL;
CREATE INDEX idx_magic_link_tokens_email ON magic_link_tokens(email);
CREATE INDEX idx_magic_link_tokens_expires ON magic_link_tokens(expires_at) WHERE used_at IS NULL;
-- Index pour cleanup des tokens expirés
CREATE INDEX idx_magic_link_tokens_cleanup ON magic_link_tokens(created_at) WHERE used_at IS NULL;
COMMENT ON TABLE magic_link_tokens IS 'Tokens de connexion par Magic Link (usage unique, expiration 15min)';
COMMENT ON COLUMN magic_link_tokens.token IS 'Token cryptographiquement sécurisé (base64url, 32 bytes)';
COMMENT ON COLUMN magic_link_tokens.used_at IS 'Timestamp d''utilisation (NULL = non utilisé)';
COMMENT ON COLUMN magic_link_tokens.redirect_to IS 'URL de destination après authentification (ex: /?doc=xxx)';
-- Table pour logs des tentatives d'authentification Magic Link
CREATE TABLE magic_link_auth_attempts (
id BIGSERIAL PRIMARY KEY,
email TEXT NOT NULL,
success BOOLEAN NOT NULL,
failure_reason TEXT,
ip_address INET NOT NULL,
user_agent TEXT,
attempted_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
-- Index pour rate limiting
CREATE INDEX idx_magic_link_attempts_ip_time ON magic_link_auth_attempts(ip_address, attempted_at);
CREATE INDEX idx_magic_link_attempts_email_time ON magic_link_auth_attempts(email, attempted_at);
COMMENT ON TABLE magic_link_auth_attempts IS 'Logs des tentatives d''authentification Magic Link (rate limiting + audit)';

View File

@@ -9,6 +9,7 @@ import (
"net/http"
"os"
"path/filepath"
"time"
"github.com/go-chi/chi/v5"
@@ -19,23 +20,26 @@ import (
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/email"
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/i18n"
whworker "github.com/btouchard/ackify-ce/backend/internal/infrastructure/webhook"
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/workers"
"github.com/btouchard/ackify-ce/backend/internal/presentation/api"
"github.com/btouchard/ackify-ce/backend/internal/presentation/handlers"
"github.com/btouchard/ackify-ce/backend/pkg/crypto"
"github.com/btouchard/ackify-ce/backend/pkg/logger"
)
type Server struct {
httpServer *http.Server
db *sql.DB
router *chi.Mux
emailSender email.Sender
emailWorker *email.Worker
webhookWorker *whworker.Worker
sessionWorker *auth.SessionWorker
baseURL string
adminEmails []string
authService *auth.OauthService
autoLogin bool
httpServer *http.Server
db *sql.DB
router *chi.Mux
emailSender email.Sender
emailWorker *email.Worker
webhookWorker *whworker.Worker
sessionWorker *auth.SessionWorker
magicLinkWorker *workers.MagicLinkCleanupWorker
baseURL string
adminEmails []string
authService *auth.OauthService
autoLogin bool
}
func NewServer(ctx context.Context, cfg *config.Config, frontend embed.FS, version string) (*Server, error) {
@@ -52,6 +56,7 @@ func NewServer(ctx context.Context, cfg *config.Config, frontend embed.FS, versi
emailQueueRepo := database.NewEmailQueueRepository(db)
webhookRepo := database.NewWebhookRepository(db)
webhookDeliveryRepo := database.NewWebhookDeliveryRepository(db)
magicLinkRepo := database.NewMagicLinkRepository(db)
// Initialize webhook publisher and worker
webhookPublisher := services.NewWebhookPublisher(webhookRepo, webhookDeliveryRepo)
@@ -60,6 +65,7 @@ func NewServer(ctx context.Context, cfg *config.Config, frontend embed.FS, versi
oauthSessionRepo := database.NewOAuthSessionRepository(db)
// Initialize OAuth auth service with session repository
// Note: SessionService is ALWAYS created, OAuthProvider is OPTIONAL (based on credentials)
authService := auth.NewOAuthService(auth.Config{
BaseURL: cfg.App.BaseURL,
ClientID: cfg.OAuth.ClientID,
@@ -75,6 +81,13 @@ func NewServer(ctx context.Context, cfg *config.Config, frontend embed.FS, versi
SessionRepo: oauthSessionRepo,
})
// Log authentication method status
if cfg.Auth.OAuthEnabled {
logger.Logger.Info("OAuth authentication enabled")
} else {
logger.Logger.Info("OAuth authentication disabled")
}
// Initialize services
signatureService := services.NewSignatureService(signatureRepo, documentRepo, signer)
signatureService.SetChecksumConfig(&cfg.Checksum)
@@ -122,6 +135,27 @@ func NewServer(ctx context.Context, cfg *config.Config, frontend embed.FS, versi
}
}
// Initialize Magic Link service (only if enabled)
var magicLinkService *services.MagicLinkService
if cfg.Auth.MagicLinkEnabled && emailSender != nil {
magicLinkService = services.NewMagicLinkService(services.MagicLinkServiceConfig{
Repository: magicLinkRepo,
EmailSender: emailSender,
BaseURL: cfg.App.BaseURL,
AppName: cfg.App.Organisation,
})
logger.Logger.Info("Magic Link authentication enabled")
} else if !cfg.Auth.MagicLinkEnabled {
logger.Logger.Info("Magic Link authentication disabled")
}
// Initialize Magic Link cleanup worker
var magicLinkWorker *workers.MagicLinkCleanupWorker
if magicLinkService != nil {
magicLinkWorker = workers.NewMagicLinkCleanupWorker(magicLinkService, 1*time.Hour)
go magicLinkWorker.Start(ctx)
}
router := chi.NewRouter()
router.Use(i18n.Middleware(i18nService))
@@ -135,6 +169,7 @@ func NewServer(ctx context.Context, cfg *config.Config, frontend embed.FS, versi
apiConfig := api.RouterConfig{
AuthService: authService,
MagicLinkService: magicLinkService,
SignatureService: signatureService,
DocumentService: documentService,
DocumentRepository: documentRepo,
@@ -146,13 +181,15 @@ func NewServer(ctx context.Context, cfg *config.Config, frontend embed.FS, versi
BaseURL: cfg.App.BaseURL,
AdminEmails: cfg.App.AdminEmails,
AutoLogin: cfg.OAuth.AutoLogin,
OAuthEnabled: cfg.Auth.OAuthEnabled,
MagicLinkEnabled: cfg.Auth.MagicLinkEnabled,
}
apiRouter := api.NewRouter(apiConfig)
router.Mount("/api/v1", apiRouter)
router.Get("/oembed", handlers.HandleOEmbed(cfg.App.BaseURL))
router.NotFound(EmbedFolder(frontend, "web/dist", cfg.App.BaseURL, version, signatureRepo))
router.NotFound(EmbedFolder(frontend, "web/dist", cfg.App.BaseURL, version, cfg.Auth.OAuthEnabled, cfg.Auth.MagicLinkEnabled, signatureRepo))
httpServer := &http.Server{
Addr: cfg.Server.ListenAddr,
@@ -160,17 +197,18 @@ func NewServer(ctx context.Context, cfg *config.Config, frontend embed.FS, versi
}
return &Server{
httpServer: httpServer,
db: db,
router: router,
emailSender: emailSender,
emailWorker: emailWorker,
webhookWorker: webhookWorker,
sessionWorker: sessionWorker,
baseURL: cfg.App.BaseURL,
adminEmails: cfg.App.AdminEmails,
authService: authService,
autoLogin: cfg.OAuth.AutoLogin,
httpServer: httpServer,
db: db,
router: router,
emailSender: emailSender,
emailWorker: emailWorker,
webhookWorker: webhookWorker,
sessionWorker: sessionWorker,
magicLinkWorker: magicLinkWorker,
baseURL: cfg.App.BaseURL,
adminEmails: cfg.App.AdminEmails,
authService: authService,
autoLogin: cfg.OAuth.AutoLogin,
}, nil
}
@@ -179,6 +217,11 @@ func (s *Server) Start() error {
}
func (s *Server) Shutdown(ctx context.Context) error {
// Stop Magic Link cleanup worker if it exists
if s.magicLinkWorker != nil {
s.magicLinkWorker.Stop()
}
// Stop OAuth session worker first if it exists
if s.sessionWorker != nil {
if err := s.sessionWorker.Stop(); err != nil {

View File

@@ -21,8 +21,9 @@ import (
// with SPA fallback support (serves index.html for non-existent routes)
// For index.html, it replaces __ACKIFY_BASE_URL__ placeholder with the actual base URL,
// __ACKIFY_VERSION__ with the application version,
// __ACKIFY_OAUTH_ENABLED__ and __ACKIFY_MAGICLINK_ENABLED__ with auth method flags,
// and __META_TAGS__ with dynamic meta tags based on query parameters
func EmbedFolder(fsEmbed embed.FS, targetPath string, baseURL string, version string, signatureRepo *database.SignatureRepository) http.HandlerFunc {
func EmbedFolder(fsEmbed embed.FS, targetPath string, baseURL string, version string, oauthEnabled bool, magicLinkEnabled bool, signatureRepo *database.SignatureRepository) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
fsys, err := fs.Sub(fsEmbed, targetPath)
if err != nil {
@@ -61,7 +62,7 @@ func EmbedFolder(fsEmbed embed.FS, targetPath string, baseURL string, version st
defer file.Close()
if shouldServeIndex || strings.HasSuffix(cleanPath, "index.html") {
serveIndexTemplate(w, r, file, baseURL, version, signatureRepo)
serveIndexTemplate(w, r, file, baseURL, version, oauthEnabled, magicLinkEnabled, signatureRepo)
return
}
@@ -70,7 +71,7 @@ func EmbedFolder(fsEmbed embed.FS, targetPath string, baseURL string, version st
}
}
func serveIndexTemplate(w http.ResponseWriter, r *http.Request, file fs.File, baseURL string, version string, signatureRepo *database.SignatureRepository) {
func serveIndexTemplate(w http.ResponseWriter, r *http.Request, file fs.File, baseURL string, version string, oauthEnabled bool, magicLinkEnabled bool, signatureRepo *database.SignatureRepository) {
content, err := io.ReadAll(file)
if err != nil {
logger.Logger.Error("Failed to read index.html", "error", err.Error())
@@ -81,6 +82,19 @@ func serveIndexTemplate(w http.ResponseWriter, r *http.Request, file fs.File, ba
processedContent := strings.ReplaceAll(string(content), "__ACKIFY_BASE_URL__", baseURL)
processedContent = strings.ReplaceAll(processedContent, "__ACKIFY_VERSION__", version)
// Convert boolean to string for JavaScript
oauthEnabledStr := "false"
if oauthEnabled {
oauthEnabledStr = "true"
}
magicLinkEnabledStr := "false"
if magicLinkEnabled {
magicLinkEnabledStr = "true"
}
processedContent = strings.ReplaceAll(processedContent, "__ACKIFY_OAUTH_ENABLED__", oauthEnabledStr)
processedContent = strings.ReplaceAll(processedContent, "__ACKIFY_MAGICLINK_ENABLED__", magicLinkEnabledStr)
metaTags := generateMetaTags(r, baseURL, signatureRepo)
processedContent = strings.ReplaceAll(processedContent, "__META_TAGS__", metaTags)

View File

@@ -0,0 +1,42 @@
{{define "content"}}
<h2>{{T "email.magic_link.title"}}</h2>
<p>{{T "email.magic_link.greeting"}}</p>
<p>{{T "email.magic_link.intro" (dict "Organisation" .Organisation "Email" .Data.Email)}}</p>
<p>{{T "email.magic_link.instructions"}}</p>
<div style="text-align: center; margin: 30px 0;">
<a href="{{.Data.MagicLink}}"
style="background-color: #4F46E5;
color: white;
padding: 14px 40px;
text-decoration: none;
border-radius: 6px;
display: inline-block;
font-weight: bold;">
{{T "email.magic_link.cta_button"}}
</a>
</div>
<div style="background: #FEF3C7; padding: 15px; border-left: 4px solid #F59E0B; border-radius: 4px; margin: 20px 0;">
<p style="margin: 0;">
⏱️ <strong>{{T "email.magic_link.warning_title"}}</strong>
{{T "email.magic_link.warning_text" (dict "ExpiresIn" .Data.ExpiresIn)}}
</p>
</div>
<p>{{T "email.magic_link.not_requested"}}</p>
<hr style="border: none; border-top: 1px solid #eee; margin: 30px 0;">
<p style="color: #666; font-size: 0.9em;">
{{T "email.magic_link.button_not_working"}}<br>
<a href="{{.Data.MagicLink}}" style="color: #4F46E5; word-break: break-all;">{{.Data.MagicLink}}</a>
</p>
<p style="color: #999; font-size: 0.8em;">
{{T "email.magic_link.footer" (dict "Organisation" .Organisation "BaseURL" .Data.BaseURL)}}
</p>
{{end}}

View File

@@ -0,0 +1,18 @@
{{define "content"}}
{{T "email.magic_link.title"}}
{{T "email.magic_link.greeting"}}
{{T "email.magic_link.intro" (dict "Organisation" .Organisation "Email" .Data.Email)}}
{{T "email.magic_link.instructions"}}
{{.Data.MagicLink}}
{{T "email.magic_link.warning_title"}} {{T "email.magic_link.warning_text" (dict "ExpiresIn" .Data.ExpiresIn)}}
{{T "email.magic_link.not_requested"}}
---
{{T "email.magic_link.footer" (dict "Organisation" .Organisation "BaseURL" .Data.BaseURL)}}
{{end}}

78
compose.e2e.yml Normal file
View File

@@ -0,0 +1,78 @@
name: ackify-ce
services:
ackify-migrate:
image: btouchard/ackify-ce
container_name: ackify-migrate
environment:
ACKIFY_DB_DSN: "postgres://postgres:testpassword@ackify-db:5432/ackify_test?sslmode=disable"
depends_on:
ackify-db:
condition: service_healthy
command: ["/app/migrate", "up"]
entrypoint: []
restart: "no"
ackify-ce:
build:
context: .
args:
VERSION: "v0.0.0-dev"
image: btouchard/ackify-ce
container_name: ackify-ce
restart: unless-stopped
environment:
ACKIFY_LOG_LEVEL: "debug"
ACKIFY_LOG_FORMAT: "classic"
ACKIFY_BASE_URL: "http://localhost:8080"
ACKIFY_ORGANISATION: "Superkloud"
ACKIFY_DB_DSN: "postgres://postgres:testpassword@ackify-db:5432/ackify_test?sslmode=disable"
ACKIFY_OAUTH_PROVIDER: "none"
ACKIFY_OAUTH_COOKIE_SECRET: "Fk7eiBaG0sOhGn5+nWB6Ipl7TqbGfeLSdhwrEKjOPPM="
ACKIFY_AUTH_OAUTH_ENABLED: "true"
ACKIFY_AUTH_MAGICLINK_ENABLED: "true"
ACKIFY_ED25519_PRIVATE_KEY: "1aK7vTp7hiqM9Z3Xbj7QDfPlYyWy740l9Fu+Fdom5ck="
ACKIFY_LISTEN_ADDR: ":8080"
ACKIFY_ADMIN_EMAILS: "benjamin@kolapsis.com"
ACKIFY_MAIL_HOST: "mailhog"
ACKIFY_MAIL_PORT: "1025"
ACKIFY_MAIL_TLS: "false"
ACKIFY_MAIL_STARTTLS: "false"
ACKIFY_MAIL_FROM: "noreply@ackify.local"
ACKIFY_MAIL_FROM_NAME: "Ackify"
depends_on:
ackify-migrate:
condition: service_completed_successfully
ackify-db:
condition: service_healthy
ports:
- "8080:8080"
ackify-db:
image: postgres:16-alpine
container_name: ackify-db-test
restart: unless-stopped
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: testpassword
POSTGRES_DB: ackify_test
volumes:
- ackify_test:/var/lib/postgresql/data
ports:
- "5432:5432"
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres -d ackify_test"]
interval: 10s
timeout: 5s
retries: 5
mailhog:
image: mailhog/mailhog:latest
container_name: ackify-mailhog-test
restart: unless-stopped
ports:
- "1025:1025"
- "8025:8025"
volumes:
ackify_test:

22
webapp/cypress.config.ts Normal file
View File

@@ -0,0 +1,22 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
import { defineConfig } from 'cypress'
export default defineConfig({
e2e: {
baseUrl: 'http://localhost:8080',
specPattern: 'cypress/e2e/**/*.cy.{js,jsx,ts,tsx}',
supportFile: 'cypress/support/e2e.ts',
fixturesFolder: 'cypress/fixtures',
video: false,
screenshotOnRunFailure: true,
defaultCommandTimeout: 10000,
requestTimeout: 10000,
env: {
mailhogUrl: 'http://localhost:8025',
},
setupNodeEvents(on, config) {
// implement node event listeners here
return config
},
},
})

11
webapp/cypress/.gitignore vendored Normal file
View File

@@ -0,0 +1,11 @@
# Cypress artifacts
screenshots/
videos/
downloads/
# Cypress cache
.cypress-cache/
# Test results
results/
reports/

View File

@@ -0,0 +1,210 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
/// <reference types="cypress" />
describe('Magic Link Authentication', () => {
const testEmail = 'test@example.com'
const baseUrl = Cypress.config('baseUrl')
beforeEach(() => {
// Clear mailbox before each test
cy.clearMailbox()
})
it('should complete full magic link authentication workflow', () => {
// Visit auth choice page with English locale
cy.visitWithLocale('/auth')
// Wait for Vue app to be fully loaded
cy.get('#app', { timeout: 10000 }).should('not.be.empty')
// Should display auth choice page
cy.contains('Sign in to Ackify', { timeout: 10000 }).should('be.visible')
cy.contains('Sign in with Email', { timeout: 10000 }).should('be.visible')
// Fill magic link form
cy.get('input[type="email"]', { timeout: 10000 }).should('be.visible').type(testEmail)
cy.contains('Send Magic Link').click()
// Should show success message
cy.contains('Check your email', { timeout: 10000 }).should('be.visible')
cy.contains('We sent you a magic link').should('be.visible')
cy.contains('The link expires in 15 minutes').should('be.visible')
// Wait for email to arrive in Mailhog
cy.waitForEmail(testEmail, 'login link', 30000).then((message) => {
// Verify email content
expect(message.To).to.have.length.greaterThan(0)
expect(message.To[0].Mailbox + '@' + message.To[0].Domain).to.equal(testEmail)
// Extract magic link from email
cy.extractMagicLink(message).then((magicLink) => {
cy.log('Magic link found:', magicLink)
// Visit the magic link
cy.visit(magicLink)
// Should redirect to home page after successful authentication
cy.url({ timeout: 10000 }).should('eq', baseUrl + '/')
// Verify user is authenticated
cy.request('/api/v1/users/me').then((response) => {
expect(response.status).to.eq(200)
expect(response.body.data).to.have.property('email', testEmail)
expect(response.body.data).to.have.property('id', testEmail)
})
})
})
})
it('should redirect to document page after authentication', () => {
const docId = 'test-document-123'
const redirectUrl = `/?doc=${docId}`
// Visit auth page with redirect parameter
cy.visitWithLocale(`/auth?redirect=${encodeURIComponent(redirectUrl)}`)
// Fill magic link form
cy.get('input[type="email"]').type(testEmail)
cy.contains('Send Magic Link').click()
// Wait for success message
cy.contains('Check your email', { timeout: 10000 }).should('be.visible')
// Wait for email
cy.waitForEmail(testEmail, 'login link', 30000).then((message) => {
// Extract magic link
cy.extractMagicLink(message).then((magicLink) => {
// Verify redirect parameter is in the link
expect(magicLink).to.include(`redirect=${encodeURIComponent(redirectUrl)}`)
// Visit the magic link
cy.visit(magicLink)
// Should redirect to document page
cy.url({ timeout: 10000 }).should('include', `/?doc=${docId}`)
})
})
})
it('should reject invalid email addresses', () => {
cy.visitWithLocale('/auth')
// Try invalid email - remove HTML5 validation and type
cy.get('input[type="email"]').then(($input) => {
$input.removeAttr('type')
$input.removeAttr('required')
cy.wrap($input).type('invalid-email')
})
cy.contains('Send Magic Link').click()
// Should show error
cy.contains('Please enter a valid email address', { timeout: 5000 }).should('be.visible')
})
it('should reject expired tokens', () => {
// This test would require mocking time or waiting 15 minutes
// Skip for now, can be tested manually or with time manipulation
cy.log('Expired token test skipped - requires time manipulation')
})
it('should prevent token reuse', () => {
cy.visitWithLocale('/auth')
// Request magic link
cy.get('input[type="email"]').type(testEmail)
cy.contains('Send Magic Link').click()
// Wait for email
cy.waitForEmail(testEmail, 'login link', 30000).then((message) => {
cy.extractMagicLink(message).then((magicLink) => {
// Use the magic link once
cy.visit(magicLink)
cy.url({ timeout: 10000 }).should('eq', baseUrl + '/')
// Clear cookies to simulate new session
cy.clearCookies()
// Try to use the same link again - should fail with 400
cy.request({ url: magicLink, failOnStatusCode: false }).then((response) => {
// Should get 400 Bad Request
expect(response.status).to.eq(400)
// Should contain error message about token
expect(response.body).to.include('Invalid or expired token')
})
})
})
})
it('should enforce rate limiting', () => {
cy.visitWithLocale('/auth')
// Send multiple requests quickly
for (let i = 0; i < 4; i++) {
cy.get('input[type="email"]').clear().type(`test${i}@example.com`)
cy.contains('Send Magic Link').click()
if (i < 3) {
// First 3 should succeed
cy.contains('Check your email', { timeout: 5000 }).should('be.visible')
cy.wait(1000) // Small delay between requests
cy.visitWithLocale('/auth') // Reload page to send another request
} else {
// 4th request should fail due to rate limiting
// Note: This depends on backend rate limiting implementation
// You may need to adjust based on actual limits
cy.log('Rate limit test - verify backend behavior')
}
}
})
it('should handle mailhog unavailability gracefully', () => {
// This test verifies frontend behavior when email fails to send
// You might want to mock the API response for this
cy.log('Email service unavailability test - manual testing recommended')
})
})
describe('Magic Link Email Content', () => {
const testEmail = 'content-test@example.com'
beforeEach(() => {
cy.clearMailbox()
})
it('should send email with correct subject and content', () => {
cy.visitWithLocale('/auth')
// Request magic link
cy.get('input[type="email"]').type(testEmail)
cy.contains('Send Magic Link').click()
// Wait for email
cy.waitForEmail(testEmail, 'login link', 30000).then((message) => {
// Verify subject
const subject = message.Content?.Headers?.Subject?.[0] || ''
expect(subject).to.include('login link')
// Verify sender
expect(message.From.Mailbox + '@' + message.From.Domain).to.include('ackify')
// Verify recipient
expect(message.To[0].Mailbox + '@' + message.To[0].Domain).to.equal(testEmail)
// Verify body contains key elements
const body = message.Content?.Body || ''
expect(body).to.include('/api/v1/auth/magic-link/verify')
expect(body).to.include('token=')
// Extract and verify token format (should be base64url)
cy.extractMagicLink(message).then((link) => {
const url = new URL(link)
const token = url.searchParams.get('token')
expect(token).to.exist
expect(token).to.have.length.greaterThan(20)
// base64url should only contain [A-Za-z0-9_-]
expect(token).to.match(/^[A-Za-z0-9_-]+$/)
})
})
})
})

View File

@@ -0,0 +1,38 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
/// <reference types="cypress" />
// ***********************************************
// This example commands.ts shows you how to
// create various custom commands and overwrite
// existing commands.
//
// For more comprehensive examples of custom
// commands please read more here:
// https://on.cypress.io/custom-commands
// ***********************************************
declare global {
namespace Cypress {
interface Chainable {
/**
* Visit a page with locale set to English
* @param url - URL to visit
* @param options - Visit options
*/
visitWithLocale(url: string, locale?: string, options?: Partial<Cypress.VisitOptions>): Chainable<Cypress.AUTWindow>
}
}
}
Cypress.Commands.add('visitWithLocale', (url: string, locale: string = 'en', options?: Partial<Cypress.VisitOptions>) => {
return cy.visit(url, {
...options,
onBeforeLoad: (win) => {
win.localStorage.setItem('locale', locale)
if (options?.onBeforeLoad) {
options.onBeforeLoad(win)
}
}
})
})
export {}

View File

@@ -0,0 +1,22 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
// ***********************************************************
// This example support/e2e.ts is processed and
// loaded automatically before your test files.
//
// This is a great place to put global configuration and
// behavior that modifies Cypress.
//
// You can change the location of this file or turn off
// automatically serving support files with the
// 'supportFile' configuration option.
//
// You can read more here:
// https://on.cypress.io/configuration
// ***********************************************************
// Import commands.js using ES2015 syntax:
import './commands'
import './mailhog'
// Alternatively you can use CommonJS syntax:
// require('./commands')

View File

@@ -0,0 +1,202 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
/// <reference types="cypress" />
/**
* Mailhog API Helper
* Documentation: https://github.com/mailhog/MailHog/blob/master/docs/APIv2.md
*/
interface MailhogMessage {
ID: string
From: {
Mailbox: string
Domain: string
}
To: Array<{
Mailbox: string
Domain: string
}>
Content: {
Headers: Record<string, string[]>
Body: string
}
Raw: {
From: string
To: string[]
Data: string
}
MIME: {
Parts: Array<{
Headers: Record<string, string[]>
Body: string
}>
} | null
}
interface MailhogResponse {
total: number
count: number
start: number
items: MailhogMessage[]
}
declare global {
namespace Cypress {
interface Chainable {
/**
* Get the latest email from Mailhog for a specific recipient
* @param email - Recipient email address
* @param timeout - Maximum time to wait for email (ms)
*/
getLatestEmail(email: string, timeout?: number): Chainable<MailhogMessage>
/**
* Extract magic link from email body
* @param message - Mailhog message
*/
extractMagicLink(message: MailhogMessage): Chainable<string>
/**
* Clear all emails from Mailhog
*/
clearMailbox(): Chainable<void>
/**
* Wait for email to arrive in Mailhog
* @param email - Recipient email address
* @param subject - Email subject (optional)
* @param timeout - Maximum time to wait (ms)
*/
waitForEmail(email: string, subject?: string, timeout?: number): Chainable<MailhogMessage>
}
}
}
Cypress.Commands.add('clearMailbox', () => {
const mailhogUrl = Cypress.env('mailhogUrl') || 'http://localhost:8025'
cy.request('DELETE', `${mailhogUrl}/api/v1/messages`).then((response) => {
expect(response.status).to.eq(200)
})
})
Cypress.Commands.add('getLatestEmail', (email: string, timeout = 10000) => {
const mailhogUrl = Cypress.env('mailhogUrl') || 'http://localhost:8025'
const startTime = Date.now()
const checkForEmail = (): Cypress.Chainable<MailhogMessage> => {
return cy.request<MailhogResponse>(`${mailhogUrl}/api/v2/messages?limit=50`).then((response) => {
expect(response.status).to.eq(200)
const messages = response.body.items || []
const targetEmail = messages.find((msg) => {
const recipients = msg.To || []
return recipients.some((to) => `${to.Mailbox}@${to.Domain}` === email)
})
if (targetEmail) {
return cy.wrap(targetEmail)
}
// Retry if timeout not reached
if (Date.now() - startTime < timeout) {
cy.wait(500)
return checkForEmail()
}
throw new Error(`No email found for ${email} after ${timeout}ms`)
})
}
return checkForEmail()
})
Cypress.Commands.add('waitForEmail', (email: string, subject?: string, timeout = 10000) => {
const mailhogUrl = Cypress.env('mailhogUrl') || 'http://localhost:8025'
const startTime = Date.now()
const checkForEmail = (): Cypress.Chainable<MailhogMessage> => {
return cy.request<MailhogResponse>(`${mailhogUrl}/api/v2/messages?limit=50`).then((response) => {
expect(response.status).to.eq(200)
const messages = response.body.items || []
const targetEmail = messages.find((msg) => {
const recipients = msg.To || []
const matchesRecipient = recipients.some((to) => `${to.Mailbox}@${to.Domain}` === email)
if (!matchesRecipient) return false
if (subject) {
const emailSubject = msg.Content?.Headers?.Subject?.[0] || ''
return emailSubject.includes(subject)
}
return true
})
if (targetEmail) {
return cy.wrap(targetEmail)
}
// Retry if timeout not reached
if (Date.now() - startTime < timeout) {
cy.wait(500)
return checkForEmail()
}
throw new Error(`No email found for ${email}${subject ? ` with subject "${subject}"` : ''} after ${timeout}ms`)
})
}
return checkForEmail()
})
Cypress.Commands.add('extractMagicLink', (message: MailhogMessage) => {
let body = message.Content?.Body || ''
// Try to get HTML part from MIME if available
if (message.MIME?.Parts && message.MIME.Parts.length > 0) {
const htmlPart = message.MIME.Parts.find((part) => {
const contentType = part.Headers['Content-Type']?.[0] || ''
return contentType.includes('text/html')
})
if (htmlPart) {
body = htmlPart.Body
} else {
// Fallback to plain text
const textPart = message.MIME.Parts.find((part) => {
const contentType = part.Headers['Content-Type']?.[0] || ''
return contentType.includes('text/plain')
})
if (textPart) {
body = textPart.Body
}
}
}
// Decode quoted-printable encoding (=3D -> =, =\n -> remove)
body = body
.replace(/=\r?\n/g, '') // Remove soft line breaks
.replace(/=([0-9A-F]{2})/g, (_, hex) => String.fromCharCode(parseInt(hex, 16)))
// Decode HTML entities (&amp; -> &, etc.)
body = body
.replace(/&amp;/g, '&')
.replace(/&lt;/g, '<')
.replace(/&gt;/g, '>')
.replace(/&quot;/g, '"')
.replace(/&#39;/g, "'")
// Extract magic link URL
// Pattern: http(s)://domain/api/v1/auth/magic-link/verify?token=xxx&redirect=xxx
const linkRegex = /(https?:\/\/[^\s]+\/api\/v1\/auth\/magic-link\/verify\?[^\s"<]+)/g
const matches = body.match(linkRegex)
if (matches && matches.length > 0) {
return cy.wrap(matches[0])
}
throw new Error('No magic link found in email body')
})
export {}

View File

@@ -9,6 +9,8 @@
<script>
window.ACKIFY_BASE_URL = '__ACKIFY_BASE_URL__';
window.ACKIFY_VERSION = '__ACKIFY_VERSION__';
window.ACKIFY_OAUTH_ENABLED = '__ACKIFY_OAUTH_ENABLED__' === 'true';
window.ACKIFY_MAGICLINK_ENABLED = '__ACKIFY_MAGICLINK_ENABLED__' === 'true';
</script>
</head>
<body>

1900
webapp/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,12 @@
"dev": "vite",
"build": "vue-tsc -b && vite build",
"preview": "vite preview",
"lint:i18n": "node scripts/check-i18n.js"
"lint:i18n": "node scripts/check-i18n.js",
"cypress:open": "cypress open",
"cypress:run": "cypress run",
"cypress:headless": "cypress run --headless",
"test:e2e": "cypress run --headless",
"test:e2e:open": "cypress open"
},
"dependencies": {
"axios": "^1.12.2",
@@ -27,6 +32,7 @@
"autoprefixer": "^10.4.21",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"cypress": "^15.5.0",
"postcss": "^8.5.6",
"tailwind-merge": "^3.3.1",
"tailwindcss": "^4.1.14",

View File

@@ -1,42 +0,0 @@
<!-- SPDX-License-Identifier: AGPL-3.0-or-later -->
<script setup lang="ts">
import { ref } from 'vue'
defineProps<{ msg: string }>()
const count = ref(0)
</script>
<template>
<h1>{{ msg }}</h1>
<div class="card">
<button type="button" @click="count++">count is {{ count }}</button>
<p>
Edit
<code>components/HelloWorld.vue</code> to test HMR
</p>
</div>
<p>
Check out
<a href="https://vuejs.org/guide/quick-start.html#local" target="_blank"
>create-vue</a
>, the official Vue + Vite starter
</p>
<p>
Learn more about IDE Support for Vue in the
<a
href="https://vuejs.org/guide/scaling-up/tooling.html#ide-support"
target="_blank"
>Vue Docs Scaling up Guide</a
>.
</p>
<p class="read-the-docs">Click on the Vite and Vue logos to learn more</p>
</template>
<style scoped>
.read-the-docs {
color: #888;
}
</style>

View File

@@ -65,7 +65,7 @@
</svg>
<span class="font-semibold">Lecture confirmée</span>
</div>
<p v-if="signedAt" class="mt-2 text-sm text-gray-600 text-center">
<p v-if="signedAt" class="mt-2 text-sm text-muted-foreground text-center">
Le {{ formatDate(signedAt) }}
</p>
</div>

View File

@@ -1,7 +1,7 @@
<!-- SPDX-License-Identifier: AGPL-3.0-or-later -->
<script setup lang="ts">
import { ref, computed } from 'vue'
import { useRoute } from 'vue-router'
import { useRoute, useRouter } from 'vue-router'
import { useAuthStore } from '@/stores/auth'
import { Menu, X, ChevronDown, User, LogOut, Shield, FileSignature } from 'lucide-vue-next'
import Button from '@/components/ui/Button.vue'
@@ -14,6 +14,7 @@ const { t } = useI18n()
const authStore = useAuthStore()
const route = useRoute()
const router = useRouter()
const mobileMenuOpen = ref(false)
const userMenuOpen = ref(false)
@@ -35,7 +36,7 @@ const toggleUserMenu = () => {
}
const login = () => {
authStore.startOAuthLogin()
router.push({name: 'auth-choice'})
}
const logout = async () => {

View File

@@ -33,7 +33,7 @@ const getInitialLocale = (): string => {
return browserLocale
}
return 'fr'
return 'en'
}
export const i18n = createI18n({

View File

@@ -26,6 +26,31 @@
"auth": {
"user": {
"connectedAs": "Angemeldet als"
},
"choice": {
"title": "Bei Ackify anmelden",
"subtitle": "Wählen Sie Ihre bevorzugte Authentifizierungsmethode",
"privacy": "Ihre Authentifizierung ist sicher und verschlüsselt"
},
"oauth": {
"title": "Anmeldung mit OAuth",
"description": "Verwenden Sie Ihr bestehendes Konto",
"button": "Mit OAuth fortfahren",
"error": "OAuth-Anmeldung fehlgeschlagen"
},
"magiclink": {
"title": "Anmeldung per E-Mail",
"description": "Wir senden Ihnen einen magischen Link",
"email_label": "E-Mail-Adresse",
"email_placeholder": "sie{'@'}beispiel.de",
"button": "Magischen Link senden",
"sent": {
"title": "Überprüfen Sie Ihre E-Mail",
"message": "Wir haben Ihnen einen magischen Link geschickt. Klicken Sie darauf, um sich anzumelden.",
"expire": "Der Link läuft in 15 Minuten ab."
},
"error_invalid_email": "Bitte geben Sie eine gültige E-Mail-Adresse ein",
"error_send": "Fehler beim Senden des magischen Links"
}
},
"sign": {
@@ -225,9 +250,9 @@
"addExpectedSigner": "Erwarteten Unterzeichner hinzufügen",
"addSigners": "Erwartete Leser hinzufügen",
"emailsLabel": "E-Mails (eine pro Zeile)",
"emailsPlaceholder": "Maria Schmidt <maria.schmidt@example.com>\njohn.mueller@example.com\nSophie Weber <sophie@example.com>",
"emailsPlaceholder": "Maria Schmidt <maria.schmidt{'@'}example.com>\njohn.mueller{'@'}example.com\nSophie Weber <sophie{'@'}example.com>",
"emailLabel": "E-Mail *",
"emailPlaceholder": "email@example.com",
"emailPlaceholder": "email{'@'}example.com",
"nameLabel": "Name",
"namePlaceholder": "Vollständiger Name",
"reader": "Leser",

View File

@@ -26,6 +26,31 @@
"auth": {
"user": {
"connectedAs": "Connected as"
},
"choice": {
"title": "Sign in to Ackify",
"subtitle": "Choose your preferred authentication method",
"privacy": "Your authentication is secure and encrypted"
},
"oauth": {
"title": "Sign in with OAuth",
"description": "Use your entreprise account",
"button": "Continue with OAuth",
"error": "OAuth login failed"
},
"magiclink": {
"title": "Sign in with Email",
"description": "We'll send you a magic link",
"email_label": "Email address",
"email_placeholder": "you{'@'}example.com",
"button": "Send Magic Link",
"sent": {
"title": "Check your email",
"message": "We sent you a magic link. Click on it to sign in.",
"expire": "The link expires in 15 minutes."
},
"error_invalid_email": "Please enter a valid email address",
"error_send": "Failed to send magic link"
}
},
"sign": {
@@ -276,9 +301,9 @@
"addExpectedSigner": "Add expected signer",
"addSigners": "Add expected readers",
"emailsLabel": "Emails (one per line)",
"emailsPlaceholder": "Mary Smith <mary.smith@example.com>\njohn.doe@example.com\nSophie Johnson <sophie@example.com>",
"emailsPlaceholder": "Mary Smith <mary.smith{'@'}example.com>\njohn.doe{'@'}example.com\nSophie Johnson <sophie{'@'}example.com>",
"emailLabel": "Email *",
"emailPlaceholder": "email@example.com",
"emailPlaceholder": "email{'@'}example.com",
"nameLabel": "Name",
"namePlaceholder": "Full name",
"reader": "Reader",

View File

@@ -26,6 +26,31 @@
"auth": {
"user": {
"connectedAs": "Conectado como"
},
"choice": {
"title": "Iniciar sesión en Ackify",
"subtitle": "Elija su método de autenticación preferido",
"privacy": "Su autenticación es segura y cifrada"
},
"oauth": {
"title": "Iniciar sesión con OAuth",
"description": "Use su cuenta existente",
"button": "Continuar con OAuth",
"error": "Error de inicio de sesión OAuth"
},
"magiclink": {
"title": "Iniciar sesión por correo electrónico",
"description": "Le enviaremos un enlace mágico",
"email_label": "Dirección de correo electrónico",
"email_placeholder": "usted{'@'}ejemplo.com",
"button": "Enviar enlace mágico",
"sent": {
"title": "Revise su correo electrónico",
"message": "Le hemos enviado un enlace mágico. Haga clic en él para iniciar sesión.",
"expire": "El enlace caduca en 15 minutos."
},
"error_invalid_email": "Por favor ingrese una dirección de correo electrónico válida",
"error_send": "Error al enviar el enlace mágico"
}
},
"sign": {
@@ -225,9 +250,9 @@
"addExpectedSigner": "Agregar firmante esperado",
"addSigners": "Agregar lectores esperados",
"emailsLabel": "Correos electrónicos (uno por línea)",
"emailsPlaceholder": "María García <maria.garcia@example.com>\njuan.martinez@example.com\nSofía Rodríguez <sofia@example.com>",
"emailsPlaceholder": "María García <maria.garcia{'@'}example.com>\njuan.martinez{'@'}example.com\nSofía Rodríguez <sofia{'@'}example.com>",
"emailLabel": "Correo electrónico *",
"emailPlaceholder": "email@example.com",
"emailPlaceholder": "email{'@'}example.com",
"nameLabel": "Nombre",
"namePlaceholder": "Nombre completo",
"reader": "Lector",

View File

@@ -26,6 +26,31 @@
"auth": {
"user": {
"connectedAs": "Connecté en tant que"
},
"choice": {
"title": "Connexion à Ackify",
"subtitle": "Choisissez votre méthode d'authentification préférée",
"privacy": "Votre authentification est sécurisée et chiffrée"
},
"oauth": {
"title": "Connexion avec OAuth",
"description": "Utilisez votre compte entreprise",
"button": "Continuer avec OAuth",
"error": "Échec de la connexion OAuth"
},
"magiclink": {
"title": "Connexion par Email",
"description": "Nous vous enverrons un lien magique",
"email_label": "Adresse email",
"email_placeholder": "vous{'@'}exemple.com",
"button": "Envoyer le lien magique",
"sent": {
"title": "Consultez vos emails",
"message": "Nous vous avons envoyé un lien magique. Cliquez dessus pour vous connecter.",
"expire": "Le lien expire dans 15 minutes."
},
"error_invalid_email": "Veuillez entrer une adresse email valide",
"error_send": "Échec de l'envoi du lien magique"
}
},
"sign": {
@@ -276,9 +301,9 @@
"addExpectedSigner": "Ajouter un signataire attendu",
"addSigners": "Ajouter des lecteurs attendus",
"emailsLabel": "Emails (un par ligne)",
"emailsPlaceholder": "Marie Dupont <marie.dupont@example.com>\njean.martin@example.com\nSophie Bernard <sophie@example.com>",
"emailsPlaceholder": "Marie Dupont <marie.dupont{'@'}example.com>\njean.martin{'@'}example.com\nSophie Bernard <sophie{'@'}example.com>",
"emailLabel": "Email *",
"emailPlaceholder": "email@example.com",
"emailPlaceholder": "email{'@'}example.com",
"nameLabel": "Nom",
"namePlaceholder": "Nom complet",
"reader": "Lecteur",

View File

@@ -26,6 +26,31 @@
"auth": {
"user": {
"connectedAs": "Connesso come"
},
"choice": {
"title": "Accedi ad Ackify",
"subtitle": "Scegli il tuo metodo di autenticazione preferito",
"privacy": "La tua autenticazione è sicura e crittografata"
},
"oauth": {
"title": "Accedi con OAuth",
"description": "Usa il tuo account esistente",
"button": "Continua con OAuth",
"error": "Accesso OAuth fallito"
},
"magiclink": {
"title": "Accedi tramite email",
"description": "Ti invieremo un link magico",
"email_label": "Indirizzo email",
"email_placeholder": "tu{'@'}esempio.com",
"button": "Invia link magico",
"sent": {
"title": "Controlla la tua email",
"message": "Ti abbiamo inviato un link magico. Clicca su di esso per accedere.",
"expire": "Il link scade tra 15 minuti."
},
"error_invalid_email": "Inserisci un indirizzo email valido",
"error_send": "Errore durante l'invio del link magico"
}
},
"sign": {
@@ -225,9 +250,9 @@
"addExpectedSigner": "Aggiungi firmatario atteso",
"addSigners": "Aggiungi lettori attesi",
"emailsLabel": "Email (una per riga)",
"emailsPlaceholder": "Maria Rossi <maria.rossi@example.com>\ngiovanni.bianchi@example.com\nSofia Verdi <sofia@example.com>",
"emailsPlaceholder": "Maria Rossi <maria.rossi{'@'}example.com>\ngiovanni.bianchi{'@'}example.com\nSofia Verdi <sofia{'@'}example.com>",
"emailLabel": "Email *",
"emailPlaceholder": "email@example.com",
"emailPlaceholder": "email{'@'}example.com",
"nameLabel": "Nome",
"namePlaceholder": "Nome completo",
"reader": "Lettore",

View File

@@ -0,0 +1,224 @@
<!-- SPDX-License-Identifier: AGPL-3.0-or-later -->
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { useRouter, useRoute } from 'vue-router'
import { useAuthStore } from '@/stores/auth'
import { useI18n } from 'vue-i18n'
import { usePageTitle } from '@/composables/usePageTitle'
import { Mail, LogIn, Loader2, AlertCircle, CheckCircle2 } from 'lucide-vue-next'
import Card from '@/components/ui/Card.vue'
import CardHeader from '@/components/ui/CardHeader.vue'
import CardTitle from '@/components/ui/CardTitle.vue'
import CardDescription from '@/components/ui/CardDescription.vue'
import CardContent from '@/components/ui/CardContent.vue'
import Button from '@/components/ui/Button.vue'
import Alert from '@/components/ui/Alert.vue'
import AlertTitle from '@/components/ui/AlertTitle.vue'
import AlertDescription from '@/components/ui/AlertDescription.vue'
const { t } = useI18n()
usePageTitle('auth.choice.title')
const router = useRouter()
const route = useRoute()
const authStore = useAuthStore()
const email = ref('')
const loading = ref(false)
const magicLinkSent = ref(false)
const errorMessage = ref('')
// Lire les flags d'authentification depuis les variables globales injectées dans index.html
const oauthEnabled = (window as any).ACKIFY_OAUTH_ENABLED || false
const magicLinkEnabled = (window as any).ACKIFY_MAGICLINK_ENABLED || false
const redirectTo = computed(() => {
return (route.query.redirect as string) || '/'
})
function checkAuthMethods() {
// Si aucune méthode disponible
if (!oauthEnabled && !magicLinkEnabled) {
errorMessage.value = t('auth.error.no_method_available')
return
}
// Si une seule méthode disponible (OAuth), rediriger automatiquement
const methods = [oauthEnabled, magicLinkEnabled].filter(Boolean)
if (methods.length === 1 && oauthEnabled) {
loginWithOAuth()
}
// Si seulement MagicLink, l'utilisateur doit quand même entrer son email (pas de redirection auto)
}
onMounted(async () => {
// Si déjà connecté, rediriger
if (!authStore.initialized) {
await authStore.checkAuth()
}
if (authStore.isAuthenticated) {
await router.push(redirectTo.value)
return
}
// Vérifier les méthodes d'authentification disponibles
checkAuthMethods()
})
async function loginWithOAuth() {
loading.value = true
errorMessage.value = ''
localStorage.setItem('preferredAuthMethod', 'oauth')
try {
await authStore.startOAuthLogin(redirectTo.value)
} catch (error: any) {
errorMessage.value = error.message || t('auth.oauth.error')
} finally {
loading.value = false
}
}
async function requestMagicLink() {
if (!email.value || !isValidEmail(email.value)) {
errorMessage.value = t('auth.magiclink.error_invalid_email')
return
}
loading.value = true
errorMessage.value = ''
magicLinkSent.value = false
localStorage.setItem('preferredAuthMethod', 'magiclink')
try {
const response = await fetch('/api/v1/auth/magic-link/request', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
email: email.value,
redirectTo: redirectTo.value,
}),
})
if (!response.ok) {
throw new Error(t('auth.magiclink.error_send'))
}
magicLinkSent.value = true
} catch (error: any) {
errorMessage.value = error.message || t('auth.magiclink.error_send')
} finally {
loading.value = false
}
}
function isValidEmail(email: string): boolean {
const re = /^[^\s@]+@[^\s@]+\.[^\s@]+$/
return re.test(email)
}
</script>
<template>
<div class="min-h-full box-border flex items-center justify-center bg-background py-12 px-4 sm:px-6 lg:px-8">
<div class="max-w-md w-full space-y-8">
<div class="text-center">
<h1 class="text-3xl font-bold text-foreground">
{{ t('auth.choice.title') }}
</h1>
<p class="mt-2 text-sm text-muted-foreground">
{{ t('auth.choice.subtitle') }}
</p>
</div>
<Alert v-if="errorMessage" variant="destructive">
<AlertCircle class="h-4 w-4" />
<AlertTitle>{{ t('common.error') }}</AlertTitle>
<AlertDescription>{{ errorMessage }}</AlertDescription>
</Alert>
<Alert v-if="magicLinkSent" variant="default" class="border-green-200 bg-green-50">
<CheckCircle2 class="h-4 w-4 text-green-600" />
<AlertTitle class="text-green-800">{{ t('auth.magiclink.sent.title') }}</AlertTitle>
<AlertDescription class="text-green-700">
{{ t('auth.magiclink.sent.message') }}
<br>
<span class="text-xs text-green-600">
{{ t('auth.magiclink.sent.expire') }}
</span>
</AlertDescription>
</Alert>
<!-- OAuth Login -->
<Card v-if="oauthEnabled">
<CardHeader>
<CardTitle class="flex items-center gap-2">
<LogIn class="h-5 w-5" />
{{ t('auth.oauth.title') }}
</CardTitle>
<CardDescription>
{{ t('auth.oauth.description') }}
</CardDescription>
</CardHeader>
<CardContent>
<Button
@click="loginWithOAuth"
:disabled="loading"
class="w-full"
size="lg"
>
<Loader2 v-if="loading" class="h-4 w-4 animate-spin mr-2" />
{{ t('auth.oauth.button') }}
</Button>
</CardContent>
</Card>
<!-- Magic Link Login -->
<Card v-if="magicLinkEnabled">
<CardHeader>
<CardTitle class="flex items-center gap-2">
<Mail class="h-5 w-5" />
{{ t('auth.magiclink.title') }}
</CardTitle>
<CardDescription>
{{ t('auth.magiclink.description') }}
</CardDescription>
</CardHeader>
<CardContent>
<form @submit.prevent="requestMagicLink" class="space-y-4">
<div>
<label for="email" class="block text-sm font-medium text-foreground mb-1">
{{ t('auth.magiclink.email_label') }}
</label>
<input
id="email"
v-model="email"
type="email"
required
:disabled="loading"
:placeholder="t('auth.magiclink.email_placeholder')"
class="w-full px-3 py-2 border border-border rounded-md shadow-sm bg-input text-foreground placeholder:text-muted-foreground focus:outline-none focus:ring-2 focus:ring-ring focus:border-transparent disabled:opacity-50 disabled:cursor-not-allowed"
/>
</div>
<Button
type="submit"
:disabled="loading"
class="w-full"
size="lg"
variant="outline"
>
<Loader2 v-if="loading" class="h-4 w-4 animate-spin mr-2" />
<Mail v-else class="h-4 w-4 mr-2" />
{{ t('auth.magiclink.button') }}
</Button>
</form>
</CardContent>
</Card>
<p class="text-center text-xs text-muted-foreground">
{{ t('auth.choice.privacy') }}
</p>
</div>
</div>
</template>

View File

@@ -8,7 +8,7 @@ usePageTitle('notFound.title')
</script>
<template>
<div class="min-h-screen bg-background text-foreground flex items-center justify-center">
<div class="min-h-full box-border bg-background text-foreground flex items-center justify-center">
<div class="text-center">
<h1 class="text-6xl font-bold text-muted-foreground">404</h1>
<p class="text-xl text-foreground mt-4">{{ t('notFound.title') }}</p>
@@ -21,4 +21,4 @@ usePageTitle('notFound.title')
</router-link>
</div>
</div>
</template>
</template>

View File

@@ -2,8 +2,9 @@
import { createRouter, createWebHistory, type RouteRecordRaw } from 'vue-router'
import { useAuthStore } from '@/stores/auth'
const SignPage = () => import('@/pages/SignPage.vue')
const HomePage = () => import('@/pages/HomePage.vue')
const SignaturesPage = () => import('@/pages/SignaturesPage.vue')
const AuthChoicePage = () => import('@/pages/AuthChoicePage.vue')
const AdminDashboard = () => import('@/pages/admin/AdminDashboard.vue')
const AdminDocumentDetail = () => import('@/pages/admin/AdminDocumentDetail.vue')
const AdminWebhooks = () => import('@/pages/admin/AdminWebhooks.vue')
@@ -14,8 +15,14 @@ const NotFoundPage = () => import('@/pages/NotFoundPage.vue')
const routes: RouteRecordRaw[] = [
{
path: '/',
name: 'sign',
component: SignPage,
name: 'home',
component: HomePage,
meta: { requiresAuth: false }
},
{
path: '/auth',
name: 'auth-choice',
component: AuthChoicePage,
meta: { requiresAuth: false }
},
{
@@ -90,12 +97,12 @@ router.beforeEach(async (to, from, next) => {
if (!authStore.isAuthenticated) {
sessionStorage.setItem('redirectAfterLogin', to.fullPath)
await authStore.startOAuthLogin(to.fullPath)
return false
next({ name: 'auth-choice', query: { redirect: to.fullPath } })
return
}
if (to.meta.requiresAdmin && !authStore.isAdmin) {
next({ name: 'sign' })
next({ name: 'home' })
return
}
}