Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting sparse semantic retrieval in neural search #333

Merged
Merged
Show file tree
Hide file tree
Changes from 69 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
5971c93
sparse mapper field and query builder
zhichao-aws Sep 5, 2023
0386a95
fix typo
zhichao-aws Sep 6, 2023
489b6e8
Add map result support in neural search for non text embedding models
zane-neo Aug 22, 2023
18022b8
Fix compilation failure issue
zane-neo Sep 1, 2023
886cdeb
Add more UTs
zane-neo Sep 1, 2023
d336f88
add sparse encoding processor
xinyual Sep 6, 2023
bb14947
add sparse encoding processor
xinyual Sep 7, 2023
5bb409b
remove guava in gradle
xinyual Sep 7, 2023
8e8df7e
modify access control
xinyual Sep 7, 2023
5e81ee5
Add map result support in neural search for non text embedding models
zane-neo Aug 22, 2023
20dd78e
Fix compilation failure issue
zane-neo Sep 1, 2023
05c3be8
change output logic
xinyual Sep 8, 2023
734fd50
create abstract
xinyual Sep 11, 2023
c00a4cf
create abstract proccesor
xinyual Sep 11, 2023
a973b42
add abstract class
xinyual Sep 11, 2023
30fc444
remove duplicate code
xinyual Sep 11, 2023
e2a30de
remove duplicate code
xinyual Sep 11, 2023
6b94a17
remove dl process
xinyual Sep 11, 2023
a3d09bd
move static to abstract class
xinyual Sep 11, 2023
f10c94d
update query rewrite logic
zhichao-aws Sep 8, 2023
589b1c0
modify header
zhichao-aws Sep 11, 2023
ec3f426
merge conflict
xinyual Sep 22, 2023
a8520d3
delete index mapper, change to rank_features
zhichao-aws Sep 13, 2023
b964d6c
remove unused import
zhichao-aws Sep 13, 2023
be45f86
list return result
zhichao-aws Sep 13, 2023
dbe00fd
refactor type and listTypeNestedMapKey, tidy
zhichao-aws Sep 14, 2023
c109666
forbid nested input. tidy.
zhichao-aws Sep 14, 2023
90516b2
tidy
zhichao-aws Sep 14, 2023
4d79cc4
enable nested
zhichao-aws Sep 15, 2023
79d861e
fix test
zhichao-aws Sep 15, 2023
9ab2e74
Add ut it to sparse encoding processor (#6)
xinyual Sep 18, 2023
84915f0
utils, tidy
zhichao-aws Sep 15, 2023
3bb95e3
rename to sparse_encoding query
zhichao-aws Sep 15, 2023
b4156f0
add validation and ut
zhichao-aws Sep 18, 2023
4771cd1
sparse encoding query builder ut
zhichao-aws Sep 18, 2023
5d12758
rename
zhichao-aws Sep 19, 2023
51a9ef3
UT for utils
zhichao-aws Sep 19, 2023
77eb300
enrich sparse encoding IT mappings
zhichao-aws Sep 19, 2023
854e9c4
add it
zhichao-aws Sep 19, 2023
03ff8b8
add it
zhichao-aws Sep 19, 2023
e791807
add integ test
zhichao-aws Sep 20, 2023
10e599a
rename resource file
zhichao-aws Sep 20, 2023
1e14a26
tidy
zhichao-aws Sep 20, 2023
473c68d
remove BoundedLinearQuery and TokenScoreUpperBound
zhichao-aws Sep 20, 2023
fa11056
tidy
zhichao-aws Sep 20, 2023
439d628
add delta to loose the equal
zhichao-aws Sep 20, 2023
99a739d
move SparseEncodingQueryBuilder to upper level path
zhichao-aws Sep 20, 2023
65b1e4f
tidy
zhichao-aws Sep 20, 2023
916b3cf
add it
zhichao-aws Sep 20, 2023
283a7a3
Update src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAc…
zhichao-aws Sep 25, 2023
c3d9fd3
Update src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil…
zhichao-aws Sep 25, 2023
11cf97d
restore gradle.propeties
zhichao-aws Sep 25, 2023
057f435
add release notes
zhichao-aws Sep 25, 2023
351bae9
change field modifier to private for NLPProcessor
zhichao-aws Sep 25, 2023
f58a073
add comments
zhichao-aws Sep 25, 2023
791c6ca
use StringUtils to check
zhichao-aws Sep 25, 2023
6878c01
null check
zhichao-aws Sep 25, 2023
c6c631e
modify changelog
zhichao-aws Sep 26, 2023
ec70c34
nit
zhichao-aws Sep 26, 2023
9223e31
nit
zhichao-aws Sep 26, 2023
ba10e27
remove query tokens from user interface
zhichao-aws Sep 26, 2023
9647ac9
fix test
zhichao-aws Sep 26, 2023
169934a
tidy
zhichao-aws Sep 26, 2023
a47c8b6
update function name
zhichao-aws Sep 26, 2023
b48091f
add javadoc
zhichao-aws Sep 26, 2023
cfc847d
remove debug log including inference result
zhichao-aws Sep 27, 2023
508b462
make query text and model id required
zhichao-aws Sep 27, 2023
aae62d4
minor changes based on comments
zhichao-aws Sep 27, 2023
2d51bb9
add locale to String.format
zhichao-aws Sep 27, 2023
9611411
update mock model url
zhichao-aws Sep 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.10...2.x)
### Features
Support sparse semantic retrieval by introducing `sparse_encoding` ingest processor and query builder ([#333](https://github.com/opensearch-project/neural-search/pull/333))
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ dependencies {
runtimeOnly group: 'org.reflections', name: 'reflections', version: '0.9.12'
runtimeOnly group: 'org.javassist', name: 'javassist', version: '3.29.2-GA'
runtimeOnly group: 'org.opensearch', name: 'common-utils', version: "${opensearch_build}"
runtimeOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
runtimeOnly group: 'org.json', name: 'json', version: '20230227'
Comment on lines +154 to +155
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need both?
and what is the usage of these dependencies?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When calling inferenceSentencesWithMapResult using remote inference, it will use gson to load the response from remote endpoint. And it can throw JSONException during running

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't add these two dependencies, ClassNotFoundException will be throwed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirming, I've seen same behavior while working on multimodal

}

// In order to add the jar to the classpath, we need to unzip the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
Expand Down Expand Up @@ -100,10 +102,38 @@ public void inferenceSentences(
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<List<Float>>> listener
) {
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, 0, listener);
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener);
}

private void inferenceSentencesWithRetry(
public void inferenceSentencesWithMapResult(
@NonNull final String modelId,
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<Map<String, ?>>> listener
) {
retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we want to do no retires?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are getting consistent with the existing inferenceSentences method. @zane-neo Could you please help answer this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 0 here doesn't mean no retry, it's only a initial value, this value will be increased till max retry time(3), this can be optimized to decrease in the future to make it more intuitive though.

}

private void retryableInferenceSentencesWithMapResult(
final String modelId,
final List<String> inputText,
final int retryTime,
final ActionListener<List<Map<String, ?>>> listener
) {
MLInput mlInput = createMLInput(null, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Map<String, ?>> result = buildMapResultFromResponse(mlOutput);
listener.onResponse(result);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
retryableInferenceSentencesWithMapResult(modelId, inputText, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
}));
}

private void retryableInferenceSentencesWithVectorResult(
final List<String> targetResponseFilters,
final String modelId,
final List<String> inputText,
Expand All @@ -113,12 +143,11 @@ private void inferenceSentencesWithRetry(
MLInput mlInput = createMLInput(targetResponseFilters, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence {} is : {} ", inputText, vector);
listener.onResponse(vector);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, retryTimeAdd, listener);
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
Expand All @@ -144,4 +173,22 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
return vector;
}

private List<Map<String, ?>> buildMapResultFromResponse(MLOutput mlOutput) {
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this cast safe? Or should we check?

Copy link
Member Author

@zhichao-aws zhichao-aws Sep 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the above one. Maybe we can use consistent code between the 2 use case. And If they need to be fixed, we can create another PR to fix them.

final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can fix the other one in the future, but I want to understand how we know for certain this wont cause class cast on invalid input.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zane-neo could you please help answer this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two inference method we created in neural-search(inferenceWithVectorResult & inferenceWithMapResult) are for text-embedding and sparse-embedding only, and they both return ModelTensorOutput. For later one in the future, it could be expanded to other use cases, e.g. remote inference, but in this case, it's still returning ModelTensorOutput. So it's safe to cast this output to ModelTensorOutput.

final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs();
if (CollectionUtils.isEmpty(tensorOutputList) || CollectionUtils.isEmpty(tensorOutputList.get(0).getMlModelTensors())) {
throw new IllegalStateException(
"Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]"
);
}
List<Map<String, ?>> resultMaps = new ArrayList<>();
for (ModelTensors tensors : tensorOutputList) {
List<ModelTensor> tensorList = tensors.getMlModelTensors();
for (ModelTensor tensor : tensorList) {
resultMaps.add(tensor.getDataAsMap());
}
}
return resultMaps;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -31,15 +30,18 @@
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
Expand All @@ -62,7 +64,7 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin,
private MLCommonsClientAccessor clientAccessor;
private NormalizationProcessorWorkflow normalizationProcessorWorkflow;
private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory();
private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();;
private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();

@Override
public Collection<Object> createComponents(
Expand All @@ -79,6 +81,7 @@ public Collection<Object> createComponents(
final Supplier<RepositoriesService> repositoriesServiceSupplier
) {
NeuralQueryBuilder.initialize(clientAccessor);
SparseEncodingQueryBuilder.initialize(clientAccessor);
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());
return List.of(clientAccessor);
}
Expand All @@ -87,14 +90,20 @@ public Collection<Object> createComponents(
public List<QuerySpec<?>> getQueries() {
return Arrays.asList(
new QuerySpec<>(NeuralQueryBuilder.NAME, NeuralQueryBuilder::new, NeuralQueryBuilder::fromXContent),
new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent)
new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent),
new QuerySpec<>(SparseEncodingQueryBuilder.NAME, SparseEncodingQueryBuilder::new, SparseEncodingQueryBuilder::fromXContent)
);
}

@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
return Collections.singletonMap(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env));
return Map.of(
TextEmbeddingProcessor.TYPE,
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env),
SparseEncodingProcessor.TYPE,
new SparseEncodingProcessorFactory(clientAccessor, parameters.env)
);
}

@Override
Expand Down
Loading