Skip to content

Commit

Permalink
expose ModelsService.changes and listen to that in ChatController
Browse files Browse the repository at this point in the history
Previously, the ChatController listened for auth changes and assumed that model changes would have been already applied. This required a specific ordering of operations in the AuthProvider, which was a brittle way of ensuring this behavior (and the `chat model selector` e2e test broke when that behavior was slightly changed). Now, the ModelsService exposes a `changes` observable and the ChatController listens to that to follow changes to models, which makes sense.
  • Loading branch information
sqs committed Sep 15, 2024
1 parent 0e9eaa7 commit c093512
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 79 deletions.
22 changes: 13 additions & 9 deletions lib/shared/src/models/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { beforeEach, describe, expect, it } from 'vitest'
import { afterEach, beforeEach, describe, expect, it } from 'vitest'
import { mockAuthStatus } from '../auth/authStatus'
import { AUTH_STATUS_FIXTURE_AUTHED, type AuthenticatedAuthStatus } from '../auth/types'
import {
Model,
Expand Down Expand Up @@ -39,6 +40,9 @@ describe('Model Provider', () => {
beforeEach(() => {
modelsService = new ModelsService()
})
afterEach(() => {
modelsService.dispose()
})

describe('getContextWindowByID', () => {
it('returns default token limit for unknown model', () => {
Expand Down Expand Up @@ -135,7 +139,7 @@ describe('Model Provider', () => {
})

beforeEach(() => {
modelsService.setAuthStatus(codyProAuthStatus)
mockAuthStatus(codyProAuthStatus)
modelsService.setModels([model1chat, model2chat, model3all, model4edit])
})

Expand Down Expand Up @@ -254,7 +258,7 @@ describe('Model Provider', () => {
beforeEach(async () => {
storage = new TestStorage()
modelsService.setStorage(storage)
modelsService.setAuthStatus(enterpriseAuthStatus)
mockAuthStatus(enterpriseAuthStatus)
await modelsService.setServerSentModels(SERVER_MODELS)
})

Expand Down Expand Up @@ -348,37 +352,37 @@ describe('Model Provider', () => {
})

it('returns false for unknown model', () => {
modelsService.setAuthStatus(codyProAuthStatus)
mockAuthStatus(codyProAuthStatus)
expect(modelsService.isModelAvailable('unknown-model')).toBe(false)
})

it('allows enterprise user to use any model', () => {
modelsService.setAuthStatus(enterpriseAuthStatus)
mockAuthStatus(enterpriseAuthStatus)
expect(modelsService.isModelAvailable(enterpriseModel)).toBe(true)
expect(modelsService.isModelAvailable(proModel)).toBe(true)
expect(modelsService.isModelAvailable(freeModel)).toBe(true)
})

it('allows Cody Pro user to use Pro and Free models', () => {
modelsService.setAuthStatus(codyProAuthStatus)
mockAuthStatus(codyProAuthStatus)
expect(modelsService.isModelAvailable(enterpriseModel)).toBe(false)
expect(modelsService.isModelAvailable(proModel)).toBe(true)
expect(modelsService.isModelAvailable(freeModel)).toBe(true)
})

it('allows free user to use only Free models', () => {
modelsService.setAuthStatus(freeUserAuthStatus)
mockAuthStatus(freeUserAuthStatus)
expect(modelsService.isModelAvailable(enterpriseModel)).toBe(false)
expect(modelsService.isModelAvailable(proModel)).toBe(false)
expect(modelsService.isModelAvailable(freeModel)).toBe(true)
})

it('handles model passed as string', () => {
modelsService.setAuthStatus(freeUserAuthStatus)
mockAuthStatus(freeUserAuthStatus)
expect(modelsService.isModelAvailable(freeModel.id)).toBe(true)
expect(modelsService.isModelAvailable(proModel.id)).toBe(false)

modelsService.setAuthStatus(codyProAuthStatus)
mockAuthStatus(codyProAuthStatus)
expect(modelsService.isModelAvailable(proModel.id)).toBe(true)
})
})
Expand Down
82 changes: 55 additions & 27 deletions lib/shared/src/models/index.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import { type Observable, Subject } from 'observable-fns'
import { authStatus, currentAuthStatus } from '../auth/authStatus'
import { type AuthStatus, isCodyProUser, isEnterpriseUser } from '../auth/types'
import { type ClientConfiguration, CodyIDE } from '../configuration'
import { CodyIDE } from '../configuration'
import { resolvedConfig } from '../configuration/resolver'
import { fetchLocalOllamaModels } from '../llm-providers/ollama/utils'
import { logDebug, logError } from '../logger'
import { type Unsubscribable, combineLatest } from '../misc/observable'
import { setSingleton, singletonNotYetSet } from '../singletons'
import { CHAT_INPUT_TOKEN_BUDGET, CHAT_OUTPUT_TOKEN_BUDGET } from '../token/constants'
import { ModelTag } from './tags'
Expand Down Expand Up @@ -341,14 +344,57 @@ export class ModelsService {
/** persistent storage to save user preferences and server defaults */
private storage: Storage | undefined

/** current system auth status */
private authStatus: AuthStatus | undefined

/** Cache of users preferences and defaults across each endpoint they have used */
private _preferences: PerSitePreferences | undefined

private static STORAGE_KEY = 'model-preferences'

/**
* Needs to be set at static initialization time by the `vscode/` codebase.
*/
public static syncModels: ((authStatus: AuthStatus) => Promise<void>) | undefined

private configSubscription: Unsubscribable

constructor() {
this.configSubscription = combineLatest([resolvedConfig, authStatus]).subscribe(
async ([{ configuration }, authStatus]) => {
try {
if (!ModelsService.syncModels) {
throw new Error(
'ModelsService.syncModels must be set at static initialization time'
)
}
await ModelsService.syncModels(authStatus)
} catch (error) {
logError('ModelsService', 'Failed to sync models', error)
}

try {
const isCodyWeb = configuration.agentIDE === CodyIDE.Web

// Disable Ollama local models for cody web
this.localModels = !isCodyWeb ? await fetchLocalOllamaModels() : []
} catch {
this.localModels = []
} finally {
this.changeNotifications.next()
}
}
)
}

public dispose(): void {
this.configSubscription.unsubscribe()
}

private changeNotifications = new Subject<void>()

/**
* An observable that emits whenever the list of models or any model in the list changes.
*/
public readonly changes: Observable<void> = this.changeNotifications

// Get all the providers currently available to the user
private get models(): Model[] {
return this.primaryModels.concat(this.localModels)
Expand All @@ -361,7 +407,7 @@ export class ModelsService {
defaults: {},
selected: {},
}
const endpoint = this.authStatus?.endpoint
const endpoint = currentAuthStatus().endpoint
if (!endpoint) {
if (!process.env.VITEST) {
logError('ModelsService::preferences', 'No auth status set')
Expand All @@ -385,21 +431,6 @@ export class ModelsService {
return empty
}

public async onConfigChange(config: ClientConfiguration): Promise<void> {
try {
const isCodyWeb = config.agentIDE === CodyIDE.Web

// Disable Ollama local models for cody web
this.localModels = !isCodyWeb ? await fetchLocalOllamaModels() : []
} catch {
this.localModels = []
}
}

public async setAuthStatus(authStatus: AuthStatus) {
this.authStatus = authStatus
}

public setStorage(storage: Storage): void {
this.storage = storage
}
Expand All @@ -410,6 +441,7 @@ export class ModelsService {
public setModels(models: Model[]): void {
logDebug('ModelsService', `Setting primary models: ${JSON.stringify(models.map(m => m.id))}`)
this.primaryModels = models
this.changeNotifications.next()
}

/**
Expand Down Expand Up @@ -440,14 +472,9 @@ export class ModelsService {
await this.flush()
}

private readonly _selectedOrDefaultModelChanges = new Subject<void>()
public get selectedOrDefaultModelChanges(): Observable<void> {
return this._selectedOrDefaultModelChanges
}

private async flush(): Promise<void> {
await this.storage?.set(ModelsService.STORAGE_KEY, JSON.stringify(this._preferences))
this._selectedOrDefaultModelChanges.next()
this.changeNotifications.next()
}

/**
Expand All @@ -463,6 +490,7 @@ export class ModelsService {
...this.primaryModels,
...models.filter(model => !existingIds.has(model.id)),
]
this.changeNotifications.next()
}

private getModelsByType(usage: ModelUsage): Model[] {
Expand Down Expand Up @@ -524,7 +552,7 @@ export class ModelsService {
}

public isModelAvailable(model: string | Model): boolean {
const status = this.authStatus
const status = currentAuthStatus()
if (!status) {
return false
}
Expand Down
44 changes: 30 additions & 14 deletions vscode/src/chat/chat-view/ChatController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,35 @@ export class ChatController implements vscode.Disposable, vscode.WebviewViewProv
)
)
.subscribe({})
),
subscriptionDisposable(
authStatus.subscribe(() => {
// Run this async because this method may be called during initialization
// and awaiting on this.postMessage may result in a deadlock
void this.sendConfig()
})
),
subscriptionDisposable(
combineLatest([
modelsService.instance!.changes.pipe(startWith(undefined)),
authStatus,
]).subscribe(([, authStatus]) => {
// Get the latest model list available to the current user to update the ChatModel.
logError(
'ChatController',
'updated authStatus',
JSON.stringify({
authStatus,
defaultModelID: getDefaultModelID(),
currentModelID: this.chatModel.modelID,
})
)
// TODO!(sqs): here, we need to make sure syncModels has already run after *it*
// reacted to the authStatus change
if (authStatus.authenticated) {
this.chatModel.updateModel(getDefaultModelID())
}
})
)
)

Expand Down Expand Up @@ -590,17 +619,6 @@ export class ChatController implements vscode.Disposable, vscode.WebviewViewProv
// #region top-level view action handlers
// =======================================================================

public setAuthStatus(status: AuthStatus): void {
// Run this async because this method may be called during initialization
// and awaiting on this.postMessage may result in a deadlock
void this.sendConfig()

// Get the latest model list available to the current user to update the ChatModel.
if (status.authenticated) {
this.chatModel.updateModel(getDefaultModelID())
}
}

// When the webview sends the 'ready' message, respond by posting the view config
private async handleReady(): Promise<void> {
await this.sendConfig()
Expand Down Expand Up @@ -1668,9 +1686,7 @@ export class ChatController implements vscode.Disposable, vscode.WebviewViewProv
models: () =>
combineLatest([
resolvedConfig,
modelsService.instance!.selectedOrDefaultModelChanges.pipe(
startWith(undefined)
),
modelsService.instance!.changes.pipe(startWith(undefined)),
]).pipe(map(() => modelsService.instance!.getModels(ModelUsage.Chat))),
highlights: parameters =>
promiseFactoryToObservable(() =>
Expand Down
27 changes: 13 additions & 14 deletions vscode/src/chat/chat-view/ChatsController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import * as uuid from 'uuid'
import * as vscode from 'vscode'

import {
type AuthStatus,
type AuthenticatedAuthStatus,
CODY_PASSTHROUGH_VSCODE_OPEN_COMMAND_ID,
type ChatClient,
Expand Down Expand Up @@ -83,20 +82,20 @@ export class ChatsController implements vscode.Disposable {
this.panel = this.createChatController()

this.disposables.push(
subscriptionDisposable(authStatus.subscribe(authStatus => this.setAuthStatus(authStatus)))
)
}

private async setAuthStatus(authStatus: AuthStatus): Promise<void> {
const hasLoggedOut = !authStatus.authenticated
const hasSwitchedAccount =
this.currentAuthAccount && this.currentAuthAccount.endpoint !== authStatus.endpoint
if (hasLoggedOut || hasSwitchedAccount) {
this.disposeAllChats()
}
subscriptionDisposable(
authStatus.subscribe(authStatus => {
const hasLoggedOut = !authStatus.authenticated
const hasSwitchedAccount =
this.currentAuthAccount &&
this.currentAuthAccount.endpoint !== authStatus.endpoint
if (hasLoggedOut || hasSwitchedAccount) {
this.disposeAllChats()
}

this.currentAuthAccount = authStatus.authenticated ? { ...authStatus } : undefined
this.panel.setAuthStatus(authStatus)
this.currentAuthAccount = authStatus.authenticated ? { ...authStatus } : undefined
})
)
)
}

public async restoreToPanel(panel: vscode.WebviewPanel, chatID: string): Promise<void> {
Expand Down
18 changes: 10 additions & 8 deletions vscode/src/commands/scm/source-control.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import {
type AuthStatus,
type ChatClient,
type ChatMessage,
type CompletionGeneratorValue,
Expand All @@ -13,6 +12,8 @@ import {
getSimplePreamble,
modelsService,
pluralize,
startWith,
subscriptionDisposable,
telemetryRecorder,
} from '@sourcegraph/cody-shared'
import * as vscode from 'vscode'
Expand All @@ -33,7 +34,14 @@ export class CodySourceControl implements vscode.Disposable {
// Register commands
this.disposables.push(
vscode.commands.registerCommand('cody.command.generate-commit', scm => this.generate(scm)),
vscode.commands.registerCommand('cody.command.abort-commit', () => this.statusUpdate())
vscode.commands.registerCommand('cody.command.abort-commit', () => this.statusUpdate()),
subscriptionDisposable(
modelsService.instance!.changes.pipe(startWith(undefined)).subscribe(() => {
const models = modelsService.instance!.getModels(ModelUsage.Chat)
const preferredModel = models.find(p => p.id.includes('claude-3-haiku'))
this.model = preferredModel ?? models[0]
})
)
)
this.initializeGitAPI()
}
Expand Down Expand Up @@ -236,12 +244,6 @@ export class CodySourceControl implements vscode.Disposable {
this.abortController = abortController
}

public setAuthStatus(_: AuthStatus): void {
const models = modelsService.instance!.getModels(ModelUsage.Chat)
const preferredModel = models.find(p => p.id.includes('claude-3-haiku'))
this.model = preferredModel ?? models[0]
}

public dispose(): void {
for (const disposable of this.disposables) {
if (disposable) {
Expand Down
2 changes: 0 additions & 2 deletions vscode/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ const register = async (
subscriptionDisposable(
authStatus.subscribe({
next: authStatus => {
sourceControl.setAuthStatus(authStatus)
statusBar.setAuthStatus(authStatus)
},
})
Expand Down Expand Up @@ -315,7 +314,6 @@ async function initializeSingletons(
void localStorage.setConfig(config)
graphqlClient.setConfig(config)
void featureFlagProvider.refresh()
void modelsService.instance!.onConfigChange(config)
upstreamHealthProvider.instance!.onConfigurationChange(config)
defaultCodeCompletionsClient.instance!.onConfigurationChange(config)
},
Expand Down
Loading

0 comments on commit c093512

Please sign in to comment.