rejig the query for creating multiple sticky states (#973)

* rejig the query for creating multiple sticky states

* fix: sticky strategy of soft and improve query

* fix: sort method was using indexes that didn't necessarilly correspond to original indexes, leading to inconsistent behavior

---------

Co-authored-by: Sean Reilly <sean@hatchet.run>
Co-authored-by: Alexander Belanger <alexander@hatchet.run>
This commit is contained in:
Sean Reilly
2024-10-17 06:29:19 -07:00
committed by GitHub
parent 17dc80cad8
commit ecb9ce1e1e
5 changed files with 198 additions and 77 deletions
+25 -18
View File
@@ -606,14 +606,26 @@ FROM workflow_version
WHERE workflow_version."sticky" IS NOT NULL
RETURNING *;
-- name: CreateMultipleWorkflowRunStickyStates :many
WITH workflow_version AS (
SELECT DISTINCT
"id" AS workflow_version_id,
"sticky"
FROM "WorkflowVersion"
WHERE "id" = ANY(@workflowVersionIds::uuid[])
-- name: CreateMultipleWorkflowRunStickyStates :exec
WITH input_rows AS (
SELECT
UNNEST(@tenantId::uuid[]) as "tenantId",
UNNEST(@workflowRunIds::uuid[]) as "workflowRunId",
UNNEST(@desiredWorkerIds::uuid[]) as "desiredWorkerId",
UNNEST(@workflowVersionIds::uuid[]) as "workflowVersionId"
), valid_rows AS (
SELECT
ir."tenantId",
ir."workflowRunId",
ir."desiredWorkerId",
ir."workflowVersionId",
wv."sticky"
FROM
input_rows ir
JOIN
"WorkflowVersion" wv ON wv."id" = ir."workflowVersionId"
WHERE
wv."sticky" IS NOT NULL
)
INSERT INTO "WorkflowRunStickyState" (
"createdAt",
@@ -626,16 +638,11 @@ INSERT INTO "WorkflowRunStickyState" (
SELECT
CURRENT_TIMESTAMP,
CURRENT_TIMESTAMP,
UNNEST(@tenantId::uuid[]),
UNNEST(@workflowRunIds::uuid[]),
UNNEST(@desiredWorkerIds::uuid[]),
workflow_version."sticky"
FROM workflow_version
JOIN UNNEST(@workflowVersionIds::uuid[]) AS wv(workflow_version_id)
ON workflow_version.workflow_version_id = wv.workflow_version_id
WHERE workflow_version."sticky" IS NOT NULL
RETURNING *;
vr."tenantId",
vr."workflowRunId",
vr."desiredWorkerId",
vr."sticky"
FROM valid_rows vr;
-- name: GetWorkflowRunAdditionalMeta :one
SELECT
@@ -576,13 +576,26 @@ func (q *Queries) CreateManyJobRuns(ctx context.Context, db DBTX, arg CreateMany
return items, nil
}
const createMultipleWorkflowRunStickyStates = `-- name: CreateMultipleWorkflowRunStickyStates :many
WITH workflow_version AS (
SELECT DISTINCT
"id" AS workflow_version_id,
"sticky"
FROM "WorkflowVersion"
WHERE "id" = ANY($4::uuid[])
const createMultipleWorkflowRunStickyStates = `-- name: CreateMultipleWorkflowRunStickyStates :exec
WITH input_rows AS (
SELECT
UNNEST($1::uuid[]) as "tenantId",
UNNEST($2::uuid[]) as "workflowRunId",
UNNEST($3::uuid[]) as "desiredWorkerId",
UNNEST($4::uuid[]) as "workflowVersionId"
), valid_rows AS (
SELECT
ir."tenantId",
ir."workflowRunId",
ir."desiredWorkerId",
ir."workflowVersionId",
wv."sticky"
FROM
input_rows ir
JOIN
"WorkflowVersion" wv ON wv."id" = ir."workflowVersionId"
WHERE
wv."sticky" IS NOT NULL
)
INSERT INTO "WorkflowRunStickyState" (
"createdAt",
@@ -595,15 +608,11 @@ INSERT INTO "WorkflowRunStickyState" (
SELECT
CURRENT_TIMESTAMP,
CURRENT_TIMESTAMP,
UNNEST($1::uuid[]),
UNNEST($2::uuid[]),
UNNEST($3::uuid[]),
workflow_version."sticky"
FROM workflow_version
JOIN UNNEST($4::uuid[]) AS wv(workflow_version_id)
ON workflow_version.workflow_version_id = wv.workflow_version_id
WHERE workflow_version."sticky" IS NOT NULL
RETURNING id, "createdAt", "updatedAt", "tenantId", "workflowRunId", "desiredWorkerId", strategy
vr."tenantId",
vr."workflowRunId",
vr."desiredWorkerId",
vr."sticky"
FROM valid_rows vr
`
type CreateMultipleWorkflowRunStickyStatesParams struct {
@@ -613,37 +622,14 @@ type CreateMultipleWorkflowRunStickyStatesParams struct {
Workflowversionids []pgtype.UUID `json:"workflowversionids"`
}
func (q *Queries) CreateMultipleWorkflowRunStickyStates(ctx context.Context, db DBTX, arg CreateMultipleWorkflowRunStickyStatesParams) ([]*WorkflowRunStickyState, error) {
rows, err := db.Query(ctx, createMultipleWorkflowRunStickyStates,
func (q *Queries) CreateMultipleWorkflowRunStickyStates(ctx context.Context, db DBTX, arg CreateMultipleWorkflowRunStickyStatesParams) error {
_, err := db.Exec(ctx, createMultipleWorkflowRunStickyStates,
arg.Tenantid,
arg.Workflowrunids,
arg.Desiredworkerids,
arg.Workflowversionids,
)
if err != nil {
return nil, err
}
defer rows.Close()
var items []*WorkflowRunStickyState
for rows.Next() {
var i WorkflowRunStickyState
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.TenantId,
&i.WorkflowRunId,
&i.DesiredWorkerId,
&i.Strategy,
); err != nil {
return nil, err
}
items = append(items, &i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
return err
}
const createStepRun = `-- name: CreateStepRun :one
+2 -9
View File
@@ -1586,22 +1586,15 @@ func createNewWorkflowRuns(ctx context.Context, pool *pgxpool.Pool, queries *dbs
desiredWorkerIds := make([]pgtype.UUID, 0)
tenantIds := make([]pgtype.UUID, 0)
workflowVersionIdMap := make(map[string]pgtype.UUID)
for _, stickyInfo := range stickyInfos {
stickyWorkflowRunIds = append(stickyWorkflowRunIds, stickyInfo.workflowRunId)
// we want distinct workflowVersionIds
workflowVersionIdMap[sqlchelpers.UUIDToStr(stickyInfo.workflowVersionId)] = stickyInfo.workflowVersionId
workflowVersionIds = append(workflowVersionIds, stickyInfo.workflowVersionId)
desiredWorkerIds = append(desiredWorkerIds, stickyInfo.desiredWorkerId)
tenantIds = append(tenantIds, stickyInfo.tenantId)
}
for _, value := range workflowVersionIdMap {
workflowVersionIds = append(workflowVersionIds, value)
}
_, err = queries.CreateMultipleWorkflowRunStickyStates(tx1Ctx, tx, dbsqlc.CreateMultipleWorkflowRunStickyStatesParams{
err = queries.CreateMultipleWorkflowRunStickyStates(tx1Ctx, tx, dbsqlc.CreateMultipleWorkflowRunStickyStatesParams{
Tenantid: tenantIds,
Workflowrunids: stickyWorkflowRunIds,
Workflowversionids: workflowVersionIds,
+34 -8
View File
@@ -1,7 +1,7 @@
package v2
import (
"sort"
"slices"
"sync"
"time"
@@ -154,13 +154,27 @@ func (r *rankedValidSlots) addSlot(slot *slot, rank int) {
r.workerSeenCount[workerId]++
}
func (r *rankedValidSlots) less(i, j int) bool {
func (r *rankedValidSlots) less(a, b *slot) int {
idxA := slices.Index(r.validSlots, a)
idxB := slices.Index(r.validSlots, b)
intA := r.slotRanking[idxA]
intB := r.slotRanking[idxB]
// if we have the same rank, sort by worker seen count
if r.slotRanking[i] == r.slotRanking[j] {
return r.workerSlotCountRank[i] > r.workerSlotCountRank[j]
if intA == intB {
intA = r.workerSlotCountRank[idxA]
intB = r.workerSlotCountRank[idxB]
}
return r.slotRanking[i] > r.slotRanking[j]
switch {
case intA == intB:
return 0
case intA > intB:
return -1
default:
return 1
}
}
func (r *rankedValidSlots) order() []*slot {
@@ -174,7 +188,7 @@ func (r *rankedValidSlots) order() []*slot {
}
// sort the slots by rank
sort.Slice(nonNegativeSlots, r.less)
slices.SortStableFunc(nonNegativeSlots, r.less)
return nonNegativeSlots
}
@@ -195,8 +209,8 @@ func getRankedSlots(
continue
}
// if this is a sticky strategy, it can only be assigned to the desired worker if the desired
// worker id is set. otherwise, it can be assigned to any worker.
// if this is a HARD sticky strategy, it can only be assigned to the desired worker if the desired
// worker id is set. otherwise, it cannot be assigned.
if qi.Sticky.Valid && qi.Sticky.StickyStrategy == dbsqlc.StickyStrategyHARD {
if qi.DesiredWorkerId.Valid && workerId == sqlchelpers.UUIDToStr(qi.DesiredWorkerId) {
validSlots.addSlot(slot, 0)
@@ -205,6 +219,18 @@ func getRankedSlots(
continue
}
// if this is a SOFT sticky strategy, we should prefer the desired worker, but if it is not
// available, we can assign to any worker.
if qi.Sticky.Valid && qi.Sticky.StickyStrategy == dbsqlc.StickyStrategySOFT {
if qi.DesiredWorkerId.Valid && workerId == sqlchelpers.UUIDToStr(qi.DesiredWorkerId) {
validSlots.addSlot(slot, 1)
} else {
validSlots.addSlot(slot, 0)
}
continue
}
// if this step has affinity labels, check if the worker has the desired labels, and rank by
// the given affinity
if len(labels) > 0 {
+109
View File
@@ -0,0 +1,109 @@
package v2
import (
"testing"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
"github.com/stretchr/testify/assert"
"github.com/hatchet-dev/hatchet/pkg/repository/prisma/dbsqlc"
"github.com/hatchet-dev/hatchet/pkg/repository/prisma/sqlchelpers"
)
var stableWorkerId1 = uuid.New().String()
var stableWorkerId2 = uuid.New().String()
func TestGetRankedSlots(t *testing.T) {
tests := []struct {
name string
qi *dbsqlc.QueueItem
labels []*dbsqlc.GetDesiredLabelsRow
slots []*slot
expectedWorker []string
}{
{
name: "HARD sticky strategy with desired worker available",
qi: &dbsqlc.QueueItem{
Sticky: dbsqlc.NullStickyStrategy{Valid: true, StickyStrategy: dbsqlc.StickyStrategyHARD},
DesiredWorkerId: sqlchelpers.UUIDFromStr(stableWorkerId1),
},
slots: []*slot{
newSlot(&worker{ListActiveWorkersResult: &ListActiveWorkersResult{ID: sqlchelpers.UUIDFromStr(stableWorkerId1)}}, []string{}),
newSlot(&worker{ListActiveWorkersResult: &ListActiveWorkersResult{ID: sqlchelpers.UUIDFromStr(uuid.New().String())}}, []string{}),
},
expectedWorker: []string{stableWorkerId1},
},
{
name: "HARD sticky strategy without desired worker",
qi: &dbsqlc.QueueItem{
Sticky: dbsqlc.NullStickyStrategy{Valid: true, StickyStrategy: dbsqlc.StickyStrategyHARD},
DesiredWorkerId: sqlchelpers.UUIDFromStr(uuid.New().String()),
},
slots: []*slot{
newSlot(&worker{ListActiveWorkersResult: &ListActiveWorkersResult{ID: sqlchelpers.UUIDFromStr(uuid.New().String())}}, []string{}),
newSlot(&worker{ListActiveWorkersResult: &ListActiveWorkersResult{ID: sqlchelpers.UUIDFromStr(uuid.New().String())}}, []string{}),
},
expectedWorker: []string{},
},
{
name: "SOFT sticky strategy with desired worker available",
qi: &dbsqlc.QueueItem{
Sticky: dbsqlc.NullStickyStrategy{Valid: true, StickyStrategy: dbsqlc.StickyStrategySOFT},
DesiredWorkerId: sqlchelpers.UUIDFromStr(stableWorkerId1),
},
slots: []*slot{
newSlot(&worker{ListActiveWorkersResult: &ListActiveWorkersResult{ID: sqlchelpers.UUIDFromStr(stableWorkerId2)}}, []string{}),
newSlot(&worker{ListActiveWorkersResult: &ListActiveWorkersResult{ID: sqlchelpers.UUIDFromStr(stableWorkerId1)}}, []string{}),
newSlot(&worker{ListActiveWorkersResult: &ListActiveWorkersResult{ID: sqlchelpers.UUIDFromStr(stableWorkerId1)}}, []string{}),
},
expectedWorker: []string{stableWorkerId1, stableWorkerId1, stableWorkerId2},
},
{
name: "Affinity labels with different worker weights",
qi: &dbsqlc.QueueItem{},
labels: []*dbsqlc.GetDesiredLabelsRow{
{
Key: "key1",
Weight: 1,
Required: false,
Comparator: dbsqlc.WorkerLabelComparatorGREATERTHAN,
IntValue: pgtype.Int4{Int32: 1, Valid: true},
},
{
Key: "key2",
Weight: 1,
Required: false,
Comparator: dbsqlc.WorkerLabelComparatorGREATERTHAN,
IntValue: pgtype.Int4{Int32: 1, Valid: true},
},
},
slots: []*slot{
newSlot(&worker{ListActiveWorkersResult: &ListActiveWorkersResult{ID: sqlchelpers.UUIDFromStr(stableWorkerId1), Labels: []*dbsqlc.ListManyWorkerLabelsRow{{
Key: "key1",
IntValue: pgtype.Int4{Int32: 2, Valid: true},
}}}}, []string{}),
newSlot(&worker{ListActiveWorkersResult: &ListActiveWorkersResult{ID: sqlchelpers.UUIDFromStr(stableWorkerId2), Labels: []*dbsqlc.ListManyWorkerLabelsRow{{
Key: "key1",
IntValue: pgtype.Int4{Int32: 4, Valid: true},
}, {
Key: "key2",
IntValue: pgtype.Int4{Int32: 4, Valid: true},
}}}}, []string{}),
},
expectedWorker: []string{stableWorkerId2, stableWorkerId1},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actualSlots := getRankedSlots(tt.qi, tt.labels, tt.slots)
actualWorkerIds := make([]string, len(actualSlots))
for i, s := range actualSlots {
actualWorkerIds[i] = s.getWorkerId()
}
assert.Equal(t, tt.expectedWorker, actualWorkerIds)
})
}
}