Skip to content

Commit

Permalink
[ML] Include the chunk text offsets in chunked inference response (el…
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle authored Dec 16, 2024
1 parent 6453a66 commit c4e964e
Show file tree
Hide file tree
Showing 56 changed files with 610 additions and 1,412 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,27 @@
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.xcontent.XContent;

import java.io.IOException;
import java.util.Iterator;

public interface ChunkedInferenceServiceResults extends InferenceServiceResults {
public interface ChunkedInference {

/**
* Implementations of this function serialize their embeddings to {@link BytesReference} for storage in semantic text fields.
* The iterator iterates over all the chunks stored in the {@link ChunkedInferenceServiceResults}.
*
* @param xcontent provided by the SemanticTextField
* @return an iterator of the serialized {@link Chunk} which includes the matched text (input) and bytes reference (output/embedding).
*/
Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent);
Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException;

/**
* A chunk of inference results containing matched text and the bytes reference.
* A chunk of inference results containing matched text, the substring location
* in the original text and the bytes reference.
* @param matchedText
* @param textOffset
* @param bytesReference
*/
record Chunk(String matchedText, BytesReference bytesReference) {}
record Chunk(String matchedText, TextOffset textOffset, BytesReference bytesReference) {}

record TextOffset(int start, int end) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ void chunkedInfer(
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<List<ChunkedInferenceServiceResults>> listener
ActionListener<List<ChunkedInference>> listener
);

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

public record ChunkedInferenceEmbeddingByte(List<ChunkedInferenceEmbeddingByte.ByteEmbeddingChunk> chunks) implements ChunkedInference {

@Override
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException {
var asChunk = new ArrayList<Chunk>();
for (var chunk : chunks) {
asChunk.add(new Chunk(chunk.matchedText(), chunk.offset(), toBytesReference(xcontent, chunk.embedding())));
}
return asChunk.iterator();
}

/**
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
*/
private static BytesReference toBytesReference(XContent xContent, byte[] value) throws IOException {
XContentBuilder builder = XContentBuilder.builder(xContent);
builder.startArray();
for (byte v : value) {
builder.value(v);
}
builder.endArray();
return BytesReference.bytes(builder);
}

public record ByteEmbeddingChunk(byte[] embedding, String matchedText, TextOffset offset) {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

public record ChunkedInferenceEmbeddingFloat(List<FloatEmbeddingChunk> chunks) implements ChunkedInference {

@Override
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException {
var asChunk = new ArrayList<Chunk>();
for (var chunk : chunks) {
asChunk.add(new Chunk(chunk.matchedText(), chunk.offset(), toBytesReference(xcontent, chunk.embedding())));
}
return asChunk.iterator();
}

/**
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
*/
private static BytesReference toBytesReference(XContent xContent, float[] value) throws IOException {
XContentBuilder b = XContentBuilder.builder(xContent);
b.startArray();
for (float v : value) {
b.value(v);
}
b.endArray();
return BytesReference.bytes(b);
}

public record FloatEmbeddingChunk(float[] embedding, String matchedText, TextOffset offset) {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.search.WeightedToken;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings;

public record ChunkedInferenceEmbeddingSparse(List<SparseEmbeddingChunk> chunks) implements ChunkedInference {

public static List<ChunkedInference> listOf(List<String> inputs, SparseEmbeddingResults sparseEmbeddingResults) {
validateInputSizeAgainstEmbeddings(inputs, sparseEmbeddingResults.embeddings().size());

var results = new ArrayList<ChunkedInference>(inputs.size());
for (int i = 0; i < inputs.size(); i++) {
results.add(
new ChunkedInferenceEmbeddingSparse(
List.of(
new SparseEmbeddingChunk(
sparseEmbeddingResults.embeddings().get(i).tokens(),
inputs.get(i),
new TextOffset(0, inputs.get(i).length())
)
)
)
);
}

return results;
}

@Override
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException {
var asChunk = new ArrayList<Chunk>();
for (var chunk : chunks) {
asChunk.add(new Chunk(chunk.matchedText(), chunk.offset(), toBytesReference(xcontent, chunk.weightedTokens())));
}
return asChunk.iterator();
}

private static BytesReference toBytesReference(XContent xContent, List<WeightedToken> tokens) throws IOException {
XContentBuilder b = XContentBuilder.builder(xContent);
b.startObject();
for (var weightedToken : tokens) {
weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS);
}
b.endObject();
return BytesReference.bytes(b);
}

public record SparseEmbeddingChunk(List<WeightedToken> weightedTokens, String matchedText, TextOffset offset) {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.xcontent.XContent;

import java.util.Iterator;
import java.util.stream.Stream;

public record ChunkedInferenceError(Exception exception) implements ChunkedInference {

@Override
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) {
return Stream.of(exception).map(e -> new Chunk(e.getMessage(), new TextOffset(0, 0), BytesArray.EMPTY)).iterator();
}
}

This file was deleted.

Loading

0 comments on commit c4e964e

Please sign in to comment.