From 80aeafc235e0631b0019a12e4cef7c7c9723e4f3 Mon Sep 17 00:00:00 2001 From: MartialBE Date: Sun, 11 Aug 2024 14:38:09 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=96=20chore:=20support=20glm-4-alltool?= =?UTF-8?q?s?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- providers/zhipu/chat.go | 72 ++++++++- providers/zhipu/type.go | 202 ++++++++++++++++++++++++- web/src/views/Channel/type/Plugin.json | 36 +++++ 3 files changed, 297 insertions(+), 13 deletions(-) diff --git a/providers/zhipu/chat.go b/providers/zhipu/chat.go index 04dee8ca4..7e9738866 100644 --- a/providers/zhipu/chat.go +++ b/providers/zhipu/chat.go @@ -8,6 +8,7 @@ import ( "one-api/common/config" "one-api/common/requester" "one-api/common/utils" + "one-api/model" "one-api/types" "strings" ) @@ -15,9 +16,14 @@ import ( type zhipuStreamHandler struct { Usage *types.Usage Request *types.ChatCompletionRequest + IsCode bool } func (p *ZhipuProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { + if request.Model == "glm-4-alltools" { + return nil, common.ErrorWrapper(nil, "glm-4-alltools 只能stream模式下请求", http.StatusBadRequest) + } + req, errWithCode := p.getChatRequest(request) if errWithCode != nil { return nil, errWithCode @@ -188,6 +194,15 @@ func (p *ZhipuProvider) pluginHandle(request *ZhipuRequest) { plugin := p.Channel.Plugin.Data() + if request.Model == "glm-4-alltools" { + glm4AlltoolsPlugin(request, plugin) + return + } + + generalPlugin(request, plugin) +} + +func generalPlugin(request *ZhipuRequest, plugin model.PluginType) { // 检测是否开启了 retrieval 插件 if pRetrieval, ok := plugin["retrieval"]; ok { if knowledgeId, ok := pRetrieval["knowledge_id"].(string); ok && knowledgeId != "" { @@ -203,8 +218,6 @@ func (p *ZhipuProvider) pluginHandle(request *ZhipuRequest) { } request.Tools = append(request.Tools, retrieval) - - // 如果开启了 retrieval 插件,web_search 无效 return } } @@ -222,6 +235,43 @@ func (p *ZhipuProvider) pluginHandle(request *ZhipuRequest) { } } +func glm4AlltoolsPlugin(request *ZhipuRequest, plugin model.PluginType) { + if pWeb, ok := plugin["web_browser"]; ok { + if enable, ok := pWeb["enable"].(bool); ok && enable { + request.Tools = append(request.Tools, ZhipuTool{ + Type: "web_browser", + WebBrowser: &map[string]bool{ + "enable": true, + }, + }) + } + } + + if pDW, ok := plugin["drawing_tool"]; ok { + if enable, ok := pDW["enable"].(bool); ok && enable { + request.Tools = append(request.Tools, ZhipuTool{ + Type: "drawing_tool", + DrawingTool: &map[string]bool{ + "enable": true, + }, + }) + } + } + + if pCode, ok := plugin["code_interpreter"]; ok { + if sandbox, ok := pCode["sandbox"].(string); ok && sandbox != "" { + codeInterpreter := ZhipuTool{ + Type: "code_interpreter", + CodeInterpreter: &map[string]string{ + "sandbox": sandbox, + }, + } + + request.Tools = append(request.Tools, codeInterpreter) + } + } +} + // 转换为OpenAI聊天流式请求体 func (h *zhipuStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) { // 如果rawLine 前缀不为data: 或者 meta:,则直接返回 @@ -262,18 +312,30 @@ func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamRes Model: h.Request.Model, } - if zhipuResponse.Choices[0].Delta.ToolCalls != nil { - choice := zhipuResponse.Choices[0] + if zhipuResponse.IsFunction() { + choice := zhipuResponse.Choices[0].ToOpenAIChoice() choice.CheckChoice(h.Request) choices := choice.ConvertOpenaiStream() for _, choice := range choices { chatCompletionCopy := streamResponse chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice} responseBody, _ := json.Marshal(chatCompletionCopy) + dataChan <- string(responseBody) } } else { - streamResponse.Choices = zhipuResponse.Choices + streamResponse.Choices = zhipuResponse.ToOpenAIChoices() + + if !h.IsCode && zhipuResponse.IsCodeInterpreter() { + h.IsCode = true + streamResponse.Choices[0].Delta.Content = "```python\n\n" + streamResponse.Choices[0].Delta.Content + } + + if h.IsCode && !zhipuResponse.IsCodeInterpreter() { + h.IsCode = false + streamResponse.Choices[0].Delta.Content = "\n```\n\n" + streamResponse.Choices[0].Delta.Content + } + responseBody, _ := json.Marshal(streamResponse) dataChan <- string(responseBody) } diff --git a/providers/zhipu/type.go b/providers/zhipu/type.go index 31d5beb4f..17f1d2e6a 100644 --- a/providers/zhipu/type.go +++ b/providers/zhipu/type.go @@ -16,10 +16,13 @@ type ZhipuRetrieval struct { } type ZhipuTool struct { - Type string `json:"type"` - Function *types.ChatCompletionFunction `json:"function,omitempty"` - WebSearch *ZhipuWebSearch `json:"web_search,omitempty"` - Retrieval *ZhipuRetrieval `json:"retrieval,omitempty"` + Type string `json:"type"` + Function *types.ChatCompletionFunction `json:"function,omitempty"` + WebSearch *ZhipuWebSearch `json:"web_search,omitempty"` + Retrieval *ZhipuRetrieval `json:"retrieval,omitempty"` + WebBrowser any `json:"web_browser,omitempty"` + DrawingTool any `json:"drawing_tool,omitempty"` + CodeInterpreter any `json:"code_interpreter,omitempty"` } type ZhipuRequest struct { Model string `json:"model"` @@ -50,13 +53,196 @@ type ZhipuResponse struct { } type ZhipuStreamResponse struct { - ID string `json:"id"` - Created int64 `json:"created"` - Choices []types.ChatCompletionStreamChoice `json:"choices"` - Usage *types.Usage `json:"usage,omitempty"` + ID string `json:"id"` + Created int64 `json:"created"` + Choices []ZhipuChoice `json:"choices"` + Usage *types.Usage `json:"usage,omitempty"` ZhipuResponseError } +func (z *ZhipuStreamResponse) ToOpenAIChoices() []types.ChatCompletionStreamChoice { + choices := make([]types.ChatCompletionStreamChoice, 0, len(z.Choices)) + + for _, choice := range z.Choices { + choices = append(choices, choice.ToOpenAIChoice()) + } + + return choices +} + +func (z *ZhipuStreamResponse) IsFunction() bool { + if z.Choices == nil { + return false + } + + choice := z.Choices[0] + + return choice.IsFunction() +} + +func (z *ZhipuStreamResponse) IsCodeInterpreter() bool { + if z.Choices == nil { + return false + } + + choice := z.Choices[0] + + if choice.Delta.ToolCalls == nil { + return false + } + + toolCall := choice.Delta.ToolCalls[0] + + return toolCall.Type == "code_interpreter" && toolCall.CodeInterpreter.Outputs == nil +} + +type ZhipuChoice struct { + Index int `json:"index"` + Delta ZhipuDelta `json:"delta"` + FinishReason string `json:"finish_reason"` + ContentFilterResults any `json:"content_filter_results,omitempty"` + Usage *types.Usage `json:"usage,omitempty"` +} + +func (z *ZhipuChoice) ToOpenAIChoice() types.ChatCompletionStreamChoice { + choice := types.ChatCompletionStreamChoice{ + Index: z.Index, + Delta: z.Delta.ToOpenAIDelta(), + ContentFilterResults: z.ContentFilterResults, + Usage: z.Usage, + } + + if z.IsFunction() || z.FinishReason != "tool_calls" { + choice.FinishReason = z.FinishReason + } + + return choice +} + +func (z *ZhipuChoice) IsFunction() bool { + if z.Delta.ToolCalls == nil { + return false + } + toolCall := z.Delta.ToolCalls[0] + + return toolCall.Type == "function" +} + +type ZhipuDelta struct { + Content string `json:"content,omitempty"` + Role string `json:"role,omitempty"` + ToolCalls []*ZhipuToolCalls `json:"tool_calls,omitempty"` +} + +func (z *ZhipuDelta) ToOpenAIDelta() types.ChatCompletionStreamChoiceDelta { + delta := types.ChatCompletionStreamChoiceDelta{ + Role: z.Role, + } + + content := z.Content + changeRole := false + if z.ToolCalls != nil { + toolCalls := make([]*types.ChatCompletionToolCalls, 0) + for _, toolCall := range z.ToolCalls { + switch toolCall.Type { + case "web_browser": + content += toolCall.WebBrowser.ToMarkdown() + changeRole = true + case "drawing_tool": + content += toolCall.DrawingTool.ToMarkdown() + changeRole = true + case "code_interpreter": + content += toolCall.CodeInterpreter.ToMarkdown() + changeRole = true + default: + toolCalls = append(toolCalls, &toolCall.ChatCompletionToolCalls) + delta.ToolCalls = toolCalls + } + } + } + + if changeRole { + delta.Role = types.ChatMessageRoleAssistant + } + + delta.Content = content + return delta +} + +type ZhipuToolCalls struct { + types.ChatCompletionToolCalls + WebBrowser *ZhipuPlugin[WebBrowserPlugin] `json:"web_browser,omitempty"` + DrawingTool *ZhipuPlugin[DrawingToolPlugin] `json:"drawing_tool,omitempty"` + CodeInterpreter *ZhipuPlugin[CodeInterpreterPlugin] `json:"code_interpreter,omitempty"` +} + +type PluginMD interface { + ToMarkdown() string +} + +type ZhipuPlugin[T PluginMD] struct { + Input string `json:"input,omitempty"` + Outputs []T `json:"outputs,omitempty"` +} + +func (z *ZhipuPlugin[T]) ToMarkdown() string { + if z.Outputs == nil { + return z.Input + } + + markdown := "\n" + + for _, output := range z.Outputs { + markdown += output.ToMarkdown() + } + + return markdown +} + +type WebBrowserPlugin struct { + Title string `json:"title,omitempty"` + Link string `json:"link,omitempty"` + Content string `json:"content,omitempty"` +} + +func (z WebBrowserPlugin) ToMarkdown() string { + markdown := "" + markdown += "[" + z.Title + "](" + z.Link + ")\n" + markdown += z.Content + "\n\n" + + return markdown +} + +type DrawingToolPlugin struct { + Image string `json:"image,omitempty"` +} + +func (z DrawingToolPlugin) ToMarkdown() string { + markdown := "" + markdown += "![" + z.Image + "](" + z.Image + ")\n\n" + + return markdown +} + +type CodeInterpreterPlugin struct { + Type string `json:"type,omitempty"` + File string `json:"file,omitempty"` + Logs string `json:"logs,omitempty"` +} + +func (z CodeInterpreterPlugin) ToMarkdown() string { + markdown := "" + + switch z.Type { + case "file": + markdown += "[结果文件](" + z.File + ")\n" + case "logs": + markdown += "```\n" + z.Logs + "\n```\n" + } + + return markdown +} + func (z *ZhipuStreamResponse) GetResponseText() (responseText string) { for _, choice := range z.Choices { responseText += choice.Delta.Content diff --git a/web/src/views/Channel/type/Plugin.json b/web/src/views/Channel/type/Plugin.json index b28eab46a..5a57a994c 100644 --- a/web/src/views/Channel/type/Plugin.json +++ b/web/src/views/Channel/type/Plugin.json @@ -29,6 +29,42 @@ "required": true } } + }, + "web_browser": { + "name": "搜索工具", + "description": "使用搜索工具功能,仅glm-4-alltools有效", + "params": { + "enable": { + "name": "启用", + "description": "是否启用搜索工具", + "type": "bool", + "required": true + } + } + }, + "drawing_tool": { + "name": "绘图工具", + "description": "使用绘图工具功能,仅glm-4-alltools有效", + "params": { + "enable": { + "name": "启用", + "description": "是否启用绘图工具", + "type": "bool", + "required": true + } + } + }, + "code_interpreter": { + "name": "代码工具 ", + "description": "使用代码工具 功能,仅glm-4-alltools有效", + "params": { + "sandbox": { + "name": "沙盒模式", + "description": "auto 自动调用沙盒环境执行代码,none 不启用沙盒环境,必填", + "type": "string", + "required": false + } + } } }, "17": {