From 39061a012780cdaeea7de859d94475ca19a89cf7 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Mon, 5 Aug 2024 16:30:02 -0700 Subject: [PATCH] Refactor KNNVectorFieldType from KNNVectorFieldMapper to a separate class for better readability. (#1931) (#1935) Signed-off-by: Navneet Verma (cherry picked from commit 967b21129121535f1f0b5d7268d2810b5d82fc47) Co-authored-by: Navneet Verma --- CHANGELOG.md | 3 +- .../codec/BasePerFieldKnnVectorsFormat.java | 6 +- .../index/mapper/KNNVectorFieldMapper.java | 103 ---------------- .../mapper/KNNVectorFieldMapperUtil.java | 4 +- .../knn/index/mapper/KNNVectorFieldType.java | 116 ++++++++++++++++++ .../knn/index/query/KNNQueryBuilder.java | 8 +- .../knn/plugin/script/KNNScoringSpace.java | 14 +-- .../plugin/script/KNNScoringSpaceUtil.java | 6 +- .../knn/index/codec/KNNCodecTestCase.java | 15 +-- .../mapper/KNNVectorFieldMapperTests.java | 2 +- .../mapper/KNNVectorFieldMapperUtilTests.java | 8 +- .../index/mapper/MethodFieldMapperTests.java | 2 +- .../knn/index/query/KNNQueryBuilderTests.java | 56 ++++----- .../knn/integ/KNNScriptScoringIT.java | 4 +- .../script/KNNScoringSpaceFactoryTests.java | 10 +- .../plugin/script/KNNScoringSpaceTests.java | 36 ++---- .../script/KNNScoringSpaceUtilTests.java | 10 +- 17 files changed, 192 insertions(+), 211 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java diff --git a/CHANGELOG.md b/CHANGELOG.md index f7431e9b7..6c761d27f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,4 +26,5 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Clean up parsing for query [#1824](https://github.com/opensearch-project/k-NN/pull/1824) * Refactor engine package structure [#1913](https://github.com/opensearch-project/k-NN/pull/1913) * Refactor method structure and definitions [#1920](https://github.com/opensearch-project/k-NN/pull/1920) -* Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925) \ No newline at end of file +* Refactor KNNVectorFieldType from KNNVectorFieldMapper to a separate class for better readability. [#1931](https://github.com/opensearch-project/k-NN/pull/1931) +* Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925) diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index 2d423c26e..2a3732d7e 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -12,8 +12,8 @@ import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.index.codec.params.KNNScalarQuantizedVectorsFormatParams; import org.opensearch.knn.index.codec.params.KNNVectorsFormatParams; -import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.mapper.KNNVectorFieldType; import java.util.Optional; import java.util.function.Function; @@ -66,7 +66,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { ); return defaultFormatSupplier.get(); } - var type = (KNNVectorFieldMapper.KNNVectorFieldType) mapperService.orElseThrow( + var type = (KNNVectorFieldType) mapperService.orElseThrow( () -> new IllegalStateException( String.format("Cannot read field type for field [%s] because mapper service is not available", field) ) @@ -117,6 +117,6 @@ public int getMaxDimensions(String fieldName) { } private boolean isKnnVectorFieldType(final String field) { - return mapperService.isPresent() && mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType; + return mapperService.isPresent() && mapperService.get().fieldType(field) instanceof KNNVectorFieldType; } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 14596189a..3b9487645 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -14,47 +14,33 @@ import java.util.Objects; import java.util.Optional; import java.util.function.Supplier; -import lombok.Getter; import lombok.extern.log4j.Log4j2; import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.IndexOptions; -import org.apache.lucene.search.DocValuesFieldExistsQuery; -import org.apache.lucene.search.Query; -import org.apache.lucene.util.BytesRef; import org.opensearch.Version; import org.opensearch.common.Explicit; -import org.opensearch.common.Nullable; import org.opensearch.common.ValidationException; import org.opensearch.common.xcontent.support.XContentMapValues; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.fielddata.IndexFieldData; import org.opensearch.index.mapper.FieldMapper; -import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.Mapper; import org.opensearch.index.mapper.MapperParsingException; import org.opensearch.index.mapper.ParametrizedFieldMapper; import org.opensearch.index.mapper.ParseContext; -import org.opensearch.index.mapper.TextSearchInfo; -import org.opensearch.index.mapper.ValueFetcher; -import org.opensearch.index.query.QueryShardContext; -import org.opensearch.index.query.QueryShardException; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KnnCircuitBreakerException; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.KNNVectorIndexFieldData; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.indices.ModelDao; -import org.opensearch.search.aggregations.support.CoreValuesSourceType; -import org.opensearch.search.lookup.SearchLookup; import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.ENCODER_FLAT; @@ -72,7 +58,6 @@ import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.deserializeStoredVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataType; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting; @@ -467,94 +452,6 @@ public Mapper.Builder parse(String name, Map node, ParserCont } } - @Getter - public static class KNNVectorFieldType extends MappedFieldType { - int dimension; - String modelId; - KNNMethodContext knnMethodContext; - VectorDataType vectorDataType; - SpaceType spaceType; - - public KNNVectorFieldType( - String name, - Map meta, - int dimension, - VectorDataType vectorDataType, - SpaceType spaceType - ) { - this(name, meta, dimension, null, null, vectorDataType, spaceType); - } - - public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext) { - this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE_FIELD, knnMethodContext.getSpaceType()); - } - - public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext, String modelId) { - this(name, meta, dimension, knnMethodContext, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD, null); - } - - public KNNVectorFieldType( - String name, - Map meta, - int dimension, - KNNMethodContext knnMethodContext, - VectorDataType vectorDataType - ) { - this(name, meta, dimension, knnMethodContext, null, vectorDataType, knnMethodContext.getSpaceType()); - } - - public KNNVectorFieldType( - String name, - Map meta, - int dimension, - @Nullable KNNMethodContext knnMethodContext, - @Nullable String modelId, - VectorDataType vectorDataType, - @Nullable SpaceType spaceType - ) { - super(name, false, false, true, TextSearchInfo.NONE, meta); - this.dimension = dimension; - this.modelId = modelId; - this.knnMethodContext = knnMethodContext; - this.vectorDataType = vectorDataType; - this.spaceType = spaceType; - } - - @Override - public ValueFetcher valueFetcher(QueryShardContext context, SearchLookup searchLookup, String format) { - throw new UnsupportedOperationException("KNN Vector do not support fields search"); - } - - @Override - public String typeName() { - return CONTENT_TYPE; - } - - @Override - public Query existsQuery(QueryShardContext context) { - return new DocValuesFieldExistsQuery(name()); - } - - @Override - public Query termQuery(Object value, QueryShardContext context) { - throw new QueryShardException( - context, - String.format(Locale.ROOT, "KNN vector do not support exact searching, use KNN queries instead: [%s]", name()) - ); - } - - @Override - public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, Supplier searchLookup) { - failIfNoDocValues(); - return new KNNVectorIndexFieldData.Builder(name(), CoreValuesSourceType.BYTES, this.vectorDataType); - } - - @Override - public Object valueForDisplay(Object value) { - return deserializeStoredVector((BytesRef) value, vectorDataType); - } - } - protected Explicit ignoreMalformed; protected boolean stored; protected boolean hasDocValues; diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index 07bf4fc2d..2adbbb695 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -236,7 +236,7 @@ public static Object deserializeStoredVector(BytesRef storedVector, VectorDataTy * @param knnVectorFieldType knn vector field type * @return expected vector length */ - public static int getExpectedVectorLength(final KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType) { + public static int getExpectedVectorLength(final KNNVectorFieldType knnVectorFieldType) { int expectedDimensions = knnVectorFieldType.getDimension(); if (isModelBasedIndex(expectedDimensions)) { ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType); @@ -255,7 +255,7 @@ private static boolean isModelBasedIndex(int expectedDimensions) { * @param knnVectorField knn vector field * @return the model metadata from knnVectorField */ - private static ModelMetadata getModelMetadataForField(final KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) { + private static ModelMetadata getModelMetadataForField(final KNNVectorFieldType knnVectorField) { String modelId = knnVectorField.getModelId(); if (modelId == null) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java new file mode 100644 index 000000000..8c3815c5f --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java @@ -0,0 +1,116 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import lombok.Getter; +import org.apache.lucene.search.DocValuesFieldExistsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.util.BytesRef; +import org.opensearch.common.Nullable; +import org.opensearch.index.fielddata.IndexFieldData; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.TextSearchInfo; +import org.opensearch.index.mapper.ValueFetcher; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.QueryShardException; +import org.opensearch.knn.index.KNNVectorIndexFieldData; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.search.aggregations.support.CoreValuesSourceType; +import org.opensearch.search.lookup.SearchLookup; + +import java.util.Locale; +import java.util.Map; +import java.util.function.Supplier; + +import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.deserializeStoredVector; + +/** + * A KNNVector field type to represent the vector field in Opensearch + */ +@Getter +public class KNNVectorFieldType extends MappedFieldType { + int dimension; + String modelId; + KNNMethodContext knnMethodContext; + VectorDataType vectorDataType; + SpaceType spaceType; + + public KNNVectorFieldType(String name, Map meta, int dimension, VectorDataType vectorDataType, SpaceType spaceType) { + this(name, meta, dimension, null, null, vectorDataType, spaceType); + } + + public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext) { + this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE_FIELD, knnMethodContext.getSpaceType()); + } + + public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext, String modelId) { + this(name, meta, dimension, knnMethodContext, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD, null); + } + + public KNNVectorFieldType( + String name, + Map meta, + int dimension, + KNNMethodContext knnMethodContext, + VectorDataType vectorDataType + ) { + this(name, meta, dimension, knnMethodContext, null, vectorDataType, knnMethodContext.getSpaceType()); + } + + public KNNVectorFieldType( + String name, + Map meta, + int dimension, + @Nullable KNNMethodContext knnMethodContext, + @Nullable String modelId, + VectorDataType vectorDataType, + @Nullable SpaceType spaceType + ) { + super(name, false, false, true, TextSearchInfo.NONE, meta); + this.dimension = dimension; + this.modelId = modelId; + this.knnMethodContext = knnMethodContext; + this.vectorDataType = vectorDataType; + this.spaceType = spaceType; + } + + @Override + public ValueFetcher valueFetcher(QueryShardContext context, SearchLookup searchLookup, String format) { + throw new UnsupportedOperationException("KNN Vector do not support fields search"); + } + + @Override + public String typeName() { + return KNNVectorFieldMapper.CONTENT_TYPE; + } + + @Override + public Query existsQuery(QueryShardContext context) { + return new DocValuesFieldExistsQuery(name()); + } + + @Override + public Query termQuery(Object value, QueryShardContext context) { + throw new QueryShardException( + context, + String.format(Locale.ROOT, "KNN vector do not support exact searching, use KNN queries instead: [%s]", name()) + ); + } + + @Override + public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, Supplier searchLookup) { + failIfNoDocValues(); + return new KNNVectorIndexFieldData.Builder(name(), CoreValuesSourceType.BYTES, this.vectorDataType); + } + + @Override + public Object valueForDisplay(Object value) { + return deserializeStoredVector((BytesRef) value, vectorDataType); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index b05938c4a..6d57cb2dd 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -24,13 +24,13 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.engine.model.QueryContext; +import org.opensearch.knn.index.mapper.KNNVectorFieldType; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorQueryType; -import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser; import org.opensearch.knn.index.engine.KNNLibrarySearchContext; import org.opensearch.knn.index.engine.KNNEngine; @@ -338,11 +338,11 @@ protected Query doToQuery(QueryShardContext context) { return new MatchNoDocsQuery(); } - if (!(mappedFieldType instanceof KNNVectorFieldMapper.KNNVectorFieldType)) { + if (!(mappedFieldType instanceof KNNVectorFieldType)) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Field '%s' is not knn_vector type.", this.fieldName)); } - KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldMapper.KNNVectorFieldType) mappedFieldType; + KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldType) mappedFieldType; int fieldDimension = knnVectorFieldType.getDimension(); KNNMethodContext knnMethodContext = knnVectorFieldType.getKnnMethodContext(); MethodComponentContext methodComponentContext = null; @@ -492,7 +492,7 @@ protected Query doToQuery(QueryShardContext context) { throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires k or distance or score to be set", NAME)); } - private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) { + private ModelMetadata getModelMetadataForField(KNNVectorFieldType knnVectorField) { String modelId = knnVectorField.getModelId(); if (modelId == null) { diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java index 850bff1ab..71616c9fd 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java @@ -11,8 +11,8 @@ import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil; +import org.opensearch.knn.index.mapper.KNNVectorFieldType; import org.opensearch.knn.index.query.KNNWeight; import org.opensearch.script.ScoreScript; import org.opensearch.search.lookup.SearchLookup; @@ -67,11 +67,7 @@ public KNNFieldSpace( final String spaceName, final Set supportingVectorDataTypes ) { - KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = toKNNVectorFieldType( - fieldType, - spaceName, - supportingVectorDataTypes - ); + KNNVectorFieldType knnVectorFieldType = toKNNVectorFieldType(fieldType, spaceName, supportingVectorDataTypes); this.processedQuery = getProcessedQuery(query, knnVectorFieldType); this.scoringMethod = getScoringMethod(this.processedQuery); } @@ -86,7 +82,7 @@ public ScoreScript getScoreScript( return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher); } - private KNNVectorFieldMapper.KNNVectorFieldType toKNNVectorFieldType( + private KNNVectorFieldType toKNNVectorFieldType( final MappedFieldType fieldType, final String spaceName, final Set supportingVectorDataTypes @@ -97,7 +93,7 @@ private KNNVectorFieldMapper.KNNVectorFieldType toKNNVectorFieldType( ); } - KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldMapper.KNNVectorFieldType) fieldType; + KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldType) fieldType; VectorDataType vectorDataType = knnVectorFieldType.getVectorDataType() == null ? VectorDataType.FLOAT : knnVectorFieldType.getVectorDataType(); @@ -116,7 +112,7 @@ private KNNVectorFieldMapper.KNNVectorFieldType toKNNVectorFieldType( return knnVectorFieldType; } - protected float[] getProcessedQuery(final Object query, final KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType) { + protected float[] getProcessedQuery(final Object query, final KNNVectorFieldType knnVectorFieldType) { return parseToFloatArray( query, KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldType), diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java index e2bade320..7a97fdb05 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java @@ -7,7 +7,7 @@ import java.util.List; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.mapper.KNNVectorFieldType; import org.opensearch.knn.plugin.stats.KNNCounter; import org.opensearch.index.mapper.BinaryFieldMapper; import org.opensearch.index.mapper.MappedFieldType; @@ -52,7 +52,7 @@ public static boolean isBinaryFieldType(MappedFieldType fieldType) { * @return true if fieldType is of type KNNVectorFieldType; false otherwise */ public static boolean isKNNVectorFieldType(MappedFieldType fieldType) { - return fieldType instanceof KNNVectorFieldMapper.KNNVectorFieldType; + return fieldType instanceof KNNVectorFieldType; } /** @@ -61,7 +61,7 @@ public static boolean isKNNVectorFieldType(MappedFieldType fieldType) { * @param fieldType KNN vector field type * @return true if the KNN field type is a binary vector data type */ - public static boolean isBinaryVectorDataType(final KNNVectorFieldMapper.KNNVectorFieldType fieldType) { + public static boolean isBinaryVectorDataType(final KNNVectorFieldType fieldType) { return VectorDataType.BINARY == fieldType.getVectorDataType(); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index 70b055df4..a0b9b32d0 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -25,6 +25,7 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; +import org.opensearch.knn.index.mapper.KNNVectorFieldType; import org.opensearch.knn.index.query.KNNQueryFactory; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.query.KNNQuery; @@ -308,18 +309,8 @@ public void testKnnVectorIndex( SpaceType.L2, new MethodComponentContext(METHOD_HNSW, Map.of(HNSW_ALGO_M, 16, HNSW_ALGO_EF_CONSTRUCTION, 256)) ); - final KNNVectorFieldMapper.KNNVectorFieldType mappedFieldType1 = new KNNVectorFieldMapper.KNNVectorFieldType( - FIELD_NAME_ONE, - Map.of(), - 3, - knnMethodContext - ); - final KNNVectorFieldMapper.KNNVectorFieldType mappedFieldType2 = new KNNVectorFieldMapper.KNNVectorFieldType( - FIELD_NAME_TWO, - Map.of(), - 2, - knnMethodContext - ); + final KNNVectorFieldType mappedFieldType1 = new KNNVectorFieldType(FIELD_NAME_ONE, Map.of(), 3, knnMethodContext); + final KNNVectorFieldType mappedFieldType2 = new KNNVectorFieldType(FIELD_NAME_TWO, Map.of(), 2, knnMethodContext); when(mapperService.fieldType(eq(FIELD_NAME_ONE))).thenReturn(mappedFieldType1); when(mapperService.fieldType(eq(FIELD_NAME_TWO))).thenReturn(mappedFieldType2); diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index e68d34e88..c95568be2 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -1058,7 +1058,7 @@ private LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperIn new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()) ); - KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldMapper.KNNVectorFieldType( + KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType( TEST_FIELD_NAME, Collections.emptyMap(), TEST_DIMENSION, diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java index a4d597a41..31da12d66 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java @@ -60,14 +60,14 @@ public void testStoredFields_whenVectorIsFloatType_thenSucceed() { } public void testGetExpectedVectorLengthSuccess() { - KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldType.class); when(knnVectorFieldType.getDimension()).thenReturn(3); - KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeBinary = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType knnVectorFieldTypeBinary = mock(KNNVectorFieldType.class); when(knnVectorFieldTypeBinary.getDimension()).thenReturn(8); when(knnVectorFieldTypeBinary.getVectorDataType()).thenReturn(VectorDataType.BINARY); - KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldType.class); when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1); String modelId = "test-model"; when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId); @@ -86,7 +86,7 @@ public void testGetExpectedVectorLengthSuccess() { } public void testGetExpectedVectorLengthFailure() { - KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldType.class); when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1); String modelId = "test-model"; when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId); diff --git a/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java index 9ae2fad9c..dcd255740 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java @@ -15,7 +15,7 @@ public class MethodFieldMapperTests extends TestCase { public void testMethodFieldMapper_whenVectorDataTypeIsGiven_thenSetItInFieldType() { - KNNVectorFieldMapper.KNNVectorFieldType mappedFieldType = new KNNVectorFieldMapper.KNNVectorFieldType( + KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( "testField", Collections.emptyMap(), 1, diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 63d9a6c30..0241a9afb 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -27,12 +27,12 @@ import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.mapper.KNNVectorFieldType; import org.opensearch.knn.index.util.KNNClusterUtil; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; @@ -169,7 +169,7 @@ public void testDoToQuery_Normal() throws Exception { KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); @@ -191,7 +191,7 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_th .build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); @@ -223,7 +223,7 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenS Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); @@ -250,7 +250,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSuppor .build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); @@ -282,7 +282,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupp .build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); @@ -309,7 +309,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSuppor KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).minScore(score).build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); @@ -336,7 +336,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenUnsupp KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).minScore(score).build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); @@ -367,7 +367,7 @@ public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSu Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); @@ -400,7 +400,7 @@ public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_then Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); @@ -428,7 +428,7 @@ public void testDoToQuery_whenRadialSearchOnBinaryIndex_thenException() { .build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(8); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY); @@ -454,7 +454,7 @@ public void testDoToQuery_KnnQueryWithFilter_Lucene() throws Exception { .build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); @@ -488,7 +488,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_th Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); @@ -515,7 +515,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenS .build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); @@ -537,7 +537,7 @@ public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); @@ -570,7 +570,7 @@ public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() { public void testDoToQuery_ThrowsIllegalArgumentExceptionForUnknownMethodParameter() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getKnnMethodContext()).thenReturn( @@ -593,7 +593,7 @@ public void testDoToQuery_whenknnQueryWithFilterAndNmsLibEngine_thenException() KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); @@ -614,7 +614,7 @@ public void testDoToQuery_FromModel() { KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); // Dimension is -1. In this case, model metadata will need to provide dimension @@ -655,7 +655,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(-K); @@ -693,7 +693,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_th Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(-K); @@ -728,7 +728,7 @@ public void testDoToQuery_InvalidDimensions() { KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(400); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); @@ -753,7 +753,7 @@ public void testDoToQuery_InvalidZeroFloatVector() { KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); @@ -774,7 +774,7 @@ public void testDoToQuery_InvalidZeroByteVector() { KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BYTE); @@ -902,7 +902,7 @@ public void testRadialSearch_whenUnsupportedEngine_thenThrowException() { .maxDistance(MAX_DISTANCE) .build(); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); Index dummyIndex = new Index("dummy", "dummy"); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); @@ -929,7 +929,7 @@ public void testRadialSearch_whenEfSearchIsSet_whenLuceneEngine_thenThrowExcepti .methodParameters(Map.of("ef_search", EF_SEARCH)) .build(); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); Index dummyIndex = new Index("dummy", "dummy"); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); @@ -955,7 +955,7 @@ public void testRadialSearch_whenEfSearchIsSet_whenFaissEngine_thenSuccess() { .methodParameters(Map.of("ef_search", EF_SEARCH)) .build(); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); Index dummyIndex = new Index("dummy", "dummy"); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); @@ -976,7 +976,7 @@ public void testDoToQuery_whenBinary_thenValid() throws Exception { KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(32); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY); @@ -992,7 +992,7 @@ public void testDoToQuery_whenBinaryWithInvalidDimension_thenException() throws KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(8); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY); diff --git a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java index a1a6e3aa6..e67ad40bc 100644 --- a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java @@ -30,8 +30,8 @@ import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder; import org.opensearch.core.rest.RestStatus; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.mapper.KNNVectorFieldType; import org.opensearch.knn.plugin.script.KNNScoringScriptEngine; import org.opensearch.knn.plugin.script.KNNScoringSpace; import org.opensearch.knn.plugin.script.KNNScoringSpaceFactory; @@ -735,7 +735,7 @@ private Map createDataset( } private BiFunction getScoreFunction(SpaceType spaceType, float[] queryVector) { - KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldMapper.KNNVectorFieldType( + KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType( FIELD_NAME, Collections.emptyMap(), SpaceType.HAMMING == spaceType ? queryVector.length * 8 : queryVector.length, diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java index e24acc483..823d21080 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java @@ -7,9 +7,9 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.SpaceType; import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.knn.index.mapper.KNNVectorFieldType; import java.util.List; @@ -18,9 +18,9 @@ public class KNNScoringSpaceFactoryTests extends KNNTestCase { public void testValidSpaces() { - KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldType.class); when(knnVectorFieldType.getDimension()).thenReturn(3); - KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeBinary = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType knnVectorFieldTypeBinary = mock(KNNVectorFieldType.class); when(knnVectorFieldTypeBinary.getDimension()).thenReturn(24); when(knnVectorFieldTypeBinary.getVectorDataType()).thenReturn(VectorDataType.BINARY); NumberFieldMapper.NumberFieldType numberFieldType = new NumberFieldMapper.NumberFieldType( @@ -65,9 +65,9 @@ public void testValidSpaces() { public void testInvalidSpace() { List floatQueryObject = List.of(1.0f, 1.0f, 1.0f); - KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldType.class); when(knnVectorFieldType.getDimension()).thenReturn(3); - KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeBinary = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType knnVectorFieldTypeBinary = mock(KNNVectorFieldType.class); when(knnVectorFieldTypeBinary.getDimension()).thenReturn(24); when(knnVectorFieldTypeBinary.getVectorDataType()).thenReturn(VectorDataType.BINARY); diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java index 07385e55b..6c557c8dd 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java @@ -15,9 +15,9 @@ import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.index.mapper.BinaryFieldMapper; import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.knn.index.mapper.KNNVectorFieldType; import java.math.BigInteger; import java.util.ArrayList; @@ -43,7 +43,7 @@ private void expectThrowsExceptionWithNonKNNField(Class clazz) throws NoSuchMeth private void expectThrowsExceptionWithKNNFieldWithBinaryDataType(Class clazz) throws NoSuchMethodException { Constructor constructor = clazz.getConstructor(Object.class, MappedFieldType.class); - KNNVectorFieldMapper.KNNVectorFieldType invalidFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType invalidFieldType = mock(KNNVectorFieldType.class); when(invalidFieldType.getVectorDataType()).thenReturn(VectorDataType.BINARY); Exception e = expectThrows(InvocationTargetException.class, () -> constructor.newInstance(null, invalidFieldType)); assertTrue(e.getCause() instanceof IllegalArgumentException); @@ -58,12 +58,7 @@ public void testL2_whenValid_thenSucceed() { float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); - KNNVectorFieldMapper.KNNVectorFieldType fieldType = new KNNVectorFieldMapper.KNNVectorFieldType( - "test", - Collections.emptyMap(), - 3, - knnMethodContext - ); + KNNVectorFieldType fieldType = new KNNVectorFieldType("test", Collections.emptyMap(), 3, knnMethodContext); KNNScoringSpace.L2 l2 = new KNNScoringSpace.L2(arrayListQueryObject, fieldType); assertEquals(1F, l2.getScoringMethod().apply(arrayFloat, arrayFloat), 0.1F); } @@ -80,12 +75,7 @@ public void testCosineSimilarity_whenValid_thenSucceed() { float[] arrayFloat2 = new float[] { 2.0f, 4.0f, 6.0f }; KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); - KNNVectorFieldMapper.KNNVectorFieldType fieldType = new KNNVectorFieldMapper.KNNVectorFieldType( - "test", - Collections.emptyMap(), - 3, - knnMethodContext - ); + KNNVectorFieldType fieldType = new KNNVectorFieldType("test", Collections.emptyMap(), 3, knnMethodContext); KNNScoringSpace.CosineSimilarity cosineSimilarity = new KNNScoringSpace.CosineSimilarity(arrayListQueryObject, fieldType); assertEquals(2F, cosineSimilarity.getScoringMethod().apply(arrayFloat2, arrayFloat), 0.1F); @@ -103,12 +93,7 @@ public void testCosineSimilarity_whenValid_thenSucceed() { public void testCosineSimilarity_whenZeroVector_thenException() { KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); - KNNVectorFieldMapper.KNNVectorFieldType fieldType = new KNNVectorFieldMapper.KNNVectorFieldType( - "test", - Collections.emptyMap(), - 3, - knnMethodContext - ); + KNNVectorFieldType fieldType = new KNNVectorFieldType("test", Collections.emptyMap(), 3, knnMethodContext); final List queryZeroVector = List.of(0.0f, 0.0f, 0.0f); IllegalArgumentException exception1 = expectThrows( @@ -133,12 +118,7 @@ public void testInnerProd_whenValid_thenSucceed() { float[] arrayFloat2_case1 = new float[] { 1.0f, 1.0f, 1.0f }; KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); - KNNVectorFieldMapper.KNNVectorFieldType fieldType = new KNNVectorFieldMapper.KNNVectorFieldType( - "test", - Collections.emptyMap(), - 3, - knnMethodContext - ); + KNNVectorFieldType fieldType = new KNNVectorFieldType("test", Collections.emptyMap(), 3, knnMethodContext); KNNScoringSpace.InnerProd innerProd = new KNNScoringSpace.InnerProd(arrayListQueryObject_case1, fieldType); assertEquals(7.0F, innerProd.getScoringMethod().apply(arrayFloat_case1, arrayFloat2_case1), 0.001F); @@ -204,7 +184,7 @@ public void testHammingBit_Base64() { public void testHamming_whenKNNFieldType_thenSucceed() { List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); - KNNVectorFieldMapper.KNNVectorFieldType fieldType = new KNNVectorFieldMapper.KNNVectorFieldType( + KNNVectorFieldType fieldType = new KNNVectorFieldType( "test", Collections.emptyMap(), 8 * arrayListQueryObject.size(), @@ -218,7 +198,7 @@ public void testHamming_whenKNNFieldType_thenSucceed() { } public void testHamming_whenNonBinaryVectorDataType_thenException() { - KNNVectorFieldMapper.KNNVectorFieldType invalidFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType invalidFieldType = mock(KNNVectorFieldType.class); when(invalidFieldType.getVectorDataType()).thenReturn(randomInt() % 2 == 0 ? VectorDataType.FLOAT : VectorDataType.BYTE); Exception e = expectThrows(IllegalArgumentException.class, () -> new KNNScoringSpace.Hamming(null, invalidFieldType)); assertTrue(e.getMessage(), e.getMessage().contains("The data type should be [BINARY]")); diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java index ace3dabc8..781ed2350 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java @@ -7,9 +7,9 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.index.mapper.BinaryFieldMapper; import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.knn.index.mapper.KNNVectorFieldType; import java.math.BigInteger; import java.util.ArrayList; @@ -32,7 +32,7 @@ public void testFieldTypeCheck() { KNNScoringSpaceUtil.isBinaryFieldType(new NumberFieldMapper.NumberFieldType("field", NumberFieldMapper.NumberType.INTEGER)) ); - assertTrue(KNNScoringSpaceUtil.isKNNVectorFieldType(mock(KNNVectorFieldMapper.KNNVectorFieldType.class))); + assertTrue(KNNScoringSpaceUtil.isKNNVectorFieldType(mock(KNNVectorFieldType.class))); assertFalse(KNNScoringSpaceUtil.isKNNVectorFieldType(new BinaryFieldMapper.BinaryFieldType("test"))); } @@ -62,7 +62,7 @@ public void testParseKNNVectorQuery() { float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); - KNNVectorFieldMapper.KNNVectorFieldType fieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType fieldType = mock(KNNVectorFieldType.class); when(fieldType.getDimension()).thenReturn(3); assertArrayEquals(arrayFloat, KNNScoringSpaceUtil.parseToFloatArray(arrayListQueryObject, 3, VectorDataType.FLOAT), 0.1f); @@ -77,13 +77,13 @@ public void testParseKNNVectorQuery() { } public void testIsBinaryVectorDataType_whenBinary_thenReturnTrue() { - KNNVectorFieldMapper.KNNVectorFieldType fieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType fieldType = mock(KNNVectorFieldType.class); when(fieldType.getVectorDataType()).thenReturn(VectorDataType.BINARY); assertTrue(KNNScoringSpaceUtil.isBinaryVectorDataType(fieldType)); } public void testIsBinaryVectorDataType_whenNonBinary_thenReturnFalse() { - KNNVectorFieldMapper.KNNVectorFieldType fieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNVectorFieldType fieldType = mock(KNNVectorFieldType.class); when(fieldType.getVectorDataType()).thenReturn(randomInt() % 2 == 0 ? VectorDataType.FLOAT : VectorDataType.BYTE); assertFalse(KNNScoringSpaceUtil.isBinaryVectorDataType(fieldType)); }