Skip to content

Commit

Permalink
feat: support xunfei's v2 api (#442, close #440)
Browse files Browse the repository at this point in the history
* 兼容讯飞v2接口

* Revert "兼容讯飞v2接口"

This reverts commit 21f05d1.

* fix: fix implementation

---------

Co-authored-by: JustSong <[email protected]>
Co-authored-by: JustSong <[email protected]>
  • Loading branch information
3 people authored Aug 19, 2023
1 parent dfaa018 commit 7e058bf
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 5 deletions.
21 changes: 17 additions & 4 deletions controller/relay-xunfei.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ type XunfeiChatResponse struct {
} `json:"payload"`
}

func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *XunfeiChatRequest {
func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
messages := make([]XunfeiMessage, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
Expand All @@ -96,7 +96,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *Xun
}
xunfeiRequest := XunfeiChatRequest{}
xunfeiRequest.Header.AppId = xunfeiAppId
xunfeiRequest.Parameter.Chat.Domain = "general"
xunfeiRequest.Parameter.Chat.Domain = domain
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
xunfeiRequest.Parameter.Chat.TopK = request.N
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
Expand Down Expand Up @@ -178,15 +178,28 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {

func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
if apiVersion == "" {
apiVersion = "v1.1"
common.SysLog("api_version not found, use default: " + apiVersion)
}
domain := "general"
if apiVersion == "v2.1" {
domain = "generalv2"
}
hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion)
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
hostUrl := "wss://aichat.xf-yun.com/v1/chat"
conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
if err != nil || resp.StatusCode != 101 {
return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
}
data := requestOpenAI2Xunfei(textRequest, appId)
data := requestOpenAI2Xunfei(textRequest, appId, domain)
err = conn.WriteJSON(data)
if err != nil {
return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil
Expand Down
2 changes: 2 additions & 0 deletions i18n/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -521,5 +521,7 @@
"此项可选,用于通过代理站来进行 API 调用,请输入代理站地址,格式为:https://domain.com": "This is optional, used to make API calls through the proxy site, please enter the proxy site address, the format is: https://domain.com",
"取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?",
"按照如下格式输入:": "Enter in the following format:",
"模型版本": "Model version",
"请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1",
"点击查看": "click to view"
}
2 changes: 1 addition & 1 deletion middleware/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func Distribute() func(c *gin.Context) {
c.Set("model_mapping", channel.ModelMapping)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.BaseURL)
if channel.Type == common.ChannelTypeAzure {
if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei {
c.Set("api_version", channel.Other)
}
c.Next()
Expand Down
17 changes: 17 additions & 0 deletions web/src/pages/Channel/EditChannel.js
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ const EditChannel = () => {
if (localInputs.type === 3 && localInputs.other === '') {
localInputs.other = '2023-06-01-preview';
}
if (localInputs.type === 18 && localInputs.other === '') {
localInputs.other = 'v2.1';
}
if (localInputs.model_mapping === '') {
localInputs.model_mapping = '{}';
}
Expand Down Expand Up @@ -275,6 +278,20 @@ const EditChannel = () => {
options={groupOptions}
/>
</Form.Field>
{
inputs.type === 18 && (
<Form.Field>
<Form.Input
label='模型版本'
name='other'
placeholder={'请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1'}
onChange={handleInputChange}
value={inputs.other}
autoComplete='new-password'
/>
</Form.Field>
)
}
<Form.Field>
<Form.Dropdown
label='模型'
Expand Down

0 comments on commit 7e058bf

Please sign in to comment.