diff --git a/demos/roms-vss/src/main/resources/application.properties b/demos/roms-vss/src/main/resources/application.properties index 950277c9..989e21dc 100644 --- a/demos/roms-vss/src/main/resources/application.properties +++ b/demos/roms-vss/src/main/resources/application.properties @@ -1,4 +1,4 @@ spring.mvc.hiddenmethod.filter.enabled=true com.redis.om.vss.useLocalImages=false com.redis.om.vss.maxLines=300 -redis.om.spring.ai.djl.enabled=true +redis.om.spring.ai.enabled=true diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/RedisAiConfiguration.java b/redis-om-spring/src/main/java/com/redis/om/spring/RedisAiConfiguration.java index 0e5ef1ac..328fc1b0 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/RedisAiConfiguration.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/RedisAiConfiguration.java @@ -57,7 +57,7 @@ import java.time.*; import java.util.Map; -@ConditionalOnProperty(name = "redis.om.spring.ai.djl.enabled") +@ConditionalOnProperty(name = "redis.om.spring.ai.enabled") @Configuration @EnableConfigurationProperties({ RedisOMAiProperties.class }) public class RedisAiConfiguration { @@ -71,10 +71,10 @@ public ImageFactory imageFactory() { @Bean(name = "djlImageEmbeddingModelCriteria") public Criteria imageEmbeddingModelCriteria(RedisOMAiProperties properties) { - return properties.getDjl().isEnabled() ? Criteria.builder().setTypes(Image.class, byte[].class) // + return Criteria.builder().setTypes(Image.class, byte[].class) // .optEngine(properties.getDjl().getImageEmbeddingModelEngine()) // .optModelUrls(properties.getDjl().getImageEmbeddingModelModelUrls()) // - .build() : null; + .build(); } @Bean(name = "djlFaceDetectionTranslator") @@ -93,20 +93,19 @@ public Criteria faceDetectionModelCriteria( // @Qualifier("djlFaceDetectionTranslator") Translator translator, // RedisOMAiProperties properties) { - return properties.getDjl().isEnabled() ? Criteria.builder().setTypes(Image.class, DetectedObjects.class) // + return Criteria.builder().setTypes(Image.class, DetectedObjects.class) // .optModelUrls(properties.getDjl().getFaceDetectionModelModelUrls()) // .optModelName(properties.getDjl().getFaceDetectionModelName()) // .optTranslator(translator) // .optEngine(properties.getDjl().getFaceDetectionModelEngine()) // - .build() : null; + .build(); } @Bean(name = "djlFaceDetectionModel") public ZooModel faceDetectionModel( - @Nullable @Qualifier("djlFaceDetectionModelCriteria") Criteria criteria, - RedisOMAiProperties properties) { + @Nullable @Qualifier("djlFaceDetectionModelCriteria") Criteria criteria) { try { - return properties.getDjl().isEnabled() && (criteria != null) ? ModelZoo.loadModel(criteria) : null; + return criteria != null ? ModelZoo.loadModel(criteria) : null; } catch (IOException | ModelNotFoundException | MalformedModelException ex) { logger.warn("Error retrieving default DJL face detection model", ex); return null; @@ -123,20 +122,19 @@ public Criteria faceEmbeddingModelCriteria( // @Qualifier("djlFaceEmbeddingTranslator") Translator translator, // RedisOMAiProperties properties) { - return properties.getDjl().isEnabled() ? Criteria.builder() // + return Criteria.builder() // .setTypes(Image.class, float[].class).optModelUrls(properties.getDjl().getFaceEmbeddingModelModelUrls()) // .optModelName(properties.getDjl().getFaceEmbeddingModelName()) // .optTranslator(translator) // .optEngine(properties.getDjl().getFaceEmbeddingModelEngine()) // - .build() : null; + .build(); } @Bean(name = "djlFaceEmbeddingModel") public ZooModel faceEmbeddingModel( - @Nullable @Qualifier("djlFaceEmbeddingModelCriteria") Criteria criteria, // - RedisOMAiProperties properties) { + @Nullable @Qualifier("djlFaceEmbeddingModelCriteria") Criteria criteria) { try { - return properties.getDjl().isEnabled() && (criteria != null) ? ModelZoo.loadModel(criteria) : null; + return criteria != null ? ModelZoo.loadModel(criteria) : null; } catch (Exception e) { logger.warn("Error retrieving default DJL face embeddings model", e); return null; @@ -145,46 +143,39 @@ public ZooModel faceEmbeddingModel( @Bean(name = "djlImageEmbeddingModel") public ZooModel imageModel( - @Nullable @Qualifier("djlImageEmbeddingModelCriteria") Criteria criteria, - RedisOMAiProperties properties) throws MalformedModelException, ModelNotFoundException, IOException { - return properties.getDjl().isEnabled() && (criteria != null) ? ModelZoo.loadModel(criteria) : null; + @Nullable @Qualifier("djlImageEmbeddingModelCriteria") Criteria criteria) throws MalformedModelException, ModelNotFoundException, IOException { + return criteria != null ? ModelZoo.loadModel(criteria) : null; } @Bean(name = "djlDefaultImagePipeline") public Pipeline defaultImagePipeline(RedisOMAiProperties properties) { - if (properties.getDjl().isEnabled()) { - Pipeline pipeline = new Pipeline(); - if (properties.getDjl().isDefaultImagePipelineCenterCrop()) { - pipeline.add(new CenterCrop()); - } - return pipeline // - .add(new Resize( // - properties.getDjl().getDefaultImagePipelineResizeWidth(), // - properties.getDjl().getDefaultImagePipelineResizeHeight() // - )) // - .add(new ToTensor()); - } else - return null; + Pipeline pipeline = new Pipeline(); + if (properties.getDjl().isDefaultImagePipelineCenterCrop()) { + pipeline.add(new CenterCrop()); + } + return pipeline // + .add(new Resize( // + properties.getDjl().getDefaultImagePipelineResizeWidth(), // + properties.getDjl().getDefaultImagePipelineResizeHeight() // + )) // + .add(new ToTensor()); } @Bean(name = "djlSentenceTokenizer") public HuggingFaceTokenizer sentenceTokenizer(RedisOMAiProperties properties) { - if (properties.getDjl().isEnabled()) { - Map options = Map.of( // - "maxLength", properties.getDjl().getSentenceTokenizerMaxLength(), // - "modelMaxLength", properties.getDjl().getSentenceTokenizerModelMaxLength() // - ); - - try { - //noinspection ResultOfMethodCallIgnored - InetAddress.getByName("www.huggingface.co").isReachable(5000); - return HuggingFaceTokenizer.newInstance(properties.getDjl().getSentenceTokenizerModel(), options); - } catch (IOException ioe) { - logger.warn("Error retrieving default DJL sentence tokenizer"); - return null; - } - } else + Map options = Map.of( // + "maxLength", properties.getDjl().getSentenceTokenizerMaxLength(), // + "modelMaxLength", properties.getDjl().getSentenceTokenizerModelMaxLength() // + ); + + try { + //noinspection ResultOfMethodCallIgnored + InetAddress.getByName("www.huggingface.co").isReachable(5000); + return HuggingFaceTokenizer.newInstance(properties.getDjl().getSentenceTokenizerModel(), options); + } catch (IOException ioe) { + logger.warn("Error retrieving default DJL sentence tokenizer"); return null; + } } @ConditionalOnMissingBean diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/RedisModulesConfiguration.java b/redis-om-spring/src/main/java/com/redis/om/spring/RedisModulesConfiguration.java index cbb23ead..a7506535 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/RedisModulesConfiguration.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/RedisModulesConfiguration.java @@ -274,7 +274,7 @@ public void registerReferenceSerializer(ContextRefreshedEvent cre) { registrar.registerReferencesFor(RedisHash.class); } - @ConditionalOnProperty(name = "redis.om.spring.ai.djl.enabled", havingValue = "false", matchIfMissing = true) + @ConditionalOnProperty(name = "redis.om.spring.ai.enabled", havingValue = "false", matchIfMissing = true) @Bean(name = "featureExtractor") public Embedder featureExtractor() { return new NoopEmbedder(); diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/RedisOMAiProperties.java b/redis-om-spring/src/main/java/com/redis/om/spring/RedisOMAiProperties.java index 01e4c1e5..fbfeeab0 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/RedisOMAiProperties.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/RedisOMAiProperties.java @@ -5,11 +5,12 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.ConfigurationProperties; -@ConditionalOnProperty(name = "redis.om.spring.ai.djl.enabled") +@ConditionalOnProperty(name = "redis.om.spring.ai.enabled") @ConfigurationProperties( prefix = "redis.om.spring.ai", ignoreInvalidFields = true ) public class RedisOMAiProperties { + private boolean enabled = false; private final Djl djl = new Djl(); private final OpenAi openAi = new OpenAi(); private final AzureOpenAi azureOpenAi = new AzureOpenAi(); @@ -18,6 +19,14 @@ public class RedisOMAiProperties { private final BedrockTitan bedrockTitan = new BedrockTitan(); private final Ollama ollama = new Ollama(); + public boolean isEnabled() { + return this.enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + public Djl getDjl() { return djl; } @@ -49,7 +58,6 @@ public Ollama getOllama() { // DJL properties public static class Djl { private static final String DEFAULT_ENGINE = "PyTorch"; - private boolean enabled = false; // image embedding settings @NotNull private String imageEmbeddingModelEngine = DEFAULT_ENGINE; @@ -86,14 +94,6 @@ public static class Djl { public Djl() { } - public boolean isEnabled() { - return this.enabled; - } - - public void setEnabled(boolean enabled) { - this.enabled = enabled; - } - public @NotNull String getImageEmbeddingModelEngine() { return this.imageEmbeddingModelEngine; } @@ -207,7 +207,7 @@ public void setFaceEmbeddingModelModelUrls(@NotNull String faceEmbeddingModelMod } public String toString() { - return "RedisOMSpringProperties.Djl(enabled=" + this.isEnabled() + ", imageEmbeddingModelEngine=" + this.getImageEmbeddingModelEngine() + ", imageEmbeddingModelModelUrls=" + this.getImageEmbeddingModelModelUrls() + ", defaultImagePipelineResizeWidth=" + this.getDefaultImagePipelineResizeWidth() + ", defaultImagePipelineResizeHeight=" + this.getDefaultImagePipelineResizeHeight() + ", defaultImagePipelineCenterCrop=" + this.isDefaultImagePipelineCenterCrop() + ", sentenceTokenizerMaxLength=" + this.getSentenceTokenizerMaxLength() + ", sentenceTokenizerModelMaxLength=" + this.getSentenceTokenizerModelMaxLength() + ", sentenceTokenizerModel=" + this.getSentenceTokenizerModel() + ", faceDetectionModelEngine=" + this.getFaceDetectionModelEngine() + ", faceDetectionModelName=" + this.getFaceDetectionModelName() + ", faceDetectionModelModelUrls=" + this.getFaceDetectionModelModelUrls() + ", faceEmbeddingModelEngine=" + this.getFaceEmbeddingModelEngine() + ", faceEmbeddingModelName=" + this.getFaceEmbeddingModelName() + ", faceEmbeddingModelModelUrls=" + this.getFaceEmbeddingModelModelUrls() + ")"; + return "RedisOMSpringProperties.Ai.Djl(imageEmbeddingModelEngine=" + this.getImageEmbeddingModelEngine() + ", imageEmbeddingModelModelUrls=" + this.getImageEmbeddingModelModelUrls() + ", defaultImagePipelineResizeWidth=" + this.getDefaultImagePipelineResizeWidth() + ", defaultImagePipelineResizeHeight=" + this.getDefaultImagePipelineResizeHeight() + ", defaultImagePipelineCenterCrop=" + this.isDefaultImagePipelineCenterCrop() + ", sentenceTokenizerMaxLength=" + this.getSentenceTokenizerMaxLength() + ", sentenceTokenizerModelMaxLength=" + this.getSentenceTokenizerModelMaxLength() + ", sentenceTokenizerModel=" + this.getSentenceTokenizerModel() + ", faceDetectionModelEngine=" + this.getFaceDetectionModelEngine() + ", faceDetectionModelName=" + this.getFaceDetectionModelName() + ", faceDetectionModelModelUrls=" + this.getFaceDetectionModelModelUrls() + ", faceEmbeddingModelEngine=" + this.getFaceEmbeddingModelEngine() + ", faceEmbeddingModelName=" + this.getFaceEmbeddingModelName() + ", faceEmbeddingModelModelUrls=" + this.getFaceEmbeddingModelModelUrls() + ")"; } } diff --git a/redis-om-spring/src/test/resources/vss_on.yaml b/redis-om-spring/src/test/resources/vss_on.yaml index 4bfe1b6f..82b0f80c 100644 --- a/redis-om-spring/src/test/resources/vss_on.yaml +++ b/redis-om-spring/src/test/resources/vss_on.yaml @@ -2,5 +2,4 @@ redis: om: spring: ai: - djl: - enabled: true \ No newline at end of file + \enabled: true \ No newline at end of file