Skip to content

Commit

Permalink
feat : example of output parser (#37)
Browse files Browse the repository at this point in the history
* feat : example of output parser

* feat : change assertion
  • Loading branch information
rajadilipkolli authored Apr 7, 2024
1 parent c894cd0 commit f470727
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.example.ai.model.request.AIChatRequest;
import com.example.ai.model.response.AIChatResponse;
import com.example.ai.model.response.ActorsFilms;
import java.util.List;
import java.util.Map;
import org.springframework.ai.chat.ChatClient;
Expand All @@ -12,9 +13,12 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.parser.BeanOutputParser;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

@RestController
Expand Down Expand Up @@ -61,4 +65,21 @@ AIChatResponse chatWithEmbeddingClient(@RequestBody AIChatRequest aiChatRequest)
List<Double> embed = embeddingClient.embed(aiChatRequest.query());
return new AIChatResponse(embed.toString());
}

@GetMapping("/output")
public ActorsFilms generate(@RequestParam(value = "actor", defaultValue = "Jr NTR") String actor) {
BeanOutputParser<ActorsFilms> outputParser = new BeanOutputParser<>(ActorsFilms.class);

String format = outputParser.getFormat();
String template = """
Generate the filmography for the actor {actor}.
{format}
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("actor", actor, "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
ChatResponse response = chatClient.call(prompt);
Generation generation = response.getResult();

return outputParser.parse(generation.getOutput().getContent());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.example.ai.model.response;

import java.util.List;

public record ActorsFilms(String actor, List<String> movies) {}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

import static io.restassured.RestAssured.given;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;

import com.example.ai.model.request.AIChatRequest;
import io.restassured.RestAssured;
import io.restassured.http.ContentType;
import org.apache.http.HttpStatus;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
Expand Down Expand Up @@ -56,4 +60,16 @@ void chatWithSystemPrompt() {
.statusCode(200)
.body("answer", containsString("cricket"));
}

@Test
void outputParser() {
given().param("actor", "Jr NTR")
.when()
.get("/api/ai/output")
.then()
.statusCode(HttpStatus.SC_OK)
.contentType(ContentType.JSON)
.body("actor", is("Jr NTR"))
.body("movies", hasSize(greaterThanOrEqualTo(25)));
}
}

0 comments on commit f470727

Please sign in to comment.