Skip to content

Commit

Permalink
Merge pull request #95 from YuJian920/azure-patch
Browse files Browse the repository at this point in the history
Add support for Azure OpenAI services
  • Loading branch information
ddiu8081 authored Sep 6, 2023
2 parents 79a1c72 + 2ba34f8 commit 4094602
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/providers/azure/api.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
export interface AzureFetchPayload {
apiKey: string
baseUrl: string
body: Record<string, any>
model?: string
signal?: AbortSignal
}

export const fetchChatCompletion = async(payload: AzureFetchPayload) => {
const { baseUrl, apiKey, body, model, signal } = payload || {}
const initOptions = {
headers: { 'Content-Type': 'application/json', 'api-key': apiKey },
method: 'POST',
body: JSON.stringify({ ...body }),
signal,
}
return fetch(`${baseUrl}/openai/deployments/${model}/chat/completions?api-version=2023-08-01-preview`, initOptions)
}

export const fetchImageGeneration = async(payload: AzureFetchPayload) => {
const { baseUrl, apiKey, body } = payload || {}
const initOptions = {
headers: { 'Content-Type': 'application/json', 'api-key': apiKey },
method: 'POST',
body: JSON.stringify(body),
}
return fetch(`${baseUrl}.openai.azure.com/openai/images/generations:submit?api-version=2023-08-01-preview`, initOptions)
}
100 changes: 100 additions & 0 deletions src/providers/azure/handler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import { fetchChatCompletion, fetchImageGeneration } from './api'
import { parseStream } from './parser'
import type { Message } from '@/types/message'
import type { HandlerPayload, Provider } from '@/types/provider'

export const handlePrompt: Provider['handlePrompt'] = async(payload, signal?: AbortSignal) => {
if (payload.botId === 'chat_continuous')
return handleChatCompletion(payload, signal)
if (payload.botId === 'chat_single')
return handleChatCompletion(payload, signal)
if (payload.botId === 'image_generation')
return handleImageGeneration(payload)
}

export const handleRapidPrompt: Provider['handleRapidPrompt'] = async(prompt, globalSettings) => {
const rapidPromptPayload = {
conversationId: 'temp',
conversationType: 'chat_single',
botId: 'temp',
globalSettings: {
...globalSettings,
temperature: 0.4,
maxTokens: 2048,
top_p: 1,
stream: false,
},
botSettings: {},
prompt,
messages: [{ role: 'user', content: prompt }],
} as HandlerPayload
const result = await handleChatCompletion(rapidPromptPayload)
if (typeof result === 'string') return result
return ''
}

const handleChatCompletion = async(payload: HandlerPayload, signal?: AbortSignal) => {
// An array to store the chat messages
const messages: Message[] = []

let maxTokens = payload.globalSettings.maxTokens as number
let messageHistorySize = payload.globalSettings.messageHistorySize as number

// Iterate through the message history
while (messageHistorySize > 0) {
messageHistorySize--
// Get the last message from the payload
const m = payload.messages.pop()
if (m === undefined)
break

if (maxTokens - m.content.length < 0)
break

maxTokens -= m.content.length
messages.unshift(m)
}

const response = await fetchChatCompletion({
apiKey: payload.globalSettings.apiKey as string,
baseUrl: (payload.globalSettings.baseUrl as string).trim().replace(/\/$/, ''),
body: {
messages,
max_tokens: maxTokens,
temperature: payload.globalSettings.temperature as number,
top_p: payload.globalSettings.topP as number,
stream: payload.globalSettings.stream as boolean ?? true,
},
model: payload.globalSettings.model as string,
signal,
})
if (!response.ok) {
const responseJson = await response.json()
console.log('responseJson', responseJson)
const errMessage = responseJson.error?.message || response.statusText || 'Unknown error'
throw new Error(errMessage, { cause: responseJson.error })
}
const isStream = response.headers.get('content-type')?.includes('text/event-stream')
if (isStream) {
return parseStream(response)
} else {
const resJson = await response.json()
return resJson.choices[0].message.content as string
}
}

const handleImageGeneration = async(payload: HandlerPayload) => {
const prompt = payload.prompt
const response = await fetchImageGeneration({
apiKey: payload.globalSettings.apiKey as string,
baseUrl: (payload.globalSettings.baseUrl as string).trim().replace(/\/$/, ''),
body: { prompt, n: 1, size: '512x512' },
})
if (!response.ok) {
const responseJson = await response.json()
const errMessage = responseJson.error?.message || response.statusText || 'Unknown error'
throw new Error(errMessage)
}
const resJson = await response.json()
return resJson.data[0].url
}
97 changes: 97 additions & 0 deletions src/providers/azure/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import {
handlePrompt,
handleRapidPrompt,
} from './handler'
import type { Provider } from '@/types/provider'

const providerOpenAI = () => {
const provider: Provider = {
id: 'provider-azure',
icon: 'i-simple-icons:microsoftazure', // @unocss-include
name: 'Azure OpenAI',
globalSettings: [
{
key: 'apiKey',
name: 'API Key',
type: 'api-key',
},
{
key: 'baseUrl',
name: 'Endpoint',
description: 'OpenAI Endpoint',
type: 'input',
},
{
key: 'model',
name: 'Azure deployment name',
description: 'Custom model name for Azure OpenAI.',
type: 'input',
},
{
key: 'maxTokens',
name: 'Max Tokens',
description: 'The maximum number of tokens to generate in the completion.',
type: 'slider',
min: 0,
max: 32768,
default: 2048,
step: 1,
},
{
key: 'messageHistorySize',
name: 'Max History Message Size',
description: 'The number of retained historical messages will be truncated if the length of the message exceeds the MaxToken parameter.',
type: 'slider',
min: 1,
max: 24,
default: 5,
step: 1,
},
{
key: 'temperature',
name: 'Temperature',
type: 'slider',
description: 'What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.',
min: 0,
max: 2,
default: 0.7,
step: 0.01,
},
{
key: 'top_p',
name: 'Top P',
description: 'An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.',
type: 'slider',
min: 0,
max: 1,
default: 1,
step: 0.01,
},
],
bots: [
{
id: 'chat_continuous',
type: 'chat_continuous',
name: 'Continuous Chat',
settings: [],
},
{
id: 'chat_single',
type: 'chat_single',
name: 'Single Chat',
settings: [],
},
{
id: 'image_generation',
type: 'image_generation',
name: 'DALL·E',
settings: [],
},
],
handlePrompt,
handleRapidPrompt,
}
return provider
}

export default providerOpenAI
42 changes: 42 additions & 0 deletions src/providers/azure/parser.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { createParser } from 'eventsource-parser'
import type { ParsedEvent, ReconnectInterval } from 'eventsource-parser'

export const parseStream = (rawResponse: Response) => {
const encoder = new TextEncoder()
const decoder = new TextDecoder()
const rb = rawResponse.body as ReadableStream

return new ReadableStream({
async start(controller) {
const streamParser = (event: ParsedEvent | ReconnectInterval) => {
if (event.type === 'event') {
const data = event.data
if (data === '[DONE]') {
controller.close()
return
}
try {
const json = JSON.parse(data)
const text = (json.choices?.[0]?.delta?.content) || ''
const queue = encoder.encode(text)
controller.enqueue(queue)
} catch (e) {
controller.error(e)
}
}
}
const reader = rb.getReader()
const parser = createParser(streamParser)
let done = false
while (!done) {
const { done: isDone, value } = await reader.read()
if (isDone) {
done = true
controller.close()
return
}
parser.feed(decoder.decode(value))
}
},
})
}
2 changes: 2 additions & 0 deletions src/stores/provider.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import providerOpenAI from '@/providers/openai'
import providerAzure from '@/providers/azure'
import providerReplicate from '@/providers/replicate'
import { allConversationTypes } from '@/types/conversation'
import type { BotMeta } from '@/types/app'

export const providerList = [
providerOpenAI(),
providerAzure(),
providerReplicate(),
]

Expand Down

0 comments on commit 4094602

Please sign in to comment.