Files
hatchet/internal/msgqueue/v1/mq_pub_buffer.go

218 lines
4.2 KiB
Go

package v1
import (
"context"
"encoding/json"
"os"
"strconv"
"sync"
"time"
)
// nolint: staticcheck
var (
PUB_FLUSH_INTERVAL = 10 * time.Millisecond
PUB_BUFFER_SIZE = 1000
PUB_MAX_CONCURRENCY = 1
PUB_TIMEOUT = 10 * time.Second
)
func init() {
if os.Getenv("SERVER_DEFAULT_BUFFER_FLUSH_INTERVAL") != "" {
if v, err := time.ParseDuration(os.Getenv("SERVER_DEFAULT_BUFFER_FLUSH_INTERVAL")); err == nil {
PUB_FLUSH_INTERVAL = v
}
}
if os.Getenv("SERVER_DEFAULT_BUFFER_SIZE") != "" {
v := os.Getenv("SERVER_DEFAULT_BUFFER_SIZE")
maxSize, err := strconv.Atoi(v)
if err == nil {
PUB_BUFFER_SIZE = maxSize
}
}
if os.Getenv("SERVER_DEFAULT_BUFFER_CONCURRENCY") != "" {
v := os.Getenv("SERVER_DEFAULT_BUFFER_CONCURRENCY")
maxConcurrency, err := strconv.Atoi(v)
if err == nil {
PUB_MAX_CONCURRENCY = maxConcurrency
}
}
}
type PubFunc func(m *Message) error
// MQPubBuffer buffers messages coming out of the task queue, groups them by tenantId and msgId, and then flushes them
// to the task handler as necessary.
type MQPubBuffer struct {
mq MessageQueue
// buffers is keyed on (tenantId, msgId) and contains a buffer of messages for that tenantId and msgId.
buffers sync.Map
ctx context.Context
cancel context.CancelFunc
}
func NewMQPubBuffer(mq MessageQueue) *MQPubBuffer {
ctx, cancel := context.WithCancel(context.Background())
return &MQPubBuffer{
mq: mq,
ctx: ctx,
cancel: cancel,
}
}
func (m *MQPubBuffer) Stop() {
m.cancel()
}
type msgWithErrCh struct {
msg *Message
errCh chan error
}
func (m *MQPubBuffer) Pub(ctx context.Context, queue Queue, msg *Message, wait bool) error {
if msg.TenantID == "" {
return nil
}
k := getPubKey(queue, msg.TenantID, msg.ID)
buf, ok := m.buffers.Load(k)
if !ok {
buf, _ = m.buffers.LoadOrStore(k, newMsgIDPubBuffer(m.ctx, msg.TenantID, msg.ID, func(msg *Message) error {
msgCtx, cancel := context.WithTimeout(context.Background(), PUB_TIMEOUT)
defer cancel()
return m.mq.SendMessage(msgCtx, queue, msg)
}))
}
msgWithErr := &msgWithErrCh{
msg: msg,
}
if wait {
msgWithErr.errCh = make(chan error)
}
// this places some backpressure on the consumer if buffers are full
msgBuf := buf.(*msgIdPubBuffer)
msgBuf.msgIdPubBufferCh <- msgWithErr
msgBuf.notifier <- struct{}{}
if wait {
return <-msgWithErr.errCh
}
return nil
}
func getPubKey(q Queue, tenantId, msgId string) string {
return q.Name() + tenantId + msgId
}
type msgIdPubBuffer struct {
tenantId string
msgId string
msgIdPubBufferCh chan *msgWithErrCh
notifier chan struct{}
pub PubFunc
semaphore chan struct{}
serialize func(t any) ([]byte, error)
}
func newMsgIDPubBuffer(ctx context.Context, tenantID, msgID string, pub PubFunc) *msgIdPubBuffer {
b := &msgIdPubBuffer{
tenantId: tenantID,
msgId: msgID,
msgIdPubBufferCh: make(chan *msgWithErrCh, PUB_BUFFER_SIZE),
notifier: make(chan struct{}),
pub: pub,
serialize: json.Marshal,
semaphore: make(chan struct{}, PUB_MAX_CONCURRENCY),
}
b.startFlusher(ctx)
return b
}
func (m *msgIdPubBuffer) startFlusher(ctx context.Context) {
ticker := time.NewTicker(PUB_FLUSH_INTERVAL)
go func() {
for {
select {
case <-ctx.Done():
m.flush()
return
case <-ticker.C:
go m.flush()
case <-m.notifier:
go m.flush()
}
}
}()
}
func (m *msgIdPubBuffer) flush() {
select {
case m.semaphore <- struct{}{}:
default:
return
}
startedFlush := time.Now()
defer func() {
go func() {
<-time.After(PUB_FLUSH_INTERVAL - time.Since(startedFlush))
<-m.semaphore
}()
}()
msgsWithErrCh := make([]*msgWithErrCh, 0)
payloadBytes := make([][]byte, 0)
// read all messages currently in the buffer
for i := 0; i < PUB_BUFFER_SIZE; i++ {
select {
case msg := <-m.msgIdPubBufferCh:
msgsWithErrCh = append(msgsWithErrCh, msg)
payloadBytes = append(payloadBytes, msg.msg.Payloads...)
default:
i = PUB_BUFFER_SIZE
}
}
if len(payloadBytes) == 0 {
return
}
err := m.pub(&Message{
TenantID: m.tenantId,
ID: m.msgId,
Payloads: payloadBytes,
})
for _, msgWithErrCh := range msgsWithErrCh {
if msgWithErrCh.errCh != nil {
msgWithErrCh.errCh <- err
}
}
}