diff --git a/lib/shared/src/sourcegraph-api/completions/browserClient.ts b/lib/shared/src/sourcegraph-api/completions/browserClient.ts index 07a4b8ff49c0..b4d23828e37e 100644 --- a/lib/shared/src/sourcegraph-api/completions/browserClient.ts +++ b/lib/shared/src/sourcegraph-api/completions/browserClient.ts @@ -8,7 +8,7 @@ import { addClientInfoParams } from '../client-name-version' import { CompletionsResponseBuilder } from './CompletionsResponseBuilder' import { type CompletionRequestParameters, SourcegraphCompletionsClient } from './client' import { parseCompletionJSON } from './parse' -import type { CompletionCallbacks, CompletionParameters, Event } from './types' +import type { CompletionCallbacks, CompletionParameters, CompletionResponse, Event } from './types' import { getSerializedParams } from './utils' declare const WorkerGlobalScope: never @@ -115,6 +115,53 @@ export class SourcegraphBrowserCompletionsClient extends SourcegraphCompletionsC console.error(error) }) } + + protected async _fetchWithCallbacks( + params: CompletionParameters, + requestParams: CompletionRequestParameters, + cb: CompletionCallbacks, + signal?: AbortSignal + ): Promise { + const { url, serializedParams } = await this.prepareRequest(params, requestParams) + const headersInstance = new Headers({ + 'Content-Type': 'application/json; charset=utf-8', + ...this.config.customHeaders, + ...requestParams.customHeaders, + }) + addCustomUserAgent(headersInstance) + if (this.config.accessToken) { + headersInstance.set('Authorization', `token ${this.config.accessToken}`) + } + if (new URLSearchParams(globalThis.location.search).get('trace')) { + headersInstance.set('X-Sourcegraph-Should-Trace', 'true') + } + try { + const response = await fetch(url.toString(), { + method: 'POST', + headers: headersInstance, + body: JSON.stringify(serializedParams), + signal, + }) + if (!response.ok) { + const errorMessage = await response.text() + throw new Error( + errorMessage.length === 0 + ? `Request failed with status code ${response.status}` + : errorMessage + ) + } + const data = (await response.json()) as CompletionResponse + if (data?.completion) { + cb.onChange(data.completion) + cb.onComplete() + } else { + throw new Error('Unexpected response format') + } + } catch (error) { + console.error(error) + cb.onError(error instanceof Error ? error : new Error(`${error}`)) + } + } } if (isRunningInWebWorker) { diff --git a/lib/shared/src/sourcegraph-api/completions/client.ts b/lib/shared/src/sourcegraph-api/completions/client.ts index e9a8f393669f..bab946e9a4eb 100644 --- a/lib/shared/src/sourcegraph-api/completions/client.ts +++ b/lib/shared/src/sourcegraph-api/completions/client.ts @@ -1,6 +1,6 @@ import type { Span } from '@opentelemetry/api' +import { addClientInfoParams, getSerializedParams } from '../..' import type { ClientConfigurationWithAccessToken } from '../../configuration' - import { useCustomChatClient } from '../../llm-providers' import { recordErrorToSpan } from '../../tracing' import type { @@ -9,6 +9,7 @@ import type { CompletionParameters, CompletionResponse, Event, + SerializedCompletionParameters, } from './types' export interface CompletionLogger { @@ -45,6 +46,8 @@ export type CompletionsClientConfig = Pick< export abstract class SourcegraphCompletionsClient { private errorEncountered = false + protected readonly isTemperatureZero = process.env.CODY_TEMPERATURE_ZERO === 'true' + constructor( protected config: CompletionsClientConfig, protected logger?: CompletionLogger @@ -88,6 +91,27 @@ export abstract class SourcegraphCompletionsClient { } } + protected async prepareRequest( + params: CompletionParameters, + requestParams: CompletionRequestParameters + ): Promise<{ url: URL; serializedParams: SerializedCompletionParameters }> { + const { apiVersion } = requestParams + const serializedParams = await getSerializedParams(params) + const url = new URL(this.completionsEndpoint) + if (apiVersion >= 1) { + url.searchParams.append('api-version', '' + apiVersion) + } + addClientInfoParams(url.searchParams) + return { url, serializedParams } + } + + protected abstract _fetchWithCallbacks( + params: CompletionParameters, + requestParams: CompletionRequestParameters, + cb: CompletionCallbacks, + signal?: AbortSignal + ): Promise + protected abstract _streamWithCallbacks( params: CompletionParameters, requestParams: CompletionRequestParameters, @@ -144,7 +168,11 @@ export abstract class SourcegraphCompletionsClient { }) if (!isNonSourcegraphProvider) { - await this._streamWithCallbacks(params, requestParams, callbacks, signal) + if (params.stream === false) { + await this._fetchWithCallbacks(params, requestParams, callbacks, signal) + } else { + await this._streamWithCallbacks(params, requestParams, callbacks, signal) + } } for (let i = 0; ; i++) { diff --git a/vscode/CHANGELOG.md b/vscode/CHANGELOG.md index ed0ede5ad0f9..f72de1b40b9e 100644 --- a/vscode/CHANGELOG.md +++ b/vscode/CHANGELOG.md @@ -8,6 +8,7 @@ This is a log of all notable changes to Cody for VS Code. [Unreleased] changes a - The [new OpenAI models (OpenAI O1 & OpenAI O1-mini)](https://sourcegraph.com/blog/openai-o1-for-cody) are now available to selected Cody Pro users for early access. [pull/5508](https://github.com/sourcegraph/cody/pull/5508) - Cody Pro users can join the waitlist for the new models by clicking the "Join Waitlist" button. [pull/5508](https://github.com/sourcegraph/cody/pull/5508) +- Chat: Support non-streaming requests. [pull/5565](https://github.com/sourcegraph/cody/pull/5565) ### Fixed diff --git a/vscode/src/completions/nodeClient.ts b/vscode/src/completions/nodeClient.ts index 0cf5556232fc..b8607ec2c5fa 100644 --- a/vscode/src/completions/nodeClient.ts +++ b/vscode/src/completions/nodeClient.ts @@ -10,6 +10,7 @@ import { type CompletionCallbacks, type CompletionParameters, type CompletionRequestParameters, + type CompletionResponse, NetworkError, RateLimitError, SourcegraphCompletionsClient, @@ -29,8 +30,6 @@ import { } from '@sourcegraph/cody-shared' import { CompletionsResponseBuilder } from '@sourcegraph/cody-shared/src/sourcegraph-api/completions/CompletionsResponseBuilder' -const isTemperatureZero = process.env.CODY_TEMPERATURE_ZERO === 'true' - export class SourcegraphNodeCompletionsClient extends SourcegraphCompletionsClient { protected _streamWithCallbacks( params: CompletionParameters, @@ -56,7 +55,7 @@ export class SourcegraphNodeCompletionsClient extends SourcegraphCompletionsClie model: params.model, }) - if (isTemperatureZero) { + if (this.isTemperatureZero) { params = { ...params, temperature: 0, @@ -203,17 +202,6 @@ export class SourcegraphNodeCompletionsClient extends SourcegraphCompletionsClie bufferText += str bufferBin = buf - // HACK: Handles non-stream request. - // TODO: Implement a function to make and process non-stream requests. - if (params.stream === false) { - const json = JSON.parse(bufferText) - if (json?.completion) { - cb.onChange(json.completion) - cb.onComplete() - return - } - } - const parseResult = parseEvents(builder, bufferText) if (isError(parseResult)) { logError( @@ -286,6 +274,68 @@ export class SourcegraphNodeCompletionsClient extends SourcegraphCompletionsClie onAbort(signal, () => request.destroy()) }) } + + protected async _fetchWithCallbacks( + params: CompletionParameters, + requestParams: CompletionRequestParameters, + cb: CompletionCallbacks, + signal?: AbortSignal + ): Promise { + const { url, serializedParams } = await this.prepareRequest(params, requestParams) + const log = this.logger?.startCompletion(params, url.toString()) + return tracer.startActiveSpan(`POST ${url.toString()}`, async span => { + span.setAttributes({ + fast: params.fast, + maxTokensToSample: params.maxTokensToSample, + temperature: this.isTemperatureZero ? 0 : params.temperature, + topK: params.topK, + topP: params.topP, + model: params.model, + }) + try { + const response = await fetch(url.toString(), { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Accept-Encoding': 'gzip;q=0', + ...(this.config.accessToken + ? { Authorization: `token ${this.config.accessToken}` } + : null), + ...(customUserAgent ? { 'User-Agent': customUserAgent } : null), + ...this.config.customHeaders, + ...requestParams.customHeaders, + ...getTraceparentHeaders(), + }, + body: JSON.stringify(serializedParams), + signal, + }) + if (!response.ok) { + const errorMessage = await response.text() + throw new NetworkError( + { + url: url.toString(), + status: response.status, + statusText: response.statusText, + }, + errorMessage, + getActiveTraceAndSpanId()?.traceId + ) + } + const json = (await response.json()) as CompletionResponse + if (typeof json?.completion === 'string') { + cb.onChange(json.completion) + cb.onComplete() + return + } + throw new Error('Unexpected response format') + } catch (error) { + const errorObject = error instanceof Error ? error : new Error(`${error}`) + log?.onError(errorObject.message, error) + recordErrorToSpan(span, errorObject) + cb.onError(errorObject) + } + }) + } } function getHeader(value: string | undefined | string[]): string | undefined {