Skip to content

Commit

Permalink
feat: Add message history truncation (#364)
Browse files Browse the repository at this point in the history
* feat: Add message history truncation

* fix: Update other tests
  • Loading branch information
johnjcsmith authored Dec 26, 2024
1 parent ee4afb5 commit 2489bf7
Show file tree
Hide file tree
Showing 8 changed files with 344 additions and 18 deletions.
1 change: 1 addition & 0 deletions control-plane/src/modules/models/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ const mockCreate = jest.fn(() => ({
}));

jest.mock("./routing", () => ({
...jest.requireActual("./routing"),
getRouting: jest.fn(() => ({
buildClient: jest.fn(() => ({
messages: {
Expand Down
3 changes: 3 additions & 0 deletions control-plane/src/modules/models/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import Anthropic from "@anthropic-ai/sdk";
import { ToolUseBlock } from "@anthropic-ai/sdk/resources";
import {
ChatIdentifiers,
CONTEXT_WINDOW,
EmbeddingIdentifiers,
getEmbeddingRouting,
getRouting,
Expand Down Expand Up @@ -47,6 +48,7 @@ export type Model = {
options: T,
) => Promise<StructuredCallOutput>;
identifier: ChatIdentifiers | EmbeddingIdentifiers;
contextWindow?: number;
embedQuery: (input: string) => Promise<number[]>;
};

Expand All @@ -72,6 +74,7 @@ export const buildModel = ({

return {
identifier,
contextWindow: CONTEXT_WINDOW[identifier],
embedQuery: async (input: string) => {
if (!isEmbeddingIdentifier(identifier)) {
throw new Error(`${identifier} is not an embedding model`);
Expand Down
5 changes: 5 additions & 0 deletions control-plane/src/modules/models/routing.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ import { logger } from "../observability/logger";
import { BedrockCohereEmbeddings } from "../embeddings/bedrock-cohere-embeddings";
import { CohereEmbeddings } from "@langchain/cohere";

export const CONTEXT_WINDOW: Record<string, number> = {
"claude-3-5-sonnet": 200_000,
"claude-3-haiku": 200_000,
};

const routingOptions = {
"claude-3-5-sonnet": [
...(env.BEDROCK_AVAILABLE
Expand Down
31 changes: 17 additions & 14 deletions control-plane/src/modules/workflows/agent/nodes/model-call.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { ReleventToolLookup } from '../agent';
import { toAnthropicMessages } from '../../workflow-messages';
import { toAnthropicMessage, toAnthropicMessages } from '../../workflow-messages';
import { logger } from '../../../observability/logger';
import { WorkflowAgentState, WorkflowAgentStateMessage } from '../state';
import { addAttributes, withSpan } from '../../../observability/tracer';
Expand All @@ -14,6 +14,7 @@ import { ToolUseBlock } from '@anthropic-ai/sdk/resources';
import { Schema, Validator } from 'jsonschema';
import { buildModelSchema, ModelOutput } from './model-output';
import { getSystemPrompt } from './system-prompt';
import { handleContextWindowOverflow } from '../overflow';

type WorkflowStateUpdate = Partial<WorkflowAgentState>;

Expand All @@ -32,16 +33,15 @@ const _handleModelCall = async (
findRelevantTools: ReleventToolLookup
): Promise<WorkflowStateUpdate> => {
detectCycle(state.messages);
const relevantSchemas = await findRelevantTools(state);

const relevantTools = await findRelevantTools(state);

addAttributes({
'model.relevant_tools': relevantSchemas.map(tool => tool.name),
'model.relevant_tools': relevantTools.map(tool => tool.name),
'model.available_tools': state.allAvailableTools,
'model.identifier': model.identifier,
});

const renderedMessages = toAnthropicMessages(state.messages);

if (!!state.workflow.resultSchema) {
const resultSchemaErrors = validateFunctionSchema(
state.workflow.resultSchema as JsonSchemaInput
Expand All @@ -53,31 +53,33 @@ const _handleModelCall = async (

const schema = buildModelSchema({
state,
relevantSchemas,
relevantSchemas: relevantTools,
resultSchema: state.workflow.resultSchema as JsonSchemaInput,
});

const schemaString = relevantSchemas.map(tool => {
return `${tool.name} - ${tool.description} ${tool.schema}`;
});
const systemPrompt = getSystemPrompt(state, relevantTools);

const systemPrompt = getSystemPrompt(state, schemaString);
const truncatedMessages = await handleContextWindowOverflow({
messages: state.messages,
systemPrompt: systemPrompt + JSON.stringify(schema),
modelContextWindow: model.contextWindow,
render: (m) => JSON.stringify(toAnthropicMessage(m)),
});

if (state.workflow.debug) {
addAttributes({
'model.input.additional_context': state.additionalContext,
'model.input.systemPrompt': systemPrompt,
'model.input.messages': JSON.stringify(
state.messages.map(m => ({
truncatedMessages.map(m => ({
id: m.id,
type: m.type,
}))
),
'model.input.rendered_messages': JSON.stringify(renderedMessages),
});
}

const response = await model.structured({
messages: renderedMessages,
messages: toAnthropicMessages(truncatedMessages),
system: systemPrompt,
schema,
});
Expand Down Expand Up @@ -258,6 +260,7 @@ const _handleModelCall = async (
};
};


const detectCycle = (messages: WorkflowAgentStateMessage[]) => {
if (messages.length >= 100) {
throw new AgentError('Maximum workflow message length exceeded.');
Expand Down
10 changes: 7 additions & 3 deletions control-plane/src/modules/workflows/agent/nodes/system-prompt.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { WorkflowAgentState } from "../state";
import { AgentTool } from "../tool";

export const getSystemPrompt = (
state: WorkflowAgentState,
schemaString: string[],
tools: AgentTool[],
): string => {
const basePrompt = [
"You are a helpful assistant with access to a set of tools designed to assist in completing tasks.",
Expand Down Expand Up @@ -36,16 +37,19 @@ export const getSystemPrompt = (
basePrompt.push(state.additionalContext);
}


// Add tool schemas
basePrompt.push("<TOOLS_SCHEMAS>");
basePrompt.push(...schemaString);
basePrompt.push(...tools.map(tool => {
return `${tool.name} - ${tool.description} ${tool.schema}`;
}));
basePrompt.push("</TOOLS_SCHEMAS>");

// Add other available tools
basePrompt.push("<OTHER_AVAILABLE_TOOLS>");
basePrompt.push(
...state.allAvailableTools.filter(
(t) => !schemaString.find((s) => s.includes(t)),
(t) => !tools.find((s) => s.name === t),
),
);
basePrompt.push("</OTHER_AVAILABLE_TOOLS>");
Expand Down
236 changes: 236 additions & 0 deletions control-plane/src/modules/workflows/agent/overflow.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
import { AgentError } from '../../../utilities/errors';
import { WorkflowAgentStateMessage } from './state';
import { handleContextWindowOverflow } from './overflow';
import { estimateTokenCount } from './utils';

jest.mock('./utils', () => ({
estimateTokenCount: jest.fn(),
}));

describe('handleContextWindowOverflow', () => {
beforeEach(() => {
jest.clearAllMocks();
});

it('should throw if system prompt exceeds threshold', async () => {
const systemPrompt = 'System prompt';
const messages: WorkflowAgentStateMessage[] = [];
const modelContextWindow = 1000;

(estimateTokenCount as jest.Mock)
.mockResolvedValueOnce(701); // system prompt (0.7 * 1000)

await expect(
handleContextWindowOverflow({
systemPrompt,
messages,
modelContextWindow,
})
).rejects.toThrow(new AgentError('System prompt can not exceed 700 tokens'));
});

it('should not modify messages if total tokens are under threshold', async () => {
const systemPrompt = 'System prompt';
const messages: WorkflowAgentStateMessage[] = [
{ type: 'human', data: { message: 'Hello' } } as any,
{ type: 'agent', data: { message: 'Hi' } } as any,
];
const modelContextWindow = 1000;

(estimateTokenCount as jest.Mock)
.mockResolvedValueOnce(100) // system prompt
.mockResolvedValueOnce(200); // messages

const result = await handleContextWindowOverflow({
systemPrompt,
messages,
modelContextWindow,
});

expect(estimateTokenCount).toHaveBeenCalledTimes(2);

expect(result).toEqual(messages);
expect(messages).toHaveLength(2);
});

it('should handle empty messages array', async () => {
const systemPrompt = 'System prompt';
const messages: WorkflowAgentStateMessage[] = [];
const modelContextWindow = 1000;

(estimateTokenCount as jest.Mock)
.mockResolvedValueOnce(200) // system prompt
.mockResolvedValueOnce(0); // empty messages

const result = await handleContextWindowOverflow({
systemPrompt,
messages,
modelContextWindow,
});

expect(estimateTokenCount).toHaveBeenCalledTimes(2);

expect(result).toEqual(messages);
expect(messages).toHaveLength(0);
});

describe('truncate strategy', () => {
it('should remove messages until total tokens are under threshold', async () => {
const systemPrompt = 'System prompt';
const messages: WorkflowAgentStateMessage[] = Array(5).fill({
type: 'human',
data: { message: 'Message' },
});
const modelContextWindow = 600;

(estimateTokenCount as jest.Mock)
.mockResolvedValueOnce(200) // system prompt
.mockResolvedValueOnce(900) // initial messages
.mockResolvedValueOnce(700) // after first removal
.mockResolvedValueOnce(500) // after second removal
.mockResolvedValueOnce(300); // after third removal

const result = await handleContextWindowOverflow({
systemPrompt,
messages,
modelContextWindow,
});

expect(estimateTokenCount).toHaveBeenCalledTimes(5);

expect(result).toHaveLength(2);
});

it('should throw if a single message exceeds the context window', async () => {
const systemPrompt = 'System prompt';
const messages: WorkflowAgentStateMessage[] = [
{ type: 'human', data: { message: 'Message' } } as any,
];
const modelContextWindow = 400;

(estimateTokenCount as jest.Mock)
.mockResolvedValueOnce(200) // system prompt
.mockResolvedValueOnce(400); // message

await expect(
handleContextWindowOverflow({
systemPrompt,
messages,
modelContextWindow,
})
).rejects.toThrow(AgentError);

expect(estimateTokenCount).toHaveBeenCalledTimes(2)
});


it('should remove tool invocation result when removing agent message', async () => {
const systemPrompt = 'System prompt';
const messages: WorkflowAgentStateMessage[] = [
{ id: "123", type: 'agent', data: { message: 'Hi', invocations: [
{
id: "toolCallId1",
},
{
id: "toolCallId2",
},
{
id: "toolCallId3",
},
]}} as any,
{ id: "456", type: 'invocation-result', data: { id: "toolCallId1" } } as any,
{ id: "456", type: 'invocation-result', data: { id: "toolCallId2" } } as any,
{ id: "456", type: 'invocation-result', data: { id: "toolCallId3" } } as any,
{ id: "789", type: 'human', data: { message: 'Hello' }} as any,
];

// Only one message needs to be removed to satisfy context window
// 2 will be removed to ensure first message is human
const modelContextWindow = 1100;
(estimateTokenCount as jest.Mock)
.mockResolvedValueOnce(200) // system prompt
.mockResolvedValueOnce(1000) // initial messages
.mockResolvedValueOnce(800) // after first removal
.mockResolvedValueOnce(600) // after second removal
.mockResolvedValueOnce(400) // after third removal
.mockResolvedValueOnce(200) // after fourth removal

const result = await handleContextWindowOverflow({
systemPrompt,
messages,
modelContextWindow,
});

expect(estimateTokenCount).toHaveBeenCalledTimes(6);

expect(result).toHaveLength(1);
expect(result[0].type).toBe('human');
});

it('should remove agent message when removing tool invocation result', async () => {
const systemPrompt = 'System prompt';
const messages: WorkflowAgentStateMessage[] = [
{ id: "456", type: 'invocation-result', data: { id: "toolCallId1" } } as any,
{ id: "123", type: 'agent', data: { message: 'Hi', invocations: [
{
id: "toolCallId1",
},
]}} as any,
{ id: "789", type: 'human', data: { message: 'Hello' }} as any,
];

// Only one message needs to be removed to satisfy context window
// 2 will be removed to ensure first message is human
const modelContextWindow = 1100;
(estimateTokenCount as jest.Mock)
.mockResolvedValueOnce(200) // system prompt
.mockResolvedValueOnce(1000) // initial messages
.mockResolvedValueOnce(800) // after first removal
.mockResolvedValueOnce(600) // after second removal

const result = await handleContextWindowOverflow({
systemPrompt,
messages,
modelContextWindow,
});

expect(estimateTokenCount).toHaveBeenCalledTimes(4);

expect(result).toHaveLength(1);
expect(result[0].type).toBe('human');
})

it('should ensure first message is human', async () => {
const systemPrompt = 'System prompt';
const messages: WorkflowAgentStateMessage[] = [
{ id: "789", type: 'human', data: { message: 'Hello' }} as any,
{ id: "123", type: 'agent', data: { message: 'Hi', invocations: [
{
id: "toolCallId1",
},
]}} as any,
{ id: "789", type: 'human', data: { message: 'Hello' }} as any,
];

// Only one message needs to be removed to satisfy context window
// 2 will be removed to ensure first message is human
const modelContextWindow = 1100;
(estimateTokenCount as jest.Mock)
.mockResolvedValueOnce(200) // system prompt
.mockResolvedValueOnce(1000) // initial messages
.mockResolvedValueOnce(800) // after first removal
.mockResolvedValueOnce(600) // after second removal

const result = await handleContextWindowOverflow({
systemPrompt,
messages,
modelContextWindow,
});

expect(estimateTokenCount).toHaveBeenCalledTimes(4);

expect(result).toHaveLength(1);
expect(result[0].type).toBe('human');
})
})
});
Loading

0 comments on commit 2489bf7

Please sign in to comment.