diff --git a/.github/workflows/llm-rag-with-langchain4j-spring-boot.yml b/.github/workflows/rag-langchain4j-AllMiniLmL6V2-llm.yml
similarity index 78%
rename from .github/workflows/llm-rag-with-langchain4j-spring-boot.yml
rename to .github/workflows/rag-langchain4j-AllMiniLmL6V2-llm.yml
index 838889f..13c29d5 100644
--- a/.github/workflows/llm-rag-with-langchain4j-spring-boot.yml
+++ b/.github/workflows/rag-langchain4j-AllMiniLmL6V2-llm.yml
@@ -1,13 +1,13 @@
-name: llm-rag-with-langchain4j-spring-boot CI Build
+name: rag-langchain4j-AllMiniLmL6V2-llm CI Build
on:
push:
paths:
- - "llm-rag-with-langchain4j-spring-boot/**"
+ - "rag-langchain4j-AllMiniLmL6V2-llm/**"
branches: [main]
pull_request:
paths:
- - "llm-rag-with-langchain4j-spring-boot/**"
+ - "rag-langchain4j-AllMiniLmL6V2-llm/**"
types:
- opened
- synchronize
@@ -19,7 +19,7 @@ jobs:
runs-on: ubuntu-latest
defaults:
run:
- working-directory: llm-rag-with-langchain4j-spring-boot
+ working-directory: rag-langchain4j-AllMiniLmL6V2-llm
strategy:
matrix:
distribution: [ 'temurin' ]
diff --git a/llm-rag-with-langchain4j-spring-boot/src/main/java/com/learning/ai/controller/CustomerSupportController.java b/llm-rag-with-langchain4j-spring-boot/src/main/java/com/learning/ai/controller/CustomerSupportController.java
deleted file mode 100644
index d80fa77..0000000
--- a/llm-rag-with-langchain4j-spring-boot/src/main/java/com/learning/ai/controller/CustomerSupportController.java
+++ /dev/null
@@ -1,29 +0,0 @@
-package com.learning.ai.controller;
-
-import com.learning.ai.config.AICustomerSupportAgent;
-import com.learning.ai.domain.AICustomerSupportResponse;
-import org.springframework.web.bind.annotation.GetMapping;
-import org.springframework.web.bind.annotation.RequestMapping;
-import org.springframework.web.bind.annotation.RequestParam;
-import org.springframework.web.bind.annotation.RestController;
-
-@RestController
-@RequestMapping("/api")
-public class CustomerSupportController {
-
- private final AICustomerSupportAgent aiCustomerSupportAgent;
-
- public CustomerSupportController(AICustomerSupportAgent aiCustomerSupportAgent) {
- this.aiCustomerSupportAgent = aiCustomerSupportAgent;
- }
-
- @GetMapping("/chat")
- public AICustomerSupportResponse customerSupportChat(
- @RequestParam(
- value = "message",
- defaultValue =
- "what should I know about the transition to consumer direct care network washington?")
- String message) {
- return aiCustomerSupportAgent.chat(message);
- }
-}
diff --git a/llm-rag-with-langchain4j-spring-boot/.gitignore b/rag-langchain4j-AllMiniLmL6V2-llm/.gitignore
similarity index 100%
rename from llm-rag-with-langchain4j-spring-boot/.gitignore
rename to rag-langchain4j-AllMiniLmL6V2-llm/.gitignore
diff --git a/llm-rag-with-langchain4j-spring-boot/.mvn/wrapper/maven-wrapper.jar b/rag-langchain4j-AllMiniLmL6V2-llm/.mvn/wrapper/maven-wrapper.jar
similarity index 100%
rename from llm-rag-with-langchain4j-spring-boot/.mvn/wrapper/maven-wrapper.jar
rename to rag-langchain4j-AllMiniLmL6V2-llm/.mvn/wrapper/maven-wrapper.jar
diff --git a/llm-rag-with-langchain4j-spring-boot/.mvn/wrapper/maven-wrapper.properties b/rag-langchain4j-AllMiniLmL6V2-llm/.mvn/wrapper/maven-wrapper.properties
similarity index 100%
rename from llm-rag-with-langchain4j-spring-boot/.mvn/wrapper/maven-wrapper.properties
rename to rag-langchain4j-AllMiniLmL6V2-llm/.mvn/wrapper/maven-wrapper.properties
diff --git a/llm-rag-with-langchain4j-spring-boot/README.md b/rag-langchain4j-AllMiniLmL6V2-llm/README.md
similarity index 100%
rename from llm-rag-with-langchain4j-spring-boot/README.md
rename to rag-langchain4j-AllMiniLmL6V2-llm/README.md
diff --git a/llm-rag-with-langchain4j-spring-boot/docker/docker-compose.yml b/rag-langchain4j-AllMiniLmL6V2-llm/docker/docker-compose.yml
similarity index 100%
rename from llm-rag-with-langchain4j-spring-boot/docker/docker-compose.yml
rename to rag-langchain4j-AllMiniLmL6V2-llm/docker/docker-compose.yml
diff --git a/llm-rag-with-langchain4j-spring-boot/docker/docker_pgadmin_servers.json b/rag-langchain4j-AllMiniLmL6V2-llm/docker/docker_pgadmin_servers.json
similarity index 100%
rename from llm-rag-with-langchain4j-spring-boot/docker/docker_pgadmin_servers.json
rename to rag-langchain4j-AllMiniLmL6V2-llm/docker/docker_pgadmin_servers.json
diff --git a/llm-rag-with-langchain4j-spring-boot/mvnw b/rag-langchain4j-AllMiniLmL6V2-llm/mvnw
similarity index 100%
rename from llm-rag-with-langchain4j-spring-boot/mvnw
rename to rag-langchain4j-AllMiniLmL6V2-llm/mvnw
diff --git a/llm-rag-with-langchain4j-spring-boot/mvnw.cmd b/rag-langchain4j-AllMiniLmL6V2-llm/mvnw.cmd
similarity index 100%
rename from llm-rag-with-langchain4j-spring-boot/mvnw.cmd
rename to rag-langchain4j-AllMiniLmL6V2-llm/mvnw.cmd
diff --git a/llm-rag-with-langchain4j-spring-boot/pom.xml b/rag-langchain4j-AllMiniLmL6V2-llm/pom.xml
similarity index 97%
rename from llm-rag-with-langchain4j-spring-boot/pom.xml
rename to rag-langchain4j-AllMiniLmL6V2-llm/pom.xml
index d80e13c..aaf29d9 100644
--- a/llm-rag-with-langchain4j-spring-boot/pom.xml
+++ b/rag-langchain4j-AllMiniLmL6V2-llm/pom.xml
@@ -9,9 +9,9 @@
org.example.ai
- llm-rag-with-langchain4j-spring-boot
+ rag-langchain4j-AllMiniLmL6V2-llm
0.0.1-SNAPSHOT
- llm-rag-with-langchain4j-spring-boot
+ rag-langchain4j-AllMiniLmL6V2-llm
Demo project for Spring Boot
diff --git a/llm-rag-with-langchain4j-spring-boot/src/main/java/com/learning/ai/LLMRagWithSpringBoot.java b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/LLMRagWithSpringBoot.java
similarity index 100%
rename from llm-rag-with-langchain4j-spring-boot/src/main/java/com/learning/ai/LLMRagWithSpringBoot.java
rename to rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/LLMRagWithSpringBoot.java
diff --git a/llm-rag-with-langchain4j-spring-boot/src/main/java/com/learning/ai/config/AIConfig.java b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/config/AIConfig.java
similarity index 91%
rename from llm-rag-with-langchain4j-spring-boot/src/main/java/com/learning/ai/config/AIConfig.java
rename to rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/config/AIConfig.java
index 9435a47..d6fa248 100644
--- a/llm-rag-with-langchain4j-spring-boot/src/main/java/com/learning/ai/config/AIConfig.java
+++ b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/config/AIConfig.java
@@ -2,7 +2,6 @@
import static dev.langchain4j.data.document.loader.FileSystemDocumentLoader.loadDocument;
-import com.zaxxer.hikari.HikariDataSource;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.DocumentSplitter;
import dev.langchain4j.data.document.parser.apache.pdfbox.ApachePdfBoxDocumentParser;
@@ -22,7 +21,7 @@
import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore;
import java.io.IOException;
import java.net.URI;
-import javax.sql.DataSource;
+import org.springframework.boot.autoconfigure.jdbc.JdbcConnectionDetails;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.Resource;
@@ -62,13 +61,13 @@ EmbeddingModel embeddingModel() {
@Bean
EmbeddingStore embeddingStore(
- EmbeddingModel embeddingModel, ResourceLoader resourceLoader, DataSource dataSource) throws IOException {
+ EmbeddingModel embeddingModel, ResourceLoader resourceLoader, JdbcConnectionDetails jdbcConnectionDetails)
+ throws IOException {
// Normally, you would already have your embedding store filled with your data.
// However, for the purpose of this demonstration, we will:
- HikariDataSource hikariDataSource = (HikariDataSource) dataSource;
- String jdbcUrl = hikariDataSource.getJdbcUrl();
+ String jdbcUrl = jdbcConnectionDetails.getJdbcUrl();
URI uri = URI.create(jdbcUrl.substring(5));
String host = uri.getHost();
int dbPort = uri.getPort();
@@ -78,8 +77,8 @@ EmbeddingStore embeddingStore(
EmbeddingStore embeddingStore = PgVectorEmbeddingStore.builder()
.host(host)
.port(dbPort != -1 ? dbPort : 5432)
- .user(hikariDataSource.getUsername())
- .password(hikariDataSource.getPassword())
+ .user(jdbcConnectionDetails.getUsername())
+ .password(jdbcConnectionDetails.getPassword())
.database(path.substring(1))
.table("ai_vector_store")
.dimension(384)
diff --git a/llm-rag-with-langchain4j-spring-boot/src/main/java/com/learning/ai/config/AICustomerSupportAgent.java b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/config/AICustomerSupportAgent.java
similarity index 93%
rename from llm-rag-with-langchain4j-spring-boot/src/main/java/com/learning/ai/config/AICustomerSupportAgent.java
rename to rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/config/AICustomerSupportAgent.java
index 9349096..fe480e2 100644
--- a/llm-rag-with-langchain4j-spring-boot/src/main/java/com/learning/ai/config/AICustomerSupportAgent.java
+++ b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/config/AICustomerSupportAgent.java
@@ -1,6 +1,6 @@
package com.learning.ai.config;
-import com.learning.ai.domain.AICustomerSupportResponse;
+import com.learning.ai.domain.response.AICustomerSupportResponse;
import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.V;
diff --git a/llm-rag-with-langchain4j-spring-boot/src/main/java/com/learning/ai/config/ChatTools.java b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/config/ChatTools.java
similarity index 92%
rename from llm-rag-with-langchain4j-spring-boot/src/main/java/com/learning/ai/config/ChatTools.java
rename to rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/config/ChatTools.java
index f44e5a4..bf500e2 100644
--- a/llm-rag-with-langchain4j-spring-boot/src/main/java/com/learning/ai/config/ChatTools.java
+++ b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/config/ChatTools.java
@@ -12,7 +12,7 @@ public class ChatTools {
/**
* This tool is available to {@link AICustomerSupportAgent}
*/
- @Tool
+ @Tool("chatAssistantTools")
String currentTime() {
log.info("Inside ChatTools");
return LocalTime.now().toString();
diff --git a/llm-rag-with-langchain4j-spring-boot/src/main/java/com/learning/ai/config/SwaggerConfig.java b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/config/SwaggerConfig.java
similarity index 69%
rename from llm-rag-with-langchain4j-spring-boot/src/main/java/com/learning/ai/config/SwaggerConfig.java
rename to rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/config/SwaggerConfig.java
index bf616bd..93eed82 100644
--- a/llm-rag-with-langchain4j-spring-boot/src/main/java/com/learning/ai/config/SwaggerConfig.java
+++ b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/config/SwaggerConfig.java
@@ -6,5 +6,7 @@
import org.springframework.context.annotation.Configuration;
@Configuration(proxyBeanMethods = false)
-@OpenAPIDefinition(info = @Info(title = "llm-rag-with-langchain4j", version = "v1.0.0"), servers = @Server(url = "/"))
+@OpenAPIDefinition(
+ info = @Info(title = "rag-langchain4j-AllMiniLmL6V2-llm", version = "v1.0.0"),
+ servers = @Server(url = "/"))
public class SwaggerConfig {}
diff --git a/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/controller/CustomerSupportController.java b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/controller/CustomerSupportController.java
new file mode 100644
index 0000000..6860ad5
--- /dev/null
+++ b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/controller/CustomerSupportController.java
@@ -0,0 +1,28 @@
+package com.learning.ai.controller;
+
+import com.learning.ai.config.AICustomerSupportAgent;
+import com.learning.ai.domain.request.AIChatRequest;
+import com.learning.ai.domain.response.AICustomerSupportResponse;
+import jakarta.validation.Valid;
+import org.springframework.validation.annotation.Validated;
+import org.springframework.web.bind.annotation.PostMapping;
+import org.springframework.web.bind.annotation.RequestBody;
+import org.springframework.web.bind.annotation.RequestMapping;
+import org.springframework.web.bind.annotation.RestController;
+
+@RestController
+@RequestMapping("/api/ai")
+@Validated
+public class CustomerSupportController {
+
+ private final AICustomerSupportAgent aiCustomerSupportAgent;
+
+ public CustomerSupportController(AICustomerSupportAgent aiCustomerSupportAgent) {
+ this.aiCustomerSupportAgent = aiCustomerSupportAgent;
+ }
+
+ @PostMapping("/chat")
+ public AICustomerSupportResponse customerSupportChat(@RequestBody @Valid AIChatRequest aiChatRequest) {
+ return aiCustomerSupportAgent.chat(aiChatRequest.question());
+ }
+}
diff --git a/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/domain/request/AIChatRequest.java b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/domain/request/AIChatRequest.java
new file mode 100644
index 0000000..c4e5d34
--- /dev/null
+++ b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/domain/request/AIChatRequest.java
@@ -0,0 +1,13 @@
+package com.learning.ai.domain.request;
+
+import jakarta.validation.constraints.NotBlank;
+import jakarta.validation.constraints.Pattern;
+import jakarta.validation.constraints.Size;
+import java.io.Serializable;
+
+public record AIChatRequest(
+ @NotBlank(message = "Query cannot be empty")
+ @Size(max = 800, message = "Query exceeds maximum length")
+ @Pattern(regexp = "^[a-zA-Z0-9 ?]*$", message = "Invalid characters in query")
+ String question)
+ implements Serializable {}
diff --git a/llm-rag-with-langchain4j-spring-boot/src/main/java/com/learning/ai/domain/AICustomerSupportResponse.java b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/domain/response/AICustomerSupportResponse.java
similarity index 59%
rename from llm-rag-with-langchain4j-spring-boot/src/main/java/com/learning/ai/domain/AICustomerSupportResponse.java
rename to rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/domain/response/AICustomerSupportResponse.java
index 5f40d87..033cc28 100644
--- a/llm-rag-with-langchain4j-spring-boot/src/main/java/com/learning/ai/domain/AICustomerSupportResponse.java
+++ b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/domain/response/AICustomerSupportResponse.java
@@ -1,3 +1,3 @@
-package com.learning.ai.domain;
+package com.learning.ai.domain.response;
public record AICustomerSupportResponse(String response) {}
diff --git a/llm-rag-with-langchain4j-spring-boot/src/main/resources/application.properties b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/resources/application.properties
similarity index 91%
rename from llm-rag-with-langchain4j-spring-boot/src/main/resources/application.properties
rename to rag-langchain4j-AllMiniLmL6V2-llm/src/main/resources/application.properties
index 1ae5910..5b5987e 100644
--- a/llm-rag-with-langchain4j-spring-boot/src/main/resources/application.properties
+++ b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/resources/application.properties
@@ -1,3 +1,5 @@
+spring.application.name=rag-langchain4j-AllMiniLmL6V2-llm
+
langchain4j.open-ai.chat-model.api-key=demo
langchain4j.open-ai.chat-model.model-name=gpt-3.5-turbo
langchain4j.open-ai.chat-model.temperature=0.7
diff --git a/llm-rag-with-langchain4j-spring-boot/src/main/resources/medicaid-wa-faqs.pdf b/rag-langchain4j-AllMiniLmL6V2-llm/src/main/resources/medicaid-wa-faqs.pdf
similarity index 100%
rename from llm-rag-with-langchain4j-spring-boot/src/main/resources/medicaid-wa-faqs.pdf
rename to rag-langchain4j-AllMiniLmL6V2-llm/src/main/resources/medicaid-wa-faqs.pdf
diff --git a/llm-rag-with-langchain4j-spring-boot/src/test/java/com/learning/ai/LLMRagWithSpringBootTest.java b/rag-langchain4j-AllMiniLmL6V2-llm/src/test/java/com/learning/ai/LLMRagWithSpringBootTest.java
similarity index 61%
rename from llm-rag-with-langchain4j-spring-boot/src/test/java/com/learning/ai/LLMRagWithSpringBootTest.java
rename to rag-langchain4j-AllMiniLmL6V2-llm/src/test/java/com/learning/ai/LLMRagWithSpringBootTest.java
index 1ad1761..0e4ecc2 100644
--- a/llm-rag-with-langchain4j-spring-boot/src/test/java/com/learning/ai/LLMRagWithSpringBootTest.java
+++ b/rag-langchain4j-AllMiniLmL6V2-llm/src/test/java/com/learning/ai/LLMRagWithSpringBootTest.java
@@ -1,10 +1,11 @@
package com.learning.ai;
import static io.restassured.RestAssured.given;
-import static io.restassured.RestAssured.when;
import static org.hamcrest.Matchers.notNullValue;
+import com.learning.ai.domain.request.AIChatRequest;
import io.restassured.RestAssured;
+import io.restassured.http.ContentType;
import io.restassured.http.Method;
import org.apache.http.HttpStatus;
import org.junit.jupiter.api.BeforeAll;
@@ -26,15 +27,22 @@ public void setUp() {
}
@Test
- void whenRequestGet_thenOK() {
- when().request(Method.GET, "/api/chat").then().statusCode(HttpStatus.SC_OK);
+ void whenRequestPost_thenOK() {
+ given().contentType(ContentType.JSON)
+ .body(new AIChatRequest(
+ "what should I know about the transition to consumer direct care network washington?"))
+ .when()
+ .request(Method.POST, "/api/ai/chat")
+ .then()
+ .statusCode(HttpStatus.SC_OK);
}
@Test
void whenRequestGetTime_thenOK() {
- given().param("message", "What is the time now?")
+ given().contentType(ContentType.JSON)
+ .body(new AIChatRequest("What is the time now?"))
.when()
- .request(Method.GET, "/api/chat")
+ .request(Method.POST, "/api/ai/chat")
.then()
.statusCode(HttpStatus.SC_OK)
.body("response", notNullValue());
diff --git a/llm-rag-with-langchain4j-spring-boot/src/test/java/com/learning/ai/TestLLMRagWithSpringBoot.java b/rag-langchain4j-AllMiniLmL6V2-llm/src/test/java/com/learning/ai/TestLLMRagWithSpringBoot.java
similarity index 100%
rename from llm-rag-with-langchain4j-spring-boot/src/test/java/com/learning/ai/TestLLMRagWithSpringBoot.java
rename to rag-langchain4j-AllMiniLmL6V2-llm/src/test/java/com/learning/ai/TestLLMRagWithSpringBoot.java
diff --git a/llm-rag-with-langchain4j-spring-boot/src/test/java/com/learning/ai/config/ContainersConfig.java b/rag-langchain4j-AllMiniLmL6V2-llm/src/test/java/com/learning/ai/config/ContainersConfig.java
similarity index 100%
rename from llm-rag-with-langchain4j-spring-boot/src/test/java/com/learning/ai/config/ContainersConfig.java
rename to rag-langchain4j-AllMiniLmL6V2-llm/src/test/java/com/learning/ai/config/ContainersConfig.java
diff --git a/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/controller/AiController.java b/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/controller/AiController.java
index 0cd2830..1945559 100644
--- a/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/controller/AiController.java
+++ b/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/controller/AiController.java
@@ -1,14 +1,13 @@
package com.learning.ai.llmragwithspringai.controller;
+import com.learning.ai.llmragwithspringai.model.request.AIChatRequest;
+import com.learning.ai.llmragwithspringai.model.response.AIChatResponse;
import com.learning.ai.llmragwithspringai.service.AIChatService;
-import jakarta.validation.constraints.NotBlank;
-import jakarta.validation.constraints.Pattern;
-import jakarta.validation.constraints.Size;
-import java.util.Map;
+import jakarta.validation.Valid;
import org.springframework.validation.annotation.Validated;
-import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.PostMapping;
+import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
-import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
@RestController
@@ -22,14 +21,9 @@ public AiController(AIChatService aiChatService) {
this.aiChatService = aiChatService;
}
- @GetMapping("/chat")
- Map ragService(
- @RequestParam
- @NotBlank(message = "Query cannot be empty")
- @Size(max = 255, message = "Query exceeds maximum length")
- @Pattern(regexp = "^[a-zA-Z0-9 ]*$", message = "Invalid characters in query")
- String question) {
- String chatResponse = aiChatService.chat(question);
- return Map.of("response", chatResponse);
+ @PostMapping("/chat")
+ AIChatResponse ragService(@Valid @RequestBody AIChatRequest aiChatRequest) {
+ String chatResponse = aiChatService.chat(aiChatRequest.question());
+ return new AIChatResponse(chatResponse);
}
}
diff --git a/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/model/request/AIChatRequest.java b/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/model/request/AIChatRequest.java
new file mode 100644
index 0000000..ffc9efe
--- /dev/null
+++ b/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/model/request/AIChatRequest.java
@@ -0,0 +1,13 @@
+package com.learning.ai.llmragwithspringai.model.request;
+
+import jakarta.validation.constraints.NotBlank;
+import jakarta.validation.constraints.Pattern;
+import jakarta.validation.constraints.Size;
+import java.io.Serializable;
+
+public record AIChatRequest(
+ @NotBlank(message = "Query cannot be empty")
+ @Size(max = 800, message = "Query exceeds maximum length")
+ @Pattern(regexp = "^[a-zA-Z0-9 ?]*$", message = "Invalid characters in query")
+ String question)
+ implements Serializable {}
diff --git a/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/model/response/AIChatResponse.java b/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/model/response/AIChatResponse.java
new file mode 100644
index 0000000..c298a0b
--- /dev/null
+++ b/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/model/response/AIChatResponse.java
@@ -0,0 +1,5 @@
+package com.learning.ai.llmragwithspringai.model.response;
+
+import java.io.Serializable;
+
+public record AIChatResponse(String response) implements Serializable {}
diff --git a/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java b/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java
index dc41a0f..66a76a5 100644
--- a/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java
+++ b/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java
@@ -18,14 +18,13 @@ public class AIChatService {
private static final String template =
"""
+ You're assisting with questions about cricket
- You're assisting with questions about cricketers
Cricket is a bat-and-ball game that is played between two teams of eleven players on a field at the centre of which is a 22-yard (20-metre) pitch with a wicket at each end,
each comprising two bails balanced on three stumps.
Two players from the batting team (the striker and nonstriker) stand in front of either wicket,
with one player from the fielding team (the bowler) bowling the ball towards the striker's wicket from the opposite end of the pitch.
- The striker's goal is to hit the bowled ball and then switch places with the nonstriker,
- with the batting team scoring one run for each exchange.
+ The striker's goal is to hit the bowled ball and then switch places with the nonstriker, with the batting team scoring one run for each exchange.
Runs are also scored when the ball reaches or crosses the boundary of the field or when the ball is bowled illegally.
Use the information from the DOCUMENTS section to provide accurate answers but act as if you knew this information innately.
@@ -44,16 +43,16 @@ public AIChatService(ChatClient aiClient, VectorStore vectorStore) {
this.vectorStore = vectorStore;
}
- public String chat(String message) {
+ public String chat(String searchQuery) {
// Querying the VectorStore using natural language looking for the information about info asked.
- List listOfSimilarDocuments = this.vectorStore.similaritySearch(message);
+ List listOfSimilarDocuments = this.vectorStore.similaritySearch(searchQuery);
String documents = listOfSimilarDocuments.stream()
.map(Document::getContent)
.collect(Collectors.joining(System.lineSeparator()));
// Constructing the systemMessage to indicate the AI model to use the passed information
// to answer the question.
Message systemMessage = new SystemPromptTemplate(template).createMessage(Map.of("documents", documents));
- UserMessage userMessage = new UserMessage(message);
+ UserMessage userMessage = new UserMessage(searchQuery);
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
ChatResponse aiResponse = aiClient.call(prompt);
return aiResponse.getResult().getOutput().getContent();
diff --git a/rag-springai-openai-llm/src/test/java/com/learning/ai/llmragwithspringai/LlmRagWithSpringAiApplicationIntTest.java b/rag-springai-openai-llm/src/test/java/com/learning/ai/llmragwithspringai/LlmRagWithSpringAiApplicationIntTest.java
index 55d1977..6308f98 100644
--- a/rag-springai-openai-llm/src/test/java/com/learning/ai/llmragwithspringai/LlmRagWithSpringAiApplicationIntTest.java
+++ b/rag-springai-openai-llm/src/test/java/com/learning/ai/llmragwithspringai/LlmRagWithSpringAiApplicationIntTest.java
@@ -4,11 +4,15 @@
import static org.hamcrest.Matchers.*;
import com.learning.ai.llmragwithspringai.config.AbstractIntegrationTest;
+import com.learning.ai.llmragwithspringai.model.request.AIChatRequest;
import io.restassured.RestAssured;
+import io.restassured.http.ContentType;
+import org.apache.http.HttpStatus;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.springframework.boot.test.web.server.LocalServerPort;
+import org.springframework.http.MediaType;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class LlmRagWithSpringAiApplicationIntTest extends AbstractIntegrationTest {
@@ -23,30 +27,37 @@ void setUp() {
@Test
void testRag() {
- given().param("question", "What trophies did Rohit won")
+ given().contentType(MediaType.APPLICATION_JSON_VALUE)
+ .body(new AIChatRequest("What trophies did Rohit won?"))
.when()
- .get("/api/ai/chat")
+ .post("/api/ai/chat")
.then()
- .statusCode(200)
+ .statusCode(HttpStatus.SC_OK)
.body("response", containsString("2007 T20 World Cup"))
- .body("response", containsString("2013 ICC Champions Trophy"));
+ .body("response", containsString("2013 ICC Champions Trophy"))
+ .log()
+ .all();
}
@Test
void testRag2() {
- given().param("question", "Who is successful IPL captain")
+ given().contentType(MediaType.APPLICATION_JSON_VALUE)
+ .body(new AIChatRequest("Who is successful IPL captain?"))
.when()
- .get("/api/ai/chat")
+ .post("/api/ai/chat")
.then()
.statusCode(200)
- .body("response", containsString("Rohit Sharma"));
+ .body("response", containsString("Rohit Sharma"))
+ .log()
+ .all();
}
@Test
void testEmptyQuery() {
- given().param("question", "")
+ given().contentType(ContentType.JSON)
+ .body(new AIChatRequest(""))
.when()
- .get("/api/ai/chat")
+ .post("/api/ai/chat")
.then()
.statusCode(400)
.header("Content-Type", is("application/problem+json"))
@@ -54,16 +65,19 @@ void testEmptyQuery() {
.body("instance", is("/api/ai/chat"))
.body("title", is("Constraint Violation"))
.body("violations", hasSize(1))
- .body("violations[0].field", is("ragService.question"))
- .body("violations[0].message", containsString("Query cannot be empty"));
+ .body("violations[0].field", is("question"))
+ .body("violations[0].message", containsString("Query cannot be empty"))
+ .log()
+ .all();
}
@Test
void testLongQueryString() {
String longQuery = "a".repeat(1000); // Example of a very long query string
- given().param("question", longQuery)
+ given().contentType(ContentType.JSON)
+ .body(new AIChatRequest(longQuery))
.when()
- .get("/api/ai/chat")
+ .post("/api/ai/chat")
.then()
.statusCode(400)
.header("Content-Type", is("application/problem+json"))
@@ -71,15 +85,18 @@ void testLongQueryString() {
.body("instance", is("/api/ai/chat"))
.body("title", is("Constraint Violation"))
.body("violations", hasSize(1))
- .body("violations[0].field", is("ragService.question"))
- .body("violations[0].message", containsString("Query exceeds maximum length"));
+ .body("violations[0].field", is("question"))
+ .body("violations[0].message", containsString("Query exceeds maximum length"))
+ .log()
+ .all();
}
@Test
void testSpecialCharactersInQuery() {
- given().param("question", "@#$%^&*()")
+ given().contentType(ContentType.JSON)
+ .body(new AIChatRequest("@#$%^&*()"))
.when()
- .get("/api/ai/chat")
+ .post("/api/ai/chat")
.then()
.statusCode(400)
.header("Content-Type", is("application/problem+json"))
@@ -87,8 +104,9 @@ void testSpecialCharactersInQuery() {
.body("instance", is("/api/ai/chat"))
.body("title", is("Constraint Violation"))
.body("violations", hasSize(1))
- .body("violations[0].field", is("ragService.question"))
+ .body("violations[0].field", is("question"))
.body("violations[0].message", containsString("Invalid characters in query"))
- .log();
+ .log()
+ .all();
}
}