Files
PrivateCaptcha/pkg/db/business.go
T
2026-05-06 14:21:40 +02:00

246 lines
7.8 KiB
Go

package db
import (
"context"
"errors"
"log/slog"
"strings"
"sync/atomic"
"time"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/common"
dbgen "github.com/PrivateCaptcha/PrivateCaptcha/pkg/db/generated"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/puzzle"
"github.com/jackc/pgx/v5/pgxpool"
)
var (
ErrInvalidInput = errors.New("invalid input")
ErrRecordNotFound = errors.New("record not found")
ErrSoftDeleted = errors.New("record is marked as deleted")
ErrDuplicateAccount = errors.New("this subscription already has an account")
ErrLocked = errors.New("lock is already acquired")
ErrMaintenance = errors.New("maintenance mode")
ErrTestProperty = errors.New("test property")
ErrAlreadyExists = errors.New("record already exists")
ErrDisabled = errors.New("entity is disabled")
ErrPermissions = errors.New("insufficient permissions")
errInvalidCacheType = errors.New("cache record type does not match")
TestPropertySitekey = strings.ReplaceAll(TestPropertyID, "-", "")
PortalLoginSitekey = strings.ReplaceAll(PortalLoginPropertyID, "-", "")
PortalRegisterSitekey = strings.ReplaceAll(PortalRegisterPropertyID, "-", "")
TestPropertyUUID = UUIDFromSiteKey(TestPropertySitekey)
)
const (
PortalLoginPropertyID = "1ca8041a-5761-40a4-addf-f715a991bfea"
PortalRegisterPropertyID = "8981be7a-3a71-414d-bb74-e7b4456603fd"
TestPropertyID = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
defaultCacheTTL = 15 * time.Minute
defaultCacheRefresh = 30 * time.Minute
negativeCacheTTL = 5 * time.Minute
auditBatchSize = 100
cachePersistFile = "business_cache.gob"
cachePersistSize = 10_000
)
type BusinessStore struct {
Pool *pgxpool.Pool
defaultImpl *BusinessStoreImpl
cacheOnlyImpl *BusinessStoreImpl
Cache common.Cache[CacheKey, any]
auditLog *AuditLog
discardAuditLog *DiscardAuditLog
// this could have been a bloom/cuckoo filter with expiration, if they existed
puzzleCache *puzzleCache
MaintenanceMode atomic.Bool
}
type Implementor interface {
Impl() *BusinessStoreImpl
WithTx(ctx context.Context, fn func(*BusinessStoreImpl) ([]*common.AuditLogEvent, error)) ([]*common.AuditLogEvent, error)
Ping(ctx context.Context) error
CheckVerifiedPuzzle(ctx context.Context, p puzzle.Puzzle, maxCount uint32) bool
CacheVerifiedPuzzle(ctx context.Context, p puzzle.Puzzle, tnow time.Time)
CheckUserPropertyAccess(ctx context.Context, property *dbgen.Property, userID int32) bool
CacheHitRatio() float64
AuditLog() common.AuditLog
}
var _ Implementor = (*BusinessStore)(nil)
func NewBusiness(pool *pgxpool.Pool) *BusinessStore {
const maxCacheSize = 10_000_000
var cache common.Cache[CacheKey, any]
var err error
cache, err = NewMemoryCache[CacheKey, any]("default", maxCacheSize, &struct{}{}, defaultCacheTTL, defaultCacheRefresh, negativeCacheTTL)
if err != nil {
slog.Error("Failed to create memory cache", common.ErrAttr(err))
cache = NewStaticCache[CacheKey, any](maxCacheSize, &struct{}{})
}
return NewBusinessEx(pool, cache)
}
func NewBusinessEx(pool *pgxpool.Pool, cache common.Cache[CacheKey, any]) *BusinessStore {
var querier dbgen.Querier
if pool != nil {
querier = dbgen.New(pool)
}
auditLog := NewAuditLog(querier, auditBatchSize)
return &BusinessStore{
Pool: pool,
auditLog: auditLog,
discardAuditLog: &DiscardAuditLog{},
defaultImpl: &BusinessStoreImpl{cache: cache, querier: querier},
cacheOnlyImpl: &BusinessStoreImpl{cache: cache},
Cache: cache,
puzzleCache: newPuzzleCache(puzzle.DefaultValidityPeriod),
}
}
func (s *BusinessStore) UpdateConfig(maintenanceMode bool) {
s.MaintenanceMode.Store(maintenanceMode)
}
func (s *BusinessStore) AuditLog() common.AuditLog {
if s.MaintenanceMode.Load() {
return s.discardAuditLog
}
return s.auditLog
}
func (s *BusinessStore) Impl() *BusinessStoreImpl {
if s.MaintenanceMode.Load() {
return s.cacheOnlyImpl
}
return s.defaultImpl
}
func (s *BusinessStore) Start(ctx context.Context, auditLogInterval time.Duration) {
s.auditLog.Start(ctx, auditLogInterval)
}
func (s *BusinessStore) Shutdown() {
s.auditLog.Shutdown()
}
func (s *BusinessStore) WithTx(ctx context.Context, fn func(*BusinessStoreImpl) ([]*common.AuditLogEvent, error)) ([]*common.AuditLogEvent, error) {
if s.MaintenanceMode.Load() {
return nil, ErrMaintenance
}
tx, err := s.Pool.Begin(ctx)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
if rerr := tx.Rollback(ctx); rerr != nil {
slog.ErrorContext(ctx, "Failed to rollback transaction", common.ErrAttr(rerr))
}
}
}()
db := dbgen.New(s.Pool)
tmpCache := NewTxCache()
impl := &BusinessStoreImpl{cache: tmpCache, querier: db.WithTx(tx)}
var auditEvents []*common.AuditLogEvent
auditEvents, err = fn(impl)
if err != nil {
return auditEvents, err
}
err = tx.Commit(ctx)
if err != nil {
return auditEvents, err
}
tmpCache.Commit(ctx, s.Cache)
return auditEvents, nil
}
func (s *BusinessStore) Ping(ctx context.Context) error {
// NOTE: we always use "real" DB connection to check for ping
return s.defaultImpl.ping(ctx)
}
func (s *BusinessStore) CacheHitRatio() float64 {
return s.Cache.HitRatio()
}
func (s *BusinessStore) CheckVerifiedPuzzle(ctx context.Context, p puzzle.Puzzle, maxCount uint32) bool {
if p == nil || p.IsZero() {
return false
}
// purely theoretically there's still a chance of cache collision, but it's so negligible that it's allowed
// (HashKey() has to match during puzzle.DefaultValidityPeriod on the same server)
return !s.puzzleCache.CheckCount(ctx, p.HashKey(), maxCount)
}
func (s *BusinessStore) CacheVerifiedPuzzle(ctx context.Context, p puzzle.Puzzle, tnow time.Time) {
if p == nil || p.IsZero() {
slog.Log(ctx, common.LevelTrace, "Skipping caching zero puzzle")
return
}
expiration := p.Expiration()
// this check should have been done before in the pipeline. Here the check only to safeguard storing in cache
if !tnow.Before(expiration) {
slog.WarnContext(ctx, "Skipping caching expired puzzle", "now", tnow, "expiration", p.Expiration())
return
}
key := p.HashKey()
value := s.puzzleCache.Inc(ctx, key, expiration.Sub(tnow))
slog.Log(ctx, common.LevelTrace, "Cached verified puzzle", "times", value, "key", key)
}
func (s *BusinessStore) CheckUserPropertyAccess(ctx context.Context, property *dbgen.Property, userID int32) bool {
_, level, err := s.cacheOnlyImpl.retrieveOrganizationWithAccess(ctx, userID, property.OrgID.Int32)
if (err == nil) && level.Valid {
return (level.AccessLevel == dbgen.AccessLevelMember) || (level.AccessLevel == dbgen.AccessLevelOwner)
}
// the reason we use orgUsers() API and not userOrgs() API is that the former is a bit faster
members, err := s.Impl().RetrieveOrganizationUsers(ctx, property.OrgID.Int32)
if (err == nil) && (len(members) > 0) {
for _, user := range members {
if user.User.ID == userID {
slog.DebugContext(ctx, "Found user as org member", "level", user.Level, "userID", userID, "orgID", property.OrgID)
return (user.Level == dbgen.AccessLevelMember) || (user.Level == dbgen.AccessLevelOwner)
}
}
}
return false
}
func (s *BusinessStore) SaveCache(ctx context.Context, dir string) error {
mc, ok := s.Cache.(*memcache[CacheKey, any])
if !ok || mc == nil {
slog.ErrorContext(ctx, "Business cache is not memcache-backed; skipping persistence")
return nil
}
return common.SaveCacheToFile(ctx, dir, cachePersistFile, cachePersistSize, mc.store, nil)
}
func (s *BusinessStore) LoadCache(ctx context.Context, dir string) error {
mc, ok := s.Cache.(*memcache[CacheKey, any])
if !ok || mc == nil {
slog.ErrorContext(ctx, "Business cache is not memcache-backed; skipping load")
return nil
}
return common.LoadCacheFromFile(ctx, dir, cachePersistFile, defaultCacheTTL, mc.store)
}