mirror of
https://github.com/hatchet-dev/hatchet.git
synced 2026-04-24 11:18:35 -05:00
feat: go worker assignment (#741)
* feat: create worker with label * feat: worker context * feat: dynamic labels * feat: affinity * fix: ptr * fix: nil labels * feat: sticky dag * feat: sticky docs * feat: sticky children * chore: lint * fix: tests * fix: possibly nil workerId * chore: cleanup unneeded pointers
This commit is contained in:
@@ -0,0 +1,38 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
|
||||
"github.com/hatchet-dev/hatchet/pkg/cmdutils"
|
||||
)
|
||||
|
||||
type userCreateEvent struct {
|
||||
Username string `json:"username"`
|
||||
UserID string `json:"user_id"`
|
||||
Data map[string]string `json:"data"`
|
||||
}
|
||||
|
||||
type stepOneOutput struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
err := godotenv.Load()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ch := cmdutils.InterruptChan()
|
||||
cleanup, err := run()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
<-ch
|
||||
|
||||
if err := cleanup(); err != nil {
|
||||
panic(fmt.Errorf("cleanup() error = %v", err))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/hatchet-dev/hatchet/pkg/client"
|
||||
"github.com/hatchet-dev/hatchet/pkg/client/types"
|
||||
"github.com/hatchet-dev/hatchet/pkg/worker"
|
||||
)
|
||||
|
||||
func run() (func() error, error) {
|
||||
c, err := client.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client: %w", err)
|
||||
}
|
||||
|
||||
w, err := worker.NewWorker(
|
||||
worker.WithClient(
|
||||
c,
|
||||
),
|
||||
worker.WithLabels(map[string]interface{}{
|
||||
"model": "fancy-ai-model-v2",
|
||||
"memory": 512,
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating worker: %w", err)
|
||||
}
|
||||
|
||||
err = w.RegisterWorkflow(
|
||||
&worker.WorkflowJob{
|
||||
On: worker.Events("user:create:affinity"),
|
||||
Name: "affinity",
|
||||
Description: "affinity",
|
||||
Steps: []*worker.WorkflowStep{
|
||||
worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
|
||||
|
||||
model := ctx.Worker().GetLabels()["model"]
|
||||
|
||||
if model != "fancy-ai-model-v3" {
|
||||
ctx.Worker().UpsertLabels(map[string]interface{}{
|
||||
"model": nil,
|
||||
})
|
||||
// Do something to load the model
|
||||
ctx.Worker().UpsertLabels(map[string]interface{}{
|
||||
"model": "fancy-ai-model-v3",
|
||||
})
|
||||
}
|
||||
|
||||
return &stepOneOutput{
|
||||
Message: ctx.Worker().ID(),
|
||||
}, nil
|
||||
}).
|
||||
SetName("step-one").
|
||||
SetDesiredLabels(map[string]*types.DesiredWorkerLabel{
|
||||
"model": {
|
||||
Value: "fancy-ai-model-v3",
|
||||
Weight: 10,
|
||||
},
|
||||
"memory": {
|
||||
Value: 512,
|
||||
Required: true,
|
||||
Comparator: types.ComparatorPtr(types.WorkerLabelComparator_GREATER_THAN),
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error registering workflow: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Printf("pushing event")
|
||||
|
||||
testEvent := userCreateEvent{
|
||||
Username: "echo-test",
|
||||
UserID: "1234",
|
||||
Data: map[string]string{
|
||||
"test": "test",
|
||||
},
|
||||
}
|
||||
|
||||
// push an event
|
||||
err := c.Event().Push(
|
||||
context.Background(),
|
||||
"user:create:affinity",
|
||||
testEvent,
|
||||
)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("error pushing event: %w", err))
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
}()
|
||||
|
||||
cleanup, err := w.Start()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error starting worker: %w", err)
|
||||
}
|
||||
|
||||
return cleanup, nil
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
|
||||
"github.com/hatchet-dev/hatchet/pkg/cmdutils"
|
||||
)
|
||||
|
||||
type userCreateEvent struct {
|
||||
Username string `json:"username"`
|
||||
UserID string `json:"user_id"`
|
||||
Data map[string]string `json:"data"`
|
||||
}
|
||||
|
||||
type stepOneOutput struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
err := godotenv.Load()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ch := cmdutils.InterruptChan()
|
||||
cleanup, err := run()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
<-ch
|
||||
|
||||
if err := cleanup(); err != nil {
|
||||
panic(fmt.Errorf("cleanup() error = %v", err))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/hatchet-dev/hatchet/pkg/client"
|
||||
"github.com/hatchet-dev/hatchet/pkg/client/types"
|
||||
"github.com/hatchet-dev/hatchet/pkg/worker"
|
||||
)
|
||||
|
||||
func run() (func() error, error) {
|
||||
c, err := client.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client: %w", err)
|
||||
}
|
||||
|
||||
w, err := worker.NewWorker(
|
||||
worker.WithClient(
|
||||
c,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating worker: %w", err)
|
||||
}
|
||||
|
||||
err = w.RegisterWorkflow(
|
||||
&worker.WorkflowJob{
|
||||
On: worker.Events("user:create:sticky"),
|
||||
Name: "sticky",
|
||||
Description: "sticky",
|
||||
StickyStrategy: types.StickyStrategyPtr(types.StickyStrategy_SOFT),
|
||||
Steps: []*worker.WorkflowStep{
|
||||
worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
|
||||
|
||||
sticky := true
|
||||
|
||||
_, err = ctx.SpawnWorkflow("step-one", nil, &worker.SpawnWorkflowOpts{
|
||||
Sticky: &sticky,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error spawning workflow: %w", err)
|
||||
}
|
||||
|
||||
return &stepOneOutput{
|
||||
Message: ctx.Worker().ID(),
|
||||
}, nil
|
||||
}).SetName("step-one"),
|
||||
worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
|
||||
return &stepOneOutput{
|
||||
Message: ctx.Worker().ID(),
|
||||
}, nil
|
||||
}).SetName("step-two"),
|
||||
worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
|
||||
return &stepOneOutput{
|
||||
Message: ctx.Worker().ID(),
|
||||
}, nil
|
||||
}).SetName("step-three").AddParents("step-one", "step-two"),
|
||||
},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error registering workflow: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Printf("pushing event")
|
||||
|
||||
testEvent := userCreateEvent{
|
||||
Username: "echo-test",
|
||||
UserID: "1234",
|
||||
Data: map[string]string{
|
||||
"test": "test",
|
||||
},
|
||||
}
|
||||
|
||||
// push an event
|
||||
err := c.Event().Push(
|
||||
context.Background(),
|
||||
"user:create:sticky",
|
||||
testEvent,
|
||||
)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("error pushing event: %w", err))
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
}()
|
||||
|
||||
cleanup, err := w.Start()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error starting worker: %w", err)
|
||||
}
|
||||
|
||||
return cleanup, nil
|
||||
}
|
||||
@@ -70,7 +70,31 @@ const myWorkflow: Workflow = {
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```go
|
||||
// TODO add go example
|
||||
err = w.RegisterWorkflow(
|
||||
&worker.WorkflowJob{
|
||||
On: worker.Events("user:create:sticky"),
|
||||
Name: "sticky",
|
||||
Description: "sticky",
|
||||
StickyStrategy: types.StickyStrategyPtr(types.StickyStrategy_SOFT),
|
||||
Steps: []*worker.WorkflowStep{
|
||||
worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
|
||||
return &stepOneOutput{
|
||||
Message: ctx.Worker().ID(),
|
||||
}, nil
|
||||
}).SetName("step-one"),
|
||||
worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
|
||||
return &stepOneOutput{
|
||||
Message: ctx.Worker().ID(),
|
||||
}, nil
|
||||
}).SetName("step-two"),
|
||||
worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
|
||||
return &stepOneOutput{
|
||||
Message: ctx.Worker().ID(),
|
||||
}, nil
|
||||
}).SetName("step-three").AddParents("step-one", "step-two"),
|
||||
},
|
||||
},
|
||||
)
|
||||
```
|
||||
</Tabs.Tab>
|
||||
</Tabs>
|
||||
@@ -143,7 +167,33 @@ const parentWorkflow: Workflow = {
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```go
|
||||
// TODO go example
|
||||
err = w.RegisterWorkflow(
|
||||
&worker.WorkflowJob{
|
||||
On: worker.Events("user:create:sticky"),
|
||||
Name: "sticky",
|
||||
Description: "sticky",
|
||||
Steps: []*worker.WorkflowStep{
|
||||
worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
|
||||
|
||||
sticky := true
|
||||
|
||||
_, err = ctx.SpawnWorkflow("sticky-child", nil, &worker.SpawnWorkflowOpts{
|
||||
Sticky: &sticky,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error spawning workflow: %w", err)
|
||||
}
|
||||
|
||||
return &stepOneOutput{
|
||||
Message: ctx.Worker().ID(),
|
||||
}, nil
|
||||
}).SetName("step-one"),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
```
|
||||
</Tabs.Tab>
|
||||
</Tabs>
|
||||
```
|
||||
|
||||
@@ -38,7 +38,15 @@ worker = hatchet.worker(
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```go
|
||||
// TODO add go example
|
||||
w, err := worker.NewWorker(
|
||||
worker.WithClient(
|
||||
c,
|
||||
),
|
||||
worker.WithLabels(map[string]interface{}{
|
||||
"model": "fancy-ai-model-v2",
|
||||
"memory": 512,
|
||||
}),
|
||||
)
|
||||
```
|
||||
</Tabs.Tab>
|
||||
</Tabs>
|
||||
@@ -107,7 +115,32 @@ const affinity: Workflow = {
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```go
|
||||
// TODO go example
|
||||
err = w.RegisterWorkflow(
|
||||
&worker.WorkflowJob{
|
||||
On: worker.Events("user:create:affinity"),
|
||||
Name: "affinity",
|
||||
Description: "affinity",
|
||||
Steps: []*worker.WorkflowStep{
|
||||
worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
|
||||
return &stepOneOutput{
|
||||
Message: ctx.Worker().ID(),
|
||||
}, nil
|
||||
}).
|
||||
SetName("step-one").
|
||||
SetDesiredLabels(map[string]*types.DesiredWorkerLabel{
|
||||
"model": {
|
||||
Value: "fancy-ai-model-v2",
|
||||
Weight: 10,
|
||||
},
|
||||
"memory": {
|
||||
Value: 512,
|
||||
Required: true,
|
||||
Comparator: types.ComparatorPtr(types.WorkerLabelComparator_GREATER_THAN),
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
)
|
||||
```
|
||||
</Tabs.Tab>
|
||||
</Tabs>
|
||||
@@ -129,7 +162,10 @@ Labels can also be set dynamically on workers using the `upsertLabels` method. T
|
||||
```python
|
||||
@hatchet.step(
|
||||
desired_worker_labels={
|
||||
"model": {"value": "fancy-vision-model", "weight": 10},
|
||||
"model": {
|
||||
"value": "fancy-vision-model",
|
||||
"weight": 10
|
||||
},
|
||||
"memory": {
|
||||
"value": 256,
|
||||
"required": True,
|
||||
@@ -181,7 +217,49 @@ Labels can also be set dynamically on workers using the `upsertLabels` method. T
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```go
|
||||
// TODO add go example
|
||||
err = w.RegisterWorkflow(
|
||||
&worker.WorkflowJob{
|
||||
On: worker.Events("user:create:affinity"),
|
||||
Name: "affinity",
|
||||
Description: "affinity",
|
||||
Steps: []*worker.WorkflowStep{
|
||||
worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
|
||||
|
||||
model := ctx.Worker().GetLabels()["model"]
|
||||
|
||||
if model != "fancy-vision-model" {
|
||||
ctx.Worker().UpsertLabels(map[string]interface{}{
|
||||
"model": nil,
|
||||
})
|
||||
// Do something to load the model
|
||||
evictModel();
|
||||
loadNewModel("fancy-vision-model");
|
||||
ctx.Worker().UpsertLabels(map[string]interface{}{
|
||||
"model": "fancy-vision-model",
|
||||
})
|
||||
}
|
||||
|
||||
return &stepOneOutput{
|
||||
Message: ctx.Worker().ID(),
|
||||
}, nil
|
||||
}).
|
||||
SetName("step-one").
|
||||
SetDesiredLabels(map[string]*types.DesiredWorkerLabel{
|
||||
"model": {
|
||||
Value: "fancy-vision-model",
|
||||
Weight: 10,
|
||||
},
|
||||
"memory": {
|
||||
Value: 512,
|
||||
Required: true,
|
||||
Comparator: types.WorkerLabelComparator_GREATER_THAN,
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
```
|
||||
</Tabs.Tab>
|
||||
</Tabs>
|
||||
```
|
||||
|
||||
@@ -478,10 +478,10 @@ func getCreateJobOpts(req *contracts.CreateWorkflowJobOpts, kind string) (*repos
|
||||
|
||||
retries := int(stepCp.Retries)
|
||||
|
||||
var affinity *map[string]repository.DesiredWorkerLabelOpts
|
||||
var affinity map[string]repository.DesiredWorkerLabelOpts
|
||||
|
||||
if stepCp.WorkerLabels != nil {
|
||||
affinity = &map[string]repository.DesiredWorkerLabelOpts{}
|
||||
affinity = map[string]repository.DesiredWorkerLabelOpts{}
|
||||
for k, v := range stepCp.WorkerLabels {
|
||||
|
||||
var c *string
|
||||
@@ -491,7 +491,7 @@ func getCreateJobOpts(req *contracts.CreateWorkflowJobOpts, kind string) (*repos
|
||||
c = &cPtr
|
||||
}
|
||||
|
||||
(*affinity)[k] = repository.DesiredWorkerLabelOpts{
|
||||
(affinity)[k] = repository.DesiredWorkerLabelOpts{
|
||||
Key: k,
|
||||
StrValue: v.StrValue,
|
||||
IntValue: v.IntValue,
|
||||
|
||||
@@ -765,7 +765,7 @@ func (ec *JobsControllerImpl) runStepRunReassignTenant(ctx context.Context, tena
|
||||
EventReason: repository.StepRunEventReasonPtr(dbsqlc.StepRunEventReasonREASSIGNED),
|
||||
EventSeverity: repository.StepRunEventSeverityPtr(dbsqlc.StepRunEventSeverityCRITICAL),
|
||||
EventMessage: repository.StringPtr("Worker has become inactive"),
|
||||
EventData: &eventData,
|
||||
EventData: eventData,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -21,6 +21,7 @@ type ChildWorkflowOpts struct {
|
||||
ParentStepRunId string
|
||||
ChildIndex int
|
||||
ChildKey *string
|
||||
DesiredWorkerId *string
|
||||
}
|
||||
|
||||
type AdminClient interface {
|
||||
@@ -216,6 +217,7 @@ func (a *adminClientImpl) RunChildWorkflow(workflowName string, input interface{
|
||||
ParentStepRunId: &opts.ParentStepRunId,
|
||||
ChildIndex: &childIndex,
|
||||
ChildKey: opts.ChildKey,
|
||||
DesiredWorkerId: opts.DesiredWorkerId,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -264,6 +266,11 @@ func (a *adminClientImpl) getPutRequest(workflow *types.Workflow) (*admincontrac
|
||||
CronTriggers: workflow.Triggers.Cron,
|
||||
}
|
||||
|
||||
if workflow.StickyStrategy != nil {
|
||||
s := admincontracts.StickyStrategy(*workflow.StickyStrategy)
|
||||
opts.Sticky = &s
|
||||
}
|
||||
|
||||
if workflow.Concurrency != nil {
|
||||
opts.Concurrency = &admincontracts.WorkflowConcurrencyOpts{
|
||||
Action: workflow.Concurrency.ActionID,
|
||||
@@ -356,6 +363,40 @@ func (a *adminClientImpl) getJobOpts(jobName string, job *types.WorkflowJob) (*a
|
||||
})
|
||||
}
|
||||
|
||||
if step.DesiredLabels != nil {
|
||||
stepOpt.WorkerLabels = make(map[string]*admincontracts.DesiredWorkerLabels, len(step.DesiredLabels))
|
||||
for key, desiredLabel := range step.DesiredLabels {
|
||||
stepOpt.WorkerLabels[key] = &admincontracts.DesiredWorkerLabels{
|
||||
Required: &desiredLabel.Required,
|
||||
Weight: &desiredLabel.Weight,
|
||||
}
|
||||
|
||||
switch value := desiredLabel.Value.(type) {
|
||||
case string:
|
||||
strValue := value
|
||||
stepOpt.WorkerLabels[key].StrValue = &strValue
|
||||
case int:
|
||||
intValue := int32(value)
|
||||
stepOpt.WorkerLabels[key].IntValue = &intValue
|
||||
case int32:
|
||||
stepOpt.WorkerLabels[key].IntValue = &value
|
||||
case int64:
|
||||
intValue := int32(value)
|
||||
stepOpt.WorkerLabels[key].IntValue = &intValue
|
||||
default:
|
||||
// For any other type, convert to string
|
||||
strValue := fmt.Sprintf("%v", value)
|
||||
stepOpt.WorkerLabels[key].StrValue = &strValue
|
||||
}
|
||||
|
||||
if desiredLabel.Comparator != nil {
|
||||
c := admincontracts.WorkerLabelComparator(*desiredLabel.Comparator)
|
||||
stepOpt.WorkerLabels[key].Comparator = &c
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
stepOpts[i] = stepOpt
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
)
|
||||
|
||||
type DispatcherClient interface {
|
||||
GetActionListener(ctx context.Context, req *GetActionListenerRequest) (WorkerActionListener, error)
|
||||
GetActionListener(ctx context.Context, req *GetActionListenerRequest) (WorkerActionListener, *string, error)
|
||||
|
||||
SendStepActionEvent(ctx context.Context, in *ActionEvent) (*ActionEventResponse, error)
|
||||
|
||||
@@ -26,6 +26,8 @@ type DispatcherClient interface {
|
||||
ReleaseSlot(ctx context.Context, stepRunId string) error
|
||||
|
||||
RefreshTimeout(ctx context.Context, stepRunId string, incrementTimeoutBy string) error
|
||||
|
||||
UpsertWorkerLabels(ctx context.Context, workerId string, labels map[string]interface{}) error
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -39,6 +41,7 @@ type GetActionListenerRequest struct {
|
||||
Services []string
|
||||
Actions []string
|
||||
MaxRuns *int
|
||||
Labels map[string]interface{}
|
||||
}
|
||||
|
||||
// ActionPayload unmarshals the action payload into the target. It also validates the resulting target.
|
||||
@@ -179,10 +182,10 @@ type actionListenerImpl struct {
|
||||
listenerStrategy ListenerStrategy
|
||||
}
|
||||
|
||||
func (d *dispatcherClientImpl) newActionListener(ctx context.Context, req *GetActionListenerRequest) (*actionListenerImpl, error) {
|
||||
func (d *dispatcherClientImpl) newActionListener(ctx context.Context, req *GetActionListenerRequest) (*actionListenerImpl, *string, error) {
|
||||
// validate the request
|
||||
if err := d.v.Validate(req); err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
registerReq := &dispatchercontracts.WorkerRegisterRequest{
|
||||
@@ -191,6 +194,11 @@ func (d *dispatcherClientImpl) newActionListener(ctx context.Context, req *GetAc
|
||||
Services: req.Services,
|
||||
}
|
||||
|
||||
if req.Labels != nil {
|
||||
|
||||
registerReq.Labels = mapLabels(req.Labels)
|
||||
}
|
||||
|
||||
if req.MaxRuns != nil {
|
||||
mr := int32(*req.MaxRuns)
|
||||
registerReq.MaxRuns = &mr
|
||||
@@ -200,7 +208,7 @@ func (d *dispatcherClientImpl) newActionListener(ctx context.Context, req *GetAc
|
||||
resp, err := d.client.Register(d.ctx.newContext(ctx), registerReq)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not register the worker: %w", err)
|
||||
return nil, nil, fmt.Errorf("could not register the worker: %w", err)
|
||||
}
|
||||
|
||||
d.l.Debug().Msgf("Registered worker with id: %s", resp.WorkerId)
|
||||
@@ -211,7 +219,7 @@ func (d *dispatcherClientImpl) newActionListener(ctx context.Context, req *GetAc
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not subscribe to the worker: %w", err)
|
||||
return nil, nil, fmt.Errorf("could not subscribe to the worker: %w", err)
|
||||
}
|
||||
|
||||
return &actionListenerImpl{
|
||||
@@ -223,7 +231,7 @@ func (d *dispatcherClientImpl) newActionListener(ctx context.Context, req *GetAc
|
||||
tenantId: d.tenantId,
|
||||
ctx: d.ctx,
|
||||
listenerStrategy: ListenerStrategyV2,
|
||||
}, nil
|
||||
}, &resp.WorkerId, nil
|
||||
}
|
||||
|
||||
func (a *actionListenerImpl) Actions(ctx context.Context) (<-chan *Action, error) {
|
||||
@@ -391,7 +399,7 @@ func (a *actionListenerImpl) Unregister() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *dispatcherClientImpl) GetActionListener(ctx context.Context, req *GetActionListenerRequest) (WorkerActionListener, error) {
|
||||
func (d *dispatcherClientImpl) GetActionListener(ctx context.Context, req *GetActionListenerRequest) (WorkerActionListener, *string, error) {
|
||||
return d.newActionListener(ctx, req)
|
||||
}
|
||||
|
||||
@@ -511,3 +519,47 @@ func (a *dispatcherClientImpl) RefreshTimeout(ctx context.Context, stepRunId str
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *dispatcherClientImpl) UpsertWorkerLabels(ctx context.Context, workerId string, req map[string]interface{}) error {
|
||||
labels := mapLabels(req)
|
||||
|
||||
_, err := a.client.UpsertWorkerLabels(a.ctx.newContext(ctx), &dispatchercontracts.UpsertWorkerLabelsRequest{
|
||||
WorkerId: workerId,
|
||||
Labels: labels,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func mapLabels(req map[string]interface{}) map[string]*dispatchercontracts.WorkerLabels {
|
||||
labels := map[string]*dispatchercontracts.WorkerLabels{}
|
||||
|
||||
for k, v := range req {
|
||||
label := dispatchercontracts.WorkerLabels{}
|
||||
|
||||
switch value := v.(type) {
|
||||
case string:
|
||||
strValue := value
|
||||
label.StrValue = &strValue
|
||||
case int:
|
||||
intValue := int32(value)
|
||||
label.IntValue = &intValue
|
||||
case int32:
|
||||
label.IntValue = &value
|
||||
case int64:
|
||||
intValue := int32(value)
|
||||
label.IntValue = &intValue
|
||||
default:
|
||||
// For any other type, convert to string
|
||||
strValue := fmt.Sprintf("%v", value)
|
||||
label.StrValue = &strValue
|
||||
}
|
||||
|
||||
labels[k] = &label
|
||||
}
|
||||
return labels
|
||||
}
|
||||
|
||||
@@ -9,6 +9,17 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type StickyStrategy int32
|
||||
|
||||
const (
|
||||
StickyStrategy_SOFT StickyStrategy = 0
|
||||
StickyStrategy_HARD StickyStrategy = 1
|
||||
)
|
||||
|
||||
func StickyStrategyPtr(v StickyStrategy) *StickyStrategy {
|
||||
return &v
|
||||
}
|
||||
|
||||
type Workflow struct {
|
||||
Name string `yaml:"name,omitempty"`
|
||||
|
||||
@@ -25,6 +36,8 @@ type Workflow struct {
|
||||
Jobs map[string]WorkflowJob `yaml:"jobs"`
|
||||
|
||||
OnFailureJob *WorkflowJob `yaml:"onFailureJob,omitempty"`
|
||||
|
||||
StickyStrategy *StickyStrategy `yaml:"sticky,omitempty"`
|
||||
}
|
||||
|
||||
type WorkflowConcurrencyLimitStrategy string
|
||||
@@ -70,15 +83,38 @@ type WorkflowJob struct {
|
||||
Steps []WorkflowStep `yaml:"steps"`
|
||||
}
|
||||
|
||||
type WorkerLabelComparator int32
|
||||
|
||||
const (
|
||||
WorkerLabelComparator_EQUAL WorkerLabelComparator = 0
|
||||
WorkerLabelComparator_NOT_EQUAL WorkerLabelComparator = 1
|
||||
WorkerLabelComparator_GREATER_THAN WorkerLabelComparator = 2
|
||||
WorkerLabelComparator_GREATER_THAN_OR_EQUAL WorkerLabelComparator = 3
|
||||
WorkerLabelComparator_LESS_THAN WorkerLabelComparator = 4
|
||||
WorkerLabelComparator_LESS_THAN_OR_EQUAL WorkerLabelComparator = 5
|
||||
)
|
||||
|
||||
func ComparatorPtr(v WorkerLabelComparator) *WorkerLabelComparator {
|
||||
return &v
|
||||
}
|
||||
|
||||
type DesiredWorkerLabel struct {
|
||||
Value any `yaml:"value,omitempty"`
|
||||
Required bool `yaml:"required,omitempty"`
|
||||
Weight int32 `yaml:"weight,omitempty"`
|
||||
Comparator *WorkerLabelComparator `yaml:"comparator,omitempty"`
|
||||
}
|
||||
|
||||
type WorkflowStep struct {
|
||||
Name string `yaml:"name,omitempty"`
|
||||
ID string `yaml:"id,omitempty"`
|
||||
ActionID string `yaml:"action"`
|
||||
Timeout string `yaml:"timeout,omitempty"`
|
||||
With map[string]interface{} `yaml:"with,omitempty"`
|
||||
Parents []string `yaml:"parents,omitempty"`
|
||||
Retries int `yaml:"retries"`
|
||||
RateLimits []RateLimit `yaml:"rateLimits,omitempty"`
|
||||
Name string `yaml:"name,omitempty"`
|
||||
ID string `yaml:"id,omitempty"`
|
||||
ActionID string `yaml:"action"`
|
||||
Timeout string `yaml:"timeout,omitempty"`
|
||||
With map[string]interface{} `yaml:"with,omitempty"`
|
||||
Parents []string `yaml:"parents,omitempty"`
|
||||
Retries int `yaml:"retries"`
|
||||
RateLimits []RateLimit `yaml:"rateLimits,omitempty"`
|
||||
DesiredLabels map[string]*DesiredWorkerLabel `yaml:"desiredLabels,omitempty"`
|
||||
}
|
||||
|
||||
type RateLimit struct {
|
||||
|
||||
@@ -839,10 +839,10 @@ func (r *workflowEngineRepository) createJobTx(ctx context.Context, tx pgx.Tx, t
|
||||
return "", err
|
||||
}
|
||||
|
||||
if stepOpts.DesiredWorkerLabels != nil && len(*stepOpts.DesiredWorkerLabels) > 0 {
|
||||
for i := range *stepOpts.DesiredWorkerLabels {
|
||||
key := (*stepOpts.DesiredWorkerLabels)[i].Key
|
||||
value := (*stepOpts.DesiredWorkerLabels)[i]
|
||||
if stepOpts.DesiredWorkerLabels != nil && len(stepOpts.DesiredWorkerLabels) > 0 {
|
||||
for i := range stepOpts.DesiredWorkerLabels {
|
||||
key := (stepOpts.DesiredWorkerLabels)[i].Key
|
||||
value := (stepOpts.DesiredWorkerLabels)[i]
|
||||
|
||||
if key == "" {
|
||||
continue
|
||||
|
||||
@@ -41,7 +41,7 @@ type CreateStepRunEventOpts struct {
|
||||
|
||||
EventSeverity *dbsqlc.StepRunEventSeverity
|
||||
|
||||
EventData *map[string]interface{}
|
||||
EventData map[string]interface{}
|
||||
}
|
||||
|
||||
type UpdateStepRunOpts struct {
|
||||
|
||||
@@ -128,7 +128,7 @@ type CreateWorkflowStepOpts struct {
|
||||
RateLimits []CreateWorkflowStepRateLimitOpts `validate:"dive"`
|
||||
|
||||
// (optional) desired worker affinity state for this step
|
||||
DesiredWorkerLabels *map[string]DesiredWorkerLabelOpts `validate:"omitempty"`
|
||||
DesiredWorkerLabels map[string]DesiredWorkerLabelOpts `validate:"omitempty"`
|
||||
}
|
||||
|
||||
type DesiredWorkerLabelOpts struct {
|
||||
|
||||
+84
-1
@@ -12,6 +12,20 @@ import (
|
||||
"github.com/hatchet-dev/hatchet/pkg/client"
|
||||
)
|
||||
|
||||
type HatchetWorkerContext interface {
|
||||
context.Context
|
||||
|
||||
SetContext(ctx context.Context)
|
||||
|
||||
GetContext() context.Context
|
||||
|
||||
ID() string
|
||||
|
||||
GetLabels() map[string]interface{}
|
||||
|
||||
UpsertLabels(labels map[string]interface{}) error
|
||||
}
|
||||
|
||||
type HatchetContext interface {
|
||||
context.Context
|
||||
|
||||
@@ -19,6 +33,8 @@ type HatchetContext interface {
|
||||
|
||||
GetContext() context.Context
|
||||
|
||||
Worker() HatchetWorkerContext
|
||||
|
||||
StepOutput(step string, target interface{}) error
|
||||
|
||||
TriggeredByEvent() bool
|
||||
@@ -76,6 +92,9 @@ type StepData map[string]interface{}
|
||||
|
||||
type hatchetContext struct {
|
||||
context.Context
|
||||
|
||||
w *hatchetWorkerContext
|
||||
|
||||
a *client.Action
|
||||
stepData *StepRunData
|
||||
c client.Client
|
||||
@@ -87,17 +106,29 @@ type hatchetContext struct {
|
||||
listenerMu sync.Mutex
|
||||
}
|
||||
|
||||
type hatchetWorkerContext struct {
|
||||
context.Context
|
||||
id *string
|
||||
worker *Worker
|
||||
}
|
||||
|
||||
func newHatchetContext(
|
||||
ctx context.Context,
|
||||
action *client.Action,
|
||||
client client.Client,
|
||||
l *zerolog.Logger,
|
||||
w *Worker,
|
||||
) (HatchetContext, error) {
|
||||
c := &hatchetContext{
|
||||
Context: ctx,
|
||||
a: action,
|
||||
c: client,
|
||||
l: l,
|
||||
w: &hatchetWorkerContext{
|
||||
Context: ctx,
|
||||
id: w.id,
|
||||
worker: w,
|
||||
},
|
||||
}
|
||||
|
||||
if action.GetGroupKeyRunId != "" {
|
||||
@@ -125,6 +156,10 @@ func (h *hatchetContext) action() *client.Action {
|
||||
return h.a
|
||||
}
|
||||
|
||||
func (h *hatchetContext) Worker() HatchetWorkerContext {
|
||||
return h.w
|
||||
}
|
||||
|
||||
func (h *hatchetContext) SetContext(ctx context.Context) {
|
||||
h.Context = ctx
|
||||
}
|
||||
@@ -212,7 +247,8 @@ func (h *hatchetContext) inc() {
|
||||
}
|
||||
|
||||
type SpawnWorkflowOpts struct {
|
||||
Key *string
|
||||
Key *string
|
||||
Sticky *bool
|
||||
}
|
||||
|
||||
func (h *hatchetContext) saveOrLoadListener() (*client.WorkflowRunsListener, error) {
|
||||
@@ -239,6 +275,16 @@ func (h *hatchetContext) SpawnWorkflow(workflowName string, input any, opts *Spa
|
||||
opts = &SpawnWorkflowOpts{}
|
||||
}
|
||||
|
||||
var desiredWorker *string
|
||||
|
||||
if opts.Sticky != nil {
|
||||
if _, exists := h.w.worker.registered_workflows[workflowName]; !exists {
|
||||
return nil, fmt.Errorf("cannot run with sticky: workflow %s is not registered on this worker", workflowName)
|
||||
}
|
||||
|
||||
desiredWorker = h.w.id
|
||||
}
|
||||
|
||||
listener, err := h.saveOrLoadListener()
|
||||
|
||||
if err != nil {
|
||||
@@ -257,6 +303,7 @@ func (h *hatchetContext) SpawnWorkflow(workflowName string, input any, opts *Spa
|
||||
ParentStepRunId: h.StepRunId(),
|
||||
ChildIndex: h.index(),
|
||||
ChildKey: opts.Key,
|
||||
DesiredWorkerId: desiredWorker,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -331,3 +378,39 @@ func toTarget(data interface{}, target interface{}) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wc *hatchetWorkerContext) SetContext(ctx context.Context) {
|
||||
wc.Context = ctx
|
||||
}
|
||||
|
||||
func (wc *hatchetWorkerContext) GetContext() context.Context {
|
||||
return wc.Context
|
||||
}
|
||||
|
||||
func (wc *hatchetWorkerContext) ID() string {
|
||||
if wc.id == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return *wc.id
|
||||
}
|
||||
|
||||
func (wc *hatchetWorkerContext) GetLabels() map[string]interface{} {
|
||||
return wc.worker.labels
|
||||
}
|
||||
|
||||
func (wc *hatchetWorkerContext) UpsertLabels(labels map[string]interface{}) error {
|
||||
|
||||
if wc.id == nil {
|
||||
return fmt.Errorf("worker id is nil, cannot upsert labels (are on web worker?)")
|
||||
}
|
||||
|
||||
err := wc.worker.client.Dispatcher().UpsertWorkerLabels(wc.Context, *wc.id, labels)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upsert labels: %w", err)
|
||||
}
|
||||
|
||||
wc.worker.labels = labels
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -84,6 +84,10 @@ func (c *testHatchetContext) client() client.Client {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *testHatchetContext) Worker() HatchetWorkerContext {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func TestAddMiddleware(t *testing.T) {
|
||||
m := middlewares{}
|
||||
middlewareFunc := func(ctx HatchetContext, next func(HatchetContext) error) error {
|
||||
|
||||
@@ -112,7 +112,7 @@ func (w *Worker) WebhookHttpHandler(opts WebhookHandlerOptions, workflows ...wor
|
||||
return
|
||||
}
|
||||
|
||||
ctx, err := newHatchetContext(r.Context(), actionWithPayload, w.client, w.l)
|
||||
ctx, err := newHatchetContext(r.Context(), actionWithPayload, w.client, w.l, w)
|
||||
if err != nil {
|
||||
w.l.Error().Err(err).Msg("error creating context")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
+41
-11
@@ -70,6 +70,8 @@ type Worker struct {
|
||||
|
||||
actions map[string]Action
|
||||
|
||||
registered_workflows map[string]bool
|
||||
|
||||
l *zerolog.Logger
|
||||
|
||||
cancelMap sync.Map
|
||||
@@ -85,6 +87,10 @@ type Worker struct {
|
||||
maxRuns *int
|
||||
|
||||
initActionNames []string
|
||||
|
||||
labels map[string]interface{}
|
||||
|
||||
id *string
|
||||
}
|
||||
|
||||
type WorkerOpt func(*WorkerOpts)
|
||||
@@ -99,6 +105,8 @@ type WorkerOpts struct {
|
||||
maxRuns *int
|
||||
|
||||
actions []string
|
||||
|
||||
labels map[string]interface{}
|
||||
}
|
||||
|
||||
func defaultWorkerOpts() *WorkerOpts {
|
||||
@@ -161,6 +169,12 @@ func WithMaxRuns(maxRuns int) WorkerOpt {
|
||||
}
|
||||
}
|
||||
|
||||
func WithLabels(labels map[string]interface{}) WorkerOpt {
|
||||
return func(opts *WorkerOpts) {
|
||||
opts.labels = labels
|
||||
}
|
||||
}
|
||||
|
||||
// NewWorker creates a new worker instance
|
||||
func NewWorker(fs ...WorkerOpt) (*Worker, error) {
|
||||
opts := defaultWorkerOpts()
|
||||
@@ -171,15 +185,25 @@ func NewWorker(fs ...WorkerOpt) (*Worker, error) {
|
||||
|
||||
mws := newMiddlewares()
|
||||
|
||||
if opts.labels != nil {
|
||||
for _, value := range opts.labels {
|
||||
if reflect.TypeOf(value).Kind() != reflect.String && reflect.TypeOf(value).Kind() != reflect.Int {
|
||||
return nil, fmt.Errorf("invalid label value: %v", value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
w := &Worker{
|
||||
client: opts.client,
|
||||
name: opts.name,
|
||||
l: opts.l,
|
||||
actions: map[string]Action{},
|
||||
alerter: opts.alerter,
|
||||
middlewares: mws,
|
||||
maxRuns: opts.maxRuns,
|
||||
initActionNames: opts.actions,
|
||||
client: opts.client,
|
||||
name: opts.name,
|
||||
l: opts.l,
|
||||
actions: map[string]Action{},
|
||||
alerter: opts.alerter,
|
||||
middlewares: mws,
|
||||
maxRuns: opts.maxRuns,
|
||||
initActionNames: opts.actions,
|
||||
labels: opts.labels,
|
||||
registered_workflows: map[string]bool{},
|
||||
}
|
||||
|
||||
mws.add(w.panicMiddleware)
|
||||
@@ -229,6 +253,9 @@ func (w *Worker) RegisterWorkflow(workflow workflowConverter) error {
|
||||
if ok && wf.On == nil {
|
||||
return fmt.Errorf("workflow must have an trigger defined via the `On` field")
|
||||
}
|
||||
|
||||
w.registered_workflows[wf.Name] = true
|
||||
|
||||
return w.On(workflow.ToWorkflowTrigger(), workflow)
|
||||
}
|
||||
|
||||
@@ -312,12 +339,15 @@ func (w *Worker) Start() (func() error, error) {
|
||||
actionNames = append(actionNames, action.Name())
|
||||
}
|
||||
|
||||
listener, err := w.client.Dispatcher().GetActionListener(ctx, &client.GetActionListenerRequest{
|
||||
listener, id, err := w.client.Dispatcher().GetActionListener(ctx, &client.GetActionListenerRequest{
|
||||
WorkerName: w.name,
|
||||
Actions: actionNames,
|
||||
MaxRuns: w.maxRuns,
|
||||
Labels: w.labels,
|
||||
})
|
||||
|
||||
w.id = id
|
||||
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("could not get action listener: %w", err)
|
||||
@@ -408,7 +438,7 @@ func (w *Worker) startStepRun(ctx context.Context, assignedAction *client.Action
|
||||
|
||||
w.cancelMap.Store(assignedAction.StepRunId, cancel)
|
||||
|
||||
hCtx, err := newHatchetContext(runContext, assignedAction, w.client, w.l)
|
||||
hCtx, err := newHatchetContext(runContext, assignedAction, w.client, w.l, w)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create hatchet context: %w", err)
|
||||
@@ -506,7 +536,7 @@ func (w *Worker) startGetGroupKey(ctx context.Context, assignedAction *client.Ac
|
||||
|
||||
w.cancelConcurrencyMap.Store(assignedAction.WorkflowRunId, cancel)
|
||||
|
||||
hCtx, err := newHatchetContext(runContext, assignedAction, w.client, w.l)
|
||||
hCtx, err := newHatchetContext(runContext, assignedAction, w.client, w.l, w)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create hatchet context: %w", err)
|
||||
|
||||
@@ -54,7 +54,7 @@ type WebhookWorkerOpts struct {
|
||||
// TODO do not expose this to the end-user client somehow
|
||||
func (w *Worker) StartWebhook(ww WebhookWorkerOpts) (func() error, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
listener, err := w.client.Dispatcher().GetActionListener(ctx, &client.GetActionListenerRequest{
|
||||
listener, _, err := w.client.Dispatcher().GetActionListener(ctx, &client.GetActionListenerRequest{
|
||||
WorkerName: w.name,
|
||||
Actions: w.initActionNames,
|
||||
MaxRuns: w.maxRuns,
|
||||
|
||||
+20
-6
@@ -151,6 +151,8 @@ type WorkflowJob struct {
|
||||
OnFailure *WorkflowJob
|
||||
|
||||
ScheduleTimeout string
|
||||
|
||||
StickyStrategy *types.StickyStrategy
|
||||
}
|
||||
|
||||
type WorkflowConcurrency struct {
|
||||
@@ -217,6 +219,10 @@ func (j *WorkflowJob) ToWorkflow(svcName string, namespace string) types.Workflo
|
||||
}
|
||||
}
|
||||
|
||||
if j.StickyStrategy != nil {
|
||||
w.StickyStrategy = j.StickyStrategy
|
||||
}
|
||||
|
||||
return w
|
||||
}
|
||||
|
||||
@@ -284,6 +290,8 @@ type WorkflowStep struct {
|
||||
Retries int
|
||||
|
||||
RateLimit []RateLimit
|
||||
|
||||
DesiredLabels map[string]*types.DesiredWorkerLabel
|
||||
}
|
||||
|
||||
type RateLimit struct {
|
||||
@@ -307,6 +315,11 @@ func (w *WorkflowStep) SetName(name string) *WorkflowStep {
|
||||
return w
|
||||
}
|
||||
|
||||
func (w *WorkflowStep) SetDesiredLabels(labels map[string]*types.DesiredWorkerLabel) *WorkflowStep {
|
||||
w.DesiredLabels = labels
|
||||
return w
|
||||
}
|
||||
|
||||
func (w *WorkflowStep) SetRateLimit(rateLimit RateLimit) *WorkflowStep {
|
||||
w.RateLimit = append(w.RateLimit, rateLimit)
|
||||
return w
|
||||
@@ -375,12 +388,13 @@ func (w *WorkflowStep) ToWorkflowStep(svcName string, index int, namespace strin
|
||||
res.Id = w.GetStepId(index)
|
||||
|
||||
res.APIStep = types.WorkflowStep{
|
||||
Name: res.Id,
|
||||
ID: w.GetStepId(index),
|
||||
Timeout: w.Timeout,
|
||||
ActionID: w.GetActionId(svcName, index),
|
||||
Parents: []string{},
|
||||
Retries: w.Retries,
|
||||
Name: res.Id,
|
||||
ID: w.GetStepId(index),
|
||||
Timeout: w.Timeout,
|
||||
ActionID: w.GetActionId(svcName, index),
|
||||
Parents: []string{},
|
||||
Retries: w.Retries,
|
||||
DesiredLabels: w.DesiredLabels,
|
||||
}
|
||||
|
||||
for _, rateLimit := range w.RateLimit {
|
||||
|
||||
Reference in New Issue
Block a user