-
Notifications
You must be signed in to change notification settings - Fork 74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Supporting sparse semantic retrieval in neural search #333
Changes from 69 commits
5971c93
0386a95
489b6e8
18022b8
886cdeb
d336f88
bb14947
5bb409b
8e8df7e
5e81ee5
20dd78e
05c3be8
734fd50
c00a4cf
a973b42
30fc444
e2a30de
6b94a17
a3d09bd
f10c94d
589b1c0
ec3f426
a8520d3
b964d6c
be45f86
dbe00fd
c109666
90516b2
4d79cc4
79d861e
9ab2e74
84915f0
3bb95e3
b4156f0
4771cd1
5d12758
51a9ef3
77eb300
854e9c4
03ff8b8
e791807
10e599a
1e14a26
473c68d
fa11056
439d628
99a739d
65b1e4f
916b3cf
283a7a3
c3d9fd3
11cf97d
057f435
351bae9
f58a073
791c6ca
6878c01
c6c631e
ec70c34
9223e31
ba10e27
9647ac9
169934a
a47c8b6
b48091f
cfc847d
508b462
aae62d4
2d51bb9
9611411
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -8,13 +8,15 @@ | |||
import java.util.ArrayList; | ||||
import java.util.Arrays; | ||||
import java.util.List; | ||||
import java.util.Map; | ||||
import java.util.stream.Collectors; | ||||
|
||||
import lombok.NonNull; | ||||
import lombok.RequiredArgsConstructor; | ||||
import lombok.extern.log4j.Log4j2; | ||||
|
||||
import org.opensearch.core.action.ActionListener; | ||||
import org.opensearch.core.common.util.CollectionUtils; | ||||
import org.opensearch.ml.client.MachineLearningNodeClient; | ||||
import org.opensearch.ml.common.FunctionName; | ||||
import org.opensearch.ml.common.dataset.MLInputDataset; | ||||
|
@@ -100,10 +102,38 @@ public void inferenceSentences( | |||
@NonNull final List<String> inputText, | ||||
@NonNull final ActionListener<List<List<Float>>> listener | ||||
) { | ||||
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, 0, listener); | ||||
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener); | ||||
} | ||||
|
||||
private void inferenceSentencesWithRetry( | ||||
public void inferenceSentencesWithMapResult( | ||||
@NonNull final String modelId, | ||||
@NonNull final List<String> inputText, | ||||
@NonNull final ActionListener<List<Map<String, ?>>> listener | ||||
) { | ||||
retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener); | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we want to do no retires? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we are getting consistent with the existing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The 0 here doesn't mean no retry, it's only a initial value, this value will be increased till max retry time(3), this can be optimized to decrease in the future to make it more intuitive though. |
||||
} | ||||
|
||||
private void retryableInferenceSentencesWithMapResult( | ||||
final String modelId, | ||||
final List<String> inputText, | ||||
final int retryTime, | ||||
final ActionListener<List<Map<String, ?>>> listener | ||||
) { | ||||
MLInput mlInput = createMLInput(null, inputText); | ||||
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { | ||||
final List<Map<String, ?>> result = buildMapResultFromResponse(mlOutput); | ||||
listener.onResponse(result); | ||||
}, e -> { | ||||
if (RetryUtil.shouldRetry(e, retryTime)) { | ||||
final int retryTimeAdd = retryTime + 1; | ||||
retryableInferenceSentencesWithMapResult(modelId, inputText, retryTimeAdd, listener); | ||||
} else { | ||||
listener.onFailure(e); | ||||
} | ||||
})); | ||||
} | ||||
|
||||
private void retryableInferenceSentencesWithVectorResult( | ||||
final List<String> targetResponseFilters, | ||||
final String modelId, | ||||
final List<String> inputText, | ||||
|
@@ -113,12 +143,11 @@ private void inferenceSentencesWithRetry( | |||
MLInput mlInput = createMLInput(targetResponseFilters, inputText); | ||||
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { | ||||
final List<List<Float>> vector = buildVectorFromResponse(mlOutput); | ||||
log.debug("Inference Response for input sentence {} is : {} ", inputText, vector); | ||||
listener.onResponse(vector); | ||||
}, e -> { | ||||
if (RetryUtil.shouldRetry(e, retryTime)) { | ||||
final int retryTimeAdd = retryTime + 1; | ||||
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, retryTimeAdd, listener); | ||||
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTimeAdd, listener); | ||||
} else { | ||||
listener.onFailure(e); | ||||
} | ||||
|
@@ -144,4 +173,22 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) { | |||
return vector; | ||||
} | ||||
|
||||
private List<Map<String, ?>> buildMapResultFromResponse(MLOutput mlOutput) { | ||||
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput; | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this cast safe? Or should we check? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as the above one. Maybe we can use consistent code between the 2 use case. And If they need to be fixed, we can create another PR to fix them. neural-search/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java Line 136 in 8484be9
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can fix the other one in the future, but I want to understand how we know for certain this wont cause class cast on invalid input. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zane-neo could you please help answer this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The two inference method we created in neural-search(inferenceWithVectorResult & inferenceWithMapResult) are for text-embedding and sparse-embedding only, and they both return ModelTensorOutput. For later one in the future, it could be expanded to other use cases, e.g. remote inference, but in this case, it's still returning ModelTensorOutput. So it's safe to cast this output to ModelTensorOutput. |
||||
final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs(); | ||||
if (CollectionUtils.isEmpty(tensorOutputList) || CollectionUtils.isEmpty(tensorOutputList.get(0).getMlModelTensors())) { | ||||
throw new IllegalStateException( | ||||
"Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]" | ||||
); | ||||
} | ||||
List<Map<String, ?>> resultMaps = new ArrayList<>(); | ||||
for (ModelTensors tensors : tensorOutputList) { | ||||
List<ModelTensor> tensorList = tensors.getMlModelTensors(); | ||||
for (ModelTensor tensor : tensorList) { | ||||
resultMaps.add(tensor.getDataAsMap()); | ||||
} | ||||
} | ||||
return resultMaps; | ||||
} | ||||
|
||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why we need both?
and what is the usage of these dependencies?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When calling inferenceSentencesWithMapResult using remote inference, it will use gson to load the response from remote endpoint. And it can throw JSONException during running
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we don't add these two dependencies, ClassNotFoundException will be throwed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
confirming, I've seen same behavior while working on multimodal