feat: rabbitmq connection pooling (#387)

* feat: add rabbitmq connection pool and remove non-fatal worker errors

* chore: go mod tidy

* fix: release pool after opening channel

* fix: make sure channel is closed after all tasks return on subscribe

* fix: don't loop endlessly
This commit is contained in:
abelanger5
2024-04-16 16:45:03 -04:00
committed by GitHub
parent 33f5ff1453
commit 347bc5dd53
6 changed files with 139 additions and 52 deletions
+2 -2
View File
@@ -15,11 +15,13 @@ require (
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/jackc/pgx-zerolog v0.0.0-20230315001418-f978528409eb
github.com/jackc/pgx/v5 v5.5.5
github.com/jackc/puddle/v2 v2.2.1
github.com/joho/godotenv v1.5.1
github.com/labstack/echo/v4 v4.12.0
github.com/oapi-codegen/runtime v1.1.1
github.com/opencontainers/go-digest v1.0.0
github.com/posthog/posthog-go v0.0.0-20240327112532-87b23fe11103
github.com/shopspring/decimal v1.4.0
github.com/spf13/cobra v1.8.0
github.com/spf13/viper v1.18.2
github.com/steebchen/prisma-client-go v0.36.0
@@ -65,7 +67,6 @@ require (
github.com/invopop/yaml v0.2.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jonboulle/clockwork v0.4.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/labstack/gommon v0.4.2 // indirect
@@ -75,7 +76,6 @@ require (
github.com/perimeterx/marshmallow v1.1.5 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/shopspring/decimal v1.4.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasttemplate v1.2.2 // indirect
+87 -14
View File
@@ -10,6 +10,7 @@ import (
"time"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/jackc/puddle/v2"
amqp "github.com/rabbitmq/amqp091-go"
"github.com/rs/zerolog"
@@ -98,7 +99,36 @@ func New(fs ...MessageQueueImplOpt) (func() error, *MessageQueueImpl) {
l: opts.l,
}
t.sessions = t.redial(ctx, opts.l, opts.url)
constructor := func(context.Context) (*amqp.Connection, error) {
conn, err := amqp.Dial(opts.url)
if err != nil {
opts.l.Error().Msgf("cannot (re)dial: %v: %q", err, opts.url)
return nil, err
}
return conn, nil
}
destructor := func(conn *amqp.Connection) {
err := conn.Close()
if err != nil {
opts.l.Error().Msgf("error closing connection: %v", err)
}
}
maxPoolSize := int32(10)
pool, err := puddle.NewPool(&puddle.Config[*amqp.Connection]{Constructor: constructor, Destructor: destructor, MaxSize: maxPoolSize})
if err != nil {
t.l.Error().Err(err).Msg("cannot create connection pool")
cancel()
return nil, nil
}
t.sessions = t.redial(ctx, opts.l, pool)
t.msgs = make(chan *msgWithQueue)
// create a new lru cache for tenant ids
@@ -138,6 +168,9 @@ func New(fs ...MessageQueueImplOpt) (func() error, *MessageQueueImpl) {
if err := cleanup1(); err != nil {
return fmt.Errorf("error cleaning up rabbitmq publisher: %w", err)
}
pool.Close()
return nil
}
@@ -257,14 +290,21 @@ func (t *MessageQueueImpl) startPublishing() func() error {
for session := range t.sessions {
pub := <-session
conn := pub.Connection
t.l.Debug().Msgf("starting publisher: %s", conn.LocalAddr().String())
for {
if pub.Channel.IsClosed() || pub.Connection.IsClosed() {
if pub.Channel.IsClosed() {
break
} else if conn.IsClosed() {
t.l.Error().Msgf("connection is closed, reconnecting")
break
}
select {
case <-ctx.Done():
return
break
case msg := <-t.msgs:
go func(msg *msgWithQueue) {
body, err := json.Marshal(msg)
@@ -318,6 +358,12 @@ func (t *MessageQueueImpl) startPublishing() func() error {
}(msg)
}
}
err := pub.Channel.Close()
if err != nil {
t.l.Error().Msgf("cannot close channel: %s, %v", conn.LocalAddr().String(), err)
}
}
}()
@@ -342,6 +388,12 @@ func (t *MessageQueueImpl) subscribe(
sessionCount++
sub := <-session
sessionWg := sync.WaitGroup{}
conn := sub.Connection
t.l.Debug().Msgf("starting subscriber %s on: %s", subId, conn.LocalAddr().String())
// we initialize the queue here because exclusive queues are bound to the session/connection. however, it's not clear
// if the exclusive queue will be available to the next session.
queueName, err := t.initQueue(sub, q)
@@ -365,10 +417,21 @@ func (t *MessageQueueImpl) subscribe(
return
}
closeChannel := func() {
sessionWg.Wait()
err = sub.Channel.Close()
if err != nil {
t.l.Error().Msgf("cannot close channel: %s, %v", conn.LocalAddr().String(), err)
}
}
inner:
for {
select {
case <-ctx.Done():
closeChannel()
return
case rabbitMsg, ok := <-deliveries:
if !ok {
@@ -377,9 +440,12 @@ func (t *MessageQueueImpl) subscribe(
}
wg.Add(1)
sessionWg.Add(1)
go func(rabbitMsg amqp.Delivery) {
defer wg.Done()
defer sessionWg.Done()
msg := &msgWithQueue{}
if len(rabbitMsg.Body) == 0 {
@@ -451,11 +517,7 @@ func (t *MessageQueueImpl) subscribe(
}
}
err = sub.CloseDeadline(time.Now())
if err != nil {
t.l.Error().Msgf("cannot close session: %s, %v", sub.LocalAddr().String(), err)
}
go closeChannel()
}
}()
@@ -472,7 +534,7 @@ func (t *MessageQueueImpl) subscribe(
}
// redial continually connects to the URL, exiting the program when no longer possible
func (t *MessageQueueImpl) redial(ctx context.Context, l *zerolog.Logger, url string) chan chan session {
func (t *MessageQueueImpl) redial(ctx context.Context, l *zerolog.Logger, pool *puddle.Pool[*amqp.Connection]) chan chan session {
sessions := make(chan chan session)
go func() {
@@ -491,7 +553,7 @@ func (t *MessageQueueImpl) redial(ctx context.Context, l *zerolog.Logger, url st
var err error
for i := 0; i < MAX_RETRY_COUNT; i++ {
newSession, err = getSession(ctx, l, url)
newSession, err = getSession(ctx, l, pool)
if err == nil {
if i > 0 {
l.Info().Msgf("re-established session after %d attempts", i)
@@ -506,6 +568,7 @@ func (t *MessageQueueImpl) redial(ctx context.Context, l *zerolog.Logger, url st
if err != nil {
l.Error().Msgf("failed to get session after %d attempts", MAX_RETRY_COUNT)
t.ready = false
return
}
@@ -545,18 +608,28 @@ func identity() string {
return fmt.Sprintf("%x", h.Sum(nil))
}
func getSession(ctx context.Context, l *zerolog.Logger, url string) (session, error) {
conn, err := amqp.Dial(url)
func getSession(ctx context.Context, l *zerolog.Logger, pool *puddle.Pool[*amqp.Connection]) (session, error) {
connFromPool, err := pool.Acquire(ctx)
if err != nil {
l.Error().Msgf("cannot (re)dial: %v: %q", err, url)
l.Error().Msgf("cannot acquire connection: %v", err)
return session{}, err
}
conn := connFromPool.Value()
ch, err := conn.Channel()
if err != nil {
connFromPool.Destroy()
l.Error().Msgf("cannot create channel: %v", err)
return session{}, err
}
return session{conn, ch}, nil
connFromPool.Release()
return session{
Channel: ch,
Connection: conn,
}, nil
}
@@ -607,18 +607,7 @@ func (ec *JobsControllerImpl) queueStepRun(ctx context.Context, tenantId, stepId
// set scheduling timeout
if scheduleTimeoutAt := stepRun.StepRun.ScheduleTimeoutAt.Time; scheduleTimeoutAt.IsZero() {
var timeoutDuration time.Duration
// get the schedule timeout from the step
stepScheduleTimeout := stepRun.StepScheduleTimeout
if stepScheduleTimeout != "" {
timeoutDuration, _ = time.ParseDuration(stepScheduleTimeout)
} else {
timeoutDuration = defaults.DefaultScheduleTimeout
}
scheduleTimeoutAt := time.Now().UTC().Add(timeoutDuration)
scheduleTimeoutAt = getScheduleTimeout(stepRun)
updateStepOpts.ScheduleTimeoutAt = &scheduleTimeoutAt
}
@@ -869,10 +858,13 @@ func (ec *JobsControllerImpl) handleStepRunFailed(ctx context.Context, task *msg
status = db.StepRunStatusPending
}
scheduleTimeoutAt := getScheduleTimeout(stepRun)
stepRun, updateInfo, err := ec.repo.StepRun().UpdateStepRun(ctx, metadata.TenantId, payload.StepRunId, &repository.UpdateStepRunOpts{
FinishedAt: &failedAt,
Error: &payload.Error,
Status: repository.StepRunStatusPtr(status),
FinishedAt: &failedAt,
Error: &payload.Error,
Status: repository.StepRunStatusPtr(status),
ScheduleTimeoutAt: &scheduleTimeoutAt,
})
if err != nil {
@@ -957,7 +949,10 @@ func (ec *JobsControllerImpl) cancelStepRun(ctx context.Context, tenantId, stepR
defer ec.handleStepRunUpdateInfo(stepRun, updateInfo)
if !stepRun.StepRun.WorkerId.Valid {
return fmt.Errorf("step run has no worker id")
// this is not a fatal error
ec.l.Debug().Msgf("step run %s has no worker id, skipping cancellation", stepRunId)
return nil
}
workerId := sqlchelpers.UUIDToStr(stepRun.StepRun.WorkerId)
@@ -1055,3 +1050,18 @@ func stepRunCancelledTask(tenantId, stepRunId, workerId, dispatcherId, cancelled
Retries: 3,
}
}
func getScheduleTimeout(stepRun *dbsqlc.GetStepRunForEngineRow) time.Time {
var timeoutDuration time.Duration
// get the schedule timeout from the step
stepScheduleTimeout := stepRun.StepScheduleTimeout
if stepScheduleTimeout != "" {
timeoutDuration, _ = time.ParseDuration(stepScheduleTimeout)
} else {
timeoutDuration = defaults.DefaultScheduleTimeout
}
return time.Now().UTC().Add(timeoutDuration)
}
+15 -6
View File
@@ -2,6 +2,7 @@ package dispatcher
import (
"context"
"errors"
"fmt"
"sync"
"time"
@@ -229,17 +230,25 @@ func (d *DispatcherImpl) Start() (func() error, error) {
return cleanup, nil
}
func (d *DispatcherImpl) handleTask(ctx context.Context, task *msgqueue.Message) error {
func (d *DispatcherImpl) handleTask(ctx context.Context, task *msgqueue.Message) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("recovered from panic: %v", r)
}
}()
switch task.ID {
case "group-key-action-assigned":
return d.a.WrapErr(d.handleGroupKeyActionAssignedTask(ctx, task), map[string]interface{}{})
err = d.a.WrapErr(d.handleGroupKeyActionAssignedTask(ctx, task), map[string]interface{}{})
case "step-run-assigned":
return d.a.WrapErr(d.handleStepRunAssignedTask(ctx, task), map[string]interface{}{})
err = d.a.WrapErr(d.handleStepRunAssignedTask(ctx, task), map[string]interface{}{})
case "step-run-cancelled":
return d.a.WrapErr(d.handleStepRunCancelled(ctx, task), map[string]interface{}{})
err = d.a.WrapErr(d.handleStepRunCancelled(ctx, task), map[string]interface{}{})
default:
err = fmt.Errorf("unknown task: %s", task.ID)
}
return fmt.Errorf("unknown task: %s", task.ID)
return err
}
func (d *DispatcherImpl) handleGroupKeyActionAssignedTask(ctx context.Context, task *msgqueue.Message) error {
@@ -364,7 +373,7 @@ func (d *DispatcherImpl) handleStepRunCancelled(ctx context.Context, task *msgqu
// get the worker for this task
w, err := d.GetWorker(payload.WorkerId)
if err != nil {
if err != nil && !errors.Is(err, ErrWorkerNotFound) {
return fmt.Errorf("could not get worker: %w", err)
}
+9 -14
View File
@@ -20,10 +20,12 @@ import (
"github.com/hatchet-dev/hatchet/internal/telemetry"
)
var ErrWorkerNotFound = fmt.Errorf("worker not found")
func (d *DispatcherImpl) GetWorker(workerId string) (*subscribedWorker, error) {
workerInt, ok := d.workers.Load(workerId)
if !ok {
return nil, fmt.Errorf("worker with id %s not found", workerId)
return nil, ErrWorkerNotFound
}
worker, ok := workerInt.(subscribedWorker)
@@ -199,16 +201,6 @@ func (s *DispatcherImpl) Listen(request *contracts.WorkerListenRequest, stream c
}
s.workers.Delete(request.WorkerId)
inactive := db.WorkerStatusInactive
_, err := s.repo.Worker().UpdateWorker(tenantId, request.WorkerId, &repository.UpdateWorkerOpts{
Status: &inactive,
})
if err != nil {
s.l.Error().Err(err).Msgf("could not update worker %s status to inactive", request.WorkerId)
}
}()
ctx := stream.Context()
@@ -420,7 +412,6 @@ func (s *DispatcherImpl) SubscribeToWorkflowEvents(request *contracts.SubscribeT
}
func waitFor(wg *sync.WaitGroup, timeout time.Duration, l *zerolog.Logger) {
done := make(chan struct{})
go func() {
@@ -495,10 +486,14 @@ func (s *DispatcherImpl) Unsubscribe(ctx context.Context, request *contracts.Wor
// no matter what, remove the worker from the connection pool
defer s.workers.Delete(request.WorkerId)
err := s.repo.Worker().DeleteWorker(tenantId, request.WorkerId)
inactive := db.WorkerStatusInactive
_, err := s.repo.Worker().UpdateWorker(tenantId, request.WorkerId, &repository.UpdateWorkerOpts{
Status: &inactive,
})
if err != nil {
return nil, err
s.l.Error().Err(err).Msgf("could not update worker %s status to inactive", request.WorkerId)
}
return &contracts.WorkerUnsubscribeResponse{