-
Notifications
You must be signed in to change notification settings - Fork 354
/
ContextChatEngine.ts
157 lines (147 loc) · 5.06 KB
/
ContextChatEngine.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import type {
BaseChatEngine,
NonStreamingChatEngineParams,
StreamingChatEngineParams,
} from "@llamaindex/core/chat-engine";
import { wrapEventCaller } from "@llamaindex/core/decorator";
import type {
ChatMessage,
LLM,
MessageContent,
MessageType,
} from "@llamaindex/core/llms";
import { BaseMemory, ChatMemoryBuffer } from "@llamaindex/core/memory";
import {
type ContextSystemPrompt,
type ModuleRecord,
PromptMixin,
type PromptsRecord,
} from "@llamaindex/core/prompts";
import type { BaseRetriever } from "@llamaindex/core/retriever";
import { EngineResponse, MetadataMode } from "@llamaindex/core/schema";
import {
extractText,
streamConverter,
streamReducer,
} from "@llamaindex/core/utils";
import { Settings } from "../../Settings.js";
import type { BaseNodePostprocessor } from "../../postprocessors/index.js";
import { DefaultContextGenerator } from "./DefaultContextGenerator.js";
import type { ContextGenerator } from "./types.js";
/**
* ContextChatEngine uses the Index to get the appropriate context for each query.
* The context is stored in the system prompt, and the chat history is chunk: ChatResponseChunk, nodes?: NodeWithScore<import("/Users/marcus/code/llamaindex/LlamaIndexTS/packages/core/src/Node").Metadata>[], nodes?: NodeWithScore<import("/Users/marcus/code/llamaindex/LlamaIndexTS/packages/core/src/Node").Metadata>[]lowing the appropriate context to be surfaced for each query.
*/
export class ContextChatEngine extends PromptMixin implements BaseChatEngine {
chatModel: LLM;
memory: BaseMemory;
contextGenerator: ContextGenerator & PromptMixin;
systemPrompt?: string | undefined;
get chatHistory() {
return this.memory.getMessages();
}
constructor(init: {
retriever: BaseRetriever;
chatModel?: LLM | undefined;
chatHistory?: ChatMessage[] | undefined;
contextSystemPrompt?: ContextSystemPrompt | undefined;
nodePostprocessors?: BaseNodePostprocessor[] | undefined;
systemPrompt?: string | undefined;
contextRole?: MessageType | undefined;
}) {
super();
this.chatModel = init.chatModel ?? Settings.llm;
this.memory = new ChatMemoryBuffer({ chatHistory: init?.chatHistory });
this.contextGenerator = new DefaultContextGenerator({
retriever: init.retriever,
contextSystemPrompt: init?.contextSystemPrompt,
nodePostprocessors: init?.nodePostprocessors,
contextRole: init?.contextRole,
metadataMode: MetadataMode.LLM,
});
this.systemPrompt = init.systemPrompt;
}
protected _getPrompts(): PromptsRecord {
return {
...this.contextGenerator.getPrompts(),
};
}
protected _updatePrompts(prompts: {
contextSystemPrompt: ContextSystemPrompt;
}): void {
this.contextGenerator.updatePrompts(prompts);
}
protected _getPromptModules(): ModuleRecord {
return {
contextGenerator: this.contextGenerator,
};
}
chat(params: NonStreamingChatEngineParams): Promise<EngineResponse>;
chat(
params: StreamingChatEngineParams,
): Promise<AsyncIterable<EngineResponse>>;
@wrapEventCaller
async chat(
params: StreamingChatEngineParams | NonStreamingChatEngineParams,
): Promise<EngineResponse | AsyncIterable<EngineResponse>> {
const { message, stream } = params;
const chatHistory = params.chatHistory
? new ChatMemoryBuffer({
chatHistory:
params.chatHistory instanceof BaseMemory
? await params.chatHistory.getMessages()
: params.chatHistory,
})
: this.memory;
const requestMessages = await this.prepareRequestMessages(
message,
chatHistory,
);
if (stream) {
const stream = await this.chatModel.chat({
messages: requestMessages.messages,
stream: true,
});
return streamConverter(
streamReducer({
stream,
initialValue: "",
reducer: (accumulator, part) => (accumulator += part.delta),
finished: (accumulator) => {
chatHistory.put({ content: accumulator, role: "assistant" });
},
}),
(r) => EngineResponse.fromChatResponseChunk(r, requestMessages.nodes),
);
}
const response = await this.chatModel.chat({
messages: requestMessages.messages,
});
chatHistory.put(response.message);
return EngineResponse.fromChatResponse(response, requestMessages.nodes);
}
reset() {
this.memory.reset();
}
private async prepareRequestMessages(
message: MessageContent,
chatHistory: BaseMemory,
) {
chatHistory.put({
content: message,
role: "user",
});
const textOnly = extractText(message);
const context = await this.contextGenerator.generate(textOnly);
const systemMessage = this.prependSystemPrompt(context.message);
const messages = await chatHistory.getMessages([systemMessage]);
return { nodes: context.nodes, messages };
}
private prependSystemPrompt(message: ChatMessage): ChatMessage {
if (!this.systemPrompt) return message;
return {
...message,
content: this.systemPrompt.trim() + "\n" + extractText(message.content),
};
}
}