feat: go worker assignment (#741)

* feat: create worker with label

* feat: worker context

* feat: dynamic labels

* feat: affinity

* fix: ptr

* fix: nil labels

* feat: sticky dag

* feat: sticky docs

* feat: sticky children

* chore: lint

* fix: tests

* fix: possibly nil workerId

* chore: cleanup unneeded pointers
This commit is contained in:
Gabe Ruttner
2024-07-26 10:19:11 -07:00
committed by GitHub
parent 1ea4dfc5de
commit fd947cb5bc
20 changed files with 720 additions and 51 deletions
+38
View File
@@ -0,0 +1,38 @@
package main
import (
"fmt"
"github.com/joho/godotenv"
"github.com/hatchet-dev/hatchet/pkg/cmdutils"
)
type userCreateEvent struct {
Username string `json:"username"`
UserID string `json:"user_id"`
Data map[string]string `json:"data"`
}
type stepOneOutput struct {
Message string `json:"message"`
}
func main() {
err := godotenv.Load()
if err != nil {
panic(err)
}
ch := cmdutils.InterruptChan()
cleanup, err := run()
if err != nil {
panic(err)
}
<-ch
if err := cleanup(); err != nil {
panic(fmt.Errorf("cleanup() error = %v", err))
}
}
+106
View File
@@ -0,0 +1,106 @@
package main
import (
"context"
"fmt"
"log"
"time"
"github.com/hatchet-dev/hatchet/pkg/client"
"github.com/hatchet-dev/hatchet/pkg/client/types"
"github.com/hatchet-dev/hatchet/pkg/worker"
)
func run() (func() error, error) {
c, err := client.New()
if err != nil {
return nil, fmt.Errorf("error creating client: %w", err)
}
w, err := worker.NewWorker(
worker.WithClient(
c,
),
worker.WithLabels(map[string]interface{}{
"model": "fancy-ai-model-v2",
"memory": 512,
}),
)
if err != nil {
return nil, fmt.Errorf("error creating worker: %w", err)
}
err = w.RegisterWorkflow(
&worker.WorkflowJob{
On: worker.Events("user:create:affinity"),
Name: "affinity",
Description: "affinity",
Steps: []*worker.WorkflowStep{
worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
model := ctx.Worker().GetLabels()["model"]
if model != "fancy-ai-model-v3" {
ctx.Worker().UpsertLabels(map[string]interface{}{
"model": nil,
})
// Do something to load the model
ctx.Worker().UpsertLabels(map[string]interface{}{
"model": "fancy-ai-model-v3",
})
}
return &stepOneOutput{
Message: ctx.Worker().ID(),
}, nil
}).
SetName("step-one").
SetDesiredLabels(map[string]*types.DesiredWorkerLabel{
"model": {
Value: "fancy-ai-model-v3",
Weight: 10,
},
"memory": {
Value: 512,
Required: true,
Comparator: types.ComparatorPtr(types.WorkerLabelComparator_GREATER_THAN),
},
}),
},
},
)
if err != nil {
return nil, fmt.Errorf("error registering workflow: %w", err)
}
go func() {
log.Printf("pushing event")
testEvent := userCreateEvent{
Username: "echo-test",
UserID: "1234",
Data: map[string]string{
"test": "test",
},
}
// push an event
err := c.Event().Push(
context.Background(),
"user:create:affinity",
testEvent,
)
if err != nil {
panic(fmt.Errorf("error pushing event: %w", err))
}
time.Sleep(10 * time.Second)
}()
cleanup, err := w.Start()
if err != nil {
return nil, fmt.Errorf("error starting worker: %w", err)
}
return cleanup, nil
}
+38
View File
@@ -0,0 +1,38 @@
package main
import (
"fmt"
"github.com/joho/godotenv"
"github.com/hatchet-dev/hatchet/pkg/cmdutils"
)
type userCreateEvent struct {
Username string `json:"username"`
UserID string `json:"user_id"`
Data map[string]string `json:"data"`
}
type stepOneOutput struct {
Message string `json:"message"`
}
func main() {
err := godotenv.Load()
if err != nil {
panic(err)
}
ch := cmdutils.InterruptChan()
cleanup, err := run()
if err != nil {
panic(err)
}
<-ch
if err := cleanup(); err != nil {
panic(fmt.Errorf("cleanup() error = %v", err))
}
}
+99
View File
@@ -0,0 +1,99 @@
package main
import (
"context"
"fmt"
"log"
"time"
"github.com/hatchet-dev/hatchet/pkg/client"
"github.com/hatchet-dev/hatchet/pkg/client/types"
"github.com/hatchet-dev/hatchet/pkg/worker"
)
func run() (func() error, error) {
c, err := client.New()
if err != nil {
return nil, fmt.Errorf("error creating client: %w", err)
}
w, err := worker.NewWorker(
worker.WithClient(
c,
),
)
if err != nil {
return nil, fmt.Errorf("error creating worker: %w", err)
}
err = w.RegisterWorkflow(
&worker.WorkflowJob{
On: worker.Events("user:create:sticky"),
Name: "sticky",
Description: "sticky",
StickyStrategy: types.StickyStrategyPtr(types.StickyStrategy_SOFT),
Steps: []*worker.WorkflowStep{
worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
sticky := true
_, err = ctx.SpawnWorkflow("step-one", nil, &worker.SpawnWorkflowOpts{
Sticky: &sticky,
})
if err != nil {
return nil, fmt.Errorf("error spawning workflow: %w", err)
}
return &stepOneOutput{
Message: ctx.Worker().ID(),
}, nil
}).SetName("step-one"),
worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
return &stepOneOutput{
Message: ctx.Worker().ID(),
}, nil
}).SetName("step-two"),
worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
return &stepOneOutput{
Message: ctx.Worker().ID(),
}, nil
}).SetName("step-three").AddParents("step-one", "step-two"),
},
},
)
if err != nil {
return nil, fmt.Errorf("error registering workflow: %w", err)
}
go func() {
log.Printf("pushing event")
testEvent := userCreateEvent{
Username: "echo-test",
UserID: "1234",
Data: map[string]string{
"test": "test",
},
}
// push an event
err := c.Event().Push(
context.Background(),
"user:create:sticky",
testEvent,
)
if err != nil {
panic(fmt.Errorf("error pushing event: %w", err))
}
time.Sleep(10 * time.Second)
}()
cleanup, err := w.Start()
if err != nil {
return nil, fmt.Errorf("error starting worker: %w", err)
}
return cleanup, nil
}
@@ -70,7 +70,31 @@ const myWorkflow: Workflow = {
</Tabs.Tab>
<Tabs.Tab>
```go
// TODO add go example
err = w.RegisterWorkflow(
&worker.WorkflowJob{
On: worker.Events("user:create:sticky"),
Name: "sticky",
Description: "sticky",
StickyStrategy: types.StickyStrategyPtr(types.StickyStrategy_SOFT),
Steps: []*worker.WorkflowStep{
worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
return &stepOneOutput{
Message: ctx.Worker().ID(),
}, nil
}).SetName("step-one"),
worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
return &stepOneOutput{
Message: ctx.Worker().ID(),
}, nil
}).SetName("step-two"),
worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
return &stepOneOutput{
Message: ctx.Worker().ID(),
}, nil
}).SetName("step-three").AddParents("step-one", "step-two"),
},
},
)
```
</Tabs.Tab>
</Tabs>
@@ -143,7 +167,33 @@ const parentWorkflow: Workflow = {
</Tabs.Tab>
<Tabs.Tab>
```go
// TODO go example
err = w.RegisterWorkflow(
&worker.WorkflowJob{
On: worker.Events("user:create:sticky"),
Name: "sticky",
Description: "sticky",
Steps: []*worker.WorkflowStep{
worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
sticky := true
_, err = ctx.SpawnWorkflow("sticky-child", nil, &worker.SpawnWorkflowOpts{
Sticky: &sticky,
})
if err != nil {
return nil, fmt.Errorf("error spawning workflow: %w", err)
}
return &stepOneOutput{
Message: ctx.Worker().ID(),
}, nil
}).SetName("step-one"),
},
},
)
```
</Tabs.Tab>
</Tabs>
```
@@ -38,7 +38,15 @@ worker = hatchet.worker(
</Tabs.Tab>
<Tabs.Tab>
```go
// TODO add go example
w, err := worker.NewWorker(
worker.WithClient(
c,
),
worker.WithLabels(map[string]interface{}{
"model": "fancy-ai-model-v2",
"memory": 512,
}),
)
```
</Tabs.Tab>
</Tabs>
@@ -107,7 +115,32 @@ const affinity: Workflow = {
</Tabs.Tab>
<Tabs.Tab>
```go
// TODO go example
err = w.RegisterWorkflow(
&worker.WorkflowJob{
On: worker.Events("user:create:affinity"),
Name: "affinity",
Description: "affinity",
Steps: []*worker.WorkflowStep{
worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
return &stepOneOutput{
Message: ctx.Worker().ID(),
}, nil
}).
SetName("step-one").
SetDesiredLabels(map[string]*types.DesiredWorkerLabel{
"model": {
Value: "fancy-ai-model-v2",
Weight: 10,
},
"memory": {
Value: 512,
Required: true,
Comparator: types.ComparatorPtr(types.WorkerLabelComparator_GREATER_THAN),
},
}),
},
},
)
```
</Tabs.Tab>
</Tabs>
@@ -129,7 +162,10 @@ Labels can also be set dynamically on workers using the `upsertLabels` method. T
```python
@hatchet.step(
desired_worker_labels={
"model": {"value": "fancy-vision-model", "weight": 10},
"model": {
"value": "fancy-vision-model",
"weight": 10
},
"memory": {
"value": 256,
"required": True,
@@ -181,7 +217,49 @@ Labels can also be set dynamically on workers using the `upsertLabels` method. T
</Tabs.Tab>
<Tabs.Tab>
```go
// TODO add go example
err = w.RegisterWorkflow(
&worker.WorkflowJob{
On: worker.Events("user:create:affinity"),
Name: "affinity",
Description: "affinity",
Steps: []*worker.WorkflowStep{
worker.Fn(func(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
model := ctx.Worker().GetLabels()["model"]
if model != "fancy-vision-model" {
ctx.Worker().UpsertLabels(map[string]interface{}{
"model": nil,
})
// Do something to load the model
evictModel();
loadNewModel("fancy-vision-model");
ctx.Worker().UpsertLabels(map[string]interface{}{
"model": "fancy-vision-model",
})
}
return &stepOneOutput{
Message: ctx.Worker().ID(),
}, nil
}).
SetName("step-one").
SetDesiredLabels(map[string]*types.DesiredWorkerLabel{
"model": {
Value: "fancy-vision-model",
Weight: 10,
},
"memory": {
Value: 512,
Required: true,
Comparator: types.WorkerLabelComparator_GREATER_THAN,
},
}),
},
},
)
```
</Tabs.Tab>
</Tabs>
```
+3 -3
View File
@@ -478,10 +478,10 @@ func getCreateJobOpts(req *contracts.CreateWorkflowJobOpts, kind string) (*repos
retries := int(stepCp.Retries)
var affinity *map[string]repository.DesiredWorkerLabelOpts
var affinity map[string]repository.DesiredWorkerLabelOpts
if stepCp.WorkerLabels != nil {
affinity = &map[string]repository.DesiredWorkerLabelOpts{}
affinity = map[string]repository.DesiredWorkerLabelOpts{}
for k, v := range stepCp.WorkerLabels {
var c *string
@@ -491,7 +491,7 @@ func getCreateJobOpts(req *contracts.CreateWorkflowJobOpts, kind string) (*repos
c = &cPtr
}
(*affinity)[k] = repository.DesiredWorkerLabelOpts{
(affinity)[k] = repository.DesiredWorkerLabelOpts{
Key: k,
StrValue: v.StrValue,
IntValue: v.IntValue,
@@ -765,7 +765,7 @@ func (ec *JobsControllerImpl) runStepRunReassignTenant(ctx context.Context, tena
EventReason: repository.StepRunEventReasonPtr(dbsqlc.StepRunEventReasonREASSIGNED),
EventSeverity: repository.StepRunEventSeverityPtr(dbsqlc.StepRunEventSeverityCRITICAL),
EventMessage: repository.StringPtr("Worker has become inactive"),
EventData: &eventData,
EventData: eventData,
})
if err != nil {
+41
View File
@@ -21,6 +21,7 @@ type ChildWorkflowOpts struct {
ParentStepRunId string
ChildIndex int
ChildKey *string
DesiredWorkerId *string
}
type AdminClient interface {
@@ -216,6 +217,7 @@ func (a *adminClientImpl) RunChildWorkflow(workflowName string, input interface{
ParentStepRunId: &opts.ParentStepRunId,
ChildIndex: &childIndex,
ChildKey: opts.ChildKey,
DesiredWorkerId: opts.DesiredWorkerId,
})
if err != nil {
@@ -264,6 +266,11 @@ func (a *adminClientImpl) getPutRequest(workflow *types.Workflow) (*admincontrac
CronTriggers: workflow.Triggers.Cron,
}
if workflow.StickyStrategy != nil {
s := admincontracts.StickyStrategy(*workflow.StickyStrategy)
opts.Sticky = &s
}
if workflow.Concurrency != nil {
opts.Concurrency = &admincontracts.WorkflowConcurrencyOpts{
Action: workflow.Concurrency.ActionID,
@@ -356,6 +363,40 @@ func (a *adminClientImpl) getJobOpts(jobName string, job *types.WorkflowJob) (*a
})
}
if step.DesiredLabels != nil {
stepOpt.WorkerLabels = make(map[string]*admincontracts.DesiredWorkerLabels, len(step.DesiredLabels))
for key, desiredLabel := range step.DesiredLabels {
stepOpt.WorkerLabels[key] = &admincontracts.DesiredWorkerLabels{
Required: &desiredLabel.Required,
Weight: &desiredLabel.Weight,
}
switch value := desiredLabel.Value.(type) {
case string:
strValue := value
stepOpt.WorkerLabels[key].StrValue = &strValue
case int:
intValue := int32(value)
stepOpt.WorkerLabels[key].IntValue = &intValue
case int32:
stepOpt.WorkerLabels[key].IntValue = &value
case int64:
intValue := int32(value)
stepOpt.WorkerLabels[key].IntValue = &intValue
default:
// For any other type, convert to string
strValue := fmt.Sprintf("%v", value)
stepOpt.WorkerLabels[key].StrValue = &strValue
}
if desiredLabel.Comparator != nil {
c := admincontracts.WorkerLabelComparator(*desiredLabel.Comparator)
stepOpt.WorkerLabels[key].Comparator = &c
}
}
}
stepOpts[i] = stepOpt
}
+59 -7
View File
@@ -17,7 +17,7 @@ import (
)
type DispatcherClient interface {
GetActionListener(ctx context.Context, req *GetActionListenerRequest) (WorkerActionListener, error)
GetActionListener(ctx context.Context, req *GetActionListenerRequest) (WorkerActionListener, *string, error)
SendStepActionEvent(ctx context.Context, in *ActionEvent) (*ActionEventResponse, error)
@@ -26,6 +26,8 @@ type DispatcherClient interface {
ReleaseSlot(ctx context.Context, stepRunId string) error
RefreshTimeout(ctx context.Context, stepRunId string, incrementTimeoutBy string) error
UpsertWorkerLabels(ctx context.Context, workerId string, labels map[string]interface{}) error
}
const (
@@ -39,6 +41,7 @@ type GetActionListenerRequest struct {
Services []string
Actions []string
MaxRuns *int
Labels map[string]interface{}
}
// ActionPayload unmarshals the action payload into the target. It also validates the resulting target.
@@ -179,10 +182,10 @@ type actionListenerImpl struct {
listenerStrategy ListenerStrategy
}
func (d *dispatcherClientImpl) newActionListener(ctx context.Context, req *GetActionListenerRequest) (*actionListenerImpl, error) {
func (d *dispatcherClientImpl) newActionListener(ctx context.Context, req *GetActionListenerRequest) (*actionListenerImpl, *string, error) {
// validate the request
if err := d.v.Validate(req); err != nil {
return nil, err
return nil, nil, err
}
registerReq := &dispatchercontracts.WorkerRegisterRequest{
@@ -191,6 +194,11 @@ func (d *dispatcherClientImpl) newActionListener(ctx context.Context, req *GetAc
Services: req.Services,
}
if req.Labels != nil {
registerReq.Labels = mapLabels(req.Labels)
}
if req.MaxRuns != nil {
mr := int32(*req.MaxRuns)
registerReq.MaxRuns = &mr
@@ -200,7 +208,7 @@ func (d *dispatcherClientImpl) newActionListener(ctx context.Context, req *GetAc
resp, err := d.client.Register(d.ctx.newContext(ctx), registerReq)
if err != nil {
return nil, fmt.Errorf("could not register the worker: %w", err)
return nil, nil, fmt.Errorf("could not register the worker: %w", err)
}
d.l.Debug().Msgf("Registered worker with id: %s", resp.WorkerId)
@@ -211,7 +219,7 @@ func (d *dispatcherClientImpl) newActionListener(ctx context.Context, req *GetAc
})
if err != nil {
return nil, fmt.Errorf("could not subscribe to the worker: %w", err)
return nil, nil, fmt.Errorf("could not subscribe to the worker: %w", err)
}
return &actionListenerImpl{
@@ -223,7 +231,7 @@ func (d *dispatcherClientImpl) newActionListener(ctx context.Context, req *GetAc
tenantId: d.tenantId,
ctx: d.ctx,
listenerStrategy: ListenerStrategyV2,
}, nil
}, &resp.WorkerId, nil
}
func (a *actionListenerImpl) Actions(ctx context.Context) (<-chan *Action, error) {
@@ -391,7 +399,7 @@ func (a *actionListenerImpl) Unregister() error {
return nil
}
func (d *dispatcherClientImpl) GetActionListener(ctx context.Context, req *GetActionListenerRequest) (WorkerActionListener, error) {
func (d *dispatcherClientImpl) GetActionListener(ctx context.Context, req *GetActionListenerRequest) (WorkerActionListener, *string, error) {
return d.newActionListener(ctx, req)
}
@@ -511,3 +519,47 @@ func (a *dispatcherClientImpl) RefreshTimeout(ctx context.Context, stepRunId str
return nil
}
func (a *dispatcherClientImpl) UpsertWorkerLabels(ctx context.Context, workerId string, req map[string]interface{}) error {
labels := mapLabels(req)
_, err := a.client.UpsertWorkerLabels(a.ctx.newContext(ctx), &dispatchercontracts.UpsertWorkerLabelsRequest{
WorkerId: workerId,
Labels: labels,
})
if err != nil {
return err
}
return nil
}
func mapLabels(req map[string]interface{}) map[string]*dispatchercontracts.WorkerLabels {
labels := map[string]*dispatchercontracts.WorkerLabels{}
for k, v := range req {
label := dispatchercontracts.WorkerLabels{}
switch value := v.(type) {
case string:
strValue := value
label.StrValue = &strValue
case int:
intValue := int32(value)
label.IntValue = &intValue
case int32:
label.IntValue = &value
case int64:
intValue := int32(value)
label.IntValue = &intValue
default:
// For any other type, convert to string
strValue := fmt.Sprintf("%v", value)
label.StrValue = &strValue
}
labels[k] = &label
}
return labels
}
+44 -8
View File
@@ -9,6 +9,17 @@ import (
"gopkg.in/yaml.v3"
)
type StickyStrategy int32
const (
StickyStrategy_SOFT StickyStrategy = 0
StickyStrategy_HARD StickyStrategy = 1
)
func StickyStrategyPtr(v StickyStrategy) *StickyStrategy {
return &v
}
type Workflow struct {
Name string `yaml:"name,omitempty"`
@@ -25,6 +36,8 @@ type Workflow struct {
Jobs map[string]WorkflowJob `yaml:"jobs"`
OnFailureJob *WorkflowJob `yaml:"onFailureJob,omitempty"`
StickyStrategy *StickyStrategy `yaml:"sticky,omitempty"`
}
type WorkflowConcurrencyLimitStrategy string
@@ -70,15 +83,38 @@ type WorkflowJob struct {
Steps []WorkflowStep `yaml:"steps"`
}
type WorkerLabelComparator int32
const (
WorkerLabelComparator_EQUAL WorkerLabelComparator = 0
WorkerLabelComparator_NOT_EQUAL WorkerLabelComparator = 1
WorkerLabelComparator_GREATER_THAN WorkerLabelComparator = 2
WorkerLabelComparator_GREATER_THAN_OR_EQUAL WorkerLabelComparator = 3
WorkerLabelComparator_LESS_THAN WorkerLabelComparator = 4
WorkerLabelComparator_LESS_THAN_OR_EQUAL WorkerLabelComparator = 5
)
func ComparatorPtr(v WorkerLabelComparator) *WorkerLabelComparator {
return &v
}
type DesiredWorkerLabel struct {
Value any `yaml:"value,omitempty"`
Required bool `yaml:"required,omitempty"`
Weight int32 `yaml:"weight,omitempty"`
Comparator *WorkerLabelComparator `yaml:"comparator,omitempty"`
}
type WorkflowStep struct {
Name string `yaml:"name,omitempty"`
ID string `yaml:"id,omitempty"`
ActionID string `yaml:"action"`
Timeout string `yaml:"timeout,omitempty"`
With map[string]interface{} `yaml:"with,omitempty"`
Parents []string `yaml:"parents,omitempty"`
Retries int `yaml:"retries"`
RateLimits []RateLimit `yaml:"rateLimits,omitempty"`
Name string `yaml:"name,omitempty"`
ID string `yaml:"id,omitempty"`
ActionID string `yaml:"action"`
Timeout string `yaml:"timeout,omitempty"`
With map[string]interface{} `yaml:"with,omitempty"`
Parents []string `yaml:"parents,omitempty"`
Retries int `yaml:"retries"`
RateLimits []RateLimit `yaml:"rateLimits,omitempty"`
DesiredLabels map[string]*DesiredWorkerLabel `yaml:"desiredLabels,omitempty"`
}
type RateLimit struct {
+4 -4
View File
@@ -839,10 +839,10 @@ func (r *workflowEngineRepository) createJobTx(ctx context.Context, tx pgx.Tx, t
return "", err
}
if stepOpts.DesiredWorkerLabels != nil && len(*stepOpts.DesiredWorkerLabels) > 0 {
for i := range *stepOpts.DesiredWorkerLabels {
key := (*stepOpts.DesiredWorkerLabels)[i].Key
value := (*stepOpts.DesiredWorkerLabels)[i]
if stepOpts.DesiredWorkerLabels != nil && len(stepOpts.DesiredWorkerLabels) > 0 {
for i := range stepOpts.DesiredWorkerLabels {
key := (stepOpts.DesiredWorkerLabels)[i].Key
value := (stepOpts.DesiredWorkerLabels)[i]
if key == "" {
continue
+1 -1
View File
@@ -41,7 +41,7 @@ type CreateStepRunEventOpts struct {
EventSeverity *dbsqlc.StepRunEventSeverity
EventData *map[string]interface{}
EventData map[string]interface{}
}
type UpdateStepRunOpts struct {
+1 -1
View File
@@ -128,7 +128,7 @@ type CreateWorkflowStepOpts struct {
RateLimits []CreateWorkflowStepRateLimitOpts `validate:"dive"`
// (optional) desired worker affinity state for this step
DesiredWorkerLabels *map[string]DesiredWorkerLabelOpts `validate:"omitempty"`
DesiredWorkerLabels map[string]DesiredWorkerLabelOpts `validate:"omitempty"`
}
type DesiredWorkerLabelOpts struct {
+84 -1
View File
@@ -12,6 +12,20 @@ import (
"github.com/hatchet-dev/hatchet/pkg/client"
)
type HatchetWorkerContext interface {
context.Context
SetContext(ctx context.Context)
GetContext() context.Context
ID() string
GetLabels() map[string]interface{}
UpsertLabels(labels map[string]interface{}) error
}
type HatchetContext interface {
context.Context
@@ -19,6 +33,8 @@ type HatchetContext interface {
GetContext() context.Context
Worker() HatchetWorkerContext
StepOutput(step string, target interface{}) error
TriggeredByEvent() bool
@@ -76,6 +92,9 @@ type StepData map[string]interface{}
type hatchetContext struct {
context.Context
w *hatchetWorkerContext
a *client.Action
stepData *StepRunData
c client.Client
@@ -87,17 +106,29 @@ type hatchetContext struct {
listenerMu sync.Mutex
}
type hatchetWorkerContext struct {
context.Context
id *string
worker *Worker
}
func newHatchetContext(
ctx context.Context,
action *client.Action,
client client.Client,
l *zerolog.Logger,
w *Worker,
) (HatchetContext, error) {
c := &hatchetContext{
Context: ctx,
a: action,
c: client,
l: l,
w: &hatchetWorkerContext{
Context: ctx,
id: w.id,
worker: w,
},
}
if action.GetGroupKeyRunId != "" {
@@ -125,6 +156,10 @@ func (h *hatchetContext) action() *client.Action {
return h.a
}
func (h *hatchetContext) Worker() HatchetWorkerContext {
return h.w
}
func (h *hatchetContext) SetContext(ctx context.Context) {
h.Context = ctx
}
@@ -212,7 +247,8 @@ func (h *hatchetContext) inc() {
}
type SpawnWorkflowOpts struct {
Key *string
Key *string
Sticky *bool
}
func (h *hatchetContext) saveOrLoadListener() (*client.WorkflowRunsListener, error) {
@@ -239,6 +275,16 @@ func (h *hatchetContext) SpawnWorkflow(workflowName string, input any, opts *Spa
opts = &SpawnWorkflowOpts{}
}
var desiredWorker *string
if opts.Sticky != nil {
if _, exists := h.w.worker.registered_workflows[workflowName]; !exists {
return nil, fmt.Errorf("cannot run with sticky: workflow %s is not registered on this worker", workflowName)
}
desiredWorker = h.w.id
}
listener, err := h.saveOrLoadListener()
if err != nil {
@@ -257,6 +303,7 @@ func (h *hatchetContext) SpawnWorkflow(workflowName string, input any, opts *Spa
ParentStepRunId: h.StepRunId(),
ChildIndex: h.index(),
ChildKey: opts.Key,
DesiredWorkerId: desiredWorker,
},
)
@@ -331,3 +378,39 @@ func toTarget(data interface{}, target interface{}) error {
return nil
}
func (wc *hatchetWorkerContext) SetContext(ctx context.Context) {
wc.Context = ctx
}
func (wc *hatchetWorkerContext) GetContext() context.Context {
return wc.Context
}
func (wc *hatchetWorkerContext) ID() string {
if wc.id == nil {
return ""
}
return *wc.id
}
func (wc *hatchetWorkerContext) GetLabels() map[string]interface{} {
return wc.worker.labels
}
func (wc *hatchetWorkerContext) UpsertLabels(labels map[string]interface{}) error {
if wc.id == nil {
return fmt.Errorf("worker id is nil, cannot upsert labels (are on web worker?)")
}
err := wc.worker.client.Dispatcher().UpsertWorkerLabels(wc.Context, *wc.id, labels)
if err != nil {
return fmt.Errorf("failed to upsert labels: %w", err)
}
wc.worker.labels = labels
return nil
}
+4
View File
@@ -84,6 +84,10 @@ func (c *testHatchetContext) client() client.Client {
panic("not implemented")
}
func (c *testHatchetContext) Worker() HatchetWorkerContext {
panic("not implemented")
}
func TestAddMiddleware(t *testing.T) {
m := middlewares{}
middlewareFunc := func(ctx HatchetContext, next func(HatchetContext) error) error {
+1 -1
View File
@@ -112,7 +112,7 @@ func (w *Worker) WebhookHttpHandler(opts WebhookHandlerOptions, workflows ...wor
return
}
ctx, err := newHatchetContext(r.Context(), actionWithPayload, w.client, w.l)
ctx, err := newHatchetContext(r.Context(), actionWithPayload, w.client, w.l, w)
if err != nil {
w.l.Error().Err(err).Msg("error creating context")
writer.WriteHeader(http.StatusInternalServerError)
+41 -11
View File
@@ -70,6 +70,8 @@ type Worker struct {
actions map[string]Action
registered_workflows map[string]bool
l *zerolog.Logger
cancelMap sync.Map
@@ -85,6 +87,10 @@ type Worker struct {
maxRuns *int
initActionNames []string
labels map[string]interface{}
id *string
}
type WorkerOpt func(*WorkerOpts)
@@ -99,6 +105,8 @@ type WorkerOpts struct {
maxRuns *int
actions []string
labels map[string]interface{}
}
func defaultWorkerOpts() *WorkerOpts {
@@ -161,6 +169,12 @@ func WithMaxRuns(maxRuns int) WorkerOpt {
}
}
func WithLabels(labels map[string]interface{}) WorkerOpt {
return func(opts *WorkerOpts) {
opts.labels = labels
}
}
// NewWorker creates a new worker instance
func NewWorker(fs ...WorkerOpt) (*Worker, error) {
opts := defaultWorkerOpts()
@@ -171,15 +185,25 @@ func NewWorker(fs ...WorkerOpt) (*Worker, error) {
mws := newMiddlewares()
if opts.labels != nil {
for _, value := range opts.labels {
if reflect.TypeOf(value).Kind() != reflect.String && reflect.TypeOf(value).Kind() != reflect.Int {
return nil, fmt.Errorf("invalid label value: %v", value)
}
}
}
w := &Worker{
client: opts.client,
name: opts.name,
l: opts.l,
actions: map[string]Action{},
alerter: opts.alerter,
middlewares: mws,
maxRuns: opts.maxRuns,
initActionNames: opts.actions,
client: opts.client,
name: opts.name,
l: opts.l,
actions: map[string]Action{},
alerter: opts.alerter,
middlewares: mws,
maxRuns: opts.maxRuns,
initActionNames: opts.actions,
labels: opts.labels,
registered_workflows: map[string]bool{},
}
mws.add(w.panicMiddleware)
@@ -229,6 +253,9 @@ func (w *Worker) RegisterWorkflow(workflow workflowConverter) error {
if ok && wf.On == nil {
return fmt.Errorf("workflow must have an trigger defined via the `On` field")
}
w.registered_workflows[wf.Name] = true
return w.On(workflow.ToWorkflowTrigger(), workflow)
}
@@ -312,12 +339,15 @@ func (w *Worker) Start() (func() error, error) {
actionNames = append(actionNames, action.Name())
}
listener, err := w.client.Dispatcher().GetActionListener(ctx, &client.GetActionListenerRequest{
listener, id, err := w.client.Dispatcher().GetActionListener(ctx, &client.GetActionListenerRequest{
WorkerName: w.name,
Actions: actionNames,
MaxRuns: w.maxRuns,
Labels: w.labels,
})
w.id = id
if err != nil {
cancel()
return nil, fmt.Errorf("could not get action listener: %w", err)
@@ -408,7 +438,7 @@ func (w *Worker) startStepRun(ctx context.Context, assignedAction *client.Action
w.cancelMap.Store(assignedAction.StepRunId, cancel)
hCtx, err := newHatchetContext(runContext, assignedAction, w.client, w.l)
hCtx, err := newHatchetContext(runContext, assignedAction, w.client, w.l, w)
if err != nil {
return fmt.Errorf("could not create hatchet context: %w", err)
@@ -506,7 +536,7 @@ func (w *Worker) startGetGroupKey(ctx context.Context, assignedAction *client.Ac
w.cancelConcurrencyMap.Store(assignedAction.WorkflowRunId, cancel)
hCtx, err := newHatchetContext(runContext, assignedAction, w.client, w.l)
hCtx, err := newHatchetContext(runContext, assignedAction, w.client, w.l, w)
if err != nil {
return fmt.Errorf("could not create hatchet context: %w", err)
+1 -1
View File
@@ -54,7 +54,7 @@ type WebhookWorkerOpts struct {
// TODO do not expose this to the end-user client somehow
func (w *Worker) StartWebhook(ww WebhookWorkerOpts) (func() error, error) {
ctx, cancel := context.WithCancel(context.Background())
listener, err := w.client.Dispatcher().GetActionListener(ctx, &client.GetActionListenerRequest{
listener, _, err := w.client.Dispatcher().GetActionListener(ctx, &client.GetActionListenerRequest{
WorkerName: w.name,
Actions: w.initActionNames,
MaxRuns: w.maxRuns,
+20 -6
View File
@@ -151,6 +151,8 @@ type WorkflowJob struct {
OnFailure *WorkflowJob
ScheduleTimeout string
StickyStrategy *types.StickyStrategy
}
type WorkflowConcurrency struct {
@@ -217,6 +219,10 @@ func (j *WorkflowJob) ToWorkflow(svcName string, namespace string) types.Workflo
}
}
if j.StickyStrategy != nil {
w.StickyStrategy = j.StickyStrategy
}
return w
}
@@ -284,6 +290,8 @@ type WorkflowStep struct {
Retries int
RateLimit []RateLimit
DesiredLabels map[string]*types.DesiredWorkerLabel
}
type RateLimit struct {
@@ -307,6 +315,11 @@ func (w *WorkflowStep) SetName(name string) *WorkflowStep {
return w
}
func (w *WorkflowStep) SetDesiredLabels(labels map[string]*types.DesiredWorkerLabel) *WorkflowStep {
w.DesiredLabels = labels
return w
}
func (w *WorkflowStep) SetRateLimit(rateLimit RateLimit) *WorkflowStep {
w.RateLimit = append(w.RateLimit, rateLimit)
return w
@@ -375,12 +388,13 @@ func (w *WorkflowStep) ToWorkflowStep(svcName string, index int, namespace strin
res.Id = w.GetStepId(index)
res.APIStep = types.WorkflowStep{
Name: res.Id,
ID: w.GetStepId(index),
Timeout: w.Timeout,
ActionID: w.GetActionId(svcName, index),
Parents: []string{},
Retries: w.Retries,
Name: res.Id,
ID: w.GetStepId(index),
Timeout: w.Timeout,
ActionID: w.GetActionId(svcName, index),
Parents: []string{},
Retries: w.Retries,
DesiredLabels: w.DesiredLabels,
}
for _, rateLimit := range w.RateLimit {