mirror of
https://github.com/hatchet-dev/hatchet.git
synced 2026-04-28 21:49:55 -05:00
feat: rabbitmq connection pooling (#387)
* feat: add rabbitmq connection pool and remove non-fatal worker errors * chore: go mod tidy * fix: release pool after opening channel * fix: make sure channel is closed after all tasks return on subscribe * fix: don't loop endlessly
This commit is contained in:
@@ -15,11 +15,13 @@ require (
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7
|
||||
github.com/jackc/pgx-zerolog v0.0.0-20230315001418-f978528409eb
|
||||
github.com/jackc/pgx/v5 v5.5.5
|
||||
github.com/jackc/puddle/v2 v2.2.1
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/labstack/echo/v4 v4.12.0
|
||||
github.com/oapi-codegen/runtime v1.1.1
|
||||
github.com/opencontainers/go-digest v1.0.0
|
||||
github.com/posthog/posthog-go v0.0.0-20240327112532-87b23fe11103
|
||||
github.com/shopspring/decimal v1.4.0
|
||||
github.com/spf13/cobra v1.8.0
|
||||
github.com/spf13/viper v1.18.2
|
||||
github.com/steebchen/prisma-client-go v0.36.0
|
||||
@@ -65,7 +67,6 @@ require (
|
||||
github.com/invopop/yaml v0.2.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||
github.com/jonboulle/clockwork v0.4.0 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/labstack/gommon v0.4.2 // indirect
|
||||
@@ -75,7 +76,6 @@ require (
|
||||
github.com/perimeterx/marshmallow v1.1.5 // indirect
|
||||
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||
github.com/shopspring/decimal v1.4.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasttemplate v1.2.2 // indirect
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru/v2"
|
||||
"github.com/jackc/puddle/v2"
|
||||
amqp "github.com/rabbitmq/amqp091-go"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
@@ -98,7 +99,36 @@ func New(fs ...MessageQueueImplOpt) (func() error, *MessageQueueImpl) {
|
||||
l: opts.l,
|
||||
}
|
||||
|
||||
t.sessions = t.redial(ctx, opts.l, opts.url)
|
||||
constructor := func(context.Context) (*amqp.Connection, error) {
|
||||
conn, err := amqp.Dial(opts.url)
|
||||
|
||||
if err != nil {
|
||||
opts.l.Error().Msgf("cannot (re)dial: %v: %q", err, opts.url)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
destructor := func(conn *amqp.Connection) {
|
||||
err := conn.Close()
|
||||
|
||||
if err != nil {
|
||||
opts.l.Error().Msgf("error closing connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
maxPoolSize := int32(10)
|
||||
|
||||
pool, err := puddle.NewPool(&puddle.Config[*amqp.Connection]{Constructor: constructor, Destructor: destructor, MaxSize: maxPoolSize})
|
||||
|
||||
if err != nil {
|
||||
t.l.Error().Err(err).Msg("cannot create connection pool")
|
||||
cancel()
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
t.sessions = t.redial(ctx, opts.l, pool)
|
||||
t.msgs = make(chan *msgWithQueue)
|
||||
|
||||
// create a new lru cache for tenant ids
|
||||
@@ -138,6 +168,9 @@ func New(fs ...MessageQueueImplOpt) (func() error, *MessageQueueImpl) {
|
||||
if err := cleanup1(); err != nil {
|
||||
return fmt.Errorf("error cleaning up rabbitmq publisher: %w", err)
|
||||
}
|
||||
|
||||
pool.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -257,14 +290,21 @@ func (t *MessageQueueImpl) startPublishing() func() error {
|
||||
for session := range t.sessions {
|
||||
pub := <-session
|
||||
|
||||
conn := pub.Connection
|
||||
|
||||
t.l.Debug().Msgf("starting publisher: %s", conn.LocalAddr().String())
|
||||
|
||||
for {
|
||||
if pub.Channel.IsClosed() || pub.Connection.IsClosed() {
|
||||
if pub.Channel.IsClosed() {
|
||||
break
|
||||
} else if conn.IsClosed() {
|
||||
t.l.Error().Msgf("connection is closed, reconnecting")
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
break
|
||||
case msg := <-t.msgs:
|
||||
go func(msg *msgWithQueue) {
|
||||
body, err := json.Marshal(msg)
|
||||
@@ -318,6 +358,12 @@ func (t *MessageQueueImpl) startPublishing() func() error {
|
||||
}(msg)
|
||||
}
|
||||
}
|
||||
|
||||
err := pub.Channel.Close()
|
||||
|
||||
if err != nil {
|
||||
t.l.Error().Msgf("cannot close channel: %s, %v", conn.LocalAddr().String(), err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -342,6 +388,12 @@ func (t *MessageQueueImpl) subscribe(
|
||||
sessionCount++
|
||||
sub := <-session
|
||||
|
||||
sessionWg := sync.WaitGroup{}
|
||||
|
||||
conn := sub.Connection
|
||||
|
||||
t.l.Debug().Msgf("starting subscriber %s on: %s", subId, conn.LocalAddr().String())
|
||||
|
||||
// we initialize the queue here because exclusive queues are bound to the session/connection. however, it's not clear
|
||||
// if the exclusive queue will be available to the next session.
|
||||
queueName, err := t.initQueue(sub, q)
|
||||
@@ -365,10 +417,21 @@ func (t *MessageQueueImpl) subscribe(
|
||||
return
|
||||
}
|
||||
|
||||
closeChannel := func() {
|
||||
sessionWg.Wait()
|
||||
|
||||
err = sub.Channel.Close()
|
||||
|
||||
if err != nil {
|
||||
t.l.Error().Msgf("cannot close channel: %s, %v", conn.LocalAddr().String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
inner:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
closeChannel()
|
||||
return
|
||||
case rabbitMsg, ok := <-deliveries:
|
||||
if !ok {
|
||||
@@ -377,9 +440,12 @@ func (t *MessageQueueImpl) subscribe(
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
sessionWg.Add(1)
|
||||
|
||||
go func(rabbitMsg amqp.Delivery) {
|
||||
defer wg.Done()
|
||||
defer sessionWg.Done()
|
||||
|
||||
msg := &msgWithQueue{}
|
||||
|
||||
if len(rabbitMsg.Body) == 0 {
|
||||
@@ -451,11 +517,7 @@ func (t *MessageQueueImpl) subscribe(
|
||||
}
|
||||
}
|
||||
|
||||
err = sub.CloseDeadline(time.Now())
|
||||
|
||||
if err != nil {
|
||||
t.l.Error().Msgf("cannot close session: %s, %v", sub.LocalAddr().String(), err)
|
||||
}
|
||||
go closeChannel()
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -472,7 +534,7 @@ func (t *MessageQueueImpl) subscribe(
|
||||
}
|
||||
|
||||
// redial continually connects to the URL, exiting the program when no longer possible
|
||||
func (t *MessageQueueImpl) redial(ctx context.Context, l *zerolog.Logger, url string) chan chan session {
|
||||
func (t *MessageQueueImpl) redial(ctx context.Context, l *zerolog.Logger, pool *puddle.Pool[*amqp.Connection]) chan chan session {
|
||||
sessions := make(chan chan session)
|
||||
|
||||
go func() {
|
||||
@@ -491,7 +553,7 @@ func (t *MessageQueueImpl) redial(ctx context.Context, l *zerolog.Logger, url st
|
||||
var err error
|
||||
|
||||
for i := 0; i < MAX_RETRY_COUNT; i++ {
|
||||
newSession, err = getSession(ctx, l, url)
|
||||
newSession, err = getSession(ctx, l, pool)
|
||||
if err == nil {
|
||||
if i > 0 {
|
||||
l.Info().Msgf("re-established session after %d attempts", i)
|
||||
@@ -506,6 +568,7 @@ func (t *MessageQueueImpl) redial(ctx context.Context, l *zerolog.Logger, url st
|
||||
|
||||
if err != nil {
|
||||
l.Error().Msgf("failed to get session after %d attempts", MAX_RETRY_COUNT)
|
||||
t.ready = false
|
||||
return
|
||||
}
|
||||
|
||||
@@ -545,18 +608,28 @@ func identity() string {
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
func getSession(ctx context.Context, l *zerolog.Logger, url string) (session, error) {
|
||||
conn, err := amqp.Dial(url)
|
||||
func getSession(ctx context.Context, l *zerolog.Logger, pool *puddle.Pool[*amqp.Connection]) (session, error) {
|
||||
connFromPool, err := pool.Acquire(ctx)
|
||||
|
||||
if err != nil {
|
||||
l.Error().Msgf("cannot (re)dial: %v: %q", err, url)
|
||||
l.Error().Msgf("cannot acquire connection: %v", err)
|
||||
return session{}, err
|
||||
}
|
||||
|
||||
conn := connFromPool.Value()
|
||||
|
||||
ch, err := conn.Channel()
|
||||
|
||||
if err != nil {
|
||||
connFromPool.Destroy()
|
||||
l.Error().Msgf("cannot create channel: %v", err)
|
||||
return session{}, err
|
||||
}
|
||||
|
||||
return session{conn, ch}, nil
|
||||
connFromPool.Release()
|
||||
|
||||
return session{
|
||||
Channel: ch,
|
||||
Connection: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -607,18 +607,7 @@ func (ec *JobsControllerImpl) queueStepRun(ctx context.Context, tenantId, stepId
|
||||
|
||||
// set scheduling timeout
|
||||
if scheduleTimeoutAt := stepRun.StepRun.ScheduleTimeoutAt.Time; scheduleTimeoutAt.IsZero() {
|
||||
var timeoutDuration time.Duration
|
||||
|
||||
// get the schedule timeout from the step
|
||||
stepScheduleTimeout := stepRun.StepScheduleTimeout
|
||||
|
||||
if stepScheduleTimeout != "" {
|
||||
timeoutDuration, _ = time.ParseDuration(stepScheduleTimeout)
|
||||
} else {
|
||||
timeoutDuration = defaults.DefaultScheduleTimeout
|
||||
}
|
||||
|
||||
scheduleTimeoutAt := time.Now().UTC().Add(timeoutDuration)
|
||||
scheduleTimeoutAt = getScheduleTimeout(stepRun)
|
||||
|
||||
updateStepOpts.ScheduleTimeoutAt = &scheduleTimeoutAt
|
||||
}
|
||||
@@ -869,10 +858,13 @@ func (ec *JobsControllerImpl) handleStepRunFailed(ctx context.Context, task *msg
|
||||
status = db.StepRunStatusPending
|
||||
}
|
||||
|
||||
scheduleTimeoutAt := getScheduleTimeout(stepRun)
|
||||
|
||||
stepRun, updateInfo, err := ec.repo.StepRun().UpdateStepRun(ctx, metadata.TenantId, payload.StepRunId, &repository.UpdateStepRunOpts{
|
||||
FinishedAt: &failedAt,
|
||||
Error: &payload.Error,
|
||||
Status: repository.StepRunStatusPtr(status),
|
||||
FinishedAt: &failedAt,
|
||||
Error: &payload.Error,
|
||||
Status: repository.StepRunStatusPtr(status),
|
||||
ScheduleTimeoutAt: &scheduleTimeoutAt,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -957,7 +949,10 @@ func (ec *JobsControllerImpl) cancelStepRun(ctx context.Context, tenantId, stepR
|
||||
defer ec.handleStepRunUpdateInfo(stepRun, updateInfo)
|
||||
|
||||
if !stepRun.StepRun.WorkerId.Valid {
|
||||
return fmt.Errorf("step run has no worker id")
|
||||
// this is not a fatal error
|
||||
ec.l.Debug().Msgf("step run %s has no worker id, skipping cancellation", stepRunId)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
workerId := sqlchelpers.UUIDToStr(stepRun.StepRun.WorkerId)
|
||||
@@ -1055,3 +1050,18 @@ func stepRunCancelledTask(tenantId, stepRunId, workerId, dispatcherId, cancelled
|
||||
Retries: 3,
|
||||
}
|
||||
}
|
||||
|
||||
func getScheduleTimeout(stepRun *dbsqlc.GetStepRunForEngineRow) time.Time {
|
||||
var timeoutDuration time.Duration
|
||||
|
||||
// get the schedule timeout from the step
|
||||
stepScheduleTimeout := stepRun.StepScheduleTimeout
|
||||
|
||||
if stepScheduleTimeout != "" {
|
||||
timeoutDuration, _ = time.ParseDuration(stepScheduleTimeout)
|
||||
} else {
|
||||
timeoutDuration = defaults.DefaultScheduleTimeout
|
||||
}
|
||||
|
||||
return time.Now().UTC().Add(timeoutDuration)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package dispatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -229,17 +230,25 @@ func (d *DispatcherImpl) Start() (func() error, error) {
|
||||
return cleanup, nil
|
||||
}
|
||||
|
||||
func (d *DispatcherImpl) handleTask(ctx context.Context, task *msgqueue.Message) error {
|
||||
func (d *DispatcherImpl) handleTask(ctx context.Context, task *msgqueue.Message) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("recovered from panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
switch task.ID {
|
||||
case "group-key-action-assigned":
|
||||
return d.a.WrapErr(d.handleGroupKeyActionAssignedTask(ctx, task), map[string]interface{}{})
|
||||
err = d.a.WrapErr(d.handleGroupKeyActionAssignedTask(ctx, task), map[string]interface{}{})
|
||||
case "step-run-assigned":
|
||||
return d.a.WrapErr(d.handleStepRunAssignedTask(ctx, task), map[string]interface{}{})
|
||||
err = d.a.WrapErr(d.handleStepRunAssignedTask(ctx, task), map[string]interface{}{})
|
||||
case "step-run-cancelled":
|
||||
return d.a.WrapErr(d.handleStepRunCancelled(ctx, task), map[string]interface{}{})
|
||||
err = d.a.WrapErr(d.handleStepRunCancelled(ctx, task), map[string]interface{}{})
|
||||
default:
|
||||
err = fmt.Errorf("unknown task: %s", task.ID)
|
||||
}
|
||||
|
||||
return fmt.Errorf("unknown task: %s", task.ID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *DispatcherImpl) handleGroupKeyActionAssignedTask(ctx context.Context, task *msgqueue.Message) error {
|
||||
@@ -364,7 +373,7 @@ func (d *DispatcherImpl) handleStepRunCancelled(ctx context.Context, task *msgqu
|
||||
// get the worker for this task
|
||||
w, err := d.GetWorker(payload.WorkerId)
|
||||
|
||||
if err != nil {
|
||||
if err != nil && !errors.Is(err, ErrWorkerNotFound) {
|
||||
return fmt.Errorf("could not get worker: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -20,10 +20,12 @@ import (
|
||||
"github.com/hatchet-dev/hatchet/internal/telemetry"
|
||||
)
|
||||
|
||||
var ErrWorkerNotFound = fmt.Errorf("worker not found")
|
||||
|
||||
func (d *DispatcherImpl) GetWorker(workerId string) (*subscribedWorker, error) {
|
||||
workerInt, ok := d.workers.Load(workerId)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("worker with id %s not found", workerId)
|
||||
return nil, ErrWorkerNotFound
|
||||
}
|
||||
|
||||
worker, ok := workerInt.(subscribedWorker)
|
||||
@@ -199,16 +201,6 @@ func (s *DispatcherImpl) Listen(request *contracts.WorkerListenRequest, stream c
|
||||
}
|
||||
|
||||
s.workers.Delete(request.WorkerId)
|
||||
|
||||
inactive := db.WorkerStatusInactive
|
||||
|
||||
_, err := s.repo.Worker().UpdateWorker(tenantId, request.WorkerId, &repository.UpdateWorkerOpts{
|
||||
Status: &inactive,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
s.l.Error().Err(err).Msgf("could not update worker %s status to inactive", request.WorkerId)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx := stream.Context()
|
||||
@@ -420,7 +412,6 @@ func (s *DispatcherImpl) SubscribeToWorkflowEvents(request *contracts.SubscribeT
|
||||
}
|
||||
|
||||
func waitFor(wg *sync.WaitGroup, timeout time.Duration, l *zerolog.Logger) {
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
@@ -495,10 +486,14 @@ func (s *DispatcherImpl) Unsubscribe(ctx context.Context, request *contracts.Wor
|
||||
// no matter what, remove the worker from the connection pool
|
||||
defer s.workers.Delete(request.WorkerId)
|
||||
|
||||
err := s.repo.Worker().DeleteWorker(tenantId, request.WorkerId)
|
||||
inactive := db.WorkerStatusInactive
|
||||
|
||||
_, err := s.repo.Worker().UpdateWorker(tenantId, request.WorkerId, &repository.UpdateWorkerOpts{
|
||||
Status: &inactive,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
s.l.Error().Err(err).Msgf("could not update worker %s status to inactive", request.WorkerId)
|
||||
}
|
||||
|
||||
return &contracts.WorkerUnsubscribeResponse{
|
||||
|
||||
Reference in New Issue
Block a user