Skip to content

Commit

Permalink
Adding initial Chain feature
Browse files Browse the repository at this point in the history
* Add Chain interface
* Add OuputParser interface
* Add Memory interface
* Rework core prompt OO models
* Change to use PromptTemplateActions as the interface for PromptTemplate
  • Loading branch information
markpollack committed Aug 12, 2023
1 parent ca26c30 commit 2608c95
Show file tree
Hide file tree
Showing 15 changed files with 311 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.slf4j.LoggerFactory;
import org.springframework.ai.core.llm.LLMResult;
import org.springframework.ai.core.llm.LlmClient;
import org.springframework.ai.core.prompt.Generation;
import org.springframework.ai.core.llm.Generation;
import org.springframework.ai.core.prompt.Prompt;
import org.springframework.ai.core.prompt.messages.Message;
import org.springframework.util.Assert;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package org.springframework.ai.core.chain;

import org.springframework.ai.core.memory.Memory;

import java.util.*;

public abstract class AbstractChain implements Chain {

private Optional<Memory> memory = Optional.empty();

public Optional<Memory> getMemory() {
return this.memory;
}

public void setMemory(Memory memory) {
Objects.requireNonNull(memory, "Memory can not be null.");
this.memory = Optional.of(memory);
}

@Override
public abstract List<String> getInputKeys();

@Override
public abstract List<String> getOutputKeys();

// TODO validation of input/outputs

@Override
public Map<String, Object> apply(Map<String, Object> inputMap) {
Map<String, Object> inputMapToUse = processBeforeApply(inputMap);
Map<String, Object> outputMap = doApply(inputMapToUse);
Map<String, Object> outputMapToUse = processAfterApply(inputMapToUse, outputMap);
return outputMapToUse;
}

protected Map<String, Object> processBeforeApply(Map<String, Object> inputMap) {
validateInputs(inputMap);
return inputMap;
}

protected abstract Map<String, Object> doApply(Map<String, Object> inputMap);

private Map<String, Object> processAfterApply(Map<String, Object> inputMap, Map<String, Object> outputMap) {
validateOutputs(outputMap);
Map<String, Object> combindedMap = new HashMap<>();
combindedMap.putAll(inputMap);
combindedMap.putAll(outputMap);
return combindedMap;
}

protected void validateOutputs(Map<String, Object> outputMap) {
Set<String> missingKeys = new HashSet<>(getOutputKeys());
missingKeys.removeAll(outputMap.keySet());
if (!missingKeys.isEmpty()) {
throw new IllegalArgumentException("Missing some output keys: " + missingKeys);
}
}

protected void validateInputs(Map<String, Object> inputMap) {
Set<String> missingKeys = new HashSet<>(getInputKeys());
missingKeys.removeAll(inputMap.keySet());
if (!missingKeys.isEmpty()) {
throw new IllegalArgumentException("Missing some input keys: " + missingKeys);
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.springframework.ai.core.chain;

import java.util.List;
import java.util.Map;
import java.util.function.Function;

public interface Chain extends Function<Map<String, Object>, Map<String, Object>> {

List<String> getInputKeys();

List<String> getOutputKeys();

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.core.prompt;
package org.springframework.ai.core.llm;

import java.util.HashMap;
import java.util.Map;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
*/
package org.springframework.ai.core.llm;

import org.springframework.ai.core.prompt.Generation;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package org.springframework.ai.core.memory;

import java.util.List;
import java.util.Map;

public interface Memory {

/**
* The keys that the memory will add to Chain inputs
*/
List<String> getKeys();

/**
* Return key-value pairs given the text input to the chain
* @param inputs input of the chain
* @return key-value pairs from memory
*/
Map<String, Object> load(Map<String, Object> inputs);

void save(Map<String, Object> inputs, Map<String, Object> outputs);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package org.springframework.ai.core.parser;

import org.springframework.ai.core.llm.Generation;

import java.util.List;

public interface OutputParser<T> {

T parse(List<Generation> output);

}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,70 @@

package org.springframework.ai.core.prompt;

import org.springframework.ai.core.prompt.messages.ChatMessage;
import org.springframework.ai.core.prompt.messages.MessageType;
import org.springframework.ai.core.prompt.messages.Message;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
* A PromptTemplate that lets you specify the role as a string should the current
* implementations and their roles not suffice for your needs.
*/
public class ChatPromptTemplate extends PromptTemplate {
public class ChatPromptTemplate implements PromptTemplateActions {

private MessageType messageType;
private final List<PromptTemplate> promptTemplates;

public ChatPromptTemplate(MessageType messageType, String template) {
super(template);
this.messageType = messageType;
public ChatPromptTemplate(List<PromptTemplate> promptTemplates) {
this.promptTemplates = promptTemplates;
}

@Override
public String render() {
StringBuilder sb = new StringBuilder();
for (PromptTemplate promptTemplate : promptTemplates) {
sb.append(promptTemplate.render());
}
return sb.toString();
}

@Override
public String render(Map<String, Object> model) {
StringBuilder sb = new StringBuilder();
for (PromptTemplate promptTemplate : promptTemplates) {
sb.append(promptTemplate.render(model));
}
return sb.toString();
}

@Override
public List<Message> createMessages() {
List<Message> messages = new ArrayList<>();
for (PromptTemplate promptTemplate : promptTemplates) {
messages.addAll(promptTemplate.createMessages());
}
return messages;
}

@Override
public List<Message> createMessages(Map<String, Object> model) {
List<Message> messages = new ArrayList<>();
for (PromptTemplate promptTemplate : promptTemplates) {
messages.addAll(promptTemplate.createMessages(model));
}
return messages;
}

@Override
public Prompt create() {
return new Prompt(new ChatMessage(this.messageType, render()));
List<Message> messages = createMessages();
return new Prompt(messages);
}

@Override
public Prompt create(Map<String, Object> model) {
return new Prompt(new ChatMessage(this.messageType, render(model)));
List<Message> messages = createMessages(model);
return new Prompt(messages);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.ai.core.prompt;

import org.springframework.ai.core.llm.Generation;

import java.util.List;

public interface OutputParser {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,30 @@

import org.antlr.runtime.Token;
import org.antlr.runtime.TokenStream;
import org.springframework.ai.core.prompt.messages.Message;
import org.springframework.ai.core.prompt.messages.UserMessage;
import org.stringtemplate.v4.ST;
import org.stringtemplate.v4.compiler.STLexer;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.*;
import java.util.Map.Entry;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class PromptTemplate extends AbstractPromptTemplate {
public class PromptTemplate implements PromptTemplateActions {

private ST st;

private Map<String, Object> dynamicModel = new HashMap<>();

protected String template;

protected TemplateFormat templateFormat = TemplateFormat.ST;

private OutputParser outputParser;

public PromptTemplate(String template) {
super(template);
this.template = template;
// If the template string is not valid, an exception will be thrown
try {
this.st = new ST(this.template, '{', '}');
Expand All @@ -46,12 +51,42 @@ public PromptTemplate(String template) {
}
}

@Override
public PromptTemplate(String template, Map<String, Object> model) {
this.template = template;
// If the template string is not valid, an exception will be thrown
try {
this.st = new ST(this.template, '{', '}');
for (Entry<String, Object> entry : model.entrySet()) {
add(entry.getKey(), entry.getValue());
}
}
catch (Exception ex) {
throw new IllegalArgumentException("The template string is not valid.", ex);
}
}

public OutputParser getOutputParser() {
return outputParser;
}

public void setOutputParser(OutputParser outputParser) {
Objects.requireNonNull(outputParser, "Output Parser can not be null");
this.outputParser = outputParser;
}

public void add(String name, Object value) {
this.st.add(name, value);
this.dynamicModel.put(name, value);
}

public String getTemplate() {
return this.template;
}

public TemplateFormat getTemplateFormat() {
return this.templateFormat;
}

// Render Methods
public String render() {
return st.render();
Expand All @@ -68,6 +103,16 @@ public String render(Map<String, Object> model) {
return st.render().trim();
}

@Override
public List<Message> createMessages() {
return List.of(new UserMessage(render()));
}

@Override
public List<Message> createMessages(Map<String, Object> model) {
return List.of(new UserMessage(render(model)));
}

@Override
public Prompt create() {
return new Prompt(render(new HashMap<>()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@

package org.springframework.ai.core.prompt;

import java.util.Map;

public interface PromptOperations {

String getTemplate();
import org.springframework.ai.core.prompt.messages.Message;

TemplateFormat getTemplateFormat();
import java.util.List;
import java.util.Map;

void add(String name, Object value);
public interface PromptTemplateActions {

String render();

String render(Map<String, Object> model);

List<Message> createMessages();

List<Message> createMessages(Map<String, Object> model);

Prompt create();

Prompt create(Map<String, Object> model);
Expand Down
Loading

0 comments on commit 2608c95

Please sign in to comment.