Skip to content

Commit

Permalink
Merge pull request #32 from dlants/prompt-caching
Browse files Browse the repository at this point in the history
Prompt caching improvements, debug
  • Loading branch information
dlants authored Jan 13, 2025
2 parents d1fc087 + 6e7cf38 commit 159e961
Show file tree
Hide file tree
Showing 9 changed files with 256 additions and 92 deletions.
14 changes: 3 additions & 11 deletions node/chat/chat.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ describe("tea/chat.spec.ts", () => {
expect(
await buffer.getLines({ start: 0, end: -1 }),
"initial render of chat works",
).toEqual(["Stopped (end_turn) [input: 0, output: 0]"] as Line[]);
).toEqual(Chat.LOGO.split("\n") as Line[]);

app.dispatch({
type: "add-message",
Expand Down Expand Up @@ -144,7 +144,7 @@ describe("tea/chat.spec.ts", () => {
expect(
await buffer.getLines({ start: 0, end: -1 }),
"initial render of chat works",
).toEqual(["Stopped (end_turn) [input: 0, output: 0]"] as Line[]);
).toEqual(Chat.LOGO.split("\n") as Line[]);

app.dispatch({
type: "add-message",
Expand Down Expand Up @@ -172,10 +172,6 @@ describe("tea/chat.spec.ts", () => {
"Stopped (end_turn) [input: 0, output: 0]",
] as Line[]);

// expect(
// await extractMountTree(mountedApp.getMountedNode()),
// ).toMatchSnapshot();

app.dispatch({
type: "clear",
});
Expand All @@ -184,11 +180,7 @@ describe("tea/chat.spec.ts", () => {
expect(
await buffer.getLines({ start: 0, end: -1 }),
"finished render is as expected",
).toEqual(["Stopped (end_turn) [input: 0, output: 0]"] as Line[]);

// expect(
// await extractMountTree(mountedApp.getMountedNode()),
// ).toMatchSnapshot();
).toEqual(Chat.LOGO.split("\n") as Line[]);
});
});
});
88 changes: 78 additions & 10 deletions node/chat/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ import {
type Update,
wrapThunk,
} from "../tea/tea.ts";
import { d, type View } from "../tea/view.ts";
import { d, withBindings, type View } from "../tea/view.ts";
import * as ToolManager from "../tools/toolManager.ts";
import { type Result } from "../utils/result.ts";
import { Counter } from "../utils/uniqueId.ts";
import type { Nvim } from "nvim-node";
import type { Lsp } from "../lsp.ts";
import {
getClient,
getClient as getProvider,
type ProviderMessage,
type ProviderMessageContent,
type ProviderName,
Expand All @@ -24,6 +24,7 @@ import {
} from "../providers/provider.ts";
import { assertUnreachable } from "../utils/assertUnreachable.ts";
import { DEFAULT_OPTIONS, type MagentaOptions } from "../options.ts";
import { getOption } from "../nvim/nvim.ts";

export type Role = "user" | "assistant";

Expand Down Expand Up @@ -103,6 +104,9 @@ export type Msg =
| {
type: "set-opts";
options: MagentaOptions;
}
| {
type: "show-message-debug-info";
};

export function init({ nvim, lsp }: { nvim: Nvim; lsp: Lsp }) {
Expand Down Expand Up @@ -421,7 +425,7 @@ ${msg.error.stack}`,
model,
// eslint-disable-next-line @typescript-eslint/require-await
async () => {
getClient(nvim, model.activeProvider, model.options).abort();
getProvider(nvim, model.activeProvider, model.options).abort();
},
];
}
Expand All @@ -430,6 +434,10 @@ ${msg.error.stack}`,
return [{ ...model, options: msg.options }];
}

case "show-message-debug-info": {
return [model, () => showDebugInfo(model)];
}

default:
assertUnreachable(msg);
}
Expand Down Expand Up @@ -490,7 +498,7 @@ ${msg.error.stack}`,
});
let res;
try {
res = await getClient(
res = await getProvider(
nvim,
model.activeProvider,
model.options,
Expand Down Expand Up @@ -537,6 +545,13 @@ ${msg.error.stack}`,
model,
dispatch,
}) => {
if (
model.messages.length == 0 &&
Object.keys(model.contextManager.files).length == 0
) {
return d`${LOGO}`;
}

return d`${model.messages.map(
(m, idx) =>
d`${messageModel.view({
Expand All @@ -556,12 +571,17 @@ ${msg.error.stack}`,
) % MESSAGE_ANIMATION.length
]
}`
: d`Stopped (${model.conversation.stopReason}) [input: ${model.conversation.usage.inputTokens.toString()}, output: ${model.conversation.usage.outputTokens.toString()}${
model.conversation.usage.cacheHits !== undefined &&
model.conversation.usage.cacheMisses !== undefined
? d`, cache hits: ${model.conversation.usage.cacheHits.toString()}, cache misses: ${model.conversation.usage.cacheMisses.toString()}`
: ""
}]`
: withBindings(
d`Stopped (${model.conversation.stopReason}) [input: ${model.conversation.usage.inputTokens.toString()}, output: ${model.conversation.usage.outputTokens.toString()}${
model.conversation.usage.cacheHits !== undefined &&
model.conversation.usage.cacheMisses !== undefined
? d`, cache hits: ${model.conversation.usage.cacheHits.toString()}, cache misses: ${model.conversation.usage.cacheMisses.toString()}`
: ""
}]`,
{
"<CR>": () => dispatch({ type: "show-message-debug-info" }),
},
)
}${
model.conversation.state == "stopped" &&
!contextManagerModel.isContextEmpty(model.contextManager)
Expand Down Expand Up @@ -634,10 +654,58 @@ ${msg.error.stack}`,
return messages.map((m) => m.message);
}

async function showDebugInfo(model: Model) {
const messages = await getMessages(model);
const provider = getProvider(nvim, model.activeProvider, model.options);
const params = provider.createStreamParameters(messages);
const nTokens = await provider.countTokens(messages);

// Create a floating window
const bufnr = await nvim.call("nvim_create_buf", [false, true]);
await nvim.call("nvim_buf_set_option", [bufnr, "bufhidden", "wipe"]);
const [editorWidth, editorHeight] = (await Promise.all([
getOption("columns", nvim),
await getOption("lines", nvim),
])) as [number, number];
const width = 80;
const height = editorHeight - 20;
await nvim.call("nvim_open_win", [
bufnr,
true,
{
relative: "editor",
width,
height,
col: Math.floor((editorWidth - width) / 2),
row: Math.floor((editorHeight - height) / 2),
style: "minimal",
border: "single",
},
]);

const lines = JSON.stringify(params, null, 2).split("\n");
lines.push(`nTokens: ${nTokens}`);
await nvim.call("nvim_buf_set_lines", [bufnr, 0, -1, false, lines]);

// Set buffer options
await nvim.call("nvim_buf_set_option", [bufnr, "modifiable", false]);
await nvim.call("nvim_buf_set_option", [bufnr, "filetype", "json"]);
}

return {
initModel,
update,
view,
getMessages,
};
}

export const LOGO = `\
________
╱ ╲
╱ ╱
╱ ╱
╲__╱__╱__╱
# magenta.nvim`;
5 changes: 2 additions & 3 deletions node/magenta.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { describe, expect, it } from "vitest";
import { withDriver } from "./test/preamble";
import { pollUntil } from "./utils/async";
import type { Position0Indexed } from "./nvim/window";
import { LOGO } from "./chat/chat";

describe("node/magenta.spec.ts", () => {
it("clear command should work", async () => {
Expand All @@ -25,9 +26,7 @@ sup?
Stopped (end_turn) [input: 0, output: 0]`);

await driver.clear();
await driver.assertDisplayBufferContent(
`Stopped (end_turn) [input: 0, output: 0]`,
);
await driver.assertDisplayBufferContent(LOGO);
await driver.inputMagentaText(`hello again`);
await driver.send();
await driver.mockAnthropic.respond({
Expand Down
117 changes: 76 additions & 41 deletions node/providers/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,34 +47,9 @@ export class AnthropicProvider implements Provider {
}
}

async sendMessage(
messages: Array<ProviderMessage>,
onText: (text: string) => void,
onError: (error: Error) => void,
): Promise<{
toolRequests: Result<ToolManager.ToolRequest, { rawRequest: unknown }>[];
stopReason: StopReason;
usage: Usage;
}> {
const buf: string[] = [];
let flushInProgress: boolean = false;

const flushBuffer = () => {
if (buf.length && !flushInProgress) {
const text = buf.join("");
buf.splice(0);

flushInProgress = true;

try {
onText(text);
} finally {
flushInProgress = false;
setInterval(flushBuffer, 1);
}
}
};

createStreamParameters(
messages: ProviderMessage[],
): Anthropic.Messages.MessageStreamParams {
const anthropicMessages = messages.map((m): MessageParam => {
let content: Anthropic.Messages.ContentBlockParam[];
if (typeof m.content == "string") {
Expand Down Expand Up @@ -116,7 +91,7 @@ export class AnthropicProvider implements Provider {
};
});

placeCacheBreakpoints(anthropicMessages);
const cacheControlItemsPlaced = placeCacheBreakpoints(anthropicMessages);

const tools: Anthropic.Tool[] = ToolManager.TOOL_SPECS.map(
(t): Anthropic.Tool => {
Expand All @@ -127,19 +102,77 @@ export class AnthropicProvider implements Provider {
},
);

return {
messages: anthropicMessages,
model: this.options.model,
max_tokens: 4096,
system: [
{
type: "text",
text: DEFAULT_SYSTEM_PROMPT,
// the prompt appears in the following order:
// tools
// system
// messages
// This ensures the tools + system prompt (which is approx 1400 tokens) is cached.
cache_control:
cacheControlItemsPlaced < 4 ? { type: "ephemeral" } : null,
},
],
tool_choice: {
type: "auto",
disable_parallel_tool_use: false,
},
tools,
};
}

async countTokens(messages: Array<ProviderMessage>): Promise<number> {
const params = this.createStreamParameters(messages);
const lastMessage = params.messages[params.messages.length - 1];
if (!lastMessage || lastMessage.role != "user") {
params.messages.push({ role: "user", content: "test" });
}
const res = await this.client.messages.countTokens({
messages: params.messages,
model: params.model,
system: params.system as Anthropic.TextBlockParam[],
tools: params.tools as Anthropic.Tool[],
});
return res.input_tokens;
}

async sendMessage(
messages: Array<ProviderMessage>,
onText: (text: string) => void,
onError: (error: Error) => void,
): Promise<{
toolRequests: Result<ToolManager.ToolRequest, { rawRequest: unknown }>[];
stopReason: StopReason;
usage: Usage;
}> {
const buf: string[] = [];
let flushInProgress: boolean = false;

const flushBuffer = () => {
if (buf.length && !flushInProgress) {
const text = buf.join("");
buf.splice(0);

flushInProgress = true;

try {
onText(text);
} finally {
flushInProgress = false;
setInterval(flushBuffer, 1);
}
}
};

try {
this.request = this.client.messages
.stream({
messages: anthropicMessages,
model: this.options.model,
max_tokens: 4096,
system: DEFAULT_SYSTEM_PROMPT,
tool_choice: {
type: "auto",
disable_parallel_tool_use: false,
},
tools,
})
.stream(this.createStreamParameters(messages))
.on("text", (text: string) => {
buf.push(text);
flushBuffer();
Expand Down Expand Up @@ -247,7 +280,7 @@ export class AnthropicProvider implements Provider {
}
}

export function placeCacheBreakpoints(messages: MessageParam[]) {
export function placeCacheBreakpoints(messages: MessageParam[]): number {
// when we scan the messages, keep track of where each part ends.
const blocks: { block: Anthropic.Messages.ContentBlockParam; acc: number }[] =
[];
Expand Down Expand Up @@ -315,6 +348,8 @@ export function placeCacheBreakpoints(messages: MessageParam[]) {
}
}
}

return powers.length;
}

const STR_CHARS_PER_TOKEN = 4;
Expand Down
9 changes: 9 additions & 0 deletions node/providers/mock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ export class MockProvider implements Provider {
}
}

createStreamParameters(messages: Array<ProviderMessage>): unknown {
return messages;
}

// eslint-disable-next-line @typescript-eslint/require-await
async countTokens(messages: Array<ProviderMessage>): Promise<number> {
return messages.length;
}

async sendMessage(
messages: Array<ProviderMessage>,
onText: (text: string) => void,
Expand Down
Loading

0 comments on commit 159e961

Please sign in to comment.