Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change query clause name to neural_sparse #416

Merged
merged 4 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.plugins.ActionPlugin;
Expand Down Expand Up @@ -87,7 +87,7 @@
) {
NeuralSearchClusterUtil.instance().initialize(clusterService);
NeuralQueryBuilder.initialize(clientAccessor);
SparseEncodingQueryBuilder.initialize(clientAccessor);
NeuralSparseQueryBuilder.initialize(clientAccessor);

Check warning on line 90 in src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java#L90

Added line #L90 was not covered by tests
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());
return List.of(clientAccessor);
}
Expand All @@ -97,7 +97,7 @@
return Arrays.asList(
new QuerySpec<>(NeuralQueryBuilder.NAME, NeuralQueryBuilder::new, NeuralQueryBuilder::fromXContent),
new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent),
new QuerySpec<>(SparseEncodingQueryBuilder.NAME, SparseEncodingQueryBuilder::new, SparseEncodingQueryBuilder::fromXContent)
new QuerySpec<>(NeuralSparseQueryBuilder.NAME, NeuralSparseQueryBuilder::new, NeuralSparseQueryBuilder::fromXContent)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
* and set the target fields according to the field name map.
*/
@Log4j2
public abstract class NLPProcessor extends AbstractProcessor {
public abstract class InferenceProcessor extends AbstractProcessor {

public static final String MODEL_ID_FIELD = "model_id";
public static final String FIELD_MAP_FIELD = "field_map";
Expand All @@ -51,7 +51,7 @@ public abstract class NLPProcessor extends AbstractProcessor {

private final Environment environment;

public NLPProcessor(
public InferenceProcessor(
String tag,
String description,
String type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
* and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the sparse encoding results.
*/
@Log4j2
public final class SparseEncodingProcessor extends NLPProcessor {
public final class SparseEncodingProcessor extends InferenceProcessor {

public static final String TYPE = "sparse_encoding";
public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the text embedding results.
*/
@Log4j2
public final class TextEmbeddingProcessor extends NLPProcessor {
public final class TextEmbeddingProcessor extends InferenceProcessor {

public static final String TYPE = "text_embedding";
public static final String LIST_TYPE_NESTED_MAP_KEY = "knn";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
import com.google.common.annotations.VisibleForTesting;

/**
* SparseEncodingQueryBuilder is responsible for handling "sparse_encoding" query types. It uses an ML SPARSE_ENCODING model
* SparseEncodingQueryBuilder is responsible for handling "neural_sparse" query types. It uses an ML SPARSE_ENCODING model
* or SPARSE_TOKENIZE model to produce a Map with String keys and Float values for input text. Then it will be transformed
* to Lucene FeatureQuery wrapped by Lucene BooleanQuery.
*/
Expand All @@ -55,8 +55,8 @@
@Accessors(chain = true, fluent = true)
@NoArgsConstructor
@AllArgsConstructor
public class SparseEncodingQueryBuilder extends AbstractQueryBuilder<SparseEncodingQueryBuilder> {
public static final String NAME = "sparse_encoding";
public class NeuralSparseQueryBuilder extends AbstractQueryBuilder<NeuralSparseQueryBuilder> {
public static final String NAME = "neural_sparse";
@VisibleForTesting
static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text");
@VisibleForTesting
Expand All @@ -65,7 +65,7 @@ public class SparseEncodingQueryBuilder extends AbstractQueryBuilder<SparseEncod
private static MLCommonsClientAccessor ML_CLIENT;

public static void initialize(MLCommonsClientAccessor mlClient) {
SparseEncodingQueryBuilder.ML_CLIENT = mlClient;
NeuralSparseQueryBuilder.ML_CLIENT = mlClient;
}

private String fieldName;
Expand All @@ -79,7 +79,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
* @param in StreamInput to initialize object from
* @throws IOException thrown if unable to read from input stream
*/
public SparseEncodingQueryBuilder(StreamInput in) throws IOException {
public NeuralSparseQueryBuilder(StreamInput in) throws IOException {
super(in);
this.fieldName = in.readString();
this.queryText = in.readString();
Expand Down Expand Up @@ -115,8 +115,8 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
* @return NeuralQueryBuilder
* @throws IOException can be thrown by parser
*/
public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) throws IOException {
SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder();
public static NeuralSparseQueryBuilder fromXContent(XContentParser parser) throws IOException {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder();
if (parser.currentToken() != XContentParser.Token.START_OBJECT) {
throw new ParsingException(parser.getTokenLocation(), "First token of " + NAME + "query must be START_OBJECT");
}
Expand Down Expand Up @@ -150,7 +150,7 @@ public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) thr
return sparseEncodingQueryBuilder;
}

private static void parseQueryParams(XContentParser parser, SparseEncodingQueryBuilder sparseEncodingQueryBuilder) throws IOException {
private static void parseQueryParams(XContentParser parser, NeuralSparseQueryBuilder sparseEncodingQueryBuilder) throws IOException {
XContentParser.Token token;
String currentFieldName = "";
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -200,7 +200,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
}, actionListener::onFailure)
))
);
return new SparseEncodingQueryBuilder().fieldName(fieldName)
return new NeuralSparseQueryBuilder().fieldName(fieldName)
.queryText(queryText)
.modelId(modelId)
.queryTokensSupplier(queryTokensSetOnce::get);
Expand Down Expand Up @@ -254,7 +254,7 @@ private static void validateQueryTokens(Map<String, Float> queryTokens) {
}

@Override
protected boolean doEquals(SparseEncodingQueryBuilder obj) {
protected boolean doEquals(NeuralSparseQueryBuilder obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD;
import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap;
import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.MODEL_ID_FIELD;
import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.NAME;
import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.QUERY_TEXT_FIELD;
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.MODEL_ID_FIELD;
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.NAME;
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.QUERY_TEXT_FIELD;

import java.io.IOException;
import java.util.List;
Expand Down Expand Up @@ -42,7 +42,7 @@
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.test.OpenSearchTestCase;

public class SparseEncodingQueryBuilderTests extends OpenSearchTestCase {
public class NeuralSparseQueryBuilderTests extends OpenSearchTestCase {

private static final String FIELD_NAME = "testField";
private static final String QUERY_TEXT = "Hello world!";
Expand Down Expand Up @@ -71,7 +71,7 @@ public void testFromXContent_whenBuiltWithQueryText_thenBuildSuccessfully() {

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
SparseEncodingQueryBuilder sparseEncodingQueryBuilder = SparseEncodingQueryBuilder.fromXContent(contentParser);
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser);

assertEquals(FIELD_NAME, sparseEncodingQueryBuilder.fieldName());
assertEquals(QUERY_TEXT, sparseEncodingQueryBuilder.queryText());
Expand Down Expand Up @@ -102,7 +102,7 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() {

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
SparseEncodingQueryBuilder sparseEncodingQueryBuilder = SparseEncodingQueryBuilder.fromXContent(contentParser);
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser);

assertEquals(FIELD_NAME, sparseEncodingQueryBuilder.fieldName());
assertEquals(QUERY_TEXT, sparseEncodingQueryBuilder.queryText());
Expand Down Expand Up @@ -137,7 +137,7 @@ public void testFromXContent_whenBuildWithMultipleRootFields_thenFail() {

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
expectThrows(ParsingException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser));
expectThrows(ParsingException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
}

@SneakyThrows
Expand All @@ -158,7 +158,7 @@ public void testFromXContent_whenBuildWithMissingQuery_thenFail() {

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
expectThrows(IllegalArgumentException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser));
expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
}

@SneakyThrows
Expand All @@ -179,7 +179,7 @@ public void testFromXContent_whenBuildWithMissingModelId_thenFail() {

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
expectThrows(IllegalArgumentException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser));
expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
}

@SneakyThrows
Expand All @@ -206,13 +206,13 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() {

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
expectThrows(IOException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser));
expectThrows(IOException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
}

@SuppressWarnings("unchecked")
@SneakyThrows
public void testToXContent() {
SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME)
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.modelId(MODEL_ID)
.queryText(QUERY_TEXT);

Expand Down Expand Up @@ -243,7 +243,7 @@ public void testToXContent() {

@SneakyThrows
public void testStreams() {
SparseEncodingQueryBuilder original = new SparseEncodingQueryBuilder();
NeuralSparseQueryBuilder original = new NeuralSparseQueryBuilder();
original.fieldName(FIELD_NAME);
original.queryText(QUERY_TEXT);
original.modelId(MODEL_ID);
Expand All @@ -260,7 +260,7 @@ public void testStreams() {
)
);

SparseEncodingQueryBuilder copy = new SparseEncodingQueryBuilder(filterStreamInput);
NeuralSparseQueryBuilder copy = new NeuralSparseQueryBuilder(filterStreamInput);
assertEquals(original, copy);
}

Expand All @@ -276,54 +276,54 @@ public void testHashAndEquals() {
String queryName1 = "query-1";
String queryName2 = "query-2";

SparseEncodingQueryBuilder sparseEncodingQueryBuilder_baseline = new SparseEncodingQueryBuilder().fieldName(fieldName1)
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baseline = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline
SparseEncodingQueryBuilder sparseEncodingQueryBuilder_baselineCopy = new SparseEncodingQueryBuilder().fieldName(fieldName1)
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baselineCopy = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except default boost and query name
SparseEncodingQueryBuilder sparseEncodingQueryBuilder_defaultBoostAndQueryName = new SparseEncodingQueryBuilder().fieldName(
fieldName1
).queryText(queryText1).modelId(modelId1);
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_defaultBoostAndQueryName = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1);

// Identical to sparseEncodingQueryBuilder_baseline except diff field name
SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffFieldName = new SparseEncodingQueryBuilder().fieldName(fieldName2)
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffFieldName = new NeuralSparseQueryBuilder().fieldName(fieldName2)
.queryText(queryText1)
.modelId(modelId1)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except diff query text
SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryText = new SparseEncodingQueryBuilder().fieldName(fieldName1)
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryText = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText2)
.modelId(modelId1)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except diff model ID
SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffModelId = new SparseEncodingQueryBuilder().fieldName(fieldName1)
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffModelId = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId2)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except diff boost
SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffBoost = new SparseEncodingQueryBuilder().fieldName(fieldName1)
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffBoost = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.boost(boost2)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except diff query name
SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryName = new SparseEncodingQueryBuilder().fieldName(fieldName1)
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryName = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.boost(boost1)
Expand Down Expand Up @@ -356,7 +356,7 @@ public void testHashAndEquals() {

@SneakyThrows
public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() {
SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME)
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.queryText(QUERY_TEXT)
.modelId(MODEL_ID);
Map<String, Float> expectedMap = Map.of("1", 1f, "2", 2f);
Expand All @@ -366,7 +366,7 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier()
listener.onResponse(List.of(Map.of("response", List.of(expectedMap))));
return null;
}).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(any(), any(), any());
SparseEncodingQueryBuilder.initialize(mlCommonsClientAccessor);
NeuralSparseQueryBuilder.initialize(mlCommonsClientAccessor);

final CountDownLatch inProgressLatch = new CountDownLatch(1);
QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class);
Expand All @@ -382,15 +382,15 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier()
return null;
}).when(queryRewriteContext).registerAsyncAction(any());

SparseEncodingQueryBuilder queryBuilder = (SparseEncodingQueryBuilder) sparseEncodingQueryBuilder.doRewrite(queryRewriteContext);
NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) sparseEncodingQueryBuilder.doRewrite(queryRewriteContext);
assertNotNull(queryBuilder.queryTokensSupplier());
assertTrue(inProgressLatch.await(5, TimeUnit.SECONDS));
assertEquals(expectedMap, queryBuilder.queryTokensSupplier().get());
}

@SneakyThrows
public void testRewrite_whenQueryTokensSupplierSet_thenReturnSelf() {
SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME)
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.queryText(QUERY_TEXT)
.modelId(MODEL_ID)
.queryTokensSupplier(QUERY_TOKENS_SUPPLIER);
Expand Down
Loading
Loading