Files
hatchet/internal/msgqueue/shared_tenant_reader.go
matt 058968c06b Refactor: Attempt II at removing pgtype.UUID everywhere + convert string UUIDs into uuid.UUID (#2894)
* fix: add type override in sqlc.yaml

* chore: gen sqlc

* chore: big find and replace

* chore: more

* fix: clean up bunch of outdated `.Valid` refs

* refactor: remove `sqlchelpers.uuidFromStr()` in favor of `uuid.MustParse()`

* refactor: remove uuidToStr

* fix: lint

* fix: use pointers for null uuids

* chore: clean up more null pointers

* chore: clean up a bunch more

* fix: couple more

* fix: some types on the api

* fix: incorrectly non-null param

* fix: more nullable params

* fix: more refs

* refactor: start replacing tenant id strings with uuids

* refactor: more tenant id uuid casting

* refactor: fix a bunch more

* refactor: more

* refactor: more

* refactor: is that all of them?!

* fix: panic

* fix: rm scans

* fix: unwind some broken things

* chore: tests

* fix: rebase issues

* fix: more tests

* fix: nil checks

* Refactor: Make all UUIDs into `uuid.UUID` (#2897)

* refactor: remove a bunch more string uuids

* refactor: pointers and lists

* refactor: fix all the refs

* refactor: fix a few more

* fix: config loader

* fix: revert some changes

* fix: tests

* fix: test

* chore: proto

* fix: durable listener

* fix: some more string types

* fix: python health worker sleep

* fix: remove a bunch of `MustParse`s from the various gRPC servers

* fix: rm more uuid.MustParse calls

* fix: rm mustparse from api

* fix: test

* fix: merge issues

* fix: handle a bunch more uses of `MustParse` everywhere

* fix: nil id for worker label

* fix: more casting in the oss

* fix: more id parsing

* fix: stringify jwt opt

* fix: couple more bugs in untyped calls

* fix: more types

* fix: broken test

* refactor: implement `GetKeyUuid`

* chore: regen sqlc

* chore: replace pgtype.UUID again

* fix: bunch more type errors

* fix: panic
2026-02-03 11:02:59 -05:00

196 lines
3.4 KiB
Go

package msgqueue
import (
"context"
"sync"
"time"
"github.com/google/uuid"
"github.com/hashicorp/go-multierror"
)
type sharedTenantSub struct {
fs *sync.Map
counter int
isRunning bool
mu sync.Mutex
cleanup func() error
}
type SharedTenantReader struct {
tenants *sync.Map
mq MessageQueue
}
func NewSharedTenantReader(mq MessageQueue) *SharedTenantReader {
return &SharedTenantReader{
tenants: &sync.Map{},
mq: mq,
}
}
func (s *SharedTenantReader) Subscribe(tenantId uuid.UUID, postAck AckHook) (func() error, error) {
tenant, _ := s.tenants.LoadOrStore(tenantId, &sharedTenantSub{
fs: &sync.Map{},
})
t := tenant.(*sharedTenantSub)
t.mu.Lock()
defer t.mu.Unlock()
t.counter++
subId := t.counter
t.fs.Store(subId, postAck)
if !t.isRunning {
t.isRunning = true
q := TenantEventConsumerQueue(tenantId)
err := s.mq.RegisterTenant(context.Background(), tenantId)
if err != nil {
return nil, err
}
cleanupSingleSub, err := s.mq.Subscribe(q, NoOpHook, func(task *Message) error {
var innerErr error
t.fs.Range(func(key, value interface{}) bool {
f := value.(AckHook)
if err := f(task); err != nil {
innerErr = multierror.Append(innerErr, err)
}
return true
})
return innerErr
})
if err != nil {
return nil, err
}
t.cleanup = cleanupSingleSub
}
return func() error {
t.mu.Lock()
defer t.mu.Unlock()
t.fs.Delete(subId)
if lenSyncMap(t.fs) == 0 {
// shut down the subscription
if t.cleanup != nil {
if err := t.cleanup(); err != nil {
return err
}
}
t.isRunning = false
}
return nil
}, nil
}
type SharedBufferedTenantReader struct {
tenants *sync.Map
mq MessageQueue
}
func NewSharedBufferedTenantReader(mq MessageQueue) *SharedBufferedTenantReader {
return &SharedBufferedTenantReader{
tenants: &sync.Map{},
mq: mq,
}
}
func (s *SharedBufferedTenantReader) Subscribe(tenantId uuid.UUID, f DstFunc) (func() error, error) {
tenant, _ := s.tenants.LoadOrStore(tenantId, &sharedTenantSub{
fs: &sync.Map{},
})
t := tenant.(*sharedTenantSub)
t.mu.Lock()
defer t.mu.Unlock()
t.counter++
subId := t.counter
t.fs.Store(subId, f)
if !t.isRunning {
t.isRunning = true
q := TenantEventConsumerQueue(tenantId)
err := s.mq.RegisterTenant(context.Background(), tenantId)
if err != nil {
return nil, err
}
subBuffer := NewMQSubBuffer(q, s.mq, func(tenantId uuid.UUID, msgId string, payloads [][]byte) error {
var innerErr error
t.fs.Range(func(key, value interface{}) bool {
f := value.(DstFunc)
if err := f(tenantId, msgId, payloads); err != nil {
innerErr = multierror.Append(innerErr, err)
}
return true
})
return innerErr
}, WithKind(PostAck), WithMaxConcurrency(1), WithFlushInterval(20*time.Millisecond), WithDisableImmediateFlush(true))
cleanupSingleSub, err := subBuffer.Start()
if err != nil {
return nil, err
}
t.cleanup = cleanupSingleSub
}
return func() error {
t.mu.Lock()
defer t.mu.Unlock()
t.fs.Delete(subId)
if lenSyncMap(t.fs) == 0 {
// shut down the subscription
if t.cleanup != nil {
if err := t.cleanup(); err != nil {
return err
}
}
t.isRunning = false
}
return nil
}, nil
}
func lenSyncMap(m *sync.Map) int {
var i int
m.Range(func(k, v interface{}) bool {
i++
return true
})
return i
}