From e416d66495a2dfea87e8798e30feef4d81f9f79d Mon Sep 17 00:00:00 2001 From: Shubham Chaudhary <36742242+shubhamvishu@users.noreply.github.com> Date: Tue, 21 Nov 2023 21:54:01 +0530 Subject: [PATCH] Make TaskExecutor cx public and use TaskExecutor for concurrent HNSW graph build (#12799) Make the TaskExecutor public which is currently pkg-private. At indexing time we concurrently create the hnsw graph (Concurrent HNSW Merge #12660). We could use the TaskExecutor implementation to do this for us. Use TaskExecutor#invokeAll in HnswConcurrentMergeBuilder#build to run the workers concurrently. --- lucene/CHANGES.txt | 5 ++ ...ene99HnswScalarQuantizedVectorsFormat.java | 9 +++- .../lucene99/Lucene99HnswVectorsFormat.java | 9 +++- .../lucene99/Lucene99HnswVectorsWriter.java | 6 +-- .../apache/lucene/search/TaskExecutor.java | 7 ++- .../util/hnsw/ConcurrentHnswMerger.java | 18 ++++--- .../util/hnsw/HnswConcurrentMergeBuilder.java | 51 ++++--------------- .../lucene/util/hnsw/HnswGraphTestCase.java | 4 +- 8 files changed, 54 insertions(+), 55 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index cfbcd6ca62d8..7d50f8939d4e 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -40,6 +40,11 @@ API Changes * GITHUB-12695: Remove public constructor of FSTCompiler. Please use FSTCompiler.Builder instead. (Juan M. Caicedo) +* GITHUB#12735: Remove FSTCompiler#getTermCount() and FSTCompiler.UnCompiledNode#inputCount (Anh Dung Bui) + +* GITHUB#12799: Make TaskExecutor constructor public and use TaskExecutor for concurrent + HNSW graph build. (Shubham Chaudhary) + New Features --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java index 1b45c7fe44b9..6023777ea944 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java @@ -31,6 +31,7 @@ import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.util.hnsw.HnswGraph; /** @@ -60,7 +61,7 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo private final FlatVectorsFormat flatVectorsFormat; private final int numMergeWorkers; - private final ExecutorService mergeExec; + private final TaskExecutor mergeExec; /** Constructs a format using default graph construction parameters */ public Lucene99HnswScalarQuantizedVectorsFormat() { @@ -121,7 +122,11 @@ public Lucene99HnswScalarQuantizedVectorsFormat( "No executor service is needed as we'll use single thread to merge"); } this.numMergeWorkers = numMergeWorkers; - this.mergeExec = mergeExec; + if (mergeExec != null) { + this.mergeExec = new TaskExecutor(mergeExec); + } else { + this.mergeExec = null; + } this.flatVectorsFormat = new Lucene99ScalarQuantizedVectorsFormat(confidenceInterval); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java index 85d65df55b99..e2e154a6c514 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java @@ -27,6 +27,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.hnsw.HnswGraph; @@ -137,7 +138,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat { private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(); private final int numMergeWorkers; - private final ExecutorService mergeExec; + private final TaskExecutor mergeExec; /** Constructs a format using default graph construction parameters */ public Lucene99HnswVectorsFormat() { @@ -192,7 +193,11 @@ public Lucene99HnswVectorsFormat( "No executor service is needed as we'll use single thread to merge"); } this.numMergeWorkers = numMergeWorkers; - this.mergeExec = mergeExec; + if (mergeExec != null) { + this.mergeExec = new TaskExecutor(mergeExec); + } else { + this.mergeExec = null; + } } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java index 230ee26564e2..f39069ba981b 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java @@ -23,7 +23,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.concurrent.ExecutorService; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.FlatVectorsWriter; import org.apache.lucene.codecs.KnnFieldVectorsWriter; @@ -35,6 +34,7 @@ import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.InfoStream; @@ -67,7 +67,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { private final int beamWidth; private final FlatVectorsWriter flatVectorWriter; private final int numMergeWorkers; - private final ExecutorService mergeExec; + private final TaskExecutor mergeExec; private final List> fields = new ArrayList<>(); private boolean finished; @@ -78,7 +78,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { int beamWidth, FlatVectorsWriter flatVectorWriter, int numMergeWorkers, - ExecutorService mergeExec) + TaskExecutor mergeExec) throws IOException { this.M = M; this.flatVectorWriter = flatVectorWriter; diff --git a/lucene/core/src/java/org/apache/lucene/search/TaskExecutor.java b/lucene/core/src/java/org/apache/lucene/search/TaskExecutor.java index f2be51206e43..5a0447acd0b6 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TaskExecutor.java +++ b/lucene/core/src/java/org/apache/lucene/search/TaskExecutor.java @@ -53,7 +53,12 @@ public final class TaskExecutor { private final Executor executor; - TaskExecutor(Executor executor) { + /** + * Creates a TaskExecutor instance + * + * @param executor the executor to be used for running tasks concurrently + */ + public TaskExecutor(Executor executor) { this.executor = Objects.requireNonNull(executor, "Executor is null"); } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java index 2253e735880d..38ecc38467b4 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java @@ -17,17 +17,17 @@ package org.apache.lucene.util.hnsw; import java.io.IOException; -import java.util.concurrent.ExecutorService; import org.apache.lucene.codecs.HnswGraphProvider; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.FixedBitSet; /** This merger merges graph in a concurrent manner, by using {@link HnswConcurrentMergeBuilder} */ public class ConcurrentHnswMerger extends IncrementalHnswGraphMerger { - private final ExecutorService exec; + private final TaskExecutor taskExecutor; private final int numWorker; /** @@ -38,10 +38,10 @@ public ConcurrentHnswMerger( RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, - ExecutorService exec, + TaskExecutor taskExecutor, int numWorker) { super(fieldInfo, scorerSupplier, M, beamWidth); - this.exec = exec; + this.taskExecutor = taskExecutor; this.numWorker = numWorker; } @@ -50,7 +50,13 @@ protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int m throws IOException { if (initReader == null) { return new HnswConcurrentMergeBuilder( - exec, numWorker, scorerSupplier, M, beamWidth, new OnHeapHnswGraph(M, maxOrd), null); + taskExecutor, + numWorker, + scorerSupplier, + M, + beamWidth, + new OnHeapHnswGraph(M, maxOrd), + null); } HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name); @@ -58,7 +64,7 @@ protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int m int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes); return new HnswConcurrentMergeBuilder( - exec, + taskExecutor, numWorker, scorerSupplier, M, diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java index 27e555a64adb..07d1ae2f698d 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java @@ -22,15 +22,12 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Future; +import java.util.concurrent.Callable; import java.util.concurrent.atomic.AtomicInteger; +import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.FixedBitSet; -import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.InfoStream; -import org.apache.lucene.util.ThreadInterruptedException; /** * A graph builder that manages multiple workers, it only supports adding the whole graph all at @@ -41,12 +38,12 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder { private static final int DEFAULT_BATCH_SIZE = 2048; // number of vectors the worker handles sequentially at one batch - private final ExecutorService exec; + private final TaskExecutor taskExecutor; private final ConcurrentMergeWorker[] workers; private InfoStream infoStream = InfoStream.getDefault(); public HnswConcurrentMergeBuilder( - ExecutorService exec, + TaskExecutor taskExecutor, int numWorker, RandomVectorScorerSupplier scorerSupplier, int M, @@ -54,7 +51,7 @@ public HnswConcurrentMergeBuilder( OnHeapHnswGraph hnsw, BitSet initializedNodes) throws IOException { - this.exec = exec; + this.taskExecutor = taskExecutor; AtomicInteger workProgress = new AtomicInteger(0); workers = new ConcurrentMergeWorker[numWorker]; for (int i = 0; i < numWorker; i++) { @@ -77,42 +74,16 @@ public OnHeapHnswGraph build(int maxOrd) throws IOException { HNSW_COMPONENT, "build graph from " + maxOrd + " vectors, with " + workers.length + " workers"); } - List> futures = new ArrayList<>(); + List> futures = new ArrayList<>(); for (int i = 0; i < workers.length; i++) { int finalI = i; futures.add( - exec.submit( - () -> { - try { - workers[finalI].run(maxOrd); - } catch (IOException e) { - throw new RuntimeException(e); - } - })); - } - Throwable exc = null; - for (Future future : futures) { - try { - future.get(); - } catch (InterruptedException e) { - var newException = new ThreadInterruptedException(e); - if (exc == null) { - exc = newException; - } else { - exc.addSuppressed(newException); - } - } catch (ExecutionException e) { - if (exc == null) { - exc = e.getCause(); - } else { - exc.addSuppressed(e.getCause()); - } - } - } - if (exc != null) { - // The error handling was copied from TaskExecutor. should we just use TaskExecutor instead? - throw IOUtils.rethrowAlways(exc); + () -> { + workers[finalI].run(maxOrd); + return null; + }); } + taskExecutor.invokeAll(futures); return workers[0].getGraph(); } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index 30105c2167b3..dde992d75dd4 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -66,6 +66,7 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.util.LuceneTestCase; @@ -978,10 +979,11 @@ public void testConcurrentMergeBuilder() throws IOException { AbstractMockVectorValues vectors = vectorValues(size, dim); RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); ExecutorService exec = Executors.newFixedThreadPool(4, new NamedThreadFactory("hnswMerge")); + TaskExecutor taskExecutor = new TaskExecutor(exec); HnswGraphBuilder.randSeed = random().nextLong(); HnswConcurrentMergeBuilder builder = new HnswConcurrentMergeBuilder( - exec, 4, scorerSupplier, 10, 30, new OnHeapHnswGraph(10, size), null); + taskExecutor, 4, scorerSupplier, 10, 30, new OnHeapHnswGraph(10, size), null); builder.setBatchSize(100); builder.build(size); exec.shutdownNow();