mirror of
https://github.com/btouchard/ackify.git
synced 2026-02-09 15:28:28 -06:00
refacto(backend): cleaning dead code
This commit is contained in:
@@ -1,42 +0,0 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package services
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// AuthorizerService provides authorization decisions for the organization.
|
||||
// In Community Edition, it uses environment variables for configuration.
|
||||
type AuthorizerService struct {
|
||||
adminEmails []string
|
||||
onlyAdminCanCreate bool
|
||||
}
|
||||
|
||||
// NewAuthorizerService creates a new AuthorizerService with the given configuration.
|
||||
func NewAuthorizerService(adminEmails []string, onlyAdminCanCreate bool) *AuthorizerService {
|
||||
// Normalize admin emails to lowercase for case-insensitive comparison
|
||||
normalized := make([]string, len(adminEmails))
|
||||
for i, email := range adminEmails {
|
||||
normalized[i] = strings.ToLower(strings.TrimSpace(email))
|
||||
}
|
||||
return &AuthorizerService{
|
||||
adminEmails: normalized,
|
||||
onlyAdminCanCreate: onlyAdminCanCreate,
|
||||
}
|
||||
}
|
||||
|
||||
// IsAdmin checks if the given email belongs to an administrator.
|
||||
func (s *AuthorizerService) IsAdmin(email string) bool {
|
||||
normalizedEmail := strings.ToLower(strings.TrimSpace(email))
|
||||
for _, adminEmail := range s.adminEmails {
|
||||
if normalizedEmail == adminEmail {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// OnlyAdminCanCreate returns whether only administrators can create documents.
|
||||
func (s *AuthorizerService) OnlyAdminCanCreate() bool {
|
||||
return s.onlyAdminCanCreate
|
||||
}
|
||||
@@ -1,218 +0,0 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/models"
|
||||
)
|
||||
|
||||
// ChecksumVerificationRepository defines the interface for checksum verification persistence
|
||||
type ChecksumVerificationRepository interface {
|
||||
RecordVerification(ctx context.Context, verification *models.ChecksumVerification) error
|
||||
GetVerificationHistory(ctx context.Context, docID string, limit int) ([]*models.ChecksumVerification, error)
|
||||
GetLastVerification(ctx context.Context, docID string) (*models.ChecksumVerification, error)
|
||||
}
|
||||
|
||||
// DocumentRepository defines the interface for document metadata operations
|
||||
type DocumentRepository interface {
|
||||
GetByDocID(ctx context.Context, docID string) (*models.Document, error)
|
||||
}
|
||||
|
||||
// ChecksumService orchestrates document integrity verification with audit trail persistence
|
||||
type ChecksumService struct {
|
||||
verificationRepo ChecksumVerificationRepository
|
||||
documentRepo DocumentRepository
|
||||
}
|
||||
|
||||
// NewChecksumService initializes checksum verification service with required repository dependencies
|
||||
func NewChecksumService(
|
||||
verificationRepo ChecksumVerificationRepository,
|
||||
documentRepo DocumentRepository,
|
||||
) *ChecksumService {
|
||||
return &ChecksumService{
|
||||
verificationRepo: verificationRepo,
|
||||
documentRepo: documentRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateChecksumFormat ensures checksum matches expected hexadecimal length for SHA-256/SHA-512/MD5
|
||||
func (s *ChecksumService) ValidateChecksumFormat(checksum, algorithm string) error {
|
||||
// Remove common separators and whitespace
|
||||
checksum = normalizeChecksum(checksum)
|
||||
|
||||
var expectedLength int
|
||||
switch algorithm {
|
||||
case "SHA-256":
|
||||
expectedLength = 64
|
||||
case "SHA-512":
|
||||
expectedLength = 128
|
||||
case "MD5":
|
||||
expectedLength = 32
|
||||
default:
|
||||
return fmt.Errorf("unsupported algorithm: %s", algorithm)
|
||||
}
|
||||
|
||||
// Check length
|
||||
if len(checksum) != expectedLength {
|
||||
return fmt.Errorf("invalid checksum length for %s: expected %d hexadecimal characters, got %d", algorithm, expectedLength, len(checksum))
|
||||
}
|
||||
|
||||
// Check if it's a valid hex string
|
||||
hexPattern := regexp.MustCompile("^[a-fA-F0-9]+$")
|
||||
if !hexPattern.MatchString(checksum) {
|
||||
return fmt.Errorf("invalid checksum format: must contain only hexadecimal characters (0-9, a-f, A-F)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyChecksum compares calculated hash against stored reference and creates immutable audit record
|
||||
func (s *ChecksumService) VerifyChecksum(ctx context.Context, docID, calculatedChecksum, verifiedBy string) (*models.ChecksumVerificationResult, error) {
|
||||
// Get document metadata
|
||||
doc, err := s.documentRepo.GetByDocID(ctx, docID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get document: %w", err)
|
||||
}
|
||||
|
||||
if doc == nil {
|
||||
return nil, fmt.Errorf("document not found: %s", docID)
|
||||
}
|
||||
|
||||
// Normalize checksums for comparison
|
||||
normalizedCalculated := normalizeChecksum(calculatedChecksum)
|
||||
normalizedStored := normalizeChecksum(doc.Checksum)
|
||||
|
||||
// Determine the algorithm to use (from document or default to SHA-256)
|
||||
algorithm := doc.ChecksumAlgorithm
|
||||
if algorithm == "" {
|
||||
algorithm = "SHA-256"
|
||||
}
|
||||
|
||||
// Validate the calculated checksum format
|
||||
if err := s.ValidateChecksumFormat(normalizedCalculated, algorithm); err != nil {
|
||||
// Record failed verification with error
|
||||
errorMsg := err.Error()
|
||||
verification := &models.ChecksumVerification{
|
||||
DocID: docID,
|
||||
VerifiedBy: verifiedBy,
|
||||
VerifiedAt: time.Now(),
|
||||
StoredChecksum: normalizedStored,
|
||||
CalculatedChecksum: normalizedCalculated,
|
||||
Algorithm: algorithm,
|
||||
IsValid: false,
|
||||
ErrorMessage: &errorMsg,
|
||||
}
|
||||
_ = s.verificationRepo.RecordVerification(ctx, verification)
|
||||
|
||||
return nil, fmt.Errorf("invalid checksum format: %w", err)
|
||||
}
|
||||
|
||||
// Check if document has a reference checksum
|
||||
if !doc.HasChecksum() {
|
||||
result := &models.ChecksumVerificationResult{
|
||||
Valid: false,
|
||||
StoredChecksum: "",
|
||||
CalculatedChecksum: normalizedCalculated,
|
||||
Algorithm: algorithm,
|
||||
Message: "No reference checksum configured for this document",
|
||||
HasReferenceHash: false,
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Compare checksums (case-insensitive)
|
||||
isValid := strings.EqualFold(normalizedCalculated, normalizedStored)
|
||||
|
||||
// Record verification
|
||||
verification := &models.ChecksumVerification{
|
||||
DocID: docID,
|
||||
VerifiedBy: verifiedBy,
|
||||
VerifiedAt: time.Now(),
|
||||
StoredChecksum: normalizedStored,
|
||||
CalculatedChecksum: normalizedCalculated,
|
||||
Algorithm: algorithm,
|
||||
IsValid: isValid,
|
||||
ErrorMessage: nil,
|
||||
}
|
||||
|
||||
if err := s.verificationRepo.RecordVerification(ctx, verification); err != nil {
|
||||
logger.Logger.Error("Failed to record verification", "error", err.Error(), "doc_id", docID)
|
||||
// Continue even if recording fails - return the result
|
||||
}
|
||||
|
||||
var message string
|
||||
if isValid {
|
||||
message = "Checksums match - document integrity verified"
|
||||
} else {
|
||||
message = "Checksums do not match - document may have been modified"
|
||||
}
|
||||
|
||||
result := &models.ChecksumVerificationResult{
|
||||
Valid: isValid,
|
||||
StoredChecksum: normalizedStored,
|
||||
CalculatedChecksum: normalizedCalculated,
|
||||
Algorithm: algorithm,
|
||||
Message: message,
|
||||
HasReferenceHash: true,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetVerificationHistory retrieves paginated audit trail of all checksum validation attempts
|
||||
func (s *ChecksumService) GetVerificationHistory(ctx context.Context, docID string, limit int) ([]*models.ChecksumVerification, error) {
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
|
||||
return s.verificationRepo.GetVerificationHistory(ctx, docID, limit)
|
||||
}
|
||||
|
||||
// GetSupportedAlgorithms returns available hash algorithms for client-side documentation
|
||||
func (s *ChecksumService) GetSupportedAlgorithms() []string {
|
||||
return []string{"SHA-256", "SHA-512", "MD5"}
|
||||
}
|
||||
|
||||
// GetChecksumInfo exposes document hash metadata for public verification interfaces
|
||||
func (s *ChecksumService) GetChecksumInfo(ctx context.Context, docID string) (map[string]interface{}, error) {
|
||||
doc, err := s.documentRepo.GetByDocID(ctx, docID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get document: %w", err)
|
||||
}
|
||||
|
||||
if doc == nil {
|
||||
return nil, fmt.Errorf("document not found: %s", docID)
|
||||
}
|
||||
|
||||
algorithm := doc.ChecksumAlgorithm
|
||||
if algorithm == "" {
|
||||
algorithm = "SHA-256"
|
||||
}
|
||||
|
||||
info := map[string]interface{}{
|
||||
"doc_id": docID,
|
||||
"has_checksum": doc.HasChecksum(),
|
||||
"algorithm": algorithm,
|
||||
"checksum_length": doc.GetExpectedChecksumLength(),
|
||||
"supported_algorithms": s.GetSupportedAlgorithms(),
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// normalizeChecksum removes common separators and converts to lowercase
|
||||
func normalizeChecksum(checksum string) string {
|
||||
// Remove spaces, hyphens, underscores
|
||||
checksum = strings.ReplaceAll(checksum, " ", "")
|
||||
checksum = strings.ReplaceAll(checksum, "-", "")
|
||||
checksum = strings.ReplaceAll(checksum, "_", "")
|
||||
checksum = strings.TrimSpace(checksum)
|
||||
// Convert to lowercase for case-insensitive comparison
|
||||
return strings.ToLower(checksum)
|
||||
}
|
||||
@@ -1,483 +0,0 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/models"
|
||||
)
|
||||
|
||||
type fakeVerificationRepository struct {
|
||||
verifications []*models.ChecksumVerification
|
||||
shouldFailRecord bool
|
||||
shouldFailGetHistory bool
|
||||
shouldFailGetLast bool
|
||||
}
|
||||
|
||||
func newFakeVerificationRepository() *fakeVerificationRepository {
|
||||
return &fakeVerificationRepository{
|
||||
verifications: make([]*models.ChecksumVerification, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeVerificationRepository) RecordVerification(_ context.Context, verification *models.ChecksumVerification) error {
|
||||
if f.shouldFailRecord {
|
||||
return errors.New("repository record failed")
|
||||
}
|
||||
|
||||
verification.ID = int64(len(f.verifications) + 1)
|
||||
f.verifications = append(f.verifications, verification)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeVerificationRepository) GetVerificationHistory(_ context.Context, docID string, limit int) ([]*models.ChecksumVerification, error) {
|
||||
if f.shouldFailGetHistory {
|
||||
return nil, errors.New("repository get history failed")
|
||||
}
|
||||
|
||||
var result []*models.ChecksumVerification
|
||||
for _, v := range f.verifications {
|
||||
if v.DocID == docID {
|
||||
result = append(result, v)
|
||||
if len(result) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (f *fakeVerificationRepository) GetLastVerification(_ context.Context, docID string) (*models.ChecksumVerification, error) {
|
||||
if f.shouldFailGetLast {
|
||||
return nil, errors.New("repository get last failed")
|
||||
}
|
||||
|
||||
for i := len(f.verifications) - 1; i >= 0; i-- {
|
||||
if f.verifications[i].DocID == docID {
|
||||
return f.verifications[i], nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type fakeDocumentRepository struct {
|
||||
documents map[string]*models.Document
|
||||
shouldFailGet bool
|
||||
}
|
||||
|
||||
func newFakeDocumentRepository() *fakeDocumentRepository {
|
||||
return &fakeDocumentRepository{
|
||||
documents: make(map[string]*models.Document),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) GetByDocID(_ context.Context, docID string) (*models.Document, error) {
|
||||
if f.shouldFailGet {
|
||||
return nil, errors.New("repository get failed")
|
||||
}
|
||||
|
||||
doc, exists := f.documents[docID]
|
||||
if !exists {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) Create(_ context.Context, docID string, input models.DocumentInput, createdBy string) (*models.Document, error) {
|
||||
if f.shouldFailGet {
|
||||
return nil, errors.New("repository create failed")
|
||||
}
|
||||
|
||||
doc := &models.Document{
|
||||
DocID: docID,
|
||||
Title: input.Title,
|
||||
URL: input.URL,
|
||||
Checksum: input.Checksum,
|
||||
ChecksumAlgorithm: input.ChecksumAlgorithm,
|
||||
Description: input.Description,
|
||||
CreatedBy: createdBy,
|
||||
}
|
||||
f.documents[docID] = doc
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) FindByReference(_ context.Context, ref string, refType string) (*models.Document, error) {
|
||||
if f.shouldFailGet {
|
||||
return nil, errors.New("repository find failed")
|
||||
}
|
||||
|
||||
for _, doc := range f.documents {
|
||||
if doc.URL == ref {
|
||||
return doc, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) List(_ context.Context, _, _ int) ([]*models.Document, error) {
|
||||
result := make([]*models.Document, 0, len(f.documents))
|
||||
for _, doc := range f.documents {
|
||||
result = append(result, doc)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) Search(_ context.Context, _ string, _, _ int) ([]*models.Document, error) {
|
||||
return []*models.Document{}, nil
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) Count(_ context.Context, _ string) (int, error) {
|
||||
return len(f.documents), nil
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) ListByCreatedBy(_ context.Context, _ string, _, _ int) ([]*models.Document, error) {
|
||||
return []*models.Document{}, nil
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) SearchByCreatedBy(_ context.Context, _, _ string, _, _ int) ([]*models.Document, error) {
|
||||
return []*models.Document{}, nil
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) CountByCreatedBy(_ context.Context, _, _ string) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func TestChecksumService_ValidateChecksumFormat(t *testing.T) {
|
||||
service := NewChecksumService(newFakeVerificationRepository(), newFakeDocumentRepository())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
checksum string
|
||||
algorithm string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "valid SHA-256",
|
||||
checksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
||||
algorithm: "SHA-256",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "valid SHA-512",
|
||||
checksum: "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e",
|
||||
algorithm: "SHA-512",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "valid MD5",
|
||||
checksum: "d41d8cd98f00b204e9800998ecf8427e",
|
||||
algorithm: "MD5",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "valid with uppercase",
|
||||
checksum: "E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855",
|
||||
algorithm: "SHA-256",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "valid with spaces",
|
||||
checksum: "e3b0c442 98fc1c14 9afbf4c8 996fb924 27ae41e4 649b934c a495991b 7852b855",
|
||||
algorithm: "SHA-256",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "valid with hyphens",
|
||||
checksum: "e3b0c442-98fc1c14-9afbf4c8-996fb924-27ae41e4-649b934c-a495991b-7852b855",
|
||||
algorithm: "SHA-256",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid - too short for SHA-256",
|
||||
checksum: "abc123",
|
||||
algorithm: "SHA-256",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - too long for MD5",
|
||||
checksum: "d41d8cd98f00b204e9800998ecf8427eextra",
|
||||
algorithm: "MD5",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - non-hex characters",
|
||||
checksum: "gggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg",
|
||||
algorithm: "SHA-256",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - unsupported algorithm",
|
||||
checksum: "abc123",
|
||||
algorithm: "SHA-1",
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := service.ValidateChecksumFormat(tt.checksum, tt.algorithm)
|
||||
|
||||
if tt.wantError && err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
if !tt.wantError && err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChecksumService_VerifyChecksum(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
docID string
|
||||
document *models.Document
|
||||
calculatedChecksum string
|
||||
verifiedBy string
|
||||
wantValid bool
|
||||
wantHasReference bool
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "valid verification - checksums match",
|
||||
docID: "doc-001",
|
||||
document: &models.Document{
|
||||
DocID: "doc-001",
|
||||
Checksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
||||
ChecksumAlgorithm: "SHA-256",
|
||||
},
|
||||
calculatedChecksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
||||
verifiedBy: "user@example.com",
|
||||
wantValid: true,
|
||||
wantHasReference: true,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid verification - checksums differ",
|
||||
docID: "doc-002",
|
||||
document: &models.Document{
|
||||
DocID: "doc-002",
|
||||
Checksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
||||
ChecksumAlgorithm: "SHA-256",
|
||||
},
|
||||
calculatedChecksum: "0000000000000000000000000000000000000000000000000000000000000000",
|
||||
verifiedBy: "user@example.com",
|
||||
wantValid: false,
|
||||
wantHasReference: true,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "no reference checksum",
|
||||
docID: "doc-003",
|
||||
document: &models.Document{
|
||||
DocID: "doc-003",
|
||||
Checksum: "",
|
||||
ChecksumAlgorithm: "SHA-256",
|
||||
},
|
||||
calculatedChecksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
||||
verifiedBy: "user@example.com",
|
||||
wantValid: false,
|
||||
wantHasReference: false,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "case insensitive comparison",
|
||||
docID: "doc-004",
|
||||
document: &models.Document{
|
||||
DocID: "doc-004",
|
||||
Checksum: "E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855",
|
||||
ChecksumAlgorithm: "SHA-256",
|
||||
},
|
||||
calculatedChecksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
||||
verifiedBy: "user@example.com",
|
||||
wantValid: true,
|
||||
wantHasReference: true,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "document not found",
|
||||
docID: "non-existent",
|
||||
document: nil,
|
||||
calculatedChecksum: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
||||
verifiedBy: "user@example.com",
|
||||
wantValid: false,
|
||||
wantHasReference: false,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
verificationRepo := newFakeVerificationRepository()
|
||||
documentRepo := newFakeDocumentRepository()
|
||||
|
||||
if tt.document != nil {
|
||||
documentRepo.documents[tt.docID] = tt.document
|
||||
}
|
||||
|
||||
service := NewChecksumService(verificationRepo, documentRepo)
|
||||
|
||||
result, err := service.VerifyChecksum(ctx, tt.docID, tt.calculatedChecksum, tt.verifiedBy)
|
||||
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result.Valid != tt.wantValid {
|
||||
t.Errorf("expected Valid=%v, got %v", tt.wantValid, result.Valid)
|
||||
}
|
||||
|
||||
if result.HasReferenceHash != tt.wantHasReference {
|
||||
t.Errorf("expected HasReferenceHash=%v, got %v", tt.wantHasReference, result.HasReferenceHash)
|
||||
}
|
||||
|
||||
// Check that verification was recorded (if document has checksum)
|
||||
if tt.wantHasReference {
|
||||
if len(verificationRepo.verifications) != 1 {
|
||||
t.Errorf("expected 1 verification recorded, got %d", len(verificationRepo.verifications))
|
||||
} else {
|
||||
v := verificationRepo.verifications[0]
|
||||
if v.IsValid != tt.wantValid {
|
||||
t.Errorf("recorded verification IsValid=%v, expected %v", v.IsValid, tt.wantValid)
|
||||
}
|
||||
if v.VerifiedBy != tt.verifiedBy {
|
||||
t.Errorf("recorded verification VerifiedBy=%s, expected %s", v.VerifiedBy, tt.verifiedBy)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChecksumService_GetVerificationHistory(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
verificationRepo := newFakeVerificationRepository()
|
||||
documentRepo := newFakeDocumentRepository()
|
||||
service := NewChecksumService(verificationRepo, documentRepo)
|
||||
|
||||
// Add test verifications
|
||||
for i := 0; i < 5; i++ {
|
||||
v := &models.ChecksumVerification{
|
||||
DocID: "doc-001",
|
||||
VerifiedBy: "user@example.com",
|
||||
VerifiedAt: time.Now(),
|
||||
StoredChecksum: "abc123",
|
||||
CalculatedChecksum: "abc123",
|
||||
Algorithm: "SHA-256",
|
||||
IsValid: true,
|
||||
}
|
||||
_ = verificationRepo.RecordVerification(ctx, v)
|
||||
}
|
||||
|
||||
// Test get all
|
||||
history, err := service.GetVerificationHistory(ctx, "doc-001", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(history) != 5 {
|
||||
t.Errorf("expected 5 verifications, got %d", len(history))
|
||||
}
|
||||
|
||||
// Test with limit
|
||||
limited, err := service.GetVerificationHistory(ctx, "doc-001", 2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(limited) != 2 {
|
||||
t.Errorf("expected 2 verifications with limit, got %d", len(limited))
|
||||
}
|
||||
}
|
||||
|
||||
func TestChecksumService_GetChecksumInfo(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
docID string
|
||||
document *models.Document
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "document with checksum",
|
||||
docID: "doc-001",
|
||||
document: &models.Document{
|
||||
DocID: "doc-001",
|
||||
Checksum: "abc123",
|
||||
ChecksumAlgorithm: "SHA-256",
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "document without checksum",
|
||||
docID: "doc-002",
|
||||
document: &models.Document{
|
||||
DocID: "doc-002",
|
||||
Checksum: "",
|
||||
ChecksumAlgorithm: "SHA-256",
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "document not found",
|
||||
docID: "non-existent",
|
||||
document: nil,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
documentRepo := newFakeDocumentRepository()
|
||||
if tt.document != nil {
|
||||
documentRepo.documents[tt.docID] = tt.document
|
||||
}
|
||||
|
||||
service := NewChecksumService(newFakeVerificationRepository(), documentRepo)
|
||||
|
||||
info, err := service.GetChecksumInfo(ctx, tt.docID)
|
||||
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if info["doc_id"] != tt.docID {
|
||||
t.Errorf("expected doc_id %s, got %v", tt.docID, info["doc_id"])
|
||||
}
|
||||
|
||||
if _, ok := info["has_checksum"]; !ok {
|
||||
t.Error("expected has_checksum field")
|
||||
}
|
||||
|
||||
if _, ok := info["supported_algorithms"]; !ok {
|
||||
t.Error("expected supported_algorithms field")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,238 +0,0 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/email"
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/i18n"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/models"
|
||||
)
|
||||
|
||||
// expectedSignerRepository defines minimal interface for expected signer operations
|
||||
type expectedSignerRepository interface {
|
||||
ListWithStatusByDocID(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error)
|
||||
}
|
||||
|
||||
// reminderRepository defines minimal interface for reminder logging and history
|
||||
type reminderRepository interface {
|
||||
LogReminder(ctx context.Context, log *models.ReminderLog) error
|
||||
GetReminderHistory(ctx context.Context, docID string) ([]*models.ReminderLog, error)
|
||||
GetReminderStats(ctx context.Context, docID string) (*models.ReminderStats, error)
|
||||
}
|
||||
|
||||
// magicLinkService defines minimal interface for creating reminder auth tokens
|
||||
type magicLinkService interface {
|
||||
CreateReminderAuthToken(ctx context.Context, email string, docID string) (string, error)
|
||||
}
|
||||
|
||||
// ReminderService manages email notifications to pending signers with delivery tracking
|
||||
type ReminderService struct {
|
||||
expectedSignerRepo expectedSignerRepository
|
||||
reminderRepo reminderRepository
|
||||
emailSender email.Sender
|
||||
magicLinkService magicLinkService
|
||||
i18n *i18n.I18n
|
||||
baseURL string
|
||||
}
|
||||
|
||||
// NewReminderService initializes reminder service with email sender and repository dependencies
|
||||
func NewReminderService(
|
||||
expectedSignerRepo expectedSignerRepository,
|
||||
reminderRepo reminderRepository,
|
||||
emailSender email.Sender,
|
||||
magicLinkService magicLinkService,
|
||||
i18nService *i18n.I18n,
|
||||
baseURL string,
|
||||
) *ReminderService {
|
||||
return &ReminderService{
|
||||
expectedSignerRepo: expectedSignerRepo,
|
||||
reminderRepo: reminderRepo,
|
||||
emailSender: emailSender,
|
||||
magicLinkService: magicLinkService,
|
||||
i18n: i18nService,
|
||||
baseURL: baseURL,
|
||||
}
|
||||
}
|
||||
|
||||
// SendReminders dispatches email notifications to all or selected pending signers with result aggregation
|
||||
func (s *ReminderService) SendReminders(
|
||||
ctx context.Context,
|
||||
docID string,
|
||||
sentBy string,
|
||||
specificEmails []string,
|
||||
docURL string,
|
||||
locale string,
|
||||
) (*models.ReminderSendResult, error) {
|
||||
|
||||
logger.Logger.Info("Starting reminder sending process",
|
||||
"doc_id", docID,
|
||||
"sent_by", sentBy,
|
||||
"specific_emails_count", len(specificEmails),
|
||||
"locale", locale)
|
||||
|
||||
allSigners, err := s.expectedSignerRepo.ListWithStatusByDocID(ctx, docID)
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to get expected signers for reminders",
|
||||
"doc_id", docID,
|
||||
"error", err.Error())
|
||||
return nil, fmt.Errorf("failed to get expected signers: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Debug("Retrieved expected signers",
|
||||
"doc_id", docID,
|
||||
"total_signers", len(allSigners))
|
||||
|
||||
var pendingSigners []*models.ExpectedSignerWithStatus
|
||||
for _, signer := range allSigners {
|
||||
if !signer.HasSigned {
|
||||
if len(specificEmails) > 0 {
|
||||
if containsEmail(specificEmails, signer.Email) {
|
||||
pendingSigners = append(pendingSigners, signer)
|
||||
}
|
||||
} else {
|
||||
pendingSigners = append(pendingSigners, signer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.Logger.Info("Identified pending signers",
|
||||
"doc_id", docID,
|
||||
"pending_count", len(pendingSigners),
|
||||
"total_signers", len(allSigners))
|
||||
|
||||
if len(pendingSigners) == 0 {
|
||||
logger.Logger.Info("No pending signers found, no reminders to send",
|
||||
"doc_id", docID)
|
||||
return &models.ReminderSendResult{
|
||||
TotalAttempted: 0,
|
||||
SuccessfullySent: 0,
|
||||
Failed: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
result := &models.ReminderSendResult{
|
||||
TotalAttempted: len(pendingSigners),
|
||||
}
|
||||
|
||||
for _, signer := range pendingSigners {
|
||||
err := s.sendSingleReminder(ctx, docID, signer.Email, signer.Name, sentBy, docURL, locale)
|
||||
if err != nil {
|
||||
result.Failed++
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("%s: %v", signer.Email, err))
|
||||
} else {
|
||||
result.SuccessfullySent++
|
||||
}
|
||||
}
|
||||
|
||||
logger.Logger.Info("Reminder batch completed",
|
||||
"doc_id", docID,
|
||||
"total_attempted", result.TotalAttempted,
|
||||
"successfully_sent", result.SuccessfullySent,
|
||||
"failed", result.Failed)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// sendSingleReminder sends a reminder to a single signer
|
||||
func (s *ReminderService) sendSingleReminder(
|
||||
ctx context.Context,
|
||||
docID string,
|
||||
recipientEmail string,
|
||||
recipientName string,
|
||||
sentBy string,
|
||||
docURL string,
|
||||
locale string,
|
||||
) error {
|
||||
|
||||
logger.Logger.Debug("Sending reminder to signer",
|
||||
"doc_id", docID,
|
||||
"recipient_email", recipientEmail,
|
||||
"recipient_name", recipientName,
|
||||
"sent_by", sentBy)
|
||||
|
||||
// Générer un token d'authentification pour ce lecteur
|
||||
token, err := s.magicLinkService.CreateReminderAuthToken(ctx, recipientEmail, docID)
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to create reminder auth token",
|
||||
"doc_id", docID,
|
||||
"recipient_email", recipientEmail,
|
||||
"error", err.Error())
|
||||
return fmt.Errorf("failed to create auth token: %w", err)
|
||||
}
|
||||
|
||||
// Construire l'URL d'authentification qui redirigera vers la page de signature
|
||||
authSignURL := fmt.Sprintf("%s/api/v1/auth/reminder-link/verify?token=%s", s.baseURL, token)
|
||||
|
||||
logger.Logger.Debug("Generated auth sign URL for reminder",
|
||||
"doc_id", docID,
|
||||
"recipient_email", recipientEmail,
|
||||
"url", authSignURL)
|
||||
|
||||
log := &models.ReminderLog{
|
||||
DocID: docID,
|
||||
RecipientEmail: recipientEmail,
|
||||
SentAt: time.Now(),
|
||||
SentBy: sentBy,
|
||||
TemplateUsed: "signature_reminder",
|
||||
Status: "sent",
|
||||
}
|
||||
|
||||
err = email.SendSignatureReminderEmail(ctx, s.emailSender, s.i18n, []string{recipientEmail}, locale, docID, docURL, authSignURL, recipientName)
|
||||
if err != nil {
|
||||
log.Status = "failed"
|
||||
errMsg := err.Error()
|
||||
log.ErrorMessage = &errMsg
|
||||
|
||||
logger.Logger.Warn("Failed to send reminder email",
|
||||
"doc_id", docID,
|
||||
"recipient_email", recipientEmail,
|
||||
"error", err.Error())
|
||||
|
||||
if logErr := s.reminderRepo.LogReminder(ctx, log); logErr != nil {
|
||||
logger.Logger.Error("Failed to log reminder error",
|
||||
"doc_id", docID,
|
||||
"recipient_email", recipientEmail,
|
||||
"log_error", logErr.Error(),
|
||||
"original_error", err.Error())
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to send email: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Info("Reminder email sent successfully",
|
||||
"doc_id", docID,
|
||||
"recipient_email", recipientEmail)
|
||||
|
||||
if err := s.reminderRepo.LogReminder(ctx, log); err != nil {
|
||||
logger.Logger.Error("Failed to log successful reminder",
|
||||
"doc_id", docID,
|
||||
"recipient_email", recipientEmail,
|
||||
"error", err.Error())
|
||||
return fmt.Errorf("email sent but failed to log: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetReminderStats retrieves aggregated reminder metrics for monitoring dashboard
|
||||
func (s *ReminderService) GetReminderStats(ctx context.Context, docID string) (*models.ReminderStats, error) {
|
||||
return s.reminderRepo.GetReminderStats(ctx, docID)
|
||||
}
|
||||
|
||||
// GetReminderHistory retrieves complete email send log with success/failure tracking
|
||||
func (s *ReminderService) GetReminderHistory(ctx context.Context, docID string) ([]*models.ReminderLog, error) {
|
||||
return s.reminderRepo.GetReminderHistory(ctx, docID)
|
||||
}
|
||||
|
||||
func containsEmail(slice []string, item string) bool {
|
||||
for _, s := range slice {
|
||||
if s == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -315,3 +315,12 @@ func (s *ReminderAsyncService) CountSent(ctx context.Context) int {
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func containsEmail(emails []string, target string) bool {
|
||||
for _, e := range emails {
|
||||
if e == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,536 +0,0 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/email"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/models"
|
||||
)
|
||||
|
||||
// Mock implementations for testing
|
||||
type mockExpectedSignerRepository struct {
|
||||
listWithStatusFunc func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error)
|
||||
}
|
||||
|
||||
func (m *mockExpectedSignerRepository) ListWithStatusByDocID(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
|
||||
if m.listWithStatusFunc != nil {
|
||||
return m.listWithStatusFunc(ctx, docID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type mockReminderRepository struct {
|
||||
logReminderFunc func(ctx context.Context, log *models.ReminderLog) error
|
||||
getReminderHistoryFunc func(ctx context.Context, docID string) ([]*models.ReminderLog, error)
|
||||
getReminderStatsFunc func(ctx context.Context, docID string) (*models.ReminderStats, error)
|
||||
}
|
||||
|
||||
func (m *mockReminderRepository) LogReminder(ctx context.Context, log *models.ReminderLog) error {
|
||||
if m.logReminderFunc != nil {
|
||||
return m.logReminderFunc(ctx, log)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReminderRepository) GetReminderHistory(ctx context.Context, docID string) ([]*models.ReminderLog, error) {
|
||||
if m.getReminderHistoryFunc != nil {
|
||||
return m.getReminderHistoryFunc(ctx, docID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockReminderRepository) GetReminderStats(ctx context.Context, docID string) (*models.ReminderStats, error) {
|
||||
if m.getReminderStatsFunc != nil {
|
||||
return m.getReminderStatsFunc(ctx, docID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type mockEmailSender struct {
|
||||
sendFunc func(ctx context.Context, msg email.Message) error
|
||||
}
|
||||
|
||||
func (m *mockEmailSender) Send(ctx context.Context, msg email.Message) error {
|
||||
if m.sendFunc != nil {
|
||||
return m.sendFunc(ctx, msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockMagicLinkService struct {
|
||||
createReminderAuthTokenFunc func(ctx context.Context, email string, docID string) (string, error)
|
||||
}
|
||||
|
||||
func (m *mockMagicLinkService) CreateReminderAuthToken(ctx context.Context, email string, docID string) (string, error) {
|
||||
if m.createReminderAuthTokenFunc != nil {
|
||||
return m.createReminderAuthTokenFunc(ctx, email, docID)
|
||||
}
|
||||
return "mock-token-123", nil
|
||||
}
|
||||
|
||||
// Test helper function
|
||||
func TestContainsEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
slice []string
|
||||
item string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Email found",
|
||||
slice: []string{"alice@example.com", "bob@example.com", "charlie@example.com"},
|
||||
item: "bob@example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Email not found",
|
||||
slice: []string{"alice@example.com", "bob@example.com"},
|
||||
item: "charlie@example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty slice",
|
||||
slice: []string{},
|
||||
item: "test@example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Case sensitive",
|
||||
slice: []string{"Test@Example.com"},
|
||||
item: "test@example.com",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := containsEmail(tt.slice, tt.item)
|
||||
if result != tt.expected {
|
||||
t.Errorf("containsEmail(%v, %q) = %v, want %v", tt.slice, tt.item, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test SendReminders with no pending signers
|
||||
func TestReminderService_SendReminders_NoPendingSigners(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
mockExpectedRepo := &mockExpectedSignerRepository{
|
||||
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
|
||||
return []*models.ExpectedSignerWithStatus{
|
||||
{ExpectedSigner: models.ExpectedSigner{Email: "signed@example.com"}, HasSigned: true},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mockReminderRepo := &mockReminderRepository{}
|
||||
mockEmailSender := &mockEmailSender{}
|
||||
mockMagicLink := &mockMagicLinkService{}
|
||||
|
||||
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com")
|
||||
|
||||
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if result.TotalAttempted != 0 {
|
||||
t.Errorf("Expected 0 total attempted, got %d", result.TotalAttempted)
|
||||
}
|
||||
|
||||
if result.SuccessfullySent != 0 {
|
||||
t.Errorf("Expected 0 successfully sent, got %d", result.SuccessfullySent)
|
||||
}
|
||||
}
|
||||
|
||||
// Test SendReminders with successful email send
|
||||
func TestReminderService_SendReminders_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
mockExpectedRepo := &mockExpectedSignerRepository{
|
||||
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
|
||||
return []*models.ExpectedSignerWithStatus{
|
||||
{ExpectedSigner: models.ExpectedSigner{Email: "pending@example.com", Name: "Pending User"}, HasSigned: false},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
loggedReminder := false
|
||||
mockReminderRepo := &mockReminderRepository{
|
||||
logReminderFunc: func(ctx context.Context, log *models.ReminderLog) error {
|
||||
loggedReminder = true
|
||||
if log.Status != "sent" {
|
||||
t.Errorf("Expected status 'sent', got '%s'", log.Status)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
emailSent := false
|
||||
mockEmailSender := &mockEmailSender{
|
||||
sendFunc: func(ctx context.Context, msg email.Message) error {
|
||||
emailSent = true
|
||||
if len(msg.To) != 1 || msg.To[0] != "pending@example.com" {
|
||||
t.Errorf("Expected email to 'pending@example.com', got %v", msg.To)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
mockMagicLink := &mockMagicLinkService{}
|
||||
|
||||
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com")
|
||||
|
||||
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if result.TotalAttempted != 1 {
|
||||
t.Errorf("Expected 1 total attempted, got %d", result.TotalAttempted)
|
||||
}
|
||||
|
||||
if result.SuccessfullySent != 1 {
|
||||
t.Errorf("Expected 1 successfully sent, got %d", result.SuccessfullySent)
|
||||
}
|
||||
|
||||
if result.Failed != 0 {
|
||||
t.Errorf("Expected 0 failed, got %d", result.Failed)
|
||||
}
|
||||
|
||||
if !emailSent {
|
||||
t.Error("Expected email to be sent")
|
||||
}
|
||||
|
||||
if !loggedReminder {
|
||||
t.Error("Expected reminder to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
// Test SendReminders with email failure
|
||||
func TestReminderService_SendReminders_EmailFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
mockExpectedRepo := &mockExpectedSignerRepository{
|
||||
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
|
||||
return []*models.ExpectedSignerWithStatus{
|
||||
{ExpectedSigner: models.ExpectedSigner{Email: "pending@example.com", Name: "Pending User"}, HasSigned: false},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
loggedReminder := false
|
||||
mockReminderRepo := &mockReminderRepository{
|
||||
logReminderFunc: func(ctx context.Context, log *models.ReminderLog) error {
|
||||
loggedReminder = true
|
||||
if log.Status != "failed" {
|
||||
t.Errorf("Expected status 'failed', got '%s'", log.Status)
|
||||
}
|
||||
if log.ErrorMessage == nil {
|
||||
t.Error("Expected error message to be set")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mockEmailSender := &mockEmailSender{
|
||||
sendFunc: func(ctx context.Context, msg email.Message) error {
|
||||
return errors.New("SMTP connection failed")
|
||||
},
|
||||
}
|
||||
mockMagicLink := &mockMagicLinkService{}
|
||||
|
||||
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com")
|
||||
|
||||
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error from SendReminders, got: %v", err)
|
||||
}
|
||||
|
||||
if result.TotalAttempted != 1 {
|
||||
t.Errorf("Expected 1 total attempted, got %d", result.TotalAttempted)
|
||||
}
|
||||
|
||||
if result.Failed != 1 {
|
||||
t.Errorf("Expected 1 failed, got %d", result.Failed)
|
||||
}
|
||||
|
||||
if result.SuccessfullySent != 0 {
|
||||
t.Errorf("Expected 0 successfully sent, got %d", result.SuccessfullySent)
|
||||
}
|
||||
|
||||
if len(result.Errors) != 1 {
|
||||
t.Errorf("Expected 1 error message, got %d", len(result.Errors))
|
||||
}
|
||||
|
||||
if !loggedReminder {
|
||||
t.Error("Expected failed reminder to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
// Test SendReminders with specific emails filter
|
||||
func TestReminderService_SendReminders_SpecificEmails(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
mockExpectedRepo := &mockExpectedSignerRepository{
|
||||
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
|
||||
return []*models.ExpectedSignerWithStatus{
|
||||
{ExpectedSigner: models.ExpectedSigner{Email: "pending1@example.com"}, HasSigned: false},
|
||||
{ExpectedSigner: models.ExpectedSigner{Email: "pending2@example.com"}, HasSigned: false},
|
||||
{ExpectedSigner: models.ExpectedSigner{Email: "pending3@example.com"}, HasSigned: false},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
emailsSent := []string{}
|
||||
mockReminderRepo := &mockReminderRepository{
|
||||
logReminderFunc: func(ctx context.Context, log *models.ReminderLog) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mockEmailSender := &mockEmailSender{
|
||||
sendFunc: func(ctx context.Context, msg email.Message) error {
|
||||
emailsSent = append(emailsSent, msg.To[0])
|
||||
return nil
|
||||
},
|
||||
}
|
||||
mockMagicLink := &mockMagicLinkService{}
|
||||
|
||||
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com")
|
||||
|
||||
specificEmails := []string{"pending2@example.com"}
|
||||
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", specificEmails, "https://example.com/doc.pdf", "en")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if result.TotalAttempted != 1 {
|
||||
t.Errorf("Expected 1 total attempted, got %d", result.TotalAttempted)
|
||||
}
|
||||
|
||||
if len(emailsSent) != 1 || emailsSent[0] != "pending2@example.com" {
|
||||
t.Errorf("Expected only 'pending2@example.com' to receive email, got %v", emailsSent)
|
||||
}
|
||||
}
|
||||
|
||||
// Test SendReminders with repository error
|
||||
func TestReminderService_SendReminders_RepositoryError(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
mockExpectedRepo := &mockExpectedSignerRepository{
|
||||
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
|
||||
return nil, errors.New("database connection failed")
|
||||
},
|
||||
}
|
||||
|
||||
mockReminderRepo := &mockReminderRepository{}
|
||||
mockEmailSender := &mockEmailSender{}
|
||||
mockMagicLink := &mockMagicLinkService{}
|
||||
|
||||
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com")
|
||||
|
||||
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en")
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil result on error, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
// Test GetReminderHistory
|
||||
func TestReminderService_GetReminderHistory(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
expectedLogs := []*models.ReminderLog{
|
||||
{
|
||||
DocID: "doc1",
|
||||
RecipientEmail: "user@example.com",
|
||||
SentAt: time.Now(),
|
||||
SentBy: "admin@example.com",
|
||||
Status: "sent",
|
||||
},
|
||||
}
|
||||
|
||||
mockReminderRepo := &mockReminderRepository{
|
||||
getReminderHistoryFunc: func(ctx context.Context, docID string) ([]*models.ReminderLog, error) {
|
||||
if docID != "doc1" {
|
||||
t.Errorf("Expected docID 'doc1', got '%s'", docID)
|
||||
}
|
||||
return expectedLogs, nil
|
||||
},
|
||||
}
|
||||
|
||||
service := NewReminderService(&mockExpectedSignerRepository{}, mockReminderRepo, &mockEmailSender{}, &mockMagicLinkService{}, nil, "https://example.com")
|
||||
|
||||
logs, err := service.GetReminderHistory(ctx, "doc1")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if len(logs) != 1 {
|
||||
t.Errorf("Expected 1 log, got %d", len(logs))
|
||||
}
|
||||
|
||||
if logs[0].RecipientEmail != "user@example.com" {
|
||||
t.Errorf("Expected recipient 'user@example.com', got '%s'", logs[0].RecipientEmail)
|
||||
}
|
||||
}
|
||||
|
||||
// Test GetReminderStats
|
||||
func TestReminderService_GetReminderStats(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now()
|
||||
expectedStats := &models.ReminderStats{
|
||||
TotalSent: 5,
|
||||
LastSentAt: &now,
|
||||
PendingCount: 2,
|
||||
}
|
||||
|
||||
mockReminderRepo := &mockReminderRepository{
|
||||
getReminderStatsFunc: func(ctx context.Context, docID string) (*models.ReminderStats, error) {
|
||||
if docID != "doc1" {
|
||||
t.Errorf("Expected docID 'doc1', got '%s'", docID)
|
||||
}
|
||||
return expectedStats, nil
|
||||
},
|
||||
}
|
||||
|
||||
service := NewReminderService(&mockExpectedSignerRepository{}, mockReminderRepo, &mockEmailSender{}, &mockMagicLinkService{}, nil, "https://example.com")
|
||||
|
||||
stats, err := service.GetReminderStats(ctx, "doc1")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if stats.TotalSent != 5 {
|
||||
t.Errorf("Expected 5 total sent, got %d", stats.TotalSent)
|
||||
}
|
||||
|
||||
if stats.PendingCount != 2 {
|
||||
t.Errorf("Expected 2 pending, got %d", stats.PendingCount)
|
||||
}
|
||||
|
||||
if stats.LastSentAt == nil {
|
||||
t.Error("Expected LastSentAt to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// Test SendReminders with multiple pending signers
|
||||
func TestReminderService_SendReminders_MultiplePending(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
mockExpectedRepo := &mockExpectedSignerRepository{
|
||||
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
|
||||
return []*models.ExpectedSignerWithStatus{
|
||||
{ExpectedSigner: models.ExpectedSigner{Email: "pending1@example.com", Name: "User 1"}, HasSigned: false},
|
||||
{ExpectedSigner: models.ExpectedSigner{Email: "pending2@example.com", Name: "User 2"}, HasSigned: false},
|
||||
{ExpectedSigner: models.ExpectedSigner{Email: "already-signed@example.com", Name: "User 3"}, HasSigned: true},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
emailsSent := 0
|
||||
mockReminderRepo := &mockReminderRepository{
|
||||
logReminderFunc: func(ctx context.Context, log *models.ReminderLog) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mockEmailSender := &mockEmailSender{
|
||||
sendFunc: func(ctx context.Context, msg email.Message) error {
|
||||
emailsSent++
|
||||
return nil
|
||||
},
|
||||
}
|
||||
mockMagicLink := &mockMagicLinkService{}
|
||||
|
||||
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com")
|
||||
|
||||
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if result.TotalAttempted != 2 {
|
||||
t.Errorf("Expected 2 total attempted, got %d", result.TotalAttempted)
|
||||
}
|
||||
|
||||
if result.SuccessfullySent != 2 {
|
||||
t.Errorf("Expected 2 successfully sent, got %d", result.SuccessfullySent)
|
||||
}
|
||||
|
||||
if emailsSent != 2 {
|
||||
t.Errorf("Expected 2 emails sent, got %d", emailsSent)
|
||||
}
|
||||
}
|
||||
|
||||
// Test SendReminders with log failure after successful email
|
||||
func TestReminderService_SendReminders_LogFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
mockExpectedRepo := &mockExpectedSignerRepository{
|
||||
listWithStatusFunc: func(ctx context.Context, docID string) ([]*models.ExpectedSignerWithStatus, error) {
|
||||
return []*models.ExpectedSignerWithStatus{
|
||||
{ExpectedSigner: models.ExpectedSigner{Email: "pending@example.com", Name: "Pending User"}, HasSigned: false},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mockReminderRepo := &mockReminderRepository{
|
||||
logReminderFunc: func(ctx context.Context, log *models.ReminderLog) error {
|
||||
return errors.New("database write failed")
|
||||
},
|
||||
}
|
||||
|
||||
mockEmailSender := &mockEmailSender{
|
||||
sendFunc: func(ctx context.Context, msg email.Message) error {
|
||||
return nil // Email succeeds
|
||||
},
|
||||
}
|
||||
mockMagicLink := &mockMagicLinkService{}
|
||||
|
||||
service := NewReminderService(mockExpectedRepo, mockReminderRepo, mockEmailSender, mockMagicLink, nil, "https://example.com")
|
||||
|
||||
result, err := service.SendReminders(ctx, "doc1", "admin@example.com", nil, "https://example.com/doc.pdf", "en")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error from SendReminders, got: %v", err)
|
||||
}
|
||||
|
||||
// The send should fail because logging failed
|
||||
if result.Failed != 1 {
|
||||
t.Errorf("Expected 1 failed, got %d", result.Failed)
|
||||
}
|
||||
|
||||
if result.SuccessfullySent != 0 {
|
||||
t.Errorf("Expected 0 successfully sent, got %d", result.SuccessfullySent)
|
||||
}
|
||||
}
|
||||
@@ -172,6 +172,64 @@ func (f *fakeCryptoSigner) CreateSignature(ctx context.Context, docID string, us
|
||||
return payloadHash, signature, nil
|
||||
}
|
||||
|
||||
type fakeDocumentRepository struct {
|
||||
documents map[string]*models.Document
|
||||
}
|
||||
|
||||
func newFakeDocumentRepository() *fakeDocumentRepository {
|
||||
return &fakeDocumentRepository{
|
||||
documents: make(map[string]*models.Document),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) Create(_ context.Context, docID string, input models.DocumentInput, createdBy string) (*models.Document, error) {
|
||||
doc := &models.Document{
|
||||
DocID: docID,
|
||||
Title: input.Title,
|
||||
URL: input.URL,
|
||||
Checksum: input.Checksum,
|
||||
ChecksumAlgorithm: input.ChecksumAlgorithm,
|
||||
CreatedBy: createdBy,
|
||||
}
|
||||
f.documents[docID] = doc
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) GetByDocID(_ context.Context, docID string) (*models.Document, error) {
|
||||
if doc, ok := f.documents[docID]; ok {
|
||||
return doc, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) FindByReference(_ context.Context, _ string, _ string) (*models.Document, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) List(_ context.Context, _, _ int) ([]*models.Document, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) Search(_ context.Context, _ string, _, _ int) ([]*models.Document, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) Count(_ context.Context, _ string) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) ListByCreatedBy(_ context.Context, _ string, _, _ int) ([]*models.Document, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) SearchByCreatedBy(_ context.Context, _, _ string, _, _ int) ([]*models.Document, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeDocumentRepository) CountByCreatedBy(_ context.Context, _, _ string) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func TestNewSignatureService(t *testing.T) {
|
||||
repo := newFakeRepository()
|
||||
docRepo := newFakeDocumentRepository()
|
||||
|
||||
@@ -1,393 +0,0 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/crypto"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/models"
|
||||
)
|
||||
|
||||
// OAuthProvider handles OAuth2 authentication flow
|
||||
// This component is optional and can be nil if OAuth is disabled
|
||||
type OAuthProvider struct {
|
||||
oauthConfig *oauth2.Config
|
||||
userInfoURL string
|
||||
logoutURL string
|
||||
allowedDomain string
|
||||
baseURL string
|
||||
sessionSvc *SessionService // Reference to session service for state management
|
||||
}
|
||||
|
||||
// OAuthProviderConfig holds configuration for the OAuth provider
|
||||
type OAuthProviderConfig struct {
|
||||
BaseURL string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
AuthURL string
|
||||
TokenURL string
|
||||
UserInfoURL string
|
||||
LogoutURL string
|
||||
Scopes []string
|
||||
AllowedDomain string
|
||||
SessionSvc *SessionService
|
||||
}
|
||||
|
||||
// NewOAuthProvider creates a new OAuth provider
|
||||
func NewOAuthProvider(config OAuthProviderConfig) *OAuthProvider {
|
||||
oauthConfig := &oauth2.Config{
|
||||
ClientID: config.ClientID,
|
||||
ClientSecret: config.ClientSecret,
|
||||
RedirectURL: config.BaseURL + "/api/v1/auth/callback",
|
||||
Scopes: config.Scopes,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: config.AuthURL,
|
||||
TokenURL: config.TokenURL,
|
||||
},
|
||||
}
|
||||
|
||||
logger.Logger.Info("OAuth provider configured successfully")
|
||||
|
||||
return &OAuthProvider{
|
||||
oauthConfig: oauthConfig,
|
||||
userInfoURL: config.UserInfoURL,
|
||||
logoutURL: config.LogoutURL,
|
||||
allowedDomain: config.AllowedDomain,
|
||||
baseURL: config.BaseURL,
|
||||
sessionSvc: config.SessionSvc,
|
||||
}
|
||||
}
|
||||
|
||||
// GetLogoutURL returns the SSO logout URL if configured
|
||||
func (p *OAuthProvider) GetLogoutURL() string {
|
||||
if p.logoutURL == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
logoutURL := p.logoutURL
|
||||
if p.baseURL != "" {
|
||||
logoutURL += "?continue=" + p.baseURL
|
||||
}
|
||||
|
||||
return logoutURL
|
||||
}
|
||||
|
||||
// CreateAuthURL creates an OAuth authorization URL with PKCE
|
||||
func (p *OAuthProvider) CreateAuthURL(w http.ResponseWriter, r *http.Request, nextURL string) string {
|
||||
// Generate PKCE code verifier and challenge
|
||||
codeVerifier, err := crypto.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to generate PKCE code verifier", "error", err.Error())
|
||||
// Fallback to OAuth flow without PKCE for backward compatibility
|
||||
return p.createAuthURLWithoutPKCE(w, r, nextURL)
|
||||
}
|
||||
|
||||
codeChallenge := crypto.GenerateCodeChallenge(codeVerifier)
|
||||
logger.Logger.Debug("Generated PKCE parameters for OAuth flow")
|
||||
|
||||
// Generate state token
|
||||
randPart := securecookie.GenerateRandomKey(20)
|
||||
token := base64.RawURLEncoding.EncodeToString(randPart)
|
||||
state := token + ":" + base64.RawURLEncoding.EncodeToString([]byte(nextURL))
|
||||
|
||||
promptParam := "select_account"
|
||||
isSilent := r.URL.Query().Get("silent") == "true"
|
||||
if isSilent {
|
||||
promptParam = "none"
|
||||
}
|
||||
|
||||
logger.Logger.Info("Starting OAuth flow with PKCE",
|
||||
"next_url", nextURL,
|
||||
"silent", isSilent,
|
||||
"state_token_length", len(token))
|
||||
|
||||
session, err := p.sessionSvc.GetSession(r)
|
||||
if err != nil {
|
||||
logger.Logger.Error("CreateAuthURL: failed to get session from store", "error", err.Error())
|
||||
// Create a new empty session if Get fails
|
||||
session, _ = p.sessionSvc.GetNewSession(r)
|
||||
}
|
||||
|
||||
// Store state and code_verifier in session
|
||||
session.Values["oauth_state"] = token
|
||||
session.Values["code_verifier"] = codeVerifier
|
||||
|
||||
err = session.Save(r, w)
|
||||
if err != nil {
|
||||
logger.Logger.Error("CreateAuthURL: failed to save session", "error", err.Error())
|
||||
}
|
||||
|
||||
// Generate OAuth URL with PKCE parameters
|
||||
authURL := p.oauthConfig.AuthCodeURL(state,
|
||||
oauth2.SetAuthURLParam("prompt", promptParam),
|
||||
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"))
|
||||
|
||||
logger.Logger.Debug("CreateAuthURL: generated auth URL with PKCE",
|
||||
"prompt", promptParam,
|
||||
"url_length", len(authURL))
|
||||
|
||||
return authURL
|
||||
}
|
||||
|
||||
// createAuthURLWithoutPKCE is a fallback method for OAuth without PKCE
|
||||
// Used for backward compatibility if PKCE generation fails
|
||||
func (p *OAuthProvider) createAuthURLWithoutPKCE(w http.ResponseWriter, r *http.Request, nextURL string) string {
|
||||
randPart := securecookie.GenerateRandomKey(20)
|
||||
token := base64.RawURLEncoding.EncodeToString(randPart)
|
||||
state := token + ":" + base64.RawURLEncoding.EncodeToString([]byte(nextURL))
|
||||
|
||||
promptParam := "select_account"
|
||||
isSilent := r.URL.Query().Get("silent") == "true"
|
||||
if isSilent {
|
||||
promptParam = "none"
|
||||
}
|
||||
|
||||
logger.Logger.Warn("Starting OAuth flow WITHOUT PKCE (fallback mode)",
|
||||
"next_url", nextURL,
|
||||
"silent", isSilent)
|
||||
|
||||
session, err := p.sessionSvc.GetSession(r)
|
||||
if err != nil {
|
||||
session, _ = p.sessionSvc.GetNewSession(r)
|
||||
}
|
||||
|
||||
session.Values["oauth_state"] = token
|
||||
|
||||
err = session.Save(r, w)
|
||||
if err != nil {
|
||||
logger.Logger.Error("CreateAuthURL: failed to save session", "error", err.Error())
|
||||
}
|
||||
|
||||
authURL := p.oauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("prompt", promptParam))
|
||||
|
||||
return authURL
|
||||
}
|
||||
|
||||
// VerifyState validates the OAuth state token for CSRF protection
|
||||
func (p *OAuthProvider) VerifyState(w http.ResponseWriter, r *http.Request, stateToken string) bool {
|
||||
session, _ := p.sessionSvc.GetSession(r)
|
||||
stored, _ := session.Values["oauth_state"].(string)
|
||||
|
||||
logger.Logger.Debug("VerifyState: validating OAuth state",
|
||||
"stored_length", len(stored),
|
||||
"token_length", len(stateToken),
|
||||
"stored_empty", stored == "",
|
||||
"token_empty", stateToken == "")
|
||||
|
||||
if stored == "" || stateToken == "" {
|
||||
logger.Logger.Warn("VerifyState: empty state tokens")
|
||||
return false
|
||||
}
|
||||
|
||||
if subtleConstantTimeCompare(stored, stateToken) {
|
||||
logger.Logger.Debug("VerifyState: state valid, clearing token")
|
||||
delete(session.Values, "oauth_state")
|
||||
_ = session.Save(r, w)
|
||||
return true
|
||||
}
|
||||
|
||||
logger.Logger.Warn("VerifyState: state mismatch")
|
||||
return false
|
||||
}
|
||||
|
||||
// subtleConstantTimeCompare performs a timing-safe string comparison
|
||||
func subtleConstantTimeCompare(a, b string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
var v byte
|
||||
for i := 0; i < len(a); i++ {
|
||||
v |= a[i] ^ b[i]
|
||||
}
|
||||
return v == 0
|
||||
}
|
||||
|
||||
// HandleCallback processes the OAuth callback and returns the authenticated user
|
||||
func (p *OAuthProvider) HandleCallback(ctx context.Context, w http.ResponseWriter, r *http.Request, code, state string) (*models.User, string, error) {
|
||||
parts := strings.SplitN(state, ":", 2)
|
||||
nextURL := "/"
|
||||
if len(parts) == 2 {
|
||||
if nb, err := base64.RawURLEncoding.DecodeString(parts[1]); err == nil {
|
||||
nextURL = string(nb)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Logger.Debug("Processing OAuth callback",
|
||||
"has_code", code != "",
|
||||
"next_url", nextURL)
|
||||
|
||||
// Retrieve code_verifier from session for PKCE
|
||||
session, _ := p.sessionSvc.GetSession(r)
|
||||
codeVerifier, hasPKCE := session.Values["code_verifier"].(string)
|
||||
|
||||
// Clean up code_verifier immediately after retrieval
|
||||
if hasPKCE {
|
||||
delete(session.Values, "code_verifier")
|
||||
_ = session.Save(r, w)
|
||||
}
|
||||
|
||||
// Exchange authorization code for token (with or without PKCE)
|
||||
var token *oauth2.Token
|
||||
var err error
|
||||
|
||||
if hasPKCE && codeVerifier != "" {
|
||||
logger.Logger.Info("OAuth token exchange with PKCE")
|
||||
token, err = p.oauthConfig.Exchange(ctx, code,
|
||||
oauth2.SetAuthURLParam("code_verifier", codeVerifier))
|
||||
} else {
|
||||
logger.Logger.Warn("OAuth token exchange without PKCE (legacy session or fallback)")
|
||||
token, err = p.oauthConfig.Exchange(ctx, code)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Logger.Error("OAuth token exchange failed",
|
||||
"error", err.Error(),
|
||||
"with_pkce", hasPKCE)
|
||||
return nil, nextURL, fmt.Errorf("oauth exchange failed: %w", err)
|
||||
}
|
||||
|
||||
logger.Logger.Info("OAuth token exchange successful", "with_pkce", hasPKCE)
|
||||
|
||||
client := p.oauthConfig.Client(ctx, token)
|
||||
resp, err := client.Get(p.userInfoURL)
|
||||
if err != nil || resp.StatusCode != 200 {
|
||||
statusCode := 0
|
||||
if resp != nil {
|
||||
statusCode = resp.StatusCode
|
||||
}
|
||||
logger.Logger.Error("User info request failed",
|
||||
"error", err,
|
||||
"status_code", statusCode)
|
||||
return nil, nextURL, fmt.Errorf("userinfo request failed: %w", err)
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
_ = Body.Close()
|
||||
}(resp.Body)
|
||||
|
||||
logger.Logger.Debug("User info retrieved successfully",
|
||||
"status_code", resp.StatusCode)
|
||||
|
||||
user, err := p.parseUserInfo(resp)
|
||||
if err != nil {
|
||||
logger.Logger.Error("Failed to parse user info",
|
||||
"error", err.Error())
|
||||
return nil, nextURL, fmt.Errorf("failed to parse user info: %w", err)
|
||||
}
|
||||
|
||||
if !p.IsAllowedDomain(user.Email) {
|
||||
logger.Logger.Warn("User domain not allowed")
|
||||
return nil, nextURL, models.ErrDomainNotAllowed
|
||||
}
|
||||
|
||||
logger.Logger.Info("OAuth callback successful")
|
||||
|
||||
// Store refresh token if available
|
||||
if token.RefreshToken != "" && p.sessionSvc.sessionRepo != nil {
|
||||
if err := p.sessionSvc.StoreRefreshToken(ctx, w, r, token, user); err != nil {
|
||||
// Log error but don't fail the authentication
|
||||
logger.Logger.Error("Failed to store refresh token (non-fatal)", "error", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return user, nextURL, nil
|
||||
}
|
||||
|
||||
// IsAllowedDomain checks if the user's email domain is allowed
|
||||
func (p *OAuthProvider) IsAllowedDomain(email string) bool {
|
||||
if p.allowedDomain == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
domain := strings.ToLower(p.allowedDomain)
|
||||
// If domain already has @ prefix, don't add another one
|
||||
if !strings.HasPrefix(domain, "@") {
|
||||
domain = "@" + domain
|
||||
}
|
||||
|
||||
return strings.HasSuffix(strings.ToLower(email), domain)
|
||||
}
|
||||
|
||||
// parseUserInfo extracts user information from the OAuth provider response
|
||||
func (p *OAuthProvider) parseUserInfo(resp *http.Response) (*models.User, error) {
|
||||
var rawUser map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&rawUser); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode user info: %w", err)
|
||||
}
|
||||
|
||||
// Reduce PII in standard logs; log only keys at debug level
|
||||
if rawUser != nil {
|
||||
keys := make([]string, 0, len(rawUser))
|
||||
for k := range rawUser {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
logger.Logger.Debug("OAuth user info received", "keys", keys)
|
||||
}
|
||||
|
||||
user := &models.User{}
|
||||
|
||||
if sub, ok := rawUser["sub"].(string); ok {
|
||||
user.Sub = sub
|
||||
} else if id, ok := rawUser["id"]; ok {
|
||||
user.Sub = fmt.Sprintf("%v", id)
|
||||
} else {
|
||||
return nil, fmt.Errorf("missing user ID in response")
|
||||
}
|
||||
|
||||
// Check for email in various provider-specific fields
|
||||
// - "email": Standard OIDC claim (Google, GitHub, GitLab, etc.)
|
||||
// - "mail": Microsoft Graph API
|
||||
// - "userPrincipalName": Microsoft fallback (UPN format)
|
||||
if email, ok := rawUser["email"].(string); ok && email != "" {
|
||||
user.Email = email
|
||||
} else if mail, ok := rawUser["mail"].(string); ok && mail != "" {
|
||||
user.Email = mail
|
||||
} else if upn, ok := rawUser["userPrincipalName"].(string); ok && upn != "" {
|
||||
user.Email = upn
|
||||
} else {
|
||||
return nil, fmt.Errorf("missing email in user info response (checked: email, mail, userPrincipalName)")
|
||||
}
|
||||
|
||||
// Extract display name from various provider-specific fields
|
||||
var name string
|
||||
if fullName, ok := rawUser["name"].(string); ok && fullName != "" {
|
||||
name = fullName
|
||||
} else if firstName, ok := rawUser["given_name"].(string); ok {
|
||||
if lastName, ok := rawUser["family_name"].(string); ok {
|
||||
name = firstName + " " + lastName
|
||||
} else {
|
||||
name = firstName
|
||||
}
|
||||
} else if displayName, ok := rawUser["displayName"].(string); ok && displayName != "" {
|
||||
name = displayName
|
||||
} else if cn, ok := rawUser["cn"].(string); ok && cn != "" {
|
||||
name = cn
|
||||
} else if displayNameSnake, ok := rawUser["display_name"].(string); ok && displayNameSnake != "" {
|
||||
name = displayNameSnake
|
||||
} else if preferredName, ok := rawUser["preferred_username"].(string); ok && preferredName != "" {
|
||||
name = preferredName
|
||||
}
|
||||
|
||||
user.Name = name
|
||||
|
||||
logger.Logger.Debug("Extracted OAuth user identifiers",
|
||||
"has_sub", user.Sub != "",
|
||||
"has_email", user.Email != "",
|
||||
"has_name", user.Name != "")
|
||||
|
||||
if !user.IsValid() {
|
||||
return nil, fmt.Errorf("invalid user data extracted")
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
@@ -1,422 +0,0 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/models"
|
||||
)
|
||||
|
||||
func TestOAuthProvider_IsAllowedDomain(t *testing.T) {
|
||||
sessionSvc := NewSessionService(SessionServiceConfig{
|
||||
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
|
||||
SecureCookies: false,
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedDomain string
|
||||
email string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "allowed domain match",
|
||||
allowedDomain: "@example.com",
|
||||
email: "user@example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "allowed domain mismatch",
|
||||
allowedDomain: "@example.com",
|
||||
email: "user@other.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "no restriction",
|
||||
allowedDomain: "",
|
||||
email: "user@any.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "case insensitive match",
|
||||
allowedDomain: "@Example.Com",
|
||||
email: "user@EXAMPLE.com",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider := &OAuthProvider{
|
||||
allowedDomain: tt.allowedDomain,
|
||||
sessionSvc: sessionSvc,
|
||||
}
|
||||
|
||||
result := provider.IsAllowedDomain(tt.email)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsAllowedDomain(%v) = %v, expected %v", tt.email, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthProvider_CreateAuthURL(t *testing.T) {
|
||||
sessionSvc := NewSessionService(SessionServiceConfig{
|
||||
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
|
||||
SecureCookies: false,
|
||||
})
|
||||
|
||||
config := OAuthProviderConfig{
|
||||
BaseURL: "https://ackify.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
AuthURL: "https://provider.com/auth",
|
||||
TokenURL: "https://provider.com/token",
|
||||
UserInfoURL: "https://provider.com/userinfo",
|
||||
Scopes: []string{"openid", "email"},
|
||||
SessionSvc: sessionSvc,
|
||||
}
|
||||
|
||||
provider := NewOAuthProvider(config)
|
||||
|
||||
t.Run("creates auth URL with PKCE", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
authURL := provider.CreateAuthURL(rec, req, "/dashboard")
|
||||
|
||||
if authURL == "" {
|
||||
t.Fatal("CreateAuthURL() returned empty string")
|
||||
}
|
||||
|
||||
// Parse the URL and check parameters
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
|
||||
// Check for required OAuth parameters
|
||||
if query.Get("client_id") == "" {
|
||||
t.Error("client_id parameter missing")
|
||||
}
|
||||
if query.Get("redirect_uri") == "" {
|
||||
t.Error("redirect_uri parameter missing")
|
||||
}
|
||||
if query.Get("response_type") != "code" {
|
||||
t.Error("response_type should be 'code'")
|
||||
}
|
||||
if query.Get("state") == "" {
|
||||
t.Error("state parameter missing")
|
||||
}
|
||||
|
||||
// Check for PKCE parameters
|
||||
if query.Get("code_challenge") == "" {
|
||||
t.Error("code_challenge parameter missing (PKCE should be enabled)")
|
||||
}
|
||||
if query.Get("code_challenge_method") != "S256" {
|
||||
t.Error("code_challenge_method should be 'S256'")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates auth URL with silent flag", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/?silent=true", nil)
|
||||
|
||||
authURL := provider.CreateAuthURL(rec, req, "/dashboard")
|
||||
|
||||
parsedURL, _ := url.Parse(authURL)
|
||||
query := parsedURL.Query()
|
||||
|
||||
if query.Get("prompt") != "none" {
|
||||
t.Errorf("prompt parameter = %v, expected 'none' for silent auth", query.Get("prompt"))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOAuthProvider_VerifyState(t *testing.T) {
|
||||
sessionSvc := NewSessionService(SessionServiceConfig{
|
||||
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
|
||||
SecureCookies: false,
|
||||
})
|
||||
|
||||
config := OAuthProviderConfig{
|
||||
BaseURL: "https://ackify.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
AuthURL: "https://provider.com/auth",
|
||||
TokenURL: "https://provider.com/token",
|
||||
UserInfoURL: "https://provider.com/userinfo",
|
||||
Scopes: []string{"openid"},
|
||||
SessionSvc: sessionSvc,
|
||||
}
|
||||
|
||||
provider := NewOAuthProvider(config)
|
||||
|
||||
t.Run("valid state", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
// Create auth URL to generate state
|
||||
_ = provider.CreateAuthURL(rec, req, "/")
|
||||
|
||||
// Get session with state
|
||||
req2 := httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range rec.Result().Cookies() {
|
||||
req2.AddCookie(cookie)
|
||||
}
|
||||
|
||||
session, _ := sessionSvc.GetSession(req2)
|
||||
storedState, ok := session.Values["oauth_state"].(string)
|
||||
if !ok {
|
||||
t.Fatal("State not stored in session")
|
||||
}
|
||||
|
||||
// Verify state
|
||||
rec2 := httptest.NewRecorder()
|
||||
valid := provider.VerifyState(rec2, req2, storedState)
|
||||
if !valid {
|
||||
t.Error("VerifyState() should return true for valid state")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid state", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
valid := provider.VerifyState(rec, req, "invalid-state-token")
|
||||
if valid {
|
||||
t.Error("VerifyState() should return false for invalid state")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty state", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
valid := provider.VerifyState(rec, req, "")
|
||||
if valid {
|
||||
t.Error("VerifyState() should return false for empty state")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOAuthProvider_parseUserInfo(t *testing.T) {
|
||||
sessionSvc := NewSessionService(SessionServiceConfig{
|
||||
CookieSecret: []byte("32-byte-secret-for-secure-cookies"),
|
||||
SecureCookies: false,
|
||||
})
|
||||
|
||||
provider := &OAuthProvider{
|
||||
sessionSvc: sessionSvc,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
responseObj map[string]interface{}
|
||||
wantErr bool
|
||||
checkUser func(*testing.T, *models.User)
|
||||
}{
|
||||
{
|
||||
name: "complete user info with sub",
|
||||
responseObj: map[string]interface{}{
|
||||
"sub": "12345",
|
||||
"email": "user@example.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
wantErr: false,
|
||||
checkUser: func(t *testing.T, user *models.User) {
|
||||
if user.Sub != "12345" {
|
||||
t.Errorf("Sub = %v, expected 12345", user.Sub)
|
||||
}
|
||||
if user.Email != "user@example.com" {
|
||||
t.Errorf("Email = %v, expected user@example.com", user.Email)
|
||||
}
|
||||
if user.Name != "Test User" {
|
||||
t.Errorf("Name = %v, expected Test User", user.Name)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user info with id instead of sub",
|
||||
responseObj: map[string]interface{}{
|
||||
"id": 67890,
|
||||
"email": "user@example.com",
|
||||
},
|
||||
wantErr: false,
|
||||
checkUser: func(t *testing.T, user *models.User) {
|
||||
if user.Sub != "67890" {
|
||||
t.Errorf("Sub = %v, expected 67890", user.Sub)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user info with given_name and family_name",
|
||||
responseObj: map[string]interface{}{
|
||||
"sub": "12345",
|
||||
"email": "user@example.com",
|
||||
"given_name": "John",
|
||||
"family_name": "Doe",
|
||||
},
|
||||
wantErr: false,
|
||||
checkUser: func(t *testing.T, user *models.User) {
|
||||
if user.Name != "John Doe" {
|
||||
t.Errorf("Name = %v, expected 'John Doe'", user.Name)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Microsoft Graph API - mail field",
|
||||
responseObj: map[string]interface{}{
|
||||
"id": "microsoft-id-12345",
|
||||
"mail": "user@company.com",
|
||||
"displayName": "Microsoft User",
|
||||
"userPrincipalName": "user@company.onmicrosoft.com",
|
||||
},
|
||||
wantErr: false,
|
||||
checkUser: func(t *testing.T, user *models.User) {
|
||||
if user.Sub != "microsoft-id-12345" {
|
||||
t.Errorf("Sub = %v, expected microsoft-id-12345", user.Sub)
|
||||
}
|
||||
if user.Email != "user@company.com" {
|
||||
t.Errorf("Email = %v, expected user@company.com (from mail field)", user.Email)
|
||||
}
|
||||
if user.Name != "Microsoft User" {
|
||||
t.Errorf("Name = %v, expected Microsoft User (from displayName)", user.Name)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Microsoft Graph API - userPrincipalName fallback",
|
||||
responseObj: map[string]interface{}{
|
||||
"id": "microsoft-id-67890",
|
||||
"displayName": "UPN User",
|
||||
"userPrincipalName": "user@company.onmicrosoft.com",
|
||||
},
|
||||
wantErr: false,
|
||||
checkUser: func(t *testing.T, user *models.User) {
|
||||
if user.Email != "user@company.onmicrosoft.com" {
|
||||
t.Errorf("Email = %v, expected user@company.onmicrosoft.com (from userPrincipalName)", user.Email)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "email field takes priority over mail",
|
||||
responseObj: map[string]interface{}{
|
||||
"sub": "12345",
|
||||
"email": "primary@example.com",
|
||||
"mail": "secondary@example.com",
|
||||
},
|
||||
wantErr: false,
|
||||
checkUser: func(t *testing.T, user *models.User) {
|
||||
if user.Email != "primary@example.com" {
|
||||
t.Errorf("Email = %v, expected primary@example.com (email should take priority)", user.Email)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing email",
|
||||
responseObj: map[string]interface{}{
|
||||
"sub": "12345",
|
||||
"name": "Test User",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing sub and id",
|
||||
responseObj: map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create mock HTTP response
|
||||
jsonData, _ := json.Marshal(tt.responseObj)
|
||||
resp := &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(bytes.NewReader(jsonData)),
|
||||
}
|
||||
|
||||
user, err := provider.parseUserInfo(resp)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("parseUserInfo() should return error")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("parseUserInfo() unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
t.Fatal("parseUserInfo() returned nil user")
|
||||
}
|
||||
|
||||
if tt.checkUser != nil {
|
||||
tt.checkUser(t, user)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubtleConstantTimeCompare(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a string
|
||||
b string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "equal strings",
|
||||
a: "secret123",
|
||||
b: "secret123",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different strings",
|
||||
a: "secret123",
|
||||
b: "secret456",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "different lengths",
|
||||
a: "short",
|
||||
b: "longer string",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty strings",
|
||||
a: "",
|
||||
b: "",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "one empty",
|
||||
a: "string",
|
||||
b: "",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := subtleConstantTimeCompare(tt.a, tt.b)
|
||||
if result != tt.expected {
|
||||
t.Errorf("subtleConstantTimeCompare(%v, %v) = %v, expected %v", tt.a, tt.b, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package email
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/i18n"
|
||||
)
|
||||
|
||||
func SendEmail(ctx context.Context, sender Sender, template string, to []string, locale string, subject string, data map[string]any) error {
|
||||
msg := Message{
|
||||
To: to,
|
||||
Subject: subject,
|
||||
Template: template,
|
||||
Locale: locale,
|
||||
Data: data,
|
||||
}
|
||||
|
||||
return sender.Send(ctx, msg)
|
||||
}
|
||||
|
||||
func SendSignatureReminderEmail(ctx context.Context, sender Sender, i18nService *i18n.I18n, to []string, locale, docID, docURL, signURL, recipientName string) error {
|
||||
data := map[string]any{
|
||||
"DocID": docID,
|
||||
"DocURL": docURL,
|
||||
"SignURL": signURL,
|
||||
"RecipientName": recipientName,
|
||||
}
|
||||
|
||||
// Get translated subject using i18n
|
||||
subject := "Document Reading Confirmation Reminder" // Fallback
|
||||
if i18nService != nil {
|
||||
subject = i18nService.T(locale, "email.reminder.subject")
|
||||
}
|
||||
|
||||
return SendEmail(ctx, sender, "signature_reminder", to, locale, subject, data)
|
||||
}
|
||||
@@ -1,268 +0,0 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package email
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Mock sender for testing
|
||||
type mockSender struct {
|
||||
sendFunc func(ctx context.Context, msg Message) error
|
||||
lastMsg *Message
|
||||
}
|
||||
|
||||
func (m *mockSender) Send(ctx context.Context, msg Message) error {
|
||||
m.lastMsg = &msg
|
||||
if m.sendFunc != nil {
|
||||
return m.sendFunc(ctx, msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSendEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
to []string
|
||||
locale string
|
||||
subject string
|
||||
data map[string]any
|
||||
sendError error
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Send email successfully",
|
||||
template: "test_template",
|
||||
to: []string{"user@example.com"},
|
||||
locale: "en",
|
||||
subject: "Test Subject",
|
||||
data: map[string]any{
|
||||
"name": "John",
|
||||
},
|
||||
sendError: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Send email with multiple recipients",
|
||||
template: "welcome",
|
||||
to: []string{"user1@example.com", "user2@example.com"},
|
||||
locale: "fr",
|
||||
subject: "Bienvenue",
|
||||
data: map[string]any{
|
||||
"company": "Acme Corp",
|
||||
},
|
||||
sendError: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Send email with error",
|
||||
template: "error_template",
|
||||
to: []string{"user@example.com"},
|
||||
locale: "en",
|
||||
subject: "Error Test",
|
||||
data: nil,
|
||||
sendError: errors.New("SMTP connection failed"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Send email with empty data",
|
||||
template: "simple_template",
|
||||
to: []string{"test@example.com"},
|
||||
locale: "en",
|
||||
subject: "Simple Email",
|
||||
data: map[string]any{},
|
||||
sendError: nil,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
mock := &mockSender{
|
||||
sendFunc: func(ctx context.Context, msg Message) error {
|
||||
return tt.sendError
|
||||
},
|
||||
}
|
||||
|
||||
err := SendEmail(ctx, mock, tt.template, tt.to, tt.locale, tt.subject, tt.data)
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error but got nil")
|
||||
}
|
||||
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify message was constructed correctly
|
||||
if mock.lastMsg == nil {
|
||||
t.Fatal("Expected message to be captured")
|
||||
}
|
||||
|
||||
if mock.lastMsg.Template != tt.template {
|
||||
t.Errorf("Expected template '%s', got '%s'", tt.template, mock.lastMsg.Template)
|
||||
}
|
||||
|
||||
if mock.lastMsg.Subject != tt.subject {
|
||||
t.Errorf("Expected subject '%s', got '%s'", tt.subject, mock.lastMsg.Subject)
|
||||
}
|
||||
|
||||
if mock.lastMsg.Locale != tt.locale {
|
||||
t.Errorf("Expected locale '%s', got '%s'", tt.locale, mock.lastMsg.Locale)
|
||||
}
|
||||
|
||||
if len(mock.lastMsg.To) != len(tt.to) {
|
||||
t.Errorf("Expected %d recipients, got %d", len(tt.to), len(mock.lastMsg.To))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendSignatureReminderEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Note: When i18n is nil, all tests use the fallback subject
|
||||
fallbackSubject := "Document Reading Confirmation Reminder"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
to []string
|
||||
locale string
|
||||
docID string
|
||||
docURL string
|
||||
signURL string
|
||||
recipientName string
|
||||
expectedSubject string
|
||||
sendError error
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Send reminder in English",
|
||||
to: []string{"user@example.com"},
|
||||
locale: "en",
|
||||
docID: "doc123",
|
||||
docURL: "https://example.com/doc.pdf",
|
||||
signURL: "https://example.com/?doc=doc123",
|
||||
recipientName: "John Doe",
|
||||
expectedSubject: fallbackSubject,
|
||||
sendError: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Send reminder in French (fallback without i18n)",
|
||||
to: []string{"utilisateur@exemple.fr"},
|
||||
locale: "fr",
|
||||
docID: "doc456",
|
||||
docURL: "https://exemple.fr/document.pdf",
|
||||
signURL: "https://exemple.fr/?doc=doc456",
|
||||
recipientName: "Marie Dupont",
|
||||
expectedSubject: fallbackSubject,
|
||||
sendError: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Send reminder with unknown locale defaults to fallback",
|
||||
to: []string{"user@example.com"},
|
||||
locale: "es",
|
||||
docID: "doc789",
|
||||
docURL: "https://example.com/doc.pdf",
|
||||
signURL: "https://example.com/?doc=doc789",
|
||||
recipientName: "Juan Garcia",
|
||||
expectedSubject: fallbackSubject,
|
||||
sendError: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Send reminder with error",
|
||||
to: []string{"user@example.com"},
|
||||
locale: "en",
|
||||
docID: "doc999",
|
||||
docURL: "https://example.com/doc.pdf",
|
||||
signURL: "https://example.com/?doc=doc999",
|
||||
recipientName: "Test User",
|
||||
expectedSubject: fallbackSubject,
|
||||
sendError: errors.New("email server unavailable"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Send reminder with empty recipient name",
|
||||
to: []string{"user@example.com"},
|
||||
locale: "en",
|
||||
docID: "doc000",
|
||||
docURL: "https://example.com/doc.pdf",
|
||||
signURL: "https://example.com/?doc=doc000",
|
||||
recipientName: "",
|
||||
expectedSubject: fallbackSubject,
|
||||
sendError: nil,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
mock := &mockSender{
|
||||
sendFunc: func(ctx context.Context, msg Message) error {
|
||||
return tt.sendError
|
||||
},
|
||||
}
|
||||
|
||||
err := SendSignatureReminderEmail(ctx, mock, nil, tt.to, tt.locale, tt.docID, tt.docURL, tt.signURL, tt.recipientName)
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error but got nil")
|
||||
}
|
||||
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify message construction
|
||||
if mock.lastMsg == nil {
|
||||
t.Fatal("Expected message to be captured")
|
||||
}
|
||||
|
||||
if mock.lastMsg.Template != "signature_reminder" {
|
||||
t.Errorf("Expected template 'signature_reminder', got '%s'", mock.lastMsg.Template)
|
||||
}
|
||||
|
||||
if mock.lastMsg.Subject != tt.expectedSubject {
|
||||
t.Errorf("Expected subject '%s', got '%s'", tt.expectedSubject, mock.lastMsg.Subject)
|
||||
}
|
||||
|
||||
if mock.lastMsg.Locale != tt.locale {
|
||||
t.Errorf("Expected locale '%s', got '%s'", tt.locale, mock.lastMsg.Locale)
|
||||
}
|
||||
|
||||
// Verify data fields
|
||||
if mock.lastMsg.Data == nil {
|
||||
t.Fatal("Expected data to be present")
|
||||
}
|
||||
|
||||
if docID, ok := mock.lastMsg.Data["DocID"].(string); !ok || docID != tt.docID {
|
||||
t.Errorf("Expected DocID '%s', got '%v'", tt.docID, mock.lastMsg.Data["DocID"])
|
||||
}
|
||||
|
||||
if docURL, ok := mock.lastMsg.Data["DocURL"].(string); !ok || docURL != tt.docURL {
|
||||
t.Errorf("Expected DocURL '%s', got '%v'", tt.docURL, mock.lastMsg.Data["DocURL"])
|
||||
}
|
||||
|
||||
if signURL, ok := mock.lastMsg.Data["SignURL"].(string); !ok || signURL != tt.signURL {
|
||||
t.Errorf("Expected SignURL '%s', got '%v'", tt.signURL, mock.lastMsg.Data["SignURL"])
|
||||
}
|
||||
|
||||
if recipientName, ok := mock.lastMsg.Data["RecipientName"].(string); !ok || recipientName != tt.recipientName {
|
||||
t.Errorf("Expected RecipientName '%s', got '%v'", tt.recipientName, mock.lastMsg.Data["RecipientName"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,113 +0,0 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package email
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/internal/infrastructure/i18n"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/models"
|
||||
)
|
||||
|
||||
// QueueSender implements the Sender interface by queuing emails instead of sending them directly
|
||||
type QueueSender struct {
|
||||
queueRepo QueueRepository
|
||||
baseURL string
|
||||
}
|
||||
|
||||
// NewQueueSender creates a new queue-based email sender
|
||||
func NewQueueSender(queueRepo QueueRepository, baseURL string) *QueueSender {
|
||||
return &QueueSender{
|
||||
queueRepo: queueRepo,
|
||||
baseURL: baseURL,
|
||||
}
|
||||
}
|
||||
|
||||
// Send queues an email for asynchronous processing
|
||||
func (q *QueueSender) Send(ctx context.Context, msg Message) error {
|
||||
// Convert message data to proper format
|
||||
data := msg.Data
|
||||
if data == nil {
|
||||
data = make(map[string]interface{})
|
||||
}
|
||||
|
||||
input := models.EmailQueueInput{
|
||||
ToAddresses: msg.To,
|
||||
CcAddresses: msg.Cc,
|
||||
BccAddresses: msg.Bcc,
|
||||
Subject: msg.Subject,
|
||||
Template: msg.Template,
|
||||
Locale: msg.Locale,
|
||||
Data: data,
|
||||
Headers: msg.Headers,
|
||||
Priority: models.EmailPriorityNormal,
|
||||
}
|
||||
|
||||
// Set priority based on template type
|
||||
switch msg.Template {
|
||||
case "signature_reminder":
|
||||
input.Priority = models.EmailPriorityHigh
|
||||
case "welcome", "notification":
|
||||
input.Priority = models.EmailPriorityNormal
|
||||
default:
|
||||
input.Priority = models.EmailPriorityNormal
|
||||
}
|
||||
|
||||
// Queue the email
|
||||
_, err := q.queueRepo.Enqueue(ctx, input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to queue email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueueSignatureReminderEmail queues a signature reminder email
|
||||
func QueueSignatureReminderEmail(
|
||||
ctx context.Context,
|
||||
queueRepo QueueRepository,
|
||||
i18nService *i18n.I18n,
|
||||
recipients []string,
|
||||
locale string,
|
||||
docID string,
|
||||
docURL string,
|
||||
signURL string,
|
||||
recipientName string,
|
||||
sentBy string,
|
||||
) error {
|
||||
data := map[string]interface{}{
|
||||
"doc_id": docID,
|
||||
"doc_url": docURL,
|
||||
"sign_url": signURL,
|
||||
"recipient_name": recipientName,
|
||||
"locale": locale,
|
||||
}
|
||||
|
||||
// Get translated subject using i18n
|
||||
subject := "Document Reading Confirmation Reminder" // Fallback
|
||||
if i18nService != nil {
|
||||
subject = i18nService.T(locale, "email.reminder.subject")
|
||||
}
|
||||
|
||||
// Create a reference for tracking
|
||||
refType := "signature_reminder"
|
||||
|
||||
input := models.EmailQueueInput{
|
||||
ToAddresses: recipients,
|
||||
Subject: subject,
|
||||
Template: "signature_reminder",
|
||||
Locale: locale,
|
||||
Data: data,
|
||||
Priority: models.EmailPriorityHigh,
|
||||
ReferenceType: &refType,
|
||||
ReferenceID: &docID,
|
||||
CreatedBy: &sentBy,
|
||||
}
|
||||
|
||||
_, err := queueRepo.Enqueue(ctx, input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to queue signature reminder: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -120,8 +120,3 @@ func WritePaginatedJSON(w http.ResponseWriter, data interface{}, page, limit, to
|
||||
|
||||
WriteJSONWithMeta(w, http.StatusOK, data, meta)
|
||||
}
|
||||
|
||||
// WriteNoContent writes a 204 No Content response
|
||||
func WriteNoContent(w http.ResponseWriter) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
@@ -222,19 +222,3 @@ func TestWritePaginatedJSON(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteNoContent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
WriteNoContent(w)
|
||||
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusNoContent, w.Code)
|
||||
}
|
||||
|
||||
if w.Body.Len() != 0 {
|
||||
t.Errorf("Expected empty body, got %d bytes", w.Body.Len())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
package shared
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GetClientIP extracts the real client IP address from the request
|
||||
// It checks X-Forwarded-For, X-Real-IP, and falls back to RemoteAddr
|
||||
func GetClientIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header (proxy/load balancer)
|
||||
if forwardedFor := r.Header.Get("X-Forwarded-For"); forwardedFor != "" {
|
||||
// X-Forwarded-For can contain multiple IPs, take the first one
|
||||
ips := strings.Split(forwardedFor, ",")
|
||||
if len(ips) > 0 {
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
|
||||
return strings.TrimSpace(realIP)
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
ip := r.RemoteAddr
|
||||
// Remove port if present
|
||||
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
||||
ip = ip[:idx]
|
||||
}
|
||||
return ip
|
||||
}
|
||||
@@ -25,6 +25,7 @@ type UserDTO struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Picture string `json:"picture,omitempty"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
}
|
||||
|
||||
@@ -40,6 +41,7 @@ func (h *Handler) HandleGetCurrentUser(w http.ResponseWriter, r *http.Request) {
|
||||
ID: user.Sub,
|
||||
Email: user.Email,
|
||||
Name: user.Name,
|
||||
Picture: user.Picture,
|
||||
IsAdmin: h.authorizer.IsAdmin(r.Context(), user.Email),
|
||||
}
|
||||
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/models"
|
||||
)
|
||||
|
||||
// HandleError handles different types of errors and returns appropriate HTTP responses
|
||||
func HandleError(w http.ResponseWriter, err error) {
|
||||
switch {
|
||||
case errors.Is(err, models.ErrUnauthorized):
|
||||
logger.Logger.Warn("Unauthorized access attempt", "error", err.Error())
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
case errors.Is(err, models.ErrSignatureNotFound):
|
||||
logger.Logger.Debug("Signature not found", "error", err.Error())
|
||||
http.Error(w, "Signature not found", http.StatusNotFound)
|
||||
case errors.Is(err, models.ErrSignatureAlreadyExists):
|
||||
logger.Logger.Debug("Duplicate signature attempt", "error", err.Error())
|
||||
http.Error(w, "Signature already exists", http.StatusConflict)
|
||||
case errors.Is(err, models.ErrInvalidUser):
|
||||
logger.Logger.Warn("Invalid user data", "error", err.Error())
|
||||
http.Error(w, "Invalid user", http.StatusBadRequest)
|
||||
case errors.Is(err, models.ErrInvalidDocument):
|
||||
logger.Logger.Warn("Invalid document ID", "error", err.Error())
|
||||
http.Error(w, "Invalid document ID", http.StatusBadRequest)
|
||||
case errors.Is(err, models.ErrDomainNotAllowed):
|
||||
logger.Logger.Warn("Domain not allowed", "error", err.Error())
|
||||
http.Error(w, "Domain not allowed", http.StatusForbidden)
|
||||
case errors.Is(err, models.ErrDatabaseConnection):
|
||||
logger.Logger.Error("Database connection error", "error", err.Error())
|
||||
http.Error(w, "Database error", http.StatusInternalServerError)
|
||||
default:
|
||||
logger.Logger.Error("Unhandled error", "error", err.Error())
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
@@ -4,8 +4,6 @@ package handlers
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@@ -208,36 +206,6 @@ func TestHandleOEmbed_MissingDocParam(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOEmbedURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
urlStr string
|
||||
baseURL string
|
||||
expected bool
|
||||
}{
|
||||
{"valid same host", "https://example.com/?doc=123", "https://example.com", true},
|
||||
{"valid with port", "https://example.com:443/?doc=123", "https://example.com", true},
|
||||
{"different host", "https://other.com/?doc=123", "https://example.com", false},
|
||||
{"localhost variations", "http://localhost:8080/?doc=123", "http://127.0.0.1:8080", true},
|
||||
{"localhost to 127.0.0.1", "http://127.0.0.1/?doc=123", "http://localhost", true},
|
||||
{"invalid URL", ":::invalid", "https://example.com", false},
|
||||
{"invalid base URL", "https://example.com/?doc=123", ":::invalid", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result := ValidateOEmbedURL(tt.urlStr, tt.baseURL)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// BENCHMARKS
|
||||
// ============================================================================
|
||||
@@ -254,16 +222,6 @@ func BenchmarkHandleOEmbed(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkValidateOEmbedURL(b *testing.B) {
|
||||
urlStr := "https://example.com/?doc=test123"
|
||||
baseURL := "https://example.com"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = ValidateOEmbedURL(urlStr, baseURL)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - Middleware: SecureHeaders
|
||||
// ============================================================================
|
||||
@@ -423,126 +381,6 @@ func TestRequestLogger_DifferentMethods(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - HandleError
|
||||
// ============================================================================
|
||||
|
||||
func TestHandleError_Unauthorized(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
HandleError(rec, models.ErrUnauthorized)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Unauthorized")
|
||||
}
|
||||
|
||||
func TestHandleError_SignatureNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
HandleError(rec, models.ErrSignatureNotFound)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Signature not found")
|
||||
}
|
||||
|
||||
func TestHandleError_SignatureAlreadyExists(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
HandleError(rec, models.ErrSignatureAlreadyExists)
|
||||
|
||||
assert.Equal(t, http.StatusConflict, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Signature already exists")
|
||||
}
|
||||
|
||||
func TestHandleError_InvalidUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
HandleError(rec, models.ErrInvalidUser)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Invalid user")
|
||||
}
|
||||
|
||||
func TestHandleError_InvalidDocument(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
HandleError(rec, models.ErrInvalidDocument)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Invalid document ID")
|
||||
}
|
||||
|
||||
func TestHandleError_DomainNotAllowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
HandleError(rec, models.ErrDomainNotAllowed)
|
||||
|
||||
assert.Equal(t, http.StatusForbidden, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Domain not allowed")
|
||||
}
|
||||
|
||||
func TestHandleError_DatabaseConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
HandleError(rec, models.ErrDatabaseConnection)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Database error")
|
||||
}
|
||||
|
||||
func TestHandleError_UnknownError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
HandleError(rec, errors.New("unknown error"))
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Internal server error")
|
||||
}
|
||||
|
||||
func TestHandleError_WrappedErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expectedStatus int
|
||||
expectedMsg string
|
||||
}{
|
||||
{
|
||||
"wrapped unauthorized",
|
||||
fmt.Errorf("auth failed: %w", models.ErrUnauthorized),
|
||||
http.StatusUnauthorized,
|
||||
"Unauthorized",
|
||||
},
|
||||
{
|
||||
"wrapped domain error",
|
||||
fmt.Errorf("validation failed: %w", models.ErrDomainNotAllowed),
|
||||
http.StatusForbidden,
|
||||
"Domain not allowed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
HandleError(rec, tt.err)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), tt.expectedMsg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TESTS - statusRecorder
|
||||
// ============================================================================
|
||||
@@ -616,13 +454,3 @@ func BenchmarkRequestLogger(b *testing.B) {
|
||||
handler.ServeHTTP(rec, req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHandleError(b *testing.B) {
|
||||
err := models.ErrUnauthorized
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
rec := httptest.NewRecorder()
|
||||
HandleError(rec, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,47 +7,36 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/models"
|
||||
)
|
||||
|
||||
type userService interface {
|
||||
GetUser(r *http.Request) (*models.User, error)
|
||||
}
|
||||
|
||||
type AuthMiddleware struct {
|
||||
userService userService
|
||||
baseURL string
|
||||
}
|
||||
|
||||
// SecureHeaders Enforce baseline security headers (CSP, XFO, etc.) to mitigate clickjacking, MIME sniffing, and unsafe embedding by default.
|
||||
// SecureHeaders enforces baseline security headers (CSP, XFO, etc.)
|
||||
func SecureHeaders(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("Referrer-Policy", "no-referrer")
|
||||
|
||||
// Check if this is an embed route - allow iframe embedding
|
||||
isEmbedRoute := strings.HasPrefix(r.URL.Path, "/embed/") || strings.HasPrefix(r.URL.Path, "/embed")
|
||||
|
||||
// OAuth provider avatar domains
|
||||
imgSrc := "img-src 'self' data: https://cdn.simpleicons.org https://*.googleusercontent.com https://avatars.githubusercontent.com https://secure.gravatar.com https://gitlab.com"
|
||||
|
||||
if isEmbedRoute {
|
||||
// Allow embedding from any origin for embed pages
|
||||
// Do not set X-Frame-Options to allow iframe embedding
|
||||
w.Header().Set("Content-Security-Policy",
|
||||
"default-src 'self'; "+
|
||||
"style-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com https://fonts.googleapis.com; "+
|
||||
"font-src 'self' https://fonts.gstatic.com; "+
|
||||
"script-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com; "+
|
||||
"img-src 'self' data: https://cdn.simpleicons.org; "+
|
||||
imgSrc+"; "+
|
||||
"connect-src 'self'; "+
|
||||
"frame-ancestors *") // Allow embedding from any origin
|
||||
"frame-ancestors *")
|
||||
} else {
|
||||
// Strict headers for non-embed routes
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
w.Header().Set("Content-Security-Policy",
|
||||
"default-src 'self'; "+
|
||||
"style-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com https://fonts.googleapis.com; "+
|
||||
"font-src 'self' https://fonts.gstatic.com; "+
|
||||
"script-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com; "+
|
||||
"img-src 'self' data: https://cdn.simpleicons.org; "+
|
||||
imgSrc+"; "+
|
||||
"connect-src 'self'; "+
|
||||
"frame-ancestors 'self'")
|
||||
}
|
||||
@@ -56,14 +45,12 @@ func SecureHeaders(next http.Handler) http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
// RequestLogger Minimal structured logging without PII; record latency and status for ops visibility.
|
||||
func RequestLogger(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sr := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
|
||||
start := time.Now()
|
||||
next.ServeHTTP(sr, r)
|
||||
duration := time.Since(start)
|
||||
// Minimal structured log to avoid PII
|
||||
logger.Logger.Info("http_request",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
@@ -81,8 +68,3 @@ func (sr *statusRecorder) WriteHeader(code int) {
|
||||
sr.status = code
|
||||
sr.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/btouchard/ackify-ce/backend/pkg/logger"
|
||||
)
|
||||
@@ -91,29 +90,3 @@ func HandleOEmbed(baseURL string) http.HandlerFunc {
|
||||
"remote_addr", r.RemoteAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateOEmbedURL checks if the provided URL is a valid Ackify document URL
|
||||
func ValidateOEmbedURL(urlStr string, baseURL string) bool {
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
baseURLParsed, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Normalize hosts for comparison (remove ports if present)
|
||||
urlHost := strings.Split(parsedURL.Host, ":")[0]
|
||||
baseHost := strings.Split(baseURLParsed.Host, ":")[0]
|
||||
|
||||
// Allow localhost variations
|
||||
if urlHost == "localhost" || urlHost == "127.0.0.1" {
|
||||
if baseHost == "localhost" || baseHost == "127.0.0.1" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return urlHost == baseHost
|
||||
}
|
||||
|
||||
@@ -7,9 +7,10 @@ import "strings"
|
||||
// This is the canonical user representation used by auth providers, domain models,
|
||||
// and API handlers.
|
||||
type User struct {
|
||||
Sub string `json:"sub"` // Unique identifier (OAuth sub claim or email for MagicLink)
|
||||
Email string `json:"email"` // User's email address
|
||||
Name string `json:"name"` // Display name (optional)
|
||||
Sub string `json:"sub"` // Unique identifier (OAuth sub claim or email for MagicLink)
|
||||
Email string `json:"email"` // User's email address
|
||||
Name string `json:"name"` // Display name (optional)
|
||||
Picture string `json:"picture,omitempty"` // Avatar URL from OAuth provider (optional)
|
||||
}
|
||||
|
||||
// IsValid returns true if the user has required fields populated.
|
||||
|
||||
@@ -419,6 +419,15 @@ func (p *Provider) parseUserInfo(resp *http.Response) (*types.User, error) {
|
||||
user.Name = preferredName
|
||||
}
|
||||
|
||||
// Extract avatar URL (picture for Google/OIDC, avatar_url for GitHub/GitLab)
|
||||
if picture, ok := rawUser["picture"].(string); ok && picture != "" {
|
||||
user.Picture = picture
|
||||
} else if avatarURL, ok := rawUser["avatar_url"].(string); ok && avatarURL != "" {
|
||||
user.Picture = avatarURL
|
||||
} else if photo, ok := rawUser["photo"].(string); ok && photo != "" {
|
||||
user.Picture = photo
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -33,7 +33,6 @@ import (
|
||||
sdk "github.com/btouchard/shm/sdk/golang"
|
||||
)
|
||||
|
||||
// Server represents the HTTP server with all its dependencies.
|
||||
type Server struct {
|
||||
httpServer *http.Server
|
||||
db *sql.DB
|
||||
@@ -90,7 +89,6 @@ type ServerBuilder struct {
|
||||
configService *services.ConfigService
|
||||
}
|
||||
|
||||
// NewServerBuilder creates a new server builder with the required configuration.
|
||||
func NewServerBuilder(cfg *config.Config, frontend embed.FS, version string) *ServerBuilder {
|
||||
return &ServerBuilder{
|
||||
cfg: cfg,
|
||||
@@ -205,7 +203,6 @@ func (b *ServerBuilder) Build(ctx context.Context) (*Server, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validateProviders checks that required providers are set.
|
||||
func (b *ServerBuilder) validateProviders() error {
|
||||
if b.db == nil {
|
||||
return errors.New("database is required: use WithDB()")
|
||||
@@ -238,7 +235,6 @@ func (b *ServerBuilder) setDefaultProviders() {
|
||||
}
|
||||
}
|
||||
|
||||
// initializeInfrastructure initializes signer, i18n, email and storage.
|
||||
func (b *ServerBuilder) initializeInfrastructure() error {
|
||||
var err error
|
||||
|
||||
@@ -297,7 +293,6 @@ type repositories struct {
|
||||
magicLink services.MagicLinkRepository
|
||||
}
|
||||
|
||||
// createRepositories creates all repository instances.
|
||||
func (b *ServerBuilder) createRepositories() *repositories {
|
||||
return &repositories{
|
||||
signature: database.NewSignatureRepository(b.db, b.tenantProvider),
|
||||
@@ -336,7 +331,6 @@ func (b *ServerBuilder) initializeTelemetry(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// initializeWebhookSystem initializes webhook publisher and worker.
|
||||
func (b *ServerBuilder) initializeWebhookSystem(ctx context.Context, repos *repositories) (*services.WebhookPublisher, *webhook.Worker, error) {
|
||||
whPublisher := services.NewWebhookPublisher(repos.webhook, repos.webhookDelivery)
|
||||
whCfg := webhook.DefaultWorkerConfig()
|
||||
@@ -349,7 +343,6 @@ func (b *ServerBuilder) initializeWebhookSystem(ctx context.Context, repos *repo
|
||||
return whPublisher, whWorker, nil
|
||||
}
|
||||
|
||||
// initializeEmailWorker initializes email worker for async processing.
|
||||
// emailRenderer is expected to be injected from main.go via WithEmailRenderer().
|
||||
func (b *ServerBuilder) initializeEmailWorker(ctx context.Context, repos *repositories, whPublisher *services.WebhookPublisher) (*email.Worker, error) {
|
||||
if b.emailSender == nil || b.cfg.Mail.Host == "" || b.emailRenderer == nil {
|
||||
@@ -370,7 +363,6 @@ func (b *ServerBuilder) initializeEmailWorker(ctx context.Context, repos *reposi
|
||||
return emailWorker, nil
|
||||
}
|
||||
|
||||
// initializeCoreServices initializes signature, document, admin, and webhook services.
|
||||
func (b *ServerBuilder) initializeCoreServices(repos *repositories) {
|
||||
b.signatureService = services.NewSignatureService(repos.signature, repos.document, b.signer)
|
||||
b.signatureService.SetChecksumConfig(&b.cfg.Checksum)
|
||||
@@ -379,7 +371,6 @@ func (b *ServerBuilder) initializeCoreServices(repos *repositories) {
|
||||
b.webhookService = services.NewWebhookService(repos.webhook, repos.webhookDelivery)
|
||||
}
|
||||
|
||||
// initializeConfigService creates and initializes the configuration service.
|
||||
func (b *ServerBuilder) initializeConfigService(ctx context.Context, repos *repositories) error {
|
||||
encryptionKey := b.cfg.OAuth.CookieSecret
|
||||
b.configService = services.NewConfigService(repos.config, b.cfg, encryptionKey)
|
||||
@@ -423,7 +414,6 @@ func (b *ServerBuilder) initializeMagicLinkCleanupWorker(ctx context.Context) *w
|
||||
return magicLinkWorker
|
||||
}
|
||||
|
||||
// initializeReminderService initializes reminder service.
|
||||
func (b *ServerBuilder) initializeReminderService(repos *repositories) {
|
||||
b.reminderService = services.NewReminderAsyncService(
|
||||
repos.expectedSigner,
|
||||
@@ -435,7 +425,6 @@ func (b *ServerBuilder) initializeReminderService(repos *repositories) {
|
||||
)
|
||||
}
|
||||
|
||||
// initializeSessionWorker initializes OAuth session cleanup worker.
|
||||
func (b *ServerBuilder) initializeSessionWorker(ctx context.Context, repos *repositories) (*auth.SessionWorker, error) {
|
||||
if repos.oauthSession == nil {
|
||||
return nil, nil
|
||||
@@ -450,7 +439,6 @@ func (b *ServerBuilder) initializeSessionWorker(ctx context.Context, repos *repo
|
||||
return sessionWorker, nil
|
||||
}
|
||||
|
||||
// buildRouter creates and configures the main router.
|
||||
func (b *ServerBuilder) buildRouter(repos *repositories, whPublisher *services.WebhookPublisher) *chi.Mux {
|
||||
router := chi.NewRouter()
|
||||
router.Use(i18n.Middleware(b.i18nService))
|
||||
@@ -554,22 +542,18 @@ func (s *Server) GetDB() *sql.DB {
|
||||
return s.db
|
||||
}
|
||||
|
||||
// GetAuthProvider returns the auth provider.
|
||||
func (s *Server) GetAuthProvider() AuthProvider {
|
||||
return s.authProvider
|
||||
}
|
||||
|
||||
// GetAuthorizer returns the authorizer.
|
||||
func (s *Server) GetAuthorizer() Authorizer {
|
||||
return s.authorizer
|
||||
}
|
||||
|
||||
// GetQuotaEnforcer returns the quota enforcer.
|
||||
func (s *Server) GetQuotaEnforcer() QuotaEnforcer {
|
||||
return s.quotaEnforcer
|
||||
}
|
||||
|
||||
// GetAuditLogger returns the audit logger.
|
||||
func (s *Server) GetAuditLogger() AuditLogger {
|
||||
return s.auditLogger
|
||||
}
|
||||
|
||||
@@ -234,7 +234,7 @@ watch(readMode, (newMode) => {
|
||||
<div class="space-y-4">
|
||||
<!-- URL input + Upload button + Submit button -->
|
||||
<div class="flex gap-3" :class="mode === 'full' ? 'flex-col sm:flex-row' : ''">
|
||||
<div class="flex-1">
|
||||
<div class="flex-1 min-w-0">
|
||||
<Label v-if="mode === 'full'" for="doc-url" class="mb-1.5">
|
||||
{{ t('documentCreateForm.url.label') }}
|
||||
</Label>
|
||||
@@ -242,13 +242,13 @@ watch(readMode, (newMode) => {
|
||||
<!-- Show selected file or URL input -->
|
||||
<div v-if="selectedFile" class="flex items-center gap-2 px-4 py-2.5 rounded-lg border border-blue-300 dark:border-blue-700 bg-blue-50 dark:bg-blue-950/30" :class="mode === 'full' ? 'h-11' : 'h-12'">
|
||||
<FileText class="w-4 h-4 text-blue-600 dark:text-blue-400 flex-shrink-0" />
|
||||
<span data-testid="selected-file-name" class="flex-1 text-sm text-blue-800 dark:text-blue-200 truncate">{{ selectedFile.name }}</span>
|
||||
<span class="text-xs text-blue-600 dark:text-blue-400">{{ formatFileSize(selectedFile.size) }}</span>
|
||||
<span data-testid="selected-file-name" class="flex-1 min-w-0 text-sm text-blue-800 dark:text-blue-200 truncate">{{ selectedFile.name }}</span>
|
||||
<span class="text-xs text-blue-600 dark:text-blue-400 flex-shrink-0 whitespace-nowrap">{{ formatFileSize(selectedFile.size) }}</span>
|
||||
<button
|
||||
type="button"
|
||||
data-testid="clear-file-button"
|
||||
@click="clearSelectedFile"
|
||||
class="p-1 rounded hover:bg-blue-200 dark:hover:bg-blue-800 text-blue-600 dark:text-blue-400"
|
||||
class="p-1 rounded hover:bg-blue-200 dark:hover:bg-blue-800 text-blue-600 dark:text-blue-400 flex-shrink-0"
|
||||
:disabled="isSubmitting"
|
||||
>
|
||||
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
|
||||
@@ -146,8 +146,18 @@ const closeUserMenu = () => {
|
||||
aria-haspopup="true"
|
||||
:aria-expanded="userMenuOpen"
|
||||
>
|
||||
<!-- User avatar with initials -->
|
||||
<div class="w-8 h-8 rounded-lg bg-slate-100 dark:bg-slate-700 flex items-center justify-center text-xs font-semibold text-slate-600 dark:text-slate-300">
|
||||
<!-- User avatar (picture or initials fallback) -->
|
||||
<img
|
||||
v-if="user?.picture"
|
||||
:src="user.picture"
|
||||
:alt="displayName"
|
||||
class="w-8 h-8 rounded-lg object-cover"
|
||||
referrerpolicy="no-referrer"
|
||||
/>
|
||||
<div
|
||||
v-else
|
||||
class="w-8 h-8 rounded-lg bg-slate-100 dark:bg-slate-700 flex items-center justify-center text-xs font-semibold text-slate-600 dark:text-slate-300"
|
||||
>
|
||||
{{ userInitials }}
|
||||
</div>
|
||||
<span class="text-slate-700 dark:text-slate-200 hidden lg:inline">{{ displayName }}</span>
|
||||
|
||||
@@ -8,6 +8,7 @@ export interface User {
|
||||
id: string
|
||||
email: string
|
||||
name: string
|
||||
picture?: string
|
||||
isAdmin: boolean
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user