Files
hatchet/pkg/repository/v1/task.go
2025-10-10 14:08:09 -04:00

3619 lines
106 KiB
Go

package v1
import (
"context"
"encoding/json"
"errors"
"fmt"
"hash/fnv"
"sort"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"github.com/hatchet-dev/hatchet/internal/cel"
"github.com/hatchet-dev/hatchet/internal/telemetry"
"github.com/hatchet-dev/hatchet/pkg/repository/postgres/sqlchelpers"
"github.com/hatchet-dev/hatchet/pkg/repository/v1/sqlcv1"
)
type CreateTaskOpts struct {
// (required) the external id
ExternalId string `validate:"required,uuid"`
// (required) the workflow run id. note this may be the same as the external id if this is a
// single-task workflow, otherwise it represents the external id of the DAG.
WorkflowRunId string `validate:"required,uuid"`
// (required) the step id
StepId string `validate:"required,uuid"`
// (required) the input bytes to the task
Input *TaskInput
FilterPayload []byte
// (required) the step index for the task
StepIndex int
// (optional) the additional metadata for the task
AdditionalMetadata []byte
// (optional) the desired worker id
DesiredWorkerId *string
// (optional) the DAG id for the task
DagId *int64
// (optional) the DAG inserted at for the task
DagInsertedAt pgtype.Timestamptz
// (required) the initial state for the task
InitialState sqlcv1.V1TaskInitialState
// (optional) the parent task external id
ParentTaskExternalId *string
// (optional) the parent task id
ParentTaskId *int64
// (optional) the parent task inserted at
ParentTaskInsertedAt *time.Time
// (optional) The priority of a task, between 1 and 3
Priority *int32
// (optional) the child index for the task
ChildIndex *int64
// (optional) the child key for the task
ChildKey *string
}
type ReplayTasksResult struct {
ReplayedTasks []TaskIdInsertedAtRetryCount
UpsertedTasks []*V1TaskWithPayload
InternalEventResults *EventMatchResults
}
type ReplayTaskOpts struct {
// (required) the task id
TaskId int64
// (required) the inserted at time
InsertedAt pgtype.Timestamptz
// (required) the external id
ExternalId string
// (required) the step id
StepId string
// (optional) the input bytes to the task, uses the existing input if not set
Input *TaskInput
// (required) the initial state for the task
InitialState sqlcv1.V1TaskInitialState
// (optional) the additional metadata for the task
AdditionalMetadata []byte
}
type TaskIdInsertedAtRetryCount struct {
// (required) the external id
Id int64 `validate:"required"`
// (required) the inserted at time
InsertedAt pgtype.Timestamptz
// (required) the retry count
RetryCount int32
}
type TaskIdInsertedAtSignalKey struct {
// (required) the external id
Id int64 `validate:"required"`
// (required) the inserted at time
InsertedAt pgtype.Timestamptz
// (required) the signal key for the event
SignalKey string
}
type CompleteTaskOpts struct {
*TaskIdInsertedAtRetryCount
// (required) the output bytes for the task
Output []byte
}
type FailTaskOpts struct {
*TaskIdInsertedAtRetryCount
// (required) whether this is an application-level error or an internal error on the Hatchet side
IsAppError bool
// (optional) the error message for the task
ErrorMessage string
// (optional) A boolean flag to indicate whether the error is non-retryable, meaning it should _not_ be retried. Defaults to false.
IsNonRetryable bool
}
type TaskIdEventKeyTuple struct {
Id int64 `validate:"required"`
EventKey string `validate:"required"`
}
// InternalTaskEvent resembles sqlcv1.V1TaskEvent, but doesn't include the id field as we
// use COPY FROM to write the events to the database.
type InternalTaskEvent struct {
TenantID string `json:"tenant_id"`
TaskID int64 `json:"task_id"`
TaskExternalID string `json:"task_external_id"`
RetryCount int32 `json:"retry_count"`
EventType sqlcv1.V1TaskEventType `json:"event_type"`
EventKey string `json:"event_key"`
Data []byte `json:"data"`
}
type FinalizedTaskResponse struct {
ReleasedTasks []*sqlcv1.ReleaseTasksRow
InternalEvents []InternalTaskEvent
}
type RetriedTask struct {
*TaskIdInsertedAtRetryCount
IsAppError bool
AppRetryCount int32
RetryBackoffFactor pgtype.Float8
RetryMaxBackoff pgtype.Int4
}
type FailTasksResponse struct {
*FinalizedTaskResponse
RetriedTasks []RetriedTask
}
type TimeoutTasksResponse struct {
*FailTasksResponse
TimeoutTasks []*sqlcv1.ListTasksToTimeoutRow
}
type ListFinalizedWorkflowRunsResponse struct {
WorkflowRunId string
OutputEvents []*TaskOutputEvent
}
type RefreshTimeoutBy struct {
TaskExternalId string `validate:"required,uuid"`
IncrementTimeoutBy string `validate:"required,duration"`
}
type TaskRepository interface {
EnsureTablePartitionsExist(ctx context.Context) (bool, error)
UpdateTablePartitions(ctx context.Context) error
// GetTaskByExternalId is a heavily cached method to return task metadata by its external id
GetTaskByExternalId(ctx context.Context, tenantId, taskExternalId string, skipCache bool) (*sqlcv1.FlattenExternalIdsRow, error)
// FlattenExternalIds is a non-cached method to look up all tasks in a workflow run by their external ids.
// This is non-cacheable because tasks can be added to a workflow run as it executes.
FlattenExternalIds(ctx context.Context, tenantId string, externalIds []string) ([]*sqlcv1.FlattenExternalIdsRow, error)
CompleteTasks(ctx context.Context, tenantId string, tasks []CompleteTaskOpts) (*FinalizedTaskResponse, error)
FailTasks(ctx context.Context, tenantId string, tasks []FailTaskOpts) (*FailTasksResponse, error)
CancelTasks(ctx context.Context, tenantId string, tasks []TaskIdInsertedAtRetryCount) (*FinalizedTaskResponse, error)
ListTasks(ctx context.Context, tenantId string, tasks []int64) ([]*sqlcv1.V1Task, error)
ListTaskMetas(ctx context.Context, tenantId string, tasks []int64) ([]*sqlcv1.ListTaskMetasRow, error)
ListFinalizedWorkflowRuns(ctx context.Context, tenantId string, rootExternalIds []string) ([]*ListFinalizedWorkflowRunsResponse, error)
// ListTaskParentOutputs is a method to return the output of a task's parent and grandparent tasks. This is for v0 compatibility
// with the v1 engine, and shouldn't be called from new v1 endpoints.
ListTaskParentOutputs(ctx context.Context, tenantId string, tasks []*sqlcv1.V1Task) (map[int64][]*TaskOutputEvent, error)
ProcessTaskTimeouts(ctx context.Context, tenantId string) (*TimeoutTasksResponse, bool, error)
ProcessTaskReassignments(ctx context.Context, tenantId string) (*FailTasksResponse, bool, error)
ProcessTaskRetryQueueItems(ctx context.Context, tenantId string) ([]*sqlcv1.V1RetryQueueItem, bool, error)
ProcessDurableSleeps(ctx context.Context, tenantId string) (*EventMatchResults, bool, error)
GetQueueCounts(ctx context.Context, tenantId string) (map[string]interface{}, error)
ReplayTasks(ctx context.Context, tenantId string, tasks []TaskIdInsertedAtRetryCount) (*ReplayTasksResult, error)
RefreshTimeoutBy(ctx context.Context, tenantId string, opt RefreshTimeoutBy) (*sqlcv1.V1TaskRuntime, error)
ReleaseSlot(ctx context.Context, tenantId string, externalId string) (*sqlcv1.V1TaskRuntime, error)
ListSignalCompletedEvents(ctx context.Context, tenantId string, tasks []TaskIdInsertedAtSignalKey) ([]*V1TaskEventWithPayload, error)
// AnalyzeTaskTables runs ANALYZE on the task tables
AnalyzeTaskTables(ctx context.Context) error
}
type TaskRepositoryImpl struct {
*sharedRepository
taskRetentionPeriod time.Duration
maxInternalRetryCount int32
timeoutLimit int
reassignLimit int
retryQueueLimit int
durableSleepLimit int
}
func newTaskRepository(s *sharedRepository, taskRetentionPeriod time.Duration, maxInternalRetryCount int32, timeoutLimit, reassignLimit, retryQueueLimit, durableSleepLimit int) TaskRepository {
return &TaskRepositoryImpl{
sharedRepository: s,
taskRetentionPeriod: taskRetentionPeriod,
maxInternalRetryCount: maxInternalRetryCount,
timeoutLimit: timeoutLimit,
reassignLimit: reassignLimit,
retryQueueLimit: retryQueueLimit,
durableSleepLimit: durableSleepLimit,
}
}
func (r *TaskRepositoryImpl) EnsureTablePartitionsExist(ctx context.Context) (bool, error) {
return r.queries.EnsureTablePartitionsExist(ctx, r.pool)
}
func (r *TaskRepositoryImpl) UpdateTablePartitions(ctx context.Context) error {
exists, err := r.EnsureTablePartitionsExist(ctx)
if err != nil {
return fmt.Errorf("failed to check if table partitions exist: %w", err)
}
if exists {
r.l.Debug().Msg("table partitions already exist, skipping")
return nil
}
const PARTITION_LOCK_OFFSET = 9000000000000000000
const partitionLockKey = PARTITION_LOCK_OFFSET + 1
tx, commit, rollback, err := sqlchelpers.PrepareTx(ctx, r.pool, r.l, 600000) // 10 minutes
if err != nil {
return fmt.Errorf("failed to prepare transaction: %w", err)
}
defer rollback()
acquired, err := r.queries.TryAdvisoryLock(ctx, tx, partitionLockKey)
if err != nil {
return fmt.Errorf("failed to try advisory lock for partition operations: %w", err)
}
if !acquired {
r.l.Debug().Msg("partition operations already running on another controller instance, skipping")
return nil
}
r.l.Debug().Msg("acquired advisory lock for partition operations")
today := time.Now().UTC()
tomorrow := today.AddDate(0, 0, 1)
removeBefore := today.Add(-1 * r.taskRetentionPeriod)
err = r.queries.CreatePartitions(ctx, r.pool, pgtype.Date{
Time: today,
Valid: true,
})
if err != nil {
return err
}
err = r.queries.CreatePartitions(ctx, r.pool, pgtype.Date{
Time: tomorrow,
Valid: true,
})
if err != nil {
return err
}
partitions, err := r.queries.ListPartitionsBeforeDate(ctx, r.pool, pgtype.Date{
Time: removeBefore,
Valid: true,
})
if err != nil {
return err
}
if len(partitions) > 0 {
r.l.Warn().Msgf("removing partitions before %s using retention period of %s", removeBefore.Format(time.RFC3339), r.taskRetentionPeriod)
}
for _, partition := range partitions {
r.l.Warn().Msgf("detaching partition %s", partition.PartitionName)
_, err := r.pool.Exec(
ctx,
fmt.Sprintf("ALTER TABLE %s DETACH PARTITION %s CONCURRENTLY", partition.ParentTable, partition.PartitionName),
)
if err != nil {
return err
}
_, err = r.pool.Exec(
ctx,
fmt.Sprintf("DROP TABLE %s", partition.PartitionName),
)
if err != nil {
return err
}
}
err = commit(ctx)
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
func (r *TaskRepositoryImpl) GetTaskByExternalId(ctx context.Context, tenantId, taskExternalId string, skipCache bool) (*sqlcv1.FlattenExternalIdsRow, error) {
if !skipCache {
// check the cache first
key := taskExternalIdTenantIdTuple{
externalId: taskExternalId,
tenantId: tenantId,
}
if val, ok := r.taskLookupCache.Get(key); ok {
return val, nil
}
}
// lookup the task
dbTasks, err := r.queries.FlattenExternalIds(ctx, r.pool, sqlcv1.FlattenExternalIdsParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Externalids: []pgtype.UUID{sqlchelpers.UUIDFromStr(taskExternalId)},
})
if err != nil {
return nil, err
}
if len(dbTasks) == 0 {
return nil, pgx.ErrNoRows
}
if len(dbTasks) > 1 {
return nil, fmt.Errorf("found more than one task for %s", taskExternalId)
}
// set the cache
res := dbTasks[0]
key := taskExternalIdTenantIdTuple{
externalId: taskExternalId,
tenantId: tenantId,
}
r.taskLookupCache.Add(key, res)
return res, nil
}
func (r *TaskRepositoryImpl) FlattenExternalIds(ctx context.Context, tenantId string, externalIds []string) ([]*sqlcv1.FlattenExternalIdsRow, error) {
return r.lookupExternalIds(ctx, r.pool, tenantId, externalIds)
}
func (r *sharedRepository) lookupExternalIds(ctx context.Context, tx sqlcv1.DBTX, tenantId string, externalIds []string) ([]*sqlcv1.FlattenExternalIdsRow, error) {
externalIdsToLookup := make([]pgtype.UUID, 0, len(externalIds))
res := make([]*sqlcv1.FlattenExternalIdsRow, 0, len(externalIds))
for _, externalId := range externalIds {
if externalId == "" {
r.l.Error().Msgf("passed in empty external id")
continue
}
externalIdsToLookup = append(externalIdsToLookup, sqlchelpers.UUIDFromStr(externalId))
}
// lookup the task
dbTasks, err := r.queries.FlattenExternalIds(ctx, tx, sqlcv1.FlattenExternalIdsParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Externalids: externalIdsToLookup,
})
if err != nil {
return nil, err
}
// set the cache
groupedExternalIds := make(map[string][]*sqlcv1.FlattenExternalIdsRow)
for _, task := range dbTasks {
rootExternalId := sqlchelpers.UUIDToStr(task.WorkflowRunExternalID)
groupedExternalIds[rootExternalId] = append(groupedExternalIds[rootExternalId], task)
}
for _, tasks := range groupedExternalIds {
res = append(res, tasks...)
}
return res, nil
}
func (r *TaskRepositoryImpl) verifyAllTasksFinalized(ctx context.Context, tx sqlcv1.DBTX, tenantId string, flattenedTasks []*sqlcv1.FlattenExternalIdsRow) ([]string, map[string]int64, error) {
taskIdsToCheck := make([]int64, len(flattenedTasks))
taskInsertedAtsToCheck := make([]pgtype.Timestamptz, len(flattenedTasks))
taskIdsToTasks := make(map[int64]*sqlcv1.FlattenExternalIdsRow)
minInsertedAt := sqlchelpers.TimestamptzFromTime(time.Now()) // current time as a placeholder - will be overwritten
for i, task := range flattenedTasks {
taskIdsToCheck[i] = task.ID
taskInsertedAtsToCheck[i] = task.InsertedAt
taskIdsToTasks[task.ID] = task
if task.InsertedAt.Time.Before(minInsertedAt.Time) {
minInsertedAt = task.InsertedAt
}
}
// run preflight check on tasks
notFinalized, err := r.queries.PreflightCheckTasksForReplay(ctx, tx, sqlcv1.PreflightCheckTasksForReplayParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Taskids: taskIdsToCheck,
Taskinsertedats: taskInsertedAtsToCheck,
Mininsertedat: minInsertedAt,
})
if err != nil {
return nil, nil, err
}
notFinalizedMap := make(map[int64]bool)
for _, task := range notFinalized {
notFinalizedMap[task.ID] = true
}
dagsToCheck := make([]int64, 0)
dagsToTasks := make(map[int64][]*sqlcv1.FlattenExternalIdsRow)
for _, task := range flattenedTasks {
if !notFinalizedMap[task.ID] && task.DagID.Valid {
dagsToCheck = append(dagsToCheck, task.DagID.Int64)
dagsToTasks[task.DagID.Int64] = append(dagsToTasks[task.DagID.Int64], task)
}
}
// check DAGs
notFinalizedDags, err := r.queries.PreflightCheckDAGsForReplay(ctx, tx, sqlcv1.PreflightCheckDAGsForReplayParams{
Dagids: dagsToCheck,
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
})
if err != nil {
return nil, nil, err
}
notFinalizedDAGsMap := make(map[int64]bool)
finalizedDAGToStepCount := make(map[string]int64)
for _, dag := range notFinalizedDags {
if dag.StepCount != dag.TaskCount {
notFinalizedDAGsMap[dag.ID] = true
} else {
rootId := sqlchelpers.UUIDToStr(dag.ExternalID)
finalizedDAGToStepCount[rootId] = dag.StepCount
}
}
candidateFinalizedRootExternalIds := make(map[string]bool, 0)
for _, task := range flattenedTasks {
candidateFinalizedRootExternalIds[sqlchelpers.UUIDToStr(task.WorkflowRunExternalID)] = true
}
// iterate through tasks one last time
for _, task := range flattenedTasks {
rootId := sqlchelpers.UUIDToStr(task.WorkflowRunExternalID)
// if root is already non-finalized, skip
if !candidateFinalizedRootExternalIds[rootId] {
continue
}
if notFinalizedMap[task.ID] {
candidateFinalizedRootExternalIds[rootId] = false
continue
}
if task.DagID.Valid && notFinalizedDAGsMap[task.DagID.Int64] {
candidateFinalizedRootExternalIds[rootId] = false
continue
}
}
finalizedRootExternalIds := make([]string, 0)
for rootId, finalized := range candidateFinalizedRootExternalIds {
if finalized {
finalizedRootExternalIds = append(finalizedRootExternalIds, rootId)
}
}
return finalizedRootExternalIds, finalizedDAGToStepCount, nil
}
func (r *TaskRepositoryImpl) CompleteTasks(ctx context.Context, tenantId string, tasks []CompleteTaskOpts) (*FinalizedTaskResponse, error) {
ctx, span := telemetry.NewSpan(ctx, "TaskRepositoryImpl.CompleteTasks")
defer span.End()
// TODO: ADD BACK VALIDATION
// if err := r.v.Validate(tasks); err != nil {
// fmt.Println("FAILED VALIDATION HERE!!!")
// return err
// }
taskIdRetryCounts := make([]TaskIdInsertedAtRetryCount, len(tasks))
for i, task := range tasks {
taskIdRetryCounts[i] = *task.TaskIdInsertedAtRetryCount
}
tx, commit, rollback, err := sqlchelpers.PrepareTx(ctx, r.pool, r.l, 5000)
if err != nil {
err = fmt.Errorf("failed to prepare tx: %w", err)
span.RecordError(err)
span.SetStatus(codes.Error, "failed to prepare tx")
return nil, err
}
defer rollback()
taskIdRetryCounts = uniqueSet(taskIdRetryCounts)
// release queue items
releasedTasks, err := r.releaseTasks(ctx, tx, tenantId, taskIdRetryCounts)
if err != nil {
err = fmt.Errorf("failed to release tasks: %w", err)
span.RecordError(err)
span.SetStatus(codes.Error, "failed to release tasks")
return nil, err
}
if len(taskIdRetryCounts) != len(releasedTasks) {
return nil, fmt.Errorf("failed to release all tasks")
}
outputs := make([][]byte, len(releasedTasks))
for i, releasedTask := range releasedTasks {
out := NewCompletedTaskOutputEvent(releasedTask, tasks[i].Output).Bytes()
outputs[i] = out
}
internalEvents, err := r.createTaskEventsAfterRelease(
ctx,
tx,
tenantId,
taskIdRetryCounts,
outputs,
releasedTasks,
sqlcv1.V1TaskEventTypeCOMPLETED,
)
if err != nil {
err = fmt.Errorf("failed to create task events after release: %w", err)
span.RecordError(err)
span.SetStatus(codes.Error, "failed to create task events after release")
return nil, err
}
// commit the transaction
if err := commit(ctx); err != nil {
err = fmt.Errorf("failed to commit transaction: %w", err)
span.RecordError(err)
span.SetStatus(codes.Error, "failed to commit transaction")
return nil, err
}
return &FinalizedTaskResponse{
ReleasedTasks: releasedTasks,
InternalEvents: internalEvents,
}, nil
}
func (r *TaskRepositoryImpl) FailTasks(ctx context.Context, tenantId string, failureOpts []FailTaskOpts) (*FailTasksResponse, error) {
ctx, span := telemetry.NewSpan(ctx, "TaskRepositoryImpl.FailTasks")
defer span.End()
tx, commit, rollback, err := sqlchelpers.PrepareTx(ctx, r.pool, r.l, 5000)
if err != nil {
err = fmt.Errorf("failed to prepare tx: %w", err)
span.RecordError(err)
span.SetStatus(codes.Error, "failed to prepare tx")
return nil, err
}
defer rollback()
res, err := r.failTasksTx(ctx, tx, tenantId, failureOpts)
if err != nil {
err = fmt.Errorf("failed to fail tasks: %w", err)
span.RecordError(err)
span.SetStatus(codes.Error, "failed to fail tasks")
return nil, err
}
// commit the transaction
if err := commit(ctx); err != nil {
err = fmt.Errorf("failed to commit transaction: %w", err)
span.RecordError(err)
span.SetStatus(codes.Error, "failed to commit transaction")
return nil, err
}
return res, nil
}
func (r *TaskRepositoryImpl) failTasksTx(ctx context.Context, tx sqlcv1.DBTX, tenantId string, failureOpts []FailTaskOpts) (*FailTasksResponse, error) {
// TODO: ADD BACK VALIDATION
// if err := r.v.Validate(tasks); err != nil {
// fmt.Println("FAILED VALIDATION HERE!!!")
// return err
// }
ctx, span := telemetry.NewSpan(ctx, "tasks_repository_impl.fail_tasks_tx")
defer span.End()
tasks := make([]TaskIdInsertedAtRetryCount, len(failureOpts))
appFailureTaskIds := make([]int64, 0)
appFailureTaskInsertedAts := make([]pgtype.Timestamptz, 0)
appFailureTaskRetryCounts := make([]int32, 0)
appFailureIsNonRetryableStatuses := make([]bool, 0)
internalFailureTaskIds := make([]int64, 0)
internalFailureInsertedAts := make([]pgtype.Timestamptz, 0)
internalFailureTaskRetryCounts := make([]int32, 0)
for i, failureOpt := range failureOpts {
tasks[i] = *failureOpt.TaskIdInsertedAtRetryCount
if failureOpt.IsAppError {
appFailureTaskIds = append(appFailureTaskIds, failureOpt.Id)
appFailureTaskInsertedAts = append(appFailureTaskInsertedAts, failureOpt.InsertedAt)
appFailureTaskRetryCounts = append(appFailureTaskRetryCounts, failureOpt.RetryCount)
appFailureIsNonRetryableStatuses = append(appFailureIsNonRetryableStatuses, failureOpt.IsNonRetryable)
} else {
internalFailureTaskIds = append(internalFailureTaskIds, failureOpt.Id)
internalFailureInsertedAts = append(internalFailureInsertedAts, failureOpt.InsertedAt)
internalFailureTaskRetryCounts = append(internalFailureTaskRetryCounts, failureOpt.RetryCount)
}
}
tasks = uniqueSet(tasks)
retriedTasks := make([]RetriedTask, 0)
// write app failures
if len(appFailureTaskIds) > 0 {
span.SetAttributes(
attribute.KeyValue{
Key: "tasks_repository_impl.fail_tasks_tx.fail_task_app_failure.batch_size",
Value: attribute.IntValue(len(appFailureTaskIds)),
},
)
appFailureRetries, err := r.queries.FailTaskAppFailure(ctx, tx, sqlcv1.FailTaskAppFailureParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Taskids: appFailureTaskIds,
Taskinsertedats: appFailureTaskInsertedAts,
Taskretrycounts: appFailureTaskRetryCounts,
Isnonretryables: appFailureIsNonRetryableStatuses,
})
if err != nil {
return nil, err
}
for _, task := range appFailureRetries {
retriedTasks = append(retriedTasks, RetriedTask{
TaskIdInsertedAtRetryCount: &TaskIdInsertedAtRetryCount{
Id: task.ID,
InsertedAt: task.InsertedAt,
RetryCount: task.RetryCount,
},
IsAppError: true,
AppRetryCount: task.AppRetryCount,
RetryBackoffFactor: task.RetryBackoffFactor,
RetryMaxBackoff: task.RetryMaxBackoff,
},
)
}
}
// write internal failures
if len(internalFailureTaskIds) > 0 {
span.SetAttributes(
attribute.KeyValue{
Key: "tasks_repository_impl.fail_tasks_tx.fail_task_internal_failure.batch_size",
Value: attribute.IntValue(len(internalFailureTaskIds)),
},
)
internalFailureRetries, err := r.queries.FailTaskInternalFailure(ctx, tx, sqlcv1.FailTaskInternalFailureParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Taskids: internalFailureTaskIds,
Taskinsertedats: internalFailureInsertedAts,
Taskretrycounts: internalFailureTaskRetryCounts,
Maxinternalretries: r.maxInternalRetryCount,
})
if err != nil {
return nil, err
}
for _, task := range internalFailureRetries {
retriedTasks = append(retriedTasks, RetriedTask{
TaskIdInsertedAtRetryCount: &TaskIdInsertedAtRetryCount{
Id: task.ID,
InsertedAt: task.InsertedAt,
RetryCount: task.RetryCount,
},
})
}
}
// release queue items
// NOTE: it's important that we do this after we've written the retries, as some of the triggers for concurrency
// slots case on the retry queue item's existence.
releasedTasks, err := r.releaseTasks(ctx, tx, tenantId, tasks)
if err != nil {
return nil, err
}
outputs := make([][]byte, len(releasedTasks))
for i, releasedTask := range releasedTasks {
out := NewFailedTaskOutputEvent(releasedTask, failureOpts[i].ErrorMessage).Bytes()
outputs[i] = out
}
internalEvents, err := r.createTaskEventsAfterRelease(
ctx,
tx,
tenantId,
tasks,
outputs,
releasedTasks,
sqlcv1.V1TaskEventTypeFAILED,
)
if err != nil {
return nil, err
}
return &FailTasksResponse{
FinalizedTaskResponse: &FinalizedTaskResponse{
ReleasedTasks: releasedTasks,
InternalEvents: internalEvents,
},
RetriedTasks: retriedTasks,
}, nil
}
func (r *TaskRepositoryImpl) ListFinalizedWorkflowRuns(ctx context.Context, tenantId string, rootExternalIds []string) ([]*ListFinalizedWorkflowRunsResponse, error) {
start := time.Now()
checkpoint := time.Now()
tx, commit, rollback, err := sqlchelpers.PrepareTx(ctx, r.pool, r.l, 30000)
if err != nil {
return nil, err
}
defer rollback()
externalIdsToEvents := make(map[string][]*TaskOutputEvent)
tasks, err := r.lookupExternalIds(ctx, tx, tenantId, rootExternalIds)
if err != nil {
return nil, err
}
durLookup := time.Since(checkpoint)
checkpoint = time.Now()
finalizedRootIds, rootIdToStepCounts, err := r.verifyAllTasksFinalized(ctx, tx, tenantId, tasks)
if err != nil {
return nil, err
}
durVerify := time.Since(checkpoint)
checkpoint = time.Now()
taskExternalIds := make([]string, 0, len(tasks))
taskExternalIdsToRootIds := make(map[string]string)
for _, task := range tasks {
taskExternalIds = append(taskExternalIds, sqlchelpers.UUIDToStr(task.ExternalID))
taskExternalIdsToRootIds[sqlchelpers.UUIDToStr(task.ExternalID)] = sqlchelpers.UUIDToStr(task.WorkflowRunExternalID)
}
outputEvents, err := r.listTaskOutputEvents(ctx, tx, tenantId, taskExternalIds)
if err != nil {
return nil, err
}
durOutputEvents := time.Since(checkpoint)
if err := commit(ctx); err != nil {
return nil, err
}
taskExternalIdsHasOutputEvent := make(map[string]bool)
// group the output events by their parent id
for _, outputEvent := range outputEvents {
rootId, ok := taskExternalIdsToRootIds[outputEvent.TaskExternalId]
if !ok {
r.l.Warn().Msgf("could not find root id for task %s", outputEvent.TaskExternalId)
continue
}
externalIdsToEvents[rootId] = append(externalIdsToEvents[rootId], outputEvent)
taskExternalIdsHasOutputEvent[outputEvent.TaskExternalId] = true
}
finalizedRootIdsMap := make(map[string]bool)
for _, rootId := range finalizedRootIds {
finalizedRootIdsMap[rootId] = true
}
// if tasks that we read originally don't have a TaskOutputEvent, they're not finalized, so set their root
// ids in finalizedRootIdsMap to false
for _, taskExternalId := range taskExternalIds {
if !taskExternalIdsHasOutputEvent[taskExternalId] {
rootId, ok := taskExternalIdsToRootIds[taskExternalId]
if !ok {
r.l.Warn().Msgf("could not find root id for task %s", taskExternalId)
continue
}
finalizedRootIdsMap[rootId] = false
}
}
// look for finalized events...
eventsForFinalizedRootIds := make(map[string][]*TaskOutputEvent)
for _, rootId := range finalizedRootIds {
if !finalizedRootIdsMap[rootId] {
continue
}
events := externalIdsToEvents[rootId]
// if the length of the rootId -> stepCount is less than the number of events, it's not finalized
if _, ok := rootIdToStepCounts[rootId]; ok && rootIdToStepCounts[rootId] > int64(len(events)) {
continue
}
eventsForFinalizedRootIds[rootId] = events
}
// put together response
res := make([]*ListFinalizedWorkflowRunsResponse, 0, len(eventsForFinalizedRootIds))
for rootId, events := range eventsForFinalizedRootIds {
res = append(res, &ListFinalizedWorkflowRunsResponse{
WorkflowRunId: rootId,
OutputEvents: events,
})
}
if time.Since(start) > 100*time.Millisecond {
r.l.Warn().Dur(
"lookup_duration",
durLookup,
).Dur(
"verify_duration",
durVerify,
).Dur(
"output_events_duration",
durOutputEvents,
).Dur("total_duration", time.Since(start)).Msgf("slow finalized workflow runs lookup")
}
return res, nil
}
func (r *TaskRepositoryImpl) CancelTasks(ctx context.Context, tenantId string, tasks []TaskIdInsertedAtRetryCount) (*FinalizedTaskResponse, error) {
ctx, span := telemetry.NewSpan(ctx, "TaskRepositoryImpl.CancelTasks")
defer span.End()
// TODO: ADD BACK VALIDATION
// if err := r.v.Validate(tasks); err != nil {
// fmt.Println("FAILED VALIDATION HERE!!!")
// return err
// }
tx, commit, rollback, err := sqlchelpers.PrepareTx(ctx, r.pool, r.l, 5000)
if err != nil {
err = fmt.Errorf("failed to prepare tx: %w", err)
span.RecordError(err)
span.SetStatus(codes.Error, "failed to prepare tx")
return nil, err
}
defer rollback()
// release queue items
res, err := r.cancelTasks(ctx, tx, tenantId, tasks)
if err != nil {
err = fmt.Errorf("failed to cancel tasks: %w", err)
span.RecordError(err)
span.SetStatus(codes.Error, "failed to cancel tasks")
return nil, err
}
// commit the transaction
if err := commit(ctx); err != nil {
err = fmt.Errorf("failed to commit transaction: %w", err)
span.RecordError(err)
span.SetStatus(codes.Error, "failed to commit transaction")
return nil, err
}
return res, nil
}
func (r *sharedRepository) cancelTasks(ctx context.Context, dbtx sqlcv1.DBTX, tenantId string, tasks []TaskIdInsertedAtRetryCount) (*FinalizedTaskResponse, error) {
// get a unique set of task ids and retry counts
tasks = uniqueSet(tasks)
// release queue items
releasedTasks, err := r.releaseTasks(ctx, dbtx, tenantId, tasks)
if err != nil {
return nil, err
}
outputs := make([][]byte, len(releasedTasks))
for i, releasedTask := range releasedTasks {
out := NewCancelledTaskOutputEvent(releasedTask).Bytes()
outputs[i] = out
}
internalEvents, err := r.createTaskEventsAfterRelease(
ctx,
dbtx,
tenantId,
tasks,
outputs,
releasedTasks,
sqlcv1.V1TaskEventTypeCANCELLED,
)
if err != nil {
return nil, err
}
return &FinalizedTaskResponse{
ReleasedTasks: releasedTasks,
InternalEvents: internalEvents,
}, nil
}
func (r *TaskRepositoryImpl) ListTasks(ctx context.Context, tenantId string, tasks []int64) ([]*sqlcv1.V1Task, error) {
return r.listTasks(ctx, r.pool, tenantId, tasks)
}
func (r *sharedRepository) listTasks(ctx context.Context, dbtx sqlcv1.DBTX, tenantId string, tasks []int64) ([]*sqlcv1.V1Task, error) {
return r.queries.ListTasks(ctx, dbtx, sqlcv1.ListTasksParams{
TenantID: sqlchelpers.UUIDFromStr(tenantId),
Ids: tasks,
})
}
func (r *TaskRepositoryImpl) listTaskOutputEvents(ctx context.Context, tx sqlcv1.DBTX, tenantId string, taskExternalIds []string) ([]*TaskOutputEvent, error) {
externalIds := make([]pgtype.UUID, 0)
eventTypes := make([][]string, 0)
for _, externalId := range taskExternalIds {
externalIds = append(externalIds, sqlchelpers.UUIDFromStr(externalId))
eventTypes = append(eventTypes, []string{
string(sqlcv1.V1TaskEventTypeCOMPLETED),
string(sqlcv1.V1TaskEventTypeFAILED),
string(sqlcv1.V1TaskEventTypeCANCELLED),
})
}
matchedEvents, err := r.queries.ListMatchingTaskEvents(ctx, tx, sqlcv1.ListMatchingTaskEventsParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Taskexternalids: externalIds,
Eventtypes: eventTypes,
})
if err != nil {
return nil, err
}
retrieveOpts := make([]RetrievePayloadOpts, len(matchedEvents))
retrieveOptsToEventData := make(map[RetrievePayloadOpts][]byte)
matchedEventToRetrieveOpts := make(map[*sqlcv1.ListMatchingTaskEventsRow]RetrievePayloadOpts)
for i, event := range matchedEvents {
opt := RetrievePayloadOpts{
Id: event.ID,
InsertedAt: event.InsertedAt,
Type: sqlcv1.V1PayloadTypeTASKEVENTDATA,
TenantId: sqlchelpers.UUIDFromStr(tenantId),
}
retrieveOpts[i] = opt
retrieveOptsToEventData[opt] = event.Data
matchedEventToRetrieveOpts[event] = opt
}
payloads, err := r.payloadStore.BulkRetrieve(ctx, retrieveOpts...)
if err != nil {
return nil, err
}
res := make([]*TaskOutputEvent, 0, len(matchedEvents))
for _, event := range matchedEvents {
retrieveOpts := matchedEventToRetrieveOpts[event]
payload, ok := payloads[retrieveOpts]
if !ok {
r.l.Error().Msgf("ListenForDurableEvent: matched event %s with created at %s and id %d has empty payload, falling back to input", event.ExternalID, event.CreatedAt.Time, event.ID)
payload = retrieveOptsToEventData[retrieveOpts]
}
o, err := newTaskEventFromBytes(payload)
if err != nil {
return nil, err
}
res = append(res, o)
}
return res, nil
}
func (r *TaskRepositoryImpl) ListTaskMetas(ctx context.Context, tenantId string, tasks []int64) ([]*sqlcv1.ListTaskMetasRow, error) {
return r.queries.ListTaskMetas(ctx, r.pool, sqlcv1.ListTaskMetasParams{
TenantID: sqlchelpers.UUIDFromStr(tenantId),
Ids: tasks,
})
}
func (r *TaskRepositoryImpl) ProcessTaskTimeouts(ctx context.Context, tenantId string) (*TimeoutTasksResponse, bool, error) {
tx, commit, rollback, err := sqlchelpers.PrepareTx(ctx, r.pool, r.l, 25000)
if err != nil {
return nil, false, err
}
defer rollback()
limit := r.timeoutLimit
// get task timeouts
toTimeout, err := r.queries.ListTasksToTimeout(ctx, tx, sqlcv1.ListTasksToTimeoutParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Limit: pgtype.Int4{
Int32: int32(limit),
Valid: true,
},
})
if err != nil {
return nil, false, err
}
if len(toTimeout) == 0 {
return &TimeoutTasksResponse{
FailTasksResponse: &FailTasksResponse{
FinalizedTaskResponse: &FinalizedTaskResponse{
ReleasedTasks: make([]*sqlcv1.ReleaseTasksRow, 0),
InternalEvents: make([]InternalTaskEvent, 0),
},
RetriedTasks: make([]RetriedTask, 0),
},
TimeoutTasks: make([]*sqlcv1.ListTasksToTimeoutRow, 0),
}, false, nil
}
// parse into FailTaskOpts
failOpts := make([]FailTaskOpts, 0, len(toTimeout))
for _, task := range toTimeout {
failOpts = append(failOpts, FailTaskOpts{
TaskIdInsertedAtRetryCount: &TaskIdInsertedAtRetryCount{
Id: task.ID,
InsertedAt: task.InsertedAt,
RetryCount: task.RetryCount,
},
IsAppError: true,
ErrorMessage: fmt.Sprintf("Task exceeded timeout of %s", task.StepTimeout.String),
IsNonRetryable: false,
})
}
// fail the tasks
failResp, err := r.failTasksTx(ctx, tx, tenantId, failOpts)
if err != nil {
return nil, false, err
}
// commit the transaction
if err := commit(ctx); err != nil {
return nil, false, err
}
return &TimeoutTasksResponse{
FailTasksResponse: failResp,
TimeoutTasks: toTimeout,
}, len(toTimeout) == limit, nil
}
func (r *TaskRepositoryImpl) ProcessTaskReassignments(ctx context.Context, tenantId string) (*FailTasksResponse, bool, error) {
tx, commit, rollback, err := sqlchelpers.PrepareTx(ctx, r.pool, r.l, 25000)
if err != nil {
return nil, false, err
}
defer rollback()
limit := r.reassignLimit
toReassign, err := r.queries.ListTasksToReassign(ctx, tx, sqlcv1.ListTasksToReassignParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Limit: pgtype.Int4{
Int32: int32(limit),
Valid: true,
},
})
if err != nil {
return nil, false, err
}
if len(toReassign) == 0 {
return &FailTasksResponse{
FinalizedTaskResponse: &FinalizedTaskResponse{
ReleasedTasks: make([]*sqlcv1.ReleaseTasksRow, 0),
InternalEvents: make([]InternalTaskEvent, 0),
},
RetriedTasks: make([]RetriedTask, 0),
}, false, nil
}
// parse into FailTaskOpts
failOpts := make([]FailTaskOpts, 0, len(toReassign))
for _, task := range toReassign {
failOpts = append(failOpts, FailTaskOpts{
TaskIdInsertedAtRetryCount: &TaskIdInsertedAtRetryCount{
Id: task.ID,
InsertedAt: task.InsertedAt,
RetryCount: task.RetryCount,
},
IsAppError: false,
IsNonRetryable: false,
})
}
// fail the tasks
res, err := r.failTasksTx(ctx, tx, tenantId, failOpts)
if err != nil {
return nil, false, err
}
// commit the transaction
if err := commit(ctx); err != nil {
return nil, false, err
}
return res, len(toReassign) == limit, nil
}
func (r *TaskRepositoryImpl) ProcessTaskRetryQueueItems(ctx context.Context, tenantId string) ([]*sqlcv1.V1RetryQueueItem, bool, error) {
tx, commit, rollback, err := sqlchelpers.PrepareTx(ctx, r.pool, r.l, 25000)
if err != nil {
return nil, false, err
}
defer rollback()
limit := r.retryQueueLimit
// get task reassignments
res, err := r.queries.ProcessRetryQueueItems(ctx, tx, sqlcv1.ProcessRetryQueueItemsParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Limit: pgtype.Int4{
Int32: int32(limit),
Valid: true,
},
})
if err != nil {
return nil, false, err
}
// commit the transaction
if err := commit(ctx); err != nil {
return nil, false, err
}
return res, len(res) == limit, nil
}
type durableSleepEventData struct {
SleepDuration string `json:"sleep_duration"`
}
func (r *TaskRepositoryImpl) ProcessDurableSleeps(ctx context.Context, tenantId string) (*EventMatchResults, bool, error) {
tx, commit, rollback, err := sqlchelpers.PrepareTx(ctx, r.pool, r.l, 25000)
if err != nil {
return nil, false, err
}
defer rollback()
limit := r.durableSleepLimit
emitted, err := r.queries.PopDurableSleep(ctx, tx, sqlcv1.PopDurableSleepParams{
TenantID: sqlchelpers.UUIDFromStr(tenantId),
Limit: pgtype.Int4{Int32: int32(limit), Valid: true},
})
if err != nil {
return nil, false, err
}
// each emitted item becomes a candidate event match for internal events
events := make([]CandidateEventMatch, 0, len(emitted))
for _, sleep := range emitted {
data, err := json.Marshal(durableSleepEventData{
SleepDuration: sleep.SleepDuration,
})
if err != nil {
return nil, false, err
}
events = append(events, CandidateEventMatch{
ID: uuid.New().String(),
EventTimestamp: time.Now(),
Key: getDurableSleepEventKey(sleep.ID),
Data: data,
})
}
results, err := r.processEventMatches(ctx, tx, tenantId, events, sqlcv1.V1EventTypeINTERNAL)
if err != nil {
return nil, false, err
}
storePayloadOpts := make([]StorePayloadOpts, len(results.CreatedTasks))
for i, task := range results.CreatedTasks {
storePayloadOpts[i] = StorePayloadOpts{
Id: task.ID,
InsertedAt: task.InsertedAt,
Type: sqlcv1.V1PayloadTypeTASKINPUT,
Payload: task.Payload,
TenantId: task.TenantID.String(),
}
}
if len(storePayloadOpts) > 0 {
err = r.payloadStore.Store(ctx, tx, storePayloadOpts...)
if err != nil {
return nil, false, fmt.Errorf("failed to store payloads for created tasks for durable sleep matches: %w", err)
}
}
if err := commit(ctx); err != nil {
return nil, false, err
}
return results, len(emitted) == limit, nil
}
func (r *TaskRepositoryImpl) GetQueueCounts(ctx context.Context, tenantId string) (map[string]interface{}, error) {
counts, err := r.getFIFOQueuedCounts(ctx, tenantId)
if err != nil {
return nil, err
}
concurrencyCounts, err := r.getConcurrencyQueuedCounts(ctx, tenantId)
if err != nil {
return nil, err
}
res := make(map[string]interface{})
for k, v := range counts {
res[k] = v
}
for k, v := range concurrencyCounts {
res[k] = v
}
return res, nil
}
func (r *TaskRepositoryImpl) getFIFOQueuedCounts(ctx context.Context, tenantId string) (map[string]interface{}, error) {
counts, err := r.queries.GetQueuedCounts(ctx, r.pool, sqlchelpers.UUIDFromStr(tenantId))
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return map[string]interface{}{}, nil
}
return nil, err
}
res := make(map[string]interface{})
for _, count := range counts {
res[count.Queue] = int(count.Count)
}
return res, nil
}
func (r *TaskRepositoryImpl) getConcurrencyQueuedCounts(ctx context.Context, tenantId string) (map[string]interface{}, error) {
concurrencyCounts, err := r.queries.GetWorkflowConcurrencyQueueCounts(ctx, r.pool, sqlchelpers.UUIDFromStr(tenantId))
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return map[string]interface{}{}, nil
}
return nil, err
}
res := make(map[string]interface{})
for _, count := range concurrencyCounts {
if _, ok := res[count.WorkflowName]; !ok {
res[count.WorkflowName] = map[string]int{}
}
v := res[count.WorkflowName].(map[string]int)
v[count.Key] = int(count.Count)
res[count.WorkflowName] = v
}
return res, nil
}
func (r *TaskRepositoryImpl) RefreshTimeoutBy(ctx context.Context, tenantId string, opt RefreshTimeoutBy) (*sqlcv1.V1TaskRuntime, error) {
if err := r.v.Validate(opt); err != nil {
return nil, err
}
tx, commit, rollback, err := sqlchelpers.PrepareTx(ctx, r.pool, r.l, 5000)
if err != nil {
return nil, err
}
defer rollback()
res, err := r.queries.RefreshTimeoutBy(ctx, tx, sqlcv1.RefreshTimeoutByParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Externalid: sqlchelpers.UUIDFromStr(opt.TaskExternalId),
IncrementTimeoutBy: sqlchelpers.TextFromStr(opt.IncrementTimeoutBy),
})
if err != nil {
return nil, err
}
if err := commit(ctx); err != nil {
return nil, err
}
return res, nil
}
func (r *TaskRepositoryImpl) ReleaseSlot(ctx context.Context, tenantId, externalId string) (*sqlcv1.V1TaskRuntime, error) {
tx, commit, rollback, err := sqlchelpers.PrepareTx(ctx, r.pool, r.l, 5000)
if err != nil {
return nil, err
}
defer rollback()
resp, err := r.queries.ManualSlotRelease(
ctx,
tx,
sqlcv1.ManualSlotReleaseParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Externalid: sqlchelpers.UUIDFromStr(externalId),
},
)
if err != nil {
return nil, err
}
if err := commit(ctx); err != nil {
return nil, err
}
return resp, nil
}
func (r *sharedRepository) releaseTasks(ctx context.Context, tx sqlcv1.DBTX, tenantId string, tasks []TaskIdInsertedAtRetryCount) ([]*sqlcv1.ReleaseTasksRow, error) {
taskIds := make([]int64, len(tasks))
taskInsertedAts := make([]pgtype.Timestamptz, len(tasks))
retryCounts := make([]int32, len(tasks))
orderedMap := make(map[string]int)
for i, task := range tasks {
taskIds[i] = task.Id
taskInsertedAts[i] = task.InsertedAt
retryCounts[i] = task.RetryCount
orderedMap[fmt.Sprintf("%d:%d", task.Id, task.RetryCount)] = i
}
releasedTasks, err := r.queries.ReleaseTasks(ctx, tx, sqlcv1.ReleaseTasksParams{
Taskids: taskIds,
Taskinsertedats: taskInsertedAts,
Retrycounts: retryCounts,
})
if err != nil {
return nil, err
}
if len(releasedTasks) != len(tasks) {
size := min(10, len(tasks))
taskIds := make([]int64, size)
for i := range size {
taskIds[i] = tasks[i].Id
}
return nil, fmt.Errorf("failed to release all tasks for tenant %s: %d/%d. Relevant task IDs: %v", tenantId, len(releasedTasks), size, taskIds)
}
res := make([]*sqlcv1.ReleaseTasksRow, len(tasks))
for _, task := range releasedTasks {
idx := orderedMap[fmt.Sprintf("%d:%d", task.ID, task.RetryCount)]
res[idx] = task
}
return res, nil
}
func (r *sharedRepository) upsertQueues(ctx context.Context, tx sqlcv1.DBTX, tenantId string, queues []string) (func(), error) {
queuesToInsert := make(map[string]struct{}, 0)
for _, queue := range queues {
if _, ok := queuesToInsert[queue]; ok {
continue
}
key := getQueueCacheKey(tenantId, queue)
if hasSetQueue, ok := r.queueCache.Get(key); ok && hasSetQueue.(bool) {
continue
}
queuesToInsert[queue] = struct{}{}
}
uniqueQueues := make([]string, 0, len(queuesToInsert))
for queue := range queuesToInsert {
uniqueQueues = append(uniqueQueues, queue)
}
err := r.queries.UpsertQueues(ctx, tx, sqlcv1.UpsertQueuesParams{
TenantID: sqlchelpers.UUIDFromStr(tenantId),
Names: uniqueQueues,
})
if err != nil {
return nil, err
}
// set all the queues to true in the cache
save := func() {
for _, queue := range uniqueQueues {
key := getQueueCacheKey(tenantId, queue)
r.queueCache.Set(key, true)
}
}
return save, nil
}
func getQueueCacheKey(tenantId string, queue string) string {
return fmt.Sprintf("%s:%s", tenantId, queue)
}
func (r *sharedRepository) createTasks(
ctx context.Context,
tx sqlcv1.DBTX,
tenantId string,
tasks []CreateTaskOpts,
) ([]*V1TaskWithPayload, error) {
// list the steps for the tasks
uniqueStepIds := make(map[string]struct{})
stepIds := make([]pgtype.UUID, 0)
externalIdToPayload := make(map[string][]byte, len(tasks))
for _, task := range tasks {
if _, ok := uniqueStepIds[task.StepId]; !ok {
uniqueStepIds[task.StepId] = struct{}{}
stepIds = append(stepIds, sqlchelpers.UUIDFromStr(task.StepId))
externalIdToPayload[task.ExternalId] = task.Input.Bytes()
}
}
steps, err := r.queries.ListStepsByIds(ctx, tx, sqlcv1.ListStepsByIdsParams{
Ids: stepIds,
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
})
if err != nil {
return nil, err
}
stepIdsToConfig := make(map[string]*sqlcv1.ListStepsByIdsRow)
for _, step := range steps {
stepIdsToConfig[sqlchelpers.UUIDToStr(step.ID)] = step
}
return r.insertTasks(ctx, tx, tenantId, tasks, stepIdsToConfig)
}
// insertTasks inserts new tasks into the database. note that we're using Postgres rules to automatically insert the created
// tasks into the queue_items table.
func (r *sharedRepository) insertTasks(
ctx context.Context,
tx sqlcv1.DBTX,
tenantId string,
tasks []CreateTaskOpts,
stepIdsToConfig map[string]*sqlcv1.ListStepsByIdsRow,
) ([]*V1TaskWithPayload, error) {
if len(tasks) == 0 {
return nil, nil
}
expressions, err := r.getStepExpressions(ctx, tx, stepIdsToConfig)
if err != nil {
return nil, fmt.Errorf("failed to get step expressions: %w", err)
}
concurrencyStrats, err := r.getConcurrencyExpressions(ctx, tx, tenantId, stepIdsToConfig)
if err != nil {
return nil, fmt.Errorf("failed to get concurrency expressions: %w", err)
}
tenantIds := make([]pgtype.UUID, len(tasks))
queues := make([]string, len(tasks))
actionIds := make([]string, len(tasks))
stepIds := make([]pgtype.UUID, len(tasks))
stepReadableIds := make([]string, len(tasks))
workflowIds := make([]pgtype.UUID, len(tasks))
scheduleTimeouts := make([]string, len(tasks))
stepTimeouts := make([]string, len(tasks))
priorities := make([]int32, len(tasks))
stickies := make([]string, len(tasks))
desiredWorkerIds := make([]pgtype.UUID, len(tasks))
externalIds := make([]pgtype.UUID, len(tasks))
displayNames := make([]string, len(tasks))
retryCounts := make([]int32, len(tasks))
additionalMetadatas := make([][]byte, len(tasks))
initialStates := make([]string, len(tasks))
initialStateReasons := make([]pgtype.Text, len(tasks))
dagIds := make([]pgtype.Int8, len(tasks))
dagInsertedAts := make([]pgtype.Timestamptz, len(tasks))
parentStrategyIds := make([][]pgtype.Int8, len(tasks))
strategyIds := make([][]int64, len(tasks))
concurrencyKeys := make([][]string, len(tasks))
parentTaskExternalIds := make([]pgtype.UUID, len(tasks))
parentTaskIds := make([]pgtype.Int8, len(tasks))
parentTaskInsertedAts := make([]pgtype.Timestamptz, len(tasks))
childIndices := make([]pgtype.Int8, len(tasks))
childKeys := make([]pgtype.Text, len(tasks))
stepIndices := make([]int64, len(tasks))
retryBackoffFactors := make([]pgtype.Float8, len(tasks))
retryMaxBackoffs := make([]pgtype.Int4, len(tasks))
createExpressionOpts := make(map[string][]createTaskExpressionEvalOpt, 0)
workflowVersionIds := make([]pgtype.UUID, len(tasks))
workflowRunIds := make([]pgtype.UUID, len(tasks))
externalIdToInput := make(map[string][]byte, len(tasks))
unix := time.Now().UnixMilli()
cleanupParentStrategyIds := make([]int64, 0)
cleanupWorkflowVersionIds := make([]pgtype.UUID, 0)
cleanupWorkflowRunIds := make([]pgtype.UUID, 0)
for i, task := range tasks {
stepConfig := stepIdsToConfig[task.StepId]
tenantIds[i] = sqlchelpers.UUIDFromStr(tenantId)
queues[i] = stepConfig.ActionId // FIXME: make the queue name dynamic
actionIds[i] = stepConfig.ActionId
stepIds[i] = sqlchelpers.UUIDFromStr(task.StepId)
stepReadableIds[i] = stepConfig.ReadableId.String
workflowIds[i] = stepConfig.WorkflowId
workflowVersionIds[i] = stepConfig.WorkflowVersionId
scheduleTimeouts[i] = stepConfig.ScheduleTimeout
stepTimeouts[i] = stepConfig.Timeout.String
externalIds[i] = sqlchelpers.UUIDFromStr(task.ExternalId)
displayNames[i] = fmt.Sprintf("%s-%d", stepConfig.ReadableId.String, unix)
stepIndices[i] = int64(task.StepIndex)
retryBackoffFactors[i] = stepConfig.RetryBackoffFactor
retryMaxBackoffs[i] = stepConfig.RetryMaxBackoff
workflowRunIds[i] = sqlchelpers.UUIDFromStr(task.WorkflowRunId)
// TODO: case on whether this is a v1 or v2 task by looking at the step data. for now,
// we're assuming a v1 task.
externalIdToInput[task.ExternalId] = r.ToV1StepRunData(task.Input).Bytes()
retryCounts[i] = 0
defaultPriority := stepConfig.DefaultPriority
priority := defaultPriority
if task.Priority != nil {
priority = *task.Priority
}
priorities[i] = priority
stickies[i] = string(sqlcv1.V1StickyStrategyNONE)
if stepConfig.WorkflowVersionSticky.Valid {
stickies[i] = string(stepConfig.WorkflowVersionSticky.StickyStrategy)
}
desiredWorkerIds[i] = pgtype.UUID{
Valid: false,
}
if task.DesiredWorkerId != nil {
desiredWorkerIds[i] = sqlchelpers.UUIDFromStr(*task.DesiredWorkerId)
}
initialStates[i] = string(task.InitialState)
if initialStates[i] == "" {
initialStates[i] = string(sqlcv1.V1TaskInitialStateQUEUED)
}
if len(task.AdditionalMetadata) > 0 {
additionalMetadatas[i] = task.AdditionalMetadata
}
if task.DagId != nil && task.DagInsertedAt.Valid {
dagIds[i] = pgtype.Int8{
Int64: *task.DagId,
Valid: true,
}
dagInsertedAts[i] = task.DagInsertedAt
}
if task.ParentTaskExternalId != nil {
parentTaskExternalIds[i] = sqlchelpers.UUIDFromStr(*task.ParentTaskExternalId)
}
if task.ParentTaskId != nil {
parentTaskIds[i] = pgtype.Int8{
Int64: *task.ParentTaskId,
Valid: true,
}
}
if task.ParentTaskInsertedAt != nil {
parentTaskInsertedAts[i] = sqlchelpers.TimestamptzFromTime(*task.ParentTaskInsertedAt)
}
if task.ChildIndex != nil {
childIndices[i] = pgtype.Int8{
Int64: *task.ChildIndex,
Valid: true,
}
}
if task.ChildKey != nil {
childKeys[i] = pgtype.Text{
String: *task.ChildKey,
Valid: true,
}
}
concurrencyKeys[i] = make([]string, 0)
// we write any parent strategy ids to the task regardless of initial state, as we need to know
// when to release the parent concurrency slot
taskParentStrategyIds := make([]pgtype.Int8, 0)
taskStrategyIds := make([]int64, 0)
emptyConcurrencyKeys := make([]string, 0)
if strats, ok := concurrencyStrats[task.StepId]; ok {
for _, strat := range strats {
taskStrategyIds = append(taskStrategyIds, strat.ID)
taskParentStrategyIds = append(taskParentStrategyIds, strat.ParentStrategyID)
emptyConcurrencyKeys = append(emptyConcurrencyKeys, "")
// we only need to cleanup parent strategy ids if the task is not in a QUEUED state, because
// this skips the creation of a concurrency slot and means we might want to cleanup the workflow slot
if strat.ParentStrategyID.Valid && task.InitialState != sqlcv1.V1TaskInitialStateQUEUED {
cleanupParentStrategyIds = append(cleanupParentStrategyIds, strat.ParentStrategyID.Int64)
cleanupWorkflowRunIds = append(cleanupWorkflowRunIds, sqlchelpers.UUIDFromStr(task.WorkflowRunId))
cleanupWorkflowVersionIds = append(cleanupWorkflowVersionIds, stepConfig.WorkflowVersionId)
}
}
}
parentStrategyIds[i] = taskParentStrategyIds
strategyIds[i] = taskStrategyIds
concurrencyKeys[i] = emptyConcurrencyKeys
// only check for concurrency if the task is in a queued state, otherwise we don't need to
// evaluate the expression (and it will likely fail if we do)
if task.InitialState == sqlcv1.V1TaskInitialStateQUEUED {
// if we have a step expression, evaluate the expression
if strats, ok := concurrencyStrats[task.StepId]; ok {
taskConcurrencyKeys := make([]string, 0)
var failTaskError error
for _, strat := range strats {
var additionalMeta map[string]interface{}
if len(additionalMetadatas[i]) > 0 {
if err := json.Unmarshal(additionalMetadatas[i], &additionalMeta); err != nil {
failTaskError = fmt.Errorf("failed to process additional metadata: not a json object")
break
}
}
if task.Input == nil {
failTaskError = fmt.Errorf("failed to parse step expression (%s): input is nil", strat.Expression)
break
}
// Make sure to fail the task with a user-friendly error if we can't parse the CEL for priority
// Can set fail task error which will insert with an initial state of failed
res, err := r.celParser.ParseAndEvalStepRun(strat.Expression, cel.NewInput(
cel.WithInput(task.Input.Input),
cel.WithAdditionalMetadata(additionalMeta),
cel.WithWorkflowRunID(task.ExternalId),
cel.WithParents(task.Input.TriggerData),
))
if err != nil {
failTaskError = fmt.Errorf("failed to parse step expression (%s): %w", strat.Expression, err)
break
}
if res.String == nil {
prefix := "expected string output for concurrency key"
if res.Int != nil {
failTaskError = fmt.Errorf("failed to parse step expression (%s): %s, got int", strat.Expression, prefix)
break
}
failTaskError = fmt.Errorf("failed to parse step expression (%s): %s, got unknown type", strat.Expression, prefix)
break
}
taskConcurrencyKeys = append(taskConcurrencyKeys, *res.String)
}
if failTaskError != nil {
// place the task into a failed state
initialStates[i] = string(sqlcv1.V1TaskInitialStateFAILED)
initialStateReasons[i] = pgtype.Text{
String: failTaskError.Error(),
Valid: true,
}
} else {
concurrencyKeys[i] = taskConcurrencyKeys
}
}
}
// next, check for step expressions to evaluate
if task.InitialState == sqlcv1.V1TaskInitialStateQUEUED && stepConfig.ExprCount > 0 {
expressions, ok := expressions[task.StepId]
if ok {
var failTaskError error
opts := make([]createTaskExpressionEvalOpt, 0)
for _, expr := range expressions {
var additionalMeta map[string]interface{}
if len(additionalMetadatas[i]) > 0 {
if err := json.Unmarshal(additionalMetadatas[i], &additionalMeta); err != nil {
failTaskError = fmt.Errorf("failed to process additional metadata: not a json object")
break
}
}
if task.Input == nil {
failTaskError = fmt.Errorf("failed to parse step expression (%s): input is nil", expr.Expression)
break
}
res, err := r.celParser.ParseAndEvalStepRun(expr.Expression, cel.NewInput(
cel.WithInput(task.Input.Input),
cel.WithAdditionalMetadata(additionalMeta),
cel.WithWorkflowRunID(task.ExternalId),
cel.WithParents(task.Input.TriggerData),
))
if err != nil {
failTaskError = fmt.Errorf("failed to parse step expression (%s): %w", expr.Expression, err)
break
}
if err := r.celParser.CheckStepRunOutAgainstKnownV1(res, expr.Kind); err != nil {
failTaskError = fmt.Errorf("failed to parse step expression (%s): %w", expr.Expression, err)
break
}
opts = append(opts, createTaskExpressionEvalOpt{
Key: expr.Key,
Kind: expr.Kind,
ValueStr: res.String,
ValueInt: res.Int,
})
}
if failTaskError != nil {
// place the task into a failed state
initialStates[i] = string(sqlcv1.V1TaskInitialStateFAILED)
initialStateReasons[i] = pgtype.Text{
String: failTaskError.Error(),
Valid: true,
}
} else {
createExpressionOpts[task.ExternalId] = opts
}
} else {
r.l.Warn().Msgf("no expressions found for step %s", task.StepId)
}
}
}
saveQueueCache, err := r.upsertQueues(ctx, tx, tenantId, queues)
if err != nil {
return nil, fmt.Errorf("failed to upsert queues: %w", err)
}
// group by step_id
stepIdsToParams := make(map[string]sqlcv1.CreateTasksParams, 0)
for i, task := range tasks {
params, ok := stepIdsToParams[task.StepId]
if !ok {
params = sqlcv1.CreateTasksParams{
Tenantids: make([]pgtype.UUID, 0),
Queues: make([]string, 0),
Actionids: make([]string, 0),
Stepids: make([]pgtype.UUID, 0),
Stepreadableids: make([]string, 0),
Workflowids: make([]pgtype.UUID, 0),
Scheduletimeouts: make([]string, 0),
Steptimeouts: make([]string, 0),
Priorities: make([]int32, 0),
Stickies: make([]string, 0),
Desiredworkerids: make([]pgtype.UUID, 0),
Externalids: make([]pgtype.UUID, 0),
Displaynames: make([]string, 0),
Retrycounts: make([]int32, 0),
Additionalmetadatas: make([][]byte, 0),
InitialStates: make([]string, 0),
InitialStateReasons: make([]pgtype.Text, 0),
Dagids: make([]pgtype.Int8, 0),
Daginsertedats: make([]pgtype.Timestamptz, 0),
Concurrencyparentstrategyids: make([][]pgtype.Int8, 0),
ConcurrencyStrategyIds: make([][]int64, 0),
ConcurrencyKeys: make([][]string, 0),
ParentTaskExternalIds: make([]pgtype.UUID, 0),
ParentTaskIds: make([]pgtype.Int8, 0),
ParentTaskInsertedAts: make([]pgtype.Timestamptz, 0),
ChildIndex: make([]pgtype.Int8, 0),
ChildKey: make([]pgtype.Text, 0),
StepIndex: make([]int64, 0),
RetryBackoffFactor: make([]pgtype.Float8, 0),
RetryMaxBackoff: make([]pgtype.Int4, 0),
WorkflowVersionIds: make([]pgtype.UUID, 0),
WorkflowRunIds: make([]pgtype.UUID, 0),
Inputs: make([][]byte, 0),
}
}
params.Tenantids = append(params.Tenantids, tenantIds[i])
params.Queues = append(params.Queues, queues[i])
params.Actionids = append(params.Actionids, actionIds[i])
params.Stepids = append(params.Stepids, stepIds[i])
params.Stepreadableids = append(params.Stepreadableids, stepReadableIds[i])
params.Workflowids = append(params.Workflowids, workflowIds[i])
params.Scheduletimeouts = append(params.Scheduletimeouts, scheduleTimeouts[i])
params.Steptimeouts = append(params.Steptimeouts, stepTimeouts[i])
params.Priorities = append(params.Priorities, priorities[i])
params.Stickies = append(params.Stickies, stickies[i])
params.Desiredworkerids = append(params.Desiredworkerids, desiredWorkerIds[i])
params.Externalids = append(params.Externalids, externalIds[i])
params.Displaynames = append(params.Displaynames, displayNames[i])
params.Retrycounts = append(params.Retrycounts, retryCounts[i])
params.Additionalmetadatas = append(params.Additionalmetadatas, additionalMetadatas[i])
params.InitialStates = append(params.InitialStates, initialStates[i])
params.InitialStateReasons = append(params.InitialStateReasons, initialStateReasons[i])
params.Dagids = append(params.Dagids, dagIds[i])
params.Daginsertedats = append(params.Daginsertedats, dagInsertedAts[i])
params.Concurrencyparentstrategyids = append(params.Concurrencyparentstrategyids, parentStrategyIds[i])
params.ConcurrencyStrategyIds = append(params.ConcurrencyStrategyIds, strategyIds[i])
params.ConcurrencyKeys = append(params.ConcurrencyKeys, concurrencyKeys[i])
params.ParentTaskExternalIds = append(params.ParentTaskExternalIds, parentTaskExternalIds[i])
params.ParentTaskIds = append(params.ParentTaskIds, parentTaskIds[i])
params.ParentTaskInsertedAts = append(params.ParentTaskInsertedAts, parentTaskInsertedAts[i])
params.ChildIndex = append(params.ChildIndex, childIndices[i])
params.ChildKey = append(params.ChildKey, childKeys[i])
params.StepIndex = append(params.StepIndex, stepIndices[i])
params.RetryBackoffFactor = append(params.RetryBackoffFactor, retryBackoffFactors[i])
params.RetryMaxBackoff = append(params.RetryMaxBackoff, retryMaxBackoffs[i])
params.WorkflowVersionIds = append(params.WorkflowVersionIds, workflowVersionIds[i])
params.WorkflowRunIds = append(params.WorkflowRunIds, workflowRunIds[i])
if r.payloadStore.DualWritesEnabled() {
// if dual writes are enabled, write the inputs to the tasks table
params.Inputs = append(params.Inputs, externalIdToInput[task.ExternalId])
} else {
// otherwise, write an empty json object to the inputs column
params.Inputs = append(params.Inputs, []byte("{}"))
}
stepIdsToParams[task.StepId] = params
}
res := make([]*V1TaskWithPayload, 0)
// for any initial states which are not queued, create a finalizing task event
eventTaskIdRetryCounts := make([]TaskIdInsertedAtRetryCount, 0)
eventTaskExternalIds := make([]string, 0)
eventDatas := make([][]byte, 0)
eventTypes := make([]sqlcv1.V1TaskEventType, 0)
for stepId, params := range stepIdsToParams {
createdTasks, err := r.queries.CreateTasks(ctx, tx, params)
if err != nil {
return nil, fmt.Errorf("failed to create tasks for step id %s: %w", stepId, err)
}
createdTasksWithPayloads := make([]*V1TaskWithPayload, len(createdTasks))
for i, task := range createdTasks {
input := externalIdToInput[sqlchelpers.UUIDToStr(task.ExternalID)]
withPayload := V1TaskWithPayload{
V1Task: task,
Payload: input,
}
res = append(res, &withPayload)
createdTasksWithPayloads[i] = &withPayload
}
for _, createdTask := range createdTasksWithPayloads {
idRetryCount := TaskIdInsertedAtRetryCount{
Id: createdTask.ID,
InsertedAt: createdTask.InsertedAt,
RetryCount: createdTask.RetryCount,
}
switch createdTask.InitialState {
case sqlcv1.V1TaskInitialStateFAILED:
eventTaskIdRetryCounts = append(eventTaskIdRetryCounts, idRetryCount)
eventTaskExternalIds = append(eventTaskExternalIds, sqlchelpers.UUIDToStr(createdTask.ExternalID))
eventDatas = append(eventDatas, NewFailedTaskOutputEventFromTask(createdTask).Bytes())
eventTypes = append(eventTypes, sqlcv1.V1TaskEventTypeFAILED)
case sqlcv1.V1TaskInitialStateCANCELLED:
eventTaskIdRetryCounts = append(eventTaskIdRetryCounts, idRetryCount)
eventTaskExternalIds = append(eventTaskExternalIds, sqlchelpers.UUIDToStr(createdTask.ExternalID))
eventDatas = append(eventDatas, NewCancelledTaskOutputEventFromTask(createdTask).Bytes())
eventTypes = append(eventTypes, sqlcv1.V1TaskEventTypeCANCELLED)
case sqlcv1.V1TaskInitialStateSKIPPED:
eventTaskIdRetryCounts = append(eventTaskIdRetryCounts, idRetryCount)
eventTaskExternalIds = append(eventTaskExternalIds, sqlchelpers.UUIDToStr(createdTask.ExternalID))
eventDatas = append(eventDatas, NewSkippedTaskOutputEventFromTask(createdTask).Bytes())
eventTypes = append(eventTypes, sqlcv1.V1TaskEventTypeCOMPLETED)
}
}
}
_, err = r.createTaskEvents(
ctx,
tx,
tenantId,
eventTaskIdRetryCounts,
eventTaskExternalIds,
eventDatas,
eventTypes,
make([]string, len(eventTaskIdRetryCounts)),
nil,
)
if err != nil {
return nil, fmt.Errorf("failed to create task events: %w", err)
}
if len(createExpressionOpts) > 0 {
err = r.createExpressionEvals(ctx, tx, res, createExpressionOpts)
if err != nil {
return nil, fmt.Errorf("failed to create expression evals: %w", err)
}
}
if len(cleanupParentStrategyIds) > 0 {
err = r.queries.CleanupWorkflowConcurrencySlotsAfterInsert(
ctx,
tx,
sqlcv1.CleanupWorkflowConcurrencySlotsAfterInsertParams{
Concurrencyparentstrategyids: cleanupParentStrategyIds,
Workflowrunids: cleanupWorkflowRunIds,
Workflowversionids: cleanupWorkflowVersionIds,
},
)
if err != nil {
return nil, fmt.Errorf("failed to cleanup workflow concurrency slots after insert: %w", err)
}
}
// TODO: this should be moved to after the transaction commits
saveQueueCache()
return res, nil
}
// replayTasks updates tasks into the database. note that we're using Postgres rules to automatically insert the created
// tasks into the queue_items table.
func (r *sharedRepository) replayTasks(
ctx context.Context,
tx sqlcv1.DBTX,
tenantId string,
tasks []ReplayTaskOpts,
) ([]*V1TaskWithPayload, error) {
if len(tasks) == 0 {
return nil, nil
}
uniqueStepIds := make(map[string]struct{})
stepIds := make([]pgtype.UUID, 0)
for _, task := range tasks {
if _, ok := uniqueStepIds[task.StepId]; !ok {
uniqueStepIds[task.StepId] = struct{}{}
stepIds = append(stepIds, sqlchelpers.UUIDFromStr(task.StepId))
}
}
steps, err := r.queries.ListStepsByIds(ctx, tx, sqlcv1.ListStepsByIdsParams{
Ids: stepIds,
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
})
if err != nil {
return nil, err
}
stepIdsToConfig := make(map[string]*sqlcv1.ListStepsByIdsRow)
for _, step := range steps {
stepIdsToConfig[sqlchelpers.UUIDToStr(step.ID)] = step
}
concurrencyStrats, err := r.getConcurrencyExpressions(ctx, tx, tenantId, stepIdsToConfig)
if err != nil {
return nil, fmt.Errorf("failed to get step expressions: %w", err)
}
taskIds := make([]int64, len(tasks))
taskInsertedAts := make([]pgtype.Timestamptz, len(tasks))
initialStates := make([]string, len(tasks))
initialStateReasons := make([]pgtype.Text, len(tasks))
concurrencyKeys := make([][]string, len(tasks))
additionalMetadatas := make([][]byte, len(tasks))
queues := make([]string, len(tasks))
externalIdToInput := make(map[string][]byte, len(tasks))
for i, task := range tasks {
stepConfig := stepIdsToConfig[task.StepId]
queues[i] = stepConfig.ActionId // FIXME: make the queue name dynamic
taskIds[i] = task.TaskId
taskInsertedAts[i] = task.InsertedAt
// TODO: case on whether this is a v1 or v2 task by looking at the step data. for now,
// we're assuming a v1 task.
if task.Input != nil {
externalIdToInput[task.ExternalId] = r.ToV1StepRunData(task.Input).Bytes()
}
initialStates[i] = string(task.InitialState)
if initialStates[i] == "" {
initialStates[i] = string(sqlcv1.V1TaskInitialStateQUEUED)
}
if len(task.AdditionalMetadata) > 0 {
additionalMetadatas[i] = task.AdditionalMetadata
}
// only check for concurrency if the task is in a queued state, otherwise we don't need to
// evaluate the expression (and it will likely fail if we do)
if task.InitialState == sqlcv1.V1TaskInitialStateQUEUED {
// if we have a step expression, evaluate the expression
if strats, ok := concurrencyStrats[task.StepId]; ok {
taskConcurrencyKeys := make([]string, 0)
var failTaskError error
for _, strat := range strats {
var additionalMeta map[string]interface{}
if len(additionalMetadatas[i]) > 0 {
if err := json.Unmarshal(additionalMetadatas[i], &additionalMeta); err != nil {
failTaskError = fmt.Errorf("failed to process additional metadata: not a json object")
break
}
}
if task.Input == nil {
failTaskError = fmt.Errorf("failed to parse step expression (%s): input is nil", strat.Expression)
break
}
res, err := r.celParser.ParseAndEvalStepRun(strat.Expression, cel.NewInput(
cel.WithInput(task.Input.Input),
cel.WithAdditionalMetadata(additionalMeta),
cel.WithWorkflowRunID(task.ExternalId),
cel.WithParents(task.Input.TriggerData),
))
if err != nil {
failTaskError = fmt.Errorf("failed to parse step expression (%s): %w", strat.Expression, err)
break
}
if res.String == nil {
prefix := "expected string output for concurrency key"
if res.Int != nil {
failTaskError = fmt.Errorf("failed to parse step expression (%s): %s, got int", strat.Expression, prefix)
break
}
failTaskError = fmt.Errorf("failed to parse step expression (%s): %s, got unknown type", strat.Expression, prefix)
break
}
taskConcurrencyKeys = append(taskConcurrencyKeys, *res.String)
}
if failTaskError != nil {
// place the task into a failed state
initialStates[i] = string(sqlcv1.V1TaskInitialStateFAILED)
initialStateReasons[i] = pgtype.Text{
String: failTaskError.Error(),
Valid: true,
}
} else {
concurrencyKeys[i] = taskConcurrencyKeys
}
}
}
}
saveQueueCache, err := r.upsertQueues(ctx, tx, tenantId, queues)
if err != nil {
return nil, fmt.Errorf("failed to upsert queues: %w", err)
}
stepIdsToParams := make(map[string]sqlcv1.ReplayTasksParams, 0)
stepIdsToStorePayloadOpts := make(map[string][]StorePayloadOpts, 0)
for i, task := range tasks {
params, ok := stepIdsToParams[task.StepId]
if !ok {
params = sqlcv1.ReplayTasksParams{
Taskids: make([]int64, 0),
Taskinsertedats: make([]pgtype.Timestamptz, 0),
Inputs: make([][]byte, 0),
InitialStates: make([]string, 0),
InitialStateReasons: make([]pgtype.Text, 0),
Concurrencykeys: make([][]string, 0),
}
}
input := externalIdToInput[task.ExternalId]
params.Taskids = append(params.Taskids, taskIds[i])
params.Taskinsertedats = append(params.Taskinsertedats, taskInsertedAts[i])
params.Inputs = append(params.Inputs, input)
params.InitialStates = append(params.InitialStates, initialStates[i])
params.InitialStateReasons = append(params.InitialStateReasons, initialStateReasons[i])
params.Concurrencykeys = append(params.Concurrencykeys, concurrencyKeys[i])
stepIdsToParams[task.StepId] = params
storePayloadOpts := StorePayloadOpts{
Id: taskIds[i],
InsertedAt: taskInsertedAts[i],
Type: sqlcv1.V1PayloadTypeTASKINPUT,
Payload: input,
TenantId: tenantId,
}
stepIdsToStorePayloadOpts[task.StepId] = append(stepIdsToStorePayloadOpts[task.StepId], storePayloadOpts)
}
res := make([]*V1TaskWithPayload, 0)
// for any initial states which are not queued, create a finalizing task event
eventTaskIdRetryCounts := make([]TaskIdInsertedAtRetryCount, 0)
eventTaskExternalIds := make([]string, 0)
eventDatas := make([][]byte, 0)
eventTypes := make([]sqlcv1.V1TaskEventType, 0)
for stepId, params := range stepIdsToParams {
replayRes, err := r.queries.ReplayTasks(ctx, tx, params)
if err != nil {
return nil, fmt.Errorf("failed to replay tasks for step id %s: %w", stepId, err)
}
storePayloadOpts, ok := stepIdsToStorePayloadOpts[stepId]
if !ok {
return nil, fmt.Errorf("missing payload store opts for step id %s", stepId)
}
err = r.payloadStore.Store(ctx, tx, storePayloadOpts...)
if err != nil {
return nil, fmt.Errorf("failed to store payloads for step id %s: %w", stepId, err)
}
replayResWithPayloads := make([]*V1TaskWithPayload, len(replayRes))
for i, task := range replayRes {
input := externalIdToInput[sqlchelpers.UUIDToStr(task.ExternalID)]
withPayload := V1TaskWithPayload{
V1Task: task,
Payload: input,
}
replayResWithPayloads[i] = &withPayload
res = append(res, &withPayload)
}
for _, replayedTask := range replayResWithPayloads {
idRetryCount := TaskIdInsertedAtRetryCount{
Id: replayedTask.ID,
InsertedAt: replayedTask.InsertedAt,
RetryCount: replayedTask.RetryCount,
}
switch replayedTask.InitialState {
case sqlcv1.V1TaskInitialStateFAILED:
eventTaskIdRetryCounts = append(eventTaskIdRetryCounts, idRetryCount)
eventTaskExternalIds = append(eventTaskExternalIds, sqlchelpers.UUIDToStr(replayedTask.ExternalID))
eventDatas = append(eventDatas, NewFailedTaskOutputEventFromTask(replayedTask).Bytes())
eventTypes = append(eventTypes, sqlcv1.V1TaskEventTypeFAILED)
case sqlcv1.V1TaskInitialStateCANCELLED:
eventTaskIdRetryCounts = append(eventTaskIdRetryCounts, idRetryCount)
eventTaskExternalIds = append(eventTaskExternalIds, sqlchelpers.UUIDToStr(replayedTask.ExternalID))
eventDatas = append(eventDatas, NewCancelledTaskOutputEventFromTask(replayedTask).Bytes())
eventTypes = append(eventTypes, sqlcv1.V1TaskEventTypeCANCELLED)
case sqlcv1.V1TaskInitialStateSKIPPED:
eventTaskIdRetryCounts = append(eventTaskIdRetryCounts, idRetryCount)
eventTaskExternalIds = append(eventTaskExternalIds, sqlchelpers.UUIDToStr(replayedTask.ExternalID))
eventDatas = append(eventDatas, NewSkippedTaskOutputEventFromTask(replayedTask).Bytes())
eventTypes = append(eventTypes, sqlcv1.V1TaskEventTypeCOMPLETED)
}
}
}
_, err = r.createTaskEvents(
ctx,
tx,
tenantId,
eventTaskIdRetryCounts,
eventTaskExternalIds,
eventDatas,
eventTypes,
make([]string, len(eventTaskIdRetryCounts)),
nil,
)
if err != nil {
return nil, fmt.Errorf("failed to create task events: %w", err)
}
// TODO: this should be moved to after the transaction commits
saveQueueCache()
return res, nil
}
func (r *sharedRepository) getConcurrencyExpressions(
ctx context.Context,
tx sqlcv1.DBTX,
tenantId string,
stepIdsToConfig map[string]*sqlcv1.ListStepsByIdsRow,
) (map[string][]*sqlcv1.V1StepConcurrency, error) {
stepIdsWithExpressions := make(map[string]struct{})
for _, step := range stepIdsToConfig {
if step.ConcurrencyCount > 0 {
stepIdsWithExpressions[sqlchelpers.UUIDToStr(step.ID)] = struct{}{}
}
}
if len(stepIdsWithExpressions) == 0 {
return nil, nil
}
stepIds := make([]pgtype.UUID, 0, len(stepIdsWithExpressions))
for stepId := range stepIdsWithExpressions {
stepIds = append(stepIds, sqlchelpers.UUIDFromStr(stepId))
}
strats, err := r.queries.ListConcurrencyStrategiesByStepId(ctx, tx, sqlcv1.ListConcurrencyStrategiesByStepIdParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Stepids: stepIds,
})
if err != nil {
return nil, err
}
stepIdToStrats := make(map[string][]*sqlcv1.V1StepConcurrency)
sort.SliceStable(strats, func(i, j int) bool {
iStrat := strats[i]
jStrat := strats[j]
if iStrat.ParentStrategyID.Valid && jStrat.ParentStrategyID.Valid && iStrat.ParentStrategyID.Int64 != jStrat.ParentStrategyID.Int64 {
return iStrat.ParentStrategyID.Int64 < jStrat.ParentStrategyID.Int64
}
if iStrat.ParentStrategyID.Valid && !jStrat.ParentStrategyID.Valid {
return true
}
if !iStrat.ParentStrategyID.Valid && jStrat.ParentStrategyID.Valid {
return false
}
return iStrat.ID < jStrat.ID
})
for _, strat := range strats {
stepId := sqlchelpers.UUIDToStr(strat.StepID)
if _, ok := stepIdToStrats[stepId]; !ok {
stepIdToStrats[stepId] = make([]*sqlcv1.V1StepConcurrency, 0)
}
stepIdToStrats[stepId] = append(stepIdToStrats[stepId], strat)
}
return stepIdToStrats, nil
}
func (r *sharedRepository) getStepExpressions(
ctx context.Context,
tx sqlcv1.DBTX,
stepIdsToConfig map[string]*sqlcv1.ListStepsByIdsRow,
) (map[string][]*sqlcv1.StepExpression, error) {
stepIdsWithExpressions := make(map[string]struct{})
for _, step := range stepIdsToConfig {
if step.ExprCount > 0 {
stepIdsWithExpressions[sqlchelpers.UUIDToStr(step.ID)] = struct{}{}
}
}
if len(stepIdsWithExpressions) == 0 {
return map[string][]*sqlcv1.StepExpression{}, nil
}
stepIds := make([]pgtype.UUID, 0, len(stepIdsWithExpressions))
for stepId := range stepIdsWithExpressions {
stepIds = append(stepIds, sqlchelpers.UUIDFromStr(stepId))
}
expressions, err := r.queries.ListStepExpressions(ctx, tx, stepIds)
if err != nil {
return nil, err
}
stepIdToExpressions := make(map[string][]*sqlcv1.StepExpression)
for _, expression := range expressions {
stepId := sqlchelpers.UUIDToStr(expression.StepId)
if _, ok := stepIdToExpressions[stepId]; !ok {
stepIdToExpressions[stepId] = make([]*sqlcv1.StepExpression, 0)
}
stepIdToExpressions[stepId] = append(stepIdToExpressions[stepId], expression)
}
return stepIdToExpressions, nil
}
func (r *sharedRepository) createTaskEventsAfterRelease(
ctx context.Context,
tx sqlcv1.DBTX,
tenantId string,
taskIdRetryCounts []TaskIdInsertedAtRetryCount,
outputs [][]byte,
releasedTasks []*sqlcv1.ReleaseTasksRow,
eventType sqlcv1.V1TaskEventType,
) ([]InternalTaskEvent, error) {
if len(taskIdRetryCounts) != len(releasedTasks) || len(taskIdRetryCounts) != len(outputs) {
return nil, fmt.Errorf("failed to release all tasks")
}
datas := make([][]byte, len(releasedTasks))
externalIds := make([]string, len(releasedTasks))
isCurrentRetry := make([]bool, len(releasedTasks))
for i, releasedTask := range releasedTasks {
datas[i] = outputs[i]
externalIds[i] = sqlchelpers.UUIDToStr(releasedTask.ExternalID)
isCurrentRetry[i] = releasedTask.IsCurrentRetry
}
// filter out any rows which are not the current retry
filteredTaskIdRetryCounts := make([]TaskIdInsertedAtRetryCount, 0)
filteredDatas := make([][]byte, 0)
filteredExternalIds := make([]string, 0)
for i := range len(datas) {
if !isCurrentRetry[i] {
continue
}
filteredTaskIdRetryCounts = append(filteredTaskIdRetryCounts, taskIdRetryCounts[i])
filteredDatas = append(filteredDatas, datas[i])
filteredExternalIds = append(filteredExternalIds, externalIds[i])
}
return r.createTaskEvents(
ctx,
tx,
tenantId,
filteredTaskIdRetryCounts,
filteredExternalIds,
filteredDatas,
makeEventTypeArr(eventType, len(filteredExternalIds)),
make([]string, len(filteredExternalIds)),
nil,
)
}
func (r *sharedRepository) createTaskEvents(
ctx context.Context,
dbtx sqlcv1.DBTX,
tenantId string,
tasks []TaskIdInsertedAtRetryCount,
taskExternalIds []string,
eventDatas [][]byte,
eventTypes []sqlcv1.V1TaskEventType,
eventKeys []string,
eventExternalIds *[]string,
) ([]InternalTaskEvent, error) {
if len(tasks) != len(eventDatas) {
return nil, fmt.Errorf("mismatched task and event data lengths")
}
taskIds := make([]int64, len(tasks))
taskInsertedAts := make([]pgtype.Timestamptz, len(tasks))
retryCounts := make([]int32, len(tasks))
eventTypesStrs := make([]string, len(tasks))
paramDatas := make([][]byte, len(tasks))
paramKeys := make([]pgtype.Text, len(tasks))
externalIds := make([]pgtype.UUID, len(tasks))
internalTaskEvents := make([]InternalTaskEvent, len(tasks))
externalIdToData := make(map[pgtype.UUID][]byte, len(tasks))
for i, task := range tasks {
taskIds[i] = task.Id
taskInsertedAts[i] = task.InsertedAt
retryCounts[i] = task.RetryCount
eventTypesStrs[i] = string(eventTypes[i])
externalId := sqlchelpers.UUIDFromStr(uuid.NewString())
if eventExternalIds != nil {
externalId = sqlchelpers.UUIDFromStr((*eventExternalIds)[i])
}
externalIds[i] = externalId
// important: if we don't set this to `eventDatas[i]` and instead allow it to be nil optionally
// we'll get errors downstream when we try to read the payload back and parse it in `registerChildWorkflows`
// because it'll try to unmarshal the `nil` value.
externalIdToData[externalId] = eventDatas[i]
if len(eventDatas[i]) == 0 || !r.payloadStore.TaskEventDualWritesEnabled() {
paramDatas[i] = nil
} else {
paramDatas[i] = eventDatas[i]
}
if eventKeys[i] != "" {
paramKeys[i] = pgtype.Text{
String: eventKeys[i],
Valid: true,
}
}
internalTaskEvents[i] = InternalTaskEvent{
TaskID: task.Id,
TaskExternalID: taskExternalIds[i],
TenantID: tenantId,
RetryCount: task.RetryCount,
EventType: eventTypes[i],
EventKey: eventKeys[i],
Data: eventDatas[i],
}
}
taskEvents, err := r.queries.CreateTaskEvents(ctx, dbtx, sqlcv1.CreateTaskEventsParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Taskids: taskIds,
Taskinsertedats: taskInsertedAts,
Retrycounts: retryCounts,
Eventtypes: eventTypesStrs,
Datas: paramDatas,
Eventkeys: paramKeys,
Externalids: externalIds,
})
if err != nil {
return nil, err
}
storePayloadOpts := make([]StorePayloadOpts, len(taskEvents))
for i, taskEvent := range taskEvents {
data := externalIdToData[taskEvent.ExternalID]
storePayloadOpts[i] = StorePayloadOpts{
Id: taskEvent.ID,
InsertedAt: taskEvent.InsertedAt,
Type: sqlcv1.V1PayloadTypeTASKEVENTDATA,
Payload: data,
TenantId: tenantId,
}
}
err = r.payloadStore.Store(ctx, dbtx, storePayloadOpts...)
if err != nil {
return nil, fmt.Errorf("failed to store task event payloads: %w", err)
}
return internalTaskEvents, nil
}
func makeEventTypeArr(status sqlcv1.V1TaskEventType, n int) []sqlcv1.V1TaskEventType {
a := make([]sqlcv1.V1TaskEventType, n)
for i := range a {
a[i] = status
}
return a
}
func hash(s string) int64 {
h := fnv.New64a()
h.Write([]byte(s))
return int64(h.Sum64())
}
func (r *TaskRepositoryImpl) ReplayTasks(ctx context.Context, tenantId string, tasks []TaskIdInsertedAtRetryCount) (*ReplayTasksResult, error) {
tx, commit, rollback, err := sqlchelpers.PrepareTx(ctx, r.pool, r.l, 30000)
if err != nil {
return nil, err
}
err = r.queries.AdvisoryLock(ctx, tx, hash("replay_"+tenantId))
if err != nil {
return nil, fmt.Errorf("failed to acquire advisory lock: %w", err)
}
defer rollback()
taskIds := make([]int64, len(tasks))
taskInsertedAts := make([]pgtype.Timestamptz, len(tasks))
for i, task := range tasks {
taskIds[i] = task.Id
taskInsertedAts[i] = task.InsertedAt
}
// list tasks (and augment with task descendants) and locks them for update
lockedTasks, err := r.queries.ListTasksForReplay(ctx, tx, sqlcv1.ListTasksForReplayParams{
Taskids: taskIds,
Taskinsertedats: taskInsertedAts,
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
})
if err != nil {
return nil, fmt.Errorf("failed to list tasks for replay: %w", err)
}
lockedTaskIds := make([]int64, len(lockedTasks))
lockedTaskInsertedAts := make([]pgtype.Timestamptz, len(lockedTasks))
subtreeStepIds := make(map[int64]map[string]bool) // dag id -> step id -> true
subtreeExternalIds := make(map[string]struct{})
dagIdsToLockMap := make(map[int64]struct{})
minInsertedAt := sqlchelpers.TimestamptzFromTime(time.Now()) // current time as a placeholder - will be overwritten
for i, task := range lockedTasks {
lockedTaskIds[i] = task.ID
lockedTaskInsertedAts[i] = task.InsertedAt
if task.DagID.Valid {
if _, ok := subtreeStepIds[task.DagID.Int64]; !ok {
subtreeStepIds[task.DagID.Int64] = make(map[string]bool)
}
dagIdsToLockMap[task.DagID.Int64] = struct{}{}
subtreeStepIds[task.DagID.Int64][sqlchelpers.UUIDToStr(task.StepID)] = true
subtreeExternalIds[sqlchelpers.UUIDToStr(task.ExternalID)] = struct{}{}
}
if task.InsertedAt.Time.Before(minInsertedAt.Time) {
minInsertedAt = task.InsertedAt
}
}
// lock all tasks in the DAGs
dagIdsToLock := make([]int64, 0, len(dagIdsToLockMap))
for dagId := range dagIdsToLockMap {
dagIdsToLock = append(dagIdsToLock, dagId)
}
successfullyLockedDAGIds, err := r.queries.LockDAGsForReplay(ctx, tx, sqlcv1.LockDAGsForReplayParams{
Dagids: dagIdsToLock,
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
})
if err != nil {
return nil, fmt.Errorf("failed to lock DAGs for replay: %w", err)
}
successfullyLockedDAGsMap := make(map[int64]bool)
for _, dagId := range successfullyLockedDAGIds {
successfullyLockedDAGsMap[dagId] = true
}
// Discard tasks which can't be replayed. Discard rules are as follows:
// 1. If a task is in a running state, discard it.
// 2. If a task is in a running state and has a DAG id, discard all tasks in the DAG.
// 3. If a task has a DAG id but it is not present in the successfully locked DAGs, discard it.
dagIdsFailedPreflight := make(map[int64]bool)
preflightDAGs, err := r.queries.PreflightCheckDAGsForReplay(ctx, tx, sqlcv1.PreflightCheckDAGsForReplayParams{
Dagids: successfullyLockedDAGIds,
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
})
if err != nil {
return nil, fmt.Errorf("failed to preflight check DAGs for replay: %w", err)
}
for _, dag := range preflightDAGs {
if dag.StepCount != dag.TaskCount {
dagIdsFailedPreflight[dag.ID] = true
}
}
tasksFailedPreflight := make(map[int64]bool)
failedPreflightChecks, err := r.queries.PreflightCheckTasksForReplay(ctx, tx, sqlcv1.PreflightCheckTasksForReplayParams{
Taskids: lockedTaskIds,
Taskinsertedats: lockedTaskInsertedAts,
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Mininsertedat: minInsertedAt,
})
if err != nil {
return nil, fmt.Errorf("failed to preflight check tasks for replay: %w", err)
}
for _, task := range failedPreflightChecks {
tasksFailedPreflight[task.ID] = true
}
// group tasks by their dag_id, if it exists
dagIdsToChildTasks := make(map[int64][]*sqlcv1.ListTasksForReplayRow)
dagIds := make(map[int64]struct{}, 0)
// figure out which tasks to replay immediately
replayOpts := make([]ReplayTaskOpts, 0)
replayedTasks := make([]TaskIdInsertedAtRetryCount, 0)
retrieveOpts := make([]RetrievePayloadOpts, len(lockedTasks))
for i, task := range lockedTasks {
retrieveOpts[i] = RetrievePayloadOpts{
Id: task.ID,
InsertedAt: task.InsertedAt,
Type: sqlcv1.V1PayloadTypeTASKINPUT,
TenantId: sqlchelpers.UUIDFromStr(tenantId),
}
}
payloads, err := r.payloadStore.BulkRetrieve(ctx, retrieveOpts...)
if err != nil {
return nil, fmt.Errorf("failed to bulk retrieve task inputs: %w", err)
}
for _, task := range lockedTasks {
// check whether to discard the task
if task.DagID.Valid && !successfullyLockedDAGsMap[task.DagID.Int64] {
r.l.Warn().Int64("task_id", task.ID).Msg("discarding task, could not lock DAG")
continue
}
if task.DagID.Valid && dagIdsFailedPreflight[task.DagID.Int64] {
r.l.Warn().Int64("task_id", task.ID).Msg("discarding task, failed preflight check for DAG")
continue
}
if tasksFailedPreflight[task.ID] {
r.l.Warn().Int64("task_id", task.ID).Msg("discarding task, failed preflight check")
continue
}
replayedTasks = append(replayedTasks, TaskIdInsertedAtRetryCount{
Id: task.ID,
InsertedAt: task.InsertedAt,
RetryCount: task.RetryCount,
})
if task.DagID.Valid {
dagIds[task.DagID.Int64] = struct{}{}
}
if task.DagID.Valid && len(task.Parents) > 0 {
isParentBeingReplayed := false
for _, parent := range task.Parents {
if subtreeStepIds[task.DagID.Int64][sqlchelpers.UUIDToStr(parent)] {
isParentBeingReplayed = true
break
}
}
if isParentBeingReplayed {
if _, ok := dagIdsToChildTasks[task.DagID.Int64]; !ok {
dagIdsToChildTasks[task.DagID.Int64] = make([]*sqlcv1.ListTasksForReplayRow, 0)
}
dagIdsToChildTasks[task.DagID.Int64] = append(dagIdsToChildTasks[task.DagID.Int64], task)
continue
}
}
if task.DagID.Valid && task.JobKind == sqlcv1.JobKindONFAILURE {
// we need to check if there are other steps in the subtree
doesOnFailureHaveOtherSteps := false
for stepId := range subtreeStepIds[task.DagID.Int64] {
if stepId == sqlchelpers.UUIDToStr(task.StepID) {
continue
}
doesOnFailureHaveOtherSteps = true
break
}
if doesOnFailureHaveOtherSteps {
if _, ok := dagIdsToChildTasks[task.DagID.Int64]; !ok {
dagIdsToChildTasks[task.DagID.Int64] = make([]*sqlcv1.ListTasksForReplayRow, 0)
}
dagIdsToChildTasks[task.DagID.Int64] = append(dagIdsToChildTasks[task.DagID.Int64], task)
continue
}
}
retrieveOpt := RetrievePayloadOpts{
Id: task.ID,
InsertedAt: task.InsertedAt,
Type: sqlcv1.V1PayloadTypeTASKINPUT,
TenantId: sqlchelpers.UUIDFromStr(tenantId),
}
input, ok := payloads[retrieveOpt]
if !ok {
// If the input wasn't found in the payload store,
// fall back to the input stored on the task itself.
r.l.Error().Msgf("ReplayTasks: task %s with ID %d and inserted_at %s has empty payload, falling back to input", task.ExternalID.String(), task.ID, task.InsertedAt.Time)
input = task.Input
}
replayOpts = append(replayOpts, ReplayTaskOpts{
TaskId: task.ID,
InsertedAt: task.InsertedAt,
StepId: sqlchelpers.UUIDToStr(task.StepID),
ExternalId: sqlchelpers.UUIDToStr(task.ExternalID),
InitialState: sqlcv1.V1TaskInitialStateQUEUED,
AdditionalMetadata: task.AdditionalMetadata,
// NOTE: we require the input to be passed in to the replay method so we can re-evaluate the concurrency keys
// Ideally we could preserve the same concurrency keys, but the replay tasks method is currently unaware of existing concurrency
// keys because they may change between retries.
Input: r.newTaskInputFromExistingBytes(input),
})
}
dagIdsArr := make([]int64, 0, len(dagIds))
for dagId := range dagIds {
dagIdsArr = append(dagIdsArr, dagId)
}
allTasksInDAGs, err := r.queries.ListAllTasksInDags(ctx, tx, sqlcv1.ListAllTasksInDagsParams{
Dagids: dagIdsArr,
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
})
if err != nil {
return nil, fmt.Errorf("failed to list all tasks in DAGs: %w", err)
}
dagIdsToAllTasks := make(map[int64][]*sqlcv1.ListAllTasksInDagsRow)
stepIdsInDAGs := make([]pgtype.UUID, 0)
for _, task := range allTasksInDAGs {
if _, ok := dagIdsToAllTasks[task.DagID.Int64]; !ok {
dagIdsToAllTasks[task.DagID.Int64] = make([]*sqlcv1.ListAllTasksInDagsRow, 0)
}
stepIdsInDAGs = append(stepIdsInDAGs, task.StepID)
dagIdsToAllTasks[task.DagID.Int64] = append(dagIdsToAllTasks[task.DagID.Int64], task)
}
upsertedTasks := make([]*V1TaskWithPayload, 0)
// NOTE: the tasks which are passed in represent a *subtree* of the DAG.
if len(replayOpts) > 0 {
upsertedTasks, err = r.replayTasks(ctx, tx, tenantId, replayOpts)
if err != nil {
return nil, fmt.Errorf("failed to replay existing tasks: %w", err)
}
}
// for any tasks which are child tasks, we need to reset the match signals for the parent tasks
spawnedChildTasks := make([]*sqlcv1.ListTasksForReplayRow, 0)
for _, task := range lockedTasks {
if task.ParentTaskID.Valid {
spawnedChildTasks = append(spawnedChildTasks, task)
}
}
eventMatches := make([]CreateMatchOpts, 0)
if len(spawnedChildTasks) > 0 {
// construct a list of signals to reset
signalEventKeys := make([]string, 0)
parentTaskIds := make([]int64, 0)
parentTaskInsertedAts := make([]pgtype.Timestamptz, 0)
for _, task := range spawnedChildTasks {
if !task.ChildIndex.Valid {
// TODO: handle error better/check with validation that this won't happen
r.l.Error().Msg("could not find child key or index for child workflow")
continue
}
var childKey *string
if task.ChildKey.Valid {
childKey = &task.ChildKey.String
}
parentExternalId := sqlchelpers.UUIDToStr(task.ParentTaskExternalID)
k := getChildSignalEventKey(parentExternalId, task.StepIndex, task.ChildIndex.Int64, childKey)
signalEventKeys = append(signalEventKeys, k)
parentTaskIds = append(parentTaskIds, task.ParentTaskID.Int64)
parentTaskInsertedAts = append(parentTaskInsertedAts, task.ParentTaskInsertedAt)
eventMatches = append(eventMatches, CreateMatchOpts{
Kind: sqlcv1.V1MatchKindSIGNAL,
Conditions: getChildWorkflowGroupMatches(sqlchelpers.UUIDToStr(task.ExternalID), task.StepReadableID),
SignalExternalId: &parentExternalId,
SignalTaskId: &task.ParentTaskID.Int64,
SignalTaskInsertedAt: task.ParentTaskInsertedAt,
SignalKey: &k,
})
}
err = r.queries.DeleteMatchingSignalEvents(ctx, tx, sqlcv1.DeleteMatchingSignalEventsParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Eventkeys: signalEventKeys,
Taskids: parentTaskIds,
Taskinsertedats: parentTaskInsertedAts,
Eventtype: sqlcv1.V1TaskEventTypeSIGNALCOMPLETED,
})
if err != nil {
return nil, fmt.Errorf("failed to delete matching signal events: %w", err)
}
}
// For any DAGs, reset all match conditions which refer to internal events within the subtree of the DAG.
// we do not reset other match conditions (for example, ones which refer to completed events for tasks
// which are outside of this subtree). otherwise, we would end up in a state where these events would
// never be matched.
// if any steps have additional match conditions, query for the additional matches
stepsToAdditionalMatches := make(map[string][]*sqlcv1.V1StepMatchCondition)
if len(stepIdsInDAGs) > 0 {
additionalMatches, err := r.queries.ListStepMatchConditions(ctx, r.pool, sqlcv1.ListStepMatchConditionsParams{
Stepids: sqlchelpers.UniqueSet(stepIdsInDAGs),
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
})
if err != nil {
return nil, fmt.Errorf("failed to list step match conditions: %w", err)
}
for _, match := range additionalMatches {
stepId := sqlchelpers.UUIDToStr(match.StepID)
stepsToAdditionalMatches[stepId] = append(stepsToAdditionalMatches[stepId], match)
}
}
for dagId, tasks := range dagIdsToChildTasks {
allTasks := dagIdsToAllTasks[dagId]
for _, task := range tasks {
taskExternalId := sqlchelpers.UUIDToStr(task.ExternalID)
stepId := sqlchelpers.UUIDToStr(task.StepID)
switch {
case task.JobKind == sqlcv1.JobKindONFAILURE:
conditions := make([]GroupMatchCondition, 0)
groupId := uuid.NewString()
for _, otherTask := range allTasks {
if sqlchelpers.UUIDToStr(otherTask.StepID) == stepId {
continue
}
otherExternalId := sqlchelpers.UUIDToStr(otherTask.ExternalID)
readableId := otherTask.StepReadableID
conditions = append(conditions, getParentOnFailureGroupMatches(groupId, otherExternalId, readableId)...)
}
eventMatches = append(eventMatches, CreateMatchOpts{
Kind: sqlcv1.V1MatchKindTRIGGER,
Conditions: conditions,
TriggerExternalId: &taskExternalId,
TriggerStepId: &stepId,
TriggerStepIndex: pgtype.Int8{
Int64: task.StepIndex,
Valid: true,
},
TriggerDAGId: &task.DagID.Int64,
TriggerDAGInsertedAt: task.DagInsertedAt,
// NOTE: we don't need to set parent task id/child index/child key because
// the task already exists
TriggerExistingTaskId: &task.ID,
TriggerExistingTaskInsertedAt: task.InsertedAt,
})
default:
conditions := make([]GroupMatchCondition, 0)
cancelGroupId := uuid.NewString()
additionalMatches, ok := stepsToAdditionalMatches[stepId]
if !ok {
additionalMatches = make([]*sqlcv1.V1StepMatchCondition, 0)
}
for _, parent := range task.Parents {
// FIXME: n^2 complexity here, fix it.
for _, otherTask := range allTasks {
if otherTask.StepID == parent {
parentExternalId := sqlchelpers.UUIDToStr(otherTask.ExternalID)
readableId := otherTask.StepReadableID
hasUserEventOrSleepMatches := false
hasAnySkippingParentOverrides := false
parentOverrideMatches := make([]*sqlcv1.V1StepMatchCondition, 0)
for _, match := range additionalMatches {
if match.Kind == sqlcv1.V1StepMatchConditionKindPARENTOVERRIDE {
if match.ParentReadableID.String == readableId {
parentOverrideMatches = append(parentOverrideMatches, match)
}
if match.Action == sqlcv1.V1MatchConditionActionSKIP {
hasAnySkippingParentOverrides = true
}
} else {
hasUserEventOrSleepMatches = true
}
}
conditions = append(conditions, getParentInDAGGroupMatch(cancelGroupId, parentExternalId, readableId, parentOverrideMatches, hasUserEventOrSleepMatches, hasAnySkippingParentOverrides)...)
}
}
}
// create an event match
eventMatches = append(eventMatches, CreateMatchOpts{
Kind: sqlcv1.V1MatchKindTRIGGER,
Conditions: conditions,
TriggerExternalId: &taskExternalId,
TriggerStepId: &stepId,
TriggerStepIndex: pgtype.Int8{
Int64: task.StepIndex,
Valid: true,
},
TriggerDAGId: &task.DagID.Int64,
TriggerDAGInsertedAt: task.DagInsertedAt,
// NOTE: we don't need to set parent task id/child index/child key because
// the task already exists
TriggerExistingTaskId: &task.ID,
TriggerExistingTaskInsertedAt: task.InsertedAt,
})
}
}
}
// reconstruct group conditions
reconstructedMatches, candidateEvents, err := r.reconstructGroupConditions(ctx, tx, tenantId, subtreeExternalIds, eventMatches)
if err != nil {
return nil, fmt.Errorf("failed to reconstruct group conditions: %w", err)
}
// create the event matches
err = r.createEventMatches(ctx, tx, tenantId, reconstructedMatches)
if err != nil {
return nil, fmt.Errorf("failed to create event matches: %w", err)
}
// process event matches
// TODO: signal the event matches to the caller
internalMatchResults, err := r.processEventMatches(ctx, tx, tenantId, candidateEvents, sqlcv1.V1EventTypeINTERNAL)
if err != nil {
return nil, fmt.Errorf("failed to process internal event matches: %w", err)
}
if err := commit(ctx); err != nil {
return nil, fmt.Errorf("failed to commit transaction: %w", err)
}
return &ReplayTasksResult{
ReplayedTasks: replayedTasks,
UpsertedTasks: upsertedTasks,
InternalEventResults: internalMatchResults,
}, nil
}
func (r *TaskRepositoryImpl) reconstructGroupConditions(
ctx context.Context,
tx sqlcv1.DBTX,
tenantId string,
subtreeExternalIds map[string]struct{},
eventMatches []CreateMatchOpts,
) ([]CreateMatchOpts, []CandidateEventMatch, error) {
// track down completed tasks and failed tasks which represent parents that aren't in the subtree
// of the DAG. for these tasks, we need to write the match conditions which refer to these tasks
// as satisfied match conditions.
// in other words, if the group match condition is an INTERNAL event and refers to a parentExternalId
// which is NOT in the subtree of what we're replaying, it represent a group condition where we'd like
// to query the task_events table to ensure the event has already occurred. if it has, we can mark the
// group condition as satisfied.
externalIds := make([]pgtype.UUID, 0)
eventTypes := make([][]string, 0)
for _, match := range eventMatches {
if match.TriggerDAGId == nil {
continue
}
for _, groupCondition := range match.Conditions {
if groupCondition.EventType == sqlcv1.V1EventTypeINTERNAL && groupCondition.EventResourceHint != nil {
externalId := *groupCondition.EventResourceHint
// if the parent task is not in the subtree, we need to query the task_events table
// to ensure the event has already occurred
if _, ok := subtreeExternalIds[externalId]; !ok {
externalIds = append(externalIds, sqlchelpers.UUIDFromStr(*groupCondition.EventResourceHint))
eventTypes = append(eventTypes, []string{groupCondition.EventKey})
}
}
}
}
// for candidate group matches, track down the task events which satisfy the group match conditions.
// we do this by constructing arrays for dag ids, external ids and event types, and then querying
// by the dag_id -> v1_task (on external_id) -> v1_task_event (on event type)
//
// NOTE: at this point, we have already deleted the tasks and events that are in the subtree, so we
// don't have to worry about collisions with the tasks we're replaying.
matchedEvents, err := r.queries.ListMatchingTaskEvents(ctx, tx, sqlcv1.ListMatchingTaskEventsParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Taskexternalids: externalIds,
Eventtypes: eventTypes,
})
if err != nil {
return nil, nil, err
}
foundMatchKeys := make(map[string]*sqlcv1.ListMatchingTaskEventsRow)
for _, eventMatch := range matchedEvents {
key := fmt.Sprintf("%s:%s", sqlchelpers.UUIDToStr(eventMatch.ExternalID), string(eventMatch.EventType))
foundMatchKeys[key] = eventMatch
}
resMatches := make([]CreateMatchOpts, 0)
resCandidateEvents := make([]CandidateEventMatch, 0)
// for each group condition, if we have a match, mark the group condition as satisfied and use
// the data from the match to update the group condition.
for _, match := range eventMatches {
if match.TriggerDAGId == nil {
resMatches = append(resMatches, match)
continue
}
conditions := make([]GroupMatchCondition, 0)
for _, groupCondition := range match.Conditions {
cond := groupCondition
if groupCondition.EventType == sqlcv1.V1EventTypeINTERNAL && groupCondition.EventResourceHint != nil {
key := fmt.Sprintf("%s:%s", *groupCondition.EventResourceHint, string(groupCondition.EventKey))
if match, ok := foundMatchKeys[key]; ok {
cond.Data = match.Data
taskExternalId := sqlchelpers.UUIDToStr(match.ExternalID)
resCandidateEvents = append(resCandidateEvents, CandidateEventMatch{
ID: uuid.NewString(),
EventTimestamp: match.CreatedAt.Time,
Key: string(match.EventType),
ResourceHint: &taskExternalId,
Data: match.Data,
})
}
}
conditions = append(conditions, cond)
}
match.Conditions = conditions
resMatches = append(resMatches, match)
}
return resMatches, resCandidateEvents, nil
}
type createTaskExpressionEvalOpt struct {
Key string
ValueStr *string
ValueInt *int
Kind sqlcv1.StepExpressionKind
}
func (r *sharedRepository) createExpressionEvals(ctx context.Context, dbtx sqlcv1.DBTX, createdTasks []*V1TaskWithPayload, opts map[string][]createTaskExpressionEvalOpt) error {
if len(opts) == 0 {
return nil
}
// map tasks using their external id
taskExternalIds := make(map[string]*V1TaskWithPayload)
for _, task := range createdTasks {
taskExternalIds[sqlchelpers.UUIDToStr(task.ExternalID)] = task
}
taskIds := make([]int64, 0)
taskInsertedAts := make([]pgtype.Timestamptz, 0)
keys := make([]string, 0)
valuesStr := make([]pgtype.Text, 0)
valuesInt := make([]pgtype.Int4, 0)
kinds := make([]string, 0)
for externalId, optList := range opts {
task, ok := taskExternalIds[externalId]
if !ok {
r.l.Warn().Str("external_id", externalId).Msg("could not find task for expression eval")
continue
}
for _, opt := range optList {
taskIds = append(taskIds, task.ID)
taskInsertedAts = append(taskInsertedAts, task.InsertedAt)
keys = append(keys, opt.Key)
if opt.ValueStr != nil {
valuesStr = append(valuesStr, pgtype.Text{
String: *opt.ValueStr,
Valid: true,
})
} else {
valuesStr = append(valuesStr, pgtype.Text{})
}
if opt.ValueInt != nil {
valuesInt = append(valuesInt, pgtype.Int4{
Int32: int32(*opt.ValueInt),
Valid: true,
})
} else {
valuesInt = append(valuesInt, pgtype.Int4{})
}
kinds = append(kinds, string(opt.Kind))
}
}
return r.queries.CreateTaskExpressionEvals(
ctx,
dbtx,
sqlcv1.CreateTaskExpressionEvalsParams{
Taskids: taskIds,
Taskinsertedats: taskInsertedAts,
Keys: keys,
Valuesstr: valuesStr,
Valuesint: valuesInt,
Kinds: kinds,
},
)
}
func uniqueSet(taskIdRetryCounts []TaskIdInsertedAtRetryCount) []TaskIdInsertedAtRetryCount {
unique := make(map[string]struct{})
res := make([]TaskIdInsertedAtRetryCount, 0)
for _, task := range taskIdRetryCounts {
k := fmt.Sprintf("%d:%d", task.Id, task.RetryCount)
if _, ok := unique[k]; !ok {
unique[k] = struct{}{}
res = append(res, task)
}
}
return res
}
func (r *TaskRepositoryImpl) ListTaskParentOutputs(ctx context.Context, tenantId string, tasks []*sqlcv1.V1Task) (map[int64][]*TaskOutputEvent, error) {
taskIds := make([]int64, 0)
taskInsertedAts := make([]pgtype.Timestamptz, 0)
for _, task := range tasks {
if task.DagID.Valid {
taskIds = append(taskIds, task.ID)
taskInsertedAts = append(taskInsertedAts, task.InsertedAt)
}
}
resMap := make(map[int64][]*TaskOutputEvent)
if len(taskIds) == 0 {
return resMap, nil
}
res, err := r.queries.ListTaskParentOutputs(ctx, r.pool, sqlcv1.ListTaskParentOutputsParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Taskids: taskIds,
Taskinsertedats: taskInsertedAts,
})
if err != nil {
return nil, err
}
retrieveOpts := make([]RetrievePayloadOpts, 0, len(res))
retrieveOptsToWorkflowRunId := make(map[RetrievePayloadOpts]pgtype.UUID, len(res))
retrieveOptToPayload := make(map[RetrievePayloadOpts][]byte)
for _, outputTask := range res {
if !outputTask.WorkflowRunID.Valid {
continue
}
opt := RetrievePayloadOpts{
Id: outputTask.TaskEventID,
InsertedAt: outputTask.TaskEventInsertedAt,
Type: sqlcv1.V1PayloadTypeTASKEVENTDATA,
TenantId: sqlchelpers.UUIDFromStr(tenantId),
}
retrieveOpts = append(retrieveOpts, opt)
retrieveOptsToWorkflowRunId[opt] = outputTask.WorkflowRunID
retrieveOptToPayload[opt] = outputTask.Output
}
payloads, err := r.payloadStore.BulkRetrieve(ctx, retrieveOpts...)
if err != nil {
return nil, fmt.Errorf("failed to retrieve task output payloads: %w", err)
}
workflowRunIdsToOutputs := make(map[string][]*TaskOutputEvent)
for retrieveOpts, workflowRunId := range retrieveOptsToWorkflowRunId {
wrId := sqlchelpers.UUIDToStr(workflowRunId)
payload, ok := payloads[retrieveOpts]
if !ok {
r.l.Error().Msgf("ListTaskParentOutputs: task %s with ID %d and inserted_at %s has empty payload, falling back to input", wrId, retrieveOpts.Id, retrieveOpts.InsertedAt.Time)
payload = retrieveOptToPayload[retrieveOpts]
}
e, err := newTaskEventFromBytes(payload)
if err != nil {
r.l.Warn().Msgf("failed to parse task output: %v", err)
continue
}
workflowRunIdsToOutputs[wrId] = append(workflowRunIdsToOutputs[wrId], e)
}
for _, task := range tasks {
if task.WorkflowRunID.Valid {
wrId := sqlchelpers.UUIDToStr(task.WorkflowRunID)
if events, ok := workflowRunIdsToOutputs[wrId]; ok {
resMap[task.ID] = events
}
}
}
return resMap, nil
}
func (r *TaskRepositoryImpl) ListSignalCompletedEvents(ctx context.Context, tenantId string, tasks []TaskIdInsertedAtSignalKey) ([]*V1TaskEventWithPayload, error) {
taskIds := make([]int64, 0)
taskInsertedAts := make([]pgtype.Timestamptz, 0)
eventKeys := make([]string, 0)
for _, task := range tasks {
taskIds = append(taskIds, task.Id)
taskInsertedAts = append(taskInsertedAts, task.InsertedAt)
eventKeys = append(eventKeys, task.SignalKey)
}
signalEvents, err := r.queries.ListMatchingSignalEvents(ctx, r.pool, sqlcv1.ListMatchingSignalEventsParams{
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
Eventtype: sqlcv1.V1TaskEventTypeSIGNALCOMPLETED,
Taskids: taskIds,
Taskinsertedats: taskInsertedAts,
Eventkeys: eventKeys,
})
if err != nil {
return nil, fmt.Errorf("failed to list matching signal events: %w", err)
}
retrieveOpts := make([]RetrievePayloadOpts, len(signalEvents))
for i, event := range signalEvents {
retrieveOpt := RetrievePayloadOpts{
Id: event.ID,
InsertedAt: event.InsertedAt,
Type: sqlcv1.V1PayloadTypeTASKEVENTDATA,
TenantId: sqlchelpers.UUIDFromStr(tenantId),
}
retrieveOpts[i] = retrieveOpt
}
payloads, err := r.payloadStore.BulkRetrieve(ctx, retrieveOpts...)
if err != nil {
return nil, fmt.Errorf("failed to retrieve task event payloads: %w", err)
}
res := make([]*V1TaskEventWithPayload, len(signalEvents))
for i, event := range signalEvents {
retrieveOpt := RetrievePayloadOpts{
Id: event.ID,
InsertedAt: event.InsertedAt,
Type: sqlcv1.V1PayloadTypeTASKEVENTDATA,
TenantId: sqlchelpers.UUIDFromStr(tenantId),
}
payload, ok := payloads[retrieveOpt]
if !ok {
r.l.Error().Msgf("ListenForDurableEvent: task %s with ID %d and inserted_at %s has empty payload, falling back to input", event.ExternalID, event.ID, event.InsertedAt.Time)
payload = event.Data
}
res[i] = &V1TaskEventWithPayload{
V1TaskEvent: event,
Payload: payload,
}
}
return res, nil
}
func (r *TaskRepositoryImpl) AnalyzeTaskTables(ctx context.Context) error {
const timeout = 1000 * 60 * 60 // 60 minute timeout
tx, commit, rollback, err := sqlchelpers.PrepareTx(ctx, r.pool, r.l, timeout)
if err != nil {
return fmt.Errorf("error beginning transaction: %v", err)
}
defer rollback()
acquired, err := r.queries.TryAdvisoryLock(ctx, tx, hash("analyze-task-tables"))
if err != nil {
return fmt.Errorf("error acquiring advisory lock: %v", err)
}
if !acquired {
r.l.Info().Msg("advisory lock already held, skipping task table analysis")
return nil
}
err = r.queries.AnalyzeV1Task(ctx, tx)
if err != nil {
return fmt.Errorf("error analyzing v1_task: %v", err)
}
err = r.queries.AnalyzeV1TaskEvent(ctx, tx)
if err != nil {
return fmt.Errorf("error analyzing v1_task_event: %v", err)
}
err = r.queries.AnalyzeV1Payload(ctx, tx)
if err != nil {
return fmt.Errorf("error analyzing v1_payload: %v", err)
}
if err := commit(ctx); err != nil {
return fmt.Errorf("error committing transaction: %v", err)
}
return nil
}