mirror of
https://github.com/hatchet-dev/hatchet.git
synced 2026-04-30 14:39:56 -05:00
rejig the query for creating multiple sticky states (#973)
* rejig the query for creating multiple sticky states * fix: sticky strategy of soft and improve query * fix: sort method was using indexes that didn't necessarilly correspond to original indexes, leading to inconsistent behavior --------- Co-authored-by: Sean Reilly <sean@hatchet.run> Co-authored-by: Alexander Belanger <alexander@hatchet.run>
This commit is contained in:
@@ -606,14 +606,26 @@ FROM workflow_version
|
||||
WHERE workflow_version."sticky" IS NOT NULL
|
||||
RETURNING *;
|
||||
|
||||
|
||||
-- name: CreateMultipleWorkflowRunStickyStates :many
|
||||
WITH workflow_version AS (
|
||||
SELECT DISTINCT
|
||||
"id" AS workflow_version_id,
|
||||
"sticky"
|
||||
FROM "WorkflowVersion"
|
||||
WHERE "id" = ANY(@workflowVersionIds::uuid[])
|
||||
-- name: CreateMultipleWorkflowRunStickyStates :exec
|
||||
WITH input_rows AS (
|
||||
SELECT
|
||||
UNNEST(@tenantId::uuid[]) as "tenantId",
|
||||
UNNEST(@workflowRunIds::uuid[]) as "workflowRunId",
|
||||
UNNEST(@desiredWorkerIds::uuid[]) as "desiredWorkerId",
|
||||
UNNEST(@workflowVersionIds::uuid[]) as "workflowVersionId"
|
||||
), valid_rows AS (
|
||||
SELECT
|
||||
ir."tenantId",
|
||||
ir."workflowRunId",
|
||||
ir."desiredWorkerId",
|
||||
ir."workflowVersionId",
|
||||
wv."sticky"
|
||||
FROM
|
||||
input_rows ir
|
||||
JOIN
|
||||
"WorkflowVersion" wv ON wv."id" = ir."workflowVersionId"
|
||||
WHERE
|
||||
wv."sticky" IS NOT NULL
|
||||
)
|
||||
INSERT INTO "WorkflowRunStickyState" (
|
||||
"createdAt",
|
||||
@@ -626,16 +638,11 @@ INSERT INTO "WorkflowRunStickyState" (
|
||||
SELECT
|
||||
CURRENT_TIMESTAMP,
|
||||
CURRENT_TIMESTAMP,
|
||||
UNNEST(@tenantId::uuid[]),
|
||||
UNNEST(@workflowRunIds::uuid[]),
|
||||
UNNEST(@desiredWorkerIds::uuid[]),
|
||||
workflow_version."sticky"
|
||||
FROM workflow_version
|
||||
JOIN UNNEST(@workflowVersionIds::uuid[]) AS wv(workflow_version_id)
|
||||
ON workflow_version.workflow_version_id = wv.workflow_version_id
|
||||
WHERE workflow_version."sticky" IS NOT NULL
|
||||
RETURNING *;
|
||||
|
||||
vr."tenantId",
|
||||
vr."workflowRunId",
|
||||
vr."desiredWorkerId",
|
||||
vr."sticky"
|
||||
FROM valid_rows vr;
|
||||
|
||||
-- name: GetWorkflowRunAdditionalMeta :one
|
||||
SELECT
|
||||
|
||||
@@ -576,13 +576,26 @@ func (q *Queries) CreateManyJobRuns(ctx context.Context, db DBTX, arg CreateMany
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const createMultipleWorkflowRunStickyStates = `-- name: CreateMultipleWorkflowRunStickyStates :many
|
||||
WITH workflow_version AS (
|
||||
SELECT DISTINCT
|
||||
"id" AS workflow_version_id,
|
||||
"sticky"
|
||||
FROM "WorkflowVersion"
|
||||
WHERE "id" = ANY($4::uuid[])
|
||||
const createMultipleWorkflowRunStickyStates = `-- name: CreateMultipleWorkflowRunStickyStates :exec
|
||||
WITH input_rows AS (
|
||||
SELECT
|
||||
UNNEST($1::uuid[]) as "tenantId",
|
||||
UNNEST($2::uuid[]) as "workflowRunId",
|
||||
UNNEST($3::uuid[]) as "desiredWorkerId",
|
||||
UNNEST($4::uuid[]) as "workflowVersionId"
|
||||
), valid_rows AS (
|
||||
SELECT
|
||||
ir."tenantId",
|
||||
ir."workflowRunId",
|
||||
ir."desiredWorkerId",
|
||||
ir."workflowVersionId",
|
||||
wv."sticky"
|
||||
FROM
|
||||
input_rows ir
|
||||
JOIN
|
||||
"WorkflowVersion" wv ON wv."id" = ir."workflowVersionId"
|
||||
WHERE
|
||||
wv."sticky" IS NOT NULL
|
||||
)
|
||||
INSERT INTO "WorkflowRunStickyState" (
|
||||
"createdAt",
|
||||
@@ -595,15 +608,11 @@ INSERT INTO "WorkflowRunStickyState" (
|
||||
SELECT
|
||||
CURRENT_TIMESTAMP,
|
||||
CURRENT_TIMESTAMP,
|
||||
UNNEST($1::uuid[]),
|
||||
UNNEST($2::uuid[]),
|
||||
UNNEST($3::uuid[]),
|
||||
workflow_version."sticky"
|
||||
FROM workflow_version
|
||||
JOIN UNNEST($4::uuid[]) AS wv(workflow_version_id)
|
||||
ON workflow_version.workflow_version_id = wv.workflow_version_id
|
||||
WHERE workflow_version."sticky" IS NOT NULL
|
||||
RETURNING id, "createdAt", "updatedAt", "tenantId", "workflowRunId", "desiredWorkerId", strategy
|
||||
vr."tenantId",
|
||||
vr."workflowRunId",
|
||||
vr."desiredWorkerId",
|
||||
vr."sticky"
|
||||
FROM valid_rows vr
|
||||
`
|
||||
|
||||
type CreateMultipleWorkflowRunStickyStatesParams struct {
|
||||
@@ -613,37 +622,14 @@ type CreateMultipleWorkflowRunStickyStatesParams struct {
|
||||
Workflowversionids []pgtype.UUID `json:"workflowversionids"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateMultipleWorkflowRunStickyStates(ctx context.Context, db DBTX, arg CreateMultipleWorkflowRunStickyStatesParams) ([]*WorkflowRunStickyState, error) {
|
||||
rows, err := db.Query(ctx, createMultipleWorkflowRunStickyStates,
|
||||
func (q *Queries) CreateMultipleWorkflowRunStickyStates(ctx context.Context, db DBTX, arg CreateMultipleWorkflowRunStickyStatesParams) error {
|
||||
_, err := db.Exec(ctx, createMultipleWorkflowRunStickyStates,
|
||||
arg.Tenantid,
|
||||
arg.Workflowrunids,
|
||||
arg.Desiredworkerids,
|
||||
arg.Workflowversionids,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []*WorkflowRunStickyState
|
||||
for rows.Next() {
|
||||
var i WorkflowRunStickyState
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.TenantId,
|
||||
&i.WorkflowRunId,
|
||||
&i.DesiredWorkerId,
|
||||
&i.Strategy,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, &i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
return err
|
||||
}
|
||||
|
||||
const createStepRun = `-- name: CreateStepRun :one
|
||||
|
||||
@@ -1586,22 +1586,15 @@ func createNewWorkflowRuns(ctx context.Context, pool *pgxpool.Pool, queries *dbs
|
||||
desiredWorkerIds := make([]pgtype.UUID, 0)
|
||||
tenantIds := make([]pgtype.UUID, 0)
|
||||
|
||||
workflowVersionIdMap := make(map[string]pgtype.UUID)
|
||||
|
||||
for _, stickyInfo := range stickyInfos {
|
||||
stickyWorkflowRunIds = append(stickyWorkflowRunIds, stickyInfo.workflowRunId)
|
||||
|
||||
// we want distinct workflowVersionIds
|
||||
workflowVersionIdMap[sqlchelpers.UUIDToStr(stickyInfo.workflowVersionId)] = stickyInfo.workflowVersionId
|
||||
workflowVersionIds = append(workflowVersionIds, stickyInfo.workflowVersionId)
|
||||
desiredWorkerIds = append(desiredWorkerIds, stickyInfo.desiredWorkerId)
|
||||
tenantIds = append(tenantIds, stickyInfo.tenantId)
|
||||
}
|
||||
|
||||
for _, value := range workflowVersionIdMap {
|
||||
workflowVersionIds = append(workflowVersionIds, value)
|
||||
}
|
||||
|
||||
_, err = queries.CreateMultipleWorkflowRunStickyStates(tx1Ctx, tx, dbsqlc.CreateMultipleWorkflowRunStickyStatesParams{
|
||||
err = queries.CreateMultipleWorkflowRunStickyStates(tx1Ctx, tx, dbsqlc.CreateMultipleWorkflowRunStickyStatesParams{
|
||||
Tenantid: tenantIds,
|
||||
Workflowrunids: stickyWorkflowRunIds,
|
||||
Workflowversionids: workflowVersionIds,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -154,13 +154,27 @@ func (r *rankedValidSlots) addSlot(slot *slot, rank int) {
|
||||
r.workerSeenCount[workerId]++
|
||||
}
|
||||
|
||||
func (r *rankedValidSlots) less(i, j int) bool {
|
||||
func (r *rankedValidSlots) less(a, b *slot) int {
|
||||
idxA := slices.Index(r.validSlots, a)
|
||||
idxB := slices.Index(r.validSlots, b)
|
||||
|
||||
intA := r.slotRanking[idxA]
|
||||
intB := r.slotRanking[idxB]
|
||||
|
||||
// if we have the same rank, sort by worker seen count
|
||||
if r.slotRanking[i] == r.slotRanking[j] {
|
||||
return r.workerSlotCountRank[i] > r.workerSlotCountRank[j]
|
||||
if intA == intB {
|
||||
intA = r.workerSlotCountRank[idxA]
|
||||
intB = r.workerSlotCountRank[idxB]
|
||||
}
|
||||
|
||||
return r.slotRanking[i] > r.slotRanking[j]
|
||||
switch {
|
||||
case intA == intB:
|
||||
return 0
|
||||
case intA > intB:
|
||||
return -1
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rankedValidSlots) order() []*slot {
|
||||
@@ -174,7 +188,7 @@ func (r *rankedValidSlots) order() []*slot {
|
||||
}
|
||||
|
||||
// sort the slots by rank
|
||||
sort.Slice(nonNegativeSlots, r.less)
|
||||
slices.SortStableFunc(nonNegativeSlots, r.less)
|
||||
|
||||
return nonNegativeSlots
|
||||
}
|
||||
@@ -195,8 +209,8 @@ func getRankedSlots(
|
||||
continue
|
||||
}
|
||||
|
||||
// if this is a sticky strategy, it can only be assigned to the desired worker if the desired
|
||||
// worker id is set. otherwise, it can be assigned to any worker.
|
||||
// if this is a HARD sticky strategy, it can only be assigned to the desired worker if the desired
|
||||
// worker id is set. otherwise, it cannot be assigned.
|
||||
if qi.Sticky.Valid && qi.Sticky.StickyStrategy == dbsqlc.StickyStrategyHARD {
|
||||
if qi.DesiredWorkerId.Valid && workerId == sqlchelpers.UUIDToStr(qi.DesiredWorkerId) {
|
||||
validSlots.addSlot(slot, 0)
|
||||
@@ -205,6 +219,18 @@ func getRankedSlots(
|
||||
continue
|
||||
}
|
||||
|
||||
// if this is a SOFT sticky strategy, we should prefer the desired worker, but if it is not
|
||||
// available, we can assign to any worker.
|
||||
if qi.Sticky.Valid && qi.Sticky.StickyStrategy == dbsqlc.StickyStrategySOFT {
|
||||
if qi.DesiredWorkerId.Valid && workerId == sqlchelpers.UUIDToStr(qi.DesiredWorkerId) {
|
||||
validSlots.addSlot(slot, 1)
|
||||
} else {
|
||||
validSlots.addSlot(slot, 0)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// if this step has affinity labels, check if the worker has the desired labels, and rank by
|
||||
// the given affinity
|
||||
if len(labels) > 0 {
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/hatchet-dev/hatchet/pkg/repository/prisma/dbsqlc"
|
||||
"github.com/hatchet-dev/hatchet/pkg/repository/prisma/sqlchelpers"
|
||||
)
|
||||
|
||||
var stableWorkerId1 = uuid.New().String()
|
||||
var stableWorkerId2 = uuid.New().String()
|
||||
|
||||
func TestGetRankedSlots(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
qi *dbsqlc.QueueItem
|
||||
labels []*dbsqlc.GetDesiredLabelsRow
|
||||
slots []*slot
|
||||
expectedWorker []string
|
||||
}{
|
||||
{
|
||||
name: "HARD sticky strategy with desired worker available",
|
||||
qi: &dbsqlc.QueueItem{
|
||||
Sticky: dbsqlc.NullStickyStrategy{Valid: true, StickyStrategy: dbsqlc.StickyStrategyHARD},
|
||||
DesiredWorkerId: sqlchelpers.UUIDFromStr(stableWorkerId1),
|
||||
},
|
||||
slots: []*slot{
|
||||
newSlot(&worker{ListActiveWorkersResult: &ListActiveWorkersResult{ID: sqlchelpers.UUIDFromStr(stableWorkerId1)}}, []string{}),
|
||||
newSlot(&worker{ListActiveWorkersResult: &ListActiveWorkersResult{ID: sqlchelpers.UUIDFromStr(uuid.New().String())}}, []string{}),
|
||||
},
|
||||
expectedWorker: []string{stableWorkerId1},
|
||||
},
|
||||
{
|
||||
name: "HARD sticky strategy without desired worker",
|
||||
qi: &dbsqlc.QueueItem{
|
||||
Sticky: dbsqlc.NullStickyStrategy{Valid: true, StickyStrategy: dbsqlc.StickyStrategyHARD},
|
||||
DesiredWorkerId: sqlchelpers.UUIDFromStr(uuid.New().String()),
|
||||
},
|
||||
slots: []*slot{
|
||||
newSlot(&worker{ListActiveWorkersResult: &ListActiveWorkersResult{ID: sqlchelpers.UUIDFromStr(uuid.New().String())}}, []string{}),
|
||||
newSlot(&worker{ListActiveWorkersResult: &ListActiveWorkersResult{ID: sqlchelpers.UUIDFromStr(uuid.New().String())}}, []string{}),
|
||||
},
|
||||
expectedWorker: []string{},
|
||||
},
|
||||
{
|
||||
name: "SOFT sticky strategy with desired worker available",
|
||||
qi: &dbsqlc.QueueItem{
|
||||
Sticky: dbsqlc.NullStickyStrategy{Valid: true, StickyStrategy: dbsqlc.StickyStrategySOFT},
|
||||
DesiredWorkerId: sqlchelpers.UUIDFromStr(stableWorkerId1),
|
||||
},
|
||||
slots: []*slot{
|
||||
newSlot(&worker{ListActiveWorkersResult: &ListActiveWorkersResult{ID: sqlchelpers.UUIDFromStr(stableWorkerId2)}}, []string{}),
|
||||
newSlot(&worker{ListActiveWorkersResult: &ListActiveWorkersResult{ID: sqlchelpers.UUIDFromStr(stableWorkerId1)}}, []string{}),
|
||||
newSlot(&worker{ListActiveWorkersResult: &ListActiveWorkersResult{ID: sqlchelpers.UUIDFromStr(stableWorkerId1)}}, []string{}),
|
||||
},
|
||||
expectedWorker: []string{stableWorkerId1, stableWorkerId1, stableWorkerId2},
|
||||
},
|
||||
{
|
||||
name: "Affinity labels with different worker weights",
|
||||
qi: &dbsqlc.QueueItem{},
|
||||
labels: []*dbsqlc.GetDesiredLabelsRow{
|
||||
{
|
||||
Key: "key1",
|
||||
Weight: 1,
|
||||
Required: false,
|
||||
Comparator: dbsqlc.WorkerLabelComparatorGREATERTHAN,
|
||||
IntValue: pgtype.Int4{Int32: 1, Valid: true},
|
||||
},
|
||||
{
|
||||
Key: "key2",
|
||||
Weight: 1,
|
||||
Required: false,
|
||||
Comparator: dbsqlc.WorkerLabelComparatorGREATERTHAN,
|
||||
IntValue: pgtype.Int4{Int32: 1, Valid: true},
|
||||
},
|
||||
},
|
||||
slots: []*slot{
|
||||
newSlot(&worker{ListActiveWorkersResult: &ListActiveWorkersResult{ID: sqlchelpers.UUIDFromStr(stableWorkerId1), Labels: []*dbsqlc.ListManyWorkerLabelsRow{{
|
||||
Key: "key1",
|
||||
IntValue: pgtype.Int4{Int32: 2, Valid: true},
|
||||
}}}}, []string{}),
|
||||
newSlot(&worker{ListActiveWorkersResult: &ListActiveWorkersResult{ID: sqlchelpers.UUIDFromStr(stableWorkerId2), Labels: []*dbsqlc.ListManyWorkerLabelsRow{{
|
||||
Key: "key1",
|
||||
IntValue: pgtype.Int4{Int32: 4, Valid: true},
|
||||
}, {
|
||||
Key: "key2",
|
||||
IntValue: pgtype.Int4{Int32: 4, Valid: true},
|
||||
}}}}, []string{}),
|
||||
},
|
||||
expectedWorker: []string{stableWorkerId2, stableWorkerId1},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actualSlots := getRankedSlots(tt.qi, tt.labels, tt.slots)
|
||||
actualWorkerIds := make([]string, len(actualSlots))
|
||||
for i, s := range actualSlots {
|
||||
actualWorkerIds[i] = s.getWorkerId()
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.expectedWorker, actualWorkerIds)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user