diff --git a/lib/shared/src/models/index.test.ts b/lib/shared/src/models/index.test.ts index 30770cfb8377..97209bc80bb7 100644 --- a/lib/shared/src/models/index.test.ts +++ b/lib/shared/src/models/index.test.ts @@ -1,4 +1,5 @@ -import { beforeEach, describe, expect, it } from 'vitest' +import { afterEach, beforeEach, describe, expect, it } from 'vitest' +import { mockAuthStatus } from '../auth/authStatus' import { AUTH_STATUS_FIXTURE_AUTHED, type AuthenticatedAuthStatus } from '../auth/types' import { Model, @@ -39,6 +40,9 @@ describe('Model Provider', () => { beforeEach(() => { modelsService = new ModelsService() }) + afterEach(() => { + modelsService.dispose() + }) describe('getContextWindowByID', () => { it('returns default token limit for unknown model', () => { @@ -135,7 +139,7 @@ describe('Model Provider', () => { }) beforeEach(() => { - modelsService.setAuthStatus(codyProAuthStatus) + mockAuthStatus(codyProAuthStatus) modelsService.setModels([model1chat, model2chat, model3all, model4edit]) }) @@ -254,7 +258,7 @@ describe('Model Provider', () => { beforeEach(async () => { storage = new TestStorage() modelsService.setStorage(storage) - modelsService.setAuthStatus(enterpriseAuthStatus) + mockAuthStatus(enterpriseAuthStatus) await modelsService.setServerSentModels(SERVER_MODELS) }) @@ -348,37 +352,37 @@ describe('Model Provider', () => { }) it('returns false for unknown model', () => { - modelsService.setAuthStatus(codyProAuthStatus) + mockAuthStatus(codyProAuthStatus) expect(modelsService.isModelAvailable('unknown-model')).toBe(false) }) it('allows enterprise user to use any model', () => { - modelsService.setAuthStatus(enterpriseAuthStatus) + mockAuthStatus(enterpriseAuthStatus) expect(modelsService.isModelAvailable(enterpriseModel)).toBe(true) expect(modelsService.isModelAvailable(proModel)).toBe(true) expect(modelsService.isModelAvailable(freeModel)).toBe(true) }) it('allows Cody Pro user to use Pro and Free models', () => { - modelsService.setAuthStatus(codyProAuthStatus) + mockAuthStatus(codyProAuthStatus) expect(modelsService.isModelAvailable(enterpriseModel)).toBe(false) expect(modelsService.isModelAvailable(proModel)).toBe(true) expect(modelsService.isModelAvailable(freeModel)).toBe(true) }) it('allows free user to use only Free models', () => { - modelsService.setAuthStatus(freeUserAuthStatus) + mockAuthStatus(freeUserAuthStatus) expect(modelsService.isModelAvailable(enterpriseModel)).toBe(false) expect(modelsService.isModelAvailable(proModel)).toBe(false) expect(modelsService.isModelAvailable(freeModel)).toBe(true) }) it('handles model passed as string', () => { - modelsService.setAuthStatus(freeUserAuthStatus) + mockAuthStatus(freeUserAuthStatus) expect(modelsService.isModelAvailable(freeModel.id)).toBe(true) expect(modelsService.isModelAvailable(proModel.id)).toBe(false) - modelsService.setAuthStatus(codyProAuthStatus) + mockAuthStatus(codyProAuthStatus) expect(modelsService.isModelAvailable(proModel.id)).toBe(true) }) }) diff --git a/lib/shared/src/models/index.ts b/lib/shared/src/models/index.ts index 1e16faafaef5..2d8aa99945b3 100644 --- a/lib/shared/src/models/index.ts +++ b/lib/shared/src/models/index.ts @@ -1,8 +1,11 @@ import { type Observable, Subject } from 'observable-fns' +import { authStatus, currentAuthStatus } from '../auth/authStatus' import { type AuthStatus, isCodyProUser, isEnterpriseUser } from '../auth/types' -import { type ClientConfiguration, CodyIDE } from '../configuration' +import { CodyIDE } from '../configuration' +import { resolvedConfig } from '../configuration/resolver' import { fetchLocalOllamaModels } from '../llm-providers/ollama/utils' import { logDebug, logError } from '../logger' +import { type Unsubscribable, combineLatest } from '../misc/observable' import { setSingleton, singletonNotYetSet } from '../singletons' import { CHAT_INPUT_TOKEN_BUDGET, CHAT_OUTPUT_TOKEN_BUDGET } from '../token/constants' import { ModelTag } from './tags' @@ -341,14 +344,57 @@ export class ModelsService { /** persistent storage to save user preferences and server defaults */ private storage: Storage | undefined - /** current system auth status */ - private authStatus: AuthStatus | undefined - /** Cache of users preferences and defaults across each endpoint they have used */ private _preferences: PerSitePreferences | undefined private static STORAGE_KEY = 'model-preferences' + /** + * Needs to be set at static initialization time by the `vscode/` codebase. + */ + public static syncModels: ((authStatus: AuthStatus) => Promise) | undefined + + private configSubscription: Unsubscribable + + constructor() { + this.configSubscription = combineLatest([resolvedConfig, authStatus]).subscribe( + async ([{ configuration }, authStatus]) => { + try { + if (!ModelsService.syncModels) { + throw new Error( + 'ModelsService.syncModels must be set at static initialization time' + ) + } + await ModelsService.syncModels(authStatus) + } catch (error) { + logError('ModelsService', 'Failed to sync models', error) + } + + try { + const isCodyWeb = configuration.agentIDE === CodyIDE.Web + + // Disable Ollama local models for cody web + this.localModels = !isCodyWeb ? await fetchLocalOllamaModels() : [] + } catch { + this.localModels = [] + } finally { + this.changeNotifications.next() + } + } + ) + } + + public dispose(): void { + this.configSubscription.unsubscribe() + } + + private changeNotifications = new Subject() + + /** + * An observable that emits whenever the list of models or any model in the list changes. + */ + public readonly changes: Observable = this.changeNotifications + // Get all the providers currently available to the user private get models(): Model[] { return this.primaryModels.concat(this.localModels) @@ -361,7 +407,7 @@ export class ModelsService { defaults: {}, selected: {}, } - const endpoint = this.authStatus?.endpoint + const endpoint = currentAuthStatus().endpoint if (!endpoint) { if (!process.env.VITEST) { logError('ModelsService::preferences', 'No auth status set') @@ -385,21 +431,6 @@ export class ModelsService { return empty } - public async onConfigChange(config: ClientConfiguration): Promise { - try { - const isCodyWeb = config.agentIDE === CodyIDE.Web - - // Disable Ollama local models for cody web - this.localModels = !isCodyWeb ? await fetchLocalOllamaModels() : [] - } catch { - this.localModels = [] - } - } - - public async setAuthStatus(authStatus: AuthStatus) { - this.authStatus = authStatus - } - public setStorage(storage: Storage): void { this.storage = storage } @@ -410,6 +441,7 @@ export class ModelsService { public setModels(models: Model[]): void { logDebug('ModelsService', `Setting primary models: ${JSON.stringify(models.map(m => m.id))}`) this.primaryModels = models + this.changeNotifications.next() } /** @@ -440,14 +472,9 @@ export class ModelsService { await this.flush() } - private readonly _selectedOrDefaultModelChanges = new Subject() - public get selectedOrDefaultModelChanges(): Observable { - return this._selectedOrDefaultModelChanges - } - private async flush(): Promise { await this.storage?.set(ModelsService.STORAGE_KEY, JSON.stringify(this._preferences)) - this._selectedOrDefaultModelChanges.next() + this.changeNotifications.next() } /** @@ -463,6 +490,7 @@ export class ModelsService { ...this.primaryModels, ...models.filter(model => !existingIds.has(model.id)), ] + this.changeNotifications.next() } private getModelsByType(usage: ModelUsage): Model[] { @@ -524,7 +552,7 @@ export class ModelsService { } public isModelAvailable(model: string | Model): boolean { - const status = this.authStatus + const status = currentAuthStatus() if (!status) { return false } diff --git a/vscode/src/chat/chat-view/ChatController.ts b/vscode/src/chat/chat-view/ChatController.ts index 4d66e769c8d2..e4e262e64bca 100644 --- a/vscode/src/chat/chat-view/ChatController.ts +++ b/vscode/src/chat/chat-view/ChatController.ts @@ -270,6 +270,35 @@ export class ChatController implements vscode.Disposable, vscode.WebviewViewProv ) ) .subscribe({}) + ), + subscriptionDisposable( + authStatus.subscribe(() => { + // Run this async because this method may be called during initialization + // and awaiting on this.postMessage may result in a deadlock + void this.sendConfig() + }) + ), + subscriptionDisposable( + combineLatest([ + modelsService.instance!.changes.pipe(startWith(undefined)), + authStatus, + ]).subscribe(([, authStatus]) => { + // Get the latest model list available to the current user to update the ChatModel. + logError( + 'ChatController', + 'updated authStatus', + JSON.stringify({ + authStatus, + defaultModelID: getDefaultModelID(), + currentModelID: this.chatModel.modelID, + }) + ) + // TODO!(sqs): here, we need to make sure syncModels has already run after *it* + // reacted to the authStatus change + if (authStatus.authenticated) { + this.chatModel.updateModel(getDefaultModelID()) + } + }) ) ) @@ -590,17 +619,6 @@ export class ChatController implements vscode.Disposable, vscode.WebviewViewProv // #region top-level view action handlers // ======================================================================= - public setAuthStatus(status: AuthStatus): void { - // Run this async because this method may be called during initialization - // and awaiting on this.postMessage may result in a deadlock - void this.sendConfig() - - // Get the latest model list available to the current user to update the ChatModel. - if (status.authenticated) { - this.chatModel.updateModel(getDefaultModelID()) - } - } - // When the webview sends the 'ready' message, respond by posting the view config private async handleReady(): Promise { await this.sendConfig() @@ -1668,9 +1686,7 @@ export class ChatController implements vscode.Disposable, vscode.WebviewViewProv models: () => combineLatest([ resolvedConfig, - modelsService.instance!.selectedOrDefaultModelChanges.pipe( - startWith(undefined) - ), + modelsService.instance!.changes.pipe(startWith(undefined)), ]).pipe(map(() => modelsService.instance!.getModels(ModelUsage.Chat))), highlights: parameters => promiseFactoryToObservable(() => diff --git a/vscode/src/chat/chat-view/ChatsController.ts b/vscode/src/chat/chat-view/ChatsController.ts index 77105c1a670b..99180e13d00b 100644 --- a/vscode/src/chat/chat-view/ChatsController.ts +++ b/vscode/src/chat/chat-view/ChatsController.ts @@ -2,7 +2,6 @@ import * as uuid from 'uuid' import * as vscode from 'vscode' import { - type AuthStatus, type AuthenticatedAuthStatus, CODY_PASSTHROUGH_VSCODE_OPEN_COMMAND_ID, type ChatClient, @@ -83,20 +82,20 @@ export class ChatsController implements vscode.Disposable { this.panel = this.createChatController() this.disposables.push( - subscriptionDisposable(authStatus.subscribe(authStatus => this.setAuthStatus(authStatus))) - ) - } - - private async setAuthStatus(authStatus: AuthStatus): Promise { - const hasLoggedOut = !authStatus.authenticated - const hasSwitchedAccount = - this.currentAuthAccount && this.currentAuthAccount.endpoint !== authStatus.endpoint - if (hasLoggedOut || hasSwitchedAccount) { - this.disposeAllChats() - } + subscriptionDisposable( + authStatus.subscribe(authStatus => { + const hasLoggedOut = !authStatus.authenticated + const hasSwitchedAccount = + this.currentAuthAccount && + this.currentAuthAccount.endpoint !== authStatus.endpoint + if (hasLoggedOut || hasSwitchedAccount) { + this.disposeAllChats() + } - this.currentAuthAccount = authStatus.authenticated ? { ...authStatus } : undefined - this.panel.setAuthStatus(authStatus) + this.currentAuthAccount = authStatus.authenticated ? { ...authStatus } : undefined + }) + ) + ) } public async restoreToPanel(panel: vscode.WebviewPanel, chatID: string): Promise { diff --git a/vscode/src/commands/scm/source-control.ts b/vscode/src/commands/scm/source-control.ts index 210c22039096..f7ba702502a0 100644 --- a/vscode/src/commands/scm/source-control.ts +++ b/vscode/src/commands/scm/source-control.ts @@ -1,5 +1,4 @@ import { - type AuthStatus, type ChatClient, type ChatMessage, type CompletionGeneratorValue, @@ -13,6 +12,8 @@ import { getSimplePreamble, modelsService, pluralize, + startWith, + subscriptionDisposable, telemetryRecorder, } from '@sourcegraph/cody-shared' import * as vscode from 'vscode' @@ -33,7 +34,14 @@ export class CodySourceControl implements vscode.Disposable { // Register commands this.disposables.push( vscode.commands.registerCommand('cody.command.generate-commit', scm => this.generate(scm)), - vscode.commands.registerCommand('cody.command.abort-commit', () => this.statusUpdate()) + vscode.commands.registerCommand('cody.command.abort-commit', () => this.statusUpdate()), + subscriptionDisposable( + modelsService.instance!.changes.pipe(startWith(undefined)).subscribe(() => { + const models = modelsService.instance!.getModels(ModelUsage.Chat) + const preferredModel = models.find(p => p.id.includes('claude-3-haiku')) + this.model = preferredModel ?? models[0] + }) + ) ) this.initializeGitAPI() } @@ -236,12 +244,6 @@ export class CodySourceControl implements vscode.Disposable { this.abortController = abortController } - public setAuthStatus(_: AuthStatus): void { - const models = modelsService.instance!.getModels(ModelUsage.Chat) - const preferredModel = models.find(p => p.id.includes('claude-3-haiku')) - this.model = preferredModel ?? models[0] - } - public dispose(): void { for (const disposable of this.disposables) { if (disposable) { diff --git a/vscode/src/main.ts b/vscode/src/main.ts index a07fca5077ca..e1af0651ddc0 100644 --- a/vscode/src/main.ts +++ b/vscode/src/main.ts @@ -254,7 +254,6 @@ const register = async ( subscriptionDisposable( authStatus.subscribe({ next: authStatus => { - sourceControl.setAuthStatus(authStatus) statusBar.setAuthStatus(authStatus) }, }) @@ -315,7 +314,6 @@ async function initializeSingletons( void localStorage.setConfig(config) graphqlClient.setConfig(config) void featureFlagProvider.refresh() - void modelsService.instance!.onConfigChange(config) upstreamHealthProvider.instance!.onConfigurationChange(config) defaultCodeCompletionsClient.instance!.onConfigurationChange(config) }, diff --git a/vscode/src/models/sync.test.ts b/vscode/src/models/sync.test.ts index 51b226977b76..42ee903666ed 100644 --- a/vscode/src/models/sync.test.ts +++ b/vscode/src/models/sync.test.ts @@ -13,6 +13,7 @@ import { type ServerModelConfiguration, getDotComDefaultModels, graphqlClient, + mockAuthStatus, modelsService, } from '@sourcegraph/cody-shared' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' @@ -29,6 +30,7 @@ describe('syncModels', () => { beforeEach(() => { setModelsSpy.mockClear() + mockAuthStatus(AUTH_STATUS_FIXTURE_AUTHED) // Mock the /.api/client-config for these tests so that modelsAPIEnabled == false vi.spyOn(ClientConfigSingleton.prototype, 'getConfig').mockResolvedValue({ diff --git a/vscode/src/models/sync.ts b/vscode/src/models/sync.ts index b9ec8c6dc4cf..3bbe2188aa9b 100644 --- a/vscode/src/models/sync.ts +++ b/vscode/src/models/sync.ts @@ -13,7 +13,11 @@ import { modelsService, telemetryRecorder, } from '@sourcegraph/cody-shared' -import type { ServerModel, ServerModelConfiguration } from '@sourcegraph/cody-shared/src/models' +import { + ModelsService, + type ServerModel, + type ServerModelConfiguration, +} from '@sourcegraph/cody-shared/src/models' import { ModelTag } from '@sourcegraph/cody-shared/src/models/tags' import * as vscode from 'vscode' import { getConfiguration } from '../configuration' @@ -31,7 +35,6 @@ import { getEnterpriseContextWindow } from './utils' */ export async function syncModels(authStatus: AuthStatus): Promise { // Offline mode only support Ollama models, which would be synced seperately. - modelsService.instance!.setAuthStatus(authStatus) if (authStatus.authenticated && authStatus.isOfflineMode) { modelsService.instance!.setModels([]) return @@ -52,7 +55,7 @@ export async function syncModels(authStatus: AuthStatus): Promise { const serverSideModels = await fetchServerSideModels(authStatus.endpoint || '') // If the request failed, fall back to using the default models if (serverSideModels) { - modelsService.instance!.setServerSentModels({ + await modelsService.instance!.setServerSentModels({ ...serverSideModels, models: maybeAdjustContextWindows(serverSideModels.models), }) @@ -121,6 +124,8 @@ export async function syncModels(authStatus: AuthStatus): Promise { } } +ModelsService.syncModels = syncModels + export async function joinModelWaitlist(authStatus: AuthStatus): Promise { localStorage.set(localStorage.keys.waitlist_o1, true) await syncModels(authStatus) diff --git a/vscode/src/services/AuthProvider.ts b/vscode/src/services/AuthProvider.ts index 160736c32eb1..a171b548b1f8 100644 --- a/vscode/src/services/AuthProvider.ts +++ b/vscode/src/services/AuthProvider.ts @@ -20,7 +20,6 @@ import { formatURL } from '../auth/auth' import { newAuthStatus } from '../chat/utils' import { getFullConfig } from '../configuration' import { logDebug } from '../log' -import { syncModels } from '../models/sync' import { maybeStartInteractiveTutorial } from '../tutorial/helpers' import { localStorage } from './LocalStorageProvider' import { secretStorage } from './SecretStorageProvider' @@ -232,7 +231,6 @@ export class AuthProvider implements vscode.Disposable { // because many listeners rely on these graphqlClient.setConfig(await getFullConfig()) await ClientConfigSingleton.getInstance().setAuthStatus(authStatus) - await syncModels(authStatus) } catch (error) { logDebug('AuthProvider', 'updateAuthStatus error', error) } finally {