-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
LlamaCpp.ts
55 lines (48 loc) · 1.43 KB
/
LlamaCpp.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import { CompletionOptions, LLMOptions, ModelProvider } from "../../index.js";
import { BaseLLM } from "../index.js";
import { streamSse } from "../stream.js";
class LlamaCpp extends BaseLLM {
static providerName: ModelProvider = "llama.cpp";
static defaultOptions: Partial<LLMOptions> = {
apiBase: "http://127.0.0.1:8080/",
};
private _convertArgs(options: CompletionOptions, prompt: string) {
const finalOptions = {
n_predict: options.maxTokens,
frequency_penalty: options.frequencyPenalty,
presence_penalty: options.presencePenalty,
min_p: options.minP,
mirostat: options.mirostat,
stop: options.stop,
top_k: options.topK,
top_p: options.topP,
temperature: options.temperature,
};
return finalOptions;
}
protected async *_streamComplete(
prompt: string,
options: CompletionOptions,
): AsyncGenerator<string> {
const headers = {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
...this.requestOptions?.headers,
};
const resp = await this.fetch(new URL("completion", this.apiBase), {
method: "POST",
headers,
body: JSON.stringify({
prompt,
stream: true,
...this._convertArgs(options, prompt),
}),
});
for await (const value of streamSse(resp)) {
if (value.content) {
yield value.content;
}
}
}
}
export default LlamaCpp;