diff --git a/CHANGELOG.md b/CHANGELOG.md index eab952cf4..02a86725f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.18...2.x) ### Features - Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283] +- Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292] ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorSimilarityFunction.java b/src/main/java/org/opensearch/knn/index/KNNVectorSimilarityFunction.java index 7eca6287c..a17b0467b 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorSimilarityFunction.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorSimilarityFunction.java @@ -29,6 +29,7 @@ public float compare(byte[] v1, byte[] v2) { @Override public VectorSimilarityFunction getVectorSimilarityFunction() { + // For binary vectors using Lucene engine we instead implement a custom BinaryVectorScorer throw new IllegalStateException("VectorSimilarityFunction is not available for Hamming space"); } }; diff --git a/src/main/java/org/opensearch/knn/index/VectorDataType.java b/src/main/java/org/opensearch/knn/index/VectorDataType.java index 4827a4582..e97bd2dbf 100644 --- a/src/main/java/org/opensearch/knn/index/VectorDataType.java +++ b/src/main/java/org/opensearch/knn/index/VectorDataType.java @@ -30,7 +30,7 @@ /** * Enum contains data_type of vectors - * Lucene supports byte and float data type + * Lucene supports binary, byte and float data type * NMSLib supports only float data type * Faiss supports binary and float data type */ @@ -39,8 +39,10 @@ public enum VectorDataType { BINARY("binary") { @Override - public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) { - throw new IllegalStateException("Unsupported method"); + public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunction knnVectorSimilarityFunction) { + // For binary vectors using Lucene engine we instead implement a custom BinaryVectorScorer so the VectorSimilarityFunction will + // not be used. + return KnnByteVectorField.createFieldType(dimension / Byte.SIZE, VectorSimilarityFunction.EUCLIDEAN); } @Override @@ -68,8 +70,8 @@ public void freeNativeMemory(long memoryAddress) { BYTE("byte") { @Override - public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) { - return KnnByteVectorField.createFieldType(dimension, vectorSimilarityFunction); + public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunction knnVectorSimilarityFunction) { + return KnnByteVectorField.createFieldType(dimension, knnVectorSimilarityFunction.getVectorSimilarityFunction()); } @Override @@ -97,8 +99,8 @@ public void freeNativeMemory(long memoryAddress) { FLOAT("float") { @Override - public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) { - return KnnVectorField.createFieldType(dimension, vectorSimilarityFunction); + public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunction knnVectorSimilarityFunction) { + return KnnVectorField.createFieldType(dimension, knnVectorSimilarityFunction.getVectorSimilarityFunction()); } @Override @@ -129,11 +131,11 @@ public void freeNativeMemory(long memoryAddress) { * Creates a KnnVectorFieldType based on the VectorDataType using the provided dimension and * VectorSimilarityFunction. * - * @param dimension Dimension of the vector - * @param vectorSimilarityFunction VectorSimilarityFunction for a given spaceType + * @param dimension Dimension of the vector + * @param knnVectorSimilarityFunction KNNVectorSimilarityFunction for a given spaceType * @return FieldType */ - public abstract FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction); + public abstract FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunction knnVectorSimilarityFunction); /** * Deserializes float vector from BytesRef. diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index 72187516f..f3a125838 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -114,7 +114,12 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { } } - KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth); + KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams( + params, + defaultMaxConnections, + defaultBeamWidth, + knnMethodContext.getSpaceType() + ); log.debug( "Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"", field, diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120BinaryVectorScorer.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120BinaryVectorScorer.java new file mode 100644 index 000000000..2b3723439 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120BinaryVectorScorer.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN9120Codec; + +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.opensearch.knn.index.KNNVectorSimilarityFunction; + +import java.io.IOException; + +/** + * A FlatVectorsScorer to be used for scoring binary vectors. Meant to be used with {@link KNN9120BinaryVectorScorer} + */ +public class KNN9120BinaryVectorScorer implements FlatVectorsScorer { + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction vectorSimilarityFunction, + RandomAccessVectorValues randomAccessVectorValues + ) throws IOException { + if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes) { + return new BinaryRandomVectorScorerSupplier((RandomAccessVectorValues.Bytes) randomAccessVectorValues); + } + throw new IllegalArgumentException("vectorValues must be an instance of RandomAccessVectorValues.Bytes"); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction vectorSimilarityFunction, + RandomAccessVectorValues randomAccessVectorValues, + float[] queryVector + ) throws IOException { + throw new IllegalArgumentException("binary vectors do not support float[] targets"); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction vectorSimilarityFunction, + RandomAccessVectorValues randomAccessVectorValues, + byte[] queryVector + ) throws IOException { + if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes) { + return new BinaryRandomVectorScorer((RandomAccessVectorValues.Bytes) randomAccessVectorValues, queryVector); + } + throw new IllegalArgumentException("vectorValues must be an instance of RandomAccessVectorValues.Bytes"); + } + + static class BinaryRandomVectorScorer implements RandomVectorScorer { + private final RandomAccessVectorValues.Bytes vectorValues; + private final byte[] queryVector; + + BinaryRandomVectorScorer(RandomAccessVectorValues.Bytes vectorValues, byte[] query) { + this.queryVector = query; + this.vectorValues = vectorValues; + } + + @Override + public float score(int node) throws IOException { + return KNNVectorSimilarityFunction.HAMMING.compare(queryVector, vectorValues.vectorValue(node)); + } + + @Override + public int maxOrd() { + return vectorValues.size(); + } + + @Override + public int ordToDoc(int ord) { + return vectorValues.ordToDoc(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return vectorValues.getAcceptOrds(acceptDocs); + } + } + + static class BinaryRandomVectorScorerSupplier implements RandomVectorScorerSupplier { + protected final RandomAccessVectorValues.Bytes vectorValues; + protected final RandomAccessVectorValues.Bytes vectorValues1; + protected final RandomAccessVectorValues.Bytes vectorValues2; + + public BinaryRandomVectorScorerSupplier(RandomAccessVectorValues.Bytes vectorValues) throws IOException { + this.vectorValues = vectorValues; + this.vectorValues1 = vectorValues.copy(); + this.vectorValues2 = vectorValues.copy(); + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + byte[] queryVector = vectorValues1.vectorValue(ord); + return new BinaryRandomVectorScorer(vectorValues2, queryVector); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new BinaryRandomVectorScorerSupplier(vectorValues.copy()); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120HnswBinaryVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120HnswBinaryVectorsFormat.java new file mode 100644 index 000000000..3d34d17cc --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120HnswBinaryVectorsFormat.java @@ -0,0 +1,119 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN9120Codec; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.TaskExecutor; +import org.opensearch.knn.index.engine.KNNEngine; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; +import static org.opensearch.knn.index.engine.KNNEngine.getMaxDimensionByEngine; + +/** + * Custom KnnVectorsFormat implementation to support binary vectors. This class is mostly identical to + * {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat}, however we use the custom {@link KNN9120BinaryVectorScorer} + * to perform hamming bit scoring. + */ +public final class KNN9120HnswBinaryVectorsFormat extends KnnVectorsFormat { + + private final int maxConn; + private final int beamWidth; + private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(new KNN9120BinaryVectorScorer()); + private final int numMergeWorkers; + private final TaskExecutor mergeExec; + + private static final String NAME = "KNN990HnswBinaryVectorsFormat"; + + /** + * Constructor logic is identical to {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat#Lucene99HnswVectorsFormat()} + */ + public KNN9120HnswBinaryVectorsFormat() { + this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructor logic is identical to {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat#Lucene99HnswVectorsFormat(int, int)} + */ + public KNN9120HnswBinaryVectorsFormat(int maxConn, int beamWidth) { + this(maxConn, beamWidth, 1, null); + } + + /** + * Constructor logic is identical to {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat#Lucene99HnswVectorsFormat(int, int, int, java.util.concurrent.ExecutorService)} + */ + public KNN9120HnswBinaryVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { + super(NAME); + if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { + throw new IllegalArgumentException( + "maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn + ); + } + if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) { + throw new IllegalArgumentException( + "beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth + ); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + if (numMergeWorkers == 1 && mergeExec != null) { + throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge"); + } + this.numMergeWorkers = numMergeWorkers; + if (mergeExec != null) { + this.mergeExec = new TaskExecutor(mergeExec); + } else { + this.mergeExec = null; + } + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter( + state, + this.maxConn, + this.beamWidth, + flatVectorsFormat.fieldsWriter(state), + this.numMergeWorkers, + this.mergeExec + ); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return getMaxDimensionByEngine(KNNEngine.LUCENE); + } + + @Override + public String toString() { + return "KNN990HnswBinaryVectorsFormat(name=KNN990HnswBinaryVectorsFormat, maxConn=" + + this.maxConn + + ", beamWidth=" + + this.beamWidth + + ", flatVectorFormat=" + + flatVectorsFormat + + ")"; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120PerFieldKnnVectorsFormat.java new file mode 100644 index 000000000..6e8fc767e --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120PerFieldKnnVectorsFormat.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN9120Codec; + +import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat; +import org.opensearch.knn.index.engine.KNNEngine; + +import java.util.Optional; + +/** + * Class provides per field format implementation for Lucene Knn vector type + */ +public class KNN9120PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat { + private static final int NUM_MERGE_WORKERS = 1; + + public KNN9120PerFieldKnnVectorsFormat(final Optional mapperService) { + super( + mapperService, + Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, + Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, + Lucene99HnswVectorsFormat::new, + knnVectorsFormatParams -> { + // There is an assumption here that hamming space will only be used for binary vectors. This will need to be fixed if that + // changes in the future. + if (knnVectorsFormatParams.getSpaceType() == SpaceType.HAMMING) { + return new KNN9120HnswBinaryVectorsFormat( + knnVectorsFormatParams.getMaxConnections(), + knnVectorsFormatParams.getBeamWidth() + ); + } else { + return new Lucene99HnswVectorsFormat(knnVectorsFormatParams.getMaxConnections(), knnVectorsFormatParams.getBeamWidth()); + } + }, + knnScalarQuantizedVectorsFormatParams -> new Lucene99HnswScalarQuantizedVectorsFormat( + knnScalarQuantizedVectorsFormatParams.getMaxConnections(), + knnScalarQuantizedVectorsFormatParams.getBeamWidth(), + NUM_MERGE_WORKERS, + knnScalarQuantizedVectorsFormatParams.getBits(), + knnScalarQuantizedVectorsFormatParams.isCompressFlag(), + knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(), + null + ) + ); + } + + @Override + /** + * This method returns the maximum dimension allowed from KNNEngine for Lucene codec + * + * @param fieldName Name of the field, ignored + * @return Maximum constant dimension set by KNNEngine + */ + public int getMaxDimensions(String fieldName) { + return KNNEngine.getMaxDimensionByEngine(KNNEngine.LUCENE); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java index f565dfe5b..67ea7b544 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java @@ -24,7 +24,7 @@ public KNN990PerFieldKnnVectorsFormat(final Optional mapperServic mapperService, Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, - () -> new Lucene99HnswVectorsFormat(), + Lucene99HnswVectorsFormat::new, knnVectorsFormatParams -> new Lucene99HnswVectorsFormat( knnVectorsFormatParams.getMaxConnections(), knnVectorsFormatParams.getBeamWidth() diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java index fb9af0109..3df040785 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java @@ -20,6 +20,7 @@ import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat; import org.opensearch.knn.index.codec.KNN910Codec.KNN910Codec; import org.opensearch.knn.index.codec.KNN9120Codec.KNN9120Codec; +import org.opensearch.knn.index.codec.KNN9120Codec.KNN9120PerFieldKnnVectorsFormat; import org.opensearch.knn.index.codec.KNN920Codec.KNN920Codec; import org.opensearch.knn.index.codec.KNN920Codec.KNN920PerFieldKnnVectorsFormat; import org.opensearch.knn.index.codec.KNN940Codec.KNN940Codec; @@ -117,14 +118,14 @@ public enum KNNCodecVersion { V_9_12_0( "KNN9120Codec", new Lucene912Codec(), - new KNN990PerFieldKnnVectorsFormat(Optional.empty()), + new KNN9120PerFieldKnnVectorsFormat(Optional.empty()), (delegate) -> new KNNFormatFacade( new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), (userCodec, mapperService) -> KNN9120Codec.builder() .delegate(userCodec) - .knnVectorsFormat(new KNN990PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) + .knnVectorsFormat(new KNN9120PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) .build(), KNN9120Codec::new ); diff --git a/src/main/java/org/opensearch/knn/index/codec/params/KNNVectorsFormatParams.java b/src/main/java/org/opensearch/knn/index/codec/params/KNNVectorsFormatParams.java index 52134bc7e..ebf985fbb 100644 --- a/src/main/java/org/opensearch/knn/index/codec/params/KNNVectorsFormatParams.java +++ b/src/main/java/org/opensearch/knn/index/codec/params/KNNVectorsFormatParams.java @@ -7,6 +7,7 @@ import lombok.Getter; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.SpaceType; import java.util.Map; @@ -17,10 +18,16 @@ public class KNNVectorsFormatParams { private int maxConnections; private int beamWidth; + private final SpaceType spaceType; public KNNVectorsFormatParams(final Map params, int defaultMaxConnections, int defaultBeamWidth) { + this(params, defaultMaxConnections, defaultBeamWidth, SpaceType.UNDEFINED); + } + + public KNNVectorsFormatParams(final Map params, int defaultMaxConnections, int defaultBeamWidth, SpaceType spaceType) { initMaxConnections(params, defaultMaxConnections); initBeamWidth(params, defaultBeamWidth); + this.spaceType = spaceType; } public boolean validate(final Map params) { diff --git a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java index 57cc016a6..701f79768 100644 --- a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java @@ -30,13 +30,18 @@ */ public class LuceneHNSWMethod extends AbstractKNNMethod { - private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT, VectorDataType.BYTE); + private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of( + VectorDataType.FLOAT, + VectorDataType.BYTE, + VectorDataType.BINARY + ); public final static List SUPPORTED_SPACES = Arrays.asList( SpaceType.UNDEFINED, SpaceType.L2, SpaceType.COSINESIMIL, - SpaceType.INNER_PRODUCT + SpaceType.INNER_PRODUCT, + SpaceType.HAMMING ); final static Encoder SQ_ENCODER = new LuceneSQEncoder(); diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index fcf3aa034..7990fdcab 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -17,8 +17,8 @@ import org.apache.lucene.document.FieldType; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.index.VectorSimilarityFunction; import org.opensearch.common.Explicit; +import org.opensearch.knn.index.KNNVectorSimilarityFunction; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.engine.KNNEngine; @@ -100,11 +100,10 @@ private LuceneFieldMapper( KNNMethodContext resolvedKnnMethodContext = originalMappingParameters.getResolvedKnnMethodContext(); VectorDataType vectorDataType = mappedFieldType.getVectorDataType(); - final VectorSimilarityFunction vectorSimilarityFunction = resolvedKnnMethodContext.getSpaceType() - .getKnnVectorSimilarityFunction() - .getVectorSimilarityFunction(); + final KNNVectorSimilarityFunction knnVectorSimilarityFunction = resolvedKnnMethodContext.getSpaceType() + .getKnnVectorSimilarityFunction(); - this.fieldType = vectorDataType.createKnnVectorFieldType(knnMappingConfig.getDimension(), vectorSimilarityFunction); + this.fieldType = vectorDataType.createKnnVectorFieldType(knnMappingConfig.getDimension(), knnVectorSimilarityFunction); if (this.hasDocValues) { this.vectorFieldType = buildDocValuesFieldType(resolvedKnnMethodContext.getKnnEngine()); diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index d01a9aff6..74b864f98 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -129,6 +129,7 @@ public static Query create(CreateQueryRequest createQueryRequest) { log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); switch (vectorDataType) { case BYTE: + case BINARY: return new LuceneEngineKnnVectorQuery( getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter, expandNested) ); diff --git a/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index d799c3869..0fa7314c8 100644 --- a/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -10,3 +10,4 @@ # org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsFormat +org.opensearch.knn.index.codec.KNN9120Codec.KNN9120HnswBinaryVectorsFormat diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java index f760a6e88..73af608c1 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java @@ -13,7 +13,6 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.util.BytesRef; @@ -109,14 +108,6 @@ private void createKNNByteVectorDocument(Directory directory) throws IOException writer.close(); } - public void testCreateKnnVectorFieldType_whenBinary_thenException() { - Exception ex = expectThrows( - IllegalStateException.class, - () -> VectorDataType.BINARY.createKnnVectorFieldType(1, VectorSimilarityFunction.EUCLIDEAN) - ); - assertTrue(ex.getMessage().contains("Unsupported method")); - } - public void testGetVectorFromBytesRef_whenBinary_thenException() { byte[] vector = { 1, 2, 3 }; float[] expected = { 1, 2, 3 }; diff --git a/src/test/java/org/opensearch/knn/index/engine/KNNMethodContextTests.java b/src/test/java/org/opensearch/knn/index/engine/KNNMethodContextTests.java index c5979e576..bc72503a2 100644 --- a/src/test/java/org/opensearch/knn/index/engine/KNNMethodContextTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/KNNMethodContextTests.java @@ -283,13 +283,6 @@ public void testValidateVectorDataType_whenBinaryFaissHNSW_thenValid() { } public void testValidateVectorDataType_whenBinaryNonFaiss_thenException() { - validateValidateVectorDataType( - KNNEngine.LUCENE, - KNNConstants.METHOD_HNSW, - VectorDataType.BINARY, - SpaceType.HAMMING, - "UnsupportedMethod" - ); validateValidateVectorDataType( KNNEngine.NMSLIB, KNNConstants.METHOD_HNSW, diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index 714723a8e..9e637be9b 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -1528,8 +1528,7 @@ public void testTypeParser_whenBinaryFaissHNSWWithInvalidSpaceType_thenException } } - public void testTypeParser_whenBinaryNonFaiss_thenException() throws IOException { - testTypeParserWithBinaryDataType(KNNEngine.LUCENE, SpaceType.HAMMING, METHOD_HNSW, 8, "is not supported for vector data type"); + public void testTypeParser_whenBinaryNmslib_thenException() throws IOException { testTypeParserWithBinaryDataType(KNNEngine.NMSLIB, SpaceType.HAMMING, METHOD_HNSW, 8, "is not supported for vector data type"); } diff --git a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java index e98a1d769..b42a4f77c 100644 --- a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java +++ b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java @@ -5,6 +5,7 @@ package org.opensearch.knn.integ; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Floats; import lombok.SneakyThrows; @@ -27,6 +28,7 @@ import java.io.IOException; import java.net.URL; import java.util.Arrays; +import java.util.Collection; import java.util.List; import java.util.Set; import java.util.stream.Collectors; @@ -34,13 +36,23 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; /** - * This class contains integration tests for binary index with HNSW in Faiss + * This class contains integration tests for binary index with HNSW in Faiss and Lucene */ @Log4j2 public class BinaryIndexIT extends KNNRestTestCase { private static TestUtils.TestData testData; private static final int NEVER_BUILD_GRAPH = -1; private static final int ALWAYS_BUILD_GRAPH = 0; + private final KNNEngine engine; + + public BinaryIndexIT(KNNEngine engine) { + this.engine = engine; + } + + @ParametersFactory + public static Collection parameters() { + return Arrays.asList(new Object[] { KNNEngine.LUCENE }, new Object[] { KNNEngine.FAISS }); + } @BeforeClass public static void setUpClass() throws IOException { @@ -66,9 +78,9 @@ public void cleanUp() { } @SneakyThrows - public void testFaissHnswBinary_whenSmallDataSet_thenCreateIngestQueryWorks() { + public void testHnswBinary_whenSmallDataSet_thenCreateIngestQueryWorks() { // Create Index - createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 16); + createKnnHnswBinaryIndex(engine, INDEX_NAME, FIELD_NAME, 16); // Ingest Byte[] vector1 = { 0b00000001, 0b00000001 }; @@ -93,9 +105,9 @@ public void testFaissHnswBinary_whenSmallDataSet_thenCreateIngestQueryWorks() { } @SneakyThrows - public void testFaissHnswBinary_when1000Data_thenRecallIsAboveNinePointZero() { + public void testHnswBinary_when1000Data_thenRecallIsAboveNinePointZero() { // Create Index - createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128); + createKnnHnswBinaryIndex(engine, INDEX_NAME, FIELD_NAME, 128); ingestTestData(INDEX_NAME, FIELD_NAME); int k = 100; @@ -110,9 +122,9 @@ public void testFaissHnswBinary_when1000Data_thenRecallIsAboveNinePointZero() { } @SneakyThrows - public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsNegativeEndToEnd_thenBuildGraphBasedOnSetting() { + public void testHnswBinary_whenBuildVectorGraphThresholdIsNegativeEndToEnd_thenBuildGraphBasedOnSetting() { // Create Index - createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128, NEVER_BUILD_GRAPH); + createKnnHnswBinaryIndex(engine, INDEX_NAME, FIELD_NAME, 128, NEVER_BUILD_GRAPH); ingestTestData(INDEX_NAME, FIELD_NAME); assertEquals(1, runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[0], 1).size()); @@ -133,9 +145,9 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsNegativeEndToEnd_ } @SneakyThrows - public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSetting() { + public void testHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSetting() { // Create Index - createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128, testData.indexData.docs.length); + createKnnHnswBinaryIndex(engine, INDEX_NAME, FIELD_NAME, 128, testData.indexData.docs.length); ingestTestData(INDEX_NAME, FIELD_NAME, false); assertEquals(1, runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[0], 1).size()); @@ -156,9 +168,9 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_ } @SneakyThrows - public void testFaissHnswBinary_whenRadialSearch_thenThrowException() { + public void testHnswBinary_whenRadialSearch_thenThrowException() { // Create Index - createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 16); + createKnnHnswBinaryIndex(engine, INDEX_NAME, FIELD_NAME, 16); // Query float[] queryVector = { (byte) 0b10001111, (byte) 0b10000000 }; diff --git a/src/test/java/org/opensearch/knn/integ/BinaryIndexInvalidMappingIT.java b/src/test/java/org/opensearch/knn/integ/BinaryIndexInvalidMappingIT.java index 29e710ec1..a706dd0cd 100644 --- a/src/test/java/org/opensearch/knn/integ/BinaryIndexInvalidMappingIT.java +++ b/src/test/java/org/opensearch/knn/integ/BinaryIndexInvalidMappingIT.java @@ -46,11 +46,6 @@ public void cleanUp() { public static Collection parameters() throws IOException { return Arrays.asList( $$( - $( - "Creation of binary index with lucene engine should fail", - createKnnHnswBinaryIndexMapping(KNNEngine.LUCENE, FIELD_NAME, 16, null), - "Validation Failed" - ), $( "Creation of binary index with nmslib engine should fail", createKnnHnswBinaryIndexMapping(KNNEngine.NMSLIB, FIELD_NAME, 16, null), diff --git a/src/test/java/org/opensearch/knn/integ/FilteredSearchBinaryIT.java b/src/test/java/org/opensearch/knn/integ/FilteredSearchBinaryIT.java index 6d0c8fde4..226f31dc2 100644 --- a/src/test/java/org/opensearch/knn/integ/FilteredSearchBinaryIT.java +++ b/src/test/java/org/opensearch/knn/integ/FilteredSearchBinaryIT.java @@ -5,6 +5,7 @@ package org.opensearch.knn.integ; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; import lombok.extern.log4j.Log4j2; @@ -19,12 +20,25 @@ import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import java.util.Arrays; +import java.util.Collection; import java.util.List; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; @Log4j2 public class FilteredSearchBinaryIT extends KNNRestTestCase { + private final KNNEngine engine; + + public FilteredSearchBinaryIT(KNNEngine engine) { + this.engine = engine; + } + + @ParametersFactory + public static Collection parameters() { + return Arrays.asList(new Object[] { KNNEngine.LUCENE }, new Object[] { KNNEngine.FAISS }); + } + @After public void cleanUp() { try { @@ -35,18 +49,18 @@ public void cleanUp() { } @SneakyThrows - public void testFilteredSearchWithFaissHnswBinary_whenDoingApproximateSearch_thenReturnCorrectResults() { - validateFilteredSearchWithFaissHnswBinary(INDEX_NAME, false); + public void testFilteredSearchHnswBinary_whenDoingApproximateSearch_thenReturnCorrectResults() { + validateFilteredSearchHnswBinary(INDEX_NAME, false); } @SneakyThrows - public void testFilteredSearchWithFaissHnswBinary_whenDoingExactSearch_thenReturnCorrectResults() { - validateFilteredSearchWithFaissHnswBinary(INDEX_NAME, true); + public void testFilteredSearchHnswBinary_whenDoingExactSearch_thenReturnCorrectResults() { + validateFilteredSearchHnswBinary(INDEX_NAME, true); } - private void validateFilteredSearchWithFaissHnswBinary(final String indexName, final boolean doExactSearch) throws Exception { + private void validateFilteredSearchHnswBinary(final String indexName, final boolean doExactSearch) throws Exception { String filterFieldName = "parking"; - createKnnBinaryIndex(indexName, FIELD_NAME, 24, KNNEngine.FAISS); + createKnnBinaryIndex(indexName, FIELD_NAME, 24, engine); for (byte i = 1; i < 4; i++) { addKnnDocWithAttributes( diff --git a/src/test/java/org/opensearch/knn/integ/NestedSearchBinaryIT.java b/src/test/java/org/opensearch/knn/integ/NestedSearchBinaryIT.java index bb3767370..521a89e9d 100644 --- a/src/test/java/org/opensearch/knn/integ/NestedSearchBinaryIT.java +++ b/src/test/java/org/opensearch/knn/integ/NestedSearchBinaryIT.java @@ -5,6 +5,7 @@ package org.opensearch.knn.integ; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import lombok.SneakyThrows; import lombok.extern.log4j.Log4j2; import org.apache.http.util.EntityUtils; @@ -19,12 +20,25 @@ import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import java.util.Arrays; +import java.util.Collection; import java.util.List; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; @Log4j2 public class NestedSearchBinaryIT extends KNNRestTestCase { + private final KNNEngine engine; + + public NestedSearchBinaryIT(KNNEngine engine) { + this.engine = engine; + } + + @ParametersFactory + public static Collection parameters() { + return Arrays.asList(new Object[] { KNNEngine.LUCENE }, new Object[] { KNNEngine.FAISS }); + } + @After public void cleanUp() { try { @@ -35,9 +49,10 @@ public void cleanUp() { } @SneakyThrows - public void testNestedSearchWithFaissHnswBinary_whenKIsTwo_thenReturnTwoResults() { + public void testNestedSearchHnswBinary_whenKIsTwo_thenReturnTwoResults() { + String nestedFieldName = "nested"; - createKnnBinaryIndexWithNestedField(INDEX_NAME, nestedFieldName, FIELD_NAME, 16, KNNEngine.FAISS); + createKnnBinaryIndexWithNestedField(INDEX_NAME, nestedFieldName, FIELD_NAME, 16, engine); int totalDocCount = 15; for (byte i = 0; i < totalDocCount; i++) { @@ -93,10 +108,10 @@ public void testNestedSearchWithFaissHnswBinary_whenKIsTwo_thenReturnTwoResults( * */ @SneakyThrows - public void testNestedSearchWithFaissHnswBinary_whenDoingExactSearch_thenReturnCorrectResults() { + public void testNestedSearchHnswBinary_whenDoingExactSearch_thenReturnCorrectResults() { String nestedFieldName = "nested"; String filterFieldName = "parking"; - createKnnBinaryIndexWithNestedField(INDEX_NAME, nestedFieldName, FIELD_NAME, 24, KNNEngine.FAISS); + createKnnBinaryIndexWithNestedField(INDEX_NAME, nestedFieldName, FIELD_NAME, 24, engine); for (byte i = 1; i < 4; i++) { String doc = NestedKnnDocBuilder.create(nestedFieldName)