Files
hatchet/pkg/repository/mq.go
2026-01-28 11:04:04 -05:00

205 lines
5.3 KiB
Go

package repository
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/hatchet-dev/hatchet/pkg/repository/sqlchelpers"
"github.com/hatchet-dev/hatchet/pkg/repository/sqlcv1"
"github.com/hatchet-dev/hatchet/pkg/telemetry"
)
type PubSubMessage struct {
QueueName string `json:"queue_name"`
Payload json.RawMessage `json:"payload"`
}
type MessageQueueRepository interface {
// PubSub
Listen(ctx context.Context, name string, f func(ctx context.Context, notification *PubSubMessage) error) error
Notify(ctx context.Context, name string, payload string) error
// Queues
BindQueue(ctx context.Context, queue string, durable, autoDeleted, exclusive bool, exclusiveConsumer *string) error
UpdateQueueLastActive(ctx context.Context, queue string) error
CleanupQueues(ctx context.Context) error
// Messages
AddMessage(ctx context.Context, queue string, payload []byte) error
ReadMessages(ctx context.Context, queue string, qos int) ([]*sqlcv1.ReadMessagesRow, error)
AckMessage(ctx context.Context, id int64) error
CleanupMessageQueueItems(ctx context.Context) error
}
type messageQueueRepository struct {
*sharedRepository
m *multiplexedListener
}
func newMessageQueueRepository(shared *sharedRepository) (*messageQueueRepository, func() error) {
m := newMultiplexedListener(shared.l, shared.pool)
return &messageQueueRepository{
sharedRepository: shared,
m: m,
}, func() error {
m.cancel()
return nil
}
}
func (m *messageQueueRepository) Listen(ctx context.Context, name string, f func(ctx context.Context, notification *PubSubMessage) error) error {
return m.m.listen(ctx, name, f)
}
func (m *messageQueueRepository) Notify(ctx context.Context, name string, payload string) error {
wrappedPayload, err := m.m.wrapMessage(name, payload)
if err != nil {
m.l.Error().Err(err).Msg("error wrapping message")
return err
}
// PostgreSQL's pg_notify has an 8000 byte limit
// If the wrapped message exceeds this, fall back to database storage
if len(wrappedPayload) > 8000 {
return m.AddMessage(ctx, name, []byte(payload))
}
return m.m.notify(ctx, wrappedPayload)
}
func (m *messageQueueRepository) AddMessage(ctx context.Context, queue string, payload []byte) error {
p := []sqlcv1.BulkAddMessageParams{}
p = append(p, sqlcv1.BulkAddMessageParams{
QueueId: pgtype.Text{
String: queue,
Valid: true,
},
Payload: payload,
ExpiresAt: sqlchelpers.TimestampFromTime(time.Now().UTC().Add(5 * time.Minute)),
ReadAfter: sqlchelpers.TimestampFromTime(time.Now().UTC()),
})
_, err := m.queries.BulkAddMessage(ctx, m.pool, p)
return err
}
func (m *messageQueueRepository) BindQueue(ctx context.Context, queue string, durable, autoDeleted, exclusive bool, exclusiveConsumer *string) error {
// if exclusive, but no consumer, return error
if exclusive && exclusiveConsumer == nil {
return errors.New("exclusive queue must have exclusive consumer")
}
params := sqlcv1.UpsertMessageQueueParams{
Name: queue,
Durable: durable,
Autodeleted: autoDeleted,
Exclusive: exclusive,
}
if exclusiveConsumer != nil {
params.ExclusiveConsumerId = sqlchelpers.UUIDFromStr(*exclusiveConsumer)
}
_, err := m.queries.UpsertMessageQueue(ctx, m.pool, params)
return err
}
func (m *messageQueueRepository) UpdateQueueLastActive(ctx context.Context, queue string) error {
return m.queries.UpdateMessageQueueActive(ctx, m.pool, queue)
}
func (m *messageQueueRepository) CleanupQueues(ctx context.Context) error {
return m.queries.CleanupMessageQueue(ctx, m.pool)
}
func (m *messageQueueRepository) ReadMessages(ctx context.Context, queue string, qos int) ([]*sqlcv1.ReadMessagesRow, error) {
ctx, span := telemetry.NewSpan(ctx, "pgmq-read-messages")
defer span.End()
return m.queries.ReadMessages(ctx, m.pool, sqlcv1.ReadMessagesParams{
Queueid: queue,
Limit: pgtype.Int4{Int32: int32(qos), Valid: true}, // nolint: gosec
})
}
func (m *messageQueueRepository) AckMessage(ctx context.Context, id int64) error {
return m.queries.BulkAckMessages(ctx, m.pool, []int64{id})
}
func (m *messageQueueRepository) CleanupMessageQueueItems(ctx context.Context) error {
// setup telemetry
ctx, span := telemetry.NewSpan(ctx, "cleanup-message-queues-database")
defer span.End()
// get the min and max queue items
minMax, err := m.queries.GetMinMaxExpiredMessageQueueItems(ctx, m.pool)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil
}
return fmt.Errorf("could not get min max processed queue items: %w", err)
}
if minMax == nil {
return nil
}
minId := minMax.MinId
maxId := minMax.MaxId
if maxId == 0 {
return nil
}
// iterate until we have no more queue items to process
var batchSize int64 = 10000
var currBatch int64
for {
if ctx.Err() != nil {
return ctx.Err()
}
currBatch++
currMax := minId + batchSize*currBatch
if currMax > maxId {
currMax = maxId
}
// get the next batch of queue items
err := m.queries.CleanupMessageQueueItems(ctx, m.pool, sqlcv1.CleanupMessageQueueItemsParams{
Minid: minId,
Maxid: minId + batchSize*currBatch,
})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil
}
return fmt.Errorf("could not cleanup queue items: %w", err)
}
if currMax == maxId {
break
}
}
return nil
}