mirror of
https://github.com/hatchet-dev/hatchet.git
synced 2025-12-19 07:40:16 -06:00
1788 lines
52 KiB
Go
1788 lines
52 KiB
Go
package dispatcher
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
"github.com/rs/zerolog"
|
|
telemetry_codes "go.opentelemetry.io/otel/codes"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
"github.com/hatchet-dev/hatchet/internal/datautils"
|
|
"github.com/hatchet-dev/hatchet/internal/msgqueue"
|
|
"github.com/hatchet-dev/hatchet/internal/services/dispatcher/contracts"
|
|
"github.com/hatchet-dev/hatchet/internal/services/shared/tasktypes"
|
|
"github.com/hatchet-dev/hatchet/pkg/repository"
|
|
"github.com/hatchet-dev/hatchet/pkg/repository/cache"
|
|
"github.com/hatchet-dev/hatchet/pkg/repository/metered"
|
|
"github.com/hatchet-dev/hatchet/pkg/repository/postgres/dbsqlc"
|
|
"github.com/hatchet-dev/hatchet/pkg/repository/postgres/sqlchelpers"
|
|
"github.com/hatchet-dev/hatchet/pkg/telemetry"
|
|
)
|
|
|
|
func (s *DispatcherImpl) Register(ctx context.Context, request *contracts.WorkerRegisterRequest) (*contracts.WorkerRegisterResponse, error) {
|
|
tenant := ctx.Value("tenant").(*dbsqlc.Tenant)
|
|
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
|
|
|
|
s.l.Debug().Msgf("Received register request from ID %s with actions %v", request.WorkerName, request.Actions)
|
|
|
|
svcs := request.Services
|
|
|
|
if len(svcs) == 0 {
|
|
svcs = []string{"default"}
|
|
}
|
|
|
|
opts := &repository.CreateWorkerOpts{
|
|
DispatcherId: s.dispatcherId,
|
|
Name: request.WorkerName,
|
|
Actions: request.Actions,
|
|
Services: svcs,
|
|
WebhookId: request.WebhookId,
|
|
}
|
|
|
|
if request.RuntimeInfo != nil {
|
|
opts.RuntimeInfo = &repository.RuntimeInfo{
|
|
SdkVersion: request.RuntimeInfo.SdkVersion,
|
|
Language: request.RuntimeInfo.Language,
|
|
LanguageVersion: request.RuntimeInfo.LanguageVersion,
|
|
Os: request.RuntimeInfo.Os,
|
|
Extra: request.RuntimeInfo.Extra,
|
|
}
|
|
}
|
|
|
|
if request.MaxRuns != nil {
|
|
mr := int(*request.MaxRuns)
|
|
opts.MaxRuns = &mr
|
|
}
|
|
|
|
if apiErrors, err := s.v.ValidateAPI(opts); err != nil {
|
|
return nil, err
|
|
} else if apiErrors != nil {
|
|
return nil, status.Errorf(codes.InvalidArgument, "Invalid request: %s", apiErrors.String())
|
|
}
|
|
|
|
// create a worker in the database
|
|
worker, err := s.repo.Worker().CreateNewWorker(ctx, tenantId, opts)
|
|
|
|
if err == metered.ErrResourceExhausted {
|
|
return nil, status.Errorf(codes.ResourceExhausted, "resource exhausted: tenant worker limit or concurrency limit exceeded")
|
|
}
|
|
|
|
if err != nil {
|
|
s.l.Error().Err(err).Msgf("could not create worker for tenant %s", tenantId)
|
|
return nil, err
|
|
}
|
|
|
|
workerId := sqlchelpers.UUIDToStr(worker.ID)
|
|
|
|
if request.Labels != nil {
|
|
_, err = s.upsertLabels(ctx, worker.ID, request.Labels)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// return the worker id to the worker
|
|
return &contracts.WorkerRegisterResponse{
|
|
TenantId: tenantId,
|
|
WorkerId: workerId,
|
|
WorkerName: worker.Name,
|
|
}, nil
|
|
}
|
|
|
|
func (s *DispatcherImpl) UpsertWorkerLabels(ctx context.Context, request *contracts.UpsertWorkerLabelsRequest) (*contracts.UpsertWorkerLabelsResponse, error) {
|
|
tenant := ctx.Value("tenant").(*dbsqlc.Tenant)
|
|
|
|
_, err := s.upsertLabels(ctx, sqlchelpers.UUIDFromStr(request.WorkerId), request.Labels)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &contracts.UpsertWorkerLabelsResponse{
|
|
TenantId: sqlchelpers.UUIDToStr(tenant.ID),
|
|
WorkerId: request.WorkerId,
|
|
}, nil
|
|
}
|
|
|
|
func (s *DispatcherImpl) upsertLabels(ctx context.Context, workerId pgtype.UUID, request map[string]*contracts.WorkerLabels) ([]*dbsqlc.WorkerLabel, error) {
|
|
affinities := make([]repository.UpsertWorkerLabelOpts, 0, len(request))
|
|
|
|
for key, config := range request {
|
|
|
|
err := s.v.Validate(config)
|
|
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.InvalidArgument, "Invalid affinity config: %s", err.Error())
|
|
}
|
|
|
|
affinities = append(affinities, repository.UpsertWorkerLabelOpts{
|
|
Key: key,
|
|
IntValue: config.IntValue,
|
|
StrValue: config.StrValue,
|
|
})
|
|
}
|
|
|
|
res, err := s.repo.Worker().UpsertWorkerLabels(ctx, workerId, affinities)
|
|
|
|
if err != nil {
|
|
s.l.Error().Err(err).Msgf("could not upsert worker affinities for worker %s", sqlchelpers.UUIDToStr(workerId))
|
|
return nil, err
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
// Subscribe handles a subscribe request from a client
|
|
func (s *DispatcherImpl) Listen(request *contracts.WorkerListenRequest, stream contracts.Dispatcher_ListenServer) error {
|
|
tenant := stream.Context().Value("tenant").(*dbsqlc.Tenant)
|
|
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
|
|
sessionId := uuid.New().String()
|
|
|
|
s.l.Debug().Msgf("Received subscribe request from ID: %s", request.WorkerId)
|
|
|
|
ctx := stream.Context()
|
|
|
|
worker, err := s.repo.Worker().GetWorkerForEngine(ctx, tenantId, request.WorkerId)
|
|
|
|
if err != nil {
|
|
s.l.Error().Err(err).Msgf("could not get worker %s", request.WorkerId)
|
|
return err
|
|
}
|
|
|
|
shouldUpdateDispatcherId := !worker.DispatcherId.Valid || sqlchelpers.UUIDToStr(worker.DispatcherId) != s.dispatcherId
|
|
|
|
// check the worker's dispatcher against the current dispatcher. if they don't match, then update the worker
|
|
if shouldUpdateDispatcherId {
|
|
_, err = s.repo.Worker().UpdateWorker(ctx, tenantId, request.WorkerId, &repository.UpdateWorkerOpts{
|
|
DispatcherId: &s.dispatcherId,
|
|
})
|
|
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil
|
|
}
|
|
|
|
s.l.Error().Err(err).Msgf("could not update worker %s dispatcher", request.WorkerId)
|
|
return err
|
|
}
|
|
}
|
|
|
|
fin := make(chan bool)
|
|
|
|
s.workers.Add(request.WorkerId, sessionId, newSubscribedWorker(stream, fin, request.WorkerId, 20))
|
|
|
|
defer func() {
|
|
// non-blocking send
|
|
select {
|
|
case fin <- true:
|
|
default:
|
|
}
|
|
|
|
s.workers.DeleteForSession(request.WorkerId, sessionId)
|
|
}()
|
|
|
|
// update the worker with a last heartbeat time every 5 seconds as long as the worker is connected
|
|
go func() {
|
|
timer := time.NewTicker(100 * time.Millisecond)
|
|
|
|
// set the last heartbeat to 6 seconds ago so the first heartbeat is sent immediately
|
|
lastHeartbeat := time.Now().UTC().Add(-6 * time.Second)
|
|
defer timer.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
s.l.Debug().Msgf("worker id %s has disconnected", request.WorkerId)
|
|
return
|
|
case <-fin:
|
|
s.l.Debug().Msgf("closing stream for worker id: %s", request.WorkerId)
|
|
return
|
|
case <-timer.C:
|
|
if now := time.Now().UTC(); lastHeartbeat.Add(4 * time.Second).Before(now) {
|
|
s.l.Debug().Msgf("updating worker %s heartbeat", request.WorkerId)
|
|
|
|
_, err := s.repo.Worker().UpdateWorker(ctx, tenantId, request.WorkerId, &repository.UpdateWorkerOpts{
|
|
LastHeartbeatAt: &now,
|
|
IsActive: repository.BoolPtr(true),
|
|
})
|
|
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return
|
|
}
|
|
|
|
s.l.Error().Err(err).Msgf("could not update worker %s heartbeat", request.WorkerId)
|
|
return
|
|
}
|
|
|
|
lastHeartbeat = time.Now().UTC()
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Keep the connection alive for sending messages
|
|
for {
|
|
select {
|
|
case <-fin:
|
|
s.l.Debug().Msgf("closing stream for worker id: %s", request.WorkerId)
|
|
return nil
|
|
case <-ctx.Done():
|
|
s.l.Debug().Msgf("worker id %s has disconnected", request.WorkerId)
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// ListenV2 is like Listen, but implementation does not include heartbeats. This should only used by SDKs
|
|
// against engine version v0.18.1+
|
|
func (s *DispatcherImpl) ListenV2(request *contracts.WorkerListenRequest, stream contracts.Dispatcher_ListenV2Server) error {
|
|
tenant := stream.Context().Value("tenant").(*dbsqlc.Tenant)
|
|
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
|
|
sessionId := uuid.New().String()
|
|
|
|
ctx := stream.Context()
|
|
|
|
s.l.Debug().Msgf("Received subscribe request from ID: %s", request.WorkerId)
|
|
|
|
worker, err := s.repo.Worker().GetWorkerForEngine(ctx, tenantId, request.WorkerId)
|
|
|
|
if err != nil {
|
|
s.l.Error().Err(err).Msgf("could not get worker %s", request.WorkerId)
|
|
return err
|
|
}
|
|
|
|
shouldUpdateDispatcherId := !worker.DispatcherId.Valid || sqlchelpers.UUIDToStr(worker.DispatcherId) != s.dispatcherId
|
|
|
|
// check the worker's dispatcher against the current dispatcher. if they don't match, then update the worker
|
|
if shouldUpdateDispatcherId {
|
|
_, err = s.repo.Worker().UpdateWorker(ctx, tenantId, request.WorkerId, &repository.UpdateWorkerOpts{
|
|
DispatcherId: &s.dispatcherId,
|
|
})
|
|
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil
|
|
}
|
|
|
|
s.l.Error().Err(err).Msgf("could not update worker %s dispatcher", request.WorkerId)
|
|
return err
|
|
}
|
|
}
|
|
|
|
sessionEstablished := time.Now().UTC()
|
|
|
|
_, err = s.repo.Worker().UpdateWorkerActiveStatus(ctx, tenantId, request.WorkerId, true, sessionEstablished)
|
|
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil
|
|
}
|
|
|
|
lastSessionEstablished := "NULL"
|
|
|
|
if worker.LastListenerEstablished.Valid {
|
|
lastSessionEstablished = worker.LastListenerEstablished.Time.String()
|
|
}
|
|
|
|
s.l.Error().Err(err).Msgf("could not update worker %s active status to true (session established %s, last session established %s)", request.WorkerId, sessionEstablished.String(), lastSessionEstablished)
|
|
return err
|
|
}
|
|
|
|
fin := make(chan bool)
|
|
|
|
s.workers.Add(request.WorkerId, sessionId, newSubscribedWorker(stream, fin, request.WorkerId, s.defaultMaxWorkerBacklogSize))
|
|
|
|
defer func() {
|
|
// non-blocking send
|
|
select {
|
|
case fin <- true:
|
|
default:
|
|
}
|
|
|
|
s.workers.DeleteForSession(request.WorkerId, sessionId)
|
|
}()
|
|
|
|
// Keep the connection alive for sending messages
|
|
for {
|
|
select {
|
|
case <-fin:
|
|
s.l.Debug().Msgf("closing stream for worker id: %s", request.WorkerId)
|
|
|
|
_, err = s.repo.Worker().UpdateWorkerActiveStatus(ctx, tenantId, request.WorkerId, false, sessionEstablished)
|
|
|
|
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
|
|
s.l.Error().Err(err).Msgf("could not update worker %s active status to false due to worker stream closing (session established %s)", request.WorkerId, sessionEstablished.String())
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
case <-ctx.Done():
|
|
s.l.Debug().Msgf("worker id %s has disconnected", request.WorkerId)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
|
defer cancel()
|
|
|
|
_, err = s.repo.Worker().UpdateWorkerActiveStatus(ctx, tenantId, request.WorkerId, false, sessionEstablished)
|
|
|
|
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
|
|
s.l.Error().Err(err).Msgf("could not update worker %s active status due to worker disconnecting (session established %s)", request.WorkerId, sessionEstablished.String())
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
const HeartbeatInterval = 4 * time.Second
|
|
|
|
// Heartbeat is used to update the last heartbeat time for a worker
|
|
func (s *DispatcherImpl) Heartbeat(ctx context.Context, req *contracts.HeartbeatRequest) (*contracts.HeartbeatResponse, error) {
|
|
ctx, span := telemetry.NewSpan(ctx, "update-worker-heartbeat")
|
|
defer span.End()
|
|
|
|
tenant := ctx.Value("tenant").(*dbsqlc.Tenant)
|
|
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
|
|
|
|
heartbeatAt := time.Now().UTC()
|
|
|
|
s.l.Debug().Msgf("Received heartbeat request from ID: %s", req.WorkerId)
|
|
|
|
// if heartbeat time is greater than expected heartbeat interval, show a warning
|
|
if req.HeartbeatAt.AsTime().Before(heartbeatAt.Add(-1 * HeartbeatInterval)) {
|
|
s.l.Warn().Msgf("heartbeat time is greater than expected heartbeat interval")
|
|
}
|
|
|
|
worker, err := s.repo.Worker().GetWorkerForEngine(ctx, tenantId, req.WorkerId)
|
|
|
|
if err != nil {
|
|
span.RecordError(err)
|
|
span.SetStatus(telemetry_codes.Error, "could not get worker")
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, status.Errorf(codes.NotFound, "worker not found: %s", req.WorkerId)
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
// if the worker is not active, the listener should reconnect
|
|
if worker.LastListenerEstablished.Valid && !worker.IsActive {
|
|
span.RecordError(err)
|
|
span.SetStatus(telemetry_codes.Error, "worker stream is not active")
|
|
return nil, status.Errorf(codes.FailedPrecondition, "Heartbeat rejected: worker stream is not active: %s", req.WorkerId)
|
|
}
|
|
|
|
err = s.repo.Worker().UpdateWorkerHeartbeat(ctx, tenantId, req.WorkerId, heartbeatAt)
|
|
|
|
if err != nil {
|
|
span.RecordError(err)
|
|
span.SetStatus(telemetry_codes.Error, "could not update worker heartbeat")
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
s.l.Error().Msgf("could not update worker heartbeat: worker %s not found", req.WorkerId)
|
|
return nil, err
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
return &contracts.HeartbeatResponse{}, nil
|
|
}
|
|
|
|
func (s *DispatcherImpl) ReleaseSlot(ctx context.Context, req *contracts.ReleaseSlotRequest) (*contracts.ReleaseSlotResponse, error) {
|
|
tenant := ctx.Value("tenant").(*dbsqlc.Tenant)
|
|
|
|
switch tenant.Version {
|
|
case dbsqlc.TenantMajorEngineVersionV0:
|
|
return s.releaseSlotV0(ctx, tenant, req)
|
|
case dbsqlc.TenantMajorEngineVersionV1:
|
|
return s.releaseSlotV1(ctx, tenant, req)
|
|
default:
|
|
return nil, status.Errorf(codes.Unimplemented, "ReleaseSlot is not implemented in engine version %s", string(tenant.Version))
|
|
}
|
|
}
|
|
|
|
func (s *DispatcherImpl) releaseSlotV0(ctx context.Context, tenant *dbsqlc.Tenant, req *contracts.ReleaseSlotRequest) (*contracts.ReleaseSlotResponse, error) {
|
|
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
|
|
|
|
if req.StepRunId == "" {
|
|
return nil, fmt.Errorf("step run id is required")
|
|
}
|
|
|
|
err := s.repo.StepRun().ReleaseStepRunSemaphore(ctx, tenantId, req.StepRunId, true)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &contracts.ReleaseSlotResponse{}, nil
|
|
}
|
|
|
|
func (s *DispatcherImpl) SubscribeToWorkflowEvents(request *contracts.SubscribeToWorkflowEventsRequest, stream contracts.Dispatcher_SubscribeToWorkflowEventsServer) error {
|
|
tenant := stream.Context().Value("tenant").(*dbsqlc.Tenant)
|
|
|
|
switch tenant.Version {
|
|
case dbsqlc.TenantMajorEngineVersionV0:
|
|
return s.subscribeToWorkflowEventsV0(request, stream)
|
|
case dbsqlc.TenantMajorEngineVersionV1:
|
|
return s.subscribeToWorkflowEventsV1(request, stream)
|
|
default:
|
|
return status.Errorf(codes.Unimplemented, "SubscribeToWorkflowEvents is not implemented in engine version %s", string(tenant.Version))
|
|
}
|
|
}
|
|
|
|
func (s *DispatcherImpl) subscribeToWorkflowEventsV0(request *contracts.SubscribeToWorkflowEventsRequest, stream contracts.Dispatcher_SubscribeToWorkflowEventsServer) error {
|
|
if request.WorkflowRunId != nil {
|
|
return s.subscribeToWorkflowEventsByWorkflowRunId(*request.WorkflowRunId, stream)
|
|
} else if request.AdditionalMetaKey != nil && request.AdditionalMetaValue != nil {
|
|
return s.subscribeToWorkflowEventsByAdditionalMeta(*request.AdditionalMetaKey, *request.AdditionalMetaValue, stream)
|
|
}
|
|
|
|
return status.Errorf(codes.InvalidArgument, "either workflow run id or additional meta key-value must be provided")
|
|
}
|
|
|
|
// SubscribeToWorkflowEvents registers workflow events with the dispatcher
|
|
func (s *DispatcherImpl) subscribeToWorkflowEventsByAdditionalMeta(key string, value string, stream contracts.Dispatcher_SubscribeToWorkflowEventsServer) error {
|
|
tenant := stream.Context().Value("tenant").(*dbsqlc.Tenant)
|
|
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
|
|
|
|
ctx, cancel := context.WithCancel(stream.Context())
|
|
defer cancel()
|
|
|
|
wg := sync.WaitGroup{}
|
|
|
|
// Keep track of active workflow run IDs
|
|
activeRunIds := make(map[string]struct{})
|
|
var mu sync.Mutex // Mutex to protect activeRunIds
|
|
var sendMu sync.Mutex // Mutex to protect sending messages
|
|
|
|
f := func(task *msgqueue.Message) error {
|
|
wg.Add(1)
|
|
defer wg.Done()
|
|
|
|
e, err := s.tenantTaskToWorkflowEventByAdditionalMeta(
|
|
task, tenantId, key, value,
|
|
func(e *contracts.WorkflowEvent) (bool, error) {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
if e.WorkflowRunId == "" {
|
|
return false, nil
|
|
}
|
|
|
|
isWorkflowRunCompletedEvent := e.ResourceType == contracts.ResourceType_RESOURCE_TYPE_WORKFLOW_RUN &&
|
|
e.EventType == contracts.ResourceEventType_RESOURCE_EVENT_TYPE_COMPLETED
|
|
|
|
if !isWorkflowRunCompletedEvent {
|
|
// Add the run ID to active runs
|
|
activeRunIds[e.WorkflowRunId] = struct{}{}
|
|
} else {
|
|
// Remove the completed run from active runs
|
|
delete(activeRunIds, e.WorkflowRunId)
|
|
}
|
|
|
|
// Only return true to hang up if we've seen at least one run and all runs are completed
|
|
if len(activeRunIds) == 0 {
|
|
return true, nil
|
|
}
|
|
|
|
return false, nil
|
|
})
|
|
|
|
if err != nil {
|
|
s.l.Error().Err(err).Msgf("could not convert task to workflow event")
|
|
return nil
|
|
} else if e == nil || e.WorkflowRunId == "" {
|
|
return nil
|
|
}
|
|
|
|
// send the task to the client
|
|
sendMu.Lock()
|
|
err = stream.Send(e)
|
|
sendMu.Unlock()
|
|
|
|
if err != nil {
|
|
cancel() // FIXME is this necessary?
|
|
s.l.Error().Err(err).Msgf("could not send workflow event to client")
|
|
return nil
|
|
}
|
|
|
|
if e.Hangup {
|
|
cancel()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// subscribe to the task queue for the tenant
|
|
cleanupQueue, err := s.sharedReader.Subscribe(tenantId, f)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
<-ctx.Done()
|
|
if err := cleanupQueue(); err != nil {
|
|
return fmt.Errorf("could not cleanup queue: %w", err)
|
|
}
|
|
|
|
waitFor(&wg, 60*time.Second, s.l)
|
|
|
|
return nil
|
|
}
|
|
|
|
// SubscribeToWorkflowEvents registers workflow events with the dispatcher
|
|
func (s *DispatcherImpl) subscribeToWorkflowEventsByWorkflowRunId(workflowRunId string, stream contracts.Dispatcher_SubscribeToWorkflowEventsServer) error {
|
|
tenant := stream.Context().Value("tenant").(*dbsqlc.Tenant)
|
|
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
|
|
|
|
s.l.Debug().Msgf("Received subscribe request for workflow: %s", workflowRunId)
|
|
|
|
ctx, cancel := context.WithCancel(stream.Context())
|
|
defer cancel()
|
|
|
|
// if the workflow run is in a final state, hang up the connection
|
|
workflowRun, err := s.repo.WorkflowRun().GetWorkflowRunById(ctx, tenantId, workflowRunId)
|
|
|
|
if err != nil {
|
|
if errors.Is(err, repository.ErrWorkflowRunNotFound) {
|
|
return status.Errorf(codes.NotFound, "workflow run %s not found", workflowRunId)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
if repository.IsFinalWorkflowRunStatus(workflowRun.WorkflowRun.Status) {
|
|
return nil
|
|
}
|
|
|
|
wg := sync.WaitGroup{}
|
|
|
|
sendMu := sync.Mutex{}
|
|
|
|
f := func(task *msgqueue.Message) error {
|
|
wg.Add(1)
|
|
defer wg.Done()
|
|
|
|
e, err := s.tenantTaskToWorkflowEventByRunId(task, tenantId, workflowRunId)
|
|
|
|
if err != nil {
|
|
s.l.Error().Err(err).Msgf("could not convert task to workflow event")
|
|
return nil
|
|
} else if e == nil {
|
|
return nil
|
|
}
|
|
|
|
// send the task to the client
|
|
sendMu.Lock()
|
|
err = stream.Send(e)
|
|
sendMu.Unlock()
|
|
|
|
if err != nil {
|
|
cancel() // FIXME is this necessary?
|
|
s.l.Error().Err(err).Msgf("could not send workflow event to client")
|
|
return nil
|
|
}
|
|
|
|
if e.Hangup {
|
|
cancel()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// subscribe to the task queue for the tenant
|
|
cleanupQueue, err := s.sharedReader.Subscribe(tenantId, f)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
<-ctx.Done()
|
|
if err := cleanupQueue(); err != nil {
|
|
return fmt.Errorf("could not cleanup queue: %w", err)
|
|
}
|
|
|
|
waitFor(&wg, 60*time.Second, s.l)
|
|
|
|
return nil
|
|
}
|
|
|
|
// map of workflow run ids to whether the workflow runs are finished and have sent a message
|
|
// that the workflow run is finished
|
|
type workflowRunAcks struct {
|
|
acks map[string]bool
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
func (w *workflowRunAcks) addWorkflowRun(id string) {
|
|
w.mu.Lock()
|
|
defer w.mu.Unlock()
|
|
|
|
w.acks[id] = false
|
|
}
|
|
|
|
func (w *workflowRunAcks) getNonAckdWorkflowRuns() []string {
|
|
w.mu.RLock()
|
|
defer w.mu.RUnlock()
|
|
|
|
ids := make([]string, 0, len(w.acks))
|
|
|
|
for id := range w.acks {
|
|
if !w.acks[id] {
|
|
ids = append(ids, id)
|
|
}
|
|
}
|
|
|
|
return ids
|
|
}
|
|
|
|
func (w *workflowRunAcks) ackWorkflowRun(id string) {
|
|
w.mu.Lock()
|
|
defer w.mu.Unlock()
|
|
|
|
delete(w.acks, id)
|
|
}
|
|
|
|
func (w *workflowRunAcks) hasWorkflowRun(id string) bool {
|
|
w.mu.RLock()
|
|
defer w.mu.RUnlock()
|
|
|
|
_, ok := w.acks[id]
|
|
return ok
|
|
}
|
|
|
|
type sendTimeFilter struct {
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func (s *sendTimeFilter) canSend() bool {
|
|
if !s.mu.TryLock() {
|
|
return false
|
|
}
|
|
|
|
go func() {
|
|
time.Sleep(10 * time.Millisecond)
|
|
s.mu.Unlock()
|
|
}()
|
|
|
|
return true
|
|
}
|
|
|
|
func (d *DispatcherImpl) cleanResults(results []*contracts.StepRunResult) []*contracts.StepRunResult {
|
|
totalSize, sizeOfOutputs, _ := calculateResultsSize(results)
|
|
|
|
if totalSize < d.payloadSizeThreshold {
|
|
return results
|
|
}
|
|
|
|
if sizeOfOutputs >= d.payloadSizeThreshold {
|
|
return nil
|
|
}
|
|
|
|
// otherwise, attempt to clean the results by removing large error fields
|
|
cleanedResults := make([]*contracts.StepRunResult, 0, len(results))
|
|
|
|
fieldThreshold := (d.payloadSizeThreshold - sizeOfOutputs) / len(results) // how much overhead we'd have per result or error field, in the worst case
|
|
|
|
for _, result := range results {
|
|
if result == nil {
|
|
continue
|
|
}
|
|
|
|
// we only try to clean the error field at the moment, as modifying the output is more risky
|
|
if result.Error != nil && len(*result.Error) > fieldThreshold {
|
|
result.Error = repository.StringPtr("Error is too large to send over the Hatchet stream.")
|
|
}
|
|
|
|
cleanedResults = append(cleanedResults, result)
|
|
}
|
|
|
|
// if we are still over the limit, we just return nil
|
|
if totalSize, _, _ := calculateResultsSize(cleanedResults); totalSize > d.payloadSizeThreshold {
|
|
return nil
|
|
}
|
|
|
|
return cleanedResults
|
|
}
|
|
|
|
func calculateResultsSize(results []*contracts.StepRunResult) (totalSize int, sizeOfOutputs int, sizeOfErrors int) {
|
|
for _, result := range results {
|
|
if result != nil && result.Output != nil {
|
|
totalSize += (len(*result.Output))
|
|
sizeOfOutputs += (len(*result.Output))
|
|
}
|
|
|
|
if result != nil && result.Error != nil {
|
|
totalSize += (len(*result.Error))
|
|
sizeOfErrors += (len(*result.Error))
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (s *DispatcherImpl) SubscribeToWorkflowRuns(server contracts.Dispatcher_SubscribeToWorkflowRunsServer) error {
|
|
tenant := server.Context().Value("tenant").(*dbsqlc.Tenant)
|
|
|
|
switch tenant.Version {
|
|
case dbsqlc.TenantMajorEngineVersionV0:
|
|
return s.subscribeToWorkflowRunsV0(server)
|
|
case dbsqlc.TenantMajorEngineVersionV1:
|
|
return s.subscribeToWorkflowRunsV1(server)
|
|
default:
|
|
return status.Errorf(codes.Unimplemented, "SubscribeToWorkflowRuns is not implemented in engine version %s", string(tenant.Version))
|
|
}
|
|
}
|
|
|
|
// SubscribeToWorkflowEvents registers workflow events with the dispatcher
|
|
func (s *DispatcherImpl) subscribeToWorkflowRunsV0(server contracts.Dispatcher_SubscribeToWorkflowRunsServer) error {
|
|
tenant := server.Context().Value("tenant").(*dbsqlc.Tenant)
|
|
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
|
|
|
|
s.l.Debug().Msgf("Received subscribe request for tenant: %s", tenantId)
|
|
|
|
acks := &workflowRunAcks{
|
|
acks: make(map[string]bool),
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(server.Context())
|
|
defer cancel()
|
|
|
|
wg := sync.WaitGroup{}
|
|
sendMu := sync.Mutex{}
|
|
|
|
sendEvent := func(e *contracts.WorkflowRunEvent) error {
|
|
results := s.cleanResults(e.Results)
|
|
|
|
if results == nil {
|
|
s.l.Warn().Msgf("results size for workflow run %s exceeds 3MB and cannot be reduced", e.WorkflowRunId)
|
|
e.Results = nil
|
|
}
|
|
|
|
// send the task to the client
|
|
sendMu.Lock()
|
|
err := server.Send(e)
|
|
sendMu.Unlock()
|
|
|
|
if err != nil {
|
|
s.l.Error().Err(err).Msgf("could not subscribe to workflow events for run %s", e.WorkflowRunId)
|
|
return err
|
|
}
|
|
|
|
acks.ackWorkflowRun(e.WorkflowRunId)
|
|
|
|
return nil
|
|
}
|
|
|
|
immediateSendFilter := &sendTimeFilter{}
|
|
iterSendFilter := &sendTimeFilter{}
|
|
|
|
iter := func(workflowRunIds []string) error {
|
|
limit := 1000
|
|
|
|
workflowRuns, err := s.repo.WorkflowRun().ListWorkflowRuns(ctx, tenantId, &repository.ListWorkflowRunsOpts{
|
|
Ids: workflowRunIds,
|
|
Limit: &limit,
|
|
})
|
|
|
|
if err != nil {
|
|
s.l.Error().Err(err).Msg("could not get workflow runs")
|
|
return nil
|
|
}
|
|
|
|
events, err := s.toWorkflowRunEvent(tenantId, workflowRuns.Rows)
|
|
|
|
if err != nil {
|
|
s.l.Error().Err(err).Msg("could not convert workflow run to event")
|
|
return nil
|
|
} else if events == nil {
|
|
return nil
|
|
}
|
|
|
|
for _, event := range events {
|
|
err := sendEvent(event)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// start a new goroutine to handle client-side streaming
|
|
go func() {
|
|
for {
|
|
req, err := server.Recv()
|
|
|
|
if err != nil {
|
|
cancel()
|
|
if errors.Is(err, io.EOF) || status.Code(err) == codes.Canceled {
|
|
return
|
|
}
|
|
|
|
s.l.Error().Err(err).Msg("could not receive message from client")
|
|
return
|
|
}
|
|
|
|
if _, parseErr := uuid.Parse(req.WorkflowRunId); parseErr != nil {
|
|
s.l.Warn().Err(parseErr).Msg("invalid workflow run id")
|
|
continue
|
|
}
|
|
|
|
acks.addWorkflowRun(req.WorkflowRunId)
|
|
|
|
if immediateSendFilter.canSend() {
|
|
if err := iter([]string{req.WorkflowRunId}); err != nil {
|
|
s.l.Error().Err(err).Msg("could not iterate over workflow runs")
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
f := func(task *msgqueue.Message) error {
|
|
wg.Add(1)
|
|
defer wg.Done()
|
|
|
|
workflowRunIds := acks.getNonAckdWorkflowRuns()
|
|
|
|
if matchedWorkflowRunId, ok := s.isMatchingWorkflowRun(task, workflowRunIds...); ok {
|
|
if immediateSendFilter.canSend() {
|
|
if err := iter([]string{matchedWorkflowRunId}); err != nil {
|
|
s.l.Error().Err(err).Msg("could not iterate over workflow runs")
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// subscribe to the task queue for the tenant
|
|
cleanupQueue, err := s.sharedReader.Subscribe(tenantId, f)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// new goroutine to poll every second for finished workflow runs which are not ackd
|
|
go func() {
|
|
ticker := time.NewTicker(1 * time.Second)
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
if !iterSendFilter.canSend() {
|
|
continue
|
|
}
|
|
|
|
workflowRunIds := acks.getNonAckdWorkflowRuns()
|
|
|
|
if len(workflowRunIds) == 0 {
|
|
continue
|
|
}
|
|
|
|
if err := iter(workflowRunIds); err != nil {
|
|
s.l.Error().Err(err).Msg("could not iterate over workflow runs")
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
<-ctx.Done()
|
|
|
|
if err := cleanupQueue(); err != nil {
|
|
return fmt.Errorf("could not cleanup queue: %w", err)
|
|
}
|
|
|
|
waitFor(&wg, 60*time.Second, s.l)
|
|
|
|
return nil
|
|
}
|
|
|
|
func waitFor(wg *sync.WaitGroup, timeout time.Duration, l *zerolog.Logger) {
|
|
done := make(chan struct{})
|
|
|
|
go func() {
|
|
wg.Wait()
|
|
defer close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(timeout):
|
|
l.Error().Msg("timed out waiting for wait group")
|
|
}
|
|
}
|
|
|
|
func (s *DispatcherImpl) SendStepActionEvent(ctx context.Context, request *contracts.StepActionEvent) (*contracts.ActionEventResponse, error) {
|
|
tenant := ctx.Value("tenant").(*dbsqlc.Tenant)
|
|
|
|
switch tenant.Version {
|
|
case dbsqlc.TenantMajorEngineVersionV0:
|
|
return s.sendStepActionEventV0(ctx, request)
|
|
case dbsqlc.TenantMajorEngineVersionV1:
|
|
return s.sendStepActionEventV1(ctx, request)
|
|
default:
|
|
return nil, status.Errorf(codes.Unimplemented, "SendStepActionEvent is not implemented in engine version %s", string(tenant.Version))
|
|
}
|
|
}
|
|
|
|
func (s *DispatcherImpl) sendStepActionEventV0(ctx context.Context, request *contracts.StepActionEvent) (*contracts.ActionEventResponse, error) {
|
|
switch request.EventType {
|
|
case contracts.StepActionEventType_STEP_EVENT_TYPE_STARTED:
|
|
return s.handleStepRunStarted(ctx, request)
|
|
case contracts.StepActionEventType_STEP_EVENT_TYPE_ACKNOWLEDGED:
|
|
return s.handleStepRunAcked(ctx, request)
|
|
case contracts.StepActionEventType_STEP_EVENT_TYPE_COMPLETED:
|
|
return s.handleStepRunCompleted(ctx, request)
|
|
case contracts.StepActionEventType_STEP_EVENT_TYPE_FAILED:
|
|
return s.handleStepRunFailed(ctx, request)
|
|
}
|
|
|
|
return nil, fmt.Errorf("unknown event type %s", request.EventType)
|
|
}
|
|
|
|
func (s *DispatcherImpl) SendGroupKeyActionEvent(ctx context.Context, request *contracts.GroupKeyActionEvent) (*contracts.ActionEventResponse, error) {
|
|
tenant := ctx.Value("tenant").(*dbsqlc.Tenant)
|
|
|
|
if tenant.Version == dbsqlc.TenantMajorEngineVersionV1 {
|
|
return nil, status.Errorf(codes.Unimplemented, "SendGroupKeyActionEvent is not implemented in engine version v1")
|
|
}
|
|
|
|
switch request.EventType {
|
|
case contracts.GroupKeyActionEventType_GROUP_KEY_EVENT_TYPE_STARTED:
|
|
return s.handleGetGroupKeyRunStarted(ctx, request)
|
|
case contracts.GroupKeyActionEventType_GROUP_KEY_EVENT_TYPE_COMPLETED:
|
|
return s.handleGetGroupKeyRunCompleted(ctx, request)
|
|
case contracts.GroupKeyActionEventType_GROUP_KEY_EVENT_TYPE_FAILED:
|
|
return s.handleGetGroupKeyRunFailed(ctx, request)
|
|
}
|
|
|
|
return nil, fmt.Errorf("unknown event type %s", request.EventType)
|
|
}
|
|
|
|
func (s *DispatcherImpl) PutOverridesData(ctx context.Context, request *contracts.OverridesData) (*contracts.OverridesDataResponse, error) {
|
|
tenant := ctx.Value("tenant").(*dbsqlc.Tenant)
|
|
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
|
|
|
|
// ensure step run id
|
|
if request.StepRunId == "" {
|
|
return nil, fmt.Errorf("step run id is required")
|
|
}
|
|
|
|
opts := &repository.UpdateStepRunOverridesDataOpts{
|
|
OverrideKey: request.Path,
|
|
Data: []byte(request.Value),
|
|
}
|
|
|
|
if request.CallerFilename != "" {
|
|
opts.CallerFile = &request.CallerFilename
|
|
}
|
|
|
|
_, err := s.repo.StepRun().UpdateStepRunOverridesData(ctx, tenantId, request.StepRunId, opts)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &contracts.OverridesDataResponse{}, nil
|
|
}
|
|
|
|
func (s *DispatcherImpl) Unsubscribe(ctx context.Context, request *contracts.WorkerUnsubscribeRequest) (*contracts.WorkerUnsubscribeResponse, error) {
|
|
tenant := ctx.Value("tenant").(*dbsqlc.Tenant)
|
|
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
|
|
|
|
// remove the worker from the connection pool
|
|
s.workers.Delete(request.WorkerId)
|
|
|
|
return &contracts.WorkerUnsubscribeResponse{
|
|
TenantId: tenantId,
|
|
WorkerId: request.WorkerId,
|
|
}, nil
|
|
}
|
|
|
|
func (d *DispatcherImpl) RefreshTimeout(ctx context.Context, request *contracts.RefreshTimeoutRequest) (*contracts.RefreshTimeoutResponse, error) {
|
|
tenant := ctx.Value("tenant").(*dbsqlc.Tenant)
|
|
|
|
switch tenant.Version {
|
|
case dbsqlc.TenantMajorEngineVersionV0:
|
|
return d.refreshTimeoutV0(ctx, tenant, request)
|
|
case dbsqlc.TenantMajorEngineVersionV1:
|
|
return d.refreshTimeoutV1(ctx, tenant, request)
|
|
default:
|
|
return nil, status.Errorf(codes.Unimplemented, "RefreshTimeout is not implemented in engine version %s", string(tenant.Version))
|
|
}
|
|
}
|
|
|
|
func (d *DispatcherImpl) refreshTimeoutV0(ctx context.Context, tenant *dbsqlc.Tenant, request *contracts.RefreshTimeoutRequest) (*contracts.RefreshTimeoutResponse, error) {
|
|
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
|
|
|
|
opts := repository.RefreshTimeoutBy{
|
|
IncrementTimeoutBy: request.IncrementTimeoutBy,
|
|
}
|
|
|
|
if apiErrors, err := d.v.ValidateAPI(opts); err != nil {
|
|
return nil, err
|
|
} else if apiErrors != nil {
|
|
return nil, status.Errorf(codes.InvalidArgument, "Invalid request: %s", apiErrors.String())
|
|
}
|
|
|
|
pgTimeoutAt, err := d.repo.StepRun().RefreshTimeoutBy(ctx, tenantId, request.StepRunId, opts)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
timeoutAt := ×tamppb.Timestamp{
|
|
Seconds: pgTimeoutAt.Time.Unix(),
|
|
Nanos: int32(pgTimeoutAt.Time.Nanosecond()), // nolint:gosec
|
|
}
|
|
|
|
return &contracts.RefreshTimeoutResponse{
|
|
TimeoutAt: timeoutAt,
|
|
}, nil
|
|
}
|
|
|
|
func (s *DispatcherImpl) handleStepRunStarted(inputCtx context.Context, request *contracts.StepActionEvent) (*contracts.ActionEventResponse, error) {
|
|
tenant := inputCtx.Value("tenant").(*dbsqlc.Tenant)
|
|
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
|
|
|
|
// run the rest on a separate context to always send to job controller
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
s.l.Debug().Msgf("Received step started event for step run %s", request.StepRunId)
|
|
|
|
startedAt := request.EventTimestamp.AsTime()
|
|
|
|
sr, err := s.repo.StepRun().GetStepRunForEngine(ctx, tenantId, request.StepRunId)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = s.repo.StepRun().StepRunStarted(ctx, tenantId, sqlchelpers.UUIDToStr(sr.WorkflowRunId), request.StepRunId, startedAt)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not mark step run started: %w", err)
|
|
}
|
|
|
|
payload, _ := datautils.ToJSONMap(tasktypes.StepRunStartedTaskPayload{
|
|
StepRunId: request.StepRunId,
|
|
StartedAt: startedAt.Format(time.RFC3339),
|
|
WorkflowRunId: sqlchelpers.UUIDToStr(sr.WorkflowRunId),
|
|
StepRetries: &sr.StepRetries,
|
|
RetryCount: &sr.SRRetryCount,
|
|
})
|
|
|
|
metadata, _ := datautils.ToJSONMap(tasktypes.StepRunStartedTaskMetadata{
|
|
TenantId: tenantId,
|
|
})
|
|
|
|
// we send the event directly to the tenant's event queue
|
|
err = s.mq.AddMessage(ctx, msgqueue.TenantEventConsumerQueue(tenantId), &msgqueue.Message{
|
|
ID: "step-run-started",
|
|
Payload: payload,
|
|
Metadata: metadata,
|
|
Retries: 3,
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &contracts.ActionEventResponse{
|
|
TenantId: tenantId,
|
|
WorkerId: request.WorkerId,
|
|
}, nil
|
|
}
|
|
|
|
func (s *DispatcherImpl) handleStepRunAcked(inputCtx context.Context, request *contracts.StepActionEvent) (*contracts.ActionEventResponse, error) {
|
|
tenant := inputCtx.Value("tenant").(*dbsqlc.Tenant)
|
|
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
|
|
|
|
// run the rest on a separate context to always send to job controller
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
s.l.Debug().Msgf("Received step ack event for step run %s", request.StepRunId)
|
|
|
|
startedAt := request.EventTimestamp.AsTime()
|
|
|
|
sr, err := s.repo.StepRun().GetStepRunForEngine(ctx, tenantId, request.StepRunId)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
payload, _ := datautils.ToJSONMap(tasktypes.StepRunStartedTaskPayload{
|
|
StepRunId: request.StepRunId,
|
|
StartedAt: startedAt.Format(time.RFC3339),
|
|
WorkflowRunId: sqlchelpers.UUIDToStr(sr.WorkflowRunId),
|
|
StepRetries: &sr.StepRetries,
|
|
RetryCount: &sr.SRRetryCount,
|
|
})
|
|
|
|
metadata, _ := datautils.ToJSONMap(tasktypes.StepRunStartedTaskMetadata{
|
|
TenantId: tenantId,
|
|
})
|
|
|
|
// send the event to the jobs queue
|
|
err = s.mq.AddMessage(ctx, msgqueue.JOB_PROCESSING_QUEUE, &msgqueue.Message{
|
|
ID: "step-run-acked",
|
|
Payload: payload,
|
|
Metadata: metadata,
|
|
Retries: 3,
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &contracts.ActionEventResponse{
|
|
TenantId: tenantId,
|
|
WorkerId: request.WorkerId,
|
|
}, nil
|
|
}
|
|
|
|
func (s *DispatcherImpl) handleStepRunCompleted(inputCtx context.Context, request *contracts.StepActionEvent) (*contracts.ActionEventResponse, error) {
|
|
tenant := inputCtx.Value("tenant").(*dbsqlc.Tenant)
|
|
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
|
|
|
|
// run the rest on a separate context to always send to job controller
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
err := s.repo.StepRun().ReleaseStepRunSemaphore(ctx, tenantId, request.StepRunId, false)
|
|
|
|
if err != nil {
|
|
s.l.Error().Err(err).Msgf("could not release semaphore for step run %s", request.StepRunId)
|
|
return nil, err
|
|
}
|
|
|
|
s.l.Debug().Msgf("Received step completed event for step run %s", request.StepRunId)
|
|
|
|
// verify that the event payload can be unmarshalled into a map type
|
|
if request.EventPayload != "" {
|
|
res := make(map[string]interface{})
|
|
|
|
if err := json.Unmarshal([]byte(request.EventPayload), &res); err != nil {
|
|
// if the payload starts with a [, then it is an array which we don't currently support
|
|
if request.EventPayload[0] == '[' {
|
|
return nil, status.Errorf(codes.InvalidArgument, "Return value is an array, which is not supported")
|
|
}
|
|
|
|
// if the payload starts with a \", then it is a string which we don't currently support
|
|
if request.EventPayload[0] == '"' {
|
|
return nil, status.Errorf(codes.InvalidArgument, "Return value is a string, which is not supported")
|
|
}
|
|
|
|
return nil, status.Errorf(codes.InvalidArgument, "Return value is not a valid JSON object")
|
|
}
|
|
}
|
|
|
|
finishedAt := request.EventTimestamp.AsTime()
|
|
|
|
meta, err := s.repo.StepRun().GetStepRunMetaForEngine(ctx, tenantId, request.StepRunId)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
payload, _ := datautils.ToJSONMap(tasktypes.StepRunFinishedTaskPayload{
|
|
WorkflowRunId: sqlchelpers.UUIDToStr(meta.WorkflowRunId),
|
|
StepRunId: request.StepRunId,
|
|
FinishedAt: finishedAt.Format(time.RFC3339),
|
|
StepOutputData: request.EventPayload,
|
|
StepRetries: &meta.Retries,
|
|
RetryCount: &meta.RetryCount,
|
|
})
|
|
|
|
metadata, _ := datautils.ToJSONMap(tasktypes.StepRunFinishedTaskMetadata{
|
|
TenantId: tenantId,
|
|
})
|
|
|
|
// send the event to the jobs queue
|
|
err = s.mq.AddMessage(ctx, msgqueue.JOB_PROCESSING_QUEUE, &msgqueue.Message{
|
|
ID: "step-run-finished",
|
|
Payload: payload,
|
|
Metadata: metadata,
|
|
Retries: 3,
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &contracts.ActionEventResponse{
|
|
TenantId: tenantId,
|
|
WorkerId: request.WorkerId,
|
|
}, nil
|
|
}
|
|
|
|
func (s *DispatcherImpl) handleStepRunFailed(inputCtx context.Context, request *contracts.StepActionEvent) (*contracts.ActionEventResponse, error) {
|
|
tenant := inputCtx.Value("tenant").(*dbsqlc.Tenant)
|
|
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
|
|
|
|
// run the rest on a separate context to always send to job controller
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
err := s.repo.StepRun().ReleaseStepRunSemaphore(ctx, tenantId, request.StepRunId, false)
|
|
|
|
if err != nil {
|
|
s.l.Error().Err(err).Msgf("could not release semaphore for step run %s", request.StepRunId)
|
|
return nil, err
|
|
}
|
|
|
|
s.l.Debug().Msgf("Received step failed event for step run %s", request.StepRunId)
|
|
|
|
failedAt := request.EventTimestamp.AsTime()
|
|
|
|
meta, err := s.repo.StepRun().GetStepRunMetaForEngine(ctx, tenantId, request.StepRunId)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
payload, _ := datautils.ToJSONMap(tasktypes.StepRunFailedTaskPayload{
|
|
WorkflowRunId: sqlchelpers.UUIDToStr(meta.WorkflowRunId),
|
|
StepRunId: request.StepRunId,
|
|
FailedAt: failedAt.Format(time.RFC3339),
|
|
Error: request.EventPayload,
|
|
StepRetries: &meta.Retries,
|
|
RetryCount: &meta.RetryCount,
|
|
})
|
|
|
|
metadata, _ := datautils.ToJSONMap(tasktypes.StepRunFailedTaskMetadata{
|
|
TenantId: tenantId,
|
|
})
|
|
|
|
// send the event to the jobs queue
|
|
err = s.mq.AddMessage(ctx, msgqueue.JOB_PROCESSING_QUEUE, &msgqueue.Message{
|
|
ID: "step-run-failed",
|
|
Payload: payload,
|
|
Metadata: metadata,
|
|
Retries: 3,
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &contracts.ActionEventResponse{
|
|
TenantId: tenantId,
|
|
WorkerId: request.WorkerId,
|
|
}, nil
|
|
}
|
|
|
|
func (s *DispatcherImpl) handleGetGroupKeyRunStarted(inputCtx context.Context, request *contracts.GroupKeyActionEvent) (*contracts.ActionEventResponse, error) {
|
|
tenant := inputCtx.Value("tenant").(*dbsqlc.Tenant)
|
|
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
|
|
|
|
// run the rest on a separate context to always send to job controller
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
s.l.Debug().Msgf("Received step started event for step run %s", request.GetGroupKeyRunId)
|
|
|
|
startedAt := request.EventTimestamp.AsTime()
|
|
|
|
payload, _ := datautils.ToJSONMap(tasktypes.GetGroupKeyRunStartedTaskPayload{
|
|
GetGroupKeyRunId: request.GetGroupKeyRunId,
|
|
StartedAt: startedAt.Format(time.RFC3339),
|
|
})
|
|
|
|
metadata, _ := datautils.ToJSONMap(tasktypes.GetGroupKeyRunStartedTaskMetadata{
|
|
TenantId: tenantId,
|
|
})
|
|
|
|
// send the event to the jobs queue
|
|
err := s.mq.AddMessage(ctx, msgqueue.WORKFLOW_PROCESSING_QUEUE, &msgqueue.Message{
|
|
ID: "get-group-key-run-started",
|
|
Payload: payload,
|
|
Metadata: metadata,
|
|
Retries: 3,
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &contracts.ActionEventResponse{
|
|
TenantId: tenantId,
|
|
WorkerId: request.WorkerId,
|
|
}, nil
|
|
}
|
|
|
|
func (s *DispatcherImpl) handleGetGroupKeyRunCompleted(inputCtx context.Context, request *contracts.GroupKeyActionEvent) (*contracts.ActionEventResponse, error) {
|
|
tenant := inputCtx.Value("tenant").(*dbsqlc.Tenant)
|
|
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
|
|
|
|
// run the rest on a separate context to always send to job controller
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
s.l.Debug().Msgf("Received step completed event for step run %s", request.GetGroupKeyRunId)
|
|
|
|
finishedAt := request.EventTimestamp.AsTime()
|
|
|
|
payload, _ := datautils.ToJSONMap(tasktypes.GetGroupKeyRunFinishedTaskPayload{
|
|
GetGroupKeyRunId: request.GetGroupKeyRunId,
|
|
FinishedAt: finishedAt.Format(time.RFC3339),
|
|
GroupKey: request.EventPayload,
|
|
})
|
|
|
|
metadata, _ := datautils.ToJSONMap(tasktypes.GetGroupKeyRunFinishedTaskMetadata{
|
|
TenantId: tenantId,
|
|
})
|
|
|
|
// send the event to the jobs queue
|
|
err := s.mq.AddMessage(ctx, msgqueue.WORKFLOW_PROCESSING_QUEUE, &msgqueue.Message{
|
|
ID: "get-group-key-run-finished",
|
|
Payload: payload,
|
|
Metadata: metadata,
|
|
Retries: 3,
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &contracts.ActionEventResponse{
|
|
TenantId: tenantId,
|
|
WorkerId: request.WorkerId,
|
|
}, nil
|
|
}
|
|
|
|
func (s *DispatcherImpl) handleGetGroupKeyRunFailed(inputCtx context.Context, request *contracts.GroupKeyActionEvent) (*contracts.ActionEventResponse, error) {
|
|
tenant := inputCtx.Value("tenant").(*dbsqlc.Tenant)
|
|
tenantId := sqlchelpers.UUIDToStr(tenant.ID)
|
|
|
|
// run the rest on a separate context to always send to job controller
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
s.l.Debug().Msgf("Received step failed event for step run %s", request.GetGroupKeyRunId)
|
|
|
|
failedAt := request.EventTimestamp.AsTime()
|
|
|
|
payload, _ := datautils.ToJSONMap(tasktypes.GetGroupKeyRunFailedTaskPayload{
|
|
GetGroupKeyRunId: request.GetGroupKeyRunId,
|
|
FailedAt: failedAt.Format(time.RFC3339),
|
|
Error: request.EventPayload,
|
|
})
|
|
|
|
metadata, _ := datautils.ToJSONMap(tasktypes.GetGroupKeyRunFailedTaskMetadata{
|
|
TenantId: tenantId,
|
|
})
|
|
|
|
// send the event to the jobs queue
|
|
err := s.mq.AddMessage(ctx, msgqueue.WORKFLOW_PROCESSING_QUEUE, &msgqueue.Message{
|
|
ID: "get-group-key-run-failed",
|
|
Payload: payload,
|
|
Metadata: metadata,
|
|
Retries: 3,
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &contracts.ActionEventResponse{
|
|
TenantId: tenantId,
|
|
WorkerId: request.WorkerId,
|
|
}, nil
|
|
}
|
|
|
|
func UnmarshalPayload[T any](payload interface{}) (T, error) {
|
|
var result T
|
|
|
|
// Convert the payload to JSON
|
|
jsonData, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return result, fmt.Errorf("failed to marshal payload: %w", err)
|
|
}
|
|
|
|
// Unmarshal JSON into the desired type
|
|
err = json.Unmarshal(jsonData, &result)
|
|
if err != nil {
|
|
return result, fmt.Errorf("failed to unmarshal payload: %w", err)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (s *DispatcherImpl) taskToWorkflowEvent(task *msgqueue.Message, tenantId string, filter func(task *contracts.WorkflowEvent) (*bool, error), hangupFunc func(task *contracts.WorkflowEvent) (bool, error)) (*contracts.WorkflowEvent, error) {
|
|
workflowEvent := &contracts.WorkflowEvent{}
|
|
|
|
var stepRunId string
|
|
|
|
switch task.ID {
|
|
case "step-run-started":
|
|
payload, err := UnmarshalPayload[tasktypes.StepRunStartedTaskPayload](task.Payload)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
workflowEvent.WorkflowRunId = payload.WorkflowRunId
|
|
workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN
|
|
workflowEvent.ResourceId = stepRunId
|
|
workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_STARTED
|
|
workflowEvent.StepRetries = payload.StepRetries
|
|
workflowEvent.RetryCount = payload.RetryCount
|
|
case "step-run-finished":
|
|
payload, err := UnmarshalPayload[tasktypes.StepRunFinishedTaskPayload](task.Payload)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
workflowEvent.WorkflowRunId = payload.WorkflowRunId
|
|
stepRunId = task.Payload["step_run_id"].(string)
|
|
workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN
|
|
workflowEvent.ResourceId = stepRunId
|
|
workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_COMPLETED
|
|
workflowEvent.EventPayload = payload.StepOutputData
|
|
|
|
workflowEvent.StepRetries = payload.StepRetries
|
|
workflowEvent.RetryCount = payload.RetryCount
|
|
case "step-run-failed":
|
|
payload, err := UnmarshalPayload[tasktypes.StepRunFailedTaskPayload](task.Payload)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
workflowEvent.WorkflowRunId = payload.WorkflowRunId
|
|
stepRunId = payload.StepRunId
|
|
workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN
|
|
workflowEvent.ResourceId = stepRunId
|
|
workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_FAILED
|
|
workflowEvent.EventPayload = payload.Error
|
|
|
|
workflowEvent.StepRetries = payload.StepRetries
|
|
workflowEvent.RetryCount = payload.RetryCount
|
|
case "step-run-cancelled":
|
|
payload, err := UnmarshalPayload[tasktypes.StepRunCancelledTaskPayload](task.Payload)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
workflowEvent.WorkflowRunId = payload.WorkflowRunId
|
|
stepRunId = payload.StepRunId
|
|
workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN
|
|
workflowEvent.ResourceId = stepRunId
|
|
workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_CANCELLED
|
|
|
|
workflowEvent.StepRetries = payload.StepRetries
|
|
workflowEvent.RetryCount = payload.RetryCount
|
|
case "step-run-timed-out":
|
|
payload, err := UnmarshalPayload[tasktypes.StepRunTimedOutTaskPayload](task.Payload)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
workflowEvent.WorkflowRunId = payload.WorkflowRunId
|
|
stepRunId = payload.StepRunId
|
|
workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN
|
|
workflowEvent.ResourceId = stepRunId
|
|
workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_TIMED_OUT
|
|
|
|
workflowEvent.StepRetries = payload.StepRetries
|
|
workflowEvent.RetryCount = payload.RetryCount
|
|
case "step-run-stream-event":
|
|
payload, err := UnmarshalPayload[tasktypes.StepRunStreamEventTaskPayload](task.Payload)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
workflowEvent.WorkflowRunId = payload.WorkflowRunId
|
|
stepRunId = payload.StepRunId
|
|
workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_STEP_RUN
|
|
workflowEvent.ResourceId = stepRunId
|
|
workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_STREAM
|
|
|
|
workflowEvent.StepRetries = payload.StepRetries
|
|
workflowEvent.RetryCount = payload.RetryCount
|
|
case "workflow-run-finished":
|
|
payload, err := UnmarshalPayload[tasktypes.WorkflowRunFinishedTask](task.Payload)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
workflowRunId := payload.WorkflowRunId
|
|
workflowEvent.ResourceType = contracts.ResourceType_RESOURCE_TYPE_WORKFLOW_RUN
|
|
workflowEvent.ResourceId = workflowRunId
|
|
workflowEvent.WorkflowRunId = workflowRunId
|
|
workflowEvent.EventType = contracts.ResourceEventType_RESOURCE_EVENT_TYPE_COMPLETED
|
|
}
|
|
|
|
match, err := filter(workflowEvent)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if match != nil && !*match {
|
|
// if not a match, we don't return it
|
|
return nil, nil
|
|
}
|
|
|
|
hangup, err := hangupFunc(workflowEvent)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if hangup {
|
|
workflowEvent.Hangup = true
|
|
return workflowEvent, nil
|
|
}
|
|
|
|
if workflowEvent.ResourceType == contracts.ResourceType_RESOURCE_TYPE_STEP_RUN {
|
|
if workflowEvent.EventType == contracts.ResourceEventType_RESOURCE_EVENT_TYPE_STREAM {
|
|
streamEventId, err := strconv.ParseInt(task.Metadata["stream_event_id"].(string), 10, 64)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
streamEvent, err := s.repo.StreamEvent().GetStreamEvent(context.Background(), tenantId, streamEventId)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
workflowEvent.EventPayload = string(streamEvent.Message)
|
|
}
|
|
}
|
|
|
|
return workflowEvent, nil
|
|
}
|
|
|
|
func (s *DispatcherImpl) tenantTaskToWorkflowEventByRunId(task *msgqueue.Message, tenantId string, workflowRunIds ...string) (*contracts.WorkflowEvent, error) {
|
|
|
|
workflowEvent, err := s.taskToWorkflowEvent(task, tenantId,
|
|
func(e *contracts.WorkflowEvent) (*bool, error) {
|
|
hit := contains(workflowRunIds, e.WorkflowRunId)
|
|
return &hit, nil
|
|
},
|
|
func(e *contracts.WorkflowEvent) (bool, error) {
|
|
// hangup on complete
|
|
return e.ResourceType == contracts.ResourceType_RESOURCE_TYPE_WORKFLOW_RUN &&
|
|
e.EventType == contracts.ResourceEventType_RESOURCE_EVENT_TYPE_COMPLETED, nil
|
|
},
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return workflowEvent, nil
|
|
}
|
|
|
|
func tinyHash(key, value string) string {
|
|
// Combine key and value
|
|
combined := fmt.Sprintf("%s:%s", key, value)
|
|
|
|
// Create SHA-256 hash
|
|
hash := sha256.Sum256([]byte(combined))
|
|
|
|
// Take first 8 bytes of the hash
|
|
shortHash := hash[:8]
|
|
|
|
// Encode to base64
|
|
encoded := base64.URLEncoding.EncodeToString(shortHash)
|
|
|
|
// Remove padding
|
|
return encoded[:11]
|
|
}
|
|
|
|
func (s *DispatcherImpl) tenantTaskToWorkflowEventByAdditionalMeta(task *msgqueue.Message, tenantId string, key string, value string, hangup func(e *contracts.WorkflowEvent) (bool, error)) (*contracts.WorkflowEvent, error) {
|
|
workflowEvent, err := s.taskToWorkflowEvent(
|
|
task,
|
|
tenantId,
|
|
func(e *contracts.WorkflowEvent) (*bool, error) {
|
|
return cache.MakeCacheable[bool](
|
|
s.cache,
|
|
fmt.Sprintf("wfram-%s-%s-%s", tenantId, e.WorkflowRunId, tinyHash(key, value)),
|
|
func() (*bool, error) {
|
|
|
|
if e.WorkflowRunId == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
am, err := s.repo.WorkflowRun().GetWorkflowRunAdditionalMeta(context.Background(), tenantId, e.WorkflowRunId)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if am.AdditionalMetadata == nil {
|
|
f := false
|
|
return &f, nil
|
|
}
|
|
|
|
var additionalMetaMap map[string]interface{}
|
|
err = json.Unmarshal(am.AdditionalMetadata, &additionalMetaMap)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if v, ok := (additionalMetaMap)[key]; ok && v == value {
|
|
t := true
|
|
return &t, nil
|
|
}
|
|
|
|
f := false
|
|
return &f, nil
|
|
|
|
},
|
|
)
|
|
},
|
|
hangup,
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return workflowEvent, nil
|
|
}
|
|
|
|
func (s *DispatcherImpl) isMatchingWorkflowRun(task *msgqueue.Message, workflowRunIds ...string) (string, bool) {
|
|
if task.ID != "workflow-run-finished" {
|
|
return "", false
|
|
}
|
|
|
|
workflowRunId := task.Payload["workflow_run_id"].(string)
|
|
|
|
if contains(workflowRunIds, workflowRunId) {
|
|
return workflowRunId, true
|
|
}
|
|
|
|
return "", false
|
|
}
|
|
|
|
func (s *DispatcherImpl) toWorkflowRunEvent(tenantId string, workflowRuns []*dbsqlc.ListWorkflowRunsRow) ([]*contracts.WorkflowRunEvent, error) {
|
|
finalWorkflowRuns := make([]*dbsqlc.ListWorkflowRunsRow, 0)
|
|
|
|
for _, workflowRun := range workflowRuns {
|
|
wrCopy := workflowRun
|
|
|
|
if !repository.IsFinalWorkflowRunStatus(wrCopy.WorkflowRun.Status) {
|
|
continue
|
|
}
|
|
|
|
finalWorkflowRuns = append(finalWorkflowRuns, wrCopy)
|
|
}
|
|
|
|
res := make([]*contracts.WorkflowRunEvent, 0)
|
|
|
|
// get step run results for each workflow run
|
|
mappedStepRunResults, err := s.getStepResultsForWorkflowRun(tenantId, finalWorkflowRuns)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for workflowRunId, stepRunResults := range mappedStepRunResults {
|
|
res = append(res, &contracts.WorkflowRunEvent{
|
|
WorkflowRunId: workflowRunId,
|
|
EventType: contracts.WorkflowRunEventType_WORKFLOW_RUN_EVENT_TYPE_FINISHED,
|
|
EventTimestamp: timestamppb.Now(),
|
|
Results: stepRunResults,
|
|
})
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func (s *DispatcherImpl) getStepResultsForWorkflowRun(tenantId string, workflowRuns []*dbsqlc.ListWorkflowRunsRow) (map[string][]*contracts.StepRunResult, error) {
|
|
workflowRunIds := make([]string, 0)
|
|
workflowRunToOnFailureJobIds := make(map[string]string)
|
|
|
|
for _, workflowRun := range workflowRuns {
|
|
workflowRunIds = append(workflowRunIds, sqlchelpers.UUIDToStr(workflowRun.WorkflowRun.ID))
|
|
|
|
if workflowRun.WorkflowVersion.OnFailureJobId.Valid {
|
|
workflowRunToOnFailureJobIds[sqlchelpers.UUIDToStr(workflowRun.WorkflowRun.ID)] = sqlchelpers.UUIDToStr(workflowRun.WorkflowVersion.OnFailureJobId)
|
|
}
|
|
}
|
|
|
|
stepRuns, err := s.repo.StepRun().ListStepRuns(context.Background(), tenantId, &repository.ListStepRunsOpts{
|
|
WorkflowRunIds: workflowRunIds,
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
res := make(map[string][]*contracts.StepRunResult)
|
|
|
|
for _, stepRun := range stepRuns {
|
|
|
|
data, err := s.repo.StepRun().GetStepRunDataForEngine(context.Background(), tenantId, sqlchelpers.UUIDToStr(stepRun.SRID))
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not get step run data for %s, %e", sqlchelpers.UUIDToStr(stepRun.SRID), err)
|
|
}
|
|
|
|
workflowRunId := sqlchelpers.UUIDToStr(stepRun.WorkflowRunId)
|
|
jobId := sqlchelpers.UUIDToStr(stepRun.JobId)
|
|
|
|
hasError := data.Error.Valid || stepRun.SRCancelledReason.Valid
|
|
isOnFailureJob := workflowRunToOnFailureJobIds[workflowRunId] == jobId
|
|
|
|
if hasError && isOnFailureJob {
|
|
continue
|
|
}
|
|
|
|
resStepRun := &contracts.StepRunResult{
|
|
StepRunId: sqlchelpers.UUIDToStr(stepRun.SRID),
|
|
StepReadableId: stepRun.StepReadableId.String,
|
|
JobRunId: sqlchelpers.UUIDToStr(stepRun.JobRunId),
|
|
}
|
|
|
|
if data.Error.Valid {
|
|
resStepRun.Error = &data.Error.String
|
|
}
|
|
|
|
if stepRun.SRCancelledReason.Valid {
|
|
errString := fmt.Sprintf("this step run was cancelled due to %s", stepRun.SRCancelledReason.String)
|
|
resStepRun.Error = &errString
|
|
}
|
|
|
|
if data.Output != nil {
|
|
resStepRun.Output = repository.StringPtr(string(data.Output))
|
|
}
|
|
|
|
if currResults, ok := res[workflowRunId]; ok {
|
|
res[workflowRunId] = append(currResults, resStepRun)
|
|
} else {
|
|
res[workflowRunId] = []*contracts.StepRunResult{resStepRun}
|
|
}
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func contains(s []string, e string) bool {
|
|
for _, a := range s {
|
|
if a == e {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|