Files
hatchet/internal/msgqueue/mq_pub_buffer.go
matt 058968c06b Refactor: Attempt II at removing pgtype.UUID everywhere + convert string UUIDs into uuid.UUID (#2894)
* fix: add type override in sqlc.yaml

* chore: gen sqlc

* chore: big find and replace

* chore: more

* fix: clean up bunch of outdated `.Valid` refs

* refactor: remove `sqlchelpers.uuidFromStr()` in favor of `uuid.MustParse()`

* refactor: remove uuidToStr

* fix: lint

* fix: use pointers for null uuids

* chore: clean up more null pointers

* chore: clean up a bunch more

* fix: couple more

* fix: some types on the api

* fix: incorrectly non-null param

* fix: more nullable params

* fix: more refs

* refactor: start replacing tenant id strings with uuids

* refactor: more tenant id uuid casting

* refactor: fix a bunch more

* refactor: more

* refactor: more

* refactor: is that all of them?!

* fix: panic

* fix: rm scans

* fix: unwind some broken things

* chore: tests

* fix: rebase issues

* fix: more tests

* fix: nil checks

* Refactor: Make all UUIDs into `uuid.UUID` (#2897)

* refactor: remove a bunch more string uuids

* refactor: pointers and lists

* refactor: fix all the refs

* refactor: fix a few more

* fix: config loader

* fix: revert some changes

* fix: tests

* fix: test

* chore: proto

* fix: durable listener

* fix: some more string types

* fix: python health worker sleep

* fix: remove a bunch of `MustParse`s from the various gRPC servers

* fix: rm more uuid.MustParse calls

* fix: rm mustparse from api

* fix: test

* fix: merge issues

* fix: handle a bunch more uses of `MustParse` everywhere

* fix: nil id for worker label

* fix: more casting in the oss

* fix: more id parsing

* fix: stringify jwt opt

* fix: couple more bugs in untyped calls

* fix: more types

* fix: broken test

* refactor: implement `GetKeyUuid`

* chore: regen sqlc

* chore: replace pgtype.UUID again

* fix: bunch more type errors

* fix: panic
2026-02-03 11:02:59 -05:00

273 lines
5.3 KiB
Go

package msgqueue
import (
"context"
"encoding/json"
"os"
"strconv"
"sync"
"time"
"github.com/google/uuid"
)
// nolint: staticcheck
var (
PUB_FLUSH_INTERVAL = 10 * time.Millisecond
PUB_BUFFER_SIZE = 1000
PUB_MAX_CONCURRENCY = 1
PUB_TIMEOUT = 10 * time.Second
)
func init() {
if os.Getenv("SERVER_DEFAULT_BUFFER_FLUSH_INTERVAL") != "" {
if v, err := time.ParseDuration(os.Getenv("SERVER_DEFAULT_BUFFER_FLUSH_INTERVAL")); err == nil {
PUB_FLUSH_INTERVAL = v
}
}
if os.Getenv("SERVER_DEFAULT_BUFFER_SIZE") != "" {
v := os.Getenv("SERVER_DEFAULT_BUFFER_SIZE")
maxSize, err := strconv.Atoi(v)
if err == nil {
PUB_BUFFER_SIZE = maxSize
}
}
if os.Getenv("SERVER_DEFAULT_BUFFER_CONCURRENCY") != "" {
v := os.Getenv("SERVER_DEFAULT_BUFFER_CONCURRENCY")
maxConcurrency, err := strconv.Atoi(v)
if err == nil {
PUB_MAX_CONCURRENCY = maxConcurrency
}
}
}
type PubFunc func(m *Message) error
// MQPubBuffer buffers messages coming out of the task queue, groups them by tenantId and msgId, and then flushes them
// to the task handler as necessary.
type MQPubBuffer struct {
mq MessageQueue
// buffers is keyed on (tenantId, msgId) and contains a buffer of messages for that tenantId and msgId.
buffers sync.Map
ctx context.Context
cancel context.CancelFunc
}
func NewMQPubBuffer(mq MessageQueue) *MQPubBuffer {
ctx, cancel := context.WithCancel(context.Background())
return &MQPubBuffer{
mq: mq,
ctx: ctx,
cancel: cancel,
}
}
func (m *MQPubBuffer) Stop() {
m.cancel()
}
type msgWithErrCh struct {
msg *Message
errCh chan error
}
func (m *MQPubBuffer) Pub(ctx context.Context, queue Queue, msg *Message, wait bool) error {
if msg.TenantID == uuid.Nil {
return nil
}
k := getPubKey(queue, msg.TenantID, msg.ID)
buf, ok := m.buffers.Load(k)
if !ok {
buf, _ = m.buffers.LoadOrStore(k, newMsgIDPubBuffer(m.ctx, msg.TenantID, msg.ID, func(msg *Message) error {
msgCtx, cancel := context.WithTimeout(context.Background(), PUB_TIMEOUT)
defer cancel()
return m.mq.SendMessage(msgCtx, queue, msg)
}))
}
msgWithErr := &msgWithErrCh{
msg: msg,
}
if wait {
msgWithErr.errCh = make(chan error)
}
// this places some backpressure on the consumer if buffers are full
msgBuf := buf.(*msgIdPubBuffer)
msgBuf.msgIdPubBufferCh <- msgWithErr
msgBuf.notifier <- struct{}{}
if wait {
return <-msgWithErr.errCh
}
return nil
}
func getPubKey(q Queue, tenantId uuid.UUID, msgId string) string {
return q.Name() + tenantId.String() + msgId
}
type msgIdPubBuffer struct {
tenantId uuid.UUID
msgId string
msgIdPubBufferCh chan *msgWithErrCh
notifier chan struct{}
pub PubFunc
semaphore chan struct{}
semaphoreRelease chan time.Duration
serialize func(t any) ([]byte, error)
}
func newMsgIDPubBuffer(ctx context.Context, tenantID uuid.UUID, msgID string, pub PubFunc) *msgIdPubBuffer {
b := &msgIdPubBuffer{
tenantId: tenantID,
msgId: msgID,
msgIdPubBufferCh: make(chan *msgWithErrCh, PUB_BUFFER_SIZE),
notifier: make(chan struct{}),
pub: pub,
serialize: json.Marshal,
semaphore: make(chan struct{}, PUB_MAX_CONCURRENCY),
semaphoreRelease: make(chan time.Duration, PUB_MAX_CONCURRENCY),
}
b.startFlusher(ctx)
b.startSemaphoreReleaser(ctx)
return b
}
func (m *msgIdPubBuffer) startFlusher(ctx context.Context) {
go func() {
ticker := time.NewTicker(PUB_FLUSH_INTERVAL)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
m.flush()
return
case <-ticker.C:
go m.flush()
case <-m.notifier:
go m.flush()
}
}
}()
}
func (m *msgIdPubBuffer) startSemaphoreReleaser(ctx context.Context) {
go func() {
timer := time.NewTimer(0)
defer timer.Stop()
for {
select {
case <-ctx.Done():
return
case delay := <-m.semaphoreRelease:
if delay > 0 {
timer.Reset(delay)
<-timer.C
}
<-m.semaphore
}
}
}()
}
func (m *msgIdPubBuffer) flush() {
select {
case m.semaphore <- struct{}{}:
default:
return
}
startedFlush := time.Now()
defer func() {
go func() {
delay := PUB_FLUSH_INTERVAL - time.Since(startedFlush)
m.semaphoreRelease <- delay
}()
}()
msgsWithErrCh := make([]*msgWithErrCh, 0)
payloadBytes := make([][]byte, 0)
var isPersistent *bool
var immediatelyExpire *bool
var retries *int
// read all messages currently in the buffer
for i := 0; i < PUB_BUFFER_SIZE; i++ {
select {
case msg := <-m.msgIdPubBufferCh:
msgsWithErrCh = append(msgsWithErrCh, msg)
payloadBytes = append(payloadBytes, msg.msg.Payloads...)
if isPersistent == nil {
isPersistent = &msg.msg.Persistent
}
if immediatelyExpire == nil {
immediatelyExpire = &msg.msg.ImmediatelyExpire
}
if retries == nil {
retries = &msg.msg.Retries
}
default:
i = PUB_BUFFER_SIZE
}
}
if len(payloadBytes) == 0 {
return
}
msgToSend := &Message{
TenantID: m.tenantId,
ID: m.msgId,
Payloads: payloadBytes,
}
if isPersistent != nil {
msgToSend.Persistent = *isPersistent
}
if immediatelyExpire != nil {
msgToSend.ImmediatelyExpire = *immediatelyExpire
}
if retries != nil {
msgToSend.Retries = *retries
}
err := m.pub(msgToSend)
for _, msgWithErrCh := range msgsWithErrCh {
if msgWithErrCh.errCh != nil {
msgWithErrCh.errCh <- err
}
}
}