From 12b91fd90ed86ba2dd5ced6acb61250d212646bf Mon Sep 17 00:00:00 2001 From: cdxker Date: Fri, 23 Aug 2024 19:14:24 -0700 Subject: [PATCH] feature: dashboard makes call to /embedding_models route --- .../src/components/NewDatasetModal.tsx | 120 +++++++++++------- .../Dashboard/Dataset/DatasetSettingsPage.tsx | 26 +++- frontends/dashboard/src/types/apiTypes.ts | 33 ----- frontends/shared/types.ts | 33 ----- server/src/handlers/chunk_handler.rs | 52 ++++++++ server/src/lib.rs | 3 + 6 files changed, 151 insertions(+), 116 deletions(-) diff --git a/frontends/dashboard/src/components/NewDatasetModal.tsx b/frontends/dashboard/src/components/NewDatasetModal.tsx index 2e45b4d130..63665f495a 100644 --- a/frontends/dashboard/src/components/NewDatasetModal.tsx +++ b/frontends/dashboard/src/components/NewDatasetModal.tsx @@ -1,4 +1,5 @@ -import { Accessor, createMemo, createSignal, useContext, For } from "solid-js"; +import { Accessor, createMemo, createSignal, createEffect, useContext, For } from "solid-js"; +import { AiOutlineWarning } from "solid-icons/ai"; import { Dialog, DialogPanel, @@ -12,7 +13,6 @@ import { useNavigate } from "@solidjs/router"; import { ServerEnvsConfiguration, availableDistanceMetrics, - availableEmbeddingModels, } from "shared/types"; import { defaultServerEnvsConfiguration } from "../pages/Dashboard/Dataset/DatasetSettingsPage"; import { createToast } from "./ShowToasts"; @@ -35,6 +35,20 @@ export const NewDatasetModal = (props: NewDatasetModalProps) => { const [isLoading, setIsLoading] = createSignal(false); const [fillWithExampleData, setFillWithExampleData] = createSignal(false); + const api_host = import.meta.env.VITE_API_HOST as unknown as string; + + const [availableEmbeddingModels, setAvailableEmbeddingModels] = + createSignal([]); + + createEffect(() => { + fetch(`${api_host}/embedding_models`) + .then((resp) => resp.json()) + .then((json) => { + console.log(json.models); + setAvailableEmbeddingModels(json.models) + }) + }); + const selectedOrgnaization = createMemo(() => { const selectedOrgId = userContext.selectedOrganizationId?.(); if (!selectedOrgId) return null; @@ -186,51 +200,69 @@ export const NewDatasetModal = (props: NewDatasetModalProps) => { > Embedding Model - + // eslint-disable-next-line @typescript-eslint/no-unsafe-member-access + model.id === + serverConfig().EMBEDDING_MODEL_NAME, // eslint-disable-next-line @typescript-eslint/no-unsafe-member-access - model.id === - serverConfig().EMBEDDING_MODEL_NAME, - // eslint-disable-next-line @typescript-eslint/no-unsafe-member-access - )?.name ?? availableEmbeddingModels[0].name - } - onChange={(e) => { - const selectedModel = availableEmbeddingModels.find( - (model) => model.name === e.currentTarget.value, - ); + )?.display_name ?? + availableEmbeddingModels()[0]?.display_name + } + onChange={(e) => { + const selectedModel = + availableEmbeddingModels()?.find( + (model) => + model.display_name === e.currentTarget.value, + ); - const embeddingSize = - selectedModel?.dimension ?? 1536; + const embeddingSize = + selectedModel?.dimension ?? 1536; - setServerConfig((prev) => { - return { - ...prev, - EMBEDDING_SIZE: embeddingSize, - EMBEDDING_MODEL_NAME: - selectedModel?.id ?? "jina-base-en", - EMBEDDING_QUERY_PREFIX: - selectedModel?.id === "jina-base-en" - ? "Search for:" - : "", - EMBEDDING_BASE_URL: - selectedModel?.url ?? - "https://api.openai.com/v1", - }; - }); - }} - > - - {(model) => ( - - )} - - + setServerConfig((prev) => { + return { + ...prev, + EMBEDDING_SIZE: embeddingSize, + EMBEDDING_MODEL_NAME: + selectedModel?.id ?? "jina-base-en", + EMBEDDING_QUERY_PREFIX: + selectedModel?.id === "jina-base-en" + ? "Search for:" + : "", + EMBEDDING_BASE_URL: + selectedModel?.url ?? + "https://api.openai.com/v1", + }; + }); + }} + > + + {(model) => ( + + )} + + + + +
+ +
+
+ No Embedding Models available +
+
+ Check server settings +
+
+
+
diff --git a/frontends/dashboard/src/pages/Dashboard/Dataset/DatasetSettingsPage.tsx b/frontends/dashboard/src/pages/Dashboard/Dataset/DatasetSettingsPage.tsx index 01ac79dd44..233c39ac81 100644 --- a/frontends/dashboard/src/pages/Dashboard/Dataset/DatasetSettingsPage.tsx +++ b/frontends/dashboard/src/pages/Dashboard/Dataset/DatasetSettingsPage.tsx @@ -12,7 +12,6 @@ import { DatasetContext } from "../../../contexts/DatasetContext"; import { ServerEnvsConfiguration, availableDistanceMetrics, - availableEmbeddingModels, } from "shared/types"; import { createToast } from "../../../components/ShowToasts"; import { AiOutlineInfoCircle } from "solid-icons/ai"; @@ -56,6 +55,21 @@ export const ServerSettingsForm = (props: { config: (prev: ServerEnvsConfiguration) => ServerEnvsConfiguration, ) => void; }) => { + + const api_host = import.meta.env.VITE_API_HOST as unknown as string; + + const [availableEmbeddingModels, setAvailableEmbeddingModels] = + createSignal([]); + + createEffect(() => { + fetch(`${api_host}/embedding_models`) + .then((resp) => resp.json()) + .then((json) => { + console.log(json.models); + setAvailableEmbeddingModels(json.models) + }) + }); + return (
{/* General LLM Settings */} @@ -554,14 +568,14 @@ export const ServerSettingsForm = (props: { name="embeddingSize" class="col-span-2 block w-full cursor-not-allowed rounded-md border-[0.5px] border-neutral-300 bg-white px-3 py-1.5 shadow-sm placeholder:text-neutral-400 focus:outline-magenta-500 sm:text-sm sm:leading-6" value={ - availableEmbeddingModels.find( + availableEmbeddingModels().find( (model) => model.id === props.serverConfig().EMBEDDING_MODEL_NAME, - )?.name ?? availableEmbeddingModels[0].name + )?.display_name ?? availableEmbeddingModels()[0]?.display_name } > - - {(model) => } + + {(model) => }
@@ -622,7 +636,7 @@ export const ServerSettingsForm = (props: { availableDistanceMetrics.find( (metric) => metric.id === props.serverConfig().DISTANCE_METRIC, - )?.name ?? availableEmbeddingModels[0].name + )?.name ?? availableDistanceMetrics[0].name } > diff --git a/frontends/dashboard/src/types/apiTypes.ts b/frontends/dashboard/src/types/apiTypes.ts index c52c7d714c..492e39f32f 100644 --- a/frontends/dashboard/src/types/apiTypes.ts +++ b/frontends/dashboard/src/types/apiTypes.ts @@ -256,39 +256,6 @@ export interface ApiKeyRespBody { updated_at: string; } -export const availableEmbeddingModels = [ - { - id: "jina-base-en", - name: "jina-base-en (securely hosted by Trieve)", - url: "https://embedding.trieve.ai", - dimension: 768, - }, - { - id: "bge-m3", - name: "bge-m3 (securely hosted by Trieve)", - url: "https://embedding.trieve.ai/bge-m3", - dimension: 1024, - }, - { - id: "jina-embeddings-v2-base-code", - name: "jina-embeddings-v2-base-code (securely hosted by Trieve)", - url: "https://embedding.trieve.ai/jina-code", - dimension: 768, - }, - { - id: "text-embedding-3-small", - name: "text-embedding-3-small (hosted by OpenAI)", - url: "https://api.openai.com/v1", - dimension: 1536, - }, - { - id: "text-embedding-3-large", - name: "text-embedding-3-large (hosted by OpenAI)", - url: "https://api.openai.com/v1", - dimension: 3072, - }, -]; - export interface EventDTO { events: Event[]; page_count: number; diff --git a/frontends/shared/types.ts b/frontends/shared/types.ts index ec61aea032..812191c9e1 100644 --- a/frontends/shared/types.ts +++ b/frontends/shared/types.ts @@ -266,39 +266,6 @@ export interface ApiKeyRespBody { updated_at: string; } -export const availableEmbeddingModels = [ - { - id: "jina-base-en", - name: "jina-base-en (securely hosted by Trieve)", - url: "https://embedding.trieve.ai", - dimension: 768, - }, - { - id: "bge-m3", - name: "bge-m3 (securely hosted by Trieve)", - url: "https://embedding.trieve.ai/bge-m3", - dimension: 1024, - }, - { - id: "jina-embeddings-v2-base-code", - name: "jina-embeddings-v2-base-code (securely hosted by Trieve)", - url: "https://embedding.trieve.ai/jina-code", - dimension: 768, - }, - { - id: "text-embedding-3-small", - name: "text-embedding-3-small (hosted by OpenAI)", - url: "https://api.openai.com/v1", - dimension: 1536, - }, - { - id: "text-embedding-3-large", - name: "text-embedding-3-large (hosted by OpenAI)", - url: "https://api.openai.com/v1", - dimension: 3072, - }, -]; - export const availableDistanceMetrics = [ { id: "cosine", diff --git a/server/src/handlers/chunk_handler.rs b/server/src/handlers/chunk_handler.rs index b588bb460c..2c3f960627 100644 --- a/server/src/handlers/chunk_handler.rs +++ b/server/src/handlers/chunk_handler.rs @@ -2634,3 +2634,55 @@ pub fn check_completion_param_validity( Ok(()) } + +#[derive(Serialize, Deserialize, Debug, Clone, ToSchema)] +pub struct AvailableModelResponse { + models: Vec, +} + +#[derive(Serialize, Deserialize, Debug, Clone, ToSchema)] +pub struct EmbeddingModel { + id: String, + display_name: String, + url: String, + dimension: u32, +} + +pub async fn get_available_models() -> Result { + let avail_models = AvailableModelResponse { + models: vec![ + EmbeddingModel { + id: "jina-base-en".into(), + display_name: "jina-base-en (securely hosted by Trieve)".into(), + url: "https://embedding.trieve.ai".into(), + dimension: 768, + }, + EmbeddingModel { + id: "bge-m3".into(), + display_name: "bge-m3 (securely hosted by Trieve)".into(), + url: "https://embedding.trieve.ai/bge-m3".into(), + dimension: 1024, + }, + EmbeddingModel { + id: "jina-embeddings-v2-base-code".into(), + display_name: "jina-embeddings-v2-base-code (securely hosted by Trieve)".into(), + url: "https://embedding.trieve.ai/jina-code".into(), + dimension: 768, + }, + EmbeddingModel { + id: "text-embedding-3-small".into(), + display_name: "text-embedding-3-small (hosted by OpenAI)".into(), + url: "https://api.openai.com/v1".into(), + dimension: 1536, + }, + EmbeddingModel { + id: "text-embedding-3-large".into(), + display_name: "text-embedding-3-large (hosted by OpenAI)".into(), + url: "https://api.openai.com/v1".into(), + dimension: 3072, + }, + ], + }; + + Ok(HttpResponse::Ok().json(avail_models)) +} diff --git a/server/src/lib.rs b/server/src/lib.rs index 7f882a6a2e..27eb1c320a 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -1127,6 +1127,9 @@ pub fn main() -> std::io::Result<()> { .route(web::put().to(handlers::analytics_handler::send_ctr_data)) .route(web::post().to(handlers::analytics_handler::get_ctr_analytics)), ) + ).service( + web::resource("/embedding_models") + .route(web::get().to(handlers::chunk_handler::get_available_models)) ), ) })