Skip to content

Commit

Permalink
tidy up
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Dec 13, 2024
1 parent 0d84d70 commit 7d84ae7
Show file tree
Hide file tree
Showing 47 changed files with 141 additions and 196 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import java.io.IOException;
import java.util.Iterator;

public interface InferenceChunks {
public interface ChunkedInference {

/**
* Implementations of this function serialize their embeddings to {@link BytesReference} for storage in semantic text fields.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ void chunkedInfer(
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<List<InferenceChunks>> listener
ActionListener<List<ChunkedInference>> listener
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.inference.InferenceChunks;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;

Expand All @@ -17,13 +17,12 @@
import java.util.Iterator;
import java.util.List;

public record ChunkedInferenceEmbeddingByte(List<ChunkedInferenceEmbeddingByte.ByteEmbeddingChunk> chunks) implements InferenceChunks {
public record ChunkedInferenceEmbeddingByte(List<ChunkedInferenceEmbeddingByte.ByteEmbeddingChunk> chunks) implements ChunkedInference {

@Override
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException {
var asChunk = new ArrayList<Chunk>();
for (var chunk : chunks)
{
for (var chunk : chunks) {
asChunk.add(new Chunk(chunk.matchedText(), chunk.offset(), toBytesReference(xcontent, chunk.embedding())));
}
return asChunk.iterator();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,16 @@
package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.inference.InferenceChunks;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;

public record ChunkedInferenceEmbeddingFloat(List<FloatEmbeddingChunk> chunks) implements InferenceChunks {
public record ChunkedInferenceEmbeddingFloat(List<FloatEmbeddingChunk> chunks) implements ChunkedInference {

@Override
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException {
Expand All @@ -43,21 +41,5 @@ private static BytesReference toBytesReference(XContent xContent, float[] value)
return BytesReference.bytes(b);
}

public record FloatEmbeddingChunk(float[] embedding, String matchedText, TextOffset offset) {
//
// @Override
// public int hashCode() {
// return Objects.hash(Arrays.hashCode(embedding), matchedText, offset);
// }
//
// @Override
// public boolean equals(Object o) {
// if (this == o) return true;
// if (o == null || getClass() != o.getClass()) return false;
// FloatEmbeddingChunk that = (FloatEmbeddingChunk) o;
// return this.matchedText.equals(that.matchedText)
// && this.offset.equals(that.offset)
// && Arrays.equals(this.embedding, that.embedding);
// }
}
public record FloatEmbeddingChunk(float[] embedding, String matchedText, TextOffset offset) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.inference.InferenceChunks;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;
Expand All @@ -21,12 +21,12 @@

import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings;

public record ChunkedInferenceEmbeddingSparse(List<SparseEmbeddingChunk> chunks) implements InferenceChunks {
public record ChunkedInferenceEmbeddingSparse(List<SparseEmbeddingChunk> chunks) implements ChunkedInference {

public static List<InferenceChunks> listOf(List<String> inputs, SparseEmbeddingResults sparseEmbeddingResults) {
public static List<ChunkedInference> listOf(List<String> inputs, SparseEmbeddingResults sparseEmbeddingResults) {
validateInputSizeAgainstEmbeddings(inputs, sparseEmbeddingResults.embeddings().size());

var results = new ArrayList<InferenceChunks>(inputs.size());
var results = new ArrayList<ChunkedInference>(inputs.size());
for (int i = 0; i < inputs.size(); i++) {
results.add(
new ChunkedInferenceEmbeddingSparse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.inference.InferenceChunks;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.xcontent.XContent;

import java.util.Iterator;
import java.util.stream.Stream;

public record ChunkedInferenceError(Exception exception) implements InferenceChunks {
public record ChunkedInferenceError(Exception exception) implements ChunkedInference {

@Override
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.EmptySettingsConfiguration;
import org.elasticsearch.inference.InferenceChunks;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
Expand Down Expand Up @@ -151,7 +151,7 @@ public void chunkedInfer(
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<List<InferenceChunks>> listener
ActionListener<List<ChunkedInference>> listener
) {
switch (model.getConfigurations().getTaskType()) {
case ANY, TEXT_EMBEDDING -> {
Expand All @@ -176,18 +176,18 @@ private InferenceTextEmbeddingFloatResults makeResults(List<String> input, int d
return new InferenceTextEmbeddingFloatResults(embeddings);
}

private List<InferenceChunks> makeChunkedResults(List<String> input, int dimensions) {
private List<ChunkedInference> makeChunkedResults(List<String> input, int dimensions) {
InferenceTextEmbeddingFloatResults nonChunkedResults = makeResults(input, dimensions);

var results = new ArrayList<InferenceChunks>();
var results = new ArrayList<ChunkedInference>();
for (int i = 0; i < input.size(); i++) {
results.add(
new ChunkedInferenceEmbeddingFloat(
List.of(
new ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk(
nonChunkedResults.embeddings().get(i).values(),
input.get(i),
new InferenceChunks.TextOffset(0, input.get(i).length())
new ChunkedInference.TextOffset(0, input.get(i).length())
)
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.EmptySettingsConfiguration;
import org.elasticsearch.inference.InferenceChunks;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
Expand Down Expand Up @@ -139,7 +139,7 @@ public void chunkedInfer(
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<List<InferenceChunks>> listener
ActionListener<List<ChunkedInference>> listener
) {
listener.onFailure(
new ElasticsearchStatusException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.EmptySettingsConfiguration;
import org.elasticsearch.inference.InferenceChunks;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
Expand Down Expand Up @@ -141,7 +141,7 @@ public void chunkedInfer(
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<List<InferenceChunks>> listener
ActionListener<List<ChunkedInference>> listener
) {
switch (model.getConfigurations().getTaskType()) {
case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeChunkedResults(input));
Expand All @@ -166,8 +166,8 @@ private SparseEmbeddingResults makeResults(List<String> input) {
return new SparseEmbeddingResults(embeddings);
}

private List<InferenceChunks> makeChunkedResults(List<String> input) {
List<InferenceChunks> results = new ArrayList<>();
private List<ChunkedInference> makeChunkedResults(List<String> input) {
List<ChunkedInference> results = new ArrayList<>();
for (int i = 0; i < input.size(); i++) {
var tokens = new ArrayList<WeightedToken>();
for (int j = 0; j < 5; j++) {
Expand All @@ -179,7 +179,7 @@ private List<InferenceChunks> makeChunkedResults(List<String> input) {
new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk(
tokens,
input.get(i),
new InferenceChunks.TextOffset(0, input.get(i).length())
new ChunkedInference.TextOffset(0, input.get(i).length())
)
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.EmptySettingsConfiguration;
import org.elasticsearch.inference.InferenceChunks;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
Expand Down Expand Up @@ -234,7 +234,7 @@ public void chunkedInfer(
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<List<InferenceChunks>> listener
ActionListener<List<ChunkedInference>> listener
) {
listener.onFailure(
new ElasticsearchStatusException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
Expand Down Expand Up @@ -108,7 +104,6 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
);

addInferenceResultsNamedWriteables(namedWriteables);
addChunkedInferenceResultsNamedWriteables(namedWriteables);

// Empty default task settings
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new));
Expand Down Expand Up @@ -433,37 +428,6 @@ private static void addInternalNamedWriteables(List<NamedWriteableRegistry.Entry
);
}

private static void addChunkedInferenceResultsNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceServiceResults.class,
ErrorChunkedInferenceResults.NAME,
ErrorChunkedInferenceResults::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceServiceResults.class,
InferenceChunkedSparseEmbeddingResults.NAME,
InferenceChunkedSparseEmbeddingResults::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceServiceResults.class,
InferenceChunkedTextEmbeddingFloatResults.NAME,
InferenceChunkedTextEmbeddingFloatResults::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceServiceResults.class,
InferenceChunkedTextEmbeddingByteResults.NAME,
InferenceChunkedTextEmbeddingByteResults::new
)
);
}

private static void addChunkingSettingsNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ChunkingSettings.class, WordBoundaryChunkingSettings.NAME, WordBoundaryChunkingSettings::new)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceChunks;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.InputType;
Expand Down Expand Up @@ -142,7 +142,7 @@ private record FieldInferenceResponse(
int inputOrder,
boolean isOriginalFieldInput,
Model model,
InferenceChunks chunkedResults
ChunkedInference chunkedResults
) {}

private record FieldInferenceResponseAccumulator(
Expand Down Expand Up @@ -274,12 +274,12 @@ public void onFailure(Exception exc) {
final List<FieldInferenceRequest> currentBatch = requests.subList(0, currentBatchSize);
final List<FieldInferenceRequest> nextBatch = requests.subList(currentBatchSize, requests.size());
final List<String> inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList());
ActionListener<List<InferenceChunks>> completionListener = new ActionListener<>() {
ActionListener<List<ChunkedInference>> completionListener = new ActionListener<>() {
@Override
public void onResponse(List<InferenceChunks> results) {
public void onResponse(List<ChunkedInference> results) {
try {
var requestsIterator = requests.iterator();
for (InferenceChunks result : results) {
for (ChunkedInference result : results) {
var request = requestsIterator.next();
var acc = inferenceResults.get(request.index);
if (result instanceof ChunkedInferenceError error) {
Expand Down Expand Up @@ -377,7 +377,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
// ensure that the order in the original field is consistent in case of multiple inputs
Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder));
List<String> inputs = responses.stream().filter(r -> r.isOriginalFieldInput).map(r -> r.input).collect(Collectors.toList());
List<InferenceChunks> results = responses.stream().map(r -> r.chunkedResults).collect(Collectors.toList());
List<ChunkedInference> results = responses.stream().map(r -> r.chunkedResults).collect(Collectors.toList());
var result = new SemanticTextField(
fieldName,
inputs,
Expand Down
Loading

0 comments on commit 7d84ae7

Please sign in to comment.