go/libraries/utils/pipeline: stage.go: Another pass at finalization strategy.

This commit is contained in:
Aaron Son
2022-05-18 12:17:35 -07:00
parent f4b6b6fe50
commit a8f5e44786
+25 -18
View File
@@ -85,11 +85,28 @@ func (s *Stage) start(eg *errgroup.Group, ctx context.Context) {
for i := 0; i < parallelism; i++ {
routineIndex := i
routineCtx := context.WithValue(ctx, localStorageKey, LocalStorage{})
eg.Go(func() error {
eg.Go(func() (rerr error) {
defer func() {
if atomic.AddInt32(&stageWorkers, -1) == 0 {
if s.outCh != nil {
close(s.outCh)
// To finalize our channel in the non-error case,
// we send a `nil` sentinel indicating we are done
// and then close the channel.
if rerr == nil {
select {
case <-ctx.Done():
case s.outCh <- nil:
}
close(s.outCh)
} else {
// In the error case, we do not want to close the
// channel until we are certain our consumer will
// see the failure in the context Err().
go func() {
<-ctx.Done()
close(s.outCh)
}()
}
}
}
}()
@@ -116,19 +133,12 @@ func (s *Stage) start(eg *errgroup.Group, ctx context.Context) {
func (s *Stage) runFirstStageInPipeline(ctx context.Context) error {
for {
if ctx.Err() != nil {
return nil
return ctx.Err()
}
iwp, err := s.stageFunc(ctx, nil)
if err == io.EOF {
// We send one last `nil` as an end-of-stream sentinel
// before we close the channel.
select {
case <-ctx.Done():
return nil
case s.outCh <- nil:
return nil
}
return nil
}
if err != nil {
return err
@@ -136,7 +146,7 @@ func (s *Stage) runFirstStageInPipeline(ctx context.Context) error {
select {
case <-ctx.Done():
return nil
return ctx.Err()
case s.outCh <- iwp:
}
}
@@ -147,15 +157,12 @@ func (s *Stage) runPipelineStage(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return nil
return ctx.Err()
case inBatch, ok := <-s.inCh:
if !ok {
return nil
return ctx.Err()
}
err := s.transformBatch(ctx, inBatch)
if err == io.EOF {
return nil
}
if err != nil {
return err
}
@@ -178,7 +185,7 @@ func (s *Stage) transformBatch(ctx context.Context, inBatch []ItemWithProps) err
select {
case <-ctx.Done():
return nil
return ctx.Err()
case s.outCh <- currBatch:
}
}