From aae62d48d84b23ab5033265a66d491da4ea0ae8b Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 27 Sep 2023 09:07:23 +0800 Subject: [PATCH] minor changes based on comments Signed-off-by: zhichao-aws --- CHANGELOG.md | 1 - .../processor/SparseEncodingProcessor.java | 2 +- .../processor/TextEmbeddingProcessor.java | 2 +- .../SparseEncodingProcessorFactory.java | 3 ++ .../TextEmbeddingProcessorFactory.java | 3 ++ .../query/SparseEncodingQueryBuilder.java | 28 ++++++++++--------- .../neuralsearch/util/TokenWeightUtil.java | 5 +++- 7 files changed, 27 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 038fe41e5..da2ae9ec9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD) ### Features -Support sparse semantic retrieval by introducing `sparse_encoding` ingest processor and query builder ([#333](https://github.com/opensearch-project/neural-search/pull/333)) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 62857541e..275117809 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -22,7 +22,7 @@ * and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the sparse encoding results. */ @Log4j2 -public class SparseEncodingProcessor extends NLPProcessor { +public final class SparseEncodingProcessor extends NLPProcessor { public static final String TYPE = "sparse_encoding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding"; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index 354b53945..1df60baea 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -21,7 +21,7 @@ * and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the text embedding results. */ @Log4j2 -public class TextEmbeddingProcessor extends NLPProcessor { +public final class TextEmbeddingProcessor extends NLPProcessor { public static final String TYPE = "text_embedding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java index dff56e9c8..104418ec5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -18,6 +18,9 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; +/** + * Factory for sparse encoding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. + */ @Log4j2 public class SparseEncodingProcessorFactory implements Processor.Factory { private final MLCommonsClientAccessor clientAccessor; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java index f805b29e1..0c9a6fa2c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java @@ -16,6 +16,9 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; +/** + * Factory for text embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. + */ public class TextEmbeddingProcessorFactory implements Processor.Factory { private final MLCommonsClientAccessor clientAccessor; diff --git a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java index 07f581a9d..a8c2baaf7 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java @@ -126,22 +126,24 @@ public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) thr if (parser.nextToken() != XContentParser.Token.END_OBJECT) { throw new ParsingException( parser.getTokenLocation(), - "[" - + NAME - + "] query doesn't support multiple fields, found [" - + sparseEncodingQueryBuilder.fieldName() - + "] and [" - + parser.currentName() - + "]" + String.format( + "[%s] query doesn't support multiple fields, found [%s] and [%s]", + NAME, + sparseEncodingQueryBuilder.fieldName(), + parser.currentName() + ) ); } requireValue(sparseEncodingQueryBuilder.fieldName(), "Field name must be provided for " + NAME + " query"); requireValue( sparseEncodingQueryBuilder.queryText(), - QUERY_TEXT_FIELD.getPreferredName() + " must be provided for " + NAME + " query" + String.format("%s field must be provided for [%s] query", QUERY_TEXT_FIELD.getPreferredName(), NAME) + ); + requireValue( + sparseEncodingQueryBuilder.modelId(), + String.format("%s field must be provided for [%s] query", MODEL_ID_FIELD.getPreferredName(), NAME) ); - requireValue(sparseEncodingQueryBuilder.modelId(), MODEL_ID_FIELD.getPreferredName() + " must be provided for " + NAME + " query"); return sparseEncodingQueryBuilder; } @@ -164,13 +166,13 @@ private static void parseQueryParams(XContentParser parser, SparseEncodingQueryB } else { throw new ParsingException( parser.getTokenLocation(), - "[" + NAME + "] query does not support [" + currentFieldName + "]" + String.format("[%s] query does not support [%s] field", NAME, currentFieldName) ); } } else { throw new ParsingException( parser.getTokenLocation(), - "[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]" + String.format("[%s] unknown token [%s] after [%s]", NAME, token, currentFieldName) ); } } @@ -220,7 +222,7 @@ protected Query doToQuery(QueryShardContext context) throws IOException { private static void validateForRewrite(String queryText, String modelId) { if (StringUtils.isBlank(queryText) || StringUtils.isBlank(modelId)) { throw new IllegalArgumentException( - QUERY_TEXT_FIELD.getPreferredName() + " and " + MODEL_ID_FIELD.getPreferredName() + " cannot be null." + String.format("%s and %s cannot be null", QUERY_TEXT_FIELD.getPreferredName(), MODEL_ID_FIELD.getPreferredName()) ); } } @@ -238,7 +240,7 @@ private static void validateQueryTokens(Map queryTokens) { for (Map.Entry entry : queryTokens.entrySet()) { if (entry.getValue() <= 0) { throw new IllegalArgumentException( - "Feature weight must be larger than 0, got: " + entry.getValue() + "for key " + entry.getKey() + "Feature weight must be larger than 0, feature [" + entry.getValue() + "] has negative weight." ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java index db249de0f..76ce0fa16 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java @@ -46,6 +46,9 @@ public class TokenWeightUtil { * @param mapResultList {@link Map} which is the response from {@link org.opensearch.neuralsearch.ml.MLCommonsClientAccessor} */ public static List> fetchListOfTokenWeightMap(List> mapResultList) { + if (null == mapResultList || mapResultList.isEmpty()) { + throw new IllegalArgumentException("The inference result can not be null or empty."); + } List results = new ArrayList<>(); for (Map map : mapResultList) { if (!map.containsKey(RESPONSE_KEY)) { @@ -66,7 +69,7 @@ private static Map buildTokenWeightMap(Object uncastedMap) { Map result = new HashMap<>(); for (Map.Entry entry : ((Map) uncastedMap).entrySet()) { if (!String.class.isAssignableFrom(entry.getKey().getClass()) || !Number.class.isAssignableFrom(entry.getValue().getClass())) { - throw new IllegalArgumentException("The expected inference result is a Map with String keys and " + " Float values."); + throw new IllegalArgumentException("The expected inference result is a Map with String keys and Float values."); } result.put((String) entry.getKey(), ((Number) entry.getValue()).floatValue()); }