Skip to content

Commit

Permalink
test: added mock server for simulating open ai calls for integration …
Browse files Browse the repository at this point in the history
…test
  • Loading branch information
FeiRoN23 committed Apr 6, 2024
1 parent ab6007e commit a309c63
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 37 deletions.
31 changes: 31 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
<java.version>21</java.version>
<spring-ai.version>0.8.1</spring-ai.version>
<hilla.version>2.5.5</hilla.version>
<junit-jupiter-api.version>5.10.2</junit-jupiter-api.version>
<assertj-core.version>3.25.3</assertj-core.version>
<mockito.version>5.2.0</mockito.version>
<mockserver-netty.version>5.15.0</mockserver-netty.version>
</properties>


Expand Down Expand Up @@ -73,6 +77,33 @@
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<version>${junit-jupiter-api.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<version>${assertj-core.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<version>${mockito.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mock-server</groupId>
<artifactId>mockserver-netty</artifactId>
<version>${mockserver-netty.version}</version>
<scope>test</scope>
</dependency>

</dependencies>

<build>
Expand Down
34 changes: 8 additions & 26 deletions src/main/java/dev/nano/mcc/MCCAssistant.java
Original file line number Diff line number Diff line change
@@ -1,42 +1,24 @@
package dev.nano.mcc;

import lombok.RequiredArgsConstructor;
import dev.nano.mcc.client.AIClientPort;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;

import java.util.List;
import java.util.Map;

@Service
@RequiredArgsConstructor
@Slf4j
public class MCCAssistant {

@Value("classpath:/prompt/system-prompt.st")
private Resource systemPrompt;
private final AIClientPort aiClientPort;

private final OpenAIClient openAIClient;
public MCCAssistant(AIClientPort aiClientPort) {
this.aiClientPort = aiClientPort;
}

public String getRecipes(String dishName) {

SystemMessage systemMessage = new SystemMessage(this.systemPrompt);
UserMessage userMessage = new UserMessage("Can you provide a recipe for + " + dishName + "?");

Prompt prompt = new Prompt(List.of(systemMessage, userMessage));

return openAIClient.getOpenAiChatClient().call(prompt).getResult().getOutput().getContent();
return aiClientPort.generateRecipe("Can you provide a recipe for " + dishName + "?");
}

public String getDishImage(String dishName) {
ImagePrompt imagePrompt = new ImagePrompt("Generate an image of a Moroccan dish called " + dishName);
return openAIClient.getOpenAiImageClient().call(imagePrompt).getResult().getOutput().getUrl();
public String getDishImage(String dishImageRequest) {
return aiClientPort.generateDishImage("Generate an image of a Moroccan dish called " + dishImageRequest);
}
}
43 changes: 32 additions & 11 deletions src/main/java/dev/nano/mcc/OpenAIClient.java
Original file line number Diff line number Diff line change
@@ -1,38 +1,59 @@
package dev.nano.mcc;

import org.springframework.ai.image.ImageOptionsBuilder;
import dev.nano.mcc.client.AIClientPort;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.OpenAiImageClient;
import org.springframework.ai.openai.OpenAiImageOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.stereotype.Component;
import org.springframework.stereotype.Repository;
import org.springframework.web.client.RestClient;

@Component
public final class OpenAIClient {
import java.util.List;

@Repository
public class OpenAIClient implements AIClientPort {

@Value("${spring.ai.openai.api-key}")
String apiKey;

@Value("${spring.ai.openai.base-url}")
String baseUrl;

private final Resource systemPrompt;


public OpenAIClient(@Value("classpath:/prompt/system-prompt.st") Resource systemPrompt) {
this.systemPrompt = systemPrompt;
}

public OpenAiChatClient getOpenAiChatClient() {
OpenAiApi openAiApi = new OpenAiApi(apiKey);
var options = new OpenAiChatOptions.Builder()
public String generateRecipe(String instructionRecipe) {
Prompt prompt = new Prompt(List.of(new SystemMessage(this.systemPrompt), new UserMessage(instructionRecipe)));
OpenAiApi openAiApi = new OpenAiApi(baseUrl, apiKey);
OpenAiChatOptions options = new OpenAiChatOptions.Builder()
.withModel("gpt-4")
.build();
return new OpenAiChatClient(openAiApi, options);
return new OpenAiChatClient(openAiApi, options).call(prompt).getResult().getOutput().getContent();
}

public OpenAiImageClient getOpenAiImageClient() {
OpenAiImageApi openAiApi = new OpenAiImageApi(apiKey);
@Override
public String generateDishImage(String instructionDishImage) {
OpenAiImageApi openAiApi = new OpenAiImageApi(baseUrl, apiKey, RestClient.builder());
var options = OpenAiImageOptions.builder()
.withQuality("hd")
.withHeight(1024).withWidth(1024)
.withResponseFormat("url")
.withModel("dall-e-3")
.build();
return new OpenAiImageClient(openAiApi, options, RetryTemplate.builder().build());
OpenAiImageClient openAiImageClient = new OpenAiImageClient(openAiApi, options, RetryTemplate.builder().build());
return openAiImageClient.call(new ImagePrompt(instructionDishImage)).getResult().getOutput().getUrl();
}
}
7 changes: 7 additions & 0 deletions src/main/java/dev/nano/mcc/client/AIClientPort.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package dev.nano.mcc.client;

public interface AIClientPort {

String generateRecipe(String instructionRecipe);
String generateDishImage(String instructionDishImage);
}
94 changes: 94 additions & 0 deletions src/test/java/dev/nano/mcc/MCCApplicationIntegrationTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package dev.nano.mcc;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.netty.handler.codec.http.HttpHeaderNames;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockserver.integration.ClientAndServer;
import org.mockserver.model.MediaType;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;

import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockserver.integration.ClientAndServer.startClientAndServer;
import static org.mockserver.model.HttpRequest.request;
import static org.mockserver.model.HttpResponse.response;

@SpringBootTest(classes = MCCApplication.class)
class MCCApplicationIntegrationTests {

@Autowired
MCCAssistant mccAssistant;

private final ObjectMapper objectMapper = new ObjectMapper();

private static final ClientAndServer mockServer = startClientAndServer(2445);

private static final String TEST_DISH_NAME = "couscous with seven vegetables";

@BeforeEach
void setup() {
mockServer.reset();
}

@Test
void generateDishImage() throws JsonProcessingException {

OpenAiImageApi.OpenAiImageResponse mockedResponse = new OpenAiImageApi.OpenAiImageResponse(
20L,
List.of(new OpenAiImageApi.Data(
"https://openai.com/image/dish_generated_url.png",
"base64_encoding_value",
"revised_prompt_value"
))
);

mockOpenAiGenerativeResponses("/v1/images/generations", objectMapper.writeValueAsString(mockedResponse));
String imageUrl = mccAssistant.getDishImage(TEST_DISH_NAME);
assertThat(imageUrl).isNotNull().isEqualTo("https://openai.com/image/dish_generated_url.png");
System.out.println("image url: " + imageUrl);
}



@Test
void generateRecipes() throws JsonProcessingException {

OpenAiApi.ChatCompletion mockedResponse = new OpenAiApi.ChatCompletion(
"id_value",
List.of(new OpenAiApi.ChatCompletion
.Choice(
OpenAiApi.ChatCompletionFinishReason.STOP,
1,
new OpenAiApi.ChatCompletionMessage("Detailed dish of a couscous recipe with seven vegetables", null),
null)),
10L,
"gpt-4",
"systemFingerPrint",
null,
new OpenAiApi.Usage(1, 2, 3)
);
mockOpenAiGenerativeResponses("/v1/chat/completions", objectMapper.writeValueAsString(mockedResponse));

String recipe = mccAssistant.getRecipes(TEST_DISH_NAME);
assertThat(recipe)
.isNotNull()
.isEqualTo("Detailed dish of a couscous recipe with seven vegetables");
}

private void mockOpenAiGenerativeResponses(String path, String objectMapper) throws JsonProcessingException {
mockServer.when(request().withMethod("POST").withPath(path))
.respond(
response()
.withStatusCode(200)
.withHeader(HttpHeaderNames.CONTENT_TYPE.toString(), MediaType.APPLICATION_JSON.toString())
.withBody(objectMapper));
}


}
50 changes: 50 additions & 0 deletions src/test/java/dev/nano/mcc/MCCAssistantTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package dev.nano.mcc;

import dev.nano.mcc.client.AIClientPort;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

import static org.mockito.BDDMockito.given;

@ExtendWith(MockitoExtension.class )
class MCCAssistantTest {
@InjectMocks
MCCAssistant mccAssistant;

@Mock
AIClientPort aiClientPort;


@Test
void shouldReturnDetailedRecipeWhenProvidingRecipeRequest() {

String recipeRequest = "Couscous with seven vegetables";

String expected = "Detailed couscous with seven vegetables recipe";

given(aiClientPort.generateRecipe("Can you provide a recipe for " + recipeRequest + "?")).willReturn(expected);

String result = mccAssistant.getRecipes(recipeRequest);

Assertions.assertThat(result).isEqualTo(expected);


}

@Test
void shouldReturnDishImageWhenProvidingDishImageRequest() {
String dishImageRequest = "Couscous with seven vegetables";

String expectedUrl = "https://openai.com/image/generated-dish.png";

given(aiClientPort.generateDishImage("Generate an image of a Moroccan dish called " + dishImageRequest)).willReturn(expectedUrl);

String result = mccAssistant.getDishImage(dishImageRequest);

Assertions.assertThat(result).isEqualTo(expectedUrl);
}
}
5 changes: 5 additions & 0 deletions src/test/resources/application.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
spring:
ai:
openai:
api-key: "dummy-key"
base-url: "http://localhost:2445"

0 comments on commit a309c63

Please sign in to comment.