Skip to content

Commit

Permalink
feat: support gpt-4 with vision (#683, #714)
Browse files Browse the repository at this point in the history
  • Loading branch information
songquanpeng committed Nov 19, 2023
1 parent 76f9288 commit 495fc62
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 17 deletions.
2 changes: 1 addition & 1 deletion controller/relay-aiproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ type AIProxyLibraryStreamResponse struct {
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
query := ""
if len(request.Messages) != 0 {
query = request.Messages[len(request.Messages)-1].Content
query = request.Messages[len(request.Messages)-1].StringContent()
}
return &AIProxyLibraryRequest{
Model: request.Model,
Expand Down
8 changes: 4 additions & 4 deletions controller/relay-ali.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,18 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, AliMessage{
User: message.Content,
User: message.StringContent(),
Bot: "Okay",
})
continue
} else {
if i == len(request.Messages)-1 {
prompt = message.Content
prompt = message.StringContent()
break
}
messages = append(messages, AliMessage{
User: message.Content,
Bot: request.Messages[i+1].Content,
User: message.StringContent(),
Bot: request.Messages[i+1].StringContent(),
})
i++
}
Expand Down
4 changes: 2 additions & 2 deletions controller/relay-baidu.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
if message.Role == "system" {
messages = append(messages, BaiduMessage{
Role: "user",
Content: message.Content,
Content: message.StringContent(),
})
messages = append(messages, BaiduMessage{
Role: "assistant",
Expand All @@ -98,7 +98,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
} else {
messages = append(messages, BaiduMessage{
Role: message.Role,
Content: message.Content,
Content: message.StringContent(),
})
}
}
Expand Down
2 changes: 1 addition & 1 deletion controller/relay-openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
if textResponse.Usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range textResponse.Choices {
completionTokens += countTokenText(choice.Message.Content, model)
completionTokens += countTokenText(choice.Message.StringContent(), model)
}
textResponse.Usage = Usage{
PromptTokens: promptTokens,
Expand Down
2 changes: 1 addition & 1 deletion controller/relay-palm.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
}
for _, message := range textRequest.Messages {
palmMessage := PaLMChatMessage{
Content: message.Content,
Content: message.StringContent(),
}
if message.Role == "user" {
palmMessage.Author = "0"
Expand Down
4 changes: 2 additions & 2 deletions controller/relay-tencent.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
if message.Role == "system" {
messages = append(messages, TencentMessage{
Role: "user",
Content: message.Content,
Content: message.StringContent(),
})
messages = append(messages, TencentMessage{
Role: "assistant",
Expand All @@ -93,7 +93,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
continue
}
messages = append(messages, TencentMessage{
Content: message.Content,
Content: message.StringContent(),
Role: message.Role,
})
}
Expand Down
2 changes: 1 addition & 1 deletion controller/relay-utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func countTokenMessages(messages []Message, model string) int {
tokenNum := 0
for _, message := range messages {
tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Content)
tokenNum += getTokenNum(tokenEncoder, message.StringContent())
tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Name != nil {
tokenNum += tokensPerName
Expand Down
4 changes: 2 additions & 2 deletions controller/relay-xunfei.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
if message.Role == "system" {
messages = append(messages, XunfeiMessage{
Role: "user",
Content: message.Content,
Content: message.StringContent(),
})
messages = append(messages, XunfeiMessage{
Role: "assistant",
Expand All @@ -90,7 +90,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
} else {
messages = append(messages, XunfeiMessage{
Role: message.Role,
Content: message.Content,
Content: message.StringContent(),
})
}
}
Expand Down
4 changes: 2 additions & 2 deletions controller/relay-zhipu.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
if message.Role == "system" {
messages = append(messages, ZhipuMessage{
Role: "system",
Content: message.Content,
Content: message.StringContent(),
})
messages = append(messages, ZhipuMessage{
Role: "user",
Expand All @@ -123,7 +123,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
} else {
messages = append(messages, ZhipuMessage{
Role: message.Role,
Content: message.Content,
Content: message.StringContent(),
})
}
}
Expand Down
41 changes: 40 additions & 1 deletion controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,49 @@ import (

type Message struct {
Role string `json:"role"`
Content string `json:"content"`
Content any `json:"content"`
Name *string `json:"name,omitempty"`
}

type ImageURL struct {
Url string `json:"url,omitempty"`
Detail string `json:"detail,omitempty"`
}

type TextContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
}

type ImageContent struct {
Type string `json:"type,omitempty"`
ImageURL *ImageURL `json:"image_url,omitempty"`
}

func (m Message) StringContent() string {
content, ok := m.Content.(string)
if ok {
return content
}
contentList, ok := m.Content.([]any)
if ok {
var contentStr string
for _, contentItem := range contentList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
if contentMap["type"] == "text" {
if subStr, ok := contentMap["text"].(string); ok {
contentStr += subStr
}
}
}
return contentStr
}
return ""
}

const (
RelayModeUnknown = iota
RelayModeChatCompletions
Expand Down

0 comments on commit 495fc62

Please sign in to comment.