-
Notifications
You must be signed in to change notification settings - Fork 430
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #95 from YuJian920/azure-patch
Add support for Azure OpenAI services
- Loading branch information
Showing
5 changed files
with
269 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
export interface AzureFetchPayload { | ||
apiKey: string | ||
baseUrl: string | ||
body: Record<string, any> | ||
model?: string | ||
signal?: AbortSignal | ||
} | ||
|
||
export const fetchChatCompletion = async(payload: AzureFetchPayload) => { | ||
const { baseUrl, apiKey, body, model, signal } = payload || {} | ||
const initOptions = { | ||
headers: { 'Content-Type': 'application/json', 'api-key': apiKey }, | ||
method: 'POST', | ||
body: JSON.stringify({ ...body }), | ||
signal, | ||
} | ||
return fetch(`${baseUrl}/openai/deployments/${model}/chat/completions?api-version=2023-08-01-preview`, initOptions) | ||
} | ||
|
||
export const fetchImageGeneration = async(payload: AzureFetchPayload) => { | ||
const { baseUrl, apiKey, body } = payload || {} | ||
const initOptions = { | ||
headers: { 'Content-Type': 'application/json', 'api-key': apiKey }, | ||
method: 'POST', | ||
body: JSON.stringify(body), | ||
} | ||
return fetch(`${baseUrl}.openai.azure.com/openai/images/generations:submit?api-version=2023-08-01-preview`, initOptions) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import { fetchChatCompletion, fetchImageGeneration } from './api' | ||
import { parseStream } from './parser' | ||
import type { Message } from '@/types/message' | ||
import type { HandlerPayload, Provider } from '@/types/provider' | ||
|
||
export const handlePrompt: Provider['handlePrompt'] = async(payload, signal?: AbortSignal) => { | ||
if (payload.botId === 'chat_continuous') | ||
return handleChatCompletion(payload, signal) | ||
if (payload.botId === 'chat_single') | ||
return handleChatCompletion(payload, signal) | ||
if (payload.botId === 'image_generation') | ||
return handleImageGeneration(payload) | ||
} | ||
|
||
export const handleRapidPrompt: Provider['handleRapidPrompt'] = async(prompt, globalSettings) => { | ||
const rapidPromptPayload = { | ||
conversationId: 'temp', | ||
conversationType: 'chat_single', | ||
botId: 'temp', | ||
globalSettings: { | ||
...globalSettings, | ||
temperature: 0.4, | ||
maxTokens: 2048, | ||
top_p: 1, | ||
stream: false, | ||
}, | ||
botSettings: {}, | ||
prompt, | ||
messages: [{ role: 'user', content: prompt }], | ||
} as HandlerPayload | ||
const result = await handleChatCompletion(rapidPromptPayload) | ||
if (typeof result === 'string') return result | ||
return '' | ||
} | ||
|
||
const handleChatCompletion = async(payload: HandlerPayload, signal?: AbortSignal) => { | ||
// An array to store the chat messages | ||
const messages: Message[] = [] | ||
|
||
let maxTokens = payload.globalSettings.maxTokens as number | ||
let messageHistorySize = payload.globalSettings.messageHistorySize as number | ||
|
||
// Iterate through the message history | ||
while (messageHistorySize > 0) { | ||
messageHistorySize-- | ||
// Get the last message from the payload | ||
const m = payload.messages.pop() | ||
if (m === undefined) | ||
break | ||
|
||
if (maxTokens - m.content.length < 0) | ||
break | ||
|
||
maxTokens -= m.content.length | ||
messages.unshift(m) | ||
} | ||
|
||
const response = await fetchChatCompletion({ | ||
apiKey: payload.globalSettings.apiKey as string, | ||
baseUrl: (payload.globalSettings.baseUrl as string).trim().replace(/\/$/, ''), | ||
body: { | ||
messages, | ||
max_tokens: maxTokens, | ||
temperature: payload.globalSettings.temperature as number, | ||
top_p: payload.globalSettings.topP as number, | ||
stream: payload.globalSettings.stream as boolean ?? true, | ||
}, | ||
model: payload.globalSettings.model as string, | ||
signal, | ||
}) | ||
if (!response.ok) { | ||
const responseJson = await response.json() | ||
console.log('responseJson', responseJson) | ||
const errMessage = responseJson.error?.message || response.statusText || 'Unknown error' | ||
throw new Error(errMessage, { cause: responseJson.error }) | ||
} | ||
const isStream = response.headers.get('content-type')?.includes('text/event-stream') | ||
if (isStream) { | ||
return parseStream(response) | ||
} else { | ||
const resJson = await response.json() | ||
return resJson.choices[0].message.content as string | ||
} | ||
} | ||
|
||
const handleImageGeneration = async(payload: HandlerPayload) => { | ||
const prompt = payload.prompt | ||
const response = await fetchImageGeneration({ | ||
apiKey: payload.globalSettings.apiKey as string, | ||
baseUrl: (payload.globalSettings.baseUrl as string).trim().replace(/\/$/, ''), | ||
body: { prompt, n: 1, size: '512x512' }, | ||
}) | ||
if (!response.ok) { | ||
const responseJson = await response.json() | ||
const errMessage = responseJson.error?.message || response.statusText || 'Unknown error' | ||
throw new Error(errMessage) | ||
} | ||
const resJson = await response.json() | ||
return resJson.data[0].url | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import { | ||
handlePrompt, | ||
handleRapidPrompt, | ||
} from './handler' | ||
import type { Provider } from '@/types/provider' | ||
|
||
const providerOpenAI = () => { | ||
const provider: Provider = { | ||
id: 'provider-azure', | ||
icon: 'i-simple-icons:microsoftazure', // @unocss-include | ||
name: 'Azure OpenAI', | ||
globalSettings: [ | ||
{ | ||
key: 'apiKey', | ||
name: 'API Key', | ||
type: 'api-key', | ||
}, | ||
{ | ||
key: 'baseUrl', | ||
name: 'Endpoint', | ||
description: 'OpenAI Endpoint', | ||
type: 'input', | ||
}, | ||
{ | ||
key: 'model', | ||
name: 'Azure deployment name', | ||
description: 'Custom model name for Azure OpenAI.', | ||
type: 'input', | ||
}, | ||
{ | ||
key: 'maxTokens', | ||
name: 'Max Tokens', | ||
description: 'The maximum number of tokens to generate in the completion.', | ||
type: 'slider', | ||
min: 0, | ||
max: 32768, | ||
default: 2048, | ||
step: 1, | ||
}, | ||
{ | ||
key: 'messageHistorySize', | ||
name: 'Max History Message Size', | ||
description: 'The number of retained historical messages will be truncated if the length of the message exceeds the MaxToken parameter.', | ||
type: 'slider', | ||
min: 1, | ||
max: 24, | ||
default: 5, | ||
step: 1, | ||
}, | ||
{ | ||
key: 'temperature', | ||
name: 'Temperature', | ||
type: 'slider', | ||
description: 'What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.', | ||
min: 0, | ||
max: 2, | ||
default: 0.7, | ||
step: 0.01, | ||
}, | ||
{ | ||
key: 'top_p', | ||
name: 'Top P', | ||
description: 'An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.', | ||
type: 'slider', | ||
min: 0, | ||
max: 1, | ||
default: 1, | ||
step: 0.01, | ||
}, | ||
], | ||
bots: [ | ||
{ | ||
id: 'chat_continuous', | ||
type: 'chat_continuous', | ||
name: 'Continuous Chat', | ||
settings: [], | ||
}, | ||
{ | ||
id: 'chat_single', | ||
type: 'chat_single', | ||
name: 'Single Chat', | ||
settings: [], | ||
}, | ||
{ | ||
id: 'image_generation', | ||
type: 'image_generation', | ||
name: 'DALL·E', | ||
settings: [], | ||
}, | ||
], | ||
handlePrompt, | ||
handleRapidPrompt, | ||
} | ||
return provider | ||
} | ||
|
||
export default providerOpenAI |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import { createParser } from 'eventsource-parser' | ||
import type { ParsedEvent, ReconnectInterval } from 'eventsource-parser' | ||
|
||
export const parseStream = (rawResponse: Response) => { | ||
const encoder = new TextEncoder() | ||
const decoder = new TextDecoder() | ||
const rb = rawResponse.body as ReadableStream | ||
|
||
return new ReadableStream({ | ||
async start(controller) { | ||
const streamParser = (event: ParsedEvent | ReconnectInterval) => { | ||
if (event.type === 'event') { | ||
const data = event.data | ||
if (data === '[DONE]') { | ||
controller.close() | ||
return | ||
} | ||
try { | ||
const json = JSON.parse(data) | ||
const text = (json.choices?.[0]?.delta?.content) || '' | ||
const queue = encoder.encode(text) | ||
controller.enqueue(queue) | ||
} catch (e) { | ||
controller.error(e) | ||
} | ||
} | ||
} | ||
const reader = rb.getReader() | ||
const parser = createParser(streamParser) | ||
let done = false | ||
while (!done) { | ||
const { done: isDone, value } = await reader.read() | ||
if (isDone) { | ||
done = true | ||
controller.close() | ||
return | ||
} | ||
parser.feed(decoder.decode(value)) | ||
} | ||
}, | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters