Skip to content

Commit

Permalink
feat(ui): allow selecting model in answer engine (#3304)
Browse files Browse the repository at this point in the history
* llm select FE sketch

* fetch model array & select model drop down

* finish model select function & to be polished

* fix selectedModel init value and style

* fix lint

* rename DropdownMenuItems

* refine feature toggle

* chore(answer): set chat model's name by selection

* refine model select dropdown position

* [autofix.ci] apply automated fixes

* properly fill model name

* rename selectedModelName to modelName

* uplift modelName state

* clean log

* format code

* fix lint

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* [autofix.ci] apply automated fixes (attempt 3/3)

* feat(ui): show tool bar in  search page

* fix: fix ui test

* refactor(chart-ui): persist modelName and show modelName selection in followup

* handle selectedModel not in models api

* fix: throw warning when request model is not in supported_models

* fix: fix modelInfo?.chat not supported case and check request.mode is supported in BE

* fix: using warn! to print warning

* fix: fix ui lint

* update: ajust style and check if model is valid

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Meng Zhang <[email protected]>
Co-authored-by: liangfung <[email protected]>
  • Loading branch information
4 people authored Nov 5, 2024
1 parent 8e29952 commit 28a5328
Show file tree
Hide file tree
Showing 17 changed files with 446 additions and 147 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/http-api-bindings/src/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
let mut builder = ExtendedOpenAIConfig::builder();
builder
.base(config)
.supported_models(model.supported_models.clone())
.model_name(model.model_name.as_deref().expect("Model name is required"));

if model.kind == "openai/chat" {
Expand Down
1 change: 1 addition & 0 deletions crates/tabby-inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ trie-rs = "0.1.1"
async-openai.workspace = true
secrecy = "0.8"
reqwest.workspace = true
tracing.workspace = true
16 changes: 15 additions & 1 deletion crates/tabby-inference/src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use async_openai::{
};
use async_trait::async_trait;
use derive_builder::Builder;
use tracing::warn;

#[async_trait]
pub trait ChatCompletionStream: Sync + Send {
Expand Down Expand Up @@ -34,6 +35,9 @@ pub struct ExtendedOpenAIConfig {
#[builder(setter(into))]
model_name: String,

#[builder(setter(into))]
supported_models: Option<Vec<String>>,

#[builder(default)]
fields_to_remove: Vec<OpenAIRequestFieldEnum>,
}
Expand All @@ -54,7 +58,17 @@ impl ExtendedOpenAIConfig {
&self,
mut request: CreateChatCompletionRequest,
) -> CreateChatCompletionRequest {
request.model = self.model_name.clone();
if request.model.is_empty() {
request.model = self.model_name.clone();
} else if let Some(supported_models) = &self.supported_models {
if !supported_models.contains(&request.model) {
warn!(
"Warning: {} model is not supported, falling back to {}",
request.model, self.model_name
);
request.model = self.model_name.clone();
}
}

for field in &self.fields_to_remove {
match field {
Expand Down
5 changes: 3 additions & 2 deletions ee/tabby-schema/graphql/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ input CreateMessageInput {

input CreateThreadAndRunInput {
thread: CreateThreadInput!
options: ThreadRunOptionsInput! = {codeQuery: null, debugOptions: null, docQuery: null, generateRelevantQuestions: false}
options: ThreadRunOptionsInput! = {codeQuery: null, debugOptions: null, docQuery: null, generateRelevantQuestions: false, modelName: null}
}

input CreateThreadInput {
Expand All @@ -144,7 +144,7 @@ input CreateThreadInput {
input CreateThreadRunInput {
threadId: ID!
additionalUserMessage: CreateMessageInput!
options: ThreadRunOptionsInput! = {codeQuery: null, debugOptions: null, docQuery: null, generateRelevantQuestions: false}
options: ThreadRunOptionsInput! = {codeQuery: null, debugOptions: null, docQuery: null, generateRelevantQuestions: false, modelName: null}
}

input CreateUserGroupInput {
Expand Down Expand Up @@ -216,6 +216,7 @@ input ThreadRunDebugOptionsInput {
}

input ThreadRunOptionsInput {
modelName: String = null
docQuery: DocQueryInput = null
codeQuery: CodeQueryInput = null
generateRelevantQuestions: Boolean! = false
Expand Down
3 changes: 3 additions & 0 deletions ee/tabby-schema/src/schema/thread/inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ fn validate_code_query_input(input: &CodeQueryInput) -> Result<(), ValidationErr

#[derive(GraphQLInputObject, Validate, Default, Clone)]
pub struct ThreadRunOptionsInput {
#[graphql(default)]
pub model_name: Option<String>,

#[validate(nested)]
#[graphql(default)]
pub doc_query: Option<DocQueryInput>,
Expand Down
12 changes: 12 additions & 0 deletions ee/tabby-ui/app/(home)/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import { useQuery } from 'urql'
import { SESSION_STORAGE_KEY } from '@/lib/constants'
import { useHealth } from '@/lib/hooks/use-health'
import { useMe } from '@/lib/hooks/use-me'
import { useSelectedModel } from '@/lib/hooks/use-models'
import { useIsChatEnabled } from '@/lib/hooks/use-server-info'
import { useStore } from '@/lib/hooks/use-store'
import { updateSelectedModel } from '@/lib/stores/chat-actions'
import {
clearHomeScrollPosition,
setHomeScrollPosition,
Expand Down Expand Up @@ -47,6 +49,8 @@ function MainPanel() {
})
const scrollY = useStore(useScrollStore, state => state.homePage)

const { selectedModel, isModelLoading, models } = useSelectedModel()

// Prefetch the search page
useEffect(() => {
router.prefetch('/search')
Expand All @@ -69,6 +73,10 @@ function MainPanel() {
resettingScroller.current = true
}, [])

const handleSelectModel = (model: string) => {
updateSelectedModel(model)
}

if (!healthInfo || !data?.me) return <></>

const onSearch = (question: string, ctx?: ThreadRunContexts) => {
Expand Down Expand Up @@ -138,6 +146,10 @@ function MainPanel() {
cleanAfterSearch={false}
contextInfo={contextInfoData?.contextInfo}
fetchingContextInfo={fetchingContextInfo}
modelName={selectedModel}
onModelSelect={handleSelectModel}
isModelLoading={isModelLoading}
models={models}
/>
</AnimationWrapper>
)}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
'use client'

import './search.css'

import { MouseEventHandler, useContext, useMemo, useState } from 'react'
import { zodResolver } from '@hookform/resolvers/zod'
import DOMPurify from 'dompurify'
Expand Down
3 changes: 0 additions & 3 deletions ee/tabby-ui/app/search/components/search.css

This file was deleted.

83 changes: 47 additions & 36 deletions ee/tabby-ui/app/search/components/search.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,48 @@ import {
useRef,
useState
} from 'react'
import Link from 'next/link'
import { useRouter } from 'next/navigation'
import slugify from '@sindresorhus/slugify'
import { compact, pick, some, uniq, uniqBy } from 'lodash-es'
import { nanoid } from 'nanoid'
import { ImperativePanelHandle } from 'react-resizable-panels'
import { toast } from 'sonner'
import { useQuery } from 'urql'

import {
ERROR_CODE_NOT_FOUND,
SESSION_STORAGE_KEY,
SLUG_TITLE_MAX_LENGTH
} from '@/lib/constants'
import { useEnableDeveloperMode } from '@/lib/experiment-flags'
import { graphql } from '@/lib/gql/generates'
import {
CodeQueryInput,
ContextInfo,
DocQueryInput,
InputMaybe,
Maybe,
Message,
Role
} from '@/lib/gql/generates/graphql'
import { useCopyToClipboard } from '@/lib/hooks/use-copy-to-clipboard'
import { useCurrentTheme } from '@/lib/hooks/use-current-theme'
import { useDebounceValue } from '@/lib/hooks/use-debounce'
import { useLatest } from '@/lib/hooks/use-latest'
import { useMe } from '@/lib/hooks/use-me'
import { useSelectedModel } from '@/lib/hooks/use-models'
import useRouterStuff from '@/lib/hooks/use-router-stuff'
import { useIsChatEnabled } from '@/lib/hooks/use-server-info'
import { ExtendedCombinedError, useThreadRun } from '@/lib/hooks/use-thread-run'
import { updateSelectedModel } from '@/lib/stores/chat-actions'
import { clearHomeScrollPosition } from '@/lib/stores/scroll-store'
import { useMutation } from '@/lib/tabby/gql'
import {
contextInfoQuery,
listThreadMessages,
listThreads
} from '@/lib/tabby/query'
import {
AttachmentCodeItem,
AttachmentDocItem,
Expand All @@ -46,6 +76,7 @@ import {
ResizablePanelGroup
} from '@/components/ui/resizable'
import { ScrollArea } from '@/components/ui/scroll-area'
import { Separator } from '@/components/ui/separator'
import { ButtonScrollToBottom } from '@/components/button-scroll-to-bottom'
import { ClientOnly } from '@/components/client-only'
import { BANNER_HEIGHT, useShowDemoBanner } from '@/components/demo-banner'
Expand All @@ -54,39 +85,6 @@ import { ThemeToggle } from '@/components/theme-toggle'
import { MyAvatar } from '@/components/user-avatar'
import UserPanel from '@/components/user-panel'

import './search.css'

import Link from 'next/link'
import slugify from '@sindresorhus/slugify'
import { compact, pick, some, uniq, uniqBy } from 'lodash-es'
import { ImperativePanelHandle } from 'react-resizable-panels'
import { toast } from 'sonner'
import { useQuery } from 'urql'

import { graphql } from '@/lib/gql/generates'
import {
CodeQueryInput,
ContextInfo,
DocQueryInput,
InputMaybe,
Maybe,
Message,
Role
} from '@/lib/gql/generates/graphql'
import { useCopyToClipboard } from '@/lib/hooks/use-copy-to-clipboard'
import { useDebounceValue } from '@/lib/hooks/use-debounce'
import { useMe } from '@/lib/hooks/use-me'
import useRouterStuff from '@/lib/hooks/use-router-stuff'
import { ExtendedCombinedError, useThreadRun } from '@/lib/hooks/use-thread-run'
import { clearHomeScrollPosition } from '@/lib/stores/scroll-store'
import { useMutation } from '@/lib/tabby/gql'
import {
contextInfoQuery,
listThreadMessages,
listThreads
} from '@/lib/tabby/query'
import { Separator } from '@/components/ui/separator'

import { AssistantMessageSection } from './assistant-message-section'
import { DevPanel } from './dev-panel'
import { MessagesSkeleton } from './messages-skeleton'
Expand Down Expand Up @@ -319,6 +317,8 @@ export function Search() {

const isLoadingRef = useLatest(isLoading)

const { selectedModel, isModelLoading, models } = useSelectedModel()

const currentMessageForDev = useMemo(() => {
return messages.find(item => item.id === messageIdForDev)
}, [messageIdForDev, messages])
Expand Down Expand Up @@ -376,6 +376,7 @@ export function Search() {
if (initialMessage) {
sessionStorage.removeItem(SESSION_STORAGE_KEY.SEARCH_INITIAL_MSG)
sessionStorage.removeItem(SESSION_STORAGE_KEY.SEARCH_INITIAL_CONTEXTS)

setIsReady(true)
onSubmitSearch(initialMessage, initialThreadRunContext)
return
Expand Down Expand Up @@ -571,7 +572,8 @@ export function Search() {
{
generateRelevantQuestions: true,
codeQuery,
docQuery
docQuery,
modelName: ctx?.modelName
}
)
}
Expand Down Expand Up @@ -638,7 +640,8 @@ export function Search() {
threadRunOptions: {
generateRelevantQuestions: true,
codeQuery,
docQuery
docQuery,
modelName: selectedModel
}
})
}
Expand Down Expand Up @@ -696,6 +699,10 @@ export function Search() {
)
}

const onModelSelect = (model: string) => {
updateSelectedModel(model)
}

const hasThreadError = useMemo(() => {
if (!isReady || fetchingThread || !threadIdFromURL) return undefined
if (threadError || !threadData?.threads?.edges?.length) {
Expand Down Expand Up @@ -867,10 +874,14 @@ export function Search() {
onSearch={onSubmitSearch}
className="min-h-[5rem] lg:max-w-4xl"
placeholder="Ask a follow up question"
isLoading={isLoading}
isFollowup
isLoading={isLoading}
contextInfo={contextInfoData?.contextInfo}
fetchingContextInfo={fetchingContextInfo}
modelName={selectedModel}
onModelSelect={onModelSelect}
isModelLoading={isModelLoading}
models={models}
/>
</div>
)}
Expand Down
Loading

0 comments on commit 28a5328

Please sign in to comment.