Skip to content

Commit

Permalink
tidy
Browse files Browse the repository at this point in the history
Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws committed Sep 26, 2023
1 parent 9647ac9 commit 169934a
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs();
if (CollectionUtils.isEmpty(tensorOutputList) || CollectionUtils.isEmpty(tensorOutputList.get(0).getMlModelTensors())) {
throw new IllegalStateException("Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]");
throw new IllegalStateException(
"Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]"
);
}
List<Map<String, ?>> resultMaps = new ArrayList<>();
for (ModelTensors tensors : tensorOutputList) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -100,10 +99,10 @@ public List<QuerySpec<?>> getQueries() {
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
return Map.of(
TextEmbeddingProcessor.TYPE,
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env),
SparseEncodingProcessor.TYPE,
new SparseEncodingProcessorFactory(clientAccessor, parameters.env)
TextEmbeddingProcessor.TYPE,
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env),
SparseEncodingProcessor.TYPE,
new SparseEncodingProcessorFactory(clientAccessor, parameters.env)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,20 +164,20 @@ Map<String, Object> buildMapWithProcessorKeyAndOriginalValue(IngestDocument inge
}

private void buildMapWithProcessorKeyAndOriginalValueForMapType(
String parentKey,
Object processorKey,
Map<String, Object> sourceAndMetadataMap,
Map<String, Object> treeRes
String parentKey,
Object processorKey,
Map<String, Object> sourceAndMetadataMap,
Map<String, Object> treeRes
) {
if (processorKey == null || sourceAndMetadataMap == null) return;
if (processorKey instanceof Map) {
Map<String, Object> next = new LinkedHashMap<>();
for (Map.Entry<String, Object> nestedFieldMapEntry : ((Map<String, Object>) processorKey).entrySet()) {
buildMapWithProcessorKeyAndOriginalValueForMapType(
nestedFieldMapEntry.getKey(),
nestedFieldMapEntry.getValue(),
(Map<String, Object>) sourceAndMetadataMap.get(parentKey),
next
nestedFieldMapEntry.getKey(),
nestedFieldMapEntry.getValue(),
(Map<String, Object>) sourceAndMetadataMap.get(parentKey),
next
);
}
treeRes.put(parentKey, next);
Expand Down Expand Up @@ -214,9 +214,9 @@ private void validateNestedTypeValue(String sourceKey, Object sourceValue, Suppl
validateListTypeValue(sourceKey, sourceValue);
} else if (Map.class.isAssignableFrom(sourceValue.getClass())) {
((Map) sourceValue).values()
.stream()
.filter(Objects::nonNull)
.forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1));
.stream()
.filter(Objects::nonNull)
.forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1));
} else if (!String.class.isAssignableFrom(sourceValue.getClass())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it");
} else if (StringUtils.isBlank(sourceValue.toString())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.opensearch.neuralsearch.query;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
Expand Down Expand Up @@ -125,18 +124,9 @@ public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) thr
requireValue(sparseEncodingQueryBuilder.fieldName(), "Field name must be provided for " + NAME + " query");
requireValue(
sparseEncodingQueryBuilder.queryText(),
QUERY_TEXT_FIELD.getPreferredName()
+ " must be provided for "
+ NAME
+ " query"
);
requireValue(
sparseEncodingQueryBuilder.modelId(),
MODEL_ID_FIELD.getPreferredName()
+ " must be provided for "
+ NAME
+ " query"
QUERY_TEXT_FIELD.getPreferredName() + " must be provided for " + NAME + " query"
);
requireValue(sparseEncodingQueryBuilder.modelId(), MODEL_ID_FIELD.getPreferredName() + " must be provided for " + NAME + " query");

return sparseEncodingQueryBuilder;
}
Expand Down Expand Up @@ -215,19 +205,14 @@ protected Query doToQuery(QueryShardContext context) throws IOException {
private static void validateForRewrite(String queryText, String modelId) {
if (StringUtils.isBlank(queryText) || StringUtils.isBlank(modelId)) {
throw new IllegalArgumentException(
QUERY_TEXT_FIELD.getPreferredName()
+ " and "
+ MODEL_ID_FIELD.getPreferredName()
+ " cannot be null."
QUERY_TEXT_FIELD.getPreferredName() + " and " + MODEL_ID_FIELD.getPreferredName() + " cannot be null."
);
}
}

private static void validateFieldType(MappedFieldType fieldType) {
if (null == fieldType || !fieldType.typeName().equals("rank_features")) {
throw new IllegalArgumentException(
"[" + NAME + "] query only works on [rank_features] fields"
);
throw new IllegalArgumentException("[" + NAME + "] query only works on [rank_features] fields");
}
}

Expand All @@ -237,7 +222,9 @@ private static void validateQueryTokens(Map<String, Float> queryTokens) {
}
for (Map.Entry<String, Float> entry : queryTokens.entrySet()) {
if (entry.getValue() <= 0) {
throw new IllegalArgumentException("Feature weight must be larger than 0, got: " + entry.getValue() + "for key " + entry.getKey());
throw new IllegalArgumentException(
"Feature weight must be larger than 0, got: " + entry.getValue() + "for key " + entry.getKey()
);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListEmpty_thenExc
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class);
Mockito.verify(resultListener).onFailure(argumentCaptor.capture());
assertEquals(
"Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]",
"Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]",
argumentCaptor.getValue().getMessage()
);
Mockito.verifyNoMoreInteractions(resultListener);
Expand Down

0 comments on commit 169934a

Please sign in to comment.