Skip to content

Commit

Permalink
feat: simplify callback manager (#1027)
Browse files Browse the repository at this point in the history
  • Loading branch information
himself65 authored Jul 8, 2024
1 parent c4bd0a5 commit 16ef5dd
Show file tree
Hide file tree
Showing 21 changed files with 169 additions and 100 deletions.
9 changes: 9 additions & 0 deletions .changeset/clean-games-hear.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
"llamaindex": patch
"@llamaindex/core": patch
---

refactor: move callback manager & llm to core module

For people who import `llamaindex/llms/base` or `llamaindex/llms/utils`,
use `@llamaindex/core/llms` and `@llamaindex/core/utils` instead.
5 changes: 5 additions & 0 deletions .changeset/tall-maps-camp.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@llamaindex/community": patch
---

refactor: depends on core pacakge instead of llamaindex
8 changes: 8 additions & 0 deletions .changeset/tricky-candles-notice.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
"llamaindex": minor
"@llamaindex/core": minor
---

refactor: simplify callback manager

Change `event.detail.payload` to `event.detail`
2 changes: 1 addition & 1 deletion examples/anthropic/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { Anthropic, FunctionTool, Settings, WikipediaTool } from "llamaindex";
import { AnthropicAgent } from "llamaindex/agent/anthropic";

Settings.callbackManager.on("llm-tool-call", (event) => {
console.log("llm-tool-call", event.detail.payload.toolCall);
console.log("llm-tool-call", event.detail.toolCall);
});

const anthropic = new Anthropic({
Expand Down
5 changes: 2 additions & 3 deletions examples/multimodal/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import {
NodeWithScore,
ObjectType,
OpenAI,
RetrievalEndEvent,
Settings,
VectorStoreIndex,
} from "llamaindex";
Expand All @@ -18,8 +17,8 @@ Settings.chunkOverlap = 20;
Settings.llm = new OpenAI({ model: "gpt-4-turbo", maxTokens: 512 });

// Update callbackManager
Settings.callbackManager.on("retrieve-end", (event: RetrievalEndEvent) => {
const { nodes, query } = event.detail.payload;
Settings.callbackManager.on("retrieve-end", (event) => {
const { nodes, query } = event.detail;
const imageNodes = nodes.filter(
(node: NodeWithScore) => node.node.type === ObjectType.IMAGE_DOCUMENT,
);
Expand Down
5 changes: 2 additions & 3 deletions examples/multimodal/rag.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import {
MultiModalResponseSynthesizer,
OpenAI,
RetrievalEndEvent,
Settings,
VectorStoreIndex,
} from "llamaindex";
Expand All @@ -15,8 +14,8 @@ Settings.chunkOverlap = 20;
Settings.llm = new OpenAI({ model: "gpt-4-turbo", maxTokens: 512 });

// Update callbackManager
Settings.callbackManager.on("retrieve-end", (event: RetrievalEndEvent) => {
const { nodes, query } = event.detail.payload;
Settings.callbackManager.on("retrieve-end", (event) => {
const { nodes, query } = event.detail;
console.log(`Retrieved ${nodes.length} nodes for query: ${query}`);
});

Expand Down
6 changes: 2 additions & 4 deletions examples/qdrantdb/preFilters.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@ import {

// Update callback manager
Settings.callbackManager.on("retrieve-end", (event) => {
const data = event.detail.payload;
const { nodes } = event.detail;
console.log(
"The retrieved nodes are:",
data.nodes.map((node: NodeWithScore) =>
node.node.getContent(MetadataMode.NONE),
),
nodes.map((node: NodeWithScore) => node.node.getContent(MetadataMode.NONE)),
);
});

Expand Down
8 changes: 4 additions & 4 deletions examples/recipes/cost-analysis.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { extractText } from "@llamaindex/core/utils";
import { encodingForModel } from "js-tiktoken";
import { ChatMessage, OpenAI, type LLMStartEvent } from "llamaindex";
import { ChatMessage, OpenAI } from "llamaindex";
import { Settings } from "llamaindex/Settings";

const encoding = encodingForModel("gpt-4-0125-preview");
Expand All @@ -12,8 +12,8 @@ const llm = new OpenAI({

let tokenCount = 0;

Settings.callbackManager.on("llm-start", (event: LLMStartEvent) => {
const { messages } = event.detail.payload;
Settings.callbackManager.on("llm-start", (event) => {
const { messages } = event.detail;
messages.reduce((count: number, message: ChatMessage) => {
return count + encoding.encode(extractText(message.content)).length;
}, 0);
Expand All @@ -24,7 +24,7 @@ Settings.callbackManager.on("llm-start", (event: LLMStartEvent) => {
});

Settings.callbackManager.on("llm-stream", (event) => {
const { chunk } = event.detail.payload;
const { chunk } = event.detail;
const { delta } = chunk;
tokenCount += encoding.encode(extractText(delta)).length;
if (tokenCount > 20) {
Expand Down
1 change: 0 additions & 1 deletion packages/core/src/global/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
export { Settings } from "./settings";
export { CallbackManager } from "./settings/callback-manager";
export type {
BaseEvent,
LLMEndEvent,
LLMStartEvent,
LLMStreamEvent,
Expand Down
47 changes: 28 additions & 19 deletions packages/core/src/global/settings/callback-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,32 @@ import type {
ToolCall,
ToolOutput,
} from "../../llms";
import { EventCaller, getEventCaller } from "../../utils/event-caller";
import type { UUID } from "../type";

export type BaseEvent<Payload> = CustomEvent<{
payload: Readonly<Payload>;
}>;

export type LLMStartEvent = BaseEvent<{
export type LLMStartEvent = {
id: UUID;
messages: ChatMessage[];
}>;
export type LLMToolCallEvent = BaseEvent<{
};

export type LLMToolCallEvent = {
toolCall: ToolCall;
}>;
export type LLMToolResultEvent = BaseEvent<{
};

export type LLMToolResultEvent = {
toolCall: ToolCall;
toolResult: ToolOutput;
}>;
export type LLMEndEvent = BaseEvent<{
};

export type LLMEndEvent = {
id: UUID;
response: ChatResponse;
}>;
export type LLMStreamEvent = BaseEvent<{
};

export type LLMStreamEvent = {
id: UUID;
chunk: ChatResponseChunk;
}>;
};

export interface LlamaIndexEventMaps {
"llm-start": LLMStartEvent;
Expand All @@ -41,24 +42,32 @@ export interface LlamaIndexEventMaps {
}

export class LlamaIndexCustomEvent<T = any> extends CustomEvent<T> {
private constructor(event: string, options?: CustomEventInit) {
reason: EventCaller | null = null;
private constructor(
event: string,
options?: CustomEventInit & {
reason?: EventCaller | null;
},
) {
super(event, options);
this.reason = options?.reason ?? null;
}

static fromEvent<Type extends keyof LlamaIndexEventMaps>(
type: Type,
detail: LlamaIndexEventMaps[Type]["detail"],
detail: LlamaIndexEventMaps[Type],
) {
return new LlamaIndexCustomEvent(type, {
detail: detail,
reason: getEventCaller(),
});
}
}

type EventHandler<Event> = (event: Event) => void;
type EventHandler<Event> = (event: LlamaIndexCustomEvent<Event>) => void;

export class CallbackManager {
#handlers = new Map<keyof LlamaIndexEventMaps, EventHandler<CustomEvent>[]>();
#handlers = new Map<keyof LlamaIndexEventMaps, EventHandler<any>[]>();

on<K extends keyof LlamaIndexEventMaps>(
event: K,
Expand Down Expand Up @@ -88,7 +97,7 @@ export class CallbackManager {

dispatchEvent<K extends keyof LlamaIndexEventMaps>(
event: K,
detail: LlamaIndexEventMaps[K]["detail"],
detail: LlamaIndexEventMaps[K],
) {
const cbs = this.#handlers.get(event);
if (!cbs) {
Expand Down
24 changes: 8 additions & 16 deletions packages/core/src/utils/wrap-llm-event.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ export function wrapLLMEvent<
> {
const id = randomUUID();
getCallbackManager().dispatchEvent("llm-start", {
payload: {
id,
messages: params[0].messages,
},
id,
messages: params[0].messages,
});
const response = await originalMethod.call(this, ...params);
if (Symbol.asyncIterator in response) {
Expand Down Expand Up @@ -58,29 +56,23 @@ export function wrapLLMEvent<
};
}
getCallbackManager().dispatchEvent("llm-stream", {
payload: {
id,
chunk,
},
id,
chunk,
});
finalResponse.raw.push(chunk);
yield chunk;
}
snapshot(() => {
getCallbackManager().dispatchEvent("llm-end", {
payload: {
id,
response: finalResponse,
},
id,
response: finalResponse,
});
});
};
} else {
getCallbackManager().dispatchEvent("llm-end", {
payload: {
id,
response,
},
id,
response,
});
}
return response;
Expand Down
67 changes: 67 additions & 0 deletions packages/core/tests/event-system.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import { CallbackManager, Settings } from "@llamaindex/core/global";
import { beforeEach, describe, expect, expectTypeOf, test, vi } from "vitest";

declare module "@llamaindex/core/global" {
interface LlamaIndexEventMaps {
test: {
value: number;
};
}
}

describe("event system", () => {
beforeEach(() => {
Settings.callbackManager = new CallbackManager();
});

test("type system", () => {
Settings.callbackManager.on("test", (event) => {
const data = event.detail;
expectTypeOf(data).not.toBeAny();
expectTypeOf(data).toEqualTypeOf<{
value: number;
}>();
});
});

test("dispatch event", async () => {
let callback;
Settings.callbackManager.on(
"test",
(callback = vi.fn((event) => {
const data = event.detail;
expect(data.value).toBe(42);
})),
);

Settings.callbackManager.dispatchEvent("test", {
value: 42,
});
expect(callback).toHaveBeenCalledTimes(0);
await new Promise((resolve) => process.nextTick(resolve));
expect(callback).toHaveBeenCalledTimes(1);
});

// rollup doesn't support decorators for now
// test('wrap event caller', async () => {
// class A {
// @wrapEventCaller
// fn() {
// Settings.callbackManager.dispatchEvent('test', {
// value: 42
// });
// }
// }
// const a = new A();
// let callback;
// Settings.callbackManager.on('test', callback = vi.fn((event) => {
// const data = event.detail;
// expect(event.reason!.caller).toBe(a);
// expect(data.value).toBe(42);
// }));
// a.fn();
// expect(callback).toHaveBeenCalledTimes(0);
// await new Promise((resolve) => process.nextTick(resolve));
// expect(callback).toHaveBeenCalledTimes(1);
// })
});
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ Settings.embedModel = new HuggingFaceEmbedding({
quantized: false,
});
Settings.callbackManager.on("llm-tool-call", (event) => {
console.log(event.detail.payload);
console.log(event.detail);
});
Settings.callbackManager.on("llm-tool-result", (event) => {
console.log(event.detail.payload);
console.log(event.detail);
});

export async function getOpenAIModelRequest(query: string) {
Expand Down
Loading

0 comments on commit 16ef5dd

Please sign in to comment.