-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat : implement simple RAG * update assertion
- Loading branch information
1 parent
f470727
commit a2d689d
Showing
5 changed files
with
179 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
128 changes: 128 additions & 0 deletions
128
chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
package com.example.ai.service; | ||
|
||
import com.example.ai.model.response.AIChatResponse; | ||
import com.example.ai.model.response.ActorsFilms; | ||
import java.io.IOException; | ||
import java.nio.charset.StandardCharsets; | ||
import java.util.ArrayList; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
import org.springframework.ai.chat.ChatClient; | ||
import org.springframework.ai.chat.ChatResponse; | ||
import org.springframework.ai.chat.Generation; | ||
import org.springframework.ai.chat.messages.SystemMessage; | ||
import org.springframework.ai.chat.messages.UserMessage; | ||
import org.springframework.ai.chat.prompt.Prompt; | ||
import org.springframework.ai.chat.prompt.PromptTemplate; | ||
import org.springframework.ai.document.Document; | ||
import org.springframework.ai.embedding.EmbeddingClient; | ||
import org.springframework.ai.parser.BeanOutputParser; | ||
import org.springframework.ai.vectorstore.SearchRequest; | ||
import org.springframework.ai.vectorstore.SimpleVectorStore; | ||
import org.springframework.beans.factory.annotation.Value; | ||
import org.springframework.core.io.Resource; | ||
import org.springframework.stereotype.Service; | ||
|
||
@Service | ||
public class ChatService { | ||
|
||
private static final Logger logger = LoggerFactory.getLogger(ChatService.class); | ||
|
||
@Value("classpath:/data/restaurants.json") | ||
private Resource restaurantsResource; | ||
|
||
@Value("classpath:/rag-prompt-template.st") | ||
private Resource ragPromptTemplate; | ||
|
||
private final EmbeddingClient embeddingClient; | ||
private final ChatClient chatClient; | ||
|
||
public ChatService(EmbeddingClient embeddingClient, ChatClient chatClient) { | ||
this.embeddingClient = embeddingClient; | ||
this.chatClient = chatClient; | ||
} | ||
|
||
public AIChatResponse chat(String query) { | ||
String answer = chatClient.call(query); | ||
return new AIChatResponse(answer); | ||
} | ||
|
||
public AIChatResponse chatWithPrompt(String query) { | ||
PromptTemplate promptTemplate = new PromptTemplate("Tell me a joke about {subject}"); | ||
Prompt prompt = promptTemplate.create(Map.of("subject", query)); | ||
ChatResponse response = chatClient.call(prompt); | ||
Generation generation = response.getResult(); | ||
String answer = (generation != null) ? generation.getOutput().getContent() : ""; | ||
return new AIChatResponse(answer); | ||
} | ||
|
||
public AIChatResponse chatWithSystemPrompt(String query) { | ||
SystemMessage systemMessage = new SystemMessage("You are a sarcastic and funny chatbot"); | ||
UserMessage userMessage = new UserMessage("Tell me a joke about " + query); | ||
Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); | ||
ChatResponse response = chatClient.call(prompt); | ||
String answer = response.getResult().getOutput().getContent(); | ||
return new AIChatResponse(answer); | ||
} | ||
|
||
public AIChatResponse getEmbeddings(String query) { | ||
List<Double> embed = embeddingClient.embed(query); | ||
return new AIChatResponse(embed.toString()); | ||
} | ||
|
||
public ActorsFilms generateAsBean(String actor) { | ||
BeanOutputParser<ActorsFilms> outputParser = new BeanOutputParser<>(ActorsFilms.class); | ||
|
||
String format = outputParser.getFormat(); | ||
String template = """ | ||
Generate the filmography for the actor {actor}. | ||
{format} | ||
"""; | ||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("actor", actor, "format", format)); | ||
Prompt prompt = new Prompt(promptTemplate.createMessage()); | ||
ChatResponse response = chatClient.call(prompt); | ||
Generation generation = response.getResult(); | ||
|
||
return outputParser.parse(generation.getOutput().getContent()); | ||
} | ||
|
||
public AIChatResponse ragGenerate(String query) throws IOException { | ||
|
||
// Step 1 - Load JSON document as Documents and save | ||
logger.info("Loading JSON as Documents and save"); | ||
SimpleVectorStore simpleVectorStore = new SimpleVectorStore(embeddingClient); | ||
List<Document> linesDocuments = new ArrayList<>(); | ||
|
||
if (restaurantsResource.exists()) { // load existing vector store if exists | ||
String contentAsString = restaurantsResource.getContentAsString(StandardCharsets.UTF_8); | ||
// Convert lines to Documents in parallel | ||
linesDocuments = | ||
contentAsString.lines().parallel().map(Document::new).toList(); | ||
simpleVectorStore.accept(linesDocuments); | ||
} | ||
|
||
// Step 2 retrieve related documents to query | ||
logger.info("Retrieving relevant documents"); | ||
List<Document> similarDocuments = | ||
simpleVectorStore.similaritySearch(SearchRequest.query(query).withTopK(2)); | ||
logger.info(String.format("Found %s relevant documents.", similarDocuments.size())); | ||
|
||
List<String> contentList = | ||
similarDocuments.stream().map(Document::getContent).toList(); | ||
PromptTemplate promptTemplate = new PromptTemplate(ragPromptTemplate); | ||
Map<String, Object> promptParameters = new HashMap<>(); | ||
|
||
promptParameters.put("input", query); | ||
promptParameters.put("documents", String.join("\n", contentList)); | ||
Prompt prompt = promptTemplate.create(promptParameters); | ||
|
||
ChatResponse response = chatClient.call(prompt); | ||
Generation generation = response.getResult(); | ||
String answer = (generation != null) ? generation.getOutput().getContent() : ""; | ||
simpleVectorStore.delete(linesDocuments.stream().map(Document::getId).toList()); | ||
return new AIChatResponse(answer); | ||
} | ||
} |
10 changes: 10 additions & 0 deletions
10
chatmodel-springai/src/main/resources/data/restaurants.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
{"address": {"building": "1007", "coord": [-73.856077, 40.848447], "street": "Morris Park Ave", "zipcode": "10462"}, "borough": "Bronx", "cuisine": "Bakery", "grades": [{"date": {"$date": 1393804800000}, "grade": "A", "score": 2}, {"date": {"$date": 1378857600000}, "grade": "A", "score": 6}, {"date": {"$date": 1358985600000}, "grade": "A", "score": 10}, {"date": {"$date": 1322006400000}, "grade": "A", "score": 9}, {"date": {"$date": 1299715200000}, "grade": "B", "score": 14}], "name": "Morris Park Bake Shop", "restaurant_id": "30075445"} | ||
{"address": {"building": "469", "coord": [-73.961704, 40.662942], "street": "Flatbush Avenue", "zipcode": "11225"}, "borough": "Brooklyn", "cuisine": "Hamburgers", "grades": [{"date": {"$date": 1419897600000}, "grade": "A", "score": 8}, {"date": {"$date": 1404172800000}, "grade": "B", "score": 23}, {"date": {"$date": 1367280000000}, "grade": "A", "score": 12}, {"date": {"$date": 1336435200000}, "grade": "A", "score": 12}], "name": "Wendy'S", "restaurant_id": "30112340"} | ||
{"address": {"building": "351", "coord": [-73.98513559999999, 40.7676919], "street": "West 57 Street", "zipcode": "10019"}, "borough": "Manhattan", "cuisine": "Irish", "grades": [{"date": {"$date": 1409961600000}, "grade": "A", "score": 2}, {"date": {"$date": 1374451200000}, "grade": "A", "score": 11}, {"date": {"$date": 1343692800000}, "grade": "A", "score": 12}, {"date": {"$date": 1325116800000}, "grade": "A", "score": 12}], "name": "Dj Reynolds Pub And Restaurant", "restaurant_id": "30191841"} | ||
{"address": {"building": "2780", "coord": [-73.98241999999999, 40.579505], "street": "Stillwell Avenue", "zipcode": "11224"}, "borough": "Brooklyn", "cuisine": "American ", "grades": [{"date": {"$date": 1402358400000}, "grade": "A", "score": 5}, {"date": {"$date": 1370390400000}, "grade": "A", "score": 7}, {"date": {"$date": 1334275200000}, "grade": "A", "score": 12}, {"date": {"$date": 1318377600000}, "grade": "A", "score": 12}], "name": "Riviera Caterer", "restaurant_id": "40356018"} | ||
{"address": {"building": "97-22", "coord": [-73.8601152, 40.7311739], "street": "63 Road", "zipcode": "11374"}, "borough": "Queens", "cuisine": "Jewish/Kosher", "grades": [{"date": {"$date": 1416787200000}, "grade": "Z", "score": 20}, {"date": {"$date": 1358380800000}, "grade": "A", "score": 13}, {"date": {"$date": 1343865600000}, "grade": "A", "score": 13}, {"date": {"$date": 1323907200000}, "grade": "B", "score": 25}], "name": "Tov Kosher Kitchen", "restaurant_id": "40356068"} | ||
{"address": {"building": "8825", "coord": [-73.8803827, 40.7643124], "street": "Astoria Boulevard", "zipcode": "11369"}, "borough": "Queens", "cuisine": "American ", "grades": [{"date": {"$date": 1416009600000}, "grade": "Z", "score": 38}, {"date": {"$date": 1398988800000}, "grade": "A", "score": 10}, {"date": {"$date": 1362182400000}, "grade": "A", "score": 7}, {"date": {"$date": 1328832000000}, "grade": "A", "score": 13}], "name": "Brunos On The Boulevard", "restaurant_id": "40356151"} | ||
{"address": {"building": "2206", "coord": [-74.1377286, 40.6119572], "street": "Victory Boulevard", "zipcode": "10314"}, "borough": "Staten Island", "cuisine": "Jewish/Kosher", "grades": [{"date": {"$date": 1412553600000}, "grade": "A", "score": 9}, {"date": {"$date": 1400544000000}, "grade": "A", "score": 12}, {"date": {"$date": 1365033600000}, "grade": "A", "score": 12}, {"date": {"$date": 1327363200000}, "grade": "A", "score": 9}], "name": "Kosher Island", "restaurant_id": "40356442"} | ||
{"address": {"building": "7114", "coord": [-73.9068506, 40.6199034], "street": "Avenue U", "zipcode": "11234"}, "borough": "Brooklyn", "cuisine": "Delicatessen", "grades": [{"date": {"$date": 1401321600000}, "grade": "A", "score": 10}, {"date": {"$date": 1389657600000}, "grade": "A", "score": 10}, {"date": {"$date": 1375488000000}, "grade": "A", "score": 8}, {"date": {"$date": 1342569600000}, "grade": "A", "score": 10}, {"date": {"$date": 1331251200000}, "grade": "A", "score": 13}, {"date": {"$date": 1318550400000}, "grade": "A", "score": 9}], "name": "Wilken'S Fine Food", "restaurant_id": "40356483"} | ||
{"address": {"building": "6409", "coord": [-74.00528899999999, 40.628886], "street": "11 Avenue", "zipcode": "11219"}, "borough": "Brooklyn", "cuisine": "American ", "grades": [{"date": {"$date": 1405641600000}, "grade": "A", "score": 12}, {"date": {"$date": 1375142400000}, "grade": "A", "score": 12}, {"date": {"$date": 1360713600000}, "grade": "A", "score": 11}, {"date": {"$date": 1345075200000}, "grade": "A", "score": 2}, {"date": {"$date": 1313539200000}, "grade": "A", "score": 11}], "name": "Regina Caterers", "restaurant_id": "40356649"} | ||
{"address": {"building": "1839", "coord": [-73.9482609, 40.6408271], "street": "Nostrand Avenue", "zipcode": "11226"}, "borough": "Brooklyn", "cuisine": "Ice Cream, Gelato, Yogurt, Ices", "grades": [{"date": {"$date": 1405296000000}, "grade": "A", "score": 12}, {"date": {"$date": 1373414400000}, "grade": "A", "score": 8}, {"date": {"$date": 1341964800000}, "grade": "A", "score": 5}, {"date": {"$date": 1329955200000}, "grade": "A", "score": 8}], "name": "Taste The Tropics Ice Cream", "restaurant_id": "40356731"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
You are a helpful assistant, conversing with a user about the subjects contained in a set of documents. | ||
Use the information from the DOCUMENTS section to provide accurate answers. If unsure or if the answer | ||
isn't found in the DOCUMENTS section, simply state that you don't know the answer. | ||
|
||
QUESTION: | ||
{input} | ||
|
||
DOCUMENTS: | ||
{documents} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters