diff --git a/go/store/nbs/archive_build.go b/go/store/nbs/archive_build.go index adce1eca6c..fc40f6a0ba 100644 --- a/go/store/nbs/archive_build.go +++ b/go/store/nbs/archive_build.go @@ -24,6 +24,7 @@ import ( "os" "path/filepath" "sort" + "sync" "sync/atomic" "time" @@ -39,6 +40,7 @@ import ( const defaultDictionarySize = 1 << 12 // NM4 - maybe just select the largest chunk. TBD. const maxSamples = 1000 const minSamples = 25 +const fourMb = 1 << 22 var errNotEnoughChunks = errors.New("Not enough samples to build default dictionary") @@ -264,12 +266,7 @@ func convertTableFileToArchive( // cg.print(n, p) //} - const fourMb = 1 << 22 - - // Allocate buffer used to compress chunks. - cmpBuff := make([]byte, 0, fourMb) - - cmpBuff = gozstd.Compress(cmpBuff[:0], defaultDict) + cmpBuff := gozstd.Compress(nil, defaultDict) // p("Default Dict Raw vs Compressed: %d , %d\n", len(defaultDict), len(cmpDefDict)) arcW, err := newArchiveWriter("") @@ -282,7 +279,7 @@ func convertTableFileToArchive( return "", hash.Hash{}, err } - _, grouped, singles, err := writeDataToArchive(ctx, cmpBuff[:0], allChunks, cgList, defaultDictByteSpanId, defaultCDict, arcW, progress, stats) + _, grouped, singles, err := writeDataToArchive(ctx, allChunks, cgList, defaultDictByteSpanId, defaultCDict, arcW, progress, stats) if err != nil { return "", hash.Hash{}, err } @@ -356,7 +353,6 @@ func indexFinalizeFlushArchive(arcW *archiveWriter, archivePath string, originTa func writeDataToArchive( ctx context.Context, - cmpBuff []byte, chunkCache *simpleChunkSourceCache, cgList []*chunkGroup, defaultSpanId uint32, @@ -371,6 +367,9 @@ func writeDataToArchive( return 0, 0, 0, err } + // Allocate buffer used to compress chunks. + cmpBuff := make([]byte, 0, fourMb) + ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -419,42 +418,102 @@ func writeDataToArchive( } } - ungroupedChunkCount := int32(len(allChunks)) - ungroupedChunkProgress := int32(0) + ungroupedChunks, err := compressChunksInParallel(ctx, allChunks, chunkCache, arcW, defaultDict, defaultSpanId, progress, stats) + if err != nil { + return 0, 0, 0, err + } - // Any chunks remaining will be written out individually, using the default dictionary. + return groupCount, groupedChunkCount, ungroupedChunks, nil + +} +func compressChunksInParallel( + ctx context.Context, + allChunks hash.HashSet, + chunkCache *simpleChunkSourceCache, + arcW *archiveWriter, + defaultDict *gozstd.CDict, + defaultSpanId uint32, + progress chan<- interface{}, + stats *Stats, +) (uint32, error) { + type compressedChunk struct { + h hash.Hash + data []byte + } + + const workerCount = 32 + + workCh := make(chan hash.Hash, len(allChunks)) + resultCh := make(chan compressedChunk, workerCount) + errCh := make(chan error, 1) + var wg sync.WaitGroup + + // Prepopulate work channel for h := range allChunks { - select { - case <-ctx.Done(): - return 0, 0, 0, ctx.Err() - default: - dictId := uint32(0) + workCh <- h + } + close(workCh) - c, e2 := chunkCache.get(ctx, h, stats) - if e2 != nil { - return 0, 0, 0, e2 + // Start worker goroutines + for i := 0; i < workerCount; i++ { + wg.Add(1) + go func() { + // Allocate buffer used to compress chunks. + cmpBuff := make([]byte, 0, fourMb) + + defer wg.Done() + for h := range workCh { + select { + case <-ctx.Done(): + return + default: + c, e2 := chunkCache.get(ctx, h, stats) + if e2 != nil { + errCh <- e2 + return + } + cmpBuff = gozstd.CompressDict(cmpBuff[:0], c.Data(), defaultDict) + resultCh <- compressedChunk{h: h, data: cmpBuff} + } } + }() + } - cmpBuff = gozstd.CompressDict(cmpBuff[:0], c.Data(), defaultDict) - dictId = defaultSpanId + // Close resultCh once all workers finish + go func() { + wg.Wait() + close(resultCh) + }() - id, err := arcW.writeByteSpan(cmpBuff) - if err != nil { - return 0, 0, 0, err - } - err = arcW.stageZStdChunk(h, dictId, id) - if err != nil { - return 0, 0, 0, err - } + // Collector: serial arcW calls + completed := int32(0) + totalChunks := int32(len(allChunks)) + for cc := range resultCh { + id, err := arcW.writeByteSpan(cc.data) + if err != nil { + return 0, err + } + err = arcW.stageZStdChunk(cc.h, defaultSpanId, id) + if err != nil { + return 0, err + } - ungroupedChunkProgress++ - progress <- ArchiveBuildProgressMsg{Stage: "Writing Ungrouped Chunks", Total: ungroupedChunkCount, Completed: ungroupedChunkProgress} + completed++ + progress <- ArchiveBuildProgressMsg{ + Stage: "Writing Ungrouped Chunks", + Total: totalChunks, + Completed: completed, } } - individualChunkCount = uint32(len(allChunks)) - - return + select { + case err := <-errCh: + return 0, err + case <-ctx.Done(): + return 0, ctx.Err() + default: + return uint32(totalChunks), nil + } } // gatherAllChunks reads all the chunks from the chunk source and returns them in a map. The map is keyed by the hash of