Skip to content
This repository has been archived by the owner on Oct 10, 2024. It is now read-only.

Commit

Permalink
Update client.js
Browse files Browse the repository at this point in the history
  • Loading branch information
mjdaoudi committed Jan 6, 2024
1 parent f7049e5 commit f3e2485
Showing 1 changed file with 46 additions and 39 deletions.
85 changes: 46 additions & 39 deletions src/client.js
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
let isNode = false;
if (typeof window === 'undefined' || typeof globalThis.fetch !== 'undefined') {
globalThis.fetch = (await import('node-fetch')).default;
isNode = true;

async function initializeFetch() {
if (typeof globalThis.fetch === "undefined") {
const nodeFetch = await import("node-fetch");
fetch = nodeFetch.default;
isNode = true;
} else {
fetch = globalThis.fetch;
}
}

initializeFetch();

const RETRY_STATUS_CODES = [429, 500, 502, 503, 504];
const ENDPOINT = 'https://api.mistral.ai';
const ENDPOINT = "https://api.mistral.ai";

/**
* MistralAPIError
Expand All @@ -20,9 +27,9 @@ class MistralAPIError extends Error {
*/
constructor(message) {
super(message);
this.name = 'MistralAPIError';
this.name = "MistralAPIError";
}
};
}

/**
* MistralClient
Expand All @@ -41,7 +48,7 @@ class MistralClient {
apiKey = process.env.MISTRAL_API_KEY,
endpoint = ENDPOINT,
maxRetries = 5,
timeout = 120,
timeout = 120
) {
this.endpoint = endpoint;
this.apiKey = apiKey;
Expand All @@ -57,16 +64,16 @@ class MistralClient {
* @param {*} request
* @return {Promise<*>}
*/
_request = async function(method, path, request) {
_request = async function (method, path, request) {
const url = `${this.endpoint}/${path}`;
const options = {
method: method,
headers: {
'Accept': 'application/json',
'Content-Type': 'application/json',
'Authorization': `Bearer ${this.apiKey}`,
Accept: "application/json",
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
},
body: method !== 'get' ? JSON.stringify(request) : null,
body: method !== "get" ? JSON.stringify(request) : null,
timeout: this.timeout * 1000,
};

Expand All @@ -86,11 +93,11 @@ class MistralClient {
const decoder = new TextDecoder();
while (true) {
// Read from the stream
const {done, value} = await reader.read();
const { done, value } = await reader.read();
// Exit if we're done
if (done) return;
// Else yield the chunk
yield decoder.decode(value, {stream: true});
yield decoder.decode(value, { stream: true });
}
} finally {
reader.releaseLock();
Expand All @@ -105,31 +112,31 @@ class MistralClient {
console.debug(
`Retrying request on response status: ${response.status}`,
`Response: ${await response.text()}`,
`Attempt: ${attempts + 1}`,
`Attempt: ${attempts + 1}`
);
// eslint-disable-next-line max-len
await new Promise((resolve) =>
setTimeout(resolve, Math.pow(2, (attempts + 1)) * 500),
setTimeout(resolve, Math.pow(2, attempts + 1) * 500)
);
} else {
throw new MistralAPIError(
`HTTP error! status: ${response.status} ` +
`Response: \n${await response.text()}`,
`Response: \n${await response.text()}`
);
}
} catch (error) {
console.error(`Request failed: ${error.message}`);
if (error.name === 'MistralAPIError') {
if (error.name === "MistralAPIError") {
throw error;
}
if (attempts === this.maxRetries - 1) throw error;
// eslint-disable-next-line max-len
await new Promise((resolve) =>
setTimeout(resolve, Math.pow(2, (attempts + 1)) * 500),
setTimeout(resolve, Math.pow(2, attempts + 1) * 500)
);
}
}
throw new Error('Max retries reached');
throw new Error("Max retries reached");
};

/**
Expand All @@ -144,15 +151,15 @@ class MistralClient {
* @param {*} safeMode
* @return {Promise<Object>}
*/
_makeChatCompletionRequest = function(
_makeChatCompletionRequest = function (
model,
messages,
temperature,
maxTokens,
topP,
randomSeed,
stream,
safeMode,
safeMode
) {
return {
model: model,
Expand All @@ -170,8 +177,8 @@ class MistralClient {
* Returns a list of the available models
* @return {Promise<Object>}
*/
listModels = async function() {
const response = await this._request('get', 'v1/models');
listModels = async function () {
const response = await this._request("get", "v1/models");
return response;
};

Expand All @@ -187,7 +194,7 @@ class MistralClient {
* @param {*} safeMode whether to use safe mode, e.g. true
* @return {Promise<Object>}
*/
chat = async function({
chat = async function ({
model,
messages,
temperature,
Expand All @@ -204,12 +211,12 @@ class MistralClient {
topP,
randomSeed,
false,
safeMode,
safeMode
);
const response = await this._request(
'post',
'v1/chat/completions',
request,
"post",
"v1/chat/completions",
request
);
return response;
};
Expand Down Expand Up @@ -243,25 +250,25 @@ class MistralClient {
topP,
randomSeed,
true,
safeMode,
safeMode
);
const response = await this._request(
'post',
'v1/chat/completions',
request,
"post",
"v1/chat/completions",
request
);

let buffer = '';
let buffer = "";

for await (const chunk of response) {
buffer += chunk;
let firstNewline;
while ((firstNewline = buffer.indexOf('\n')) !== -1) {
while ((firstNewline = buffer.indexOf("\n")) !== -1) {
const chunkLine = buffer.substring(0, firstNewline);
buffer = buffer.substring(firstNewline + 1);
if (chunkLine.startsWith('data:')) {
if (chunkLine.startsWith("data:")) {
const json = chunkLine.substring(6).trim();
if (json !== '[DONE]') {
if (json !== "[DONE]") {
yield JSON.parse(json);
}
}
Expand All @@ -277,12 +284,12 @@ class MistralClient {
* e.g. ['What is the best French cheese?']
* @return {Promise<Object>}
*/
embeddings = async function({model, input}) {
embeddings = async function ({ model, input }) {
const request = {
model: model,
input: input,
};
const response = await this._request('post', 'v1/embeddings', request);
const response = await this._request("post", "v1/embeddings", request);
return response;
};
}
Expand Down

0 comments on commit f3e2485

Please sign in to comment.