Files
hatchet/internal/services/controllers/workflows/controller.go
T
Gabe Ruttner b7cec9ec53 feat: soft delete (#717)
* feat: soft delete workflows and versions

* feat: filter soft deletes wf and wfr

* feat: filter events and step runs

* fix: query

* fix: query

* chore: generate

* wip

* chore: squash migrations

* chore: separate retention into new service

* feat: regularly clean up

* chore: migrations

* fix: tests

* fix: queries

* fix: ambiguous

* fix: refs

* fix: ambiguous id

* fix: remove update from

* fix: soft delete

* fix: cleanup retention scheduler

* fix: has more query

* chore: gen

* fix: query

* fix: table
2024-07-18 09:06:05 -04:00

498 lines
14 KiB
Go

package workflows
import (
"context"
"fmt"
"sync"
"time"
"github.com/go-co-op/gocron/v2"
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
"github.com/hatchet-dev/hatchet/internal/datautils"
"github.com/hatchet-dev/hatchet/internal/integrations/alerting"
"github.com/hatchet-dev/hatchet/internal/msgqueue"
"github.com/hatchet-dev/hatchet/internal/services/shared/recoveryutils"
"github.com/hatchet-dev/hatchet/internal/services/shared/tasktypes"
"github.com/hatchet-dev/hatchet/internal/telemetry"
hatcheterrors "github.com/hatchet-dev/hatchet/pkg/errors"
"github.com/hatchet-dev/hatchet/pkg/logger"
"github.com/hatchet-dev/hatchet/pkg/repository"
"github.com/hatchet-dev/hatchet/pkg/repository/prisma/db"
"github.com/hatchet-dev/hatchet/pkg/repository/prisma/dbsqlc"
"github.com/hatchet-dev/hatchet/pkg/repository/prisma/sqlchelpers"
)
type WorkflowsController interface {
Start(ctx context.Context) error
}
type WorkflowsControllerImpl struct {
mq msgqueue.MessageQueue
l *zerolog.Logger
repo repository.EngineRepository
dv datautils.DataDecoderValidator
s gocron.Scheduler
tenantAlerter *alerting.TenantAlertManager
a *hatcheterrors.Wrapped
partitionId string
}
type WorkflowsControllerOpt func(*WorkflowsControllerOpts)
type WorkflowsControllerOpts struct {
mq msgqueue.MessageQueue
l *zerolog.Logger
repo repository.EngineRepository
dv datautils.DataDecoderValidator
ta *alerting.TenantAlertManager
alerter hatcheterrors.Alerter
partitionId string
}
func defaultWorkflowsControllerOpts() *WorkflowsControllerOpts {
logger := logger.NewDefaultLogger("workflows-controller")
alerter := hatcheterrors.NoOpAlerter{}
return &WorkflowsControllerOpts{
l: &logger,
dv: datautils.NewDataDecoderValidator(),
alerter: alerter,
}
}
func WithMessageQueue(mq msgqueue.MessageQueue) WorkflowsControllerOpt {
return func(opts *WorkflowsControllerOpts) {
opts.mq = mq
}
}
func WithLogger(l *zerolog.Logger) WorkflowsControllerOpt {
return func(opts *WorkflowsControllerOpts) {
opts.l = l
}
}
func WithRepository(r repository.EngineRepository) WorkflowsControllerOpt {
return func(opts *WorkflowsControllerOpts) {
opts.repo = r
}
}
func WithAlerter(a hatcheterrors.Alerter) WorkflowsControllerOpt {
return func(opts *WorkflowsControllerOpts) {
opts.alerter = a
}
}
func WithDataDecoderValidator(dv datautils.DataDecoderValidator) WorkflowsControllerOpt {
return func(opts *WorkflowsControllerOpts) {
opts.dv = dv
}
}
func WithTenantAlerter(ta *alerting.TenantAlertManager) WorkflowsControllerOpt {
return func(opts *WorkflowsControllerOpts) {
opts.ta = ta
}
}
func WithPartitionId(partitionId string) WorkflowsControllerOpt {
return func(opts *WorkflowsControllerOpts) {
opts.partitionId = partitionId
}
}
func New(fs ...WorkflowsControllerOpt) (*WorkflowsControllerImpl, error) {
opts := defaultWorkflowsControllerOpts()
for _, f := range fs {
f(opts)
}
if opts.mq == nil {
return nil, fmt.Errorf("task queue is required. use WithMessageQueue")
}
if opts.repo == nil {
return nil, fmt.Errorf("repository is required. use WithRepository")
}
if opts.ta == nil {
return nil, fmt.Errorf("tenant alerter is required. use WithTenantAlerter")
}
if opts.partitionId == "" {
return nil, fmt.Errorf("partition ID is required. use WithPartitionId")
}
s, err := gocron.NewScheduler(gocron.WithLocation(time.UTC))
if err != nil {
return nil, fmt.Errorf("could not create scheduler: %w", err)
}
newLogger := opts.l.With().Str("service", "workflows-controller").Logger()
opts.l = &newLogger
a := hatcheterrors.NewWrapped(opts.alerter)
a.WithData(map[string]interface{}{"service": "workflows-controller"})
return &WorkflowsControllerImpl{
mq: opts.mq,
l: opts.l,
repo: opts.repo,
dv: opts.dv,
s: s,
tenantAlerter: opts.ta,
a: a,
partitionId: opts.partitionId,
}, nil
}
func (wc *WorkflowsControllerImpl) Start() (func() error, error) {
wc.l.Debug().Msg("starting workflows controller")
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
_, err := wc.s.NewJob(
gocron.DurationJob(time.Second*5),
gocron.NewTask(
wc.runGetGroupKeyRunRequeue(ctx),
),
)
if err != nil {
cancel()
return nil, fmt.Errorf("could not schedule get group key run requeue: %w", err)
}
_, err = wc.s.NewJob(
gocron.DurationJob(time.Second*5),
gocron.NewTask(
wc.runGetGroupKeyRunReassign(ctx),
),
)
if err != nil {
cancel()
return nil, fmt.Errorf("could not schedule get group key run reassign: %w", err)
}
_, err = wc.s.NewJob(
gocron.DurationJob(time.Second*15),
gocron.NewTask(
wc.runPollActiveQueues(ctx),
),
)
if err != nil {
cancel()
return nil, fmt.Errorf("could not poll active queues: %w", err)
}
wc.s.Start()
f := func(task *msgqueue.Message) error {
wg.Add(1)
defer wg.Done()
err := wc.handleTask(context.Background(), task)
if err != nil {
wc.l.Error().Err(err).Msg("could not handle job task")
return err
}
return nil
}
cleanupQueue, err := wc.mq.Subscribe(msgqueue.WORKFLOW_PROCESSING_QUEUE, f, msgqueue.NoOpHook)
if err != nil {
cancel()
return nil, err
}
cleanup := func() error {
cancel()
if err := cleanupQueue(); err != nil {
return fmt.Errorf("could not cleanup queue: %w", err)
}
wg.Wait()
if err := wc.s.Shutdown(); err != nil {
return fmt.Errorf("could not shutdown scheduler: %w", err)
}
return nil
}
return cleanup, nil
}
func (wc *WorkflowsControllerImpl) handleTask(ctx context.Context, task *msgqueue.Message) (err error) {
defer func() {
if r := recover(); r != nil {
recoverErr := recoveryutils.RecoverWithAlert(wc.l, wc.a, r)
if recoverErr != nil {
err = recoverErr
}
}
}()
switch task.ID {
case "workflow-run-queued":
return wc.handleWorkflowRunQueued(ctx, task)
case "get-group-key-run-started":
return wc.handleGroupKeyRunStarted(ctx, task)
case "get-group-key-run-finished":
return wc.handleGroupKeyRunFinished(ctx, task)
case "get-group-key-run-failed":
return wc.handleGroupKeyRunFailed(ctx, task)
case "get-group-key-run-timed-out":
return wc.handleGetGroupKeyRunTimedOut(ctx, task)
case "workflow-run-finished":
return wc.handleWorkflowRunFinished(ctx, task)
}
return fmt.Errorf("unknown task: %s", task.ID)
}
func (ec *WorkflowsControllerImpl) handleGroupKeyRunStarted(ctx context.Context, task *msgqueue.Message) error {
ctx, span := telemetry.NewSpan(ctx, "get-group-key-run-started") // nolint:ineffassign
defer span.End()
payload := tasktypes.GetGroupKeyRunStartedTaskPayload{}
metadata := tasktypes.GetGroupKeyRunStartedTaskMetadata{}
err := ec.dv.DecodeAndValidate(task.Payload, &payload)
if err != nil {
return fmt.Errorf("could not decode group key run started task payload: %w", err)
}
err = ec.dv.DecodeAndValidate(task.Metadata, &metadata)
if err != nil {
return fmt.Errorf("could not decode group key run started task metadata: %w", err)
}
// update the get group key run in the database
startedAt, err := time.Parse(time.RFC3339, payload.StartedAt)
if err != nil {
return fmt.Errorf("could not parse started at: %w", err)
}
_, err = ec.repo.GetGroupKeyRun().UpdateGetGroupKeyRun(ctx, metadata.TenantId, payload.GetGroupKeyRunId, &repository.UpdateGetGroupKeyRunOpts{
StartedAt: &startedAt,
Status: repository.StepRunStatusPtr(db.StepRunStatusRunning),
})
return err
}
func (wc *WorkflowsControllerImpl) handleGroupKeyRunFinished(ctx context.Context, task *msgqueue.Message) error {
ctx, span := telemetry.NewSpan(ctx, "handle-group-key-run-finished")
defer span.End()
payload := tasktypes.GetGroupKeyRunFinishedTaskPayload{}
metadata := tasktypes.GetGroupKeyRunFinishedTaskMetadata{}
err := wc.dv.DecodeAndValidate(task.Payload, &payload)
if err != nil {
return fmt.Errorf("could not decode group key run finished task payload: %w", err)
}
err = wc.dv.DecodeAndValidate(task.Metadata, &metadata)
if err != nil {
return fmt.Errorf("could not decode group key run finished task metadata: %w", err)
}
// update the group key run in the database
finishedAt, err := time.Parse(time.RFC3339, payload.FinishedAt)
if err != nil {
return fmt.Errorf("could not parse started at: %w", err)
}
groupKeyRun, err := wc.repo.GetGroupKeyRun().UpdateGetGroupKeyRun(ctx, metadata.TenantId, payload.GetGroupKeyRunId, &repository.UpdateGetGroupKeyRunOpts{
FinishedAt: &finishedAt,
Status: repository.StepRunStatusPtr(db.StepRunStatusSucceeded),
Output: &payload.GroupKey,
})
if err != nil {
return fmt.Errorf("could not update group key run: %w", err)
}
errGroup := new(errgroup.Group)
errGroup.Go(func() error {
workflowVersionId := sqlchelpers.UUIDToStr(groupKeyRun.WorkflowVersionId)
err := wc.bumpQueue(ctx, metadata.TenantId, workflowVersionId, &payload.GroupKey)
return err
})
return errGroup.Wait()
}
func (wc *WorkflowsControllerImpl) runPollActiveQueues(ctx context.Context) func() {
return func() {
wc.l.Debug().Msg("polling active queues")
toQueueList, err := wc.repo.WorkflowRun().ListActiveQueuedWorkflowVersions(ctx)
if err != nil {
wc.l.Error().Err(err).Msg("could not list active queued workflow versions")
return
}
errGroup := new(errgroup.Group)
for i := range toQueueList {
toQueue := toQueueList[i]
errGroup.Go(func() error {
workflowVersionId := sqlchelpers.UUIDToStr(toQueue.WorkflowVersionId)
tenantId := sqlchelpers.UUIDToStr(toQueue.TenantId)
var key *string
if toQueue.ConcurrencyGroupId.Valid {
key = &toQueue.ConcurrencyGroupId.String
}
err := wc.bumpQueue(ctx, tenantId, workflowVersionId, key)
return err
})
}
err = errGroup.Wait()
if err != nil {
wc.l.Error().Err(err)
}
}
}
func (wc *WorkflowsControllerImpl) bumpQueue(ctx context.Context, tenantId string, workflowVersionId string, groupKey *string) error {
workflowVersion, err := wc.repo.Workflow().GetWorkflowVersionById(ctx, tenantId, workflowVersionId)
if err != nil {
return fmt.Errorf("could not get workflow version: %w", err)
}
if workflowVersion.ConcurrencyLimitStrategy.Valid {
switch workflowVersion.ConcurrencyLimitStrategy.ConcurrencyLimitStrategy {
case dbsqlc.ConcurrencyLimitStrategyCANCELINPROGRESS:
if groupKey == nil {
return fmt.Errorf("group key is required for cancel in progress strategy")
}
err = wc.queueByCancelInProgress(ctx, tenantId, *groupKey, workflowVersion)
case dbsqlc.ConcurrencyLimitStrategyGROUPROUNDROBIN:
err = wc.queueByGroupRoundRobin(ctx, tenantId, workflowVersion)
default:
return fmt.Errorf("unimplemented concurrency limit strategy: %s", workflowVersion.ConcurrencyLimitStrategy.ConcurrencyLimitStrategy)
}
}
return err
}
func (wc *WorkflowsControllerImpl) handleGroupKeyRunFailed(ctx context.Context, task *msgqueue.Message) error {
ctx, span := telemetry.NewSpan(ctx, "handle-group-key-run-failed") // nolint: ineffassign
defer span.End()
payload := tasktypes.GetGroupKeyRunFailedTaskPayload{}
metadata := tasktypes.GetGroupKeyRunFailedTaskMetadata{}
err := wc.dv.DecodeAndValidate(task.Payload, &payload)
if err != nil {
return fmt.Errorf("could not decode group key run failed task payload: %w", err)
}
err = wc.dv.DecodeAndValidate(task.Metadata, &metadata)
if err != nil {
return fmt.Errorf("could not decode group key run failed task metadata: %w", err)
}
// update the group key run in the database
failedAt, err := time.Parse(time.RFC3339, payload.FailedAt)
if err != nil {
return fmt.Errorf("could not parse started at: %w", err)
}
_, err = wc.repo.GetGroupKeyRun().UpdateGetGroupKeyRun(ctx, metadata.TenantId, payload.GetGroupKeyRunId, &repository.UpdateGetGroupKeyRunOpts{
FinishedAt: &failedAt,
Error: &payload.Error,
Status: repository.StepRunStatusPtr(db.StepRunStatusFailed),
})
if err != nil {
return fmt.Errorf("could not update get group key run: %w", err)
}
return nil
}
func (wc *WorkflowsControllerImpl) handleGetGroupKeyRunTimedOut(ctx context.Context, task *msgqueue.Message) error {
ctx, span := telemetry.NewSpan(ctx, "handle-get-group-key-run-timed-out")
defer span.End()
payload := tasktypes.GetGroupKeyRunTimedOutTaskPayload{}
metadata := tasktypes.GetGroupKeyRunTimedOutTaskMetadata{}
err := wc.dv.DecodeAndValidate(task.Payload, &payload)
if err != nil {
return fmt.Errorf("could not decode get group key run run timed out task payload: %w", err)
}
err = wc.dv.DecodeAndValidate(task.Metadata, &metadata)
if err != nil {
return fmt.Errorf("could not decode get group key run run timed out task metadata: %w", err)
}
return wc.cancelGetGroupKeyRun(ctx, metadata.TenantId, payload.GetGroupKeyRunId, "TIMED_OUT")
}
func (wc *WorkflowsControllerImpl) cancelGetGroupKeyRun(ctx context.Context, tenantId, getGroupKeyRunId, reason string) error {
ctx, span := telemetry.NewSpan(ctx, "cancel-get-group-key-run") // nolint: ineffassign
defer span.End()
// cancel current step run
now := time.Now().UTC()
groupKeyRun, err := wc.repo.GetGroupKeyRun().UpdateGetGroupKeyRun(ctx, tenantId, getGroupKeyRunId, &repository.UpdateGetGroupKeyRunOpts{
CancelledAt: &now,
CancelledReason: repository.StringPtr(reason),
Status: repository.StepRunStatusPtr(db.StepRunStatusCancelled),
})
if err != nil {
return fmt.Errorf("could not update step run: %w", err)
}
// cancel all existing jobs on the workflow run
workflowRunId := sqlchelpers.UUIDToStr(groupKeyRun.WorkflowRunId)
workflowRun, err := wc.repo.WorkflowRun().GetWorkflowRunById(ctx, tenantId, workflowRunId)
if err != nil {
return fmt.Errorf("could not get workflow run: %w", err)
}
return wc.cancelWorkflowRunJobs(ctx, workflowRun)
}