From 16ef5dd63112fc5a686a4aaf55f0b4a53043295e Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Mon, 8 Jul 2024 16:44:54 -0700 Subject: [PATCH] feat: simplify callback manager (#1027) --- .changeset/clean-games-hear.md | 9 +++ .changeset/tall-maps-camp.md | 5 ++ .changeset/tricky-candles-notice.md | 8 +++ examples/anthropic/agent.ts | 2 +- examples/multimodal/context.ts | 5 +- examples/multimodal/rag.ts | 5 +- examples/qdrantdb/preFilters.ts | 6 +- examples/recipes/cost-analysis.ts | 8 +-- packages/core/src/global/index.ts | 1 - .../src/global/settings/callback-manager.ts | 47 +++++++------ packages/core/src/utils/wrap-llm-event.ts | 24 +++---- packages/core/tests/event-system.test.ts | 67 +++++++++++++++++++ .../nextjs-node-runtime/src/actions/openai.ts | 4 +- packages/llamaindex/e2e/node/utils.ts | 25 +++---- packages/llamaindex/src/agent/base.ts | 8 +-- packages/llamaindex/src/agent/types.ts | 9 ++- packages/llamaindex/src/agent/utils.ts | 10 +-- .../src/cloud/LlamaCloudRetriever.ts | 6 +- packages/llamaindex/src/index.edge.ts | 1 - .../src/indices/vectorStore/index.ts | 10 +-- packages/llamaindex/src/llm/types.ts | 9 ++- 21 files changed, 169 insertions(+), 100 deletions(-) create mode 100644 .changeset/clean-games-hear.md create mode 100644 .changeset/tall-maps-camp.md create mode 100644 .changeset/tricky-candles-notice.md create mode 100644 packages/core/tests/event-system.test.ts diff --git a/.changeset/clean-games-hear.md b/.changeset/clean-games-hear.md new file mode 100644 index 0000000000..44467bd6f4 --- /dev/null +++ b/.changeset/clean-games-hear.md @@ -0,0 +1,9 @@ +--- +"llamaindex": patch +"@llamaindex/core": patch +--- + +refactor: move callback manager & llm to core module + +For people who import `llamaindex/llms/base` or `llamaindex/llms/utils`, +use `@llamaindex/core/llms` and `@llamaindex/core/utils` instead. diff --git a/.changeset/tall-maps-camp.md b/.changeset/tall-maps-camp.md new file mode 100644 index 0000000000..46f7b35471 --- /dev/null +++ b/.changeset/tall-maps-camp.md @@ -0,0 +1,5 @@ +--- +"@llamaindex/community": patch +--- + +refactor: depends on core pacakge instead of llamaindex diff --git a/.changeset/tricky-candles-notice.md b/.changeset/tricky-candles-notice.md new file mode 100644 index 0000000000..224c197676 --- /dev/null +++ b/.changeset/tricky-candles-notice.md @@ -0,0 +1,8 @@ +--- +"llamaindex": minor +"@llamaindex/core": minor +--- + +refactor: simplify callback manager + +Change `event.detail.payload` to `event.detail` diff --git a/examples/anthropic/agent.ts b/examples/anthropic/agent.ts index f4f8324407..7678e3f825 100644 --- a/examples/anthropic/agent.ts +++ b/examples/anthropic/agent.ts @@ -2,7 +2,7 @@ import { Anthropic, FunctionTool, Settings, WikipediaTool } from "llamaindex"; import { AnthropicAgent } from "llamaindex/agent/anthropic"; Settings.callbackManager.on("llm-tool-call", (event) => { - console.log("llm-tool-call", event.detail.payload.toolCall); + console.log("llm-tool-call", event.detail.toolCall); }); const anthropic = new Anthropic({ diff --git a/examples/multimodal/context.ts b/examples/multimodal/context.ts index b8b72b7cc0..c4a28646cd 100644 --- a/examples/multimodal/context.ts +++ b/examples/multimodal/context.ts @@ -4,7 +4,6 @@ import { NodeWithScore, ObjectType, OpenAI, - RetrievalEndEvent, Settings, VectorStoreIndex, } from "llamaindex"; @@ -18,8 +17,8 @@ Settings.chunkOverlap = 20; Settings.llm = new OpenAI({ model: "gpt-4-turbo", maxTokens: 512 }); // Update callbackManager -Settings.callbackManager.on("retrieve-end", (event: RetrievalEndEvent) => { - const { nodes, query } = event.detail.payload; +Settings.callbackManager.on("retrieve-end", (event) => { + const { nodes, query } = event.detail; const imageNodes = nodes.filter( (node: NodeWithScore) => node.node.type === ObjectType.IMAGE_DOCUMENT, ); diff --git a/examples/multimodal/rag.ts b/examples/multimodal/rag.ts index 2c7a4f104f..7d9a10c9ee 100644 --- a/examples/multimodal/rag.ts +++ b/examples/multimodal/rag.ts @@ -1,7 +1,6 @@ import { MultiModalResponseSynthesizer, OpenAI, - RetrievalEndEvent, Settings, VectorStoreIndex, } from "llamaindex"; @@ -15,8 +14,8 @@ Settings.chunkOverlap = 20; Settings.llm = new OpenAI({ model: "gpt-4-turbo", maxTokens: 512 }); // Update callbackManager -Settings.callbackManager.on("retrieve-end", (event: RetrievalEndEvent) => { - const { nodes, query } = event.detail.payload; +Settings.callbackManager.on("retrieve-end", (event) => { + const { nodes, query } = event.detail; console.log(`Retrieved ${nodes.length} nodes for query: ${query}`); }); diff --git a/examples/qdrantdb/preFilters.ts b/examples/qdrantdb/preFilters.ts index 133ac7122b..d60751d9fc 100644 --- a/examples/qdrantdb/preFilters.ts +++ b/examples/qdrantdb/preFilters.ts @@ -11,12 +11,10 @@ import { // Update callback manager Settings.callbackManager.on("retrieve-end", (event) => { - const data = event.detail.payload; + const { nodes } = event.detail; console.log( "The retrieved nodes are:", - data.nodes.map((node: NodeWithScore) => - node.node.getContent(MetadataMode.NONE), - ), + nodes.map((node: NodeWithScore) => node.node.getContent(MetadataMode.NONE)), ); }); diff --git a/examples/recipes/cost-analysis.ts b/examples/recipes/cost-analysis.ts index 725d35d1c7..09da953080 100644 --- a/examples/recipes/cost-analysis.ts +++ b/examples/recipes/cost-analysis.ts @@ -1,6 +1,6 @@ import { extractText } from "@llamaindex/core/utils"; import { encodingForModel } from "js-tiktoken"; -import { ChatMessage, OpenAI, type LLMStartEvent } from "llamaindex"; +import { ChatMessage, OpenAI } from "llamaindex"; import { Settings } from "llamaindex/Settings"; const encoding = encodingForModel("gpt-4-0125-preview"); @@ -12,8 +12,8 @@ const llm = new OpenAI({ let tokenCount = 0; -Settings.callbackManager.on("llm-start", (event: LLMStartEvent) => { - const { messages } = event.detail.payload; +Settings.callbackManager.on("llm-start", (event) => { + const { messages } = event.detail; messages.reduce((count: number, message: ChatMessage) => { return count + encoding.encode(extractText(message.content)).length; }, 0); @@ -24,7 +24,7 @@ Settings.callbackManager.on("llm-start", (event: LLMStartEvent) => { }); Settings.callbackManager.on("llm-stream", (event) => { - const { chunk } = event.detail.payload; + const { chunk } = event.detail; const { delta } = chunk; tokenCount += encoding.encode(extractText(delta)).length; if (tokenCount > 20) { diff --git a/packages/core/src/global/index.ts b/packages/core/src/global/index.ts index 2c54f3a7e7..d3afd40b98 100644 --- a/packages/core/src/global/index.ts +++ b/packages/core/src/global/index.ts @@ -1,7 +1,6 @@ export { Settings } from "./settings"; export { CallbackManager } from "./settings/callback-manager"; export type { - BaseEvent, LLMEndEvent, LLMStartEvent, LLMStreamEvent, diff --git a/packages/core/src/global/settings/callback-manager.ts b/packages/core/src/global/settings/callback-manager.ts index d85f9bb791..53c87db660 100644 --- a/packages/core/src/global/settings/callback-manager.ts +++ b/packages/core/src/global/settings/callback-manager.ts @@ -6,31 +6,32 @@ import type { ToolCall, ToolOutput, } from "../../llms"; +import { EventCaller, getEventCaller } from "../../utils/event-caller"; import type { UUID } from "../type"; -export type BaseEvent = CustomEvent<{ - payload: Readonly; -}>; - -export type LLMStartEvent = BaseEvent<{ +export type LLMStartEvent = { id: UUID; messages: ChatMessage[]; -}>; -export type LLMToolCallEvent = BaseEvent<{ +}; + +export type LLMToolCallEvent = { toolCall: ToolCall; -}>; -export type LLMToolResultEvent = BaseEvent<{ +}; + +export type LLMToolResultEvent = { toolCall: ToolCall; toolResult: ToolOutput; -}>; -export type LLMEndEvent = BaseEvent<{ +}; + +export type LLMEndEvent = { id: UUID; response: ChatResponse; -}>; -export type LLMStreamEvent = BaseEvent<{ +}; + +export type LLMStreamEvent = { id: UUID; chunk: ChatResponseChunk; -}>; +}; export interface LlamaIndexEventMaps { "llm-start": LLMStartEvent; @@ -41,24 +42,32 @@ export interface LlamaIndexEventMaps { } export class LlamaIndexCustomEvent extends CustomEvent { - private constructor(event: string, options?: CustomEventInit) { + reason: EventCaller | null = null; + private constructor( + event: string, + options?: CustomEventInit & { + reason?: EventCaller | null; + }, + ) { super(event, options); + this.reason = options?.reason ?? null; } static fromEvent( type: Type, - detail: LlamaIndexEventMaps[Type]["detail"], + detail: LlamaIndexEventMaps[Type], ) { return new LlamaIndexCustomEvent(type, { detail: detail, + reason: getEventCaller(), }); } } -type EventHandler = (event: Event) => void; +type EventHandler = (event: LlamaIndexCustomEvent) => void; export class CallbackManager { - #handlers = new Map[]>(); + #handlers = new Map[]>(); on( event: K, @@ -88,7 +97,7 @@ export class CallbackManager { dispatchEvent( event: K, - detail: LlamaIndexEventMaps[K]["detail"], + detail: LlamaIndexEventMaps[K], ) { const cbs = this.#handlers.get(event); if (!cbs) { diff --git a/packages/core/src/utils/wrap-llm-event.ts b/packages/core/src/utils/wrap-llm-event.ts index 41f85e601b..88e309a3f5 100644 --- a/packages/core/src/utils/wrap-llm-event.ts +++ b/packages/core/src/utils/wrap-llm-event.ts @@ -22,10 +22,8 @@ export function wrapLLMEvent< > { const id = randomUUID(); getCallbackManager().dispatchEvent("llm-start", { - payload: { - id, - messages: params[0].messages, - }, + id, + messages: params[0].messages, }); const response = await originalMethod.call(this, ...params); if (Symbol.asyncIterator in response) { @@ -58,29 +56,23 @@ export function wrapLLMEvent< }; } getCallbackManager().dispatchEvent("llm-stream", { - payload: { - id, - chunk, - }, + id, + chunk, }); finalResponse.raw.push(chunk); yield chunk; } snapshot(() => { getCallbackManager().dispatchEvent("llm-end", { - payload: { - id, - response: finalResponse, - }, + id, + response: finalResponse, }); }); }; } else { getCallbackManager().dispatchEvent("llm-end", { - payload: { - id, - response, - }, + id, + response, }); } return response; diff --git a/packages/core/tests/event-system.test.ts b/packages/core/tests/event-system.test.ts new file mode 100644 index 0000000000..657391219e --- /dev/null +++ b/packages/core/tests/event-system.test.ts @@ -0,0 +1,67 @@ +import { CallbackManager, Settings } from "@llamaindex/core/global"; +import { beforeEach, describe, expect, expectTypeOf, test, vi } from "vitest"; + +declare module "@llamaindex/core/global" { + interface LlamaIndexEventMaps { + test: { + value: number; + }; + } +} + +describe("event system", () => { + beforeEach(() => { + Settings.callbackManager = new CallbackManager(); + }); + + test("type system", () => { + Settings.callbackManager.on("test", (event) => { + const data = event.detail; + expectTypeOf(data).not.toBeAny(); + expectTypeOf(data).toEqualTypeOf<{ + value: number; + }>(); + }); + }); + + test("dispatch event", async () => { + let callback; + Settings.callbackManager.on( + "test", + (callback = vi.fn((event) => { + const data = event.detail; + expect(data.value).toBe(42); + })), + ); + + Settings.callbackManager.dispatchEvent("test", { + value: 42, + }); + expect(callback).toHaveBeenCalledTimes(0); + await new Promise((resolve) => process.nextTick(resolve)); + expect(callback).toHaveBeenCalledTimes(1); + }); + + // rollup doesn't support decorators for now + // test('wrap event caller', async () => { + // class A { + // @wrapEventCaller + // fn() { + // Settings.callbackManager.dispatchEvent('test', { + // value: 42 + // }); + // } + // } + // const a = new A(); + // let callback; + // Settings.callbackManager.on('test', callback = vi.fn((event) => { + // const data = event.detail; + // expect(event.reason!.caller).toBe(a); + // expect(data.value).toBe(42); + // })); + // a.fn(); + // expect(callback).toHaveBeenCalledTimes(0); + // await new Promise((resolve) => process.nextTick(resolve)); + // expect(callback).toHaveBeenCalledTimes(1); + // }) +}); diff --git a/packages/llamaindex/e2e/examples/nextjs-node-runtime/src/actions/openai.ts b/packages/llamaindex/e2e/examples/nextjs-node-runtime/src/actions/openai.ts index 1621a09473..29c7733d85 100644 --- a/packages/llamaindex/e2e/examples/nextjs-node-runtime/src/actions/openai.ts +++ b/packages/llamaindex/e2e/examples/nextjs-node-runtime/src/actions/openai.ts @@ -19,10 +19,10 @@ Settings.embedModel = new HuggingFaceEmbedding({ quantized: false, }); Settings.callbackManager.on("llm-tool-call", (event) => { - console.log(event.detail.payload); + console.log(event.detail); }); Settings.callbackManager.on("llm-tool-result", (event) => { - console.log(event.detail.payload); + console.log(event.detail); }); export async function getOpenAIModelRequest(query: string) { diff --git a/packages/llamaindex/e2e/node/utils.ts b/packages/llamaindex/e2e/node/utils.ts index 717fed9b1e..7f216cccb3 100644 --- a/packages/llamaindex/e2e/node/utils.ts +++ b/packages/llamaindex/e2e/node/utils.ts @@ -5,15 +5,16 @@ import { type LLMStartEvent, type LLMStreamEvent, } from "@llamaindex/core/global"; +import { CustomEvent } from "@llamaindex/env"; import { readFile, writeFile } from "node:fs/promises"; import { join } from "node:path"; import { type test } from "node:test"; import { fileURLToPath } from "node:url"; type MockStorage = { - llmEventStart: LLMStartEvent["detail"]["payload"][]; - llmEventEnd: LLMEndEvent["detail"]["payload"][]; - llmEventStream: LLMStreamEvent["detail"]["payload"][]; + llmEventStart: LLMStartEvent[]; + llmEventEnd: LLMEndEvent[]; + llmEventStream: LLMStreamEvent[]; }; export const llmCompleteMockStorage: MockStorage = { @@ -36,35 +37,35 @@ export async function mockLLMEvent( llmEventStream: [], }; - function captureLLMStart(event: LLMStartEvent) { - idMap.set(event.detail.payload.id, `PRESERVE_${counter++}`); + function captureLLMStart(event: CustomEvent) { + idMap.set(event.detail.id, `PRESERVE_${counter++}`); newLLMCompleteMockStorage.llmEventStart.push({ - ...event.detail.payload, + ...event.detail, // @ts-expect-error id is not UUID, but it is fine for testing id: idMap.get(event.detail.payload.id)!, }); } - function captureLLMEnd(event: LLMEndEvent) { + function captureLLMEnd(event: CustomEvent) { newLLMCompleteMockStorage.llmEventEnd.push({ - ...event.detail.payload, + ...event.detail, // @ts-expect-error id is not UUID, but it is fine for testing id: idMap.get(event.detail.payload.id)!, response: { - ...event.detail.payload.response, + ...event.detail.response, // hide raw object since it might too big raw: null, }, }); } - function captureLLMStream(event: LLMStreamEvent) { + function captureLLMStream(event: CustomEvent) { newLLMCompleteMockStorage.llmEventStream.push({ - ...event.detail.payload, + ...event.detail, // @ts-expect-error id is not UUID, but it is fine for testing id: idMap.get(event.detail.payload.id)!, chunk: { - ...event.detail.payload.chunk, + ...event.detail.chunk, // hide raw object since it might too big raw: null, }, diff --git a/packages/llamaindex/src/agent/base.ts b/packages/llamaindex/src/agent/base.ts index 464f37b73c..7d55b3aacf 100644 --- a/packages/llamaindex/src/agent/base.ts +++ b/packages/llamaindex/src/agent/base.ts @@ -69,9 +69,7 @@ export function createTaskOutputStream< controller.enqueue(output); }; Settings.callbackManager.dispatchEvent("agent-start", { - payload: { - startStep: step, - }, + startStep: step, }); context.logger.log("Starting step(id, %s).", step.id); @@ -93,9 +91,7 @@ export function createTaskOutputStream< step.id, ); Settings.callbackManager.dispatchEvent("agent-end", { - payload: { - endStep: step, - }, + endStep: step, }); controller.close(); } diff --git a/packages/llamaindex/src/agent/types.ts b/packages/llamaindex/src/agent/types.ts index bc14da7fdd..d1ee13aa8c 100644 --- a/packages/llamaindex/src/agent/types.ts +++ b/packages/llamaindex/src/agent/types.ts @@ -1,4 +1,3 @@ -import type { BaseEvent } from "@llamaindex/core/global"; import type { BaseToolWithCall, ChatMessage, @@ -90,9 +89,9 @@ export type TaskHandler< ) => void, ) => Promise; -export type AgentStartEvent = BaseEvent<{ +export type AgentStartEvent = { startStep: TaskStep; -}>; -export type AgentEndEvent = BaseEvent<{ +}; +export type AgentEndEvent = { endStep: TaskStep; -}>; +}; diff --git a/packages/llamaindex/src/agent/utils.ts b/packages/llamaindex/src/agent/utils.ts index 48962a8bfc..ff87110077 100644 --- a/packages/llamaindex/src/agent/utils.ts +++ b/packages/llamaindex/src/agent/utils.ts @@ -225,9 +225,7 @@ export async function callTool( } try { Settings.callbackManager.dispatchEvent("llm-tool-call", { - payload: { - toolCall: { ...toolCall, input }, - }, + toolCall: { ...toolCall, input }, }); output = await call.call(tool, input); logger.log( @@ -241,10 +239,8 @@ export async function callTool( isError: false, }; Settings.callbackManager.dispatchEvent("llm-tool-result", { - payload: { - toolCall: { ...toolCall, input }, - toolResult: { ...toolOutput }, - }, + toolCall: { ...toolCall, input }, + toolResult: { ...toolOutput }, }); return toolOutput; } catch (e) { diff --git a/packages/llamaindex/src/cloud/LlamaCloudRetriever.ts b/packages/llamaindex/src/cloud/LlamaCloudRetriever.ts index 29b8a55a16..74f7041397 100644 --- a/packages/llamaindex/src/cloud/LlamaCloudRetriever.ts +++ b/packages/llamaindex/src/cloud/LlamaCloudRetriever.ts @@ -92,10 +92,8 @@ export class LlamaCloudRetriever implements BaseRetriever { results.retrieval_nodes, ); Settings.callbackManager.dispatchEvent("retrieve-end", { - payload: { - query, - nodes: nodesWithScores, - }, + query, + nodes: nodesWithScores, }); return nodesWithScores; } diff --git a/packages/llamaindex/src/index.edge.ts b/packages/llamaindex/src/index.edge.ts index 95a512f16e..48cfcd00fc 100644 --- a/packages/llamaindex/src/index.edge.ts +++ b/packages/llamaindex/src/index.edge.ts @@ -13,7 +13,6 @@ declare module "@llamaindex/core/global" { export { CallbackManager } from "@llamaindex/core/global"; export type { - BaseEvent, JSONArray, JSONObject, JSONValue, diff --git a/packages/llamaindex/src/indices/vectorStore/index.ts b/packages/llamaindex/src/indices/vectorStore/index.ts index 0404ea6694..aab886649b 100644 --- a/packages/llamaindex/src/indices/vectorStore/index.ts +++ b/packages/llamaindex/src/indices/vectorStore/index.ts @@ -414,9 +414,7 @@ export class VectorIndexRetriever implements BaseRetriever { preFilters, }: RetrieveParams): Promise { Settings.callbackManager.dispatchEvent("retrieve-start", { - payload: { - query, - }, + query, }); const vectorStores = this.index.vectorStores; let nodesWithScores: NodeWithScore[] = []; @@ -433,10 +431,8 @@ export class VectorIndexRetriever implements BaseRetriever { ); } Settings.callbackManager.dispatchEvent("retrieve-end", { - payload: { - query, - nodes: nodesWithScores, - }, + query, + nodes: nodesWithScores, }); return nodesWithScores; } diff --git a/packages/llamaindex/src/llm/types.ts b/packages/llamaindex/src/llm/types.ts index d8f95c790c..c947d80c48 100644 --- a/packages/llamaindex/src/llm/types.ts +++ b/packages/llamaindex/src/llm/types.ts @@ -1,11 +1,10 @@ -import type { BaseEvent } from "@llamaindex/core/global"; import type { MessageContent } from "@llamaindex/core/llms"; import type { NodeWithScore } from "@llamaindex/core/schema"; -export type RetrievalStartEvent = BaseEvent<{ +export type RetrievalStartEvent = { query: MessageContent; -}>; -export type RetrievalEndEvent = BaseEvent<{ +}; +export type RetrievalEndEvent = { query: MessageContent; nodes: NodeWithScore[]; -}>; +};