diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfiguration.java index 6498d233bd..ee55d705c7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfiguration.java @@ -17,7 +17,10 @@ package org.springframework.ai.autoconfigure.vectorstore.cosmosdb; import com.azure.cosmos.CosmosClientBuilder; + +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.vectorstore.CosmosDBVectorStore; import org.springframework.ai.vectorstore.CosmosDBVectorStoreConfig; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; @@ -32,9 +35,9 @@ /** * @author Theo van Kraay + * @author Soby Chacko * @since 1.0.0 */ - @AutoConfiguration @ConditionalOnClass({ CosmosDBVectorStore.class, EmbeddingModel.class, CosmosAsyncClient.class }) @EnableConfigurationProperties(CosmosDBVectorStoreProperties.class) @@ -53,12 +56,18 @@ public CosmosAsyncClient cosmosClient(CosmosDBVectorStoreProperties properties) .buildAsyncClient(); } + @Bean + @ConditionalOnMissingBean(BatchingStrategy.class) + BatchingStrategy batchingStrategy() { + return new TokenCountBatchingStrategy(); + } + @Bean @ConditionalOnMissingBean public CosmosDBVectorStore cosmosDBVectorStore(ObservationRegistry observationRegistry, ObjectProvider customObservationConvention, CosmosDBVectorStoreProperties properties, CosmosAsyncClient cosmosAsyncClient, - EmbeddingModel embeddingModel) { + EmbeddingModel embeddingModel, BatchingStrategy batchingStrategy) { CosmosDBVectorStoreConfig config = new CosmosDBVectorStoreConfig(); config.setDatabaseName(properties.getDatabaseName()); @@ -67,7 +76,7 @@ public CosmosDBVectorStore cosmosDBVectorStore(ObservationRegistry observationRe config.setVectorStoreThoughput(properties.getVectorStoreThoughput()); config.setVectorDimensions(properties.getVectorDimensions()); return new CosmosDBVectorStore(observationRegistry, customObservationConvention.getIfAvailable(), - cosmosAsyncClient, config, embeddingModel); + cosmosAsyncClient, config, embeddingModel, batchingStrategy); } } diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java index fbbcb1f72a..8dc1f787a7 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java @@ -45,9 +45,9 @@ /** * @author Theo van Kraay + * @author Soby Chacko * @since 1.0.0 */ - public class CosmosDBVectorStore extends AbstractObservationVectorStore implements AutoCloseable { private static final Logger logger = LoggerFactory.getLogger(CosmosDBVectorStore.class); @@ -65,10 +65,17 @@ public class CosmosDBVectorStore extends AbstractObservationVectorStore implemen public CosmosDBVectorStore(ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, CosmosAsyncClient cosmosClient, CosmosDBVectorStoreConfig properties, EmbeddingModel embeddingModel) { + this(observationRegistry, customObservationConvention, cosmosClient, properties, embeddingModel, + new TokenCountBatchingStrategy()); + } + + public CosmosDBVectorStore(ObservationRegistry observationRegistry, + VectorStoreObservationConvention customObservationConvention, CosmosAsyncClient cosmosClient, + CosmosDBVectorStoreConfig properties, EmbeddingModel embeddingModel, BatchingStrategy batchingStrategy) { super(observationRegistry, customObservationConvention); this.cosmosClient = cosmosClient; this.properties = properties; - this.batchingStrategy = new TokenCountBatchingStrategy(); + this.batchingStrategy = batchingStrategy; cosmosClient.createDatabaseIfNotExists(properties.getDatabaseName()).block(); initializeContainer(properties.getContainerName(), properties.getDatabaseName(), @@ -76,7 +83,6 @@ public CosmosDBVectorStore(ObservationRegistry observationRegistry, properties.getPartitionKeyPath()); this.embeddingModel = embeddingModel; - } private void initializeContainer(String containerName, String databaseName, int vectorStoreThoughput, @@ -94,9 +100,7 @@ private void initializeContainer(String containerName, String databaseName, int PartitionKeyDefinition subpartitionKeyDefinition = new PartitionKeyDefinition(); List pathsfromCommaSeparatedList = new ArrayList(); String[] subpartitionKeyPaths = partitionKeyPath.split(","); - for (String path : subpartitionKeyPaths) { - pathsfromCommaSeparatedList.add(path); - } + Collections.addAll(pathsfromCommaSeparatedList, subpartitionKeyPaths); if (subpartitionKeyPaths.length > 1) { subpartitionKeyDefinition.setPaths(pathsfromCommaSeparatedList); subpartitionKeyDefinition.setKind(PartitionKind.MULTI_HASH); @@ -180,7 +184,7 @@ public void doAdd(List documents) { .getCreateItemOperation(mapCosmosDocument(doc, doc.getEmbedding()), new PartitionKey(doc.getId())); return new ImmutablePair<>(doc.getId(), operation); // Pair the document ID // with the operation - }).collect(Collectors.toList()); + }).toList(); try { // Extract just the CosmosItemOperations from the pairs diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java index 4269a0d72d..1e05b6b8e2 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java @@ -42,7 +42,6 @@ * @author Theo van Kraay * @since 1.0.0 */ - @EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_ENDPOINT", matches = ".+") @EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_KEY", matches = ".+") public class CosmosDBVectorStoreIT {