From ee6e7eeea0d2010724b1b31321067214508816e7 Mon Sep 17 00:00:00 2001 From: Vijayan Balasubramanian Date: Sat, 28 Sep 2024 14:49:43 -0700 Subject: [PATCH] Refactor and Update unit test to include field with no live docs Refactored if/else to reduce nesting. Added unit test when one of the field doesn't have live docs. Signed-off-by: Vijayan Balasubramanian --- CHANGELOG.md | 1 + .../NativeEngines990KnnVectorsWriter.java | 32 ++++++++-------- ...eEngines990KnnVectorsWriterFlushTests.java | 38 ++++++++++++------- 3 files changed, 42 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 67879bae7..5615509de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,3 +31,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Maintenance * Remove benchmarks folder from k-NN repo [#2127](https://github.com/opensearch-project/k-NN/pull/2127) ### Refactoring +* Minor refactoring and refactored some unit test [#2167](https://github.com/opensearch-project/k-NN/pull/2167) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 23cd2a4de..2f22565c9 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -84,24 +84,24 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { final FieldInfo fieldInfo = field.getFieldInfo(); final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); int totalLiveDocs = field.getVectors().size(); - if (totalLiveDocs > 0) { - final Supplier> knnVectorValuesSupplier = () -> getVectorValues( - vectorDataType, - field.getDocsWithField(), - field.getVectors() - ); - final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs); - final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); - final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); - - StopWatch stopWatch = new StopWatch().start(); - writer.flushIndex(knnVectorValues, totalLiveDocs); - long time_in_millis = stopWatch.stop().totalTime().millis(); - KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis); - log.debug("Flush took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); - } else { + if (totalLiveDocs == 0) { log.debug("[Flush] No live docs for field {}", fieldInfo.getName()); + continue; } + final Supplier> knnVectorValuesSupplier = () -> getVectorValues( + vectorDataType, + field.getDocsWithField(), + field.getVectors() + ); + final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs); + final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); + final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); + + StopWatch stopWatch = new StopWatch().start(); + writer.flushIndex(knnVectorValues, totalLiveDocs); + long time_in_millis = stopWatch.stop().totalTime().millis(); + KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis); + log.debug("Flush took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java index dbb564908..9f74b2c10 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -32,8 +32,11 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.function.Predicate; +import java.util.stream.Collectors; import java.util.stream.IntStream; import static com.carrotsearch.randomizedtesting.RandomizedTest.$; @@ -44,6 +47,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockConstruction; import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -86,6 +90,7 @@ public static Collection data() { "Multi Field", List.of( Map.of(0, new float[] { 1, 2, 3 }, 1, new float[] { 2, 3, 4 }, 2, new float[] { 3, 4, 5 }), + Collections.emptyMap(), Map.of( 0, new float[] { 1, 2, 3, 4 }, @@ -105,18 +110,16 @@ public static Collection data() { @SneakyThrows public void testFlush() { // Given - List> expectedVectorValues = new ArrayList<>(); - IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final List> expectedVectorValues = vectorsPerField.stream().map(vectors -> { final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( - new ArrayList<>(vectorsPerField.get(i).values()) + new ArrayList<>(vectors.values()) ); final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( VectorDataType.FLOAT, randomVectorValues ); - expectedVectorValues.add(knnVectorValues); - - }); + return knnVectorValues; + }).collect(Collectors.toList()); try ( MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); @@ -172,15 +175,19 @@ public void testFlush() { IntStream.range(0, vectorsPerField.size()).forEach(i -> { try { - verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + if (vectorsPerField.get(i).isEmpty()) { + verify(nativeIndexWriter, never()).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } else { + verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } } catch (Exception e) { throw new RuntimeException(e); } }); - + final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); knnVectorValuesFactoryMockedStatic.verify( () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), - times(expectedVectorValues.size()) + times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled)) ); } } @@ -264,16 +271,21 @@ public void testFlush_WithQuantization() { IntStream.range(0, vectorsPerField.size()).forEach(i -> { try { - verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState); - verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + if (vectorsPerField.get(i).isEmpty()) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0), never()).writeState(i, quantizationState); + verify(nativeIndexWriter, never()).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } else { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState); + verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } } catch (Exception e) { throw new RuntimeException(e); } }); - + final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); knnVectorValuesFactoryMockedStatic.verify( () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), - times(expectedVectorValues.size() * 2) + times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled) * 2) ); } }