diff --git a/examples/json_format.js b/examples/json_format.js index 7803c56..0e70e4b 100644 --- a/examples/json_format.js +++ b/examples/json_format.js @@ -5,7 +5,7 @@ const apiKey = process.env.MISTRAL_API_KEY; const client = new MistralClient(apiKey); const chatResponse = await client.chat({ - model: 'mistral-large', + model: 'mistral-large-latest', messages: [{role: 'user', content: 'What is the best French cheese?'}], responseFormat: {type: 'json_object'}, }); diff --git a/examples/package-lock.json b/examples/package-lock.json index be7d412..7a2df17 100644 --- a/examples/package-lock.json +++ b/examples/package-lock.json @@ -13,7 +13,7 @@ }, "..": { "name": "@mistralai/mistralai", - "version": "0.0.1", + "version": "0.4.0", "license": "ISC", "dependencies": { "node-fetch": "^2.6.7" diff --git a/package.json b/package.json index a4038a6..b1583ce 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@mistralai/mistralai", - "version": "0.3.0", + "version": "0.4.0", "description": "", "author": "bam4d@mistral.ai", "license": "ISC", diff --git a/src/client.d.ts b/src/client.d.ts index fe820ef..411a0b3 100644 --- a/src/client.d.ts +++ b/src/client.d.ts @@ -141,6 +141,17 @@ declare module "@mistralai/mistralai" { responseFormat?: ResponseFormat; } + export interface CompletionRequest { + model: string; + prompt: string; + suffix?: string; + temperature?: number; + maxTokens?: number; + topP?: number; + randomSeed?: number; + stop?: string | string[]; + } + export interface ChatRequestOptions { signal?: AbortSignal; } @@ -170,6 +181,17 @@ declare module "@mistralai/mistralai" { options?: ChatRequestOptions ): AsyncGenerator; + completion( + request: CompletionRequest, + options?: ChatRequestOptions + ): Promise; + + completionStream( + request: CompletionRequest, + options?: ChatRequestOptions + ): AsyncGenerator; + + embeddings(options: { model: string; input: string | string[]; diff --git a/src/client.js b/src/client.js index 836295d..00cf199 100644 --- a/src/client.js +++ b/src/client.js @@ -161,7 +161,7 @@ class MistralClient { } else { throw new MistralAPIError( `HTTP error! status: ${response.status} ` + - `Response: \n${await response.text()}`, + `Response: \n${await response.text()}`, ); } } catch (error) { @@ -228,6 +228,47 @@ class MistralClient { }; }; + /** + * Creates a completion request + * @param {*} model + * @param {*} prompt + * @param {*} suffix + * @param {*} temperature + * @param {*} maxTokens + * @param {*} topP + * @param {*} randomSeed + * @param {*} stop + * @param {*} stream + * @return {Promise} + */ + _makeCompletionRequest = function( + model, + prompt, + suffix, + temperature, + maxTokens, + topP, + randomSeed, + stop, + stream, + ) { + // if modelDefault and model are undefined, throw an error + if (!model && !this.modelDefault) { + throw new MistralAPIError('You must provide a model name'); + } + return { + model: model ?? this.modelDefault, + prompt: prompt, + suffix: suffix ?? undefined, + temperature: temperature ?? undefined, + max_tokens: maxTokens ?? undefined, + top_p: topP ?? undefined, + random_seed: randomSeed ?? undefined, + stop: stop ?? undefined, + stream: stream ?? undefined, + }; + }; + /** * Returns a list of the available models * @return {Promise} @@ -401,6 +442,134 @@ class MistralClient { const response = await this._request('post', 'v1/embeddings', request); return response; }; + + /** + * A completion endpoint without streaming. + * + * @param {Object} data - The main completion configuration. + * @param {*} data.model - the name of the model to chat with, + * e.g. mistral-tiny + * @param {*} data.prompt - the prompt to complete, + * e.g. 'def fibonacci(n: int):' + * @param {*} data.temperature - the temperature to use for sampling, e.g. 0.5 + * @param {*} data.maxTokens - the maximum number of tokens to generate, + * e.g. 100 + * @param {*} data.topP - the cumulative probability of tokens to generate, + * e.g. 0.9 + * @param {*} data.randomSeed - the random seed to use for sampling, e.g. 42 + * @param {*} data.stop - the stop sequence to use, e.g. ['\n'] + * @param {*} data.suffix - the suffix to append to the prompt, + * e.g. 'n = int(input(\'Enter a number: \'))' + * @param {Object} options - Additional operational options. + * @param {*} [options.signal] - optional AbortSignal instance to control + * request The signal will be combined with + * default timeout signal + * @return {Promise} + */ + completion = async function( + { + model, + prompt, + suffix, + temperature, + maxTokens, + topP, + randomSeed, + stop, + }, + {signal} = {}, + ) { + const request = this._makeCompletionRequest( + model, + prompt, + suffix, + temperature, + maxTokens, + topP, + randomSeed, + stop, + false, + ); + const response = await this._request( + 'post', + 'v1/fim/completions', + request, + signal, + ); + return response; + }; + + /** + * A completion endpoint that streams responses. + * + * @param {Object} data - The main completion configuration. + * @param {*} data.model - the name of the model to chat with, + * e.g. mistral-tiny + * @param {*} data.prompt - the prompt to complete, + * e.g. 'def fibonacci(n: int):' + * @param {*} data.temperature - the temperature to use for sampling, e.g. 0.5 + * @param {*} data.maxTokens - the maximum number of tokens to generate, + * e.g. 100 + * @param {*} data.topP - the cumulative probability of tokens to generate, + * e.g. 0.9 + * @param {*} data.randomSeed - the random seed to use for sampling, e.g. 42 + * @param {*} data.stop - the stop sequence to use, e.g. ['\n'] + * @param {*} data.suffix - the suffix to append to the prompt, + * e.g. 'n = int(input(\'Enter a number: \'))' + * @param {Object} options - Additional operational options. + * @param {*} [options.signal] - optional AbortSignal instance to control + * request The signal will be combined with + * default timeout signal + * @return {Promise} + */ + completionStream = async function* ( + { + model, + prompt, + suffix, + temperature, + maxTokens, + topP, + randomSeed, + stop, + }, + {signal} = {}, + ) { + const request = this._makeCompletionRequest( + model, + prompt, + suffix, + temperature, + maxTokens, + topP, + randomSeed, + stop, + true, + ); + const response = await this._request( + 'post', + 'v1/fim/completions', + request, + signal, + ); + + let buffer = ''; + const decoder = new TextDecoder(); + for await (const chunk of response) { + buffer += decoder.decode(chunk, {stream: true}); + let firstNewline; + while ((firstNewline = buffer.indexOf('\n')) !== -1) { + const chunkLine = buffer.substring(0, firstNewline); + buffer = buffer.substring(firstNewline + 1); + if (chunkLine.startsWith('data:')) { + const json = chunkLine.substring(6).trim(); + if (json !== '[DONE]') { + yield JSON.parse(json); + } + } + } + } + }; } export default MistralClient; diff --git a/tests/client.test.js b/tests/client.test.js index 54b52b6..7113a86 100644 --- a/tests/client.test.js +++ b/tests/client.test.js @@ -23,7 +23,7 @@ describe('Mistral Client', () => { client._fetch = mockFetch(200, mockResponse); const response = await client.chat({ - model: 'mistral-small', + model: 'mistral-small-latest', messages: [ { role: 'user', @@ -40,7 +40,7 @@ describe('Mistral Client', () => { client._fetch = mockFetch(200, mockResponse); const response = await client.chat({ - model: 'mistral-small', + model: 'mistral-small-latest', messages: [ { role: 'user', @@ -58,7 +58,7 @@ describe('Mistral Client', () => { client._fetch = mockFetch(200, mockResponse); const response = await client.chat({ - model: 'mistral-small', + model: 'mistral-small-latest', messages: [ { role: 'user', @@ -78,7 +78,7 @@ describe('Mistral Client', () => { client._fetch = mockFetchStream(200, mockResponse); const response = await client.chatStream({ - model: 'mistral-small', + model: 'mistral-small-latest', messages: [ { role: 'user', @@ -101,7 +101,7 @@ describe('Mistral Client', () => { client._fetch = mockFetchStream(200, mockResponse); const response = await client.chatStream({ - model: 'mistral-small', + model: 'mistral-small-latest', messages: [ { role: 'user', @@ -125,7 +125,7 @@ describe('Mistral Client', () => { client._fetch = mockFetchStream(200, mockResponse); const response = await client.chatStream({ - model: 'mistral-small', + model: 'mistral-small-latest', messages: [ { role: 'user', @@ -176,4 +176,18 @@ describe('Mistral Client', () => { expect(response).toEqual(mockResponse); }); }); + + describe('completion()', () => { + it('should return a chat response object', async() => { + // Mock the fetch function + const mockResponse = mockChatResponsePayload(); + client._fetch = mockFetch(200, mockResponse); + + const response = await client.completion({ + model: 'mistral-small-latest', + prompt: '# this is a', + }); + expect(response).toEqual(mockResponse); + }); + }); }); diff --git a/tests/utils.js b/tests/utils.js index 3370228..b49f1d0 100644 --- a/tests/utils.js +++ b/tests/utils.js @@ -78,7 +78,7 @@ export function mockListModels() { ], }, { - id: 'mistral-small', + id: 'mistral-small-latest', object: 'model', created: 1703186988, owned_by: 'mistralai', @@ -172,7 +172,7 @@ export function mockChatResponsePayload() { index: 0, }, ], - model: 'mistral-small', + model: 'mistral-small-latest', usage: {prompt_tokens: 90, total_tokens: 90, completion_tokens: 0}, }; } @@ -187,7 +187,7 @@ export function mockChatResponseStreamingPayload() { [encoder.encode('data: ' + JSON.stringify({ id: 'cmpl-8cd9019d21ba490aa6b9740f5d0a883e', - model: 'mistral-small', + model: 'mistral-small-latest', choices: [ { index: 0, @@ -207,7 +207,7 @@ export function mockChatResponseStreamingPayload() { id: 'cmpl-8cd9019d21ba490aa6b9740f5d0a883e', object: 'chat.completion.chunk', created: 1703168544, - model: 'mistral-small', + model: 'mistral-small-latest', choices: [ { index: i,