fix: failure after cancellation (#3243)

* fix: failure after cancellation

* chore: generate

* fix: list multiple concurrency strategies

* fix: lock concerns
This commit is contained in:
Gabe Ruttner
2026-03-11 18:11:15 -07:00
committed by GitHub
parent 18f8fc45f7
commit 7748898c59
10 changed files with 124 additions and 60 deletions
+1 -1
View File
@@ -53,7 +53,7 @@ func (t *TasksService) V1TaskGet(ctx echo.Context, request gen.V1TaskGetRequestO
return nil, err
}
workflowVersion, _, _, _, _, err := t.config.V1.Workflows().GetWorkflowVersionWithTriggers(ctx.Request().Context(), task.TenantID, taskWithData.WorkflowVersionID)
workflowVersion, _, _, _, _, _, err := t.config.V1.Workflows().GetWorkflowVersionWithTriggers(ctx.Request().Context(), task.TenantID, taskWithData.WorkflowVersionID)
if err != nil {
return nil, err
@@ -80,7 +80,7 @@ func (t *V1WorkflowRunsService) getWorkflowRunDetails(
return nil, err
}
workflowVersion, _, _, _, _, err := t.config.V1.Workflows().GetWorkflowVersionWithTriggers(ctx, tenantId, workflowRun.WorkflowVersionId)
workflowVersion, _, _, _, _, _, err := t.config.V1.Workflows().GetWorkflowVersionWithTriggers(ctx, tenantId, workflowRun.WorkflowVersionId)
if err != nil {
return nil, err
+1 -1
View File
@@ -17,7 +17,7 @@ func (t *WorkflowService) WorkflowGet(ctx echo.Context, request gen.WorkflowGetR
return gen.WorkflowGet404JSONResponse(gen.APIErrors{}), nil
}
version, _, _, _, _, err := t.config.V1.Workflows().GetWorkflowVersionWithTriggers(ctx.Request().Context(), tenantId, *workflow.WorkflowVersionId)
version, _, _, _, _, _, err := t.config.V1.Workflows().GetWorkflowVersionWithTriggers(ctx.Request().Context(), tenantId, *workflow.WorkflowVersionId)
if err != nil {
return nil, err
@@ -43,7 +43,7 @@ func (t *WorkflowService) WorkflowVersionGet(ctx echo.Context, request gen.Workf
workflowVersionId = *row.WorkflowVersionId
}
row, crons, events, scheduleT, stepConcurrency, err := t.config.V1.Workflows().GetWorkflowVersionWithTriggers(ctx.Request().Context(), tenantId, workflowVersionId)
row, crons, events, scheduleT, stepConcurrency, workflowConcurrency, err := t.config.V1.Workflows().GetWorkflowVersionWithTriggers(ctx.Request().Context(), tenantId, workflowVersionId)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
@@ -58,11 +58,7 @@ func (t *WorkflowService) WorkflowVersionGet(ctx echo.Context, request gen.Workf
resp := transformers.ToWorkflowVersion(
&row.WorkflowVersion,
&workflow.Workflow,
&transformers.WorkflowConcurrency{
MaxRuns: row.ConcurrencyMaxRuns,
LimitStrategy: row.ConcurrencyLimitStrategy,
Expression: row.ConcurrencyExpression.String,
},
workflowConcurrency,
crons,
events,
scheduleT,
+8 -15
View File
@@ -4,7 +4,6 @@ import (
"encoding/json"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
"github.com/hatchet-dev/hatchet/api/v1/server/oas/gen"
"github.com/hatchet-dev/hatchet/pkg/repository/sqlcv1"
@@ -53,16 +52,10 @@ func ToWorkflowVersionMeta(version *sqlcv1.WorkflowVersion, workflow *sqlcv1.Wor
return res
}
type WorkflowConcurrency struct {
MaxRuns pgtype.Int4
LimitStrategy sqlcv1.NullV1ConcurrencyStrategy
Expression string
}
func ToWorkflowVersion(
version *sqlcv1.WorkflowVersion,
workflow *sqlcv1.Workflow,
concurrency *WorkflowConcurrency,
workflowConcurrency []*sqlcv1.ListWorkflowConcurrencyByVersionIdRow,
crons []*sqlcv1.WorkflowTriggerCronRef,
events []*sqlcv1.WorkflowTriggerEventRef,
schedules []*sqlcv1.WorkflowTriggerScheduledRef,
@@ -154,13 +147,13 @@ func ToWorkflowVersion(
}
res.Triggers = &triggersResp
res.V1Concurrency = ToV1Concurrency(concurrency, stepConcurrency)
res.V1Concurrency = ToV1Concurrency(workflowConcurrency, stepConcurrency)
return res
}
func ToV1Concurrency(workflowConcurrency *WorkflowConcurrency, taskConcurrencies []*sqlcv1.ListConcurrencyStrategiesByWorkflowVersionIdRow) *[]gen.ConcurrencySetting {
res := make([]gen.ConcurrencySetting, 0, len(taskConcurrencies)+1)
func ToV1Concurrency(workflowConcurrencies []*sqlcv1.ListWorkflowConcurrencyByVersionIdRow, taskConcurrencies []*sqlcv1.ListConcurrencyStrategiesByWorkflowVersionIdRow) *[]gen.ConcurrencySetting {
res := make([]gen.ConcurrencySetting, 0, len(taskConcurrencies)+len(workflowConcurrencies))
for _, c := range taskConcurrencies {
res = append(res, gen.ConcurrencySetting{
@@ -172,11 +165,11 @@ func ToV1Concurrency(workflowConcurrency *WorkflowConcurrency, taskConcurrencies
})
}
if workflowConcurrency != nil && workflowConcurrency.LimitStrategy.Valid {
for _, wc := range workflowConcurrencies {
res = append(res, gen.ConcurrencySetting{
Expression: workflowConcurrency.Expression,
LimitStrategy: gen.ConcurrencyLimitStrategy(workflowConcurrency.LimitStrategy.V1ConcurrencyStrategy),
MaxRuns: workflowConcurrency.MaxRuns.Int32,
Expression: wc.Expression,
LimitStrategy: gen.ConcurrencyLimitStrategy(wc.LimitStrategy),
MaxRuns: wc.MaxRuns,
Scope: gen.ConcurrencyScopeWORKFLOW,
})
}
+13 -4
View File
@@ -239,6 +239,12 @@ WITH input AS (
input i ON i.task_id = t.id AND i.task_inserted_at = t.inserted_at AND i.task_retry_count = t.retry_count
WHERE
t.tenant_id = @tenantId::uuid
-- only fail tasks which still have a v1_task_runtime for the current retry count.
-- a cancellation deletes the v1_task_runtime, so a late failure event should not trigger a retry.
AND EXISTS (
SELECT 1 FROM v1_task_runtime tr
WHERE tr.task_id = t.id AND tr.task_inserted_at = t.inserted_at AND tr.retry_count = t.retry_count
)
-- order by the task id to get a stable lock order
ORDER BY
id
@@ -276,7 +282,7 @@ RETURNING
v1_task.retry_max_backoff;
-- name: FailTaskInternalFailure :many
-- Fails a task due to an application-level error
-- Fails a task due to an internal error
WITH input AS (
SELECT
*
@@ -292,13 +298,16 @@ WITH input AS (
t.id
FROM
v1_task t
-- only fail tasks which have a v1_task_runtime equivalent to the current retry count. otherwise,
-- a cancellation which deletes the v1_task_runtime might lead to a future failure event, which triggers
-- a retry.
JOIN
input i ON i.task_id = t.id AND i.task_inserted_at = t.inserted_at AND i.task_retry_count = t.retry_count
WHERE
t.tenant_id = @tenantId::uuid
-- only fail tasks which still have a v1_task_runtime for the current retry count.
-- a cancellation deletes the v1_task_runtime, so a late failure event should not trigger a retry.
AND EXISTS (
SELECT 1 FROM v1_task_runtime tr
WHERE tr.task_id = t.id AND tr.task_inserted_at = t.inserted_at AND tr.retry_count = t.retry_count
)
-- order by the task id to get a stable lock order
ORDER BY
id
+13 -4
View File
@@ -374,6 +374,12 @@ WITH input AS (
input i ON i.task_id = t.id AND i.task_inserted_at = t.inserted_at AND i.task_retry_count = t.retry_count
WHERE
t.tenant_id = $5::uuid
-- only fail tasks which still have a v1_task_runtime for the current retry count.
-- a cancellation deletes the v1_task_runtime, so a late failure event should not trigger a retry.
AND EXISTS (
SELECT 1 FROM v1_task_runtime tr
WHERE tr.task_id = t.id AND tr.task_inserted_at = t.inserted_at AND tr.retry_count = t.retry_count
)
-- order by the task id to get a stable lock order
ORDER BY
id
@@ -478,13 +484,16 @@ WITH input AS (
t.id
FROM
v1_task t
-- only fail tasks which have a v1_task_runtime equivalent to the current retry count. otherwise,
-- a cancellation which deletes the v1_task_runtime might lead to a future failure event, which triggers
-- a retry.
JOIN
input i ON i.task_id = t.id AND i.task_inserted_at = t.inserted_at AND i.task_retry_count = t.retry_count
WHERE
t.tenant_id = $5::uuid
-- only fail tasks which still have a v1_task_runtime for the current retry count.
-- a cancellation deletes the v1_task_runtime, so a late failure event should not trigger a retry.
AND EXISTS (
SELECT 1 FROM v1_task_runtime tr
WHERE tr.task_id = t.id AND tr.task_inserted_at = t.inserted_at AND tr.retry_count = t.retry_count
)
-- order by the task id to get a stable lock order
ORDER BY
id
@@ -523,7 +532,7 @@ type FailTaskInternalFailureRow struct {
RetryCount int32 `json:"retry_count"`
}
// Fails a task due to an application-level error
// Fails a task due to an internal error
func (q *Queries) FailTaskInternalFailure(ctx context.Context, db DBTX, arg FailTaskInternalFailureParams) ([]*FailTaskInternalFailureRow, error) {
rows, err := db.Query(ctx, failTaskInternalFailure,
arg.Maxinternalretries,
+14 -6
View File
@@ -696,20 +696,28 @@ LIMIT 1;
-- name: GetWorkflowVersionById :one
SELECT
sqlc.embed(wv),
sqlc.embed(w),
wc.id as "concurrencyId",
wc.max_concurrency as "concurrencyMaxRuns",
wc.strategy as "concurrencyLimitStrategy",
wc.expression as "concurrencyExpression"
sqlc.embed(w)
FROM
"WorkflowVersion" as wv
JOIN "Workflow" as w on w."id" = wv."workflowId"
LEFT JOIN v1_workflow_concurrency as wc ON (wc.workflow_version_id, wc.workflow_id) = (wv."id", w."id")
WHERE
wv."id" = @id::uuid AND
wv."deletedAt" IS NULL
LIMIT 1;
-- name: ListWorkflowConcurrencyByVersionId :many
SELECT
wc.id,
wc.max_concurrency AS "maxRuns",
wc.strategy AS "limitStrategy",
wc.expression
FROM
v1_workflow_concurrency wc
WHERE
wc.workflow_version_id = @workflowVersionId::uuid AND
wc.workflow_id = @workflowId::uuid
ORDER BY wc.id ASC;
-- name: ListWorkflows :many
SELECT
sqlc.embed(workflows)
+54 -16
View File
@@ -1220,15 +1220,10 @@ func (q *Queries) GetWorkflowShape(ctx context.Context, db DBTX, workflowversion
const getWorkflowVersionById = `-- name: GetWorkflowVersionById :one
SELECT
wv.id, wv."createdAt", wv."updatedAt", wv."deletedAt", wv.version, wv."order", wv."workflowId", wv.checksum, wv."scheduleTimeout", wv."onFailureJobId", wv.sticky, wv.kind, wv."defaultPriority", wv."createWorkflowVersionOpts", wv."inputJsonSchema",
w.id, w."createdAt", w."updatedAt", w."deletedAt", w."tenantId", w.name, w.description, w."isPaused",
wc.id as "concurrencyId",
wc.max_concurrency as "concurrencyMaxRuns",
wc.strategy as "concurrencyLimitStrategy",
wc.expression as "concurrencyExpression"
w.id, w."createdAt", w."updatedAt", w."deletedAt", w."tenantId", w.name, w.description, w."isPaused"
FROM
"WorkflowVersion" as wv
JOIN "Workflow" as w on w."id" = wv."workflowId"
LEFT JOIN v1_workflow_concurrency as wc ON (wc.workflow_version_id, wc.workflow_id) = (wv."id", w."id")
WHERE
wv."id" = $1::uuid AND
wv."deletedAt" IS NULL
@@ -1236,12 +1231,8 @@ LIMIT 1
`
type GetWorkflowVersionByIdRow struct {
WorkflowVersion WorkflowVersion `json:"workflow_version"`
Workflow Workflow `json:"workflow"`
ConcurrencyId pgtype.Int8 `json:"concurrencyId"`
ConcurrencyMaxRuns pgtype.Int4 `json:"concurrencyMaxRuns"`
ConcurrencyLimitStrategy NullV1ConcurrencyStrategy `json:"concurrencyLimitStrategy"`
ConcurrencyExpression pgtype.Text `json:"concurrencyExpression"`
WorkflowVersion WorkflowVersion `json:"workflow_version"`
Workflow Workflow `json:"workflow"`
}
func (q *Queries) GetWorkflowVersionById(ctx context.Context, db DBTX, id uuid.UUID) (*GetWorkflowVersionByIdRow, error) {
@@ -1271,10 +1262,6 @@ func (q *Queries) GetWorkflowVersionById(ctx context.Context, db DBTX, id uuid.U
&i.Workflow.Name,
&i.Workflow.Description,
&i.Workflow.IsPaused,
&i.ConcurrencyId,
&i.ConcurrencyMaxRuns,
&i.ConcurrencyLimitStrategy,
&i.ConcurrencyExpression,
)
return &i, err
}
@@ -1812,6 +1799,57 @@ func (q *Queries) ListStepsByWorkflowVersionIds(ctx context.Context, db DBTX, ar
return items, nil
}
const listWorkflowConcurrencyByVersionId = `-- name: ListWorkflowConcurrencyByVersionId :many
SELECT
wc.id,
wc.max_concurrency AS "maxRuns",
wc.strategy AS "limitStrategy",
wc.expression
FROM
v1_workflow_concurrency wc
WHERE
wc.workflow_version_id = $1::uuid AND
wc.workflow_id = $2::uuid
ORDER BY wc.id ASC
`
type ListWorkflowConcurrencyByVersionIdParams struct {
Workflowversionid uuid.UUID `json:"workflowversionid"`
Workflowid uuid.UUID `json:"workflowid"`
}
type ListWorkflowConcurrencyByVersionIdRow struct {
ID int64 `json:"id"`
MaxRuns int32 `json:"maxRuns"`
LimitStrategy V1ConcurrencyStrategy `json:"limitStrategy"`
Expression string `json:"expression"`
}
func (q *Queries) ListWorkflowConcurrencyByVersionId(ctx context.Context, db DBTX, arg ListWorkflowConcurrencyByVersionIdParams) ([]*ListWorkflowConcurrencyByVersionIdRow, error) {
rows, err := db.Query(ctx, listWorkflowConcurrencyByVersionId, arg.Workflowversionid, arg.Workflowid)
if err != nil {
return nil, err
}
defer rows.Close()
var items []*ListWorkflowConcurrencyByVersionIdRow
for rows.Next() {
var i ListWorkflowConcurrencyByVersionIdRow
if err := rows.Scan(
&i.ID,
&i.MaxRuns,
&i.LimitStrategy,
&i.Expression,
); err != nil {
return nil, err
}
items = append(items, &i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listWorkflowNamesByIds = `-- name: ListWorkflowNamesByIds :many
SELECT id, name
FROM "Workflow"
+17 -6
View File
@@ -217,6 +217,7 @@ type WorkflowRepository interface {
[]*sqlcv1.WorkflowTriggerEventRef,
[]*sqlcv1.WorkflowTriggerScheduledRef,
[]*sqlcv1.ListConcurrencyStrategiesByWorkflowVersionIdRow,
[]*sqlcv1.ListWorkflowConcurrencyByVersionIdRow,
error)
GetWorkflowVersionById(ctx context.Context, tenantId uuid.UUID, workflowId uuid.UUID) (*sqlcv1.GetWorkflowVersionForEngineRow, error)
@@ -1140,6 +1141,7 @@ func (r *workflowRepository) GetWorkflowVersionWithTriggers(ctx context.Context,
[]*sqlcv1.WorkflowTriggerEventRef,
[]*sqlcv1.WorkflowTriggerScheduledRef,
[]*sqlcv1.ListConcurrencyStrategiesByWorkflowVersionIdRow,
[]*sqlcv1.ListWorkflowConcurrencyByVersionIdRow,
error,
) {
row, err := r.queries.GetWorkflowVersionById(
@@ -1149,7 +1151,7 @@ func (r *workflowRepository) GetWorkflowVersionWithTriggers(ctx context.Context,
)
if err != nil {
return nil, nil, nil, nil, nil, fmt.Errorf("failed to fetch workflow version: %w", err)
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to fetch workflow version: %w", err)
}
crons, err := r.queries.GetWorkflowVersionCronTriggerRefs(
@@ -1159,7 +1161,7 @@ func (r *workflowRepository) GetWorkflowVersionWithTriggers(ctx context.Context,
)
if err != nil {
return nil, nil, nil, nil, nil, fmt.Errorf("failed to fetch cron triggers: %w", err)
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to fetch cron triggers: %w", err)
}
events, err := r.queries.GetWorkflowVersionEventTriggerRefs(
@@ -1169,7 +1171,7 @@ func (r *workflowRepository) GetWorkflowVersionWithTriggers(ctx context.Context,
)
if err != nil {
return nil, nil, nil, nil, nil, fmt.Errorf("failed to fetch event triggers: %w", err)
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to fetch event triggers: %w", err)
}
scheduled, err := r.queries.GetWorkflowVersionScheduleTriggerRefs(
@@ -1179,7 +1181,7 @@ func (r *workflowRepository) GetWorkflowVersionWithTriggers(ctx context.Context,
)
if err != nil {
return nil, nil, nil, nil, nil, fmt.Errorf("failed to fetch scheduled triggers: %w", err)
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to fetch scheduled triggers: %w", err)
}
stepConcurrency, err := r.queries.ListConcurrencyStrategiesByWorkflowVersionId(ctx, r.pool, sqlcv1.ListConcurrencyStrategiesByWorkflowVersionIdParams{
@@ -1189,10 +1191,19 @@ func (r *workflowRepository) GetWorkflowVersionWithTriggers(ctx context.Context,
})
if err != nil {
return nil, nil, nil, nil, nil, fmt.Errorf("failed to fetch workflow concurrency strategies: %w", err)
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to fetch step concurrency strategies: %w", err)
}
return row, crons, events, scheduled, stepConcurrency, nil
workflowConcurrency, err := r.queries.ListWorkflowConcurrencyByVersionId(ctx, r.pool, sqlcv1.ListWorkflowConcurrencyByVersionIdParams{
Workflowversionid: row.WorkflowVersion.ID,
Workflowid: row.Workflow.ID,
})
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to fetch workflow concurrency strategies: %w", err)
}
return row, crons, events, scheduled, stepConcurrency, workflowConcurrency, nil
}
func (r *workflowRepository) GetWorkflowVersionById(ctx context.Context, tenantId, workflowId uuid.UUID) (*sqlcv1.GetWorkflowVersionForEngineRow, error) {