Skip to content

Commit

Permalink
Add Highlighter for Semantic Text Fields
Browse files Browse the repository at this point in the history
This PR introduces a new highlighter, `semantic`, tailored for semantic text fields.
It extracts the most relevant fragments by scoring nested chunks using the original semantic query.

In this initial version, the highlighter returns only the original chunks computed during ingestion. However, this is an implementation detail, and future enhancements could combine multiple chunks to generate the fragments.
  • Loading branch information
jimczi committed Dec 5, 2024
1 parent 2fe6b60 commit 534a96a
Show file tree
Hide file tree
Showing 9 changed files with 1,026 additions and 43 deletions.
53 changes: 22 additions & 31 deletions docs/reference/mapping/types/semantic-text.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -112,50 +112,41 @@ Trying to <<delete-inference-api,delete an {infer} endpoint>> that is used on a
{infer-cap} endpoints have a limit on the amount of text they can process.
To allow for large amounts of text to be used in semantic search, `semantic_text` automatically generates smaller passages if needed, called _chunks_.

Each chunk will include the text subpassage and the corresponding embedding generated from it.
Each chunk refers to a passage of the text and the corresponding embedding generated from it.
When querying, the individual passages will be automatically searched for each document, and the most relevant passage will be used to compute a score.

For more details on chunking and how to configure chunking settings, see <<infer-chunking-config, Configuring chunking>> in the Inference API documentation.

Refer to <<semantic-search-semantic-text,this tutorial>> to learn more about
semantic search using `semantic_text` and the `semantic` query.

[discrete]
[[semantic-text-structure]]
==== `semantic_text` structure
[[semantic-text-highlighting]]
==== Extracting Relevant Fragments from Semantic Text

Once a document is ingested, a `semantic_text` field will have the following structure:
You can extract the most relevant fragments from a semantic text field by using the <<highlighting,highlight parameter>> in the <<search-search-api-request-body,Search API>>.

[source,console-result]
[source,console]
------------------------------------------------------------
"inference_field": {
"text": "these are not the droids you're looking for", <1>
"inference": {
"inference_id": "my-elser-endpoint", <2>
"model_settings": { <3>
"task_type": "sparse_embedding"
PUT test-index
{
"query": {
"semantic": {
"field": "my_semantic_field"
}
},
"chunks": [ <4>
{
"text": "these are not the droids you're looking for",
"embeddings": {
(...)
"highlight": {
"fields": {
"my_semantic_field": {
"type": "semantic",
"number_of_fragments": 2 <1>
}
}
}
]
}
}
}
------------------------------------------------------------
// TEST[skip:TBD]
<1> The field will become an object structure to accommodate both the original
text and the inference results.
<2> The `inference_id` used to generate the embeddings.
<3> Model settings, including the task type and dimensions/similarity if
applicable.
<4> Inference results will be grouped in chunks, each with its corresponding
text and embeddings.

Refer to <<semantic-search-semantic-text,this tutorial>> to learn more about
semantic search using `semantic_text` and the `semantic` query.

// TEST[skip:Requires inference endpoint]
<1> Specifies the maximum number of fragments to return.

[discrete]
[[custom-indexing]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.elasticsearch.plugins.SystemIndexPlugin;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestHandler;
import org.elasticsearch.search.fetch.subphase.highlight.Highlighter;
import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.threadpool.ExecutorBuilder;
Expand Down Expand Up @@ -67,6 +68,7 @@
import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings;
import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
Expand Down Expand Up @@ -411,4 +413,9 @@ public List<RetrieverSpec<?>> getRetrievers() {
new RetrieverSpec<>(new ParseField(RandomRankBuilder.NAME), RandomRankRetrieverBuilder::fromXContent)
);
}

@Override
public Map<String, Highlighter> getHighlighters() {
return Map.of(SemanticTextHighlighter.NAME, new SemanticTextHighlighter());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.highlight;

import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType;
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper.SparseVectorFieldType;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.fetch.subphase.highlight.FieldHighlightContext;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightField;
import org.elasticsearch.search.fetch.subphase.highlight.Highlighter;
import org.elasticsearch.search.vectors.VectorData;
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;

/**
* A {@link Highlighter} designed for the {@link SemanticTextFieldMapper}.
* This highlighter extracts semantic queries and evaluates them against each chunk produced by the semantic text field.
* It returns the top-scoring chunks as snippets, optionally sorted by their scores.
*/
public class SemanticTextHighlighter implements Highlighter {
public static final String NAME = "semantic";

private record OffsetAndScore(int offset, float score) {}

@Override
public boolean canHighlight(MappedFieldType fieldType) {
if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType) {
return true;
}
return false;
}

@Override
public HighlightField highlight(FieldHighlightContext fieldContext) throws IOException {
SemanticTextFieldMapper.SemanticTextFieldType fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) fieldContext.fieldType;
if (fieldType.getEmbeddingsField() == null) {
// nothing indexed yet
return null;
}

final List<Query> queries = switch (fieldType.getModelSettings().taskType()) {
case SPARSE_EMBEDDING -> extractSparseVectorQueries(
(SparseVectorFieldType) fieldType.getEmbeddingsField().fieldType(),
fieldContext.query
);
case TEXT_EMBEDDING -> extractDenseVectorQueries(
(DenseVectorFieldType) fieldType.getEmbeddingsField().fieldType(),
fieldContext.query
);
default -> throw new IllegalStateException(
"Wrong task type for a semantic text field, got [" + fieldType.getModelSettings().taskType().name() + "]"
);
};
if (queries.isEmpty()) {
// nothing to highlight
return null;
}

int numberOfFragments = fieldContext.field.fieldOptions().numberOfFragments() <= 0
? 1 // we return the best fragment by default
: fieldContext.field.fieldOptions().numberOfFragments();

List<OffsetAndScore> chunks = extractOffsetAndScores(
fieldContext.context.getSearchExecutionContext(),
fieldContext.hitContext.reader(),
fieldType,
fieldContext.hitContext.docId(),
queries
);
if (chunks.size() == 0) {
return null;
}

chunks.sort(Comparator.comparingDouble(OffsetAndScore::score).reversed());
int size = Math.min(chunks.size(), numberOfFragments);
if (fieldContext.field.fieldOptions().scoreOrdered() == false) {
chunks = chunks.subList(0, size);
chunks.sort(Comparator.comparingInt(c -> c.offset));
}
Text[] snippets = new Text[size];
List<Map<?, ?>> nestedSources = XContentMapValues.extractNestedSources(
fieldType.getChunksField().fullPath(),
fieldContext.hitContext.source().source()
);
for (int i = 0; i < size; i++) {
var chunk = chunks.get(i);
if (nestedSources.size() <= chunk.offset) {
throw new IllegalStateException("Invalid content for field [" + fieldType.name() + "]");
}
String content = (String) nestedSources.get(chunk.offset).get(SemanticTextField.CHUNKED_TEXT_FIELD);
if (content == null) {
throw new IllegalStateException("Invalid content for field [" + fieldType.name() + "]");
}
snippets[i] = new Text(content);
}
return new HighlightField(fieldContext.fieldName, snippets);
}

private List<OffsetAndScore> extractOffsetAndScores(
SearchExecutionContext context,
LeafReader reader,
SemanticTextFieldMapper.SemanticTextFieldType fieldType,
int docId,
List<Query> leafQueries
) throws IOException {
var bitSet = context.bitsetFilter(fieldType.getChunksField().parentTypeFilter()).getBitSet(reader.getContext());
int previousParent = docId > 0 ? bitSet.prevSetBit(docId - 1) : -1;

BooleanQuery.Builder bq = new BooleanQuery.Builder().add(fieldType.getChunksField().nestedTypeFilter(), BooleanClause.Occur.FILTER);
leafQueries.stream().forEach(q -> bq.add(q, BooleanClause.Occur.SHOULD));
Weight weight = new IndexSearcher(reader).createWeight(bq.build(), ScoreMode.COMPLETE, 1);
Scorer scorer = weight.scorer(reader.getContext());
if (previousParent != -1) {
if (scorer.iterator().advance(previousParent) == DocIdSetIterator.NO_MORE_DOCS) {
return List.of();
}
} else if (scorer.iterator().nextDoc() == DocIdSetIterator.NO_MORE_DOCS) {
return List.of();
}
List<OffsetAndScore> results = new ArrayList<>();
int offset = 0;
while (scorer.docID() < docId) {
results.add(new OffsetAndScore(offset++, scorer.score()));
if (scorer.iterator().nextDoc() == DocIdSetIterator.NO_MORE_DOCS) {
break;
}
}
return results;
}

private List<Query> extractDenseVectorQueries(DenseVectorFieldType fieldType, Query querySection) {
// TODO: Handle knn section when semantic text field can be used.
List<Query> queries = new ArrayList<>();
querySection.visit(new QueryVisitor() {
@Override
public boolean acceptField(String field) {
return fieldType.name().equals(field);
}

@Override
public void consumeTerms(Query query, Term... terms) {
super.consumeTerms(query, terms);
}

@Override
public void visitLeaf(Query query) {
if (query instanceof KnnFloatVectorQuery knnQuery) {
queries.add(fieldType.createExactKnnQuery(VectorData.fromFloats(knnQuery.getTargetCopy()), null));
} else if (query instanceof KnnByteVectorQuery knnQuery) {
queries.add(fieldType.createExactKnnQuery(VectorData.fromBytes(knnQuery.getTargetCopy()), null));
}
}
});
return queries;
}

private List<Query> extractSparseVectorQueries(SparseVectorFieldType fieldType, Query querySection) {
List<Query> queries = new ArrayList<>();
querySection.visit(new QueryVisitor() {
@Override
public boolean acceptField(String field) {
return fieldType.name().equals(field);
}

@Override
public void consumeTerms(Query query, Term... terms) {
super.consumeTerms(query, terms);
}

@Override
public QueryVisitor getSubVisitor(BooleanClause.Occur occur, Query parent) {
if (parent instanceof SparseVectorQueryWrapper sparseVectorQuery) {
queries.add(sparseVectorQuery.getTermsQuery());
}
return this;
}
});
return queries;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public record SemanticTextField(String fieldName, List<String> originalValues, I
static final String SEARCH_INFERENCE_ID_FIELD = "search_inference_id";
static final String CHUNKS_FIELD = "chunks";
static final String CHUNKED_EMBEDDINGS_FIELD = "embeddings";
static final String CHUNKED_TEXT_FIELD = "text";
public static final String CHUNKED_TEXT_FIELD = "text";
static final String MODEL_SETTINGS_FIELD = "model_settings";
static final String TASK_TYPE_FIELD = "task_type";
static final String DIMENSIONS_FIELD = "dimensions";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
import org.elasticsearch.index.query.NestedQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.SimilarityMeasure;
Expand All @@ -57,6 +56,7 @@
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -529,17 +529,8 @@ public QueryBuilder semanticQuery(InferenceResults inferenceResults, Integer req
);
}

// TODO: Use WeightedTokensQueryBuilder
TextExpansionResults textExpansionResults = (TextExpansionResults) inferenceResults;
var boolQuery = QueryBuilders.boolQuery();
for (var weightedToken : textExpansionResults.getWeightedTokens()) {
boolQuery.should(
QueryBuilders.termQuery(inferenceResultsFieldName, weightedToken.token()).boost(weightedToken.weight())
);
}
boolQuery.minimumShouldMatch(1);

yield boolQuery;
yield new SparseVectorQueryBuilder(name(), textExpansionResults.getWeightedTokens(), null, null, null, null);
}
case TEXT_EMBEDDING -> {
if (inferenceResults instanceof MlTextEmbeddingResults == false) {
Expand Down
Loading

0 comments on commit 534a96a

Please sign in to comment.