Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expose ModelsService.changes and listen to that in ChatController #5578

Merged
merged 1 commit into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions lib/shared/src/models/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { beforeEach, describe, expect, it } from 'vitest'
import { afterEach, beforeEach, describe, expect, it } from 'vitest'
import { mockAuthStatus } from '../auth/authStatus'
import { AUTH_STATUS_FIXTURE_AUTHED, type AuthenticatedAuthStatus } from '../auth/types'
import {
Model,
Expand Down Expand Up @@ -39,6 +40,9 @@ describe('Model Provider', () => {
beforeEach(() => {
modelsService = new ModelsService()
})
afterEach(() => {
modelsService.dispose()
})

describe('getContextWindowByID', () => {
it('returns default token limit for unknown model', () => {
Expand Down Expand Up @@ -135,7 +139,7 @@ describe('Model Provider', () => {
})

beforeEach(() => {
modelsService.setAuthStatus(codyProAuthStatus)
mockAuthStatus(codyProAuthStatus)
modelsService.setModels([model1chat, model2chat, model3all, model4edit])
})

Expand Down Expand Up @@ -254,7 +258,7 @@ describe('Model Provider', () => {
beforeEach(async () => {
storage = new TestStorage()
modelsService.setStorage(storage)
modelsService.setAuthStatus(enterpriseAuthStatus)
mockAuthStatus(enterpriseAuthStatus)
await modelsService.setServerSentModels(SERVER_MODELS)
})

Expand Down Expand Up @@ -348,37 +352,37 @@ describe('Model Provider', () => {
})

it('returns false for unknown model', () => {
modelsService.setAuthStatus(codyProAuthStatus)
mockAuthStatus(codyProAuthStatus)
expect(modelsService.isModelAvailable('unknown-model')).toBe(false)
})

it('allows enterprise user to use any model', () => {
modelsService.setAuthStatus(enterpriseAuthStatus)
mockAuthStatus(enterpriseAuthStatus)
expect(modelsService.isModelAvailable(enterpriseModel)).toBe(true)
expect(modelsService.isModelAvailable(proModel)).toBe(true)
expect(modelsService.isModelAvailable(freeModel)).toBe(true)
})

it('allows Cody Pro user to use Pro and Free models', () => {
modelsService.setAuthStatus(codyProAuthStatus)
mockAuthStatus(codyProAuthStatus)
expect(modelsService.isModelAvailable(enterpriseModel)).toBe(false)
expect(modelsService.isModelAvailable(proModel)).toBe(true)
expect(modelsService.isModelAvailable(freeModel)).toBe(true)
})

it('allows free user to use only Free models', () => {
modelsService.setAuthStatus(freeUserAuthStatus)
mockAuthStatus(freeUserAuthStatus)
expect(modelsService.isModelAvailable(enterpriseModel)).toBe(false)
expect(modelsService.isModelAvailable(proModel)).toBe(false)
expect(modelsService.isModelAvailable(freeModel)).toBe(true)
})

it('handles model passed as string', () => {
modelsService.setAuthStatus(freeUserAuthStatus)
mockAuthStatus(freeUserAuthStatus)
expect(modelsService.isModelAvailable(freeModel.id)).toBe(true)
expect(modelsService.isModelAvailable(proModel.id)).toBe(false)

modelsService.setAuthStatus(codyProAuthStatus)
mockAuthStatus(codyProAuthStatus)
expect(modelsService.isModelAvailable(proModel.id)).toBe(true)
})
})
Expand Down
82 changes: 55 additions & 27 deletions lib/shared/src/models/index.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import { type Observable, Subject } from 'observable-fns'
import { authStatus, currentAuthStatus } from '../auth/authStatus'
import { type AuthStatus, isCodyProUser, isEnterpriseUser } from '../auth/types'
import { type ClientConfiguration, CodyIDE } from '../configuration'
import { CodyIDE } from '../configuration'
import { resolvedConfig } from '../configuration/resolver'
import { fetchLocalOllamaModels } from '../llm-providers/ollama/utils'
import { logDebug, logError } from '../logger'
import { type Unsubscribable, combineLatest } from '../misc/observable'
import { setSingleton, singletonNotYetSet } from '../singletons'
import { CHAT_INPUT_TOKEN_BUDGET, CHAT_OUTPUT_TOKEN_BUDGET } from '../token/constants'
import { ModelTag } from './tags'
Expand Down Expand Up @@ -341,14 +344,57 @@ export class ModelsService {
/** persistent storage to save user preferences and server defaults */
private storage: Storage | undefined

/** current system auth status */
private authStatus: AuthStatus | undefined

/** Cache of users preferences and defaults across each endpoint they have used */
private _preferences: PerSitePreferences | undefined

private static STORAGE_KEY = 'model-preferences'

/**
* Needs to be set at static initialization time by the `vscode/` codebase.
*/
public static syncModels: ((authStatus: AuthStatus) => Promise<void>) | undefined

private configSubscription: Unsubscribable

constructor() {
this.configSubscription = combineLatest([resolvedConfig, authStatus]).subscribe(
async ([{ configuration }, authStatus]) => {
try {
if (!ModelsService.syncModels) {
throw new Error(
'ModelsService.syncModels must be set at static initialization time'
)
}
await ModelsService.syncModels(authStatus)
} catch (error) {
logError('ModelsService', 'Failed to sync models', error)
}

try {
const isCodyWeb = configuration.agentIDE === CodyIDE.Web

// Disable Ollama local models for cody web
this.localModels = !isCodyWeb ? await fetchLocalOllamaModels() : []
} catch {
this.localModels = []
} finally {
this.changeNotifications.next()
}
}
)
}

public dispose(): void {
this.configSubscription.unsubscribe()
}

private changeNotifications = new Subject<void>()

/**
* An observable that emits whenever the list of models or any model in the list changes.
*/
public readonly changes: Observable<void> = this.changeNotifications

// Get all the providers currently available to the user
private get models(): Model[] {
return this.primaryModels.concat(this.localModels)
Expand All @@ -361,7 +407,7 @@ export class ModelsService {
defaults: {},
selected: {},
}
const endpoint = this.authStatus?.endpoint
const endpoint = currentAuthStatus().endpoint
if (!endpoint) {
if (!process.env.VITEST) {
logError('ModelsService::preferences', 'No auth status set')
Expand All @@ -385,21 +431,6 @@ export class ModelsService {
return empty
}

public async onConfigChange(config: ClientConfiguration): Promise<void> {
try {
const isCodyWeb = config.agentIDE === CodyIDE.Web

// Disable Ollama local models for cody web
this.localModels = !isCodyWeb ? await fetchLocalOllamaModels() : []
} catch {
this.localModels = []
}
}

public async setAuthStatus(authStatus: AuthStatus) {
this.authStatus = authStatus
}

public setStorage(storage: Storage): void {
this.storage = storage
}
Expand All @@ -410,6 +441,7 @@ export class ModelsService {
public setModels(models: Model[]): void {
logDebug('ModelsService', `Setting primary models: ${JSON.stringify(models.map(m => m.id))}`)
this.primaryModels = models
this.changeNotifications.next()
}

/**
Expand Down Expand Up @@ -440,14 +472,9 @@ export class ModelsService {
await this.flush()
}

private readonly _selectedOrDefaultModelChanges = new Subject<void>()
public get selectedOrDefaultModelChanges(): Observable<void> {
return this._selectedOrDefaultModelChanges
}

private async flush(): Promise<void> {
await this.storage?.set(ModelsService.STORAGE_KEY, JSON.stringify(this._preferences))
this._selectedOrDefaultModelChanges.next()
this.changeNotifications.next()
}

/**
Expand All @@ -463,6 +490,7 @@ export class ModelsService {
...this.primaryModels,
...models.filter(model => !existingIds.has(model.id)),
]
this.changeNotifications.next()
}

private getModelsByType(usage: ModelUsage): Model[] {
Expand Down Expand Up @@ -524,7 +552,7 @@ export class ModelsService {
}

public isModelAvailable(model: string | Model): boolean {
const status = this.authStatus
const status = currentAuthStatus()
if (!status) {
return false
}
Expand Down
44 changes: 30 additions & 14 deletions vscode/src/chat/chat-view/ChatController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,35 @@ export class ChatController implements vscode.Disposable, vscode.WebviewViewProv
)
)
.subscribe({})
),
subscriptionDisposable(
authStatus.subscribe(() => {
// Run this async because this method may be called during initialization
// and awaiting on this.postMessage may result in a deadlock
void this.sendConfig()
})
),
subscriptionDisposable(
combineLatest([
modelsService.instance!.changes.pipe(startWith(undefined)),
authStatus,
]).subscribe(([, authStatus]) => {
// Get the latest model list available to the current user to update the ChatModel.
logError(
'ChatController',
'updated authStatus',
JSON.stringify({
authStatus,
defaultModelID: getDefaultModelID(),
currentModelID: this.chatModel.modelID,
})
)
// TODO!(sqs): here, we need to make sure syncModels has already run after *it*
// reacted to the authStatus change
if (authStatus.authenticated) {
this.chatModel.updateModel(getDefaultModelID())
}
})
)
)

Expand Down Expand Up @@ -590,17 +619,6 @@ export class ChatController implements vscode.Disposable, vscode.WebviewViewProv
// #region top-level view action handlers
// =======================================================================

public setAuthStatus(status: AuthStatus): void {
// Run this async because this method may be called during initialization
// and awaiting on this.postMessage may result in a deadlock
void this.sendConfig()

// Get the latest model list available to the current user to update the ChatModel.
if (status.authenticated) {
this.chatModel.updateModel(getDefaultModelID())
}
}

// When the webview sends the 'ready' message, respond by posting the view config
private async handleReady(): Promise<void> {
await this.sendConfig()
Expand Down Expand Up @@ -1668,9 +1686,7 @@ export class ChatController implements vscode.Disposable, vscode.WebviewViewProv
models: () =>
combineLatest([
resolvedConfig,
modelsService.instance!.selectedOrDefaultModelChanges.pipe(
startWith(undefined)
),
modelsService.instance!.changes.pipe(startWith(undefined)),
]).pipe(map(() => modelsService.instance!.getModels(ModelUsage.Chat))),
highlights: parameters =>
promiseFactoryToObservable(() =>
Expand Down
27 changes: 13 additions & 14 deletions vscode/src/chat/chat-view/ChatsController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import * as uuid from 'uuid'
import * as vscode from 'vscode'

import {
type AuthStatus,
type AuthenticatedAuthStatus,
CODY_PASSTHROUGH_VSCODE_OPEN_COMMAND_ID,
type ChatClient,
Expand Down Expand Up @@ -83,20 +82,20 @@ export class ChatsController implements vscode.Disposable {
this.panel = this.createChatController()

this.disposables.push(
subscriptionDisposable(authStatus.subscribe(authStatus => this.setAuthStatus(authStatus)))
)
}

private async setAuthStatus(authStatus: AuthStatus): Promise<void> {
const hasLoggedOut = !authStatus.authenticated
const hasSwitchedAccount =
this.currentAuthAccount && this.currentAuthAccount.endpoint !== authStatus.endpoint
if (hasLoggedOut || hasSwitchedAccount) {
this.disposeAllChats()
}
subscriptionDisposable(
authStatus.subscribe(authStatus => {
const hasLoggedOut = !authStatus.authenticated
const hasSwitchedAccount =
this.currentAuthAccount &&
this.currentAuthAccount.endpoint !== authStatus.endpoint
if (hasLoggedOut || hasSwitchedAccount) {
this.disposeAllChats()
}

this.currentAuthAccount = authStatus.authenticated ? { ...authStatus } : undefined
this.panel.setAuthStatus(authStatus)
this.currentAuthAccount = authStatus.authenticated ? { ...authStatus } : undefined
})
)
)
}

public async restoreToPanel(panel: vscode.WebviewPanel, chatID: string): Promise<void> {
Expand Down
18 changes: 10 additions & 8 deletions vscode/src/commands/scm/source-control.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import {
type AuthStatus,
type ChatClient,
type ChatMessage,
type CompletionGeneratorValue,
Expand All @@ -13,6 +12,8 @@ import {
getSimplePreamble,
modelsService,
pluralize,
startWith,
subscriptionDisposable,
telemetryRecorder,
} from '@sourcegraph/cody-shared'
import * as vscode from 'vscode'
Expand All @@ -33,7 +34,14 @@ export class CodySourceControl implements vscode.Disposable {
// Register commands
this.disposables.push(
vscode.commands.registerCommand('cody.command.generate-commit', scm => this.generate(scm)),
vscode.commands.registerCommand('cody.command.abort-commit', () => this.statusUpdate())
vscode.commands.registerCommand('cody.command.abort-commit', () => this.statusUpdate()),
subscriptionDisposable(
modelsService.instance!.changes.pipe(startWith(undefined)).subscribe(() => {
const models = modelsService.instance!.getModels(ModelUsage.Chat)
const preferredModel = models.find(p => p.id.includes('claude-3-haiku'))
this.model = preferredModel ?? models[0]
})
)
)
this.initializeGitAPI()
}
Expand Down Expand Up @@ -236,12 +244,6 @@ export class CodySourceControl implements vscode.Disposable {
this.abortController = abortController
}

public setAuthStatus(_: AuthStatus): void {
const models = modelsService.instance!.getModels(ModelUsage.Chat)
const preferredModel = models.find(p => p.id.includes('claude-3-haiku'))
this.model = preferredModel ?? models[0]
}

public dispose(): void {
for (const disposable of this.disposables) {
if (disposable) {
Expand Down
2 changes: 0 additions & 2 deletions vscode/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ const register = async (
subscriptionDisposable(
authStatus.subscribe({
next: authStatus => {
sourceControl.setAuthStatus(authStatus)
statusBar.setAuthStatus(authStatus)
},
})
Expand Down Expand Up @@ -315,7 +314,6 @@ async function initializeSingletons(
void localStorage.setConfig(config)
graphqlClient.setConfig(config)
void featureFlagProvider.refresh()
void modelsService.instance!.onConfigChange(config)
upstreamHealthProvider.instance!.onConfigurationChange(config)
defaultCodeCompletionsClient.instance!.onConfigurationChange(config)
},
Expand Down
Loading
Loading