Skip to content

Commit

Permalink
Add OutputParser
Browse files Browse the repository at this point in the history
  • Loading branch information
markpollack committed Aug 18, 2023
1 parent 678690a commit 36f3e96
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 34 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package org.springframework.ai.parser;

import org.springframework.core.convert.support.DefaultConversionService;

public abstract class AbstractConversionServiceOutputParser<T> implements OutputParser<T> {

private final DefaultConversionService conversionService;

public AbstractConversionServiceOutputParser(DefaultConversionService conversionService) {
this.conversionService = conversionService;
}

public DefaultConversionService getConversionService() {
return conversionService;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package org.springframework.ai.parser;

import org.springframework.core.convert.support.DefaultConversionService;

import java.util.List;

/**
* Parse out a List from a formatting request to convert a
*/
public class ListOutputParser extends AbstractConversionServiceOutputParser<List<String>> {

public ListOutputParser(DefaultConversionService defaultConversionService) {
super(defaultConversionService);
}

@Override
public String getFormat() {
return """
Your response should be a list of comma separated values
eg: `foo, bar, baz`
""";
}

@Override
public List<String> parse(String text) {
return getConversionService().convert(text, List.class);
}

}
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
package org.springframework.ai.parser;

import org.springframework.ai.client.Generation;
/*
* Copyright 2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import java.util.List;
package org.springframework.ai.parser;

public interface OutputParser<T> {
import org.springframework.ai.prompt.FormatProvider;

T parse(List<Generation> output);
public interface OutputParser<T> extends Parser<T>, FormatProvider {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package org.springframework.ai.parser;

import java.util.Locale;

@FunctionalInterface
public interface Parser<T> {

T parse(String text);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package org.springframework.ai.prompt;

public interface FormatProvider {

String getFormat();

}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import org.antlr.runtime.Token;
import org.antlr.runtime.TokenStream;
import org.springframework.ai.parser.OutputParser;
import org.springframework.ai.prompt.messages.Message;
import org.springframework.ai.prompt.messages.UserMessage;
import org.springframework.core.io.Resource;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package org.springframework.ai.prompt.parsers;

import org.junit.jupiter.api.Test;
import org.springframework.ai.parser.ListOutputParser;
import org.springframework.core.convert.support.DefaultConversionService;

import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;

class ListOutputParserTest {

@Test
void csv() {
String csvAsString = "foo, bar, baz";
ListOutputParser listOutputParser = new ListOutputParser(new DefaultConversionService());
List<String> list = listOutputParser.parse(csvAsString);
assertThat(list).containsExactlyElementsOf(List.of("foo", "bar", "baz"));
}

}
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package org.springframework.ai.openai.client;

import org.junit.jupiter.api.Test;
import org.springframework.ai.client.AiClient;
import org.springframework.ai.client.Generation;
import org.springframework.ai.parser.OutputParser;
import org.springframework.ai.parser.ListOutputParser;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.prompt.PromptTemplate;
import org.springframework.ai.prompt.SystemPromptTemplate;
Expand All @@ -11,6 +14,7 @@
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.core.io.Resource;

import java.util.List;
Expand Down Expand Up @@ -47,7 +51,6 @@ void roleTest() {
assertThat(response).isNotNull();

evaluateQuestionAndAnswer(request, response.getText());

}

private void evaluateQuestionAndAnswer(String question, String answer) {
Expand All @@ -61,4 +64,25 @@ private void evaluateQuestionAndAnswer(String question, String answer) {
assertThat(response.getText()).isEqualTo("YES");
}

@Test
void outputParser() {
DefaultConversionService conversionService = new DefaultConversionService();
ListOutputParser outputParser = new ListOutputParser(conversionService);

String format = outputParser.getFormat();
String template = """
List five {subject}
{format}
""";
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "ice cream flavors", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = openAiClient.generate(prompt).getGeneration();

List<String> list = outputParser.parse(generation.getText());
System.out.println(list);
assertThat(list).hasSize(5);

}

}

0 comments on commit 36f3e96

Please sign in to comment.