Skip to content

Commit

Permalink
Changed approach to a hardcoded fields for image and text
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Sep 29, 2023
1 parent c61db6a commit 8811b48
Show file tree
Hide file tree
Showing 19 changed files with 693 additions and 409 deletions.
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ dependencies {
runtimeOnly group: 'org.reflections', name: 'reflections', version: '0.9.12'
runtimeOnly group: 'org.javassist', name: 'javassist', version: '3.29.2-GA'
runtimeOnly group: 'org.opensearch', name: 'common-utils', version: "${opensearch_build}"
runtimeOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
runtimeOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
runtimeOnly group: 'org.json', name: 'json', version: '20230227'
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

package org.opensearch.neuralsearch.ml;

import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_IMAGE;
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_TEXT;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
Expand All @@ -22,7 +24,6 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelResultFilter;
Expand Down Expand Up @@ -115,6 +116,14 @@ public void inferenceSentencesWithMapResult(
retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener);
}

public void inferenceSentences(
@NonNull final String modelId,
@NonNull final Map<String, String> inputObjects,
@NonNull final ActionListener<List<Float>> listener
) {
inferenceSentencesWithRetry(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener);
}

private void retryableInferenceSentencesWithMapResult(
final String modelId,
final List<String> inputText,
Expand Down Expand Up @@ -198,4 +207,42 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
}
return resultMaps;
}

private List<Float> buildSingleVectorFromResponse(MLOutput mlOutput) {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
return vector.isEmpty() ? new ArrayList<>() : vector.get(0);
}

private void inferenceSentencesWithRetry(
@NonNull final List<String> targetResponseFilters,
final String modelId,
final Map<String, String> inputObjects,
final int retryTime,
final ActionListener<List<Float>> listener
) {
MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> vector = buildSingleVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence {} is : {} ", inputObjects, vector);
listener.onResponse(vector);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputObjects, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
}));
}

private MLInput createMLMultimodalInput(final List<String> targetResponseFilters, Map<String, String> input) {
List<String> inputText = new ArrayList<>();
inputText.add(input.get(INPUT_TEXT));
if (input.containsKey(INPUT_IMAGE)) {
inputText.add(input.get(INPUT_IMAGE));
}
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@
import org.opensearch.ingest.Processor;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.InferenceProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.InferenceProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
Expand Down Expand Up @@ -105,8 +105,8 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env),
SparseEncodingProcessor.TYPE,
new SparseEncodingProcessorFactory(clientAccessor, parameters.env),
InferenceProcessor.TYPE,
new InferenceProcessorFactory(clientAccessor, parameters.env)
TextImageEmbeddingProcessor.TYPE,
new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the text embedding results.
*/
@Log4j2
public final class TextEmbeddingProcessor extends NLPProcessor {
public class TextEmbeddingProcessor extends NLPProcessor {

public static final String TYPE = "text_embedding";
public static final String LIST_TYPE_NESTED_MAP_KEY = "knn";
Expand Down
Loading

0 comments on commit 8811b48

Please sign in to comment.