Skip to content

Commit

Permalink
feat : polish spring chat model
Browse files Browse the repository at this point in the history
  • Loading branch information
rajadilipkolli committed Apr 7, 2024
1 parent ef10569 commit b87de52
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 25 deletions.
24 changes: 24 additions & 0 deletions chatmodel-springai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
<properties>
<java.version>17</java.version>
<spring-ai.version>0.8.1</spring-ai.version>
<spotless.version>2.43.0</spotless.version>
</properties>

<dependencies>
Expand Down Expand Up @@ -72,6 +73,29 @@
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>com.diffplug.spotless</groupId>
<artifactId>spotless-maven-plugin</artifactId>
<version>${spotless.version}</version>
<configuration>
<java>
<palantirJavaFormat>
<version>2.40.0</version>
</palantirJavaFormat>
<importOrder />
<removeUnusedImports />
<formatAnnotations />
</java>
</configuration>
<executions>
<execution>
<phase>compile</phase>
<goals>
<goal>check</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>

Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
package com.example.ai.controller;

import com.example.ai.model.request.AIChatRequest;
import com.example.ai.model.response.AIChatResponse;
import java.util.List;
import java.util.Map;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

@RestController
Expand All @@ -20,32 +23,42 @@ public class ChatController {

private final ChatClient chatClient;

ChatController(ChatClient chatClient) {
private final EmbeddingClient embeddingClient;

ChatController(ChatClient chatClient, EmbeddingClient embeddingClient) {
this.chatClient = chatClient;
this.embeddingClient = embeddingClient;
}

@GetMapping("/chat")
Map<String, String> chat(@RequestParam String question) {
var response = chatClient.call(question);
return Map.of("question", question, "answer", response);
@PostMapping("/chat")
AIChatResponse chat(@RequestBody AIChatRequest aiChatRequest) {
var answer = chatClient.call(aiChatRequest.query());
return new AIChatResponse(answer);
}

@GetMapping("/chat-with-prompt")
AIChatResponse chatWithPrompt(@RequestParam String subject) {
@PostMapping("/chat-with-prompt")
AIChatResponse chatWithPrompt(@RequestBody AIChatRequest aiChatRequest) {
PromptTemplate promptTemplate = new PromptTemplate("Tell me a joke about {subject}");
Prompt prompt = promptTemplate.create(Map.of("subject", subject));
Prompt prompt = promptTemplate.create(Map.of("subject", aiChatRequest.query()));
ChatResponse response = chatClient.call(prompt);
String answer = response.getResult().getOutput().getContent();
Generation generation = response.getResult();
String answer = (generation != null) ? generation.getOutput().getContent() : "";
return new AIChatResponse(answer);
}

@GetMapping("/chat-with-system-prompt")
AIChatResponse chatWithSystemPrompt(@RequestParam String subject) {
@PostMapping("/chat-with-system-prompt")
AIChatResponse chatWithSystemPrompt(@RequestBody AIChatRequest aiChatRequest) {
SystemMessage systemMessage = new SystemMessage("You are a sarcastic and funny chatbot");
UserMessage userMessage = new UserMessage("Tell me a joke about " + subject);
UserMessage userMessage = new UserMessage("Tell me a joke about " + aiChatRequest.query());
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
ChatResponse response = chatClient.call(prompt);
String answer = response.getResult().getOutput().getContent();
return new AIChatResponse(answer);
}

@PostMapping("/emebedding-client-conversion")
AIChatResponse chatWithEmbeddingClient(@RequestBody AIChatRequest aiChatRequest) {
List<Double> embed = embeddingClient.embed(aiChatRequest.query());
return new AIChatResponse(embed.toString());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package com.example.ai.model.request;

public record AIChatRequest(String query) {}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ spring.ai.openai.chat.options.model=gpt-3.5-turbo
spring.ai.openai.chat.options.temperature=0.2
spring.ai.openai.chat.options.responseFormat=json_object

spring.ai.openai.embedding.enabled=false
spring.ai.openai.embedding.enabled=true

##logging
logging.level.org.apache.hc.client5.http=INFO
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import static io.restassured.RestAssured.given;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.containsStringIgnoringCase;

import com.example.ai.model.request.AIChatRequest;
import io.restassured.RestAssured;
import io.restassured.http.ContentType;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
Expand All @@ -25,30 +26,32 @@ void setUp() {

@Test
void testChat() {
given().param("question", "Hello?")
given().contentType(ContentType.JSON)
.body(new AIChatRequest("Hello?"))
.when()
.get("/api/ai/chat")
.post("/api/ai/chat")
.then()
.statusCode(200)
.body("question", containsStringIgnoringCase("Hello?"))
.body("answer", containsString("Hello!"));
}

@Test
void chatWithPrompt() {
given().param("subject", "java")
given().contentType(ContentType.JSON)
.body(new AIChatRequest("java"))
.when()
.get("/api/ai/chat-with-prompt")
.post("/api/ai/chat-with-prompt")
.then()
.statusCode(200)
.body("answer", containsString("Java"));
}

@Test
void chatWithSystemPrompt() {
given().param("subject", "cricket")
given().contentType(ContentType.JSON)
.body(new AIChatRequest("cricket"))
.when()
.get("/api/ai/chat-with-system-prompt")
.post("/api/ai/chat-with-system-prompt")
.then()
.statusCode(200)
.body("answer", containsString("cricket"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
Expand Down Expand Up @@ -64,6 +65,7 @@ public String chat(String query) {
LOGGER.info("Calling ai with prompt :{}", prompt);
ChatResponse aiResponse = aiClient.call(prompt);
LOGGER.info("Response received from call :{}", aiResponse);
return aiResponse.getResult().getOutput().getContent();
Generation generation = aiResponse.getResult();
return (generation != null) ? generation.getOutput().getContent() : "";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.util.stream.Collectors;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
Expand Down Expand Up @@ -55,6 +56,7 @@ public String chat(String searchQuery) {
UserMessage userMessage = new UserMessage(searchQuery);
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
ChatResponse aiResponse = aiClient.call(prompt);
return aiResponse.getResult().getOutput().getContent();
Generation generation = aiResponse.getResult();
return (generation != null) ? generation.getOutput().getContent() : "";
}
}

0 comments on commit b87de52

Please sign in to comment.