refacto(backend): cleaning dead code

This commit is contained in:
Benjamin
2026-01-17 01:49:41 +01:00
parent 493d915fa7
commit 998d227898
26 changed files with 106 additions and 3092 deletions

View File

@@ -1,42 +0,0 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package services
import (
"strings"
)
// AuthorizerService provides authorization decisions for the organization.
// In Community Edition, it uses environment variables for configuration.
type AuthorizerService struct {
adminEmails []string
onlyAdminCanCreate bool
}
// NewAuthorizerService creates a new AuthorizerService with the given configuration.
func NewAuthorizerService(adminEmails []string, onlyAdminCanCreate bool) *AuthorizerService {
// Normalize admin emails to lowercase for case-insensitive comparison
normalized := make([]string, len(adminEmails))
for i, email := range adminEmails {
normalized[i] = strings.ToLower(strings.TrimSpace(email))
}
return &AuthorizerService{
adminEmails: normalized,
onlyAdminCanCreate: onlyAdminCanCreate,
}
}
// IsAdmin checks if the given email belongs to an administrator.
func (s *AuthorizerService) IsAdmin(email string) bool {
normalizedEmail := strings.ToLower(strings.TrimSpace(email))
for _, adminEmail := range s.adminEmails {
if normalizedEmail == adminEmail {
return true
}
}
return false
}
// OnlyAdminCanCreate returns whether only administrators can create documents.
func (s *AuthorizerService) OnlyAdminCanCreate() bool {
return s.onlyAdminCanCreate
}

View File

@@ -1,218 +0,0 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package services
import (
"context"
"fmt"
"regexp"
"strings"
"time"
"github.com/btouchard/ackify-ce/backend/pkg/logger"
"github.com/btouchard/ackify-ce/backend/pkg/models"
)
// ChecksumVerificationRepository defines the interface for checksum verification persistence
type ChecksumVerificationRepository interface {
RecordVerification(ctx context.Context, verification *models.ChecksumVerification) error
GetVerificationHistory(ctx context.Context, docID string, limit int) ([]*models.ChecksumVerification, error)
GetLastVerification(ctx context.Context, docID string) (*models.ChecksumVerification, error)
}
// DocumentRepository defines the interface for document metadata operations
type DocumentRepository interface {
GetByDocID(ctx context.Context, docID string) (*models.Document, error)
}
// ChecksumService orchestrates document integrity verification with audit trail persistence
type ChecksumService struct {
verificationRepo ChecksumVerificationRepository
documentRepo DocumentRepository
}
// NewChecksumService initializes checksum verification service with required repository dependencies
func NewChecksumService(
verificationRepo ChecksumVerificationRepository,
documentRepo DocumentRepository,
) *ChecksumService {
return &ChecksumService{
verificationRepo: verificationRepo,
documentRepo: documentRepo,
}
}
// ValidateChecksumFormat ensures checksum matches expected hexadecimal length for SHA-256/SHA-512/MD5
func (s *ChecksumService) ValidateChecksumFormat(checksum, algorithm string) error {
// Remove common separators and whitespace
checksum = normalizeChecksum(checksum)
var expectedLength int
switch algorithm {
case "SHA-256":
expectedLength = 64
case "SHA-512":
expectedLength = 128
case "MD5":
expectedLength = 32
default:
return fmt.Errorf("unsupported algorithm: %s", algorithm)
}
// Check length
if len(checksum) != expectedLength {
return fmt.Errorf("invalid checksum length for %s: expected %d hexadecimal characters, got %d", algorithm, expectedLength, len(checksum))
}
// Check if it's a valid hex string
hexPattern := regexp.MustCompile("^[a-fA-F0-9]+$")
if !hexPattern.MatchString(checksum) {
return fmt.Errorf("invalid checksum format: must contain only hexadecimal characters (0-9, a-f, A-F)")
}
return nil
}
// VerifyChecksum compares calculated hash against stored reference and creates immutable audit record
func (s *ChecksumService) VerifyChecksum(ctx context.Context, docID, calculatedChecksum, verifiedBy string) (*models.ChecksumVerificationResult, error) {
// Get document metadata
doc, err := s.documentRepo.GetByDocID(ctx, docID)
if err != nil {
return nil, fmt.Errorf("failed to get document: %w", err)
}
if doc == nil {
return nil, fmt.Errorf("document not found: %s", docID)
}
// Normalize checksums for comparison
normalizedCalculated := normalizeChecksum(calculatedChecksum)
normalizedStored := normalizeChecksum(doc.Checksum)
// Determine the algorithm to use (from document or default to SHA-256)
algorithm := doc.ChecksumAlgorithm
if algorithm == "" {
algorithm = "SHA-256"
}
// Validate the calculated checksum format
if err := s.ValidateChecksumFormat(normalizedCalculated, algorithm); err != nil {
// Record failed verification with error
errorMsg := err.Error()
verification := &models.ChecksumVerification{
DocID: docID,
VerifiedBy: verifiedBy,
VerifiedAt: time.Now(),
StoredChecksum: normalizedStored,
CalculatedChecksum: normalizedCalculated,
Algorithm: algorithm,
IsValid: false,
ErrorMessage: &errorMsg,
}
_ = s.verificationRepo.RecordVerification(ctx, verification)
return nil, fmt.Errorf("invalid checksum format: %w", err)
}
// Check if document has a reference checksum
if !doc.HasChecksum() {
result := &models.ChecksumVerificationResult{
Valid: false,
StoredChecksum: "",
CalculatedChecksum: normalizedCalculated,
Algorithm: algorithm,
Message: "No reference checksum configured for this document",
HasReferenceHash: false,
}
return result, nil
}
// Compare checksums (case-insensitive)
isValid := strings.EqualFold(normalizedCalculated, normalizedStored)
// Record verification
verification := &models.ChecksumVerification{
DocID: docID,
VerifiedBy: verifiedBy,
VerifiedAt: time.Now(),
StoredChecksum: normalizedStored,
CalculatedChecksum: normalizedCalculated,
Algorithm: algorithm,
IsValid: isValid,
ErrorMessage: nil,
}
if err := s.verificationRepo.RecordVerification(ctx, verification); err != nil {
logger.Logger.Error("Failed to record verification", "error", err.Error(), "doc_id", docID)
// Continue even if recording fails - return the result
}
var message string
if isValid {
message = "Checksums match - document integrity verified"
} else {
message = "Checksums do not match - document may have been modified"
}
result := &models.ChecksumVerificationResult{
Valid: isValid,
StoredChecksum: normalizedStored,
CalculatedChecksum: normalizedCalculated,
Algorithm: algorithm,
Message: message,
HasReferenceHash: true,
}
return result, nil
}
// GetVerificationHistory retrieves paginated audit trail of all checksum validation attempts
func (s *ChecksumService) GetVerificationHistory(ctx context.Context, docID string, limit int) ([]*models.ChecksumVerification, error) {
if limit <= 0 {
limit = 20
}
return s.verificationRepo.GetVerificationHistory(ctx, docID, limit)
}
// GetSupportedAlgorithms returns available hash algorithms for client-side documentation
func (s *ChecksumService) GetSupportedAlgorithms() []string {
return []string{"SHA-256", "SHA-512", "MD5"}
}
// GetChecksumInfo exposes document hash metadata for public verification interfaces
func (s *ChecksumService) GetChecksumInfo(ctx context.Context, docID string) (map[string]interface{}, error) {
doc, err := s.documentRepo.GetByDocID(ctx, docID)
if err != nil {
return nil, fmt.Errorf("failed to get document: %w", err)
}
if doc == nil {
return nil, fmt.Errorf("document not found: %s", docID)
}
algorithm := doc.ChecksumAlgorithm
if algorithm == "" {
algorithm = "SHA-256"
}
info := map[string]interface{}{
"doc_id": docID,
"has_checksum": doc.HasChecksum(),
"algorithm": algorithm,
"checksum_length": doc.GetExpectedChecksumLength(),
"supported_algorithms": s.GetSupportedAlgorithms(),
}
return info, nil
}
// normalizeChecksum removes common separators and converts to lowercase
func normalizeChecksum(checksum string) string {
// Remove spaces, hyphens, underscores
checksum = strings.ReplaceAll(checksum, " ", "")
checksum = strings.ReplaceAll(checksum, "-", "")
checksum = strings.ReplaceAll(checksum, "_", "")
checksum = strings.TrimSpace(checksum)
// Convert to lowercase for case-insensitive comparison
return strings.ToLower(checksum)
}

View File

@@ -1,483 +0,0 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package services
import (
"context"
"errors"
"testing"
"time"
"github.com/btouchard/ackify-ce/backend/pkg/models"
)
type fakeVerificationRepository struct {
verifications []*models.ChecksumVerification
shouldFailRecord bool
shouldFailGetHistory bool
shouldFailGetLast bool
}
func newFakeVerificationRepository() *fakeVerificationRepository {
return &fakeVerificationRepository{
verifications: make([]*models.ChecksumVerification, 0),
}
}
func (f *fakeVerificationRepository) RecordVerification(_ context.Context, verification *models.ChecksumVerification) error {
if f.shouldFailRecord {
return errors.New("repository record failed")
}
verification.ID = int64(len(f.verifications) + 1)
f.verifications = append(f.verifications, verification)
return nil
}
func (f *fakeVerificationRepository) GetVerificationHistory(_ context.Context, docID string, limit int) ([]*models.ChecksumVerification, error) {
if f.shouldFailGetHistory {
return nil, errors.New("repository get history failed")
}
var result []*models.ChecksumVerification
for _, v := range f.verifications {
if v.DocID == docID {
result = append(result, v)
if len(result) >= limit {
break
}
}
}
return result, nil
}
func (f *fakeVerificationRepository) GetLastVerification(_ context.Context, docID string) (*models.ChecksumVerification, error) {
if f.shouldFailGetLast {
return nil, errors.New("repository get last failed")
}
for i := len(f.verifications) - 1; i >= 0; i-- {
if f.verifications[i].DocID == docID {
return f.verifications[i], nil
}
}
return nil, nil
}
type fakeDocumentRepository struct {
documents map[string]*models.Document
shouldFailGet bool
}
func newFakeDocumentRepository() *fakeDocumentRepository {
return &fakeDocumentRepository{
documents: make(map[string]*models.Document),
}
}
func (f *fakeDocumentRepository) GetByDocID(_ context.Context, docID string) (*models.Document, error) {
if f.shouldFailGet {
return nil, errors.New("repository get failed")
}
doc, exists := f.documents[docID]
if !exists {
return nil, nil
}
return doc, nil
}
func (f *fakeDocumentRepository) Create(_ context.Context, docID string, input models.DocumentInput, createdBy string) (*models.Document, error) {
if f.shouldFailGet {
return nil, errors.New("repository create failed")
}
doc := &models.Document{
DocID: docID,
Title: input.Title,
URL: input.URL,
Checksum: input.Checksum,
ChecksumAlgorithm: input.ChecksumAlgorithm,
Description: input.Description,
CreatedBy: createdBy,
}
f.documents[docID] = doc
return doc, nil
}
func (f *fakeDocumentRepository) FindByReference(_ context.Context, ref string, refType string) (*models.Document, error) {
if f.shouldFailGet {
return nil, errors.New("repository find failed")
}
for _, doc := range f.documents {
if doc.URL == ref {
return doc, nil
}
}
return nil, nil
}
func (f *fakeDocumentRepository) List(_ context.Context, _, _ int) ([]*models.Document, error) {
result := make([]*models.Document, 0, len(f.documents))
for _, doc := range f.documents {
result = append(result, doc)
}
return result, nil
}
func (f *fakeDocumentRepository) Search(_ context.Context, _ string, _, _ int) ([]*models.Document, error) {
return []*models.Document{}, nil
}
func (f *fakeDocumentRepository) Count(_ context.Context, _ string) (int, error) {
return len(f.documents), nil
}
func (f *fakeDocumentRepository) ListByCreatedBy(_ context.Context, _ string, _, _ int) ([]*models.Document, error) {
return []*models.Document{}, nil
}
func (f *fakeDocumentRepository) SearchByCreatedBy(_ context.Context, _, _ string, _, _ int) ([]*models.Document, error) {
return []*models.Document{}, nil
}
func (f *fakeDocumentRepository) CountByCreatedBy(_ context.Context, _, _ string) (int, error) {
return 0, nil
}
func TestChecksumService_ValidateChecksumFormat(t *testing.T) {
service := NewChecksumService(newFakeVerificationRepository(), newFakeDocumentRepository())
tests := []struct {
name string
checksum string
algorithm string
wantError bool
}{
{
name: "valid SHA-256",
checksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
algorithm: "SHA-256",
wantError: false,
},
{
name: "valid SHA-512",
checksum: "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e",
algorithm: "SHA-512",
wantError: false,
},
{
name: "valid MD5",
checksum: "d41d8cd98f00b204e9800998ecf8427e",
algorithm: "MD5",
wantError: false,
},
{
name: "valid with uppercase",
checksum: "E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855",
algorithm: "SHA-256",
wantError: false,
},
{
name: "valid with spaces",
checksum: "e3b0c442 98fc1c14 9afbf4c8 996fb924 27ae41e4 649b934c a495991b 7852b855",
algorithm: "SHA-256",
wantError: false,
},
{
name: "valid with hyphens",
checksum: "e3b0c442-98fc1c14-9afbf4c8-996fb924-27ae41e4-649b934c-a495991b-7852b855",
algorithm: "SHA-256",
wantError: false,
},
{
name: "invalid - too short for SHA-256",
checksum: "abc123",
algorithm: "SHA-256",
wantError: true,
},
{
name: "invalid - too long for MD5",
checksum: "d41d8cd98f00b204e9800998ecf8427eextra",
algorithm: "MD5",
wantError: true,
},
{
name: "invalid - non-hex characters",
checksum: "gggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg",
algorithm: "SHA-256",
wantError: true,
},
{
name: "invalid - unsupported algorithm",
checksum: "abc123",
algorithm: "SHA-1",
wantError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := service.ValidateChecksumFormat(tt.checksum, tt.algorithm)
if tt.wantError && err == nil {
t.Error("expected error, got nil")
}
if !tt.wantError && err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
}
func TestChecksumService_VerifyChecksum(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
docID string
document *models.Document
calculatedChecksum string
verifiedBy string
wantValid bool
wantHasReference bool
wantError bool
}{
{
name: "valid verification - checksums match",
docID: "doc-001",
document: &models.Document{
DocID: "doc-001",
Checksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
ChecksumAlgorithm: "SHA-256",
},
calculatedChecksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
verifiedBy: "user@example.com",
wantValid: true,
wantHasReference: true,
wantError: false,
},
{
name: "invalid verification - checksums differ",
docID: "doc-002",
document: &models.Document{
DocID: "doc-002",
Checksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
ChecksumAlgorithm: "SHA-256",
},
calculatedChecksum: "0000000000000000000000000000000000000000000000000000000000000000",
verifiedBy: "user@example.com",
wantValid: false,
wantHasReference: true,
wantError: false,
},
{
name: "no reference checksum",
docID: "doc-003",
document: &models.Document{
DocID: "doc-003",
Checksum: "",
ChecksumAlgorithm: "SHA-256",
},
calculatedChecksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
verifiedBy: "user@example.com",
wantValid: false,
wantHasReference: false,
wantError: false,
},
{
name: "case insensitive comparison",
docID: "doc-004",
document: &models.Document{
DocID: "doc-004",
Checksum: "E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855",
ChecksumAlgorithm: "SHA-256",
},
calculatedChecksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
verifiedBy: "user@example.com",
wantValid: true,
wantHasReference: true,
wantError: false,
},
{
name: "document not found",
docID: "non-existent",
document: nil,
calculatedChecksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
verifiedBy: "user@example.com",
wantValid: false,
wantHasReference: false,
wantError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
verificationRepo := newFakeVerificationRepository()
documentRepo := newFakeDocumentRepository()
if tt.document != nil {
documentRepo.documents[tt.docID] = tt.document
}
service := NewChecksumService(verificationRepo, documentRepo)
result, err := service.VerifyChecksum(ctx, tt.docID, tt.calculatedChecksum, tt.verifiedBy)
if tt.wantError {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if result.Valid != tt.wantValid {
t.Errorf("expected Valid=%v, got %v", tt.wantValid, result.Valid)
}
if result.HasReferenceHash != tt.wantHasReference {
t.Errorf("expected HasReferenceHash=%v, got %v", tt.wantHasReference, result.HasReferenceHash)
}
// Check that verification was recorded (if document has checksum)
if tt.wantHasReference {
if len(verificationRepo.verifications) != 1 {
t.Errorf("expected 1 verification recorded, got %d", len(verificationRepo.verifications))
} else {
v := verificationRepo.verifications[0]
if v.IsValid != tt.wantValid {
t.Errorf("recorded verification IsValid=%v, expected %v", v.IsValid, tt.wantValid)
}
if v.VerifiedBy != tt.verifiedBy {
t.Errorf("recorded verification VerifiedBy=%s, expected %s", v.VerifiedBy, tt.verifiedBy)
}
}
}
})
}
}
func TestChecksumService_GetVerificationHistory(t *testing.T) {
ctx := context.Background()
verificationRepo := newFakeVerificationRepository()
documentRepo := newFakeDocumentRepository()
service := NewChecksumService(verificationRepo, documentRepo)
// Add test verifications
for i := 0; i < 5; i++ {
v := &models.ChecksumVerification{
DocID: "doc-001",
VerifiedBy: "user@example.com",
VerifiedAt: time.Now(),
StoredChecksum: "abc123",
CalculatedChecksum: "abc123",
Algorithm: "SHA-256",
IsValid: true,
}
_ = verificationRepo.RecordVerification(ctx, v)
}
// Test get all
history, err := service.GetVerificationHistory(ctx, "doc-001", 10)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(history) != 5 {
t.Errorf("expected 5 verifications, got %d", len(history))
}
// Test with limit
limited, err := service.GetVerificationHistory(ctx, "doc-001", 2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(limited) != 2 {
t.Errorf("expected 2 verifications with limit, got %d", len(limited))
}
}
func TestChecksumService_GetChecksumInfo(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
docID string
document *models.Document
wantError bool
}{
{
name: "document with checksum",
docID: "doc-001",
document: &models.Document{
DocID: "doc-001",
Checksum: "abc123",
ChecksumAlgorithm: "SHA-256",
},
wantError: false,
},
{
name: "document without checksum",
docID: "doc-002",
document: &models.Document{
DocID: "doc-002",
Checksum: "",
ChecksumAlgorithm: "SHA-256",
},
wantError: false,
},
{
name: "document not found",
docID: "non-existent",
document: nil,
wantError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
documentRepo := newFakeDocumentRepository()
if tt.document != nil {
documentRepo.documents[tt.docID] = tt.document
}
service := NewChecksumService(newFakeVerificationRepository(), documentRepo)
info, err := service.GetChecksumInfo(ctx, tt.docID)
if tt.wantError {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if info["doc_id"] != tt.docID {
t.Errorf("expected doc_id %s, got %v", tt.docID, info["doc_id"])
}
if _, ok := info["has_checksum"]; !ok {
t.Error("expected has_checksum field")
}
if _, ok := info["supported_algorithms"]; !ok {
t.Error("expected supported_algorithms field")
}
})
}
}

View File

@@ -1,238 +0,0 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package services
import (
"context"
"fmt"
"time"
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/email"
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/i18n"
"github.com/btouchard/ackify-ce/backend/pkg/logger"
"github.com/btouchard/ackify-ce/backend/pkg/models"
)
// expectedSignerRepository defines minimal interface for expected signer operations
type expectedSignerRepository interface {
ListWithStatusByDocID(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error)
}
// reminderRepository defines minimal interface for reminder logging and history
type reminderRepository interface {
LogReminder(ctx context.Context, log *models.ReminderLog) error
GetReminderHistory(ctx context.Context, docID string) ([]*models.ReminderLog, error)
GetReminderStats(ctx context.Context, docID string) (*models.ReminderStats, error)
}
// magicLinkService defines minimal interface for creating reminder auth tokens
type magicLinkService interface {
CreateReminderAuthToken(ctx context.Context, email string, docID string) (string, error)
}
// ReminderService manages email notifications to pending signers with delivery tracking
type ReminderService struct {
expectedSignerRepo expectedSignerRepository
reminderRepo reminderRepository
emailSender email.Sender
magicLinkService magicLinkService
i18n *i18n.I18n
baseURL string
}
// NewReminderService initializes reminder service with email sender and repository dependencies
func NewReminderService(
expectedSignerRepo expectedSignerRepository,
reminderRepo reminderRepository,
emailSender email.Sender,
magicLinkService magicLinkService,
i18nService *i18n.I18n,
baseURL string,
) *ReminderService {
return &ReminderService{
expectedSignerRepo: expectedSignerRepo,
reminderRepo: reminderRepo,
emailSender: emailSender,
magicLinkService: magicLinkService,
i18n: i18nService,
baseURL: baseURL,
}
}
// SendReminders dispatches email notifications to all or selected pending signers with result aggregation
func (s *ReminderService) SendReminders(
ctx context.Context,
docID string,
sentBy string,
specificEmails []string,
docURL string,
locale string,
) (*models.ReminderSendResult, error) {
logger.Logger.Info("Starting reminder sending process",
"doc_id", docID,
"sent_by", sentBy,
"specific_emails_count", len(specificEmails),
"locale", locale)
allSigners, err := s.expectedSignerRepo.ListWithStatusByDocID(ctx, docID)
if err != nil {
logger.Logger.Error("Failed to get expected signers for reminders",
"doc_id", docID,
"error", err.Error())
return nil, fmt.Errorf("failed to get expected signers: %w", err)
}
logger.Logger.Debug("Retrieved expected signers",
"doc_id", docID,
"total_signers", len(allSigners))
var pendingSigners []*models.ExpectedSignerWithStatus
for _, signer := range allSigners {
if !signer.HasSigned {
if len(specificEmails) > 0 {
if containsEmail(specificEmails, signer.Email) {
pendingSigners = append(pendingSigners, signer)
}
} else {
pendingSigners = append(pendingSigners, signer)
}
}
}
logger.Logger.Info("Identified pending signers",
"doc_id", docID,
"pending_count", len(pendingSigners),
"total_signers", len(allSigners))
if len(pendingSigners) == 0 {
logger.Logger.Info("No pending signers found, no reminders to send",
"doc_id", docID)
return &models.ReminderSendResult{
TotalAttempted: 0,
SuccessfullySent: 0,
Failed: 0,
}, nil
}
result := &models.ReminderSendResult{
TotalAttempted: len(pendingSigners),
}
for _, signer := range pendingSigners {
err := s.sendSingleReminder(ctx, docID, signer.Email, signer.Name, sentBy, docURL, locale)
if err != nil {
result.Failed++
result.Errors = append(result.Errors, fmt.Sprintf("%s: %v", signer.Email, err))
} else {
result.SuccessfullySent++
}
}
logger.Logger.Info("Reminder batch completed",
"doc_id", docID,
"total_attempted", result.TotalAttempted,
"successfully_sent", result.SuccessfullySent,
"failed", result.Failed)
return result, nil
}
// sendSingleReminder sends a reminder to a single signer
func (s *ReminderService) sendSingleReminder(
ctx context.Context,
docID string,
recipientEmail string,
recipientName string,
sentBy string,
docURL string,
locale string,
) error {
logger.Logger.Debug("Sending reminder to signer",
"doc_id", docID,
"recipient_email", recipientEmail,
"recipient_name", recipientName,
"sent_by", sentBy)
// Générer un token d'authentification pour ce lecteur
token, err := s.magicLinkService.CreateReminderAuthToken(ctx, recipientEmail, docID)
if err != nil {
logger.Logger.Error("Failed to create reminder auth token",
"doc_id", docID,
"recipient_email", recipientEmail,
"error", err.Error())
return fmt.Errorf("failed to create auth token: %w", err)
}
// Construire l'URL d'authentification qui redirigera vers la page de signature
authSignURL := fmt.Sprintf("%s/api/v1/auth/reminder-link/verify?token=%s", s.baseURL, token)
logger.Logger.Debug("Generated auth sign URL for reminder",
"doc_id", docID,
"recipient_email", recipientEmail,
"url", authSignURL)
log := &models.ReminderLog{
DocID: docID,
RecipientEmail: recipientEmail,
SentAt: time.Now(),
SentBy: sentBy,
TemplateUsed: "signature_reminder",
Status: "sent",
}
err = email.SendSignatureReminderEmail(ctx, s.emailSender, s.i18n, []string{recipientEmail}, locale, docID, docURL, authSignURL, recipientName)
if err != nil {
log.Status = "failed"
errMsg := err.Error()
log.ErrorMessage = &errMsg
logger.Logger.Warn("Failed to send reminder email",
"doc_id", docID,
"recipient_email", recipientEmail,
"error", err.Error())
if logErr := s.reminderRepo.LogReminder(ctx, log); logErr != nil {
logger.Logger.Error("Failed to log reminder error",
"doc_id", docID,
"recipient_email", recipientEmail,
"log_error", logErr.Error(),
"original_error", err.Error())
}
return fmt.Errorf("failed to send email: %w", err)
}
logger.Logger.Info("Reminder email sent successfully",
"doc_id", docID,
"recipient_email", recipientEmail)
if err := s.reminderRepo.LogReminder(ctx, log); err != nil {
logger.Logger.Error("Failed to log successful reminder",
"doc_id", docID,
"recipient_email", recipientEmail,
"error", err.Error())
return fmt.Errorf("email sent but failed to log: %w", err)
}
return nil
}
// GetReminderStats retrieves aggregated reminder metrics for monitoring dashboard
func (s *ReminderService) GetReminderStats(ctx context.Context, docID string) (*models.ReminderStats, error) {
return s.reminderRepo.GetReminderStats(ctx, docID)
}
// GetReminderHistory retrieves complete email send log with success/failure tracking
func (s *ReminderService) GetReminderHistory(ctx context.Context, docID string) ([]*models.ReminderLog, error) {
return s.reminderRepo.GetReminderHistory(ctx, docID)
}
func containsEmail(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}

View File

@@ -315,3 +315,12 @@ func (s *ReminderAsyncService) CountSent(ctx context.Context) int {
}
return c
}
func containsEmail(emails []string, target string) bool {
for _, e := range emails {
if e == target {
return true
}
}
return false
}

View File

@@ -1,536 +0,0 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package services
import (
"context"
"errors"
"testing"
"time"
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/email"
"github.com/btouchard/ackify-ce/backend/pkg/models"
)
// Mock implementations for testing
type mockExpectedSignerRepository struct {
listWithStatusFunc func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error)
}
func (m *mockExpectedSignerRepository) ListWithStatusByDocID(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
if m.listWithStatusFunc != nil {
return m.listWithStatusFunc(ctx, docID)
}
return nil, nil
}
type mockReminderRepository struct {
logReminderFunc func(ctx context.Context, log *models.ReminderLog) error
getReminderHistoryFunc func(ctx context.Context, docID string) ([]*models.ReminderLog, error)
getReminderStatsFunc func(ctx context.Context, docID string) (*models.ReminderStats, error)
}
func (m *mockReminderRepository) LogReminder(ctx context.Context, log *models.ReminderLog) error {
if m.logReminderFunc != nil {
return m.logReminderFunc(ctx, log)
}
return nil
}
func (m *mockReminderRepository) GetReminderHistory(ctx context.Context, docID string) ([]*models.ReminderLog, error) {
if m.getReminderHistoryFunc != nil {
return m.getReminderHistoryFunc(ctx, docID)
}
return nil, nil
}
func (m *mockReminderRepository) GetReminderStats(ctx context.Context, docID string) (*models.ReminderStats, error) {
if m.getReminderStatsFunc != nil {
return m.getReminderStatsFunc(ctx, docID)
}
return nil, nil
}
type mockEmailSender struct {
sendFunc func(ctx context.Context, msg email.Message) error
}
func (m *mockEmailSender) Send(ctx context.Context, msg email.Message) error {
if m.sendFunc != nil {
return m.sendFunc(ctx, msg)
}
return nil
}
type mockMagicLinkService struct {
createReminderAuthTokenFunc func(ctx context.Context, email string, docID string) (string, error)
}
func (m *mockMagicLinkService) CreateReminderAuthToken(ctx context.Context, email string, docID string) (string, error) {
if m.createReminderAuthTokenFunc != nil {
return m.createReminderAuthTokenFunc(ctx, email, docID)
}
return "mock-token-123", nil
}
// Test helper function
func TestContainsEmail(t *testing.T) {
t.Parallel()
tests := []struct {
name string
slice []string
item string
expected bool
}{
{
name: "Email found",
slice: []string{"alice@example.com", "bob@example.com", "charlie@example.com"},
item: "bob@example.com",
expected: true,
},
{
name: "Email not found",
slice: []string{"alice@example.com", "bob@example.com"},
item: "charlie@example.com",
expected: false,
},
{
name: "Empty slice",
slice: []string{},
item: "test@example.com",
expected: false,
},
{
name: "Case sensitive",
slice: []string{"Test@Example.com"},
item: "test@example.com",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := containsEmail(tt.slice, tt.item)
if result != tt.expected {
t.Errorf("containsEmail(%v, %q) = %v, want %v", tt.slice, tt.item, result, tt.expected)
}
})
}
}
// Test SendReminders with no pending signers
func TestReminderService_SendReminders_NoPendingSigners(t *testing.T) {
t.Parallel()
ctx := context.Background()
mockExpectedRepo := &mockExpectedSignerRepository{
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
return []*models.ExpectedSignerWithStatus{
{ExpectedSigner: models.ExpectedSigner{Email: "signed@example.com"}, HasSigned: true},
}, nil
},
}
mockReminderRepo := &mockReminderRepository{}
mockEmailSender := &mockEmailSender{}
mockMagicLink := &mockMagicLinkService{}
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com")
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en")
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if result.TotalAttempted != 0 {
t.Errorf("Expected 0 total attempted, got %d", result.TotalAttempted)
}
if result.SuccessfullySent != 0 {
t.Errorf("Expected 0 successfully sent, got %d", result.SuccessfullySent)
}
}
// Test SendReminders with successful email send
func TestReminderService_SendReminders_Success(t *testing.T) {
t.Parallel()
ctx := context.Background()
mockExpectedRepo := &mockExpectedSignerRepository{
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
return []*models.ExpectedSignerWithStatus{
{ExpectedSigner: models.ExpectedSigner{Email: "pending@example.com", Name: "Pending User"}, HasSigned: false},
}, nil
},
}
loggedReminder := false
mockReminderRepo := &mockReminderRepository{
logReminderFunc: func(ctx context.Context, log *models.ReminderLog) error {
loggedReminder = true
if log.Status != "sent" {
t.Errorf("Expected status 'sent', got '%s'", log.Status)
}
return nil
},
}
emailSent := false
mockEmailSender := &mockEmailSender{
sendFunc: func(ctx context.Context, msg email.Message) error {
emailSent = true
if len(msg.To) != 1 || msg.To[0] != "pending@example.com" {
t.Errorf("Expected email to 'pending@example.com', got %v", msg.To)
}
return nil
},
}
mockMagicLink := &mockMagicLinkService{}
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com")
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en")
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if result.TotalAttempted != 1 {
t.Errorf("Expected 1 total attempted, got %d", result.TotalAttempted)
}
if result.SuccessfullySent != 1 {
t.Errorf("Expected 1 successfully sent, got %d", result.SuccessfullySent)
}
if result.Failed != 0 {
t.Errorf("Expected 0 failed, got %d", result.Failed)
}
if !emailSent {
t.Error("Expected email to be sent")
}
if !loggedReminder {
t.Error("Expected reminder to be logged")
}
}
// Test SendReminders with email failure
func TestReminderService_SendReminders_EmailFailure(t *testing.T) {
t.Parallel()
ctx := context.Background()
mockExpectedRepo := &mockExpectedSignerRepository{
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
return []*models.ExpectedSignerWithStatus{
{ExpectedSigner: models.ExpectedSigner{Email: "pending@example.com", Name: "Pending User"}, HasSigned: false},
}, nil
},
}
loggedReminder := false
mockReminderRepo := &mockReminderRepository{
logReminderFunc: func(ctx context.Context, log *models.ReminderLog) error {
loggedReminder = true
if log.Status != "failed" {
t.Errorf("Expected status 'failed', got '%s'", log.Status)
}
if log.ErrorMessage == nil {
t.Error("Expected error message to be set")
}
return nil
},
}
mockEmailSender := &mockEmailSender{
sendFunc: func(ctx context.Context, msg email.Message) error {
return errors.New("SMTP connection failed")
},
}
mockMagicLink := &mockMagicLinkService{}
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com")
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en")
if err != nil {
t.Fatalf("Expected no error from SendReminders, got: %v", err)
}
if result.TotalAttempted != 1 {
t.Errorf("Expected 1 total attempted, got %d", result.TotalAttempted)
}
if result.Failed != 1 {
t.Errorf("Expected 1 failed, got %d", result.Failed)
}
if result.SuccessfullySent != 0 {
t.Errorf("Expected 0 successfully sent, got %d", result.SuccessfullySent)
}
if len(result.Errors) != 1 {
t.Errorf("Expected 1 error message, got %d", len(result.Errors))
}
if !loggedReminder {
t.Error("Expected failed reminder to be logged")
}
}
// Test SendReminders with specific emails filter
func TestReminderService_SendReminders_SpecificEmails(t *testing.T) {
t.Parallel()
ctx := context.Background()
mockExpectedRepo := &mockExpectedSignerRepository{
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
return []*models.ExpectedSignerWithStatus{
{ExpectedSigner: models.ExpectedSigner{Email: "pending1@example.com"}, HasSigned: false},
{ExpectedSigner: models.ExpectedSigner{Email: "pending2@example.com"}, HasSigned: false},
{ExpectedSigner: models.ExpectedSigner{Email: "pending3@example.com"}, HasSigned: false},
}, nil
},
}
emailsSent := []string{}
mockReminderRepo := &mockReminderRepository{
logReminderFunc: func(ctx context.Context, log *models.ReminderLog) error {
return nil
},
}
mockEmailSender := &mockEmailSender{
sendFunc: func(ctx context.Context, msg email.Message) error {
emailsSent = append(emailsSent, msg.To[0])
return nil
},
}
mockMagicLink := &mockMagicLinkService{}
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com")
specificEmails := []string{"pending2@example.com"}
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", specificEmails, "https://example.com/doc.pdf", "en")
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if result.TotalAttempted != 1 {
t.Errorf("Expected 1 total attempted, got %d", result.TotalAttempted)
}
if len(emailsSent) != 1 || emailsSent[0] != "pending2@example.com" {
t.Errorf("Expected only 'pending2@example.com' to receive email, got %v", emailsSent)
}
}
// Test SendReminders with repository error
func TestReminderService_SendReminders_RepositoryError(t *testing.T) {
t.Parallel()
ctx := context.Background()
mockExpectedRepo := &mockExpectedSignerRepository{
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
return nil, errors.New("database connection failed")
},
}
mockReminderRepo := &mockReminderRepository{}
mockEmailSender := &mockEmailSender{}
mockMagicLink := &mockMagicLinkService{}
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com")
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en")
if err == nil {
t.Fatal("Expected error, got nil")
}
if result != nil {
t.Errorf("Expected nil result on error, got %v", result)
}
}
// Test GetReminderHistory
func TestReminderService_GetReminderHistory(t *testing.T) {
t.Parallel()
ctx := context.Background()
expectedLogs := []*models.ReminderLog{
{
DocID: "doc1",
RecipientEmail: "user@example.com",
SentAt: time.Now(),
SentBy: "admin@example.com",
Status: "sent",
},
}
mockReminderRepo := &mockReminderRepository{
getReminderHistoryFunc: func(ctx context.Context, docID string) ([]*models.ReminderLog, error) {
if docID != "doc1" {
t.Errorf("Expected docID 'doc1', got '%s'", docID)
}
return expectedLogs, nil
},
}
service := NewReminderService(&mockExpectedSignerRepository{}, mockReminderRepo, &mockEmailSender{}, &mockMagicLinkService{}, nil, "https://example.com")
logs, err := service.GetReminderHistory(ctx, "doc1")
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if len(logs) != 1 {
t.Errorf("Expected 1 log, got %d", len(logs))
}
if logs[0].RecipientEmail != "user@example.com" {
t.Errorf("Expected recipient 'user@example.com', got '%s'", logs[0].RecipientEmail)
}
}
// Test GetReminderStats
func TestReminderService_GetReminderStats(t *testing.T) {
t.Parallel()
ctx := context.Background()
now := time.Now()
expectedStats := &models.ReminderStats{
TotalSent: 5,
LastSentAt: &now,
PendingCount: 2,
}
mockReminderRepo := &mockReminderRepository{
getReminderStatsFunc: func(ctx context.Context, docID string) (*models.ReminderStats, error) {
if docID != "doc1" {
t.Errorf("Expected docID 'doc1', got '%s'", docID)
}
return expectedStats, nil
},
}
service := NewReminderService(&mockExpectedSignerRepository{}, mockReminderRepo, &mockEmailSender{}, &mockMagicLinkService{}, nil, "https://example.com")
stats, err := service.GetReminderStats(ctx, "doc1")
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if stats.TotalSent != 5 {
t.Errorf("Expected 5 total sent, got %d", stats.TotalSent)
}
if stats.PendingCount != 2 {
t.Errorf("Expected 2 pending, got %d", stats.PendingCount)
}
if stats.LastSentAt == nil {
t.Error("Expected LastSentAt to be set")
}
}
// Test SendReminders with multiple pending signers
func TestReminderService_SendReminders_MultiplePending(t *testing.T) {
t.Parallel()
ctx := context.Background()
mockExpectedRepo := &mockExpectedSignerRepository{
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
return []*models.ExpectedSignerWithStatus{
{ExpectedSigner: models.ExpectedSigner{Email: "pending1@example.com", Name: "User 1"}, HasSigned: false},
{ExpectedSigner: models.ExpectedSigner{Email: "pending2@example.com", Name: "User 2"}, HasSigned: false},
{ExpectedSigner: models.ExpectedSigner{Email: "already-signed@example.com", Name: "User 3"}, HasSigned: true},
}, nil
},
}
emailsSent := 0
mockReminderRepo := &mockReminderRepository{
logReminderFunc: func(ctx context.Context, log *models.ReminderLog) error {
return nil
},
}
mockEmailSender := &mockEmailSender{
sendFunc: func(ctx context.Context, msg email.Message) error {
emailsSent++
return nil
},
}
mockMagicLink := &mockMagicLinkService{}
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com")
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en")
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if result.TotalAttempted != 2 {
t.Errorf("Expected 2 total attempted, got %d", result.TotalAttempted)
}
if result.SuccessfullySent != 2 {
t.Errorf("Expected 2 successfully sent, got %d", result.SuccessfullySent)
}
if emailsSent != 2 {
t.Errorf("Expected 2 emails sent, got %d", emailsSent)
}
}
// Test SendReminders with log failure after successful email
func TestReminderService_SendReminders_LogFailure(t *testing.T) {
t.Parallel()
ctx := context.Background()
mockExpectedRepo := &mockExpectedSignerRepository{
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
return []*models.ExpectedSignerWithStatus{
{ExpectedSigner: models.ExpectedSigner{Email: "pending@example.com", Name: "Pending User"}, HasSigned: false},
}, nil
},
}
mockReminderRepo := &mockReminderRepository{
logReminderFunc: func(ctx context.Context, log *models.ReminderLog) error {
return errors.New("database write failed")
},
}
mockEmailSender := &mockEmailSender{
sendFunc: func(ctx context.Context, msg email.Message) error {
return nil // Email succeeds
},
}
mockMagicLink := &mockMagicLinkService{}
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com")
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en")
if err != nil {
t.Fatalf("Expected no error from SendReminders, got: %v", err)
}
// The send should fail because logging failed
if result.Failed != 1 {
t.Errorf("Expected 1 failed, got %d", result.Failed)
}
if result.SuccessfullySent != 0 {
t.Errorf("Expected 0 successfully sent, got %d", result.SuccessfullySent)
}
}

View File

@@ -172,6 +172,64 @@ func (f *fakeCryptoSigner) CreateSignature(ctx context.Context, docID string, us
return payloadHash, signature, nil
}
type fakeDocumentRepository struct {
documents map[string]*models.Document
}
func newFakeDocumentRepository() *fakeDocumentRepository {
return &fakeDocumentRepository{
documents: make(map[string]*models.Document),
}
}
func (f *fakeDocumentRepository) Create(_ context.Context, docID string, input models.DocumentInput, createdBy string) (*models.Document, error) {
doc := &models.Document{
DocID: docID,
Title: input.Title,
URL: input.URL,
Checksum: input.Checksum,
ChecksumAlgorithm: input.ChecksumAlgorithm,
CreatedBy: createdBy,
}
f.documents[docID] = doc
return doc, nil
}
func (f *fakeDocumentRepository) GetByDocID(_ context.Context, docID string) (*models.Document, error) {
if doc, ok := f.documents[docID]; ok {
return doc, nil
}
return nil, nil
}
func (f *fakeDocumentRepository) FindByReference(_ context.Context, _ string, _ string) (*models.Document, error) {
return nil, nil
}
func (f *fakeDocumentRepository) List(_ context.Context, _, _ int) ([]*models.Document, error) {
return nil, nil
}
func (f *fakeDocumentRepository) Search(_ context.Context, _ string, _, _ int) ([]*models.Document, error) {
return nil, nil
}
func (f *fakeDocumentRepository) Count(_ context.Context, _ string) (int, error) {
return 0, nil
}
func (f *fakeDocumentRepository) ListByCreatedBy(_ context.Context, _ string, _, _ int) ([]*models.Document, error) {
return nil, nil
}
func (f *fakeDocumentRepository) SearchByCreatedBy(_ context.Context, _, _ string, _, _ int) ([]*models.Document, error) {
return nil, nil
}
func (f *fakeDocumentRepository) CountByCreatedBy(_ context.Context, _, _ string) (int, error) {
return 0, nil
}
func TestNewSignatureService(t *testing.T) {
repo := newFakeRepository()
docRepo := newFakeDocumentRepository()

View File

@@ -1,393 +0,0 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package auth
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/gorilla/securecookie"
"golang.org/x/oauth2"
"github.com/btouchard/ackify-ce/backend/pkg/crypto"
"github.com/btouchard/ackify-ce/backend/pkg/logger"
"github.com/btouchard/ackify-ce/backend/pkg/models"
)
// OAuthProvider handles OAuth2 authentication flow
// This component is optional and can be nil if OAuth is disabled
type OAuthProvider struct {
oauthConfig *oauth2.Config
userInfoURL string
logoutURL string
allowedDomain string
baseURL string
sessionSvc *SessionService // Reference to session service for state management
}
// OAuthProviderConfig holds configuration for the OAuth provider
type OAuthProviderConfig struct {
BaseURL string
ClientID string
ClientSecret string
AuthURL string
TokenURL string
UserInfoURL string
LogoutURL string
Scopes []string
AllowedDomain string
SessionSvc *SessionService
}
// NewOAuthProvider creates a new OAuth provider
func NewOAuthProvider(config OAuthProviderConfig) *OAuthProvider {
oauthConfig := &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.BaseURL + "/api/v1/auth/callback",
Scopes: config.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: config.AuthURL,
TokenURL: config.TokenURL,
},
}
logger.Logger.Info("OAuth provider configured successfully")
return &OAuthProvider{
oauthConfig: oauthConfig,
userInfoURL: config.UserInfoURL,
logoutURL: config.LogoutURL,
allowedDomain: config.AllowedDomain,
baseURL: config.BaseURL,
sessionSvc: config.SessionSvc,
}
}
// GetLogoutURL returns the SSO logout URL if configured
func (p *OAuthProvider) GetLogoutURL() string {
if p.logoutURL == "" {
return ""
}
logoutURL := p.logoutURL
if p.baseURL != "" {
logoutURL += "?continue=" + p.baseURL
}
return logoutURL
}
// CreateAuthURL creates an OAuth authorization URL with PKCE
func (p *OAuthProvider) CreateAuthURL(w http.ResponseWriter, r *http.Request, nextURL string) string {
// Generate PKCE code verifier and challenge
codeVerifier, err := crypto.GenerateCodeVerifier()
if err != nil {
logger.Logger.Error("Failed to generate PKCE code verifier", "error", err.Error())
// Fallback to OAuth flow without PKCE for backward compatibility
return p.createAuthURLWithoutPKCE(w, r, nextURL)
}
codeChallenge := crypto.GenerateCodeChallenge(codeVerifier)
logger.Logger.Debug("Generated PKCE parameters for OAuth flow")
// Generate state token
randPart := securecookie.GenerateRandomKey(20)
token := base64.RawURLEncoding.EncodeToString(randPart)
state := token + ":" + base64.RawURLEncoding.EncodeToString([]byte(nextURL))
promptParam := "select_account"
isSilent := r.URL.Query().Get("silent") == "true"
if isSilent {
promptParam = "none"
}
logger.Logger.Info("Starting OAuth flow with PKCE",
"next_url", nextURL,
"silent", isSilent,
"state_token_length", len(token))
session, err := p.sessionSvc.GetSession(r)
if err != nil {
logger.Logger.Error("CreateAuthURL: failed to get session from store", "error", err.Error())
// Create a new empty session if Get fails
session, _ = p.sessionSvc.GetNewSession(r)
}
// Store state and code_verifier in session
session.Values["oauth_state"] = token
session.Values["code_verifier"] = codeVerifier
err = session.Save(r, w)
if err != nil {
logger.Logger.Error("CreateAuthURL: failed to save session", "error", err.Error())
}
// Generate OAuth URL with PKCE parameters
authURL := p.oauthConfig.AuthCodeURL(state,
oauth2.SetAuthURLParam("prompt", promptParam),
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"))
logger.Logger.Debug("CreateAuthURL: generated auth URL with PKCE",
"prompt", promptParam,
"url_length", len(authURL))
return authURL
}
// createAuthURLWithoutPKCE is a fallback method for OAuth without PKCE
// Used for backward compatibility if PKCE generation fails
func (p *OAuthProvider) createAuthURLWithoutPKCE(w http.ResponseWriter, r *http.Request, nextURL string) string {
randPart := securecookie.GenerateRandomKey(20)
token := base64.RawURLEncoding.EncodeToString(randPart)
state := token + ":" + base64.RawURLEncoding.EncodeToString([]byte(nextURL))
promptParam := "select_account"
isSilent := r.URL.Query().Get("silent") == "true"
if isSilent {
promptParam = "none"
}
logger.Logger.Warn("Starting OAuth flow WITHOUT PKCE (fallback mode)",
"next_url", nextURL,
"silent", isSilent)
session, err := p.sessionSvc.GetSession(r)
if err != nil {
session, _ = p.sessionSvc.GetNewSession(r)
}
session.Values["oauth_state"] = token
err = session.Save(r, w)
if err != nil {
logger.Logger.Error("CreateAuthURL: failed to save session", "error", err.Error())
}
authURL := p.oauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("prompt", promptParam))
return authURL
}
// VerifyState validates the OAuth state token for CSRF protection
func (p *OAuthProvider) VerifyState(w http.ResponseWriter, r *http.Request, stateToken string) bool {
session, _ := p.sessionSvc.GetSession(r)
stored, _ := session.Values["oauth_state"].(string)
logger.Logger.Debug("VerifyState: validating OAuth state",
"stored_length", len(stored),
"token_length", len(stateToken),
"stored_empty", stored == "",
"token_empty", stateToken == "")
if stored == "" || stateToken == "" {
logger.Logger.Warn("VerifyState: empty state tokens")
return false
}
if subtleConstantTimeCompare(stored, stateToken) {
logger.Logger.Debug("VerifyState: state valid, clearing token")
delete(session.Values, "oauth_state")
_ = session.Save(r, w)
return true
}
logger.Logger.Warn("VerifyState: state mismatch")
return false
}
// subtleConstantTimeCompare performs a timing-safe string comparison
func subtleConstantTimeCompare(a, b string) bool {
if len(a) != len(b) {
return false
}
var v byte
for i := 0; i < len(a); i++ {
v |= a[i] ^ b[i]
}
return v == 0
}
// HandleCallback processes the OAuth callback and returns the authenticated user
func (p *OAuthProvider) HandleCallback(ctx context.Context, w http.ResponseWriter, r *http.Request, code, state string) (*models.User, string, error) {
parts := strings.SplitN(state, ":", 2)
nextURL := "/"
if len(parts) == 2 {
if nb, err := base64.RawURLEncoding.DecodeString(parts[1]); err == nil {
nextURL = string(nb)
}
}
logger.Logger.Debug("Processing OAuth callback",
"has_code", code != "",
"next_url", nextURL)
// Retrieve code_verifier from session for PKCE
session, _ := p.sessionSvc.GetSession(r)
codeVerifier, hasPKCE := session.Values["code_verifier"].(string)
// Clean up code_verifier immediately after retrieval
if hasPKCE {
delete(session.Values, "code_verifier")
_ = session.Save(r, w)
}
// Exchange authorization code for token (with or without PKCE)
var token *oauth2.Token
var err error
if hasPKCE && codeVerifier != "" {
logger.Logger.Info("OAuth token exchange with PKCE")
token, err = p.oauthConfig.Exchange(ctx, code,
oauth2.SetAuthURLParam("code_verifier", codeVerifier))
} else {
logger.Logger.Warn("OAuth token exchange without PKCE (legacy session or fallback)")
token, err = p.oauthConfig.Exchange(ctx, code)
}
if err != nil {
logger.Logger.Error("OAuth token exchange failed",
"error", err.Error(),
"with_pkce", hasPKCE)
return nil, nextURL, fmt.Errorf("oauth exchange failed: %w", err)
}
logger.Logger.Info("OAuth token exchange successful", "with_pkce", hasPKCE)
client := p.oauthConfig.Client(ctx, token)
resp, err := client.Get(p.userInfoURL)
if err != nil || resp.StatusCode != 200 {
statusCode := 0
if resp != nil {
statusCode = resp.StatusCode
}
logger.Logger.Error("User info request failed",
"error", err,
"status_code", statusCode)
return nil, nextURL, fmt.Errorf("userinfo request failed: %w", err)
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(resp.Body)
logger.Logger.Debug("User info retrieved successfully",
"status_code", resp.StatusCode)
user, err := p.parseUserInfo(resp)
if err != nil {
logger.Logger.Error("Failed to parse user info",
"error", err.Error())
return nil, nextURL, fmt.Errorf("failed to parse user info: %w", err)
}
if !p.IsAllowedDomain(user.Email) {
logger.Logger.Warn("User domain not allowed")
return nil, nextURL, models.ErrDomainNotAllowed
}
logger.Logger.Info("OAuth callback successful")
// Store refresh token if available
if token.RefreshToken != "" && p.sessionSvc.sessionRepo != nil {
if err := p.sessionSvc.StoreRefreshToken(ctx, w, r, token, user); err != nil {
// Log error but don't fail the authentication
logger.Logger.Error("Failed to store refresh token (non-fatal)", "error", err.Error())
}
}
return user, nextURL, nil
}
// IsAllowedDomain checks if the user's email domain is allowed
func (p *OAuthProvider) IsAllowedDomain(email string) bool {
if p.allowedDomain == "" {
return true
}
domain := strings.ToLower(p.allowedDomain)
// If domain already has @ prefix, don't add another one
if !strings.HasPrefix(domain, "@") {
domain = "@" + domain
}
return strings.HasSuffix(strings.ToLower(email), domain)
}
// parseUserInfo extracts user information from the OAuth provider response
func (p *OAuthProvider) parseUserInfo(resp *http.Response) (*models.User, error) {
var rawUser map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&rawUser); err != nil {
return nil, fmt.Errorf("failed to decode user info: %w", err)
}
// Reduce PII in standard logs; log only keys at debug level
if rawUser != nil {
keys := make([]string, 0, len(rawUser))
for k := range rawUser {
keys = append(keys, k)
}
logger.Logger.Debug("OAuth user info received", "keys", keys)
}
user := &models.User{}
if sub, ok := rawUser["sub"].(string); ok {
user.Sub = sub
} else if id, ok := rawUser["id"]; ok {
user.Sub = fmt.Sprintf("%v", id)
} else {
return nil, fmt.Errorf("missing user ID in response")
}
// Check for email in various provider-specific fields
// - "email": Standard OIDC claim (Google, GitHub, GitLab, etc.)
// - "mail": Microsoft Graph API
// - "userPrincipalName": Microsoft fallback (UPN format)
if email, ok := rawUser["email"].(string); ok && email != "" {
user.Email = email
} else if mail, ok := rawUser["mail"].(string); ok && mail != "" {
user.Email = mail
} else if upn, ok := rawUser["userPrincipalName"].(string); ok && upn != "" {
user.Email = upn
} else {
return nil, fmt.Errorf("missing email in user info response (checked: email, mail, userPrincipalName)")
}
// Extract display name from various provider-specific fields
var name string
if fullName, ok := rawUser["name"].(string); ok && fullName != "" {
name = fullName
} else if firstName, ok := rawUser["given_name"].(string); ok {
if lastName, ok := rawUser["family_name"].(string); ok {
name = firstName + " " + lastName
} else {
name = firstName
}
} else if displayName, ok := rawUser["displayName"].(string); ok && displayName != "" {
name = displayName
} else if cn, ok := rawUser["cn"].(string); ok && cn != "" {
name = cn
} else if displayNameSnake, ok := rawUser["display_name"].(string); ok && displayNameSnake != "" {
name = displayNameSnake
} else if preferredName, ok := rawUser["preferred_username"].(string); ok && preferredName != "" {
name = preferredName
}
user.Name = name
logger.Logger.Debug("Extracted OAuth user identifiers",
"has_sub", user.Sub != "",
"has_email", user.Email != "",
"has_name", user.Name != "")
if !user.IsValid() {
return nil, fmt.Errorf("invalid user data extracted")
}
return user, nil
}

View File

@@ -1,422 +0,0 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package auth
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/btouchard/ackify-ce/backend/pkg/models"
)
func TestOAuthProvider_IsAllowedDomain(t *testing.T) {
sessionSvc := NewSessionService(SessionServiceConfig{
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
SecureCookies: false,
})
tests := []struct {
name string
allowedDomain string
email string
expected bool
}{
{
name: "allowed domain match",
allowedDomain: "@example.com",
email: "user@example.com",
expected: true,
},
{
name: "allowed domain mismatch",
allowedDomain: "@example.com",
email: "user@other.com",
expected: false,
},
{
name: "no restriction",
allowedDomain: "",
email: "user@any.com",
expected: true,
},
{
name: "case insensitive match",
allowedDomain: "@Example.Com",
email: "user@EXAMPLE.com",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider := &OAuthProvider{
allowedDomain: tt.allowedDomain,
sessionSvc: sessionSvc,
}
result := provider.IsAllowedDomain(tt.email)
if result != tt.expected {
t.Errorf("IsAllowedDomain(%v) = %v, expected %v", tt.email, result, tt.expected)
}
})
}
}
func TestOAuthProvider_CreateAuthURL(t *testing.T) {
sessionSvc := NewSessionService(SessionServiceConfig{
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
SecureCookies: false,
})
config := OAuthProviderConfig{
BaseURL: "https://ackify.example.com",
ClientID: "test-client",
ClientSecret: "test-secret",
AuthURL: "https://provider.com/auth",
TokenURL: "https://provider.com/token",
UserInfoURL: "https://provider.com/userinfo",
Scopes: []string{"openid", "email"},
SessionSvc: sessionSvc,
}
provider := NewOAuthProvider(config)
t.Run("creates auth URL with PKCE", func(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
authURL := provider.CreateAuthURL(rec, req, "/dashboard")
if authURL == "" {
t.Fatal("CreateAuthURL() returned empty string")
}
// Parse the URL and check parameters
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
query := parsedURL.Query()
// Check for required OAuth parameters
if query.Get("client_id") == "" {
t.Error("client_id parameter missing")
}
if query.Get("redirect_uri") == "" {
t.Error("redirect_uri parameter missing")
}
if query.Get("response_type") != "code" {
t.Error("response_type should be 'code'")
}
if query.Get("state") == "" {
t.Error("state parameter missing")
}
// Check for PKCE parameters
if query.Get("code_challenge") == "" {
t.Error("code_challenge parameter missing (PKCE should be enabled)")
}
if query.Get("code_challenge_method") != "S256" {
t.Error("code_challenge_method should be 'S256'")
}
})
t.Run("creates auth URL with silent flag", func(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/?silent=true", nil)
authURL := provider.CreateAuthURL(rec, req, "/dashboard")
parsedURL, _ := url.Parse(authURL)
query := parsedURL.Query()
if query.Get("prompt") != "none" {
t.Errorf("prompt parameter = %v, expected 'none' for silent auth", query.Get("prompt"))
}
})
}
func TestOAuthProvider_VerifyState(t *testing.T) {
sessionSvc := NewSessionService(SessionServiceConfig{
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
SecureCookies: false,
})
config := OAuthProviderConfig{
BaseURL: "https://ackify.example.com",
ClientID: "test-client",
ClientSecret: "test-secret",
AuthURL: "https://provider.com/auth",
TokenURL: "https://provider.com/token",
UserInfoURL: "https://provider.com/userinfo",
Scopes: []string{"openid"},
SessionSvc: sessionSvc,
}
provider := NewOAuthProvider(config)
t.Run("valid state", func(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
// Create auth URL to generate state
_ = provider.CreateAuthURL(rec, req, "/")
// Get session with state
req2 := httptest.NewRequest("GET", "/", nil)
for _, cookie := range rec.Result().Cookies() {
req2.AddCookie(cookie)
}
session, _ := sessionSvc.GetSession(req2)
storedState, ok := session.Values["oauth_state"].(string)
if !ok {
t.Fatal("State not stored in session")
}
// Verify state
rec2 := httptest.NewRecorder()
valid := provider.VerifyState(rec2, req2, storedState)
if !valid {
t.Error("VerifyState() should return true for valid state")
}
})
t.Run("invalid state", func(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
valid := provider.VerifyState(rec, req, "invalid-state-token")
if valid {
t.Error("VerifyState() should return false for invalid state")
}
})
t.Run("empty state", func(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
valid := provider.VerifyState(rec, req, "")
if valid {
t.Error("VerifyState() should return false for empty state")
}
})
}
func TestOAuthProvider_parseUserInfo(t *testing.T) {
sessionSvc := NewSessionService(SessionServiceConfig{
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
SecureCookies: false,
})
provider := &OAuthProvider{
sessionSvc: sessionSvc,
}
tests := []struct {
name string
responseObj map[string]interface{}
wantErr bool
checkUser func(*testing.T, *models.User)
}{
{
name: "complete user info with sub",
responseObj: map[string]interface{}{
"sub": "12345",
"email": "user@example.com",
"name": "Test User",
},
wantErr: false,
checkUser: func(t *testing.T, user *models.User) {
if user.Sub != "12345" {
t.Errorf("Sub = %v, expected 12345", user.Sub)
}
if user.Email != "user@example.com" {
t.Errorf("Email = %v, expected user@example.com", user.Email)
}
if user.Name != "Test User" {
t.Errorf("Name = %v, expected Test User", user.Name)
}
},
},
{
name: "user info with id instead of sub",
responseObj: map[string]interface{}{
"id": 67890,
"email": "user@example.com",
},
wantErr: false,
checkUser: func(t *testing.T, user *models.User) {
if user.Sub != "67890" {
t.Errorf("Sub = %v, expected 67890", user.Sub)
}
},
},
{
name: "user info with given_name and family_name",
responseObj: map[string]interface{}{
"sub": "12345",
"email": "user@example.com",
"given_name": "John",
"family_name": "Doe",
},
wantErr: false,
checkUser: func(t *testing.T, user *models.User) {
if user.Name != "John Doe" {
t.Errorf("Name = %v, expected 'John Doe'", user.Name)
}
},
},
{
name: "Microsoft Graph API - mail field",
responseObj: map[string]interface{}{
"id": "microsoft-id-12345",
"mail": "user@company.com",
"displayName": "Microsoft User",
"userPrincipalName": "user@company.onmicrosoft.com",
},
wantErr: false,
checkUser: func(t *testing.T, user *models.User) {
if user.Sub != "microsoft-id-12345" {
t.Errorf("Sub = %v, expected microsoft-id-12345", user.Sub)
}
if user.Email != "user@company.com" {
t.Errorf("Email = %v, expected user@company.com (from mail field)", user.Email)
}
if user.Name != "Microsoft User" {
t.Errorf("Name = %v, expected Microsoft User (from displayName)", user.Name)
}
},
},
{
name: "Microsoft Graph API - userPrincipalName fallback",
responseObj: map[string]interface{}{
"id": "microsoft-id-67890",
"displayName": "UPN User",
"userPrincipalName": "user@company.onmicrosoft.com",
},
wantErr: false,
checkUser: func(t *testing.T, user *models.User) {
if user.Email != "user@company.onmicrosoft.com" {
t.Errorf("Email = %v, expected user@company.onmicrosoft.com (from userPrincipalName)", user.Email)
}
},
},
{
name: "email field takes priority over mail",
responseObj: map[string]interface{}{
"sub": "12345",
"email": "primary@example.com",
"mail": "secondary@example.com",
},
wantErr: false,
checkUser: func(t *testing.T, user *models.User) {
if user.Email != "primary@example.com" {
t.Errorf("Email = %v, expected primary@example.com (email should take priority)", user.Email)
}
},
},
{
name: "missing email",
responseObj: map[string]interface{}{
"sub": "12345",
"name": "Test User",
},
wantErr: true,
},
{
name: "missing sub and id",
responseObj: map[string]interface{}{
"email": "user@example.com",
"name": "Test User",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mock HTTP response
jsonData, _ := json.Marshal(tt.responseObj)
resp := &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(jsonData)),
}
user, err := provider.parseUserInfo(resp)
if tt.wantErr {
if err == nil {
t.Error("parseUserInfo() should return error")
}
return
}
if err != nil {
t.Fatalf("parseUserInfo() unexpected error: %v", err)
}
if user == nil {
t.Fatal("parseUserInfo() returned nil user")
}
if tt.checkUser != nil {
tt.checkUser(t, user)
}
})
}
}
func TestSubtleConstantTimeCompare(t *testing.T) {
tests := []struct {
name string
a string
b string
expected bool
}{
{
name: "equal strings",
a: "secret123",
b: "secret123",
expected: true,
},
{
name: "different strings",
a: "secret123",
b: "secret456",
expected: false,
},
{
name: "different lengths",
a: "short",
b: "longer string",
expected: false,
},
{
name: "empty strings",
a: "",
b: "",
expected: true,
},
{
name: "one empty",
a: "string",
b: "",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := subtleConstantTimeCompare(tt.a, tt.b)
if result != tt.expected {
t.Errorf("subtleConstantTimeCompare(%v, %v) = %v, expected %v", tt.a, tt.b, result, tt.expected)
}
})
}
}

View File

@@ -1,37 +0,0 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package email
import (
"context"
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/i18n"
)
func SendEmail(ctx context.Context, sender Sender, template string, to []string, locale string, subject string, data map[string]any) error {
msg := Message{
To: to,
Subject: subject,
Template: template,
Locale: locale,
Data: data,
}
return sender.Send(ctx, msg)
}
func SendSignatureReminderEmail(ctx context.Context, sender Sender, i18nService *i18n.I18n, to []string, locale, docID, docURL, signURL, recipientName string) error {
data := map[string]any{
"DocID": docID,
"DocURL": docURL,
"SignURL": signURL,
"RecipientName": recipientName,
}
// Get translated subject using i18n
subject := "Document Reading Confirmation Reminder" // Fallback
if i18nService != nil {
subject = i18nService.T(locale, "email.reminder.subject")
}
return SendEmail(ctx, sender, "signature_reminder", to, locale, subject, data)
}

View File

@@ -1,268 +0,0 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package email
import (
"context"
"errors"
"testing"
)
// Mock sender for testing
type mockSender struct {
sendFunc func(ctx context.Context, msg Message) error
lastMsg *Message
}
func (m *mockSender) Send(ctx context.Context, msg Message) error {
m.lastMsg = &msg
if m.sendFunc != nil {
return m.sendFunc(ctx, msg)
}
return nil
}
func TestSendEmail(t *testing.T) {
t.Parallel()
tests := []struct {
name string
template string
to []string
locale string
subject string
data map[string]any
sendError error
expectError bool
}{
{
name: "Send email successfully",
template: "test_template",
to: []string{"user@example.com"},
locale: "en",
subject: "Test Subject",
data: map[string]any{
"name": "John",
},
sendError: nil,
expectError: false,
},
{
name: "Send email with multiple recipients",
template: "welcome",
to: []string{"user1@example.com", "user2@example.com"},
locale: "fr",
subject: "Bienvenue",
data: map[string]any{
"company": "Acme Corp",
},
sendError: nil,
expectError: false,
},
{
name: "Send email with error",
template: "error_template",
to: []string{"user@example.com"},
locale: "en",
subject: "Error Test",
data: nil,
sendError: errors.New("SMTP connection failed"),
expectError: true,
},
{
name: "Send email with empty data",
template: "simple_template",
to: []string{"test@example.com"},
locale: "en",
subject: "Simple Email",
data: map[string]any{},
sendError: nil,
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
mock := &mockSender{
sendFunc: func(ctx context.Context, msg Message) error {
return tt.sendError
},
}
err := SendEmail(ctx, mock, tt.template, tt.to, tt.locale, tt.subject, tt.data)
if tt.expectError && err == nil {
t.Error("Expected error but got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Unexpected error: %v", err)
}
// Verify message was constructed correctly
if mock.lastMsg == nil {
t.Fatal("Expected message to be captured")
}
if mock.lastMsg.Template != tt.template {
t.Errorf("Expected template '%s', got '%s'", tt.template, mock.lastMsg.Template)
}
if mock.lastMsg.Subject != tt.subject {
t.Errorf("Expected subject '%s', got '%s'", tt.subject, mock.lastMsg.Subject)
}
if mock.lastMsg.Locale != tt.locale {
t.Errorf("Expected locale '%s', got '%s'", tt.locale, mock.lastMsg.Locale)
}
if len(mock.lastMsg.To) != len(tt.to) {
t.Errorf("Expected %d recipients, got %d", len(tt.to), len(mock.lastMsg.To))
}
})
}
}
func TestSendSignatureReminderEmail(t *testing.T) {
t.Parallel()
// Note: When i18n is nil, all tests use the fallback subject
fallbackSubject := "Document Reading Confirmation Reminder"
tests := []struct {
name string
to []string
locale string
docID string
docURL string
signURL string
recipientName string
expectedSubject string
sendError error
expectError bool
}{
{
name: "Send reminder in English",
to: []string{"user@example.com"},
locale: "en",
docID: "doc123",
docURL: "https://example.com/doc.pdf",
signURL: "https://example.com/?doc=doc123",
recipientName: "John Doe",
expectedSubject: fallbackSubject,
sendError: nil,
expectError: false,
},
{
name: "Send reminder in French (fallback without i18n)",
to: []string{"utilisateur@exemple.fr"},
locale: "fr",
docID: "doc456",
docURL: "https://exemple.fr/document.pdf",
signURL: "https://exemple.fr/?doc=doc456",
recipientName: "Marie Dupont",
expectedSubject: fallbackSubject,
sendError: nil,
expectError: false,
},
{
name: "Send reminder with unknown locale defaults to fallback",
to: []string{"user@example.com"},
locale: "es",
docID: "doc789",
docURL: "https://example.com/doc.pdf",
signURL: "https://example.com/?doc=doc789",
recipientName: "Juan Garcia",
expectedSubject: fallbackSubject,
sendError: nil,
expectError: false,
},
{
name: "Send reminder with error",
to: []string{"user@example.com"},
locale: "en",
docID: "doc999",
docURL: "https://example.com/doc.pdf",
signURL: "https://example.com/?doc=doc999",
recipientName: "Test User",
expectedSubject: fallbackSubject,
sendError: errors.New("email server unavailable"),
expectError: true,
},
{
name: "Send reminder with empty recipient name",
to: []string{"user@example.com"},
locale: "en",
docID: "doc000",
docURL: "https://example.com/doc.pdf",
signURL: "https://example.com/?doc=doc000",
recipientName: "",
expectedSubject: fallbackSubject,
sendError: nil,
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
mock := &mockSender{
sendFunc: func(ctx context.Context, msg Message) error {
return tt.sendError
},
}
err := SendSignatureReminderEmail(ctx, mock, nil, tt.to, tt.locale, tt.docID, tt.docURL, tt.signURL, tt.recipientName)
if tt.expectError && err == nil {
t.Error("Expected error but got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Unexpected error: %v", err)
}
// Verify message construction
if mock.lastMsg == nil {
t.Fatal("Expected message to be captured")
}
if mock.lastMsg.Template != "signature_reminder" {
t.Errorf("Expected template 'signature_reminder', got '%s'", mock.lastMsg.Template)
}
if mock.lastMsg.Subject != tt.expectedSubject {
t.Errorf("Expected subject '%s', got '%s'", tt.expectedSubject, mock.lastMsg.Subject)
}
if mock.lastMsg.Locale != tt.locale {
t.Errorf("Expected locale '%s', got '%s'", tt.locale, mock.lastMsg.Locale)
}
// Verify data fields
if mock.lastMsg.Data == nil {
t.Fatal("Expected data to be present")
}
if docID, ok := mock.lastMsg.Data["DocID"].(string); !ok || docID != tt.docID {
t.Errorf("Expected DocID '%s', got '%v'", tt.docID, mock.lastMsg.Data["DocID"])
}
if docURL, ok := mock.lastMsg.Data["DocURL"].(string); !ok || docURL != tt.docURL {
t.Errorf("Expected DocURL '%s', got '%v'", tt.docURL, mock.lastMsg.Data["DocURL"])
}
if signURL, ok := mock.lastMsg.Data["SignURL"].(string); !ok || signURL != tt.signURL {
t.Errorf("Expected SignURL '%s', got '%v'", tt.signURL, mock.lastMsg.Data["SignURL"])
}
if recipientName, ok := mock.lastMsg.Data["RecipientName"].(string); !ok || recipientName != tt.recipientName {
t.Errorf("Expected RecipientName '%s', got '%v'", tt.recipientName, mock.lastMsg.Data["RecipientName"])
}
})
}
}

View File

@@ -1,113 +0,0 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package email
import (
"context"
"fmt"
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/i18n"
"github.com/btouchard/ackify-ce/backend/pkg/models"
)
// QueueSender implements the Sender interface by queuing emails instead of sending them directly
type QueueSender struct {
queueRepo QueueRepository
baseURL string
}
// NewQueueSender creates a new queue-based email sender
func NewQueueSender(queueRepo QueueRepository, baseURL string) *QueueSender {
return &QueueSender{
queueRepo: queueRepo,
baseURL: baseURL,
}
}
// Send queues an email for asynchronous processing
func (q *QueueSender) Send(ctx context.Context, msg Message) error {
// Convert message data to proper format
data := msg.Data
if data == nil {
data = make(map[string]interface{})
}
input := models.EmailQueueInput{
ToAddresses: msg.To,
CcAddresses: msg.Cc,
BccAddresses: msg.Bcc,
Subject: msg.Subject,
Template: msg.Template,
Locale: msg.Locale,
Data: data,
Headers: msg.Headers,
Priority: models.EmailPriorityNormal,
}
// Set priority based on template type
switch msg.Template {
case "signature_reminder":
input.Priority = models.EmailPriorityHigh
case "welcome", "notification":
input.Priority = models.EmailPriorityNormal
default:
input.Priority = models.EmailPriorityNormal
}
// Queue the email
_, err := q.queueRepo.Enqueue(ctx, input)
if err != nil {
return fmt.Errorf("failed to queue email: %w", err)
}
return nil
}
// QueueSignatureReminderEmail queues a signature reminder email
func QueueSignatureReminderEmail(
ctx context.Context,
queueRepo QueueRepository,
i18nService *i18n.I18n,
recipients []string,
locale string,
docID string,
docURL string,
signURL string,
recipientName string,
sentBy string,
) error {
data := map[string]interface{}{
"doc_id": docID,
"doc_url": docURL,
"sign_url": signURL,
"recipient_name": recipientName,
"locale": locale,
}
// Get translated subject using i18n
subject := "Document Reading Confirmation Reminder" // Fallback
if i18nService != nil {
subject = i18nService.T(locale, "email.reminder.subject")
}
// Create a reference for tracking
refType := "signature_reminder"
input := models.EmailQueueInput{
ToAddresses: recipients,
Subject: subject,
Template: "signature_reminder",
Locale: locale,
Data: data,
Priority: models.EmailPriorityHigh,
ReferenceType: &refType,
ReferenceID: &docID,
CreatedBy: &sentBy,
}
_, err := queueRepo.Enqueue(ctx, input)
if err != nil {
return fmt.Errorf("failed to queue signature reminder: %w", err)
}
return nil
}

View File

@@ -120,8 +120,3 @@ func WritePaginatedJSON(w http.ResponseWriter, data interface{}, page, limit, to
WriteJSONWithMeta(w, http.StatusOK, data, meta)
}
// WriteNoContent writes a 204 No Content response
func WriteNoContent(w http.ResponseWriter) {
w.WriteHeader(http.StatusNoContent)
}

View File

@@ -222,19 +222,3 @@ func TestWritePaginatedJSON(t *testing.T) {
})
}
}
func TestWriteNoContent(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
WriteNoContent(w)
if w.Code != http.StatusNoContent {
t.Errorf("Expected status code %d, got %d", http.StatusNoContent, w.Code)
}
if w.Body.Len() != 0 {
t.Errorf("Expected empty body, got %d bytes", w.Body.Len())
}
}

View File

@@ -1,33 +0,0 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
package shared
import (
"net/http"
"strings"
)
// GetClientIP extracts the real client IP address from the request
// It checks X-Forwarded-For, X-Real-IP, and falls back to RemoteAddr
func GetClientIP(r *http.Request) string {
// Check X-Forwarded-For header (proxy/load balancer)
if forwardedFor := r.Header.Get("X-Forwarded-For"); forwardedFor != "" {
// X-Forwarded-For can contain multiple IPs, take the first one
ips := strings.Split(forwardedFor, ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
// Check X-Real-IP header
if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
return strings.TrimSpace(realIP)
}
// Fall back to RemoteAddr
ip := r.RemoteAddr
// Remove port if present
if idx := strings.LastIndex(ip, ":"); idx != -1 {
ip = ip[:idx]
}
return ip
}

View File

@@ -25,6 +25,7 @@ type UserDTO struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Picture string `json:"picture,omitempty"`
IsAdmin bool `json:"isAdmin"`
}
@@ -40,6 +41,7 @@ func (h *Handler) HandleGetCurrentUser(w http.ResponseWriter, r *http.Request) {
ID: user.Sub,
Email: user.Email,
Name: user.Name,
Picture: user.Picture,
IsAdmin: h.authorizer.IsAdmin(r.Context(), user.Email),
}

View File

@@ -1,39 +0,0 @@
package handlers
import (
"errors"
"net/http"
"github.com/btouchard/ackify-ce/backend/pkg/logger"
"github.com/btouchard/ackify-ce/backend/pkg/models"
)
// HandleError handles different types of errors and returns appropriate HTTP responses
func HandleError(w http.ResponseWriter, err error) {
switch {
case errors.Is(err, models.ErrUnauthorized):
logger.Logger.Warn("Unauthorized access attempt", "error", err.Error())
http.Error(w, "Unauthorized", http.StatusUnauthorized)
case errors.Is(err, models.ErrSignatureNotFound):
logger.Logger.Debug("Signature not found", "error", err.Error())
http.Error(w, "Signature not found", http.StatusNotFound)
case errors.Is(err, models.ErrSignatureAlreadyExists):
logger.Logger.Debug("Duplicate signature attempt", "error", err.Error())
http.Error(w, "Signature already exists", http.StatusConflict)
case errors.Is(err, models.ErrInvalidUser):
logger.Logger.Warn("Invalid user data", "error", err.Error())
http.Error(w, "Invalid user", http.StatusBadRequest)
case errors.Is(err, models.ErrInvalidDocument):
logger.Logger.Warn("Invalid document ID", "error", err.Error())
http.Error(w, "Invalid document ID", http.StatusBadRequest)
case errors.Is(err, models.ErrDomainNotAllowed):
logger.Logger.Warn("Domain not allowed", "error", err.Error())
http.Error(w, "Domain not allowed", http.StatusForbidden)
case errors.Is(err, models.ErrDatabaseConnection):
logger.Logger.Error("Database connection error", "error", err.Error())
http.Error(w, "Database error", http.StatusInternalServerError)
default:
logger.Logger.Error("Unhandled error", "error", err.Error())
http.Error(w, "Internal server error", http.StatusInternalServerError)
}
}

View File

@@ -4,8 +4,6 @@ package handlers
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
@@ -208,36 +206,6 @@ func TestHandleOEmbed_MissingDocParam(t *testing.T) {
}
}
func TestValidateOEmbedURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
urlStr string
baseURL string
expected bool
}{
{"valid same host", "https://example.com/?doc=123", "https://example.com", true},
{"valid with port", "https://example.com:443/?doc=123", "https://example.com", true},
{"different host", "https://other.com/?doc=123", "https://example.com", false},
{"localhost variations", "http://localhost:8080/?doc=123", "http://127.0.0.1:8080", true},
{"localhost to 127.0.0.1", "http://127.0.0.1/?doc=123", "http://localhost", true},
{"invalid URL", ":::invalid", "https://example.com", false},
{"invalid base URL", "https://example.com/?doc=123", ":::invalid", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := ValidateOEmbedURL(tt.urlStr, tt.baseURL)
if result != tt.expected {
t.Errorf("Expected %v, got %v", tt.expected, result)
}
})
}
}
// ============================================================================
// BENCHMARKS
// ============================================================================
@@ -254,16 +222,6 @@ func BenchmarkHandleOEmbed(b *testing.B) {
}
}
func BenchmarkValidateOEmbedURL(b *testing.B) {
urlStr := "https://example.com/?doc=test123"
baseURL := "https://example.com"
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = ValidateOEmbedURL(urlStr, baseURL)
}
}
// ============================================================================
// TESTS - Middleware: SecureHeaders
// ============================================================================
@@ -423,126 +381,6 @@ func TestRequestLogger_DifferentMethods(t *testing.T) {
}
}
// ============================================================================
// TESTS - HandleError
// ============================================================================
func TestHandleError_Unauthorized(t *testing.T) {
t.Parallel()
rec := httptest.NewRecorder()
HandleError(rec, models.ErrUnauthorized)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
assert.Contains(t, rec.Body.String(), "Unauthorized")
}
func TestHandleError_SignatureNotFound(t *testing.T) {
t.Parallel()
rec := httptest.NewRecorder()
HandleError(rec, models.ErrSignatureNotFound)
assert.Equal(t, http.StatusNotFound, rec.Code)
assert.Contains(t, rec.Body.String(), "Signature not found")
}
func TestHandleError_SignatureAlreadyExists(t *testing.T) {
t.Parallel()
rec := httptest.NewRecorder()
HandleError(rec, models.ErrSignatureAlreadyExists)
assert.Equal(t, http.StatusConflict, rec.Code)
assert.Contains(t, rec.Body.String(), "Signature already exists")
}
func TestHandleError_InvalidUser(t *testing.T) {
t.Parallel()
rec := httptest.NewRecorder()
HandleError(rec, models.ErrInvalidUser)
assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.Contains(t, rec.Body.String(), "Invalid user")
}
func TestHandleError_InvalidDocument(t *testing.T) {
t.Parallel()
rec := httptest.NewRecorder()
HandleError(rec, models.ErrInvalidDocument)
assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.Contains(t, rec.Body.String(), "Invalid document ID")
}
func TestHandleError_DomainNotAllowed(t *testing.T) {
t.Parallel()
rec := httptest.NewRecorder()
HandleError(rec, models.ErrDomainNotAllowed)
assert.Equal(t, http.StatusForbidden, rec.Code)
assert.Contains(t, rec.Body.String(), "Domain not allowed")
}
func TestHandleError_DatabaseConnection(t *testing.T) {
t.Parallel()
rec := httptest.NewRecorder()
HandleError(rec, models.ErrDatabaseConnection)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "Database error")
}
func TestHandleError_UnknownError(t *testing.T) {
t.Parallel()
rec := httptest.NewRecorder()
HandleError(rec, errors.New("unknown error"))
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "Internal server error")
}
func TestHandleError_WrappedErrors(t *testing.T) {
t.Parallel()
tests := []struct {
name string
err error
expectedStatus int
expectedMsg string
}{
{
"wrapped unauthorized",
fmt.Errorf("auth failed: %w", models.ErrUnauthorized),
http.StatusUnauthorized,
"Unauthorized",
},
{
"wrapped domain error",
fmt.Errorf("validation failed: %w", models.ErrDomainNotAllowed),
http.StatusForbidden,
"Domain not allowed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
rec := httptest.NewRecorder()
HandleError(rec, tt.err)
assert.Equal(t, tt.expectedStatus, rec.Code)
assert.Contains(t, rec.Body.String(), tt.expectedMsg)
})
}
}
// ============================================================================
// TESTS - statusRecorder
// ============================================================================
@@ -616,13 +454,3 @@ func BenchmarkRequestLogger(b *testing.B) {
handler.ServeHTTP(rec, req)
}
}
func BenchmarkHandleError(b *testing.B) {
err := models.ErrUnauthorized
b.ResetTimer()
for i := 0; i < b.N; i++ {
rec := httptest.NewRecorder()
HandleError(rec, err)
}
}

View File

@@ -7,47 +7,36 @@ import (
"time"
"github.com/btouchard/ackify-ce/backend/pkg/logger"
"github.com/btouchard/ackify-ce/backend/pkg/models"
)
type userService interface {
GetUser(r *http.Request) (*models.User, error)
}
type AuthMiddleware struct {
userService userService
baseURL string
}
// SecureHeaders Enforce baseline security headers (CSP, XFO, etc.) to mitigate clickjacking, MIME sniffing, and unsafe embedding by default.
// SecureHeaders enforces baseline security headers (CSP, XFO, etc.)
func SecureHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("Referrer-Policy", "no-referrer")
// Check if this is an embed route - allow iframe embedding
isEmbedRoute := strings.HasPrefix(r.URL.Path, "/embed/") || strings.HasPrefix(r.URL.Path, "/embed")
// OAuth provider avatar domains
imgSrc := "img-src 'self' data: https://cdn.simpleicons.org https://*.googleusercontent.com https://avatars.githubusercontent.com https://secure.gravatar.com https://gitlab.com"
if isEmbedRoute {
// Allow embedding from any origin for embed pages
// Do not set X-Frame-Options to allow iframe embedding
w.Header().Set("Content-Security-Policy",
"default-src 'self'; "+
"style-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com https://fonts.googleapis.com; "+
"font-src 'self' https://fonts.gstatic.com; "+
"script-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com; "+
"img-src 'self' data: https://cdn.simpleicons.org; "+
imgSrc+"; "+
"connect-src 'self'; "+
"frame-ancestors *") // Allow embedding from any origin
"frame-ancestors *")
} else {
// Strict headers for non-embed routes
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("Content-Security-Policy",
"default-src 'self'; "+
"style-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com https://fonts.googleapis.com; "+
"font-src 'self' https://fonts.gstatic.com; "+
"script-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com; "+
"img-src 'self' data: https://cdn.simpleicons.org; "+
imgSrc+"; "+
"connect-src 'self'; "+
"frame-ancestors 'self'")
}
@@ -56,14 +45,12 @@ func SecureHeaders(next http.Handler) http.Handler {
})
}
// RequestLogger Minimal structured logging without PII; record latency and status for ops visibility.
func RequestLogger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sr := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
start := time.Now()
next.ServeHTTP(sr, r)
duration := time.Since(start)
// Minimal structured log to avoid PII
logger.Logger.Info("http_request",
"method", r.Method,
"path", r.URL.Path,
@@ -81,8 +68,3 @@ func (sr *statusRecorder) WriteHeader(code int) {
sr.status = code
sr.ResponseWriter.WriteHeader(code)
}
type ErrorResponse struct {
Error string `json:"error"`
Message string `json:"message,omitempty"`
}

View File

@@ -5,7 +5,6 @@ import (
"encoding/json"
"net/http"
"net/url"
"strings"
"github.com/btouchard/ackify-ce/backend/pkg/logger"
)
@@ -91,29 +90,3 @@ func HandleOEmbed(baseURL string) http.HandlerFunc {
"remote_addr", r.RemoteAddr)
}
}
// ValidateOEmbedURL checks if the provided URL is a valid Ackify document URL
func ValidateOEmbedURL(urlStr string, baseURL string) bool {
parsedURL, err := url.Parse(urlStr)
if err != nil {
return false
}
baseURLParsed, err := url.Parse(baseURL)
if err != nil {
return false
}
// Normalize hosts for comparison (remove ports if present)
urlHost := strings.Split(parsedURL.Host, ":")[0]
baseHost := strings.Split(baseURLParsed.Host, ":")[0]
// Allow localhost variations
if urlHost == "localhost" || urlHost == "127.0.0.1" {
if baseHost == "localhost" || baseHost == "127.0.0.1" {
return true
}
}
return urlHost == baseHost
}

View File

@@ -7,9 +7,10 @@ import "strings"
// This is the canonical user representation used by auth providers, domain models,
// and API handlers.
type User struct {
Sub string `json:"sub"` // Unique identifier (OAuth sub claim or email for MagicLink)
Email string `json:"email"` // User's email address
Name string `json:"name"` // Display name (optional)
Sub string `json:"sub"` // Unique identifier (OAuth sub claim or email for MagicLink)
Email string `json:"email"` // User's email address
Name string `json:"name"` // Display name (optional)
Picture string `json:"picture,omitempty"` // Avatar URL from OAuth provider (optional)
}
// IsValid returns true if the user has required fields populated.

View File

@@ -419,6 +419,15 @@ func (p *Provider) parseUserInfo(resp *http.Response) (*types.User, error) {
user.Name = preferredName
}
// Extract avatar URL (picture for Google/OIDC, avatar_url for GitHub/GitLab)
if picture, ok := rawUser["picture"].(string); ok && picture != "" {
user.Picture = picture
} else if avatarURL, ok := rawUser["avatar_url"].(string); ok && avatarURL != "" {
user.Picture = avatarURL
} else if photo, ok := rawUser["photo"].(string); ok && photo != "" {
user.Picture = photo
}
return user, nil
}

View File

@@ -33,7 +33,6 @@ import (
sdk "github.com/btouchard/shm/sdk/golang"
)
// Server represents the HTTP server with all its dependencies.
type Server struct {
httpServer *http.Server
db *sql.DB
@@ -90,7 +89,6 @@ type ServerBuilder struct {
configService *services.ConfigService
}
// NewServerBuilder creates a new server builder with the required configuration.
func NewServerBuilder(cfg *config.Config, frontend embed.FS, version string) *ServerBuilder {
return &ServerBuilder{
cfg: cfg,
@@ -205,7 +203,6 @@ func (b *ServerBuilder) Build(ctx context.Context) (*Server, error) {
}, nil
}
// validateProviders checks that required providers are set.
func (b *ServerBuilder) validateProviders() error {
if b.db == nil {
return errors.New("database is required: use WithDB()")
@@ -238,7 +235,6 @@ func (b *ServerBuilder) setDefaultProviders() {
}
}
// initializeInfrastructure initializes signer, i18n, email and storage.
func (b *ServerBuilder) initializeInfrastructure() error {
var err error
@@ -297,7 +293,6 @@ type repositories struct {
magicLink services.MagicLinkRepository
}
// createRepositories creates all repository instances.
func (b *ServerBuilder) createRepositories() *repositories {
return &repositories{
signature: database.NewSignatureRepository(b.db, b.tenantProvider),
@@ -336,7 +331,6 @@ func (b *ServerBuilder) initializeTelemetry(ctx context.Context) error {
return nil
}
// initializeWebhookSystem initializes webhook publisher and worker.
func (b *ServerBuilder) initializeWebhookSystem(ctx context.Context, repos *repositories) (*services.WebhookPublisher, *webhook.Worker, error) {
whPublisher := services.NewWebhookPublisher(repos.webhook, repos.webhookDelivery)
whCfg := webhook.DefaultWorkerConfig()
@@ -349,7 +343,6 @@ func (b *ServerBuilder) initializeWebhookSystem(ctx context.Context, repos *repo
return whPublisher, whWorker, nil
}
// initializeEmailWorker initializes email worker for async processing.
// emailRenderer is expected to be injected from main.go via WithEmailRenderer().
func (b *ServerBuilder) initializeEmailWorker(ctx context.Context, repos *repositories, whPublisher *services.WebhookPublisher) (*email.Worker, error) {
if b.emailSender == nil || b.cfg.Mail.Host == "" || b.emailRenderer == nil {
@@ -370,7 +363,6 @@ func (b *ServerBuilder) initializeEmailWorker(ctx context.Context, repos *reposi
return emailWorker, nil
}
// initializeCoreServices initializes signature, document, admin, and webhook services.
func (b *ServerBuilder) initializeCoreServices(repos *repositories) {
b.signatureService = services.NewSignatureService(repos.signature, repos.document, b.signer)
b.signatureService.SetChecksumConfig(&b.cfg.Checksum)
@@ -379,7 +371,6 @@ func (b *ServerBuilder) initializeCoreServices(repos *repositories) {
b.webhookService = services.NewWebhookService(repos.webhook, repos.webhookDelivery)
}
// initializeConfigService creates and initializes the configuration service.
func (b *ServerBuilder) initializeConfigService(ctx context.Context, repos *repositories) error {
encryptionKey := b.cfg.OAuth.CookieSecret
b.configService = services.NewConfigService(repos.config, b.cfg, encryptionKey)
@@ -423,7 +414,6 @@ func (b *ServerBuilder) initializeMagicLinkCleanupWorker(ctx context.Context) *w
return magicLinkWorker
}
// initializeReminderService initializes reminder service.
func (b *ServerBuilder) initializeReminderService(repos *repositories) {
b.reminderService = services.NewReminderAsyncService(
repos.expectedSigner,
@@ -435,7 +425,6 @@ func (b *ServerBuilder) initializeReminderService(repos *repositories) {
)
}
// initializeSessionWorker initializes OAuth session cleanup worker.
func (b *ServerBuilder) initializeSessionWorker(ctx context.Context, repos *repositories) (*auth.SessionWorker, error) {
if repos.oauthSession == nil {
return nil, nil
@@ -450,7 +439,6 @@ func (b *ServerBuilder) initializeSessionWorker(ctx context.Context, repos *repo
return sessionWorker, nil
}
// buildRouter creates and configures the main router.
func (b *ServerBuilder) buildRouter(repos *repositories, whPublisher *services.WebhookPublisher) *chi.Mux {
router := chi.NewRouter()
router.Use(i18n.Middleware(b.i18nService))
@@ -554,22 +542,18 @@ func (s *Server) GetDB() *sql.DB {
return s.db
}
// GetAuthProvider returns the auth provider.
func (s *Server) GetAuthProvider() AuthProvider {
return s.authProvider
}
// GetAuthorizer returns the authorizer.
func (s *Server) GetAuthorizer() Authorizer {
return s.authorizer
}
// GetQuotaEnforcer returns the quota enforcer.
func (s *Server) GetQuotaEnforcer() QuotaEnforcer {
return s.quotaEnforcer
}
// GetAuditLogger returns the audit logger.
func (s *Server) GetAuditLogger() AuditLogger {
return s.auditLogger
}

View File

@@ -234,7 +234,7 @@ watch(readMode, (newMode) => {
<div class="space-y-4">
<!-- URL input + Upload button + Submit button -->
<div class="flex gap-3" :class="mode === 'full' ? 'flex-col sm:flex-row' : ''">
<div class="flex-1">
<div class="flex-1 min-w-0">
<Label v-if="mode === 'full'" for="doc-url" class="mb-1.5">
{{ t('documentCreateForm.url.label') }}
</Label>
@@ -242,13 +242,13 @@ watch(readMode, (newMode) => {
<!-- Show selected file or URL input -->
<div v-if="selectedFile" class="flex items-center gap-2 px-4 py-2.5 rounded-lg border border-blue-300 dark:border-blue-700 bg-blue-50 dark:bg-blue-950/30" :class="mode === 'full' ? 'h-11' : 'h-12'">
<FileText class="w-4 h-4 text-blue-600 dark:text-blue-400 flex-shrink-0" />
<span data-testid="selected-file-name" class="flex-1 text-sm text-blue-800 dark:text-blue-200 truncate">{{ selectedFile.name }}</span>
<span class="text-xs text-blue-600 dark:text-blue-400">{{ formatFileSize(selectedFile.size) }}</span>
<span data-testid="selected-file-name" class="flex-1 min-w-0 text-sm text-blue-800 dark:text-blue-200 truncate">{{ selectedFile.name }}</span>
<span class="text-xs text-blue-600 dark:text-blue-400 flex-shrink-0 whitespace-nowrap">{{ formatFileSize(selectedFile.size) }}</span>
<button
type="button"
data-testid="clear-file-button"
@click="clearSelectedFile"
class="p-1 rounded hover:bg-blue-200 dark:hover:bg-blue-800 text-blue-600 dark:text-blue-400"
class="p-1 rounded hover:bg-blue-200 dark:hover:bg-blue-800 text-blue-600 dark:text-blue-400 flex-shrink-0"
:disabled="isSubmitting"
>
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">

View File

@@ -146,8 +146,18 @@ const closeUserMenu = () => {
aria-haspopup="true"
:aria-expanded="userMenuOpen"
>
<!-- User avatar with initials -->
<div class="w-8 h-8 rounded-lg bg-slate-100 dark:bg-slate-700 flex items-center justify-center text-xs font-semibold text-slate-600 dark:text-slate-300">
<!-- User avatar (picture or initials fallback) -->
<img
v-if="user?.picture"
:src="user.picture"
:alt="displayName"
class="w-8 h-8 rounded-lg object-cover"
referrerpolicy="no-referrer"
/>
<div
v-else
class="w-8 h-8 rounded-lg bg-slate-100 dark:bg-slate-700 flex items-center justify-center text-xs font-semibold text-slate-600 dark:text-slate-300"
>
{{ userInitials }}
</div>
<span class="text-slate-700 dark:text-slate-200 hidden lg:inline">{{ displayName }}</span>

View File

@@ -8,6 +8,7 @@ export interface User {
id: string
email: string
name: string
picture?: string
isAdmin: boolean
}