diff --git a/cmd/hatchet-migrate/migrate/migrations/20250320210345_v1_0_6.sql b/cmd/hatchet-migrate/migrate/migrations/20250320210345_v1_0_6.sql new file mode 100644 index 000000000..10a3acd4c --- /dev/null +++ b/cmd/hatchet-migrate/migrate/migrations/20250320210345_v1_0_6.sql @@ -0,0 +1,33 @@ +-- +goose Up +-- +goose StatementBegin +CREATE OR REPLACE FUNCTION get_refill_value(rate_limit "RateLimit") +RETURNS INTEGER AS $$ +DECLARE + refill_amount INTEGER; +BEGIN + IF (NOW() - rate_limit."lastRefill") >= (rate_limit."window"::INTERVAL - INTERVAL '10 milliseconds') THEN + refill_amount := rate_limit."limitValue"; + ELSE + refill_amount := rate_limit."value"; + END IF; + RETURN refill_amount; +END; +$$ LANGUAGE plpgsql; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +CREATE OR REPLACE FUNCTION get_refill_value(rate_limit "RateLimit") +RETURNS INTEGER AS $$ +DECLARE + refill_amount INTEGER; +BEGIN + IF (NOW() - rate_limit."lastRefill") >= (rate_limit."window"::INTERVAL) THEN + refill_amount := rate_limit."limitValue"; + ELSE + refill_amount := rate_limit."value"; + END IF; + RETURN refill_amount; +END; +$$ LANGUAGE plpgsql; +-- +goose StatementEnd diff --git a/pkg/repository/v1/scheduler.go b/pkg/repository/v1/scheduler.go index 72fd478d2..9286e95c3 100644 --- a/pkg/repository/v1/scheduler.go +++ b/pkg/repository/v1/scheduler.go @@ -2,6 +2,7 @@ package v1 import ( "context" + "time" "github.com/hatchet-dev/hatchet/pkg/repository/v1/sqlcv1" "github.com/jackc/pgx/v5/pgtype" @@ -39,7 +40,7 @@ type QueueRepository interface { type RateLimitRepository interface { ListCandidateRateLimits(ctx context.Context, tenantId pgtype.UUID) ([]string, error) - UpdateRateLimits(ctx context.Context, tenantId pgtype.UUID, updates map[string]int) (map[string]int, error) + UpdateRateLimits(ctx context.Context, tenantId pgtype.UUID, updates map[string]int) (map[string]int, *time.Time, error) } type AssignmentRepository interface { diff --git a/pkg/repository/v1/scheduler_rate_limit.go b/pkg/repository/v1/scheduler_rate_limit.go index ad54ae65a..d8f4cff30 100644 --- a/pkg/repository/v1/scheduler_rate_limit.go +++ b/pkg/repository/v1/scheduler_rate_limit.go @@ -2,6 +2,7 @@ package v1 import ( "context" + "time" "github.com/hatchet-dev/hatchet/pkg/repository/postgres/sqlchelpers" "github.com/hatchet-dev/hatchet/pkg/repository/v1/sqlcv1" @@ -39,11 +40,11 @@ func (d *rateLimitRepository) ListCandidateRateLimits(ctx context.Context, tenan return ids, nil } -func (d *rateLimitRepository) UpdateRateLimits(ctx context.Context, tenantId pgtype.UUID, updates map[string]int) (map[string]int, error) { +func (d *rateLimitRepository) UpdateRateLimits(ctx context.Context, tenantId pgtype.UUID, updates map[string]int) (map[string]int, *time.Time, error) { tx, commit, rollback, err := sqlchelpers.PrepareTx(ctx, d.pool, d.l, 5000) if err != nil { - return nil, err + return nil, nil, err } defer rollback() @@ -62,17 +63,17 @@ func (d *rateLimitRepository) UpdateRateLimits(ctx context.Context, tenantId pgt _, err = d.queries.BulkUpdateRateLimits(ctx, tx, params) if err != nil { - return nil, err + return nil, nil, err } newRls, err := d.queries.ListRateLimitsForTenantWithMutate(ctx, tx, tenantId) if err != nil { - return nil, err + return nil, nil, err } if err := commit(ctx); err != nil { - return nil, err + return nil, nil, err } res := make(map[string]int, len(newRls)) @@ -81,5 +82,16 @@ func (d *rateLimitRepository) UpdateRateLimits(ctx context.Context, tenantId pgt res[rl.Key] = int(rl.Value) } - return res, err + nextRefillAt := time.Now().Add(time.Second * 2) + + if len(newRls) > 0 { + // get min of all next refill times + for _, rl := range newRls { + if rl.NextRefillAt.Time.Before(nextRefillAt) { + nextRefillAt = rl.NextRefillAt.Time + } + } + } + + return res, &nextRefillAt, err } diff --git a/pkg/repository/v1/sqlcv1/rate_limits.sql b/pkg/repository/v1/sqlcv1/rate_limits.sql index a81f1ee6d..6cf14da19 100644 --- a/pkg/repository/v1/sqlcv1/rate_limits.sql +++ b/pkg/repository/v1/sqlcv1/rate_limits.sql @@ -74,7 +74,7 @@ WITH rls_to_update AS ( "RateLimit" rl WHERE rl."tenantId" = @tenantId::uuid - AND NOW() - rl."lastRefill" >= rl."window"::INTERVAL + AND NOW() - rl."lastRefill" >= (rl."window"::INTERVAL - INTERVAL '10 milliseconds') ORDER BY rl."tenantId" ASC, rl."key" ASC FOR UPDATE @@ -93,7 +93,7 @@ WITH rls_to_update AS ( ) SELECT rl.*, - (rl."lastRefill" + rl."window"::INTERVAL)::timestamp AS "nextRefillAt" + (rl."lastRefill" + rl."window"::INTERVAL - INTERVAL '10 milliseconds')::timestamp AS "nextRefillAt" FROM "RateLimit" rl WHERE @@ -105,7 +105,7 @@ UNION ALL SELECT refill.*, -- return the next refill time - (refill."lastRefill" + refill."window"::INTERVAL)::timestamp AS "nextRefillAt" + (refill."lastRefill" + refill."window"::INTERVAL - INTERVAL '10 milliseconds')::timestamp AS "nextRefillAt" FROM refill; @@ -145,7 +145,7 @@ UPDATE SET "value" = get_refill_value(rl) - (SELECT "units" FROM input WHERE "key" = rl."key"), "lastRefill" = CASE - WHEN NOW() - rl."lastRefill" >= rl."window"::INTERVAL THEN + WHEN NOW() - rl."lastRefill" >= (rl."window"::INTERVAL - INTERVAL '10 milliseconds') THEN CURRENT_TIMESTAMP ELSE rl."lastRefill" diff --git a/pkg/repository/v1/sqlcv1/rate_limits.sql.go b/pkg/repository/v1/sqlcv1/rate_limits.sql.go index 25c329315..6d01416b7 100644 --- a/pkg/repository/v1/sqlcv1/rate_limits.sql.go +++ b/pkg/repository/v1/sqlcv1/rate_limits.sql.go @@ -38,7 +38,7 @@ UPDATE SET "value" = get_refill_value(rl) - (SELECT "units" FROM input WHERE "key" = rl."key"), "lastRefill" = CASE - WHEN NOW() - rl."lastRefill" >= rl."window"::INTERVAL THEN + WHEN NOW() - rl."lastRefill" >= (rl."window"::INTERVAL - INTERVAL '10 milliseconds') THEN CURRENT_TIMESTAMP ELSE rl."lastRefill" @@ -216,7 +216,7 @@ WITH rls_to_update AS ( "RateLimit" rl WHERE rl."tenantId" = $1::uuid - AND NOW() - rl."lastRefill" >= rl."window"::INTERVAL + AND NOW() - rl."lastRefill" >= (rl."window"::INTERVAL - INTERVAL '10 milliseconds') ORDER BY rl."tenantId" ASC, rl."key" ASC FOR UPDATE @@ -235,7 +235,7 @@ WITH rls_to_update AS ( ) SELECT rl."tenantId", rl.key, rl."limitValue", rl.value, rl."window", rl."lastRefill", - (rl."lastRefill" + rl."window"::INTERVAL)::timestamp AS "nextRefillAt" + (rl."lastRefill" + rl."window"::INTERVAL - INTERVAL '10 milliseconds')::timestamp AS "nextRefillAt" FROM "RateLimit" rl WHERE @@ -247,7 +247,7 @@ UNION ALL SELECT refill."tenantId", refill.key, refill."limitValue", refill.value, refill."window", refill."lastRefill", -- return the next refill time - (refill."lastRefill" + refill."window"::INTERVAL)::timestamp AS "nextRefillAt" + (refill."lastRefill" + refill."window"::INTERVAL - INTERVAL '10 milliseconds')::timestamp AS "nextRefillAt" FROM refill ` diff --git a/pkg/scheduling/v1/rate_limit.go b/pkg/scheduling/v1/rate_limit.go index f8ac07ae9..1cf488472 100644 --- a/pkg/scheduling/v1/rate_limit.go +++ b/pkg/scheduling/v1/rate_limit.go @@ -23,6 +23,9 @@ type rateLimiter struct { tenantId pgtype.UUID + nextRefillAt *time.Time + nextRefillAtMu sync.RWMutex + l *zerolog.Logger // unacked is a map of taskId to rateLimitSet @@ -106,6 +109,13 @@ func (r *rateLimiter) use(ctx context.Context, taskId int64, rls map[string]int3 if !r.rateLimitsExist(rls) { return res } + } else if r.shouldRefill() { + err := r.flushToDatabase(ctx) + + if err != nil { + r.l.Error().Err(err).Msg("error flushing rate limits to database") + return res + } } currRls := r.copyDbRateLimits() @@ -154,6 +164,17 @@ func (r *rateLimiter) rateLimitsExist(rls map[string]int32) bool { return true } +func (r *rateLimiter) shouldRefill() bool { + r.nextRefillAtMu.Lock() + defer r.nextRefillAtMu.Unlock() + + if r.nextRefillAt == nil { + return false + } + + return r.nextRefillAt.After(time.Now().UTC()) +} + func (r *rateLimiter) copyDbRateLimits() rateLimitSet { r.dbRateLimitsMu.RLock() defer r.dbRateLimitsMu.RUnlock() @@ -263,12 +284,17 @@ func (r *rateLimiter) flushToDatabase(ctx context.Context) error { updates[k] = v.val } - newRateLimits, err := r.rateLimitRepo.UpdateRateLimits(ctx, r.tenantId, updates) + newRateLimits, nextRefillAt, err := r.rateLimitRepo.UpdateRateLimits(ctx, r.tenantId, updates) if err != nil { return err } + // update the next refill time + r.nextRefillAtMu.Lock() + r.nextRefillAt = nextRefillAt + r.nextRefillAtMu.Unlock() + r.dbRateLimits = make(rateLimitSet) // update the db rate limits diff --git a/pkg/scheduling/v1/rate_limit_test.go b/pkg/scheduling/v1/rate_limit_test.go index 6fe447020..4720ab6d2 100644 --- a/pkg/scheduling/v1/rate_limit_test.go +++ b/pkg/scheduling/v1/rate_limit_test.go @@ -6,6 +6,7 @@ import ( "math/rand" "sync" "testing" + "time" "github.com/jackc/pgx/v5/pgtype" "github.com/rs/zerolog" @@ -22,16 +23,17 @@ func (m *mockRateLimitRepo) ListCandidateRateLimits(ctx context.Context, tenantI return args.Get(0).([]string), args.Error(1) } -func (m *mockRateLimitRepo) UpdateRateLimits(ctx context.Context, tenantId pgtype.UUID, updates map[string]int) (map[string]int, error) { +func (m *mockRateLimitRepo) UpdateRateLimits(ctx context.Context, tenantId pgtype.UUID, updates map[string]int) (map[string]int, *time.Time, error) { args := m.Called(ctx, tenantId, updates) - return args.Get(0).(map[string]int), args.Error(1) + arg1 := args.Get(1).(time.Time) + return args.Get(0).(map[string]int), &arg1, args.Error(2) } func TestRateLimiter_Use(t *testing.T) { l := zerolog.Nop() mockRateLimitRepo := &mockRateLimitRepo{} - mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5, "key3": 7}, nil) + mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5, "key3": 7}, time.Now().Add(2*time.Second), nil) rateLimiter := &rateLimiter{ dbRateLimits: rateLimitSet{ @@ -62,7 +64,7 @@ func TestRateLimiter_Ack(t *testing.T) { l := zerolog.Nop() mockRateLimitRepo := &mockRateLimitRepo{} - mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5}, nil) + mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5}, time.Now().Add(2*time.Second), nil) rateLimiter := &rateLimiter{ dbRateLimits: rateLimitSet{ @@ -87,7 +89,7 @@ func TestRateLimiter_Nack(t *testing.T) { l := zerolog.Nop() mockRateLimitRepo := &mockRateLimitRepo{} - mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5}, nil) + mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5}, time.Now().Add(2*time.Second), nil) rateLimiter := &rateLimiter{ dbRateLimits: rateLimitSet{ @@ -112,7 +114,7 @@ func TestRateLimiter_Concurrency(t *testing.T) { l := zerolog.Nop() mockRateLimitRepo := &mockRateLimitRepo{} - mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 100, "key2": 100}, nil) + mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 100, "key2": 100}, time.Now().Add(2*time.Second), nil) rateLimiter := &rateLimiter{ dbRateLimits: rateLimitSet{ @@ -151,7 +153,7 @@ func TestRateLimiter_FlushToDatabase(t *testing.T) { l := zerolog.Nop() mockRateLimitRepo := &mockRateLimitRepo{} // Mock implementation of rateLimitRepo - mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5}, nil) + mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 10, "key2": 5}, time.Now().Add(2*time.Second), nil) rateLimiter := &rateLimiter{ dbRateLimits: rateLimitSet{ @@ -184,7 +186,7 @@ func BenchmarkRateLimiter(b *testing.B) { l := zerolog.Nop() mockRateLimitRepo := &mockRateLimitRepo{} - mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 1000, "key2": 1000}, nil) + mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(map[string]int{"key1": 1000, "key2": 1000}, time.Now().Add(2*time.Second), nil) r := rateLimiter{ unacked: make(map[int64]rateLimitSet),