diff --git a/api-contracts/dispatcher/dispatcher.proto b/api-contracts/dispatcher/dispatcher.proto index 95f4db1d1..79b854f32 100644 --- a/api-contracts/dispatcher/dispatcher.proto +++ b/api-contracts/dispatcher/dispatcher.proto @@ -236,7 +236,13 @@ message ActionEventResponse { message SubscribeToWorkflowEventsRequest { // the id of the workflow run - string workflowRunId = 1; + optional string workflowRunId = 1; + + // the key of the additional meta field to subscribe to + optional string additionalMetaKey = 2; + + // the value of the additional meta field to subscribe to + optional string additionalMetaValue = 3; } message SubscribeToWorkflowRunsRequest { diff --git a/cmd/hatchet-engine/engine/run.go b/cmd/hatchet-engine/engine/run.go index 0df51dcd2..3e24c5aa0 100644 --- a/cmd/hatchet-engine/engine/run.go +++ b/cmd/hatchet-engine/engine/run.go @@ -23,6 +23,7 @@ import ( "github.com/hatchet-dev/hatchet/internal/telemetry" "github.com/hatchet-dev/hatchet/pkg/config/loader" "github.com/hatchet-dev/hatchet/pkg/config/server" + "github.com/hatchet-dev/hatchet/pkg/repository/cache" ) type Teardown struct { @@ -280,6 +281,9 @@ func RunWithConfig(ctx context.Context, sc *server.ServerConfig) ([]Teardown, er } if sc.HasService("grpc") { + + cacheInstance := cache.New(10 * time.Second) + // create the dispatcher d, err := dispatcher.New( dispatcher.WithAlerter(sc.Alerter), @@ -287,6 +291,7 @@ func RunWithConfig(ctx context.Context, sc *server.ServerConfig) ([]Teardown, er dispatcher.WithRepository(sc.EngineRepository), dispatcher.WithLogger(sc.Logger), dispatcher.WithEntitlementsRepository(sc.EntitlementRepository), + dispatcher.WithCache(cacheInstance), ) if err != nil { @@ -362,6 +367,8 @@ func RunWithConfig(ctx context.Context, sc *server.ServerConfig) ([]Teardown, er if err != nil { return fmt.Errorf("failed to cleanup dispatcher: %w", err) } + + cacheInstance.Stop() return nil }) diff --git a/examples/stream-event-by-meta/main.go b/examples/stream-event-by-meta/main.go new file mode 100644 index 000000000..4e503618e --- /dev/null +++ b/examples/stream-event-by-meta/main.go @@ -0,0 +1,110 @@ +package main + +import ( + "fmt" + "time" + + "github.com/joho/godotenv" + "golang.org/x/exp/rand" + + "github.com/hatchet-dev/hatchet/pkg/client" + "github.com/hatchet-dev/hatchet/pkg/cmdutils" + "github.com/hatchet-dev/hatchet/pkg/worker" +) + +type streamEventInput struct { + Index int `json:"index"` +} + +type stepOneOutput struct { + Message string `json:"message"` +} + +func StepOne(ctx worker.HatchetContext) (result *stepOneOutput, err error) { + input := &streamEventInput{} + + err = ctx.WorkflowInput(input) + + if err != nil { + return nil, err + } + + ctx.StreamEvent([]byte(fmt.Sprintf("This is a stream event %d", input.Index))) + + return &stepOneOutput{ + Message: fmt.Sprintf("This ran at %s", time.Now().String()), + }, nil +} + +func main() { + err := godotenv.Load() + + if err != nil { + panic(err) + } + + c, err := client.New() + + if err != nil { + panic(err) + } + + w, err := worker.NewWorker( + worker.WithClient( + c, + ), + ) + + if err != nil { + panic(err) + } + + err = w.On( + worker.NoTrigger(), + &worker.WorkflowJob{ + Name: "stream-event-workflow", + Description: "This sends a stream event.", + Steps: []*worker.WorkflowStep{ + worker.Fn(StepOne).SetName("step-one"), + }, + }, + ) + + if err != nil { + panic(err) + } + + interruptCtx, cancel := cmdutils.InterruptContextFromChan(cmdutils.InterruptChan()) + defer cancel() + + _, err = w.Start() + + if err != nil { + panic(fmt.Errorf("error cleaning up: %w", err)) + } + + // Generate a random number between 1 and 100 + streamKey := "streamKey" + streamValue := fmt.Sprintf("stream-event-%d", rand.Intn(100)+1) + + _, err = c.Admin().RunWorkflow("stream-event-workflow", &streamEventInput{ + Index: 0, + }, + client.WithRunMetadata(map[string]interface{}{ + streamKey: streamValue, + }), + ) + + if err != nil { + panic(err) + } + + err = c.Subscribe().StreamByAdditionalMetadata(interruptCtx, streamKey, streamValue, func(event client.StreamEvent) error { + fmt.Println(string(event.Message)) + return nil + }) + + if err != nil { + panic(err) + } +} diff --git a/frontend/docs/pages/home/features/streaming.mdx b/frontend/docs/pages/home/features/streaming.mdx index 61e1da646..cfb06acbd 100644 --- a/frontend/docs/pages/home/features/streaming.mdx +++ b/frontend/docs/pages/home/features/streaming.mdx @@ -27,7 +27,7 @@ Listeners are used to subscribe to the event stream for a specific workflow run. Here's an example of how to create a listener: - + ```python @@ -59,6 +59,26 @@ async function listen_for_files() { ``` +```go copy +workflowRunId, err := c.Admin().RunWorkflow("stream-event-workflow", &streamEventInput{ + Index: 0, +}) + +if err != nil { +panic(err) +} + +err = c.Subscribe().Stream(interruptCtx, workflowRunId, func(event client.StreamEvent) error { +fmt.Println(string(event.Message)) + +return nil +}) + +if err != nil { +panic(err) +} + +```` ## Streaming from a Step Context @@ -75,7 +95,7 @@ def step1(self, context: Context): context.put_stream('hello from step1') # continue with the step run... return {"step1": "results"} -``` +```` @@ -225,6 +245,111 @@ for await (const event of listener) { + +## Streaming by Additional Metadata + +Often it is helpful to stream from multiple workflows (i.e. child workflows spawned from a parent) to achieve this, you can specify an [additional meta](/features/additional-metadata) key-value pair before runing a workflow that can then be used to subscribe to all events from workflows that have the same key-value pair. + +Since additinoal metadata is propegated from parent to child workflows, this can be used to track all events from a specific workflow run. + +Here's an example of how to create a listener: + + + + +```python +# Generate a random stream key to use to track all +# stream events for this workflow run. +streamKey = "streamKey" +streamVal = f"sk-{random.randint(1, 100)}" + +# Specify the stream key as additional metadata +# when running the workflow. + +# This key gets propagated to all child workflows +# and can have an arbitrary property name. + +workflowRun = hatchet.admin.run_workflow( + "Parent", + {"n": 2}, + options={"additional_metadata": {streamKey: streamVal}}, +) + +# Stream all events for the additional meta key value +listener = hatchet.listener.stream_by_additional_metadata(streamKey, streamVal) + +async for event in listener: + print(event.type, event.payload) + +``` + + + + +```typescript copy +// Generate a random stream key to use to track all +// stream events for this workflow run. +const streamKey = "streamKey"; +const streamVal = `sk-${Math.random().toString(36).substring(7)}`; + +// Specify the stream key as additional metadata +// when running the workflow. + +// This key gets propagated to all child workflows +// and can have an arbitrary property name. +await hatchet.admin.runWorkflow( + "parent-workflow", + {}, + { additionalMetadata: { [streamKey]: streamVal } }, +); + +// Stream all events for the additional meta key value +const stream = await hatchet.listener.streamByAdditionalMeta( + streamKey, + streamVal, +); + +for await (const event of stream) { + console.log("event received", event); +} +``` + + + + +```go copy +// Generate a random stream key to use to track all +// stream events for this workflow run. +streamKey := "streamKey" +streamValue := fmt.Sprintf("stream-event-%d", rand.Intn(100)+1) + +// Specify the stream key as additional metadata +// when running the workflow. + +// This key gets propagated to all child workflows +// and can have an arbitrary property name. +\_, err = c.Admin().RunWorkflow("stream-event-workflow", &streamEventInput{ + Index: 0, + }, + client.WithRunMetadata(map[string]interface{}{ + streamKey: streamValue, + }), +) + +if err != nil { + panic(err) +} + +// Stream all events for the additional meta key value +err = c.Subscribe().StreamByAdditionalMetadata(interruptCtx, streamKey, streamValue, func(event client.StreamEvent) error { + fmt.Println(string(event.Message)) + return nil +}) +``` + + + + ## Consuming Streams on Frontend To consume a stream from the backend, create a Streaming Response endpoint to "proxy" the stream from the Hatchet workflow run. diff --git a/internal/services/admin/server.go b/internal/services/admin/server.go index da7f9405c..a775a6695 100644 --- a/internal/services/admin/server.go +++ b/internal/services/admin/server.go @@ -137,6 +137,10 @@ func (a *AdminServiceImpl) TriggerWorkflow(ctx context.Context, req *contracts.T additionalMetadata, parentAdditionalMeta, ) + + if err != nil { + return nil, fmt.Errorf("could not create workflow run opts: %w", err) + } } else { createOpts, err = repository.GetCreateWorkflowRunOptsFromManual(workflowVersion, []byte(req.Input), additionalMetadata) } diff --git a/internal/services/controllers/jobs/controller.go b/internal/services/controllers/jobs/controller.go index 120f6fe29..a26872715 100644 --- a/internal/services/controllers/jobs/controller.go +++ b/internal/services/controllers/jobs/controller.go @@ -1214,7 +1214,9 @@ func (ec *JobsControllerImpl) failStepRun(ctx context.Context, tenantId, stepRun err = ec.mq.AddMessage( ctx, msgqueue.QueueTypeFromDispatcherID(dispatcherId), - stepRunCancelledTask(tenantId, stepRunId, workerId, dispatcherId, *repository.StringPtr(eventMessage)), + stepRunCancelledTask( + tenantId, stepRunId, workerId, dispatcherId, *repository.StringPtr(eventMessage), + sqlchelpers.UUIDToStr(stepRun.WorkflowRunId), &stepRun.StepRetries, &stepRun.SRRetryCount), ) if err != nil { @@ -1326,7 +1328,9 @@ func (ec *JobsControllerImpl) cancelStepRun(ctx context.Context, tenantId, stepR err = ec.mq.AddMessage( ctx, msgqueue.QueueTypeFromDispatcherID(dispatcherId), - stepRunCancelledTask(tenantId, stepRunId, workerId, dispatcherId, reason), + stepRunCancelledTask(tenantId, stepRunId, workerId, dispatcherId, reason, + sqlchelpers.UUIDToStr(stepRun.WorkflowRunId), &stepRun.StepRetries, &stepRun.SRRetryCount, + ), ) if err != nil { @@ -1379,11 +1383,14 @@ func stepRunAssignedTask(tenantId, stepRunId, workerId, dispatcherId string) *ms } } -func stepRunCancelledTask(tenantId, stepRunId, workerId, dispatcherId, cancelledReason string) *msgqueue.Message { +func stepRunCancelledTask(tenantId, stepRunId, workerId, dispatcherId, cancelledReason string, runId string, retries *int32, retryCount *int32) *msgqueue.Message { payload, _ := datautils.ToJSONMap(tasktypes.StepRunCancelledTaskPayload{ + WorkflowRunId: runId, StepRunId: stepRunId, WorkerId: workerId, CancelledReason: cancelledReason, + StepRetries: retries, + RetryCount: retryCount, }) metadata, _ := datautils.ToJSONMap(tasktypes.StepRunCancelledTaskMetadata{ @@ -1391,6 +1398,7 @@ func stepRunCancelledTask(tenantId, stepRunId, workerId, dispatcherId, cancelled DispatcherId: dispatcherId, }) + // TODO add additional metadata return &msgqueue.Message{ ID: "step-run-cancelled", Payload: payload, diff --git a/internal/services/dispatcher/contracts/dispatcher.pb.go b/internal/services/dispatcher/contracts/dispatcher.pb.go index 47346cb1b..b2bfd398a 100644 --- a/internal/services/dispatcher/contracts/dispatcher.pb.go +++ b/internal/services/dispatcher/contracts/dispatcher.pb.go @@ -1274,7 +1274,11 @@ type SubscribeToWorkflowEventsRequest struct { unknownFields protoimpl.UnknownFields // the id of the workflow run - WorkflowRunId string `protobuf:"bytes,1,opt,name=workflowRunId,proto3" json:"workflowRunId,omitempty"` + WorkflowRunId *string `protobuf:"bytes,1,opt,name=workflowRunId,proto3,oneof" json:"workflowRunId,omitempty"` + // the key of the additional meta field to subscribe to + AdditionalMetaKey *string `protobuf:"bytes,2,opt,name=additionalMetaKey,proto3,oneof" json:"additionalMetaKey,omitempty"` + // the value of the additional meta field to subscribe to + AdditionalMetaValue *string `protobuf:"bytes,3,opt,name=additionalMetaValue,proto3,oneof" json:"additionalMetaValue,omitempty"` } func (x *SubscribeToWorkflowEventsRequest) Reset() { @@ -1310,8 +1314,22 @@ func (*SubscribeToWorkflowEventsRequest) Descriptor() ([]byte, []int) { } func (x *SubscribeToWorkflowEventsRequest) GetWorkflowRunId() string { - if x != nil { - return x.WorkflowRunId + if x != nil && x.WorkflowRunId != nil { + return *x.WorkflowRunId + } + return "" +} + +func (x *SubscribeToWorkflowEventsRequest) GetAdditionalMetaKey() string { + if x != nil && x.AdditionalMetaKey != nil { + return *x.AdditionalMetaKey + } + return "" +} + +func (x *SubscribeToWorkflowEventsRequest) GetAdditionalMetaValue() string { + if x != nil && x.AdditionalMetaValue != nil { + return *x.AdditionalMetaValue } return "" } @@ -2189,12 +2207,23 @@ var file_dispatcher_proto_rawDesc = []byte{ 0x12, 0x1a, 0x0a, 0x08, 0x74, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x74, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x49, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, - 0x77, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x49, 0x64, 0x22, 0x48, 0x0a, 0x20, 0x53, 0x75, 0x62, 0x73, - 0x63, 0x72, 0x69, 0x62, 0x65, 0x54, 0x6f, 0x57, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, 0x45, - 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x24, 0x0a, 0x0d, - 0x77, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, 0x52, 0x75, 0x6e, 0x49, 0x64, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x0d, 0x77, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, 0x52, 0x75, 0x6e, - 0x49, 0x64, 0x22, 0x46, 0x0a, 0x1e, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x54, + 0x77, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x49, 0x64, 0x22, 0xf7, 0x01, 0x0a, 0x20, 0x53, 0x75, 0x62, + 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x54, 0x6f, 0x57, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, + 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x29, 0x0a, + 0x0d, 0x77, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, 0x52, 0x75, 0x6e, 0x49, 0x64, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0d, 0x77, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, + 0x52, 0x75, 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x12, 0x31, 0x0a, 0x11, 0x61, 0x64, 0x64, 0x69, + 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x4d, 0x65, 0x74, 0x61, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x48, 0x01, 0x52, 0x11, 0x61, 0x64, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x61, + 0x6c, 0x4d, 0x65, 0x74, 0x61, 0x4b, 0x65, 0x79, 0x88, 0x01, 0x01, 0x12, 0x35, 0x0a, 0x13, 0x61, + 0x64, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x4d, 0x65, 0x74, 0x61, 0x56, 0x61, 0x6c, + 0x75, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x48, 0x02, 0x52, 0x13, 0x61, 0x64, 0x64, 0x69, + 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x4d, 0x65, 0x74, 0x61, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x88, + 0x01, 0x01, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x77, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, 0x52, + 0x75, 0x6e, 0x49, 0x64, 0x42, 0x14, 0x0a, 0x12, 0x5f, 0x61, 0x64, 0x64, 0x69, 0x74, 0x69, 0x6f, + 0x6e, 0x61, 0x6c, 0x4d, 0x65, 0x74, 0x61, 0x4b, 0x65, 0x79, 0x42, 0x16, 0x0a, 0x14, 0x5f, 0x61, + 0x64, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x4d, 0x65, 0x74, 0x61, 0x56, 0x61, 0x6c, + 0x75, 0x65, 0x22, 0x46, 0x0a, 0x1e, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x54, 0x6f, 0x57, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, 0x52, 0x75, 0x6e, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x77, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, 0x52, 0x75, 0x6e, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x77, 0x6f, 0x72, @@ -2810,6 +2839,7 @@ func file_dispatcher_proto_init() { file_dispatcher_proto_msgTypes[0].OneofWrappers = []interface{}{} file_dispatcher_proto_msgTypes[1].OneofWrappers = []interface{}{} file_dispatcher_proto_msgTypes[5].OneofWrappers = []interface{}{} + file_dispatcher_proto_msgTypes[12].OneofWrappers = []interface{}{} file_dispatcher_proto_msgTypes[14].OneofWrappers = []interface{}{} file_dispatcher_proto_msgTypes[16].OneofWrappers = []interface{}{} type x struct{} diff --git a/internal/services/dispatcher/dispatcher.go b/internal/services/dispatcher/dispatcher.go index aa39a52df..e31db004e 100644 --- a/internal/services/dispatcher/dispatcher.go +++ b/internal/services/dispatcher/dispatcher.go @@ -21,6 +21,7 @@ import ( "github.com/hatchet-dev/hatchet/internal/telemetry/servertel" "github.com/hatchet-dev/hatchet/pkg/logger" "github.com/hatchet-dev/hatchet/pkg/repository" + "github.com/hatchet-dev/hatchet/pkg/repository/cache" "github.com/hatchet-dev/hatchet/pkg/repository/prisma/dbsqlc" "github.com/hatchet-dev/hatchet/pkg/repository/prisma/sqlchelpers" "github.com/hatchet-dev/hatchet/pkg/validator" @@ -36,12 +37,13 @@ type Dispatcher interface { type DispatcherImpl struct { contracts.UnimplementedDispatcherServer - s gocron.Scheduler - mq msgqueue.MessageQueue - l *zerolog.Logger - dv datautils.DataDecoderValidator - v validator.Validator - repo repository.EngineRepository + s gocron.Scheduler + mq msgqueue.MessageQueue + l *zerolog.Logger + dv datautils.DataDecoderValidator + v validator.Validator + repo repository.EngineRepository + cache cache.Cacheable entitlements repository.EntitlementsRepository @@ -121,6 +123,7 @@ type DispatcherOpts struct { entitlements repository.EntitlementsRepository dispatcherId string alerter hatcheterrors.Alerter + cache cache.Cacheable } func defaultDispatcherOpts() *DispatcherOpts { @@ -177,6 +180,12 @@ func WithDispatcherId(dispatcherId string) DispatcherOpt { } } +func WithCache(cache cache.Cacheable) DispatcherOpt { + return func(opts *DispatcherOpts) { + opts.cache = cache + } +} + func New(fs ...DispatcherOpt) (*DispatcherImpl, error) { opts := defaultDispatcherOpts() @@ -196,6 +205,10 @@ func New(fs ...DispatcherOpt) (*DispatcherImpl, error) { return nil, fmt.Errorf("entitlements repository is required. use WithEntitlementsRepository") } + if opts.cache == nil { + return nil, fmt.Errorf("cache is required. use WithCache") + } + newLogger := opts.l.With().Str("service", "dispatcher").Logger() opts.l = &newLogger @@ -220,6 +233,7 @@ func New(fs ...DispatcherOpt) (*DispatcherImpl, error) { workers: &workers{}, s: s, a: a, + cache: opts.cache, }, nil } diff --git a/internal/services/dispatcher/server.go b/internal/services/dispatcher/server.go index 2e0958422..2c594f867 100644 --- a/internal/services/dispatcher/server.go +++ b/internal/services/dispatcher/server.go @@ -2,6 +2,8 @@ package dispatcher import ( "context" + "crypto/sha256" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -23,6 +25,7 @@ import ( "github.com/hatchet-dev/hatchet/internal/services/shared/tasktypes" "github.com/hatchet-dev/hatchet/internal/telemetry" "github.com/hatchet-dev/hatchet/pkg/repository" + "github.com/hatchet-dev/hatchet/pkg/repository/cache" "github.com/hatchet-dev/hatchet/pkg/repository/metered" "github.com/hatchet-dev/hatchet/pkg/repository/prisma/dbsqlc" "github.com/hatchet-dev/hatchet/pkg/repository/prisma/sqlchelpers" @@ -484,12 +487,118 @@ func (s *DispatcherImpl) ReleaseSlot(ctx context.Context, req *contracts.Release return &contracts.ReleaseSlotResponse{}, nil } -// SubscribeToWorkflowEvents registers workflow events with the dispatcher func (s *DispatcherImpl) SubscribeToWorkflowEvents(request *contracts.SubscribeToWorkflowEventsRequest, stream contracts.Dispatcher_SubscribeToWorkflowEventsServer) error { + + fmt.Println("SubscribeToWorkflowEvents") + fmt.Println(request) + + if request.WorkflowRunId != nil { + return s.subscribeToWorkflowEventsByWorkflowRunId(*request.WorkflowRunId, stream) + } else if request.AdditionalMetaKey != nil && request.AdditionalMetaValue != nil { + return s.subscribeToWorkflowEventsByAdditionalMeta(*request.AdditionalMetaKey, *request.AdditionalMetaValue, stream) + } + + return status.Errorf(codes.InvalidArgument, "either workflow run id or additional meta key-value must be provided") +} + +// SubscribeToWorkflowEvents registers workflow events with the dispatcher +func (s *DispatcherImpl) subscribeToWorkflowEventsByAdditionalMeta(key string, value string, stream contracts.Dispatcher_SubscribeToWorkflowEventsServer) error { tenant := stream.Context().Value("tenant").(*dbsqlc.Tenant) tenantId := sqlchelpers.UUIDToStr(tenant.ID) - s.l.Debug().Msgf("Received subscribe request for workflow: %s", request.WorkflowRunId) + s.l.Error().Msgf("Received subscribe request for additional meta key-value: {%s: %s}", key, value) + + q, err := msgqueue.TenantEventConsumerQueue(tenantId) + if err != nil { + return err + } + + ctx, cancel := context.WithCancel(stream.Context()) + defer cancel() + + wg := sync.WaitGroup{} + + // Keep track of active workflow run IDs + activeRunIds := make(map[string]struct{}) + var mu sync.Mutex // Mutex to protect activeRunIds + + f := func(task *msgqueue.Message) error { + wg.Add(1) + defer wg.Done() + + e, err := s.tenantTaskToWorkflowEventByAdditionalMeta( + task, tenantId, key, value, + func(e *contracts.WorkflowEvent) (bool, error) { + mu.Lock() + defer mu.Unlock() + + if e.WorkflowRunId == "" { + return false, nil + } + + if e.ResourceType != contracts.ResourceType_RESOURCE_TYPE_WORKFLOW_RUN && + e.EventType != contracts.ResourceEventType_RESOURCE_EVENT_TYPE_COMPLETED { + // Add the run ID to active runs + activeRunIds[e.WorkflowRunId] = struct{}{} + } else { + // Remove the completed run from active runs + delete(activeRunIds, e.WorkflowRunId) + } + + // Only return true to hang up if we've seen at least one run and all runs are completed + if len(activeRunIds) == 0 { + return true, nil + } + + return false, nil + }) + + if err != nil { + s.l.Error().Err(err).Msgf("could not convert task to workflow event") + return nil + } else if e == nil { + return nil + } + + // send the task to the client + err = stream.Send(e) + + if err != nil { + cancel() // FIXME is this necessary? + s.l.Error().Err(err).Msgf("could not send workflow event to client") + return nil + } + + if e.Hangup { + cancel() + } + + return nil + } + + // subscribe to the task queue for the tenant + cleanupQueue, err := s.mq.Subscribe(q, msgqueue.NoOpHook, f) + + if err != nil { + return err + } + + <-ctx.Done() + if err := cleanupQueue(); err != nil { + return fmt.Errorf("could not cleanup queue: %w", err) + } + + waitFor(&wg, 60*time.Second, s.l) + + return nil +} + +// SubscribeToWorkflowEvents registers workflow events with the dispatcher +func (s *DispatcherImpl) subscribeToWorkflowEventsByWorkflowRunId(workflowRunId string, stream contracts.Dispatcher_SubscribeToWorkflowEventsServer) error { + tenant := stream.Context().Value("tenant").(*dbsqlc.Tenant) + tenantId := sqlchelpers.UUIDToStr(tenant.ID) + + s.l.Debug().Msgf("Received subscribe request for workflow: %s", workflowRunId) q, err := msgqueue.TenantEventConsumerQueue(tenantId) @@ -501,11 +610,11 @@ func (s *DispatcherImpl) SubscribeToWorkflowEvents(request *contracts.SubscribeT defer cancel() // if the workflow run is in a final state, hang up the connection - workflowRun, err := s.repo.WorkflowRun().GetWorkflowRunById(ctx, tenantId, request.WorkflowRunId) + workflowRun, err := s.repo.WorkflowRun().GetWorkflowRunById(ctx, tenantId, workflowRunId) if err != nil { if errors.Is(err, repository.ErrWorkflowRunNotFound) { - return status.Errorf(codes.NotFound, "workflow run %s not found", request.WorkflowRunId) + return status.Errorf(codes.NotFound, "workflow run %s not found", workflowRunId) } return err @@ -521,7 +630,7 @@ func (s *DispatcherImpl) SubscribeToWorkflowEvents(request *contracts.SubscribeT wg.Add(1) defer wg.Done() - e, err := s.tenantTaskToWorkflowEvent(task, tenantId, request.WorkflowRunId) + e, err := s.tenantTaskToWorkflowEventByRunId(task, tenantId, workflowRunId) if err != nil { s.l.Error().Err(err).Msgf("could not convert task to workflow event") @@ -901,9 +1010,18 @@ func (s *DispatcherImpl) handleStepRunStarted(ctx context.Context, request *cont startedAt := request.EventTimestamp.AsTime() + sr, err := s.repo.StepRun().GetStepRunForEngine(ctx, tenantId, request.StepRunId) + + if err != nil { + return nil, err + } + payload, _ := datautils.ToJSONMap(tasktypes.StepRunStartedTaskPayload{ - StepRunId: request.StepRunId, - StartedAt: startedAt.Format(time.RFC3339), + StepRunId: request.StepRunId, + StartedAt: startedAt.Format(time.RFC3339), + WorkflowRunId: sqlchelpers.UUIDToStr(sr.WorkflowRunId), + StepRetries: &sr.StepRetries, + RetryCount: &sr.SRRetryCount, }) metadata, _ := datautils.ToJSONMap(tasktypes.StepRunStartedTaskMetadata{ @@ -911,7 +1029,7 @@ func (s *DispatcherImpl) handleStepRunStarted(ctx context.Context, request *cont }) // send the event to the jobs queue - err := s.mq.AddMessage(ctx, msgqueue.JOB_PROCESSING_QUEUE, &msgqueue.Message{ + err = s.mq.AddMessage(ctx, msgqueue.JOB_PROCESSING_QUEUE, &msgqueue.Message{ ID: "step-run-started", Payload: payload, Metadata: metadata, @@ -955,10 +1073,19 @@ func (s *DispatcherImpl) handleStepRunCompleted(ctx context.Context, request *co finishedAt := request.EventTimestamp.AsTime() + meta, err := s.repo.StepRun().GetStepRunMetaForEngine(ctx, tenantId, request.StepRunId) + + if err != nil { + return nil, err + } + payload, _ := datautils.ToJSONMap(tasktypes.StepRunFinishedTaskPayload{ + WorkflowRunId: sqlchelpers.UUIDToStr(meta.WorkflowRunId), StepRunId: request.StepRunId, FinishedAt: finishedAt.Format(time.RFC3339), StepOutputData: request.EventPayload, + StepRetries: &meta.Retries, + RetryCount: &meta.RetryCount, }) metadata, _ := datautils.ToJSONMap(tasktypes.StepRunFinishedTaskMetadata{ @@ -966,7 +1093,7 @@ func (s *DispatcherImpl) handleStepRunCompleted(ctx context.Context, request *co }) // send the event to the jobs queue - err := s.mq.AddMessage(ctx, msgqueue.JOB_PROCESSING_QUEUE, &msgqueue.Message{ + err = s.mq.AddMessage(ctx, msgqueue.JOB_PROCESSING_QUEUE, &msgqueue.Message{ ID: "step-run-finished", Payload: payload, Metadata: metadata, @@ -991,10 +1118,19 @@ func (s *DispatcherImpl) handleStepRunFailed(ctx context.Context, request *contr failedAt := request.EventTimestamp.AsTime() + meta, err := s.repo.StepRun().GetStepRunMetaForEngine(ctx, tenantId, request.StepRunId) + + if err != nil { + return nil, err + } + payload, _ := datautils.ToJSONMap(tasktypes.StepRunFailedTaskPayload{ - StepRunId: request.StepRunId, - FailedAt: failedAt.Format(time.RFC3339), - Error: request.EventPayload, + WorkflowRunId: sqlchelpers.UUIDToStr(meta.WorkflowRunId), + StepRunId: request.StepRunId, + FailedAt: failedAt.Format(time.RFC3339), + Error: request.EventPayload, + StepRetries: &meta.Retries, + RetryCount: &meta.RetryCount, }) metadata, _ := datautils.ToJSONMap(tasktypes.StepRunFailedTaskMetadata{ @@ -1002,7 +1138,7 @@ func (s *DispatcherImpl) handleStepRunFailed(ctx context.Context, request *contr }) // send the event to the jobs queue - err := s.mq.AddMessage(ctx, msgqueue.JOB_PROCESSING_QUEUE, &msgqueue.Message{ + err = s.mq.AddMessage(ctx, msgqueue.JOB_PROCESSING_QUEUE, &msgqueue.Message{ ID: "step-run-failed", Payload: payload, Metadata: metadata, @@ -1126,68 +1262,143 @@ func (s *DispatcherImpl) handleGetGroupKeyRunFailed(ctx context.Context, request }, nil } -func (s *DispatcherImpl) tenantTaskToWorkflowEvent(task *msgqueue.Message, tenantId string, workflowRunIds ...string) (*contracts.WorkflowEvent, error) { +func UnmarshalPayload[T any](payload interface{}) (T, error) { + var result T + + // Convert the payload to JSON + jsonData, err := json.Marshal(payload) + if err != nil { + return result, fmt.Errorf("failed to marshal payload: %w", err) + } + + // Unmarshal JSON into the desired type + err = json.Unmarshal(jsonData, &result) + if err != nil { + return result, fmt.Errorf("failed to unmarshal payload: %w", err) + } + + return result, nil +} + +func (s *DispatcherImpl) taskToWorkflowEvent(task *msgqueue.Message, tenantId string, filter func(task *contracts.WorkflowEvent) (*bool, error), hangupFunc func(task *contracts.WorkflowEvent) (bool, error)) (*contracts.WorkflowEvent, error) { workflowEvent := &contracts.WorkflowEvent{} var stepRunId string switch task.ID { case "step-run-started": - stepRunId = task.Payload["step_run_id"].(string) - workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN - workflowEvent.ResourceId = stepRunId - workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_STARTED - case "step-run-finished": - stepRunId = task.Payload["step_run_id"].(string) - workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN - workflowEvent.ResourceId = stepRunId - workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_COMPLETED - workflowEvent.EventPayload = task.Payload["step_output_data"].(string) - case "step-run-failed": - stepRunId = task.Payload["step_run_id"].(string) - workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN - workflowEvent.ResourceId = stepRunId - workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_FAILED - workflowEvent.EventPayload = task.Payload["error"].(string) - case "step-run-cancelled": - stepRunId = task.Payload["step_run_id"].(string) - workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN - workflowEvent.ResourceId = stepRunId - workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_CANCELLED - case "step-run-timed-out": - stepRunId = task.Payload["step_run_id"].(string) - workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN - workflowEvent.ResourceId = stepRunId - workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_TIMED_OUT - case "step-run-stream-event": - stepRunId = task.Payload["step_run_id"].(string) - workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN - workflowEvent.ResourceId = stepRunId - workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_STREAM - case "workflow-run-finished": - workflowRunId := task.Payload["workflow_run_id"].(string) - workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_WORKFLOW_RUN - workflowEvent.ResourceId = workflowRunId - workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_COMPLETED - workflowEvent.Hangup = true - } - - if workflowEvent.ResourceType == contracts.ResourceType_RESOURCE_TYPE_STEP_RUN { - // determine if this step run matches the workflow run id - stepRun, err := s.repo.StepRun().GetStepRunForEngine(context.Background(), tenantId, stepRunId) - + payload, err := UnmarshalPayload[tasktypes.StepRunStartedTaskPayload](task.Payload) if err != nil { return nil, err } - - if !contains(workflowRunIds, sqlchelpers.UUIDToStr(stepRun.WorkflowRunId)) { - // this is an expected error, so we don't return it - return nil, nil + workflowEvent.WorkflowRunId = payload.WorkflowRunId + workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN + workflowEvent.ResourceId = stepRunId + workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_STARTED + workflowEvent.StepRetries = payload.StepRetries + workflowEvent.RetryCount = payload.RetryCount + case "step-run-finished": + payload, err := UnmarshalPayload[tasktypes.StepRunFinishedTaskPayload](task.Payload) + if err != nil { + return nil, err } + workflowEvent.WorkflowRunId = payload.WorkflowRunId + stepRunId = task.Payload["step_run_id"].(string) + workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN + workflowEvent.ResourceId = stepRunId + workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_COMPLETED + workflowEvent.EventPayload = payload.StepOutputData - workflowEvent.StepRetries = &stepRun.StepRetries - workflowEvent.RetryCount = &stepRun.SRRetryCount + workflowEvent.StepRetries = payload.StepRetries + workflowEvent.RetryCount = payload.RetryCount + case "step-run-failed": + payload, err := UnmarshalPayload[tasktypes.StepRunFailedTaskPayload](task.Payload) + if err != nil { + return nil, err + } + workflowEvent.WorkflowRunId = payload.WorkflowRunId + stepRunId = payload.StepRunId + workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN + workflowEvent.ResourceId = stepRunId + workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_FAILED + workflowEvent.EventPayload = payload.Error + workflowEvent.StepRetries = payload.StepRetries + workflowEvent.RetryCount = payload.RetryCount + case "step-run-cancelled": + payload, err := UnmarshalPayload[tasktypes.StepRunCancelledTaskPayload](task.Payload) + if err != nil { + return nil, err + } + workflowEvent.WorkflowRunId = payload.WorkflowRunId + stepRunId = payload.StepRunId + workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN + workflowEvent.ResourceId = stepRunId + workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_CANCELLED + + workflowEvent.StepRetries = payload.StepRetries + workflowEvent.RetryCount = payload.RetryCount + case "step-run-timed-out": + payload, err := UnmarshalPayload[tasktypes.StepRunTimedOutTaskPayload](task.Payload) + if err != nil { + return nil, err + } + workflowEvent.WorkflowRunId = payload.WorkflowRunId + stepRunId = payload.StepRunId + workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN + workflowEvent.ResourceId = stepRunId + workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_TIMED_OUT + + workflowEvent.StepRetries = payload.StepRetries + workflowEvent.RetryCount = payload.RetryCount + case "step-run-stream-event": + payload, err := UnmarshalPayload[tasktypes.StepRunStreamEventTaskPayload](task.Payload) + if err != nil { + return nil, err + } + workflowEvent.WorkflowRunId = payload.WorkflowRunId + stepRunId = payload.StepRunId + workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN + workflowEvent.ResourceId = stepRunId + workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_STREAM + + workflowEvent.StepRetries = payload.StepRetries + workflowEvent.RetryCount = payload.RetryCount + case "workflow-run-finished": + payload, err := UnmarshalPayload[tasktypes.WorkflowRunFinishedTask](task.Payload) + if err != nil { + return nil, err + } + workflowRunId := payload.WorkflowRunId + workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_WORKFLOW_RUN + workflowEvent.ResourceId = workflowRunId + workflowEvent.WorkflowRunId = workflowRunId + workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_COMPLETED + } + + match, err := filter(workflowEvent) + + if err != nil { + return nil, err + } + + if match != nil && !*match { + // if not a match, we don't return it + return nil, nil + } + + hangup, err := hangupFunc(workflowEvent) + + if err != nil { + return nil, err + } + + if hangup { + workflowEvent.Hangup = true + return workflowEvent, nil + } + + if workflowEvent.ResourceType == contracts.ResourceType_RESOURCE_TYPE_STEP_RUN { if workflowEvent.EventType == contracts.ResourceEventType_RESOURCE_EVENT_TYPE_STREAM { streamEventId, err := strconv.ParseInt(task.Metadata["stream_event_id"].(string), 10, 64) if err != nil { @@ -1202,13 +1413,96 @@ func (s *DispatcherImpl) tenantTaskToWorkflowEvent(task *msgqueue.Message, tenan workflowEvent.EventPayload = string(streamEvent.Message) } + } - } else if workflowEvent.ResourceType == contracts.ResourceType_RESOURCE_TYPE_WORKFLOW_RUN { - if !contains(workflowRunIds, workflowEvent.ResourceId) { - return nil, nil - } + return workflowEvent, nil +} - workflowEvent.Hangup = true +func (s *DispatcherImpl) tenantTaskToWorkflowEventByRunId(task *msgqueue.Message, tenantId string, workflowRunIds ...string) (*contracts.WorkflowEvent, error) { + + workflowEvent, err := s.taskToWorkflowEvent(task, tenantId, + func(e *contracts.WorkflowEvent) (*bool, error) { + hit := contains(workflowRunIds, e.WorkflowRunId) + return &hit, nil + }, + func(e *contracts.WorkflowEvent) (bool, error) { + // hangup on complete + return e.ResourceType == contracts.ResourceType_RESOURCE_TYPE_WORKFLOW_RUN && + e.EventType == contracts.ResourceEventType_RESOURCE_EVENT_TYPE_COMPLETED, nil + }, + ) + + if err != nil { + return nil, err + } + + return workflowEvent, nil +} + +func tinyHash(key, value string) string { + // Combine key and value + combined := fmt.Sprintf("%s:%s", key, value) + + // Create SHA-256 hash + hash := sha256.Sum256([]byte(combined)) + + // Take first 8 bytes of the hash + shortHash := hash[:8] + + // Encode to base64 + encoded := base64.URLEncoding.EncodeToString(shortHash) + + // Remove padding + return encoded[:11] +} + +func (s *DispatcherImpl) tenantTaskToWorkflowEventByAdditionalMeta(task *msgqueue.Message, tenantId string, key string, value string, hangup func(e *contracts.WorkflowEvent) (bool, error)) (*contracts.WorkflowEvent, error) { + workflowEvent, err := s.taskToWorkflowEvent( + task, + tenantId, + func(e *contracts.WorkflowEvent) (*bool, error) { + return cache.MakeCacheable[bool]( + s.cache, + fmt.Sprintf("wfram-%s-%s-%s", tenantId, e.WorkflowRunId, tinyHash(key, value)), + func() (*bool, error) { + + if e.WorkflowRunId == "" { + return nil, nil + } + + am, err := s.repo.WorkflowRun().GetWorkflowRunAdditionalMeta(context.Background(), tenantId, e.WorkflowRunId) + + if err != nil { + return nil, err + } + + if am.AdditionalMetadata == nil { + f := false + return &f, nil + } + + var additionalMetaMap map[string]interface{} + err = json.Unmarshal(am.AdditionalMetadata, &additionalMetaMap) + if err != nil { + return nil, err + } + + if v, ok := (additionalMetaMap)[key]; ok && v == value { + t := true + return &t, nil + } + + f := false + return &f, nil + + }, + ) + }, + hangup, + ) + + if err != nil { + return nil, err } return workflowEvent, nil diff --git a/internal/services/ingestor/server.go b/internal/services/ingestor/server.go index 81ae50feb..fb78e8958 100644 --- a/internal/services/ingestor/server.go +++ b/internal/services/ingestor/server.go @@ -102,6 +102,12 @@ func (i *IngestorImpl) PutStreamEvent(ctx context.Context, req *contracts.PutStr return nil, status.Errorf(codes.InvalidArgument, "Invalid request: %s", err) } + meta, err := i.streamEventRepository.GetStreamEventMeta(ctx, tenantId, req.StepRunId) + + if err != nil { + return nil, err + } + streamEvent, err := i.streamEventRepository.PutStreamEvent(ctx, tenantId, &opts) if err != nil { @@ -114,7 +120,9 @@ func (i *IngestorImpl) PutStreamEvent(ctx context.Context, req *contracts.PutStr return nil, err } - err = i.mq.AddMessage(context.Background(), q, streamEventToTask(streamEvent)) + e := streamEventToTask(streamEvent, sqlchelpers.UUIDToStr(meta.WorkflowRunId), &meta.RetryCount, &meta.Retries) + + err = i.mq.AddMessage(context.Background(), q, e) if err != nil { return nil, err @@ -176,13 +184,16 @@ func toEvent(e *dbsqlc.Event) (*contracts.Event, error) { }, nil } -func streamEventToTask(e *dbsqlc.StreamEvent) *msgqueue.Message { +func streamEventToTask(e *dbsqlc.StreamEvent, workflowRunId string, retryCount *int32, retries *int32) *msgqueue.Message { tenantId := sqlchelpers.UUIDToStr(e.TenantId) payloadTyped := tasktypes.StepRunStreamEventTaskPayload{ + WorkflowRunId: workflowRunId, StepRunId: sqlchelpers.UUIDToStr(e.StepRunId), CreatedAt: e.CreatedAt.Time.String(), StreamEventId: strconv.FormatInt(e.ID, 10), + RetryCount: retryCount, + StepRetries: retries, } payload, _ := datautils.ToJSONMap(payloadTyped) diff --git a/internal/services/shared/tasktypes/step.go b/internal/services/shared/tasktypes/step.go index 3cbeaef00..32660649a 100644 --- a/internal/services/shared/tasktypes/step.go +++ b/internal/services/shared/tasktypes/step.go @@ -3,19 +3,20 @@ package tasktypes import ( "github.com/hatchet-dev/hatchet/internal/datautils" "github.com/hatchet-dev/hatchet/internal/msgqueue" - "github.com/hatchet-dev/hatchet/pkg/repository/prisma/db" "github.com/hatchet-dev/hatchet/pkg/repository/prisma/dbsqlc" "github.com/hatchet-dev/hatchet/pkg/repository/prisma/sqlchelpers" ) type StepRunTaskPayload struct { - StepRunId string `json:"step_run_id" validate:"required,uuid"` - JobRunId string `json:"job_run_id" validate:"required,uuid"` + WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"` + StepRunId string `json:"step_run_id" validate:"required,uuid"` + JobRunId string `json:"job_run_id" validate:"required,uuid"` + StepRetries *int32 `json:"step_retries,omitempty"` + RetryCount *int32 `json:"retry_count,omitempty"` } type StepRunTaskMetadata struct { - TenantId string `json:"tenant_id" validate:"required,uuid"` - + TenantId string `json:"tenant_id" validate:"required,uuid"` StepId string `json:"step_id" validate:"required,uuid"` ActionId string `json:"action_id" validate:"required,actionId"` JobId string `json:"job_id" validate:"required,uuid"` @@ -34,9 +35,12 @@ type StepRunAssignedTaskMetadata struct { } type StepRunCancelledTaskPayload struct { + WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"` StepRunId string `json:"step_run_id" validate:"required,uuid"` WorkerId string `json:"worker_id" validate:"required,uuid"` CancelledReason string `json:"cancelled_reason" validate:"required"` + StepRetries *int32 `json:"step_retries,omitempty"` + RetryCount *int32 `json:"retry_count,omitempty"` } type StepRunCancelledTaskMetadata struct { @@ -53,8 +57,11 @@ type StepRunRequeueTaskMetadata struct { } type StepRunNotifyCancelTaskPayload struct { + WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"` StepRunId string `json:"step_run_id" validate:"required,uuid"` CancelledReason string `json:"cancelled_reason" validate:"required"` + StepRetries *int32 `json:"step_retries,omitempty"` + RetryCount *int32 `json:"retry_count,omitempty"` } type StepRunNotifyCancelTaskMetadata struct { @@ -62,8 +69,11 @@ type StepRunNotifyCancelTaskMetadata struct { } type StepRunStartedTaskPayload struct { - StepRunId string `json:"step_run_id" validate:"required,uuid"` - StartedAt string `json:"started_at" validate:"required"` + WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"` + StepRunId string `json:"step_run_id" validate:"required,uuid"` + StartedAt string `json:"started_at" validate:"required"` + StepRetries *int32 `json:"step_retries,omitempty"` + RetryCount *int32 `json:"retry_count,omitempty"` } type StepRunStartedTaskMetadata struct { @@ -71,9 +81,12 @@ type StepRunStartedTaskMetadata struct { } type StepRunFinishedTaskPayload struct { + WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"` StepRunId string `json:"step_run_id" validate:"required,uuid"` FinishedAt string `json:"finished_at" validate:"required"` StepOutputData string `json:"step_output_data"` + StepRetries *int32 `json:"step_retries,omitempty"` + RetryCount *int32 `json:"retry_count,omitempty"` } type StepRunFinishedTaskMetadata struct { @@ -81,9 +94,12 @@ type StepRunFinishedTaskMetadata struct { } type StepRunStreamEventTaskPayload struct { + WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"` StepRunId string `json:"step_run_id" validate:"required,uuid"` CreatedAt string `json:"created_at" validate:"required"` StreamEventId string `json:"stream_event_id"` + StepRetries *int32 `json:"step_retries,omitempty"` + RetryCount *int32 `json:"retry_count,omitempty"` } type StepRunStreamEventTaskMetadata struct { @@ -92,9 +108,12 @@ type StepRunStreamEventTaskMetadata struct { } type StepRunFailedTaskPayload struct { - StepRunId string `json:"step_run_id" validate:"required,uuid"` - FailedAt string `json:"failed_at" validate:"required"` - Error string `json:"error" validate:"required"` + WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"` + StepRunId string `json:"step_run_id" validate:"required,uuid"` + FailedAt string `json:"failed_at" validate:"required"` + Error string `json:"error" validate:"required"` + StepRetries *int32 `json:"step_retries,omitempty"` + RetryCount *int32 `json:"retry_count,omitempty"` } type StepRunFailedTaskMetadata struct { @@ -102,7 +121,10 @@ type StepRunFailedTaskMetadata struct { } type StepRunTimedOutTaskPayload struct { - StepRunId string `json:"step_run_id" validate:"required,uuid"` + WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"` + StepRunId string `json:"step_run_id" validate:"required,uuid"` + StepRetries *int32 `json:"step_retries,omitempty"` + RetryCount *int32 `json:"retry_count,omitempty"` } type StepRunTimedOutTaskMetadata struct { @@ -110,11 +132,15 @@ type StepRunTimedOutTaskMetadata struct { } type StepRunRetryTaskPayload struct { - StepRunId string `json:"step_run_id" validate:"required,uuid"` - JobRunId string `json:"job_run_id" validate:"required,uuid"` + WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"` + StepRunId string `json:"step_run_id" validate:"required,uuid"` + JobRunId string `json:"job_run_id" validate:"required,uuid"` // optional - if not provided, the step run will be retried with the same input InputData string `json:"input_data,omitempty"` + + StepRetries *int32 `json:"step_retries,omitempty"` + RetryCount *int32 `json:"retry_count,omitempty"` } type StepRunRetryTaskMetadata struct { @@ -122,43 +148,33 @@ type StepRunRetryTaskMetadata struct { } type StepRunReplayTaskPayload struct { - StepRunId string `json:"step_run_id" validate:"required,uuid"` - JobRunId string `json:"job_run_id" validate:"required,uuid"` + WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"` + StepRunId string `json:"step_run_id" validate:"required,uuid"` + JobRunId string `json:"job_run_id" validate:"required,uuid"` // optional - if not provided, the step run will be retried with the same input - InputData string `json:"input_data,omitempty"` + InputData string `json:"input_data,omitempty"` + StepRetries *int32 `json:"step_retries,omitempty"` + RetryCount *int32 `json:"retry_count,omitempty"` } type StepRunReplayTaskMetadata struct { TenantId string `json:"tenant_id" validate:"required,uuid"` } -func TenantToStepRunRequeueTask(tenant db.TenantModel) *msgqueue.Message { - payload, _ := datautils.ToJSONMap(StepRunRequeueTaskPayload{ - TenantId: tenant.ID, - }) - - metadata, _ := datautils.ToJSONMap(StepRunRequeueTaskMetadata{ - TenantId: tenant.ID, - }) - - return &msgqueue.Message{ - ID: "step-run-requeue-ticker", - Payload: payload, - Metadata: metadata, - Retries: 3, - } -} - func StepRunRetryToTask(stepRun *dbsqlc.GetStepRunForEngineRow, inputData []byte) *msgqueue.Message { jobRunId := sqlchelpers.UUIDToStr(stepRun.JobRunId) stepRunId := sqlchelpers.UUIDToStr(stepRun.SRID) tenantId := sqlchelpers.UUIDToStr(stepRun.SRTenantId) + workflowRunId := sqlchelpers.UUIDToStr(stepRun.WorkflowRunId) payload, _ := datautils.ToJSONMap(StepRunRetryTaskPayload{ - JobRunId: jobRunId, - StepRunId: stepRunId, - InputData: string(inputData), + WorkflowRunId: workflowRunId, + JobRunId: jobRunId, + StepRunId: stepRunId, + InputData: string(inputData), + StepRetries: &stepRun.StepRetries, + RetryCount: &stepRun.SRRetryCount, }) metadata, _ := datautils.ToJSONMap(StepRunRetryTaskMetadata{ @@ -177,11 +193,15 @@ func StepRunReplayToTask(stepRun *dbsqlc.GetStepRunForEngineRow, inputData []byt jobRunId := sqlchelpers.UUIDToStr(stepRun.JobRunId) stepRunId := sqlchelpers.UUIDToStr(stepRun.SRID) tenantId := sqlchelpers.UUIDToStr(stepRun.SRTenantId) + workflowRunId := sqlchelpers.UUIDToStr(stepRun.WorkflowRunId) payload, _ := datautils.ToJSONMap(StepRunReplayTaskPayload{ - JobRunId: jobRunId, - StepRunId: stepRunId, - InputData: string(inputData), + WorkflowRunId: workflowRunId, + JobRunId: jobRunId, + StepRunId: stepRunId, + InputData: string(inputData), + StepRetries: &stepRun.StepRetries, + RetryCount: &stepRun.SRRetryCount, }) metadata, _ := datautils.ToJSONMap(StepRunReplayTaskMetadata{ @@ -201,8 +221,11 @@ func StepRunCancelToTask(stepRun *dbsqlc.GetStepRunForEngineRow, reason string) tenantId := sqlchelpers.UUIDToStr(stepRun.SRTenantId) payload, _ := datautils.ToJSONMap(StepRunNotifyCancelTaskPayload{ + WorkflowRunId: sqlchelpers.UUIDToStr(stepRun.WorkflowRunId), StepRunId: stepRunId, CancelledReason: reason, + StepRetries: &stepRun.StepRetries, + RetryCount: &stepRun.SRRetryCount, }) metadata, _ := datautils.ToJSONMap(StepRunNotifyCancelTaskMetadata{ @@ -219,8 +242,11 @@ func StepRunCancelToTask(stepRun *dbsqlc.GetStepRunForEngineRow, reason string) func StepRunQueuedToTask(stepRun *dbsqlc.GetStepRunForEngineRow) *msgqueue.Message { payload, _ := datautils.ToJSONMap(StepRunTaskPayload{ - JobRunId: sqlchelpers.UUIDToStr(stepRun.JobRunId), - StepRunId: sqlchelpers.UUIDToStr(stepRun.SRID), + WorkflowRunId: sqlchelpers.UUIDToStr(stepRun.WorkflowRunId), + JobRunId: sqlchelpers.UUIDToStr(stepRun.JobRunId), + StepRunId: sqlchelpers.UUIDToStr(stepRun.SRID), + StepRetries: &stepRun.StepRetries, + RetryCount: &stepRun.SRRetryCount, }) metadata, _ := datautils.ToJSONMap(StepRunTaskMetadata{ diff --git a/internal/services/shared/tasktypes/workflow.go b/internal/services/shared/tasktypes/workflow.go index b5b6cf523..74480e7b5 100644 --- a/internal/services/shared/tasktypes/workflow.go +++ b/internal/services/shared/tasktypes/workflow.go @@ -39,8 +39,9 @@ type WorkflowRunQueuedTaskMetadata struct { } type WorkflowRunFinishedTask struct { - WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"` - Status string `json:"status" validate:"required"` + WorkflowRunId string `json:"workflow_run_id" validate:"required,uuid"` + Status string `json:"status" validate:"required"` + AdditionalMetadata map[string]interface{} `json:"additional_metadata"` } type WorkflowRunFinishedTaskMetadata struct { diff --git a/internal/services/ticker/step_run_timeout.go b/internal/services/ticker/step_run_timeout.go deleted file mode 100644 index 5612ae8aa..000000000 --- a/internal/services/ticker/step_run_timeout.go +++ /dev/null @@ -1,59 +0,0 @@ -package ticker - -import ( - "context" - "time" - - "github.com/hatchet-dev/hatchet/internal/datautils" - "github.com/hatchet-dev/hatchet/internal/msgqueue" - "github.com/hatchet-dev/hatchet/internal/services/shared/tasktypes" - "github.com/hatchet-dev/hatchet/pkg/repository/prisma/sqlchelpers" -) - -func (t *TickerImpl) runPollStepRuns(ctx context.Context) func() { - return func() { - ctx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - t.l.Debug().Msgf("ticker: polling step runs") - - stepRuns, err := t.repo.Ticker().PollStepRuns(ctx, t.tickerId) - - if err != nil { - t.l.Err(err).Msg("could not poll step runs") - return - } - - for _, stepRun := range stepRuns { - tenantId := sqlchelpers.UUIDToStr(stepRun.TenantId) - stepRunId := sqlchelpers.UUIDToStr(stepRun.ID) - - err := t.mq.AddMessage( - ctx, - msgqueue.JOB_PROCESSING_QUEUE, - taskStepRunTimedOut(tenantId, stepRunId), - ) - - if err != nil { - t.l.Err(err).Msg("could not add step run timeout task") - } - } - } -} - -func taskStepRunTimedOut(tenantId, stepRunId string) *msgqueue.Message { - payload, _ := datautils.ToJSONMap(tasktypes.StepRunTimedOutTaskPayload{ - StepRunId: stepRunId, - }) - - metadata, _ := datautils.ToJSONMap(tasktypes.StepRunTimedOutTaskMetadata{ - TenantId: tenantId, - }) - - return &msgqueue.Message{ - ID: "step-run-timed-out", - Payload: payload, - Metadata: metadata, - Retries: 3, - } -} diff --git a/internal/services/ticker/ticker.go b/internal/services/ticker/ticker.go index 9e14aff33..50de39da0 100644 --- a/internal/services/ticker/ticker.go +++ b/internal/services/ticker/ticker.go @@ -181,18 +181,6 @@ func (t *TickerImpl) Start() (func() error, error) { return nil, fmt.Errorf("could not create update heartbeat job: %w", err) } - // _, err = t.s.NewJob( - // gocron.DurationJob(time.Second*1), - // gocron.NewTask( - // t.runPollStepRuns(ctx), - // ), - // ) - - // if err != nil { - // cancel() - // return nil, fmt.Errorf("could not create update heartbeat job: %w", err) - // } - _, err = t.s.NewJob( gocron.DurationJob(time.Second*1), gocron.NewTask( diff --git a/pkg/client/listener.go b/pkg/client/listener.go index bc656dab0..f93c2fe42 100644 --- a/pkg/client/listener.go +++ b/pkg/client/listener.go @@ -228,6 +228,8 @@ type SubscribeClient interface { Stream(ctx context.Context, workflowRunId string, handler StreamHandler) error + StreamByAdditionalMetadata(ctx context.Context, key string, value string, handler StreamHandler) error + SubscribeToWorkflowRunEvents(ctx context.Context) (*WorkflowRunsListener, error) } @@ -256,7 +258,7 @@ func newSubscribe(conn *grpc.ClientConn, opts *sharedClientOpts) SubscribeClient func (r *subscribeClientImpl) On(ctx context.Context, workflowRunId string, handler RunHandler) error { stream, err := r.client.SubscribeToWorkflowEvents(r.ctx.newContext(ctx), &dispatchercontracts.SubscribeToWorkflowEventsRequest{ - WorkflowRunId: workflowRunId, + WorkflowRunId: &workflowRunId, }) if err != nil { @@ -286,7 +288,40 @@ func (r *subscribeClientImpl) On(ctx context.Context, workflowRunId string, hand func (r *subscribeClientImpl) Stream(ctx context.Context, workflowRunId string, handler StreamHandler) error { stream, err := r.client.SubscribeToWorkflowEvents(r.ctx.newContext(ctx), &dispatchercontracts.SubscribeToWorkflowEventsRequest{ - WorkflowRunId: workflowRunId, + WorkflowRunId: &workflowRunId, + }) + + if err != nil { + return err + } + + for { + event, err := stream.Recv() + + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } + + return err + } + + if event.EventType != dispatchercontracts.ResourceEventType_RESOURCE_EVENT_TYPE_STREAM { + continue + } + + if err := handler(StreamEvent{ + Message: []byte(event.EventPayload), + }); err != nil { + return err + } + } +} + +func (r *subscribeClientImpl) StreamByAdditionalMetadata(ctx context.Context, key string, value string, handler StreamHandler) error { + stream, err := r.client.SubscribeToWorkflowEvents(r.ctx.newContext(ctx), &dispatchercontracts.SubscribeToWorkflowEventsRequest{ + AdditionalMetaKey: &key, + AdditionalMetaValue: &value, }) if err != nil { diff --git a/pkg/repository/prisma/dbsqlc/step_runs.sql b/pkg/repository/prisma/dbsqlc/step_runs.sql index 04f8a7e04..3143de18f 100644 --- a/pkg/repository/prisma/dbsqlc/step_runs.sql +++ b/pkg/repository/prisma/dbsqlc/step_runs.sql @@ -31,6 +31,17 @@ WHERE sr."id" = @id::uuid AND sr."tenantId" = @tenantId::uuid; +-- name: GetStepRunMeta :one +SELECT + jr."workflowRunId" AS "workflowRunId", + sr."retryCount" AS "retryCount", + s."retries" as "retries" +FROM "StepRun" sr +JOIN "Step" s ON sr."stepId" = s."id" +JOIN "JobRun" jr ON sr."jobRunId" = jr."id" +WHERE sr."id" = @stepRunId::uuid +AND sr."tenantId" = @tenantId::uuid; + -- name: GetStepRunForEngine :many SELECT DISTINCT ON (sr."id") diff --git a/pkg/repository/prisma/dbsqlc/step_runs.sql.go b/pkg/repository/prisma/dbsqlc/step_runs.sql.go index 7abf3ba8f..c2e113e97 100644 --- a/pkg/repository/prisma/dbsqlc/step_runs.sql.go +++ b/pkg/repository/prisma/dbsqlc/step_runs.sql.go @@ -1002,6 +1002,36 @@ func (q *Queries) GetStepRunForEngine(ctx context.Context, db DBTX, arg GetStepR return items, nil } +const getStepRunMeta = `-- name: GetStepRunMeta :one +SELECT + jr."workflowRunId" AS "workflowRunId", + sr."retryCount" AS "retryCount", + s."retries" as "retries" +FROM "StepRun" sr +JOIN "Step" s ON sr."stepId" = s."id" +JOIN "JobRun" jr ON sr."jobRunId" = jr."id" +WHERE sr."id" = $1::uuid +AND sr."tenantId" = $2::uuid +` + +type GetStepRunMetaParams struct { + Steprunid pgtype.UUID `json:"steprunid"` + Tenantid pgtype.UUID `json:"tenantid"` +} + +type GetStepRunMetaRow struct { + WorkflowRunId pgtype.UUID `json:"workflowRunId"` + RetryCount int32 `json:"retryCount"` + Retries int32 `json:"retries"` +} + +func (q *Queries) GetStepRunMeta(ctx context.Context, db DBTX, arg GetStepRunMetaParams) (*GetStepRunMetaRow, error) { + row := db.QueryRow(ctx, getStepRunMeta, arg.Steprunid, arg.Tenantid) + var i GetStepRunMetaRow + err := row.Scan(&i.WorkflowRunId, &i.RetryCount, &i.Retries) + return &i, err +} + const listNonFinalChildStepRuns = `-- name: ListNonFinalChildStepRuns :many WITH RECURSIVE currStepRun AS ( SELECT id, "createdAt", "updatedAt", "deletedAt", "tenantId", "jobRunId", "stepId", "order", "workerId", "tickerId", status, input, output, "requeueAfter", "scheduleTimeoutAt", error, "startedAt", "finishedAt", "timeoutAt", "cancelledAt", "cancelledReason", "cancelledError", "inputSchema", "callerFiles", "gitRepoBranch", "retryCount", "semaphoreReleased" diff --git a/pkg/repository/prisma/dbsqlc/stream_event.sql b/pkg/repository/prisma/dbsqlc/stream_event.sql index b61f83ff6..0295922c5 100644 --- a/pkg/repository/prisma/dbsqlc/stream_event.sql +++ b/pkg/repository/prisma/dbsqlc/stream_event.sql @@ -1,3 +1,14 @@ +-- name: GetStreamEventMeta :one +SELECT + jr."workflowRunId" AS "workflowRunId", + sr."retryCount" AS "retryCount", + s."retries" as "retries" +FROM "StepRun" sr +JOIN "Step" s ON sr."stepId" = s."id" +JOIN "JobRun" jr ON sr."jobRunId" = jr."id" +WHERE sr."id" = @stepRunId::uuid +AND sr."tenantId" = @tenantId::uuid; + -- name: CreateStreamEvent :one INSERT INTO "StreamEvent" ( "createdAt", diff --git a/pkg/repository/prisma/dbsqlc/stream_event.sql.go b/pkg/repository/prisma/dbsqlc/stream_event.sql.go index 9b1a2597c..ec73bcf46 100644 --- a/pkg/repository/prisma/dbsqlc/stream_event.sql.go +++ b/pkg/repository/prisma/dbsqlc/stream_event.sql.go @@ -96,3 +96,33 @@ func (q *Queries) GetStreamEvent(ctx context.Context, db DBTX, arg GetStreamEven ) return &i, err } + +const getStreamEventMeta = `-- name: GetStreamEventMeta :one +SELECT + jr."workflowRunId" AS "workflowRunId", + sr."retryCount" AS "retryCount", + s."retries" as "retries" +FROM "StepRun" sr +JOIN "Step" s ON sr."stepId" = s."id" +JOIN "JobRun" jr ON sr."jobRunId" = jr."id" +WHERE sr."id" = $1::uuid +AND sr."tenantId" = $2::uuid +` + +type GetStreamEventMetaParams struct { + Steprunid pgtype.UUID `json:"steprunid"` + Tenantid pgtype.UUID `json:"tenantid"` +} + +type GetStreamEventMetaRow struct { + WorkflowRunId pgtype.UUID `json:"workflowRunId"` + RetryCount int32 `json:"retryCount"` + Retries int32 `json:"retries"` +} + +func (q *Queries) GetStreamEventMeta(ctx context.Context, db DBTX, arg GetStreamEventMetaParams) (*GetStreamEventMetaRow, error) { + row := db.QueryRow(ctx, getStreamEventMeta, arg.Steprunid, arg.Tenantid) + var i GetStreamEventMetaRow + err := row.Scan(&i.WorkflowRunId, &i.RetryCount, &i.Retries) + return &i, err +} diff --git a/pkg/repository/prisma/dbsqlc/tickers.sql b/pkg/repository/prisma/dbsqlc/tickers.sql index 8cf1fdf41..220da17a2 100644 --- a/pkg/repository/prisma/dbsqlc/tickers.sql +++ b/pkg/repository/prisma/dbsqlc/tickers.sql @@ -67,41 +67,6 @@ WHERE "id" = sqlc.arg('id')::uuid RETURNING *; --- name: PollStepRuns :many -WITH inactiveTickers AS ( - SELECT "id" - FROM "Ticker" - WHERE - "isActive" = false OR - "lastHeartbeatAt" < NOW() - INTERVAL '10 seconds' -), -stepRunsToTimeout AS ( - SELECT - stepRun."id" - FROM - "StepRun" as stepRun - LEFT JOIN inactiveTickers ON stepRun."tickerId" = inactiveTickers."id" - WHERE - ("status" = 'RUNNING' OR "status" = 'ASSIGNED') - AND "deletedAt" IS NULL - AND "timeoutAt" < NOW() - AND ( - inactiveTickers."id" IS NOT NULL - OR "tickerId" IS NULL - ) - LIMIT 1000 - FOR UPDATE SKIP LOCKED -) -UPDATE - "StepRun" as stepRuns -SET - "tickerId" = @tickerId::uuid -FROM - stepRunsToTimeout -WHERE - stepRuns."id" = stepRunsToTimeout."id" -RETURNING stepRuns.*; - -- name: PollGetGroupKeyRuns :many WITH getGroupKeyRunsToTimeout AS ( SELECT diff --git a/pkg/repository/prisma/dbsqlc/tickers.sql.go b/pkg/repository/prisma/dbsqlc/tickers.sql.go index 41575f1ed..d91b4b9d6 100644 --- a/pkg/repository/prisma/dbsqlc/tickers.sql.go +++ b/pkg/repository/prisma/dbsqlc/tickers.sql.go @@ -503,90 +503,6 @@ func (q *Queries) PollScheduledWorkflows(ctx context.Context, db DBTX, tickerid return items, nil } -const pollStepRuns = `-- name: PollStepRuns :many -WITH inactiveTickers AS ( - SELECT "id" - FROM "Ticker" - WHERE - "isActive" = false OR - "lastHeartbeatAt" < NOW() - INTERVAL '10 seconds' -), -stepRunsToTimeout AS ( - SELECT - stepRun."id" - FROM - "StepRun" as stepRun - LEFT JOIN inactiveTickers ON stepRun."tickerId" = inactiveTickers."id" - WHERE - ("status" = 'RUNNING' OR "status" = 'ASSIGNED') - AND "deletedAt" IS NULL - AND "timeoutAt" < NOW() - AND ( - inactiveTickers."id" IS NOT NULL - OR "tickerId" IS NULL - ) - LIMIT 1000 - FOR UPDATE SKIP LOCKED -) -UPDATE - "StepRun" as stepRuns -SET - "tickerId" = $1::uuid -FROM - stepRunsToTimeout -WHERE - stepRuns."id" = stepRunsToTimeout."id" -RETURNING stepruns.id, stepruns."createdAt", stepruns."updatedAt", stepruns."deletedAt", stepruns."tenantId", stepruns."jobRunId", stepruns."stepId", stepruns."order", stepruns."workerId", stepruns."tickerId", stepruns.status, stepruns.input, stepruns.output, stepruns."requeueAfter", stepruns."scheduleTimeoutAt", stepruns.error, stepruns."startedAt", stepruns."finishedAt", stepruns."timeoutAt", stepruns."cancelledAt", stepruns."cancelledReason", stepruns."cancelledError", stepruns."inputSchema", stepruns."callerFiles", stepruns."gitRepoBranch", stepruns."retryCount", stepruns."semaphoreReleased" -` - -func (q *Queries) PollStepRuns(ctx context.Context, db DBTX, tickerid pgtype.UUID) ([]*StepRun, error) { - rows, err := db.Query(ctx, pollStepRuns, tickerid) - if err != nil { - return nil, err - } - defer rows.Close() - var items []*StepRun - for rows.Next() { - var i StepRun - if err := rows.Scan( - &i.ID, - &i.CreatedAt, - &i.UpdatedAt, - &i.DeletedAt, - &i.TenantId, - &i.JobRunId, - &i.StepId, - &i.Order, - &i.WorkerId, - &i.TickerId, - &i.Status, - &i.Input, - &i.Output, - &i.RequeueAfter, - &i.ScheduleTimeoutAt, - &i.Error, - &i.StartedAt, - &i.FinishedAt, - &i.TimeoutAt, - &i.CancelledAt, - &i.CancelledReason, - &i.CancelledError, - &i.InputSchema, - &i.CallerFiles, - &i.GitRepoBranch, - &i.RetryCount, - &i.SemaphoreReleased, - ); err != nil { - return nil, err - } - items = append(items, &i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - const pollTenantAlerts = `-- name: PollTenantAlerts :many WITH active_tenant_alerts AS ( SELECT diff --git a/pkg/repository/prisma/dbsqlc/workflow_runs.sql b/pkg/repository/prisma/dbsqlc/workflow_runs.sql index 19c5757bb..0b43e7dfe 100644 --- a/pkg/repository/prisma/dbsqlc/workflow_runs.sql +++ b/pkg/repository/prisma/dbsqlc/workflow_runs.sql @@ -492,6 +492,16 @@ FROM workflow_version WHERE workflow_version."sticky" IS NOT NULL RETURNING *; +-- name: GetWorkflowRunAdditionalMeta :one +SELECT + "additionalMetadata", + "id" +FROM + "WorkflowRun" +WHERE + "id" = @workflowRunId::uuid AND + "tenantId" = @tenantId::uuid; + -- name: GetWorkflowRunStickyStateForUpdate :one SELECT * diff --git a/pkg/repository/prisma/dbsqlc/workflow_runs.sql.go b/pkg/repository/prisma/dbsqlc/workflow_runs.sql.go index 0beb051d2..8822f4743 100644 --- a/pkg/repository/prisma/dbsqlc/workflow_runs.sql.go +++ b/pkg/repository/prisma/dbsqlc/workflow_runs.sql.go @@ -824,6 +824,34 @@ func (q *Queries) GetWorkflowRun(ctx context.Context, db DBTX, arg GetWorkflowRu return items, nil } +const getWorkflowRunAdditionalMeta = `-- name: GetWorkflowRunAdditionalMeta :one +SELECT + "additionalMetadata", + "id" +FROM + "WorkflowRun" +WHERE + "id" = $1::uuid AND + "tenantId" = $2::uuid +` + +type GetWorkflowRunAdditionalMetaParams struct { + Workflowrunid pgtype.UUID `json:"workflowrunid"` + Tenantid pgtype.UUID `json:"tenantid"` +} + +type GetWorkflowRunAdditionalMetaRow struct { + AdditionalMetadata []byte `json:"additionalMetadata"` + ID pgtype.UUID `json:"id"` +} + +func (q *Queries) GetWorkflowRunAdditionalMeta(ctx context.Context, db DBTX, arg GetWorkflowRunAdditionalMetaParams) (*GetWorkflowRunAdditionalMetaRow, error) { + row := db.QueryRow(ctx, getWorkflowRunAdditionalMeta, arg.Workflowrunid, arg.Tenantid) + var i GetWorkflowRunAdditionalMetaRow + err := row.Scan(&i.AdditionalMetadata, &i.ID) + return &i, err +} + const getWorkflowRunInput = `-- name: GetWorkflowRunInput :one SELECT jld."data" AS lookupData FROM "JobRun" jr diff --git a/pkg/repository/prisma/step_run.go b/pkg/repository/prisma/step_run.go index e672b2ee4..cd9ff8da3 100644 --- a/pkg/repository/prisma/step_run.go +++ b/pkg/repository/prisma/step_run.go @@ -213,6 +213,13 @@ func NewStepRunEngineRepository(pool *pgxpool.Pool, v validator.Validator, l *ze } } +func (s *stepRunEngineRepository) GetStepRunMetaForEngine(ctx context.Context, tenantId, stepRunId string) (*dbsqlc.GetStepRunMetaRow, error) { + return s.queries.GetStepRunMeta(ctx, s.pool, dbsqlc.GetStepRunMetaParams{ + Steprunid: sqlchelpers.UUIDFromStr(stepRunId), + Tenantid: sqlchelpers.UUIDFromStr(tenantId), + }) +} + func (s *stepRunEngineRepository) ListRunningStepRunsForTicker(ctx context.Context, tickerId string) ([]*dbsqlc.GetStepRunForEngineRow, error) { tx, err := s.pool.Begin(ctx) diff --git a/pkg/repository/prisma/stream_event.go b/pkg/repository/prisma/stream_event.go index 3013b202b..77eb8f1fd 100644 --- a/pkg/repository/prisma/stream_event.go +++ b/pkg/repository/prisma/stream_event.go @@ -33,6 +33,13 @@ func NewStreamEventsEngineRepository(pool *pgxpool.Pool, v validator.Validator, } } +func (r *streamEventEngineRepository) GetStreamEventMeta(ctx context.Context, tenantId string, stepRunId string) (*dbsqlc.GetStreamEventMetaRow, error) { + return r.queries.GetStreamEventMeta(ctx, r.pool, dbsqlc.GetStreamEventMetaParams{ + Steprunid: sqlchelpers.UUIDFromStr(stepRunId), + Tenantid: sqlchelpers.UUIDFromStr(tenantId), + }) +} + func (r *streamEventEngineRepository) PutStreamEvent(ctx context.Context, tenantId string, opts *repository.CreateStreamEventOpts) (*dbsqlc.StreamEvent, error) { if err := r.v.Validate(opts); err != nil { return nil, err diff --git a/pkg/repository/prisma/ticker.go b/pkg/repository/prisma/ticker.go index 67f301e3e..d4ae248c3 100644 --- a/pkg/repository/prisma/ticker.go +++ b/pkg/repository/prisma/ticker.go @@ -85,10 +85,6 @@ func (t *tickerRepository) Delete(ctx context.Context, tickerId string) error { return err } -func (t *tickerRepository) PollStepRuns(ctx context.Context, tickerId string) ([]*dbsqlc.StepRun, error) { - return t.queries.PollStepRuns(ctx, t.pool, sqlchelpers.UUIDFromStr(tickerId)) -} - func (t *tickerRepository) PollGetGroupKeyRuns(ctx context.Context, tickerId string) ([]*dbsqlc.GetGroupKeyRun, error) { return t.queries.PollGetGroupKeyRuns(ctx, t.pool, sqlchelpers.UUIDFromStr(tickerId)) } diff --git a/pkg/repository/prisma/workflow_run.go b/pkg/repository/prisma/workflow_run.go index 99be6163e..f8b97e26f 100644 --- a/pkg/repository/prisma/workflow_run.go +++ b/pkg/repository/prisma/workflow_run.go @@ -184,6 +184,13 @@ func (w *workflowRunEngineRepository) GetWorkflowRunById(ctx context.Context, te return runs[0], nil } +func (w *workflowRunEngineRepository) GetWorkflowRunAdditionalMeta(ctx context.Context, tenantId, workflowRunId string) (*dbsqlc.GetWorkflowRunAdditionalMetaRow, error) { + return w.queries.GetWorkflowRunAdditionalMeta(ctx, w.pool, dbsqlc.GetWorkflowRunAdditionalMetaParams{ + Tenantid: sqlchelpers.UUIDFromStr(tenantId), + Workflowrunid: sqlchelpers.UUIDFromStr(workflowRunId), + }) +} + func (w *workflowRunEngineRepository) ListWorkflowRuns(ctx context.Context, tenantId string, opts *repository.ListWorkflowRunsOpts) (*repository.ListWorkflowRunsResult, error) { if err := w.v.Validate(opts); err != nil { return nil, err diff --git a/pkg/repository/step_run.go b/pkg/repository/step_run.go index 914f61624..6585075cb 100644 --- a/pkg/repository/step_run.go +++ b/pkg/repository/step_run.go @@ -31,7 +31,9 @@ func IsFinalJobRunStatus(status dbsqlc.JobRunStatus) bool { } func IsFinalWorkflowRunStatus(status dbsqlc.WorkflowRunStatus) bool { - return status != dbsqlc.WorkflowRunStatusPENDING && status != dbsqlc.WorkflowRunStatusRUNNING && status != dbsqlc.WorkflowRunStatusQUEUED + return status != dbsqlc.WorkflowRunStatusPENDING && + status != dbsqlc.WorkflowRunStatusRUNNING && + status != dbsqlc.WorkflowRunStatusQUEUED } type CreateStepRunEventOpts struct { @@ -185,6 +187,8 @@ type StepRunEngineRepository interface { GetStepRunDataForEngine(ctx context.Context, tenantId, stepRunId string) (*dbsqlc.GetStepRunDataForEngineRow, error) + GetStepRunMetaForEngine(ctx context.Context, tenantId, stepRunId string) (*dbsqlc.GetStepRunMetaRow, error) + // QueueStepRun is like UpdateStepRun, except that it will only update the step run if it is in // a pending state. QueueStepRun(ctx context.Context, tenantId, stepRunId string, opts *UpdateStepRunOpts) (*dbsqlc.GetStepRunForEngineRow, error) diff --git a/pkg/repository/stream_event.go b/pkg/repository/stream_event.go index 0684f0a8d..2b3e7aff6 100644 --- a/pkg/repository/stream_event.go +++ b/pkg/repository/stream_event.go @@ -30,4 +30,7 @@ type StreamEventsEngineRepository interface { // CleanupStreamEvents deletes all stale StreamEvents. CleanupStreamEvents(ctx context.Context) error + + // GetStreamEventMeta + GetStreamEventMeta(ctx context.Context, tenantId string, stepRunId string) (*dbsqlc.GetStreamEventMetaRow, error) } diff --git a/pkg/repository/ticker.go b/pkg/repository/ticker.go index 72b2414a0..181edf347 100644 --- a/pkg/repository/ticker.go +++ b/pkg/repository/ticker.go @@ -35,9 +35,6 @@ type TickerEngineRepository interface { // Delete deletes a ticker. Delete(ctx context.Context, tickerId string) error - // PollStepRuns looks for step runs who are close to past their timeoutAt value and are in a running state - PollStepRuns(ctx context.Context, tickerId string) ([]*dbsqlc.StepRun, error) - // PollJobRuns looks for get group key runs who are close to past their timeoutAt value and are in a running state PollGetGroupKeyRuns(ctx context.Context, tickerId string) ([]*dbsqlc.GetGroupKeyRun, error) diff --git a/pkg/repository/workflow_run.go b/pkg/repository/workflow_run.go index 1ff63e233..d04131156 100644 --- a/pkg/repository/workflow_run.go +++ b/pkg/repository/workflow_run.go @@ -411,6 +411,8 @@ type WorkflowRunEngineRepository interface { // GetWorkflowRunById returns a workflow run by id. GetWorkflowRunById(ctx context.Context, tenantId, runId string) (*dbsqlc.GetWorkflowRunRow, error) + GetWorkflowRunAdditionalMeta(ctx context.Context, tenantId, workflowRunId string) (*dbsqlc.GetWorkflowRunAdditionalMetaRow, error) + ReplayWorkflowRun(ctx context.Context, tenantId, workflowRunId string) (*dbsqlc.GetWorkflowRunRow, error) ListActiveQueuedWorkflowVersions(ctx context.Context) ([]*dbsqlc.ListActiveQueuedWorkflowVersionsRow, error)