Skip to content

Commit

Permalink
#1199 - Add UT for recommendation service (#1207)
Browse files Browse the repository at this point in the history
* #1199 - add UT

* #1199 - sonar fix

---------

Co-authored-by: Duy Le Van <[email protected]>
  • Loading branch information
duylv27 and Duy Le Van authored Nov 1, 2024
1 parent fd49ac9 commit 9e9e80f
Show file tree
Hide file tree
Showing 14 changed files with 2,451 additions and 17 deletions.
1 change: 0 additions & 1 deletion .github/workflows/recommendation-ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ jobs:
- name: Test Results
uses: dorny/test-reporter@v1
if: ${{ env.FROM_ORIGINAL_REPOSITORY == 'true' && (success() || failure()) }}
continue-on-error: true # TODO: remove once defining UT
with:
name: Recommendation-Service-Unit-Test-Results
path: "recommendation/**/*-reports/TEST*.xml"
Expand Down
1 change: 1 addition & 0 deletions common-library/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>

<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-tx</artifactId>
Expand Down
16 changes: 16 additions & 0 deletions common-library/src/it/java/common/container/ContainerFactory.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package common.container;

import dasniko.testcontainers.keycloak.KeycloakContainer;
import org.springframework.boot.testcontainers.service.connection.ServiceConnection;
import org.springframework.context.annotation.Bean;
import org.springframework.test.context.DynamicPropertyRegistry;
import org.testcontainers.containers.KafkaContainer;
import org.testcontainers.containers.PostgreSQLContainer;
import org.testcontainers.utility.DockerImageName;

/**
Expand Down Expand Up @@ -33,6 +36,7 @@ public static KafkaContainer kafkaContainer(DynamicPropertyRegistry registry, St
DockerImageName.parse("confluentinc/cp-kafka:%s".formatted(version))
);
registry.add("spring.kafka.bootstrap-servers", kafkaContainer::getBootstrapServers);
registry.add("bootstrap.servers", kafkaContainer::getBootstrapServers);

// Consumer properties
registry.add("auto.offset.reset", () -> "earliest");
Expand All @@ -43,4 +47,16 @@ public static KafkaContainer kafkaContainer(DynamicPropertyRegistry registry, St
return kafkaContainer;
}

public static PostgreSQLContainer pgvector(DynamicPropertyRegistry registry, String version) {
var image = DockerImageName.parse("pgvector/pgvector:%s".formatted(version))
.asCompatibleSubstituteFor("postgres");
var postgres = new PostgreSQLContainer<>(image);
postgres.start();

registry.add("spring.datasource.url", postgres::getJdbcUrl);
registry.add("spring.datasource.username", postgres::getUsername);
registry.add("spring.datasource.password", postgres::getPassword);
return postgres;
}

}
15 changes: 15 additions & 0 deletions recommendation/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,21 @@
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-oauth2-resource-server</artifactId>
</dependency>

<!-- Test -->
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>kafka</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.yas</groupId>
<artifactId>common-library</artifactId>
<version>${revision}</version>
<classifier>tests</classifier>
<type>test-jar</type>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public abstract class SimpleVectorRepository<D extends BaseDocument, E> implemen
* @param vectorStore vector store service.
*/
@SneakyThrows
public SimpleVectorRepository(Class<D> docType, VectorStore vectorStore) {
protected SimpleVectorRepository(Class<D> docType, VectorStore vectorStore) {
Assert.isTrue(docType.isAnnotationPresent(DocumentMetadata.class),
"Document must be annotated by '@DocumentFormat'");
this.docType = docType;
Expand All @@ -55,15 +55,6 @@ public SimpleVectorRepository(Class<D> docType, VectorStore vectorStore) {
this.documentFormatter = documentMetadata.documentFormatter().getDeclaredConstructor().newInstance();
}

/**
* Retrieves the entity data for a given product ID. It used for
* {@link SimpleVectorRepository#add(Long)}, and {@link SimpleVectorRepository#search(Long)} operation.
*
* @param id the ID.
* @return a map of entity attributes where keys are attribute names and values are their corresponding values.
*/
public abstract E getEntity(Long id);

/**
* Add a record to the vector database by fetching data from an external source.
* This method retrieves the entity using
Expand Down Expand Up @@ -93,8 +84,8 @@ public void add(Long entityId) {
* @param entityId the ID of the entity to be deleted from the vector store
*/
public void delete(Long entityId) {
DefaultIdGenerator defaultIdGenerator = new DefaultIdGenerator(documentMetadata.docIdPrefix(), entityId);
var docId = defaultIdGenerator.generateId();
IdGenerator idGenerator = getIdGenerator(entityId);
var docId = idGenerator.generateId();
vectorStore.delete(List.of(docId));
}

Expand Down Expand Up @@ -133,7 +124,7 @@ public List<D> search(Long id) {
.toList();
}

protected IdGenerator getIdGenerator(Long entityId) {
public IdGenerator getIdGenerator(Long entityId) {
return new DefaultIdGenerator(documentMetadata.docIdPrefix(), entityId);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import java.math.BigDecimal;
import java.util.List;
import lombok.Getter;

@Getter
@lombok.Setter
@lombok.Getter
@JsonIgnoreProperties(ignoreUnknown = true)
public class RelatedProductVm {

@JsonProperty("id")
private Integer productId;
private Long productId;

private String name;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.yas.recommendation.config;

import common.container.ContainerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.boot.testcontainers.service.connection.ServiceConnection;
import org.springframework.context.annotation.Bean;
import org.springframework.test.context.DynamicPropertyRegistry;
import org.testcontainers.containers.KafkaContainer;
import org.testcontainers.containers.PostgreSQLContainer;
import org.testcontainers.utility.DockerImageName;

@TestConfiguration
public class KafkaIntegrationTestConfiguration {

@Value("${kafka.version}")
private String kafkaVersion;

@Value("${pgvector.version}")
private String pgVectorVersion;

@Bean
@ServiceConnection
public KafkaContainer kafkaContainer(DynamicPropertyRegistry registry) {
return ContainerFactory.kafkaContainer(registry, kafkaVersion);
}

@Bean
@ServiceConnection
public PostgreSQLContainer pgvectorContainer(DynamicPropertyRegistry registry) {
return ContainerFactory.pgvector(registry, pgVectorVersion);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package com.yas.recommendation.query;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;

import com.yas.recommendation.config.KafkaIntegrationTestConfiguration;
import com.yas.recommendation.configuration.EmbeddingSearchConfiguration;
import com.yas.recommendation.service.ProductService;
import com.yas.recommendation.vector.product.query.RelatedProductQuery;
import com.yas.recommendation.vector.product.store.ProductVectorRepository;
import com.yas.recommendation.viewmodel.ProductDetailVm;
import com.yas.recommendation.viewmodel.RelatedProductVm;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.context.annotation.Import;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.test.context.TestPropertySource;
import org.testcontainers.junit.jupiter.Testcontainers;

@Testcontainers
@SpringBootTest
@Import(KafkaIntegrationTestConfiguration.class)
@TestPropertySource("classpath:application-test.properties")
public class VectorQueryTest {

@Autowired
private JdbcTemplate jdbcClient;

@Autowired
private VectorStore vectorStore;

@Autowired
private RelatedProductQuery relatedProductQuery;

@Autowired
private ProductVectorRepository productVectorRepository;

@MockBean
private EmbeddingModel embeddingModel;

@MockBean
private ProductService productService;

@MockBean
private EmbeddingSearchConfiguration embeddingSearchConfiguration;

@AfterEach
public void tearDown() {
jdbcClient.execute("DELETE FROM vector_store;");
}

@Test
public void testSimilaritySearch() {
// Given
var productId = -1L;
var similarProductId = -2L;
ProductDetailVm searchedProduct = getProductDetailVm(productId);
ProductDetailVm similarProduct = getProductDetailVm(similarProductId);

// When
when(embeddingSearchConfiguration.topK()).thenReturn(10);
when(embeddingSearchConfiguration.similarityThreshold()).thenReturn(-1D); // force to query all data, not depend on vector compare operation
when(productService.getProductDetail(productId)).thenReturn(searchedProduct);
when(productService.getProductDetail(similarProductId)).thenReturn(similarProduct);
when(embeddingModel.embed(any(Document.class))).thenReturn(randomEmbed());
productVectorRepository.add(productId);
productVectorRepository.add(similarProductId);
List<RelatedProductVm> relatedProductVms = relatedProductQuery.similaritySearch(-2L);

// Then
assertFalse(relatedProductVms.isEmpty());
}

private static float @NotNull [] randomEmbed() {
int size = 1536;
float[] floatArray = new float[size];
Random random = new Random();
for (int i = 0; i < size; i++) {
floatArray[i] = random.nextFloat();
}
return floatArray;
}

private static @NotNull ProductDetailVm getProductDetailVm(long productId) {
return new ProductDetailVm(
productId,
"IPhone 14 Pro",
"Latest iPhone model",
"The iPhone 14 Pro comes with the latest technology...",
"6.1-inch display, A16 Bionic chip, 128GB Storage",
"IPH14PRO",
"0123456789012",
"iphone-14-pro",
true,
true,
true,
true,
true,
999.99,
101L,
Collections.emptyList(),
"iPhone 14 Pro",
"iPhone, Apple, Smartphone",
"Buy the latest iPhone 14 Pro...",
1L,
"Apple",
Collections.emptyList(),
null,
null,
null
);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package com.yas.recommendation.store;

import static com.yas.recommendation.vector.common.store.SimpleVectorRepository.TYPE_METADATA;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.yas.recommendation.configuration.EmbeddingSearchConfiguration;
import com.yas.recommendation.vector.common.document.BaseDocument;
import com.yas.recommendation.vector.common.document.DocumentMetadata;
import com.yas.recommendation.vector.common.formatter.DocumentFormatter;
import java.util.Map;
import lombok.Getter;
import lombok.SneakyThrows;
import org.junit.jupiter.api.Assertions;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.beans.factory.annotation.Autowired;

public class BaseVectorRepositoryTest<D extends BaseDocument, E> {

private final Class<D> docClass;
private final DocumentMetadata documentMetadata;

@Getter
private final DocumentFormatter documentFormatter;

@Getter
@Autowired
private ObjectMapper objectMapper;

@Autowired
private EmbeddingSearchConfiguration embeddingSearchConf;

@SneakyThrows
public BaseVectorRepositoryTest(Class<D> docClass) {
Assertions.assertNotNull(docClass, "Document must not be 'null'");
this.docClass = docClass;
this.documentMetadata = getDocumentMetadata();
this.documentFormatter = documentMetadata
.documentFormatter()
.getDeclaredConstructor()
.newInstance();
}

public DocumentMetadata getDocumentMetadata() {
assertTrue(
docClass.isAnnotationPresent(DocumentMetadata.class),
"Document must be annotated by 'DocumentMetadata'"
);
return docClass.getAnnotation(DocumentMetadata.class);
}

public void assertDocumentData(Document createdDoc, E entity) {
var expectedContent = getFormatEntity(entity);
assertEquals(expectedContent, createdDoc.getContent(), "Document format must be formated at declared metadata");
assertNotNull(createdDoc.getMetadata(), "Document's metadata must not be null");
assertFalse(createdDoc.getMetadata().isEmpty(), "Document's metadata must not be empty");

var expectedMetadata = objectMapper.convertValue(entity, Map.class);
expectedMetadata.put(TYPE_METADATA, documentMetadata.docIdPrefix());
assertEquals(expectedMetadata.keySet(), createdDoc.getMetadata().keySet());
}

public void assertSearchRequest(SearchRequest searchRequest, E entity) {
assertNotNull(searchRequest.query, "Search query must be created");
assertEquals(getFormatEntity(entity), searchRequest.query, "Search's Query must be formatted correctly");
assertEquals(searchRequest.getTopK(), embeddingSearchConf.topK(), "Search's top K must be configured");
assertEquals(
searchRequest.getSimilarityThreshold(),
embeddingSearchConf.similarityThreshold(),
"Search's top K must be configured"
);
assertNotNull(searchRequest.getFilterExpression(), "Search filter default must be specified");
assertEquals(
searchRequest.getFilterExpression().type(),
Filter.ExpressionType.NE,
"Search filter default must be correctly"
);

Filter.Key key = (Filter.Key) searchRequest.getFilterExpression().left();
assertEquals(key.key(), "id", "Search filter default must be correctly");
}

private String getFormatEntity(E entity) {
return documentFormatter.format(
objectMapper.convertValue(entity, Map.class),
documentMetadata.contentFormat(),
objectMapper
);
}

}
Loading

0 comments on commit 9e9e80f

Please sign in to comment.