diff --git a/go/libraries/utils/pipeline/stage.go b/go/libraries/utils/pipeline/stage.go index d2051d6f1a..0933d6f3e8 100644 --- a/go/libraries/utils/pipeline/stage.go +++ b/go/libraries/utils/pipeline/stage.go @@ -85,11 +85,28 @@ func (s *Stage) start(eg *errgroup.Group, ctx context.Context) { for i := 0; i < parallelism; i++ { routineIndex := i routineCtx := context.WithValue(ctx, localStorageKey, LocalStorage{}) - eg.Go(func() error { + eg.Go(func() (rerr error) { defer func() { if atomic.AddInt32(&stageWorkers, -1) == 0 { if s.outCh != nil { - close(s.outCh) + // To finalize our channel in the non-error case, + // we send a `nil` sentinel indicating we are done + // and then close the channel. + if rerr == nil { + select { + case <-ctx.Done(): + case s.outCh <- nil: + } + close(s.outCh) + } else { + // In the error case, we do not want to close the + // channel until we are certain our consumer will + // see the failure in the context Err(). + go func() { + <-ctx.Done() + close(s.outCh) + }() + } } } }() @@ -116,19 +133,12 @@ func (s *Stage) start(eg *errgroup.Group, ctx context.Context) { func (s *Stage) runFirstStageInPipeline(ctx context.Context) error { for { if ctx.Err() != nil { - return nil + return ctx.Err() } iwp, err := s.stageFunc(ctx, nil) if err == io.EOF { - // We send one last `nil` as an end-of-stream sentinel - // before we close the channel. - select { - case <-ctx.Done(): - return nil - case s.outCh <- nil: - return nil - } + return nil } if err != nil { return err @@ -136,7 +146,7 @@ func (s *Stage) runFirstStageInPipeline(ctx context.Context) error { select { case <-ctx.Done(): - return nil + return ctx.Err() case s.outCh <- iwp: } } @@ -147,15 +157,12 @@ func (s *Stage) runPipelineStage(ctx context.Context) error { for { select { case <-ctx.Done(): - return nil + return ctx.Err() case inBatch, ok := <-s.inCh: if !ok { - return nil + return ctx.Err() } err := s.transformBatch(ctx, inBatch) - if err == io.EOF { - return nil - } if err != nil { return err } @@ -178,7 +185,7 @@ func (s *Stage) transformBatch(ctx context.Context, inBatch []ItemWithProps) err select { case <-ctx.Done(): - return nil + return ctx.Err() case s.outCh <- currBatch: } }