diff --git a/api-contracts/dispatcher/dispatcher.proto b/api-contracts/dispatcher/dispatcher.proto
index 95f4db1d1..79b854f32 100644
--- a/api-contracts/dispatcher/dispatcher.proto
+++ b/api-contracts/dispatcher/dispatcher.proto
@@ -236,7 +236,13 @@ message ActionEventResponse {
message SubscribeToWorkflowEventsRequest {
// the id of the workflow run
- string workflowRunId = 1;
+ optional string workflowRunId = 1;
+
+ // the key of the additional meta field to subscribe to
+ optional string additionalMetaKey = 2;
+
+ // the value of the additional meta field to subscribe to
+ optional string additionalMetaValue = 3;
}
message SubscribeToWorkflowRunsRequest {
diff --git a/cmd/hatchet-engine/engine/run.go b/cmd/hatchet-engine/engine/run.go
index 0df51dcd2..3e24c5aa0 100644
--- a/cmd/hatchet-engine/engine/run.go
+++ b/cmd/hatchet-engine/engine/run.go
@@ -23,6 +23,7 @@ import (
"github.com/hatchet-dev/hatchet/internal/telemetry"
"github.com/hatchet-dev/hatchet/pkg/config/loader"
"github.com/hatchet-dev/hatchet/pkg/config/server"
+ "github.com/hatchet-dev/hatchet/pkg/repository/cache"
)
type Teardown struct {
@@ -280,6 +281,9 @@ func RunWithConfig(ctx context.Context, sc *server.ServerConfig) ([]Teardown, er
}
if sc.HasService("grpc") {
+
+ cacheInstance := cache.New(10 * time.Second)
+
// create the dispatcher
d, err := dispatcher.New(
dispatcher.WithAlerter(sc.Alerter),
@@ -287,6 +291,7 @@ func RunWithConfig(ctx context.Context, sc *server.ServerConfig) ([]Teardown, er
dispatcher.WithRepository(sc.EngineRepository),
dispatcher.WithLogger(sc.Logger),
dispatcher.WithEntitlementsRepository(sc.EntitlementRepository),
+ dispatcher.WithCache(cacheInstance),
)
if err != nil {
@@ -362,6 +367,8 @@ func RunWithConfig(ctx context.Context, sc *server.ServerConfig) ([]Teardown, er
if err != nil {
return fmt.Errorf("failed to cleanup dispatcher: %w", err)
}
+
+ cacheInstance.Stop()
return nil
})
diff --git a/examples/stream-event-by-meta/main.go b/examples/stream-event-by-meta/main.go
new file mode 100644
index 000000000..4e503618e
--- /dev/null
+++ b/examples/stream-event-by-meta/main.go
@@ -0,0 +1,110 @@
+package main
+
+import (
+ "fmt"
+ "time"
+
+ "github.com/joho/godotenv"
+ "golang.org/x/exp/rand"
+
+ "github.com/hatchet-dev/hatchet/pkg/client"
+ "github.com/hatchet-dev/hatchet/pkg/cmdutils"
+ "github.com/hatchet-dev/hatchet/pkg/worker"
+)
+
+type streamEventInput struct {
+ Index int `json:"index"`
+}
+
+type stepOneOutput struct {
+ Message string `json:"message"`
+}
+
+func StepOne(ctx worker.HatchetContext) (result *stepOneOutput, err error) {
+ input := &streamEventInput{}
+
+ err = ctx.WorkflowInput(input)
+
+ if err != nil {
+ return nil, err
+ }
+
+ ctx.StreamEvent([]byte(fmt.Sprintf("This is a stream event %d", input.Index)))
+
+ return &stepOneOutput{
+ Message: fmt.Sprintf("This ran at %s", time.Now().String()),
+ }, nil
+}
+
+func main() {
+ err := godotenv.Load()
+
+ if err != nil {
+ panic(err)
+ }
+
+ c, err := client.New()
+
+ if err != nil {
+ panic(err)
+ }
+
+ w, err := worker.NewWorker(
+ worker.WithClient(
+ c,
+ ),
+ )
+
+ if err != nil {
+ panic(err)
+ }
+
+ err = w.On(
+ worker.NoTrigger(),
+ &worker.WorkflowJob{
+ Name: "stream-event-workflow",
+ Description: "This sends a stream event.",
+ Steps: []*worker.WorkflowStep{
+ worker.Fn(StepOne).SetName("step-one"),
+ },
+ },
+ )
+
+ if err != nil {
+ panic(err)
+ }
+
+ interruptCtx, cancel := cmdutils.InterruptContextFromChan(cmdutils.InterruptChan())
+ defer cancel()
+
+ _, err = w.Start()
+
+ if err != nil {
+ panic(fmt.Errorf("error cleaning up: %w", err))
+ }
+
+ // Generate a random number between 1 and 100
+ streamKey := "streamKey"
+ streamValue := fmt.Sprintf("stream-event-%d", rand.Intn(100)+1)
+
+ _, err = c.Admin().RunWorkflow("stream-event-workflow", &streamEventInput{
+ Index: 0,
+ },
+ client.WithRunMetadata(map[string]interface{}{
+ streamKey: streamValue,
+ }),
+ )
+
+ if err != nil {
+ panic(err)
+ }
+
+ err = c.Subscribe().StreamByAdditionalMetadata(interruptCtx, streamKey, streamValue, func(event client.StreamEvent) error {
+ fmt.Println(string(event.Message))
+ return nil
+ })
+
+ if err != nil {
+ panic(err)
+ }
+}
diff --git a/frontend/docs/pages/home/features/streaming.mdx b/frontend/docs/pages/home/features/streaming.mdx
index 61e1da646..cfb06acbd 100644
--- a/frontend/docs/pages/home/features/streaming.mdx
+++ b/frontend/docs/pages/home/features/streaming.mdx
@@ -27,7 +27,7 @@ Listeners are used to subscribe to the event stream for a specific workflow run.
Here's an example of how to create a listener:
-
+
```python
@@ -59,6 +59,26 @@ async function listen_for_files() {
```
+```go copy
+workflowRunId, err := c.Admin().RunWorkflow("stream-event-workflow", &streamEventInput{
+ Index: 0,
+})
+
+if err != nil {
+panic(err)
+}
+
+err = c.Subscribe().Stream(interruptCtx, workflowRunId, func(event client.StreamEvent) error {
+fmt.Println(string(event.Message))
+
+return nil
+})
+
+if err != nil {
+panic(err)
+}
+
+````
## Streaming from a Step Context
@@ -75,7 +95,7 @@ def step1(self, context: Context):
context.put_stream('hello from step1')
# continue with the step run...
return {"step1": "results"}
-```
+````
@@ -225,6 +245,111 @@ for await (const event of listener) {
+
+## Streaming by Additional Metadata
+
+Often it is helpful to stream from multiple workflows (i.e. child workflows spawned from a parent) to achieve this, you can specify an [additional meta](/features/additional-metadata) key-value pair before runing a workflow that can then be used to subscribe to all events from workflows that have the same key-value pair.
+
+Since additinoal metadata is propegated from parent to child workflows, this can be used to track all events from a specific workflow run.
+
+Here's an example of how to create a listener:
+
+
+
+
+```python
+# Generate a random stream key to use to track all
+# stream events for this workflow run.
+streamKey = "streamKey"
+streamVal = f"sk-{random.randint(1, 100)}"
+
+# Specify the stream key as additional metadata
+# when running the workflow.
+
+# This key gets propagated to all child workflows
+# and can have an arbitrary property name.
+
+workflowRun = hatchet.admin.run_workflow(
+ "Parent",
+ {"n": 2},
+ options={"additional_metadata": {streamKey: streamVal}},
+)
+
+# Stream all events for the additional meta key value
+listener = hatchet.listener.stream_by_additional_metadata(streamKey, streamVal)
+
+async for event in listener:
+ print(event.type, event.payload)
+
+```
+
+
+
+
+```typescript copy
+// Generate a random stream key to use to track all
+// stream events for this workflow run.
+const streamKey = "streamKey";
+const streamVal = `sk-${Math.random().toString(36).substring(7)}`;
+
+// Specify the stream key as additional metadata
+// when running the workflow.
+
+// This key gets propagated to all child workflows
+// and can have an arbitrary property name.
+await hatchet.admin.runWorkflow(
+ "parent-workflow",
+ {},
+ { additionalMetadata: { [streamKey]: streamVal } },
+);
+
+// Stream all events for the additional meta key value
+const stream = await hatchet.listener.streamByAdditionalMeta(
+ streamKey,
+ streamVal,
+);
+
+for await (const event of stream) {
+ console.log("event received", event);
+}
+```
+
+
+
+
+```go copy
+// Generate a random stream key to use to track all
+// stream events for this workflow run.
+streamKey := "streamKey"
+streamValue := fmt.Sprintf("stream-event-%d", rand.Intn(100)+1)
+
+// Specify the stream key as additional metadata
+// when running the workflow.
+
+// This key gets propagated to all child workflows
+// and can have an arbitrary property name.
+\_, err = c.Admin().RunWorkflow("stream-event-workflow", &streamEventInput{
+ Index: 0,
+ },
+ client.WithRunMetadata(map[string]interface{}{
+ streamKey: streamValue,
+ }),
+)
+
+if err != nil {
+ panic(err)
+}
+
+// Stream all events for the additional meta key value
+err = c.Subscribe().StreamByAdditionalMetadata(interruptCtx, streamKey, streamValue, func(event client.StreamEvent) error {
+ fmt.Println(string(event.Message))
+ return nil
+})
+```
+
+
+
+
## Consuming Streams on Frontend
To consume a stream from the backend, create a Streaming Response endpoint to "proxy" the stream from the Hatchet workflow run.
diff --git a/internal/services/admin/server.go b/internal/services/admin/server.go
index da7f9405c..a775a6695 100644
--- a/internal/services/admin/server.go
+++ b/internal/services/admin/server.go
@@ -137,6 +137,10 @@ func (a *AdminServiceImpl) TriggerWorkflow(ctx context.Context, req *contracts.T
additionalMetadata,
parentAdditionalMeta,
)
+
+ if err != nil {
+ return nil, fmt.Errorf("could not create workflow run opts: %w", err)
+ }
} else {
createOpts, err = repository.GetCreateWorkflowRunOptsFromManual(workflowVersion, []byte(req.Input), additionalMetadata)
}
diff --git a/internal/services/controllers/jobs/controller.go b/internal/services/controllers/jobs/controller.go
index 120f6fe29..a26872715 100644
--- a/internal/services/controllers/jobs/controller.go
+++ b/internal/services/controllers/jobs/controller.go
@@ -1214,7 +1214,9 @@ func (ec *JobsControllerImpl) failStepRun(ctx context.Context, tenantId, stepRun
err = ec.mq.AddMessage(
ctx,
msgqueue.QueueTypeFromDispatcherID(dispatcherId),
- stepRunCancelledTask(tenantId, stepRunId, workerId, dispatcherId, *repository.StringPtr(eventMessage)),
+ stepRunCancelledTask(
+ tenantId, stepRunId, workerId, dispatcherId, *repository.StringPtr(eventMessage),
+ sqlchelpers.UUIDToStr(stepRun.WorkflowRunId), &stepRun.StepRetries, &stepRun.SRRetryCount),
)
if err != nil {
@@ -1326,7 +1328,9 @@ func (ec *JobsControllerImpl) cancelStepRun(ctx context.Context, tenantId, stepR
err = ec.mq.AddMessage(
ctx,
msgqueue.QueueTypeFromDispatcherID(dispatcherId),
- stepRunCancelledTask(tenantId, stepRunId, workerId, dispatcherId, reason),
+ stepRunCancelledTask(tenantId, stepRunId, workerId, dispatcherId, reason,
+ sqlchelpers.UUIDToStr(stepRun.WorkflowRunId), &stepRun.StepRetries, &stepRun.SRRetryCount,
+ ),
)
if err != nil {
@@ -1379,11 +1383,14 @@ func stepRunAssignedTask(tenantId, stepRunId, workerId, dispatcherId string) *ms
}
}
-func stepRunCancelledTask(tenantId, stepRunId, workerId, dispatcherId, cancelledReason string) *msgqueue.Message {
+func stepRunCancelledTask(tenantId, stepRunId, workerId, dispatcherId, cancelledReason string, runId string, retries *int32, retryCount *int32) *msgqueue.Message {
payload, _ := datautils.ToJSONMap(tasktypes.StepRunCancelledTaskPayload{
+ WorkflowRunId: runId,
StepRunId: stepRunId,
WorkerId: workerId,
CancelledReason: cancelledReason,
+ StepRetries: retries,
+ RetryCount: retryCount,
})
metadata, _ := datautils.ToJSONMap(tasktypes.StepRunCancelledTaskMetadata{
@@ -1391,6 +1398,7 @@ func stepRunCancelledTask(tenantId, stepRunId, workerId, dispatcherId, cancelled
DispatcherId: dispatcherId,
})
+ // TODO add additional metadata
return &msgqueue.Message{
ID: "step-run-cancelled",
Payload: payload,
diff --git a/internal/services/dispatcher/contracts/dispatcher.pb.go b/internal/services/dispatcher/contracts/dispatcher.pb.go
index 47346cb1b..b2bfd398a 100644
--- a/internal/services/dispatcher/contracts/dispatcher.pb.go
+++ b/internal/services/dispatcher/contracts/dispatcher.pb.go
@@ -1274,7 +1274,11 @@ type SubscribeToWorkflowEventsRequest struct {
unknownFields protoimpl.UnknownFields
// the id of the workflow run
- WorkflowRunId string `protobuf:"bytes,1,opt,name=workflowRunId,proto3" json:"workflowRunId,omitempty"`
+ WorkflowRunId *string `protobuf:"bytes,1,opt,name=workflowRunId,proto3,oneof" json:"workflowRunId,omitempty"`
+ // the key of the additional meta field to subscribe to
+ AdditionalMetaKey *string `protobuf:"bytes,2,opt,name=additionalMetaKey,proto3,oneof" json:"additionalMetaKey,omitempty"`
+ // the value of the additional meta field to subscribe to
+ AdditionalMetaValue *string `protobuf:"bytes,3,opt,name=additionalMetaValue,proto3,oneof" json:"additionalMetaValue,omitempty"`
}
func (x *SubscribeToWorkflowEventsRequest) Reset() {
@@ -1310,8 +1314,22 @@ func (*SubscribeToWorkflowEventsRequest) Descriptor() ([]byte, []int) {
}
func (x *SubscribeToWorkflowEventsRequest) GetWorkflowRunId() string {
- if x != nil {
- return x.WorkflowRunId
+ if x != nil && x.WorkflowRunId != nil {
+ return *x.WorkflowRunId
+ }
+ return ""
+}
+
+func (x *SubscribeToWorkflowEventsRequest) GetAdditionalMetaKey() string {
+ if x != nil && x.AdditionalMetaKey != nil {
+ return *x.AdditionalMetaKey
+ }
+ return ""
+}
+
+func (x *SubscribeToWorkflowEventsRequest) GetAdditionalMetaValue() string {
+ if x != nil && x.AdditionalMetaValue != nil {
+ return *x.AdditionalMetaValue
}
return ""
}
@@ -2189,12 +2207,23 @@ var file_dispatcher_proto_rawDesc = []byte{
0x12, 0x1a, 0x0a, 0x08, 0x74, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01,
0x28, 0x09, 0x52, 0x08, 0x74, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x1a, 0x0a, 0x08,
0x77, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x49, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08,
- 0x77, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x49, 0x64, 0x22, 0x48, 0x0a, 0x20, 0x53, 0x75, 0x62, 0x73,
- 0x63, 0x72, 0x69, 0x62, 0x65, 0x54, 0x6f, 0x57, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, 0x45,
- 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x24, 0x0a, 0x0d,
- 0x77, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, 0x52, 0x75, 0x6e, 0x49, 0x64, 0x18, 0x01, 0x20,
- 0x01, 0x28, 0x09, 0x52, 0x0d, 0x77, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, 0x52, 0x75, 0x6e,
- 0x49, 0x64, 0x22, 0x46, 0x0a, 0x1e, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x54,
+ 0x77, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x49, 0x64, 0x22, 0xf7, 0x01, 0x0a, 0x20, 0x53, 0x75, 0x62,
+ 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x54, 0x6f, 0x57, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77,
+ 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x29, 0x0a,
+ 0x0d, 0x77, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, 0x52, 0x75, 0x6e, 0x49, 0x64, 0x18, 0x01,
+ 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0d, 0x77, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77,
+ 0x52, 0x75, 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x12, 0x31, 0x0a, 0x11, 0x61, 0x64, 0x64, 0x69,
+ 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x4d, 0x65, 0x74, 0x61, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20,
+ 0x01, 0x28, 0x09, 0x48, 0x01, 0x52, 0x11, 0x61, 0x64, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x61,
+ 0x6c, 0x4d, 0x65, 0x74, 0x61, 0x4b, 0x65, 0x79, 0x88, 0x01, 0x01, 0x12, 0x35, 0x0a, 0x13, 0x61,
+ 0x64, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x4d, 0x65, 0x74, 0x61, 0x56, 0x61, 0x6c,
+ 0x75, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x48, 0x02, 0x52, 0x13, 0x61, 0x64, 0x64, 0x69,
+ 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x4d, 0x65, 0x74, 0x61, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x88,
+ 0x01, 0x01, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x77, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, 0x52,
+ 0x75, 0x6e, 0x49, 0x64, 0x42, 0x14, 0x0a, 0x12, 0x5f, 0x61, 0x64, 0x64, 0x69, 0x74, 0x69, 0x6f,
+ 0x6e, 0x61, 0x6c, 0x4d, 0x65, 0x74, 0x61, 0x4b, 0x65, 0x79, 0x42, 0x16, 0x0a, 0x14, 0x5f, 0x61,
+ 0x64, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x4d, 0x65, 0x74, 0x61, 0x56, 0x61, 0x6c,
+ 0x75, 0x65, 0x22, 0x46, 0x0a, 0x1e, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x54,
0x6f, 0x57, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, 0x52, 0x75, 0x6e, 0x73, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x77, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77,
0x52, 0x75, 0x6e, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x77, 0x6f, 0x72,
@@ -2810,6 +2839,7 @@ func file_dispatcher_proto_init() {
file_dispatcher_proto_msgTypes[0].OneofWrappers = []interface{}{}
file_dispatcher_proto_msgTypes[1].OneofWrappers = []interface{}{}
file_dispatcher_proto_msgTypes[5].OneofWrappers = []interface{}{}
+ file_dispatcher_proto_msgTypes[12].OneofWrappers = []interface{}{}
file_dispatcher_proto_msgTypes[14].OneofWrappers = []interface{}{}
file_dispatcher_proto_msgTypes[16].OneofWrappers = []interface{}{}
type x struct{}
diff --git a/internal/services/dispatcher/dispatcher.go b/internal/services/dispatcher/dispatcher.go
index aa39a52df..e31db004e 100644
--- a/internal/services/dispatcher/dispatcher.go
+++ b/internal/services/dispatcher/dispatcher.go
@@ -21,6 +21,7 @@ import (
"github.com/hatchet-dev/hatchet/internal/telemetry/servertel"
"github.com/hatchet-dev/hatchet/pkg/logger"
"github.com/hatchet-dev/hatchet/pkg/repository"
+ "github.com/hatchet-dev/hatchet/pkg/repository/cache"
"github.com/hatchet-dev/hatchet/pkg/repository/prisma/dbsqlc"
"github.com/hatchet-dev/hatchet/pkg/repository/prisma/sqlchelpers"
"github.com/hatchet-dev/hatchet/pkg/validator"
@@ -36,12 +37,13 @@ type Dispatcher interface {
type DispatcherImpl struct {
contracts.UnimplementedDispatcherServer
- s gocron.Scheduler
- mq msgqueue.MessageQueue
- l *zerolog.Logger
- dv datautils.DataDecoderValidator
- v validator.Validator
- repo repository.EngineRepository
+ s gocron.Scheduler
+ mq msgqueue.MessageQueue
+ l *zerolog.Logger
+ dv datautils.DataDecoderValidator
+ v validator.Validator
+ repo repository.EngineRepository
+ cache cache.Cacheable
entitlements repository.EntitlementsRepository
@@ -121,6 +123,7 @@ type DispatcherOpts struct {
entitlements repository.EntitlementsRepository
dispatcherId string
alerter hatcheterrors.Alerter
+ cache cache.Cacheable
}
func defaultDispatcherOpts() *DispatcherOpts {
@@ -177,6 +180,12 @@ func WithDispatcherId(dispatcherId string) DispatcherOpt {
}
}
+func WithCache(cache cache.Cacheable) DispatcherOpt {
+ return func(opts *DispatcherOpts) {
+ opts.cache = cache
+ }
+}
+
func New(fs ...DispatcherOpt) (*DispatcherImpl, error) {
opts := defaultDispatcherOpts()
@@ -196,6 +205,10 @@ func New(fs ...DispatcherOpt) (*DispatcherImpl, error) {
return nil, fmt.Errorf("entitlements repository is required. use WithEntitlementsRepository")
}
+ if opts.cache == nil {
+ return nil, fmt.Errorf("cache is required. use WithCache")
+ }
+
newLogger := opts.l.With().Str("service", "dispatcher").Logger()
opts.l = &newLogger
@@ -220,6 +233,7 @@ func New(fs ...DispatcherOpt) (*DispatcherImpl, error) {
workers: &workers{},
s: s,
a: a,
+ cache: opts.cache,
}, nil
}
diff --git a/internal/services/dispatcher/server.go b/internal/services/dispatcher/server.go
index 2e0958422..2c594f867 100644
--- a/internal/services/dispatcher/server.go
+++ b/internal/services/dispatcher/server.go
@@ -2,6 +2,8 @@ package dispatcher
import (
"context"
+ "crypto/sha256"
+ "encoding/base64"
"encoding/json"
"errors"
"fmt"
@@ -23,6 +25,7 @@ import (
"github.com/hatchet-dev/hatchet/internal/services/shared/tasktypes"
"github.com/hatchet-dev/hatchet/internal/telemetry"
"github.com/hatchet-dev/hatchet/pkg/repository"
+ "github.com/hatchet-dev/hatchet/pkg/repository/cache"
"github.com/hatchet-dev/hatchet/pkg/repository/metered"
"github.com/hatchet-dev/hatchet/pkg/repository/prisma/dbsqlc"
"github.com/hatchet-dev/hatchet/pkg/repository/prisma/sqlchelpers"
@@ -484,12 +487,118 @@ func (s *DispatcherImpl) ReleaseSlot(ctx context.Context, req *contracts.Release
return &contracts.ReleaseSlotResponse{}, nil
}
-// SubscribeToWorkflowEvents registers workflow events with the dispatcher
func (s *DispatcherImpl) SubscribeToWorkflowEvents(request *contracts.SubscribeToWorkflowEventsRequest, stream contracts.Dispatcher_SubscribeToWorkflowEventsServer) error {
+
+ fmt.Println("SubscribeToWorkflowEvents")
+ fmt.Println(request)
+
+ if request.WorkflowRunId != nil {
+ return s.subscribeToWorkflowEventsByWorkflowRunId(*request.WorkflowRunId, stream)
+ } else if request.AdditionalMetaKey != nil && request.AdditionalMetaValue != nil {
+ return s.subscribeToWorkflowEventsByAdditionalMeta(*request.AdditionalMetaKey, *request.AdditionalMetaValue, stream)
+ }
+
+ return status.Errorf(codes.InvalidArgument, "either workflow run id or additional meta key-value must be provided")
+}
+
+// SubscribeToWorkflowEvents registers workflow events with the dispatcher
+func (s *DispatcherImpl) subscribeToWorkflowEventsByAdditionalMeta(key string, value string, stream contracts.Dispatcher_SubscribeToWorkflowEventsServer) error {
tenant := stream.Context().Value("tenant").(*dbsqlc.Tenant)
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
- s.l.Debug().Msgf("Received subscribe request for workflow: %s", request.WorkflowRunId)
+ s.l.Error().Msgf("Received subscribe request for additional meta key-value: {%s: %s}", key, value)
+
+ q, err := msgqueue.TenantEventConsumerQueue(tenantId)
+ if err != nil {
+ return err
+ }
+
+ ctx, cancel := context.WithCancel(stream.Context())
+ defer cancel()
+
+ wg := sync.WaitGroup{}
+
+ // Keep track of active workflow run IDs
+ activeRunIds := make(map[string]struct{})
+ var mu sync.Mutex // Mutex to protect activeRunIds
+
+ f := func(task *msgqueue.Message) error {
+ wg.Add(1)
+ defer wg.Done()
+
+ e, err := s.tenantTaskToWorkflowEventByAdditionalMeta(
+ task, tenantId, key, value,
+ func(e *contracts.WorkflowEvent) (bool, error) {
+ mu.Lock()
+ defer mu.Unlock()
+
+ if e.WorkflowRunId == "" {
+ return false, nil
+ }
+
+ if e.ResourceType != contracts.ResourceType_RESOURCE_TYPE_WORKFLOW_RUN &&
+ e.EventType != contracts.ResourceEventType_RESOURCE_EVENT_TYPE_COMPLETED {
+ // Add the run ID to active runs
+ activeRunIds[e.WorkflowRunId] = struct{}{}
+ } else {
+ // Remove the completed run from active runs
+ delete(activeRunIds, e.WorkflowRunId)
+ }
+
+ // Only return true to hang up if we've seen at least one run and all runs are completed
+ if len(activeRunIds) == 0 {
+ return true, nil
+ }
+
+ return false, nil
+ })
+
+ if err != nil {
+ s.l.Error().Err(err).Msgf("could not convert task to workflow event")
+ return nil
+ } else if e == nil {
+ return nil
+ }
+
+ // send the task to the client
+ err = stream.Send(e)
+
+ if err != nil {
+ cancel() // FIXME is this necessary?
+ s.l.Error().Err(err).Msgf("could not send workflow event to client")
+ return nil
+ }
+
+ if e.Hangup {
+ cancel()
+ }
+
+ return nil
+ }
+
+ // subscribe to the task queue for the tenant
+ cleanupQueue, err := s.mq.Subscribe(q, msgqueue.NoOpHook, f)
+
+ if err != nil {
+ return err
+ }
+
+ <-ctx.Done()
+ if err := cleanupQueue(); err != nil {
+ return fmt.Errorf("could not cleanup queue: %w", err)
+ }
+
+ waitFor(&wg, 60*time.Second, s.l)
+
+ return nil
+}
+
+// SubscribeToWorkflowEvents registers workflow events with the dispatcher
+func (s *DispatcherImpl) subscribeToWorkflowEventsByWorkflowRunId(workflowRunId string, stream contracts.Dispatcher_SubscribeToWorkflowEventsServer) error {
+ tenant := stream.Context().Value("tenant").(*dbsqlc.Tenant)
+ tenantId := sqlchelpers.UUIDToStr(tenant.ID)
+
+ s.l.Debug().Msgf("Received subscribe request for workflow: %s", workflowRunId)
q, err := msgqueue.TenantEventConsumerQueue(tenantId)
@@ -501,11 +610,11 @@ func (s *DispatcherImpl) SubscribeToWorkflowEvents(request *contracts.SubscribeT
defer cancel()
// if the workflow run is in a final state, hang up the connection
- workflowRun, err := s.repo.WorkflowRun().GetWorkflowRunById(ctx, tenantId, request.WorkflowRunId)
+ workflowRun, err := s.repo.WorkflowRun().GetWorkflowRunById(ctx, tenantId, workflowRunId)
if err != nil {
if errors.Is(err, repository.ErrWorkflowRunNotFound) {
- return status.Errorf(codes.NotFound, "workflow run %s not found", request.WorkflowRunId)
+ return status.Errorf(codes.NotFound, "workflow run %s not found", workflowRunId)
}
return err
@@ -521,7 +630,7 @@ func (s *DispatcherImpl) SubscribeToWorkflowEvents(request *contracts.SubscribeT
wg.Add(1)
defer wg.Done()
- e, err := s.tenantTaskToWorkflowEvent(task, tenantId, request.WorkflowRunId)
+ e, err := s.tenantTaskToWorkflowEventByRunId(task, tenantId, workflowRunId)
if err != nil {
s.l.Error().Err(err).Msgf("could not convert task to workflow event")
@@ -901,9 +1010,18 @@ func (s *DispatcherImpl) handleStepRunStarted(ctx context.Context, request *cont
startedAt := request.EventTimestamp.AsTime()
+ sr, err := s.repo.StepRun().GetStepRunForEngine(ctx, tenantId, request.StepRunId)
+
+ if err != nil {
+ return nil, err
+ }
+
payload, _ := datautils.ToJSONMap(tasktypes.StepRunStartedTaskPayload{
- StepRunId: request.StepRunId,
- StartedAt: startedAt.Format(time.RFC3339),
+ StepRunId: request.StepRunId,
+ StartedAt: startedAt.Format(time.RFC3339),
+ WorkflowRunId: sqlchelpers.UUIDToStr(sr.WorkflowRunId),
+ StepRetries: &sr.StepRetries,
+ RetryCount: &sr.SRRetryCount,
})
metadata, _ := datautils.ToJSONMap(tasktypes.StepRunStartedTaskMetadata{
@@ -911,7 +1029,7 @@ func (s *DispatcherImpl) handleStepRunStarted(ctx context.Context, request *cont
})
// send the event to the jobs queue
- err := s.mq.AddMessage(ctx, msgqueue.JOB_PROCESSING_QUEUE, &msgqueue.Message{
+ err = s.mq.AddMessage(ctx, msgqueue.JOB_PROCESSING_QUEUE, &msgqueue.Message{
ID: "step-run-started",
Payload: payload,
Metadata: metadata,
@@ -955,10 +1073,19 @@ func (s *DispatcherImpl) handleStepRunCompleted(ctx context.Context, request *co
finishedAt := request.EventTimestamp.AsTime()
+ meta, err := s.repo.StepRun().GetStepRunMetaForEngine(ctx, tenantId, request.StepRunId)
+
+ if err != nil {
+ return nil, err
+ }
+
payload, _ := datautils.ToJSONMap(tasktypes.StepRunFinishedTaskPayload{
+ WorkflowRunId: sqlchelpers.UUIDToStr(meta.WorkflowRunId),
StepRunId: request.StepRunId,
FinishedAt: finishedAt.Format(time.RFC3339),
StepOutputData: request.EventPayload,
+ StepRetries: &meta.Retries,
+ RetryCount: &meta.RetryCount,
})
metadata, _ := datautils.ToJSONMap(tasktypes.StepRunFinishedTaskMetadata{
@@ -966,7 +1093,7 @@ func (s *DispatcherImpl) handleStepRunCompleted(ctx context.Context, request *co
})
// send the event to the jobs queue
- err := s.mq.AddMessage(ctx, msgqueue.JOB_PROCESSING_QUEUE, &msgqueue.Message{
+ err = s.mq.AddMessage(ctx, msgqueue.JOB_PROCESSING_QUEUE, &msgqueue.Message{
ID: "step-run-finished",
Payload: payload,
Metadata: metadata,
@@ -991,10 +1118,19 @@ func (s *DispatcherImpl) handleStepRunFailed(ctx context.Context, request *contr
failedAt := request.EventTimestamp.AsTime()
+ meta, err := s.repo.StepRun().GetStepRunMetaForEngine(ctx, tenantId, request.StepRunId)
+
+ if err != nil {
+ return nil, err
+ }
+
payload, _ := datautils.ToJSONMap(tasktypes.StepRunFailedTaskPayload{
- StepRunId: request.StepRunId,
- FailedAt: failedAt.Format(time.RFC3339),
- Error: request.EventPayload,
+ WorkflowRunId: sqlchelpers.UUIDToStr(meta.WorkflowRunId),
+ StepRunId: request.StepRunId,
+ FailedAt: failedAt.Format(time.RFC3339),
+ Error: request.EventPayload,
+ StepRetries: &meta.Retries,
+ RetryCount: &meta.RetryCount,
})
metadata, _ := datautils.ToJSONMap(tasktypes.StepRunFailedTaskMetadata{
@@ -1002,7 +1138,7 @@ func (s *DispatcherImpl) handleStepRunFailed(ctx context.Context, request *contr
})
// send the event to the jobs queue
- err := s.mq.AddMessage(ctx, msgqueue.JOB_PROCESSING_QUEUE, &msgqueue.Message{
+ err = s.mq.AddMessage(ctx, msgqueue.JOB_PROCESSING_QUEUE, &msgqueue.Message{
ID: "step-run-failed",
Payload: payload,
Metadata: metadata,
@@ -1126,68 +1262,143 @@ func (s *DispatcherImpl) handleGetGroupKeyRunFailed(ctx context.Context, request
}, nil
}
-func (s *DispatcherImpl) tenantTaskToWorkflowEvent(task *msgqueue.Message, tenantId string, workflowRunIds ...string) (*contracts.WorkflowEvent, error) {
+func UnmarshalPayload[T any](payload interface{}) (T, error) {
+ var result T
+
+ // Convert the payload to JSON
+ jsonData, err := json.Marshal(payload)
+ if err != nil {
+ return result, fmt.Errorf("failed to marshal payload: %w", err)
+ }
+
+ // Unmarshal JSON into the desired type
+ err = json.Unmarshal(jsonData, &result)
+ if err != nil {
+ return result, fmt.Errorf("failed to unmarshal payload: %w", err)
+ }
+
+ return result, nil
+}
+
+func (s *DispatcherImpl) taskToWorkflowEvent(task *msgqueue.Message, tenantId string, filter func(task *contracts.WorkflowEvent) (*bool, error), hangupFunc func(task *contracts.WorkflowEvent) (bool, error)) (*contracts.WorkflowEvent, error) {
workflowEvent := &contracts.WorkflowEvent{}
var stepRunId string
switch task.ID {
case "step-run-started":
- stepRunId = task.Payload["step_run_id"].(string)
- workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN
- workflowEvent.ResourceId = stepRunId
- workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_STARTED
- case "step-run-finished":
- stepRunId = task.Payload["step_run_id"].(string)
- workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN
- workflowEvent.ResourceId = stepRunId
- workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_COMPLETED
- workflowEvent.EventPayload = task.Payload["step_output_data"].(string)
- case "step-run-failed":
- stepRunId = task.Payload["step_run_id"].(string)
- workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN
- workflowEvent.ResourceId = stepRunId
- workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_FAILED
- workflowEvent.EventPayload = task.Payload["error"].(string)
- case "step-run-cancelled":
- stepRunId = task.Payload["step_run_id"].(string)
- workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN
- workflowEvent.ResourceId = stepRunId
- workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_CANCELLED
- case "step-run-timed-out":
- stepRunId = task.Payload["step_run_id"].(string)
- workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN
- workflowEvent.ResourceId = stepRunId
- workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_TIMED_OUT
- case "step-run-stream-event":
- stepRunId = task.Payload["step_run_id"].(string)
- workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN
- workflowEvent.ResourceId = stepRunId
- workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_STREAM
- case "workflow-run-finished":
- workflowRunId := task.Payload["workflow_run_id"].(string)
- workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_WORKFLOW_RUN
- workflowEvent.ResourceId = workflowRunId
- workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_COMPLETED
- workflowEvent.Hangup = true
- }
-
- if workflowEvent.ResourceType == contracts.ResourceType_RESOURCE_TYPE_STEP_RUN {
- // determine if this step run matches the workflow run id
- stepRun, err := s.repo.StepRun().GetStepRunForEngine(context.Background(), tenantId, stepRunId)
-
+ payload, err := UnmarshalPayload[tasktypes.StepRunStartedTaskPayload](task.Payload)
if err != nil {
return nil, err
}
-
- if !contains(workflowRunIds, sqlchelpers.UUIDToStr(stepRun.WorkflowRunId)) {
- // this is an expected error, so we don't return it
- return nil, nil
+ workflowEvent.WorkflowRunId = payload.WorkflowRunId
+ workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN
+ workflowEvent.ResourceId = stepRunId
+ workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_STARTED
+ workflowEvent.StepRetries = payload.StepRetries
+ workflowEvent.RetryCount = payload.RetryCount
+ case "step-run-finished":
+ payload, err := UnmarshalPayload[tasktypes.StepRunFinishedTaskPayload](task.Payload)
+ if err != nil {
+ return nil, err
}
+ workflowEvent.WorkflowRunId = payload.WorkflowRunId
+ stepRunId = task.Payload["step_run_id"].(string)
+ workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN
+ workflowEvent.ResourceId = stepRunId
+ workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_COMPLETED
+ workflowEvent.EventPayload = payload.StepOutputData
- workflowEvent.StepRetries = &stepRun.StepRetries
- workflowEvent.RetryCount = &stepRun.SRRetryCount
+ workflowEvent.StepRetries = payload.StepRetries
+ workflowEvent.RetryCount = payload.RetryCount
+ case "step-run-failed":
+ payload, err := UnmarshalPayload[tasktypes.StepRunFailedTaskPayload](task.Payload)
+ if err != nil {
+ return nil, err
+ }
+ workflowEvent.WorkflowRunId = payload.WorkflowRunId
+ stepRunId = payload.StepRunId
+ workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN
+ workflowEvent.ResourceId = stepRunId
+ workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_FAILED
+ workflowEvent.EventPayload = payload.Error
+ workflowEvent.StepRetries = payload.StepRetries
+ workflowEvent.RetryCount = payload.RetryCount
+ case "step-run-cancelled":
+ payload, err := UnmarshalPayload[tasktypes.StepRunCancelledTaskPayload](task.Payload)
+ if err != nil {
+ return nil, err
+ }
+ workflowEvent.WorkflowRunId = payload.WorkflowRunId
+ stepRunId = payload.StepRunId
+ workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN
+ workflowEvent.ResourceId = stepRunId
+ workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_CANCELLED
+
+ workflowEvent.StepRetries = payload.StepRetries
+ workflowEvent.RetryCount = payload.RetryCount
+ case "step-run-timed-out":
+ payload, err := UnmarshalPayload[tasktypes.StepRunTimedOutTaskPayload](task.Payload)
+ if err != nil {
+ return nil, err
+ }
+ workflowEvent.WorkflowRunId = payload.WorkflowRunId
+ stepRunId = payload.StepRunId
+ workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN
+ workflowEvent.ResourceId = stepRunId
+ workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_TIMED_OUT
+
+ workflowEvent.StepRetries = payload.StepRetries
+ workflowEvent.RetryCount = payload.RetryCount
+ case "step-run-stream-event":
+ payload, err := UnmarshalPayload[tasktypes.StepRunStreamEventTaskPayload](task.Payload)
+ if err != nil {
+ return nil, err
+ }
+ workflowEvent.WorkflowRunId = payload.WorkflowRunId
+ stepRunId = payload.StepRunId
+ workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN
+ workflowEvent.ResourceId = stepRunId
+ workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_STREAM
+
+ workflowEvent.StepRetries = payload.StepRetries
+ workflowEvent.RetryCount = payload.RetryCount
+ case "workflow-run-finished":
+ payload, err := UnmarshalPayload[tasktypes.WorkflowRunFinishedTask](task.Payload)
+ if err != nil {
+ return nil, err
+ }
+ workflowRunId := payload.WorkflowRunId
+ workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_WORKFLOW_RUN
+ workflowEvent.ResourceId = workflowRunId
+ workflowEvent.WorkflowRunId = workflowRunId
+ workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_COMPLETED
+ }
+
+ match, err := filter(workflowEvent)
+
+ if err != nil {
+ return nil, err
+ }
+
+ if match != nil && !*match {
+ // if not a match, we don't return it
+ return nil, nil
+ }
+
+ hangup, err := hangupFunc(workflowEvent)
+
+ if err != nil {
+ return nil, err
+ }
+
+ if hangup {
+ workflowEvent.Hangup = true
+ return workflowEvent, nil
+ }
+
+ if workflowEvent.ResourceType == contracts.ResourceType_RESOURCE_TYPE_STEP_RUN {
if workflowEvent.EventType == contracts.ResourceEventType_RESOURCE_EVENT_TYPE_STREAM {
streamEventId, err := strconv.ParseInt(task.Metadata["stream_event_id"].(string), 10, 64)
if err != nil {
@@ -1202,13 +1413,96 @@ func (s *DispatcherImpl) tenantTaskToWorkflowEvent(task *msgqueue.Message, tenan
workflowEvent.EventPayload = string(streamEvent.Message)
}
+ }
- } else if workflowEvent.ResourceType == contracts.ResourceType_RESOURCE_TYPE_WORKFLOW_RUN {
- if !contains(workflowRunIds, workflowEvent.ResourceId) {
- return nil, nil
- }
+ return workflowEvent, nil
+}
- workflowEvent.Hangup = true
+func (s *DispatcherImpl) tenantTaskToWorkflowEventByRunId(task *msgqueue.Message, tenantId string, workflowRunIds ...string) (*contracts.WorkflowEvent, error) {
+
+ workflowEvent, err := s.taskToWorkflowEvent(task, tenantId,
+ func(e *contracts.WorkflowEvent) (*bool, error) {
+ hit := contains(workflowRunIds, e.WorkflowRunId)
+ return &hit, nil
+ },
+ func(e *contracts.WorkflowEvent) (bool, error) {
+ // hangup on complete
+ return e.ResourceType == contracts.ResourceType_RESOURCE_TYPE_WORKFLOW_RUN &&
+ e.EventType == contracts.ResourceEventType_RESOURCE_EVENT_TYPE_COMPLETED, nil
+ },
+ )
+
+ if err != nil {
+ return nil, err
+ }
+
+ return workflowEvent, nil
+}
+
+func tinyHash(key, value string) string {
+ // Combine key and value
+ combined := fmt.Sprintf("%s:%s", key, value)
+
+ // Create SHA-256 hash
+ hash := sha256.Sum256([]byte(combined))
+
+ // Take first 8 bytes of the hash
+ shortHash := hash[:8]
+
+ // Encode to base64
+ encoded := base64.URLEncoding.EncodeToString(shortHash)
+
+ // Remove padding
+ return encoded[:11]
+}
+
+func (s *DispatcherImpl) tenantTaskToWorkflowEventByAdditionalMeta(task *msgqueue.Message, tenantId string, key string, value string, hangup func(e *contracts.WorkflowEvent) (bool, error)) (*contracts.WorkflowEvent, error) {
+ workflowEvent, err := s.taskToWorkflowEvent(
+ task,
+ tenantId,
+ func(e *contracts.WorkflowEvent) (*bool, error) {
+ return cache.MakeCacheable[bool](
+ s.cache,
+ fmt.Sprintf("wfram-%s-%s-%s", tenantId, e.WorkflowRunId, tinyHash(key, value)),
+ func() (*bool, error) {
+
+ if e.WorkflowRunId == "" {
+ return nil, nil
+ }
+
+ am, err := s.repo.WorkflowRun().GetWorkflowRunAdditionalMeta(context.Background(), tenantId, e.WorkflowRunId)
+
+ if err != nil {
+ return nil, err
+ }
+
+ if am.AdditionalMetadata == nil {
+ f := false
+ return &f, nil
+ }
+
+ var additionalMetaMap map[string]interface{}
+ err = json.Unmarshal(am.AdditionalMetadata, &additionalMetaMap)
+ if err != nil {
+ return nil, err
+ }
+
+ if v, ok := (additionalMetaMap)[key]; ok && v == value {
+ t := true
+ return &t, nil
+ }
+
+ f := false
+ return &f, nil
+
+ },
+ )
+ },
+ hangup,
+ )
+
+ if err != nil {
+ return nil, err
}
return workflowEvent, nil
diff --git a/internal/services/ingestor/server.go b/internal/services/ingestor/server.go
index 81ae50feb..fb78e8958 100644
--- a/internal/services/ingestor/server.go
+++ b/internal/services/ingestor/server.go
@@ -102,6 +102,12 @@ func (i *IngestorImpl) PutStreamEvent(ctx context.Context, req *contracts.PutStr
return nil, status.Errorf(codes.InvalidArgument, "Invalid request: %s", err)
}
+ meta, err := i.streamEventRepository.GetStreamEventMeta(ctx, tenantId, req.StepRunId)
+
+ if err != nil {
+ return nil, err
+ }
+
streamEvent, err := i.streamEventRepository.PutStreamEvent(ctx, tenantId, &opts)
if err != nil {
@@ -114,7 +120,9 @@ func (i *IngestorImpl) PutStreamEvent(ctx context.Context, req *contracts.PutStr
return nil, err
}
- err = i.mq.AddMessage(context.Background(), q, streamEventToTask(streamEvent))
+ e := streamEventToTask(streamEvent, sqlchelpers.UUIDToStr(meta.WorkflowRunId), &meta.RetryCount, &meta.Retries)
+
+ err = i.mq.AddMessage(context.Background(), q, e)
if err != nil {
return nil, err
@@ -176,13 +184,16 @@ func toEvent(e *dbsqlc.Event) (*contracts.Event, error) {
}, nil
}
-func streamEventToTask(e *dbsqlc.StreamEvent) *msgqueue.Message {
+func streamEventToTask(e *dbsqlc.StreamEvent, workflowRunId string, retryCount *int32, retries *int32) *msgqueue.Message {
tenantId := sqlchelpers.UUIDToStr(e.TenantId)
payloadTyped := tasktypes.StepRunStreamEventTaskPayload{
+ WorkflowRunId: workflowRunId,
StepRunId: sqlchelpers.UUIDToStr(e.StepRunId),
CreatedAt: e.CreatedAt.Time.String(),
StreamEventId: strconv.FormatInt(e.ID, 10),
+ RetryCount: retryCount,
+ StepRetries: retries,
}
payload, _ := datautils.ToJSONMap(payloadTyped)
diff --git a/internal/services/shared/tasktypes/step.go b/internal/services/shared/tasktypes/step.go
index 3cbeaef00..32660649a 100644
--- a/internal/services/shared/tasktypes/step.go
+++ b/internal/services/shared/tasktypes/step.go
@@ -3,19 +3,20 @@ package tasktypes
import (
"github.com/hatchet-dev/hatchet/internal/datautils"
"github.com/hatchet-dev/hatchet/internal/msgqueue"
- "github.com/hatchet-dev/hatchet/pkg/repository/prisma/db"
"github.com/hatchet-dev/hatchet/pkg/repository/prisma/dbsqlc"
"github.com/hatchet-dev/hatchet/pkg/repository/prisma/sqlchelpers"
)
type StepRunTaskPayload struct {
- StepRunId string `json:"step_run_id" validate:"required,uuid"`
- JobRunId string `json:"job_run_id" validate:"required,uuid"`
+ WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"`
+ StepRunId string `json:"step_run_id" validate:"required,uuid"`
+ JobRunId string `json:"job_run_id" validate:"required,uuid"`
+ StepRetries *int32 `json:"step_retries,omitempty"`
+ RetryCount *int32 `json:"retry_count,omitempty"`
}
type StepRunTaskMetadata struct {
- TenantId string `json:"tenant_id" validate:"required,uuid"`
-
+ TenantId string `json:"tenant_id" validate:"required,uuid"`
StepId string `json:"step_id" validate:"required,uuid"`
ActionId string `json:"action_id" validate:"required,actionId"`
JobId string `json:"job_id" validate:"required,uuid"`
@@ -34,9 +35,12 @@ type StepRunAssignedTaskMetadata struct {
}
type StepRunCancelledTaskPayload struct {
+ WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"`
StepRunId string `json:"step_run_id" validate:"required,uuid"`
WorkerId string `json:"worker_id" validate:"required,uuid"`
CancelledReason string `json:"cancelled_reason" validate:"required"`
+ StepRetries *int32 `json:"step_retries,omitempty"`
+ RetryCount *int32 `json:"retry_count,omitempty"`
}
type StepRunCancelledTaskMetadata struct {
@@ -53,8 +57,11 @@ type StepRunRequeueTaskMetadata struct {
}
type StepRunNotifyCancelTaskPayload struct {
+ WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"`
StepRunId string `json:"step_run_id" validate:"required,uuid"`
CancelledReason string `json:"cancelled_reason" validate:"required"`
+ StepRetries *int32 `json:"step_retries,omitempty"`
+ RetryCount *int32 `json:"retry_count,omitempty"`
}
type StepRunNotifyCancelTaskMetadata struct {
@@ -62,8 +69,11 @@ type StepRunNotifyCancelTaskMetadata struct {
}
type StepRunStartedTaskPayload struct {
- StepRunId string `json:"step_run_id" validate:"required,uuid"`
- StartedAt string `json:"started_at" validate:"required"`
+ WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"`
+ StepRunId string `json:"step_run_id" validate:"required,uuid"`
+ StartedAt string `json:"started_at" validate:"required"`
+ StepRetries *int32 `json:"step_retries,omitempty"`
+ RetryCount *int32 `json:"retry_count,omitempty"`
}
type StepRunStartedTaskMetadata struct {
@@ -71,9 +81,12 @@ type StepRunStartedTaskMetadata struct {
}
type StepRunFinishedTaskPayload struct {
+ WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"`
StepRunId string `json:"step_run_id" validate:"required,uuid"`
FinishedAt string `json:"finished_at" validate:"required"`
StepOutputData string `json:"step_output_data"`
+ StepRetries *int32 `json:"step_retries,omitempty"`
+ RetryCount *int32 `json:"retry_count,omitempty"`
}
type StepRunFinishedTaskMetadata struct {
@@ -81,9 +94,12 @@ type StepRunFinishedTaskMetadata struct {
}
type StepRunStreamEventTaskPayload struct {
+ WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"`
StepRunId string `json:"step_run_id" validate:"required,uuid"`
CreatedAt string `json:"created_at" validate:"required"`
StreamEventId string `json:"stream_event_id"`
+ StepRetries *int32 `json:"step_retries,omitempty"`
+ RetryCount *int32 `json:"retry_count,omitempty"`
}
type StepRunStreamEventTaskMetadata struct {
@@ -92,9 +108,12 @@ type StepRunStreamEventTaskMetadata struct {
}
type StepRunFailedTaskPayload struct {
- StepRunId string `json:"step_run_id" validate:"required,uuid"`
- FailedAt string `json:"failed_at" validate:"required"`
- Error string `json:"error" validate:"required"`
+ WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"`
+ StepRunId string `json:"step_run_id" validate:"required,uuid"`
+ FailedAt string `json:"failed_at" validate:"required"`
+ Error string `json:"error" validate:"required"`
+ StepRetries *int32 `json:"step_retries,omitempty"`
+ RetryCount *int32 `json:"retry_count,omitempty"`
}
type StepRunFailedTaskMetadata struct {
@@ -102,7 +121,10 @@ type StepRunFailedTaskMetadata struct {
}
type StepRunTimedOutTaskPayload struct {
- StepRunId string `json:"step_run_id" validate:"required,uuid"`
+ WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"`
+ StepRunId string `json:"step_run_id" validate:"required,uuid"`
+ StepRetries *int32 `json:"step_retries,omitempty"`
+ RetryCount *int32 `json:"retry_count,omitempty"`
}
type StepRunTimedOutTaskMetadata struct {
@@ -110,11 +132,15 @@ type StepRunTimedOutTaskMetadata struct {
}
type StepRunRetryTaskPayload struct {
- StepRunId string `json:"step_run_id" validate:"required,uuid"`
- JobRunId string `json:"job_run_id" validate:"required,uuid"`
+ WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"`
+ StepRunId string `json:"step_run_id" validate:"required,uuid"`
+ JobRunId string `json:"job_run_id" validate:"required,uuid"`
// optional - if not provided, the step run will be retried with the same input
InputData string `json:"input_data,omitempty"`
+
+ StepRetries *int32 `json:"step_retries,omitempty"`
+ RetryCount *int32 `json:"retry_count,omitempty"`
}
type StepRunRetryTaskMetadata struct {
@@ -122,43 +148,33 @@ type StepRunRetryTaskMetadata struct {
}
type StepRunReplayTaskPayload struct {
- StepRunId string `json:"step_run_id" validate:"required,uuid"`
- JobRunId string `json:"job_run_id" validate:"required,uuid"`
+ WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"`
+ StepRunId string `json:"step_run_id" validate:"required,uuid"`
+ JobRunId string `json:"job_run_id" validate:"required,uuid"`
// optional - if not provided, the step run will be retried with the same input
- InputData string `json:"input_data,omitempty"`
+ InputData string `json:"input_data,omitempty"`
+ StepRetries *int32 `json:"step_retries,omitempty"`
+ RetryCount *int32 `json:"retry_count,omitempty"`
}
type StepRunReplayTaskMetadata struct {
TenantId string `json:"tenant_id" validate:"required,uuid"`
}
-func TenantToStepRunRequeueTask(tenant db.TenantModel) *msgqueue.Message {
- payload, _ := datautils.ToJSONMap(StepRunRequeueTaskPayload{
- TenantId: tenant.ID,
- })
-
- metadata, _ := datautils.ToJSONMap(StepRunRequeueTaskMetadata{
- TenantId: tenant.ID,
- })
-
- return &msgqueue.Message{
- ID: "step-run-requeue-ticker",
- Payload: payload,
- Metadata: metadata,
- Retries: 3,
- }
-}
-
func StepRunRetryToTask(stepRun *dbsqlc.GetStepRunForEngineRow, inputData []byte) *msgqueue.Message {
jobRunId := sqlchelpers.UUIDToStr(stepRun.JobRunId)
stepRunId := sqlchelpers.UUIDToStr(stepRun.SRID)
tenantId := sqlchelpers.UUIDToStr(stepRun.SRTenantId)
+ workflowRunId := sqlchelpers.UUIDToStr(stepRun.WorkflowRunId)
payload, _ := datautils.ToJSONMap(StepRunRetryTaskPayload{
- JobRunId: jobRunId,
- StepRunId: stepRunId,
- InputData: string(inputData),
+ WorkflowRunId: workflowRunId,
+ JobRunId: jobRunId,
+ StepRunId: stepRunId,
+ InputData: string(inputData),
+ StepRetries: &stepRun.StepRetries,
+ RetryCount: &stepRun.SRRetryCount,
})
metadata, _ := datautils.ToJSONMap(StepRunRetryTaskMetadata{
@@ -177,11 +193,15 @@ func StepRunReplayToTask(stepRun *dbsqlc.GetStepRunForEngineRow, inputData []byt
jobRunId := sqlchelpers.UUIDToStr(stepRun.JobRunId)
stepRunId := sqlchelpers.UUIDToStr(stepRun.SRID)
tenantId := sqlchelpers.UUIDToStr(stepRun.SRTenantId)
+ workflowRunId := sqlchelpers.UUIDToStr(stepRun.WorkflowRunId)
payload, _ := datautils.ToJSONMap(StepRunReplayTaskPayload{
- JobRunId: jobRunId,
- StepRunId: stepRunId,
- InputData: string(inputData),
+ WorkflowRunId: workflowRunId,
+ JobRunId: jobRunId,
+ StepRunId: stepRunId,
+ InputData: string(inputData),
+ StepRetries: &stepRun.StepRetries,
+ RetryCount: &stepRun.SRRetryCount,
})
metadata, _ := datautils.ToJSONMap(StepRunReplayTaskMetadata{
@@ -201,8 +221,11 @@ func StepRunCancelToTask(stepRun *dbsqlc.GetStepRunForEngineRow, reason string)
tenantId := sqlchelpers.UUIDToStr(stepRun.SRTenantId)
payload, _ := datautils.ToJSONMap(StepRunNotifyCancelTaskPayload{
+ WorkflowRunId: sqlchelpers.UUIDToStr(stepRun.WorkflowRunId),
StepRunId: stepRunId,
CancelledReason: reason,
+ StepRetries: &stepRun.StepRetries,
+ RetryCount: &stepRun.SRRetryCount,
})
metadata, _ := datautils.ToJSONMap(StepRunNotifyCancelTaskMetadata{
@@ -219,8 +242,11 @@ func StepRunCancelToTask(stepRun *dbsqlc.GetStepRunForEngineRow, reason string)
func StepRunQueuedToTask(stepRun *dbsqlc.GetStepRunForEngineRow) *msgqueue.Message {
payload, _ := datautils.ToJSONMap(StepRunTaskPayload{
- JobRunId: sqlchelpers.UUIDToStr(stepRun.JobRunId),
- StepRunId: sqlchelpers.UUIDToStr(stepRun.SRID),
+ WorkflowRunId: sqlchelpers.UUIDToStr(stepRun.WorkflowRunId),
+ JobRunId: sqlchelpers.UUIDToStr(stepRun.JobRunId),
+ StepRunId: sqlchelpers.UUIDToStr(stepRun.SRID),
+ StepRetries: &stepRun.StepRetries,
+ RetryCount: &stepRun.SRRetryCount,
})
metadata, _ := datautils.ToJSONMap(StepRunTaskMetadata{
diff --git a/internal/services/shared/tasktypes/workflow.go b/internal/services/shared/tasktypes/workflow.go
index b5b6cf523..74480e7b5 100644
--- a/internal/services/shared/tasktypes/workflow.go
+++ b/internal/services/shared/tasktypes/workflow.go
@@ -39,8 +39,9 @@ type WorkflowRunQueuedTaskMetadata struct {
}
type WorkflowRunFinishedTask struct {
- WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"`
- Status string `json:"status" validate:"required"`
+ WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"`
+ Status string `json:"status" validate:"required"`
+ AdditionalMetadata map[string]interface{} `json:"additional_metadata"`
}
type WorkflowRunFinishedTaskMetadata struct {
diff --git a/internal/services/ticker/step_run_timeout.go b/internal/services/ticker/step_run_timeout.go
deleted file mode 100644
index 5612ae8aa..000000000
--- a/internal/services/ticker/step_run_timeout.go
+++ /dev/null
@@ -1,59 +0,0 @@
-package ticker
-
-import (
- "context"
- "time"
-
- "github.com/hatchet-dev/hatchet/internal/datautils"
- "github.com/hatchet-dev/hatchet/internal/msgqueue"
- "github.com/hatchet-dev/hatchet/internal/services/shared/tasktypes"
- "github.com/hatchet-dev/hatchet/pkg/repository/prisma/sqlchelpers"
-)
-
-func (t *TickerImpl) runPollStepRuns(ctx context.Context) func() {
- return func() {
- ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
- defer cancel()
-
- t.l.Debug().Msgf("ticker: polling step runs")
-
- stepRuns, err := t.repo.Ticker().PollStepRuns(ctx, t.tickerId)
-
- if err != nil {
- t.l.Err(err).Msg("could not poll step runs")
- return
- }
-
- for _, stepRun := range stepRuns {
- tenantId := sqlchelpers.UUIDToStr(stepRun.TenantId)
- stepRunId := sqlchelpers.UUIDToStr(stepRun.ID)
-
- err := t.mq.AddMessage(
- ctx,
- msgqueue.JOB_PROCESSING_QUEUE,
- taskStepRunTimedOut(tenantId, stepRunId),
- )
-
- if err != nil {
- t.l.Err(err).Msg("could not add step run timeout task")
- }
- }
- }
-}
-
-func taskStepRunTimedOut(tenantId, stepRunId string) *msgqueue.Message {
- payload, _ := datautils.ToJSONMap(tasktypes.StepRunTimedOutTaskPayload{
- StepRunId: stepRunId,
- })
-
- metadata, _ := datautils.ToJSONMap(tasktypes.StepRunTimedOutTaskMetadata{
- TenantId: tenantId,
- })
-
- return &msgqueue.Message{
- ID: "step-run-timed-out",
- Payload: payload,
- Metadata: metadata,
- Retries: 3,
- }
-}
diff --git a/internal/services/ticker/ticker.go b/internal/services/ticker/ticker.go
index 9e14aff33..50de39da0 100644
--- a/internal/services/ticker/ticker.go
+++ b/internal/services/ticker/ticker.go
@@ -181,18 +181,6 @@ func (t *TickerImpl) Start() (func() error, error) {
return nil, fmt.Errorf("could not create update heartbeat job: %w", err)
}
- // _, err = t.s.NewJob(
- // gocron.DurationJob(time.Second*1),
- // gocron.NewTask(
- // t.runPollStepRuns(ctx),
- // ),
- // )
-
- // if err != nil {
- // cancel()
- // return nil, fmt.Errorf("could not create update heartbeat job: %w", err)
- // }
-
_, err = t.s.NewJob(
gocron.DurationJob(time.Second*1),
gocron.NewTask(
diff --git a/pkg/client/listener.go b/pkg/client/listener.go
index bc656dab0..f93c2fe42 100644
--- a/pkg/client/listener.go
+++ b/pkg/client/listener.go
@@ -228,6 +228,8 @@ type SubscribeClient interface {
Stream(ctx context.Context, workflowRunId string, handler StreamHandler) error
+ StreamByAdditionalMetadata(ctx context.Context, key string, value string, handler StreamHandler) error
+
SubscribeToWorkflowRunEvents(ctx context.Context) (*WorkflowRunsListener, error)
}
@@ -256,7 +258,7 @@ func newSubscribe(conn *grpc.ClientConn, opts *sharedClientOpts) SubscribeClient
func (r *subscribeClientImpl) On(ctx context.Context, workflowRunId string, handler RunHandler) error {
stream, err := r.client.SubscribeToWorkflowEvents(r.ctx.newContext(ctx), &dispatchercontracts.SubscribeToWorkflowEventsRequest{
- WorkflowRunId: workflowRunId,
+ WorkflowRunId: &workflowRunId,
})
if err != nil {
@@ -286,7 +288,40 @@ func (r *subscribeClientImpl) On(ctx context.Context, workflowRunId string, hand
func (r *subscribeClientImpl) Stream(ctx context.Context, workflowRunId string, handler StreamHandler) error {
stream, err := r.client.SubscribeToWorkflowEvents(r.ctx.newContext(ctx), &dispatchercontracts.SubscribeToWorkflowEventsRequest{
- WorkflowRunId: workflowRunId,
+ WorkflowRunId: &workflowRunId,
+ })
+
+ if err != nil {
+ return err
+ }
+
+ for {
+ event, err := stream.Recv()
+
+ if err != nil {
+ if errors.Is(err, io.EOF) {
+ return nil
+ }
+
+ return err
+ }
+
+ if event.EventType != dispatchercontracts.ResourceEventType_RESOURCE_EVENT_TYPE_STREAM {
+ continue
+ }
+
+ if err := handler(StreamEvent{
+ Message: []byte(event.EventPayload),
+ }); err != nil {
+ return err
+ }
+ }
+}
+
+func (r *subscribeClientImpl) StreamByAdditionalMetadata(ctx context.Context, key string, value string, handler StreamHandler) error {
+ stream, err := r.client.SubscribeToWorkflowEvents(r.ctx.newContext(ctx), &dispatchercontracts.SubscribeToWorkflowEventsRequest{
+ AdditionalMetaKey: &key,
+ AdditionalMetaValue: &value,
})
if err != nil {
diff --git a/pkg/repository/prisma/dbsqlc/step_runs.sql b/pkg/repository/prisma/dbsqlc/step_runs.sql
index 04f8a7e04..3143de18f 100644
--- a/pkg/repository/prisma/dbsqlc/step_runs.sql
+++ b/pkg/repository/prisma/dbsqlc/step_runs.sql
@@ -31,6 +31,17 @@ WHERE
sr."id" = @id::uuid AND
sr."tenantId" = @tenantId::uuid;
+-- name: GetStepRunMeta :one
+SELECT
+ jr."workflowRunId" AS "workflowRunId",
+ sr."retryCount" AS "retryCount",
+ s."retries" as "retries"
+FROM "StepRun" sr
+JOIN "Step" s ON sr."stepId" = s."id"
+JOIN "JobRun" jr ON sr."jobRunId" = jr."id"
+WHERE sr."id" = @stepRunId::uuid
+AND sr."tenantId" = @tenantId::uuid;
+
-- name: GetStepRunForEngine :many
SELECT
DISTINCT ON (sr."id")
diff --git a/pkg/repository/prisma/dbsqlc/step_runs.sql.go b/pkg/repository/prisma/dbsqlc/step_runs.sql.go
index 7abf3ba8f..c2e113e97 100644
--- a/pkg/repository/prisma/dbsqlc/step_runs.sql.go
+++ b/pkg/repository/prisma/dbsqlc/step_runs.sql.go
@@ -1002,6 +1002,36 @@ func (q *Queries) GetStepRunForEngine(ctx context.Context, db DBTX, arg GetStepR
return items, nil
}
+const getStepRunMeta = `-- name: GetStepRunMeta :one
+SELECT
+ jr."workflowRunId" AS "workflowRunId",
+ sr."retryCount" AS "retryCount",
+ s."retries" as "retries"
+FROM "StepRun" sr
+JOIN "Step" s ON sr."stepId" = s."id"
+JOIN "JobRun" jr ON sr."jobRunId" = jr."id"
+WHERE sr."id" = $1::uuid
+AND sr."tenantId" = $2::uuid
+`
+
+type GetStepRunMetaParams struct {
+ Steprunid pgtype.UUID `json:"steprunid"`
+ Tenantid pgtype.UUID `json:"tenantid"`
+}
+
+type GetStepRunMetaRow struct {
+ WorkflowRunId pgtype.UUID `json:"workflowRunId"`
+ RetryCount int32 `json:"retryCount"`
+ Retries int32 `json:"retries"`
+}
+
+func (q *Queries) GetStepRunMeta(ctx context.Context, db DBTX, arg GetStepRunMetaParams) (*GetStepRunMetaRow, error) {
+ row := db.QueryRow(ctx, getStepRunMeta, arg.Steprunid, arg.Tenantid)
+ var i GetStepRunMetaRow
+ err := row.Scan(&i.WorkflowRunId, &i.RetryCount, &i.Retries)
+ return &i, err
+}
+
const listNonFinalChildStepRuns = `-- name: ListNonFinalChildStepRuns :many
WITH RECURSIVE currStepRun AS (
SELECT id, "createdAt", "updatedAt", "deletedAt", "tenantId", "jobRunId", "stepId", "order", "workerId", "tickerId", status, input, output, "requeueAfter", "scheduleTimeoutAt", error, "startedAt", "finishedAt", "timeoutAt", "cancelledAt", "cancelledReason", "cancelledError", "inputSchema", "callerFiles", "gitRepoBranch", "retryCount", "semaphoreReleased"
diff --git a/pkg/repository/prisma/dbsqlc/stream_event.sql b/pkg/repository/prisma/dbsqlc/stream_event.sql
index b61f83ff6..0295922c5 100644
--- a/pkg/repository/prisma/dbsqlc/stream_event.sql
+++ b/pkg/repository/prisma/dbsqlc/stream_event.sql
@@ -1,3 +1,14 @@
+-- name: GetStreamEventMeta :one
+SELECT
+ jr."workflowRunId" AS "workflowRunId",
+ sr."retryCount" AS "retryCount",
+ s."retries" as "retries"
+FROM "StepRun" sr
+JOIN "Step" s ON sr."stepId" = s."id"
+JOIN "JobRun" jr ON sr."jobRunId" = jr."id"
+WHERE sr."id" = @stepRunId::uuid
+AND sr."tenantId" = @tenantId::uuid;
+
-- name: CreateStreamEvent :one
INSERT INTO "StreamEvent" (
"createdAt",
diff --git a/pkg/repository/prisma/dbsqlc/stream_event.sql.go b/pkg/repository/prisma/dbsqlc/stream_event.sql.go
index 9b1a2597c..ec73bcf46 100644
--- a/pkg/repository/prisma/dbsqlc/stream_event.sql.go
+++ b/pkg/repository/prisma/dbsqlc/stream_event.sql.go
@@ -96,3 +96,33 @@ func (q *Queries) GetStreamEvent(ctx context.Context, db DBTX, arg GetStreamEven
)
return &i, err
}
+
+const getStreamEventMeta = `-- name: GetStreamEventMeta :one
+SELECT
+ jr."workflowRunId" AS "workflowRunId",
+ sr."retryCount" AS "retryCount",
+ s."retries" as "retries"
+FROM "StepRun" sr
+JOIN "Step" s ON sr."stepId" = s."id"
+JOIN "JobRun" jr ON sr."jobRunId" = jr."id"
+WHERE sr."id" = $1::uuid
+AND sr."tenantId" = $2::uuid
+`
+
+type GetStreamEventMetaParams struct {
+ Steprunid pgtype.UUID `json:"steprunid"`
+ Tenantid pgtype.UUID `json:"tenantid"`
+}
+
+type GetStreamEventMetaRow struct {
+ WorkflowRunId pgtype.UUID `json:"workflowRunId"`
+ RetryCount int32 `json:"retryCount"`
+ Retries int32 `json:"retries"`
+}
+
+func (q *Queries) GetStreamEventMeta(ctx context.Context, db DBTX, arg GetStreamEventMetaParams) (*GetStreamEventMetaRow, error) {
+ row := db.QueryRow(ctx, getStreamEventMeta, arg.Steprunid, arg.Tenantid)
+ var i GetStreamEventMetaRow
+ err := row.Scan(&i.WorkflowRunId, &i.RetryCount, &i.Retries)
+ return &i, err
+}
diff --git a/pkg/repository/prisma/dbsqlc/tickers.sql b/pkg/repository/prisma/dbsqlc/tickers.sql
index 8cf1fdf41..220da17a2 100644
--- a/pkg/repository/prisma/dbsqlc/tickers.sql
+++ b/pkg/repository/prisma/dbsqlc/tickers.sql
@@ -67,41 +67,6 @@ WHERE
"id" = sqlc.arg('id')::uuid
RETURNING *;
--- name: PollStepRuns :many
-WITH inactiveTickers AS (
- SELECT "id"
- FROM "Ticker"
- WHERE
- "isActive" = false OR
- "lastHeartbeatAt" < NOW() - INTERVAL '10 seconds'
-),
-stepRunsToTimeout AS (
- SELECT
- stepRun."id"
- FROM
- "StepRun" as stepRun
- LEFT JOIN inactiveTickers ON stepRun."tickerId" = inactiveTickers."id"
- WHERE
- ("status" = 'RUNNING' OR "status" = 'ASSIGNED')
- AND "deletedAt" IS NULL
- AND "timeoutAt" < NOW()
- AND (
- inactiveTickers."id" IS NOT NULL
- OR "tickerId" IS NULL
- )
- LIMIT 1000
- FOR UPDATE SKIP LOCKED
-)
-UPDATE
- "StepRun" as stepRuns
-SET
- "tickerId" = @tickerId::uuid
-FROM
- stepRunsToTimeout
-WHERE
- stepRuns."id" = stepRunsToTimeout."id"
-RETURNING stepRuns.*;
-
-- name: PollGetGroupKeyRuns :many
WITH getGroupKeyRunsToTimeout AS (
SELECT
diff --git a/pkg/repository/prisma/dbsqlc/tickers.sql.go b/pkg/repository/prisma/dbsqlc/tickers.sql.go
index 41575f1ed..d91b4b9d6 100644
--- a/pkg/repository/prisma/dbsqlc/tickers.sql.go
+++ b/pkg/repository/prisma/dbsqlc/tickers.sql.go
@@ -503,90 +503,6 @@ func (q *Queries) PollScheduledWorkflows(ctx context.Context, db DBTX, tickerid
return items, nil
}
-const pollStepRuns = `-- name: PollStepRuns :many
-WITH inactiveTickers AS (
- SELECT "id"
- FROM "Ticker"
- WHERE
- "isActive" = false OR
- "lastHeartbeatAt" < NOW() - INTERVAL '10 seconds'
-),
-stepRunsToTimeout AS (
- SELECT
- stepRun."id"
- FROM
- "StepRun" as stepRun
- LEFT JOIN inactiveTickers ON stepRun."tickerId" = inactiveTickers."id"
- WHERE
- ("status" = 'RUNNING' OR "status" = 'ASSIGNED')
- AND "deletedAt" IS NULL
- AND "timeoutAt" < NOW()
- AND (
- inactiveTickers."id" IS NOT NULL
- OR "tickerId" IS NULL
- )
- LIMIT 1000
- FOR UPDATE SKIP LOCKED
-)
-UPDATE
- "StepRun" as stepRuns
-SET
- "tickerId" = $1::uuid
-FROM
- stepRunsToTimeout
-WHERE
- stepRuns."id" = stepRunsToTimeout."id"
-RETURNING stepruns.id, stepruns."createdAt", stepruns."updatedAt", stepruns."deletedAt", stepruns."tenantId", stepruns."jobRunId", stepruns."stepId", stepruns."order", stepruns."workerId", stepruns."tickerId", stepruns.status, stepruns.input, stepruns.output, stepruns."requeueAfter", stepruns."scheduleTimeoutAt", stepruns.error, stepruns."startedAt", stepruns."finishedAt", stepruns."timeoutAt", stepruns."cancelledAt", stepruns."cancelledReason", stepruns."cancelledError", stepruns."inputSchema", stepruns."callerFiles", stepruns."gitRepoBranch", stepruns."retryCount", stepruns."semaphoreReleased"
-`
-
-func (q *Queries) PollStepRuns(ctx context.Context, db DBTX, tickerid pgtype.UUID) ([]*StepRun, error) {
- rows, err := db.Query(ctx, pollStepRuns, tickerid)
- if err != nil {
- return nil, err
- }
- defer rows.Close()
- var items []*StepRun
- for rows.Next() {
- var i StepRun
- if err := rows.Scan(
- &i.ID,
- &i.CreatedAt,
- &i.UpdatedAt,
- &i.DeletedAt,
- &i.TenantId,
- &i.JobRunId,
- &i.StepId,
- &i.Order,
- &i.WorkerId,
- &i.TickerId,
- &i.Status,
- &i.Input,
- &i.Output,
- &i.RequeueAfter,
- &i.ScheduleTimeoutAt,
- &i.Error,
- &i.StartedAt,
- &i.FinishedAt,
- &i.TimeoutAt,
- &i.CancelledAt,
- &i.CancelledReason,
- &i.CancelledError,
- &i.InputSchema,
- &i.CallerFiles,
- &i.GitRepoBranch,
- &i.RetryCount,
- &i.SemaphoreReleased,
- ); err != nil {
- return nil, err
- }
- items = append(items, &i)
- }
- if err := rows.Err(); err != nil {
- return nil, err
- }
- return items, nil
-}
-
const pollTenantAlerts = `-- name: PollTenantAlerts :many
WITH active_tenant_alerts AS (
SELECT
diff --git a/pkg/repository/prisma/dbsqlc/workflow_runs.sql b/pkg/repository/prisma/dbsqlc/workflow_runs.sql
index 19c5757bb..0b43e7dfe 100644
--- a/pkg/repository/prisma/dbsqlc/workflow_runs.sql
+++ b/pkg/repository/prisma/dbsqlc/workflow_runs.sql
@@ -492,6 +492,16 @@ FROM workflow_version
WHERE workflow_version."sticky" IS NOT NULL
RETURNING *;
+-- name: GetWorkflowRunAdditionalMeta :one
+SELECT
+ "additionalMetadata",
+ "id"
+FROM
+ "WorkflowRun"
+WHERE
+ "id" = @workflowRunId::uuid AND
+ "tenantId" = @tenantId::uuid;
+
-- name: GetWorkflowRunStickyStateForUpdate :one
SELECT
*
diff --git a/pkg/repository/prisma/dbsqlc/workflow_runs.sql.go b/pkg/repository/prisma/dbsqlc/workflow_runs.sql.go
index 0beb051d2..8822f4743 100644
--- a/pkg/repository/prisma/dbsqlc/workflow_runs.sql.go
+++ b/pkg/repository/prisma/dbsqlc/workflow_runs.sql.go
@@ -824,6 +824,34 @@ func (q *Queries) GetWorkflowRun(ctx context.Context, db DBTX, arg GetWorkflowRu
return items, nil
}
+const getWorkflowRunAdditionalMeta = `-- name: GetWorkflowRunAdditionalMeta :one
+SELECT
+ "additionalMetadata",
+ "id"
+FROM
+ "WorkflowRun"
+WHERE
+ "id" = $1::uuid AND
+ "tenantId" = $2::uuid
+`
+
+type GetWorkflowRunAdditionalMetaParams struct {
+ Workflowrunid pgtype.UUID `json:"workflowrunid"`
+ Tenantid pgtype.UUID `json:"tenantid"`
+}
+
+type GetWorkflowRunAdditionalMetaRow struct {
+ AdditionalMetadata []byte `json:"additionalMetadata"`
+ ID pgtype.UUID `json:"id"`
+}
+
+func (q *Queries) GetWorkflowRunAdditionalMeta(ctx context.Context, db DBTX, arg GetWorkflowRunAdditionalMetaParams) (*GetWorkflowRunAdditionalMetaRow, error) {
+ row := db.QueryRow(ctx, getWorkflowRunAdditionalMeta, arg.Workflowrunid, arg.Tenantid)
+ var i GetWorkflowRunAdditionalMetaRow
+ err := row.Scan(&i.AdditionalMetadata, &i.ID)
+ return &i, err
+}
+
const getWorkflowRunInput = `-- name: GetWorkflowRunInput :one
SELECT jld."data" AS lookupData
FROM "JobRun" jr
diff --git a/pkg/repository/prisma/step_run.go b/pkg/repository/prisma/step_run.go
index e672b2ee4..cd9ff8da3 100644
--- a/pkg/repository/prisma/step_run.go
+++ b/pkg/repository/prisma/step_run.go
@@ -213,6 +213,13 @@ func NewStepRunEngineRepository(pool *pgxpool.Pool, v validator.Validator, l *ze
}
}
+func (s *stepRunEngineRepository) GetStepRunMetaForEngine(ctx context.Context, tenantId, stepRunId string) (*dbsqlc.GetStepRunMetaRow, error) {
+ return s.queries.GetStepRunMeta(ctx, s.pool, dbsqlc.GetStepRunMetaParams{
+ Steprunid: sqlchelpers.UUIDFromStr(stepRunId),
+ Tenantid: sqlchelpers.UUIDFromStr(tenantId),
+ })
+}
+
func (s *stepRunEngineRepository) ListRunningStepRunsForTicker(ctx context.Context, tickerId string) ([]*dbsqlc.GetStepRunForEngineRow, error) {
tx, err := s.pool.Begin(ctx)
diff --git a/pkg/repository/prisma/stream_event.go b/pkg/repository/prisma/stream_event.go
index 3013b202b..77eb8f1fd 100644
--- a/pkg/repository/prisma/stream_event.go
+++ b/pkg/repository/prisma/stream_event.go
@@ -33,6 +33,13 @@ func NewStreamEventsEngineRepository(pool *pgxpool.Pool, v validator.Validator,
}
}
+func (r *streamEventEngineRepository) GetStreamEventMeta(ctx context.Context, tenantId string, stepRunId string) (*dbsqlc.GetStreamEventMetaRow, error) {
+ return r.queries.GetStreamEventMeta(ctx, r.pool, dbsqlc.GetStreamEventMetaParams{
+ Steprunid: sqlchelpers.UUIDFromStr(stepRunId),
+ Tenantid: sqlchelpers.UUIDFromStr(tenantId),
+ })
+}
+
func (r *streamEventEngineRepository) PutStreamEvent(ctx context.Context, tenantId string, opts *repository.CreateStreamEventOpts) (*dbsqlc.StreamEvent, error) {
if err := r.v.Validate(opts); err != nil {
return nil, err
diff --git a/pkg/repository/prisma/ticker.go b/pkg/repository/prisma/ticker.go
index 67f301e3e..d4ae248c3 100644
--- a/pkg/repository/prisma/ticker.go
+++ b/pkg/repository/prisma/ticker.go
@@ -85,10 +85,6 @@ func (t *tickerRepository) Delete(ctx context.Context, tickerId string) error {
return err
}
-func (t *tickerRepository) PollStepRuns(ctx context.Context, tickerId string) ([]*dbsqlc.StepRun, error) {
- return t.queries.PollStepRuns(ctx, t.pool, sqlchelpers.UUIDFromStr(tickerId))
-}
-
func (t *tickerRepository) PollGetGroupKeyRuns(ctx context.Context, tickerId string) ([]*dbsqlc.GetGroupKeyRun, error) {
return t.queries.PollGetGroupKeyRuns(ctx, t.pool, sqlchelpers.UUIDFromStr(tickerId))
}
diff --git a/pkg/repository/prisma/workflow_run.go b/pkg/repository/prisma/workflow_run.go
index 99be6163e..f8b97e26f 100644
--- a/pkg/repository/prisma/workflow_run.go
+++ b/pkg/repository/prisma/workflow_run.go
@@ -184,6 +184,13 @@ func (w *workflowRunEngineRepository) GetWorkflowRunById(ctx context.Context, te
return runs[0], nil
}
+func (w *workflowRunEngineRepository) GetWorkflowRunAdditionalMeta(ctx context.Context, tenantId, workflowRunId string) (*dbsqlc.GetWorkflowRunAdditionalMetaRow, error) {
+ return w.queries.GetWorkflowRunAdditionalMeta(ctx, w.pool, dbsqlc.GetWorkflowRunAdditionalMetaParams{
+ Tenantid: sqlchelpers.UUIDFromStr(tenantId),
+ Workflowrunid: sqlchelpers.UUIDFromStr(workflowRunId),
+ })
+}
+
func (w *workflowRunEngineRepository) ListWorkflowRuns(ctx context.Context, tenantId string, opts *repository.ListWorkflowRunsOpts) (*repository.ListWorkflowRunsResult, error) {
if err := w.v.Validate(opts); err != nil {
return nil, err
diff --git a/pkg/repository/step_run.go b/pkg/repository/step_run.go
index 914f61624..6585075cb 100644
--- a/pkg/repository/step_run.go
+++ b/pkg/repository/step_run.go
@@ -31,7 +31,9 @@ func IsFinalJobRunStatus(status dbsqlc.JobRunStatus) bool {
}
func IsFinalWorkflowRunStatus(status dbsqlc.WorkflowRunStatus) bool {
- return status != dbsqlc.WorkflowRunStatusPENDING && status != dbsqlc.WorkflowRunStatusRUNNING && status != dbsqlc.WorkflowRunStatusQUEUED
+ return status != dbsqlc.WorkflowRunStatusPENDING &&
+ status != dbsqlc.WorkflowRunStatusRUNNING &&
+ status != dbsqlc.WorkflowRunStatusQUEUED
}
type CreateStepRunEventOpts struct {
@@ -185,6 +187,8 @@ type StepRunEngineRepository interface {
GetStepRunDataForEngine(ctx context.Context, tenantId, stepRunId string) (*dbsqlc.GetStepRunDataForEngineRow, error)
+ GetStepRunMetaForEngine(ctx context.Context, tenantId, stepRunId string) (*dbsqlc.GetStepRunMetaRow, error)
+
// QueueStepRun is like UpdateStepRun, except that it will only update the step run if it is in
// a pending state.
QueueStepRun(ctx context.Context, tenantId, stepRunId string, opts *UpdateStepRunOpts) (*dbsqlc.GetStepRunForEngineRow, error)
diff --git a/pkg/repository/stream_event.go b/pkg/repository/stream_event.go
index 0684f0a8d..2b3e7aff6 100644
--- a/pkg/repository/stream_event.go
+++ b/pkg/repository/stream_event.go
@@ -30,4 +30,7 @@ type StreamEventsEngineRepository interface {
// CleanupStreamEvents deletes all stale StreamEvents.
CleanupStreamEvents(ctx context.Context) error
+
+ // GetStreamEventMeta
+ GetStreamEventMeta(ctx context.Context, tenantId string, stepRunId string) (*dbsqlc.GetStreamEventMetaRow, error)
}
diff --git a/pkg/repository/ticker.go b/pkg/repository/ticker.go
index 72b2414a0..181edf347 100644
--- a/pkg/repository/ticker.go
+++ b/pkg/repository/ticker.go
@@ -35,9 +35,6 @@ type TickerEngineRepository interface {
// Delete deletes a ticker.
Delete(ctx context.Context, tickerId string) error
- // PollStepRuns looks for step runs who are close to past their timeoutAt value and are in a running state
- PollStepRuns(ctx context.Context, tickerId string) ([]*dbsqlc.StepRun, error)
-
// PollJobRuns looks for get group key runs who are close to past their timeoutAt value and are in a running state
PollGetGroupKeyRuns(ctx context.Context, tickerId string) ([]*dbsqlc.GetGroupKeyRun, error)
diff --git a/pkg/repository/workflow_run.go b/pkg/repository/workflow_run.go
index 1ff63e233..d04131156 100644
--- a/pkg/repository/workflow_run.go
+++ b/pkg/repository/workflow_run.go
@@ -411,6 +411,8 @@ type WorkflowRunEngineRepository interface {
// GetWorkflowRunById returns a workflow run by id.
GetWorkflowRunById(ctx context.Context, tenantId, runId string) (*dbsqlc.GetWorkflowRunRow, error)
+ GetWorkflowRunAdditionalMeta(ctx context.Context, tenantId, workflowRunId string) (*dbsqlc.GetWorkflowRunAdditionalMetaRow, error)
+
ReplayWorkflowRun(ctx context.Context, tenantId, workflowRunId string) (*dbsqlc.GetWorkflowRunRow, error)
ListActiveQueuedWorkflowVersions(ctx context.Context) ([]*dbsqlc.ListActiveQueuedWorkflowVersionsRow, error)