From a309c63d7829269387f6189acb61a85ed7d5e7d5 Mon Sep 17 00:00:00 2001 From: Nabil Benhammou Date: Sat, 6 Apr 2024 22:53:09 +0200 Subject: [PATCH] test: added mock server for simulating open ai calls for integration test --- pom.xml | 31 ++++++ src/main/java/dev/nano/mcc/MCCAssistant.java | 34 ++----- src/main/java/dev/nano/mcc/OpenAIClient.java | 43 ++++++--- .../dev/nano/mcc/client/AIClientPort.java | 7 ++ .../mcc/MCCApplicationIntegrationTests.java | 94 +++++++++++++++++++ .../java/dev/nano/mcc/MCCAssistantTest.java | 50 ++++++++++ src/test/resources/application.yaml | 5 + 7 files changed, 227 insertions(+), 37 deletions(-) create mode 100644 src/main/java/dev/nano/mcc/client/AIClientPort.java create mode 100644 src/test/java/dev/nano/mcc/MCCApplicationIntegrationTests.java create mode 100644 src/test/java/dev/nano/mcc/MCCAssistantTest.java create mode 100644 src/test/resources/application.yaml diff --git a/pom.xml b/pom.xml index 0afc157..9c4a8ce 100644 --- a/pom.xml +++ b/pom.xml @@ -18,6 +18,10 @@ 21 0.8.1 2.5.5 + 5.10.2 + 3.25.3 + 5.2.0 + 5.15.0 @@ -73,6 +77,33 @@ spring-boot-starter-test test + + + org.junit.jupiter + junit-jupiter-api + ${junit-jupiter-api.version} + test + + + + org.assertj + assertj-core + ${assertj-core.version} + test + + + org.mockito + mockito-inline + ${mockito.version} + test + + + org.mock-server + mockserver-netty + ${mockserver-netty.version} + test + + diff --git a/src/main/java/dev/nano/mcc/MCCAssistant.java b/src/main/java/dev/nano/mcc/MCCAssistant.java index 2b751cf..b2a42ee 100644 --- a/src/main/java/dev/nano/mcc/MCCAssistant.java +++ b/src/main/java/dev/nano/mcc/MCCAssistant.java @@ -1,42 +1,24 @@ package dev.nano.mcc; -import lombok.RequiredArgsConstructor; +import dev.nano.mcc.client.AIClientPort; import lombok.extern.slf4j.Slf4j; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.SystemMessage; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.SystemPromptTemplate; -import org.springframework.ai.image.ImagePrompt; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.core.io.Resource; import org.springframework.stereotype.Service; -import java.util.List; -import java.util.Map; - @Service -@RequiredArgsConstructor @Slf4j public class MCCAssistant { - @Value("classpath:/prompt/system-prompt.st") - private Resource systemPrompt; + private final AIClientPort aiClientPort; - private final OpenAIClient openAIClient; + public MCCAssistant(AIClientPort aiClientPort) { + this.aiClientPort = aiClientPort; + } public String getRecipes(String dishName) { - - SystemMessage systemMessage = new SystemMessage(this.systemPrompt); - UserMessage userMessage = new UserMessage("Can you provide a recipe for + " + dishName + "?"); - - Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); - - return openAIClient.getOpenAiChatClient().call(prompt).getResult().getOutput().getContent(); + return aiClientPort.generateRecipe("Can you provide a recipe for " + dishName + "?"); } - public String getDishImage(String dishName) { - ImagePrompt imagePrompt = new ImagePrompt("Generate an image of a Moroccan dish called " + dishName); - return openAIClient.getOpenAiImageClient().call(imagePrompt).getResult().getOutput().getUrl(); + public String getDishImage(String dishImageRequest) { + return aiClientPort.generateDishImage("Generate an image of a Moroccan dish called " + dishImageRequest); } } diff --git a/src/main/java/dev/nano/mcc/OpenAIClient.java b/src/main/java/dev/nano/mcc/OpenAIClient.java index 3118b20..1ad3701 100644 --- a/src/main/java/dev/nano/mcc/OpenAIClient.java +++ b/src/main/java/dev/nano/mcc/OpenAIClient.java @@ -1,6 +1,10 @@ package dev.nano.mcc; -import org.springframework.ai.image.ImageOptionsBuilder; +import dev.nano.mcc.client.AIClientPort; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.OpenAiImageClient; @@ -8,31 +12,48 @@ import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiImageApi; import org.springframework.beans.factory.annotation.Value; +import org.springframework.core.io.Resource; import org.springframework.retry.support.RetryTemplate; -import org.springframework.stereotype.Component; +import org.springframework.stereotype.Repository; +import org.springframework.web.client.RestClient; -@Component -public final class OpenAIClient { +import java.util.List; + +@Repository +public class OpenAIClient implements AIClientPort { @Value("${spring.ai.openai.api-key}") String apiKey; + + @Value("${spring.ai.openai.base-url}") + String baseUrl; + + private final Resource systemPrompt; + + + public OpenAIClient(@Value("classpath:/prompt/system-prompt.st") Resource systemPrompt) { + this.systemPrompt = systemPrompt; + } - public OpenAiChatClient getOpenAiChatClient() { - OpenAiApi openAiApi = new OpenAiApi(apiKey); - var options = new OpenAiChatOptions.Builder() + public String generateRecipe(String instructionRecipe) { + Prompt prompt = new Prompt(List.of(new SystemMessage(this.systemPrompt), new UserMessage(instructionRecipe))); + OpenAiApi openAiApi = new OpenAiApi(baseUrl, apiKey); + OpenAiChatOptions options = new OpenAiChatOptions.Builder() .withModel("gpt-4") .build(); - return new OpenAiChatClient(openAiApi, options); + return new OpenAiChatClient(openAiApi, options).call(prompt).getResult().getOutput().getContent(); } - public OpenAiImageClient getOpenAiImageClient() { - OpenAiImageApi openAiApi = new OpenAiImageApi(apiKey); + @Override + public String generateDishImage(String instructionDishImage) { + OpenAiImageApi openAiApi = new OpenAiImageApi(baseUrl, apiKey, RestClient.builder()); var options = OpenAiImageOptions.builder() .withQuality("hd") .withHeight(1024).withWidth(1024) .withResponseFormat("url") .withModel("dall-e-3") .build(); - return new OpenAiImageClient(openAiApi, options, RetryTemplate.builder().build()); + OpenAiImageClient openAiImageClient = new OpenAiImageClient(openAiApi, options, RetryTemplate.builder().build()); + return openAiImageClient.call(new ImagePrompt(instructionDishImage)).getResult().getOutput().getUrl(); } } diff --git a/src/main/java/dev/nano/mcc/client/AIClientPort.java b/src/main/java/dev/nano/mcc/client/AIClientPort.java new file mode 100644 index 0000000..a219ec4 --- /dev/null +++ b/src/main/java/dev/nano/mcc/client/AIClientPort.java @@ -0,0 +1,7 @@ +package dev.nano.mcc.client; + +public interface AIClientPort { + + String generateRecipe(String instructionRecipe); + String generateDishImage(String instructionDishImage); +} diff --git a/src/test/java/dev/nano/mcc/MCCApplicationIntegrationTests.java b/src/test/java/dev/nano/mcc/MCCApplicationIntegrationTests.java new file mode 100644 index 0000000..8397606 --- /dev/null +++ b/src/test/java/dev/nano/mcc/MCCApplicationIntegrationTests.java @@ -0,0 +1,94 @@ +package dev.nano.mcc; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.netty.handler.codec.http.HttpHeaderNames; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockserver.integration.ClientAndServer; +import org.mockserver.model.MediaType; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiImageApi; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockserver.integration.ClientAndServer.startClientAndServer; +import static org.mockserver.model.HttpRequest.request; +import static org.mockserver.model.HttpResponse.response; + +@SpringBootTest(classes = MCCApplication.class) +class MCCApplicationIntegrationTests { + + @Autowired + MCCAssistant mccAssistant; + + private final ObjectMapper objectMapper = new ObjectMapper(); + + private static final ClientAndServer mockServer = startClientAndServer(2445); + + private static final String TEST_DISH_NAME = "couscous with seven vegetables"; + + @BeforeEach + void setup() { + mockServer.reset(); + } + + @Test + void generateDishImage() throws JsonProcessingException { + + OpenAiImageApi.OpenAiImageResponse mockedResponse = new OpenAiImageApi.OpenAiImageResponse( + 20L, + List.of(new OpenAiImageApi.Data( + "https://openai.com/image/dish_generated_url.png", + "base64_encoding_value", + "revised_prompt_value" + )) + ); + + mockOpenAiGenerativeResponses("/v1/images/generations", objectMapper.writeValueAsString(mockedResponse)); + String imageUrl = mccAssistant.getDishImage(TEST_DISH_NAME); + assertThat(imageUrl).isNotNull().isEqualTo("https://openai.com/image/dish_generated_url.png"); + System.out.println("image url: " + imageUrl); + } + + + + @Test + void generateRecipes() throws JsonProcessingException { + + OpenAiApi.ChatCompletion mockedResponse = new OpenAiApi.ChatCompletion( + "id_value", + List.of(new OpenAiApi.ChatCompletion + .Choice( + OpenAiApi.ChatCompletionFinishReason.STOP, + 1, + new OpenAiApi.ChatCompletionMessage("Detailed dish of a couscous recipe with seven vegetables", null), + null)), + 10L, + "gpt-4", + "systemFingerPrint", + null, + new OpenAiApi.Usage(1, 2, 3) + ); + mockOpenAiGenerativeResponses("/v1/chat/completions", objectMapper.writeValueAsString(mockedResponse)); + + String recipe = mccAssistant.getRecipes(TEST_DISH_NAME); + assertThat(recipe) + .isNotNull() + .isEqualTo("Detailed dish of a couscous recipe with seven vegetables"); + } + + private void mockOpenAiGenerativeResponses(String path, String objectMapper) throws JsonProcessingException { + mockServer.when(request().withMethod("POST").withPath(path)) + .respond( + response() + .withStatusCode(200) + .withHeader(HttpHeaderNames.CONTENT_TYPE.toString(), MediaType.APPLICATION_JSON.toString()) + .withBody(objectMapper)); + } + + +} diff --git a/src/test/java/dev/nano/mcc/MCCAssistantTest.java b/src/test/java/dev/nano/mcc/MCCAssistantTest.java new file mode 100644 index 0000000..649ab24 --- /dev/null +++ b/src/test/java/dev/nano/mcc/MCCAssistantTest.java @@ -0,0 +1,50 @@ +package dev.nano.mcc; + +import dev.nano.mcc.client.AIClientPort; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import static org.mockito.BDDMockito.given; + +@ExtendWith(MockitoExtension.class ) +class MCCAssistantTest { + @InjectMocks + MCCAssistant mccAssistant; + + @Mock + AIClientPort aiClientPort; + + + @Test + void shouldReturnDetailedRecipeWhenProvidingRecipeRequest() { + + String recipeRequest = "Couscous with seven vegetables"; + + String expected = "Detailed couscous with seven vegetables recipe"; + + given(aiClientPort.generateRecipe("Can you provide a recipe for " + recipeRequest + "?")).willReturn(expected); + + String result = mccAssistant.getRecipes(recipeRequest); + + Assertions.assertThat(result).isEqualTo(expected); + + + } + + @Test + void shouldReturnDishImageWhenProvidingDishImageRequest() { + String dishImageRequest = "Couscous with seven vegetables"; + + String expectedUrl = "https://openai.com/image/generated-dish.png"; + + given(aiClientPort.generateDishImage("Generate an image of a Moroccan dish called " + dishImageRequest)).willReturn(expectedUrl); + + String result = mccAssistant.getDishImage(dishImageRequest); + + Assertions.assertThat(result).isEqualTo(expectedUrl); + } +} diff --git a/src/test/resources/application.yaml b/src/test/resources/application.yaml new file mode 100644 index 0000000..8615869 --- /dev/null +++ b/src/test/resources/application.yaml @@ -0,0 +1,5 @@ +spring: + ai: + openai: + api-key: "dummy-key" + base-url: "http://localhost:2445" \ No newline at end of file