Skip to content

Commit

Permalink
Make TaskExecutor cx public and use TaskExecutor for concurrent HNSW …
Browse files Browse the repository at this point in the history
…graph build (apache#12799)

Make the TaskExecutor public which is currently pkg-private. At indexing time we concurrently create the hnsw graph (Concurrent HNSW Merge apache#12660). We could use the TaskExecutor implementation to do this for us.
Use TaskExecutor#invokeAll in HnswConcurrentMergeBuilder#build to run the workers concurrently.
  • Loading branch information
shubhamvishu authored and javanna committed Nov 21, 2023
1 parent 27c9736 commit e416d66
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 55 deletions.
5 changes: 5 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ API Changes
* GITHUB-12695: Remove public constructor of FSTCompiler. Please use FSTCompiler.Builder
instead. (Juan M. Caicedo)

* GITHUB#12735: Remove FSTCompiler#getTermCount() and FSTCompiler.UnCompiledNode#inputCount (Anh Dung Bui)

* GITHUB#12799: Make TaskExecutor constructor public and use TaskExecutor for concurrent
HNSW graph build. (Shubham Chaudhary)

New Features
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.util.hnsw.HnswGraph;

/**
Expand Down Expand Up @@ -60,7 +61,7 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo
private final FlatVectorsFormat flatVectorsFormat;

private final int numMergeWorkers;
private final ExecutorService mergeExec;
private final TaskExecutor mergeExec;

/** Constructs a format using default graph construction parameters */
public Lucene99HnswScalarQuantizedVectorsFormat() {
Expand Down Expand Up @@ -121,7 +122,11 @@ public Lucene99HnswScalarQuantizedVectorsFormat(
"No executor service is needed as we'll use single thread to merge");
}
this.numMergeWorkers = numMergeWorkers;
this.mergeExec = mergeExec;
if (mergeExec != null) {
this.mergeExec = new TaskExecutor(mergeExec);
} else {
this.mergeExec = null;
}
this.flatVectorsFormat = new Lucene99ScalarQuantizedVectorsFormat(confidenceInterval);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.hnsw.HnswGraph;

Expand Down Expand Up @@ -137,7 +138,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat();

private final int numMergeWorkers;
private final ExecutorService mergeExec;
private final TaskExecutor mergeExec;

/** Constructs a format using default graph construction parameters */
public Lucene99HnswVectorsFormat() {
Expand Down Expand Up @@ -192,7 +193,11 @@ public Lucene99HnswVectorsFormat(
"No executor service is needed as we'll use single thread to merge");
}
this.numMergeWorkers = numMergeWorkers;
this.mergeExec = mergeExec;
if (mergeExec != null) {
this.mergeExec = new TaskExecutor(mergeExec);
} else {
this.mergeExec = null;
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutorService;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.FlatVectorsWriter;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
Expand All @@ -35,6 +34,7 @@
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InfoStream;
Expand Down Expand Up @@ -67,7 +67,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
private final int beamWidth;
private final FlatVectorsWriter flatVectorWriter;
private final int numMergeWorkers;
private final ExecutorService mergeExec;
private final TaskExecutor mergeExec;

private final List<FieldWriter<?>> fields = new ArrayList<>();
private boolean finished;
Expand All @@ -78,7 +78,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
int beamWidth,
FlatVectorsWriter flatVectorWriter,
int numMergeWorkers,
ExecutorService mergeExec)
TaskExecutor mergeExec)
throws IOException {
this.M = M;
this.flatVectorWriter = flatVectorWriter;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ public final class TaskExecutor {

private final Executor executor;

TaskExecutor(Executor executor) {
/**
* Creates a TaskExecutor instance
*
* @param executor the executor to be used for running tasks concurrently
*/
public TaskExecutor(Executor executor) {
this.executor = Objects.requireNonNull(executor, "Executor is null");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
package org.apache.lucene.util.hnsw;

import java.io.IOException;
import java.util.concurrent.ExecutorService;
import org.apache.lucene.codecs.HnswGraphProvider;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.FixedBitSet;

/** This merger merges graph in a concurrent manner, by using {@link HnswConcurrentMergeBuilder} */
public class ConcurrentHnswMerger extends IncrementalHnswGraphMerger {

private final ExecutorService exec;
private final TaskExecutor taskExecutor;
private final int numWorker;

/**
Expand All @@ -38,10 +38,10 @@ public ConcurrentHnswMerger(
RandomVectorScorerSupplier scorerSupplier,
int M,
int beamWidth,
ExecutorService exec,
TaskExecutor taskExecutor,
int numWorker) {
super(fieldInfo, scorerSupplier, M, beamWidth);
this.exec = exec;
this.taskExecutor = taskExecutor;
this.numWorker = numWorker;
}

Expand All @@ -50,15 +50,21 @@ protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int m
throws IOException {
if (initReader == null) {
return new HnswConcurrentMergeBuilder(
exec, numWorker, scorerSupplier, M, beamWidth, new OnHeapHnswGraph(M, maxOrd), null);
taskExecutor,
numWorker,
scorerSupplier,
M,
beamWidth,
new OnHeapHnswGraph(M, maxOrd),
null);
}

HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name);
BitSet initializedNodes = new FixedBitSet(maxOrd);
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes);

return new HnswConcurrentMergeBuilder(
exec,
taskExecutor,
numWorker,
scorerSupplier,
M,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,12 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.ThreadInterruptedException;

/**
* A graph builder that manages multiple workers, it only supports adding the whole graph all at
Expand All @@ -41,20 +38,20 @@ public class HnswConcurrentMergeBuilder implements HnswBuilder {
private static final int DEFAULT_BATCH_SIZE =
2048; // number of vectors the worker handles sequentially at one batch

private final ExecutorService exec;
private final TaskExecutor taskExecutor;
private final ConcurrentMergeWorker[] workers;
private InfoStream infoStream = InfoStream.getDefault();

public HnswConcurrentMergeBuilder(
ExecutorService exec,
TaskExecutor taskExecutor,
int numWorker,
RandomVectorScorerSupplier scorerSupplier,
int M,
int beamWidth,
OnHeapHnswGraph hnsw,
BitSet initializedNodes)
throws IOException {
this.exec = exec;
this.taskExecutor = taskExecutor;
AtomicInteger workProgress = new AtomicInteger(0);
workers = new ConcurrentMergeWorker[numWorker];
for (int i = 0; i < numWorker; i++) {
Expand All @@ -77,42 +74,16 @@ public OnHeapHnswGraph build(int maxOrd) throws IOException {
HNSW_COMPONENT,
"build graph from " + maxOrd + " vectors, with " + workers.length + " workers");
}
List<Future<?>> futures = new ArrayList<>();
List<Callable<Void>> futures = new ArrayList<>();
for (int i = 0; i < workers.length; i++) {
int finalI = i;
futures.add(
exec.submit(
() -> {
try {
workers[finalI].run(maxOrd);
} catch (IOException e) {
throw new RuntimeException(e);
}
}));
}
Throwable exc = null;
for (Future<?> future : futures) {
try {
future.get();
} catch (InterruptedException e) {
var newException = new ThreadInterruptedException(e);
if (exc == null) {
exc = newException;
} else {
exc.addSuppressed(newException);
}
} catch (ExecutionException e) {
if (exc == null) {
exc = e.getCause();
} else {
exc.addSuppressed(e.getCause());
}
}
}
if (exc != null) {
// The error handling was copied from TaskExecutor. should we just use TaskExecutor instead?
throw IOUtils.rethrowAlways(exc);
() -> {
workers[finalI].run(maxOrd);
return null;
});
}
taskExecutor.invokeAll(futures);
return workers[0].getGraph();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.util.LuceneTestCase;
Expand Down Expand Up @@ -978,10 +979,11 @@ public void testConcurrentMergeBuilder() throws IOException {
AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
ExecutorService exec = Executors.newFixedThreadPool(4, new NamedThreadFactory("hnswMerge"));
TaskExecutor taskExecutor = new TaskExecutor(exec);
HnswGraphBuilder.randSeed = random().nextLong();
HnswConcurrentMergeBuilder builder =
new HnswConcurrentMergeBuilder(
exec, 4, scorerSupplier, 10, 30, new OnHeapHnswGraph(10, size), null);
taskExecutor, 4, scorerSupplier, 10, 30, new OnHeapHnswGraph(10, size), null);
builder.setBatchSize(100);
builder.build(size);
exec.shutdownNow();
Expand Down

0 comments on commit e416d66

Please sign in to comment.