From 786b150e6ac440e850de624a503fa1983da9f031 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Thu, 25 Jan 2024 15:49:29 -0800 Subject: [PATCH] Addressing heemin's comment Signed-off-by: Varun Jain --- .../neuralsearch/bwc/HybridSearchIT.java | 45 ++---- .../neuralsearch/bwc/HybridSearchIT.java | 2 +- .../neuralsearch/BaseNeuralSearchIT.java | 146 +++++++++++------- 3 files changed, 102 insertions(+), 91 deletions(-) diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java index c23913259..48735182a 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java @@ -38,56 +38,39 @@ public class HybridSearchIT extends AbstractRestartUpgradeRestTestCase { // Create Text Embedding Processor, Ingestion Pipeline, add document and search pipeline with normalization processor // Validate process , pipeline and document count in restart-upgrade scenario public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() throws Exception { - waitForClusterHealthGreen(NODES_BWC_CLUSTER); - if (isRunningAgainstOldCluster()) { - String modelId = uploadTextEmbeddingModel(); - loadModel(modelId); - createPipelineProcessor(modelId, PIPELINE_NAME); - createIndexWithConfiguration( - getIndexNameForTest(), - Files.readString(Path.of(classLoader.getResource("processor/IndexMappingMultipleShard.json").toURI())), - PIPELINE_NAME - ); - addDocuments(getIndexNameForTest(), true); - createSearchPipeline(SEARCH_PIPELINE_NAME); - } else { - String modelId = null; - try { - modelId = getModelId(getIngestionPipeline(PIPELINE_NAME), TEXT_EMBEDDING_PROCESSOR); - loadModel(modelId); - addDocuments(getIndexNameForTest(), false); - validateTestIndex(modelId, getIndexNameForTest(), SEARCH_PIPELINE_NAME); - } finally { - wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME); - } - } + validateNormalizationProcessor("processor/IndexMappingMultipleShard.json", PIPELINE_NAME, SEARCH_PIPELINE_NAME); } // Test restart-upgrade normalization processor when index with single shard // Create Text Embedding Processor, Ingestion Pipeline, add document and search pipeline with normalization processor // Validate process , pipeline and document count in restart-upgrade scenario public void testNormalizationProcessor_whenIndexWithSingleShard_E2EFlow() throws Exception { + validateNormalizationProcessor("processor/IndexMappingSingleShard.json", PIPELINE1_NAME, SEARCH_PIPELINE1_NAME); + } + + private void validateNormalizationProcessor(final String fileName, final String pipelineName, final String searchPipelineName) + throws Exception { waitForClusterHealthGreen(NODES_BWC_CLUSTER); if (isRunningAgainstOldCluster()) { String modelId = uploadTextEmbeddingModel(); loadModel(modelId); - createPipelineProcessor(modelId, PIPELINE1_NAME); + createPipelineProcessor(modelId, pipelineName); createIndexWithConfiguration( getIndexNameForTest(), - Files.readString(Path.of(classLoader.getResource("processor/IndexMappingSingleShard.json").toURI())), - PIPELINE1_NAME + Files.readString(Path.of(classLoader.getResource(fileName).toURI())), + pipelineName ); addDocuments(getIndexNameForTest(), true); - createSearchPipeline(SEARCH_PIPELINE1_NAME); + createSearchPipeline(searchPipelineName); } else { String modelId = null; try { - modelId = getModelId(getIngestionPipeline(PIPELINE1_NAME), TEXT_EMBEDDING_PROCESSOR); + modelId = getModelId(getIngestionPipeline(pipelineName), TEXT_EMBEDDING_PROCESSOR); loadModel(modelId); addDocuments(getIndexNameForTest(), false); - validateTestIndex(modelId, getIndexNameForTest(), SEARCH_PIPELINE1_NAME); + validateTestIndex(modelId, getIndexNameForTest(), searchPipelineName); } finally { - wipeOfTestResources(getIndexNameForTest(), PIPELINE1_NAME, modelId, SEARCH_PIPELINE1_NAME); + wipeOfTestResources(getIndexNameForTest(), pipelineName, modelId, searchPipelineName); } } } @@ -127,7 +110,7 @@ private void validateTestIndex(final String modelId, final String index, final S } } - public HybridQueryBuilder getQueryBuilder(final String modelId) { + private HybridQueryBuilder getQueryBuilder(final String modelId) { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(); neuralQueryBuilder.fieldName("passage_embedding"); neuralQueryBuilder.modelId(modelId); diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java index 1e1669c9b..292540820 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java @@ -103,7 +103,7 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod } } - public HybridQueryBuilder getQueryBuilder(final String modelId) { + private HybridQueryBuilder getQueryBuilder(final String modelId) { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(); neuralQueryBuilder.fieldName("passage_embedding"); neuralQueryBuilder.modelId(modelId); diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index e062e439e..786d96acf 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -119,7 +119,7 @@ protected void updateClusterSettings() { } @SneakyThrows - protected void updateClusterSettings(String settingKey, Object value) { + protected void updateClusterSettings(final String settingKey, final Object value) { XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() .startObject("persistent") @@ -138,13 +138,13 @@ protected void updateClusterSettings(String settingKey, Object value) { assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } - protected String registerModelGroupAndUploadModel(String requestBody) throws Exception { + protected String registerModelGroupAndUploadModel(final String requestBody) throws Exception { String modelGroupId = getModelGroupId(); // model group id is dynamically generated, we need to update model update request body after group is registered return uploadModel(String.format(LOCALE, requestBody, modelGroupId)); } - protected String uploadModel(String requestBody) throws Exception { + protected String uploadModel(final String requestBody) throws Exception { Response uploadResponse = makeRequest( client(), "POST", @@ -173,7 +173,7 @@ protected String uploadModel(String requestBody) throws Exception { return modelId; } - protected void loadModel(String modelId) throws Exception { + protected void loadModel(final String modelId) throws Exception { Response uploadResponse = makeRequest( client(), "POST", @@ -236,7 +236,7 @@ protected String prepareSparseEncodingModel() { */ @SuppressWarnings("unchecked") @SneakyThrows - protected float[] runInference(String modelId, String queryText) { + protected float[] runInference(final String modelId, final String queryText) { Response inferenceResponse = makeRequest( client(), "POST", @@ -264,7 +264,8 @@ protected float[] runInference(String modelId, String queryText) { return vectorAsListToArray(data); } - protected void createIndexWithConfiguration(String indexName, String indexConfiguration, String pipelineName) throws Exception { + protected void createIndexWithConfiguration(final String indexName, String indexConfiguration, final String pipelineName) + throws Exception { if (StringUtils.isNotBlank(pipelineName)) { indexConfiguration = String.format(LOCALE, indexConfiguration, pipelineName); } @@ -285,12 +286,13 @@ protected void createIndexWithConfiguration(String indexName, String indexConfig assertEquals(indexName, node.get("index").toString()); } - protected void createPipelineProcessor(String modelId, String pipelineName, ProcessorType processorType) throws Exception { + protected void createPipelineProcessor(final String modelId, final String pipelineName, final ProcessorType processorType) + throws Exception { String requestBody = Files.readString(Path.of(classLoader.getResource(PIPELINE_CONFIGS_BY_TYPE.get(processorType)).toURI())); createPipelineProcessor(requestBody, pipelineName, modelId); } - protected void createPipelineProcessor(String requestBody, String pipelineName, String modelId) throws Exception { + protected void createPipelineProcessor(final String requestBody, final String pipelineName, final String modelId) throws Exception { Response pipelineCreateResponse = makeRequest( client(), "PUT", @@ -307,7 +309,7 @@ protected void createPipelineProcessor(String requestBody, String pipelineName, assertEquals("true", node.get("acknowledged").toString()); } - protected void createSearchRequestProcessor(String modelId, String pipelineName) throws Exception { + protected void createSearchRequestProcessor(final String modelId, final String pipelineName) throws Exception { Response pipelineCreateResponse = makeRequest( client(), "PUT", @@ -337,7 +339,7 @@ protected void createSearchRequestProcessor(String modelId, String pipelineName) * @return number of documents indexed to that index */ @SneakyThrows - protected int getDocCount(String indexName) { + protected int getDocCount(final String indexName) { Request request = new Request("GET", "/" + indexName + "/_count"); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); @@ -354,7 +356,7 @@ protected int getDocCount(String indexName) { * @param resultSize number of results to return in the search * @return Search results represented as a map */ - protected Map search(String index, QueryBuilder queryBuilder, int resultSize) { + protected Map search(final String index, final QueryBuilder queryBuilder, final int resultSize) { return search(index, queryBuilder, null, resultSize); } @@ -368,7 +370,12 @@ protected Map search(String index, QueryBuilder queryBuilder, in * @return Search results represented as a map */ @SneakyThrows - protected Map search(String index, QueryBuilder queryBuilder, QueryBuilder rescorer, int resultSize) { + protected Map search( + final String index, + final QueryBuilder queryBuilder, + final QueryBuilder rescorer, + final int resultSize + ) { return search(index, queryBuilder, rescorer, resultSize, Map.of()); } @@ -384,11 +391,11 @@ protected Map search(String index, QueryBuilder queryBuilder, Qu */ @SneakyThrows protected Map search( - String index, - QueryBuilder queryBuilder, - QueryBuilder rescorer, - int resultSize, - Map requestParams + final String index, + final QueryBuilder queryBuilder, + final QueryBuilder rescorer, + final int resultSize, + final Map requestParams ) { XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("query"); queryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -424,18 +431,18 @@ protected Map search( * @param vectorFieldNames List of vectir fields to be added * @param vectors List of vectors corresponding to those fields */ - protected void addKnnDoc(String index, String docId, List vectorFieldNames, List vectors) { + protected void addKnnDoc(final String index, final String docId, final List vectorFieldNames, final List vectors) { addKnnDoc(index, docId, vectorFieldNames, vectors, Collections.emptyList(), Collections.emptyList()); } @SneakyThrows protected void addKnnDoc( - String index, - String docId, - List vectorFieldNames, - List vectors, - List textFieldNames, - List texts + final String index, + final String docId, + final List vectorFieldNames, + final List vectors, + final List textFieldNames, + final List texts ) { addKnnDoc(index, docId, vectorFieldNames, vectors, textFieldNames, texts, Collections.emptyList(), Collections.emptyList()); } @@ -454,14 +461,14 @@ protected void addKnnDoc( */ @SneakyThrows protected void addKnnDoc( - String index, - String docId, - List vectorFieldNames, - List vectors, - List textFieldNames, - List texts, - List nestedFieldNames, - List> nestedFields + final String index, + final String docId, + final List vectorFieldNames, + final List vectors, + final List textFieldNames, + final List texts, + final List nestedFieldNames, + final List> nestedFields ) { Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); @@ -490,18 +497,23 @@ protected void addKnnDoc( } @SneakyThrows - protected void addSparseEncodingDoc(String index, String docId, List fieldNames, List> docs) { + protected void addSparseEncodingDoc( + final String index, + final String docId, + final List fieldNames, + final List> docs + ) { addSparseEncodingDoc(index, docId, fieldNames, docs, Collections.emptyList(), Collections.emptyList()); } @SneakyThrows protected void addSparseEncodingDoc( - String index, - String docId, - List fieldNames, - List> docs, - List textFieldNames, - List texts + final String index, + final String docId, + final List fieldNames, + final List> docs, + final List textFieldNames, + final List texts ) { Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); @@ -526,7 +538,7 @@ protected void addSparseEncodingDoc( * @return Map of first internal hit from the search */ @SuppressWarnings("unchecked") - protected Map getFirstInnerHit(Map searchResponseAsMap) { + protected Map getFirstInnerHit(final Map searchResponseAsMap) { Map hits1map = (Map) searchResponseAsMap.get("hits"); List hits2List = (List) hits1map.get("hits"); assertTrue(hits2List.size() > 0); @@ -540,7 +552,7 @@ protected Map getFirstInnerHit(Map searchRespons * @return number of hits from the search */ @SuppressWarnings("unchecked") - protected int getHitCount(Map searchResponseAsMap) { + protected int getHitCount(final Map searchResponseAsMap) { Map hits1map = (Map) searchResponseAsMap.get("hits"); List hits1List = (List) hits1map.get("hits"); return hits1List.size(); @@ -553,7 +565,7 @@ protected int getHitCount(Map searchResponseAsMap) { * @return number of scores list from the search */ @SuppressWarnings("unchecked") - protected List getNormalizationScoreList(Map searchResponseAsMap) { + protected List getNormalizationScoreList(final Map searchResponseAsMap) { Map hits1map = (Map) searchResponseAsMap.get("hits"); List hitsList = (List) hits1map.get("hits"); List scores = new ArrayList<>(); @@ -571,17 +583,17 @@ protected List getNormalizationScoreList(Map searchRespo * @param knnFieldConfigs list of configs specifying field */ @SneakyThrows - protected void prepareKnnIndex(String indexName, List knnFieldConfigs) { + protected void prepareKnnIndex(final String indexName, final List knnFieldConfigs) { prepareKnnIndex(indexName, knnFieldConfigs, 3); } @SneakyThrows - protected void prepareKnnIndex(String indexName, List knnFieldConfigs, int numOfShards) { + protected void prepareKnnIndex(final String indexName, final List knnFieldConfigs, final int numOfShards) { createIndexWithConfiguration(indexName, buildIndexConfiguration(knnFieldConfigs, numOfShards), ""); } @SneakyThrows - protected void prepareSparseEncodingIndex(String indexName, List sparseEncodingFieldNames) { + protected void prepareSparseEncodingIndex(final String indexName, final List sparseEncodingFieldNames) { XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject("mappings").startObject("properties"); for (String fieldName : sparseEncodingFieldNames) { @@ -602,12 +614,17 @@ protected void prepareSparseEncodingIndex(String indexName, List sparseE * @param queryText Text to produce query vector from * @return Expected OpenSearch score for this indexVector */ - protected float computeExpectedScore(String modelId, float[] indexVector, SpaceType spaceType, String queryText) { + protected float computeExpectedScore( + final String modelId, + final float[] indexVector, + final SpaceType spaceType, + final String queryText + ) { float[] queryVector = runInference(modelId, queryText); return spaceType.getVectorSimilarityFunction().compare(queryVector, indexVector); } - protected Map getTaskQueryResponse(String taskId) throws Exception { + protected Map getTaskQueryResponse(final String taskId) throws Exception { Response taskQueryResponse = makeRequest( client(), "GET", @@ -619,7 +636,7 @@ protected Map getTaskQueryResponse(String taskId) throws Excepti return XContentHelper.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(taskQueryResponse.getEntity()), false); } - protected boolean checkComplete(Map node) { + protected boolean checkComplete(final Map node) { Predicate> predicate = x -> node.get("error") != null || "COMPLETED".equals(String.valueOf(node.get("state"))); return predicate.test(node); } @@ -917,7 +934,7 @@ private String getModelGroupId() { ); } - protected String registerModelGroup(String modelGroupRegisterRequestBody) throws IOException, ParseException { + protected String registerModelGroup(final String modelGroupRegisterRequestBody) throws IOException, ParseException { Response modelGroupResponse = makeRequest( client(), "POST", @@ -937,7 +954,7 @@ protected String registerModelGroup(String modelGroupRegisterRequestBody) throws } // Method that waits till the health of nodes in the cluster goes green - protected void waitForClusterHealthGreen(String numOfNodes) throws IOException { + protected void waitForClusterHealthGreen(final String numOfNodes) throws IOException { Request waitForGreen = new Request("GET", "/_cluster/health"); waitForGreen.addParameter("wait_for_nodes", numOfNodes); waitForGreen.addParameter("wait_for_status", "green"); @@ -955,8 +972,14 @@ protected void waitForClusterHealthGreen(String numOfNodes) throws IOException { * @param imageText name of the image text * */ - protected void addDocument(String index, String docId, String fieldName, String text, String imagefieldName, String imageText) - throws IOException { + protected void addDocument( + final String index, + final String docId, + final String fieldName, + final String text, + final String imagefieldName, + final String imageText + ) throws IOException { Request request = new Request("PUT", "/" + index + "/_doc/" + docId + "?refresh=true"); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); @@ -976,7 +999,7 @@ protected void addDocument(String index, String docId, String fieldName, String * @return get pipeline response as a map object */ @SneakyThrows - protected Map getIngestionPipeline(String pipelineName) { + protected Map getIngestionPipeline(final String pipelineName) { Request request = new Request("GET", "/_ingest/pipeline/" + pipelineName); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); @@ -993,7 +1016,7 @@ protected Map getIngestionPipeline(String pipelineName) { * @return delete pipeline response as a map object */ @SneakyThrows - protected Map deletePipeline(String pipelineName) { + protected Map deletePipeline(final String pipelineName) { Request request = new Request("DELETE", "/_ingest/pipeline/" + pipelineName); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); @@ -1002,12 +1025,12 @@ protected Map deletePipeline(String pipelineName) { return responseMap; } - protected float computeExpectedScore(String modelId, Map tokenWeightMap, String queryText) { + protected float computeExpectedScore(final String modelId, final Map tokenWeightMap, final String queryText) { Map queryTokens = runSparseModelInference(modelId, queryText); return computeExpectedScore(tokenWeightMap, queryTokens); } - protected float computeExpectedScore(Map tokenWeightMap, Map queryTokens) { + protected float computeExpectedScore(final Map tokenWeightMap, final Map queryTokens) { Float score = 0f; for (Map.Entry entry : queryTokens.entrySet()) { if (tokenWeightMap.containsKey(entry.getKey())) { @@ -1018,7 +1041,7 @@ protected float computeExpectedScore(Map tokenWeightMap, Map runSparseModelInference(String modelId, String queryText) { + protected Map runSparseModelInference(final String modelId, final String queryText) { Response inferenceResponse = makeRequest( client(), "POST", @@ -1049,7 +1072,7 @@ protected Map runSparseModelInference(String modelId, String quer // rank_features use lucene FeatureField, which will compress the Float number to 16 bit // this function simulate the encoding and decoding progress in lucene FeatureField - protected Float getFeatureFieldCompressedNumber(Float originNumber) { + protected Float getFeatureFieldCompressedNumber(final Float originNumber) { int freqBits = Float.floatToIntBits(originNumber); freqBits = freqBits >> 15; freqBits = ((int) ((float) freqBits)) << 15; @@ -1057,7 +1080,12 @@ protected Float getFeatureFieldCompressedNumber(Float originNumber) { } // Wipe of all the resources after execution of the tests. - protected void wipeOfTestResources(String indexName, String ingestPipeline, String modelId, String searchPipeline) throws IOException { + protected void wipeOfTestResources( + final String indexName, + final String ingestPipeline, + final String modelId, + final String searchPipeline + ) throws IOException { if (ingestPipeline != null) { deletePipeline(ingestPipeline); }