Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
Calcium-Ion committed Dec 24, 2024
2 parents 7180e6f + 61495a4 commit be0c240
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 75 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ build
logs
web/dist
.env
one-api
4 changes: 3 additions & 1 deletion common/str.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ func StrToMap(str string) map[string]interface{} {
m := make(map[string]interface{})
err := json.Unmarshal([]byte(str), &m)
if err != nil {
return nil
return map[string]interface{}{
"result": str,
}
}
return m
}
Expand Down
183 changes: 109 additions & 74 deletions relay/channel/gemini/relay-gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
}
}

tool_call_ids := make(map[string]string)
//shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {

Expand All @@ -108,6 +108,27 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
},
}
continue
} else if message.Role == "tool" {
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role != "user" {
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
Role: "user",
})
}
var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
name := ""
if message.Name != nil {
name = *message.Name
} else if val, exists := tool_call_ids[message.ToolCallId]; exists {
name = val
}
functionResp := &FunctionResponse{
Name: name,
Response: common.StrToMap(message.StringContent()),
}
*parts = append(*parts, GeminiPart{
FunctionResponse: functionResp,
})
continue
}
var parts []GeminiPart
content := GeminiChatContent{
Expand All @@ -125,62 +146,49 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
},
}
parts = append(parts, toolCall)
tool_call_ids[call.ID] = call.Function.Name
}
}
if !isToolCall {
if message.Role == "tool" {
content.Role = "user"
name := ""
if message.Name != nil {
name = *message.Name
}
functionResp := &FunctionResponse{
Name: name,
Response: common.StrToMap(message.StringContent()),
}
parts = append(parts, GeminiPart{
FunctionResponse: functionResp,
})
} else {
openaiContent := message.ParseContent()
imageNum := 0
for _, part := range openaiContent {
if part.Type == dto.ContentTypeText {
openaiContent := message.ParseContent()
imageNum := 0
for _, part := range openaiContent {
if part.Type == dto.ContentTypeText {
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == dto.ContentTypeImageURL {
imageNum += 1

if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
}
// 判断是否是url
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
// 是url,获取图片的类型和base64编码的数据
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
parts = append(parts, GeminiPart{
Text: part.Text,
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
} else if part.Type == dto.ContentTypeImageURL {
imageNum += 1

if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
}
// 判断是否是url
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
// 是url,获取图片的类型和base64编码的数据
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
} else {
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
if err != nil {
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: "image/" + format,
Data: base64String,
},
})
} else {
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
if err != nil {
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: "image/" + format,
Data: base64String,
},
})
}
}
}
}

content.Parts = parts

// there's no assistant role in gemini and API shall vomit if Role is not user or model
Expand Down Expand Up @@ -242,30 +250,46 @@ func (g *GeminiChatResponse) GetResponseText() string {
return ""
}

func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall {
var toolCalls []dto.ToolCall

item := candidate.Content.Parts[0]
if item.FunctionCall == nil {
return toolCalls
}
func getToolCall(item *GeminiPart) *dto.ToolCall {
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
if err != nil {
//common.SysError("getToolCalls failed: " + err.Error())
return toolCalls
//common.SysError("getToolCall failed: " + err.Error())
return nil
}
toolCall := dto.ToolCall{
return &dto.ToolCall{
ID: fmt.Sprintf("call_%s", common.GetUUID()),
Type: "function",
Function: dto.FunctionCall{
Arguments: string(argsBytes),
Name: item.FunctionCall.FunctionName,
},
}
toolCalls = append(toolCalls, toolCall)
return toolCalls
}

// func getToolCalls(candidate *GeminiChatCandidate, index int) []dto.ToolCall {
// var toolCalls []dto.ToolCall

// item := candidate.Content.Parts[index]
// if item.FunctionCall == nil {
// return toolCalls
// }
// argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
// if err != nil {
// //common.SysError("getToolCalls failed: " + err.Error())
// return toolCalls
// }
// toolCall := dto.ToolCall{
// ID: fmt.Sprintf("call_%s", common.GetUUID()),
// Type: "function",
// Function: dto.FunctionCall{
// Arguments: string(argsBytes),
// Name: item.FunctionCall.FunctionName,
// },
// }
// toolCalls = append(toolCalls, toolCall)
// return toolCalls
// }

func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Expand All @@ -275,6 +299,8 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
}
content, _ := json.Marshal("")
for i, candidate := range response.Candidates {
// jsonData, _ := json.MarshalIndent(candidate, "", " ")
// common.SysLog(fmt.Sprintf("candidate: %v", string(jsonData)))
choice := dto.OpenAITextResponseChoice{
Index: i,
Message: dto.Message{
Expand All @@ -284,16 +310,20 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
FinishReason: constant.FinishReasonStop,
}
if len(candidate.Content.Parts) > 0 {
if candidate.Content.Parts[0].FunctionCall != nil {
choice.FinishReason = constant.FinishReasonToolCalls
choice.Message.SetToolCalls(getToolCalls(&candidate))
} else {
var texts []string
for _, part := range candidate.Content.Parts {
var texts []string
var tool_calls []dto.ToolCall
for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil {
choice.FinishReason = constant.FinishReasonToolCalls
if call := getToolCall(&part); call != nil {
tool_calls = append(tool_calls, *call)
}
} else {
texts = append(texts, part.Text)
}
choice.Message.SetStringContent(strings.Join(texts, "\n"))
}
choice.Message.SetStringContent(strings.Join(texts, "\n"))
choice.Message.SetToolCalls(tool_calls)
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
Expand All @@ -304,18 +334,23 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
var choice dto.ChatCompletionsStreamResponseChoice
//choice.Delta.SetContentString(geminiResponse.GetResponseText())
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
respFirstParts := geminiResponse.Candidates[0].Content.Parts
if respFirstParts[0].FunctionCall != nil {
// function response
choice.Delta.ToolCalls = getToolCalls(&geminiResponse.Candidates[0])
} else {
// text response
var texts []string
for _, part := range respFirstParts {
var texts []string
var tool_calls []dto.ToolCall
for _, part := range geminiResponse.Candidates[0].Content.Parts {
if part.FunctionCall != nil {
if call := getToolCall(&part); call != nil {
tool_calls = append(tool_calls, *call)
}
} else {
texts = append(texts, part.Text)
}
}
if len(texts) > 0 {
choice.Delta.SetContentString(strings.Join(texts, "\n"))
}
if len(tool_calls) > 0 {
choice.Delta.ToolCalls = tool_calls
}
}
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
Expand Down

0 comments on commit be0c240

Please sign in to comment.