Files
hatchet/internal/services/dispatcher/server.go

688 lines
21 KiB
Go

package dispatcher
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/rs/zerolog"
telemetry_codes "go.opentelemetry.io/otel/codes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/hatchet-dev/hatchet/internal/msgqueue"
"github.com/hatchet-dev/hatchet/internal/services/dispatcher/contracts"
"github.com/hatchet-dev/hatchet/pkg/analytics"
v1 "github.com/hatchet-dev/hatchet/pkg/repository"
"github.com/hatchet-dev/hatchet/pkg/repository/sqlcv1"
"github.com/hatchet-dev/hatchet/pkg/telemetry"
tasktypes "github.com/hatchet-dev/hatchet/internal/services/shared/tasktypes/v1"
)
func (s *DispatcherImpl) Register(ctx context.Context, request *contracts.WorkerRegisterRequest) (*contracts.WorkerRegisterResponse, error) {
tenant := ctx.Value("tenant").(*sqlcv1.Tenant)
tenantId := tenant.ID
s.l.Debug().Ctx(ctx).Msgf("Received register request from ID %s with actions %v", request.WorkerName, request.Actions)
svcs := request.Services
if len(svcs) == 0 {
svcs = []string{"default"}
}
opts := &v1.CreateWorkerOpts{
DispatcherId: s.dispatcherId,
Name: request.WorkerName,
Actions: request.Actions,
Services: svcs,
}
if request.RuntimeInfo != nil {
opts.RuntimeInfo = &v1.RuntimeInfo{
SdkVersion: request.RuntimeInfo.SdkVersion,
Language: request.RuntimeInfo.Language,
LanguageVersion: request.RuntimeInfo.LanguageVersion,
Os: request.RuntimeInfo.Os,
Extra: request.RuntimeInfo.Extra,
}
}
if len(request.SlotConfig) > 0 {
opts.SlotConfig = request.SlotConfig
} else {
// default to 100 slots
opts.SlotConfig = map[string]int32{v1.SlotTypeDefault: 100}
}
// fixme: deprecated remove in a future release feb6 2026
if request.Slots != nil {
if len(request.SlotConfig) > 0 {
return nil, status.Errorf(codes.InvalidArgument, "either slot_config or slots (deprecated) must be provided, not both")
}
opts.SlotConfig = map[string]int32{v1.SlotTypeDefault: *request.Slots}
}
if apiErrors, err := s.v.ValidateAPI(opts); err != nil {
return nil, err
} else if apiErrors != nil {
return nil, status.Errorf(codes.InvalidArgument, "Invalid request: %s", apiErrors.String())
}
// create a worker in the database
worker, err := s.repov1.Workers().CreateNewWorker(ctx, tenantId, opts)
if err == v1.ErrResourceExhausted {
return nil, status.Errorf(codes.ResourceExhausted, "resource exhausted: tenant worker limit or concurrency limit exceeded")
}
if err != nil {
s.l.Error().Ctx(ctx).Err(err).Msgf("could not create worker for tenant %s", tenantId)
return nil, err
}
workerId := worker.ID.String()
if request.Labels != nil {
_, err = s.upsertLabels(ctx, worker.ID, request.Labels)
if err != nil {
return nil, err
}
}
s.analytics.Count(ctx, analytics.Worker, analytics.Register, analytics.Props(
"worker_name", request.WorkerName,
"runtime_language", strings.ToLower(request.GetRuntimeInfo().GetLanguage().String()),
"runtime_sdk_version", request.GetRuntimeInfo().GetSdkVersion(),
"runtime_language_version", request.GetRuntimeInfo().GetLanguageVersion(),
"runtime_os", request.GetRuntimeInfo().GetOs(),
"runtime_extra", request.GetRuntimeInfo().GetExtra(),
"has_labels", len(request.Labels) > 0,
"has_webhook_id", request.WebhookId != nil,
"has_runtime_info", request.RuntimeInfo != nil,
"has_slot_config", len(request.SlotConfig) > 0,
"has_custom_slots", request.Slots != nil,
"has_services", len(request.Services) > 0,
))
return &contracts.WorkerRegisterResponse{
TenantId: tenantId.String(),
WorkerId: workerId,
WorkerName: worker.Name,
}, nil
}
func (s *DispatcherImpl) UpsertWorkerLabels(ctx context.Context, request *contracts.UpsertWorkerLabelsRequest) (*contracts.UpsertWorkerLabelsResponse, error) {
tenant := ctx.Value("tenant").(*sqlcv1.Tenant)
s.analytics.Count(ctx, analytics.Worker, analytics.Create)
workerId, err := uuid.Parse(request.WorkerId)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid worker ID format: %s", request.WorkerId)
}
_, err = s.upsertLabels(ctx, workerId, request.Labels)
if err != nil {
return nil, err
}
return &contracts.UpsertWorkerLabelsResponse{
TenantId: tenant.ID.String(),
WorkerId: request.WorkerId,
}, nil
}
func (s *DispatcherImpl) upsertLabels(ctx context.Context, workerId uuid.UUID, request map[string]*contracts.WorkerLabels) ([]*sqlcv1.WorkerLabel, error) {
affinities := make([]v1.UpsertWorkerLabelOpts, 0, len(request))
for key, config := range request {
err := s.v.Validate(config)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "Invalid affinity config: %s", err.Error())
}
affinities = append(affinities, v1.UpsertWorkerLabelOpts{
Key: key,
IntValue: config.IntValue,
StrValue: config.StrValue,
})
}
res, err := s.repov1.Workers().UpsertWorkerLabels(ctx, workerId, affinities)
if err != nil {
s.l.Error().Ctx(ctx).Err(err).Msgf("could not upsert worker affinities for worker %s", workerId.String())
return nil, err
}
return res, nil
}
// Subscribe handles a subscribe request from a client
func (s *DispatcherImpl) Listen(request *contracts.WorkerListenRequest, stream contracts.Dispatcher_ListenServer) error {
ctx := stream.Context()
tenant := ctx.Value("tenant").(*sqlcv1.Tenant)
tenantId := tenant.ID
s.analytics.Count(ctx, analytics.Worker, analytics.Listen)
sessionId := uuid.New().String()
workerId, err := uuid.Parse(request.WorkerId)
if err != nil {
s.l.Error().Ctx(ctx).Err(err).Msgf("invalid worker ID format: %s", request.WorkerId)
return status.Errorf(codes.InvalidArgument, "invalid worker ID format: %s", request.WorkerId)
}
s.l.Debug().Ctx(ctx).Msgf("Received subscribe request from ID: %s", request.WorkerId)
worker, err := s.repov1.Workers().GetWorkerForEngine(ctx, tenantId, workerId)
if err != nil {
s.l.Error().Ctx(ctx).Err(err).Msgf("could not get worker %s", request.WorkerId)
return err
}
shouldUpdateDispatcherId := worker.DispatcherId == nil || (worker.DispatcherId != nil && *worker.DispatcherId != s.dispatcherId)
// check the worker's dispatcher against the current dispatcher. if they don't match, then update the worker
if shouldUpdateDispatcherId {
_, err = s.repov1.Workers().UpdateWorker(ctx, tenantId, workerId, &v1.UpdateWorkerOpts{
DispatcherId: &s.dispatcherId,
})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil
}
s.l.Error().Ctx(ctx).Err(err).Msgf("could not update worker %s dispatcher", request.WorkerId)
return err
}
}
fin := make(chan bool)
s.workers.Add(workerId, sessionId, newSubscribedWorker(stream, fin, workerId, s.defaultMaxWorkerLockAcquisitionTime, s.pubBuffer))
defer func() {
// non-blocking send
select {
case fin <- true:
default:
}
s.workers.DeleteForSession(workerId, sessionId)
}()
// update the worker with a last heartbeat time every 5 seconds as long as the worker is connected
go func() {
timer := time.NewTicker(100 * time.Millisecond)
// set the last heartbeat to 6 seconds ago so the first heartbeat is sent immediately
lastHeartbeat := time.Now().UTC().Add(-6 * time.Second)
defer timer.Stop()
for {
select {
case <-ctx.Done():
s.l.Debug().Ctx(ctx).Msgf("worker id %s has disconnected", request.WorkerId)
return
case <-fin:
s.l.Debug().Ctx(ctx).Msgf("closing stream for worker id: %s", request.WorkerId)
return
case <-timer.C:
if now := time.Now().UTC(); lastHeartbeat.Add(4 * time.Second).Before(now) {
s.l.Debug().Ctx(ctx).Msgf("updating worker %s heartbeat", request.WorkerId)
_, err := s.repov1.Workers().UpdateWorker(ctx, tenantId, workerId, &v1.UpdateWorkerOpts{
LastHeartbeatAt: &now,
IsActive: v1.BoolPtr(true),
})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return
}
s.l.Error().Ctx(ctx).Err(err).Msgf("could not update worker %s heartbeat", request.WorkerId)
return
}
lastHeartbeat = time.Now().UTC()
}
}
}
}()
// Keep the connection alive for sending messages
for {
select {
case <-fin:
s.l.Debug().Ctx(ctx).Msgf("closing stream for worker id: %s", request.WorkerId)
return nil
case <-ctx.Done():
s.l.Debug().Ctx(ctx).Msgf("worker id %s has disconnected", request.WorkerId)
return nil
}
}
}
// ListenV2 is like Listen, but implementation does not include heartbeats. This should only used by SDKs
// against engine version v0.18.1+
func (s *DispatcherImpl) ListenV2(request *contracts.WorkerListenRequest, stream contracts.Dispatcher_ListenV2Server) error {
ctx := stream.Context()
tenant := ctx.Value("tenant").(*sqlcv1.Tenant)
tenantId := tenant.ID
s.analytics.Count(stream.Context(), analytics.Worker, analytics.Listen)
sessionId := uuid.New().String()
workerId, err := uuid.Parse(request.WorkerId)
if err != nil {
s.l.Error().Ctx(ctx).Err(err).Msgf("invalid worker ID format: %s", request.WorkerId)
return status.Errorf(codes.InvalidArgument, "invalid worker ID format: %s", request.WorkerId)
}
s.l.Debug().Ctx(ctx).Msgf("Received subscribe request from ID: %s", request.WorkerId)
worker, err := s.repov1.Workers().GetWorkerForEngine(ctx, tenantId, workerId)
if err != nil {
s.l.Error().Ctx(ctx).Err(err).Msgf("could not get worker %s", request.WorkerId)
return err
}
shouldUpdateDispatcherId := worker.DispatcherId == nil || (worker.DispatcherId != nil && *worker.DispatcherId != s.dispatcherId)
// check the worker's dispatcher against the current dispatcher. if they don't match, then update the worker
if shouldUpdateDispatcherId {
_, err = s.repov1.Workers().UpdateWorker(ctx, tenantId, workerId, &v1.UpdateWorkerOpts{
DispatcherId: &s.dispatcherId,
})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil
}
s.l.Error().Ctx(ctx).Err(err).Msgf("could not update worker %s dispatcher", request.WorkerId)
return err
}
}
sessionEstablished := time.Now().UTC()
_, err = s.repov1.Workers().UpdateWorkerActiveStatus(ctx, tenantId, workerId, true, sessionEstablished)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil
}
lastSessionEstablished := "NULL"
if worker.LastListenerEstablished.Valid {
lastSessionEstablished = worker.LastListenerEstablished.Time.String()
}
s.l.Error().Ctx(ctx).Err(err).Msgf("could not update worker %s active status to true (session established %s, last session established %s)", request.WorkerId, sessionEstablished.String(), lastSessionEstablished)
return err
}
fin := make(chan bool)
s.workers.Add(workerId, sessionId, newSubscribedWorker(stream, fin, workerId, s.defaultMaxWorkerLockAcquisitionTime, s.pubBuffer))
defer func() {
// non-blocking send
select {
case fin <- true:
default:
}
s.workers.DeleteForSession(workerId, sessionId)
}()
// Keep the connection alive for sending messages
for {
select {
case <-fin:
s.l.Debug().Ctx(ctx).Msgf("closing stream for worker id: %s", request.WorkerId)
_, err = s.repov1.Workers().UpdateWorkerActiveStatus(ctx, tenantId, workerId, false, sessionEstablished)
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
s.l.Error().Ctx(ctx).Err(err).Msgf("could not update worker %s active status to false due to worker stream closing (session established %s)", request.WorkerId, sessionEstablished.String())
return err
}
return nil
case <-ctx.Done():
s.l.Debug().Ctx(ctx).Msgf("worker id %s has disconnected", request.WorkerId)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
_, err = s.repov1.Workers().UpdateWorkerActiveStatus(ctx, tenantId, workerId, false, sessionEstablished)
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
s.l.Error().Ctx(ctx).Err(err).Msgf("could not update worker %s active status due to worker disconnecting (session established %s)", request.WorkerId, sessionEstablished.String())
return err
}
return nil
}
}
}
const HeartbeatInterval = 4 * time.Second
// Heartbeat is used to update the last heartbeat time for a worker
func (s *DispatcherImpl) Heartbeat(ctx context.Context, req *contracts.HeartbeatRequest) (*contracts.HeartbeatResponse, error) {
ctx, span := telemetry.NewSpan(ctx, "update-worker-heartbeat")
defer span.End()
tenant := ctx.Value("tenant").(*sqlcv1.Tenant)
tenantId := tenant.ID
workerId, err := uuid.Parse(req.WorkerId)
if err != nil {
s.l.Error().Ctx(ctx).Err(err).Msgf("invalid worker ID format: %s", req.WorkerId)
return nil, status.Errorf(codes.InvalidArgument, "invalid worker ID format: %s", req.WorkerId)
}
heartbeatAt := time.Now().UTC()
s.l.Debug().Ctx(ctx).Msgf("Received heartbeat request from ID: %s", req.WorkerId)
// if heartbeat time is greater than expected heartbeat interval, show a warning
if req.HeartbeatAt.AsTime().Before(heartbeatAt.Add(-1 * HeartbeatInterval)) {
s.l.Warn().Ctx(ctx).Msgf("heartbeat time is greater than expected heartbeat interval")
}
worker, err := s.repov1.Workers().GetWorkerForEngine(ctx, tenantId, workerId)
if err != nil {
span.RecordError(err)
span.SetStatus(telemetry_codes.Error, "could not get worker")
if errors.Is(err, pgx.ErrNoRows) {
return nil, status.Errorf(codes.NotFound, "worker not found: %s", req.WorkerId)
}
return nil, err
}
// if the worker is not active, the listener should reconnect
if worker.LastListenerEstablished.Valid && !worker.IsActive {
span.RecordError(err)
span.SetStatus(telemetry_codes.Error, "worker stream is not active")
return nil, status.Errorf(codes.FailedPrecondition, "Heartbeat rejected: worker stream is not active: %s", req.WorkerId)
}
err = s.repov1.Workers().UpdateWorkerHeartbeat(ctx, tenantId, workerId, heartbeatAt)
if err != nil {
span.RecordError(err)
span.SetStatus(telemetry_codes.Error, "could not update worker heartbeat")
if errors.Is(err, pgx.ErrNoRows) {
s.l.Error().Ctx(ctx).Msgf("could not update worker heartbeat: worker %s not found", req.WorkerId)
return nil, err
}
return nil, err
}
// if the worker doesn't have a previous heartbeat or hasn't heartbeat in 30 seconds, notify downstream components that a
// new worker is available
if !worker.LastHeartbeatAt.Valid || worker.LastHeartbeatAt.Time.Before(heartbeatAt.Add(-30*time.Second)) {
if tenant.SchedulerPartitionId.Valid {
go func() {
notifyCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
msg, err := tasktypes.NotifyNewWorker(tenantId, worker.ID)
if err != nil {
s.l.Err(err).Ctx(ctx).Str("scheduler_partition_id", tenant.SchedulerPartitionId.String).Msg("could not create message for notifying new worker")
} else {
err = s.mqv1.SendMessage(
notifyCtx,
msgqueue.QueueTypeFromPartitionIDAndController(tenant.SchedulerPartitionId.String, msgqueue.Scheduler),
msg,
)
if err != nil {
s.l.Err(err).Ctx(ctx).Str("scheduler_partition_id", tenant.SchedulerPartitionId.String).Msg("could not add message to scheduler partition queue")
}
}
}()
}
}
return &contracts.HeartbeatResponse{}, nil
}
func (s *DispatcherImpl) ReleaseSlot(ctx context.Context, req *contracts.ReleaseSlotRequest) (*contracts.ReleaseSlotResponse, error) {
tenant := ctx.Value("tenant").(*sqlcv1.Tenant)
s.analytics.Count(ctx, analytics.Worker, analytics.Release)
return s.releaseSlot(ctx, tenant, req)
}
func (s *DispatcherImpl) RestoreEvictedTask(ctx context.Context, req *contracts.RestoreEvictedTaskRequest) (*contracts.RestoreEvictedTaskResponse, error) {
tenant := ctx.Value("tenant").(*sqlcv1.Tenant)
s.analytics.Count(ctx, analytics.DurableTask, analytics.Restore)
return s.restoreEvictedTask(ctx, tenant, req)
}
func (s *DispatcherImpl) SubscribeToWorkflowEvents(request *contracts.SubscribeToWorkflowEventsRequest, stream contracts.Dispatcher_SubscribeToWorkflowEventsServer) error {
if _, ok := stream.Context().Value("tenant").(*sqlcv1.Tenant); ok {
s.analytics.Count(stream.Context(), analytics.WorkflowRun, analytics.Subscribe)
}
return s.subscribeToWorkflowEventsV1(request, stream)
}
// map of workflow run ids to whether the workflow runs are finished and have sent a message
// that the workflow run is finished
type workflowRunAcks struct {
acks map[uuid.UUID]bool
mu sync.RWMutex
}
func (w *workflowRunAcks) addWorkflowRun(id uuid.UUID) {
w.mu.Lock()
defer w.mu.Unlock()
w.acks[id] = false
}
func (w *workflowRunAcks) getNonAckdWorkflowRuns() []uuid.UUID {
w.mu.RLock()
defer w.mu.RUnlock()
ids := make([]uuid.UUID, 0, len(w.acks))
for id := range w.acks {
if !w.acks[id] {
ids = append(ids, id)
}
}
return ids
}
func (w *workflowRunAcks) ackWorkflowRun(id uuid.UUID) {
w.mu.Lock()
defer w.mu.Unlock()
delete(w.acks, id)
}
func (w *workflowRunAcks) hasWorkflowRun(id uuid.UUID) bool {
w.mu.RLock()
defer w.mu.RUnlock()
_, ok := w.acks[id]
return ok
}
type sendTimeFilter struct {
mu sync.Mutex
}
func (s *sendTimeFilter) canSend() bool {
if !s.mu.TryLock() {
return false
}
go func() {
time.Sleep(10 * time.Millisecond)
s.mu.Unlock()
}()
return true
}
func (d *DispatcherImpl) cleanResults(results []*contracts.StepRunResult) []*contracts.StepRunResult {
totalSize, sizeOfOutputs, _ := calculateResultsSize(results)
if totalSize < d.payloadSizeThreshold {
return results
}
if sizeOfOutputs >= d.payloadSizeThreshold {
return nil
}
// otherwise, attempt to clean the results by removing large error fields
cleanedResults := make([]*contracts.StepRunResult, 0, len(results))
fieldThreshold := (d.payloadSizeThreshold - sizeOfOutputs) / len(results) // how much overhead we'd have per result or error field, in the worst case
for _, result := range results {
if result == nil {
continue
}
// we only try to clean the error field at the moment, as modifying the output is more risky
if result.Error != nil && len(*result.Error) > fieldThreshold {
result.Error = v1.StringPtr("Error is too large to send over the Hatchet stream.")
}
cleanedResults = append(cleanedResults, result)
}
// if we are still over the limit, we just return nil
if totalSize, _, _ := calculateResultsSize(cleanedResults); totalSize > d.payloadSizeThreshold {
return nil
}
return cleanedResults
}
func calculateResultsSize(results []*contracts.StepRunResult) (totalSize int, sizeOfOutputs int, sizeOfErrors int) {
for _, result := range results {
if result != nil && result.Output != nil {
totalSize += (len(*result.Output))
sizeOfOutputs += (len(*result.Output))
}
if result != nil && result.Error != nil {
totalSize += (len(*result.Error))
sizeOfErrors += (len(*result.Error))
}
}
return
}
func (s *DispatcherImpl) SubscribeToWorkflowRuns(server contracts.Dispatcher_SubscribeToWorkflowRunsServer) error {
s.analytics.Count(server.Context(), analytics.WorkflowRun, analytics.Subscribe)
return s.subscribeToWorkflowRunsV1(server)
}
func waitFor(wg *sync.WaitGroup, timeout time.Duration, l *zerolog.Logger) {
done := make(chan struct{})
go func() {
wg.Wait()
defer close(done)
}()
select {
case <-done:
case <-time.After(timeout):
l.Error().Msg("timed out waiting for wait group")
}
}
func (s *DispatcherImpl) SendStepActionEvent(ctx context.Context, request *contracts.StepActionEvent) (*contracts.ActionEventResponse, error) {
return s.sendStepActionEventV1(ctx, request)
}
func (s *DispatcherImpl) SendGroupKeyActionEvent(ctx context.Context, request *contracts.GroupKeyActionEvent) (*contracts.ActionEventResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "SendGroupKeyActionEvent is not implemented in engine version v1")
}
func (s *DispatcherImpl) PutOverridesData(ctx context.Context, request *contracts.OverridesData) (*contracts.OverridesDataResponse, error) {
return &contracts.OverridesDataResponse{}, nil
}
func (s *DispatcherImpl) Unsubscribe(ctx context.Context, request *contracts.WorkerUnsubscribeRequest) (*contracts.WorkerUnsubscribeResponse, error) {
tenant := ctx.Value("tenant").(*sqlcv1.Tenant)
tenantId := tenant.ID
s.analytics.Count(ctx, analytics.Worker, analytics.Delete)
workerId, err := uuid.Parse(request.WorkerId)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid worker ID format: %s", request.WorkerId)
}
// remove the worker from the connection pool
s.workers.Delete(workerId)
return &contracts.WorkerUnsubscribeResponse{
TenantId: tenantId.String(),
WorkerId: request.WorkerId,
}, nil
}
func (d *DispatcherImpl) RefreshTimeout(ctx context.Context, request *contracts.RefreshTimeoutRequest) (*contracts.RefreshTimeoutResponse, error) {
tenant := ctx.Value("tenant").(*sqlcv1.Tenant)
return d.refreshTimeoutV1(ctx, tenant, request)
}
func UnmarshalPayload[T any](payload interface{}) (T, error) {
var result T
// Convert the payload to JSON
jsonData, err := json.Marshal(payload)
if err != nil {
return result, fmt.Errorf("failed to marshal payload: %w", err)
}
// Unmarshal JSON into the desired type
err = json.Unmarshal(jsonData, &result)
if err != nil {
return result, fmt.Errorf("failed to unmarshal payload: %w", err)
}
return result, nil
}
func (s *DispatcherImpl) GetVersion(ctx context.Context, req *contracts.GetVersionRequest) (*contracts.GetVersionResponse, error) {
return &contracts.GetVersionResponse{
Version: s.version,
}, nil
}