-
Notifications
You must be signed in to change notification settings - Fork 297
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Autocomplete: Add fireworks provider (#441)
Add a new autocomplete provider aimed at the Fireworks API. Specifically this runs the StarCoder model config with a fill-in-middle prompt. We're still working on various latency improvements for the backend but this is already working okayish enough to dogfood. If someone wants to test it, I'll be posting API keys in Slack (sorry for the open source community, but we'll be getting there 🤗) Let's get this out as a dogfooding build soonish. ## Test plan Configure the new provider like this: ``` "cody.autocomplete.advanced.provider": "unstable-fireworks", "cody.autocomplete.advanced.accessToken": "TOKEN", "cody.autocomplete.advanced.serverEndpoint": "https://api.fireworks.ai/inference/v1/completions", ``` <img width="1180" alt="Screenshot 2023-07-28 at 11 58 13" src="https://github.com/sourcegraph/cody/assets/458591/8a57b7a3-db26-45ca-8c9b-65d359151e8a"> <!-- Required. See https://docs.sourcegraph.com/dev/background-information/testing_principles. -->
- Loading branch information
1 parent
98d2cc9
commit 47c309e
Showing
6 changed files
with
420 additions
and
10 deletions.
There are no files selected for viewing
243 changes: 243 additions & 0 deletions
243
completions-review-tool/data/starcoder-fireworks-1690539528108.json
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import fetch from 'isomorphic-fetch' | ||
|
||
import { Completion } from '..' | ||
import { logger } from '../../log' | ||
import { ReferenceSnippet } from '../context' | ||
import { getLanguageConfig } from '../language' | ||
import { isAbortError } from '../utils' | ||
|
||
import { Provider, ProviderConfig, ProviderOptions } from './provider' | ||
|
||
interface UnstableFireworksOptions { | ||
serverEndpoint: string | ||
accessToken: null | string | ||
} | ||
|
||
const PROVIDER_IDENTIFIER = 'fireworks' | ||
const STOP_WORD = '<|endoftext|>' | ||
const CONTEXT_WINDOW_CHARS = 3500 // ~ 1280 token limit | ||
|
||
export class UnstableFireworksProvider extends Provider { | ||
private serverEndpoint: string | ||
private accessToken: null | string | ||
|
||
constructor(options: ProviderOptions, unstableFireworksOptions: UnstableFireworksOptions) { | ||
super(options) | ||
this.serverEndpoint = unstableFireworksOptions.serverEndpoint | ||
this.accessToken = unstableFireworksOptions.accessToken | ||
} | ||
|
||
private createPrompt(snippets: ReferenceSnippet[]): string { | ||
const maxPromptChars = CONTEXT_WINDOW_CHARS - CONTEXT_WINDOW_CHARS * this.options.responsePercentage | ||
|
||
const intro: string[] = [] | ||
let prompt = '' | ||
|
||
const languageConfig = getLanguageConfig(this.options.languageId) | ||
if (languageConfig) { | ||
intro.push(`Path: ${this.options.fileName}`) | ||
} | ||
|
||
for (let snippetsToInclude = 0; snippetsToInclude < snippets.length + 1; snippetsToInclude++) { | ||
if (snippetsToInclude > 0) { | ||
const snippet = snippets[snippetsToInclude - 1] | ||
intro.push(`Here is a reference snippet of code from ${snippet.fileName}:\n\n${snippet.content}`) | ||
} | ||
|
||
const introString = | ||
intro | ||
.join('\n\n') | ||
.split('\n') | ||
.map(line => (languageConfig ? languageConfig.commentStart + line : '')) | ||
.join('\n') + '\n' | ||
|
||
// Prompt format is taken form https://starcoder.co/bigcode/starcoder#fill-in-the-middle | ||
const nextPrompt = `<fim_prefix>${introString}${this.options.prefix}<fim_suffix>${this.options.suffix}<fim_middle>` | ||
|
||
if (nextPrompt.length >= maxPromptChars) { | ||
return prompt | ||
} | ||
|
||
prompt = nextPrompt | ||
} | ||
|
||
return prompt | ||
} | ||
|
||
public async generateCompletions(abortSignal: AbortSignal, snippets: ReferenceSnippet[]): Promise<Completion[]> { | ||
const prompt = this.createPrompt(snippets) | ||
|
||
const request = { | ||
prompt, | ||
// To speed up sample generation in single-line case, we request a lower token limit | ||
// since we can't terminate on the first `\n`. | ||
max_tokens: this.options.multiline ? 256 : 30, | ||
temperature: 0.4, | ||
top_p: 0.95, | ||
min_tokens: 1, | ||
n: this.options.n, | ||
echo: false, | ||
model: 'fireworks-starcoder-16b-w8a16', | ||
} | ||
console.log(request) | ||
|
||
const log = logger.startCompletion({ | ||
request, | ||
provider: PROVIDER_IDENTIFIER, | ||
serverEndpoint: this.serverEndpoint, | ||
}) | ||
|
||
const response = await fetch(this.serverEndpoint, { | ||
method: 'POST', | ||
body: JSON.stringify(request), | ||
headers: { | ||
'Content-Type': 'application/json', | ||
Authorization: `Bearer ${this.accessToken}`, | ||
}, | ||
signal: abortSignal, | ||
}) | ||
|
||
try { | ||
const data = (await response.json()) as | ||
| { choices: { text: string; finish_reason: string }[] } | ||
| { error: string } | ||
|
||
if ('error' in data) { | ||
throw new Error(data.error) | ||
} | ||
|
||
const completions = data.choices.map(c => ({ | ||
content: postProcess(c.text, this.options.multiline), | ||
stopReason: c.finish_reason, | ||
})) | ||
log?.onComplete(completions.map(c => c.content)) | ||
|
||
return completions.map(c => ({ | ||
prefix: this.options.prefix, | ||
content: c.content, | ||
stopReason: c.stopReason, | ||
})) | ||
} catch (error: any) { | ||
if (!isAbortError(error)) { | ||
log?.onError(error) | ||
} | ||
|
||
throw error | ||
} | ||
} | ||
} | ||
|
||
function postProcess(content: string, multiline: boolean): string { | ||
content = content.replace(STOP_WORD, '') | ||
|
||
// The model might return multiple lines for single line completions because | ||
// we are only able to specify a token limit. | ||
if (!multiline && content.includes('\n')) { | ||
content = content.slice(0, content.indexOf('\n')) | ||
} | ||
|
||
return content.trim() | ||
} | ||
|
||
export function createProviderConfig(unstableFireworksOptions: UnstableFireworksOptions): ProviderConfig { | ||
return { | ||
create(options: ProviderOptions) { | ||
return new UnstableFireworksProvider(options, unstableFireworksOptions) | ||
}, | ||
maximumContextCharacters: CONTEXT_WINDOW_CHARS, | ||
enableExtendedMultilineTriggers: true, | ||
identifier: PROVIDER_IDENTIFIER, | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters