From 36f3e96745b8573f4db4e505ddb75a933b93e3db Mon Sep 17 00:00:00 2001 From: Mark Pollack Date: Fri, 18 Aug 2023 16:37:34 -0400 Subject: [PATCH] Add OutputParser --- ...AbstractConversionServiceOutputParser.java | 17 +++++++++++ .../ai/parser/ListOutputParser.java | 29 +++++++++++++++++++ .../ai/parser/OutputParser.java | 24 +++++++++++---- .../org/springframework/ai/parser/Parser.java | 10 +++++++ .../ai/prompt/FormatProvider.java | 7 +++++ .../ai/prompt/OutputParser.java | 27 ----------------- .../ai/prompt/PromptTemplate.java | 1 + .../prompt/parsers/ListOutputParserTest.java | 21 ++++++++++++++ .../openai/client/ClientIntegrationTests.java | 26 ++++++++++++++++- 9 files changed, 128 insertions(+), 34 deletions(-) create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/parser/AbstractConversionServiceOutputParser.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/parser/ListOutputParser.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/parser/Parser.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/prompt/FormatProvider.java delete mode 100644 spring-ai-core/src/main/java/org/springframework/ai/prompt/OutputParser.java create mode 100644 spring-ai-core/src/test/java/org/springframework/ai/prompt/parsers/ListOutputParserTest.java diff --git a/spring-ai-core/src/main/java/org/springframework/ai/parser/AbstractConversionServiceOutputParser.java b/spring-ai-core/src/main/java/org/springframework/ai/parser/AbstractConversionServiceOutputParser.java new file mode 100644 index 0000000000..13eab1370d --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/parser/AbstractConversionServiceOutputParser.java @@ -0,0 +1,17 @@ +package org.springframework.ai.parser; + +import org.springframework.core.convert.support.DefaultConversionService; + +public abstract class AbstractConversionServiceOutputParser implements OutputParser { + + private final DefaultConversionService conversionService; + + public AbstractConversionServiceOutputParser(DefaultConversionService conversionService) { + this.conversionService = conversionService; + } + + public DefaultConversionService getConversionService() { + return conversionService; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/parser/ListOutputParser.java b/spring-ai-core/src/main/java/org/springframework/ai/parser/ListOutputParser.java new file mode 100644 index 0000000000..f863d29c7e --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/parser/ListOutputParser.java @@ -0,0 +1,29 @@ +package org.springframework.ai.parser; + +import org.springframework.core.convert.support.DefaultConversionService; + +import java.util.List; + +/** + * Parse out a List from a formatting request to convert a + */ +public class ListOutputParser extends AbstractConversionServiceOutputParser> { + + public ListOutputParser(DefaultConversionService defaultConversionService) { + super(defaultConversionService); + } + + @Override + public String getFormat() { + return """ + Your response should be a list of comma separated values + eg: `foo, bar, baz` + """; + } + + @Override + public List parse(String text) { + return getConversionService().convert(text, List.class); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/parser/OutputParser.java b/spring-ai-core/src/main/java/org/springframework/ai/parser/OutputParser.java index 5c363b407d..b2cd14d33f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/parser/OutputParser.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/parser/OutputParser.java @@ -1,11 +1,23 @@ -package org.springframework.ai.parser; - -import org.springframework.ai.client.Generation; +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import java.util.List; +package org.springframework.ai.parser; -public interface OutputParser { +import org.springframework.ai.prompt.FormatProvider; - T parse(List output); +public interface OutputParser extends Parser, FormatProvider { } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/parser/Parser.java b/spring-ai-core/src/main/java/org/springframework/ai/parser/Parser.java new file mode 100644 index 0000000000..6ac722e7df --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/parser/Parser.java @@ -0,0 +1,10 @@ +package org.springframework.ai.parser; + +import java.util.Locale; + +@FunctionalInterface +public interface Parser { + + T parse(String text); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/FormatProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/prompt/FormatProvider.java new file mode 100644 index 0000000000..efefae42ee --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/prompt/FormatProvider.java @@ -0,0 +1,7 @@ +package org.springframework.ai.prompt; + +public interface FormatProvider { + + String getFormat(); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/OutputParser.java b/spring-ai-core/src/main/java/org/springframework/ai/prompt/OutputParser.java deleted file mode 100644 index 836b1c4efb..0000000000 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/OutputParser.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright 2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.prompt; - -import org.springframework.ai.client.Generation; - -import java.util.List; - -public interface OutputParser { - - Object parseResut(List generations); - -} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplate.java index ca47301aec..da065e2572 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplate.java @@ -18,6 +18,7 @@ import org.antlr.runtime.Token; import org.antlr.runtime.TokenStream; +import org.springframework.ai.parser.OutputParser; import org.springframework.ai.prompt.messages.Message; import org.springframework.ai.prompt.messages.UserMessage; import org.springframework.core.io.Resource; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/prompt/parsers/ListOutputParserTest.java b/spring-ai-core/src/test/java/org/springframework/ai/prompt/parsers/ListOutputParserTest.java new file mode 100644 index 0000000000..b06bbb14ca --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/prompt/parsers/ListOutputParserTest.java @@ -0,0 +1,21 @@ +package org.springframework.ai.prompt.parsers; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.parser.ListOutputParser; +import org.springframework.core.convert.support.DefaultConversionService; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +class ListOutputParserTest { + + @Test + void csv() { + String csvAsString = "foo, bar, baz"; + ListOutputParser listOutputParser = new ListOutputParser(new DefaultConversionService()); + List list = listOutputParser.parse(csvAsString); + assertThat(list).containsExactlyElementsOf(List.of("foo", "bar", "baz")); + } + +} \ No newline at end of file diff --git a/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/ClientIntegrationTests.java b/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/ClientIntegrationTests.java index 5053556f11..d03974a4ec 100644 --- a/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/ClientIntegrationTests.java +++ b/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/ClientIntegrationTests.java @@ -1,7 +1,10 @@ package org.springframework.ai.openai.client; import org.junit.jupiter.api.Test; +import org.springframework.ai.client.AiClient; import org.springframework.ai.client.Generation; +import org.springframework.ai.parser.OutputParser; +import org.springframework.ai.parser.ListOutputParser; import org.springframework.ai.prompt.Prompt; import org.springframework.ai.prompt.PromptTemplate; import org.springframework.ai.prompt.SystemPromptTemplate; @@ -11,6 +14,7 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.core.io.Resource; import java.util.List; @@ -47,7 +51,6 @@ void roleTest() { assertThat(response).isNotNull(); evaluateQuestionAndAnswer(request, response.getText()); - } private void evaluateQuestionAndAnswer(String question, String answer) { @@ -61,4 +64,25 @@ private void evaluateQuestionAndAnswer(String question, String answer) { assertThat(response.getText()).isEqualTo("YES"); } + @Test + void outputParser() { + DefaultConversionService conversionService = new DefaultConversionService(); + ListOutputParser outputParser = new ListOutputParser(conversionService); + + String format = outputParser.getFormat(); + String template = """ + List five {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "ice cream flavors", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = openAiClient.generate(prompt).getGeneration(); + + List list = outputParser.parse(generation.getText()); + System.out.println(list); + assertThat(list).hasSize(5); + + } + }