Skip to content

Commit

Permalink
🔖 chore: support glm-4-alltools
Browse files Browse the repository at this point in the history
  • Loading branch information
MartialBE committed Aug 11, 2024
1 parent 1bd449e commit 80aeafc
Show file tree
Hide file tree
Showing 3 changed files with 297 additions and 13 deletions.
72 changes: 67 additions & 5 deletions providers/zhipu/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,22 @@ import (
"one-api/common/config"
"one-api/common/requester"
"one-api/common/utils"
"one-api/model"
"one-api/types"
"strings"
)

type zhipuStreamHandler struct {
Usage *types.Usage
Request *types.ChatCompletionRequest
IsCode bool
}

func (p *ZhipuProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
if request.Model == "glm-4-alltools" {
return nil, common.ErrorWrapper(nil, "glm-4-alltools 只能stream模式下请求", http.StatusBadRequest)
}

req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
Expand Down Expand Up @@ -188,6 +194,15 @@ func (p *ZhipuProvider) pluginHandle(request *ZhipuRequest) {

plugin := p.Channel.Plugin.Data()

if request.Model == "glm-4-alltools" {
glm4AlltoolsPlugin(request, plugin)
return
}

generalPlugin(request, plugin)
}

func generalPlugin(request *ZhipuRequest, plugin model.PluginType) {
// 检测是否开启了 retrieval 插件
if pRetrieval, ok := plugin["retrieval"]; ok {
if knowledgeId, ok := pRetrieval["knowledge_id"].(string); ok && knowledgeId != "" {
Expand All @@ -203,8 +218,6 @@ func (p *ZhipuProvider) pluginHandle(request *ZhipuRequest) {
}

request.Tools = append(request.Tools, retrieval)

// 如果开启了 retrieval 插件,web_search 无效
return
}
}
Expand All @@ -222,6 +235,43 @@ func (p *ZhipuProvider) pluginHandle(request *ZhipuRequest) {
}
}

func glm4AlltoolsPlugin(request *ZhipuRequest, plugin model.PluginType) {
if pWeb, ok := plugin["web_browser"]; ok {
if enable, ok := pWeb["enable"].(bool); ok && enable {
request.Tools = append(request.Tools, ZhipuTool{
Type: "web_browser",
WebBrowser: &map[string]bool{
"enable": true,
},
})
}
}

if pDW, ok := plugin["drawing_tool"]; ok {
if enable, ok := pDW["enable"].(bool); ok && enable {
request.Tools = append(request.Tools, ZhipuTool{
Type: "drawing_tool",
DrawingTool: &map[string]bool{
"enable": true,
},
})
}
}

if pCode, ok := plugin["code_interpreter"]; ok {
if sandbox, ok := pCode["sandbox"].(string); ok && sandbox != "" {
codeInterpreter := ZhipuTool{
Type: "code_interpreter",
CodeInterpreter: &map[string]string{
"sandbox": sandbox,
},
}

request.Tools = append(request.Tools, codeInterpreter)
}
}
}

// 转换为OpenAI聊天流式请求体
func (h *zhipuStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data: 或者 meta:,则直接返回
Expand Down Expand Up @@ -262,18 +312,30 @@ func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamRes
Model: h.Request.Model,
}

if zhipuResponse.Choices[0].Delta.ToolCalls != nil {
choice := zhipuResponse.Choices[0]
if zhipuResponse.IsFunction() {
choice := zhipuResponse.Choices[0].ToOpenAIChoice()
choice.CheckChoice(h.Request)
choices := choice.ConvertOpenaiStream()
for _, choice := range choices {
chatCompletionCopy := streamResponse
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
responseBody, _ := json.Marshal(chatCompletionCopy)

dataChan <- string(responseBody)
}
} else {
streamResponse.Choices = zhipuResponse.Choices
streamResponse.Choices = zhipuResponse.ToOpenAIChoices()

if !h.IsCode && zhipuResponse.IsCodeInterpreter() {
h.IsCode = true
streamResponse.Choices[0].Delta.Content = "```python\n\n" + streamResponse.Choices[0].Delta.Content
}

if h.IsCode && !zhipuResponse.IsCodeInterpreter() {
h.IsCode = false
streamResponse.Choices[0].Delta.Content = "\n```\n\n" + streamResponse.Choices[0].Delta.Content
}

responseBody, _ := json.Marshal(streamResponse)
dataChan <- string(responseBody)
}
Expand Down
202 changes: 194 additions & 8 deletions providers/zhipu/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@ type ZhipuRetrieval struct {
}

type ZhipuTool struct {
Type string `json:"type"`
Function *types.ChatCompletionFunction `json:"function,omitempty"`
WebSearch *ZhipuWebSearch `json:"web_search,omitempty"`
Retrieval *ZhipuRetrieval `json:"retrieval,omitempty"`
Type string `json:"type"`
Function *types.ChatCompletionFunction `json:"function,omitempty"`
WebSearch *ZhipuWebSearch `json:"web_search,omitempty"`
Retrieval *ZhipuRetrieval `json:"retrieval,omitempty"`
WebBrowser any `json:"web_browser,omitempty"`
DrawingTool any `json:"drawing_tool,omitempty"`
CodeInterpreter any `json:"code_interpreter,omitempty"`
}
type ZhipuRequest struct {
Model string `json:"model"`
Expand Down Expand Up @@ -50,13 +53,196 @@ type ZhipuResponse struct {
}

type ZhipuStreamResponse struct {
ID string `json:"id"`
Created int64 `json:"created"`
Choices []types.ChatCompletionStreamChoice `json:"choices"`
Usage *types.Usage `json:"usage,omitempty"`
ID string `json:"id"`
Created int64 `json:"created"`
Choices []ZhipuChoice `json:"choices"`
Usage *types.Usage `json:"usage,omitempty"`
ZhipuResponseError
}

func (z *ZhipuStreamResponse) ToOpenAIChoices() []types.ChatCompletionStreamChoice {
choices := make([]types.ChatCompletionStreamChoice, 0, len(z.Choices))

for _, choice := range z.Choices {
choices = append(choices, choice.ToOpenAIChoice())
}

return choices
}

func (z *ZhipuStreamResponse) IsFunction() bool {
if z.Choices == nil {
return false
}

choice := z.Choices[0]

return choice.IsFunction()
}

func (z *ZhipuStreamResponse) IsCodeInterpreter() bool {
if z.Choices == nil {
return false
}

choice := z.Choices[0]

if choice.Delta.ToolCalls == nil {
return false
}

toolCall := choice.Delta.ToolCalls[0]

return toolCall.Type == "code_interpreter" && toolCall.CodeInterpreter.Outputs == nil
}

type ZhipuChoice struct {
Index int `json:"index"`
Delta ZhipuDelta `json:"delta"`
FinishReason string `json:"finish_reason"`
ContentFilterResults any `json:"content_filter_results,omitempty"`
Usage *types.Usage `json:"usage,omitempty"`
}

func (z *ZhipuChoice) ToOpenAIChoice() types.ChatCompletionStreamChoice {
choice := types.ChatCompletionStreamChoice{
Index: z.Index,
Delta: z.Delta.ToOpenAIDelta(),
ContentFilterResults: z.ContentFilterResults,
Usage: z.Usage,
}

if z.IsFunction() || z.FinishReason != "tool_calls" {
choice.FinishReason = z.FinishReason
}

return choice
}

func (z *ZhipuChoice) IsFunction() bool {
if z.Delta.ToolCalls == nil {
return false
}
toolCall := z.Delta.ToolCalls[0]

return toolCall.Type == "function"
}

type ZhipuDelta struct {
Content string `json:"content,omitempty"`
Role string `json:"role,omitempty"`
ToolCalls []*ZhipuToolCalls `json:"tool_calls,omitempty"`
}

func (z *ZhipuDelta) ToOpenAIDelta() types.ChatCompletionStreamChoiceDelta {
delta := types.ChatCompletionStreamChoiceDelta{
Role: z.Role,
}

content := z.Content
changeRole := false
if z.ToolCalls != nil {
toolCalls := make([]*types.ChatCompletionToolCalls, 0)
for _, toolCall := range z.ToolCalls {
switch toolCall.Type {
case "web_browser":
content += toolCall.WebBrowser.ToMarkdown()
changeRole = true
case "drawing_tool":
content += toolCall.DrawingTool.ToMarkdown()
changeRole = true
case "code_interpreter":
content += toolCall.CodeInterpreter.ToMarkdown()
changeRole = true
default:
toolCalls = append(toolCalls, &toolCall.ChatCompletionToolCalls)
delta.ToolCalls = toolCalls
}
}
}

if changeRole {
delta.Role = types.ChatMessageRoleAssistant
}

delta.Content = content
return delta
}

type ZhipuToolCalls struct {
types.ChatCompletionToolCalls
WebBrowser *ZhipuPlugin[WebBrowserPlugin] `json:"web_browser,omitempty"`
DrawingTool *ZhipuPlugin[DrawingToolPlugin] `json:"drawing_tool,omitempty"`
CodeInterpreter *ZhipuPlugin[CodeInterpreterPlugin] `json:"code_interpreter,omitempty"`
}

type PluginMD interface {
ToMarkdown() string
}

type ZhipuPlugin[T PluginMD] struct {
Input string `json:"input,omitempty"`
Outputs []T `json:"outputs,omitempty"`
}

func (z *ZhipuPlugin[T]) ToMarkdown() string {
if z.Outputs == nil {
return z.Input
}

markdown := "\n"

for _, output := range z.Outputs {
markdown += output.ToMarkdown()
}

return markdown
}

type WebBrowserPlugin struct {
Title string `json:"title,omitempty"`
Link string `json:"link,omitempty"`
Content string `json:"content,omitempty"`
}

func (z WebBrowserPlugin) ToMarkdown() string {
markdown := ""
markdown += "[" + z.Title + "](" + z.Link + ")\n"
markdown += z.Content + "\n\n"

return markdown
}

type DrawingToolPlugin struct {
Image string `json:"image,omitempty"`
}

func (z DrawingToolPlugin) ToMarkdown() string {
markdown := ""
markdown += "![" + z.Image + "](" + z.Image + ")\n\n"

return markdown
}

type CodeInterpreterPlugin struct {
Type string `json:"type,omitempty"`
File string `json:"file,omitempty"`
Logs string `json:"logs,omitempty"`
}

func (z CodeInterpreterPlugin) ToMarkdown() string {
markdown := ""

switch z.Type {
case "file":
markdown += "[结果文件](" + z.File + ")\n"
case "logs":
markdown += "```\n" + z.Logs + "\n```\n"
}

return markdown
}

func (z *ZhipuStreamResponse) GetResponseText() (responseText string) {
for _, choice := range z.Choices {
responseText += choice.Delta.Content
Expand Down
36 changes: 36 additions & 0 deletions web/src/views/Channel/type/Plugin.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,42 @@
"required": true
}
}
},
"web_browser": {
"name": "搜索工具",
"description": "使用搜索工具功能,仅glm-4-alltools有效",
"params": {
"enable": {
"name": "启用",
"description": "是否启用搜索工具",
"type": "bool",
"required": true
}
}
},
"drawing_tool": {
"name": "绘图工具",
"description": "使用绘图工具功能,仅glm-4-alltools有效",
"params": {
"enable": {
"name": "启用",
"description": "是否启用绘图工具",
"type": "bool",
"required": true
}
}
},
"code_interpreter": {
"name": "代码工具 ",
"description": "使用代码工具 功能,仅glm-4-alltools有效",
"params": {
"sandbox": {
"name": "沙盒模式",
"description": "auto 自动调用沙盒环境执行代码,none 不启用沙盒环境,必填",
"type": "string",
"required": false
}
}
}
},
"17": {
Expand Down

0 comments on commit 80aeafc

Please sign in to comment.