From 998d227898e3f2d0ce14652a255c94234879ec2d Mon Sep 17 00:00:00 2001 From: Benjamin Date: Sat, 17 Jan 2026 01:49:41 +0100 Subject: [PATCH] refacto(backend): cleaning dead code --- .../application/services/authorizer.go | 42 -- .../application/services/checksum_service.go | 218 ------- .../services/checksum_service_test.go | 483 ---------------- .../internal/application/services/reminder.go | 238 -------- .../application/services/reminder_async.go | 9 + .../application/services/reminder_test.go | 536 ------------------ .../application/services/signature_test.go | 58 ++ .../infrastructure/auth/oauth_provider.go | 393 ------------- .../auth/oauth_provider_test.go | 422 -------------- .../internal/infrastructure/email/helpers.go | 37 -- .../infrastructure/email/helpers_test.go | 268 --------- .../infrastructure/email/queue_helpers.go | 113 ---- .../presentation/api/shared/response.go | 5 - .../presentation/api/shared/response_test.go | 16 - .../internal/presentation/api/shared/utils.go | 33 -- .../presentation/api/users/handler.go | 2 + .../internal/presentation/handlers/errors.go | 39 -- .../presentation/handlers/handlers_test.go | 172 ------ .../presentation/handlers/middleware.go | 32 +- .../internal/presentation/handlers/oembed.go | 27 - backend/pkg/types/user.go | 7 +- backend/pkg/web/auth/dynamic_provider.go | 9 + backend/pkg/web/server.go | 16 - webapp/src/components/DocumentCreateForm.vue | 8 +- webapp/src/components/layout/AppHeader.vue | 14 +- webapp/src/stores/auth.ts | 1 + 26 files changed, 106 insertions(+), 3092 deletions(-) delete mode 100644 backend/internal/application/services/authorizer.go delete mode 100644 backend/internal/application/services/checksum_service.go delete mode 100644 backend/internal/application/services/checksum_service_test.go delete mode 100644 backend/internal/application/services/reminder.go delete mode 100644 backend/internal/application/services/reminder_test.go delete mode 100644 backend/internal/infrastructure/auth/oauth_provider.go delete mode 100644 backend/internal/infrastructure/auth/oauth_provider_test.go delete mode 100644 backend/internal/infrastructure/email/helpers.go delete mode 100644 backend/internal/infrastructure/email/helpers_test.go delete mode 100644 backend/internal/infrastructure/email/queue_helpers.go delete mode 100644 backend/internal/presentation/api/shared/utils.go delete mode 100644 backend/internal/presentation/handlers/errors.go diff --git a/backend/internal/application/services/authorizer.go b/backend/internal/application/services/authorizer.go deleted file mode 100644 index 98a5183..0000000 --- a/backend/internal/application/services/authorizer.go +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later -package services - -import ( - "strings" -) - -// AuthorizerService provides authorization decisions for the organization. -// In Community Edition, it uses environment variables for configuration. -type AuthorizerService struct { - adminEmails []string - onlyAdminCanCreate bool -} - -// NewAuthorizerService creates a new AuthorizerService with the given configuration. -func NewAuthorizerService(adminEmails []string, onlyAdminCanCreate bool) *AuthorizerService { - // Normalize admin emails to lowercase for case-insensitive comparison - normalized := make([]string, len(adminEmails)) - for i, email := range adminEmails { - normalized[i] = strings.ToLower(strings.TrimSpace(email)) - } - return &AuthorizerService{ - adminEmails: normalized, - onlyAdminCanCreate: onlyAdminCanCreate, - } -} - -// IsAdmin checks if the given email belongs to an administrator. -func (s *AuthorizerService) IsAdmin(email string) bool { - normalizedEmail := strings.ToLower(strings.TrimSpace(email)) - for _, adminEmail := range s.adminEmails { - if normalizedEmail == adminEmail { - return true - } - } - return false -} - -// OnlyAdminCanCreate returns whether only administrators can create documents. -func (s *AuthorizerService) OnlyAdminCanCreate() bool { - return s.onlyAdminCanCreate -} diff --git a/backend/internal/application/services/checksum_service.go b/backend/internal/application/services/checksum_service.go deleted file mode 100644 index 125471d..0000000 --- a/backend/internal/application/services/checksum_service.go +++ /dev/null @@ -1,218 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later -package services - -import ( - "context" - "fmt" - "regexp" - "strings" - "time" - - "github.com/btouchard/ackify-ce/backend/pkg/logger" - "github.com/btouchard/ackify-ce/backend/pkg/models" -) - -// ChecksumVerificationRepository defines the interface for checksum verification persistence -type ChecksumVerificationRepository interface { - RecordVerification(ctx context.Context, verification *models.ChecksumVerification) error - GetVerificationHistory(ctx context.Context, docID string, limit int) ([]*models.ChecksumVerification, error) - GetLastVerification(ctx context.Context, docID string) (*models.ChecksumVerification, error) -} - -// DocumentRepository defines the interface for document metadata operations -type DocumentRepository interface { - GetByDocID(ctx context.Context, docID string) (*models.Document, error) -} - -// ChecksumService orchestrates document integrity verification with audit trail persistence -type ChecksumService struct { - verificationRepo ChecksumVerificationRepository - documentRepo DocumentRepository -} - -// NewChecksumService initializes checksum verification service with required repository dependencies -func NewChecksumService( - verificationRepo ChecksumVerificationRepository, - documentRepo DocumentRepository, -) *ChecksumService { - return &ChecksumService{ - verificationRepo: verificationRepo, - documentRepo: documentRepo, - } -} - -// ValidateChecksumFormat ensures checksum matches expected hexadecimal length for SHA-256/SHA-512/MD5 -func (s *ChecksumService) ValidateChecksumFormat(checksum, algorithm string) error { - // Remove common separators and whitespace - checksum = normalizeChecksum(checksum) - - var expectedLength int - switch algorithm { - case "SHA-256": - expectedLength = 64 - case "SHA-512": - expectedLength = 128 - case "MD5": - expectedLength = 32 - default: - return fmt.Errorf("unsupported algorithm: %s", algorithm) - } - - // Check length - if len(checksum) != expectedLength { - return fmt.Errorf("invalid checksum length for %s: expected %d hexadecimal characters, got %d", algorithm, expectedLength, len(checksum)) - } - - // Check if it's a valid hex string - hexPattern := regexp.MustCompile("^[a-fA-F0-9]+$") - if !hexPattern.MatchString(checksum) { - return fmt.Errorf("invalid checksum format: must contain only hexadecimal characters (0-9, a-f, A-F)") - } - - return nil -} - -// VerifyChecksum compares calculated hash against stored reference and creates immutable audit record -func (s *ChecksumService) VerifyChecksum(ctx context.Context, docID, calculatedChecksum, verifiedBy string) (*models.ChecksumVerificationResult, error) { - // Get document metadata - doc, err := s.documentRepo.GetByDocID(ctx, docID) - if err != nil { - return nil, fmt.Errorf("failed to get document: %w", err) - } - - if doc == nil { - return nil, fmt.Errorf("document not found: %s", docID) - } - - // Normalize checksums for comparison - normalizedCalculated := normalizeChecksum(calculatedChecksum) - normalizedStored := normalizeChecksum(doc.Checksum) - - // Determine the algorithm to use (from document or default to SHA-256) - algorithm := doc.ChecksumAlgorithm - if algorithm == "" { - algorithm = "SHA-256" - } - - // Validate the calculated checksum format - if err := s.ValidateChecksumFormat(normalizedCalculated, algorithm); err != nil { - // Record failed verification with error - errorMsg := err.Error() - verification := &models.ChecksumVerification{ - DocID: docID, - VerifiedBy: verifiedBy, - VerifiedAt: time.Now(), - StoredChecksum: normalizedStored, - CalculatedChecksum: normalizedCalculated, - Algorithm: algorithm, - IsValid: false, - ErrorMessage: &errorMsg, - } - _ = s.verificationRepo.RecordVerification(ctx, verification) - - return nil, fmt.Errorf("invalid checksum format: %w", err) - } - - // Check if document has a reference checksum - if !doc.HasChecksum() { - result := &models.ChecksumVerificationResult{ - Valid: false, - StoredChecksum: "", - CalculatedChecksum: normalizedCalculated, - Algorithm: algorithm, - Message: "No reference checksum configured for this document", - HasReferenceHash: false, - } - return result, nil - } - - // Compare checksums (case-insensitive) - isValid := strings.EqualFold(normalizedCalculated, normalizedStored) - - // Record verification - verification := &models.ChecksumVerification{ - DocID: docID, - VerifiedBy: verifiedBy, - VerifiedAt: time.Now(), - StoredChecksum: normalizedStored, - CalculatedChecksum: normalizedCalculated, - Algorithm: algorithm, - IsValid: isValid, - ErrorMessage: nil, - } - - if err := s.verificationRepo.RecordVerification(ctx, verification); err != nil { - logger.Logger.Error("Failed to record verification", "error", err.Error(), "doc_id", docID) - // Continue even if recording fails - return the result - } - - var message string - if isValid { - message = "Checksums match - document integrity verified" - } else { - message = "Checksums do not match - document may have been modified" - } - - result := &models.ChecksumVerificationResult{ - Valid: isValid, - StoredChecksum: normalizedStored, - CalculatedChecksum: normalizedCalculated, - Algorithm: algorithm, - Message: message, - HasReferenceHash: true, - } - - return result, nil -} - -// GetVerificationHistory retrieves paginated audit trail of all checksum validation attempts -func (s *ChecksumService) GetVerificationHistory(ctx context.Context, docID string, limit int) ([]*models.ChecksumVerification, error) { - if limit <= 0 { - limit = 20 - } - - return s.verificationRepo.GetVerificationHistory(ctx, docID, limit) -} - -// GetSupportedAlgorithms returns available hash algorithms for client-side documentation -func (s *ChecksumService) GetSupportedAlgorithms() []string { - return []string{"SHA-256", "SHA-512", "MD5"} -} - -// GetChecksumInfo exposes document hash metadata for public verification interfaces -func (s *ChecksumService) GetChecksumInfo(ctx context.Context, docID string) (map[string]interface{}, error) { - doc, err := s.documentRepo.GetByDocID(ctx, docID) - if err != nil { - return nil, fmt.Errorf("failed to get document: %w", err) - } - - if doc == nil { - return nil, fmt.Errorf("document not found: %s", docID) - } - - algorithm := doc.ChecksumAlgorithm - if algorithm == "" { - algorithm = "SHA-256" - } - - info := map[string]interface{}{ - "doc_id": docID, - "has_checksum": doc.HasChecksum(), - "algorithm": algorithm, - "checksum_length": doc.GetExpectedChecksumLength(), - "supported_algorithms": s.GetSupportedAlgorithms(), - } - - return info, nil -} - -// normalizeChecksum removes common separators and converts to lowercase -func normalizeChecksum(checksum string) string { - // Remove spaces, hyphens, underscores - checksum = strings.ReplaceAll(checksum, " ", "") - checksum = strings.ReplaceAll(checksum, "-", "") - checksum = strings.ReplaceAll(checksum, "_", "") - checksum = strings.TrimSpace(checksum) - // Convert to lowercase for case-insensitive comparison - return strings.ToLower(checksum) -} diff --git a/backend/internal/application/services/checksum_service_test.go b/backend/internal/application/services/checksum_service_test.go deleted file mode 100644 index 0f649ff..0000000 --- a/backend/internal/application/services/checksum_service_test.go +++ /dev/null @@ -1,483 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later -package services - -import ( - "context" - "errors" - "testing" - "time" - - "github.com/btouchard/ackify-ce/backend/pkg/models" -) - -type fakeVerificationRepository struct { - verifications []*models.ChecksumVerification - shouldFailRecord bool - shouldFailGetHistory bool - shouldFailGetLast bool -} - -func newFakeVerificationRepository() *fakeVerificationRepository { - return &fakeVerificationRepository{ - verifications: make([]*models.ChecksumVerification, 0), - } -} - -func (f *fakeVerificationRepository) RecordVerification(_ context.Context, verification *models.ChecksumVerification) error { - if f.shouldFailRecord { - return errors.New("repository record failed") - } - - verification.ID = int64(len(f.verifications) + 1) - f.verifications = append(f.verifications, verification) - return nil -} - -func (f *fakeVerificationRepository) GetVerificationHistory(_ context.Context, docID string, limit int) ([]*models.ChecksumVerification, error) { - if f.shouldFailGetHistory { - return nil, errors.New("repository get history failed") - } - - var result []*models.ChecksumVerification - for _, v := range f.verifications { - if v.DocID == docID { - result = append(result, v) - if len(result) >= limit { - break - } - } - } - - return result, nil -} - -func (f *fakeVerificationRepository) GetLastVerification(_ context.Context, docID string) (*models.ChecksumVerification, error) { - if f.shouldFailGetLast { - return nil, errors.New("repository get last failed") - } - - for i := len(f.verifications) - 1; i >= 0; i-- { - if f.verifications[i].DocID == docID { - return f.verifications[i], nil - } - } - - return nil, nil -} - -type fakeDocumentRepository struct { - documents map[string]*models.Document - shouldFailGet bool -} - -func newFakeDocumentRepository() *fakeDocumentRepository { - return &fakeDocumentRepository{ - documents: make(map[string]*models.Document), - } -} - -func (f *fakeDocumentRepository) GetByDocID(_ context.Context, docID string) (*models.Document, error) { - if f.shouldFailGet { - return nil, errors.New("repository get failed") - } - - doc, exists := f.documents[docID] - if !exists { - return nil, nil - } - - return doc, nil -} - -func (f *fakeDocumentRepository) Create(_ context.Context, docID string, input models.DocumentInput, createdBy string) (*models.Document, error) { - if f.shouldFailGet { - return nil, errors.New("repository create failed") - } - - doc := &models.Document{ - DocID: docID, - Title: input.Title, - URL: input.URL, - Checksum: input.Checksum, - ChecksumAlgorithm: input.ChecksumAlgorithm, - Description: input.Description, - CreatedBy: createdBy, - } - f.documents[docID] = doc - return doc, nil -} - -func (f *fakeDocumentRepository) FindByReference(_ context.Context, ref string, refType string) (*models.Document, error) { - if f.shouldFailGet { - return nil, errors.New("repository find failed") - } - - for _, doc := range f.documents { - if doc.URL == ref { - return doc, nil - } - } - - return nil, nil -} - -func (f *fakeDocumentRepository) List(_ context.Context, _, _ int) ([]*models.Document, error) { - result := make([]*models.Document, 0, len(f.documents)) - for _, doc := range f.documents { - result = append(result, doc) - } - return result, nil -} - -func (f *fakeDocumentRepository) Search(_ context.Context, _ string, _, _ int) ([]*models.Document, error) { - return []*models.Document{}, nil -} - -func (f *fakeDocumentRepository) Count(_ context.Context, _ string) (int, error) { - return len(f.documents), nil -} - -func (f *fakeDocumentRepository) ListByCreatedBy(_ context.Context, _ string, _, _ int) ([]*models.Document, error) { - return []*models.Document{}, nil -} - -func (f *fakeDocumentRepository) SearchByCreatedBy(_ context.Context, _, _ string, _, _ int) ([]*models.Document, error) { - return []*models.Document{}, nil -} - -func (f *fakeDocumentRepository) CountByCreatedBy(_ context.Context, _, _ string) (int, error) { - return 0, nil -} - -func TestChecksumService_ValidateChecksumFormat(t *testing.T) { - service := NewChecksumService(newFakeVerificationRepository(), newFakeDocumentRepository()) - - tests := []struct { - name string - checksum string - algorithm string - wantError bool - }{ - { - name: "valid SHA-256", - checksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", - algorithm: "SHA-256", - wantError: false, - }, - { - name: "valid SHA-512", - checksum: "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e", - algorithm: "SHA-512", - wantError: false, - }, - { - name: "valid MD5", - checksum: "d41d8cd98f00b204e9800998ecf8427e", - algorithm: "MD5", - wantError: false, - }, - { - name: "valid with uppercase", - checksum: "E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855", - algorithm: "SHA-256", - wantError: false, - }, - { - name: "valid with spaces", - checksum: "e3b0c442 98fc1c14 9afbf4c8 996fb924 27ae41e4 649b934c a495991b 7852b855", - algorithm: "SHA-256", - wantError: false, - }, - { - name: "valid with hyphens", - checksum: "e3b0c442-98fc1c14-9afbf4c8-996fb924-27ae41e4-649b934c-a495991b-7852b855", - algorithm: "SHA-256", - wantError: false, - }, - { - name: "invalid - too short for SHA-256", - checksum: "abc123", - algorithm: "SHA-256", - wantError: true, - }, - { - name: "invalid - too long for MD5", - checksum: "d41d8cd98f00b204e9800998ecf8427eextra", - algorithm: "MD5", - wantError: true, - }, - { - name: "invalid - non-hex characters", - checksum: "gggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg", - algorithm: "SHA-256", - wantError: true, - }, - { - name: "invalid - unsupported algorithm", - checksum: "abc123", - algorithm: "SHA-1", - wantError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := service.ValidateChecksumFormat(tt.checksum, tt.algorithm) - - if tt.wantError && err == nil { - t.Error("expected error, got nil") - } - if !tt.wantError && err != nil { - t.Errorf("unexpected error: %v", err) - } - }) - } -} - -func TestChecksumService_VerifyChecksum(t *testing.T) { - ctx := context.Background() - - tests := []struct { - name string - docID string - document *models.Document - calculatedChecksum string - verifiedBy string - wantValid bool - wantHasReference bool - wantError bool - }{ - { - name: "valid verification - checksums match", - docID: "doc-001", - document: &models.Document{ - DocID: "doc-001", - Checksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", - ChecksumAlgorithm: "SHA-256", - }, - calculatedChecksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", - verifiedBy: "user@example.com", - wantValid: true, - wantHasReference: true, - wantError: false, - }, - { - name: "invalid verification - checksums differ", - docID: "doc-002", - document: &models.Document{ - DocID: "doc-002", - Checksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", - ChecksumAlgorithm: "SHA-256", - }, - calculatedChecksum: "0000000000000000000000000000000000000000000000000000000000000000", - verifiedBy: "user@example.com", - wantValid: false, - wantHasReference: true, - wantError: false, - }, - { - name: "no reference checksum", - docID: "doc-003", - document: &models.Document{ - DocID: "doc-003", - Checksum: "", - ChecksumAlgorithm: "SHA-256", - }, - calculatedChecksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", - verifiedBy: "user@example.com", - wantValid: false, - wantHasReference: false, - wantError: false, - }, - { - name: "case insensitive comparison", - docID: "doc-004", - document: &models.Document{ - DocID: "doc-004", - Checksum: "E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855", - ChecksumAlgorithm: "SHA-256", - }, - calculatedChecksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", - verifiedBy: "user@example.com", - wantValid: true, - wantHasReference: true, - wantError: false, - }, - { - name: "document not found", - docID: "non-existent", - document: nil, - calculatedChecksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", - verifiedBy: "user@example.com", - wantValid: false, - wantHasReference: false, - wantError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - verificationRepo := newFakeVerificationRepository() - documentRepo := newFakeDocumentRepository() - - if tt.document != nil { - documentRepo.documents[tt.docID] = tt.document - } - - service := NewChecksumService(verificationRepo, documentRepo) - - result, err := service.VerifyChecksum(ctx, tt.docID, tt.calculatedChecksum, tt.verifiedBy) - - if tt.wantError { - if err == nil { - t.Error("expected error, got nil") - } - return - } - - if err != nil { - t.Errorf("unexpected error: %v", err) - return - } - - if result.Valid != tt.wantValid { - t.Errorf("expected Valid=%v, got %v", tt.wantValid, result.Valid) - } - - if result.HasReferenceHash != tt.wantHasReference { - t.Errorf("expected HasReferenceHash=%v, got %v", tt.wantHasReference, result.HasReferenceHash) - } - - // Check that verification was recorded (if document has checksum) - if tt.wantHasReference { - if len(verificationRepo.verifications) != 1 { - t.Errorf("expected 1 verification recorded, got %d", len(verificationRepo.verifications)) - } else { - v := verificationRepo.verifications[0] - if v.IsValid != tt.wantValid { - t.Errorf("recorded verification IsValid=%v, expected %v", v.IsValid, tt.wantValid) - } - if v.VerifiedBy != tt.verifiedBy { - t.Errorf("recorded verification VerifiedBy=%s, expected %s", v.VerifiedBy, tt.verifiedBy) - } - } - } - }) - } -} - -func TestChecksumService_GetVerificationHistory(t *testing.T) { - ctx := context.Background() - verificationRepo := newFakeVerificationRepository() - documentRepo := newFakeDocumentRepository() - service := NewChecksumService(verificationRepo, documentRepo) - - // Add test verifications - for i := 0; i < 5; i++ { - v := &models.ChecksumVerification{ - DocID: "doc-001", - VerifiedBy: "user@example.com", - VerifiedAt: time.Now(), - StoredChecksum: "abc123", - CalculatedChecksum: "abc123", - Algorithm: "SHA-256", - IsValid: true, - } - _ = verificationRepo.RecordVerification(ctx, v) - } - - // Test get all - history, err := service.GetVerificationHistory(ctx, "doc-001", 10) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(history) != 5 { - t.Errorf("expected 5 verifications, got %d", len(history)) - } - - // Test with limit - limited, err := service.GetVerificationHistory(ctx, "doc-001", 2) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(limited) != 2 { - t.Errorf("expected 2 verifications with limit, got %d", len(limited)) - } -} - -func TestChecksumService_GetChecksumInfo(t *testing.T) { - ctx := context.Background() - - tests := []struct { - name string - docID string - document *models.Document - wantError bool - }{ - { - name: "document with checksum", - docID: "doc-001", - document: &models.Document{ - DocID: "doc-001", - Checksum: "abc123", - ChecksumAlgorithm: "SHA-256", - }, - wantError: false, - }, - { - name: "document without checksum", - docID: "doc-002", - document: &models.Document{ - DocID: "doc-002", - Checksum: "", - ChecksumAlgorithm: "SHA-256", - }, - wantError: false, - }, - { - name: "document not found", - docID: "non-existent", - document: nil, - wantError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - documentRepo := newFakeDocumentRepository() - if tt.document != nil { - documentRepo.documents[tt.docID] = tt.document - } - - service := NewChecksumService(newFakeVerificationRepository(), documentRepo) - - info, err := service.GetChecksumInfo(ctx, tt.docID) - - if tt.wantError { - if err == nil { - t.Error("expected error, got nil") - } - return - } - - if err != nil { - t.Errorf("unexpected error: %v", err) - return - } - - if info["doc_id"] != tt.docID { - t.Errorf("expected doc_id %s, got %v", tt.docID, info["doc_id"]) - } - - if _, ok := info["has_checksum"]; !ok { - t.Error("expected has_checksum field") - } - - if _, ok := info["supported_algorithms"]; !ok { - t.Error("expected supported_algorithms field") - } - }) - } -} diff --git a/backend/internal/application/services/reminder.go b/backend/internal/application/services/reminder.go deleted file mode 100644 index 32a5d5c..0000000 --- a/backend/internal/application/services/reminder.go +++ /dev/null @@ -1,238 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later -package services - -import ( - "context" - "fmt" - "time" - - "github.com/btouchard/ackify-ce/backend/internal/infrastructure/email" - "github.com/btouchard/ackify-ce/backend/internal/infrastructure/i18n" - "github.com/btouchard/ackify-ce/backend/pkg/logger" - "github.com/btouchard/ackify-ce/backend/pkg/models" -) - -// expectedSignerRepository defines minimal interface for expected signer operations -type expectedSignerRepository interface { - ListWithStatusByDocID(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) -} - -// reminderRepository defines minimal interface for reminder logging and history -type reminderRepository interface { - LogReminder(ctx context.Context, log *models.ReminderLog) error - GetReminderHistory(ctx context.Context, docID string) ([]*models.ReminderLog, error) - GetReminderStats(ctx context.Context, docID string) (*models.ReminderStats, error) -} - -// magicLinkService defines minimal interface for creating reminder auth tokens -type magicLinkService interface { - CreateReminderAuthToken(ctx context.Context, email string, docID string) (string, error) -} - -// ReminderService manages email notifications to pending signers with delivery tracking -type ReminderService struct { - expectedSignerRepo expectedSignerRepository - reminderRepo reminderRepository - emailSender email.Sender - magicLinkService magicLinkService - i18n *i18n.I18n - baseURL string -} - -// NewReminderService initializes reminder service with email sender and repository dependencies -func NewReminderService( - expectedSignerRepo expectedSignerRepository, - reminderRepo reminderRepository, - emailSender email.Sender, - magicLinkService magicLinkService, - i18nService *i18n.I18n, - baseURL string, -) *ReminderService { - return &ReminderService{ - expectedSignerRepo: expectedSignerRepo, - reminderRepo: reminderRepo, - emailSender: emailSender, - magicLinkService: magicLinkService, - i18n: i18nService, - baseURL: baseURL, - } -} - -// SendReminders dispatches email notifications to all or selected pending signers with result aggregation -func (s *ReminderService) SendReminders( - ctx context.Context, - docID string, - sentBy string, - specificEmails []string, - docURL string, - locale string, -) (*models.ReminderSendResult, error) { - - logger.Logger.Info("Starting reminder sending process", - "doc_id", docID, - "sent_by", sentBy, - "specific_emails_count", len(specificEmails), - "locale", locale) - - allSigners, err := s.expectedSignerRepo.ListWithStatusByDocID(ctx, docID) - if err != nil { - logger.Logger.Error("Failed to get expected signers for reminders", - "doc_id", docID, - "error", err.Error()) - return nil, fmt.Errorf("failed to get expected signers: %w", err) - } - - logger.Logger.Debug("Retrieved expected signers", - "doc_id", docID, - "total_signers", len(allSigners)) - - var pendingSigners []*models.ExpectedSignerWithStatus - for _, signer := range allSigners { - if !signer.HasSigned { - if len(specificEmails) > 0 { - if containsEmail(specificEmails, signer.Email) { - pendingSigners = append(pendingSigners, signer) - } - } else { - pendingSigners = append(pendingSigners, signer) - } - } - } - - logger.Logger.Info("Identified pending signers", - "doc_id", docID, - "pending_count", len(pendingSigners), - "total_signers", len(allSigners)) - - if len(pendingSigners) == 0 { - logger.Logger.Info("No pending signers found, no reminders to send", - "doc_id", docID) - return &models.ReminderSendResult{ - TotalAttempted: 0, - SuccessfullySent: 0, - Failed: 0, - }, nil - } - - result := &models.ReminderSendResult{ - TotalAttempted: len(pendingSigners), - } - - for _, signer := range pendingSigners { - err := s.sendSingleReminder(ctx, docID, signer.Email, signer.Name, sentBy, docURL, locale) - if err != nil { - result.Failed++ - result.Errors = append(result.Errors, fmt.Sprintf("%s: %v", signer.Email, err)) - } else { - result.SuccessfullySent++ - } - } - - logger.Logger.Info("Reminder batch completed", - "doc_id", docID, - "total_attempted", result.TotalAttempted, - "successfully_sent", result.SuccessfullySent, - "failed", result.Failed) - - return result, nil -} - -// sendSingleReminder sends a reminder to a single signer -func (s *ReminderService) sendSingleReminder( - ctx context.Context, - docID string, - recipientEmail string, - recipientName string, - sentBy string, - docURL string, - locale string, -) error { - - logger.Logger.Debug("Sending reminder to signer", - "doc_id", docID, - "recipient_email", recipientEmail, - "recipient_name", recipientName, - "sent_by", sentBy) - - // Générer un token d'authentification pour ce lecteur - token, err := s.magicLinkService.CreateReminderAuthToken(ctx, recipientEmail, docID) - if err != nil { - logger.Logger.Error("Failed to create reminder auth token", - "doc_id", docID, - "recipient_email", recipientEmail, - "error", err.Error()) - return fmt.Errorf("failed to create auth token: %w", err) - } - - // Construire l'URL d'authentification qui redirigera vers la page de signature - authSignURL := fmt.Sprintf("%s/api/v1/auth/reminder-link/verify?token=%s", s.baseURL, token) - - logger.Logger.Debug("Generated auth sign URL for reminder", - "doc_id", docID, - "recipient_email", recipientEmail, - "url", authSignURL) - - log := &models.ReminderLog{ - DocID: docID, - RecipientEmail: recipientEmail, - SentAt: time.Now(), - SentBy: sentBy, - TemplateUsed: "signature_reminder", - Status: "sent", - } - - err = email.SendSignatureReminderEmail(ctx, s.emailSender, s.i18n, []string{recipientEmail}, locale, docID, docURL, authSignURL, recipientName) - if err != nil { - log.Status = "failed" - errMsg := err.Error() - log.ErrorMessage = &errMsg - - logger.Logger.Warn("Failed to send reminder email", - "doc_id", docID, - "recipient_email", recipientEmail, - "error", err.Error()) - - if logErr := s.reminderRepo.LogReminder(ctx, log); logErr != nil { - logger.Logger.Error("Failed to log reminder error", - "doc_id", docID, - "recipient_email", recipientEmail, - "log_error", logErr.Error(), - "original_error", err.Error()) - } - - return fmt.Errorf("failed to send email: %w", err) - } - - logger.Logger.Info("Reminder email sent successfully", - "doc_id", docID, - "recipient_email", recipientEmail) - - if err := s.reminderRepo.LogReminder(ctx, log); err != nil { - logger.Logger.Error("Failed to log successful reminder", - "doc_id", docID, - "recipient_email", recipientEmail, - "error", err.Error()) - return fmt.Errorf("email sent but failed to log: %w", err) - } - - return nil -} - -// GetReminderStats retrieves aggregated reminder metrics for monitoring dashboard -func (s *ReminderService) GetReminderStats(ctx context.Context, docID string) (*models.ReminderStats, error) { - return s.reminderRepo.GetReminderStats(ctx, docID) -} - -// GetReminderHistory retrieves complete email send log with success/failure tracking -func (s *ReminderService) GetReminderHistory(ctx context.Context, docID string) ([]*models.ReminderLog, error) { - return s.reminderRepo.GetReminderHistory(ctx, docID) -} - -func containsEmail(slice []string, item string) bool { - for _, s := range slice { - if s == item { - return true - } - } - return false -} diff --git a/backend/internal/application/services/reminder_async.go b/backend/internal/application/services/reminder_async.go index d3020b9..b83f493 100644 --- a/backend/internal/application/services/reminder_async.go +++ b/backend/internal/application/services/reminder_async.go @@ -315,3 +315,12 @@ func (s *ReminderAsyncService) CountSent(ctx context.Context) int { } return c } + +func containsEmail(emails []string, target string) bool { + for _, e := range emails { + if e == target { + return true + } + } + return false +} diff --git a/backend/internal/application/services/reminder_test.go b/backend/internal/application/services/reminder_test.go deleted file mode 100644 index c98a1e2..0000000 --- a/backend/internal/application/services/reminder_test.go +++ /dev/null @@ -1,536 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later -package services - -import ( - "context" - "errors" - "testing" - "time" - - "github.com/btouchard/ackify-ce/backend/internal/infrastructure/email" - "github.com/btouchard/ackify-ce/backend/pkg/models" -) - -// Mock implementations for testing -type mockExpectedSignerRepository struct { - listWithStatusFunc func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) -} - -func (m *mockExpectedSignerRepository) ListWithStatusByDocID(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) { - if m.listWithStatusFunc != nil { - return m.listWithStatusFunc(ctx, docID) - } - return nil, nil -} - -type mockReminderRepository struct { - logReminderFunc func(ctx context.Context, log *models.ReminderLog) error - getReminderHistoryFunc func(ctx context.Context, docID string) ([]*models.ReminderLog, error) - getReminderStatsFunc func(ctx context.Context, docID string) (*models.ReminderStats, error) -} - -func (m *mockReminderRepository) LogReminder(ctx context.Context, log *models.ReminderLog) error { - if m.logReminderFunc != nil { - return m.logReminderFunc(ctx, log) - } - return nil -} - -func (m *mockReminderRepository) GetReminderHistory(ctx context.Context, docID string) ([]*models.ReminderLog, error) { - if m.getReminderHistoryFunc != nil { - return m.getReminderHistoryFunc(ctx, docID) - } - return nil, nil -} - -func (m *mockReminderRepository) GetReminderStats(ctx context.Context, docID string) (*models.ReminderStats, error) { - if m.getReminderStatsFunc != nil { - return m.getReminderStatsFunc(ctx, docID) - } - return nil, nil -} - -type mockEmailSender struct { - sendFunc func(ctx context.Context, msg email.Message) error -} - -func (m *mockEmailSender) Send(ctx context.Context, msg email.Message) error { - if m.sendFunc != nil { - return m.sendFunc(ctx, msg) - } - return nil -} - -type mockMagicLinkService struct { - createReminderAuthTokenFunc func(ctx context.Context, email string, docID string) (string, error) -} - -func (m *mockMagicLinkService) CreateReminderAuthToken(ctx context.Context, email string, docID string) (string, error) { - if m.createReminderAuthTokenFunc != nil { - return m.createReminderAuthTokenFunc(ctx, email, docID) - } - return "mock-token-123", nil -} - -// Test helper function -func TestContainsEmail(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - slice []string - item string - expected bool - }{ - { - name: "Email found", - slice: []string{"alice@example.com", "bob@example.com", "charlie@example.com"}, - item: "bob@example.com", - expected: true, - }, - { - name: "Email not found", - slice: []string{"alice@example.com", "bob@example.com"}, - item: "charlie@example.com", - expected: false, - }, - { - name: "Empty slice", - slice: []string{}, - item: "test@example.com", - expected: false, - }, - { - name: "Case sensitive", - slice: []string{"Test@Example.com"}, - item: "test@example.com", - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - result := containsEmail(tt.slice, tt.item) - if result != tt.expected { - t.Errorf("containsEmail(%v, %q) = %v, want %v", tt.slice, tt.item, result, tt.expected) - } - }) - } -} - -// Test SendReminders with no pending signers -func TestReminderService_SendReminders_NoPendingSigners(t *testing.T) { - t.Parallel() - ctx := context.Background() - - mockExpectedRepo := &mockExpectedSignerRepository{ - listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) { - return []*models.ExpectedSignerWithStatus{ - {ExpectedSigner: models.ExpectedSigner{Email: "signed@example.com"}, HasSigned: true}, - }, nil - }, - } - - mockReminderRepo := &mockReminderRepository{} - mockEmailSender := &mockEmailSender{} - mockMagicLink := &mockMagicLinkService{} - - service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com") - - result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en") - - if err != nil { - t.Fatalf("Expected no error, got: %v", err) - } - - if result.TotalAttempted != 0 { - t.Errorf("Expected 0 total attempted, got %d", result.TotalAttempted) - } - - if result.SuccessfullySent != 0 { - t.Errorf("Expected 0 successfully sent, got %d", result.SuccessfullySent) - } -} - -// Test SendReminders with successful email send -func TestReminderService_SendReminders_Success(t *testing.T) { - t.Parallel() - ctx := context.Background() - - mockExpectedRepo := &mockExpectedSignerRepository{ - listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) { - return []*models.ExpectedSignerWithStatus{ - {ExpectedSigner: models.ExpectedSigner{Email: "pending@example.com", Name: "Pending User"}, HasSigned: false}, - }, nil - }, - } - - loggedReminder := false - mockReminderRepo := &mockReminderRepository{ - logReminderFunc: func(ctx context.Context, log *models.ReminderLog) error { - loggedReminder = true - if log.Status != "sent" { - t.Errorf("Expected status 'sent', got '%s'", log.Status) - } - return nil - }, - } - - emailSent := false - mockEmailSender := &mockEmailSender{ - sendFunc: func(ctx context.Context, msg email.Message) error { - emailSent = true - if len(msg.To) != 1 || msg.To[0] != "pending@example.com" { - t.Errorf("Expected email to 'pending@example.com', got %v", msg.To) - } - return nil - }, - } - mockMagicLink := &mockMagicLinkService{} - - service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com") - - result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en") - - if err != nil { - t.Fatalf("Expected no error, got: %v", err) - } - - if result.TotalAttempted != 1 { - t.Errorf("Expected 1 total attempted, got %d", result.TotalAttempted) - } - - if result.SuccessfullySent != 1 { - t.Errorf("Expected 1 successfully sent, got %d", result.SuccessfullySent) - } - - if result.Failed != 0 { - t.Errorf("Expected 0 failed, got %d", result.Failed) - } - - if !emailSent { - t.Error("Expected email to be sent") - } - - if !loggedReminder { - t.Error("Expected reminder to be logged") - } -} - -// Test SendReminders with email failure -func TestReminderService_SendReminders_EmailFailure(t *testing.T) { - t.Parallel() - ctx := context.Background() - - mockExpectedRepo := &mockExpectedSignerRepository{ - listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) { - return []*models.ExpectedSignerWithStatus{ - {ExpectedSigner: models.ExpectedSigner{Email: "pending@example.com", Name: "Pending User"}, HasSigned: false}, - }, nil - }, - } - - loggedReminder := false - mockReminderRepo := &mockReminderRepository{ - logReminderFunc: func(ctx context.Context, log *models.ReminderLog) error { - loggedReminder = true - if log.Status != "failed" { - t.Errorf("Expected status 'failed', got '%s'", log.Status) - } - if log.ErrorMessage == nil { - t.Error("Expected error message to be set") - } - return nil - }, - } - - mockEmailSender := &mockEmailSender{ - sendFunc: func(ctx context.Context, msg email.Message) error { - return errors.New("SMTP connection failed") - }, - } - mockMagicLink := &mockMagicLinkService{} - - service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com") - - result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en") - - if err != nil { - t.Fatalf("Expected no error from SendReminders, got: %v", err) - } - - if result.TotalAttempted != 1 { - t.Errorf("Expected 1 total attempted, got %d", result.TotalAttempted) - } - - if result.Failed != 1 { - t.Errorf("Expected 1 failed, got %d", result.Failed) - } - - if result.SuccessfullySent != 0 { - t.Errorf("Expected 0 successfully sent, got %d", result.SuccessfullySent) - } - - if len(result.Errors) != 1 { - t.Errorf("Expected 1 error message, got %d", len(result.Errors)) - } - - if !loggedReminder { - t.Error("Expected failed reminder to be logged") - } -} - -// Test SendReminders with specific emails filter -func TestReminderService_SendReminders_SpecificEmails(t *testing.T) { - t.Parallel() - ctx := context.Background() - - mockExpectedRepo := &mockExpectedSignerRepository{ - listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) { - return []*models.ExpectedSignerWithStatus{ - {ExpectedSigner: models.ExpectedSigner{Email: "pending1@example.com"}, HasSigned: false}, - {ExpectedSigner: models.ExpectedSigner{Email: "pending2@example.com"}, HasSigned: false}, - {ExpectedSigner: models.ExpectedSigner{Email: "pending3@example.com"}, HasSigned: false}, - }, nil - }, - } - - emailsSent := []string{} - mockReminderRepo := &mockReminderRepository{ - logReminderFunc: func(ctx context.Context, log *models.ReminderLog) error { - return nil - }, - } - - mockEmailSender := &mockEmailSender{ - sendFunc: func(ctx context.Context, msg email.Message) error { - emailsSent = append(emailsSent, msg.To[0]) - return nil - }, - } - mockMagicLink := &mockMagicLinkService{} - - service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com") - - specificEmails := []string{"pending2@example.com"} - result, err := service.SendReminders(ctx, "doc1", "admin@example.com", specificEmails, "https://example.com/doc.pdf", "en") - - if err != nil { - t.Fatalf("Expected no error, got: %v", err) - } - - if result.TotalAttempted != 1 { - t.Errorf("Expected 1 total attempted, got %d", result.TotalAttempted) - } - - if len(emailsSent) != 1 || emailsSent[0] != "pending2@example.com" { - t.Errorf("Expected only 'pending2@example.com' to receive email, got %v", emailsSent) - } -} - -// Test SendReminders with repository error -func TestReminderService_SendReminders_RepositoryError(t *testing.T) { - t.Parallel() - ctx := context.Background() - - mockExpectedRepo := &mockExpectedSignerRepository{ - listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) { - return nil, errors.New("database connection failed") - }, - } - - mockReminderRepo := &mockReminderRepository{} - mockEmailSender := &mockEmailSender{} - mockMagicLink := &mockMagicLinkService{} - - service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com") - - result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en") - - if err == nil { - t.Fatal("Expected error, got nil") - } - - if result != nil { - t.Errorf("Expected nil result on error, got %v", result) - } -} - -// Test GetReminderHistory -func TestReminderService_GetReminderHistory(t *testing.T) { - t.Parallel() - ctx := context.Background() - - expectedLogs := []*models.ReminderLog{ - { - DocID: "doc1", - RecipientEmail: "user@example.com", - SentAt: time.Now(), - SentBy: "admin@example.com", - Status: "sent", - }, - } - - mockReminderRepo := &mockReminderRepository{ - getReminderHistoryFunc: func(ctx context.Context, docID string) ([]*models.ReminderLog, error) { - if docID != "doc1" { - t.Errorf("Expected docID 'doc1', got '%s'", docID) - } - return expectedLogs, nil - }, - } - - service := NewReminderService(&mockExpectedSignerRepository{}, mockReminderRepo, &mockEmailSender{}, &mockMagicLinkService{}, nil, "https://example.com") - - logs, err := service.GetReminderHistory(ctx, "doc1") - - if err != nil { - t.Fatalf("Expected no error, got: %v", err) - } - - if len(logs) != 1 { - t.Errorf("Expected 1 log, got %d", len(logs)) - } - - if logs[0].RecipientEmail != "user@example.com" { - t.Errorf("Expected recipient 'user@example.com', got '%s'", logs[0].RecipientEmail) - } -} - -// Test GetReminderStats -func TestReminderService_GetReminderStats(t *testing.T) { - t.Parallel() - ctx := context.Background() - - now := time.Now() - expectedStats := &models.ReminderStats{ - TotalSent: 5, - LastSentAt: &now, - PendingCount: 2, - } - - mockReminderRepo := &mockReminderRepository{ - getReminderStatsFunc: func(ctx context.Context, docID string) (*models.ReminderStats, error) { - if docID != "doc1" { - t.Errorf("Expected docID 'doc1', got '%s'", docID) - } - return expectedStats, nil - }, - } - - service := NewReminderService(&mockExpectedSignerRepository{}, mockReminderRepo, &mockEmailSender{}, &mockMagicLinkService{}, nil, "https://example.com") - - stats, err := service.GetReminderStats(ctx, "doc1") - - if err != nil { - t.Fatalf("Expected no error, got: %v", err) - } - - if stats.TotalSent != 5 { - t.Errorf("Expected 5 total sent, got %d", stats.TotalSent) - } - - if stats.PendingCount != 2 { - t.Errorf("Expected 2 pending, got %d", stats.PendingCount) - } - - if stats.LastSentAt == nil { - t.Error("Expected LastSentAt to be set") - } -} - -// Test SendReminders with multiple pending signers -func TestReminderService_SendReminders_MultiplePending(t *testing.T) { - t.Parallel() - ctx := context.Background() - - mockExpectedRepo := &mockExpectedSignerRepository{ - listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) { - return []*models.ExpectedSignerWithStatus{ - {ExpectedSigner: models.ExpectedSigner{Email: "pending1@example.com", Name: "User 1"}, HasSigned: false}, - {ExpectedSigner: models.ExpectedSigner{Email: "pending2@example.com", Name: "User 2"}, HasSigned: false}, - {ExpectedSigner: models.ExpectedSigner{Email: "already-signed@example.com", Name: "User 3"}, HasSigned: true}, - }, nil - }, - } - - emailsSent := 0 - mockReminderRepo := &mockReminderRepository{ - logReminderFunc: func(ctx context.Context, log *models.ReminderLog) error { - return nil - }, - } - - mockEmailSender := &mockEmailSender{ - sendFunc: func(ctx context.Context, msg email.Message) error { - emailsSent++ - return nil - }, - } - mockMagicLink := &mockMagicLinkService{} - - service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com") - - result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en") - - if err != nil { - t.Fatalf("Expected no error, got: %v", err) - } - - if result.TotalAttempted != 2 { - t.Errorf("Expected 2 total attempted, got %d", result.TotalAttempted) - } - - if result.SuccessfullySent != 2 { - t.Errorf("Expected 2 successfully sent, got %d", result.SuccessfullySent) - } - - if emailsSent != 2 { - t.Errorf("Expected 2 emails sent, got %d", emailsSent) - } -} - -// Test SendReminders with log failure after successful email -func TestReminderService_SendReminders_LogFailure(t *testing.T) { - t.Parallel() - ctx := context.Background() - - mockExpectedRepo := &mockExpectedSignerRepository{ - listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) { - return []*models.ExpectedSignerWithStatus{ - {ExpectedSigner: models.ExpectedSigner{Email: "pending@example.com", Name: "Pending User"}, HasSigned: false}, - }, nil - }, - } - - mockReminderRepo := &mockReminderRepository{ - logReminderFunc: func(ctx context.Context, log *models.ReminderLog) error { - return errors.New("database write failed") - }, - } - - mockEmailSender := &mockEmailSender{ - sendFunc: func(ctx context.Context, msg email.Message) error { - return nil // Email succeeds - }, - } - mockMagicLink := &mockMagicLinkService{} - - service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com") - - result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en") - - if err != nil { - t.Fatalf("Expected no error from SendReminders, got: %v", err) - } - - // The send should fail because logging failed - if result.Failed != 1 { - t.Errorf("Expected 1 failed, got %d", result.Failed) - } - - if result.SuccessfullySent != 0 { - t.Errorf("Expected 0 successfully sent, got %d", result.SuccessfullySent) - } -} diff --git a/backend/internal/application/services/signature_test.go b/backend/internal/application/services/signature_test.go index fd04ff3..82313e9 100644 --- a/backend/internal/application/services/signature_test.go +++ b/backend/internal/application/services/signature_test.go @@ -172,6 +172,64 @@ func (f *fakeCryptoSigner) CreateSignature(ctx context.Context, docID string, us return payloadHash, signature, nil } +type fakeDocumentRepository struct { + documents map[string]*models.Document +} + +func newFakeDocumentRepository() *fakeDocumentRepository { + return &fakeDocumentRepository{ + documents: make(map[string]*models.Document), + } +} + +func (f *fakeDocumentRepository) Create(_ context.Context, docID string, input models.DocumentInput, createdBy string) (*models.Document, error) { + doc := &models.Document{ + DocID: docID, + Title: input.Title, + URL: input.URL, + Checksum: input.Checksum, + ChecksumAlgorithm: input.ChecksumAlgorithm, + CreatedBy: createdBy, + } + f.documents[docID] = doc + return doc, nil +} + +func (f *fakeDocumentRepository) GetByDocID(_ context.Context, docID string) (*models.Document, error) { + if doc, ok := f.documents[docID]; ok { + return doc, nil + } + return nil, nil +} + +func (f *fakeDocumentRepository) FindByReference(_ context.Context, _ string, _ string) (*models.Document, error) { + return nil, nil +} + +func (f *fakeDocumentRepository) List(_ context.Context, _, _ int) ([]*models.Document, error) { + return nil, nil +} + +func (f *fakeDocumentRepository) Search(_ context.Context, _ string, _, _ int) ([]*models.Document, error) { + return nil, nil +} + +func (f *fakeDocumentRepository) Count(_ context.Context, _ string) (int, error) { + return 0, nil +} + +func (f *fakeDocumentRepository) ListByCreatedBy(_ context.Context, _ string, _, _ int) ([]*models.Document, error) { + return nil, nil +} + +func (f *fakeDocumentRepository) SearchByCreatedBy(_ context.Context, _, _ string, _, _ int) ([]*models.Document, error) { + return nil, nil +} + +func (f *fakeDocumentRepository) CountByCreatedBy(_ context.Context, _, _ string) (int, error) { + return 0, nil +} + func TestNewSignatureService(t *testing.T) { repo := newFakeRepository() docRepo := newFakeDocumentRepository() diff --git a/backend/internal/infrastructure/auth/oauth_provider.go b/backend/internal/infrastructure/auth/oauth_provider.go deleted file mode 100644 index 09d9cff..0000000 --- a/backend/internal/infrastructure/auth/oauth_provider.go +++ /dev/null @@ -1,393 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later -package auth - -import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - - "github.com/gorilla/securecookie" - "golang.org/x/oauth2" - - "github.com/btouchard/ackify-ce/backend/pkg/crypto" - "github.com/btouchard/ackify-ce/backend/pkg/logger" - "github.com/btouchard/ackify-ce/backend/pkg/models" -) - -// OAuthProvider handles OAuth2 authentication flow -// This component is optional and can be nil if OAuth is disabled -type OAuthProvider struct { - oauthConfig *oauth2.Config - userInfoURL string - logoutURL string - allowedDomain string - baseURL string - sessionSvc *SessionService // Reference to session service for state management -} - -// OAuthProviderConfig holds configuration for the OAuth provider -type OAuthProviderConfig struct { - BaseURL string - ClientID string - ClientSecret string - AuthURL string - TokenURL string - UserInfoURL string - LogoutURL string - Scopes []string - AllowedDomain string - SessionSvc *SessionService -} - -// NewOAuthProvider creates a new OAuth provider -func NewOAuthProvider(config OAuthProviderConfig) *OAuthProvider { - oauthConfig := &oauth2.Config{ - ClientID: config.ClientID, - ClientSecret: config.ClientSecret, - RedirectURL: config.BaseURL + "/api/v1/auth/callback", - Scopes: config.Scopes, - Endpoint: oauth2.Endpoint{ - AuthURL: config.AuthURL, - TokenURL: config.TokenURL, - }, - } - - logger.Logger.Info("OAuth provider configured successfully") - - return &OAuthProvider{ - oauthConfig: oauthConfig, - userInfoURL: config.UserInfoURL, - logoutURL: config.LogoutURL, - allowedDomain: config.AllowedDomain, - baseURL: config.BaseURL, - sessionSvc: config.SessionSvc, - } -} - -// GetLogoutURL returns the SSO logout URL if configured -func (p *OAuthProvider) GetLogoutURL() string { - if p.logoutURL == "" { - return "" - } - - logoutURL := p.logoutURL - if p.baseURL != "" { - logoutURL += "?continue=" + p.baseURL - } - - return logoutURL -} - -// CreateAuthURL creates an OAuth authorization URL with PKCE -func (p *OAuthProvider) CreateAuthURL(w http.ResponseWriter, r *http.Request, nextURL string) string { - // Generate PKCE code verifier and challenge - codeVerifier, err := crypto.GenerateCodeVerifier() - if err != nil { - logger.Logger.Error("Failed to generate PKCE code verifier", "error", err.Error()) - // Fallback to OAuth flow without PKCE for backward compatibility - return p.createAuthURLWithoutPKCE(w, r, nextURL) - } - - codeChallenge := crypto.GenerateCodeChallenge(codeVerifier) - logger.Logger.Debug("Generated PKCE parameters for OAuth flow") - - // Generate state token - randPart := securecookie.GenerateRandomKey(20) - token := base64.RawURLEncoding.EncodeToString(randPart) - state := token + ":" + base64.RawURLEncoding.EncodeToString([]byte(nextURL)) - - promptParam := "select_account" - isSilent := r.URL.Query().Get("silent") == "true" - if isSilent { - promptParam = "none" - } - - logger.Logger.Info("Starting OAuth flow with PKCE", - "next_url", nextURL, - "silent", isSilent, - "state_token_length", len(token)) - - session, err := p.sessionSvc.GetSession(r) - if err != nil { - logger.Logger.Error("CreateAuthURL: failed to get session from store", "error", err.Error()) - // Create a new empty session if Get fails - session, _ = p.sessionSvc.GetNewSession(r) - } - - // Store state and code_verifier in session - session.Values["oauth_state"] = token - session.Values["code_verifier"] = codeVerifier - - err = session.Save(r, w) - if err != nil { - logger.Logger.Error("CreateAuthURL: failed to save session", "error", err.Error()) - } - - // Generate OAuth URL with PKCE parameters - authURL := p.oauthConfig.AuthCodeURL(state, - oauth2.SetAuthURLParam("prompt", promptParam), - oauth2.SetAuthURLParam("code_challenge", codeChallenge), - oauth2.SetAuthURLParam("code_challenge_method", "S256")) - - logger.Logger.Debug("CreateAuthURL: generated auth URL with PKCE", - "prompt", promptParam, - "url_length", len(authURL)) - - return authURL -} - -// createAuthURLWithoutPKCE is a fallback method for OAuth without PKCE -// Used for backward compatibility if PKCE generation fails -func (p *OAuthProvider) createAuthURLWithoutPKCE(w http.ResponseWriter, r *http.Request, nextURL string) string { - randPart := securecookie.GenerateRandomKey(20) - token := base64.RawURLEncoding.EncodeToString(randPart) - state := token + ":" + base64.RawURLEncoding.EncodeToString([]byte(nextURL)) - - promptParam := "select_account" - isSilent := r.URL.Query().Get("silent") == "true" - if isSilent { - promptParam = "none" - } - - logger.Logger.Warn("Starting OAuth flow WITHOUT PKCE (fallback mode)", - "next_url", nextURL, - "silent", isSilent) - - session, err := p.sessionSvc.GetSession(r) - if err != nil { - session, _ = p.sessionSvc.GetNewSession(r) - } - - session.Values["oauth_state"] = token - - err = session.Save(r, w) - if err != nil { - logger.Logger.Error("CreateAuthURL: failed to save session", "error", err.Error()) - } - - authURL := p.oauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("prompt", promptParam)) - - return authURL -} - -// VerifyState validates the OAuth state token for CSRF protection -func (p *OAuthProvider) VerifyState(w http.ResponseWriter, r *http.Request, stateToken string) bool { - session, _ := p.sessionSvc.GetSession(r) - stored, _ := session.Values["oauth_state"].(string) - - logger.Logger.Debug("VerifyState: validating OAuth state", - "stored_length", len(stored), - "token_length", len(stateToken), - "stored_empty", stored == "", - "token_empty", stateToken == "") - - if stored == "" || stateToken == "" { - logger.Logger.Warn("VerifyState: empty state tokens") - return false - } - - if subtleConstantTimeCompare(stored, stateToken) { - logger.Logger.Debug("VerifyState: state valid, clearing token") - delete(session.Values, "oauth_state") - _ = session.Save(r, w) - return true - } - - logger.Logger.Warn("VerifyState: state mismatch") - return false -} - -// subtleConstantTimeCompare performs a timing-safe string comparison -func subtleConstantTimeCompare(a, b string) bool { - if len(a) != len(b) { - return false - } - var v byte - for i := 0; i < len(a); i++ { - v |= a[i] ^ b[i] - } - return v == 0 -} - -// HandleCallback processes the OAuth callback and returns the authenticated user -func (p *OAuthProvider) HandleCallback(ctx context.Context, w http.ResponseWriter, r *http.Request, code, state string) (*models.User, string, error) { - parts := strings.SplitN(state, ":", 2) - nextURL := "/" - if len(parts) == 2 { - if nb, err := base64.RawURLEncoding.DecodeString(parts[1]); err == nil { - nextURL = string(nb) - } - } - - logger.Logger.Debug("Processing OAuth callback", - "has_code", code != "", - "next_url", nextURL) - - // Retrieve code_verifier from session for PKCE - session, _ := p.sessionSvc.GetSession(r) - codeVerifier, hasPKCE := session.Values["code_verifier"].(string) - - // Clean up code_verifier immediately after retrieval - if hasPKCE { - delete(session.Values, "code_verifier") - _ = session.Save(r, w) - } - - // Exchange authorization code for token (with or without PKCE) - var token *oauth2.Token - var err error - - if hasPKCE && codeVerifier != "" { - logger.Logger.Info("OAuth token exchange with PKCE") - token, err = p.oauthConfig.Exchange(ctx, code, - oauth2.SetAuthURLParam("code_verifier", codeVerifier)) - } else { - logger.Logger.Warn("OAuth token exchange without PKCE (legacy session or fallback)") - token, err = p.oauthConfig.Exchange(ctx, code) - } - - if err != nil { - logger.Logger.Error("OAuth token exchange failed", - "error", err.Error(), - "with_pkce", hasPKCE) - return nil, nextURL, fmt.Errorf("oauth exchange failed: %w", err) - } - - logger.Logger.Info("OAuth token exchange successful", "with_pkce", hasPKCE) - - client := p.oauthConfig.Client(ctx, token) - resp, err := client.Get(p.userInfoURL) - if err != nil || resp.StatusCode != 200 { - statusCode := 0 - if resp != nil { - statusCode = resp.StatusCode - } - logger.Logger.Error("User info request failed", - "error", err, - "status_code", statusCode) - return nil, nextURL, fmt.Errorf("userinfo request failed: %w", err) - } - defer func(Body io.ReadCloser) { - _ = Body.Close() - }(resp.Body) - - logger.Logger.Debug("User info retrieved successfully", - "status_code", resp.StatusCode) - - user, err := p.parseUserInfo(resp) - if err != nil { - logger.Logger.Error("Failed to parse user info", - "error", err.Error()) - return nil, nextURL, fmt.Errorf("failed to parse user info: %w", err) - } - - if !p.IsAllowedDomain(user.Email) { - logger.Logger.Warn("User domain not allowed") - return nil, nextURL, models.ErrDomainNotAllowed - } - - logger.Logger.Info("OAuth callback successful") - - // Store refresh token if available - if token.RefreshToken != "" && p.sessionSvc.sessionRepo != nil { - if err := p.sessionSvc.StoreRefreshToken(ctx, w, r, token, user); err != nil { - // Log error but don't fail the authentication - logger.Logger.Error("Failed to store refresh token (non-fatal)", "error", err.Error()) - } - } - - return user, nextURL, nil -} - -// IsAllowedDomain checks if the user's email domain is allowed -func (p *OAuthProvider) IsAllowedDomain(email string) bool { - if p.allowedDomain == "" { - return true - } - - domain := strings.ToLower(p.allowedDomain) - // If domain already has @ prefix, don't add another one - if !strings.HasPrefix(domain, "@") { - domain = "@" + domain - } - - return strings.HasSuffix(strings.ToLower(email), domain) -} - -// parseUserInfo extracts user information from the OAuth provider response -func (p *OAuthProvider) parseUserInfo(resp *http.Response) (*models.User, error) { - var rawUser map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&rawUser); err != nil { - return nil, fmt.Errorf("failed to decode user info: %w", err) - } - - // Reduce PII in standard logs; log only keys at debug level - if rawUser != nil { - keys := make([]string, 0, len(rawUser)) - for k := range rawUser { - keys = append(keys, k) - } - logger.Logger.Debug("OAuth user info received", "keys", keys) - } - - user := &models.User{} - - if sub, ok := rawUser["sub"].(string); ok { - user.Sub = sub - } else if id, ok := rawUser["id"]; ok { - user.Sub = fmt.Sprintf("%v", id) - } else { - return nil, fmt.Errorf("missing user ID in response") - } - - // Check for email in various provider-specific fields - // - "email": Standard OIDC claim (Google, GitHub, GitLab, etc.) - // - "mail": Microsoft Graph API - // - "userPrincipalName": Microsoft fallback (UPN format) - if email, ok := rawUser["email"].(string); ok && email != "" { - user.Email = email - } else if mail, ok := rawUser["mail"].(string); ok && mail != "" { - user.Email = mail - } else if upn, ok := rawUser["userPrincipalName"].(string); ok && upn != "" { - user.Email = upn - } else { - return nil, fmt.Errorf("missing email in user info response (checked: email, mail, userPrincipalName)") - } - - // Extract display name from various provider-specific fields - var name string - if fullName, ok := rawUser["name"].(string); ok && fullName != "" { - name = fullName - } else if firstName, ok := rawUser["given_name"].(string); ok { - if lastName, ok := rawUser["family_name"].(string); ok { - name = firstName + " " + lastName - } else { - name = firstName - } - } else if displayName, ok := rawUser["displayName"].(string); ok && displayName != "" { - name = displayName - } else if cn, ok := rawUser["cn"].(string); ok && cn != "" { - name = cn - } else if displayNameSnake, ok := rawUser["display_name"].(string); ok && displayNameSnake != "" { - name = displayNameSnake - } else if preferredName, ok := rawUser["preferred_username"].(string); ok && preferredName != "" { - name = preferredName - } - - user.Name = name - - logger.Logger.Debug("Extracted OAuth user identifiers", - "has_sub", user.Sub != "", - "has_email", user.Email != "", - "has_name", user.Name != "") - - if !user.IsValid() { - return nil, fmt.Errorf("invalid user data extracted") - } - - return user, nil -} diff --git a/backend/internal/infrastructure/auth/oauth_provider_test.go b/backend/internal/infrastructure/auth/oauth_provider_test.go deleted file mode 100644 index 128ca76..0000000 --- a/backend/internal/infrastructure/auth/oauth_provider_test.go +++ /dev/null @@ -1,422 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later -package auth - -import ( - "bytes" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "net/url" - "testing" - - "github.com/btouchard/ackify-ce/backend/pkg/models" -) - -func TestOAuthProvider_IsAllowedDomain(t *testing.T) { - sessionSvc := NewSessionService(SessionServiceConfig{ - CookieSecret: []byte("32-byte-secret-for-secure-cookies"), - SecureCookies: false, - }) - - tests := []struct { - name string - allowedDomain string - email string - expected bool - }{ - { - name: "allowed domain match", - allowedDomain: "@example.com", - email: "user@example.com", - expected: true, - }, - { - name: "allowed domain mismatch", - allowedDomain: "@example.com", - email: "user@other.com", - expected: false, - }, - { - name: "no restriction", - allowedDomain: "", - email: "user@any.com", - expected: true, - }, - { - name: "case insensitive match", - allowedDomain: "@Example.Com", - email: "user@EXAMPLE.com", - expected: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - provider := &OAuthProvider{ - allowedDomain: tt.allowedDomain, - sessionSvc: sessionSvc, - } - - result := provider.IsAllowedDomain(tt.email) - if result != tt.expected { - t.Errorf("IsAllowedDomain(%v) = %v, expected %v", tt.email, result, tt.expected) - } - }) - } -} - -func TestOAuthProvider_CreateAuthURL(t *testing.T) { - sessionSvc := NewSessionService(SessionServiceConfig{ - CookieSecret: []byte("32-byte-secret-for-secure-cookies"), - SecureCookies: false, - }) - - config := OAuthProviderConfig{ - BaseURL: "https://ackify.example.com", - ClientID: "test-client", - ClientSecret: "test-secret", - AuthURL: "https://provider.com/auth", - TokenURL: "https://provider.com/token", - UserInfoURL: "https://provider.com/userinfo", - Scopes: []string{"openid", "email"}, - SessionSvc: sessionSvc, - } - - provider := NewOAuthProvider(config) - - t.Run("creates auth URL with PKCE", func(t *testing.T) { - rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/", nil) - - authURL := provider.CreateAuthURL(rec, req, "/dashboard") - - if authURL == "" { - t.Fatal("CreateAuthURL() returned empty string") - } - - // Parse the URL and check parameters - parsedURL, err := url.Parse(authURL) - if err != nil { - t.Fatalf("Failed to parse auth URL: %v", err) - } - - query := parsedURL.Query() - - // Check for required OAuth parameters - if query.Get("client_id") == "" { - t.Error("client_id parameter missing") - } - if query.Get("redirect_uri") == "" { - t.Error("redirect_uri parameter missing") - } - if query.Get("response_type") != "code" { - t.Error("response_type should be 'code'") - } - if query.Get("state") == "" { - t.Error("state parameter missing") - } - - // Check for PKCE parameters - if query.Get("code_challenge") == "" { - t.Error("code_challenge parameter missing (PKCE should be enabled)") - } - if query.Get("code_challenge_method") != "S256" { - t.Error("code_challenge_method should be 'S256'") - } - }) - - t.Run("creates auth URL with silent flag", func(t *testing.T) { - rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/?silent=true", nil) - - authURL := provider.CreateAuthURL(rec, req, "/dashboard") - - parsedURL, _ := url.Parse(authURL) - query := parsedURL.Query() - - if query.Get("prompt") != "none" { - t.Errorf("prompt parameter = %v, expected 'none' for silent auth", query.Get("prompt")) - } - }) -} - -func TestOAuthProvider_VerifyState(t *testing.T) { - sessionSvc := NewSessionService(SessionServiceConfig{ - CookieSecret: []byte("32-byte-secret-for-secure-cookies"), - SecureCookies: false, - }) - - config := OAuthProviderConfig{ - BaseURL: "https://ackify.example.com", - ClientID: "test-client", - ClientSecret: "test-secret", - AuthURL: "https://provider.com/auth", - TokenURL: "https://provider.com/token", - UserInfoURL: "https://provider.com/userinfo", - Scopes: []string{"openid"}, - SessionSvc: sessionSvc, - } - - provider := NewOAuthProvider(config) - - t.Run("valid state", func(t *testing.T) { - rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/", nil) - - // Create auth URL to generate state - _ = provider.CreateAuthURL(rec, req, "/") - - // Get session with state - req2 := httptest.NewRequest("GET", "/", nil) - for _, cookie := range rec.Result().Cookies() { - req2.AddCookie(cookie) - } - - session, _ := sessionSvc.GetSession(req2) - storedState, ok := session.Values["oauth_state"].(string) - if !ok { - t.Fatal("State not stored in session") - } - - // Verify state - rec2 := httptest.NewRecorder() - valid := provider.VerifyState(rec2, req2, storedState) - if !valid { - t.Error("VerifyState() should return true for valid state") - } - }) - - t.Run("invalid state", func(t *testing.T) { - rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/", nil) - - valid := provider.VerifyState(rec, req, "invalid-state-token") - if valid { - t.Error("VerifyState() should return false for invalid state") - } - }) - - t.Run("empty state", func(t *testing.T) { - rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/", nil) - - valid := provider.VerifyState(rec, req, "") - if valid { - t.Error("VerifyState() should return false for empty state") - } - }) -} - -func TestOAuthProvider_parseUserInfo(t *testing.T) { - sessionSvc := NewSessionService(SessionServiceConfig{ - CookieSecret: []byte("32-byte-secret-for-secure-cookies"), - SecureCookies: false, - }) - - provider := &OAuthProvider{ - sessionSvc: sessionSvc, - } - - tests := []struct { - name string - responseObj map[string]interface{} - wantErr bool - checkUser func(*testing.T, *models.User) - }{ - { - name: "complete user info with sub", - responseObj: map[string]interface{}{ - "sub": "12345", - "email": "user@example.com", - "name": "Test User", - }, - wantErr: false, - checkUser: func(t *testing.T, user *models.User) { - if user.Sub != "12345" { - t.Errorf("Sub = %v, expected 12345", user.Sub) - } - if user.Email != "user@example.com" { - t.Errorf("Email = %v, expected user@example.com", user.Email) - } - if user.Name != "Test User" { - t.Errorf("Name = %v, expected Test User", user.Name) - } - }, - }, - { - name: "user info with id instead of sub", - responseObj: map[string]interface{}{ - "id": 67890, - "email": "user@example.com", - }, - wantErr: false, - checkUser: func(t *testing.T, user *models.User) { - if user.Sub != "67890" { - t.Errorf("Sub = %v, expected 67890", user.Sub) - } - }, - }, - { - name: "user info with given_name and family_name", - responseObj: map[string]interface{}{ - "sub": "12345", - "email": "user@example.com", - "given_name": "John", - "family_name": "Doe", - }, - wantErr: false, - checkUser: func(t *testing.T, user *models.User) { - if user.Name != "John Doe" { - t.Errorf("Name = %v, expected 'John Doe'", user.Name) - } - }, - }, - { - name: "Microsoft Graph API - mail field", - responseObj: map[string]interface{}{ - "id": "microsoft-id-12345", - "mail": "user@company.com", - "displayName": "Microsoft User", - "userPrincipalName": "user@company.onmicrosoft.com", - }, - wantErr: false, - checkUser: func(t *testing.T, user *models.User) { - if user.Sub != "microsoft-id-12345" { - t.Errorf("Sub = %v, expected microsoft-id-12345", user.Sub) - } - if user.Email != "user@company.com" { - t.Errorf("Email = %v, expected user@company.com (from mail field)", user.Email) - } - if user.Name != "Microsoft User" { - t.Errorf("Name = %v, expected Microsoft User (from displayName)", user.Name) - } - }, - }, - { - name: "Microsoft Graph API - userPrincipalName fallback", - responseObj: map[string]interface{}{ - "id": "microsoft-id-67890", - "displayName": "UPN User", - "userPrincipalName": "user@company.onmicrosoft.com", - }, - wantErr: false, - checkUser: func(t *testing.T, user *models.User) { - if user.Email != "user@company.onmicrosoft.com" { - t.Errorf("Email = %v, expected user@company.onmicrosoft.com (from userPrincipalName)", user.Email) - } - }, - }, - { - name: "email field takes priority over mail", - responseObj: map[string]interface{}{ - "sub": "12345", - "email": "primary@example.com", - "mail": "secondary@example.com", - }, - wantErr: false, - checkUser: func(t *testing.T, user *models.User) { - if user.Email != "primary@example.com" { - t.Errorf("Email = %v, expected primary@example.com (email should take priority)", user.Email) - } - }, - }, - { - name: "missing email", - responseObj: map[string]interface{}{ - "sub": "12345", - "name": "Test User", - }, - wantErr: true, - }, - { - name: "missing sub and id", - responseObj: map[string]interface{}{ - "email": "user@example.com", - "name": "Test User", - }, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Create mock HTTP response - jsonData, _ := json.Marshal(tt.responseObj) - resp := &http.Response{ - StatusCode: 200, - Body: io.NopCloser(bytes.NewReader(jsonData)), - } - - user, err := provider.parseUserInfo(resp) - - if tt.wantErr { - if err == nil { - t.Error("parseUserInfo() should return error") - } - return - } - - if err != nil { - t.Fatalf("parseUserInfo() unexpected error: %v", err) - } - - if user == nil { - t.Fatal("parseUserInfo() returned nil user") - } - - if tt.checkUser != nil { - tt.checkUser(t, user) - } - }) - } -} - -func TestSubtleConstantTimeCompare(t *testing.T) { - tests := []struct { - name string - a string - b string - expected bool - }{ - { - name: "equal strings", - a: "secret123", - b: "secret123", - expected: true, - }, - { - name: "different strings", - a: "secret123", - b: "secret456", - expected: false, - }, - { - name: "different lengths", - a: "short", - b: "longer string", - expected: false, - }, - { - name: "empty strings", - a: "", - b: "", - expected: true, - }, - { - name: "one empty", - a: "string", - b: "", - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := subtleConstantTimeCompare(tt.a, tt.b) - if result != tt.expected { - t.Errorf("subtleConstantTimeCompare(%v, %v) = %v, expected %v", tt.a, tt.b, result, tt.expected) - } - }) - } -} diff --git a/backend/internal/infrastructure/email/helpers.go b/backend/internal/infrastructure/email/helpers.go deleted file mode 100644 index c28370d..0000000 --- a/backend/internal/infrastructure/email/helpers.go +++ /dev/null @@ -1,37 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later -package email - -import ( - "context" - - "github.com/btouchard/ackify-ce/backend/internal/infrastructure/i18n" -) - -func SendEmail(ctx context.Context, sender Sender, template string, to []string, locale string, subject string, data map[string]any) error { - msg := Message{ - To: to, - Subject: subject, - Template: template, - Locale: locale, - Data: data, - } - - return sender.Send(ctx, msg) -} - -func SendSignatureReminderEmail(ctx context.Context, sender Sender, i18nService *i18n.I18n, to []string, locale, docID, docURL, signURL, recipientName string) error { - data := map[string]any{ - "DocID": docID, - "DocURL": docURL, - "SignURL": signURL, - "RecipientName": recipientName, - } - - // Get translated subject using i18n - subject := "Document Reading Confirmation Reminder" // Fallback - if i18nService != nil { - subject = i18nService.T(locale, "email.reminder.subject") - } - - return SendEmail(ctx, sender, "signature_reminder", to, locale, subject, data) -} diff --git a/backend/internal/infrastructure/email/helpers_test.go b/backend/internal/infrastructure/email/helpers_test.go deleted file mode 100644 index a9f2bc8..0000000 --- a/backend/internal/infrastructure/email/helpers_test.go +++ /dev/null @@ -1,268 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later -package email - -import ( - "context" - "errors" - "testing" -) - -// Mock sender for testing -type mockSender struct { - sendFunc func(ctx context.Context, msg Message) error - lastMsg *Message -} - -func (m *mockSender) Send(ctx context.Context, msg Message) error { - m.lastMsg = &msg - if m.sendFunc != nil { - return m.sendFunc(ctx, msg) - } - return nil -} - -func TestSendEmail(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - template string - to []string - locale string - subject string - data map[string]any - sendError error - expectError bool - }{ - { - name: "Send email successfully", - template: "test_template", - to: []string{"user@example.com"}, - locale: "en", - subject: "Test Subject", - data: map[string]any{ - "name": "John", - }, - sendError: nil, - expectError: false, - }, - { - name: "Send email with multiple recipients", - template: "welcome", - to: []string{"user1@example.com", "user2@example.com"}, - locale: "fr", - subject: "Bienvenue", - data: map[string]any{ - "company": "Acme Corp", - }, - sendError: nil, - expectError: false, - }, - { - name: "Send email with error", - template: "error_template", - to: []string{"user@example.com"}, - locale: "en", - subject: "Error Test", - data: nil, - sendError: errors.New("SMTP connection failed"), - expectError: true, - }, - { - name: "Send email with empty data", - template: "simple_template", - to: []string{"test@example.com"}, - locale: "en", - subject: "Simple Email", - data: map[string]any{}, - sendError: nil, - expectError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - ctx := context.Background() - - mock := &mockSender{ - sendFunc: func(ctx context.Context, msg Message) error { - return tt.sendError - }, - } - - err := SendEmail(ctx, mock, tt.template, tt.to, tt.locale, tt.subject, tt.data) - - if tt.expectError && err == nil { - t.Error("Expected error but got nil") - } - - if !tt.expectError && err != nil { - t.Errorf("Unexpected error: %v", err) - } - - // Verify message was constructed correctly - if mock.lastMsg == nil { - t.Fatal("Expected message to be captured") - } - - if mock.lastMsg.Template != tt.template { - t.Errorf("Expected template '%s', got '%s'", tt.template, mock.lastMsg.Template) - } - - if mock.lastMsg.Subject != tt.subject { - t.Errorf("Expected subject '%s', got '%s'", tt.subject, mock.lastMsg.Subject) - } - - if mock.lastMsg.Locale != tt.locale { - t.Errorf("Expected locale '%s', got '%s'", tt.locale, mock.lastMsg.Locale) - } - - if len(mock.lastMsg.To) != len(tt.to) { - t.Errorf("Expected %d recipients, got %d", len(tt.to), len(mock.lastMsg.To)) - } - }) - } -} - -func TestSendSignatureReminderEmail(t *testing.T) { - t.Parallel() - - // Note: When i18n is nil, all tests use the fallback subject - fallbackSubject := "Document Reading Confirmation Reminder" - - tests := []struct { - name string - to []string - locale string - docID string - docURL string - signURL string - recipientName string - expectedSubject string - sendError error - expectError bool - }{ - { - name: "Send reminder in English", - to: []string{"user@example.com"}, - locale: "en", - docID: "doc123", - docURL: "https://example.com/doc.pdf", - signURL: "https://example.com/?doc=doc123", - recipientName: "John Doe", - expectedSubject: fallbackSubject, - sendError: nil, - expectError: false, - }, - { - name: "Send reminder in French (fallback without i18n)", - to: []string{"utilisateur@exemple.fr"}, - locale: "fr", - docID: "doc456", - docURL: "https://exemple.fr/document.pdf", - signURL: "https://exemple.fr/?doc=doc456", - recipientName: "Marie Dupont", - expectedSubject: fallbackSubject, - sendError: nil, - expectError: false, - }, - { - name: "Send reminder with unknown locale defaults to fallback", - to: []string{"user@example.com"}, - locale: "es", - docID: "doc789", - docURL: "https://example.com/doc.pdf", - signURL: "https://example.com/?doc=doc789", - recipientName: "Juan Garcia", - expectedSubject: fallbackSubject, - sendError: nil, - expectError: false, - }, - { - name: "Send reminder with error", - to: []string{"user@example.com"}, - locale: "en", - docID: "doc999", - docURL: "https://example.com/doc.pdf", - signURL: "https://example.com/?doc=doc999", - recipientName: "Test User", - expectedSubject: fallbackSubject, - sendError: errors.New("email server unavailable"), - expectError: true, - }, - { - name: "Send reminder with empty recipient name", - to: []string{"user@example.com"}, - locale: "en", - docID: "doc000", - docURL: "https://example.com/doc.pdf", - signURL: "https://example.com/?doc=doc000", - recipientName: "", - expectedSubject: fallbackSubject, - sendError: nil, - expectError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - ctx := context.Background() - - mock := &mockSender{ - sendFunc: func(ctx context.Context, msg Message) error { - return tt.sendError - }, - } - - err := SendSignatureReminderEmail(ctx, mock, nil, tt.to, tt.locale, tt.docID, tt.docURL, tt.signURL, tt.recipientName) - - if tt.expectError && err == nil { - t.Error("Expected error but got nil") - } - - if !tt.expectError && err != nil { - t.Errorf("Unexpected error: %v", err) - } - - // Verify message construction - if mock.lastMsg == nil { - t.Fatal("Expected message to be captured") - } - - if mock.lastMsg.Template != "signature_reminder" { - t.Errorf("Expected template 'signature_reminder', got '%s'", mock.lastMsg.Template) - } - - if mock.lastMsg.Subject != tt.expectedSubject { - t.Errorf("Expected subject '%s', got '%s'", tt.expectedSubject, mock.lastMsg.Subject) - } - - if mock.lastMsg.Locale != tt.locale { - t.Errorf("Expected locale '%s', got '%s'", tt.locale, mock.lastMsg.Locale) - } - - // Verify data fields - if mock.lastMsg.Data == nil { - t.Fatal("Expected data to be present") - } - - if docID, ok := mock.lastMsg.Data["DocID"].(string); !ok || docID != tt.docID { - t.Errorf("Expected DocID '%s', got '%v'", tt.docID, mock.lastMsg.Data["DocID"]) - } - - if docURL, ok := mock.lastMsg.Data["DocURL"].(string); !ok || docURL != tt.docURL { - t.Errorf("Expected DocURL '%s', got '%v'", tt.docURL, mock.lastMsg.Data["DocURL"]) - } - - if signURL, ok := mock.lastMsg.Data["SignURL"].(string); !ok || signURL != tt.signURL { - t.Errorf("Expected SignURL '%s', got '%v'", tt.signURL, mock.lastMsg.Data["SignURL"]) - } - - if recipientName, ok := mock.lastMsg.Data["RecipientName"].(string); !ok || recipientName != tt.recipientName { - t.Errorf("Expected RecipientName '%s', got '%v'", tt.recipientName, mock.lastMsg.Data["RecipientName"]) - } - }) - } -} diff --git a/backend/internal/infrastructure/email/queue_helpers.go b/backend/internal/infrastructure/email/queue_helpers.go deleted file mode 100644 index cdb697c..0000000 --- a/backend/internal/infrastructure/email/queue_helpers.go +++ /dev/null @@ -1,113 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later -package email - -import ( - "context" - "fmt" - - "github.com/btouchard/ackify-ce/backend/internal/infrastructure/i18n" - "github.com/btouchard/ackify-ce/backend/pkg/models" -) - -// QueueSender implements the Sender interface by queuing emails instead of sending them directly -type QueueSender struct { - queueRepo QueueRepository - baseURL string -} - -// NewQueueSender creates a new queue-based email sender -func NewQueueSender(queueRepo QueueRepository, baseURL string) *QueueSender { - return &QueueSender{ - queueRepo: queueRepo, - baseURL: baseURL, - } -} - -// Send queues an email for asynchronous processing -func (q *QueueSender) Send(ctx context.Context, msg Message) error { - // Convert message data to proper format - data := msg.Data - if data == nil { - data = make(map[string]interface{}) - } - - input := models.EmailQueueInput{ - ToAddresses: msg.To, - CcAddresses: msg.Cc, - BccAddresses: msg.Bcc, - Subject: msg.Subject, - Template: msg.Template, - Locale: msg.Locale, - Data: data, - Headers: msg.Headers, - Priority: models.EmailPriorityNormal, - } - - // Set priority based on template type - switch msg.Template { - case "signature_reminder": - input.Priority = models.EmailPriorityHigh - case "welcome", "notification": - input.Priority = models.EmailPriorityNormal - default: - input.Priority = models.EmailPriorityNormal - } - - // Queue the email - _, err := q.queueRepo.Enqueue(ctx, input) - if err != nil { - return fmt.Errorf("failed to queue email: %w", err) - } - - return nil -} - -// QueueSignatureReminderEmail queues a signature reminder email -func QueueSignatureReminderEmail( - ctx context.Context, - queueRepo QueueRepository, - i18nService *i18n.I18n, - recipients []string, - locale string, - docID string, - docURL string, - signURL string, - recipientName string, - sentBy string, -) error { - data := map[string]interface{}{ - "doc_id": docID, - "doc_url": docURL, - "sign_url": signURL, - "recipient_name": recipientName, - "locale": locale, - } - - // Get translated subject using i18n - subject := "Document Reading Confirmation Reminder" // Fallback - if i18nService != nil { - subject = i18nService.T(locale, "email.reminder.subject") - } - - // Create a reference for tracking - refType := "signature_reminder" - - input := models.EmailQueueInput{ - ToAddresses: recipients, - Subject: subject, - Template: "signature_reminder", - Locale: locale, - Data: data, - Priority: models.EmailPriorityHigh, - ReferenceType: &refType, - ReferenceID: &docID, - CreatedBy: &sentBy, - } - - _, err := queueRepo.Enqueue(ctx, input) - if err != nil { - return fmt.Errorf("failed to queue signature reminder: %w", err) - } - - return nil -} diff --git a/backend/internal/presentation/api/shared/response.go b/backend/internal/presentation/api/shared/response.go index 23ab494..5e473e3 100644 --- a/backend/internal/presentation/api/shared/response.go +++ b/backend/internal/presentation/api/shared/response.go @@ -120,8 +120,3 @@ func WritePaginatedJSON(w http.ResponseWriter, data interface{}, page, limit, to WriteJSONWithMeta(w, http.StatusOK, data, meta) } - -// WriteNoContent writes a 204 No Content response -func WriteNoContent(w http.ResponseWriter) { - w.WriteHeader(http.StatusNoContent) -} diff --git a/backend/internal/presentation/api/shared/response_test.go b/backend/internal/presentation/api/shared/response_test.go index 38a26a8..66d60e8 100644 --- a/backend/internal/presentation/api/shared/response_test.go +++ b/backend/internal/presentation/api/shared/response_test.go @@ -222,19 +222,3 @@ func TestWritePaginatedJSON(t *testing.T) { }) } } - -func TestWriteNoContent(t *testing.T) { - t.Parallel() - - w := httptest.NewRecorder() - - WriteNoContent(w) - - if w.Code != http.StatusNoContent { - t.Errorf("Expected status code %d, got %d", http.StatusNoContent, w.Code) - } - - if w.Body.Len() != 0 { - t.Errorf("Expected empty body, got %d bytes", w.Body.Len()) - } -} diff --git a/backend/internal/presentation/api/shared/utils.go b/backend/internal/presentation/api/shared/utils.go deleted file mode 100644 index 0a29b49..0000000 --- a/backend/internal/presentation/api/shared/utils.go +++ /dev/null @@ -1,33 +0,0 @@ -// SPDX-License-Identifier: AGPL-3.0-or-later -package shared - -import ( - "net/http" - "strings" -) - -// GetClientIP extracts the real client IP address from the request -// It checks X-Forwarded-For, X-Real-IP, and falls back to RemoteAddr -func GetClientIP(r *http.Request) string { - // Check X-Forwarded-For header (proxy/load balancer) - if forwardedFor := r.Header.Get("X-Forwarded-For"); forwardedFor != "" { - // X-Forwarded-For can contain multiple IPs, take the first one - ips := strings.Split(forwardedFor, ",") - if len(ips) > 0 { - return strings.TrimSpace(ips[0]) - } - } - - // Check X-Real-IP header - if realIP := r.Header.Get("X-Real-IP"); realIP != "" { - return strings.TrimSpace(realIP) - } - - // Fall back to RemoteAddr - ip := r.RemoteAddr - // Remove port if present - if idx := strings.LastIndex(ip, ":"); idx != -1 { - ip = ip[:idx] - } - return ip -} diff --git a/backend/internal/presentation/api/users/handler.go b/backend/internal/presentation/api/users/handler.go index 221dbd7..a61419c 100644 --- a/backend/internal/presentation/api/users/handler.go +++ b/backend/internal/presentation/api/users/handler.go @@ -25,6 +25,7 @@ type UserDTO struct { ID string `json:"id"` Email string `json:"email"` Name string `json:"name"` + Picture string `json:"picture,omitempty"` IsAdmin bool `json:"isAdmin"` } @@ -40,6 +41,7 @@ func (h *Handler) HandleGetCurrentUser(w http.ResponseWriter, r *http.Request) { ID: user.Sub, Email: user.Email, Name: user.Name, + Picture: user.Picture, IsAdmin: h.authorizer.IsAdmin(r.Context(), user.Email), } diff --git a/backend/internal/presentation/handlers/errors.go b/backend/internal/presentation/handlers/errors.go deleted file mode 100644 index 259568c..0000000 --- a/backend/internal/presentation/handlers/errors.go +++ /dev/null @@ -1,39 +0,0 @@ -package handlers - -import ( - "errors" - "net/http" - - "github.com/btouchard/ackify-ce/backend/pkg/logger" - "github.com/btouchard/ackify-ce/backend/pkg/models" -) - -// HandleError handles different types of errors and returns appropriate HTTP responses -func HandleError(w http.ResponseWriter, err error) { - switch { - case errors.Is(err, models.ErrUnauthorized): - logger.Logger.Warn("Unauthorized access attempt", "error", err.Error()) - http.Error(w, "Unauthorized", http.StatusUnauthorized) - case errors.Is(err, models.ErrSignatureNotFound): - logger.Logger.Debug("Signature not found", "error", err.Error()) - http.Error(w, "Signature not found", http.StatusNotFound) - case errors.Is(err, models.ErrSignatureAlreadyExists): - logger.Logger.Debug("Duplicate signature attempt", "error", err.Error()) - http.Error(w, "Signature already exists", http.StatusConflict) - case errors.Is(err, models.ErrInvalidUser): - logger.Logger.Warn("Invalid user data", "error", err.Error()) - http.Error(w, "Invalid user", http.StatusBadRequest) - case errors.Is(err, models.ErrInvalidDocument): - logger.Logger.Warn("Invalid document ID", "error", err.Error()) - http.Error(w, "Invalid document ID", http.StatusBadRequest) - case errors.Is(err, models.ErrDomainNotAllowed): - logger.Logger.Warn("Domain not allowed", "error", err.Error()) - http.Error(w, "Domain not allowed", http.StatusForbidden) - case errors.Is(err, models.ErrDatabaseConnection): - logger.Logger.Error("Database connection error", "error", err.Error()) - http.Error(w, "Database error", http.StatusInternalServerError) - default: - logger.Logger.Error("Unhandled error", "error", err.Error()) - http.Error(w, "Internal server error", http.StatusInternalServerError) - } -} diff --git a/backend/internal/presentation/handlers/handlers_test.go b/backend/internal/presentation/handlers/handlers_test.go index cc3af41..8c53e9f 100644 --- a/backend/internal/presentation/handlers/handlers_test.go +++ b/backend/internal/presentation/handlers/handlers_test.go @@ -4,8 +4,6 @@ package handlers import ( "context" "encoding/json" - "errors" - "fmt" "net/http" "net/http/httptest" "net/url" @@ -208,36 +206,6 @@ func TestHandleOEmbed_MissingDocParam(t *testing.T) { } } -func TestValidateOEmbedURL(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - urlStr string - baseURL string - expected bool - }{ - {"valid same host", "https://example.com/?doc=123", "https://example.com", true}, - {"valid with port", "https://example.com:443/?doc=123", "https://example.com", true}, - {"different host", "https://other.com/?doc=123", "https://example.com", false}, - {"localhost variations", "http://localhost:8080/?doc=123", "http://127.0.0.1:8080", true}, - {"localhost to 127.0.0.1", "http://127.0.0.1/?doc=123", "http://localhost", true}, - {"invalid URL", ":::invalid", "https://example.com", false}, - {"invalid base URL", "https://example.com/?doc=123", ":::invalid", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - result := ValidateOEmbedURL(tt.urlStr, tt.baseURL) - if result != tt.expected { - t.Errorf("Expected %v, got %v", tt.expected, result) - } - }) - } -} - // ============================================================================ // BENCHMARKS // ============================================================================ @@ -254,16 +222,6 @@ func BenchmarkHandleOEmbed(b *testing.B) { } } -func BenchmarkValidateOEmbedURL(b *testing.B) { - urlStr := "https://example.com/?doc=test123" - baseURL := "https://example.com" - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ValidateOEmbedURL(urlStr, baseURL) - } -} - // ============================================================================ // TESTS - Middleware: SecureHeaders // ============================================================================ @@ -423,126 +381,6 @@ func TestRequestLogger_DifferentMethods(t *testing.T) { } } -// ============================================================================ -// TESTS - HandleError -// ============================================================================ - -func TestHandleError_Unauthorized(t *testing.T) { - t.Parallel() - - rec := httptest.NewRecorder() - HandleError(rec, models.ErrUnauthorized) - - assert.Equal(t, http.StatusUnauthorized, rec.Code) - assert.Contains(t, rec.Body.String(), "Unauthorized") -} - -func TestHandleError_SignatureNotFound(t *testing.T) { - t.Parallel() - - rec := httptest.NewRecorder() - HandleError(rec, models.ErrSignatureNotFound) - - assert.Equal(t, http.StatusNotFound, rec.Code) - assert.Contains(t, rec.Body.String(), "Signature not found") -} - -func TestHandleError_SignatureAlreadyExists(t *testing.T) { - t.Parallel() - - rec := httptest.NewRecorder() - HandleError(rec, models.ErrSignatureAlreadyExists) - - assert.Equal(t, http.StatusConflict, rec.Code) - assert.Contains(t, rec.Body.String(), "Signature already exists") -} - -func TestHandleError_InvalidUser(t *testing.T) { - t.Parallel() - - rec := httptest.NewRecorder() - HandleError(rec, models.ErrInvalidUser) - - assert.Equal(t, http.StatusBadRequest, rec.Code) - assert.Contains(t, rec.Body.String(), "Invalid user") -} - -func TestHandleError_InvalidDocument(t *testing.T) { - t.Parallel() - - rec := httptest.NewRecorder() - HandleError(rec, models.ErrInvalidDocument) - - assert.Equal(t, http.StatusBadRequest, rec.Code) - assert.Contains(t, rec.Body.String(), "Invalid document ID") -} - -func TestHandleError_DomainNotAllowed(t *testing.T) { - t.Parallel() - - rec := httptest.NewRecorder() - HandleError(rec, models.ErrDomainNotAllowed) - - assert.Equal(t, http.StatusForbidden, rec.Code) - assert.Contains(t, rec.Body.String(), "Domain not allowed") -} - -func TestHandleError_DatabaseConnection(t *testing.T) { - t.Parallel() - - rec := httptest.NewRecorder() - HandleError(rec, models.ErrDatabaseConnection) - - assert.Equal(t, http.StatusInternalServerError, rec.Code) - assert.Contains(t, rec.Body.String(), "Database error") -} - -func TestHandleError_UnknownError(t *testing.T) { - t.Parallel() - - rec := httptest.NewRecorder() - HandleError(rec, errors.New("unknown error")) - - assert.Equal(t, http.StatusInternalServerError, rec.Code) - assert.Contains(t, rec.Body.String(), "Internal server error") -} - -func TestHandleError_WrappedErrors(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - err error - expectedStatus int - expectedMsg string - }{ - { - "wrapped unauthorized", - fmt.Errorf("auth failed: %w", models.ErrUnauthorized), - http.StatusUnauthorized, - "Unauthorized", - }, - { - "wrapped domain error", - fmt.Errorf("validation failed: %w", models.ErrDomainNotAllowed), - http.StatusForbidden, - "Domain not allowed", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - rec := httptest.NewRecorder() - HandleError(rec, tt.err) - - assert.Equal(t, tt.expectedStatus, rec.Code) - assert.Contains(t, rec.Body.String(), tt.expectedMsg) - }) - } -} - // ============================================================================ // TESTS - statusRecorder // ============================================================================ @@ -616,13 +454,3 @@ func BenchmarkRequestLogger(b *testing.B) { handler.ServeHTTP(rec, req) } } - -func BenchmarkHandleError(b *testing.B) { - err := models.ErrUnauthorized - - b.ResetTimer() - for i := 0; i < b.N; i++ { - rec := httptest.NewRecorder() - HandleError(rec, err) - } -} diff --git a/backend/internal/presentation/handlers/middleware.go b/backend/internal/presentation/handlers/middleware.go index 133fa67..65f0325 100644 --- a/backend/internal/presentation/handlers/middleware.go +++ b/backend/internal/presentation/handlers/middleware.go @@ -7,47 +7,36 @@ import ( "time" "github.com/btouchard/ackify-ce/backend/pkg/logger" - "github.com/btouchard/ackify-ce/backend/pkg/models" ) -type userService interface { - GetUser(r *http.Request) (*models.User, error) -} - -type AuthMiddleware struct { - userService userService - baseURL string -} - -// SecureHeaders Enforce baseline security headers (CSP, XFO, etc.) to mitigate clickjacking, MIME sniffing, and unsafe embedding by default. +// SecureHeaders enforces baseline security headers (CSP, XFO, etc.) func SecureHeaders(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("Referrer-Policy", "no-referrer") - // Check if this is an embed route - allow iframe embedding isEmbedRoute := strings.HasPrefix(r.URL.Path, "/embed/") || strings.HasPrefix(r.URL.Path, "/embed") + // OAuth provider avatar domains + imgSrc := "img-src 'self' data: https://cdn.simpleicons.org https://*.googleusercontent.com https://avatars.githubusercontent.com https://secure.gravatar.com https://gitlab.com" + if isEmbedRoute { - // Allow embedding from any origin for embed pages - // Do not set X-Frame-Options to allow iframe embedding w.Header().Set("Content-Security-Policy", "default-src 'self'; "+ "style-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com https://fonts.googleapis.com; "+ "font-src 'self' https://fonts.gstatic.com; "+ "script-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com; "+ - "img-src 'self' data: https://cdn.simpleicons.org; "+ + imgSrc+"; "+ "connect-src 'self'; "+ - "frame-ancestors *") // Allow embedding from any origin + "frame-ancestors *") } else { - // Strict headers for non-embed routes w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("Content-Security-Policy", "default-src 'self'; "+ "style-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com https://fonts.googleapis.com; "+ "font-src 'self' https://fonts.gstatic.com; "+ "script-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com; "+ - "img-src 'self' data: https://cdn.simpleicons.org; "+ + imgSrc+"; "+ "connect-src 'self'; "+ "frame-ancestors 'self'") } @@ -56,14 +45,12 @@ func SecureHeaders(next http.Handler) http.Handler { }) } -// RequestLogger Minimal structured logging without PII; record latency and status for ops visibility. func RequestLogger(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sr := &statusRecorder{ResponseWriter: w, status: http.StatusOK} start := time.Now() next.ServeHTTP(sr, r) duration := time.Since(start) - // Minimal structured log to avoid PII logger.Logger.Info("http_request", "method", r.Method, "path", r.URL.Path, @@ -81,8 +68,3 @@ func (sr *statusRecorder) WriteHeader(code int) { sr.status = code sr.ResponseWriter.WriteHeader(code) } - -type ErrorResponse struct { - Error string `json:"error"` - Message string `json:"message,omitempty"` -} diff --git a/backend/internal/presentation/handlers/oembed.go b/backend/internal/presentation/handlers/oembed.go index e4ff921..ba98c5e 100644 --- a/backend/internal/presentation/handlers/oembed.go +++ b/backend/internal/presentation/handlers/oembed.go @@ -5,7 +5,6 @@ import ( "encoding/json" "net/http" "net/url" - "strings" "github.com/btouchard/ackify-ce/backend/pkg/logger" ) @@ -91,29 +90,3 @@ func HandleOEmbed(baseURL string) http.HandlerFunc { "remote_addr", r.RemoteAddr) } } - -// ValidateOEmbedURL checks if the provided URL is a valid Ackify document URL -func ValidateOEmbedURL(urlStr string, baseURL string) bool { - parsedURL, err := url.Parse(urlStr) - if err != nil { - return false - } - - baseURLParsed, err := url.Parse(baseURL) - if err != nil { - return false - } - - // Normalize hosts for comparison (remove ports if present) - urlHost := strings.Split(parsedURL.Host, ":")[0] - baseHost := strings.Split(baseURLParsed.Host, ":")[0] - - // Allow localhost variations - if urlHost == "localhost" || urlHost == "127.0.0.1" { - if baseHost == "localhost" || baseHost == "127.0.0.1" { - return true - } - } - - return urlHost == baseHost -} diff --git a/backend/pkg/types/user.go b/backend/pkg/types/user.go index ebb3d5a..f0dc38f 100644 --- a/backend/pkg/types/user.go +++ b/backend/pkg/types/user.go @@ -7,9 +7,10 @@ import "strings" // This is the canonical user representation used by auth providers, domain models, // and API handlers. type User struct { - Sub string `json:"sub"` // Unique identifier (OAuth sub claim or email for MagicLink) - Email string `json:"email"` // User's email address - Name string `json:"name"` // Display name (optional) + Sub string `json:"sub"` // Unique identifier (OAuth sub claim or email for MagicLink) + Email string `json:"email"` // User's email address + Name string `json:"name"` // Display name (optional) + Picture string `json:"picture,omitempty"` // Avatar URL from OAuth provider (optional) } // IsValid returns true if the user has required fields populated. diff --git a/backend/pkg/web/auth/dynamic_provider.go b/backend/pkg/web/auth/dynamic_provider.go index c8ddff2..e2079c2 100644 --- a/backend/pkg/web/auth/dynamic_provider.go +++ b/backend/pkg/web/auth/dynamic_provider.go @@ -419,6 +419,15 @@ func (p *Provider) parseUserInfo(resp *http.Response) (*types.User, error) { user.Name = preferredName } + // Extract avatar URL (picture for Google/OIDC, avatar_url for GitHub/GitLab) + if picture, ok := rawUser["picture"].(string); ok && picture != "" { + user.Picture = picture + } else if avatarURL, ok := rawUser["avatar_url"].(string); ok && avatarURL != "" { + user.Picture = avatarURL + } else if photo, ok := rawUser["photo"].(string); ok && photo != "" { + user.Picture = photo + } + return user, nil } diff --git a/backend/pkg/web/server.go b/backend/pkg/web/server.go index 051eeb7..633f574 100644 --- a/backend/pkg/web/server.go +++ b/backend/pkg/web/server.go @@ -33,7 +33,6 @@ import ( sdk "github.com/btouchard/shm/sdk/golang" ) -// Server represents the HTTP server with all its dependencies. type Server struct { httpServer *http.Server db *sql.DB @@ -90,7 +89,6 @@ type ServerBuilder struct { configService *services.ConfigService } -// NewServerBuilder creates a new server builder with the required configuration. func NewServerBuilder(cfg *config.Config, frontend embed.FS, version string) *ServerBuilder { return &ServerBuilder{ cfg: cfg, @@ -205,7 +203,6 @@ func (b *ServerBuilder) Build(ctx context.Context) (*Server, error) { }, nil } -// validateProviders checks that required providers are set. func (b *ServerBuilder) validateProviders() error { if b.db == nil { return errors.New("database is required: use WithDB()") @@ -238,7 +235,6 @@ func (b *ServerBuilder) setDefaultProviders() { } } -// initializeInfrastructure initializes signer, i18n, email and storage. func (b *ServerBuilder) initializeInfrastructure() error { var err error @@ -297,7 +293,6 @@ type repositories struct { magicLink services.MagicLinkRepository } -// createRepositories creates all repository instances. func (b *ServerBuilder) createRepositories() *repositories { return &repositories{ signature: database.NewSignatureRepository(b.db, b.tenantProvider), @@ -336,7 +331,6 @@ func (b *ServerBuilder) initializeTelemetry(ctx context.Context) error { return nil } -// initializeWebhookSystem initializes webhook publisher and worker. func (b *ServerBuilder) initializeWebhookSystem(ctx context.Context, repos *repositories) (*services.WebhookPublisher, *webhook.Worker, error) { whPublisher := services.NewWebhookPublisher(repos.webhook, repos.webhookDelivery) whCfg := webhook.DefaultWorkerConfig() @@ -349,7 +343,6 @@ func (b *ServerBuilder) initializeWebhookSystem(ctx context.Context, repos *repo return whPublisher, whWorker, nil } -// initializeEmailWorker initializes email worker for async processing. // emailRenderer is expected to be injected from main.go via WithEmailRenderer(). func (b *ServerBuilder) initializeEmailWorker(ctx context.Context, repos *repositories, whPublisher *services.WebhookPublisher) (*email.Worker, error) { if b.emailSender == nil || b.cfg.Mail.Host == "" || b.emailRenderer == nil { @@ -370,7 +363,6 @@ func (b *ServerBuilder) initializeEmailWorker(ctx context.Context, repos *reposi return emailWorker, nil } -// initializeCoreServices initializes signature, document, admin, and webhook services. func (b *ServerBuilder) initializeCoreServices(repos *repositories) { b.signatureService = services.NewSignatureService(repos.signature, repos.document, b.signer) b.signatureService.SetChecksumConfig(&b.cfg.Checksum) @@ -379,7 +371,6 @@ func (b *ServerBuilder) initializeCoreServices(repos *repositories) { b.webhookService = services.NewWebhookService(repos.webhook, repos.webhookDelivery) } -// initializeConfigService creates and initializes the configuration service. func (b *ServerBuilder) initializeConfigService(ctx context.Context, repos *repositories) error { encryptionKey := b.cfg.OAuth.CookieSecret b.configService = services.NewConfigService(repos.config, b.cfg, encryptionKey) @@ -423,7 +414,6 @@ func (b *ServerBuilder) initializeMagicLinkCleanupWorker(ctx context.Context) *w return magicLinkWorker } -// initializeReminderService initializes reminder service. func (b *ServerBuilder) initializeReminderService(repos *repositories) { b.reminderService = services.NewReminderAsyncService( repos.expectedSigner, @@ -435,7 +425,6 @@ func (b *ServerBuilder) initializeReminderService(repos *repositories) { ) } -// initializeSessionWorker initializes OAuth session cleanup worker. func (b *ServerBuilder) initializeSessionWorker(ctx context.Context, repos *repositories) (*auth.SessionWorker, error) { if repos.oauthSession == nil { return nil, nil @@ -450,7 +439,6 @@ func (b *ServerBuilder) initializeSessionWorker(ctx context.Context, repos *repo return sessionWorker, nil } -// buildRouter creates and configures the main router. func (b *ServerBuilder) buildRouter(repos *repositories, whPublisher *services.WebhookPublisher) *chi.Mux { router := chi.NewRouter() router.Use(i18n.Middleware(b.i18nService)) @@ -554,22 +542,18 @@ func (s *Server) GetDB() *sql.DB { return s.db } -// GetAuthProvider returns the auth provider. func (s *Server) GetAuthProvider() AuthProvider { return s.authProvider } -// GetAuthorizer returns the authorizer. func (s *Server) GetAuthorizer() Authorizer { return s.authorizer } -// GetQuotaEnforcer returns the quota enforcer. func (s *Server) GetQuotaEnforcer() QuotaEnforcer { return s.quotaEnforcer } -// GetAuditLogger returns the audit logger. func (s *Server) GetAuditLogger() AuditLogger { return s.auditLogger } diff --git a/webapp/src/components/DocumentCreateForm.vue b/webapp/src/components/DocumentCreateForm.vue index 4715e26..05a72c4 100644 --- a/webapp/src/components/DocumentCreateForm.vue +++ b/webapp/src/components/DocumentCreateForm.vue @@ -234,7 +234,7 @@ watch(readMode, (newMode) => {
-
+
@@ -242,13 +242,13 @@ watch(readMode, (newMode) => {
- {{ selectedFile.name }} - {{ formatFileSize(selectedFile.size) }} + {{ selectedFile.name }} + {{ formatFileSize(selectedFile.size) }}