-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test: added mock server for simulating open ai calls for integration …
…test
- Loading branch information
Showing
7 changed files
with
227 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,24 @@ | ||
package dev.nano.mcc; | ||
|
||
import lombok.RequiredArgsConstructor; | ||
import dev.nano.mcc.client.AIClientPort; | ||
import lombok.extern.slf4j.Slf4j; | ||
import org.springframework.ai.chat.messages.Message; | ||
import org.springframework.ai.chat.messages.SystemMessage; | ||
import org.springframework.ai.chat.messages.UserMessage; | ||
import org.springframework.ai.chat.prompt.Prompt; | ||
import org.springframework.ai.chat.prompt.SystemPromptTemplate; | ||
import org.springframework.ai.image.ImagePrompt; | ||
import org.springframework.beans.factory.annotation.Value; | ||
import org.springframework.core.io.Resource; | ||
import org.springframework.stereotype.Service; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
@Service | ||
@RequiredArgsConstructor | ||
@Slf4j | ||
public class MCCAssistant { | ||
|
||
@Value("classpath:/prompt/system-prompt.st") | ||
private Resource systemPrompt; | ||
private final AIClientPort aiClientPort; | ||
|
||
private final OpenAIClient openAIClient; | ||
public MCCAssistant(AIClientPort aiClientPort) { | ||
this.aiClientPort = aiClientPort; | ||
} | ||
|
||
public String getRecipes(String dishName) { | ||
|
||
SystemMessage systemMessage = new SystemMessage(this.systemPrompt); | ||
UserMessage userMessage = new UserMessage("Can you provide a recipe for + " + dishName + "?"); | ||
|
||
Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); | ||
|
||
return openAIClient.getOpenAiChatClient().call(prompt).getResult().getOutput().getContent(); | ||
return aiClientPort.generateRecipe("Can you provide a recipe for " + dishName + "?"); | ||
} | ||
|
||
public String getDishImage(String dishName) { | ||
ImagePrompt imagePrompt = new ImagePrompt("Generate an image of a Moroccan dish called " + dishName); | ||
return openAIClient.getOpenAiImageClient().call(imagePrompt).getResult().getOutput().getUrl(); | ||
public String getDishImage(String dishImageRequest) { | ||
return aiClientPort.generateDishImage("Generate an image of a Moroccan dish called " + dishImageRequest); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,59 @@ | ||
package dev.nano.mcc; | ||
|
||
import org.springframework.ai.image.ImageOptionsBuilder; | ||
import dev.nano.mcc.client.AIClientPort; | ||
import org.springframework.ai.chat.messages.SystemMessage; | ||
import org.springframework.ai.chat.messages.UserMessage; | ||
import org.springframework.ai.chat.prompt.Prompt; | ||
import org.springframework.ai.image.ImagePrompt; | ||
import org.springframework.ai.openai.OpenAiChatClient; | ||
import org.springframework.ai.openai.OpenAiChatOptions; | ||
import org.springframework.ai.openai.OpenAiImageClient; | ||
import org.springframework.ai.openai.OpenAiImageOptions; | ||
import org.springframework.ai.openai.api.OpenAiApi; | ||
import org.springframework.ai.openai.api.OpenAiImageApi; | ||
import org.springframework.beans.factory.annotation.Value; | ||
import org.springframework.core.io.Resource; | ||
import org.springframework.retry.support.RetryTemplate; | ||
import org.springframework.stereotype.Component; | ||
import org.springframework.stereotype.Repository; | ||
import org.springframework.web.client.RestClient; | ||
|
||
@Component | ||
public final class OpenAIClient { | ||
import java.util.List; | ||
|
||
@Repository | ||
public class OpenAIClient implements AIClientPort { | ||
|
||
@Value("${spring.ai.openai.api-key}") | ||
String apiKey; | ||
|
||
@Value("${spring.ai.openai.base-url}") | ||
String baseUrl; | ||
|
||
private final Resource systemPrompt; | ||
|
||
|
||
public OpenAIClient(@Value("classpath:/prompt/system-prompt.st") Resource systemPrompt) { | ||
this.systemPrompt = systemPrompt; | ||
} | ||
|
||
public OpenAiChatClient getOpenAiChatClient() { | ||
OpenAiApi openAiApi = new OpenAiApi(apiKey); | ||
var options = new OpenAiChatOptions.Builder() | ||
public String generateRecipe(String instructionRecipe) { | ||
Prompt prompt = new Prompt(List.of(new SystemMessage(this.systemPrompt), new UserMessage(instructionRecipe))); | ||
OpenAiApi openAiApi = new OpenAiApi(baseUrl, apiKey); | ||
OpenAiChatOptions options = new OpenAiChatOptions.Builder() | ||
.withModel("gpt-4") | ||
.build(); | ||
return new OpenAiChatClient(openAiApi, options); | ||
return new OpenAiChatClient(openAiApi, options).call(prompt).getResult().getOutput().getContent(); | ||
} | ||
|
||
public OpenAiImageClient getOpenAiImageClient() { | ||
OpenAiImageApi openAiApi = new OpenAiImageApi(apiKey); | ||
@Override | ||
public String generateDishImage(String instructionDishImage) { | ||
OpenAiImageApi openAiApi = new OpenAiImageApi(baseUrl, apiKey, RestClient.builder()); | ||
var options = OpenAiImageOptions.builder() | ||
.withQuality("hd") | ||
.withHeight(1024).withWidth(1024) | ||
.withResponseFormat("url") | ||
.withModel("dall-e-3") | ||
.build(); | ||
return new OpenAiImageClient(openAiApi, options, RetryTemplate.builder().build()); | ||
OpenAiImageClient openAiImageClient = new OpenAiImageClient(openAiApi, options, RetryTemplate.builder().build()); | ||
return openAiImageClient.call(new ImagePrompt(instructionDishImage)).getResult().getOutput().getUrl(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
package dev.nano.mcc.client; | ||
|
||
public interface AIClientPort { | ||
|
||
String generateRecipe(String instructionRecipe); | ||
String generateDishImage(String instructionDishImage); | ||
} |
94 changes: 94 additions & 0 deletions
94
src/test/java/dev/nano/mcc/MCCApplicationIntegrationTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
package dev.nano.mcc; | ||
|
||
import com.fasterxml.jackson.core.JsonProcessingException; | ||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
import io.netty.handler.codec.http.HttpHeaderNames; | ||
import org.junit.jupiter.api.BeforeEach; | ||
import org.junit.jupiter.api.Test; | ||
import org.mockserver.integration.ClientAndServer; | ||
import org.mockserver.model.MediaType; | ||
import org.springframework.ai.openai.api.OpenAiApi; | ||
import org.springframework.ai.openai.api.OpenAiImageApi; | ||
import org.springframework.beans.factory.annotation.Autowired; | ||
import org.springframework.boot.test.context.SpringBootTest; | ||
|
||
import java.util.List; | ||
|
||
import static org.assertj.core.api.Assertions.assertThat; | ||
import static org.mockserver.integration.ClientAndServer.startClientAndServer; | ||
import static org.mockserver.model.HttpRequest.request; | ||
import static org.mockserver.model.HttpResponse.response; | ||
|
||
@SpringBootTest(classes = MCCApplication.class) | ||
class MCCApplicationIntegrationTests { | ||
|
||
@Autowired | ||
MCCAssistant mccAssistant; | ||
|
||
private final ObjectMapper objectMapper = new ObjectMapper(); | ||
|
||
private static final ClientAndServer mockServer = startClientAndServer(2445); | ||
|
||
private static final String TEST_DISH_NAME = "couscous with seven vegetables"; | ||
|
||
@BeforeEach | ||
void setup() { | ||
mockServer.reset(); | ||
} | ||
|
||
@Test | ||
void generateDishImage() throws JsonProcessingException { | ||
|
||
OpenAiImageApi.OpenAiImageResponse mockedResponse = new OpenAiImageApi.OpenAiImageResponse( | ||
20L, | ||
List.of(new OpenAiImageApi.Data( | ||
"https://openai.com/image/dish_generated_url.png", | ||
"base64_encoding_value", | ||
"revised_prompt_value" | ||
)) | ||
); | ||
|
||
mockOpenAiGenerativeResponses("/v1/images/generations", objectMapper.writeValueAsString(mockedResponse)); | ||
String imageUrl = mccAssistant.getDishImage(TEST_DISH_NAME); | ||
assertThat(imageUrl).isNotNull().isEqualTo("https://openai.com/image/dish_generated_url.png"); | ||
System.out.println("image url: " + imageUrl); | ||
} | ||
|
||
|
||
|
||
@Test | ||
void generateRecipes() throws JsonProcessingException { | ||
|
||
OpenAiApi.ChatCompletion mockedResponse = new OpenAiApi.ChatCompletion( | ||
"id_value", | ||
List.of(new OpenAiApi.ChatCompletion | ||
.Choice( | ||
OpenAiApi.ChatCompletionFinishReason.STOP, | ||
1, | ||
new OpenAiApi.ChatCompletionMessage("Detailed dish of a couscous recipe with seven vegetables", null), | ||
null)), | ||
10L, | ||
"gpt-4", | ||
"systemFingerPrint", | ||
null, | ||
new OpenAiApi.Usage(1, 2, 3) | ||
); | ||
mockOpenAiGenerativeResponses("/v1/chat/completions", objectMapper.writeValueAsString(mockedResponse)); | ||
|
||
String recipe = mccAssistant.getRecipes(TEST_DISH_NAME); | ||
assertThat(recipe) | ||
.isNotNull() | ||
.isEqualTo("Detailed dish of a couscous recipe with seven vegetables"); | ||
} | ||
|
||
private void mockOpenAiGenerativeResponses(String path, String objectMapper) throws JsonProcessingException { | ||
mockServer.when(request().withMethod("POST").withPath(path)) | ||
.respond( | ||
response() | ||
.withStatusCode(200) | ||
.withHeader(HttpHeaderNames.CONTENT_TYPE.toString(), MediaType.APPLICATION_JSON.toString()) | ||
.withBody(objectMapper)); | ||
} | ||
|
||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
package dev.nano.mcc; | ||
|
||
import dev.nano.mcc.client.AIClientPort; | ||
import org.assertj.core.api.Assertions; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.extension.ExtendWith; | ||
import org.mockito.InjectMocks; | ||
import org.mockito.Mock; | ||
import org.mockito.junit.jupiter.MockitoExtension; | ||
|
||
import static org.mockito.BDDMockito.given; | ||
|
||
@ExtendWith(MockitoExtension.class ) | ||
class MCCAssistantTest { | ||
@InjectMocks | ||
MCCAssistant mccAssistant; | ||
|
||
@Mock | ||
AIClientPort aiClientPort; | ||
|
||
|
||
@Test | ||
void shouldReturnDetailedRecipeWhenProvidingRecipeRequest() { | ||
|
||
String recipeRequest = "Couscous with seven vegetables"; | ||
|
||
String expected = "Detailed couscous with seven vegetables recipe"; | ||
|
||
given(aiClientPort.generateRecipe("Can you provide a recipe for " + recipeRequest + "?")).willReturn(expected); | ||
|
||
String result = mccAssistant.getRecipes(recipeRequest); | ||
|
||
Assertions.assertThat(result).isEqualTo(expected); | ||
|
||
|
||
} | ||
|
||
@Test | ||
void shouldReturnDishImageWhenProvidingDishImageRequest() { | ||
String dishImageRequest = "Couscous with seven vegetables"; | ||
|
||
String expectedUrl = "https://openai.com/image/generated-dish.png"; | ||
|
||
given(aiClientPort.generateDishImage("Generate an image of a Moroccan dish called " + dishImageRequest)).willReturn(expectedUrl); | ||
|
||
String result = mccAssistant.getDishImage(dishImageRequest); | ||
|
||
Assertions.assertThat(result).isEqualTo(expectedUrl); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
spring: | ||
ai: | ||
openai: | ||
api-key: "dummy-key" | ||
base-url: "http://localhost:2445" |