Skip to content

Commit

Permalink
refactor: move llm & callback manager to core module (#1026)
Browse files Browse the repository at this point in the history
  • Loading branch information
himself65 authored Jul 8, 2024
1 parent f5c8ca7 commit c4bd0a5
Show file tree
Hide file tree
Showing 75 changed files with 641 additions and 983 deletions.
1 change: 1 addition & 0 deletions examples/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"@aws-crypto/sha256-js": "^5.2.0",
"@azure/identity": "^4.2.1",
"@datastax/astra-db-ts": "^1.2.1",
"@llamaindex/core": "^0.0.3",
"@notionhq/client": "^2.2.15",
"@pinecone-database/pinecone": "^2.2.2",
"@zilliz/milvus2-sdk-node": "^2.4.2",
Expand Down
17 changes: 9 additions & 8 deletions examples/qdrantdb/preFilters.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
import * as dotenv from "dotenv";
import {
CallbackManager,
Document,
MetadataMode,
NodeWithScore,
QdrantVectorStore,
Settings,
VectorStoreIndex,
storageContextFromDefaults,
} from "llamaindex";

// Update callback manager
Settings.callbackManager = new CallbackManager({
onRetrieve: (data) => {
console.log(
"The retrieved nodes are:",
data.nodes.map((node) => node.node.getContent(MetadataMode.NONE)),
);
},
Settings.callbackManager.on("retrieve-end", (event) => {
const data = event.detail.payload;
console.log(
"The retrieved nodes are:",
data.nodes.map((node: NodeWithScore) =>
node.node.getContent(MetadataMode.NONE),
),
);
});

// Load environment variables from local .env file
Expand Down
2 changes: 1 addition & 1 deletion examples/recipes/cost-analysis.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { extractText } from "@llamaindex/core/utils";
import { encodingForModel } from "js-tiktoken";
import { ChatMessage, OpenAI, type LLMStartEvent } from "llamaindex";
import { Settings } from "llamaindex/Settings";
import { extractText } from "llamaindex/llm/utils";

const encoding = encodingForModel("gpt-4-0125-preview");

Expand Down
4 changes: 2 additions & 2 deletions packages/community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@
"dev": "bunchee --watch"
},
"devDependencies": {
"@types/node": "^20.14.2",
"bunchee": "5.3.0-beta.0"
},
"dependencies": {
"@aws-sdk/client-bedrock-runtime": "^3.600.0",
"@types/node": "^20.14.2",
"llamaindex": "workspace:*"
"@llamaindex/core": "workspace:*"
}
}
25 changes: 13 additions & 12 deletions packages/community/src/llm/bedrock/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@ import {
InvokeModelCommand,
InvokeModelWithResponseStreamCommand,
} from "@aws-sdk/client-bedrock-runtime";
import type {
ChatMessage,
ChatResponse,
CompletionResponse,
LLMChatParamsNonStreaming,
LLMChatParamsStreaming,
LLMCompletionParamsNonStreaming,
LLMCompletionParamsStreaming,
LLMMetadata,
ToolCallLLMMessageOptions,
} from "llamaindex";
import { streamConverter, ToolCallLLM, wrapLLMEvent } from "llamaindex";
import {
type ChatMessage,
type ChatResponse,
type CompletionResponse,
type LLMChatParamsNonStreaming,
type LLMChatParamsStreaming,
type LLMCompletionParamsNonStreaming,
type LLMCompletionParamsStreaming,
type LLMMetadata,
ToolCallLLM,
type ToolCallLLMMessageOptions,
} from "@llamaindex/core/llms";
import { streamConverter, wrapLLMEvent } from "@llamaindex/core/utils";
import {
type BedrockAdditionalChatOptions,
type BedrockChatStreamResponse,
Expand Down
4 changes: 2 additions & 2 deletions packages/community/src/llm/bedrock/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import {
type ChatMessage,
type ChatResponseChunk,
type LLMMetadata,
streamConverter,
type ToolCallLLMMessageOptions,
} from "llamaindex";
} from "@llamaindex/core/llms";
import { streamConverter } from "@llamaindex/core/utils";
import type { ToolChoice } from "./types";
import { toUtf8 } from "./utils";

Expand Down
2 changes: 1 addition & 1 deletion packages/community/src/llm/bedrock/providers/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import type {
PartialToolCall,
ToolCall,
ToolCallLLMMessageOptions,
} from "llamaindex";
} from "@llamaindex/core/llms";
import {
type BedrockAdditionalChatOptions,
type BedrockChatStreamResponse,
Expand Down
2 changes: 1 addition & 1 deletion packages/community/src/llm/bedrock/providers/meta.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import type {
InvokeModelCommandInput,
InvokeModelWithResponseStreamCommandInput,
} from "@aws-sdk/client-bedrock-runtime";
import type { ChatMessage, LLMMetadata } from "llamaindex";
import type { ChatMessage, LLMMetadata } from "@llamaindex/core/llms";
import type { MetaNoneStreamingResponse, MetaStreamEvent } from "../types";
import {
mapChatMessagesToMetaLlama2Messages,
Expand Down
4 changes: 2 additions & 2 deletions packages/community/src/llm/bedrock/utils.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import type { JSONObject } from "@llamaindex/core/global";
import type {
BaseTool,
ChatMessage,
JSONObject,
MessageContent,
MessageContentDetail,
MessageContentTextDetail,
ToolCallLLMMessageOptions,
ToolMetadata,
} from "llamaindex";
} from "@llamaindex/core/llms";
import type {
AnthropicContent,
AnthropicImageContent,
Expand Down
14 changes: 14 additions & 0 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,20 @@
"types": "./dist/schema/index.d.ts",
"default": "./dist/schema/index.js"
}
},
"./utils": {
"require": {
"types": "./dist/utils/index.d.cts",
"default": "./dist/utils/index.cjs"
},
"import": {
"types": "./dist/utils/index.d.ts",
"default": "./dist/utils/index.js"
},
"default": {
"types": "./dist/utils/index.d.ts",
"default": "./dist/utils/index.js"
}
}
},
"files": [
Expand Down
11 changes: 11 additions & 0 deletions packages/core/src/global/index.ts
Original file line number Diff line number Diff line change
@@ -1 +1,12 @@
export { Settings } from "./settings";
export { CallbackManager } from "./settings/callback-manager";
export type {
BaseEvent,
LLMEndEvent,
LLMStartEvent,
LLMStreamEvent,
LLMToolCallEvent,
LLMToolResultEvent,
LlamaIndexEventMaps,
} from "./settings/callback-manager";
export type { JSONArray, JSONObject, JSONValue } from "./type";
21 changes: 21 additions & 0 deletions packages/core/src/global/settings.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import {
type CallbackManager,
getCallbackManager,
setCallbackManager,
withCallbackManager,
} from "./settings/callback-manager";
import {
getChunkSize,
setChunkSize,
Expand All @@ -14,4 +20,19 @@ export const Settings = {
withChunkSize<Result>(chunkSize: number, fn: () => Result): Result {
return withChunkSize(chunkSize, fn);
},

get callbackManager(): CallbackManager {
return getCallbackManager();
},

set callbackManager(callbackManager: CallbackManager) {
setCallbackManager(callbackManager);
},

withCallbackManager<Result>(
callbackManager: CallbackManager,
fn: () => Result,
): Result {
return withCallbackManager(callbackManager, fn);
},
};
131 changes: 131 additions & 0 deletions packages/core/src/global/settings/callback-manager.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import { AsyncLocalStorage, CustomEvent } from "@llamaindex/env";
import type {
ChatMessage,
ChatResponse,
ChatResponseChunk,
ToolCall,
ToolOutput,
} from "../../llms";
import type { UUID } from "../type";

export type BaseEvent<Payload> = CustomEvent<{
payload: Readonly<Payload>;
}>;

export type LLMStartEvent = BaseEvent<{
id: UUID;
messages: ChatMessage[];
}>;
export type LLMToolCallEvent = BaseEvent<{
toolCall: ToolCall;
}>;
export type LLMToolResultEvent = BaseEvent<{
toolCall: ToolCall;
toolResult: ToolOutput;
}>;
export type LLMEndEvent = BaseEvent<{
id: UUID;
response: ChatResponse;
}>;
export type LLMStreamEvent = BaseEvent<{
id: UUID;
chunk: ChatResponseChunk;
}>;

export interface LlamaIndexEventMaps {
"llm-start": LLMStartEvent;
"llm-end": LLMEndEvent;
"llm-tool-call": LLMToolCallEvent;
"llm-tool-result": LLMToolResultEvent;
"llm-stream": LLMStreamEvent;
}

export class LlamaIndexCustomEvent<T = any> extends CustomEvent<T> {
private constructor(event: string, options?: CustomEventInit) {
super(event, options);
}

static fromEvent<Type extends keyof LlamaIndexEventMaps>(
type: Type,
detail: LlamaIndexEventMaps[Type]["detail"],
) {
return new LlamaIndexCustomEvent(type, {
detail: detail,
});
}
}

type EventHandler<Event> = (event: Event) => void;

export class CallbackManager {
#handlers = new Map<keyof LlamaIndexEventMaps, EventHandler<CustomEvent>[]>();

on<K extends keyof LlamaIndexEventMaps>(
event: K,
handler: EventHandler<LlamaIndexEventMaps[K]>,
) {
if (!this.#handlers.has(event)) {
this.#handlers.set(event, []);
}
this.#handlers.get(event)!.push(handler);
return this;
}

off<K extends keyof LlamaIndexEventMaps>(
event: K,
handler: EventHandler<LlamaIndexEventMaps[K]>,
) {
if (!this.#handlers.has(event)) {
return this;
}
const cbs = this.#handlers.get(event)!;
const index = cbs.indexOf(handler);
if (index > -1) {
cbs.splice(index, 1);
}
return this;
}

dispatchEvent<K extends keyof LlamaIndexEventMaps>(
event: K,
detail: LlamaIndexEventMaps[K]["detail"],
) {
const cbs = this.#handlers.get(event);
if (!cbs) {
return;
}
queueMicrotask(() => {
cbs.forEach((handler) =>
handler(
LlamaIndexCustomEvent.fromEvent(event, structuredClone(detail)),
),
);
});
}
}

export const globalCallbackManager = new CallbackManager();

const callbackManagerAsyncLocalStorage =
new AsyncLocalStorage<CallbackManager>();

let currentCallbackManager: CallbackManager | null = null;

export function getCallbackManager(): CallbackManager {
return (
callbackManagerAsyncLocalStorage.getStore() ??
currentCallbackManager ??
globalCallbackManager
);
}

export function setCallbackManager(callbackManager: CallbackManager) {
currentCallbackManager = callbackManager;
}

export function withCallbackManager<Result>(
callbackManager: CallbackManager,
fn: () => Result,
): Result {
return callbackManagerAsyncLocalStorage.run(callbackManager, fn);
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { streamConverter } from "../utils";
import { extractText } from "../utils/llms";
import type {
ChatResponse,
ChatResponseChunk,
Expand All @@ -9,8 +11,7 @@ import type {
LLMCompletionParamsStreaming,
LLMMetadata,
ToolCallLLMMessageOptions,
} from "@llamaindex/core/llms";
import { extractText, streamConverter } from "./utils.js";
} from "./type";

export abstract class BaseLLM<
AdditionalChatOptions extends object = object,
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/llms/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export { BaseLLM, ToolCallLLM } from "./base";
export type {
BaseTool,
BaseToolWithCall,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
import { AsyncLocalStorage, randomUUID } from "@llamaindex/env";
import { isAsyncIterable, isIterable } from "../utils.js";

export const isAsyncIterable = (
obj: unknown,
): obj is AsyncIterable<unknown> => {
return obj != null && typeof obj === "object" && Symbol.asyncIterator in obj;
};

export const isIterable = (obj: unknown): obj is Iterable<unknown> => {
return obj != null && typeof obj === "object" && Symbol.iterator in obj;
};

const eventReasonAsyncLocalStorage = new AsyncLocalStorage<EventCaller>();

Expand Down
Loading

0 comments on commit c4bd0a5

Please sign in to comment.