Files
hatchet/pkg/scheduling/v1/rate_limit_test.go
matt 058968c06b Refactor: Attempt II at removing pgtype.UUID everywhere + convert string UUIDs into uuid.UUID (#2894)
* fix: add type override in sqlc.yaml

* chore: gen sqlc

* chore: big find and replace

* chore: more

* fix: clean up bunch of outdated `.Valid` refs

* refactor: remove `sqlchelpers.uuidFromStr()` in favor of `uuid.MustParse()`

* refactor: remove uuidToStr

* fix: lint

* fix: use pointers for null uuids

* chore: clean up more null pointers

* chore: clean up a bunch more

* fix: couple more

* fix: some types on the api

* fix: incorrectly non-null param

* fix: more nullable params

* fix: more refs

* refactor: start replacing tenant id strings with uuids

* refactor: more tenant id uuid casting

* refactor: fix a bunch more

* refactor: more

* refactor: more

* refactor: is that all of them?!

* fix: panic

* fix: rm scans

* fix: unwind some broken things

* chore: tests

* fix: rebase issues

* fix: more tests

* fix: nil checks

* Refactor: Make all UUIDs into `uuid.UUID` (#2897)

* refactor: remove a bunch more string uuids

* refactor: pointers and lists

* refactor: fix all the refs

* refactor: fix a few more

* fix: config loader

* fix: revert some changes

* fix: tests

* fix: test

* chore: proto

* fix: durable listener

* fix: some more string types

* fix: python health worker sleep

* fix: remove a bunch of `MustParse`s from the various gRPC servers

* fix: rm more uuid.MustParse calls

* fix: rm mustparse from api

* fix: test

* fix: merge issues

* fix: handle a bunch more uses of `MustParse` everywhere

* fix: nil id for worker label

* fix: more casting in the oss

* fix: more id parsing

* fix: stringify jwt opt

* fix: couple more bugs in untyped calls

* fix: more types

* fix: broken test

* refactor: implement `GetKeyUuid`

* chore: regen sqlc

* chore: replace pgtype.UUID again

* fix: bunch more type errors

* fix: panic
2026-02-03 11:02:59 -05:00

263 lines
7.8 KiB
Go

//go:build !e2e && !load && !rampup && !integration
package v1
import (
"context"
"fmt"
"math/rand"
"sync"
"testing"
"time"
"github.com/google/uuid"
v1 "github.com/hatchet-dev/hatchet/pkg/repository"
"github.com/hatchet-dev/hatchet/pkg/repository/sqlcv1"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type mockRateLimitRepo struct {
mock.Mock
}
func (m *mockRateLimitRepo) UpdateRateLimits(ctx context.Context, tenantId uuid.UUID, updates map[string]int) ([]*sqlcv1.ListRateLimitsForTenantWithMutateRow, *time.Time, error) {
args := m.Called(ctx, tenantId, updates)
return args.Get(0).([]*sqlcv1.ListRateLimitsForTenantWithMutateRow), args.Get(1).(*time.Time), args.Error(2)
}
func (m *mockRateLimitRepo) UpsertRateLimit(ctx context.Context, tenantId uuid.UUID, key string, opts *v1.UpsertRateLimitOpts) (*sqlcv1.RateLimit, error) {
panic("not implemented")
}
func (m *mockRateLimitRepo) ListRateLimits(ctx context.Context, tenantId uuid.UUID, opts *v1.ListRateLimitOpts) (*v1.ListRateLimitsResult, error) {
panic("not implemented")
}
func TestRateLimiter_Use(t *testing.T) {
l := zerolog.Nop()
mockRateLimitRepo := &mockRateLimitRepo{}
mockRows := []*sqlcv1.ListRateLimitsForTenantWithMutateRow{
{Key: "key1", Value: 10},
{Key: "key2", Value: 5},
{Key: "key3", Value: 7},
}
nextRefill := time.Now().Add(2 * time.Second)
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(mockRows, &nextRefill, nil)
rateLimiter := &rateLimiter{
dbRateLimits: rateLimitSet{
"key1": {key: "key1", val: 10},
"key2": {key: "key2", val: 5},
"key3": {key: "key3", val: 7},
},
unacked: make(map[int64]rateLimitSet),
unflushed: make(rateLimitSet),
l: &l,
rateLimitRepo: mockRateLimitRepo,
}
// Test simple rate limit usage
res := rateLimiter.use(context.Background(), 1, map[string]int32{"key1": 5})
assert.True(t, res.succeeded)
res = rateLimiter.use(context.Background(), 2, map[string]int32{"key1": 6})
assert.False(t, res.succeeded)
// Test multiple keys
res = rateLimiter.use(context.Background(), 3, map[string]int32{"key2": 3, "key3": 4})
assert.True(t, res.succeeded)
res = rateLimiter.use(context.Background(), 4, map[string]int32{"key2": 3, "key3": 4})
assert.False(t, res.succeeded)
}
func TestRateLimiter_Ack(t *testing.T) {
l := zerolog.Nop()
mockRateLimitRepo := &mockRateLimitRepo{}
mockRows := []*sqlcv1.ListRateLimitsForTenantWithMutateRow{
{Key: "key1", Value: 10},
{Key: "key2", Value: 5},
}
nextRefill := time.Now().Add(2 * time.Second)
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(mockRows, &nextRefill, nil)
rateLimiter := &rateLimiter{
dbRateLimits: rateLimitSet{
"key1": {key: "key1", val: 10},
"key2": {key: "key2", val: 5},
},
unacked: make(map[int64]rateLimitSet),
unflushed: make(rateLimitSet),
l: &l,
rateLimitRepo: mockRateLimitRepo,
}
rateLimiter.use(context.Background(), 1, map[string]int32{"key1": 5})
rateLimiter.ack(1)
// Verify unacked is empty and unflushed contains step1 rate limits
assert.Empty(t, rateLimiter.unacked)
assert.Equal(t, 5, rateLimiter.unflushed["key1"].val)
}
func TestRateLimiter_Nack(t *testing.T) {
l := zerolog.Nop()
mockRateLimitRepo := &mockRateLimitRepo{}
mockRows := []*sqlcv1.ListRateLimitsForTenantWithMutateRow{
{Key: "key1", Value: 10},
{Key: "key2", Value: 5},
}
nextRefill := time.Now().Add(2 * time.Second)
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(mockRows, &nextRefill, nil)
rateLimiter := &rateLimiter{
dbRateLimits: rateLimitSet{
"key1": {key: "key1", val: 10},
"key2": {key: "key2", val: 5},
},
unacked: make(map[int64]rateLimitSet),
unflushed: make(rateLimitSet),
l: &l,
rateLimitRepo: mockRateLimitRepo,
}
rateLimiter.use(context.Background(), 1, map[string]int32{"key1": 5})
rateLimiter.nack(1)
// Verify unacked is empty and unflushed doesn't contain step1 rate limits
assert.Empty(t, rateLimiter.unacked)
assert.NotContains(t, rateLimiter.unflushed, 1)
}
func TestRateLimiter_Concurrency(t *testing.T) {
l := zerolog.Nop()
mockRateLimitRepo := &mockRateLimitRepo{}
mockRows := []*sqlcv1.ListRateLimitsForTenantWithMutateRow{
{Key: "key1", Value: 100},
{Key: "key2", Value: 100},
}
nextRefill := time.Now().Add(2 * time.Second)
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(mockRows, &nextRefill, nil)
rateLimiter := &rateLimiter{
dbRateLimits: rateLimitSet{
"key1": {key: "key1", val: 100},
"key2": {key: "key2", val: 100},
},
unacked: make(map[int64]rateLimitSet),
unflushed: make(rateLimitSet),
l: &l,
rateLimitRepo: mockRateLimitRepo,
}
var wg sync.WaitGroup
numUsers := 100
useAmount := 1
wg.Add(numUsers)
for i := 0; i < numUsers; i++ {
go func(taskId int64) {
defer wg.Done()
res := rateLimiter.use(context.Background(), taskId, map[string]int32{"key1": int32(useAmount), "key2": int32(useAmount)}) // nolint: gosec
assert.True(t, res.succeeded)
rateLimiter.ack(taskId)
}(
int64(i),
)
}
wg.Wait()
// After all usages, the total used amount should be numUsers * useAmount
assert.Equal(t, numUsers*useAmount, rateLimiter.unflushed["key1"].val)
}
func TestRateLimiter_FlushToDatabase(t *testing.T) {
l := zerolog.Nop()
mockRateLimitRepo := &mockRateLimitRepo{} // Mock implementation of rateLimitRepo
mockRows := []*sqlcv1.ListRateLimitsForTenantWithMutateRow{
{Key: "key1", Value: 10},
{Key: "key2", Value: 5},
}
nextRefill := time.Now().Add(2 * time.Second)
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(mockRows, &nextRefill, nil)
rateLimiter := &rateLimiter{
dbRateLimits: rateLimitSet{
"key1": {key: "key1", val: 10},
"key2": {key: "key2", val: 5},
},
unacked: make(map[int64]rateLimitSet),
unflushed: make(rateLimitSet),
l: &l,
rateLimitRepo: mockRateLimitRepo,
}
// Add some rate limits to unflushed
rateLimiter.unflushed["key1"] = &rateLimit{key: "key1", val: 5}
rateLimiter.unflushed["key2"] = &rateLimit{key: "key2", val: 3}
// Flush rate limits to database
err := rateLimiter.flushToDatabase(context.Background())
assert.NoError(t, err)
// Verify that dbRateLimits contains the updated values
assert.Equal(t, 10, rateLimiter.dbRateLimits["key1"].val)
assert.Equal(t, 5, rateLimiter.dbRateLimits["key2"].val)
// Verify that unflushed is empty
assert.Empty(t, rateLimiter.unflushed)
}
func BenchmarkRateLimiter(b *testing.B) {
l := zerolog.Nop()
mockRateLimitRepo := &mockRateLimitRepo{}
mockRows := []*sqlcv1.ListRateLimitsForTenantWithMutateRow{
{Key: "key1", Value: 1000},
{Key: "key2", Value: 1000},
}
nextRefill := time.Now().Add(2 * time.Second)
mockRateLimitRepo.On("UpdateRateLimits", context.Background(), mock.Anything, mock.Anything).Return(mockRows, &nextRefill, nil)
r := rateLimiter{
unacked: make(map[int64]rateLimitSet),
unflushed: make(rateLimitSet),
dbRateLimits: make(rateLimitSet),
l: &l,
rateLimitRepo: mockRateLimitRepo,
}
// Initialize dbRateLimits with some random rate limits
for i := 0; i < 1000; i++ {
key := fmt.Sprintf("rate_limit_%d", i)
value := rand.Intn(1000) // nolint: gosec
r.dbRateLimits[key] = &rateLimit{
key: key,
val: value,
}
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
count := 0
for pb.Next() {
taskId := int64(count)
requests := map[string]int32{
"rate_limit_1": rand.Int31n(5), // nolint: gosec
"rate_limit_2": rand.Int31n(5), // nolint: gosec
"rate_limit_3": rand.Int31n(5), // nolint: gosec
}
r.use(context.Background(), taskId, requests)
count++
}
})
}