From 28a5328d9fa5d4064f1955d0be3adeec4647c106 Mon Sep 17 00:00:00 2001 From: antimonyGu Date: Tue, 5 Nov 2024 08:16:21 -0800 Subject: [PATCH] feat(ui): allow selecting model in answer engine (#3304) * llm select FE sketch * fetch model array & select model drop down * finish model select function & to be polished * fix selectedModel init value and style * fix lint * rename DropdownMenuItems * refine feature toggle * chore(answer): set chat model's name by selection * refine model select dropdown position * [autofix.ci] apply automated fixes * properly fill model name * rename selectedModelName to modelName * uplift modelName state * clean log * format code * fix lint * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * [autofix.ci] apply automated fixes (attempt 3/3) * feat(ui): show tool bar in search page * fix: fix ui test * refactor(chart-ui): persist modelName and show modelName selection in followup * handle selectedModel not in models api * fix: throw warning when request model is not in supported_models * fix: fix modelInfo?.chat not supported case and check request.mode is supported in BE * fix: using warn! to print warning * fix: fix ui lint * update: ajust style and check if model is valid * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Meng Zhang Co-authored-by: liangfung <1098486429@qq.com> --- Cargo.lock | 1 + crates/http-api-bindings/src/chat/mod.rs | 1 + crates/tabby-inference/Cargo.toml | 1 + crates/tabby-inference/src/chat.rs | 16 +- ee/tabby-schema/graphql/schema.graphql | 5 +- ee/tabby-schema/src/schema/thread/inputs.rs | 3 + ee/tabby-ui/app/(home)/page.tsx | 12 + .../components/assistant-message-section.tsx | 2 - ee/tabby-ui/app/search/components/search.css | 3 - ee/tabby-ui/app/search/components/search.tsx | 83 ++-- ee/tabby-ui/components/textarea-search.tsx | 355 +++++++++++++----- ee/tabby-ui/components/ui/dropdown-menu.tsx | 24 +- ee/tabby-ui/lib/hooks/use-models.tsx | 72 ++++ ee/tabby-ui/lib/stores/chat-actions.ts | 7 + ee/tabby-ui/lib/stores/chat-store.ts | 4 +- ee/tabby-ui/lib/types/chat.ts | 1 + ee/tabby-webserver/src/service/answer.rs | 3 +- 17 files changed, 446 insertions(+), 147 deletions(-) delete mode 100644 ee/tabby-ui/app/search/components/search.css create mode 100644 ee/tabby-ui/lib/hooks/use-models.tsx diff --git a/Cargo.lock b/Cargo.lock index 0df49890a4b5..3c8aae278b9f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5203,6 +5203,7 @@ dependencies = [ "reqwest", "secrecy", "tabby-common", + "tracing", "trie-rs", ] diff --git a/crates/http-api-bindings/src/chat/mod.rs b/crates/http-api-bindings/src/chat/mod.rs index 50ad471b8f46..b78bdec2c060 100644 --- a/crates/http-api-bindings/src/chat/mod.rs +++ b/crates/http-api-bindings/src/chat/mod.rs @@ -18,6 +18,7 @@ pub async fn create(model: &HttpModelConfig) -> Arc { let mut builder = ExtendedOpenAIConfig::builder(); builder .base(config) + .supported_models(model.supported_models.clone()) .model_name(model.model_name.as_deref().expect("Model name is required")); if model.kind == "openai/chat" { diff --git a/crates/tabby-inference/Cargo.toml b/crates/tabby-inference/Cargo.toml index d2532d542b57..c362b809d0dd 100644 --- a/crates/tabby-inference/Cargo.toml +++ b/crates/tabby-inference/Cargo.toml @@ -19,3 +19,4 @@ trie-rs = "0.1.1" async-openai.workspace = true secrecy = "0.8" reqwest.workspace = true +tracing.workspace = true diff --git a/crates/tabby-inference/src/chat.rs b/crates/tabby-inference/src/chat.rs index 252db5a260d8..5ef447daaa3a 100644 --- a/crates/tabby-inference/src/chat.rs +++ b/crates/tabby-inference/src/chat.rs @@ -7,6 +7,7 @@ use async_openai::{ }; use async_trait::async_trait; use derive_builder::Builder; +use tracing::warn; #[async_trait] pub trait ChatCompletionStream: Sync + Send { @@ -34,6 +35,9 @@ pub struct ExtendedOpenAIConfig { #[builder(setter(into))] model_name: String, + #[builder(setter(into))] + supported_models: Option>, + #[builder(default)] fields_to_remove: Vec, } @@ -54,7 +58,17 @@ impl ExtendedOpenAIConfig { &self, mut request: CreateChatCompletionRequest, ) -> CreateChatCompletionRequest { - request.model = self.model_name.clone(); + if request.model.is_empty() { + request.model = self.model_name.clone(); + } else if let Some(supported_models) = &self.supported_models { + if !supported_models.contains(&request.model) { + warn!( + "Warning: {} model is not supported, falling back to {}", + request.model, self.model_name + ); + request.model = self.model_name.clone(); + } + } for field in &self.fields_to_remove { match field { diff --git a/ee/tabby-schema/graphql/schema.graphql b/ee/tabby-schema/graphql/schema.graphql index 5cd8e3742a8a..6e47d4594323 100644 --- a/ee/tabby-schema/graphql/schema.graphql +++ b/ee/tabby-schema/graphql/schema.graphql @@ -134,7 +134,7 @@ input CreateMessageInput { input CreateThreadAndRunInput { thread: CreateThreadInput! - options: ThreadRunOptionsInput! = {codeQuery: null, debugOptions: null, docQuery: null, generateRelevantQuestions: false} + options: ThreadRunOptionsInput! = {codeQuery: null, debugOptions: null, docQuery: null, generateRelevantQuestions: false, modelName: null} } input CreateThreadInput { @@ -144,7 +144,7 @@ input CreateThreadInput { input CreateThreadRunInput { threadId: ID! additionalUserMessage: CreateMessageInput! - options: ThreadRunOptionsInput! = {codeQuery: null, debugOptions: null, docQuery: null, generateRelevantQuestions: false} + options: ThreadRunOptionsInput! = {codeQuery: null, debugOptions: null, docQuery: null, generateRelevantQuestions: false, modelName: null} } input CreateUserGroupInput { @@ -216,6 +216,7 @@ input ThreadRunDebugOptionsInput { } input ThreadRunOptionsInput { + modelName: String = null docQuery: DocQueryInput = null codeQuery: CodeQueryInput = null generateRelevantQuestions: Boolean! = false diff --git a/ee/tabby-schema/src/schema/thread/inputs.rs b/ee/tabby-schema/src/schema/thread/inputs.rs index e49a89325187..b495b153d3c8 100644 --- a/ee/tabby-schema/src/schema/thread/inputs.rs +++ b/ee/tabby-schema/src/schema/thread/inputs.rs @@ -70,6 +70,9 @@ fn validate_code_query_input(input: &CodeQueryInput) -> Result<(), ValidationErr #[derive(GraphQLInputObject, Validate, Default, Clone)] pub struct ThreadRunOptionsInput { + #[graphql(default)] + pub model_name: Option, + #[validate(nested)] #[graphql(default)] pub doc_query: Option, diff --git a/ee/tabby-ui/app/(home)/page.tsx b/ee/tabby-ui/app/(home)/page.tsx index 98fee9cb4242..d735cc7cbcfd 100644 --- a/ee/tabby-ui/app/(home)/page.tsx +++ b/ee/tabby-ui/app/(home)/page.tsx @@ -9,8 +9,10 @@ import { useQuery } from 'urql' import { SESSION_STORAGE_KEY } from '@/lib/constants' import { useHealth } from '@/lib/hooks/use-health' import { useMe } from '@/lib/hooks/use-me' +import { useSelectedModel } from '@/lib/hooks/use-models' import { useIsChatEnabled } from '@/lib/hooks/use-server-info' import { useStore } from '@/lib/hooks/use-store' +import { updateSelectedModel } from '@/lib/stores/chat-actions' import { clearHomeScrollPosition, setHomeScrollPosition, @@ -47,6 +49,8 @@ function MainPanel() { }) const scrollY = useStore(useScrollStore, state => state.homePage) + const { selectedModel, isModelLoading, models } = useSelectedModel() + // Prefetch the search page useEffect(() => { router.prefetch('/search') @@ -69,6 +73,10 @@ function MainPanel() { resettingScroller.current = true }, []) + const handleSelectModel = (model: string) => { + updateSelectedModel(model) + } + if (!healthInfo || !data?.me) return <> const onSearch = (question: string, ctx?: ThreadRunContexts) => { @@ -138,6 +146,10 @@ function MainPanel() { cleanAfterSearch={false} contextInfo={contextInfoData?.contextInfo} fetchingContextInfo={fetchingContextInfo} + modelName={selectedModel} + onModelSelect={handleSelectModel} + isModelLoading={isModelLoading} + models={models} /> )} diff --git a/ee/tabby-ui/app/search/components/assistant-message-section.tsx b/ee/tabby-ui/app/search/components/assistant-message-section.tsx index 7792cf05f7dd..d49da10157f9 100644 --- a/ee/tabby-ui/app/search/components/assistant-message-section.tsx +++ b/ee/tabby-ui/app/search/components/assistant-message-section.tsx @@ -1,7 +1,5 @@ 'use client' -import './search.css' - import { MouseEventHandler, useContext, useMemo, useState } from 'react' import { zodResolver } from '@hookform/resolvers/zod' import DOMPurify from 'dompurify' diff --git a/ee/tabby-ui/app/search/components/search.css b/ee/tabby-ui/app/search/components/search.css deleted file mode 100644 index 8bee68523e08..000000000000 --- a/ee/tabby-ui/app/search/components/search.css +++ /dev/null @@ -1,3 +0,0 @@ -.text-area-autosize::-webkit-scrollbar { - display: none; -} diff --git a/ee/tabby-ui/app/search/components/search.tsx b/ee/tabby-ui/app/search/components/search.tsx index 9c6234cd3f0a..adb54475d6f5 100644 --- a/ee/tabby-ui/app/search/components/search.tsx +++ b/ee/tabby-ui/app/search/components/search.tsx @@ -8,8 +8,14 @@ import { useRef, useState } from 'react' +import Link from 'next/link' import { useRouter } from 'next/navigation' +import slugify from '@sindresorhus/slugify' +import { compact, pick, some, uniq, uniqBy } from 'lodash-es' import { nanoid } from 'nanoid' +import { ImperativePanelHandle } from 'react-resizable-panels' +import { toast } from 'sonner' +import { useQuery } from 'urql' import { ERROR_CODE_NOT_FOUND, @@ -17,9 +23,33 @@ import { SLUG_TITLE_MAX_LENGTH } from '@/lib/constants' import { useEnableDeveloperMode } from '@/lib/experiment-flags' +import { graphql } from '@/lib/gql/generates' +import { + CodeQueryInput, + ContextInfo, + DocQueryInput, + InputMaybe, + Maybe, + Message, + Role +} from '@/lib/gql/generates/graphql' +import { useCopyToClipboard } from '@/lib/hooks/use-copy-to-clipboard' import { useCurrentTheme } from '@/lib/hooks/use-current-theme' +import { useDebounceValue } from '@/lib/hooks/use-debounce' import { useLatest } from '@/lib/hooks/use-latest' +import { useMe } from '@/lib/hooks/use-me' +import { useSelectedModel } from '@/lib/hooks/use-models' +import useRouterStuff from '@/lib/hooks/use-router-stuff' import { useIsChatEnabled } from '@/lib/hooks/use-server-info' +import { ExtendedCombinedError, useThreadRun } from '@/lib/hooks/use-thread-run' +import { updateSelectedModel } from '@/lib/stores/chat-actions' +import { clearHomeScrollPosition } from '@/lib/stores/scroll-store' +import { useMutation } from '@/lib/tabby/gql' +import { + contextInfoQuery, + listThreadMessages, + listThreads +} from '@/lib/tabby/query' import { AttachmentCodeItem, AttachmentDocItem, @@ -46,6 +76,7 @@ import { ResizablePanelGroup } from '@/components/ui/resizable' import { ScrollArea } from '@/components/ui/scroll-area' +import { Separator } from '@/components/ui/separator' import { ButtonScrollToBottom } from '@/components/button-scroll-to-bottom' import { ClientOnly } from '@/components/client-only' import { BANNER_HEIGHT, useShowDemoBanner } from '@/components/demo-banner' @@ -54,39 +85,6 @@ import { ThemeToggle } from '@/components/theme-toggle' import { MyAvatar } from '@/components/user-avatar' import UserPanel from '@/components/user-panel' -import './search.css' - -import Link from 'next/link' -import slugify from '@sindresorhus/slugify' -import { compact, pick, some, uniq, uniqBy } from 'lodash-es' -import { ImperativePanelHandle } from 'react-resizable-panels' -import { toast } from 'sonner' -import { useQuery } from 'urql' - -import { graphql } from '@/lib/gql/generates' -import { - CodeQueryInput, - ContextInfo, - DocQueryInput, - InputMaybe, - Maybe, - Message, - Role -} from '@/lib/gql/generates/graphql' -import { useCopyToClipboard } from '@/lib/hooks/use-copy-to-clipboard' -import { useDebounceValue } from '@/lib/hooks/use-debounce' -import { useMe } from '@/lib/hooks/use-me' -import useRouterStuff from '@/lib/hooks/use-router-stuff' -import { ExtendedCombinedError, useThreadRun } from '@/lib/hooks/use-thread-run' -import { clearHomeScrollPosition } from '@/lib/stores/scroll-store' -import { useMutation } from '@/lib/tabby/gql' -import { - contextInfoQuery, - listThreadMessages, - listThreads -} from '@/lib/tabby/query' -import { Separator } from '@/components/ui/separator' - import { AssistantMessageSection } from './assistant-message-section' import { DevPanel } from './dev-panel' import { MessagesSkeleton } from './messages-skeleton' @@ -319,6 +317,8 @@ export function Search() { const isLoadingRef = useLatest(isLoading) + const { selectedModel, isModelLoading, models } = useSelectedModel() + const currentMessageForDev = useMemo(() => { return messages.find(item => item.id === messageIdForDev) }, [messageIdForDev, messages]) @@ -376,6 +376,7 @@ export function Search() { if (initialMessage) { sessionStorage.removeItem(SESSION_STORAGE_KEY.SEARCH_INITIAL_MSG) sessionStorage.removeItem(SESSION_STORAGE_KEY.SEARCH_INITIAL_CONTEXTS) + setIsReady(true) onSubmitSearch(initialMessage, initialThreadRunContext) return @@ -571,7 +572,8 @@ export function Search() { { generateRelevantQuestions: true, codeQuery, - docQuery + docQuery, + modelName: ctx?.modelName } ) } @@ -638,7 +640,8 @@ export function Search() { threadRunOptions: { generateRelevantQuestions: true, codeQuery, - docQuery + docQuery, + modelName: selectedModel } }) } @@ -696,6 +699,10 @@ export function Search() { ) } + const onModelSelect = (model: string) => { + updateSelectedModel(model) + } + const hasThreadError = useMemo(() => { if (!isReady || fetchingThread || !threadIdFromURL) return undefined if (threadError || !threadData?.threads?.edges?.length) { @@ -867,10 +874,14 @@ export function Search() { onSearch={onSubmitSearch} className="min-h-[5rem] lg:max-w-4xl" placeholder="Ask a follow up question" - isLoading={isLoading} isFollowup + isLoading={isLoading} contextInfo={contextInfoData?.contextInfo} fetchingContextInfo={fetchingContextInfo} + modelName={selectedModel} + onModelSelect={onModelSelect} + isModelLoading={isModelLoading} + models={models} /> )} diff --git a/ee/tabby-ui/components/textarea-search.tsx b/ee/tabby-ui/components/textarea-search.tsx index 891ecc870a25..db7106736c72 100644 --- a/ee/tabby-ui/components/textarea-search.tsx +++ b/ee/tabby-ui/components/textarea-search.tsx @@ -2,6 +2,7 @@ import { useEffect, useMemo, useRef, useState } from 'react' import { Editor } from '@tiptap/react' +import { Maybe } from 'graphql/jsutils/Maybe' import { ContextInfo } from '@/lib/gql/generates/graphql' import { useCurrentTheme } from '@/lib/hooks/use-current-theme' @@ -12,19 +13,37 @@ import { getMentionsFromText, getThreadRunContextsFromMentions } from '@/lib/utils' +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuRadioGroup, + DropdownMenuRadioItem, + DropdownMenuTrigger +} from '@/components/ui/dropdown-menu' import { Tooltip, TooltipContent, TooltipTrigger } from '@/components/ui/tooltip' +import LoadingWrapper from './loading-wrapper' import { PromptEditor, PromptEditorRef } from './prompt-editor' import { Button } from './ui/button' -import { IconArrowRight, IconAtSign, IconHash, IconSpinner } from './ui/icons' +import { + IconArrowRight, + IconAtSign, + IconBox, + IconCheck, + IconHash, + IconSpinner +} from './ui/icons' import { Separator } from './ui/separator' +import { Skeleton } from './ui/skeleton' export default function TextAreaSearch({ onSearch, + onModelSelect, + modelName, className, placeholder, showBetaBadge, @@ -34,9 +53,14 @@ export default function TextAreaSearch({ cleanAfterSearch = true, isFollowup, contextInfo, - fetchingContextInfo + fetchingContextInfo, + isModelLoading, + models }: { onSearch: (value: string, ctx: ThreadRunContexts) => void + onModelSelect: (v: string) => void + isModelLoading: boolean + modelName: string | undefined className?: string placeholder?: string showBetaBadge?: boolean @@ -48,11 +72,11 @@ export default function TextAreaSearch({ contextInfo?: ContextInfo fetchingContextInfo: boolean onValueChange?: (value: string | undefined) => void + models: Maybe> | undefined }) { const [isShow, setIsShow] = useState(false) const [isFocus, setIsFocus] = useState(false) const [value, setValue] = useState('') - const { theme } = useCurrentTheme() const editorRef = useRef(null) useEffect(() => { @@ -60,12 +84,23 @@ export default function TextAreaSearch({ setIsShow(true) }, []) - const onWrapperClick = () => { + const focusTextarea = () => { editorRef.current?.editor?.commands.focus() } + const onWrapperClick = () => { + focusTextarea() + } + + const handleSelectModel = (v: string) => { + onModelSelect(v) + setTimeout(() => { + focusTextarea() + }) + } + const handleSubmit = (editor: Editor | undefined | null) => { - if (!editor || isLoading) { + if (!editor || isLoading || isModelLoading) { return } @@ -73,7 +108,10 @@ export default function TextAreaSearch({ if (!text) return const mentions = getMentionsFromText(text, contextInfo?.sources) - const ctx = getThreadRunContextsFromMentions(mentions) + const ctx: ThreadRunContexts = { + ...getThreadRunContextsFromMentions(mentions), + modelName + } // do submit onSearch(text, ctx) @@ -81,6 +119,7 @@ export default function TextAreaSearch({ // clear content if (cleanAfterSearch) { editorRef.current?.editor?.chain().clearContent().focus().run() + setValue('') } } @@ -111,6 +150,8 @@ export default function TextAreaSearch({ return checkSourcesAvailability(contextInfo?.sources) }, [contextInfo?.sources]) + const showModelSelect = !!models?.length + return (
- {showBetaBadge && ( - - - - Beta - - - -

- Please note that the answer engine is still in its early stages, - and certain functionalities, such as finding the correct code - context and the quality of summarizations, still have room for - improvement. If you encounter an issue and believe it can be - enhanced, consider sharing it in our Slack community! -

-
-
- )} + {showBetaBadge && } +
- setIsFocus(true)} - onBlur={() => setIsFocus(false)} - onUpdate={({ editor }) => setValue(editor.getText().trim())} - ref={editorRef} - placement={isFollowup ? 'bottom' : 'top'} - className={cn( - 'text-area-autosize mr-1 flex-1 resize-none rounded-lg !border-none bg-transparent !shadow-none !outline-none !ring-0 !ring-offset-0', - { - '!h-[48px]': !isShow, - 'py-3': !showBetaBadge, - 'py-4': showBetaBadge +
+ setIsFocus(true)} + onBlur={() => setIsFocus(false)} + onUpdate={({ editor }) => setValue(editor.getText().trim())} + ref={editorRef} + placement={isFollowup ? 'bottom' : 'top'} + className={cn( + 'text-area-autosize resize-none rounded-lg !border-none bg-transparent !shadow-none !outline-none !ring-0 !ring-offset-0', + { + '!h-[48px]': !isShow && !isFollowup, + '!h-[24px]': !isShow && isFollowup, + 'py-3': !showBetaBadge, + 'py-4': showBetaBadge + } + )} + editorClassName={ + isFollowup && showModelSelect + ? 'min-h-[1.75rem]' + : 'min-h-[3.5em]' } + /> + {isFollowup && showModelSelect && ( +
+ +
)} - // editorClassName={isFollowup ? 'min-h-[3.45rem]' : 'min-h-[3.5em]'} - editorClassName="min-h-[3.5em]" - /> -
+
+
0, '!bg-muted !text-primary !cursor-default': - isLoading || value.length === 0, + isLoading || value.length === 0 || isModelLoading, 'mr-1.5': !showBetaBadge - // 'mb-4': !showBetaBadge, - // 'mb-5': showBetaBadge } )} onClick={() => handleSubmit(editorRef.current?.editor)} @@ -197,60 +233,179 @@ export default function TextAreaSearch({
-
e.stopPropagation()} - > - {/* llm select */} - {/* - */} + + +
+ } + > + {/* mention codebase */} + + + + + + Select a codebase to chat with + + + - - - - - - Select a codebase to chat with - - + {/* mention docs */} + + + + + + Select a document to bring into context + + + + {/* model select */} + {!!models?.length && ( + <> + + + + )} + +
+ )} +
+ ) +} + +interface ModelSelectProps { + models: Maybe> | undefined + value: string | undefined + onChange: (v: string) => void + isInitializing?: boolean +} + +function ModelSelect({ + models, + value, + onChange, + isInitializing +}: ModelSelectProps) { + const onModelSelect = (v: string) => { + onChange(v) + } - - - + return ( + + + + } + > + {!!models?.length && ( + + - - - Select a document to bring into context - - - - + + + + {models.map(model => { + const isSelected = model === value + return ( + { + onModelSelect(model) + e.stopPropagation() + }} + value={model} + key={model} + className="cursor-pointer py-2 pl-3" + > + + + {model} + + + ) + })} + + + + )} + + ) +} + +function BetaBadge() { + const { theme } = useCurrentTheme() + return ( + + + + Beta + + + +

+ Please note that the answer engine is still in its early stages, and + certain functionalities, such as finding the correct code context and + the quality of summarizations, still have room for improvement. If you + encounter an issue and believe it can be enhanced, consider sharing it + in our Slack community! +

+
+
) } diff --git a/ee/tabby-ui/components/ui/dropdown-menu.tsx b/ee/tabby-ui/components/ui/dropdown-menu.tsx index 184d4e6007ef..c22a45768fdb 100644 --- a/ee/tabby-ui/components/ui/dropdown-menu.tsx +++ b/ee/tabby-ui/components/ui/dropdown-menu.tsx @@ -17,6 +17,8 @@ const DropdownMenuSub = DropdownMenuPrimitive.Sub const DropdownMenuRadioGroup = DropdownMenuPrimitive.RadioGroup +const DropdownMenuIndicator = DropdownMenuPrimitive.ItemIndicator + const DropdownMenuSubContent = React.forwardRef< React.ElementRef, React.ComponentPropsWithoutRef @@ -69,6 +71,24 @@ const DropdownMenuItem = React.forwardRef< )) DropdownMenuItem.displayName = DropdownMenuPrimitive.Item.displayName +const DropdownMenuRadioItem = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef & { + inset?: boolean + } +>(({ className, inset, ...props }, ref) => ( + +)) +DropdownMenuRadioItem.displayName = DropdownMenuPrimitive.RadioItem.displayName + const DropdownMenuLabel = React.forwardRef< React.ElementRef, React.ComponentPropsWithoutRef & { @@ -124,5 +144,7 @@ export { DropdownMenuPortal, DropdownMenuSub, DropdownMenuSubContent, - DropdownMenuRadioGroup + DropdownMenuRadioGroup, + DropdownMenuRadioItem, + DropdownMenuIndicator } diff --git a/ee/tabby-ui/lib/hooks/use-models.tsx b/ee/tabby-ui/lib/hooks/use-models.tsx new file mode 100644 index 000000000000..47c0ee7593a6 --- /dev/null +++ b/ee/tabby-ui/lib/hooks/use-models.tsx @@ -0,0 +1,72 @@ +'use client' + +import { useEffect } from 'react' +import { Maybe } from 'graphql/jsutils/Maybe' +import useSWR, { SWRResponse } from 'swr' + +import fetcher from '@/lib/tabby/fetcher' + +import { updateSelectedModel } from '../stores/chat-actions' +import { useChatStore } from '../stores/chat-store' +import { useStore } from './use-store' + +export interface ModelInfo { + completion: Maybe> + chat: Maybe> +} + +export function useModels(): SWRResponse { + return useSWR( + '/v1beta/models', + (url: string) => { + return fetcher(url, { + errorHandler: () => { + throw new Error('Fetch supported model failed.') + } + }) + }, + { + shouldRetryOnError: false + } + ) +} + +export function useSelectedModel() { + const { data: modelData, isLoading: isFetchingModel } = useModels() + const isModelHydrated = useStore(useChatStore, state => state._hasHydrated) + + const selectedModel = useStore(useChatStore, state => state.selectedModel) + + // once model hydrated, try to init model + useEffect(() => { + if (isModelHydrated && !isFetchingModel) { + // check if current model is valid + const validModel = getModelFromModelInfo(selectedModel, modelData?.chat) + if (selectedModel !== validModel) { + updateSelectedModel(validModel) + } + } + }, [isModelHydrated, isFetchingModel]) + + return { + // fetching model data or trying to get selected model from localstorage + isModelLoading: isFetchingModel || !isModelHydrated, + selectedModel, + models: modelData?.chat + } +} + +export function getModelFromModelInfo( + model: string | undefined, + models: Maybe> | undefined +) { + if (!models?.length) return undefined + + const isValidModel = !!model && models.includes(model) + if (isValidModel) { + return model + } + + // return the first model by default + return models[0] +} diff --git a/ee/tabby-ui/lib/stores/chat-actions.ts b/ee/tabby-ui/lib/stores/chat-actions.ts index 83b0ad511b18..17cd7b21b704 100644 --- a/ee/tabby-ui/lib/stores/chat-actions.ts +++ b/ee/tabby-ui/lib/stores/chat-actions.ts @@ -74,3 +74,10 @@ export const updateChat = (id: string, chat: Partial) => { }) })) } + +export const updateSelectedModel = (model: string | undefined) => { + set(state => ({ + ...state, + selectedModel: model + })) +} diff --git a/ee/tabby-ui/lib/stores/chat-store.ts b/ee/tabby-ui/lib/stores/chat-store.ts index f6168661b6b7..cf3c959ae633 100644 --- a/ee/tabby-ui/lib/stores/chat-store.ts +++ b/ee/tabby-ui/lib/stores/chat-store.ts @@ -11,12 +11,14 @@ export interface ChatState { activeChatId: string | undefined _hasHydrated: boolean setHasHydrated: (state: boolean) => void + selectedModel: string | undefined } const initialState: Omit = { _hasHydrated: false, chats: undefined, - activeChatId: nanoid() + activeChatId: nanoid(), + selectedModel: undefined } export const useChatStore = create()( diff --git a/ee/tabby-ui/lib/types/chat.ts b/ee/tabby-ui/lib/types/chat.ts index 1c0f597f565d..0d6bbe8b74a7 100644 --- a/ee/tabby-ui/lib/types/chat.ts +++ b/ee/tabby-ui/lib/types/chat.ts @@ -67,6 +67,7 @@ type MergeUnionType = { } export type ThreadRunContexts = { + modelName?: string searchPublic?: boolean docSourceIds?: string[] codeSourceIds?: string[] diff --git a/ee/tabby-webserver/src/service/answer.rs b/ee/tabby-webserver/src/service/answer.rs index c0276bcb034b..3813d7016283 100644 --- a/ee/tabby-webserver/src/service/answer.rs +++ b/ee/tabby-webserver/src/service/answer.rs @@ -154,12 +154,12 @@ impl AnswerService { CreateChatCompletionRequestArgs::default() .messages(chat_messages) + .model(options.model_name.as_deref().unwrap_or("")) .presence_penalty(self.config.presence_penalty) .build() .expect("Failed to build chat completion request") }; - let s = match self.chat.chat_stream(request).await { Ok(s) => s, Err(err) => { @@ -1066,6 +1066,7 @@ mod tests { ), ]; let options = ThreadRunOptionsInput { + model_name: None, code_query: Some(make_code_query_input( Some(TEST_SOURCE_ID), Some(TEST_GIT_URL),