Skip to content

Commit

Permalink
CosmosDB vector store auto configuration changes
Browse files Browse the repository at this point in the history
- Configurable BatchingStrategy via auto configuraiton
- Minor code cleanup
  • Loading branch information
sobychacko committed Oct 22, 2024
1 parent 745e718 commit c979d1c
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
package org.springframework.ai.autoconfigure.vectorstore.cosmosdb;

import com.azure.cosmos.CosmosClientBuilder;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.CosmosDBVectorStore;
import org.springframework.ai.vectorstore.CosmosDBVectorStoreConfig;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
Expand All @@ -32,9 +35,9 @@

/**
* @author Theo van Kraay
* @author Soby Chacko
* @since 1.0.0
*/

@AutoConfiguration
@ConditionalOnClass({ CosmosDBVectorStore.class, EmbeddingModel.class, CosmosAsyncClient.class })
@EnableConfigurationProperties(CosmosDBVectorStoreProperties.class)
Expand All @@ -53,12 +56,18 @@ public CosmosAsyncClient cosmosClient(CosmosDBVectorStoreProperties properties)
.buildAsyncClient();
}

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy batchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
public CosmosDBVectorStore cosmosDBVectorStore(ObservationRegistry observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
CosmosDBVectorStoreProperties properties, CosmosAsyncClient cosmosAsyncClient,
EmbeddingModel embeddingModel) {
EmbeddingModel embeddingModel, BatchingStrategy batchingStrategy) {

CosmosDBVectorStoreConfig config = new CosmosDBVectorStoreConfig();
config.setDatabaseName(properties.getDatabaseName());
Expand All @@ -67,7 +76,7 @@ public CosmosDBVectorStore cosmosDBVectorStore(ObservationRegistry observationRe
config.setVectorStoreThoughput(properties.getVectorStoreThoughput());
config.setVectorDimensions(properties.getVectorDimensions());
return new CosmosDBVectorStore(observationRegistry, customObservationConvention.getIfAvailable(),
cosmosAsyncClient, config, embeddingModel);
cosmosAsyncClient, config, embeddingModel, batchingStrategy);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@

/**
* @author Theo van Kraay
* @author Soby Chacko
* @since 1.0.0
*/

public class CosmosDBVectorStore extends AbstractObservationVectorStore implements AutoCloseable {

private static final Logger logger = LoggerFactory.getLogger(CosmosDBVectorStore.class);
Expand All @@ -65,18 +65,24 @@ public class CosmosDBVectorStore extends AbstractObservationVectorStore implemen
public CosmosDBVectorStore(ObservationRegistry observationRegistry,
VectorStoreObservationConvention customObservationConvention, CosmosAsyncClient cosmosClient,
CosmosDBVectorStoreConfig properties, EmbeddingModel embeddingModel) {
this(observationRegistry, customObservationConvention, cosmosClient, properties, embeddingModel,
new TokenCountBatchingStrategy());
}

public CosmosDBVectorStore(ObservationRegistry observationRegistry,
VectorStoreObservationConvention customObservationConvention, CosmosAsyncClient cosmosClient,
CosmosDBVectorStoreConfig properties, EmbeddingModel embeddingModel, BatchingStrategy batchingStrategy) {
super(observationRegistry, customObservationConvention);
this.cosmosClient = cosmosClient;
this.properties = properties;
this.batchingStrategy = new TokenCountBatchingStrategy();
this.batchingStrategy = batchingStrategy;
cosmosClient.createDatabaseIfNotExists(properties.getDatabaseName()).block();

initializeContainer(properties.getContainerName(), properties.getDatabaseName(),
properties.getVectorStoreThoughput(), properties.getVectorDimensions(),
properties.getPartitionKeyPath());

this.embeddingModel = embeddingModel;

}

private void initializeContainer(String containerName, String databaseName, int vectorStoreThoughput,
Expand All @@ -94,9 +100,7 @@ private void initializeContainer(String containerName, String databaseName, int
PartitionKeyDefinition subpartitionKeyDefinition = new PartitionKeyDefinition();
List<String> pathsfromCommaSeparatedList = new ArrayList<String>();
String[] subpartitionKeyPaths = partitionKeyPath.split(",");
for (String path : subpartitionKeyPaths) {
pathsfromCommaSeparatedList.add(path);
}
Collections.addAll(pathsfromCommaSeparatedList, subpartitionKeyPaths);
if (subpartitionKeyPaths.length > 1) {
subpartitionKeyDefinition.setPaths(pathsfromCommaSeparatedList);
subpartitionKeyDefinition.setKind(PartitionKind.MULTI_HASH);
Expand Down Expand Up @@ -180,7 +184,7 @@ public void doAdd(List<Document> documents) {
.getCreateItemOperation(mapCosmosDocument(doc, doc.getEmbedding()), new PartitionKey(doc.getId()));
return new ImmutablePair<>(doc.getId(), operation); // Pair the document ID
// with the operation
}).collect(Collectors.toList());
}).toList();

try {
// Extract just the CosmosItemOperations from the pairs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
* @author Theo van Kraay
* @since 1.0.0
*/

@EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_ENDPOINT", matches = ".+")
@EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_KEY", matches = ".+")
public class CosmosDBVectorStoreIT {
Expand Down

0 comments on commit c979d1c

Please sign in to comment.