Skip to content

Commit

Permalink
Batch the chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Oct 23, 2024
1 parent 57532e7 commit 2370870
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.Function;

Expand Down Expand Up @@ -668,25 +669,13 @@ public void chunkedInfer(
).batchRequestsWithListeners(listener);
}

for (var batch : batchedRequests) {
var inferenceRequest = buildInferenceRequest(
esModel.mlNodeDeploymentId(),
EmptyConfigUpdate.INSTANCE,
batch.batch().inputs(),
inputType,
timeout
);

ActionListener<InferModelAction.Response> mlResultsListener = batch.listener()
.delegateFailureAndWrap(
(l, inferenceResult) -> translateToChunkedResult(model.getTaskType(), inferenceResult.getInferenceResults(), l)
);

var maybeDeployListener = mlResultsListener.delegateResponse(
(l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, mlResultsListener)
);

client.execute(InferModelAction.INSTANCE, inferenceRequest, maybeDeployListener);
if (batchedRequests.isEmpty()) {
listener.onResponse(List.of());
} else {
// Avoid filling the inference queue by executing the batches in series
// Each batch contains up to EMBEDDING_MAX_BATCH_SIZE inference request
var sequentialRunner = new BatchIterator(esModel, inputType, timeout, batchedRequests);
sequentialRunner.doNextRequest();
}
} else {
listener.onFailure(notElasticsearchModelException(model));
Expand Down Expand Up @@ -1004,4 +993,58 @@ static TaskType inferenceConfigToTaskType(InferenceConfig config) {
return null;
}
}

// Iterates over the batch sending 1 request at a time to avoid
// filling the ml node inference queue.
class BatchIterator {
private final AtomicInteger index = new AtomicInteger();
private final ElasticsearchInternalModel esModel;
private final List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners;
private final InputType inputType;
private final TimeValue timeout;

BatchIterator(
ElasticsearchInternalModel esModel,
InputType inputType,
TimeValue timeout,
List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners
) {
this.esModel = esModel;
this.requestAndListeners = requestAndListeners;
this.inputType = inputType;
this.timeout = timeout;
}

void doNextRequest() {
inferenceExecutor.execute(() -> inferOnBatch(requestAndListeners.get(index.get())));
}

private void inferOnBatch(EmbeddingRequestChunker.BatchRequestAndListener batch) {
var inferenceRequest = buildInferenceRequest(
esModel.mlNodeDeploymentId(),
EmptyConfigUpdate.INSTANCE,
batch.batch().inputs(),
inputType,
timeout
);

ActionListener<InferModelAction.Response> mlResultsListener = batch.listener()
.delegateFailureAndWrap(
(l, inferenceResult) -> translateToChunkedResult(esModel.getTaskType(), inferenceResult.getInferenceResults(), l)
);

// schedule the next request once the results have been processed
var scheduleNextListener = ActionListener.runAfter(mlResultsListener, () -> {
if (index.incrementAndGet() < requestAndListeners.size()) {
doNextRequest();
}
});

var maybeDeployListener = scheduleNextListener.delegateResponse(
(l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, scheduleNextListener)
);

client.execute(InferModelAction.INSTANCE, inferenceRequest, maybeDeployListener);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,25 @@
import java.util.concurrent.atomic.AtomicReference;

import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.startsWith;

public class EmbeddingRequestChunkerTests extends ESTestCase {

public void testEmptyInput() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
assertThat(batches, empty());
}

public void testBlankInput() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
}

public void testShortInputsAreSingleBatch() {
String input = "one chunk";
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.logging.log4j.Level;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.LatchedActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -60,6 +61,7 @@
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.junit.After;
import org.junit.Before;
Expand All @@ -73,6 +75,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -936,17 +939,17 @@ public void testParsePersistedConfig() {
}
}

public void testChunkInfer_E5WithNullChunkingSettings() {
public void testChunkInfer_E5WithNullChunkingSettings() throws InterruptedException {
testChunkInfer_e5(null);
}

public void testChunkInfer_E5ChunkingSettingsSetAndFeatureFlagEnabled() {
public void testChunkInfer_E5ChunkingSettingsSetAndFeatureFlagEnabled() throws InterruptedException {
assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled());
testChunkInfer_e5(ChunkingSettingsTests.createRandomChunkingSettings());
}

@SuppressWarnings("unchecked")
private void testChunkInfer_e5(ChunkingSettings chunkingSettings) {
private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws InterruptedException {
var mlTrainedModelResults = new ArrayList<InferenceResults>();
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
Expand Down Expand Up @@ -994,6 +997,9 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) {
gotResults.set(true);
}, ESTestCase::fail);

var latch = new CountDownLatch(1);
var latchedListener = new LatchedActionListener<>(resultsListener, latch);

service.chunkedInfer(
model,
null,
Expand All @@ -1002,23 +1008,24 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) {
InputType.SEARCH,
new ChunkingOptions(null, null),
InferenceAction.Request.DEFAULT_TIMEOUT,
ActionListener.runAfter(resultsListener, () -> terminate(threadPool))
latchedListener
);

latch.await();
assertTrue("Listener not called", gotResults.get());
}

public void testChunkInfer_SparseWithNullChunkingSettings() {
public void testChunkInfer_SparseWithNullChunkingSettings() throws InterruptedException {
testChunkInfer_Sparse(null);
}

public void testChunkInfer_SparseWithChunkingSettingsSetAndFeatureFlagEnabled() {
public void testChunkInfer_SparseWithChunkingSettingsSetAndFeatureFlagEnabled() throws InterruptedException {
assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled());
testChunkInfer_Sparse(ChunkingSettingsTests.createRandomChunkingSettings());
}

@SuppressWarnings("unchecked")
private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) {
private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) throws InterruptedException {
var mlTrainedModelResults = new ArrayList<InferenceResults>();
mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults());
mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults());
Expand All @@ -1042,6 +1049,7 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) {
var service = createService(client);

var gotResults = new AtomicBoolean();

var resultsListener = ActionListener.<List<ChunkedInferenceServiceResults>>wrap(chunkedResponse -> {
assertThat(chunkedResponse, hasSize(2));
assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class));
Expand All @@ -1061,6 +1069,9 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) {
gotResults.set(true);
}, ESTestCase::fail);

var latch = new CountDownLatch(1);
var latchedListener = new LatchedActionListener<>(resultsListener, latch);

service.chunkedInfer(
model,
null,
Expand All @@ -1069,23 +1080,24 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) {
InputType.SEARCH,
new ChunkingOptions(null, null),
InferenceAction.Request.DEFAULT_TIMEOUT,
ActionListener.runAfter(resultsListener, () -> terminate(threadPool))
latchedListener
);

latch.await();
assertTrue("Listener not called", gotResults.get());
}

public void testChunkInfer_ElserWithNullChunkingSettings() {
public void testChunkInfer_ElserWithNullChunkingSettings() throws InterruptedException {
testChunkInfer_Elser(null);
}

public void testChunkInfer_ElserWithChunkingSettingsSetAndFeatureFlagEnabled() {
public void testChunkInfer_ElserWithChunkingSettingsSetAndFeatureFlagEnabled() throws InterruptedException {
assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled());
testChunkInfer_Elser(ChunkingSettingsTests.createRandomChunkingSettings());
}

@SuppressWarnings("unchecked")
private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) {
private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) throws InterruptedException {
var mlTrainedModelResults = new ArrayList<InferenceResults>();
mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults());
mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults());
Expand Down Expand Up @@ -1129,6 +1141,9 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) {
gotResults.set(true);
}, ESTestCase::fail);

var latch = new CountDownLatch(1);
var latchedListener = new LatchedActionListener<>(resultsListener, latch);

service.chunkedInfer(
model,
null,
Expand All @@ -1137,9 +1152,10 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) {
InputType.SEARCH,
new ChunkingOptions(null, null),
InferenceAction.Request.DEFAULT_TIMEOUT,
ActionListener.runAfter(resultsListener, () -> terminate(threadPool))
latchedListener
);

latch.await();
assertTrue("Listener not called", gotResults.get());
}

Expand Down Expand Up @@ -1200,7 +1216,7 @@ public void testChunkInferSetsTokenization() {
}

@SuppressWarnings("unchecked")
public void testChunkInfer_FailsBatch() {
public void testChunkInfer_FailsBatch() throws InterruptedException {
var mlTrainedModelResults = new ArrayList<InferenceResults>();
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
Expand Down Expand Up @@ -1236,6 +1252,9 @@ public void testChunkInfer_FailsBatch() {
gotResults.set(true);
}, ESTestCase::fail);

var latch = new CountDownLatch(1);
var latchedListener = new LatchedActionListener<>(resultsListener, latch);

service.chunkedInfer(
model,
null,
Expand All @@ -1244,9 +1263,93 @@ public void testChunkInfer_FailsBatch() {
InputType.SEARCH,
new ChunkingOptions(null, null),
InferenceAction.Request.DEFAULT_TIMEOUT,
ActionListener.runAfter(resultsListener, () -> terminate(threadPool))
latchedListener
);

latch.await();
assertTrue("Listener not called", gotResults.get());
}

@SuppressWarnings("unchecked")
public void testChunkingLargeDocument() throws InterruptedException {
assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled());

int wordsPerChunk = 10;
int numBatches = randomIntBetween(3, 6);
int numChunks = randomIntBetween(
((numBatches - 1) * ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE) + 1,
numBatches * ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE
);

// build a doc with enough words to make numChunks of chunks
int numWords = numChunks * wordsPerChunk;
var docBuilder = new StringBuilder();
for (int i = 0; i < numWords; i++) {
docBuilder.append("word ");
}

// how many response objects to return in each batch
int[] numResponsesPerBatch = new int[numBatches];
for (int i = 0; i < numBatches - 1; i++) {
numResponsesPerBatch[i] = ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE;
}
numResponsesPerBatch[numBatches - 1] = numChunks % ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE;

var batchIndex = new AtomicInteger();
Client client = mock(Client.class);
when(client.threadPool()).thenReturn(threadPool);

// mock the inference response
doAnswer(invocationOnMock -> {
var listener = (ActionListener<InferModelAction.Response>) invocationOnMock.getArguments()[2];

var mlTrainedModelResults = new ArrayList<InferenceResults>();
for (int i = 0; i < numResponsesPerBatch[batchIndex.get()]; i++) {
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
}
batchIndex.incrementAndGet();
var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true);
listener.onResponse(response);
return null;
}).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class));

var service = createService(client);

var gotResults = new AtomicBoolean();
var resultsListener = ActionListener.<List<ChunkedInferenceServiceResults>>wrap(chunkedResponse -> {
assertThat(chunkedResponse, hasSize(1));
assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
var sparseResults = (InferenceChunkedTextEmbeddingFloatResults) chunkedResponse.get(0);
assertThat(sparseResults.chunks(), hasSize(numChunks));

gotResults.set(true);
}, ESTestCase::fail);

// Create model using the word boundary chunker.
var model = new MultilingualE5SmallModel(
"foo",
TaskType.TEXT_EMBEDDING,
"e5",
new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform", null),
new WordBoundaryChunkingSettings(wordsPerChunk, 0)
);

var latch = new CountDownLatch(1);
var latchedListener = new LatchedActionListener<>(resultsListener, latch);

// For the given input we know how many requests will be made
service.chunkedInfer(
model,
null,
List.of(docBuilder.toString()),
Map.of(),
InputType.SEARCH,
new ChunkingOptions(null, null),
InferenceAction.Request.DEFAULT_TIMEOUT,
latchedListener
);

latch.await();
assertTrue("Listener not called", gotResults.get());
}

Expand Down

0 comments on commit 2370870

Please sign in to comment.