diff --git a/src/helpers/completion.ts b/src/helpers/completion.ts index 85c1591..1b2be7b 100644 --- a/src/helpers/completion.ts +++ b/src/helpers/completion.ts @@ -1,4 +1,4 @@ -import { OpenAIApi, Configuration, ChatCompletionRequestMessage } from 'openai'; +import { OpenAIApi, Configuration, ChatCompletionRequestMessage, Model } from 'openai'; import dedent from 'dedent'; import { IncomingMessage } from 'http'; import { KnownError } from './error'; @@ -313,3 +313,10 @@ function getRevisionPrompt(prompt: string, code: string) { ${generationDetails} `; } + +export async function getModels(key: string, apiEndpoint: string): Promise { + const openAi = getOpenAi(key, apiEndpoint); + const response = await openAi.listModels(); + + return response.data.data.filter(model => model.object === 'model'); +} diff --git a/src/helpers/config.ts b/src/helpers/config.ts index a536270..603e8ff 100644 --- a/src/helpers/config.ts +++ b/src/helpers/config.ts @@ -8,6 +8,8 @@ import { KnownError, handleCliError } from './error'; import * as p from '@clack/prompts'; import { red } from 'kolorist'; import i18n from './i18n'; +import { getModels } from './completion'; +import { Model } from 'openai'; const { hasOwnProperty } = Object.prototype; export const hasOwn = (object: unknown, key: PropertyKey) => @@ -186,9 +188,18 @@ export const showConfigUI = async () => { if (p.isCancel(silentMode)) return; await setConfigs([['SILENT_MODE', silentMode ? 'true' : 'false']]); } else if (choice === 'MODEL') { - const model = await p.text({ - message: i18n.t('Enter the model you want to use'), - }); + const { + OPENAI_KEY: key, + OPENAI_API_ENDPOINT: apiEndpoint, + } = await getConfig(); + const models = await getModels(key, apiEndpoint); + const model = (await p.select({ + message: 'Pick a model.', + options: models.map((m: Model) => { + return { value: m.id, label: m.id }; + }) + })) as string; + if (p.isCancel(model)) return; await setConfigs([['MODEL', model]]); } else if (choice === 'LANGUAGE') {