fix: boundary conditions on 1-second rate limiters (#1379)

This commit is contained in:
abelanger5
2025-03-20 17:44:08 -04:00
committed by GitHub
parent 2333090751
commit aebcf0bb0c
7 changed files with 98 additions and 24 deletions

View File

@@ -0,0 +1,33 @@
-- +goose Up
-- +goose StatementBegin
CREATE OR REPLACE FUNCTION get_refill_value(rate_limit "RateLimit")
RETURNS INTEGER AS $$
DECLARE
refill_amount INTEGER;
BEGIN
IF (NOW() - rate_limit."lastRefill") >= (rate_limit."window"::INTERVAL - INTERVAL '10 milliseconds') THEN
refill_amount := rate_limit."limitValue";
ELSE
refill_amount := rate_limit."value";
END IF;
RETURN refill_amount;
END;
$$ LANGUAGE plpgsql;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
CREATE OR REPLACE FUNCTION get_refill_value(rate_limit "RateLimit")
RETURNS INTEGER AS $$
DECLARE
refill_amount INTEGER;
BEGIN
IF (NOW() - rate_limit."lastRefill") >= (rate_limit."window"::INTERVAL) THEN
refill_amount := rate_limit."limitValue";
ELSE
refill_amount := rate_limit."value";
END IF;
RETURN refill_amount;
END;
$$ LANGUAGE plpgsql;
-- +goose StatementEnd

View File

@@ -2,6 +2,7 @@ package v1
import (
"context"
"time"
"github.com/hatchet-dev/hatchet/pkg/repository/v1/sqlcv1"
"github.com/jackc/pgx/v5/pgtype"
@@ -39,7 +40,7 @@ type QueueRepository interface {
type RateLimitRepository interface {
ListCandidateRateLimits(ctx context.Context, tenantId pgtype.UUID) ([]string, error)
UpdateRateLimits(ctx context.Context, tenantId pgtype.UUID, updates map[string]int) (map[string]int, error)
UpdateRateLimits(ctx context.Context, tenantId pgtype.UUID, updates map[string]int) (map[string]int, *time.Time, error)
}
type AssignmentRepository interface {

View File

@@ -2,6 +2,7 @@ package v1
import (
"context"
"time"
"github.com/hatchet-dev/hatchet/pkg/repository/postgres/sqlchelpers"
"github.com/hatchet-dev/hatchet/pkg/repository/v1/sqlcv1"
@@ -39,11 +40,11 @@ func (d *rateLimitRepository) ListCandidateRateLimits(ctx context.Context, tenan
return ids, nil
}
func (d *rateLimitRepository) UpdateRateLimits(ctx context.Context, tenantId pgtype.UUID, updates map[string]int) (map[string]int, error) {
func (d *rateLimitRepository) UpdateRateLimits(ctx context.Context, tenantId pgtype.UUID, updates map[string]int) (map[string]int, *time.Time, error) {
tx, commit, rollback, err := sqlchelpers.PrepareTx(ctx, d.pool, d.l, 5000)
if err != nil {
return nil, err
return nil, nil, err
}
defer rollback()
@@ -62,17 +63,17 @@ func (d *rateLimitRepository) UpdateRateLimits(ctx context.Context, tenantId pgt
_, err = d.queries.BulkUpdateRateLimits(ctx, tx, params)
if err != nil {
return nil, err
return nil, nil, err
}
newRls, err := d.queries.ListRateLimitsForTenantWithMutate(ctx, tx, tenantId)
if err != nil {
return nil, err
return nil, nil, err
}
if err := commit(ctx); err != nil {
return nil, err
return nil, nil, err
}
res := make(map[string]int, len(newRls))
@@ -81,5 +82,16 @@ func (d *rateLimitRepository) UpdateRateLimits(ctx context.Context, tenantId pgt
res[rl.Key] = int(rl.Value)
}
return res, err
nextRefillAt := time.Now().Add(time.Second * 2)
if len(newRls) > 0 {
// get min of all next refill times
for _, rl := range newRls {
if rl.NextRefillAt.Time.Before(nextRefillAt) {
nextRefillAt = rl.NextRefillAt.Time
}
}
}
return res, &nextRefillAt, err
}

View File

@@ -74,7 +74,7 @@ WITH rls_to_update AS (
"RateLimit" rl
WHERE
rl."tenantId" = @tenantId::uuid
AND NOW() - rl."lastRefill" >= rl."window"::INTERVAL
AND NOW() - rl."lastRefill" >= (rl."window"::INTERVAL - INTERVAL '10 milliseconds')
ORDER BY
rl."tenantId" ASC, rl."key" ASC
FOR UPDATE
@@ -93,7 +93,7 @@ WITH rls_to_update AS (
)
SELECT
rl.*,
(rl."lastRefill" + rl."window"::INTERVAL)::timestamp AS "nextRefillAt"
(rl."lastRefill" + rl."window"::INTERVAL - INTERVAL '10 milliseconds')::timestamp AS "nextRefillAt"
FROM
"RateLimit" rl
WHERE
@@ -105,7 +105,7 @@ UNION ALL
SELECT
refill.*,
-- return the next refill time
(refill."lastRefill" + refill."window"::INTERVAL)::timestamp AS "nextRefillAt"
(refill."lastRefill" + refill."window"::INTERVAL - INTERVAL '10 milliseconds')::timestamp AS "nextRefillAt"
FROM
refill;
@@ -145,7 +145,7 @@ UPDATE
SET
"value" = get_refill_value(rl) - (SELECT "units" FROM input WHERE "key" = rl."key"),
"lastRefill" = CASE
WHEN NOW() - rl."lastRefill" >= rl."window"::INTERVAL THEN
WHEN NOW() - rl."lastRefill" >= (rl."window"::INTERVAL - INTERVAL '10 milliseconds') THEN
CURRENT_TIMESTAMP
ELSE
rl."lastRefill"

View File

@@ -38,7 +38,7 @@ UPDATE
SET
"value" = get_refill_value(rl) - (SELECT "units" FROM input WHERE "key" = rl."key"),
"lastRefill" = CASE
WHEN NOW() - rl."lastRefill" >= rl."window"::INTERVAL THEN
WHEN NOW() - rl."lastRefill" >= (rl."window"::INTERVAL - INTERVAL '10 milliseconds') THEN
CURRENT_TIMESTAMP
ELSE
rl."lastRefill"
@@ -216,7 +216,7 @@ WITH rls_to_update AS (
"RateLimit" rl
WHERE
rl."tenantId" = $1::uuid
AND NOW() - rl."lastRefill" >= rl."window"::INTERVAL
AND NOW() - rl."lastRefill" >= (rl."window"::INTERVAL - INTERVAL '10 milliseconds')
ORDER BY
rl."tenantId" ASC, rl."key" ASC
FOR UPDATE
@@ -235,7 +235,7 @@ WITH rls_to_update AS (
)
SELECT
rl."tenantId", rl.key, rl."limitValue", rl.value, rl."window", rl."lastRefill",
(rl."lastRefill" + rl."window"::INTERVAL)::timestamp AS "nextRefillAt"
(rl."lastRefill" + rl."window"::INTERVAL - INTERVAL '10 milliseconds')::timestamp AS "nextRefillAt"
FROM
"RateLimit" rl
WHERE
@@ -247,7 +247,7 @@ UNION ALL
SELECT
refill."tenantId", refill.key, refill."limitValue", refill.value, refill."window", refill."lastRefill",
-- return the next refill time
(refill."lastRefill" + refill."window"::INTERVAL)::timestamp AS "nextRefillAt"
(refill."lastRefill" + refill."window"::INTERVAL - INTERVAL '10 milliseconds')::timestamp AS "nextRefillAt"
FROM
refill
`

View File

@@ -23,6 +23,9 @@ type rateLimiter struct {
tenantId pgtype.UUID
nextRefillAt *time.Time
nextRefillAtMu sync.RWMutex
l *zerolog.Logger
// unacked is a map of taskId to rateLimitSet
@@ -106,6 +109,13 @@ func (r *rateLimiter) use(ctx context.Context, taskId int64, rls map[string]int3
if !r.rateLimitsExist(rls) {
return res
}
} else if r.shouldRefill() {
err := r.flushToDatabase(ctx)
if err != nil {
r.l.Error().Err(err).Msg("error flushing rate limits to database")
return res
}
}
currRls := r.copyDbRateLimits()
@@ -154,6 +164,17 @@ func (r *rateLimiter) rateLimitsExist(rls map[string]int32) bool {
return true
}
func (r *rateLimiter) shouldRefill() bool {
r.nextRefillAtMu.Lock()
defer r.nextRefillAtMu.Unlock()
if r.nextRefillAt == nil {
return false
}
return r.nextRefillAt.After(time.Now().UTC())
}
func (r *rateLimiter) copyDbRateLimits() rateLimitSet {
r.dbRateLimitsMu.RLock()
defer r.dbRateLimitsMu.RUnlock()
@@ -263,12 +284,17 @@ func (r *rateLimiter) flushToDatabase(ctx context.Context) error {
updates[k] = v.val
}
newRateLimits, err := r.rateLimitRepo.UpdateRateLimits(ctx, r.tenantId, updates)
newRateLimits, nextRefillAt, err := r.rateLimitRepo.UpdateRateLimits(ctx, r.tenantId, updates)
if err != nil {
return err
}
// update the next refill time
r.nextRefillAtMu.Lock()
r.nextRefillAt = nextRefillAt
r.nextRefillAtMu.Unlock()
r.dbRateLimits = make(rateLimitSet)
// update the db rate limits

View File

@@ -6,6 +6,7 @@ import (
"math/rand"
"sync"
"testing"
"time"
"github.com/jackc/pgx/v5/pgtype"
"github.com/rs/zerolog"
@@ -22,16 +23,17 @@ func (m *mockRateLimitRepo) ListCandidateRateLimits(ctx context.Context, tenantI
return args.Get(0).([]string), args.Error(1)
}
func (m *mockRateLimitRepo) UpdateRateLimits(ctx context.Context, tenantId pgtype.UUID, updates map[string]int) (map[string]int, error) {
func (m *mockRateLimitRepo) UpdateRateLimits(ctx context.Context, tenantId pgtype.UUID, updates map[string]int) (map[string]int, *time.Time, error) {
args := m.Called(ctx, tenantId, updates)
return args.Get(0).(map[string]int), args.Error(1)
arg1 := args.Get(1).(time.Time)
return args.Get(0).(map[string]int), &arg1, args.Error(2)
}
func TestRateLimiter_Use(t *testing.T) {
l := zerolog.Nop()
mockRateLimitRepo := &mockRateLimitRepo{}
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5, "key3": 7}, nil)
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5, "key3": 7}, time.Now().Add(2*time.Second), nil)
rateLimiter := &rateLimiter{
dbRateLimits: rateLimitSet{
@@ -62,7 +64,7 @@ func TestRateLimiter_Ack(t *testing.T) {
l := zerolog.Nop()
mockRateLimitRepo := &mockRateLimitRepo{}
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5}, nil)
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5}, time.Now().Add(2*time.Second), nil)
rateLimiter := &rateLimiter{
dbRateLimits: rateLimitSet{
@@ -87,7 +89,7 @@ func TestRateLimiter_Nack(t *testing.T) {
l := zerolog.Nop()
mockRateLimitRepo := &mockRateLimitRepo{}
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5}, nil)
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5}, time.Now().Add(2*time.Second), nil)
rateLimiter := &rateLimiter{
dbRateLimits: rateLimitSet{
@@ -112,7 +114,7 @@ func TestRateLimiter_Concurrency(t *testing.T) {
l := zerolog.Nop()
mockRateLimitRepo := &mockRateLimitRepo{}
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 100, "key2": 100}, nil)
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 100, "key2": 100}, time.Now().Add(2*time.Second), nil)
rateLimiter := &rateLimiter{
dbRateLimits: rateLimitSet{
@@ -151,7 +153,7 @@ func TestRateLimiter_FlushToDatabase(t *testing.T) {
l := zerolog.Nop()
mockRateLimitRepo := &mockRateLimitRepo{} // Mock implementation of rateLimitRepo
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5}, nil)
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5}, time.Now().Add(2*time.Second), nil)
rateLimiter := &rateLimiter{
dbRateLimits: rateLimitSet{
@@ -184,7 +186,7 @@ func BenchmarkRateLimiter(b *testing.B) {
l := zerolog.Nop()
mockRateLimitRepo := &mockRateLimitRepo{}
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 1000, "key2": 1000}, nil)
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 1000, "key2": 1000}, time.Now().Add(2*time.Second), nil)
r := rateLimiter{
unacked: make(map[int64]rateLimitSet),