From 643991a69a155c6e2156bb6bbb09034b787204ff Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Mon, 13 May 2024 13:43:30 +0200 Subject: [PATCH] [widgets] use `chatCompletionStream` --- packages/tasks/src/index.ts | 1 - packages/tasks/src/widget-example.ts | 9 +-- packages/widgets/package.json | 1 - packages/widgets/pnpm-lock.yaml | 3 - .../WidgetOutputConvo.svelte | 10 ++- .../InferenceWidget/shared/inputValidation.ts | 4 +- .../ConversationalWidget.svelte | 79 ++++++++----------- 7 files changed, 45 insertions(+), 62 deletions(-) diff --git a/packages/tasks/src/index.ts b/packages/tasks/src/index.ts index ceb3b04d1..9904fe9d5 100644 --- a/packages/tasks/src/index.ts +++ b/packages/tasks/src/index.ts @@ -19,7 +19,6 @@ export type { LibraryUiElement, ModelLibraryKey } from "./model-libraries"; export type { ModelData, TransformersInfo } from "./model-data"; export type { AddedToken, SpecialTokensMap, TokenizerConfig } from "./tokenizer-data"; export type { - ChatMessage, WidgetExample, WidgetExampleAttribute, WidgetExampleAssetAndPromptInput, diff --git a/packages/tasks/src/widget-example.ts b/packages/tasks/src/widget-example.ts index 2cfd3e2f4..3c47530f4 100644 --- a/packages/tasks/src/widget-example.ts +++ b/packages/tasks/src/widget-example.ts @@ -2,6 +2,8 @@ * See default-widget-inputs.ts for the default widget inputs, this files only contains the types */ +import type { ChatCompletionInputMessage } from "./tasks"; + type TableData = Record; //#region outputs @@ -51,13 +53,8 @@ export interface WidgetExampleBase { output?: TOutput; } -export interface ChatMessage { - role: "user" | "assistant" | "system"; - content: string; -} - export interface WidgetExampleChatInput extends WidgetExampleBase { - messages: ChatMessage[]; + messages: ChatCompletionInputMessage[]; } export interface WidgetExampleTextInput extends WidgetExampleBase { diff --git a/packages/widgets/package.json b/packages/widgets/package.json index 17bab136f..8232c66d4 100644 --- a/packages/widgets/package.json +++ b/packages/widgets/package.json @@ -46,7 +46,6 @@ ], "dependencies": { "@huggingface/inference": "workspace:^", - "@huggingface/jinja": "workspace:^", "@huggingface/tasks": "workspace:^", "marked": "^12.0.2" }, diff --git a/packages/widgets/pnpm-lock.yaml b/packages/widgets/pnpm-lock.yaml index c98fa8723..21e41edeb 100644 --- a/packages/widgets/pnpm-lock.yaml +++ b/packages/widgets/pnpm-lock.yaml @@ -8,9 +8,6 @@ dependencies: '@huggingface/inference': specifier: workspace:^ version: link:../inference - '@huggingface/jinja': - specifier: workspace:^ - version: link:../jinja '@huggingface/tasks': specifier: workspace:^ version: link:../tasks diff --git a/packages/widgets/src/lib/components/InferenceWidget/shared/WidgetOutputConvo/WidgetOutputConvo.svelte b/packages/widgets/src/lib/components/InferenceWidget/shared/WidgetOutputConvo/WidgetOutputConvo.svelte index 737febb34..7d6ddc2aa 100644 --- a/packages/widgets/src/lib/components/InferenceWidget/shared/WidgetOutputConvo/WidgetOutputConvo.svelte +++ b/packages/widgets/src/lib/components/InferenceWidget/shared/WidgetOutputConvo/WidgetOutputConvo.svelte @@ -3,11 +3,11 @@ import { isFullyScrolled, scrollToMax } from "../../../../utils/ViewUtils.js"; import WidgetOutputConvoBubble from "../WidgetOuputConvoBubble/WidgetOutputConvoBubble.svelte"; - import type { ChatMessage, SpecialTokensMap } from "@huggingface/tasks"; + import type { ChatCompletionInputMessage, SpecialTokensMap } from "@huggingface/tasks"; import { widgetStates } from "../../stores.js"; export let modelId: string; - export let messages: ChatMessage[]; + export let messages: ChatCompletionInputMessage[]; export let specialTokensMap: SpecialTokensMap | undefined = undefined; let wrapperEl: HTMLElement; @@ -30,8 +30,10 @@
{#each messages as message} - {@const position = message.role === "user" ? "right" : message.role === "assistant" ? "left" : "center"} - + {#if message.content} + {@const position = message.role === "user" ? "right" : message.role === "assistant" ? "left" : "center"} + + {/if} {/each}
diff --git a/packages/widgets/src/lib/components/InferenceWidget/shared/inputValidation.ts b/packages/widgets/src/lib/components/InferenceWidget/shared/inputValidation.ts index a90f8c822..dc274f418 100644 --- a/packages/widgets/src/lib/components/InferenceWidget/shared/inputValidation.ts +++ b/packages/widgets/src/lib/components/InferenceWidget/shared/inputValidation.ts @@ -1,5 +1,5 @@ import type { - ChatMessage, + ChatCompletionInputMessage, WidgetExampleAssetAndPromptInput, WidgetExampleAssetAndTextInput, WidgetExampleAssetAndZeroShotInput, @@ -104,7 +104,7 @@ export function isChatInput(sample: unknown): sample is WidgetExampleCh "messages" in sample && Array.isArray(sample.messages) && sample.messages.every( - (message): message is ChatMessage => + (message): message is ChatCompletionInputMessage => isObject(message) && "role" in message && "content" in message && diff --git a/packages/widgets/src/lib/components/InferenceWidget/widgets/ConversationalWidget/ConversationalWidget.svelte b/packages/widgets/src/lib/components/InferenceWidget/widgets/ConversationalWidget/ConversationalWidget.svelte index d6510d929..36ed0af34 100644 --- a/packages/widgets/src/lib/components/InferenceWidget/widgets/ConversationalWidget/ConversationalWidget.svelte +++ b/packages/widgets/src/lib/components/InferenceWidget/widgets/ConversationalWidget/ConversationalWidget.svelte @@ -2,21 +2,20 @@ import { onMount, tick } from "svelte"; import type { WidgetProps, ExampleRunOpts, InferenceRunOpts } from "../../shared/types.js"; import type { Options } from "@huggingface/inference"; - import { Template } from "@huggingface/jinja"; import type { SpecialTokensMap, TokenizerConfig, WidgetExampleTextInput, - TextGenerationInput, + ChatCompletionInput, WidgetExampleOutputText, WidgetExampleChatInput, WidgetExample, AddedToken, + ChatCompletionInputMessage, } from "@huggingface/tasks"; import { SPECIAL_TOKENS_ATTRIBUTES } from "@huggingface/tasks"; import { HfInference } from "@huggingface/inference"; - import type { ChatMessage } from "@huggingface/tasks"; import WidgetOutputConvo from "../../shared/WidgetOutputConvo/WidgetOutputConvo.svelte"; import WidgetQuickInput from "../../shared/WidgetQuickInput/WidgetQuickInput.svelte"; import WidgetWrapper from "../../shared/WidgetWrapper/WidgetWrapper.svelte"; @@ -40,13 +39,12 @@ $: isDisabled = $widgetStates?.[model.id]?.isDisabled; - let messages: ChatMessage[] = []; + let messages: ChatCompletionInputMessage[] = []; let error: string = ""; let isLoading: boolean = false; let outputJson: string; let text = ""; - let compiledTemplate: Template; let tokenizerConfig: TokenizerConfig; let specialTokensMap: SpecialTokensMap | undefined = undefined; let inferenceClient: HfInference | undefined = undefined; @@ -54,7 +52,7 @@ $: inferenceClient = new HfInference(apiToken); - // Check config and compile template + // check config onMount(() => { const config = model.config; if (config === undefined) { @@ -81,12 +79,6 @@ error = "No chat template found in tokenizer config"; return; } - try { - compiledTemplate = new Template(chatTemplate); - } catch (e) { - error = `Invalid chat template: "${(e as Error).message}"`; - return; - } }); async function handleNewMessage(): Promise { @@ -125,33 +117,18 @@ await tick(); return; } - if (!compiledTemplate) { - return; - } if (!inferenceClient) { error = "Inference client not ready"; return; } - // Render chat template + specialTokensMap = extractSpecialTokensMap(tokenizerConfig); - let chatText; - try { - chatText = compiledTemplate.render({ - messages, - add_generation_prompt: true, - ...specialTokensMap, - }); - } catch (e) { - error = `An error occurred while rendering the chat template: "${(e as Error).message}"`; - return; - } + const previousMessages = [...messages]; - const input: TextGenerationInput & Required> = { - inputs: chatText, - parameters: { - return_full_text: false, - }, + const input: ChatCompletionInput = { + model: model.id, + messages: previousMessages, }; addInferenceParameters(input, model); @@ -171,32 +148,44 @@ tgiSupportedModels = await getTgiSupportedModels(apiUrl); if ($tgiSupportedModels?.has(model.id)) { - console.debug("Starting text generation using the TGI streaming API"); + console.debug("Starting chat completion using the TGI streaming API"); let newMessage = { role: "assistant", content: "", - } satisfies ChatMessage; - const previousMessages = [...messages]; - const tokenStream = inferenceClient.textGenerationStream( + } satisfies ChatCompletionInputMessage; + + const tokenStream = inferenceClient.chatCompletionStream( { - ...input, - model: model.id, accessToken: apiToken, + ...input, }, opts ); + for await (const newToken of tokenStream) { - if (newToken.token.special) continue; - newMessage.content = newMessage.content + newToken.token.text; + const newTokenContent = newToken.choices?.[0].delta.content; + if (!newTokenContent) { + continue; + } + newMessage.content = newMessage.content + newTokenContent; messages = [...previousMessages, newMessage]; await tick(); } } else { - console.debug("Starting text generation using the synchronous API"); - input.parameters.max_new_tokens = 100; - const output = await inferenceClient.textGeneration({ ...input, model: model.id, accessToken: apiToken }, opts); - messages = [...messages, { role: "assistant", content: output.generated_text }]; - await tick(); + console.debug("Starting chat completion using the synchronous API"); + input.max_new_tokens = 100; + const output = await inferenceClient.chatCompletion( + { + accessToken: apiToken, + ...input, + }, + opts + ); + const newAssistantMsg = output.choices.at(-1)?.message; + if (newAssistantMsg) { + messages = [...messages, newAssistantMsg]; + await tick(); + } } } catch (e) { if (!isOnLoadCall) {