Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1199 - Add UT for recommendation service #1207

Merged
merged 2 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/recommendation-ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ jobs:
- name: Test Results
uses: dorny/test-reporter@v1
if: ${{ env.FROM_ORIGINAL_REPOSITORY == 'true' && (success() || failure()) }}
continue-on-error: true # TODO: remove once defining UT
with:
name: Recommendation-Service-Unit-Test-Results
path: "recommendation/**/*-reports/TEST*.xml"
Expand Down
1 change: 1 addition & 0 deletions common-library/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>

<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-tx</artifactId>
Expand Down
16 changes: 16 additions & 0 deletions common-library/src/it/java/common/container/ContainerFactory.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package common.container;

import dasniko.testcontainers.keycloak.KeycloakContainer;
import org.springframework.boot.testcontainers.service.connection.ServiceConnection;
import org.springframework.context.annotation.Bean;
import org.springframework.test.context.DynamicPropertyRegistry;
import org.testcontainers.containers.KafkaContainer;
import org.testcontainers.containers.PostgreSQLContainer;
import org.testcontainers.utility.DockerImageName;

/**
Expand Down Expand Up @@ -33,6 +36,7 @@ public static KafkaContainer kafkaContainer(DynamicPropertyRegistry registry, St
DockerImageName.parse("confluentinc/cp-kafka:%s".formatted(version))
);
registry.add("spring.kafka.bootstrap-servers", kafkaContainer::getBootstrapServers);
registry.add("bootstrap.servers", kafkaContainer::getBootstrapServers);

// Consumer properties
registry.add("auto.offset.reset", () -> "earliest");
Expand All @@ -43,4 +47,16 @@ public static KafkaContainer kafkaContainer(DynamicPropertyRegistry registry, St
return kafkaContainer;
}

public static PostgreSQLContainer pgvector(DynamicPropertyRegistry registry, String version) {
var image = DockerImageName.parse("pgvector/pgvector:%s".formatted(version))
.asCompatibleSubstituteFor("postgres");
var postgres = new PostgreSQLContainer<>(image);
postgres.start();

registry.add("spring.datasource.url", postgres::getJdbcUrl);
registry.add("spring.datasource.username", postgres::getUsername);
registry.add("spring.datasource.password", postgres::getPassword);
return postgres;
}

}
15 changes: 15 additions & 0 deletions recommendation/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,21 @@
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-oauth2-resource-server</artifactId>
</dependency>

<!-- Test -->
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>kafka</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.yas</groupId>
<artifactId>common-library</artifactId>
<version>${revision}</version>
<classifier>tests</classifier>
<type>test-jar</type>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public abstract class SimpleVectorRepository<D extends BaseDocument, E> implemen
* @param vectorStore vector store service.
*/
@SneakyThrows
public SimpleVectorRepository(Class<D> docType, VectorStore vectorStore) {
protected SimpleVectorRepository(Class<D> docType, VectorStore vectorStore) {
Assert.isTrue(docType.isAnnotationPresent(DocumentMetadata.class),
"Document must be annotated by '@DocumentFormat'");
this.docType = docType;
Expand All @@ -55,15 +55,6 @@ public SimpleVectorRepository(Class<D> docType, VectorStore vectorStore) {
this.documentFormatter = documentMetadata.documentFormatter().getDeclaredConstructor().newInstance();
}

/**
* Retrieves the entity data for a given product ID. It used for
* {@link SimpleVectorRepository#add(Long)}, and {@link SimpleVectorRepository#search(Long)} operation.
*
* @param id the ID.
* @return a map of entity attributes where keys are attribute names and values are their corresponding values.
*/
public abstract E getEntity(Long id);

/**
* Add a record to the vector database by fetching data from an external source.
* This method retrieves the entity using
Expand Down Expand Up @@ -93,8 +84,8 @@ public void add(Long entityId) {
* @param entityId the ID of the entity to be deleted from the vector store
*/
public void delete(Long entityId) {
DefaultIdGenerator defaultIdGenerator = new DefaultIdGenerator(documentMetadata.docIdPrefix(), entityId);
var docId = defaultIdGenerator.generateId();
IdGenerator idGenerator = getIdGenerator(entityId);
var docId = idGenerator.generateId();
vectorStore.delete(List.of(docId));
}

Expand Down Expand Up @@ -133,7 +124,7 @@ public List<D> search(Long id) {
.toList();
}

protected IdGenerator getIdGenerator(Long entityId) {
public IdGenerator getIdGenerator(Long entityId) {
return new DefaultIdGenerator(documentMetadata.docIdPrefix(), entityId);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import java.math.BigDecimal;
import java.util.List;
import lombok.Getter;

@Getter
@lombok.Setter
@lombok.Getter
@JsonIgnoreProperties(ignoreUnknown = true)
public class RelatedProductVm {

@JsonProperty("id")
private Integer productId;
private Long productId;

private String name;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.yas.recommendation.config;

import common.container.ContainerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.boot.testcontainers.service.connection.ServiceConnection;
import org.springframework.context.annotation.Bean;
import org.springframework.test.context.DynamicPropertyRegistry;
import org.testcontainers.containers.KafkaContainer;
import org.testcontainers.containers.PostgreSQLContainer;
import org.testcontainers.utility.DockerImageName;

@TestConfiguration
public class KafkaIntegrationTestConfiguration {

@Value("${kafka.version}")
private String kafkaVersion;

@Value("${pgvector.version}")
private String pgVectorVersion;

@Bean
@ServiceConnection
public KafkaContainer kafkaContainer(DynamicPropertyRegistry registry) {
return ContainerFactory.kafkaContainer(registry, kafkaVersion);
}

@Bean
@ServiceConnection
public PostgreSQLContainer pgvectorContainer(DynamicPropertyRegistry registry) {
return ContainerFactory.pgvector(registry, pgVectorVersion);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package com.yas.recommendation.query;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;

import com.yas.recommendation.config.KafkaIntegrationTestConfiguration;
import com.yas.recommendation.configuration.EmbeddingSearchConfiguration;
import com.yas.recommendation.service.ProductService;
import com.yas.recommendation.vector.product.query.RelatedProductQuery;
import com.yas.recommendation.vector.product.store.ProductVectorRepository;
import com.yas.recommendation.viewmodel.ProductDetailVm;
import com.yas.recommendation.viewmodel.RelatedProductVm;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.context.annotation.Import;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.test.context.TestPropertySource;
import org.testcontainers.junit.jupiter.Testcontainers;

@Testcontainers
@SpringBootTest
@Import(KafkaIntegrationTestConfiguration.class)
@TestPropertySource("classpath:application-test.properties")
public class VectorQueryTest {

@Autowired
private JdbcTemplate jdbcClient;

@Autowired
private VectorStore vectorStore;

@Autowired
private RelatedProductQuery relatedProductQuery;

@Autowired
private ProductVectorRepository productVectorRepository;

@MockBean
private EmbeddingModel embeddingModel;

@MockBean
private ProductService productService;

@MockBean
private EmbeddingSearchConfiguration embeddingSearchConfiguration;

@AfterEach
public void tearDown() {
jdbcClient.execute("DELETE FROM vector_store;");
}

@Test
public void testSimilaritySearch() {
// Given
var productId = -1L;
var similarProductId = -2L;
ProductDetailVm searchedProduct = getProductDetailVm(productId);
ProductDetailVm similarProduct = getProductDetailVm(similarProductId);

// When
when(embeddingSearchConfiguration.topK()).thenReturn(10);
when(embeddingSearchConfiguration.similarityThreshold()).thenReturn(-1D); // force to query all data, not depend on vector compare operation
when(productService.getProductDetail(productId)).thenReturn(searchedProduct);
when(productService.getProductDetail(similarProductId)).thenReturn(similarProduct);
when(embeddingModel.embed(any(Document.class))).thenReturn(randomEmbed());
productVectorRepository.add(productId);
productVectorRepository.add(similarProductId);
List<RelatedProductVm> relatedProductVms = relatedProductQuery.similaritySearch(-2L);

// Then
assertFalse(relatedProductVms.isEmpty());
}

private static float @NotNull [] randomEmbed() {
int size = 1536;
float[] floatArray = new float[size];
Random random = new Random();
for (int i = 0; i < size; i++) {
floatArray[i] = random.nextFloat();
}
return floatArray;
}

private static @NotNull ProductDetailVm getProductDetailVm(long productId) {
return new ProductDetailVm(
productId,
"IPhone 14 Pro",
"Latest iPhone model",
"The iPhone 14 Pro comes with the latest technology...",
"6.1-inch display, A16 Bionic chip, 128GB Storage",
"IPH14PRO",
"0123456789012",
"iphone-14-pro",
true,
true,
true,
true,
true,
999.99,
101L,
Collections.emptyList(),
"iPhone 14 Pro",
"iPhone, Apple, Smartphone",
"Buy the latest iPhone 14 Pro...",
1L,
"Apple",
Collections.emptyList(),
null,
null,
null
);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package com.yas.recommendation.store;

import static com.yas.recommendation.vector.common.store.SimpleVectorRepository.TYPE_METADATA;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.yas.recommendation.configuration.EmbeddingSearchConfiguration;
import com.yas.recommendation.vector.common.document.BaseDocument;
import com.yas.recommendation.vector.common.document.DocumentMetadata;
import com.yas.recommendation.vector.common.formatter.DocumentFormatter;
import java.util.Map;
import lombok.Getter;
import lombok.SneakyThrows;
import org.junit.jupiter.api.Assertions;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.beans.factory.annotation.Autowired;

public class BaseVectorRepositoryTest<D extends BaseDocument, E> {

private final Class<D> docClass;
private final DocumentMetadata documentMetadata;

@Getter
private final DocumentFormatter documentFormatter;

@Getter
@Autowired
private ObjectMapper objectMapper;

@Autowired
private EmbeddingSearchConfiguration embeddingSearchConf;

@SneakyThrows
public BaseVectorRepositoryTest(Class<D> docClass) {
Assertions.assertNotNull(docClass, "Document must not be 'null'");
this.docClass = docClass;
this.documentMetadata = getDocumentMetadata();
this.documentFormatter = documentMetadata
.documentFormatter()
.getDeclaredConstructor()
.newInstance();
}

public DocumentMetadata getDocumentMetadata() {
assertTrue(
docClass.isAnnotationPresent(DocumentMetadata.class),
"Document must be annotated by 'DocumentMetadata'"
);
return docClass.getAnnotation(DocumentMetadata.class);
}

public void assertDocumentData(Document createdDoc, E entity) {
var expectedContent = getFormatEntity(entity);
assertEquals(expectedContent, createdDoc.getContent(), "Document format must be formated at declared metadata");
assertNotNull(createdDoc.getMetadata(), "Document's metadata must not be null");
assertFalse(createdDoc.getMetadata().isEmpty(), "Document's metadata must not be empty");

var expectedMetadata = objectMapper.convertValue(entity, Map.class);
expectedMetadata.put(TYPE_METADATA, documentMetadata.docIdPrefix());
assertEquals(expectedMetadata.keySet(), createdDoc.getMetadata().keySet());
}

public void assertSearchRequest(SearchRequest searchRequest, E entity) {
assertNotNull(searchRequest.query, "Search query must be created");
assertEquals(getFormatEntity(entity), searchRequest.query, "Search's Query must be formatted correctly");
assertEquals(searchRequest.getTopK(), embeddingSearchConf.topK(), "Search's top K must be configured");
assertEquals(
searchRequest.getSimilarityThreshold(),
embeddingSearchConf.similarityThreshold(),
"Search's top K must be configured"
);
assertNotNull(searchRequest.getFilterExpression(), "Search filter default must be specified");
assertEquals(
searchRequest.getFilterExpression().type(),
Filter.ExpressionType.NE,
"Search filter default must be correctly"
);

Filter.Key key = (Filter.Key) searchRequest.getFilterExpression().left();
assertEquals(key.key(), "id", "Search filter default must be correctly");
}

private String getFormatEntity(E entity) {
return documentFormatter.format(
objectMapper.convertValue(entity, Map.class),
documentMetadata.contentFormat(),
objectMapper
);
}

}
Loading
Loading