diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/RedisEnhancedKeyValueAdapter.java b/redis-om-spring/src/main/java/com/redis/om/spring/RedisEnhancedKeyValueAdapter.java index ea0b0b2f..8d4a9e64 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/RedisEnhancedKeyValueAdapter.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/RedisEnhancedKeyValueAdapter.java @@ -7,7 +7,7 @@ import com.redis.om.spring.indexing.RediSearchIndexer; import com.redis.om.spring.ops.RedisModulesOperations; import com.redis.om.spring.ops.search.SearchOperations; -import com.redis.om.spring.vectorize.FeatureExtractor; +import com.redis.om.spring.vectorize.Embedder; import org.springframework.data.convert.CustomConversions; import org.springframework.data.mapping.PersistentPropertyAccessor; import org.springframework.data.redis.connection.RedisConnection; @@ -41,7 +41,7 @@ public class RedisEnhancedKeyValueAdapter extends RedisKeyValueAdapter { private final RedisModulesOperations modulesOperations; private final RediSearchIndexer indexer; private final EntityAuditor auditor; - private final FeatureExtractor featureExtractor; + private final Embedder embedder; private final RedisOMProperties redisOMProperties; /** @@ -56,9 +56,9 @@ public RedisEnhancedKeyValueAdapter( // RedisOperations redisOps, // RedisModulesOperations rmo, // RediSearchIndexer indexer, // - FeatureExtractor featureExtractor, // + Embedder embedder, // RedisOMProperties redisOMProperties) { - this(redisOps, rmo, new RedisMappingContext(), indexer, featureExtractor, redisOMProperties); + this(redisOps, rmo, new RedisMappingContext(), indexer, embedder, redisOMProperties); } /** @@ -75,9 +75,9 @@ public RedisEnhancedKeyValueAdapter( // RedisModulesOperations rmo, // RedisMappingContext mappingContext, // RediSearchIndexer indexer, // - FeatureExtractor featureExtractor, // + Embedder embedder, // RedisOMProperties redisOMProperties) { - this(redisOps, rmo, mappingContext, new RedisOMCustomConversions(), indexer, featureExtractor, redisOMProperties); + this(redisOps, rmo, mappingContext, new RedisOMCustomConversions(), indexer, embedder, redisOMProperties); } /** @@ -96,7 +96,7 @@ public RedisEnhancedKeyValueAdapter( // RedisMappingContext mappingContext, // @Nullable CustomConversions customConversions, // RediSearchIndexer indexer, // - FeatureExtractor featureExtractor, // + Embedder embedder, // RedisOMProperties redisOMProperties) { super(redisOps, mappingContext, customConversions); @@ -114,7 +114,7 @@ public RedisEnhancedKeyValueAdapter( // this.modulesOperations = (RedisModulesOperations) rmo; this.indexer = indexer; this.auditor = new EntityAuditor(this.redisOperations); - this.featureExtractor = featureExtractor; + this.embedder = embedder; this.redisOMProperties = redisOMProperties; } @@ -141,7 +141,7 @@ public Object put(Object id, Object item, String keyspace) { } byte[] redisKey = createKey(sanitizeKeyspace(keyspace), idAsString); auditor.processEntity(redisKey, item); - featureExtractor.processEntity(item); + embedder.processEntity(item); rdo = new RedisData(); converter.write(item, rdo); diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/RedisJSONKeyValueAdapter.java b/redis-om-spring/src/main/java/com/redis/om/spring/RedisJSONKeyValueAdapter.java index c43942fd..8c497356 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/RedisJSONKeyValueAdapter.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/RedisJSONKeyValueAdapter.java @@ -11,7 +11,7 @@ import com.redis.om.spring.ops.json.JSONOperations; import com.redis.om.spring.ops.search.SearchOperations; import com.redis.om.spring.util.ObjectUtils; -import com.redis.om.spring.vectorize.FeatureExtractor; +import com.redis.om.spring.vectorize.Embedder; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.beans.*; @@ -53,7 +53,7 @@ public class RedisJSONKeyValueAdapter extends RedisKeyValueAdapter { private final RediSearchIndexer indexer; private final GsonBuilder gsonBuilder; private final EntityAuditor auditor; - private final FeatureExtractor featureExtractor; + private final Embedder embedder; private final RedisOMProperties redisOMProperties; /** @@ -72,7 +72,7 @@ public RedisJSONKeyValueAdapter( // RedisMappingContext mappingContext, // RediSearchIndexer indexer, // GsonBuilder gsonBuilder, // - FeatureExtractor featureExtractor, // + Embedder embedder, // RedisOMProperties redisOMProperties) { super(redisOps, mappingContext, new RedisOMCustomConversions()); this.modulesOperations = (RedisModulesOperations) rmo; @@ -82,7 +82,7 @@ public RedisJSONKeyValueAdapter( // this.indexer = indexer; this.auditor = new EntityAuditor(this.redisOperations); this.gsonBuilder = gsonBuilder; - this.featureExtractor = featureExtractor; + this.embedder = embedder; this.redisOMProperties = redisOMProperties; } @@ -102,7 +102,7 @@ public Object put(Object id, Object item, String keyspace) { processVersion(key, item); auditor.processEntity(key, item); - featureExtractor.processEntity(item); + embedder.processEntity(item); Optional maybeTtl = getTTLForEntity(item); ops.set(key, item); diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/RedisModulesConfiguration.java b/redis-om-spring/src/main/java/com/redis/om/spring/RedisModulesConfiguration.java index 269636fc..6875d1bc 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/RedisModulesConfiguration.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/RedisModulesConfiguration.java @@ -31,9 +31,9 @@ import com.redis.om.spring.search.stream.EntityStream; import com.redis.om.spring.search.stream.EntityStreamImpl; import com.redis.om.spring.serialization.gson.*; -import com.redis.om.spring.vectorize.DefaultFeatureExtractor; -import com.redis.om.spring.vectorize.FeatureExtractor; -import com.redis.om.spring.vectorize.NoopFeatureExtractor; +import com.redis.om.spring.vectorize.DefaultEmbedder; +import com.redis.om.spring.vectorize.Embedder; +import com.redis.om.spring.vectorize.NoopEmbedder; import com.redis.om.spring.vectorize.face.FaceDetectionTranslator; import com.redis.om.spring.vectorize.face.FaceFeatureTranslator; import org.apache.commons.lang3.ObjectUtils; @@ -562,7 +562,7 @@ BedrockTitanEmbeddingModel bedrockTitanEmbeddingModel(RedisOMProperties properti } @Bean(name = "featureExtractor") - public FeatureExtractor featureExtractor( + public Embedder featureExtractor( @Nullable @Qualifier("djlImageEmbeddingModel") ZooModel imageEmbeddingModel, @Nullable @Qualifier("djlFaceEmbeddingModel") ZooModel faceEmbeddingModel, @Nullable @Qualifier("djlImageFactory") ImageFactory imageFactory, @@ -574,10 +574,10 @@ public FeatureExtractor featureExtractor( @Nullable BedrockTitanEmbeddingModel bedrockTitanEmbeddingModel, RedisOMProperties properties, ApplicationContext ac) { return properties.getDjl().isEnabled() ? - new DefaultFeatureExtractor(ac, imageEmbeddingModel, faceEmbeddingModel, imageFactory, defaultImagePipeline, + new DefaultEmbedder(ac, imageEmbeddingModel, faceEmbeddingModel, imageFactory, defaultImagePipeline, sentenceTokenizer, openAITextVectorizer, azureOpenAIClient, vertexAiPaLm2EmbeddingModel, bedrockCohereEmbeddingModel, bedrockTitanEmbeddingModel, properties) : - new NoopFeatureExtractor(); + new NoopEmbedder(); } @Bean(name = "redisJSONKeyValueAdapter") @@ -588,9 +588,9 @@ RedisJSONKeyValueAdapter getRedisJSONKeyValueAdapter( // RediSearchIndexer indexer, // @Qualifier("omGsonBuilder") GsonBuilder gsonBuilder, // RedisOMProperties properties, // - @Nullable @Qualifier("featureExtractor") FeatureExtractor featureExtractor) { + @Nullable @Qualifier("featureExtractor") Embedder embedder) { return new RedisJSONKeyValueAdapter(redisOps, redisModulesOperations, mappingContext, indexer, gsonBuilder, - featureExtractor, properties); + embedder, properties); } @Bean(name = "redisJSONKeyValueTemplate") @@ -601,10 +601,10 @@ public CustomRedisKeyValueTemplate getRedisJSONKeyValueTemplate( // RediSearchIndexer indexer, // @Qualifier("omGsonBuilder") GsonBuilder gsonBuilder, // RedisOMProperties properties, // - @Nullable @Qualifier("featureExtractor") FeatureExtractor featureExtractor) { + @Nullable @Qualifier("featureExtractor") Embedder embedder) { return new CustomRedisKeyValueTemplate( - new RedisJSONKeyValueAdapter(redisOps, redisModulesOperations, mappingContext, indexer, gsonBuilder, - featureExtractor, properties), mappingContext); + new RedisJSONKeyValueAdapter(redisOps, redisModulesOperations, mappingContext, indexer, gsonBuilder, embedder, + properties), mappingContext); } @Bean(name = "redisCustomKeyValueTemplate") @@ -614,9 +614,9 @@ public CustomRedisKeyValueTemplate getKeyValueTemplate( // RedisMappingContext mappingContext, // RediSearchIndexer indexer, // RedisOMProperties properties, // - @Nullable @Qualifier("featureExtractor") FeatureExtractor featureExtractor) { + @Nullable @Qualifier("featureExtractor") Embedder embedder) { return new CustomRedisKeyValueTemplate( - new RedisEnhancedKeyValueAdapter(redisOps, redisModulesOperations, mappingContext, indexer, featureExtractor, + new RedisEnhancedKeyValueAdapter(redisOps, redisModulesOperations, mappingContext, indexer, embedder, properties), // mappingContext); } diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisDocumentRepositoryFactory.java b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisDocumentRepositoryFactory.java index bc37e439..ac9f5ccf 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisDocumentRepositoryFactory.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisDocumentRepositoryFactory.java @@ -5,7 +5,7 @@ import com.redis.om.spring.indexing.RediSearchIndexer; import com.redis.om.spring.ops.RedisModulesOperations; import com.redis.om.spring.repository.query.RediSearchQuery; -import com.redis.om.spring.vectorize.FeatureExtractor; +import com.redis.om.spring.vectorize.Embedder; import org.springframework.beans.BeanUtils; import org.springframework.data.keyvalue.core.KeyValueOperations; import org.springframework.data.keyvalue.repository.query.KeyValuePartTreeQuery; @@ -42,7 +42,7 @@ public class RedisDocumentRepositoryFactory extends KeyValueRepositoryFactory { private final RediSearchIndexer indexer; private final GsonBuilder gsonBuilder; private final RedisMappingContext mappingContext; - private final FeatureExtractor featureExtractor; + private final Embedder embedder; private final RedisOMProperties properties; /** @@ -54,7 +54,7 @@ public class RedisDocumentRepositoryFactory extends KeyValueRepositoryFactory { * @param indexer must not be {@literal null}. * @param mappingContext must not be {@literal null}. * @param gsonBuilder must not be {@literal null}. - * @param featureExtractor must not be {@literal null}. + * @param embedder must not be {@literal null}. * @param properties must not be {@literal null}. */ public RedisDocumentRepositoryFactory( // @@ -63,7 +63,7 @@ public RedisDocumentRepositoryFactory( // RediSearchIndexer indexer, // RedisMappingContext mappingContext, // GsonBuilder gsonBuilder, // - FeatureExtractor featureExtractor, // + Embedder embedder, // RedisOMProperties properties // ) { this( // @@ -73,7 +73,7 @@ public RedisDocumentRepositoryFactory( // DEFAULT_QUERY_CREATOR, // mappingContext, // gsonBuilder, // - featureExtractor, // + embedder, // properties // ); // } @@ -96,7 +96,7 @@ public RedisDocumentRepositoryFactory( // Class> queryCreator, // RedisMappingContext mappingContext, // GsonBuilder gsonBuilder, // - FeatureExtractor featureExtractor, // + Embedder embedder, // RedisOMProperties properties // ) { this( // @@ -107,7 +107,7 @@ public RedisDocumentRepositoryFactory( // RediSearchQuery.class, // mappingContext, // gsonBuilder, // - featureExtractor, // + embedder, // properties // ); } @@ -123,7 +123,7 @@ public RedisDocumentRepositoryFactory( // * @param repositoryQueryType must not be {@literal null}. * @param mappingContext must not be {@literal null}. * @param gsonBuilder must not be {@literal null}. - * @param featureExtractor must not be {@literal null}. + * @param embedder must not be {@literal null}. * @param properties must not be {@literal null}. */ public RedisDocumentRepositoryFactory( // @@ -134,7 +134,7 @@ public RedisDocumentRepositoryFactory( // Class repositoryQueryType, // RedisMappingContext mappingContext, // GsonBuilder gsonBuilder, // - FeatureExtractor featureExtractor, // + Embedder embedder, // RedisOMProperties properties // ) { @@ -145,7 +145,7 @@ public RedisDocumentRepositoryFactory( // Assert.notNull(indexer, "RediSearchIndexer must not be null!"); Assert.notNull(queryCreator, "Query creator type must not be null!"); Assert.notNull(repositoryQueryType, "RepositoryQueryType type must not be null!"); - Assert.notNull(featureExtractor, "FeatureExtractor type must not be null!"); + Assert.notNull(embedder, "FeatureExtractor type must not be null!"); Assert.notNull(properties, "RedisOMSpringProperties type must not be null!"); this.keyValueOperations = keyValueOperations; @@ -155,7 +155,7 @@ public RedisDocumentRepositoryFactory( // this.repositoryQueryType = repositoryQueryType; this.mappingContext = mappingContext; this.gsonBuilder = gsonBuilder; - this.featureExtractor = featureExtractor; + this.embedder = embedder; this.properties = properties; } @@ -170,7 +170,7 @@ protected Object getTargetRepository(RepositoryInformation repositoryInformation indexer, // mappingContext, // gsonBuilder, // - featureExtractor, // + embedder, // properties // ); } diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisDocumentRepositoryFactoryBean.java b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisDocumentRepositoryFactoryBean.java index 6efa7667..dd6314ae 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisDocumentRepositoryFactoryBean.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisDocumentRepositoryFactoryBean.java @@ -4,7 +4,7 @@ import com.redis.om.spring.RedisOMProperties; import com.redis.om.spring.indexing.RediSearchIndexer; import com.redis.om.spring.ops.RedisModulesOperations; -import com.redis.om.spring.vectorize.FeatureExtractor; +import com.redis.om.spring.vectorize.Embedder; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.keyvalue.core.KeyValueOperations; import org.springframework.data.keyvalue.repository.support.KeyValueRepositoryFactoryBean; @@ -27,7 +27,7 @@ public class RedisDocumentRepositoryFactoryBean, S, @Autowired private GsonBuilder gsonBuilder; @Autowired - private @Nullable FeatureExtractor featureExtractor; + private @Nullable Embedder embedder; @Autowired private RedisOMProperties properties; @@ -49,7 +49,7 @@ protected final RedisDocumentRepositoryFactory createRepositoryFactory( // Class repositoryQueryType // ) { return new RedisDocumentRepositoryFactory(operations, rmo, indexer, queryCreator, repositoryQueryType, - this.mappingContext, this.gsonBuilder, this.featureExtractor, this.properties); + this.mappingContext, this.gsonBuilder, this.embedder, this.properties); } @Override diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisEnhancedRepositoryFactory.java b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisEnhancedRepositoryFactory.java index 23b749e9..535ec17f 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisEnhancedRepositoryFactory.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisEnhancedRepositoryFactory.java @@ -4,7 +4,7 @@ import com.redis.om.spring.indexing.RediSearchIndexer; import com.redis.om.spring.ops.RedisModulesOperations; import com.redis.om.spring.repository.query.RedisEnhancedQuery; -import com.redis.om.spring.vectorize.FeatureExtractor; +import com.redis.om.spring.vectorize.Embedder; import org.springframework.beans.BeanUtils; import org.springframework.data.keyvalue.core.KeyValueOperations; import org.springframework.data.keyvalue.repository.query.KeyValuePartTreeQuery; @@ -45,7 +45,7 @@ public class RedisEnhancedRepositoryFactory extends RepositoryFactorySupport { private final MappingContext context; private final Class> queryCreator; private final Class repositoryQueryType; - private final FeatureExtractor featureExtractor; + private final Embedder embedder; private final RedisOMProperties properties; /** @@ -56,7 +56,7 @@ public class RedisEnhancedRepositoryFactory extends RepositoryFactorySupport { * @param redisOperations must not be {@literal null}. * @param rmo must not be {@literal null}. * @param indexer must not be {@literal null}. - * @param featureExtractor must not be {@literal null}. + * @param embedder must not be {@literal null}. * @param properties must not be {@literal null}. */ public RedisEnhancedRepositoryFactory( // @@ -64,9 +64,9 @@ public RedisEnhancedRepositoryFactory( // RedisOperations redisOperations, // RedisModulesOperations rmo, // RediSearchIndexer indexer, // - FeatureExtractor featureExtractor, // + Embedder embedder, // RedisOMProperties properties) { - this(keyValueOperations, redisOperations, rmo, indexer, featureExtractor, DEFAULT_QUERY_CREATOR, properties); + this(keyValueOperations, redisOperations, rmo, indexer, embedder, DEFAULT_QUERY_CREATOR, properties); } /** @@ -77,7 +77,7 @@ public RedisEnhancedRepositoryFactory( // * @param redisOperations must not be {@literal null}. * @param rmo must not be {@literal null}. * @param indexer must not be {@literal null}. - * @param featureExtractor must not be {@literal null}. + * @param embedder must not be {@literal null}. * @param queryCreator must not be {@literal null}. * @param properties must not be {@literal null}. */ @@ -86,7 +86,7 @@ public RedisEnhancedRepositoryFactory( // RedisOperations redisOperations, // RedisModulesOperations rmo, // RediSearchIndexer indexer, // - FeatureExtractor featureExtractor, // + Embedder embedder, // Class> queryCreator, // RedisOMProperties properties) { @@ -95,7 +95,7 @@ public RedisEnhancedRepositoryFactory( // redisOperations, // rmo, // indexer, // - featureExtractor, // + embedder, // queryCreator, // RedisEnhancedQuery.class, // properties // @@ -118,7 +118,7 @@ public RedisEnhancedRepositoryFactory( // RedisOperations redisOperations, // RedisModulesOperations rmo, // RediSearchIndexer indexer, // - FeatureExtractor featureExtractor, // + Embedder embedder, // Class> queryCreator, // Class repositoryQueryType, // RedisOMProperties properties) { @@ -128,7 +128,7 @@ public RedisEnhancedRepositoryFactory( // Assert.notNull(rmo, "RedisModulesOperations must not be null!"); Assert.notNull(queryCreator, "Query creator type must not be null!"); Assert.notNull(repositoryQueryType, "RepositoryQueryType type must not be null!"); - Assert.notNull(featureExtractor, "FeatureExtractor type must not be null!"); + Assert.notNull(embedder, "FeatureExtractor type must not be null!"); Assert.notNull(properties, "RedisOMSpringProperties type must not be null!"); this.keyValueOperations = keyValueOperations; @@ -138,7 +138,7 @@ public RedisEnhancedRepositoryFactory( // this.context = keyValueOperations.getMappingContext(); this.queryCreator = queryCreator; this.repositoryQueryType = repositoryQueryType; - this.featureExtractor = featureExtractor; + this.embedder = embedder; this.properties = properties; } @@ -166,7 +166,7 @@ public EntityInformation getEntityInformation(Class domainClas protected Object getTargetRepository(RepositoryInformation repositoryInformation) { EntityInformation entityInformation = getEntityInformation(repositoryInformation.getDomainType()); return super.getTargetRepositoryViaReflection(repositoryInformation, entityInformation, keyValueOperations, rmo, - indexer, featureExtractor, properties); + indexer, embedder, properties); } /* (non-Javadoc) diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisEnhancedRepositoryFactoryBean.java b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisEnhancedRepositoryFactoryBean.java index 49e67a2e..a9a0508c 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisEnhancedRepositoryFactoryBean.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisEnhancedRepositoryFactoryBean.java @@ -4,7 +4,7 @@ import com.redis.om.spring.indexing.RediSearchIndexer; import com.redis.om.spring.ops.RedisModulesOperations; import com.redis.om.spring.repository.query.RedisEnhancedQuery; -import com.redis.om.spring.vectorize.FeatureExtractor; +import com.redis.om.spring.vectorize.Embedder; import org.springframework.data.keyvalue.core.KeyValueOperations; import org.springframework.data.keyvalue.repository.config.QueryCreatorType; import org.springframework.data.mapping.context.MappingContext; @@ -27,7 +27,7 @@ public class RedisEnhancedRepositoryFactoryBean, S, private @Nullable RediSearchIndexer indexer; private @Nullable Class> queryCreator; private @Nullable Class repositoryQueryType; - private @Nullable FeatureExtractor featureExtractor; + private @Nullable Embedder embedder; private RedisOMProperties properties; @@ -45,20 +45,20 @@ public RedisEnhancedRepositoryFactoryBean( // RedisOperations redisOperations, // RedisModulesOperations rmo, // RediSearchIndexer indexer, // - FeatureExtractor featureExtractor, // + Embedder embedder, // RedisOMProperties properties) { super(repositoryInterface); setRedisModulesOperations(rmo); setRedisOperations(redisOperations); setKeyspaceToIndexMap(indexer); - setFeatureExtractor(featureExtractor); + setFeatureExtractor(embedder); setRedisOMSpringProperties(properties); } - private void setFeatureExtractor(FeatureExtractor featureExtractor) { + private void setFeatureExtractor(Embedder embedder) { Assert.notNull(rmo, "FeatureExtractor must not be null!"); - this.featureExtractor = featureExtractor; + this.embedder = embedder; } /** @@ -162,7 +162,7 @@ protected RedisEnhancedRepositoryFactory createRepositoryFactory( // Class> queryCreator, // Class repositoryQueryType // ) { - return new RedisEnhancedRepositoryFactory(operations, redisOperations, rmo, indexer, featureExtractor, queryCreator, + return new RedisEnhancedRepositoryFactory(operations, redisOperations, rmo, indexer, embedder, queryCreator, RedisEnhancedQuery.class, properties); } @@ -179,7 +179,7 @@ public void afterPropertiesSet() { Assert.notNull(queryCreator, "Query creator must not be null!"); Assert.notNull(repositoryQueryType, "RepositoryQueryType must not be null!"); Assert.notNull(indexer, "RediSearchIndexer type must not be null"); - Assert.notNull(featureExtractor, "FeatureExtractor type must not be null!"); + Assert.notNull(embedder, "FeatureExtractor type must not be null!"); super.afterPropertiesSet(); } diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisDocumentRepository.java b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisDocumentRepository.java index 1cfb0a5a..0662ce33 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisDocumentRepository.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisDocumentRepository.java @@ -22,7 +22,7 @@ import com.redis.om.spring.search.stream.SearchStream; import com.redis.om.spring.serialization.gson.GsonListOfType; import com.redis.om.spring.util.ObjectUtils; -import com.redis.om.spring.vectorize.FeatureExtractor; +import com.redis.om.spring.vectorize.Embedder; import org.springframework.beans.BeanWrapper; import org.springframework.beans.BeanWrapperImpl; import org.springframework.beans.PropertyAccessor; @@ -79,7 +79,7 @@ public class SimpleRedisDocumentRepository extends SimpleKeyValueReposito protected final RediSearchIndexer indexer; protected final MappingRedisOMConverter mappingConverter; protected final EntityAuditor auditor; - protected final FeatureExtractor featureExtractor; + protected final Embedder embedder; private final GsonBuilder gsonBuilder; private final ULIDIdentifierGenerator generator; private final RedisOMProperties properties; @@ -94,7 +94,7 @@ public SimpleRedisDocumentRepository( // RediSearchIndexer indexer, // RedisMappingContext mappingContext, // GsonBuilder gsonBuilder, // - FeatureExtractor featureExtractor, // + Embedder embedder, // RedisOMProperties properties) { super(metadata, operations); this.modulesOperations = (RedisModulesOperations) rmo; @@ -106,7 +106,7 @@ public SimpleRedisDocumentRepository( // this.gsonBuilder = gsonBuilder; this.mappingContext = mappingContext; this.auditor = new EntityAuditor(modulesOperations.template()); - this.featureExtractor = featureExtractor; + this.embedder = embedder; this.properties = properties; this.entityStream = new EntityStreamImpl(modulesOperations, modulesOperations.gsonBuilder(), indexer); } @@ -192,7 +192,7 @@ public List saveAll(Iterable entities) { // process entity pre-save mutation auditor.processEntity(entity, isNew); - featureExtractor.processEntity(entity); + embedder.processEntity(entity); Optional maybeTtl = getTTLForEntity(entity); diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisEnhancedRepository.java b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisEnhancedRepository.java index 7f90350d..ce4b3048 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisEnhancedRepository.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisEnhancedRepository.java @@ -18,7 +18,7 @@ import com.redis.om.spring.search.stream.RedisFluentQueryByExample; import com.redis.om.spring.search.stream.SearchStream; import com.redis.om.spring.util.ObjectUtils; -import com.redis.om.spring.vectorize.FeatureExtractor; +import com.redis.om.spring.vectorize.Embedder; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.dao.IncorrectResultSizeDataAccessException; import org.springframework.data.domain.*; @@ -62,7 +62,7 @@ public class SimpleRedisEnhancedRepository extends SimpleKeyValueReposito protected final MappingRedisOMConverter mappingConverter; protected final RedisEnhancedKeyValueAdapter enhancedKeyValueAdapter; protected final EntityAuditor auditor; - protected final FeatureExtractor featureExtractor; + protected final Embedder embedder; private final ULIDIdentifierGenerator generator; private final RedisOMProperties properties; @@ -75,7 +75,7 @@ public SimpleRedisEnhancedRepository( // KeyValueOperations operations, // @Qualifier("redisModulesOperations") RedisModulesOperations rmo, // RediSearchIndexer indexer, // - FeatureExtractor featureExtractor, // + Embedder embedder, // RedisOMProperties properties // ) { super(metadata, operations); @@ -84,11 +84,11 @@ public SimpleRedisEnhancedRepository( // this.operations = operations; this.indexer = indexer; this.mappingConverter = new MappingRedisOMConverter(null, new ReferenceResolverImpl(modulesOperations.template())); - this.enhancedKeyValueAdapter = new RedisEnhancedKeyValueAdapter(rmo.template(), rmo, indexer, featureExtractor, + this.enhancedKeyValueAdapter = new RedisEnhancedKeyValueAdapter(rmo.template(), rmo, indexer, embedder, properties); this.generator = ULIDIdentifierGenerator.INSTANCE; this.auditor = new EntityAuditor(modulesOperations.template()); - this.featureExtractor = featureExtractor; + this.embedder = embedder; this.properties = properties; this.entityStream = new EntityStreamImpl(modulesOperations, modulesOperations.gsonBuilder(), indexer); } @@ -262,7 +262,7 @@ public List saveAll(Iterable entities) { // process entity pre-save mutation auditor.processEntity(entity, isNew); - featureExtractor.processEntity(entity); + embedder.processEntity(entity); RedisData rdo = new RedisData(); mappingConverter.write(entity, rdo); diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/util/ObjectUtils.java b/redis-om-spring/src/main/java/com/redis/om/spring/util/ObjectUtils.java index 0c920e19..4e145b11 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/util/ObjectUtils.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/util/ObjectUtils.java @@ -700,16 +700,10 @@ public static String getSchemaFieldType(Schema.Field field) { } public static byte[] doubleListToByteArray(List doubleList) { - // Allocate ByteBuffer of appropriate size - ByteBuffer buffer = ByteBuffer.allocate(doubleList.size() * Double.BYTES); - - // Fill the ByteBuffer with double values from the list - for (Double value : doubleList) { - buffer.putDouble(value); - } - - // Convert ByteBuffer to byte[] - return buffer.array(); + byte[] bytes = new byte[Float.BYTES * doubleList.size()]; + float[] input = doubleListToFloatArray(doubleList); + ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().put(input); + return bytes; } public static float[] doubleListToFloatArray(List doubleList) { diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/DefaultEmbedder.java b/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/DefaultEmbedder.java new file mode 100644 index 00000000..22ad459e --- /dev/null +++ b/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/DefaultEmbedder.java @@ -0,0 +1,482 @@ +package com.redis.om.spring.vectorize; + +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.translator.ImageFeatureExtractor; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.Pipeline; +import ai.djl.translate.TranslateException; +import com.azure.ai.openai.OpenAIClient; +import com.redis.om.spring.RedisOMProperties; +import com.redis.om.spring.annotations.Document; +import com.redis.om.spring.annotations.Vectorize; +import com.redis.om.spring.metamodel.MetamodelField; +import com.redis.om.spring.util.ObjectUtils; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel; +import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingOptions; +import org.springframework.ai.bedrock.cohere.BedrockCohereEmbeddingModel; +import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi; +import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingModel; +import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingModel; +import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi; +import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingModel; +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.ollama.OllamaEmbeddingModel; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.openai.OpenAiEmbeddingModel; +import org.springframework.ai.openai.OpenAiEmbeddingOptions; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.vertexai.palm2.VertexAiPaLm2EmbeddingModel; +import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api; +import org.springframework.beans.PropertyAccessor; +import org.springframework.beans.PropertyAccessorFactory; +import org.springframework.context.ApplicationContext; +import org.springframework.core.io.Resource; +import org.springframework.web.client.RestClient; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; + +import java.io.IOException; +import java.io.InputStream; +import java.lang.reflect.Field; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static com.redis.om.spring.annotations.EmbeddingType.SENTENCE; +import static com.redis.om.spring.util.ObjectUtils.byteArrayToFloatArray; +import static com.redis.om.spring.util.ObjectUtils.longArrayToFloatArray; + +public class DefaultEmbedder implements Embedder { + private static final Log logger = LogFactory.getLog(DefaultEmbedder.class); + public final Pipeline imagePipeline; + public final HuggingFaceTokenizer sentenceTokenizer; + private final ZooModel imageEmbeddingModel; + private final ZooModel faceEmbeddingModel; + private final ImageFactory imageFactory; + private final ApplicationContext applicationContext; + private final ImageFeatureExtractor imageFeatureExtractor; + private final OpenAiEmbeddingModel defaultOpenAITextVectorizer; + private final OllamaEmbeddingModel defaultOllamaEmbeddingModel; + private final RedisOMProperties properties; + private final OllamaApi ollamaApi; + private final OpenAIClient azureOpenAIClient; + private final VertexAiPaLm2EmbeddingModel vertexAiPaLm2EmbeddingModel; + private final BedrockCohereEmbeddingModel bedrockCohereEmbeddingModel; + private final BedrockTitanEmbeddingModel bedrockTitanEmbeddingModel; + + public DefaultEmbedder( // + ApplicationContext applicationContext, // + ZooModel imageEmbeddingModel, // + ZooModel faceEmbeddingModel, // + ImageFactory imageFactory, // + Pipeline imagePipeline, // + HuggingFaceTokenizer sentenceTokenizer, // + OpenAiEmbeddingModel openAITextVectorizer, // + OpenAIClient azureOpenAIClient, // + VertexAiPaLm2EmbeddingModel vertexAiPaLm2EmbeddingModel, // + BedrockCohereEmbeddingModel bedrockCohereEmbeddingModel, // + BedrockTitanEmbeddingModel bedrockTitanEmbeddingModel, // + RedisOMProperties properties // + ) { + this.applicationContext = applicationContext; + this.imageEmbeddingModel = imageEmbeddingModel; + this.faceEmbeddingModel = faceEmbeddingModel; + this.imageFactory = imageFactory; + this.imagePipeline = imagePipeline; + this.sentenceTokenizer = sentenceTokenizer; + + // feature extractor + this.imageFeatureExtractor = ImageFeatureExtractor.builder().setPipeline(imagePipeline).build(); + this.defaultOpenAITextVectorizer = openAITextVectorizer; + this.azureOpenAIClient = azureOpenAIClient; + this.vertexAiPaLm2EmbeddingModel = vertexAiPaLm2EmbeddingModel; + this.bedrockCohereEmbeddingModel = bedrockCohereEmbeddingModel; + this.bedrockTitanEmbeddingModel = bedrockTitanEmbeddingModel; + this.properties = properties; + + this.ollamaApi = new OllamaApi(properties.getOllama().getBaseUrl()); + + this.defaultOllamaEmbeddingModel = new OllamaEmbeddingModel(ollamaApi, + new OllamaOptions().withModel(OllamaOptions.DEFAULT_MODEL)); + } + + private byte[] getImageEmbeddingsAsByteArrayFor(InputStream is) { + try { + var img = imageFactory.fromInputStream(is); + Predictor predictor = imageEmbeddingModel.newPredictor(imageFeatureExtractor); + return predictor.predict(img); + } catch (IOException | TranslateException e) { + logger.warn("Error generating image embedding", e); + return new byte[] {}; + } + } + + private float[] getImageEmbeddingsAsFloatArrayFor(InputStream is) { + return byteArrayToFloatArray(getImageEmbeddingsAsByteArrayFor(is)); + } + + private byte[] getFacialImageEmbeddingsAsByteArrayFor(InputStream is) throws IOException, TranslateException { + return ObjectUtils.floatArrayToByteArray(getFacialImageEmbeddingsAsFloatArrayFor(is)); + } + + private float[] getFacialImageEmbeddingsAsFloatArrayFor(InputStream is) throws IOException, TranslateException { + try (Predictor predictor = faceEmbeddingModel.newPredictor()) { + var img = imageFactory.fromInputStream(is); + return predictor.predict(img); + } + } + + private List getSentenceEmbeddingsAsByteArrayFor(List texts) { + Encoding[] encodings = sentenceTokenizer.batchEncode(texts); + return Arrays.stream(encodings).map(e -> ObjectUtils.longArrayToByteArray(e.getIds())).toList(); + } + + private List getSentenceEmbeddingAsFloatArrayFor(List texts) { + Encoding[] encodings = sentenceTokenizer.batchEncode(texts); + return Arrays.stream(encodings).map(e -> ObjectUtils.longArrayToFloatArray(e.getIds())).toList(); + } + + private byte[] getSentenceEmbeddingsAsByteArrayFor(String text) { + Encoding encoding = sentenceTokenizer.encode(text); + return ObjectUtils.longArrayToByteArray(encoding.getIds()); + } + + private float[] getSentenceEmbeddingAsFloatArrayFor(String text) { + Encoding encoding = sentenceTokenizer.encode(text); + return longArrayToFloatArray(encoding.getIds()); + } + + private List getEmbeddingsAsByteArrayFor(List texts, EmbeddingModel model) { + EmbeddingResponse embeddingResponse = model.embedForResponse(texts); + List embeddings = embeddingResponse.getResults(); + return embeddings.stream().map(e -> ObjectUtils.doubleListToByteArray(e.getOutput())).toList(); + } + + private List getEmbeddingAsFloatArrayFor(List texts, EmbeddingModel model) { + EmbeddingResponse embeddingResponse = model.embedForResponse(texts); + List embeddings = embeddingResponse.getResults(); + return embeddings.stream().map(e -> ObjectUtils.doubleListToFloatArray(e.getOutput())).toList(); + } + + private byte[] getEmbeddingsAsByteArrayFor(String text, EmbeddingModel model) { + EmbeddingResponse embeddingResponse = model.embedForResponse(List.of(text)); + Embedding embedding = embeddingResponse.getResult(); + return ObjectUtils.doubleListToByteArray(embedding.getOutput()); + } + + private float[] getEmbeddingAsFloatArrayFor(String text, EmbeddingModel model) { + EmbeddingResponse embeddingResponse = model.embedForResponse(List.of(text)); + Embedding embedding = embeddingResponse.getResult(); + return ObjectUtils.doubleListToFloatArray(embedding.getOutput()); + } + + @Override + public void processEntity(Object item) { + if (!isReady()) { + return; + } + List fields = ObjectUtils.getFieldsWithAnnotation(item.getClass(), Vectorize.class); + if (!fields.isEmpty()) { + PropertyAccessor accessor = PropertyAccessorFactory.forBeanPropertyAccess(item); + fields.forEach(f -> { + Vectorize vectorize = f.getAnnotation(Vectorize.class); + Object fieldValue = accessor.getPropertyValue(f.getName()); + boolean isDocument = item.getClass().isAnnotationPresent(Document.class); + + if (fieldValue != null) { + switch (vectorize.embeddingType()) { + case IMAGE -> processImageEmbedding(accessor, vectorize, fieldValue, isDocument); + case WORD -> { + //TODO: implement me! + } + case FACE -> processFaceEmbedding(accessor, vectorize, fieldValue, isDocument); + case SENTENCE -> processSentenceEmbedding(accessor, vectorize, fieldValue, isDocument); + } + } + }); + } + } + + private void processImageEmbedding(PropertyAccessor accessor, Vectorize vectorize, Object fieldValue, + boolean isDocument) { + Resource resource = applicationContext.getResource(fieldValue.toString()); + try { + if (isDocument) { + accessor.setPropertyValue(vectorize.destination(), + getImageEmbeddingsAsFloatArrayFor(resource.getInputStream())); + } else { + accessor.setPropertyValue(vectorize.destination(), getImageEmbeddingsAsByteArrayFor(resource.getInputStream())); + } + } catch (IOException e) { + logger.warn("Error generating image embedding", e); + } + } + + private void processFaceEmbedding(PropertyAccessor accessor, Vectorize vectorize, Object fieldValue, + boolean isDocument) { + Resource resource = applicationContext.getResource(fieldValue.toString()); + try { + if (isDocument) { + accessor.setPropertyValue(vectorize.destination(), + getFacialImageEmbeddingsAsFloatArrayFor(resource.getInputStream())); + } else { + accessor.setPropertyValue(vectorize.destination(), + getFacialImageEmbeddingsAsByteArrayFor(resource.getInputStream())); + } + } catch (IOException | TranslateException e) { + logger.warn("Error generating facial image embedding", e); + } + } + + private void processSentenceEmbedding(PropertyAccessor accessor, Vectorize vectorize, Object fieldValue, + boolean isDocument) { + switch (vectorize.provider()) { + case DJL -> processDjlSentenceEmbedding(accessor, vectorize, fieldValue, isDocument); + case OPENAI -> processOpenAiSentenceEmbedding(accessor, vectorize, fieldValue, isDocument); + case OLLAMA -> processOllamaSentenceEmbedding(accessor, vectorize, fieldValue, isDocument); + case AZURE_OPENAI -> processAzureOpenAiSentenceEmbedding(accessor, vectorize, fieldValue, isDocument); + case VERTEX_AI -> processVertexAiSentenceEmbedding(accessor, vectorize, fieldValue, isDocument); + case AMAZON_BEDROCK_COHERE -> processBedrockCohereSentenceEmbedding(accessor, vectorize, fieldValue, isDocument); + case AMAZON_BEDROCK_TITAN -> processBedrockTitanSentenceEmbedding(accessor, vectorize, fieldValue, isDocument); + } + } + + private void processDjlSentenceEmbedding(PropertyAccessor accessor, Vectorize vectorize, Object fieldValue, + boolean isDocument) { + if (isDocument) { + accessor.setPropertyValue(vectorize.destination(), getSentenceEmbeddingAsFloatArrayFor(fieldValue.toString())); + } else { + accessor.setPropertyValue(vectorize.destination(), getSentenceEmbeddingsAsByteArrayFor(fieldValue.toString())); + } + } + + private void processOpenAiSentenceEmbedding(PropertyAccessor accessor, Vectorize vectorize, Object fieldValue, + boolean isDocument) { + OpenAiEmbeddingModel model = getOpenAiEmbeddingModel(vectorize); + if (isDocument) { + accessor.setPropertyValue(vectorize.destination(), getEmbeddingAsFloatArrayFor(fieldValue.toString(), model)); + } else { + accessor.setPropertyValue(vectorize.destination(), getEmbeddingsAsByteArrayFor(fieldValue.toString(), model)); + } + } + + private void processOllamaSentenceEmbedding(PropertyAccessor accessor, Vectorize vectorize, Object fieldValue, + boolean isDocument) { + OllamaEmbeddingModel model = getOllamaEmbeddingModel(vectorize); + if (isDocument) { + accessor.setPropertyValue(vectorize.destination(), getEmbeddingAsFloatArrayFor(fieldValue.toString(), model)); + } else { + accessor.setPropertyValue(vectorize.destination(), getEmbeddingsAsByteArrayFor(fieldValue.toString(), model)); + } + } + + private void processAzureOpenAiSentenceEmbedding(PropertyAccessor accessor, Vectorize vectorize, Object fieldValue, + boolean isDocument) { + AzureOpenAiEmbeddingModel model = getAzureOpenAiEmbeddingModel(vectorize); + if (isDocument) { + accessor.setPropertyValue(vectorize.destination(), getEmbeddingAsFloatArrayFor(fieldValue.toString(), model)); + } else { + accessor.setPropertyValue(vectorize.destination(), getEmbeddingsAsByteArrayFor(fieldValue.toString(), model)); + } + } + + private void processVertexAiSentenceEmbedding(PropertyAccessor accessor, Vectorize vectorize, Object fieldValue, + boolean isDocument) { + VertexAiPaLm2EmbeddingModel model = getVertexAiPaLm2EmbeddingModel(vectorize); + if (isDocument) { + accessor.setPropertyValue(vectorize.destination(), getEmbeddingAsFloatArrayFor(fieldValue.toString(), model)); + } else { + accessor.setPropertyValue(vectorize.destination(), getEmbeddingsAsByteArrayFor(fieldValue.toString(), model)); + } + } + + private void processBedrockCohereSentenceEmbedding(PropertyAccessor accessor, Vectorize vectorize, Object fieldValue, + boolean isDocument) { + BedrockCohereEmbeddingModel model = getBedrockCohereEmbeddingModel(vectorize); + if (isDocument) { + accessor.setPropertyValue(vectorize.destination(), getEmbeddingAsFloatArrayFor(fieldValue.toString(), model)); + } else { + accessor.setPropertyValue(vectorize.destination(), getEmbeddingsAsByteArrayFor(fieldValue.toString(), model)); + } + } + + private void processBedrockTitanSentenceEmbedding(PropertyAccessor accessor, Vectorize vectorize, Object fieldValue, + boolean isDocument) { + BedrockTitanEmbeddingModel model = getBedrockTitanEmbeddingModel(vectorize); + if (isDocument) { + accessor.setPropertyValue(vectorize.destination(), getEmbeddingAsFloatArrayFor(fieldValue.toString(), model)); + } else { + accessor.setPropertyValue(vectorize.destination(), getEmbeddingsAsByteArrayFor(fieldValue.toString(), model)); + } + } + + private OpenAiEmbeddingModel getOpenAiEmbeddingModel(Vectorize vectorize) { + if (vectorize.openAiEmbeddingModel() != OpenAiApi.EmbeddingModel.TEXT_EMBEDDING_ADA_002) { + var openAiApi = new OpenAiApi(properties.getOpenAi().getApiKey()); + return new OpenAiEmbeddingModel(openAiApi, MetadataMode.EMBED, + OpenAiEmbeddingOptions.builder().withModel(vectorize.openAiEmbeddingModel().getValue()).build(), + RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + return this.defaultOpenAITextVectorizer; + } + + private OllamaEmbeddingModel getOllamaEmbeddingModel(Vectorize vectorize) { + if (!vectorize.ollamaEmbeddingModel().id().equals(OllamaOptions.DEFAULT_MODEL)) { + return new OllamaEmbeddingModel(ollamaApi, new OllamaOptions().withModel(vectorize.ollamaEmbeddingModel().id())); + } + return this.defaultOllamaEmbeddingModel; + } + + private AzureOpenAiEmbeddingModel getAzureOpenAiEmbeddingModel(Vectorize vectorize) { + AzureOpenAiEmbeddingOptions options = AzureOpenAiEmbeddingOptions.builder() + .withDeploymentName(vectorize.azureOpenAiDeploymentName()).build(); + return new AzureOpenAiEmbeddingModel(this.azureOpenAIClient, MetadataMode.EMBED, options); + } + + private VertexAiPaLm2EmbeddingModel getVertexAiPaLm2EmbeddingModel(Vectorize vectorize) { + if (!vectorize.vertexAiPaLm2ApiModel().equals(VertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL)) { + VertexAiPaLm2Api vertexAiApi = new VertexAiPaLm2Api(properties.getVertexAi().getEndPoint(), + properties.getVertexAi().getApiKey(), VertexAiPaLm2Api.DEFAULT_GENERATE_MODEL, + vectorize.vertexAiPaLm2ApiModel(), RestClient.builder()); + return new VertexAiPaLm2EmbeddingModel(vertexAiApi); + } + return this.vertexAiPaLm2EmbeddingModel; + } + + private BedrockCohereEmbeddingModel getBedrockCohereEmbeddingModel(Vectorize vectorize) { + if (!vectorize.cohereEmbeddingModel().equals(CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V1)) { + AwsCredentials credentials = AwsBasicCredentials.create(properties.getBedrockCohere().getAccessKey(), + properties.getBedrockCohere().getSecretKey()); + var cohereEmbeddingApi = new CohereEmbeddingBedrockApi(vectorize.cohereEmbeddingModel().id(), + StaticCredentialsProvider.create(credentials), properties.getBedrockCohere().getRegion(), + ModelOptionsUtils.OBJECT_MAPPER); + return new BedrockCohereEmbeddingModel(cohereEmbeddingApi); + } + return this.bedrockCohereEmbeddingModel; + } + + private BedrockTitanEmbeddingModel getBedrockTitanEmbeddingModel(Vectorize vectorize) { + if (!vectorize.titanEmbeddingModel().equals(TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1)) { + AwsCredentials credentials = AwsBasicCredentials.create(properties.getBedrockCohere().getAccessKey(), + properties.getBedrockCohere().getSecretKey()); + var titanEmbeddingApi = new TitanEmbeddingBedrockApi(vectorize.cohereEmbeddingModel().id(), + StaticCredentialsProvider.create(credentials), properties.getBedrockTitan().getRegion(), + ModelOptionsUtils.OBJECT_MAPPER, Duration.ofMinutes(5L)); + return new BedrockTitanEmbeddingModel(titanEmbeddingApi); + } + return this.bedrockTitanEmbeddingModel; + } + + @Override + public boolean isReady() { + return this.faceEmbeddingModel != null && this.sentenceTokenizer != null; + } + + @Override + public List getTextEmbeddingsAsBytes(List texts, Field field) { + if (field.isAnnotationPresent(Vectorize.class)) { + Vectorize vectorize = field.getAnnotation(Vectorize.class); + return vectorize.embeddingType() == SENTENCE ? + getSentenceEmbeddingAsBytes(texts, vectorize) : + Collections.emptyList(); + } else { + return Collections.emptyList(); + } + } + + private List getSentenceEmbeddingAsBytes(List texts, Vectorize vectorize) { + return switch (vectorize.provider()) { + case DJL -> getSentenceEmbeddingsAsByteArrayFor(texts); + case OPENAI -> { + OpenAiEmbeddingModel model = getOpenAiEmbeddingModel(vectorize); + yield getEmbeddingsAsByteArrayFor(texts, model); + } + case OLLAMA -> { + OllamaEmbeddingModel model = getOllamaEmbeddingModel(vectorize); + yield getEmbeddingsAsByteArrayFor(texts, model); + } + case AZURE_OPENAI -> { + AzureOpenAiEmbeddingModel model = getAzureOpenAiEmbeddingModel(vectorize); + yield getEmbeddingsAsByteArrayFor(texts, model); + } + case VERTEX_AI -> { + VertexAiPaLm2EmbeddingModel model = getVertexAiPaLm2EmbeddingModel(vectorize); + yield getEmbeddingsAsByteArrayFor(texts, model); + } + case AMAZON_BEDROCK_COHERE -> { + BedrockCohereEmbeddingModel model = getBedrockCohereEmbeddingModel(vectorize); + yield getEmbeddingsAsByteArrayFor(texts, model); + } + case AMAZON_BEDROCK_TITAN -> { + BedrockTitanEmbeddingModel model = getBedrockTitanEmbeddingModel(vectorize); + yield getEmbeddingsAsByteArrayFor(texts, model); + } + }; + } + + private List getSentenceEmbeddingAsFloats(List texts, Vectorize vectorize) { + return switch (vectorize.provider()) { + case DJL -> getSentenceEmbeddingAsFloatArrayFor(texts); + case OPENAI -> { + OpenAiEmbeddingModel model = getOpenAiEmbeddingModel(vectorize); + yield getEmbeddingAsFloatArrayFor(texts, model); + } + case OLLAMA -> { + OllamaEmbeddingModel model = getOllamaEmbeddingModel(vectorize); + yield getEmbeddingAsFloatArrayFor(texts, model); + } + case AZURE_OPENAI -> { + AzureOpenAiEmbeddingModel model = getAzureOpenAiEmbeddingModel(vectorize); + yield getEmbeddingAsFloatArrayFor(texts, model); + } + case VERTEX_AI -> { + VertexAiPaLm2EmbeddingModel model = getVertexAiPaLm2EmbeddingModel(vectorize); + yield getEmbeddingAsFloatArrayFor(texts, model); + } + case AMAZON_BEDROCK_COHERE -> { + BedrockCohereEmbeddingModel model = getBedrockCohereEmbeddingModel(vectorize); + yield getEmbeddingAsFloatArrayFor(texts, model); + } + case AMAZON_BEDROCK_TITAN -> { + BedrockTitanEmbeddingModel model = getBedrockTitanEmbeddingModel(vectorize); + yield getEmbeddingAsFloatArrayFor(texts, model); + } + }; + } + + @Override + public List getTextEmbeddingsAsFloats(List texts, Field field) { + if (field.isAnnotationPresent(Vectorize.class)) { + Vectorize vectorize = field.getAnnotation(Vectorize.class); + return vectorize.embeddingType() == SENTENCE ? + getSentenceEmbeddingAsFloats(texts, vectorize) : + Collections.emptyList(); + } else { + return Collections.emptyList(); + } + } + + @Override + public List getTextEmbeddingsAsBytes(List texts, MetamodelField metamodelField) { + return getTextEmbeddingsAsBytes(texts, metamodelField.getSearchFieldAccessor().getField()); + } + + @Override + public List getTextEmbeddingsAsFloats(List texts, MetamodelField metamodelField) { + return getTextEmbeddingsAsFloats(texts, metamodelField.getSearchFieldAccessor().getField()); + } +} diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/Embedder.java b/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/Embedder.java new file mode 100644 index 00000000..a6189f84 --- /dev/null +++ b/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/Embedder.java @@ -0,0 +1,21 @@ +package com.redis.om.spring.vectorize; + +import com.redis.om.spring.metamodel.MetamodelField; + +import java.lang.reflect.Field; +import java.util.List; + +public interface Embedder { + + void processEntity(Object item); + + boolean isReady(); + + List getTextEmbeddingsAsBytes(List texts, Field field); + + List getTextEmbeddingsAsFloats(List texts, Field field); + + List getTextEmbeddingsAsBytes(List texts, MetamodelField field); + + List getTextEmbeddingsAsFloats(List texts, MetamodelField field); +} diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/NoopEmbedder.java b/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/NoopEmbedder.java new file mode 100644 index 00000000..e736ac75 --- /dev/null +++ b/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/NoopEmbedder.java @@ -0,0 +1,40 @@ +package com.redis.om.spring.vectorize; + +import com.redis.om.spring.metamodel.MetamodelField; + +import java.lang.reflect.Field; +import java.util.List; + +public class NoopEmbedder implements Embedder { + + @Override + public void processEntity(Object item) { + // NOOP + } + + @Override + public boolean isReady() { + return false; + } + + @Override + public List getTextEmbeddingsAsBytes(List texts, Field field) { + return List.of(); + } + + @Override + public List getTextEmbeddingsAsFloats(List texts, Field field) { + return List.of(); + } + + @Override + public List getTextEmbeddingsAsBytes(List description, MetamodelField field) { + return List.of(); + } + + @Override + public List getTextEmbeddingsAsFloats(List texts, MetamodelField field) { + return List.of(); + } + +} diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/AbstractBaseOMTest.java b/redis-om-spring/src/test/java/com/redis/om/spring/AbstractBaseOMTest.java index 0b14e06b..e47a12e3 100644 --- a/redis-om-spring/src/test/java/com/redis/om/spring/AbstractBaseOMTest.java +++ b/redis-om-spring/src/test/java/com/redis/om/spring/AbstractBaseOMTest.java @@ -3,7 +3,7 @@ import com.google.gson.GsonBuilder; import com.redis.om.spring.indexing.RediSearchIndexer; import com.redis.om.spring.ops.RedisModulesOperations; -import com.redis.om.spring.vectorize.FeatureExtractor; +import com.redis.om.spring.vectorize.Embedder; import com.redis.testcontainers.RedisStackContainer; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; @@ -32,7 +32,7 @@ public abstract class AbstractBaseOMTest { @Autowired @Qualifier("featureExtractor") - public FeatureExtractor featureExtractor; + public Embedder embedder; @Autowired protected StringRedisTemplate template; @Autowired diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/vectorize/VectorizeDocumentTest.java b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/vectorize/VectorizeDocumentTest.java index 319f5549..5389cea2 100644 --- a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/vectorize/VectorizeDocumentTest.java +++ b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/vectorize/VectorizeDocumentTest.java @@ -8,7 +8,7 @@ import com.redis.om.spring.search.stream.SearchStream; import com.redis.om.spring.tuple.Fields; import com.redis.om.spring.tuple.Pair; -import com.redis.om.spring.vectorize.FeatureExtractor; +import com.redis.om.spring.vectorize.Embedder; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; @@ -29,7 +29,7 @@ class VectorizeDocumentTest extends AbstractBaseDocumentTest { EntityStream entityStream; @Autowired - FeatureExtractor featureExtractor; + Embedder embedder; @BeforeEach void loadTestData() throws IOException { @@ -166,4 +166,20 @@ void testKnnSentenceSimilaritySearchWithScores() { .containsExactly(0.0, 0.6704, 0.7162, 0.7705, 0.8107) // ); } + + @Test + @EnabledIf( + expression = "#{@featureExtractor.isReady()}", // + loadContext = true // + ) + void testEmbedderCanVectorizeSentence() { + Optional maybeCat = repository.findFirstByName("cat"); + assertThat(maybeCat).isPresent(); + Product cat = maybeCat.get(); + var catEmbedding = cat.getSentenceEmbedding(); + List embeddings = embedder.getTextEmbeddingsAsFloats(List.of(cat.getDescription()), Product$.DESCRIPTION); + assertAll( // + () -> assertThat(embeddings).isNotEmpty(), // + () -> assertThat(embeddings.get(0)).isEqualTo(catEmbedding)); + } } diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/vectorize/VectorizeOllamaDocumentTest.java b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/vectorize/VectorizeOllamaDocumentTest.java index e429280d..9d7afb67 100644 --- a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/vectorize/VectorizeOllamaDocumentTest.java +++ b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/vectorize/VectorizeOllamaDocumentTest.java @@ -8,21 +8,15 @@ import com.redis.om.spring.search.stream.SearchStream; import com.redis.om.spring.tuple.Fields; import com.redis.om.spring.tuple.Pair; -import com.redis.om.spring.vectorize.FeatureExtractor; +import com.redis.om.spring.vectorize.Embedder; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIfEnvironmentVariable; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.context.annotation.Bean; -import org.springframework.test.context.junit.jupiter.EnabledIf; import org.springframework.test.context.junit.jupiter.DisabledIf; -import org.springframework.web.client.RestTemplate; +import org.springframework.test.context.junit.jupiter.EnabledIf; -import java.io.BufferedReader; import java.io.IOException; -import java.io.InputStreamReader; -import java.net.HttpURLConnection; -import java.net.URL; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; @@ -32,7 +26,7 @@ @DisabledIfEnvironmentVariable(named = "GITHUB_ACTIONS", matches = "true") @DisabledIf( - expression = "#{!T(com.redis.om.spring.annotations.document.vectorize.VectorizeOllamaDocumentTest).isOllamaRunning()}", + expression = "#{!T(com.redis.om.spring.util.Utils).isOllamaRunning()}", reason = "Disabled if Ollama is not running locally" ) class VectorizeOllamaDocumentTest extends AbstractBaseDocumentTest { @@ -43,7 +37,7 @@ class VectorizeOllamaDocumentTest extends AbstractBaseDocumentTest { EntityStream entityStream; @Autowired - FeatureExtractor featureExtractor; + Embedder embedder; @BeforeEach void loadTestData() throws IOException { @@ -68,7 +62,6 @@ void loadTestData() throws IOException { ) void testSentenceIsVectorized() { Optional cat = repository.findFirstByName("cat"); - System.out.println("TEXT EMBEDDING SIZE: " + cat.get().getTextEmbedding().length); assertAll( // () -> assertThat(cat).isPresent(), // () -> assertThat(cat.get()).extracting("textEmbedding").isNotNull(), // @@ -147,23 +140,19 @@ void testKnnSentenceSimilaritySearchWithScores() { ); } - public static boolean isOllamaRunning() { - try { - URL url = new URL("http://localhost:11434"); - HttpURLConnection connection = (HttpURLConnection) url.openConnection(); - connection.setRequestMethod("GET"); - connection.connect(); - - int responseCode = connection.getResponseCode(); - if (responseCode == HttpURLConnection.HTTP_OK) { - BufferedReader reader = new BufferedReader(new InputStreamReader(connection.getInputStream())); - String response = reader.readLine(); - reader.close(); - return response != null && response.contains("Ollama is running"); - } - } catch (IOException e) { - // Handle the exception if needed - } - return false; + @Test + @EnabledIf( + expression = "#{@featureExtractor.isReady()}", // + loadContext = true // + ) + void testEmbedderCanVectorizeSentence() { + Optional maybeCat = repository.findFirstByName("cat"); + assertThat(maybeCat).isPresent(); + DocWithOllamaEmbedding cat = maybeCat.get(); + var catEmbedding = cat.getTextEmbedding(); + List embeddings = embedder.getTextEmbeddingsAsFloats(List.of(cat.getText()), DocWithOllamaEmbedding$.TEXT); + assertAll( // + () -> assertThat(embeddings).isNotEmpty(), // + () -> assertThat(embeddings.get(0)).isEqualTo(catEmbedding)); } } diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/vectorize/VectorizeOpenAIDocumentTest.java b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/vectorize/VectorizeOpenAIDocumentTest.java index 22877206..f98ddc84 100644 --- a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/vectorize/VectorizeOpenAIDocumentTest.java +++ b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/vectorize/VectorizeOpenAIDocumentTest.java @@ -10,7 +10,7 @@ import com.redis.om.spring.search.stream.SearchStream; import com.redis.om.spring.tuple.Fields; import com.redis.om.spring.tuple.Pair; -import com.redis.om.spring.vectorize.FeatureExtractor; +import com.redis.om.spring.vectorize.Embedder; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledIfEnvironmentVariable; @@ -42,7 +42,7 @@ class VectorizeOpenAIDocumentTest extends AbstractBaseDocumentTest { EntityStream entityStream; @Autowired - FeatureExtractor featureExtractor; + Embedder embedder; @BeforeEach void loadTestData() throws IOException { @@ -76,7 +76,6 @@ void loadTestData() throws IOException { ) void testSentenceIsVectorized() { Optional cat = repository.findFirstByName("cat"); - System.out.println("TEXT EMBEDDING SIZE: " + cat.get().getTextEmbedding().length); assertAll( // () -> assertThat(cat).isPresent(), // () -> assertThat(cat.get()).extracting("textEmbedding").isNotNull(), // @@ -91,7 +90,6 @@ void testSentenceIsVectorized() { ) void testSentenceIsVectorizedWithCustomModel() { Optional cat = repository2.findFirstByName("cat"); - System.out.println("TEXT EMBEDDING SIZE: " + cat.get().getTextEmbedding().length); assertAll( // () -> assertThat(cat).isPresent(), // () -> assertThat(cat.get()).extracting("textEmbedding").isNotNull(), // diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/serialization/SerializationTest.java b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/serialization/SerializationTest.java index ee3e7d8e..b914a5e1 100644 --- a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/serialization/SerializationTest.java +++ b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/serialization/SerializationTest.java @@ -5,8 +5,8 @@ import com.redis.om.spring.AbstractBaseEnhancedRedisTest; import com.redis.om.spring.fixtures.hash.model.KitchenSink; import com.redis.om.spring.fixtures.hash.repository.KitchenSinkRepository; -import com.redis.om.spring.vectorize.DefaultFeatureExtractor; -import com.redis.om.spring.vectorize.FeatureExtractor; +import com.redis.om.spring.vectorize.DefaultEmbedder; +import com.redis.om.spring.vectorize.Embedder; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; @@ -32,7 +32,7 @@ class SerializationTest extends AbstractBaseEnhancedRedisTest { KitchenSinkRepository repository; @Autowired - FeatureExtractor featureExtractor; + Embedder embedder; @Autowired private ApplicationContext applicationContext; @@ -69,11 +69,11 @@ public void cleanUp() throws IOException, NoSuchMethodException, InvocationTarge ulid = UlidCreator.getMonotonicUlid(); byteArray = "Hello World!".getBytes(); - java.lang.reflect.Method method = DefaultFeatureExtractor.class.getDeclaredMethod( + java.lang.reflect.Method method = DefaultEmbedder.class.getDeclaredMethod( "getImageEmbeddingsAsByteArrayFor", InputStream.class); method.setAccessible(true); - byteArray2 = (byte[]) method.invoke(featureExtractor, + byteArray2 = (byte[]) method.invoke(embedder, applicationContext.getResource("classpath:/images/cat.jpg").getInputStream()); yearMonth = YearMonth.of(1972, 6); diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/vectorize/VectorizeHashTest.java b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/vectorize/VectorizeHashTest.java index a1f0c9f2..ec3a58d2 100644 --- a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/vectorize/VectorizeHashTest.java +++ b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/vectorize/VectorizeHashTest.java @@ -8,7 +8,7 @@ import com.redis.om.spring.search.stream.SearchStream; import com.redis.om.spring.tuple.Fields; import com.redis.om.spring.tuple.Pair; -import com.redis.om.spring.vectorize.FeatureExtractor; +import com.redis.om.spring.vectorize.Embedder; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; @@ -29,7 +29,7 @@ class VectorizeHashTest extends AbstractBaseEnhancedRedisTest { EntityStream entityStream; @Autowired - FeatureExtractor featureExtractor; + Embedder embedder; @BeforeEach void loadTestData() throws IOException { @@ -165,4 +165,21 @@ void testKnnSentenceSimilaritySearchWithScores() { .containsExactly(0.0, 0.6704, 0.7162, 0.7705, 0.8107) // ); } + + @Test + @EnabledIf( + expression = "#{@featureExtractor.isReady()}", // + loadContext = true // + ) + void testEmbedderCanVectorizeSentence() { + Optional maybeCat = repository.findFirstByName("cat"); + assertThat(maybeCat).isPresent(); + Product cat = maybeCat.get(); + var catEmbedding = cat.getSentenceEmbedding(); + List embeddings = embedder.getTextEmbeddingsAsBytes(List.of(cat.getDescription()), Product$.DESCRIPTION); + assertAll( // + () -> assertThat(embeddings).isNotEmpty(), // + () -> assertThat(embeddings.get(0)).isEqualTo(catEmbedding)); + } + } diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/vectorize/VectorizeOllamaHashTest.java b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/vectorize/VectorizeOllamaHashTest.java new file mode 100644 index 00000000..11937bb1 --- /dev/null +++ b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/vectorize/VectorizeOllamaHashTest.java @@ -0,0 +1,159 @@ +package com.redis.om.spring.annotations.hash.vectorize; + +import com.redis.om.spring.AbstractBaseEnhancedRedisTest; +import com.redis.om.spring.fixtures.hash.model.HashWithOllamaEmbedding; +import com.redis.om.spring.fixtures.hash.model.HashWithOllamaEmbedding$; +import com.redis.om.spring.fixtures.hash.repository.HashWithOllamaEmbeddingRepository; +import com.redis.om.spring.search.stream.EntityStream; +import com.redis.om.spring.search.stream.SearchStream; +import com.redis.om.spring.tuple.Fields; +import com.redis.om.spring.tuple.Pair; +import com.redis.om.spring.vectorize.Embedder; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledIfEnvironmentVariable; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.junit.jupiter.DisabledIf; +import org.springframework.test.context.junit.jupiter.EnabledIf; + +import java.io.IOException; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertAll; + +@DisabledIfEnvironmentVariable(named = "GITHUB_ACTIONS", matches = "true") +@DisabledIf( + expression = "#{!T(com.redis.om.spring.util.Utils).isOllamaRunning()}", + reason = "Disabled if Ollama is not running locally" +) +class VectorizeOllamaHashTest extends AbstractBaseEnhancedRedisTest { + @Autowired + HashWithOllamaEmbeddingRepository repository; + + @Autowired + EntityStream entityStream; + + @Autowired + Embedder embedder; + + @BeforeEach + void loadTestData() throws IOException { + if (repository.count() == 0) { + repository.save(HashWithOllamaEmbedding.of("cat", + "The cat is a small domesticated carnivorous mammal with soft fur, a short snout, and retractable claws.")); + repository.save(HashWithOllamaEmbedding.of("dog", + "A dog is a domesticated mammal of the family Canidae, characterized by its loyalty, playfulness, and friendly demeanor.")); + repository.save(HashWithOllamaEmbedding.of("lion", + "The lion is a large cat of the genus Panthera native to Africa and India, known for its muscular body, deep roar, and mane on the male.")); + repository.save(HashWithOllamaEmbedding.of("elephant", + "Elephants are the largest existing land animals, characterized by their long trunk, tusks, and large ears.")); + repository.save(HashWithOllamaEmbedding.of("giraffe", + "The giraffe is an African even-toed ungulate mammal, the tallest living terrestrial animal, and the largest ruminant.")); + } + } + + @Test + @EnabledIf( + expression = "#{@featureExtractor.isReady()}", // + loadContext = true // + ) + void testSentenceIsVectorized() { + Optional cat = repository.findFirstByName("cat"); + assertAll( // + () -> assertThat(cat).isPresent(), // + () -> assertThat(cat.get()).extracting("textEmbedding").isNotNull(), // + () -> assertThat(cat.get().getTextEmbedding()).hasSize(4096*Float.BYTES) // + ); + } + + @Test + @EnabledIf( + expression = "#{@featureExtractor.isReady()}", // + loadContext = true // + ) + void testKnnSentenceSimilaritySearch() { + HashWithOllamaEmbedding cat = repository.findFirstByName("cat").get(); + int K = 5; + + SearchStream stream = entityStream.of(HashWithOllamaEmbedding.class); + + List results = stream // + .filter(HashWithOllamaEmbedding$.TEXT_EMBEDDING.knn(K, cat.getTextEmbedding())) // + .sorted(HashWithOllamaEmbedding$._TEXT_EMBEDDING_SCORE) // + .limit(K) // + .collect(Collectors.toList()); + + assertThat(results).hasSize(5).map(HashWithOllamaEmbedding::getName).containsExactly( // + "cat", "dog", "lion", "elephant", "giraffe" // + ); + } + + @Test + @EnabledIf( + expression = "#{@featureExtractor.isReady()}", // + loadContext = true // + ) + void testKnnHybridSentenceSimilaritySearch() { + HashWithOllamaEmbedding cat = repository.findFirstByName("cat").get(); + int K = 5; + + SearchStream stream = entityStream.of(HashWithOllamaEmbedding.class); + + List results = stream // + .filter(HashWithOllamaEmbedding$.NAME.in("cat", "lion", "dog")) // + .filter(HashWithOllamaEmbedding$.TEXT_EMBEDDING.knn(K, cat.getTextEmbedding())) // + .sorted(HashWithOllamaEmbedding$._TEXT_EMBEDDING_SCORE) // + .limit(K) // + .collect(Collectors.toList()); + + assertThat(results).hasSize(3).map(HashWithOllamaEmbedding::getName).containsExactly( // + "cat", "dog", "lion" // + ); + } + + @Test + @EnabledIf( + expression = "#{@featureExtractor.isReady()}", // + loadContext = true // + ) + void testKnnSentenceSimilaritySearchWithScores() { + HashWithOllamaEmbedding cat = repository.findFirstByName("cat").get(); + int K = 5; + + SearchStream stream = entityStream.of(HashWithOllamaEmbedding.class); + + List> results = stream // + .filter(HashWithOllamaEmbedding$.TEXT_EMBEDDING.knn(K, cat.getTextEmbedding())) // + .sorted(HashWithOllamaEmbedding$._TEXT_EMBEDDING_SCORE) // + .limit(K) // + .map(Fields.of(HashWithOllamaEmbedding$._THIS, HashWithOllamaEmbedding$._TEXT_EMBEDDING_SCORE)) // + .collect(Collectors.toList()); + + assertAll( // + () -> assertThat(results).hasSize(5).map(Pair::getFirst).map(HashWithOllamaEmbedding::getName) + .containsExactly("cat", "dog", "lion", "elephant", "giraffe"), // + () -> assertThat(results).hasSize(5).map(Pair::getSecond).usingElementComparator(closeToComparator) + .containsExactly(1.78813934326E-7, 0.301205277443, 0.315115869045, 0.338551998138, 0.407371640205) // + ); + } + + @Test + @EnabledIf( + expression = "#{@featureExtractor.isReady()}", // + loadContext = true // + ) + void testEmbedderCanVectorizeSentence() { + Optional maybeCat = repository.findFirstByName("cat"); + assertThat(maybeCat).isPresent(); + HashWithOllamaEmbedding cat = maybeCat.get(); + var catEmbedding = cat.getTextEmbedding(); + List embeddings = embedder.getTextEmbeddingsAsBytes(List.of(cat.getText()), HashWithOllamaEmbedding$.TEXT); + assertAll( // + () -> assertThat(embeddings).isNotEmpty(), // + () -> assertThat(embeddings.get(0)).isEqualTo(catEmbedding) + ); + } +} diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/vectorize/VectorizeOpenAIHashTest.java b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/vectorize/VectorizeOpenAIHashTest.java new file mode 100644 index 00000000..99acbb7b --- /dev/null +++ b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/vectorize/VectorizeOpenAIHashTest.java @@ -0,0 +1,170 @@ +package com.redis.om.spring.annotations.hash.vectorize; + +import com.redis.om.spring.AbstractBaseEnhancedRedisTest; +import com.redis.om.spring.fixtures.hash.model.HashWithCustomModelOpenAIEmbedding; +import com.redis.om.spring.fixtures.hash.model.HashWithOpenAIEmbedding; +import com.redis.om.spring.fixtures.hash.model.HashWithOpenAIEmbedding$; +import com.redis.om.spring.fixtures.hash.repository.HashWithCustomModelOpenAIEmbeddingRepository; +import com.redis.om.spring.fixtures.hash.repository.HashWithOpenAIEmbeddingRepository; +import com.redis.om.spring.search.stream.EntityStream; +import com.redis.om.spring.search.stream.SearchStream; +import com.redis.om.spring.tuple.Fields; +import com.redis.om.spring.tuple.Pair; +import com.redis.om.spring.vectorize.Embedder; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledIfEnvironmentVariable; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.junit.jupiter.DisabledIf; +import org.springframework.test.context.junit.jupiter.EnabledIf; + +import java.io.IOException; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertAll; + +@DisabledIfEnvironmentVariable(named = "GITHUB_ACTIONS", matches = "true") +@DisabledIf( + expression = "#{systemEnvironment['OPENAI_API_KEY'] == null}", + reason = "Disabled if OPENAI_API_KEY environment variable is not set" +) +class VectorizeOpenAIHashTest extends AbstractBaseEnhancedRedisTest { + @Autowired + HashWithOpenAIEmbeddingRepository repository; + + @Autowired + HashWithCustomModelOpenAIEmbeddingRepository repository2; + + @Autowired + EntityStream entityStream; + + @Autowired + Embedder embedder; + + @BeforeEach + void loadTestData() throws IOException { + if (repository.count() == 0) { + repository.save( + HashWithOpenAIEmbedding.of("cat", "The cat (Felis catus) is a domestic species of small carnivorous mammal.")); + repository.save(HashWithOpenAIEmbedding.of("cat2", + "It is the only domesticated species in the family Felidae and is commonly referred to as the domestic cat or house cat")); + repository.save(HashWithOpenAIEmbedding.of("catdog", "This is a picture of a cat and a dog together")); + repository.save(HashWithOpenAIEmbedding.of("face", "Three years later, the coffin was still full of Jello.")); + repository.save( + HashWithOpenAIEmbedding.of("face2", "The person box was packed with jelly many dozens of months later.")); + } + + if (repository2.count() == 0) { + repository2.save( + HashWithCustomModelOpenAIEmbedding.of("cat", "The cat (Felis catus) is a domestic species of small carnivorous mammal.")); + repository2.save(HashWithCustomModelOpenAIEmbedding.of("cat2", + "It is the only domesticated species in the family Felidae and is commonly referred to as the domestic cat or house cat")); + repository2.save(HashWithCustomModelOpenAIEmbedding.of("catdog", "This is a picture of a cat and a dog together")); + repository2.save(HashWithCustomModelOpenAIEmbedding.of("face", "Three years later, the coffin was still full of Jello.")); + repository2.save( + HashWithCustomModelOpenAIEmbedding.of("face2", "The person box was packed with jelly many dozens of months later.")); + } + } + + @Test + @EnabledIf( + expression = "#{@featureExtractor.isReady()}", // + loadContext = true // + ) + void testSentenceIsVectorized() { + Optional cat = repository.findFirstByName("cat"); + assertAll( // + () -> assertThat(cat).isPresent(), // + () -> assertThat(cat.get()).extracting("textEmbedding").isNotNull(), // + () -> assertThat(cat.get().getTextEmbedding()).hasSize(1536*Float.BYTES) // + ); + } + + @Test + @EnabledIf( + expression = "#{@featureExtractor.isReady()}", // + loadContext = true // + ) + void testSentenceIsVectorizedWithCustomModel() { + Optional cat = repository2.findFirstByName("cat"); + assertAll( // + () -> assertThat(cat).isPresent(), // + () -> assertThat(cat.get()).extracting("textEmbedding").isNotNull(), // + () -> assertThat(cat.get().getTextEmbedding()).hasSize(3072*Float.BYTES) // + ); + } + + @Test + @EnabledIf( + expression = "#{@featureExtractor.isReady()}", // + loadContext = true // + ) + void testKnnSentenceSimilaritySearch() { + HashWithOpenAIEmbedding cat = repository.findFirstByName("cat").get(); + int K = 5; + + SearchStream stream = entityStream.of(HashWithOpenAIEmbedding.class); + + List results = stream // + .filter(HashWithOpenAIEmbedding$.TEXT_EMBEDDING.knn(K, cat.getTextEmbedding())) // + .sorted(HashWithOpenAIEmbedding$._TEXT_EMBEDDING_SCORE) // + .limit(K) // + .collect(Collectors.toList()); + + assertThat(results).hasSize(5).map(HashWithOpenAIEmbedding::getName).containsExactly( // + "cat", "cat2", "catdog", "face2", "face" // + ); + } + + @Test + @EnabledIf( + expression = "#{@featureExtractor.isReady()}", // + loadContext = true // + ) + void testKnnHybridSentenceSimilaritySearch() { + HashWithOpenAIEmbedding cat = repository.findFirstByName("cat").get(); + int K = 5; + + SearchStream stream = entityStream.of(HashWithOpenAIEmbedding.class); + + List results = stream // + .filter(HashWithOpenAIEmbedding$.NAME.startsWith("cat")) // + .filter(HashWithOpenAIEmbedding$.TEXT_EMBEDDING.knn(K, cat.getTextEmbedding())) // + .sorted(HashWithOpenAIEmbedding$._TEXT_EMBEDDING_SCORE) // + .limit(K) // + .collect(Collectors.toList()); + + assertThat(results).hasSize(3).map(HashWithOpenAIEmbedding::getName).containsExactly( // + "cat", "cat2", "catdog" // + ); + } + + @Test + @EnabledIf( + expression = "#{@featureExtractor.isReady()}", // + loadContext = true // + ) + void testKnnSentenceSimilaritySearchWithScores() { + HashWithOpenAIEmbedding cat = repository.findFirstByName("cat").get(); + int K = 5; + + SearchStream stream = entityStream.of(HashWithOpenAIEmbedding.class); + + List> results = stream // + .filter(HashWithOpenAIEmbedding$.TEXT_EMBEDDING.knn(K, cat.getTextEmbedding())) // + .sorted(HashWithOpenAIEmbedding$._TEXT_EMBEDDING_SCORE) // + .limit(K) // + .map(Fields.of(HashWithOpenAIEmbedding$._THIS, HashWithOpenAIEmbedding$._TEXT_EMBEDDING_SCORE)) // + .collect(Collectors.toList()); + + assertAll( // + () -> assertThat(results).hasSize(5).map(Pair::getFirst).map(HashWithOpenAIEmbedding::getName) + .containsExactly("cat", "cat2", "catdog", "face2", "face"), // + () -> assertThat(results).hasSize(5).map(Pair::getSecond).usingElementComparator(closeToComparator) + .containsExactly(7.15255737305E-7, 0.0800130963326, 0.163947761059, 0.261719405651, 0.288997769356) // + ); + } +} diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/fixtures/hash/model/HashWithCustomModelOpenAIEmbedding.java b/redis-om-spring/src/test/java/com/redis/om/spring/fixtures/hash/model/HashWithCustomModelOpenAIEmbedding.java new file mode 100644 index 00000000..e19cafa7 --- /dev/null +++ b/redis-om-spring/src/test/java/com/redis/om/spring/fixtures/hash/model/HashWithCustomModelOpenAIEmbedding.java @@ -0,0 +1,45 @@ +package com.redis.om.spring.fixtures.hash.model; + +import com.redis.om.spring.annotations.*; +import com.redis.om.spring.indexing.DistanceMetric; +import com.redis.om.spring.indexing.VectorType; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import org.springframework.ai.openai.api.OpenAiApi.EmbeddingModel; +import org.springframework.data.annotation.Id; +import org.springframework.data.redis.core.RedisHash; +import redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm; + +@Data +@RequiredArgsConstructor(staticName = "of") +@NoArgsConstructor(force = true) +@RedisHash +public class HashWithCustomModelOpenAIEmbedding { + @Id + private String id; + + @Indexed + @NonNull + private String name; + + @Indexed( // + schemaFieldType = SchemaFieldType.VECTOR, // + algorithm = VectorAlgorithm.HNSW, // + type = VectorType.FLOAT32, // + dimension = 3072, // + distanceMetric = DistanceMetric.COSINE, // + initialCapacity = 10 + ) + private byte[] textEmbedding; + + @Vectorize( // + destination = "textEmbedding", // + embeddingType = EmbeddingType.SENTENCE, // + provider = EmbeddingProvider.OPENAI, // + openAiEmbeddingModel = EmbeddingModel.TEXT_EMBEDDING_3_LARGE + ) + @NonNull + private String text; +} diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/fixtures/hash/model/HashWithOllamaEmbedding.java b/redis-om-spring/src/test/java/com/redis/om/spring/fixtures/hash/model/HashWithOllamaEmbedding.java new file mode 100644 index 00000000..14fa8679 --- /dev/null +++ b/redis-om-spring/src/test/java/com/redis/om/spring/fixtures/hash/model/HashWithOllamaEmbedding.java @@ -0,0 +1,39 @@ +package com.redis.om.spring.fixtures.hash.model; + +import com.redis.om.spring.annotations.*; +import com.redis.om.spring.indexing.DistanceMetric; +import com.redis.om.spring.indexing.VectorType; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import org.springframework.data.annotation.Id; +import org.springframework.data.redis.core.RedisHash; +import redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm; + +@Data +@RequiredArgsConstructor(staticName = "of") +@NoArgsConstructor(force = true) +@RedisHash +public class HashWithOllamaEmbedding { + @Id + private String id; + + @Indexed + @NonNull + private String name; + + @Indexed(// + schemaFieldType = SchemaFieldType.VECTOR, // + algorithm = VectorAlgorithm.HNSW, // + type = VectorType.FLOAT32, // + dimension = 4096, // + distanceMetric = DistanceMetric.COSINE, // + initialCapacity = 10 + ) + private byte[] textEmbedding; + + @Vectorize(destination = "textEmbedding", embeddingType = EmbeddingType.SENTENCE, provider = EmbeddingProvider.OLLAMA) + @NonNull + private String text; +} diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/fixtures/hash/model/HashWithOpenAIEmbedding.java b/redis-om-spring/src/test/java/com/redis/om/spring/fixtures/hash/model/HashWithOpenAIEmbedding.java new file mode 100644 index 00000000..252e7ea1 --- /dev/null +++ b/redis-om-spring/src/test/java/com/redis/om/spring/fixtures/hash/model/HashWithOpenAIEmbedding.java @@ -0,0 +1,39 @@ +package com.redis.om.spring.fixtures.hash.model; + +import com.redis.om.spring.annotations.*; +import com.redis.om.spring.indexing.DistanceMetric; +import com.redis.om.spring.indexing.VectorType; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import org.springframework.data.annotation.Id; +import org.springframework.data.redis.core.RedisHash; +import redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm; + +@Data +@RequiredArgsConstructor(staticName = "of") +@NoArgsConstructor(force = true) +@RedisHash +public class HashWithOpenAIEmbedding { + @Id + private String id; + + @Indexed + @NonNull + private String name; + + @Indexed(// + schemaFieldType = SchemaFieldType.VECTOR, // + algorithm = VectorAlgorithm.HNSW, // + type = VectorType.FLOAT32, // + dimension = 1536, // + distanceMetric = DistanceMetric.COSINE, // + initialCapacity = 10 + ) + private byte[] textEmbedding; + + @Vectorize(destination = "textEmbedding", embeddingType = EmbeddingType.SENTENCE, provider = EmbeddingProvider.OPENAI) + @NonNull + private String text; +} diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/fixtures/hash/repository/HashWithCustomModelOpenAIEmbeddingRepository.java b/redis-om-spring/src/test/java/com/redis/om/spring/fixtures/hash/repository/HashWithCustomModelOpenAIEmbeddingRepository.java new file mode 100644 index 00000000..6cd9f9d5 --- /dev/null +++ b/redis-om-spring/src/test/java/com/redis/om/spring/fixtures/hash/repository/HashWithCustomModelOpenAIEmbeddingRepository.java @@ -0,0 +1,11 @@ +package com.redis.om.spring.fixtures.hash.repository; + +import com.redis.om.spring.fixtures.document.model.DocWithCustomModelOpenAIEmbedding; +import com.redis.om.spring.fixtures.hash.model.HashWithCustomModelOpenAIEmbedding; +import com.redis.om.spring.repository.RedisDocumentRepository; + +import java.util.Optional; + +public interface HashWithCustomModelOpenAIEmbeddingRepository extends RedisDocumentRepository { + Optional findFirstByName(String name); +} diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/fixtures/hash/repository/HashWithOllamaEmbeddingRepository.java b/redis-om-spring/src/test/java/com/redis/om/spring/fixtures/hash/repository/HashWithOllamaEmbeddingRepository.java new file mode 100644 index 00000000..6c5bdf8d --- /dev/null +++ b/redis-om-spring/src/test/java/com/redis/om/spring/fixtures/hash/repository/HashWithOllamaEmbeddingRepository.java @@ -0,0 +1,10 @@ +package com.redis.om.spring.fixtures.hash.repository; + +import com.redis.om.spring.fixtures.hash.model.HashWithOllamaEmbedding; +import com.redis.om.spring.repository.RedisEnhancedRepository; + +import java.util.Optional; + +public interface HashWithOllamaEmbeddingRepository extends RedisEnhancedRepository { + Optional findFirstByName(String name); +} diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/fixtures/hash/repository/HashWithOpenAIEmbeddingRepository.java b/redis-om-spring/src/test/java/com/redis/om/spring/fixtures/hash/repository/HashWithOpenAIEmbeddingRepository.java new file mode 100644 index 00000000..d30cbd9e --- /dev/null +++ b/redis-om-spring/src/test/java/com/redis/om/spring/fixtures/hash/repository/HashWithOpenAIEmbeddingRepository.java @@ -0,0 +1,12 @@ +package com.redis.om.spring.fixtures.hash.repository; + +import com.redis.om.spring.fixtures.document.model.DocWithOpenAIEmbedding; +import com.redis.om.spring.fixtures.hash.model.HashWithOpenAIEmbedding; +import com.redis.om.spring.repository.RedisDocumentRepository; +import com.redis.om.spring.repository.RedisEnhancedRepository; + +import java.util.Optional; + +public interface HashWithOpenAIEmbeddingRepository extends RedisEnhancedRepository { + Optional findFirstByName(String name); +} diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/repository/support/QueryByExampleDocumentRepositoryIntegrationTests.java b/redis-om-spring/src/test/java/com/redis/om/spring/repository/support/QueryByExampleDocumentRepositoryIntegrationTests.java index 4e01a62a..1b7579d3 100644 --- a/redis-om-spring/src/test/java/com/redis/om/spring/repository/support/QueryByExampleDocumentRepositoryIntegrationTests.java +++ b/redis-om-spring/src/test/java/com/redis/om/spring/repository/support/QueryByExampleDocumentRepositoryIntegrationTests.java @@ -55,7 +55,7 @@ void before() { indexer, // mappingContext, // gsonBuilder, // - featureExtractor, + embedder, new RedisOMProperties()); repository.deleteAll(); diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/repository/support/QueryByExampleHashRepositoryIntegrationTests.java b/redis-om-spring/src/test/java/com/redis/om/spring/repository/support/QueryByExampleHashRepositoryIntegrationTests.java index b7dc96dd..a7b9cd67 100644 --- a/redis-om-spring/src/test/java/com/redis/om/spring/repository/support/QueryByExampleHashRepositoryIntegrationTests.java +++ b/redis-om-spring/src/test/java/com/redis/om/spring/repository/support/QueryByExampleHashRepositoryIntegrationTests.java @@ -50,7 +50,7 @@ class QueryByExampleHashRepositoryIntegrationTests extends AbstractBaseEnhancedR void before() { repository = new SimpleRedisEnhancedRepository<>(getEntityInformation(Person.class), - new KeyValueTemplate(new RedisKeyValueAdapter(template)), modulesOperations, indexer, featureExtractor, + new KeyValueTemplate(new RedisKeyValueAdapter(template)), modulesOperations, indexer, embedder, new RedisOMProperties()); repository.deleteAll();