mirror of
https://github.com/hatchet-dev/hatchet.git
synced 2025-12-30 21:29:44 -06:00
fix: boundary conditions on 1-second rate limiters (#1379)
This commit is contained in:
@@ -0,0 +1,33 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
CREATE OR REPLACE FUNCTION get_refill_value(rate_limit "RateLimit")
|
||||
RETURNS INTEGER AS $$
|
||||
DECLARE
|
||||
refill_amount INTEGER;
|
||||
BEGIN
|
||||
IF (NOW() - rate_limit."lastRefill") >= (rate_limit."window"::INTERVAL - INTERVAL '10 milliseconds') THEN
|
||||
refill_amount := rate_limit."limitValue";
|
||||
ELSE
|
||||
refill_amount := rate_limit."value";
|
||||
END IF;
|
||||
RETURN refill_amount;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
CREATE OR REPLACE FUNCTION get_refill_value(rate_limit "RateLimit")
|
||||
RETURNS INTEGER AS $$
|
||||
DECLARE
|
||||
refill_amount INTEGER;
|
||||
BEGIN
|
||||
IF (NOW() - rate_limit."lastRefill") >= (rate_limit."window"::INTERVAL) THEN
|
||||
refill_amount := rate_limit."limitValue";
|
||||
ELSE
|
||||
refill_amount := rate_limit."value";
|
||||
END IF;
|
||||
RETURN refill_amount;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
-- +goose StatementEnd
|
||||
@@ -2,6 +2,7 @@ package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/hatchet-dev/hatchet/pkg/repository/v1/sqlcv1"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
@@ -39,7 +40,7 @@ type QueueRepository interface {
|
||||
|
||||
type RateLimitRepository interface {
|
||||
ListCandidateRateLimits(ctx context.Context, tenantId pgtype.UUID) ([]string, error)
|
||||
UpdateRateLimits(ctx context.Context, tenantId pgtype.UUID, updates map[string]int) (map[string]int, error)
|
||||
UpdateRateLimits(ctx context.Context, tenantId pgtype.UUID, updates map[string]int) (map[string]int, *time.Time, error)
|
||||
}
|
||||
|
||||
type AssignmentRepository interface {
|
||||
|
||||
@@ -2,6 +2,7 @@ package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/hatchet-dev/hatchet/pkg/repository/postgres/sqlchelpers"
|
||||
"github.com/hatchet-dev/hatchet/pkg/repository/v1/sqlcv1"
|
||||
@@ -39,11 +40,11 @@ func (d *rateLimitRepository) ListCandidateRateLimits(ctx context.Context, tenan
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (d *rateLimitRepository) UpdateRateLimits(ctx context.Context, tenantId pgtype.UUID, updates map[string]int) (map[string]int, error) {
|
||||
func (d *rateLimitRepository) UpdateRateLimits(ctx context.Context, tenantId pgtype.UUID, updates map[string]int) (map[string]int, *time.Time, error) {
|
||||
tx, commit, rollback, err := sqlchelpers.PrepareTx(ctx, d.pool, d.l, 5000)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
defer rollback()
|
||||
@@ -62,17 +63,17 @@ func (d *rateLimitRepository) UpdateRateLimits(ctx context.Context, tenantId pgt
|
||||
_, err = d.queries.BulkUpdateRateLimits(ctx, tx, params)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
newRls, err := d.queries.ListRateLimitsForTenantWithMutate(ctx, tx, tenantId)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := commit(ctx); err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
res := make(map[string]int, len(newRls))
|
||||
@@ -81,5 +82,16 @@ func (d *rateLimitRepository) UpdateRateLimits(ctx context.Context, tenantId pgt
|
||||
res[rl.Key] = int(rl.Value)
|
||||
}
|
||||
|
||||
return res, err
|
||||
nextRefillAt := time.Now().Add(time.Second * 2)
|
||||
|
||||
if len(newRls) > 0 {
|
||||
// get min of all next refill times
|
||||
for _, rl := range newRls {
|
||||
if rl.NextRefillAt.Time.Before(nextRefillAt) {
|
||||
nextRefillAt = rl.NextRefillAt.Time
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return res, &nextRefillAt, err
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ WITH rls_to_update AS (
|
||||
"RateLimit" rl
|
||||
WHERE
|
||||
rl."tenantId" = @tenantId::uuid
|
||||
AND NOW() - rl."lastRefill" >= rl."window"::INTERVAL
|
||||
AND NOW() - rl."lastRefill" >= (rl."window"::INTERVAL - INTERVAL '10 milliseconds')
|
||||
ORDER BY
|
||||
rl."tenantId" ASC, rl."key" ASC
|
||||
FOR UPDATE
|
||||
@@ -93,7 +93,7 @@ WITH rls_to_update AS (
|
||||
)
|
||||
SELECT
|
||||
rl.*,
|
||||
(rl."lastRefill" + rl."window"::INTERVAL)::timestamp AS "nextRefillAt"
|
||||
(rl."lastRefill" + rl."window"::INTERVAL - INTERVAL '10 milliseconds')::timestamp AS "nextRefillAt"
|
||||
FROM
|
||||
"RateLimit" rl
|
||||
WHERE
|
||||
@@ -105,7 +105,7 @@ UNION ALL
|
||||
SELECT
|
||||
refill.*,
|
||||
-- return the next refill time
|
||||
(refill."lastRefill" + refill."window"::INTERVAL)::timestamp AS "nextRefillAt"
|
||||
(refill."lastRefill" + refill."window"::INTERVAL - INTERVAL '10 milliseconds')::timestamp AS "nextRefillAt"
|
||||
FROM
|
||||
refill;
|
||||
|
||||
@@ -145,7 +145,7 @@ UPDATE
|
||||
SET
|
||||
"value" = get_refill_value(rl) - (SELECT "units" FROM input WHERE "key" = rl."key"),
|
||||
"lastRefill" = CASE
|
||||
WHEN NOW() - rl."lastRefill" >= rl."window"::INTERVAL THEN
|
||||
WHEN NOW() - rl."lastRefill" >= (rl."window"::INTERVAL - INTERVAL '10 milliseconds') THEN
|
||||
CURRENT_TIMESTAMP
|
||||
ELSE
|
||||
rl."lastRefill"
|
||||
|
||||
@@ -38,7 +38,7 @@ UPDATE
|
||||
SET
|
||||
"value" = get_refill_value(rl) - (SELECT "units" FROM input WHERE "key" = rl."key"),
|
||||
"lastRefill" = CASE
|
||||
WHEN NOW() - rl."lastRefill" >= rl."window"::INTERVAL THEN
|
||||
WHEN NOW() - rl."lastRefill" >= (rl."window"::INTERVAL - INTERVAL '10 milliseconds') THEN
|
||||
CURRENT_TIMESTAMP
|
||||
ELSE
|
||||
rl."lastRefill"
|
||||
@@ -216,7 +216,7 @@ WITH rls_to_update AS (
|
||||
"RateLimit" rl
|
||||
WHERE
|
||||
rl."tenantId" = $1::uuid
|
||||
AND NOW() - rl."lastRefill" >= rl."window"::INTERVAL
|
||||
AND NOW() - rl."lastRefill" >= (rl."window"::INTERVAL - INTERVAL '10 milliseconds')
|
||||
ORDER BY
|
||||
rl."tenantId" ASC, rl."key" ASC
|
||||
FOR UPDATE
|
||||
@@ -235,7 +235,7 @@ WITH rls_to_update AS (
|
||||
)
|
||||
SELECT
|
||||
rl."tenantId", rl.key, rl."limitValue", rl.value, rl."window", rl."lastRefill",
|
||||
(rl."lastRefill" + rl."window"::INTERVAL)::timestamp AS "nextRefillAt"
|
||||
(rl."lastRefill" + rl."window"::INTERVAL - INTERVAL '10 milliseconds')::timestamp AS "nextRefillAt"
|
||||
FROM
|
||||
"RateLimit" rl
|
||||
WHERE
|
||||
@@ -247,7 +247,7 @@ UNION ALL
|
||||
SELECT
|
||||
refill."tenantId", refill.key, refill."limitValue", refill.value, refill."window", refill."lastRefill",
|
||||
-- return the next refill time
|
||||
(refill."lastRefill" + refill."window"::INTERVAL)::timestamp AS "nextRefillAt"
|
||||
(refill."lastRefill" + refill."window"::INTERVAL - INTERVAL '10 milliseconds')::timestamp AS "nextRefillAt"
|
||||
FROM
|
||||
refill
|
||||
`
|
||||
|
||||
@@ -23,6 +23,9 @@ type rateLimiter struct {
|
||||
|
||||
tenantId pgtype.UUID
|
||||
|
||||
nextRefillAt *time.Time
|
||||
nextRefillAtMu sync.RWMutex
|
||||
|
||||
l *zerolog.Logger
|
||||
|
||||
// unacked is a map of taskId to rateLimitSet
|
||||
@@ -106,6 +109,13 @@ func (r *rateLimiter) use(ctx context.Context, taskId int64, rls map[string]int3
|
||||
if !r.rateLimitsExist(rls) {
|
||||
return res
|
||||
}
|
||||
} else if r.shouldRefill() {
|
||||
err := r.flushToDatabase(ctx)
|
||||
|
||||
if err != nil {
|
||||
r.l.Error().Err(err).Msg("error flushing rate limits to database")
|
||||
return res
|
||||
}
|
||||
}
|
||||
|
||||
currRls := r.copyDbRateLimits()
|
||||
@@ -154,6 +164,17 @@ func (r *rateLimiter) rateLimitsExist(rls map[string]int32) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *rateLimiter) shouldRefill() bool {
|
||||
r.nextRefillAtMu.Lock()
|
||||
defer r.nextRefillAtMu.Unlock()
|
||||
|
||||
if r.nextRefillAt == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return r.nextRefillAt.After(time.Now().UTC())
|
||||
}
|
||||
|
||||
func (r *rateLimiter) copyDbRateLimits() rateLimitSet {
|
||||
r.dbRateLimitsMu.RLock()
|
||||
defer r.dbRateLimitsMu.RUnlock()
|
||||
@@ -263,12 +284,17 @@ func (r *rateLimiter) flushToDatabase(ctx context.Context) error {
|
||||
updates[k] = v.val
|
||||
}
|
||||
|
||||
newRateLimits, err := r.rateLimitRepo.UpdateRateLimits(ctx, r.tenantId, updates)
|
||||
newRateLimits, nextRefillAt, err := r.rateLimitRepo.UpdateRateLimits(ctx, r.tenantId, updates)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update the next refill time
|
||||
r.nextRefillAtMu.Lock()
|
||||
r.nextRefillAt = nextRefillAt
|
||||
r.nextRefillAtMu.Unlock()
|
||||
|
||||
r.dbRateLimits = make(rateLimitSet)
|
||||
|
||||
// update the db rate limits
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"math/rand"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/rs/zerolog"
|
||||
@@ -22,16 +23,17 @@ func (m *mockRateLimitRepo) ListCandidateRateLimits(ctx context.Context, tenantI
|
||||
return args.Get(0).([]string), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockRateLimitRepo) UpdateRateLimits(ctx context.Context, tenantId pgtype.UUID, updates map[string]int) (map[string]int, error) {
|
||||
func (m *mockRateLimitRepo) UpdateRateLimits(ctx context.Context, tenantId pgtype.UUID, updates map[string]int) (map[string]int, *time.Time, error) {
|
||||
args := m.Called(ctx, tenantId, updates)
|
||||
return args.Get(0).(map[string]int), args.Error(1)
|
||||
arg1 := args.Get(1).(time.Time)
|
||||
return args.Get(0).(map[string]int), &arg1, args.Error(2)
|
||||
}
|
||||
|
||||
func TestRateLimiter_Use(t *testing.T) {
|
||||
l := zerolog.Nop()
|
||||
|
||||
mockRateLimitRepo := &mockRateLimitRepo{}
|
||||
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5, "key3": 7}, nil)
|
||||
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5, "key3": 7}, time.Now().Add(2*time.Second), nil)
|
||||
|
||||
rateLimiter := &rateLimiter{
|
||||
dbRateLimits: rateLimitSet{
|
||||
@@ -62,7 +64,7 @@ func TestRateLimiter_Ack(t *testing.T) {
|
||||
l := zerolog.Nop()
|
||||
|
||||
mockRateLimitRepo := &mockRateLimitRepo{}
|
||||
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5}, nil)
|
||||
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5}, time.Now().Add(2*time.Second), nil)
|
||||
|
||||
rateLimiter := &rateLimiter{
|
||||
dbRateLimits: rateLimitSet{
|
||||
@@ -87,7 +89,7 @@ func TestRateLimiter_Nack(t *testing.T) {
|
||||
l := zerolog.Nop()
|
||||
|
||||
mockRateLimitRepo := &mockRateLimitRepo{}
|
||||
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5}, nil)
|
||||
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5}, time.Now().Add(2*time.Second), nil)
|
||||
|
||||
rateLimiter := &rateLimiter{
|
||||
dbRateLimits: rateLimitSet{
|
||||
@@ -112,7 +114,7 @@ func TestRateLimiter_Concurrency(t *testing.T) {
|
||||
l := zerolog.Nop()
|
||||
|
||||
mockRateLimitRepo := &mockRateLimitRepo{}
|
||||
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 100, "key2": 100}, nil)
|
||||
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 100, "key2": 100}, time.Now().Add(2*time.Second), nil)
|
||||
|
||||
rateLimiter := &rateLimiter{
|
||||
dbRateLimits: rateLimitSet{
|
||||
@@ -151,7 +153,7 @@ func TestRateLimiter_FlushToDatabase(t *testing.T) {
|
||||
l := zerolog.Nop()
|
||||
|
||||
mockRateLimitRepo := &mockRateLimitRepo{} // Mock implementation of rateLimitRepo
|
||||
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5}, nil)
|
||||
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5}, time.Now().Add(2*time.Second), nil)
|
||||
|
||||
rateLimiter := &rateLimiter{
|
||||
dbRateLimits: rateLimitSet{
|
||||
@@ -184,7 +186,7 @@ func BenchmarkRateLimiter(b *testing.B) {
|
||||
l := zerolog.Nop()
|
||||
|
||||
mockRateLimitRepo := &mockRateLimitRepo{}
|
||||
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 1000, "key2": 1000}, nil)
|
||||
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 1000, "key2": 1000}, time.Now().Add(2*time.Second), nil)
|
||||
|
||||
r := rateLimiter{
|
||||
unacked: make(map[int64]rateLimitSet),
|
||||
|
||||
Reference in New Issue
Block a user