mirror of
https://github.com/hatchet-dev/hatchet.git
synced 2026-02-04 23:28:47 -06:00
Logic for requeueing and reassigning did not limit the number of step runs to requeue, so when events accumulate with no worker present it causes memory to spike along with a very high query latency on the database. This commit limits the number of step runs returned in the requeue and reassign queries, and also properly locks step run rows for these queries so only a step run in a PENDING or PENDING_ASSIGNMENT state can be requeued. It also improves performance of the `AssignStepRunToWorker` query and ensures that `maxRuns` on workers are always respected through the introduction of a `WorkerSemaphore` model. The value gets decremented when a step run is assigned and incremented when a step run is in a final state. Co-authored-by: Luca Steeb <contact@luca-steeb.com> * Update controller.go --------- Co-authored-by: steebchen <contact@luca-steeb.com>
606 lines
14 KiB
Go
606 lines
14 KiB
Go
package worker
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"reflect"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/rs/zerolog"
|
|
|
|
"github.com/hatchet-dev/hatchet/internal/logger"
|
|
"github.com/hatchet-dev/hatchet/pkg/client"
|
|
"github.com/hatchet-dev/hatchet/pkg/client/types"
|
|
"github.com/hatchet-dev/hatchet/pkg/errors"
|
|
"github.com/hatchet-dev/hatchet/pkg/integrations"
|
|
)
|
|
|
|
type actionFunc func(args ...any) []any
|
|
|
|
// Action is an individual action that can be run by the worker.
|
|
type Action interface {
|
|
// Name returns the name of the action
|
|
Name() string
|
|
|
|
// Run runs the action
|
|
Run(args ...any) []any
|
|
|
|
MethodFn() any
|
|
|
|
ConcurrencyFn() GetWorkflowConcurrencyGroupFn
|
|
|
|
// Service returns the service that the action belongs to
|
|
Service() string
|
|
}
|
|
|
|
type actionImpl struct {
|
|
name string
|
|
run actionFunc
|
|
runConcurrencyAction GetWorkflowConcurrencyGroupFn
|
|
method any
|
|
service string
|
|
}
|
|
|
|
func (j *actionImpl) Name() string {
|
|
return j.name
|
|
}
|
|
|
|
func (j *actionImpl) Run(args ...interface{}) []interface{} {
|
|
return j.run(args...)
|
|
}
|
|
|
|
func (j *actionImpl) MethodFn() any {
|
|
return j.method
|
|
}
|
|
|
|
func (j *actionImpl) ConcurrencyFn() GetWorkflowConcurrencyGroupFn {
|
|
return j.runConcurrencyAction
|
|
}
|
|
|
|
func (j *actionImpl) Service() string {
|
|
return j.service
|
|
}
|
|
|
|
type Worker struct {
|
|
client client.Client
|
|
|
|
name string
|
|
|
|
actions map[string]Action
|
|
|
|
l *zerolog.Logger
|
|
|
|
cancelMap sync.Map
|
|
|
|
cancelConcurrencyMap sync.Map
|
|
|
|
services sync.Map
|
|
|
|
alerter errors.Alerter
|
|
|
|
middlewares *middlewares
|
|
|
|
maxRuns *int
|
|
}
|
|
|
|
type WorkerOpt func(*WorkerOpts)
|
|
|
|
type WorkerOpts struct {
|
|
client client.Client
|
|
name string
|
|
l *zerolog.Logger
|
|
|
|
integrations []integrations.Integration
|
|
alerter errors.Alerter
|
|
maxRuns *int
|
|
}
|
|
|
|
func defaultWorkerOpts() *WorkerOpts {
|
|
logger := logger.NewDefaultLogger("worker")
|
|
|
|
return &WorkerOpts{
|
|
name: getHostName(),
|
|
l: &logger,
|
|
integrations: []integrations.Integration{},
|
|
alerter: errors.NoOpAlerter{},
|
|
}
|
|
}
|
|
|
|
func WithLogLevel(lvl string) WorkerOpt {
|
|
return func(opts *WorkerOpts) {
|
|
logger := logger.NewDefaultLogger("worker")
|
|
lvl, err := zerolog.ParseLevel(lvl)
|
|
|
|
if err == nil {
|
|
logger = logger.Level(lvl)
|
|
}
|
|
|
|
opts.l = &logger
|
|
}
|
|
}
|
|
|
|
func WithName(name string) WorkerOpt {
|
|
return func(opts *WorkerOpts) {
|
|
opts.name = name
|
|
}
|
|
}
|
|
|
|
func WithClient(client client.Client) WorkerOpt {
|
|
return func(opts *WorkerOpts) {
|
|
opts.client = client
|
|
}
|
|
}
|
|
|
|
func WithIntegration(integration integrations.Integration) WorkerOpt {
|
|
return func(opts *WorkerOpts) {
|
|
opts.integrations = append(opts.integrations, integration)
|
|
}
|
|
}
|
|
|
|
func WithErrorAlerter(alerter errors.Alerter) WorkerOpt {
|
|
return func(opts *WorkerOpts) {
|
|
opts.alerter = alerter
|
|
}
|
|
}
|
|
|
|
func WithMaxRuns(maxRuns int) WorkerOpt {
|
|
return func(opts *WorkerOpts) {
|
|
opts.maxRuns = &maxRuns
|
|
}
|
|
}
|
|
|
|
// NewWorker creates a new worker instance
|
|
func NewWorker(fs ...WorkerOpt) (*Worker, error) {
|
|
opts := defaultWorkerOpts()
|
|
|
|
for _, f := range fs {
|
|
f(opts)
|
|
}
|
|
|
|
mws := newMiddlewares()
|
|
|
|
mws.add(panicMiddleware)
|
|
|
|
w := &Worker{
|
|
client: opts.client,
|
|
name: opts.name,
|
|
l: opts.l,
|
|
actions: map[string]Action{},
|
|
alerter: opts.alerter,
|
|
middlewares: mws,
|
|
maxRuns: opts.maxRuns,
|
|
}
|
|
|
|
// register all integrations
|
|
for _, integration := range opts.integrations {
|
|
actions := integration.Actions()
|
|
integrationId := integration.GetId()
|
|
|
|
for _, integrationAction := range actions {
|
|
action := fmt.Sprintf("%s:%s", integrationId, integrationAction)
|
|
|
|
err := w.registerAction(integrationId, action, integration.ActionHandler(integrationAction))
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not register integration action %s: %w", action, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
w.NewService("default")
|
|
|
|
return w, nil
|
|
}
|
|
|
|
func (w *Worker) Use(mws ...MiddlewareFunc) {
|
|
w.middlewares.add(mws...)
|
|
}
|
|
|
|
func (w *Worker) NewService(name string) *Service {
|
|
svc := &Service{
|
|
Name: name,
|
|
worker: w,
|
|
mws: newMiddlewares(),
|
|
}
|
|
|
|
w.services.Store(name, svc)
|
|
|
|
return svc
|
|
}
|
|
|
|
func (w *Worker) On(t triggerConverter, workflow workflowConverter) error {
|
|
// get the default service
|
|
svc, ok := w.services.Load("default")
|
|
|
|
if !ok {
|
|
return fmt.Errorf("could not load default service")
|
|
}
|
|
|
|
return svc.(*Service).On(t, workflow)
|
|
}
|
|
|
|
// RegisterAction can be used to register a single action which can be reused across multiple workflows.
|
|
//
|
|
// An action should be of the format <service>:<verb>, for example slack:create-channel.
|
|
//
|
|
// The method must match the following signatures:
|
|
// - func(ctx context.Context) error
|
|
// - func(ctx context.Context, input *Input) error
|
|
// - func(ctx context.Context, input *Input) (*Output, error)
|
|
// - func(ctx context.Context) (*Output, error)
|
|
func (w *Worker) RegisterAction(actionId string, method any) error {
|
|
// parse the action
|
|
action, err := types.ParseActionID(actionId)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("could not parse action id: %w", err)
|
|
}
|
|
|
|
return w.registerAction(action.Service, action.Verb, method)
|
|
}
|
|
|
|
func (w *Worker) registerAction(service, verb string, method any) error {
|
|
actionId := fmt.Sprintf("%s:%s", service, verb)
|
|
|
|
// if the service is "concurrency", then this is a special action
|
|
if service == "concurrency" {
|
|
w.actions[actionId] = &actionImpl{
|
|
name: actionId,
|
|
runConcurrencyAction: method.(GetWorkflowConcurrencyGroupFn),
|
|
method: method,
|
|
service: service,
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
actionFunc, err := getFnFromMethod(method)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("could not get function from method: %w", err)
|
|
}
|
|
|
|
// if action has already been registered, ensure that the method is the same
|
|
if currMethod, ok := w.actions[actionId]; ok {
|
|
if reflect.ValueOf(currMethod.MethodFn()).Pointer() != reflect.ValueOf(method).Pointer() {
|
|
return fmt.Errorf("action %s is already registered with function %s", actionId, getFnName(currMethod.MethodFn()))
|
|
}
|
|
}
|
|
|
|
w.actions[actionId] = &actionImpl{
|
|
name: actionId,
|
|
run: actionFunc,
|
|
method: method,
|
|
service: service,
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Start starts the worker in blocking fashion
|
|
func (w *Worker) Start() (func() error, error) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
actionNames := []string{}
|
|
|
|
for _, action := range w.actions {
|
|
actionNames = append(actionNames, action.Name())
|
|
}
|
|
|
|
listener, err := w.client.Dispatcher().GetActionListener(ctx, &client.GetActionListenerRequest{
|
|
WorkerName: w.name,
|
|
Actions: actionNames,
|
|
MaxRuns: w.maxRuns,
|
|
})
|
|
|
|
if err != nil {
|
|
cancel()
|
|
return nil, fmt.Errorf("could not get action listener: %w", err)
|
|
}
|
|
|
|
actionCh, err := listener.Actions(ctx)
|
|
|
|
if err != nil {
|
|
cancel()
|
|
return nil, fmt.Errorf("could not get action channel: %w", err)
|
|
}
|
|
|
|
go func() {
|
|
for {
|
|
select {
|
|
case action := <-actionCh:
|
|
go func(action *client.Action) {
|
|
err := w.executeAction(context.Background(), action)
|
|
|
|
if err != nil {
|
|
w.l.Error().Err(err).Msgf("could not execute action: %s", action.ActionId)
|
|
}
|
|
|
|
w.l.Debug().Msgf("action %s completed", action.ActionId)
|
|
}(action)
|
|
case <-ctx.Done():
|
|
w.l.Debug().Msgf("worker %s received context done, stopping", w.name)
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
cleanup := func() error {
|
|
cancel()
|
|
|
|
w.l.Debug().Msgf("worker %s is stopping...", w.name)
|
|
|
|
err := listener.Unregister()
|
|
if err != nil {
|
|
return fmt.Errorf("could not unregister worker: %w", err)
|
|
}
|
|
|
|
w.l.Debug().Msgf("worker %s stopped", w.name)
|
|
|
|
return nil
|
|
}
|
|
|
|
return cleanup, nil
|
|
}
|
|
|
|
func (w *Worker) executeAction(ctx context.Context, assignedAction *client.Action) error {
|
|
switch assignedAction.ActionType {
|
|
case client.ActionTypeStartStepRun:
|
|
return w.startStepRun(ctx, assignedAction)
|
|
case client.ActionTypeCancelStepRun:
|
|
return w.cancelStepRun(ctx, assignedAction)
|
|
case client.ActionTypeStartGetGroupKey:
|
|
return w.startGetGroupKey(ctx, assignedAction)
|
|
default:
|
|
return fmt.Errorf("unknown action type: %s", assignedAction.ActionType)
|
|
}
|
|
}
|
|
|
|
func (w *Worker) startStepRun(ctx context.Context, assignedAction *client.Action) error {
|
|
// send a message that the step run started
|
|
_, err := w.client.Dispatcher().SendStepActionEvent(
|
|
ctx,
|
|
w.getActionEvent(assignedAction, client.ActionEventTypeStarted),
|
|
)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("could not send action event: %w", err)
|
|
}
|
|
|
|
action, ok := w.actions[assignedAction.ActionId]
|
|
|
|
if !ok {
|
|
return fmt.Errorf("job not found")
|
|
}
|
|
|
|
arg, err := decodeArgsToInterface(reflect.TypeOf(action.MethodFn()))
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("could not decode args to interface: %w", err)
|
|
}
|
|
|
|
runContext, cancel := context.WithCancel(context.Background())
|
|
|
|
w.cancelMap.Store(assignedAction.StepRunId, cancel)
|
|
|
|
hCtx, err := newHatchetContext(runContext, assignedAction, w.client, w.l)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("could not create hatchet context: %w", err)
|
|
}
|
|
|
|
// get the action's service
|
|
svcAny, ok := w.services.Load(action.Service())
|
|
|
|
if !ok {
|
|
return fmt.Errorf("could not load service %s", action.Service())
|
|
}
|
|
|
|
svc := svcAny.(*Service)
|
|
|
|
// wrap the run with middleware. start by wrapping the global worker middleware, then
|
|
// the service-specific middleware
|
|
return w.middlewares.runAll(hCtx, func(ctx HatchetContext) error {
|
|
return svc.mws.runAll(ctx, func(ctx HatchetContext) error {
|
|
args := []any{ctx}
|
|
|
|
if arg != nil {
|
|
args = append(args, arg)
|
|
}
|
|
|
|
runResults := action.Run(args...)
|
|
|
|
// check whether run context was cancelled while action was running
|
|
select {
|
|
case <-ctx.Done():
|
|
w.l.Debug().Msgf("step run %s was cancelled, returning", assignedAction.StepRunId)
|
|
return nil
|
|
default:
|
|
}
|
|
|
|
var result any
|
|
|
|
if len(runResults) == 2 {
|
|
result = runResults[0]
|
|
}
|
|
|
|
if runResults[len(runResults)-1] != nil {
|
|
err = runResults[len(runResults)-1].(error)
|
|
}
|
|
|
|
if err != nil {
|
|
failureEvent := w.getActionEvent(assignedAction, client.ActionEventTypeFailed)
|
|
|
|
w.alerter.SendAlert(context.Background(), err, map[string]interface{}{
|
|
"actionId": assignedAction.ActionId,
|
|
"workerId": assignedAction.WorkerId,
|
|
"stepRunId": assignedAction.StepRunId,
|
|
"jobName": assignedAction.JobName,
|
|
"actionType": assignedAction.ActionType,
|
|
})
|
|
|
|
failureEvent.EventPayload = err.Error()
|
|
|
|
_, err := w.client.Dispatcher().SendStepActionEvent(
|
|
ctx,
|
|
failureEvent,
|
|
)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("could not send action event: %w", err)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// send a message that the step run completed
|
|
finishedEvent, err := w.getActionFinishedEvent(assignedAction, result)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("could not create finished event: %w", err)
|
|
}
|
|
|
|
_, err = w.client.Dispatcher().SendStepActionEvent(
|
|
ctx,
|
|
finishedEvent,
|
|
)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("could not send action event: %w", err)
|
|
}
|
|
|
|
return nil
|
|
})
|
|
})
|
|
}
|
|
|
|
func (w *Worker) startGetGroupKey(ctx context.Context, assignedAction *client.Action) error {
|
|
// send a message that the step run started
|
|
_, err := w.client.Dispatcher().SendGroupKeyActionEvent(
|
|
ctx,
|
|
w.getActionEvent(assignedAction, client.ActionEventTypeStarted),
|
|
)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("could not send action event: %w", err)
|
|
}
|
|
|
|
action, ok := w.actions[assignedAction.ActionId]
|
|
|
|
if !ok {
|
|
return fmt.Errorf("job not found")
|
|
}
|
|
|
|
// action should be concurrency action
|
|
if action.ConcurrencyFn() == nil {
|
|
return fmt.Errorf("action %s is not a concurrency action", action.Name())
|
|
}
|
|
|
|
runContext, cancel := context.WithCancel(context.Background())
|
|
|
|
w.cancelConcurrencyMap.Store(assignedAction.WorkflowRunId, cancel)
|
|
|
|
hCtx, err := newHatchetContext(runContext, assignedAction, w.client, w.l)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("could not create hatchet context: %w", err)
|
|
}
|
|
|
|
concurrencyKey, err := action.ConcurrencyFn()(hCtx)
|
|
|
|
if err != nil {
|
|
failureEvent := w.getActionEvent(assignedAction, client.ActionEventTypeFailed)
|
|
|
|
w.alerter.SendAlert(context.Background(), err, map[string]interface{}{
|
|
"actionId": assignedAction.ActionId,
|
|
"workerId": assignedAction.WorkerId,
|
|
"workflowRunId": assignedAction.WorkflowRunId,
|
|
"jobName": assignedAction.JobName,
|
|
"actionType": assignedAction.ActionType,
|
|
})
|
|
|
|
failureEvent.EventPayload = err.Error()
|
|
|
|
_, err := w.client.Dispatcher().SendGroupKeyActionEvent(
|
|
ctx,
|
|
failureEvent,
|
|
)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("could not send action event: %w", err)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// send a message that the step run completed
|
|
finishedEvent, err := w.getGroupKeyActionFinishedEvent(assignedAction, concurrencyKey)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("could not create finished event: %w", err)
|
|
}
|
|
|
|
_, err = w.client.Dispatcher().SendGroupKeyActionEvent(
|
|
ctx,
|
|
finishedEvent,
|
|
)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("could not send action event: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *Worker) cancelStepRun(ctx context.Context, assignedAction *client.Action) error {
|
|
cancel, ok := w.cancelMap.Load(assignedAction.StepRunId)
|
|
|
|
if !ok {
|
|
return fmt.Errorf("could not find step run to cancel")
|
|
}
|
|
|
|
w.l.Debug().Msgf("cancelling step run %s", assignedAction.StepRunId)
|
|
|
|
cancelFn := cancel.(context.CancelFunc)
|
|
|
|
cancelFn()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *Worker) getActionEvent(action *client.Action, eventType client.ActionEventType) *client.ActionEvent {
|
|
timestamp := time.Now().UTC()
|
|
|
|
return &client.ActionEvent{
|
|
Action: action,
|
|
EventTimestamp: ×tamp,
|
|
EventType: eventType,
|
|
}
|
|
}
|
|
|
|
func (w *Worker) getActionFinishedEvent(action *client.Action, output any) (*client.ActionEvent, error) {
|
|
event := w.getActionEvent(action, client.ActionEventTypeCompleted)
|
|
|
|
event.EventPayload = output
|
|
|
|
return event, nil
|
|
}
|
|
|
|
func (w *Worker) getGroupKeyActionFinishedEvent(action *client.Action, output string) (*client.ActionEvent, error) {
|
|
event := w.getActionEvent(action, client.ActionEventTypeCompleted)
|
|
|
|
event.EventPayload = output
|
|
|
|
return event, nil
|
|
}
|
|
|
|
func getHostName() string {
|
|
hostName, err := os.Hostname()
|
|
if err != nil {
|
|
hostName = "Unknown"
|
|
}
|
|
return hostName
|
|
}
|