Files
PrivateCaptcha/pkg/api/middlewares.go
2025-08-16 10:56:35 +03:00

382 lines
13 KiB
Go

package api
import (
"context"
"log/slog"
"net/http"
"time"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/billing"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/common"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/db"
dbgen "github.com/PrivateCaptcha/PrivateCaptcha/pkg/db/generated"
)
const (
userLimitTTL = 1 * time.Hour
userLimitRefresh = 3 * time.Hour
)
type UserLimiter interface {
CheckUsers(ctx context.Context, users map[int32]uint) error
Evaluate(ctx context.Context, userID int32) (bool, error)
}
type AuthMiddleware struct {
Store db.Implementor
PlanService billing.PlanService
SitekeyChan chan string
UsersChan chan int32
BatchSize int
SitekeyBackfillCancel context.CancelFunc
UsersBackfillCancel context.CancelFunc
Limiter UserLimiter
// this is a simple way to control negative cache spam, disabled by default
NegativeSitekeyThreshold uint
}
type baseUserLimiter struct {
store db.Implementor
userLimits common.Cache[int32, bool]
}
var _ UserLimiter = (*baseUserLimiter)(nil)
func (ul *baseUserLimiter) unknownUsers(ctx context.Context, users map[int32]uint) []int32 {
result := make([]int32, 0, len(users))
for userID := range users {
if _, err := ul.userLimits.Get(ctx, userID); err == db.ErrCacheMiss {
result = append(result, userID)
}
}
return result
}
func (ul *baseUserLimiter) CheckUsers(ctx context.Context, batch map[int32]uint) error {
if len(batch) == 0 {
slog.DebugContext(ctx, "No users to check")
return nil
}
unknownUsers := ul.unknownUsers(ctx, batch)
if len(unknownUsers) == 0 {
slog.DebugContext(ctx, "All user limits were recently checked", "count", len(batch))
return nil
}
t := struct{}{}
users, err := ul.store.Impl().RetrieveUsersWithoutSubscription(ctx, unknownUsers)
if err == nil {
violatorsMap := make(map[int32]struct{})
for _, u := range users {
_ = ul.userLimits.Set(ctx, u.ID, true)
violatorsMap[u.ID] = t
}
for _, u := range unknownUsers {
if _, found := violatorsMap[u]; !found {
_ = ul.userLimits.SetMissing(ctx, u)
}
}
} else {
slog.ErrorContext(ctx, "Failed to check users without subscriptions", "count", len(unknownUsers), common.ErrAttr(err))
}
return err
}
func (ul *baseUserLimiter) Evaluate(ctx context.Context, userID int32) (bool, error) {
_, err := ul.userLimits.Get(ctx, userID)
// "false" because by we only check if user has a subscription at all, we don't verify usage limits
return false, err
}
func NewUserLimiter(store db.Implementor) *baseUserLimiter {
const maxLimitedUsers = 10_000
var userLimits common.Cache[int32, bool]
var err error
// missing TTL should be equal to "usual" TTL here because it has the same meaning (we mark user has no violation)
userLimits, err = db.NewMemoryCache[int32, bool]("user_limits", maxLimitedUsers, false /*missing value*/, userLimitTTL, userLimitRefresh, userLimitTTL)
if err != nil {
slog.Error("Failed to create memory cache for user limits", common.ErrAttr(err))
userLimits = db.NewStaticCache[int32, bool](maxLimitedUsers, false /*missing data*/)
}
return &baseUserLimiter{
userLimits: userLimits,
store: store,
}
}
func NewAuthMiddleware(store db.Implementor,
userLimiter UserLimiter,
planService billing.PlanService) *AuthMiddleware {
const batchSize = 10
am := &AuthMiddleware{
Store: store,
Limiter: userLimiter,
PlanService: planService,
SitekeyChan: make(chan string, 100*batchSize),
UsersChan: make(chan int32, 10*batchSize),
BatchSize: batchSize,
SitekeyBackfillCancel: func() {},
UsersBackfillCancel: func() {},
}
return am
}
func (am *AuthMiddleware) StartBackfill(backfillDelay time.Duration) {
var sitekeyBackfillCtx context.Context
sitekeyBackfillCtx, am.SitekeyBackfillCancel = context.WithCancel(
context.WithValue(context.Background(), common.TraceIDContextKey, "sitekey_backfill"))
go common.ProcessBatchMap(sitekeyBackfillCtx, am.SitekeyChan, backfillDelay, am.BatchSize, am.BatchSize*100, am.backfillSitekeyImpl)
var usersBackfillCtx context.Context
usersBackfillCtx, am.UsersBackfillCancel = context.WithCancel(
context.WithValue(context.Background(), common.TraceIDContextKey, "users_backfill"))
// NOTE: we use the same backfill delay because users processing is slower and sitekey channel will block on it
go common.ProcessBatchMap(usersBackfillCtx, am.UsersChan, backfillDelay, am.BatchSize, am.BatchSize*10, am.backfillUsersImpl)
}
func (am *AuthMiddleware) Shutdown() {
slog.Debug("Shutting down auth middleware")
am.SitekeyBackfillCancel()
am.UsersBackfillCancel()
close(am.SitekeyChan)
close(am.UsersChan)
}
// we cache properties and send owners down the background pipeline
func (am *AuthMiddleware) backfillSitekeyImpl(ctx context.Context, batch map[string]uint) error {
properties, err := am.Store.Impl().RetrievePropertiesBySitekey(ctx, batch, am.NegativeSitekeyThreshold)
if err == nil {
for _, p := range properties {
if p.OrgOwnerID.Valid {
am.UsersChan <- p.OrgOwnerID.Int32
}
if p.CreatorID.Valid && (!p.OrgOwnerID.Valid || (p.CreatorID.Int32 != p.OrgOwnerID.Int32)) {
am.UsersChan <- p.CreatorID.Int32
}
}
} else {
slog.ErrorContext(ctx, "Failed to retrieve properties by sitekey", "count", len(batch), common.ErrAttr(err))
}
return err
}
// we block users without a subscription and (re)cache users API keys to ensure smooth auth in /verify codepath
func (am *AuthMiddleware) backfillUsersImpl(ctx context.Context, batch map[int32]uint) error {
if err := am.Limiter.CheckUsers(ctx, batch); err != nil {
slog.ErrorContext(ctx, "Failed to check user limits", common.ErrAttr(err))
// NOTE: we ignore this error because it is not critical for retry
}
// TODO: Refactor linear fetching of API keys to use batch mode
// we do it linearly instead of in a batch with the assumption that most of these will be cached
// (to be verified in metrics)
// but we can use another SQL query and also BulkGet API of otter (postponed as benefit is not obvious _atm_)
// also the same is in WarmupAPICacheJob (maintenance)
for userID := range batch {
if _, err := am.Store.Impl().RetrieveUserAPIKeys(ctx, userID); err != nil {
slog.ErrorContext(ctx, "Failed to retrieve users API keys", "userID", userID, common.ErrAttr(err))
}
}
// we ignore errors as both of the above are not critical to retry the batch
return nil
}
func (am *AuthMiddleware) originAllowed(r *http.Request, origin string) (bool, []string) {
return len(origin) > 0, nil
}
func isOriginAllowed(origin string, property *dbgen.Property) bool {
if common.IsLocalhost(origin) {
return property.AllowLocalhost
}
if property.AllowSubdomains {
return common.IsSubDomainOrDomain(origin, property.Domain)
}
return origin == property.Domain
}
func (am *AuthMiddleware) SitekeyOptions(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
sitekey := r.URL.Query().Get(common.ParamSiteKey)
// don't validate all characters for speed reasons
if len(sitekey) != db.SitekeyLen {
slog.Log(ctx, common.LevelTrace, "Sitekey is not valid", "method", r.Method)
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
ctx = context.WithValue(ctx, common.SitekeyContextKey, sitekey)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func (am *AuthMiddleware) refreshPropertyBySitekey(sitekey string) {
// backfill in the background
am.SitekeyChan <- sitekey
}
func (am *AuthMiddleware) Sitekey(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
origin := r.Header.Get("Origin")
if len(origin) == 0 {
slog.Log(ctx, common.LevelTrace, "Origin header is missing from the request")
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
// we verify sitekey in underlying DB call
sitekey := r.URL.Query().Get(common.ParamSiteKey)
property, err := am.Store.Impl().GetCachedPropertyBySitekey(ctx, sitekey, am.refreshPropertyBySitekey)
if err != nil {
switch err {
// this will happen when the user does not have such property or it was deleted
case db.ErrNegativeCacheHit, db.ErrRecordNotFound, db.ErrSoftDeleted:
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
case db.ErrInvalidInput:
slog.Log(ctx, common.LevelTrace, "Sitekey is not valid", "sitekey", len(sitekey), "origin", origin)
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
case db.ErrTestProperty:
// BUMP
case db.ErrCacheMiss:
// backfill in the background
am.SitekeyChan <- sitekey
default:
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
}
if property != nil {
if originHost, err := common.ParseDomainName(origin); err == nil {
if !isOriginAllowed(originHost, property) {
slog.WarnContext(ctx, "Origin is not allowed", "origin", originHost, "domain", property.Domain, "subdomains", property.AllowSubdomains)
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
} else {
slog.WarnContext(ctx, "Failed to parse origin domain name", common.ErrAttr(err))
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
if softRestriction, err := am.Limiter.Evaluate(ctx, property.OrgOwnerID.Int32); err == nil {
// if user is not an active subscriber, their properties and orgs might still exist but should not serve puzzles
if !softRestriction {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
} else {
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
}
return
}
ctx = context.WithValue(ctx, common.PropertyContextKey, property)
} else {
ctx = context.WithValue(ctx, common.SitekeyContextKey, sitekey)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func isAPIKeyValid(ctx context.Context, key *dbgen.APIKey, tnow time.Time) bool {
if key == nil {
return false
}
if !key.Enabled.Valid || !key.Enabled.Bool {
slog.WarnContext(ctx, "API key is disabled", "keyID", key.ID)
return false
}
if !key.ExpiresAt.Valid || key.ExpiresAt.Time.Before(tnow) {
slog.WarnContext(ctx, "API key is expired", "keyID", key.ID, "expiresAt", key.ExpiresAt)
return false
}
return true
}
func headerAPIKey(r *http.Request) string {
return r.Header.Get(common.HeaderAPIKey)
}
func formSecretAPIKey(r *http.Request) string {
return r.PostFormValue(common.ParamSecret)
}
func (am *AuthMiddleware) APIKey(keyFunc func(r *http.Request) string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
secret := keyFunc(r)
if len(secret) != db.SecretLen {
slog.Log(ctx, common.LevelTrace, "Invalid secret length", "length", len(secret))
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
// security assumptions here are that API keys of all legitimate users should be already cached via
// the backfill routine for puzzles (legitimate verification assumes a previously issued puzzle if on the same server)
// for everybody else, we rely on rate limiting and delaying DB access to check API key as long as possible.
// The only exception is when due to routing and/or horizontally scaled servers verify request lands on another node
apiKey, err := am.Store.Impl().GetCachedAPIKey(ctx, secret)
if err != nil {
slog.Log(ctx, common.LevelTrace, "Failed to get cached API key", common.ErrAttr(err))
switch err {
case db.ErrNegativeCacheHit, db.ErrRecordNotFound, db.ErrSoftDeleted:
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
case db.ErrInvalidInput:
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
case db.ErrCacheMiss:
// do nothing - we postpone accessing DB to after we verify parts of the payload itself
// we do not backfill API keys like puzzles as we have to check API key validity synchronously
default:
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
}
if apiKey != nil {
now := time.Now().UTC()
if !isAPIKeyValid(ctx, apiKey, now) {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
// if user is not an active subscriber, their properties and orgs might still exist but should not allow API
if softRestriction, err := am.Limiter.Evaluate(ctx, apiKey.UserID.Int32); (err == nil) && !softRestriction {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
ctx = context.WithValue(ctx, common.APIKeyContextKey, apiKey)
} else {
ctx = context.WithValue(ctx, common.SecretContextKey, secret)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}