diff --git a/.dockerignore b/.dockerignore index d2f27e7..088df0d 100644 --- a/.dockerignore +++ b/.dockerignore @@ -53,6 +53,7 @@ install/ client_secret*.json # Node.js (if any frontend assets) +**/node_modules/ node_modules/ npm-debug.log yarn-error.log \ No newline at end of file diff --git a/.env.example b/.env.example index 25acbf1..b82a094 100644 --- a/.env.example +++ b/.env.example @@ -5,10 +5,8 @@ ACKIFY_LOG_LEVEL=info ACKIFY_LOG_FORMAT=classic # Database Configuration -POSTGRES_USER=ackifyr POSTGRES_PASSWORD=your_secure_password -POSTGRES_DB=ackify -ACKIFY_DB_DSN=postgres://ackifyr:your_secure_password@localhost:5432/ackify?sslmode=disable +ACKIFY_APP_PASSWORD=ackify_app_password # ============================================================================ # Authentication Configuration diff --git a/CHANGELOG.md b/CHANGELOG.md index e5d176e..2dc8864 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,48 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.2.8] - 2025-12-15 + +### 🔐 Multi-Tenant Security & Row Level Security + +Version majeure de sécurité introduisant l'isolation des données par tenant avec PostgreSQL Row Level Security (RLS). + +### Added + +- **Row Level Security (RLS)** + - Isolation des données au niveau PostgreSQL avec politiques RLS + - Protection de 11 tables : documents, signatures, expected_signers, webhooks, reminder_logs, email_queue, checksum_verifications, webhook_deliveries, oauth_sessions, magic_link_tokens, magic_link_auth_attempts + - Fonction `current_tenant_id()` pour récupérer le tenant de la session + - `FORCE ROW LEVEL SECURITY` pour appliquer les politiques même aux propriétaires des tables + - Comportement sécurisé par défaut : aucune donnée accessible si tenant non défini + +- **Support Multi-Tenant** + - Nouvelle table `instance_metadata` stockant l'UUID unique du tenant + - Colonne `tenant_id` (UUID) ajoutée à toutes les tables métier et d'authentification + - Index optimisés sur `tenant_id` pour des performances optimales + - Triggers d'immutabilité empêchant la modification du `tenant_id` après création + - Backfill automatique des données existantes avec le tenant de l'instance + +- **Gestion du Rôle Applicatif** + - Création automatique du rôle `ackify_app` par l'outil de migration + - Séparation des privilèges (rôle applicatif vs rôle superuser) + - Variable d'environnement `ACKIFY_APP_PASSWORD` pour définir le mot de passe du rôle + - Privilèges par défaut configurés pour les futures tables + +### Technical Details + +**Nouvelles migrations :** +- `0015_add_tenant_support.{up,down}.sql` - Support multi-tenant +- `0016_add_rls_policies.{up,down}.sql` - Politiques RLS + +**Fichiers modifiés :** +- `backend/cmd/migrate/main.go` - Création du rôle `ackify_app` + +**Sécurité :** +- Les politiques RLS utilisent `USING` et `WITH CHECK` pour filtrer lectures et écritures +- Les tokens magic link acceptent `tenant_id IS NULL` pour les requêtes de login +- Les sessions OAuth sont isolées par tenant après authentification + ## [1.2.6] - 2025-12-08 ### 🏗️ Architecture & CI/CD @@ -569,6 +611,7 @@ For users upgrading from v1.1.x to v1.2.0: - NULL UserName handling in database operations - Proper string conversion for UserName field +[1.2.8]: https://github.com/btouchard/ackify-ce/compare/v1.2.6...v1.2.8 [1.2.6]: https://github.com/btouchard/ackify-ce/compare/v1.2.5...v1.2.6 [1.2.5]: https://github.com/btouchard/ackify-ce/compare/v1.2.4...v1.2.5 [1.2.4]: https://github.com/btouchard/ackify-ce/compare/v1.2.3...v1.2.4 diff --git a/backend/cmd/migrate/main.go b/backend/cmd/migrate/main.go index 979ebd8..d1293c3 100644 --- a/backend/cmd/migrate/main.go +++ b/backend/cmd/migrate/main.go @@ -1,13 +1,13 @@ package main import ( + "database/sql" "errors" "flag" "fmt" "log" "os" - - "database/sql" + "strings" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/postgres" @@ -52,6 +52,11 @@ func main() { switch command { case "up": + // Ensure ackify_app role exists before running migrations (for RLS support) + if err := ensureAppRole(db); err != nil { + log.Fatal("Failed to ensure ackify_app role:", err) + } + err = m.Up() if err != nil && !errors.Is(err, migrate.ErrNoChange) { log.Fatal("Migration up failed:", err) @@ -128,6 +133,9 @@ func printUsage() { fmt.Println(" -db-dsn string Database DSN (or DB_DSN env var)") fmt.Println(" -migrations-path string Path to migrations (default: file://migrations)") fmt.Println() + fmt.Println("Environment:") + fmt.Println(" ACKIFY_APP_PASSWORD Password for the ackify_app role (required for RLS)") + fmt.Println() fmt.Println("Examples:") fmt.Println(" migrate up") fmt.Println(" migrate down 2") @@ -135,3 +143,76 @@ func printUsage() { fmt.Println(" migrate force 1 # For existing DB with only signatures table") fmt.Println(" migrate version") } + +// ensureAppRole creates or updates the ackify_app role used for RLS. +// The password is read from ACKIFY_APP_PASSWORD environment variable. +// If not set, the function logs a warning and continues (for backward compatibility). +// If set, the role is created (or password updated) before migrations run. +func ensureAppRole(db *sql.DB) error { + password := strings.TrimSpace(os.Getenv("ACKIFY_APP_PASSWORD")) + if password == "" { + log.Println("WARNING: ACKIFY_APP_PASSWORD not set. ackify_app role will not be created.") + log.Println(" RLS migrations will fail if the role doesn't exist.") + log.Println(" Set ACKIFY_APP_PASSWORD to enable RLS support.") + return nil + } + + // Check if role exists + var exists bool + err := db.QueryRow("SELECT EXISTS(SELECT 1 FROM pg_roles WHERE rolname = 'ackify_app')").Scan(&exists) + if err != nil { + return fmt.Errorf("failed to check if ackify_app role exists: %w", err) + } + + if exists { + // Update password to ensure it matches environment + _, err = db.Exec(fmt.Sprintf("ALTER ROLE ackify_app WITH PASSWORD '%s'", escapePassword(password))) + if err != nil { + return fmt.Errorf("failed to update ackify_app password: %w", err) + } + log.Println("ackify_app role exists, password updated") + } else { + // Create the role with all necessary attributes + createSQL := fmt.Sprintf(` + CREATE ROLE ackify_app WITH + LOGIN + PASSWORD '%s' + NOCREATEDB + NOCREATEROLE + NOINHERIT + NOREPLICATION + CONNECTION LIMIT -1 + `, escapePassword(password)) + + _, err = db.Exec(createSQL) + if err != nil { + return fmt.Errorf("failed to create ackify_app role: %w", err) + } + log.Println("ackify_app role created successfully") + } + + // Grant CONNECT on database (idempotent) + var dbName string + err = db.QueryRow("SELECT current_database()").Scan(&dbName) + if err != nil { + return fmt.Errorf("failed to get current database name: %w", err) + } + + _, err = db.Exec(fmt.Sprintf("GRANT CONNECT ON DATABASE %s TO ackify_app", dbName)) + if err != nil { + return fmt.Errorf("failed to grant CONNECT to ackify_app: %w", err) + } + + // Grant USAGE on public schema (idempotent) + _, err = db.Exec("GRANT USAGE ON SCHEMA public TO ackify_app") + if err != nil { + return fmt.Errorf("failed to grant USAGE on public schema: %w", err) + } + + return nil +} + +// escapePassword escapes single quotes in password for SQL +func escapePassword(password string) string { + return strings.ReplaceAll(password, "'", "''") +} diff --git a/backend/internal/infrastructure/auth/session_worker.go b/backend/internal/infrastructure/auth/session_worker.go index f8cf2df..99f3a0c 100644 --- a/backend/internal/infrastructure/auth/session_worker.go +++ b/backend/internal/infrastructure/auth/session_worker.go @@ -3,10 +3,12 @@ package auth import ( "context" + "database/sql" "fmt" "sync" "time" + "github.com/btouchard/ackify-ce/backend/internal/infrastructure/tenant" "github.com/btouchard/ackify-ce/backend/pkg/logger" ) @@ -16,6 +18,10 @@ type SessionWorker struct { cleanupInterval time.Duration cleanupAge time.Duration + // RLS support + db *sql.DB + tenants tenant.Provider + ctx context.Context cancel context.CancelFunc wg sync.WaitGroup @@ -39,7 +45,7 @@ func DefaultSessionWorkerConfig() SessionWorkerConfig { } // NewSessionWorker creates a new OAuth session cleanup worker -func NewSessionWorker(sessionRepo SessionRepository, config SessionWorkerConfig) *SessionWorker { +func NewSessionWorker(sessionRepo SessionRepository, config SessionWorkerConfig, db *sql.DB, tenants tenant.Provider) *SessionWorker { // Apply defaults if config.CleanupInterval <= 0 { config.CleanupInterval = 24 * time.Hour @@ -54,6 +60,8 @@ func NewSessionWorker(sessionRepo SessionRepository, config SessionWorkerConfig) sessionRepo: sessionRepo, cleanupInterval: config.CleanupInterval, cleanupAge: config.CleanupAge, + db: db, + tenants: tenants, ctx: ctx, cancel: cancel, stopChan: make(chan struct{}), @@ -148,7 +156,28 @@ func (w *SessionWorker) performCleanup() { logger.Logger.Debug("Starting OAuth session cleanup", "older_than", w.cleanupAge) - deleted, err := w.sessionRepo.DeleteExpired(ctx, w.cleanupAge) + var deleted int64 + var err error + + // Use RLS context if db and tenants are available + if w.db != nil && w.tenants != nil { + tenantID, tenantErr := w.tenants.CurrentTenant(ctx) + if tenantErr != nil { + logger.Logger.Error("Failed to get tenant for session cleanup", + "error", tenantErr.Error()) + return + } + + err = tenant.WithTenantContext(ctx, w.db, tenantID, func(txCtx context.Context) error { + var cleanupErr error + deleted, cleanupErr = w.sessionRepo.DeleteExpired(txCtx, w.cleanupAge) + return cleanupErr + }) + } else { + // No RLS - direct repository access (for tests) + deleted, err = w.sessionRepo.DeleteExpired(ctx, w.cleanupAge) + } + if err != nil { logger.Logger.Error("Failed to cleanup expired OAuth sessions", "error", err.Error()) diff --git a/backend/internal/infrastructure/auth/session_worker_test.go b/backend/internal/infrastructure/auth/session_worker_test.go index c508bae..765f17c 100644 --- a/backend/internal/infrastructure/auth/session_worker_test.go +++ b/backend/internal/infrastructure/auth/session_worker_test.go @@ -10,8 +10,18 @@ import ( "time" "github.com/btouchard/ackify-ce/backend/internal/domain/models" + "github.com/google/uuid" ) +// mockTenantProvider for testing +type mockTenantProvider struct { + tenantID uuid.UUID +} + +func (m *mockTenantProvider) CurrentTenant(ctx context.Context) (uuid.UUID, error) { + return m.tenantID, nil +} + // mockSessionRepo implements SessionRepository for testing type mockSessionRepoForWorker struct { mu sync.Mutex @@ -56,7 +66,7 @@ func TestSessionWorker_StartStop(t *testing.T) { CleanupAge: 1 * time.Hour, } - worker := NewSessionWorker(repo, config) + worker := NewSessionWorker(repo, config, nil, &mockTenantProvider{tenantID: uuid.New()}) // Test starting err := worker.Start() @@ -113,7 +123,7 @@ func TestSessionWorker_CleanupSuccess(t *testing.T) { CleanupAge: 24 * time.Hour, } - worker := NewSessionWorker(repo, config) + worker := NewSessionWorker(repo, config, nil, &mockTenantProvider{tenantID: uuid.New()}) err := worker.Start() if err != nil { @@ -147,7 +157,7 @@ func TestSessionWorker_CleanupError(t *testing.T) { CleanupAge: 24 * time.Hour, } - worker := NewSessionWorker(repo, config) + worker := NewSessionWorker(repo, config, nil, &mockTenantProvider{tenantID: uuid.New()}) err := worker.Start() if err != nil { @@ -181,7 +191,7 @@ func TestSessionWorker_ImmediateCleanupOnStart(t *testing.T) { CleanupAge: 24 * time.Hour, } - worker := NewSessionWorker(repo, config) + worker := NewSessionWorker(repo, config, nil, &mockTenantProvider{tenantID: uuid.New()}) err := worker.Start() if err != nil { @@ -228,7 +238,7 @@ func TestSessionWorker_GracefulShutdown(t *testing.T) { CleanupAge: 1 * time.Hour, } - worker := NewSessionWorker(repo, config) + worker := NewSessionWorker(repo, config, nil, &mockTenantProvider{tenantID: uuid.New()}) err := worker.Start() if err != nil { @@ -295,7 +305,7 @@ func TestSessionWorker_ContextCancellation(t *testing.T) { CleanupAge: 1 * time.Hour, } - worker := NewSessionWorker(repo, config) + worker := NewSessionWorker(repo, config, nil, &mockTenantProvider{tenantID: uuid.New()}) err := worker.Start() if err != nil { diff --git a/backend/internal/infrastructure/database/document_repository.go b/backend/internal/infrastructure/database/document_repository.go index a4c56bd..bb92b3a 100644 --- a/backend/internal/infrastructure/database/document_repository.go +++ b/backend/internal/infrastructure/database/document_repository.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/btouchard/ackify-ce/backend/internal/domain/models" + "github.com/btouchard/ackify-ce/backend/internal/infrastructure/dbctx" "github.com/btouchard/ackify-ce/backend/internal/infrastructure/tenant" "github.com/btouchard/ackify-ce/backend/pkg/logger" ) @@ -46,7 +47,7 @@ func (r *DocumentRepository) Create(ctx context.Context, docID string, input mod } doc := &models.Document{} - err = r.db.QueryRowContext( + err = dbctx.GetQuerier(ctx, r.db).QueryRowContext( ctx, query, tenantID, @@ -80,20 +81,16 @@ func (r *DocumentRepository) Create(ctx context.Context, docID string, input mod } // GetByDocID retrieves document metadata by document ID (excluding soft-deleted documents) +// RLS policy automatically filters by tenant_id func (r *DocumentRepository) GetByDocID(ctx context.Context, docID string) (*models.Document, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` SELECT doc_id, tenant_id, title, url, checksum, checksum_algorithm, description, created_at, updated_at, created_by, deleted_at FROM documents - WHERE tenant_id = $1 AND doc_id = $2 AND deleted_at IS NULL + WHERE doc_id = $1 AND deleted_at IS NULL ` doc := &models.Document{} - err = r.db.QueryRowContext(ctx, query, tenantID, docID).Scan( + err := dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query, docID).Scan( &doc.DocID, &doc.TenantID, &doc.Title, @@ -120,12 +117,8 @@ func (r *DocumentRepository) GetByDocID(ctx context.Context, docID string) (*mod } // FindByReference searches for a document by reference (URL, path, or doc_id) +// RLS policy automatically filters by tenant_id func (r *DocumentRepository) FindByReference(ctx context.Context, ref string, refType string) (*models.Document, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - var query string var args []interface{} @@ -135,37 +128,37 @@ func (r *DocumentRepository) FindByReference(ctx context.Context, ref string, re query = ` SELECT doc_id, tenant_id, title, url, checksum, checksum_algorithm, description, created_at, updated_at, created_by, deleted_at FROM documents - WHERE tenant_id = $1 AND url = $2 AND deleted_at IS NULL + WHERE url = $1 AND deleted_at IS NULL LIMIT 1 ` - args = []interface{}{tenantID, ref} + args = []interface{}{ref} case "path": // Search by URL field (paths are also stored in url field, excluding soft-deleted) query = ` SELECT doc_id, tenant_id, title, url, checksum, checksum_algorithm, description, created_at, updated_at, created_by, deleted_at FROM documents - WHERE tenant_id = $1 AND url = $2 AND deleted_at IS NULL + WHERE url = $1 AND deleted_at IS NULL LIMIT 1 ` - args = []interface{}{tenantID, ref} + args = []interface{}{ref} case "reference": // Search by doc_id (excluding soft-deleted) query = ` SELECT doc_id, tenant_id, title, url, checksum, checksum_algorithm, description, created_at, updated_at, created_by, deleted_at FROM documents - WHERE tenant_id = $1 AND doc_id = $2 AND deleted_at IS NULL + WHERE doc_id = $1 AND deleted_at IS NULL LIMIT 1 ` - args = []interface{}{tenantID, ref} + args = []interface{}{ref} default: return nil, fmt.Errorf("unknown reference type: %s", refType) } doc := &models.Document{} - err = r.db.QueryRowContext(ctx, query, args...).Scan( + err := dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query, args...).Scan( &doc.DocID, &doc.TenantID, &doc.Title, @@ -203,16 +196,12 @@ func (r *DocumentRepository) FindByReference(ctx context.Context, ref string, re } // Update modifies existing document metadata while preserving creation timestamp and creator +// RLS policy automatically filters by tenant_id func (r *DocumentRepository) Update(ctx context.Context, docID string, input models.DocumentInput) (*models.Document, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` UPDATE documents - SET title = $3, url = $4, checksum = $5, checksum_algorithm = $6, description = $7 - WHERE tenant_id = $1 AND doc_id = $2 AND deleted_at IS NULL + SET title = $2, url = $3, checksum = $4, checksum_algorithm = $5, description = $6 + WHERE doc_id = $1 AND deleted_at IS NULL RETURNING doc_id, tenant_id, title, url, checksum, checksum_algorithm, description, created_at, updated_at, created_by, deleted_at ` @@ -224,10 +213,9 @@ func (r *DocumentRepository) Update(ctx context.Context, docID string, input mod } doc := &models.Document{} - err = r.db.QueryRowContext( + err := dbctx.GetQuerier(ctx, r.db).QueryRowContext( ctx, query, - tenantID, docID, input.Title, input.URL, @@ -288,7 +276,7 @@ func (r *DocumentRepository) CreateOrUpdate(ctx context.Context, docID string, i } doc := &models.Document{} - err = r.db.QueryRowContext( + err = dbctx.GetQuerier(ctx, r.db).QueryRowContext( ctx, query, tenantID, @@ -322,15 +310,11 @@ func (r *DocumentRepository) CreateOrUpdate(ctx context.Context, docID string, i } // Delete soft-deletes document by setting deleted_at timestamp, preserving metadata and signature history +// RLS policy automatically filters by tenant_id func (r *DocumentRepository) Delete(ctx context.Context, docID string) error { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return fmt.Errorf("failed to get tenant: %w", err) - } + query := `UPDATE documents SET deleted_at = now() WHERE doc_id = $1 AND deleted_at IS NULL` - query := `UPDATE documents SET deleted_at = now() WHERE tenant_id = $1 AND doc_id = $2 AND deleted_at IS NULL` - - result, err := r.db.ExecContext(ctx, query, tenantID, docID) + result, err := dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, query, docID) if err != nil { logger.Logger.Error("Failed to delete document", "error", err.Error(), "doc_id", docID) return fmt.Errorf("failed to delete document: %w", err) @@ -349,21 +333,17 @@ func (r *DocumentRepository) Delete(ctx context.Context, docID string) error { } // List retrieves paginated documents ordered by creation date, newest first (excluding soft-deleted) +// RLS policy automatically filters by tenant_id func (r *DocumentRepository) List(ctx context.Context, limit, offset int) ([]*models.Document, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` SELECT doc_id, tenant_id, title, url, checksum, checksum_algorithm, description, created_at, updated_at, created_by, deleted_at FROM documents - WHERE tenant_id = $1 AND deleted_at IS NULL + WHERE deleted_at IS NULL ORDER BY created_at DESC - LIMIT $2 OFFSET $3 + LIMIT $1 OFFSET $2 ` - rows, err := r.db.QueryContext(ctx, query, tenantID, limit, offset) + rows, err := dbctx.GetQuerier(ctx, r.db).QueryContext(ctx, query, limit, offset) if err != nil { logger.Logger.Error("Failed to list documents", "error", err.Error()) return nil, fmt.Errorf("failed to list documents: %w", err) @@ -402,28 +382,24 @@ func (r *DocumentRepository) List(ctx context.Context, limit, offset int) ([]*mo // Search retrieves paginated documents matching the search query (excluding soft-deleted) // Searches in doc_id, title, url, and description fields using case-insensitive pattern matching +// RLS policy automatically filters by tenant_id func (r *DocumentRepository) Search(ctx context.Context, query string, limit, offset int) ([]*models.Document, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - searchQuery := ` SELECT doc_id, tenant_id, title, url, checksum, checksum_algorithm, description, created_at, updated_at, created_by, deleted_at FROM documents - WHERE tenant_id = $1 AND deleted_at IS NULL + WHERE deleted_at IS NULL AND ( - doc_id ILIKE $2 - OR title ILIKE $2 - OR url ILIKE $2 - OR description ILIKE $2 + doc_id ILIKE $1 + OR title ILIKE $1 + OR url ILIKE $1 + OR description ILIKE $1 ) ORDER BY created_at DESC - LIMIT $3 OFFSET $4 + LIMIT $2 OFFSET $3 ` searchPattern := "%" + query + "%" - rows, err := r.db.QueryContext(ctx, searchQuery, tenantID, searchPattern, limit, offset) + rows, err := dbctx.GetQuerier(ctx, r.db).QueryContext(ctx, searchQuery, searchPattern, limit, offset) if err != nil { logger.Logger.Error("Failed to search documents", "error", err.Error(), "query", query) return nil, fmt.Errorf("failed to search documents: %w", err) @@ -467,12 +443,8 @@ func (r *DocumentRepository) Search(ctx context.Context, query string, limit, of } // Count returns the total number of documents matching the optional search query (excluding soft-deleted) +// RLS policy automatically filters by tenant_id func (r *DocumentRepository) Count(ctx context.Context, searchQuery string) (int, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return 0, fmt.Errorf("failed to get tenant: %w", err) - } - var query string var args []interface{} @@ -481,28 +453,28 @@ func (r *DocumentRepository) Count(ctx context.Context, searchQuery string) (int query = ` SELECT COUNT(*) FROM documents - WHERE tenant_id = $1 AND deleted_at IS NULL + WHERE deleted_at IS NULL AND ( - doc_id ILIKE $2 - OR title ILIKE $2 - OR url ILIKE $2 - OR description ILIKE $2 + doc_id ILIKE $1 + OR title ILIKE $1 + OR url ILIKE $1 + OR description ILIKE $1 ) ` searchPattern := "%" + searchQuery + "%" - args = []interface{}{tenantID, searchPattern} + args = []interface{}{searchPattern} } else { // Count all documents query = ` SELECT COUNT(*) FROM documents - WHERE tenant_id = $1 AND deleted_at IS NULL + WHERE deleted_at IS NULL ` - args = []interface{}{tenantID} + args = []interface{}{} } var count int - err = r.db.QueryRowContext(ctx, query, args...).Scan(&count) + err := dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query, args...).Scan(&count) if err != nil { logger.Logger.Error("Failed to count documents", "error", err.Error(), "search", searchQuery) return 0, fmt.Errorf("failed to count documents: %w", err) diff --git a/backend/internal/infrastructure/database/email_queue_repository.go b/backend/internal/infrastructure/database/email_queue_repository.go index 068ae2a..7a80cdb 100644 --- a/backend/internal/infrastructure/database/email_queue_repository.go +++ b/backend/internal/infrastructure/database/email_queue_repository.go @@ -11,6 +11,7 @@ import ( "github.com/lib/pq" "github.com/btouchard/ackify-ce/backend/internal/domain/models" + "github.com/btouchard/ackify-ce/backend/internal/infrastructure/dbctx" "github.com/btouchard/ackify-ce/backend/internal/infrastructure/tenant" "github.com/btouchard/ackify-ce/backend/pkg/logger" ) @@ -92,7 +93,7 @@ func (r *EmailQueueRepository) Enqueue(ctx context.Context, input models.EmailQu CreatedBy: input.CreatedBy, } - err = r.db.QueryRowContext( + err = dbctx.GetQuerier(ctx, r.db).QueryRowContext( ctx, query, tenantID, @@ -158,7 +159,7 @@ func (r *EmailQueueRepository) GetNextToProcess(ctx context.Context, limit int) last_error, error_details, reference_type, reference_id, created_by ` - rows, err := r.db.QueryContext(ctx, query, time.Now(), limit) + rows, err := dbctx.GetQuerier(ctx, r.db).QueryContext(ctx, query, time.Now(), limit) if err != nil { return nil, fmt.Errorf("failed to get next emails to process: %w", err) } @@ -209,7 +210,7 @@ func (r *EmailQueueRepository) MarkAsSent(ctx context.Context, id int64) error { WHERE id = $2 ` - result, err := r.db.ExecContext(ctx, query, time.Now(), id) + result, err := dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, query, time.Now(), id) if err != nil { return fmt.Errorf("failed to mark email as sent: %w", err) } @@ -290,7 +291,7 @@ func (r *EmailQueueRepository) MarkAsFailedWithDelay(ctx context.Context, id int args = []interface{}{time.Now(), errorMsg, errorDetailsJSON, id} } - result, err := r.db.ExecContext(ctx, query, args...) + result, err := dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, query, args...) if err != nil { return fmt.Errorf("failed to mark email as failed: %w", err) } @@ -310,7 +311,7 @@ func (r *EmailQueueRepository) MarkAsFailedWithDelay(ctx context.Context, id int error_details = $3 WHERE id = $4 ` - _, err = r.db.ExecContext(ctx, query, time.Now(), errorMsg, errorDetailsJSON, id) + _, err = dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, query, time.Now(), errorMsg, errorDetailsJSON, id) if err != nil { return fmt.Errorf("failed to mark email as permanently failed: %w", err) } @@ -347,7 +348,7 @@ func (r *EmailQueueRepository) GetRetryableEmails(ctx context.Context, limit int last_error, error_details, reference_type, reference_id, created_by ` - rows, err := r.db.QueryContext(ctx, query, time.Now(), limit) + rows, err := dbctx.GetQuerier(ctx, r.db).QueryContext(ctx, query, time.Now(), limit) if err != nil { return nil, fmt.Errorf("failed to get retryable emails: %w", err) } @@ -402,7 +403,7 @@ func (r *EmailQueueRepository) GetQueueStats(ctx context.Context) (*models.Email FROM email_queue GROUP BY status ` - rows, err := r.db.QueryContext(ctx, statusQuery) + rows, err := dbctx.GetQuerier(ctx, r.db).QueryContext(ctx, statusQuery) if err != nil { return nil, fmt.Errorf("failed to get status counts: %w", err) } @@ -430,7 +431,7 @@ func (r *EmailQueueRepository) GetQueueStats(ctx context.Context) (*models.Email // Get oldest pending email var oldestPending sql.NullTime - err = r.db.QueryRowContext(ctx, ` + err = dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, ` SELECT MIN(created_at) FROM email_queue WHERE status = 'pending' @@ -443,7 +444,7 @@ func (r *EmailQueueRepository) GetQueueStats(ctx context.Context) (*models.Email } // Get average retry count - err = r.db.QueryRowContext(ctx, ` + err = dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, ` SELECT AVG(retry_count)::float FROM email_queue WHERE status IN ('sent', 'failed') @@ -453,7 +454,7 @@ func (r *EmailQueueRepository) GetQueueStats(ctx context.Context) (*models.Email } // Get last 24 hours stats - err = r.db.QueryRowContext(ctx, ` + err = dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, ` SELECT COUNT(*) FILTER (WHERE status = 'sent' AND processed_at >= NOW() - INTERVAL '24 hours') as sent, COUNT(*) FILTER (WHERE status = 'failed' AND processed_at >= NOW() - INTERVAL '24 hours') as failed, @@ -476,7 +477,7 @@ func (r *EmailQueueRepository) CancelEmail(ctx context.Context, id int64) error WHERE id = $2 AND status IN ('pending', 'processing') ` - result, err := r.db.ExecContext(ctx, query, time.Now(), id) + result, err := dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, query, time.Now(), id) if err != nil { return fmt.Errorf("failed to cancel email: %w", err) } @@ -503,7 +504,7 @@ func (r *EmailQueueRepository) CleanupOldEmails(ctx context.Context, olderThan t ` cutoff := time.Now().Add(-olderThan) - result, err := r.db.ExecContext(ctx, query, cutoff) + result, err := dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, query, cutoff) if err != nil { return 0, fmt.Errorf("failed to cleanup old emails: %w", err) } diff --git a/backend/internal/infrastructure/database/expected_signer_repository.go b/backend/internal/infrastructure/database/expected_signer_repository.go index 111b0cf..fae6252 100644 --- a/backend/internal/infrastructure/database/expected_signer_repository.go +++ b/backend/internal/infrastructure/database/expected_signer_repository.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/btouchard/ackify-ce/backend/internal/domain/models" + "github.com/btouchard/ackify-ce/backend/internal/infrastructure/dbctx" "github.com/btouchard/ackify-ce/backend/internal/infrastructure/tenant" "github.com/btouchard/ackify-ce/backend/pkg/logger" ) @@ -49,7 +50,7 @@ func (r *ExpectedSignerRepository) AddExpected(ctx context.Context, docID string ON CONFLICT (doc_id, email) DO NOTHING `, strings.Join(valueStrings, ",")) - _, err = r.db.ExecContext(ctx, query, valueArgs...) + _, err = dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, query, valueArgs...) if err != nil { return fmt.Errorf("failed to add expected signers: %w", err) } @@ -58,20 +59,16 @@ func (r *ExpectedSignerRepository) AddExpected(ctx context.Context, docID string } // ListByDocID retrieves all expected signers for a document, ordered chronologically by when they were added +// RLS policy automatically filters by tenant_id func (r *ExpectedSignerRepository) ListByDocID(ctx context.Context, docID string) ([]*models.ExpectedSigner, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` SELECT id, tenant_id, doc_id, email, name, added_at, added_by, notes FROM expected_signers - WHERE tenant_id = $1 AND doc_id = $2 + WHERE doc_id = $1 ORDER BY added_at ASC ` - rows, err := r.db.QueryContext(ctx, query, tenantID, docID) + rows, err := dbctx.GetQuerier(ctx, r.db).QueryContext(ctx, query, docID) if err != nil { return nil, fmt.Errorf("failed to query expected signers: %w", err) } @@ -105,12 +102,8 @@ func (r *ExpectedSignerRepository) ListByDocID(ctx context.Context, docID string } // ListWithStatusByDocID enriches signer data with signature completion status and reminder tracking metrics +// RLS policy automatically filters by tenant_id func (r *ExpectedSignerRepository) ListWithStatusByDocID(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` SELECT es.id, @@ -131,12 +124,12 @@ func (r *ExpectedSignerRepository) ListWithStatusByDocID(ctx context.Context, do FROM expected_signers es LEFT JOIN signatures s ON es.tenant_id = s.tenant_id AND es.doc_id = s.doc_id AND es.email = s.user_email LEFT JOIN reminder_logs rl ON es.tenant_id = rl.tenant_id AND es.doc_id = rl.doc_id AND es.email = rl.recipient_email - WHERE es.tenant_id = $1 AND es.doc_id = $2 + WHERE es.doc_id = $1 GROUP BY es.id, es.tenant_id, es.doc_id, es.email, es.name, es.added_at, es.added_by, es.notes, s.id, s.signed_at, s.user_name ORDER BY has_signed DESC, es.added_at ASC ` - rows, err := r.db.QueryContext(ctx, query, tenantID, docID) + rows, err := dbctx.GetQuerier(ctx, r.db).QueryContext(ctx, query, docID) if err != nil { return nil, fmt.Errorf("failed to query expected signers with status: %w", err) } @@ -190,18 +183,14 @@ func (r *ExpectedSignerRepository) ListWithStatusByDocID(ctx context.Context, do } // Remove deletes a specific expected signer by document ID and email address +// RLS policy automatically filters by tenant_id func (r *ExpectedSignerRepository) Remove(ctx context.Context, docID, email string) error { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return fmt.Errorf("failed to get tenant: %w", err) - } - query := ` DELETE FROM expected_signers - WHERE tenant_id = $1 AND doc_id = $2 AND email = $3 + WHERE doc_id = $1 AND email = $2 ` - result, err := r.db.ExecContext(ctx, query, tenantID, docID, email) + result, err := dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, query, docID, email) if err != nil { return fmt.Errorf("failed to remove expected signer: %w", err) } @@ -219,18 +208,14 @@ func (r *ExpectedSignerRepository) Remove(ctx context.Context, docID, email stri } // RemoveAllForDoc purges all expected signers associated with a document in a single operation +// RLS policy automatically filters by tenant_id func (r *ExpectedSignerRepository) RemoveAllForDoc(ctx context.Context, docID string) error { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return fmt.Errorf("failed to get tenant: %w", err) - } - query := ` DELETE FROM expected_signers - WHERE tenant_id = $1 AND doc_id = $2 + WHERE doc_id = $1 ` - _, err = r.db.ExecContext(ctx, query, tenantID, docID) + _, err := dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, query, docID) if err != nil { return fmt.Errorf("failed to remove all expected signers: %w", err) } @@ -239,21 +224,17 @@ func (r *ExpectedSignerRepository) RemoveAllForDoc(ctx context.Context, docID st } // IsExpected efficiently verifies if an email address is in the expected signer list for a document +// RLS policy automatically filters by tenant_id func (r *ExpectedSignerRepository) IsExpected(ctx context.Context, docID, email string) (bool, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return false, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` SELECT EXISTS( SELECT 1 FROM expected_signers - WHERE tenant_id = $1 AND doc_id = $2 AND email = $3 + WHERE doc_id = $1 AND email = $2 ) ` var exists bool - err = r.db.QueryRowContext(ctx, query, tenantID, docID, email).Scan(&exists) + err := dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query, docID, email).Scan(&exists) if err != nil { return false, fmt.Errorf("failed to check if email is expected: %w", err) } @@ -262,26 +243,22 @@ func (r *ExpectedSignerRepository) IsExpected(ctx context.Context, docID, email } // GetStats calculates signature completion metrics including percentage progress for a document +// RLS policy automatically filters by tenant_id func (r *ExpectedSignerRepository) GetStats(ctx context.Context, docID string) (*models.DocCompletionStats, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` SELECT COUNT(*) as expected_count, COUNT(s.id) as signed_count FROM expected_signers es LEFT JOIN signatures s ON es.tenant_id = s.tenant_id AND es.doc_id = s.doc_id AND es.email = s.user_email - WHERE es.tenant_id = $1 AND es.doc_id = $2 + WHERE es.doc_id = $1 ` stats := &models.DocCompletionStats{ DocID: docID, } - err = r.db.QueryRowContext(ctx, query, tenantID, docID).Scan(&stats.ExpectedCount, &stats.SignedCount) + err := dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query, docID).Scan(&stats.ExpectedCount, &stats.SignedCount) if err != nil { return nil, fmt.Errorf("failed to get stats: %w", err) } diff --git a/backend/internal/infrastructure/database/oauth_session_repository.go b/backend/internal/infrastructure/database/oauth_session_repository.go index 6365ecd..725ca31 100644 --- a/backend/internal/infrastructure/database/oauth_session_repository.go +++ b/backend/internal/infrastructure/database/oauth_session_repository.go @@ -8,6 +8,7 @@ import ( "time" "github.com/btouchard/ackify-ce/backend/internal/domain/models" + "github.com/btouchard/ackify-ce/backend/internal/infrastructure/dbctx" "github.com/btouchard/ackify-ce/backend/internal/infrastructure/tenant" "github.com/btouchard/ackify-ce/backend/pkg/logger" ) @@ -52,7 +53,7 @@ func (r *OAuthSessionRepository) Create(ctx context.Context, session *models.OAu RETURNING id, created_at, updated_at ` - err = r.db.QueryRowContext( + err = dbctx.GetQuerier(ctx, r.db).QueryRowContext( ctx, query, tenantID, @@ -82,12 +83,8 @@ func (r *OAuthSessionRepository) Create(ctx context.Context, session *models.OAu } // GetBySessionID retrieves an OAuth session by session ID +// RLS policy automatically filters by tenant_id func (r *OAuthSessionRepository) GetBySessionID(ctx context.Context, sessionID string) (*models.OAuthSession, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` SELECT id, @@ -102,13 +99,13 @@ func (r *OAuthSessionRepository) GetBySessionID(ctx context.Context, sessionID s user_agent, ip_address FROM oauth_sessions - WHERE session_id = $1 AND tenant_id = $2 + WHERE session_id = $1 ` session := &models.OAuthSession{} var lastRefreshedAt sql.NullTime - err = r.db.QueryRowContext(ctx, query, sessionID, tenantID).Scan( + err := dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query, sessionID).Scan( &session.ID, &session.TenantID, &session.SessionID, @@ -141,12 +138,8 @@ func (r *OAuthSessionRepository) GetBySessionID(ctx context.Context, sessionID s } // UpdateRefreshToken updates the refresh token and expiration time +// RLS policy automatically filters by tenant_id func (r *OAuthSessionRepository) UpdateRefreshToken(ctx context.Context, sessionID string, encryptedToken []byte, expiresAt time.Time) error { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return fmt.Errorf("failed to get tenant: %w", err) - } - query := ` UPDATE oauth_sessions SET @@ -154,10 +147,10 @@ func (r *OAuthSessionRepository) UpdateRefreshToken(ctx context.Context, session access_token_expires_at = $2, last_refreshed_at = now(), updated_at = now() - WHERE session_id = $3 AND tenant_id = $4 + WHERE session_id = $3 ` - result, err := r.db.ExecContext(ctx, query, encryptedToken, expiresAt, sessionID, tenantID) + result, err := dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, query, encryptedToken, expiresAt, sessionID) if err != nil { logger.Logger.Error("Failed to update OAuth session refresh token", "session_id", sessionID, @@ -181,15 +174,11 @@ func (r *OAuthSessionRepository) UpdateRefreshToken(ctx context.Context, session } // DeleteBySessionID deletes an OAuth session by session ID +// RLS policy automatically filters by tenant_id func (r *OAuthSessionRepository) DeleteBySessionID(ctx context.Context, sessionID string) error { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return fmt.Errorf("failed to get tenant: %w", err) - } + query := `DELETE FROM oauth_sessions WHERE session_id = $1` - query := `DELETE FROM oauth_sessions WHERE session_id = $1 AND tenant_id = $2` - - result, err := r.db.ExecContext(ctx, query, sessionID, tenantID) + result, err := dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, query, sessionID) if err != nil { logger.Logger.Error("Failed to delete OAuth session", "session_id", sessionID, @@ -217,7 +206,7 @@ func (r *OAuthSessionRepository) DeleteExpired(ctx context.Context, olderThan ti ` cutoffTime := time.Now().Add(-olderThan) - result, err := r.db.ExecContext(ctx, query, cutoffTime) + result, err := dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, query, cutoffTime) if err != nil { logger.Logger.Error("Failed to delete expired OAuth sessions", "cutoff_time", cutoffTime, diff --git a/backend/internal/infrastructure/database/reminder_repository.go b/backend/internal/infrastructure/database/reminder_repository.go index 8667b69..669ec5e 100644 --- a/backend/internal/infrastructure/database/reminder_repository.go +++ b/backend/internal/infrastructure/database/reminder_repository.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/btouchard/ackify-ce/backend/internal/domain/models" + "github.com/btouchard/ackify-ce/backend/internal/infrastructure/dbctx" "github.com/btouchard/ackify-ce/backend/internal/infrastructure/tenant" "github.com/btouchard/ackify-ce/backend/pkg/logger" ) @@ -36,7 +37,7 @@ func (r *ReminderRepository) LogReminder(ctx context.Context, log *models.Remind RETURNING id ` - err = r.db.QueryRowContext(ctx, query, + err = dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query, tenantID, log.DocID, log.RecipientEmail, @@ -56,20 +57,16 @@ func (r *ReminderRepository) LogReminder(ctx context.Context, log *models.Remind } // GetReminderHistory retrieves complete reminder audit trail for a document, ordered by send time descending +// RLS policy automatically filters by tenant_id func (r *ReminderRepository) GetReminderHistory(ctx context.Context, docID string) ([]*models.ReminderLog, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` SELECT id, tenant_id, doc_id, recipient_email, sent_at, sent_by, template_used, status, error_message FROM reminder_logs - WHERE tenant_id = $1 AND doc_id = $2 + WHERE doc_id = $1 ORDER BY sent_at DESC ` - rows, err := r.db.QueryContext(ctx, query, tenantID, docID) + rows, err := dbctx.GetQuerier(ctx, r.db).QueryContext(ctx, query, docID) if err != nil { return nil, fmt.Errorf("failed to query reminder history: %w", err) } @@ -104,22 +101,18 @@ func (r *ReminderRepository) GetReminderHistory(ctx context.Context, docID strin } // GetLastReminderByEmail retrieves the most recent reminder sent to a specific recipient for throttling logic +// RLS policy automatically filters by tenant_id func (r *ReminderRepository) GetLastReminderByEmail(ctx context.Context, docID, email string) (*models.ReminderLog, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` SELECT id, tenant_id, doc_id, recipient_email, sent_at, sent_by, template_used, status, error_message FROM reminder_logs - WHERE tenant_id = $1 AND doc_id = $2 AND recipient_email = $3 + WHERE doc_id = $1 AND recipient_email = $2 ORDER BY sent_at DESC LIMIT 1 ` log := &models.ReminderLog{} - err = r.db.QueryRowContext(ctx, query, tenantID, docID, email).Scan( + err := dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query, docID, email).Scan( &log.ID, &log.TenantID, &log.DocID, @@ -143,20 +136,16 @@ func (r *ReminderRepository) GetLastReminderByEmail(ctx context.Context, docID, } // GetReminderCount tallies successfully delivered reminders to a recipient for rate limiting +// RLS policy automatically filters by tenant_id func (r *ReminderRepository) GetReminderCount(ctx context.Context, docID, email string) (int, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return 0, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` SELECT COUNT(*) FROM reminder_logs - WHERE tenant_id = $1 AND doc_id = $2 AND recipient_email = $3 AND status = 'sent' + WHERE doc_id = $1 AND recipient_email = $2 AND status = 'sent' ` var count int - err = r.db.QueryRowContext(ctx, query, tenantID, docID, email).Scan(&count) + err := dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query, docID, email).Scan(&count) if err != nil { return 0, fmt.Errorf("failed to get reminder count: %w", err) } @@ -165,24 +154,20 @@ func (r *ReminderRepository) GetReminderCount(ctx context.Context, docID, email } // GetReminderStats aggregates reminder metrics including pending signers and last send timestamp +// RLS policy automatically filters by tenant_id func (r *ReminderRepository) GetReminderStats(ctx context.Context, docID string) (*models.ReminderStats, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` SELECT COUNT(*) as total_sent, MAX(sent_at) as last_sent_at FROM reminder_logs - WHERE tenant_id = $1 AND doc_id = $2 AND status = 'sent' + WHERE doc_id = $1 AND status = 'sent' ` stats := &models.ReminderStats{} var lastSent sql.NullTime - err = r.db.QueryRowContext(ctx, query, tenantID, docID).Scan(&stats.TotalSent, &lastSent) + err := dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query, docID).Scan(&stats.TotalSent, &lastSent) if err != nil { return nil, fmt.Errorf("failed to get reminder stats: %w", err) } @@ -195,10 +180,10 @@ func (r *ReminderRepository) GetReminderStats(ctx context.Context, docID string) SELECT COUNT(*) FROM expected_signers es LEFT JOIN signatures s ON es.tenant_id = s.tenant_id AND es.doc_id = s.doc_id AND es.email = s.user_email - WHERE es.tenant_id = $1 AND es.doc_id = $2 AND s.id IS NULL + WHERE es.doc_id = $1 AND s.id IS NULL ` - err = r.db.QueryRowContext(ctx, pendingQuery, tenantID, docID).Scan(&stats.PendingCount) + err = dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, pendingQuery, docID).Scan(&stats.PendingCount) if err != nil { return nil, fmt.Errorf("failed to get pending count: %w", err) } diff --git a/backend/internal/infrastructure/database/repository_concurrency_test.go b/backend/internal/infrastructure/database/repository_concurrency_test.go index 0ccb3c3..aa75c67 100644 --- a/backend/internal/infrastructure/database/repository_concurrency_test.go +++ b/backend/internal/infrastructure/database/repository_concurrency_test.go @@ -309,9 +309,14 @@ func TestRepository_Concurrency_Integration(t *testing.T) { const duration = 2 * time.Second const numWorkers = 20 - ctx, cancel := context.WithTimeout(ctx, duration) + // Use a separate context for timeout control (not passed to DB operations) + // This avoids race conditions when context is cancelled during row scanning + timeoutCtx, cancel := context.WithTimeout(context.Background(), duration) defer cancel() + // Use background context for DB operations to avoid cancellation races + dbCtx := context.Background() + var wg sync.WaitGroup operationCounts := make(chan map[string]int, numWorkers) @@ -331,7 +336,7 @@ func TestRepository_Concurrency_Integration(t *testing.T) { for { select { - case <-ctx.Done(): + case <-timeoutCtx.Done(): operationCounts <- counts return default: @@ -341,14 +346,14 @@ func TestRepository_Concurrency_Integration(t *testing.T) { fmt.Sprintf("stress-user-%d-%d", workerID, counts["creates"]), fmt.Sprintf("stress%d%d@example.com", workerID, counts["creates"]), ) - if err := repo.Create(ctx, sig); err != nil { + if err := repo.Create(dbCtx, sig); err != nil { counts["errors"]++ } else { counts["creates"]++ } case 1: // GetByDocAndUser - _, err := repo.GetByDocAndUser(ctx, "test-doc-123", "user-123") + _, err := repo.GetByDocAndUser(dbCtx, "test-doc-123", "user-123") if err != nil && !strings.Contains(err.Error(), "not found") { counts["errors"]++ } else { @@ -356,7 +361,7 @@ func TestRepository_Concurrency_Integration(t *testing.T) { } case 2: // ExistsByDocAndUser - _, err := repo.ExistsByDocAndUser(ctx, "test-doc-123", "user-123") + _, err := repo.ExistsByDocAndUser(dbCtx, "test-doc-123", "user-123") if err != nil { counts["errors"]++ } else { @@ -364,7 +369,7 @@ func TestRepository_Concurrency_Integration(t *testing.T) { } case 3: // GetLastSignature - _, err := repo.GetLastSignature(ctx, "test-doc-123") + _, err := repo.GetLastSignature(dbCtx, "test-doc-123") if err != nil { counts["errors"]++ } else { @@ -372,7 +377,7 @@ func TestRepository_Concurrency_Integration(t *testing.T) { } case 4: // GetAllSignaturesOrdered - _, err := repo.GetAllSignaturesOrdered(ctx) + _, err := repo.GetAllSignaturesOrdered(dbCtx) if err != nil { counts["errors"]++ } else { diff --git a/backend/internal/infrastructure/database/rls_integration_test.go b/backend/internal/infrastructure/database/rls_integration_test.go new file mode 100644 index 0000000..5c28bbd --- /dev/null +++ b/backend/internal/infrastructure/database/rls_integration_test.go @@ -0,0 +1,635 @@ +//go:build integration + +// SPDX-License-Identifier: AGPL-3.0-or-later +package database + +import ( + "context" + "database/sql" + "testing" + "time" + + "github.com/btouchard/ackify-ce/backend/internal/domain/models" + "github.com/btouchard/ackify-ce/backend/internal/infrastructure/tenant" + "github.com/google/uuid" +) + +// isSuperuser checks if the current database connection is a superuser. +// Superusers bypass RLS policies, so isolation tests will fail when connected as superuser. +func isSuperuser(db *sql.DB) bool { + var isSuperuser bool + err := db.QueryRow("SELECT usesuper FROM pg_user WHERE usename = current_user").Scan(&isSuperuser) + if err != nil { + return true // Assume superuser on error + } + return isSuperuser +} + +// skipIfSuperuser skips the test if connected as a superuser. +// RLS policies are only enforced for non-superuser roles. +func skipIfSuperuser(t *testing.T, db *sql.DB) { + t.Helper() + if isSuperuser(db) { + t.Skip("Skipping RLS isolation test: connected as superuser (RLS is bypassed). Use ackify_app role to test RLS enforcement.") + } +} + +// TestRLS_TenantIsolation verifies that data inserted by one tenant +// cannot be accessed by another tenant when using RLS. +// NOTE: This test requires a non-superuser connection to verify RLS enforcement. +func TestRLS_TenantIsolation(t *testing.T) { + testDB := SetupTestDB(t) + skipIfSuperuser(t, testDB.DB) + ctx := context.Background() + + // Get the default tenant ID from the test setup + tenantA, err := testDB.TenantProvider.CurrentTenant(ctx) + if err != nil { + t.Fatalf("Failed to get tenant A ID: %v", err) + } + + // Create a different tenant ID for isolation testing + tenantB := uuid.New() + + docRepo := NewDocumentRepository(testDB.DB, testDB.TenantProvider) + sigRepo := NewSignatureRepository(testDB.DB, testDB.TenantProvider) + + // Create a document with tenant A + docID := "doc-tenant-a-" + uuid.New().String()[:8] + docInput := models.DocumentInput{ + Title: "Document A", + URL: "https://example.com/doc-a", + Checksum: "checksum-a", + ChecksumAlgorithm: "SHA-256", + Description: "Test document for tenant A", + } + + var docA *models.Document + err = tenant.WithTenantContext(ctx, testDB.DB, tenantA, func(txCtx context.Context) error { + var createErr error + docA, createErr = docRepo.Create(txCtx, docID, docInput, "user-a@example.com") + return createErr + }) + if err != nil { + t.Fatalf("Failed to create document with tenant A: %v", err) + } + + // Create a signature with tenant A + sigA := &models.Signature{ + DocID: docA.DocID, + UserSub: "user-sub-a", + UserEmail: "user-a@example.com", + UserName: "User A", + SignedAtUTC: time.Now().UTC(), + PayloadHash: "cGF5bG9hZC1oYXNoLWE=", + Signature: "c2lnbmF0dXJlLWE=", + Nonce: "nonce-a-" + uuid.New().String()[:8], + } + + err = tenant.WithTenantContext(ctx, testDB.DB, tenantA, func(txCtx context.Context) error { + return sigRepo.Create(txCtx, sigA) + }) + if err != nil { + t.Fatalf("Failed to create signature with tenant A: %v", err) + } + + // Verify tenant A can access its own data + t.Run("tenant_A_can_access_own_data", func(t *testing.T) { + var doc *models.Document + var signatures []*models.Signature + + err := tenant.WithTenantContext(ctx, testDB.DB, tenantA, func(txCtx context.Context) error { + var fetchErr error + doc, fetchErr = docRepo.GetByDocID(txCtx, docA.DocID) + if fetchErr != nil { + return fetchErr + } + + signatures, fetchErr = sigRepo.GetByDoc(txCtx, docA.DocID) + return fetchErr + }) + + if err != nil { + t.Errorf("Tenant A should be able to access its own data: %v", err) + } + + if doc == nil { + t.Error("Tenant A should see its document") + } + + if len(signatures) != 1 { + t.Errorf("Tenant A should see 1 signature, got %d", len(signatures)) + } + }) + + // Verify tenant B cannot access tenant A's data + t.Run("tenant_B_cannot_access_tenant_A_data", func(t *testing.T) { + var doc *models.Document + var signatures []*models.Signature + + err := tenant.WithTenantContext(ctx, testDB.DB, tenantB, func(txCtx context.Context) error { + var fetchErr error + doc, fetchErr = docRepo.GetByDocID(txCtx, docA.DocID) + if fetchErr != nil && fetchErr != models.ErrDocumentNotFound { + return fetchErr + } + + signatures, fetchErr = sigRepo.GetByDoc(txCtx, docA.DocID) + return fetchErr + }) + + // Should not return an error, just no data + if err != nil && err != models.ErrDocumentNotFound { + t.Errorf("Unexpected error for tenant B: %v", err) + } + + if doc != nil { + t.Error("Tenant B should NOT see tenant A's document") + } + + if len(signatures) != 0 { + t.Errorf("Tenant B should see 0 signatures from tenant A, got %d", len(signatures)) + } + }) +} + +// TestRLS_DocumentIsolation tests document isolation specifically. +// NOTE: This test requires a non-superuser connection to verify RLS enforcement. +func TestRLS_DocumentIsolation(t *testing.T) { + testDB := SetupTestDB(t) + skipIfSuperuser(t, testDB.DB) + ctx := context.Background() + + tenantA, _ := testDB.TenantProvider.CurrentTenant(ctx) + tenantB := uuid.New() + + docRepo := NewDocumentRepository(testDB.DB, testDB.TenantProvider) + + // Create documents for tenant A + docsToCreate := []struct { + docID string + input models.DocumentInput + }{ + { + docID: "rls-doc-1-" + uuid.New().String()[:8], + input: models.DocumentInput{ + Title: "Doc 1", + URL: "https://example.com/1", + Checksum: "checksum1", + ChecksumAlgorithm: "SHA-256", + }, + }, + { + docID: "rls-doc-2-" + uuid.New().String()[:8], + input: models.DocumentInput{ + Title: "Doc 2", + URL: "https://example.com/2", + Checksum: "checksum2", + ChecksumAlgorithm: "SHA-256", + }, + }, + } + + for _, doc := range docsToCreate { + err := tenant.WithTenantContext(ctx, testDB.DB, tenantA, func(txCtx context.Context) error { + _, createErr := docRepo.Create(txCtx, doc.docID, doc.input, "admin@a.com") + return createErr + }) + if err != nil { + t.Fatalf("Failed to create document: %v", err) + } + } + + // Verify tenant A sees both documents + t.Run("tenant_A_sees_all_its_documents", func(t *testing.T) { + var docs []*models.Document + + err := tenant.WithTenantContext(ctx, testDB.DB, tenantA, func(txCtx context.Context) error { + var fetchErr error + docs, fetchErr = docRepo.List(txCtx, 100, 0) + return fetchErr + }) + + if err != nil { + t.Fatalf("Failed to list documents for tenant A: %v", err) + } + + if len(docs) < 2 { + t.Errorf("Tenant A should see at least 2 documents, got %d", len(docs)) + } + }) + + // Verify tenant B sees no documents from tenant A + t.Run("tenant_B_sees_no_documents", func(t *testing.T) { + var docs []*models.Document + + err := tenant.WithTenantContext(ctx, testDB.DB, tenantB, func(txCtx context.Context) error { + var fetchErr error + docs, fetchErr = docRepo.List(txCtx, 100, 0) + return fetchErr + }) + + if err != nil { + t.Fatalf("Failed to list documents for tenant B: %v", err) + } + + if len(docs) != 0 { + t.Errorf("Tenant B should see 0 documents, got %d", len(docs)) + } + }) +} + +// TestRLS_SignatureIsolation tests signature isolation specifically. +// NOTE: This test requires a non-superuser connection to verify RLS enforcement. +func TestRLS_SignatureIsolation(t *testing.T) { + testDB := SetupTestDB(t) + skipIfSuperuser(t, testDB.DB) + ctx := context.Background() + + tenantA, _ := testDB.TenantProvider.CurrentTenant(ctx) + tenantB := uuid.New() + + docRepo := NewDocumentRepository(testDB.DB, testDB.TenantProvider) + sigRepo := NewSignatureRepository(testDB.DB, testDB.TenantProvider) + + // Create document for tenant A + docID := "rls-sig-test-" + uuid.New().String()[:8] + docInput := models.DocumentInput{ + Title: "Signature Test Doc", + URL: "https://example.com/sig-test", + Checksum: "checksum-sig", + ChecksumAlgorithm: "SHA-256", + } + + err := tenant.WithTenantContext(ctx, testDB.DB, tenantA, func(txCtx context.Context) error { + _, createErr := docRepo.Create(txCtx, docID, docInput, "admin@a.com") + return createErr + }) + if err != nil { + t.Fatalf("Failed to create document: %v", err) + } + + // Create signatures for tenant A + userSub := "rls-user-" + uuid.New().String()[:8] + sig := &models.Signature{ + DocID: docID, + UserSub: userSub, + UserEmail: "signer@a.com", + UserName: "Signer A", + SignedAtUTC: time.Now().UTC(), + PayloadHash: "cGF5bG9hZC1oYXNo", + Signature: "c2lnbmF0dXJl", + Nonce: "nonce-" + uuid.New().String()[:8], + } + + err = tenant.WithTenantContext(ctx, testDB.DB, tenantA, func(txCtx context.Context) error { + return sigRepo.Create(txCtx, sig) + }) + if err != nil { + t.Fatalf("Failed to create signature: %v", err) + } + + // Tenant A can get signature by doc and user + t.Run("tenant_A_can_get_signature", func(t *testing.T) { + var fetchedSig *models.Signature + + err := tenant.WithTenantContext(ctx, testDB.DB, tenantA, func(txCtx context.Context) error { + var fetchErr error + fetchedSig, fetchErr = sigRepo.GetByDocAndUser(txCtx, docID, userSub) + return fetchErr + }) + + if err != nil { + t.Errorf("Tenant A should be able to get its signature: %v", err) + } + + if fetchedSig == nil { + t.Error("Tenant A should see its signature") + } + }) + + // Tenant B cannot get tenant A's signature + t.Run("tenant_B_cannot_get_signature", func(t *testing.T) { + var fetchedSig *models.Signature + + err := tenant.WithTenantContext(ctx, testDB.DB, tenantB, func(txCtx context.Context) error { + var fetchErr error + fetchedSig, fetchErr = sigRepo.GetByDocAndUser(txCtx, docID, userSub) + if fetchErr == models.ErrSignatureNotFound { + return nil // Expected + } + return fetchErr + }) + + if err != nil { + t.Errorf("Unexpected error for tenant B: %v", err) + } + + if fetchedSig != nil { + t.Error("Tenant B should NOT see tenant A's signature") + } + }) + + // Tenant A can check signature status + t.Run("tenant_A_can_check_signature_status", func(t *testing.T) { + var hasSigned bool + + err := tenant.WithTenantContext(ctx, testDB.DB, tenantA, func(txCtx context.Context) error { + var checkErr error + hasSigned, checkErr = sigRepo.CheckUserSignatureStatus(txCtx, docID, userSub) + return checkErr + }) + + if err != nil { + t.Errorf("Tenant A should be able to check signature status: %v", err) + } + + if !hasSigned { + t.Error("Tenant A should see that user has signed") + } + }) + + // Tenant B gets false for signature status check + t.Run("tenant_B_signature_status_is_false", func(t *testing.T) { + var hasSigned bool + + err := tenant.WithTenantContext(ctx, testDB.DB, tenantB, func(txCtx context.Context) error { + var checkErr error + hasSigned, checkErr = sigRepo.CheckUserSignatureStatus(txCtx, docID, userSub) + return checkErr + }) + + if err != nil { + t.Errorf("Unexpected error for tenant B: %v", err) + } + + if hasSigned { + t.Error("Tenant B should NOT see tenant A's signature status") + } + }) +} + +// TestRLS_TransactionCommitRollback tests that RLS transactions properly commit and rollback +func TestRLS_TransactionCommitRollback(t *testing.T) { + testDB := SetupTestDB(t) + ctx := context.Background() + + tenantID, _ := testDB.TenantProvider.CurrentTenant(ctx) + docRepo := NewDocumentRepository(testDB.DB, testDB.TenantProvider) + + t.Run("successful_transaction_commits", func(t *testing.T) { + docID := "rls-commit-test-" + uuid.New().String()[:8] + docInput := models.DocumentInput{ + Title: "Commit Test", + URL: "https://example.com/commit", + Checksum: "checksum", + ChecksumAlgorithm: "SHA-256", + } + + // Create document in transaction + err := tenant.WithTenantContext(ctx, testDB.DB, tenantID, func(txCtx context.Context) error { + _, createErr := docRepo.Create(txCtx, docID, docInput, "admin@test.com") + return createErr + }) + if err != nil { + t.Fatalf("Failed to create document: %v", err) + } + + // Verify document exists after commit + var doc *models.Document + err = tenant.WithTenantContext(ctx, testDB.DB, tenantID, func(txCtx context.Context) error { + var fetchErr error + doc, fetchErr = docRepo.GetByDocID(txCtx, docID) + return fetchErr + }) + + if err != nil { + t.Errorf("Should be able to fetch document after commit: %v", err) + } + + if doc == nil { + t.Error("Document should exist after commit") + } + }) + + t.Run("failed_transaction_rollbacks", func(t *testing.T) { + docID := "rls-rollback-test-" + uuid.New().String()[:8] + docInput := models.DocumentInput{ + Title: "Rollback Test", + URL: "https://example.com/rollback", + Checksum: "checksum", + ChecksumAlgorithm: "SHA-256", + } + + // Attempt to create document but return an error to trigger rollback + err := tenant.WithTenantContext(ctx, testDB.DB, tenantID, func(txCtx context.Context) error { + _, createErr := docRepo.Create(txCtx, docID, docInput, "admin@test.com") + if createErr != nil { + return createErr + } + // Return error to trigger rollback + return models.ErrDocumentNotFound + }) + + if err == nil { + t.Fatal("Expected error to be returned") + } + + // Verify document does NOT exist after rollback + var doc *models.Document + err = tenant.WithTenantContext(ctx, testDB.DB, tenantID, func(txCtx context.Context) error { + var fetchErr error + doc, fetchErr = docRepo.GetByDocID(txCtx, docID) + if fetchErr == models.ErrDocumentNotFound { + return nil // Expected + } + return fetchErr + }) + + if err != nil { + t.Errorf("Unexpected error when checking for rolled back document: %v", err) + } + + if doc != nil { + t.Error("Document should NOT exist after rollback") + } + }) +} + +// TestRLS_ExpectedSignersIsolation tests expected signers isolation. +// NOTE: This test requires a non-superuser connection to verify RLS enforcement. +func TestRLS_ExpectedSignersIsolation(t *testing.T) { + testDB := SetupTestDB(t) + skipIfSuperuser(t, testDB.DB) + ctx := context.Background() + + tenantA, _ := testDB.TenantProvider.CurrentTenant(ctx) + tenantB := uuid.New() + + docRepo := NewDocumentRepository(testDB.DB, testDB.TenantProvider) + signerRepo := NewExpectedSignerRepository(testDB.DB, testDB.TenantProvider) + + // Create document for tenant A + docID := "rls-signer-test-" + uuid.New().String()[:8] + docInput := models.DocumentInput{ + Title: "Expected Signers Test", + URL: "https://example.com/signers", + Checksum: "checksum", + ChecksumAlgorithm: "SHA-256", + } + + err := tenant.WithTenantContext(ctx, testDB.DB, tenantA, func(txCtx context.Context) error { + _, createErr := docRepo.Create(txCtx, docID, docInput, "admin@a.com") + return createErr + }) + if err != nil { + t.Fatalf("Failed to create document: %v", err) + } + + // Add expected signers for tenant A using AddExpected + contacts := []models.ContactInfo{ + {Name: "Signer One", Email: "signer1@a.com"}, + {Name: "Signer Two", Email: "signer2@a.com"}, + } + + err = tenant.WithTenantContext(ctx, testDB.DB, tenantA, func(txCtx context.Context) error { + return signerRepo.AddExpected(txCtx, docID, contacts, "admin@a.com") + }) + if err != nil { + t.Fatalf("Failed to add expected signers: %v", err) + } + + // Tenant A can see expected signers + t.Run("tenant_A_sees_expected_signers", func(t *testing.T) { + var fetchedSigners []*models.ExpectedSigner + + err := tenant.WithTenantContext(ctx, testDB.DB, tenantA, func(txCtx context.Context) error { + var fetchErr error + fetchedSigners, fetchErr = signerRepo.ListByDocID(txCtx, docID) + return fetchErr + }) + + if err != nil { + t.Errorf("Tenant A should be able to get expected signers: %v", err) + } + + if len(fetchedSigners) != 2 { + t.Errorf("Tenant A should see 2 expected signers, got %d", len(fetchedSigners)) + } + }) + + // Tenant B cannot see tenant A's expected signers + t.Run("tenant_B_cannot_see_expected_signers", func(t *testing.T) { + var fetchedSigners []*models.ExpectedSigner + + err := tenant.WithTenantContext(ctx, testDB.DB, tenantB, func(txCtx context.Context) error { + var fetchErr error + fetchedSigners, fetchErr = signerRepo.ListByDocID(txCtx, docID) + return fetchErr + }) + + if err != nil { + t.Errorf("Unexpected error for tenant B: %v", err) + } + + if len(fetchedSigners) != 0 { + t.Errorf("Tenant B should see 0 expected signers, got %d", len(fetchedSigners)) + } + }) +} + +// TestRLS_WebhookIsolation tests webhook isolation. +// NOTE: This test requires a non-superuser connection to verify RLS enforcement. +func TestRLS_WebhookIsolation(t *testing.T) { + testDB := SetupTestDB(t) + skipIfSuperuser(t, testDB.DB) + ctx := context.Background() + + tenantA, _ := testDB.TenantProvider.CurrentTenant(ctx) + tenantB := uuid.New() + + webhookRepo := NewWebhookRepository(testDB.DB, testDB.TenantProvider) + + // Create webhook for tenant A + var webhookID int64 + err := tenant.WithTenantContext(ctx, testDB.DB, tenantA, func(txCtx context.Context) error { + wh, createErr := webhookRepo.Create(txCtx, models.WebhookInput{ + Title: "Webhook A", + TargetURL: "https://hooks.example.com/a", + Secret: "secret-a", + Description: "Webhook for tenant A", + Active: true, + Events: []string{"document.signed"}, + CreatedBy: "admin@a.com", + }) + if createErr != nil { + return createErr + } + webhookID = wh.ID + return nil + }) + if err != nil { + t.Fatalf("Failed to create webhook: %v", err) + } + + // Tenant A can see its webhook + t.Run("tenant_A_sees_webhook", func(t *testing.T) { + var webhook *models.Webhook + + err := tenant.WithTenantContext(ctx, testDB.DB, tenantA, func(txCtx context.Context) error { + var fetchErr error + webhook, fetchErr = webhookRepo.GetByID(txCtx, webhookID) + return fetchErr + }) + + if err != nil { + t.Errorf("Tenant A should be able to get its webhook: %v", err) + } + + if webhook == nil { + t.Error("Tenant A should see its webhook") + } + }) + + // Tenant B cannot see tenant A's webhook (will get sql.ErrNoRows) + t.Run("tenant_B_cannot_see_webhook", func(t *testing.T) { + var webhook *models.Webhook + + err := tenant.WithTenantContext(ctx, testDB.DB, tenantB, func(txCtx context.Context) error { + var fetchErr error + webhook, fetchErr = webhookRepo.GetByID(txCtx, webhookID) + // sql.ErrNoRows is expected, ignore it + if fetchErr != nil { + return nil // Expected: no rows + } + return nil + }) + + if err != nil { + t.Errorf("Unexpected error for tenant B: %v", err) + } + + if webhook != nil { + t.Error("Tenant B should NOT see tenant A's webhook") + } + }) + + // Tenant B sees empty list + t.Run("tenant_B_sees_empty_webhook_list", func(t *testing.T) { + var webhooks []*models.Webhook + + err := tenant.WithTenantContext(ctx, testDB.DB, tenantB, func(txCtx context.Context) error { + var fetchErr error + webhooks, fetchErr = webhookRepo.List(txCtx, 100, 0) + return fetchErr + }) + + if err != nil { + t.Errorf("Unexpected error for tenant B: %v", err) + } + + if len(webhooks) != 0 { + t.Errorf("Tenant B should see 0 webhooks, got %d", len(webhooks)) + } + }) +} diff --git a/backend/internal/infrastructure/database/signature_repository.go b/backend/internal/infrastructure/database/signature_repository.go index 0bf5319..c027bbc 100644 --- a/backend/internal/infrastructure/database/signature_repository.go +++ b/backend/internal/infrastructure/database/signature_repository.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/btouchard/ackify-ce/backend/internal/domain/models" + "github.com/btouchard/ackify-ce/backend/internal/infrastructure/dbctx" "github.com/btouchard/ackify-ce/backend/internal/infrastructure/tenant" ) @@ -104,7 +105,7 @@ func (r *SignatureRepository) Create(ctx context.Context, signature *models.Sign docChecksum = sql.NullString{String: signature.DocChecksum, Valid: true} } - err = r.db.QueryRowContext( + err = dbctx.GetQuerier(ctx, r.db).QueryRowContext( ctx, query, tenantID, signature.DocID, @@ -129,23 +130,19 @@ func (r *SignatureRepository) Create(ctx context.Context, signature *models.Sign } // GetByDocAndUser retrieves a specific signature by document ID and user OAuth subject identifier +// RLS policy automatically filters by tenant_id func (r *SignatureRepository) GetByDocAndUser(ctx context.Context, docID, userSub string) (*models.Signature, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` SELECT s.id, s.tenant_id, s.doc_id, s.user_sub, s.user_email, s.user_name, s.signed_at, s.doc_checksum, s.payload_hash, s.signature, s.nonce, s.created_at, s.referer, s.prev_hash, s.hash_version, s.doc_deleted_at, d.title, d.url FROM signatures s - LEFT JOIN documents d ON s.doc_id = d.doc_id - WHERE s.tenant_id = $1 AND s.doc_id = $2 AND s.user_sub = $3 + LEFT JOIN documents d ON s.doc_id = d.doc_id AND s.tenant_id = d.tenant_id + WHERE s.doc_id = $1 AND s.user_sub = $2 ` signature := &models.Signature{} - err = scanSignature(r.db.QueryRowContext(ctx, query, tenantID, docID, userSub), signature) + err := scanSignature(dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query, docID, userSub), signature) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -158,23 +155,19 @@ func (r *SignatureRepository) GetByDocAndUser(ctx context.Context, docID, userSu } // GetByDoc retrieves all signatures for a specific document, ordered by creation timestamp descending +// RLS policy automatically filters by tenant_id func (r *SignatureRepository) GetByDoc(ctx context.Context, docID string) ([]*models.Signature, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` SELECT s.id, s.tenant_id, s.doc_id, s.user_sub, s.user_email, s.user_name, s.signed_at, s.doc_checksum, s.payload_hash, s.signature, s.nonce, s.created_at, s.referer, s.prev_hash, s.hash_version, s.doc_deleted_at, d.title, d.url FROM signatures s - LEFT JOIN documents d ON s.doc_id = d.doc_id - WHERE s.tenant_id = $1 AND s.doc_id = $2 + LEFT JOIN documents d ON s.doc_id = d.doc_id AND s.tenant_id = d.tenant_id + WHERE s.doc_id = $1 ORDER BY s.created_at DESC ` - rows, err := r.db.QueryContext(ctx, query, tenantID, docID) + rows, err := dbctx.GetQuerier(ctx, r.db).QueryContext(ctx, query, docID) if err != nil { return nil, fmt.Errorf("failed to query signatures: %w", err) } @@ -195,23 +188,19 @@ func (r *SignatureRepository) GetByDoc(ctx context.Context, docID string) ([]*mo } // GetByUser retrieves all signatures created by a specific user, ordered by creation timestamp descending +// RLS policy automatically filters by tenant_id func (r *SignatureRepository) GetByUser(ctx context.Context, userSub string) ([]*models.Signature, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` SELECT s.id, s.tenant_id, s.doc_id, s.user_sub, s.user_email, s.user_name, s.signed_at, s.doc_checksum, s.payload_hash, s.signature, s.nonce, s.created_at, s.referer, s.prev_hash, s.hash_version, s.doc_deleted_at, d.title, d.url FROM signatures s - LEFT JOIN documents d ON s.doc_id = d.doc_id - WHERE s.tenant_id = $1 AND s.user_sub = $2 + LEFT JOIN documents d ON s.doc_id = d.doc_id AND s.tenant_id = d.tenant_id + WHERE s.user_sub = $1 ORDER BY s.created_at DESC ` - rows, err := r.db.QueryContext(ctx, query, tenantID, userSub) + rows, err := dbctx.GetQuerier(ctx, r.db).QueryContext(ctx, query, userSub) if err != nil { return nil, fmt.Errorf("failed to query user signatures: %w", err) } @@ -232,16 +221,12 @@ func (r *SignatureRepository) GetByUser(ctx context.Context, userSub string) ([] } // ExistsByDocAndUser efficiently checks if a signature already exists without retrieving full record data +// RLS policy automatically filters by tenant_id func (r *SignatureRepository) ExistsByDocAndUser(ctx context.Context, docID, userSub string) (bool, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return false, fmt.Errorf("failed to get tenant: %w", err) - } - - query := `SELECT EXISTS(SELECT 1 FROM signatures WHERE tenant_id = $1 AND doc_id = $2 AND user_sub = $3)` + query := `SELECT EXISTS(SELECT 1 FROM signatures WHERE doc_id = $1 AND user_sub = $2)` var exists bool - err = r.db.QueryRowContext(ctx, query, tenantID, docID, userSub).Scan(&exists) + err := dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query, docID, userSub).Scan(&exists) if err != nil { return false, fmt.Errorf("failed to check signature existence: %w", err) } @@ -250,21 +235,17 @@ func (r *SignatureRepository) ExistsByDocAndUser(ctx context.Context, docID, use } // CheckUserSignatureStatus verifies if a user has signed, accepting either OAuth subject or email as identifier +// RLS policy automatically filters by tenant_id func (r *SignatureRepository) CheckUserSignatureStatus(ctx context.Context, docID, userIdentifier string) (bool, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return false, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` SELECT EXISTS( SELECT 1 FROM signatures - WHERE tenant_id = $1 AND doc_id = $2 AND (user_sub = $3 OR LOWER(user_email) = LOWER($3)) + WHERE doc_id = $1 AND (user_sub = $2 OR LOWER(user_email) = LOWER($2)) ) ` var exists bool - err = r.db.QueryRowContext(ctx, query, tenantID, docID, userIdentifier).Scan(&exists) + err := dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query, docID, userIdentifier).Scan(&exists) if err != nil { return false, fmt.Errorf("failed to check user signature status: %w", err) } @@ -273,25 +254,21 @@ func (r *SignatureRepository) CheckUserSignatureStatus(ctx context.Context, docI } // GetLastSignature retrieves the most recent signature for hash chain linking (returns nil if no signatures exist) +// RLS policy automatically filters by tenant_id func (r *SignatureRepository) GetLastSignature(ctx context.Context, docID string) (*models.Signature, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` SELECT s.id, s.tenant_id, s.doc_id, s.user_sub, s.user_email, s.user_name, s.signed_at, s.doc_checksum, s.payload_hash, s.signature, s.nonce, s.created_at, s.referer, s.prev_hash, s.hash_version, s.doc_deleted_at, d.title, d.url FROM signatures s - LEFT JOIN documents d ON s.doc_id = d.doc_id - WHERE s.tenant_id = $1 AND s.doc_id = $2 + LEFT JOIN documents d ON s.doc_id = d.doc_id AND s.tenant_id = d.tenant_id + WHERE s.doc_id = $1 ORDER BY s.id DESC LIMIT 1 ` signature := &models.Signature{} - err = scanSignature(r.db.QueryRowContext(ctx, query, tenantID, docID), signature) + err := scanSignature(dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query, docID), signature) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -304,22 +281,17 @@ func (r *SignatureRepository) GetLastSignature(ctx context.Context, docID string } // GetAllSignaturesOrdered retrieves all signatures in chronological order for chain integrity verification +// RLS policy automatically filters by tenant_id func (r *SignatureRepository) GetAllSignaturesOrdered(ctx context.Context) ([]*models.Signature, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` SELECT s.id, s.tenant_id, s.doc_id, s.user_sub, s.user_email, s.user_name, s.signed_at, s.doc_checksum, s.payload_hash, s.signature, s.nonce, s.created_at, s.referer, s.prev_hash, s.hash_version, s.doc_deleted_at, d.title, d.url FROM signatures s - LEFT JOIN documents d ON s.doc_id = d.doc_id - WHERE s.tenant_id = $1 + LEFT JOIN documents d ON s.doc_id = d.doc_id AND s.tenant_id = d.tenant_id ORDER BY s.id ASC` - rows, err := r.db.QueryContext(ctx, query, tenantID) + rows, err := dbctx.GetQuerier(ctx, r.db).QueryContext(ctx, query) if err != nil { return nil, fmt.Errorf("failed to query all signatures: %w", err) } @@ -340,14 +312,10 @@ func (r *SignatureRepository) GetAllSignaturesOrdered(ctx context.Context) ([]*m } // UpdatePrevHash modifies the previous hash pointer for chain reconstruction operations +// RLS policy automatically filters by tenant_id func (r *SignatureRepository) UpdatePrevHash(ctx context.Context, id int64, prevHash *string) error { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return fmt.Errorf("failed to get tenant: %w", err) - } - - query := `UPDATE signatures SET prev_hash = $3 WHERE tenant_id = $1 AND id = $2` - if _, err := r.db.ExecContext(ctx, query, tenantID, id, prevHash); err != nil { + query := `UPDATE signatures SET prev_hash = $2 WHERE id = $1` + if _, err := dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, query, id, prevHash); err != nil { return fmt.Errorf("failed to update prev_hash: %w", err) } return nil diff --git a/backend/internal/infrastructure/database/webhook_delivery_repository.go b/backend/internal/infrastructure/database/webhook_delivery_repository.go index dd8d5ff..a6f7127 100644 --- a/backend/internal/infrastructure/database/webhook_delivery_repository.go +++ b/backend/internal/infrastructure/database/webhook_delivery_repository.go @@ -9,6 +9,7 @@ import ( "time" "github.com/btouchard/ackify-ce/backend/internal/domain/models" + "github.com/btouchard/ackify-ce/backend/internal/infrastructure/dbctx" "github.com/btouchard/ackify-ce/backend/internal/infrastructure/tenant" ) @@ -72,7 +73,7 @@ func (r *WebhookDeliveryRepository) Enqueue(ctx context.Context, input models.We MaxRetries: maxRetries, ScheduledFor: scheduled, } - err = r.db.QueryRowContext(ctx, q, + err = dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, q, tenantID, input.WebhookID, input.EventType, input.EventID, payloadJSON, input.Priority, maxRetries, scheduled, ).Scan(&item.ID, &item.TenantID, &item.Status, &item.RetryCount, &item.CreatedAt, &item.ProcessedAt, &item.NextRetryAt) if err != nil { @@ -102,7 +103,7 @@ func (r *WebhookDeliveryRepository) GetNextToProcess(ctx context.Context, limit FROM upd u JOIN webhooks w ON w.id = u.webhook_id ` - rows, err := r.db.QueryContext(ctx, q, time.Now(), limit) + rows, err := dbctx.GetQuerier(ctx, r.db).QueryContext(ctx, q, time.Now(), limit) if err != nil { return nil, fmt.Errorf("failed to get next webhook deliveries: %w", err) } @@ -144,7 +145,7 @@ func (r *WebhookDeliveryRepository) GetRetryable(ctx context.Context, limit int) FROM upd u JOIN webhooks w ON w.id = u.webhook_id ` - rows, err := r.db.QueryContext(ctx, q, time.Now(), limit) + rows, err := dbctx.GetQuerier(ctx, r.db).QueryContext(ctx, q, time.Now(), limit) if err != nil { return nil, fmt.Errorf("failed to get retryable webhook deliveries: %w", err) } @@ -178,7 +179,7 @@ func (r *WebhookDeliveryRepository) MarkDelivered(ctx context.Context, id int64, SET status='delivered', processed_at=now(), response_status=$1, response_headers=$2, response_body=$3 WHERE id=$4 ` - _, err := r.db.ExecContext(ctx, q, responseStatus, headersJSON, responseBody, id) + _, err := dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, q, responseStatus, headersJSON, responseBody, id) return err } @@ -193,38 +194,35 @@ func (r *WebhookDeliveryRepository) MarkFailed(ctx context.Context, id int64, er SET status='pending', retry_count=retry_count+1, last_error=$1, scheduled_for=calculate_next_retry_time(retry_count+1) WHERE id=$2 AND retry_count < max_retries ` - res, e := r.db.ExecContext(ctx, q, errMsg, id) + res, e := dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, q, errMsg, id) if e != nil { return e } if n, _ := res.RowsAffected(); n == 0 { // mark as permanently failed q := `UPDATE webhook_deliveries SET status='failed', processed_at=now(), last_error=$1 WHERE id=$2` - _, e = r.db.ExecContext(ctx, q, errMsg, id) + _, e = dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, q, errMsg, id) return e } return nil } q := `UPDATE webhook_deliveries SET status='failed', processed_at=now(), last_error=$1 WHERE id=$2` - _, e := r.db.ExecContext(ctx, q, errMsg, id) + _, e := dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, q, errMsg, id) return e } +// ListByWebhook retrieves paginated webhook deliveries for a specific webhook +// RLS policy automatically filters by tenant_id func (r *WebhookDeliveryRepository) ListByWebhook(ctx context.Context, webhookID int64, limit, offset int) ([]*models.WebhookDelivery, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - q := ` SELECT id, tenant_id, webhook_id, event_type, event_id, payload, status, retry_count, max_retries, priority, created_at, scheduled_for, processed_at, next_retry_at, request_headers, response_status, response_headers, response_body, last_error FROM webhook_deliveries - WHERE webhook_id=$1 AND tenant_id=$2 + WHERE webhook_id=$1 ORDER BY id DESC - LIMIT $3 OFFSET $4 + LIMIT $2 OFFSET $3 ` - rows, err := r.db.QueryContext(ctx, q, webhookID, tenantID, limit, offset) + rows, err := dbctx.GetQuerier(ctx, r.db).QueryContext(ctx, q, webhookID, limit, offset) if err != nil { return nil, fmt.Errorf("failed to list deliveries: %w", err) } @@ -246,7 +244,7 @@ func (r *WebhookDeliveryRepository) ListByWebhook(ctx context.Context, webhookID func (r *WebhookDeliveryRepository) CleanupOld(ctx context.Context, olderThan time.Duration) (int64, error) { q := `DELETE FROM webhook_deliveries WHERE status IN ('delivered','failed','cancelled') AND processed_at < $1` cutoff := time.Now().Add(-olderThan) - res, err := r.db.ExecContext(ctx, q, cutoff) + res, err := dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, q, cutoff) if err != nil { return 0, fmt.Errorf("failed to cleanup old deliveries: %w", err) } diff --git a/backend/internal/infrastructure/database/webhook_repository.go b/backend/internal/infrastructure/database/webhook_repository.go index d3d58ca..a4e8836 100644 --- a/backend/internal/infrastructure/database/webhook_repository.go +++ b/backend/internal/infrastructure/database/webhook_repository.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/btouchard/ackify-ce/backend/internal/domain/models" + "github.com/btouchard/ackify-ce/backend/internal/infrastructure/dbctx" "github.com/btouchard/ackify-ce/backend/internal/infrastructure/tenant" "github.com/lib/pq" ) @@ -41,7 +42,7 @@ func (r *WebhookRepository) Create(ctx context.Context, input models.WebhookInpu ` wh := &models.Webhook{} var headersOut models.NullRawMessage - err = r.db.QueryRowContext(ctx, query, + err = dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query, tenantID, input.Title, input.TargetURL, @@ -64,12 +65,9 @@ func (r *WebhookRepository) Create(ctx context.Context, input models.WebhookInpu return wh, nil } +// Update modifies an existing webhook configuration +// RLS policy automatically filters by tenant_id func (r *WebhookRepository) Update(ctx context.Context, id int64, input models.WebhookInput) (*models.Webhook, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - headersJSON := []byte("{}") if input.Headers != nil { if data, err := json.Marshal(input.Headers); err == nil { @@ -80,12 +78,12 @@ func (r *WebhookRepository) Update(ctx context.Context, id int64, input models.W query := ` UPDATE webhooks SET title=$1, target_url=$2, secret=COALESCE(NULLIF($3,''), secret), active=$4, events=$5, headers=$6, description=$7, updated_at=now() - WHERE id=$8 AND tenant_id=$9 + WHERE id=$8 RETURNING id, tenant_id, title, target_url, secret, active, events, headers, description, created_by, created_at, updated_at, last_delivered_at, failure_count ` wh := &models.Webhook{} var headersOut models.NullRawMessage - err = r.db.QueryRowContext(ctx, query, + err := dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query, input.Title, input.TargetURL, input.Secret, @@ -94,7 +92,6 @@ func (r *WebhookRepository) Update(ctx context.Context, id int64, input models.W headersJSON, input.Description, id, - tenantID, ).Scan( &wh.ID, &wh.TenantID, &wh.Title, &wh.TargetURL, &wh.Secret, &wh.Active, pq.Array(&wh.Events), &headersOut, &wh.Description, &wh.CreatedBy, &wh.CreatedAt, &wh.UpdatedAt, &wh.LastDeliveredAt, &wh.FailureCount, @@ -108,13 +105,10 @@ func (r *WebhookRepository) Update(ctx context.Context, id int64, input models.W return wh, nil } +// SetActive enables or disables a webhook +// RLS policy automatically filters by tenant_id func (r *WebhookRepository) SetActive(ctx context.Context, id int64, active bool) error { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return fmt.Errorf("failed to get tenant: %w", err) - } - - res, err := r.db.ExecContext(ctx, `UPDATE webhooks SET active=$1, updated_at=now() WHERE id=$2 AND tenant_id=$3`, active, id, tenantID) + res, err := dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, `UPDATE webhooks SET active=$1, updated_at=now() WHERE id=$2`, active, id) if err != nil { return fmt.Errorf("failed to set active: %w", err) } @@ -125,34 +119,28 @@ func (r *WebhookRepository) SetActive(ctx context.Context, id int64, active bool return nil } +// Delete removes a webhook configuration +// RLS policy automatically filters by tenant_id func (r *WebhookRepository) Delete(ctx context.Context, id int64) error { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return fmt.Errorf("failed to get tenant: %w", err) - } - - _, err = r.db.ExecContext(ctx, `DELETE FROM webhooks WHERE id=$1 AND tenant_id=$2`, id, tenantID) + _, err := dbctx.GetQuerier(ctx, r.db).ExecContext(ctx, `DELETE FROM webhooks WHERE id=$1`, id) if err != nil { return fmt.Errorf("failed to delete webhook: %w", err) } return nil } +// GetByID retrieves a webhook by its ID +// RLS policy automatically filters by tenant_id func (r *WebhookRepository) GetByID(ctx context.Context, id int64) (*models.Webhook, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` SELECT id, tenant_id, title, target_url, secret, active, events, headers, description, created_by, created_at, updated_at, last_delivered_at, failure_count FROM webhooks - WHERE id=$1 AND tenant_id=$2 + WHERE id=$1 ` wh := &models.Webhook{} var events []string var headersJSON models.NullRawMessage - err = r.db.QueryRowContext(ctx, query, id, tenantID).Scan( + err := dbctx.GetQuerier(ctx, r.db).QueryRowContext(ctx, query, id).Scan( &wh.ID, &wh.TenantID, &wh.Title, &wh.TargetURL, &wh.Secret, &wh.Active, pq.Array(&events), &headersJSON, &wh.Description, &wh.CreatedBy, &wh.CreatedAt, &wh.UpdatedAt, &wh.LastDeliveredAt, &wh.FailureCount, ) @@ -166,20 +154,16 @@ func (r *WebhookRepository) GetByID(ctx context.Context, id int64) (*models.Webh return wh, nil } +// List retrieves paginated webhooks +// RLS policy automatically filters by tenant_id func (r *WebhookRepository) List(ctx context.Context, limit, offset int) ([]*models.Webhook, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` SELECT id, tenant_id, title, target_url, secret, active, events, headers, description, created_by, created_at, updated_at, last_delivered_at, failure_count FROM webhooks - WHERE tenant_id=$1 ORDER BY id DESC - LIMIT $2 OFFSET $3 + LIMIT $1 OFFSET $2 ` - rows, err := r.db.QueryContext(ctx, query, tenantID, limit, offset) + rows, err := dbctx.GetQuerier(ctx, r.db).QueryContext(ctx, query, limit, offset) if err != nil { return nil, fmt.Errorf("failed to list webhooks: %w", err) } @@ -206,18 +190,14 @@ func (r *WebhookRepository) List(ctx context.Context, limit, offset int) ([]*mod } // ListActiveByEvent returns active webhooks subscribed to a given event type +// RLS policy automatically filters by tenant_id func (r *WebhookRepository) ListActiveByEvent(ctx context.Context, event string) ([]*models.Webhook, error) { - tenantID, err := r.tenants.CurrentTenant(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tenant: %w", err) - } - query := ` SELECT id, tenant_id, title, target_url, secret, active, events, headers, description, created_by, created_at, updated_at, last_delivered_at, failure_count FROM webhooks - WHERE tenant_id=$1 AND active = TRUE AND $2 = ANY(events) + WHERE active = TRUE AND $1 = ANY(events) ` - rows, err := r.db.QueryContext(ctx, query, tenantID, event) + rows, err := dbctx.GetQuerier(ctx, r.db).QueryContext(ctx, query, event) if err != nil { return nil, fmt.Errorf("failed to list active webhooks: %w", err) } diff --git a/backend/internal/infrastructure/dbctx/context.go b/backend/internal/infrastructure/dbctx/context.go new file mode 100644 index 0000000..8c920b8 --- /dev/null +++ b/backend/internal/infrastructure/dbctx/context.go @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +// Package dbctx provides context helpers for database transactions. +// It enables RLS (Row Level Security) by storing transactions in context.Context, +// allowing repositories to transparently use either a transaction or raw DB connection. +package dbctx + +import ( + "context" + "database/sql" +) + +// Querier is a common interface for *sql.DB and *sql.Tx. +// It allows repositories to work transparently with either a raw DB connection +// or a transaction, enabling RLS isolation via transactional set_config. +type Querier interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) +} + +// Compile-time interface checks +var ( + _ Querier = (*sql.DB)(nil) + _ Querier = (*sql.Tx)(nil) +) + +// txKey is the context key for storing the current transaction. +type txKey struct{} + +// WithTx returns a new context containing the given transaction. +// This is used by the RLS middleware to propagate the transaction +// through the request lifecycle. +func WithTx(ctx context.Context, tx *sql.Tx) context.Context { + return context.WithValue(ctx, txKey{}, tx) +} + +// TxFromContext extracts the transaction from the context if present. +// Returns nil if no transaction is stored in the context. +func TxFromContext(ctx context.Context) *sql.Tx { + if tx, ok := ctx.Value(txKey{}).(*sql.Tx); ok { + return tx + } + return nil +} + +// GetQuerier returns the Querier to use for database operations. +// If a transaction is present in the context (set by RLS middleware), +// it returns the transaction. Otherwise, it returns the raw DB connection. +// +// This allows repositories to transparently benefit from RLS isolation +// when called within a transactional context, while still working +// correctly for operations that bypass RLS (e.g., migrations, admin tasks). +func GetQuerier(ctx context.Context, db *sql.DB) Querier { + if tx := TxFromContext(ctx); tx != nil { + return tx + } + return db +} diff --git a/backend/internal/infrastructure/email/worker.go b/backend/internal/infrastructure/email/worker.go index e584ba4..5087606 100644 --- a/backend/internal/infrastructure/email/worker.go +++ b/backend/internal/infrastructure/email/worker.go @@ -3,6 +3,7 @@ package email import ( "context" + "database/sql" "encoding/json" "fmt" "strings" @@ -10,6 +11,7 @@ import ( "time" "github.com/btouchard/ackify-ce/backend/internal/domain/models" + "github.com/btouchard/ackify-ce/backend/internal/infrastructure/tenant" "github.com/btouchard/ackify-ce/backend/pkg/logger" ) @@ -48,6 +50,10 @@ type Worker struct { renderer *Renderer publisher EventPublisher + // RLS support + db *sql.DB + tenants tenant.Provider + // Worker configuration batchSize int pollInterval time.Duration @@ -85,7 +91,7 @@ func DefaultWorkerConfig() WorkerConfig { } // NewWorker creates a new email worker -func NewWorker(queueRepo QueueRepository, sender Sender, renderer *Renderer, config WorkerConfig) *Worker { +func NewWorker(queueRepo QueueRepository, sender Sender, renderer *Renderer, config WorkerConfig, db *sql.DB, tenants tenant.Provider) *Worker { // Apply defaults if config.BatchSize <= 0 { config.BatchSize = 10 @@ -109,6 +115,8 @@ func NewWorker(queueRepo QueueRepository, sender Sender, renderer *Renderer, con queueRepo: queueRepo, sender: sender, renderer: renderer, + db: db, + tenants: tenants, batchSize: config.BatchSize, pollInterval: config.PollInterval, cleanupInterval: config.CleanupInterval, @@ -218,28 +226,40 @@ func (w *Worker) processBatch() { ctx, cancel := context.WithTimeout(w.ctx, 5*time.Minute) defer cancel() - // Get next batch of emails - emails, err := w.queueRepo.GetNextToProcess(ctx, w.batchSize) + // Get tenant ID for RLS context + tenantID, err := w.tenants.CurrentTenant(ctx) + if err != nil { + logger.Logger.Error("Failed to get tenant for email worker", "error", err.Error()) + return + } + + // Get next batch of emails within tenant context + var emails []*models.EmailQueueItem + err = tenant.WithTenantContext(ctx, w.db, tenantID, func(txCtx context.Context) error { + var fetchErr error + emails, fetchErr = w.queueRepo.GetNextToProcess(txCtx, w.batchSize) + if fetchErr != nil { + return fetchErr + } + if len(emails) == 0 { + // Also check for retryable emails + emails, fetchErr = w.queueRepo.GetRetryableEmails(txCtx, w.batchSize) + } + return fetchErr + }) if err != nil { logger.Logger.Error("Failed to get emails to process", "error", err.Error()) return } if len(emails) == 0 { - // Also check for retryable emails - emails, err = w.queueRepo.GetRetryableEmails(ctx, w.batchSize) - if err != nil { - logger.Logger.Error("Failed to get retryable emails", "error", err.Error()) - return - } - if len(emails) == 0 { - return // Nothing to process - } + return // Nothing to process } logger.Logger.Debug("Processing email batch", "count", len(emails)) // Process emails concurrently with limited concurrency + // Each goroutine gets its own tenant context (transaction) sem := make(chan struct{}, w.maxConcurrent) var wg sync.WaitGroup @@ -251,7 +271,16 @@ func (w *Worker) processBatch() { defer wg.Done() defer func() { <-sem }() // Release semaphore - w.processEmail(ctx, item) + // Each email processing gets its own RLS transaction + err := tenant.WithTenantContext(ctx, w.db, tenantID, func(txCtx context.Context) error { + w.processEmail(txCtx, item) + return nil + }) + if err != nil { + logger.Logger.Error("Failed to process email with tenant context", + "id", item.ID, + "error", err.Error()) + } }(email) } @@ -395,7 +424,19 @@ func (w *Worker) performCleanup() { ctx, cancel := context.WithTimeout(w.ctx, 5*time.Minute) defer cancel() - deleted, err := w.queueRepo.CleanupOldEmails(ctx, w.cleanupAge) + // Get tenant ID for RLS context + tenantID, err := w.tenants.CurrentTenant(ctx) + if err != nil { + logger.Logger.Error("Failed to get tenant for email cleanup", "error", err.Error()) + return + } + + var deleted int64 + err = tenant.WithTenantContext(ctx, w.db, tenantID, func(txCtx context.Context) error { + var cleanupErr error + deleted, cleanupErr = w.queueRepo.CleanupOldEmails(txCtx, w.cleanupAge) + return cleanupErr + }) if err != nil { logger.Logger.Error("Failed to cleanup old emails", "error", err.Error()) return diff --git a/backend/internal/infrastructure/tenant/context.go b/backend/internal/infrastructure/tenant/context.go new file mode 100644 index 0000000..b39ce62 --- /dev/null +++ b/backend/internal/infrastructure/tenant/context.go @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later +package tenant + +import ( + "context" + "database/sql" + "fmt" + + "github.com/btouchard/ackify-ce/backend/internal/infrastructure/dbctx" + "github.com/google/uuid" +) + +// WithTenantContext executes the given function within a transactional context +// configured with RLS tenant isolation. It: +// 1. Begins a new transaction +// 2. Sets the app.tenant_id session variable for RLS policies +// 3. Stores the transaction in the context for use by repositories +// 4. Commits on success, rolls back on error or panic +// +// This is the primary mechanism for ensuring RLS isolation in workers, +// background jobs, and tests. HTTP handlers should use the RLS middleware instead. +// +// Example usage: +// +// err := tenant.WithTenantContext(ctx, db, tenantID, func(ctx context.Context) error { +// // All repository calls here will use RLS isolation +// doc, err := docRepo.GetByDocID(ctx, docID) +// return err +// }) +func WithTenantContext(ctx context.Context, db *sql.DB, tenantID uuid.UUID, fn func(ctx context.Context) error) (err error) { + // Begin transaction + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + + // Ensure cleanup on panic or error + defer func() { + if p := recover(); p != nil { + _ = tx.Rollback() + panic(p) // Re-throw panic after rollback + } else if err != nil { + _ = tx.Rollback() + } + }() + + // Set tenant_id for RLS policies (LOCAL = transaction scope only) + _, err = tx.ExecContext(ctx, "SELECT set_config('app.tenant_id', $1, true)", tenantID.String()) + if err != nil { + return fmt.Errorf("failed to set tenant context: %w", err) + } + + // Store transaction in context for GetQuerier + txCtx := dbctx.WithTx(ctx, tx) + + // Execute the function + if err = fn(txCtx); err != nil { + return err + } + + // Commit transaction + if err = tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + return nil +} + +// WithTenantContextFromProvider is like WithTenantContext but obtains the tenant ID +// from a Provider. This is useful when the tenant ID is not known upfront. +func WithTenantContextFromProvider(ctx context.Context, db *sql.DB, provider Provider, fn func(ctx context.Context) error) error { + tenantID, err := provider.CurrentTenant(ctx) + if err != nil { + return fmt.Errorf("failed to get tenant ID: %w", err) + } + return WithTenantContext(ctx, db, tenantID, fn) +} diff --git a/backend/internal/infrastructure/webhook/worker.go b/backend/internal/infrastructure/webhook/worker.go index db5ac64..89f6392 100644 --- a/backend/internal/infrastructure/webhook/worker.go +++ b/backend/internal/infrastructure/webhook/worker.go @@ -5,17 +5,19 @@ import ( "context" "crypto/hmac" "crypto/sha256" + "database/sql" "encoding/hex" + "fmt" "io" "net/http" + "strconv" "strings" "sync" "time" - "fmt" "github.com/btouchard/ackify-ce/backend/internal/infrastructure/database" + "github.com/btouchard/ackify-ce/backend/internal/infrastructure/tenant" "github.com/btouchard/ackify-ce/backend/pkg/logger" - "strconv" ) // DeliveryRepository is the minimal interface used by the worker @@ -52,6 +54,10 @@ type Worker struct { http HTTPDoer cfg WorkerConfig + // RLS support + db *sql.DB + tenants tenant.Provider + ctx context.Context cancel context.CancelFunc wg sync.WaitGroup @@ -60,7 +66,7 @@ type Worker struct { started bool } -func NewWorker(repo DeliveryRepository, httpClient HTTPDoer, cfg WorkerConfig) *Worker { +func NewWorker(repo DeliveryRepository, httpClient HTTPDoer, cfg WorkerConfig, db *sql.DB, tenants tenant.Provider) *Worker { if cfg.BatchSize <= 0 { cfg.BatchSize = 10 } @@ -80,7 +86,7 @@ func NewWorker(repo DeliveryRepository, httpClient HTTPDoer, cfg WorkerConfig) * cfg.RequestTimeout = 10 * time.Second } ctx, cancel := context.WithCancel(context.Background()) - return &Worker{repo: repo, http: httpClient, cfg: cfg, ctx: ctx, cancel: cancel, stopChan: make(chan struct{})} + return &Worker{repo: repo, http: httpClient, cfg: cfg, db: db, tenants: tenants, ctx: ctx, cancel: cancel, stopChan: make(chan struct{})} } func (w *Worker) Start() error { @@ -148,33 +154,85 @@ func (w *Worker) cleanupLoop() { case <-w.stopChan: return case <-t.C: - if n, err := w.repo.CleanupOld(w.ctx, w.cfg.CleanupAge); err != nil { - logger.Logger.Error("Failed to cleanup webhook deliveries", "error", err.Error()) - } else if n > 0 { - logger.Logger.Info("Cleaned webhook deliveries", "count", n) - } + w.performCleanup() } } } +func (w *Worker) performCleanup() { + ctx, cancel := context.WithTimeout(w.ctx, 5*time.Minute) + defer cancel() + + var deleted int64 + var err error + + // Use RLS context if db and tenants are available + if w.db != nil && w.tenants != nil { + tenantID, tenantErr := w.tenants.CurrentTenant(ctx) + if tenantErr != nil { + logger.Logger.Error("Failed to get tenant for webhook cleanup", "error", tenantErr.Error()) + return + } + + err = tenant.WithTenantContext(ctx, w.db, tenantID, func(txCtx context.Context) error { + var cleanupErr error + deleted, cleanupErr = w.repo.CleanupOld(txCtx, w.cfg.CleanupAge) + return cleanupErr + }) + } else { + // No RLS - direct repository access (for tests) + deleted, err = w.repo.CleanupOld(ctx, w.cfg.CleanupAge) + } + + if err != nil { + logger.Logger.Error("Failed to cleanup webhook deliveries", "error", err.Error()) + } else if deleted > 0 { + logger.Logger.Info("Cleaned webhook deliveries", "count", deleted) + } +} + func (w *Worker) processBatch() { ctx, cancel := context.WithTimeout(w.ctx, 5*time.Minute) defer cancel() - items, err := w.repo.GetNextToProcess(ctx, w.cfg.BatchSize) + + var items []*database.WebhookDeliveryItem + var err error + + // Use RLS context if db and tenants are available + if w.db != nil && w.tenants != nil { + tenantID, tenantErr := w.tenants.CurrentTenant(ctx) + if tenantErr != nil { + logger.Logger.Error("Failed to get tenant for webhook worker", "error", tenantErr.Error()) + return + } + + err = tenant.WithTenantContext(ctx, w.db, tenantID, func(txCtx context.Context) error { + var fetchErr error + items, fetchErr = w.repo.GetNextToProcess(txCtx, w.cfg.BatchSize) + if fetchErr != nil { + return fetchErr + } + if len(items) == 0 { + items, fetchErr = w.repo.GetRetryable(txCtx, w.cfg.BatchSize) + } + return fetchErr + }) + } else { + // No RLS - direct repository access (for tests) + items, err = w.repo.GetNextToProcess(ctx, w.cfg.BatchSize) + if err == nil && len(items) == 0 { + items, err = w.repo.GetRetryable(ctx, w.cfg.BatchSize) + } + } + if err != nil { logger.Logger.Error("Failed to get webhook deliveries", "error", err.Error()) return } if len(items) == 0 { - items, err = w.repo.GetRetryable(ctx, w.cfg.BatchSize) - if err != nil { - logger.Logger.Error("Failed to get retryable webhook deliveries", "error", err.Error()) - return - } - if len(items) == 0 { - return - } + return } + sem := make(chan struct{}, w.cfg.MaxConcurrent) var wg sync.WaitGroup for _, it := range items { @@ -183,7 +241,22 @@ func (w *Worker) processBatch() { go func(item *database.WebhookDeliveryItem) { defer wg.Done() defer func() { <-sem }() - w.processOne(ctx, item) + + // Use RLS context if available + if w.db != nil && w.tenants != nil { + tenantID, _ := w.tenants.CurrentTenant(ctx) + err := tenant.WithTenantContext(ctx, w.db, tenantID, func(txCtx context.Context) error { + w.processOne(txCtx, item) + return nil + }) + if err != nil { + logger.Logger.Error("Failed to process webhook with tenant context", + "id", item.ID, + "error", err.Error()) + } + } else { + w.processOne(ctx, item) + } }(it) } wg.Wait() diff --git a/backend/internal/infrastructure/webhook/worker_test.go b/backend/internal/infrastructure/webhook/worker_test.go index d7ff02e..9ce3123 100644 --- a/backend/internal/infrastructure/webhook/worker_test.go +++ b/backend/internal/infrastructure/webhook/worker_test.go @@ -3,15 +3,25 @@ package webhook import ( "context" + "io" "net/http" "strings" "testing" "time" "github.com/btouchard/ackify-ce/backend/internal/infrastructure/database" - "io" + "github.com/google/uuid" ) +// mockTenantProvider for testing +type mockTenantProviderWebhook struct { + tenantID uuid.UUID +} + +func (m *mockTenantProviderWebhook) CurrentTenant(ctx context.Context) (uuid.UUID, error) { + return m.tenantID, nil +} + func TestComputeSignature(t *testing.T) { secret := "supersecret" ts := int64(1730000000) @@ -63,7 +73,8 @@ func (f *fakeDelRepo) CleanupOld(ctx context.Context, olderThan time.Duration) ( func TestWorker_ProcessBatch_Success(t *testing.T) { repo := &fakeDelRepo{} doer := &fakeDoer{resp: &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("ok")), Header: http.Header{}}} - w := NewWorker(repo, doer, DefaultWorkerConfig()) + tenants := &mockTenantProviderWebhook{tenantID: uuid.New()} + w := NewWorker(repo, doer, DefaultWorkerConfig(), nil, tenants) w.processBatch() if repo.delivered != 1 { t.Fatalf("expected delivered=1, got %d", repo.delivered) diff --git a/backend/internal/infrastructure/workers/magic_link_cleanup.go b/backend/internal/infrastructure/workers/magic_link_cleanup.go index 26c845d..a32296a 100644 --- a/backend/internal/infrastructure/workers/magic_link_cleanup.go +++ b/backend/internal/infrastructure/workers/magic_link_cleanup.go @@ -3,9 +3,11 @@ package workers import ( "context" + "database/sql" "time" "github.com/btouchard/ackify-ce/backend/internal/application/services" + "github.com/btouchard/ackify-ce/backend/internal/infrastructure/tenant" "github.com/btouchard/ackify-ce/backend/pkg/logger" ) @@ -14,9 +16,13 @@ type MagicLinkCleanupWorker struct { service *services.MagicLinkService interval time.Duration stopChan chan struct{} + + // RLS support + db *sql.DB + tenants tenant.Provider } -func NewMagicLinkCleanupWorker(service *services.MagicLinkService, interval time.Duration) *MagicLinkCleanupWorker { +func NewMagicLinkCleanupWorker(service *services.MagicLinkService, interval time.Duration, db *sql.DB, tenants tenant.Provider) *MagicLinkCleanupWorker { if interval == 0 { interval = 1 * time.Hour // Défaut: toutes les heures } @@ -25,6 +31,8 @@ func NewMagicLinkCleanupWorker(service *services.MagicLinkService, interval time service: service, interval: interval, stopChan: make(chan struct{}), + db: db, + tenants: tenants, } } @@ -53,7 +61,19 @@ func (w *MagicLinkCleanupWorker) Stop() { } func (w *MagicLinkCleanupWorker) cleanup(ctx context.Context) { - deleted, err := w.service.CleanupExpiredTokens(ctx) + // Get tenant ID for RLS context + tenantID, err := w.tenants.CurrentTenant(ctx) + if err != nil { + logger.Logger.Error("Failed to get tenant for magic link cleanup", "error", err) + return + } + + var deleted int64 + err = tenant.WithTenantContext(ctx, w.db, tenantID, func(txCtx context.Context) error { + var cleanupErr error + deleted, cleanupErr = w.service.CleanupExpiredTokens(txCtx) + return cleanupErr + }) if err != nil { logger.Logger.Error("Failed to cleanup expired magic link tokens", "error", err) return diff --git a/backend/internal/presentation/api/router.go b/backend/internal/presentation/api/router.go index 8fc22a5..3d02ca5 100644 --- a/backend/internal/presentation/api/router.go +++ b/backend/internal/presentation/api/router.go @@ -3,6 +3,7 @@ package api import ( "context" + "database/sql" "encoding/json" "net/http" "os" @@ -14,6 +15,7 @@ import ( "github.com/btouchard/ackify-ce/backend/internal/application/services" "github.com/btouchard/ackify-ce/backend/internal/domain/models" + "github.com/btouchard/ackify-ce/backend/internal/infrastructure/tenant" apiAdmin "github.com/btouchard/ackify-ce/backend/internal/presentation/api/admin" apiAuth "github.com/btouchard/ackify-ce/backend/internal/presentation/api/auth" "github.com/btouchard/ackify-ce/backend/internal/presentation/api/documents" @@ -93,6 +95,10 @@ type webhookService interface { // RouterConfig holds configuration for the API router type RouterConfig struct { + // Database for RLS middleware + DB *sql.DB // Required for RLS transaction management + TenantProvider tenant.Provider // Required for tenant context + // Capability providers AuthProvider providers.AuthProvider // Required for session management OAuthProvider providers.OAuthAuthProvider // Optional, for OAuth authentication @@ -153,6 +159,13 @@ func NewRouter(cfg RouterConfig) *chi.Mux { r.Use(apiMiddleware.CORS) r.Use(generalRateLimit.Middleware) + // RLS middleware for database tenant isolation (always active) + // Must be after Recoverer to handle panics, before handlers that use DB + if cfg.DB != nil && cfg.TenantProvider != nil { + rlsMiddleware := shared.NewRLSMiddleware(cfg.DB, cfg.TenantProvider) + r.Use(rlsMiddleware.Handler) + } + // Initialize handlers healthHandler := health.NewHandler() authHandler := apiAuth.NewHandler(cfg.AuthProvider, cfg.OAuthProvider, cfg.MagicLinkService, apiMiddleware, cfg.BaseURL, cfg.OAuthEnabled, cfg.MagicLinkEnabled) diff --git a/backend/internal/presentation/api/shared/rls_middleware.go b/backend/internal/presentation/api/shared/rls_middleware.go new file mode 100644 index 0000000..5f36746 --- /dev/null +++ b/backend/internal/presentation/api/shared/rls_middleware.go @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later +package shared + +import ( + "database/sql" + "net/http" + + "github.com/btouchard/ackify-ce/backend/internal/infrastructure/dbctx" + "github.com/btouchard/ackify-ce/backend/internal/infrastructure/tenant" + "github.com/btouchard/ackify-ce/backend/pkg/logger" +) + +// RLSMiddleware provides Row Level Security context for database queries. +// It wraps each request in a transaction with app.tenant_id set via set_config. +// RLS is always active - this is a security feature that cannot be disabled. +type RLSMiddleware struct { + db *sql.DB + tenants tenant.Provider +} + +// NewRLSMiddleware creates a new RLS middleware. +func NewRLSMiddleware(db *sql.DB, tenants tenant.Provider) *RLSMiddleware { + return &RLSMiddleware{ + db: db, + tenants: tenants, + } +} + +// Handler wraps HTTP requests with RLS transaction context. +// For each request: +// 1. Gets the current tenant ID from the provider +// 2. Starts a database transaction +// 3. Sets app.tenant_id in the session via set_config +// 4. Stores the transaction in the request context +// 5. Calls the next handler +// 6. Commits on success (2xx-3xx status) or rolls back on error/panic +func (m *RLSMiddleware) Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + requestID := getRequestID(ctx) + + // Get current tenant from provider + tenantID, err := m.tenants.CurrentTenant(ctx) + if err != nil { + logger.Logger.Error("rls_middleware: failed to get tenant", + "request_id", requestID, + "error", err.Error()) + WriteError(w, http.StatusInternalServerError, "RLS_ERROR", "Failed to establish tenant context", nil) + return + } + + // Start transaction + tx, err := m.db.BeginTx(ctx, nil) + if err != nil { + logger.Logger.Error("rls_middleware: failed to begin transaction", + "request_id", requestID, + "error", err.Error()) + WriteError(w, http.StatusInternalServerError, "RLS_ERROR", "Failed to start database transaction", nil) + return + } + + // Set tenant context in session + // The 'true' makes it local to this transaction only + _, err = tx.ExecContext(ctx, "SELECT set_config('app.tenant_id', $1, true)", tenantID.String()) + if err != nil { + tx.Rollback() + logger.Logger.Error("rls_middleware: failed to set tenant context", + "request_id", requestID, + "tenant_id", tenantID.String(), + "error", err.Error()) + WriteError(w, http.StatusInternalServerError, "RLS_ERROR", "Failed to set tenant context", nil) + return + } + + logger.Logger.Debug("rls_middleware: tenant context set", + "request_id", requestID, + "tenant_id", tenantID.String()) + + // Store transaction in context for repositories to use + ctxWithTx := dbctx.WithTx(ctx, tx) + + // Wrap response writer to capture status code + wrapped := &statusCapturingResponseWriter{ResponseWriter: w, status: http.StatusOK} + + // Handle panics - rollback on panic + defer func() { + if rec := recover(); rec != nil { + tx.Rollback() + logger.Logger.Error("rls_middleware: panic recovered, transaction rolled back", + "request_id", requestID, + "panic", rec) + panic(rec) // re-panic after rollback to let recovery middleware handle it + } + }() + + // Call next handler with transaction context + next.ServeHTTP(wrapped, r.WithContext(ctxWithTx)) + + // Commit or rollback based on response status + if wrapped.status >= 200 && wrapped.status < 400 { + if err := tx.Commit(); err != nil { + logger.Logger.Error("rls_middleware: failed to commit transaction", + "request_id", requestID, + "status", wrapped.status, + "error", err.Error()) + // Transaction already used, can't send error response + } else { + logger.Logger.Debug("rls_middleware: transaction committed", + "request_id", requestID, + "status", wrapped.status) + } + } else { + if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { + logger.Logger.Error("rls_middleware: failed to rollback transaction", + "request_id", requestID, + "status", wrapped.status, + "error", err.Error()) + } else { + logger.Logger.Debug("rls_middleware: transaction rolled back", + "request_id", requestID, + "status", wrapped.status) + } + } + }) +} + +// statusCapturingResponseWriter captures the HTTP status code for decision making +type statusCapturingResponseWriter struct { + http.ResponseWriter + status int + wroteHeader bool +} + +func (w *statusCapturingResponseWriter) WriteHeader(code int) { + if !w.wroteHeader { + w.status = code + w.wroteHeader = true + } + w.ResponseWriter.WriteHeader(code) +} + +func (w *statusCapturingResponseWriter) Write(b []byte) (int, error) { + if !w.wroteHeader { + w.WriteHeader(http.StatusOK) + } + return w.ResponseWriter.Write(b) +} diff --git a/backend/internal/presentation/api/shared/rls_middleware_test.go b/backend/internal/presentation/api/shared/rls_middleware_test.go new file mode 100644 index 0000000..7253f09 --- /dev/null +++ b/backend/internal/presentation/api/shared/rls_middleware_test.go @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later +package shared + +import ( + "context" + "database/sql" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" +) + +// mockTenantProvider is a test implementation of tenant.Provider +type mockTenantProvider struct { + tenantID uuid.UUID + err error +} + +func (m *mockTenantProvider) CurrentTenant(ctx context.Context) (uuid.UUID, error) { + if m.err != nil { + return uuid.Nil, m.err + } + return m.tenantID, nil +} + +func TestNewRLSMiddleware(t *testing.T) { + tenantID := uuid.New() + provider := &mockTenantProvider{tenantID: tenantID} + + // Test with nil db (should still create middleware) + m := NewRLSMiddleware(nil, provider) + if m == nil { + t.Error("NewRLSMiddleware returned nil") + } + if m.db != nil { + t.Error("db should be nil") + } + if m.tenants != provider { + t.Error("tenants should be the provided provider") + } +} + +func TestStatusCapturingResponseWriter_WriteHeader(t *testing.T) { + rr := httptest.NewRecorder() + wrapped := &statusCapturingResponseWriter{ResponseWriter: rr, status: http.StatusOK} + + // First WriteHeader should set status + wrapped.WriteHeader(http.StatusNotFound) + if wrapped.status != http.StatusNotFound { + t.Errorf("Expected status %d, got %d", http.StatusNotFound, wrapped.status) + } + if !wrapped.wroteHeader { + t.Error("wroteHeader should be true after WriteHeader") + } + + // Second WriteHeader should be ignored + wrapped.WriteHeader(http.StatusOK) + if wrapped.status != http.StatusNotFound { + t.Errorf("Expected status %d after second WriteHeader, got %d", http.StatusNotFound, wrapped.status) + } +} + +func TestStatusCapturingResponseWriter_Write(t *testing.T) { + rr := httptest.NewRecorder() + wrapped := &statusCapturingResponseWriter{ResponseWriter: rr, status: http.StatusOK} + + // Write without explicit WriteHeader should trigger implicit 200 + _, err := wrapped.Write([]byte("test")) + if err != nil { + t.Errorf("Write returned error: %v", err) + } + if wrapped.status != http.StatusOK { + t.Errorf("Expected status %d, got %d", http.StatusOK, wrapped.status) + } + if !wrapped.wroteHeader { + t.Error("wroteHeader should be true after Write") + } +} + +func TestRLSMiddleware_TenantError(t *testing.T) { + provider := &mockTenantProvider{err: sql.ErrNoRows} + m := NewRLSMiddleware(nil, provider) + + called := false + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rr := httptest.NewRecorder() + + m.Handler(handler).ServeHTTP(rr, req) + + if called { + t.Error("Handler should not be called when tenant lookup fails") + } + if rr.Code != http.StatusInternalServerError { + t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, rr.Code) + } +} diff --git a/backend/migrations/0016_add_rls_policies.down.sql b/backend/migrations/0016_add_rls_policies.down.sql new file mode 100644 index 0000000..bc017a9 --- /dev/null +++ b/backend/migrations/0016_add_rls_policies.down.sql @@ -0,0 +1,76 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +-- ============================================================================ +-- Migration Rollback: Remove Row Level Security (RLS) Policies +-- ============================================================================ +-- This rollback disables RLS and removes all tenant isolation policies. +-- WARNING: After this rollback, tenant isolation relies solely on application +-- code (WHERE tenant_id = ...). Use with caution in production. +-- ============================================================================ + +-- Step 1: Drop policies and disable RLS on all tables + +-- ----- DOCUMENTS ----- +DROP POLICY IF EXISTS tenant_isolation_documents ON documents; +ALTER TABLE documents DISABLE ROW LEVEL SECURITY; + +-- ----- SIGNATURES ----- +DROP POLICY IF EXISTS tenant_isolation_signatures ON signatures; +ALTER TABLE signatures DISABLE ROW LEVEL SECURITY; + +-- ----- EXPECTED_SIGNERS ----- +DROP POLICY IF EXISTS tenant_isolation_expected_signers ON expected_signers; +ALTER TABLE expected_signers DISABLE ROW LEVEL SECURITY; + +-- ----- WEBHOOKS ----- +DROP POLICY IF EXISTS tenant_isolation_webhooks ON webhooks; +ALTER TABLE webhooks DISABLE ROW LEVEL SECURITY; + +-- ----- REMINDER_LOGS ----- +DROP POLICY IF EXISTS tenant_isolation_reminder_logs ON reminder_logs; +ALTER TABLE reminder_logs DISABLE ROW LEVEL SECURITY; + +-- ----- EMAIL_QUEUE ----- +DROP POLICY IF EXISTS tenant_isolation_email_queue ON email_queue; +ALTER TABLE email_queue DISABLE ROW LEVEL SECURITY; + +-- ----- CHECKSUM_VERIFICATIONS ----- +DROP POLICY IF EXISTS tenant_isolation_checksum_verifications ON checksum_verifications; +ALTER TABLE checksum_verifications DISABLE ROW LEVEL SECURITY; + +-- ----- WEBHOOK_DELIVERIES ----- +DROP POLICY IF EXISTS tenant_isolation_webhook_deliveries ON webhook_deliveries; +ALTER TABLE webhook_deliveries DISABLE ROW LEVEL SECURITY; + +-- ----- OAUTH_SESSIONS ----- +DROP POLICY IF EXISTS tenant_isolation_oauth_sessions ON oauth_sessions; +ALTER TABLE oauth_sessions DISABLE ROW LEVEL SECURITY; + +-- ----- MAGIC_LINK_TOKENS ----- +DROP POLICY IF EXISTS tenant_isolation_magic_link_tokens ON magic_link_tokens; +ALTER TABLE magic_link_tokens DISABLE ROW LEVEL SECURITY; + +-- ----- MAGIC_LINK_AUTH_ATTEMPTS ----- +DROP POLICY IF EXISTS tenant_isolation_magic_link_auth_attempts ON magic_link_auth_attempts; +ALTER TABLE magic_link_auth_attempts DISABLE ROW LEVEL SECURITY; + +-- Step 2: Revoke privileges from ackify_app role +-- Note: We don't DROP the role as it might be in use by other connections +REVOKE SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA public FROM ackify_app; +REVOKE USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public FROM ackify_app; +REVOKE USAGE ON SCHEMA public FROM ackify_app; +-- REVOKE CONNECT is not done to avoid breaking active connections + +-- Step 3: Remove default privileges +ALTER DEFAULT PRIVILEGES IN SCHEMA public + REVOKE SELECT, INSERT, UPDATE, DELETE ON TABLES FROM ackify_app; + +ALTER DEFAULT PRIVILEGES IN SCHEMA public + REVOKE USAGE, SELECT ON SEQUENCES FROM ackify_app; + +-- Step 4: Drop the helper function +DROP FUNCTION IF EXISTS current_tenant_id(); + +-- Note: The ackify_app role is NOT dropped to avoid breaking existing connections. +-- To fully remove it, run: DROP ROLE IF EXISTS ackify_app; +-- after ensuring no active connections use this role. diff --git a/backend/migrations/0016_add_rls_policies.up.sql b/backend/migrations/0016_add_rls_policies.up.sql new file mode 100644 index 0000000..9472b35 --- /dev/null +++ b/backend/migrations/0016_add_rls_policies.up.sql @@ -0,0 +1,196 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +-- ============================================================================ +-- Migration: Add Row Level Security (RLS) Policies +-- ============================================================================ +-- This migration enables PostgreSQL Row Level Security for tenant isolation. +-- It ensures that all queries are automatically filtered by tenant_id, +-- eliminating the risk of data leakage if application code forgets the filter. +-- +-- Prerequisites: +-- - Migration 0015 must have run (tenant_id columns exist) +-- - A non-superuser role 'ackify_app' should be used for runtime queries +-- +-- How it works: +-- 1. current_tenant_id() reads 'app.tenant_id' from session config +-- 2. RLS policies filter rows where tenant_id = current_tenant_id() +-- 3. FORCE ROW LEVEL SECURITY ensures policies apply even to table owners +-- 4. Application sets app.tenant_id via: SELECT set_config('app.tenant_id', $1, true) +-- ============================================================================ + +-- Create helper function to get current tenant from session +-- The function returns NULL if app.tenant_id is not set, which means +-- RLS policies will filter out ALL rows (secure by default). +CREATE OR REPLACE FUNCTION current_tenant_id() RETURNS UUID AS $$ +DECLARE + tenant_id_str TEXT; +BEGIN + tenant_id_str := current_setting('app.tenant_id', true); + IF tenant_id_str IS NULL OR tenant_id_str = '' THEN + RETURN NULL; + END IF; + RETURN tenant_id_str::UUID; +EXCEPTION WHEN OTHERS THEN + -- Invalid UUID format - return NULL for safety + RAISE WARNING 'current_tenant_id(): Invalid UUID format: %', tenant_id_str; + RETURN NULL; +END; +$$ LANGUAGE plpgsql STABLE; + +COMMENT ON FUNCTION current_tenant_id() IS 'Returns the current tenant UUID from session config (app.tenant_id). Returns NULL if not set.'; + +-- IMPORTANT: The ackify_app role is created by the migrate tool before running migrations. +-- Set ACKIFY_APP_PASSWORD environment variable to enable RLS support. +-- The migrate tool will create the role with the specified password. + +-- ============================================================================ +-- Enable RLS and create policies for each tenant-aware table +-- ============================================================================ + +-- ----- DOCUMENTS ----- +ALTER TABLE documents ENABLE ROW LEVEL SECURITY; +ALTER TABLE documents FORCE ROW LEVEL SECURITY; + +DROP POLICY IF EXISTS tenant_isolation_documents ON documents; +CREATE POLICY tenant_isolation_documents ON documents + USING (tenant_id = current_tenant_id()) + WITH CHECK (tenant_id = current_tenant_id()); + +GRANT SELECT, INSERT, UPDATE, DELETE ON documents TO ackify_app; + +-- ----- SIGNATURES ----- +ALTER TABLE signatures ENABLE ROW LEVEL SECURITY; +ALTER TABLE signatures FORCE ROW LEVEL SECURITY; + +DROP POLICY IF EXISTS tenant_isolation_signatures ON signatures; +CREATE POLICY tenant_isolation_signatures ON signatures + USING (tenant_id = current_tenant_id()) + WITH CHECK (tenant_id = current_tenant_id()); + +GRANT SELECT, INSERT, UPDATE, DELETE ON signatures TO ackify_app; +GRANT USAGE, SELECT ON SEQUENCE signatures_id_seq TO ackify_app; + +-- ----- EXPECTED_SIGNERS ----- +ALTER TABLE expected_signers ENABLE ROW LEVEL SECURITY; +ALTER TABLE expected_signers FORCE ROW LEVEL SECURITY; + +DROP POLICY IF EXISTS tenant_isolation_expected_signers ON expected_signers; +CREATE POLICY tenant_isolation_expected_signers ON expected_signers + USING (tenant_id = current_tenant_id()) + WITH CHECK (tenant_id = current_tenant_id()); + +GRANT SELECT, INSERT, UPDATE, DELETE ON expected_signers TO ackify_app; +GRANT USAGE, SELECT ON SEQUENCE expected_signers_id_seq TO ackify_app; + +-- ----- WEBHOOKS ----- +ALTER TABLE webhooks ENABLE ROW LEVEL SECURITY; +ALTER TABLE webhooks FORCE ROW LEVEL SECURITY; + +DROP POLICY IF EXISTS tenant_isolation_webhooks ON webhooks; +CREATE POLICY tenant_isolation_webhooks ON webhooks + USING (tenant_id = current_tenant_id()) + WITH CHECK (tenant_id = current_tenant_id()); + +GRANT SELECT, INSERT, UPDATE, DELETE ON webhooks TO ackify_app; +GRANT USAGE, SELECT ON SEQUENCE webhooks_id_seq TO ackify_app; + +-- ----- REMINDER_LOGS ----- +ALTER TABLE reminder_logs ENABLE ROW LEVEL SECURITY; +ALTER TABLE reminder_logs FORCE ROW LEVEL SECURITY; + +DROP POLICY IF EXISTS tenant_isolation_reminder_logs ON reminder_logs; +CREATE POLICY tenant_isolation_reminder_logs ON reminder_logs + USING (tenant_id = current_tenant_id()) + WITH CHECK (tenant_id = current_tenant_id()); + +GRANT SELECT, INSERT, UPDATE, DELETE ON reminder_logs TO ackify_app; +GRANT USAGE, SELECT ON SEQUENCE reminder_logs_id_seq TO ackify_app; + +-- ----- EMAIL_QUEUE ----- +ALTER TABLE email_queue ENABLE ROW LEVEL SECURITY; +ALTER TABLE email_queue FORCE ROW LEVEL SECURITY; + +DROP POLICY IF EXISTS tenant_isolation_email_queue ON email_queue; +CREATE POLICY tenant_isolation_email_queue ON email_queue + USING (tenant_id = current_tenant_id()) + WITH CHECK (tenant_id = current_tenant_id()); + +GRANT SELECT, INSERT, UPDATE, DELETE ON email_queue TO ackify_app; +GRANT USAGE, SELECT ON SEQUENCE email_queue_id_seq TO ackify_app; + +-- ----- CHECKSUM_VERIFICATIONS ----- +ALTER TABLE checksum_verifications ENABLE ROW LEVEL SECURITY; +ALTER TABLE checksum_verifications FORCE ROW LEVEL SECURITY; + +DROP POLICY IF EXISTS tenant_isolation_checksum_verifications ON checksum_verifications; +CREATE POLICY tenant_isolation_checksum_verifications ON checksum_verifications + USING (tenant_id = current_tenant_id()) + WITH CHECK (tenant_id = current_tenant_id()); + +GRANT SELECT, INSERT, UPDATE, DELETE ON checksum_verifications TO ackify_app; +GRANT USAGE, SELECT ON SEQUENCE checksum_verifications_id_seq TO ackify_app; + +-- ----- WEBHOOK_DELIVERIES ----- +ALTER TABLE webhook_deliveries ENABLE ROW LEVEL SECURITY; +ALTER TABLE webhook_deliveries FORCE ROW LEVEL SECURITY; + +DROP POLICY IF EXISTS tenant_isolation_webhook_deliveries ON webhook_deliveries; +CREATE POLICY tenant_isolation_webhook_deliveries ON webhook_deliveries + USING (tenant_id = current_tenant_id()) + WITH CHECK (tenant_id = current_tenant_id()); + +GRANT SELECT, INSERT, UPDATE, DELETE ON webhook_deliveries TO ackify_app; +GRANT USAGE, SELECT ON SEQUENCE webhook_deliveries_id_seq TO ackify_app; + +-- ----- OAUTH_SESSIONS ----- +ALTER TABLE oauth_sessions ENABLE ROW LEVEL SECURITY; +ALTER TABLE oauth_sessions FORCE ROW LEVEL SECURITY; + +DROP POLICY IF EXISTS tenant_isolation_oauth_sessions ON oauth_sessions; +CREATE POLICY tenant_isolation_oauth_sessions ON oauth_sessions + USING (tenant_id = current_tenant_id()) + WITH CHECK (tenant_id = current_tenant_id()); + +GRANT SELECT, INSERT, UPDATE, DELETE ON oauth_sessions TO ackify_app; +GRANT USAGE, SELECT ON SEQUENCE oauth_sessions_id_seq TO ackify_app; + +-- ----- MAGIC_LINK_TOKENS ----- +-- Note: Magic link tokens may have NULL tenant_id for login requests +-- Policy allows NULL tenant_id OR matching tenant_id +ALTER TABLE magic_link_tokens ENABLE ROW LEVEL SECURITY; +ALTER TABLE magic_link_tokens FORCE ROW LEVEL SECURITY; + +DROP POLICY IF EXISTS tenant_isolation_magic_link_tokens ON magic_link_tokens; +CREATE POLICY tenant_isolation_magic_link_tokens ON magic_link_tokens + USING (tenant_id IS NULL OR tenant_id = current_tenant_id()) + WITH CHECK (tenant_id IS NULL OR tenant_id = current_tenant_id()); + +GRANT SELECT, INSERT, UPDATE, DELETE ON magic_link_tokens TO ackify_app; +GRANT USAGE, SELECT ON SEQUENCE magic_link_tokens_id_seq TO ackify_app; + +-- ----- MAGIC_LINK_AUTH_ATTEMPTS ----- +-- Note: Auth attempts may have NULL tenant_id before authentication +ALTER TABLE magic_link_auth_attempts ENABLE ROW LEVEL SECURITY; +ALTER TABLE magic_link_auth_attempts FORCE ROW LEVEL SECURITY; + +DROP POLICY IF EXISTS tenant_isolation_magic_link_auth_attempts ON magic_link_auth_attempts; +CREATE POLICY tenant_isolation_magic_link_auth_attempts ON magic_link_auth_attempts + USING (tenant_id IS NULL OR tenant_id = current_tenant_id()) + WITH CHECK (tenant_id IS NULL OR tenant_id = current_tenant_id()); + +GRANT SELECT, INSERT, UPDATE, DELETE ON magic_link_auth_attempts TO ackify_app; +GRANT USAGE, SELECT ON SEQUENCE magic_link_auth_attempts_id_seq TO ackify_app; + +-- ----- INSTANCE_METADATA ----- +-- This table is read-only for the app (tenant ID source) +-- No RLS needed as it contains only one row per instance +GRANT SELECT ON instance_metadata TO ackify_app; + +-- ============================================================================ +-- Set default privileges for future tables +-- ============================================================================ +ALTER DEFAULT PRIVILEGES IN SCHEMA public + GRANT SELECT, INSERT, UPDATE, DELETE ON TABLES TO ackify_app; + +ALTER DEFAULT PRIVILEGES IN SCHEMA public + GRANT USAGE, SELECT ON SEQUENCES TO ackify_app; diff --git a/backend/pkg/web/server.go b/backend/pkg/web/server.go index f44bc6a..7da0670 100644 --- a/backend/pkg/web/server.go +++ b/backend/pkg/web/server.go @@ -352,7 +352,7 @@ func (b *ServerBuilder) createRepositories() *repositories { func (b *ServerBuilder) initializeWebhookSystem(repos *repositories) (*services.WebhookPublisher, *webhook.Worker, error) { whPublisher := services.NewWebhookPublisher(repos.webhook, repos.webhookDelivery) whCfg := webhook.DefaultWorkerConfig() - whWorker := webhook.NewWorker(repos.webhookDelivery, &http.Client{}, whCfg) + whWorker := webhook.NewWorker(repos.webhookDelivery, &http.Client{}, whCfg, b.db, b.tenantProvider) if err := whWorker.Start(); err != nil { return nil, nil, fmt.Errorf("failed to start webhook worker: %w", err) @@ -370,7 +370,7 @@ func (b *ServerBuilder) initializeEmailWorker(repos *repositories, whPublisher * renderer := email.NewRenderer(getTemplatesDir(), b.cfg.App.BaseURL, b.cfg.App.Organisation, b.cfg.Mail.FromName, b.cfg.Mail.From, "fr", b.i18nService) workerConfig := email.DefaultWorkerConfig() - emailWorker := email.NewWorker(repos.emailQueue, b.emailSender, renderer, workerConfig) + emailWorker := email.NewWorker(repos.emailQueue, b.emailSender, renderer, workerConfig, b.db, b.tenantProvider) if whPublisher != nil { emailWorker.SetPublisher(whPublisher) @@ -424,7 +424,7 @@ func (b *ServerBuilder) initializeMagicLinkService(ctx context.Context, repos *r var magicLinkWorker *workers.MagicLinkCleanupWorker if b.magicLinkEnabled { logger.Logger.Info("Magic Link authentication enabled") - magicLinkWorker = workers.NewMagicLinkCleanupWorker(b.magicLinkService, 1*time.Hour) + magicLinkWorker = workers.NewMagicLinkCleanupWorker(b.magicLinkService, 1*time.Hour, b.db, b.tenantProvider) go magicLinkWorker.Start(ctx) } else { logger.Logger.Info("Magic Link authentication disabled") @@ -454,7 +454,7 @@ func (b *ServerBuilder) initializeSessionWorker(repos *repositories) (*auth.Sess } workerConfig := auth.DefaultSessionWorkerConfig() - sessionWorker := auth.NewSessionWorker(repos.oauthSession, workerConfig) + sessionWorker := auth.NewSessionWorker(repos.oauthSession, workerConfig, b.db, b.tenantProvider) if err := sessionWorker.Start(); err != nil { return nil, fmt.Errorf("failed to start OAuth session worker: %w", err) } @@ -470,6 +470,11 @@ func (b *ServerBuilder) buildRouter(repos *repositories, whPublisher *services.W // Build API router config using providers apiConfig := api.RouterConfig{ + // Database for RLS middleware + DB: b.db, + TenantProvider: b.tenantProvider, + + // Capability providers AuthProvider: b.authProvider, OAuthProvider: b.oauthProvider, Authorizer: b.authorizer, diff --git a/compose.e2e.yml b/compose.e2e.yml index 8577470..3061334 100644 --- a/compose.e2e.yml +++ b/compose.e2e.yml @@ -6,6 +6,7 @@ services: container_name: ackify-migrate environment: ACKIFY_DB_DSN: "postgres://postgres:testpassword@ackify-db:5432/ackify_test?sslmode=disable" + ACKIFY_APP_PASSWORD: "ackifytestpassword" depends_on: ackify-db: condition: service_healthy @@ -27,7 +28,7 @@ services: ACKIFY_LOG_FORMAT: "classic" ACKIFY_BASE_URL: "http://localhost:8080" ACKIFY_ORGANISATION: "Ackify Test" - ACKIFY_DB_DSN: "postgres://postgres:testpassword@ackify-db:5432/ackify_test?sslmode=disable" + ACKIFY_DB_DSN: "postgres://ackify_app:ackifytestpassword@ackify-db:5432/ackify_test?sslmode=disable" ACKIFY_AUTH_OAUTH_ENABLED: "true" ACKIFY_AUTH_MAGICLINK_ENABLED: "true" ACKIFY_OAUTH_PROVIDER: "custom" diff --git a/compose.yml b/compose.yml index 87285d2..5c975de 100644 --- a/compose.yml +++ b/compose.yml @@ -7,7 +7,8 @@ services: container_name: ackify-ce-migrate environment: ACKIFY_LOG_LEVEL: "${ACKIFY_LOG_LEVEL}" - ACKIFY_DB_DSN: "postgres://${POSTGRES_USER}:${POSTGRES_PASSWORD}@ackify-db:5432/${POSTGRES_DB}?sslmode=disable" + ACKIFY_DB_DSN: "postgres://postgres:${POSTGRES_PASSWORD}@ackify-db:5432/ackify?sslmode=disable" + ACKIFY_APP_PASSWORD: "${ACKIFY_APP_PASSWORD:-ackify}" depends_on: ackify-db: condition: service_healthy @@ -25,7 +26,7 @@ services: ACKIFY_LOG_LEVEL: "${ACKIFY_LOG_LEVEL}" ACKIFY_BASE_URL: "${ACKIFY_BASE_URL}" ACKIFY_ORGANISATION: "${ACKIFY_ORGANISATION}" - ACKIFY_DB_DSN: "postgres://${POSTGRES_USER}:${POSTGRES_PASSWORD}@ackify-db:5432/${POSTGRES_DB}?sslmode=disable" + ACKIFY_DB_DSN: "postgres://ackify_app:${ACKIFY_APP_PASSWORD}@ackify-db:5432/ackify?sslmode=disable" ACKIFY_OAUTH_PROVIDER: "${ACKIFY_OAUTH_PROVIDER}" ACKIFY_OAUTH_CLIENT_ID: "${ACKIFY_OAUTH_CLIENT_ID}" ACKIFY_OAUTH_CLIENT_SECRET: "${ACKIFY_OAUTH_CLIENT_SECRET}" @@ -59,15 +60,15 @@ services: container_name: ackify-db restart: unless-stopped environment: - POSTGRES_USER: ${POSTGRES_USER} + POSTGRES_USER: postgres POSTGRES_PASSWORD: ${POSTGRES_PASSWORD} - POSTGRES_DB: ${POSTGRES_DB} + POSTGRES_DB: ackify volumes: - ackify_data:/var/lib/postgresql/data networks: - internal healthcheck: - test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}"] + test: ["CMD-SHELL", "pg_isready -U postgres -d ackify"] interval: 10s timeout: 5s retries: 5 diff --git a/docs/en/configuration.md b/docs/en/configuration.md index 79ed877..7b82f7b 100644 --- a/docs/en/configuration.md +++ b/docs/en/configuration.md @@ -173,6 +173,23 @@ ACKIFY_CHECKSUM_INSECURE_SKIP_VERIFY=false These variables disable critical security protections and should **only** be used in isolated test environments. +### Row Level Security (RLS) + +Ackify uses PostgreSQL Row Level Security for tenant data isolation. This is configured automatically during migrations. + +```bash +# Password for the ackify_app database role (required for RLS) +ACKIFY_APP_PASSWORD=your_secure_app_password +``` + +**How it works**: +- The `migrate` tool creates a non-superuser role `ackify_app` before running migrations +- The application connects to PostgreSQL using this role +- RLS policies automatically filter all queries by tenant +- No data leakage possible even if application code forgets tenant filtering + +See [Row Level Security](configuration/rls.md) for detailed documentation. + ## Advanced Configuration ### OAuth2 Providers @@ -200,10 +217,13 @@ ACKIFY_LOG_LEVEL=info ACKIFY_LISTEN_ADDR=:8080 # Database -POSTGRES_USER=ackifyr +POSTGRES_USER=postgres POSTGRES_PASSWORD=super_secure_password_123 POSTGRES_DB=ackify +# RLS (Row Level Security) +ACKIFY_APP_PASSWORD=another_secure_password_456 + # OAuth2 (Google) ACKIFY_OAUTH_PROVIDER=google ACKIFY_OAUTH_CLIENT_ID=123456789-abc.apps.googleusercontent.com @@ -258,6 +278,7 @@ curl http://localhost:8080/api/v1/health - ✅ Use HTTPS (`ACKIFY_BASE_URL=https://...`) - ✅ Generate strong secrets (64+ characters) +- ✅ Configure RLS with strong password (`ACKIFY_APP_PASSWORD`) - ✅ Restrict OAuth domain (`ACKIFY_OAUTH_ALLOWED_DOMAIN`) - ✅ Configure admin emails (`ACKIFY_ADMIN_EMAILS`) - ✅ Use PostgreSQL with SSL in production diff --git a/docs/en/configuration/rls.md b/docs/en/configuration/rls.md new file mode 100644 index 0000000..1d5532b --- /dev/null +++ b/docs/en/configuration/rls.md @@ -0,0 +1,188 @@ +# Row Level Security (RLS) + +PostgreSQL Row Level Security provides automatic tenant data isolation at the database level. + +## Overview + +RLS ensures that each tenant can only access their own data, regardless of how the application queries the database. This is a critical security feature for multi-tenant deployments. + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Request Flow │ +├─────────────────────────────────────────────────────────────────┤ +│ 1. HTTP Request arrives │ +│ 2. RLS Middleware starts a transaction │ +│ 3. Middleware sets: SET app.tenant_id = '' │ +│ 4. All queries automatically filtered by tenant_id │ +│ 5. Transaction committed on success, rolled back on error │ +└─────────────────────────────────────────────────────────────────┘ +``` + +## Configuration + +### Required Variable + +```bash +# Password for the ackify_app database role +ACKIFY_APP_PASSWORD=your_secure_password +``` + +### How It Works + +1. **During migration** (`migrate up`): + - The migrate tool reads `ACKIFY_APP_PASSWORD` + - Creates the `ackify_app` role if it doesn't exist + - Updates the password if the role already exists + - Runs migrations that enable RLS policies + +2. **At runtime**: + - Application connects as `ackify_app` (not `postgres`) + - RLS policies filter all queries by `tenant_id` + - No data leakage possible + +### compose.yml Configuration + +```yaml +services: + ackify-migrate: + environment: + # Superuser connection for migrations + ACKIFY_DB_DSN: "postgres://postgres:${POSTGRES_PASSWORD}@db:5432/ackify?sslmode=disable" + # Password for ackify_app role creation + ACKIFY_APP_PASSWORD: "${ACKIFY_APP_PASSWORD}" + + ackify-ce: + environment: + # Application connects with ackify_app role (RLS enforced) + ACKIFY_DB_DSN: "postgres://ackify_app:${ACKIFY_APP_PASSWORD}@db:5432/ackify?sslmode=disable" +``` + +## Security Benefits + +### Automatic Filtering + +Without RLS, application code must always include tenant filtering: + +```sql +-- Without RLS: Easy to forget tenant_id filter +SELECT * FROM documents WHERE doc_id = '123'; -- BUG: Returns any tenant's data! +``` + +With RLS, filtering is automatic: + +```sql +-- With RLS: Database enforces tenant isolation +SELECT * FROM documents WHERE doc_id = '123'; -- Only returns current tenant's data +``` + +### Defense in Depth + +Even if application code has a bug that forgets tenant filtering, RLS prevents data leakage at the database level. + +## Tables with RLS + +RLS policies are applied to all tenant-aware tables: + +| Table | Policy | +|-------|--------| +| `documents` | `tenant_id = current_tenant_id()` | +| `signatures` | `tenant_id = current_tenant_id()` | +| `expected_signers` | `tenant_id = current_tenant_id()` | +| `webhooks` | `tenant_id = current_tenant_id()` | +| `reminder_logs` | `tenant_id = current_tenant_id()` | +| `email_queue` | `tenant_id = current_tenant_id()` | +| `checksum_verifications` | `tenant_id = current_tenant_id()` | +| `webhook_deliveries` | `tenant_id = current_tenant_id()` | +| `oauth_sessions` | `tenant_id = current_tenant_id()` | +| `magic_link_tokens` | `tenant_id IS NULL OR tenant_id = current_tenant_id()` | +| `magic_link_auth_attempts` | `tenant_id IS NULL OR tenant_id = current_tenant_id()` | + +## Troubleshooting + +### Empty Results When Querying Directly + +If you connect to the database with `psql` and get empty results: + +```sql +-- This returns 0 rows because app.tenant_id is not set +SELECT COUNT(*) FROM documents; +``` + +**Solution**: Set the tenant context first: + +```sql +-- Option 1: Session-level (persists until disconnect) +SELECT set_config('app.tenant_id', 'your-tenant-uuid', false); + +-- Option 2: Transaction-level +BEGIN; +SELECT set_config('app.tenant_id', 'your-tenant-uuid', true); +SELECT * FROM documents; +COMMIT; +``` + +### Superuser Bypasses RLS + +If you connect as `postgres` (superuser), RLS is bypassed: + +```sql +-- As postgres: Returns ALL data (no RLS filtering) +SELECT COUNT(*) FROM documents; +``` + +This is by design. Use `ackify_app` for application connections. + +### Migration Fails with "role does not exist" + +If migrations fail because `ackify_app` doesn't exist: + +1. Ensure `ACKIFY_APP_PASSWORD` is set +2. Check migrate tool logs for warnings +3. Verify the migrate tool runs before migrations + +## Manual Role Management + +In rare cases, you may need to manage the role manually: + +```sql +-- Create role (if not using migrate tool) +CREATE ROLE ackify_app WITH + LOGIN + PASSWORD 'your_password' + NOCREATEDB + NOCREATEROLE + NOINHERIT; + +-- Grant permissions +GRANT CONNECT ON DATABASE ackify TO ackify_app; +GRANT USAGE ON SCHEMA public TO ackify_app; +GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA public TO ackify_app; +GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public TO ackify_app; + +-- Change password +ALTER ROLE ackify_app WITH PASSWORD 'new_password'; +``` + +## Testing RLS + +To verify RLS is working correctly: + +```bash +# Connect as ackify_app +psql -U ackify_app -d ackify + +# Without tenant context - should return 0 rows +SELECT COUNT(*) FROM documents; + +# With tenant context - should return tenant's rows +SELECT set_config('app.tenant_id', '', false); +SELECT COUNT(*) FROM documents; +``` + +## Best Practices + +1. **Always use strong passwords** for `ACKIFY_APP_PASSWORD` +2. **Never connect as superuser** from the application +3. **Use SSL** for database connections in production +4. **Rotate passwords** periodically +5. **Monitor** for failed authentication attempts diff --git a/docs/fr/configuration.md b/docs/fr/configuration.md index 39c2b10..663f8aa 100644 --- a/docs/fr/configuration.md +++ b/docs/fr/configuration.md @@ -173,6 +173,23 @@ ACKIFY_CHECKSUM_INSECURE_SKIP_VERIFY=false Ces variables désactivent des protections de sécurité critiques et ne doivent être utilisées **que** dans des environnements de test isolés. +### Row Level Security (RLS) + +Ackify utilise PostgreSQL Row Level Security pour l'isolation des données par tenant. Ceci est configuré automatiquement lors des migrations. + +```bash +# Mot de passe pour le rôle base de données ackify_app (requis pour RLS) +ACKIFY_APP_PASSWORD=your_secure_app_password +``` + +**Fonctionnement** : +- L'outil `migrate` crée un rôle non-superuser `ackify_app` avant d'exécuter les migrations +- L'application se connecte à PostgreSQL avec ce rôle +- Les policies RLS filtrent automatiquement toutes les requêtes par tenant +- Aucune fuite de données possible même si le code applicatif oublie le filtrage tenant + +Voir [Row Level Security](configuration/rls.md) pour la documentation détaillée. + ## Configuration Avancée ### OAuth2 Providers @@ -200,10 +217,13 @@ ACKIFY_LOG_LEVEL=info ACKIFY_LISTEN_ADDR=:8080 # Base de données -POSTGRES_USER=ackifyr +POSTGRES_USER=postgres POSTGRES_PASSWORD=super_secure_password_123 POSTGRES_DB=ackify +# RLS (Row Level Security) +ACKIFY_APP_PASSWORD=another_secure_password_456 + # OAuth2 (Google) ACKIFY_OAUTH_PROVIDER=google ACKIFY_OAUTH_CLIENT_ID=123456789-abc.apps.googleusercontent.com @@ -258,6 +278,7 @@ curl http://localhost:8080/api/v1/health - ✅ Utiliser HTTPS (`ACKIFY_BASE_URL=https://...`) - ✅ Générer des secrets forts (64+ caractères) +- ✅ Configurer RLS avec un mot de passe fort (`ACKIFY_APP_PASSWORD`) - ✅ Restreindre le domaine OAuth (`ACKIFY_OAUTH_ALLOWED_DOMAIN`) - ✅ Configurer les emails admin (`ACKIFY_ADMIN_EMAILS`) - ✅ Utiliser PostgreSQL avec SSL en production diff --git a/docs/fr/configuration/rls.md b/docs/fr/configuration/rls.md new file mode 100644 index 0000000..89954cd --- /dev/null +++ b/docs/fr/configuration/rls.md @@ -0,0 +1,188 @@ +# Row Level Security (RLS) + +PostgreSQL Row Level Security fournit une isolation automatique des données par tenant au niveau de la base de données. + +## Vue d'ensemble + +RLS garantit que chaque tenant ne peut accéder qu'à ses propres données, peu importe comment l'application interroge la base de données. C'est une fonctionnalité de sécurité critique pour les déploiements multi-tenant. + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Flux de requête │ +├─────────────────────────────────────────────────────────────────┤ +│ 1. Requête HTTP arrive │ +│ 2. Le middleware RLS démarre une transaction │ +│ 3. Le middleware définit : SET app.tenant_id = '' │ +│ 4. Toutes les requêtes filtrées automatiquement par tenant_id │ +│ 5. Transaction validée si succès, annulée si erreur │ +└─────────────────────────────────────────────────────────────────┘ +``` + +## Configuration + +### Variable requise + +```bash +# Mot de passe pour le rôle base de données ackify_app +ACKIFY_APP_PASSWORD=your_secure_password +``` + +### Fonctionnement + +1. **Pendant la migration** (`migrate up`) : + - L'outil migrate lit `ACKIFY_APP_PASSWORD` + - Crée le rôle `ackify_app` s'il n'existe pas + - Met à jour le mot de passe si le rôle existe déjà + - Exécute les migrations qui activent les policies RLS + +2. **À l'exécution** : + - L'application se connecte en tant que `ackify_app` (pas `postgres`) + - Les policies RLS filtrent toutes les requêtes par `tenant_id` + - Aucune fuite de données possible + +### Configuration compose.yml + +```yaml +services: + ackify-migrate: + environment: + # Connexion superuser pour les migrations + ACKIFY_DB_DSN: "postgres://postgres:${POSTGRES_PASSWORD}@db:5432/ackify?sslmode=disable" + # Mot de passe pour la création du rôle ackify_app + ACKIFY_APP_PASSWORD: "${ACKIFY_APP_PASSWORD}" + + ackify-ce: + environment: + # L'application se connecte avec le rôle ackify_app (RLS appliqué) + ACKIFY_DB_DSN: "postgres://ackify_app:${ACKIFY_APP_PASSWORD}@db:5432/ackify?sslmode=disable" +``` + +## Avantages sécurité + +### Filtrage automatique + +Sans RLS, le code applicatif doit toujours inclure le filtrage tenant : + +```sql +-- Sans RLS : Facile d'oublier le filtre tenant_id +SELECT * FROM documents WHERE doc_id = '123'; -- BUG : Retourne les données de n'importe quel tenant ! +``` + +Avec RLS, le filtrage est automatique : + +```sql +-- Avec RLS : La base de données impose l'isolation tenant +SELECT * FROM documents WHERE doc_id = '123'; -- Retourne uniquement les données du tenant courant +``` + +### Défense en profondeur + +Même si le code applicatif contient un bug qui oublie le filtrage tenant, RLS empêche les fuites de données au niveau base de données. + +## Tables avec RLS + +Les policies RLS sont appliquées à toutes les tables tenant-aware : + +| Table | Policy | +|-------|--------| +| `documents` | `tenant_id = current_tenant_id()` | +| `signatures` | `tenant_id = current_tenant_id()` | +| `expected_signers` | `tenant_id = current_tenant_id()` | +| `webhooks` | `tenant_id = current_tenant_id()` | +| `reminder_logs` | `tenant_id = current_tenant_id()` | +| `email_queue` | `tenant_id = current_tenant_id()` | +| `checksum_verifications` | `tenant_id = current_tenant_id()` | +| `webhook_deliveries` | `tenant_id = current_tenant_id()` | +| `oauth_sessions` | `tenant_id = current_tenant_id()` | +| `magic_link_tokens` | `tenant_id IS NULL OR tenant_id = current_tenant_id()` | +| `magic_link_auth_attempts` | `tenant_id IS NULL OR tenant_id = current_tenant_id()` | + +## Dépannage + +### Résultats vides lors de requêtes directes + +Si vous vous connectez à la base avec `psql` et obtenez des résultats vides : + +```sql +-- Retourne 0 lignes car app.tenant_id n'est pas défini +SELECT COUNT(*) FROM documents; +``` + +**Solution** : Définir le contexte tenant d'abord : + +```sql +-- Option 1 : Niveau session (persiste jusqu'à déconnexion) +SELECT set_config('app.tenant_id', 'votre-tenant-uuid', false); + +-- Option 2 : Niveau transaction +BEGIN; +SELECT set_config('app.tenant_id', 'votre-tenant-uuid', true); +SELECT * FROM documents; +COMMIT; +``` + +### Le superuser contourne RLS + +Si vous vous connectez en tant que `postgres` (superuser), RLS est contourné : + +```sql +-- En tant que postgres : Retourne TOUTES les données (pas de filtrage RLS) +SELECT COUNT(*) FROM documents; +``` + +C'est voulu. Utilisez `ackify_app` pour les connexions applicatives. + +### La migration échoue avec "role does not exist" + +Si les migrations échouent parce que `ackify_app` n'existe pas : + +1. Vérifiez que `ACKIFY_APP_PASSWORD` est défini +2. Consultez les logs du migrate tool pour les warnings +3. Vérifiez que le migrate tool s'exécute avant les migrations + +## Gestion manuelle du rôle + +Dans de rares cas, vous pourriez avoir besoin de gérer le rôle manuellement : + +```sql +-- Créer le rôle (si vous n'utilisez pas le migrate tool) +CREATE ROLE ackify_app WITH + LOGIN + PASSWORD 'votre_mot_de_passe' + NOCREATEDB + NOCREATEROLE + NOINHERIT; + +-- Accorder les permissions +GRANT CONNECT ON DATABASE ackify TO ackify_app; +GRANT USAGE ON SCHEMA public TO ackify_app; +GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA public TO ackify_app; +GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public TO ackify_app; + +-- Changer le mot de passe +ALTER ROLE ackify_app WITH PASSWORD 'nouveau_mot_de_passe'; +``` + +## Tester RLS + +Pour vérifier que RLS fonctionne correctement : + +```bash +# Se connecter en tant que ackify_app +psql -U ackify_app -d ackify + +# Sans contexte tenant - devrait retourner 0 lignes +SELECT COUNT(*) FROM documents; + +# Avec contexte tenant - devrait retourner les lignes du tenant +SELECT set_config('app.tenant_id', '', false); +SELECT COUNT(*) FROM documents; +``` + +## Bonnes pratiques + +1. **Toujours utiliser des mots de passe forts** pour `ACKIFY_APP_PASSWORD` +2. **Ne jamais se connecter en superuser** depuis l'application +3. **Utiliser SSL** pour les connexions base de données en production +4. **Faire tourner les mots de passe** périodiquement +5. **Surveiller** les tentatives d'authentification échouées diff --git a/run-tests-suite.sh b/run-tests-suite.sh index bfdef71..d46c8a9 100755 --- a/run-tests-suite.sh +++ b/run-tests-suite.sh @@ -144,6 +144,7 @@ else # Run migrations echo -e "${YELLOW}📝 Running database migrations...${NC}" export ACKIFY_DB_DSN="postgres://postgres:testpassword@localhost:5432/ackify_test?sslmode=disable" + export ACKIFY_APP_PASSWORD="ackifytestpassword" cd "$PROJECT_ROOT" if go run ./backend/cmd/migrate/main.go -migrations-path file://backend/migrations up; then echo -e "${GREEN}✓ Migrations applied${NC}" diff --git a/webapp/cypress/e2e/13-embed-page.cy.ts b/webapp/cypress/e2e/13-embed-page.cy.ts index 33d3163..d903287 100644 --- a/webapp/cypress/e2e/13-embed-page.cy.ts +++ b/webapp/cypress/e2e/13-embed-page.cy.ts @@ -12,7 +12,8 @@ describe('Test 13: Embed Page Functionality', () => { it('should display embed page with no signatures state', () => { // Step 1: Visit embed page with new document (force English locale) - cy.visitWithLocale(`/embed?doc=${sharedDocId}`, 'en') + cy.loginViaMagicLink('embed-user1@test.com') + cy.visitWithLocale(`/embed?doc=${sharedDocId}`, 'en') // Step 2: Should load without authentication cy.url({ timeout: 10000 }).should('include', '/embed')