Skip to content

Commit

Permalink
Merge pull request ChatGPTNextWeb#5180 from frostime/contrib-modellist
Browse files Browse the repository at this point in the history
✨ feat: 调整模型列表,将自定义模型放在前面显示
  • Loading branch information
Dogtiti authored Aug 5, 2024
2 parents aa40015 + 3486954 commit a6b7432
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 4 deletions.
2 changes: 2 additions & 0 deletions app/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ export interface LLMModel {
displayName?: string;
available: boolean;
provider: LLMModelProvider;
sorted: number;
}

export interface LLMModelProvider {
id: string;
providerName: string;
providerType: string;
sorted: number;
}

export abstract class LLMApi {
Expand Down
4 changes: 4 additions & 0 deletions app/client/platforms/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -411,13 +411,17 @@ export class ChatGPTApi implements LLMApi {
return [];
}

//由于目前 OpenAI 的 disableListModels 默认为 true,所以当前实际不会运行到这场
let seq = 1000; //同 Constant.ts 中的排序保持一致
return chatModels.map((m) => ({
name: m.id,
available: true,
sorted: seq++,
provider: {
id: "openai",
providerName: "OpenAI",
providerType: "openai",
sorted: 1,
},
}));
}
Expand Down
19 changes: 19 additions & 0 deletions app/constant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -320,86 +320,105 @@ const tencentModels = [

const moonshotModes = ["moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"];

let seq = 1000; // 内置的模型序号生成器从1000开始
export const DEFAULT_MODELS = [
...openaiModels.map((name) => ({
name,
available: true,
sorted: seq++, // Global sequence sort(index)
provider: {
id: "openai",
providerName: "OpenAI",
providerType: "openai",
sorted: 1, // 这里是固定的,确保顺序与之前内置的版本一致
},
})),
...openaiModels.map((name) => ({
name,
available: true,
sorted: seq++,
provider: {
id: "azure",
providerName: "Azure",
providerType: "azure",
sorted: 2,
},
})),
...googleModels.map((name) => ({
name,
available: true,
sorted: seq++,
provider: {
id: "google",
providerName: "Google",
providerType: "google",
sorted: 3,
},
})),
...anthropicModels.map((name) => ({
name,
available: true,
sorted: seq++,
provider: {
id: "anthropic",
providerName: "Anthropic",
providerType: "anthropic",
sorted: 4,
},
})),
...baiduModels.map((name) => ({
name,
available: true,
sorted: seq++,
provider: {
id: "baidu",
providerName: "Baidu",
providerType: "baidu",
sorted: 5,
},
})),
...bytedanceModels.map((name) => ({
name,
available: true,
sorted: seq++,
provider: {
id: "bytedance",
providerName: "ByteDance",
providerType: "bytedance",
sorted: 6,
},
})),
...alibabaModes.map((name) => ({
name,
available: true,
sorted: seq++,
provider: {
id: "alibaba",
providerName: "Alibaba",
providerType: "alibaba",
sorted: 7,
},
})),
...tencentModels.map((name) => ({
name,
available: true,
sorted: seq++,
provider: {
id: "tencent",
providerName: "Tencent",
providerType: "tencent",
sorted: 8,
},
})),
...moonshotModes.map((name) => ({
name,
available: true,
sorted: seq++,
provider: {
id: "moonshot",
providerName: "Moonshot",
providerType: "moonshot",
sorted: 9,
},
})),
] as const;
Expand Down
48 changes: 44 additions & 4 deletions app/utils/model.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,42 @@
import { DEFAULT_MODELS } from "../constant";
import { LLMModel } from "../client/api";

const CustomSeq = {
val: -1000, //To ensure the custom model located at front, start from -1000, refer to constant.ts
cache: new Map<string, number>(),
next: (id: string) => {
if (CustomSeq.cache.has(id)) {
return CustomSeq.cache.get(id) as number;
} else {
let seq = CustomSeq.val++;
CustomSeq.cache.set(id, seq);
return seq;
}
},
};

const customProvider = (providerName: string) => ({
id: providerName.toLowerCase(),
providerName: providerName,
providerType: "custom",
sorted: CustomSeq.next(providerName),
});

/**
* Sorts an array of models based on specified rules.
*
* First, sorted by provider; if the same, sorted by model
*/
const sortModelTable = (models: ReturnType<typeof collectModels>) =>
models.sort((a, b) => {
if (a.provider && b.provider) {
let cmp = a.provider.sorted - b.provider.sorted;
return cmp === 0 ? a.sorted - b.sorted : cmp;
} else {
return a.sorted - b.sorted;
}
});

export function collectModelTable(
models: readonly LLMModel[],
customModels: string,
Expand All @@ -17,6 +47,7 @@ export function collectModelTable(
available: boolean;
name: string;
displayName: string;
sorted: number;
provider?: LLMModel["provider"]; // Marked as optional
isDefault?: boolean;
}
Expand Down Expand Up @@ -84,6 +115,7 @@ export function collectModelTable(
displayName: displayName || customModelName,
available,
provider, // Use optional chaining
sorted: CustomSeq.next(`${customModelName}@${provider?.id}`),
};
}
}
Expand All @@ -99,13 +131,16 @@ export function collectModelTableWithDefaultModel(
) {
let modelTable = collectModelTable(models, customModels);
if (defaultModel && defaultModel !== "") {
if (defaultModel.includes('@')) {
if (defaultModel.includes("@")) {
if (defaultModel in modelTable) {
modelTable[defaultModel].isDefault = true;
}
} else {
for (const key of Object.keys(modelTable)) {
if (modelTable[key].available && key.split('@').shift() == defaultModel) {
if (
modelTable[key].available &&
key.split("@").shift() == defaultModel
) {
modelTable[key].isDefault = true;
break;
}
Expand All @@ -123,7 +158,9 @@ export function collectModels(
customModels: string,
) {
const modelTable = collectModelTable(models, customModels);
const allModels = Object.values(modelTable);
let allModels = Object.values(modelTable);

allModels = sortModelTable(allModels);

return allModels;
}
Expand All @@ -138,7 +175,10 @@ export function collectModelsWithDefaultModel(
customModels,
defaultModel,
);
const allModels = Object.values(modelTable);
let allModels = Object.values(modelTable);

allModels = sortModelTable(allModels);

return allModels;
}

Expand Down

0 comments on commit a6b7432

Please sign in to comment.