Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat : convert from get to post endpoint #27

Merged
merged 3 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
package com.learning.ai.llmragwithspringai.controller;

import com.learning.ai.llmragwithspringai.model.request.AIChatRequest;
import com.learning.ai.llmragwithspringai.model.response.AIChatResponse;
import com.learning.ai.llmragwithspringai.service.AIChatService;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.Pattern;
import jakarta.validation.constraints.Size;
import java.util.Map;
import jakarta.validation.Valid;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

@RestController
Expand All @@ -22,14 +21,9 @@ public AiController(AIChatService aiChatService) {
this.aiChatService = aiChatService;
}

@GetMapping("/chat")
Map<String, String> ragService(
@RequestParam
@NotBlank(message = "Query cannot be empty")
@Size(max = 255, message = "Query exceeds maximum length")
@Pattern(regexp = "^[a-zA-Z0-9 ]*$", message = "Invalid characters in query")
String question) {
String chatResponse = aiChatService.chat(question);
return Map.of("response", chatResponse);
@PostMapping("/chat")
AIChatResponse ragService(@Valid @RequestBody AIChatRequest aiChatRequest) {
String chatResponse = aiChatService.chat(aiChatRequest.question());
return new AIChatResponse(chatResponse);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.learning.ai.llmragwithspringai.model.request;

import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.Pattern;
import jakarta.validation.constraints.Size;
import java.io.Serializable;

public record AIChatRequest(
@NotBlank(message = "Query cannot be empty")
@Size(max = 800, message = "Query exceeds maximum length")
@Pattern(regexp = "^[a-zA-Z0-9 ?]*$", message = "Invalid characters in query")
String question)
implements Serializable {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.learning.ai.llmragwithspringai.model.response;

import java.io.Serializable;

public record AIChatResponse(String response) implements Serializable {}
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ public class AIChatService {

private static final String template =
"""
You're assisting with questions about cricket

You're assisting with questions about cricketers
Cricket is a bat-and-ball game that is played between two teams of eleven players on a field at the centre of which is a 22-yard (20-metre) pitch with a wicket at each end,
each comprising two bails balanced on three stumps.
Two players from the batting team (the striker and nonstriker) stand in front of either wicket,
with one player from the fielding team (the bowler) bowling the ball towards the striker's wicket from the opposite end of the pitch.
The striker's goal is to hit the bowled ball and then switch places with the nonstriker,
with the batting team scoring one run for each exchange.
The striker's goal is to hit the bowled ball and then switch places with the nonstriker, with the batting team scoring one run for each exchange.
Runs are also scored when the ball reaches or crosses the boundary of the field or when the ball is bowled illegally.

Use the information from the DOCUMENTS section to provide accurate answers but act as if you knew this information innately.
Expand All @@ -44,16 +43,16 @@ public AIChatService(ChatClient aiClient, VectorStore vectorStore) {
this.vectorStore = vectorStore;
}

public String chat(String message) {
public String chat(String searchQuery) {
// Querying the VectorStore using natural language looking for the information about info asked.
List<Document> listOfSimilarDocuments = this.vectorStore.similaritySearch(message);
List<Document> listOfSimilarDocuments = this.vectorStore.similaritySearch(searchQuery);
String documents = listOfSimilarDocuments.stream()
.map(Document::getContent)
.collect(Collectors.joining(System.lineSeparator()));
// Constructing the systemMessage to indicate the AI model to use the passed information
// to answer the question.
Message systemMessage = new SystemPromptTemplate(template).createMessage(Map.of("documents", documents));
UserMessage userMessage = new UserMessage(message);
UserMessage userMessage = new UserMessage(searchQuery);
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
ChatResponse aiResponse = aiClient.call(prompt);
return aiResponse.getResult().getOutput().getContent();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
import static org.hamcrest.Matchers.*;

import com.learning.ai.llmragwithspringai.config.AbstractIntegrationTest;
import com.learning.ai.llmragwithspringai.model.request.AIChatRequest;
import io.restassured.RestAssured;
import io.restassured.http.ContentType;
import org.apache.http.HttpStatus;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.springframework.boot.test.web.server.LocalServerPort;
import org.springframework.http.MediaType;

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class LlmRagWithSpringAiApplicationIntTest extends AbstractIntegrationTest {
Expand All @@ -23,72 +27,86 @@ void setUp() {

@Test
void testRag() {
given().param("question", "What trophies did Rohit won")
given().contentType(MediaType.APPLICATION_JSON_VALUE)
.body(new AIChatRequest("What trophies did Rohit won?"))
.when()
.get("/api/ai/chat")
.post("/api/ai/chat")
.then()
.statusCode(200)
.statusCode(HttpStatus.SC_OK)
.body("response", containsString("2007 T20 World Cup"))
.body("response", containsString("2013 ICC Champions Trophy"));
.body("response", containsString("2013 ICC Champions Trophy"))
.log()
.all();
}

@Test
void testRag2() {
given().param("question", "Who is successful IPL captain")
given().contentType(MediaType.APPLICATION_JSON_VALUE)
.body(new AIChatRequest("Who is successful IPL captain?"))
.when()
.get("/api/ai/chat")
.post("/api/ai/chat")
.then()
.statusCode(200)
.body("response", containsString("Rohit Sharma"));
.body("response", containsString("Rohit Sharma"))
.log()
.all();
}

@Test
void testEmptyQuery() {
given().param("question", "")
given().contentType(ContentType.JSON)
.body(new AIChatRequest(""))
.when()
.get("/api/ai/chat")
.post("/api/ai/chat")
.then()
.statusCode(400)
.header("Content-Type", is("application/problem+json"))
.body("detail", is("Invalid request content."))
.body("instance", is("/api/ai/chat"))
.body("title", is("Constraint Violation"))
.body("violations", hasSize(1))
.body("violations[0].field", is("ragService.question"))
.body("violations[0].message", containsString("Query cannot be empty"));
.body("violations[0].field", is("question"))
.body("violations[0].message", containsString("Query cannot be empty"))
.log()
.all();
}

@Test
void testLongQueryString() {
String longQuery = "a".repeat(1000); // Example of a very long query string
given().param("question", longQuery)
given().contentType(ContentType.JSON)
.body(new AIChatRequest(longQuery))
.when()
.get("/api/ai/chat")
.post("/api/ai/chat")
.then()
.statusCode(400)
.header("Content-Type", is("application/problem+json"))
.body("detail", is("Invalid request content."))
.body("instance", is("/api/ai/chat"))
.body("title", is("Constraint Violation"))
.body("violations", hasSize(1))
.body("violations[0].field", is("ragService.question"))
.body("violations[0].message", containsString("Query exceeds maximum length"));
.body("violations[0].field", is("question"))
.body("violations[0].message", containsString("Query exceeds maximum length"))
.log()
.all();
}

@Test
void testSpecialCharactersInQuery() {
given().param("question", "@#$%^&*()")
given().contentType(ContentType.JSON)
.body(new AIChatRequest("@#$%^&*()"))
.when()
.get("/api/ai/chat")
.post("/api/ai/chat")
.then()
.statusCode(400)
.header("Content-Type", is("application/problem+json"))
.body("detail", is("Invalid request content."))
.body("instance", is("/api/ai/chat"))
.body("title", is("Constraint Violation"))
.body("violations", hasSize(1))
.body("violations[0].field", is("ragService.question"))
.body("violations[0].field", is("question"))
.body("violations[0].message", containsString("Invalid characters in query"))
.log();
.log()
.all();
}
}