-
Notifications
You must be signed in to change notification settings - Fork 824
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add Chain interface * Add OuputParser interface * Add Memory interface * Rework core prompt OO models * Change to use PromptTemplateActions as the interface for PromptTemplate
- Loading branch information
1 parent
ca26c30
commit 2608c95
Showing
15 changed files
with
311 additions
and
107 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
67 changes: 67 additions & 0 deletions
67
spring-ai-core/src/main/java/org/springframework/ai/core/chain/AbstractChain.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
package org.springframework.ai.core.chain; | ||
|
||
import org.springframework.ai.core.memory.Memory; | ||
|
||
import java.util.*; | ||
|
||
public abstract class AbstractChain implements Chain { | ||
|
||
private Optional<Memory> memory = Optional.empty(); | ||
|
||
public Optional<Memory> getMemory() { | ||
return this.memory; | ||
} | ||
|
||
public void setMemory(Memory memory) { | ||
Objects.requireNonNull(memory, "Memory can not be null."); | ||
this.memory = Optional.of(memory); | ||
} | ||
|
||
@Override | ||
public abstract List<String> getInputKeys(); | ||
|
||
@Override | ||
public abstract List<String> getOutputKeys(); | ||
|
||
// TODO validation of input/outputs | ||
|
||
@Override | ||
public Map<String, Object> apply(Map<String, Object> inputMap) { | ||
Map<String, Object> inputMapToUse = processBeforeApply(inputMap); | ||
Map<String, Object> outputMap = doApply(inputMapToUse); | ||
Map<String, Object> outputMapToUse = processAfterApply(inputMapToUse, outputMap); | ||
return outputMapToUse; | ||
} | ||
|
||
protected Map<String, Object> processBeforeApply(Map<String, Object> inputMap) { | ||
validateInputs(inputMap); | ||
return inputMap; | ||
} | ||
|
||
protected abstract Map<String, Object> doApply(Map<String, Object> inputMap); | ||
|
||
private Map<String, Object> processAfterApply(Map<String, Object> inputMap, Map<String, Object> outputMap) { | ||
validateOutputs(outputMap); | ||
Map<String, Object> combindedMap = new HashMap<>(); | ||
combindedMap.putAll(inputMap); | ||
combindedMap.putAll(outputMap); | ||
return combindedMap; | ||
} | ||
|
||
protected void validateOutputs(Map<String, Object> outputMap) { | ||
Set<String> missingKeys = new HashSet<>(getOutputKeys()); | ||
missingKeys.removeAll(outputMap.keySet()); | ||
if (!missingKeys.isEmpty()) { | ||
throw new IllegalArgumentException("Missing some output keys: " + missingKeys); | ||
} | ||
} | ||
|
||
protected void validateInputs(Map<String, Object> inputMap) { | ||
Set<String> missingKeys = new HashSet<>(getInputKeys()); | ||
missingKeys.removeAll(inputMap.keySet()); | ||
if (!missingKeys.isEmpty()) { | ||
throw new IllegalArgumentException("Missing some input keys: " + missingKeys); | ||
} | ||
} | ||
|
||
} |
13 changes: 13 additions & 0 deletions
13
spring-ai-core/src/main/java/org/springframework/ai/core/chain/Chain.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
package org.springframework.ai.core.chain; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.function.Function; | ||
|
||
public interface Chain extends Function<Map<String, Object>, Map<String, Object>> { | ||
|
||
List<String> getInputKeys(); | ||
|
||
List<String> getOutputKeys(); | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
22 changes: 22 additions & 0 deletions
22
spring-ai-core/src/main/java/org/springframework/ai/core/memory/Memory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
package org.springframework.ai.core.memory; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
public interface Memory { | ||
|
||
/** | ||
* The keys that the memory will add to Chain inputs | ||
*/ | ||
List<String> getKeys(); | ||
|
||
/** | ||
* Return key-value pairs given the text input to the chain | ||
* @param inputs input of the chain | ||
* @return key-value pairs from memory | ||
*/ | ||
Map<String, Object> load(Map<String, Object> inputs); | ||
|
||
void save(Map<String, Object> inputs, Map<String, Object> outputs); | ||
|
||
} |
11 changes: 11 additions & 0 deletions
11
spring-ai-core/src/main/java/org/springframework/ai/core/parser/OutputParser.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
package org.springframework.ai.core.parser; | ||
|
||
import org.springframework.ai.core.llm.Generation; | ||
|
||
import java.util.List; | ||
|
||
public interface OutputParser<T> { | ||
|
||
T parse(List<Generation> output); | ||
|
||
} |
39 changes: 0 additions & 39 deletions
39
spring-ai-core/src/main/java/org/springframework/ai/core/prompt/AbstractPromptTemplate.java
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.