From ea9f7c6ba7e82edd96bdcd8f99a485a992819a88 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Sun, 2 Nov 2025 19:48:35 +0100 Subject: [PATCH] fix: pass tx down to retrieve --- internal/services/dispatcher/dispatcher_v1.go | 2 +- pkg/repository/v1/ids.go | 2 +- pkg/repository/v1/match.go | 2 +- pkg/repository/v1/payloadstore.go | 10 +++++++--- pkg/repository/v1/task.go | 8 ++++---- pkg/repository/v1/trigger.go | 2 +- 6 files changed, 15 insertions(+), 11 deletions(-) diff --git a/internal/services/dispatcher/dispatcher_v1.go b/internal/services/dispatcher/dispatcher_v1.go index ceb885fc3..c9ffbb53f 100644 --- a/internal/services/dispatcher/dispatcher_v1.go +++ b/internal/services/dispatcher/dispatcher_v1.go @@ -237,7 +237,7 @@ func (d *DispatcherImpl) handleTaskBulkAssignedTask(ctx context.Context, msg *ms } } - inputs, err := d.repov1.Payloads().Retrieve(ctx, retrievePayloadOpts...) + inputs, err := d.repov1.Payloads().Retrieve(ctx, nil, retrievePayloadOpts...) if err != nil { d.l.Error().Err(err).Msgf("could not bulk retrieve inputs for %d tasks", len(bulkDatas)) diff --git a/pkg/repository/v1/ids.go b/pkg/repository/v1/ids.go index 9e351000b..e12f7f9d6 100644 --- a/pkg/repository/v1/ids.go +++ b/pkg/repository/v1/ids.go @@ -180,7 +180,7 @@ func (s *sharedRepository) generateExternalIdsForChildWorkflows(ctx context.Cont } } - payloads, err := s.payloadStore.Retrieve(ctx, retrievePayloadOpts...) + payloads, err := s.payloadStore.Retrieve(ctx, tx, retrievePayloadOpts...) if err != nil { return err diff --git a/pkg/repository/v1/match.go b/pkg/repository/v1/match.go index d3a1692a4..eedf4ed56 100644 --- a/pkg/repository/v1/match.go +++ b/pkg/repository/v1/match.go @@ -495,7 +495,7 @@ func (m *sharedRepository) processEventMatches(ctx context.Context, tx sqlcv1.DB } } - payloads, err := m.payloadStore.Retrieve(ctx, retrievePayloadOpts...) + payloads, err := m.payloadStore.Retrieve(ctx, tx, retrievePayloadOpts...) if err != nil { return nil, fmt.Errorf("failed to retrieve dag input payloads: %w", err) diff --git a/pkg/repository/v1/payloadstore.go b/pkg/repository/v1/payloadstore.go index 985e87340..735b8bf44 100644 --- a/pkg/repository/v1/payloadstore.go +++ b/pkg/repository/v1/payloadstore.go @@ -53,7 +53,7 @@ type ExternalStore interface { type PayloadStoreRepository interface { Store(ctx context.Context, tx sqlcv1.DBTX, payloads ...StorePayloadOpts) error - Retrieve(ctx context.Context, opts ...RetrievePayloadOpts) (map[RetrievePayloadOpts][]byte, error) + Retrieve(ctx context.Context, tx sqlcv1.DBTX, opts ...RetrievePayloadOpts) (map[RetrievePayloadOpts][]byte, error) RetrieveFromExternal(ctx context.Context, keys ...ExternalPayloadLocationKey) (map[ExternalPayloadLocationKey][]byte, error) ProcessPayloadWAL(ctx context.Context, partitionNumber int64, pubBuffer *msgqueue.MQPubBuffer) (bool, error) ProcessPayloadExternalCutovers(ctx context.Context, partitionNumber int64) (bool, error) @@ -304,8 +304,12 @@ func (p *payloadStoreRepositoryImpl) Store(ctx context.Context, tx sqlcv1.DBTX, return err } -func (p *payloadStoreRepositoryImpl) Retrieve(ctx context.Context, opts ...RetrievePayloadOpts) (map[RetrievePayloadOpts][]byte, error) { - return p.retrieve(ctx, p.pool, opts...) +func (p *payloadStoreRepositoryImpl) Retrieve(ctx context.Context, tx sqlcv1.DBTX, opts ...RetrievePayloadOpts) (map[RetrievePayloadOpts][]byte, error) { + if tx == nil { + tx = p.pool + } + + return p.retrieve(ctx, tx, opts...) } func (p *payloadStoreRepositoryImpl) RetrieveFromExternal(ctx context.Context, keys ...ExternalPayloadLocationKey) (map[ExternalPayloadLocationKey][]byte, error) { diff --git a/pkg/repository/v1/task.go b/pkg/repository/v1/task.go index 92f3ab254..03755f9dc 100644 --- a/pkg/repository/v1/task.go +++ b/pkg/repository/v1/task.go @@ -1117,7 +1117,7 @@ func (r *TaskRepositoryImpl) listTaskOutputEvents(ctx context.Context, tx sqlcv1 matchedEventToRetrieveOpts[event] = opt } - payloads, err := r.payloadStore.Retrieve(ctx, retrieveOpts...) + payloads, err := r.payloadStore.Retrieve(ctx, tx, retrieveOpts...) if err != nil { return nil, err @@ -2883,7 +2883,7 @@ func (r *TaskRepositoryImpl) ReplayTasks(ctx context.Context, tenantId string, t } } - payloads, err := r.payloadStore.Retrieve(ctx, retrieveOpts...) + payloads, err := r.payloadStore.Retrieve(ctx, tx, retrieveOpts...) if err != nil { return nil, fmt.Errorf("failed to bulk retrieve task inputs: %w", err) @@ -3491,7 +3491,7 @@ func (r *TaskRepositoryImpl) ListTaskParentOutputs(ctx context.Context, tenantId retrieveOptToPayload[opt] = outputTask.Output } - payloads, err := r.payloadStore.Retrieve(ctx, retrieveOpts...) + payloads, err := r.payloadStore.Retrieve(ctx, r.pool, retrieveOpts...) if err != nil { return nil, fmt.Errorf("failed to retrieve task output payloads: %w", err) @@ -3567,7 +3567,7 @@ func (r *TaskRepositoryImpl) ListSignalCompletedEvents(ctx context.Context, tena retrieveOpts[i] = retrieveOpt } - payloads, err := r.payloadStore.Retrieve(ctx, retrieveOpts...) + payloads, err := r.payloadStore.Retrieve(ctx, r.pool, retrieveOpts...) if err != nil { return nil, fmt.Errorf("failed to retrieve task event payloads: %w", err) diff --git a/pkg/repository/v1/trigger.go b/pkg/repository/v1/trigger.go index 713746d32..d3cbe726f 100644 --- a/pkg/repository/v1/trigger.go +++ b/pkg/repository/v1/trigger.go @@ -1512,7 +1512,7 @@ func (r *TriggerRepositoryImpl) registerChildWorkflows( } } - payloads, err := r.payloadStore.Retrieve(ctx, retrievePayloadOpts...) + payloads, err := r.payloadStore.Retrieve(ctx, tx, retrievePayloadOpts...) if err != nil { return nil, fmt.Errorf("failed to retrieve payloads for signal created events: %w", err)