Skip to content

Commit

Permalink
Moving byte embeddings to text_embedding_bytes field (#105290)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-buttner authored Feb 8, 2024
1 parent 8cfcb70 commit 563c3e6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;

Expand Down Expand Up @@ -43,7 +42,7 @@
*/
public record TextEmbeddingByteResults(List<Embedding> embeddings) implements InferenceServiceResults, TextEmbedding {
public static final String NAME = "text_embedding_service_byte_results";
public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString();
public static final String TEXT_EMBEDDING_BYTES = "text_embedding_bytes";

public TextEmbeddingByteResults(StreamInput in) throws IOException {
this(in.readCollectionAsList(Embedding::new));
Expand All @@ -56,7 +55,7 @@ public int getFirstEmbeddingSize() {

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startArray(TEXT_EMBEDDING);
builder.startArray(TEXT_EMBEDDING_BYTES);
for (Embedding embedding : embeddings) {
embedding.toXContent(builder, params);
}
Expand All @@ -78,7 +77,7 @@ public String getWriteableName() {
public List<? extends InferenceResults> transformToCoordinationFormat() {
return embeddings.stream()
.map(embedding -> embedding.values.stream().mapToDouble(value -> value).toArray())
.map(values -> new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(TEXT_EMBEDDING, values, false))
.map(values -> new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(TEXT_EMBEDDING_BYTES, values, false))
.toList();
}

Expand All @@ -94,7 +93,7 @@ public List<? extends InferenceResults> transformToLegacyFormat() {

public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(TEXT_EMBEDDING, embeddings.stream().map(Embedding::asMap).collect(Collectors.toList()));
map.put(TEXT_EMBEDDING_BYTES, embeddings.stream().map(Embedding::asMap).collect(Collectors.toList()));

return map;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE
entity.asMap(),
is(
Map.of(
TextEmbeddingByteResults.TEXT_EMBEDDING,
TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES,
List.of(Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, List.of((byte) 23)))
)
)
Expand All @@ -58,7 +58,7 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE
String xContentResult = Strings.toString(entity, true, true);
assertThat(xContentResult, is("""
{
"text_embedding" : [
"text_embedding_bytes" : [
{
"embedding" : [
23
Expand All @@ -78,7 +78,7 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I
entity.asMap(),
is(
Map.of(
TextEmbeddingByteResults.TEXT_EMBEDDING,
TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES,
List.of(
Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, List.of((byte) 23)),
Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, List.of((byte) 24))
Expand All @@ -90,7 +90,7 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I
String xContentResult = Strings.toString(entity, true, true);
assertThat(xContentResult, is("""
{
"text_embedding" : [
"text_embedding_bytes" : [
{
"embedding" : [
23
Expand Down Expand Up @@ -118,12 +118,12 @@ public void testTransformToCoordinationFormat() {
is(
List.of(
new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(
TextEmbeddingByteResults.TEXT_EMBEDDING,
TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES,
new double[] { 23F, 24F },
false
),
new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(
TextEmbeddingByteResults.TEXT_EMBEDDING,
TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES,
new double[] { 25F, 26F },
false
)
Expand Down Expand Up @@ -158,7 +158,7 @@ protected TextEmbeddingByteResults mutateInstance(TextEmbeddingByteResults insta

public static Map<String, Object> buildExpectation(List<List<Byte>> embeddings) {
return Map.of(
TextEmbeddingByteResults.TEXT_EMBEDDING,
TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES,
embeddings.stream().map(embedding -> Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList()
);
}
Expand Down

0 comments on commit 563c3e6

Please sign in to comment.