From 04a91040628ae725916119bb133e80131cd62c33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9F=A9=E8=B4=A4=E6=B6=9B?= <601803023@qq.com> Date: Fri, 9 Aug 2024 09:55:40 +0800 Subject: [PATCH] feat: support gemini ai model (#1173) --- plugins/wasm-go/extensions/ai-proxy/README.md | 66 ++ .../extensions/ai-proxy/provider/baidu.go | 8 +- .../extensions/ai-proxy/provider/gemini.go | 606 ++++++++++++++++++ .../extensions/ai-proxy/provider/model.go | 19 + .../extensions/ai-proxy/provider/provider.go | 11 + 5 files changed, 706 insertions(+), 4 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-proxy/provider/gemini.go diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index 709d72bc2e..af0d56fbe9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -157,6 +157,13 @@ Cloudflare Workers AI 所对应的 `type` 为 `cloudflare`。它特有的配置 讯飞星火认知大模型的`apiTokens`字段值为`APIKey:APISecret`。即填入自己的APIKey与APISecret,并以`:`分隔。 +#### Gemini + +Gemini 所对应的 `type` 为 `gemini`。它特有的配置字段如下: + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +| --------------------- | -------- | -------- |-----|-------------------------------------------------------------------------------------------------| +| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) | ## 用法示例 @@ -942,6 +949,65 @@ provider: } ``` +### 使用 OpenAI 协议代理 gemini 服务 + +**配置信息** + +```yaml +provider: + type: gemini + apiTokens: + - "YOUR_GEMINI_API_TOKEN" + modelMapping: + "*": "gemini-pro" + geminiSafetySetting: + "HARM_CATEGORY_SEXUALLY_EXPLICIT" :"BLOCK_NONE" + "HARM_CATEGORY_HATE_SPEECH" :"BLOCK_NONE" + "HARM_CATEGORY_HARASSMENT" :"BLOCK_NONE" + "HARM_CATEGORY_DANGEROUS_CONTENT" :"BLOCK_NONE" +``` + +**请求示例** + +```json +{ + "model": "gpt-3.5", + "messages": [ + { + "role": "user", + "content": "Who are you?" + } + ], + "stream": false +} +``` + +**响应示例** + +```json +{ + "id": "chatcmpl-b010867c-0d3f-40ba-95fd-4e8030551aeb", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "I am a large multi-modal model, trained by Google. I am designed to provide information and answer questions to the best of my abilities." + }, + "finish_reason": "stop" + } + ], + "created": 1722756984, + "model": "gemini-pro", + "object": "chat.completion", + "usage": { + "prompt_tokens": 5, + "completion_tokens": 29, + "total_tokens": 34 + } +} +``` + ## 完整配置示例 ### Kubernetes 示例 diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go index 3ab1626ed3..cf96ab5ede 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -83,7 +83,7 @@ func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, return types.ActionContinue, errors.New("request model is empty") } // 根据模型重写requestPath - path := b.GetRequestPath(request.Model) + path := b.getRequestPath(request.Model) _ = util.OverwriteRequestPath(path) if b.config.context == nil { @@ -126,7 +126,7 @@ func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, } request.Model = mappedModel ctx.SetContext(ctxKeyFinalRequestModel, request.Model) - path := b.GetRequestPath(mappedModel) + path := b.getRequestPath(mappedModel) _ = util.OverwriteRequestPath(path) if b.config.context == nil { @@ -226,7 +226,7 @@ type baiduTextGenRequest struct { UserId string `json:"user_id,omitempty"` } -func (b *baiduProvider) GetRequestPath(baiduModel string) string { +func (b *baiduProvider) getRequestPath(baiduModel string) string { // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t suffix, ok := baiduModelToPathSuffixMap[baiduModel] if !ok { @@ -326,7 +326,7 @@ func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, resp Created: time.Now().UnixMilli() / 1000, Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), SystemFingerprint: "", - Object: objectChatCompletion, + Object: objectChatCompletionChunk, Choices: []chatCompletionChoice{choice}, Usage: usage{ PromptTokens: response.Usage.PromptTokens, diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go new file mode 100644 index 0000000000..2da0ec1e1f --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -0,0 +1,606 @@ +package provider + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/google/uuid" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "strings" + "time" +) + +// geminiProvider is the provider for google gemini/gemini flash service. + +const ( + geminiApiKeyHeader = "x-goog-api-key" + geminiDomain = "generativelanguage.googleapis.com" +) + +type geminiProviderInitializer struct { +} + +func (g *geminiProviderInitializer) ValidateConfig(config ProviderConfig) error { + if config.apiTokens == nil || len(config.apiTokens) == 0 { + return errors.New("no apiToken found in provider config") + } + return nil +} + +func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &geminiProvider{ + config: config, + contextCache: createContextCache(&config), + }, nil +} + +type geminiProvider struct { + config ProviderConfig + contextCache *contextCache +} + +func (g *geminiProvider) GetProviderType() string { + return providerTypeGemini +} + +func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { + if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + return types.ActionContinue, errUnsupportedApiName + } + + _ = proxywasm.ReplaceHttpRequestHeader(geminiApiKeyHeader, g.config.GetRandomToken()) + _ = util.OverwriteRequestHost(geminiDomain) + + _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") + _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + + // Delay the header processing to allow changing streaming mode in OnRequestBody + return types.HeaderStopIteration, nil +} + +func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { + if apiName == ApiNameChatCompletion { + return g.onChatCompletionRequestBody(ctx, body, log) + } else if apiName == ApiNameEmbeddings { + return g.onEmbeddingsRequestBody(ctx, body, log) + } + return types.ActionContinue, errUnsupportedApiName +} + +func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { + // 使用gemini接口协议 + if g.config.protocol == protocolOriginal { + request := &geminiChatRequest{} + if err := json.Unmarshal(body, request); err != nil { + return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) + } + if request.Model == "" { + return types.ActionContinue, errors.New("request model is empty") + } + // 根据模型重写requestPath + path := g.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream) + _ = util.OverwriteRequestPath(path) + + // 移除多余的model和stream字段 + request = &geminiChatRequest{ + Contents: request.Contents, + SafetySettings: request.SafetySettings, + GenerationConfig: request.GenerationConfig, + Tools: request.Tools, + } + if g.config.context == nil { + return types.ActionContinue, replaceJsonRequestBody(request, log) + } + + err := g.contextCache.GetContent(func(content string, err error) { + defer func() { + _ = proxywasm.ResumeHttpRequest() + }() + + if err != nil { + log.Errorf("failed to load context file: %v", err) + _ = util.SendResponse(500, "ai-proxy.gemini.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) + } + g.setSystemContent(request, content) + if err := replaceJsonRequestBody(request, log); err != nil { + _ = util.SendResponse(500, "ai-proxy.gemini.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) + } + }, log) + if err == nil { + return types.ActionPause, nil + } + return types.ActionContinue, err + } + request := &chatCompletionRequest{} + if err := decodeChatCompletionRequest(body, request); err != nil { + return types.ActionContinue, err + } + + // 映射模型重写requestPath + model := request.Model + if model == "" { + return types.ActionContinue, errors.New("missing model in chat completion request") + } + ctx.SetContext(ctxKeyOriginalRequestModel, model) + mappedModel := getMappedModel(model, g.config.modelMapping, log) + if mappedModel == "" { + return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") + } + request.Model = mappedModel + ctx.SetContext(ctxKeyFinalRequestModel, request.Model) + path := g.getRequestPath(ApiNameChatCompletion, mappedModel, request.Stream) + _ = util.OverwriteRequestPath(path) + + if g.config.context == nil { + geminiRequest := g.buildGeminiChatRequest(request) + return types.ActionContinue, replaceJsonRequestBody(geminiRequest, log) + } + + err := g.contextCache.GetContent(func(content string, err error) { + defer func() { + _ = proxywasm.ResumeHttpRequest() + }() + if err != nil { + log.Errorf("failed to load context file: %v", err) + _ = util.SendResponse(500, "ai-proxy.gemini.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) + } + insertContextMessage(request, content) + geminiRequest := g.buildGeminiChatRequest(request) + if err := replaceJsonRequestBody(geminiRequest, log); err != nil { + _ = util.SendResponse(500, "ai-proxy.gemini.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) + } + }, log) + if err == nil { + return types.ActionPause, nil + } + return types.ActionContinue, err +} + +func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { + // 使用gemini接口协议 + if g.config.protocol == protocolOriginal { + request := &geminiBatchEmbeddingRequest{} + if err := json.Unmarshal(body, request); err != nil { + return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) + } + if request.Model == "" { + return types.ActionContinue, errors.New("request model is empty") + } + // 根据模型重写requestPath + path := g.getRequestPath(ApiNameEmbeddings, request.Model, false) + _ = util.OverwriteRequestPath(path) + + // 移除多余的model字段 + request = &geminiBatchEmbeddingRequest{ + Requests: request.Requests, + } + return types.ActionContinue, replaceJsonRequestBody(request, log) + } + request := &embeddingsRequest{} + if err := json.Unmarshal(body, request); err != nil { + return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) + } + + // 映射模型重写requestPath + model := request.Model + if model == "" { + return types.ActionContinue, errors.New("missing model in embeddings request") + } + ctx.SetContext(ctxKeyOriginalRequestModel, model) + mappedModel := getMappedModel(model, g.config.modelMapping, log) + if mappedModel == "" { + return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") + } + request.Model = mappedModel + ctx.SetContext(ctxKeyFinalRequestModel, request.Model) + path := g.getRequestPath(ApiNameEmbeddings, mappedModel, false) + _ = util.OverwriteRequestPath(path) + + geminiRequest := g.buildBatchEmbeddingRequest(request) + return types.ActionContinue, replaceJsonRequestBody(geminiRequest, log) +} + +func (g *geminiProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { + if g.config.protocol == protocolOriginal { + ctx.DontReadResponseBody() + return types.ActionContinue, nil + } + + _ = proxywasm.RemoveHttpResponseHeader("Content-Length") + return types.ActionContinue, nil +} + +func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { + log.Infof("chunk body:%s", string(chunk)) + if isLastChunk || len(chunk) == 0 { + return nil, nil + } + // sample end event response: + // data: {"candidates": [{"content": {"parts": [{"text": "我是 Gemini,一个大型多模态模型,由 Google 训练。我的职责是尽我所能帮助您,并尽力提供全面且信息丰富的答复。"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 2,"candidatesTokenCount": 35,"totalTokenCount": 37}} + responseBuilder := &strings.Builder{} + lines := strings.Split(string(chunk), "\n") + for _, data := range lines { + if len(data) < 6 { + // ignore blank line or wrong format + continue + } + data = data[6:] + var geminiResp geminiChatResponse + if err := json.Unmarshal([]byte(data), &geminiResp); err != nil { + log.Errorf("unable to unmarshal gemini response: %v", err) + continue + } + response := g.buildChatCompletionStreamResponse(ctx, &geminiResp) + responseBody, err := json.Marshal(response) + if err != nil { + log.Errorf("unable to marshal response: %v", err) + return nil, err + } + g.appendResponse(responseBuilder, string(responseBody)) + } + modifiedResponseChunk := responseBuilder.String() + log.Debugf("=== modified response chunk: %s", modifiedResponseChunk) + return []byte(modifiedResponseChunk), nil +} + +func (g *geminiProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { + if apiName == ApiNameChatCompletion { + return g.onChatCompletionResponseBody(ctx, body, log) + } else if apiName == ApiNameEmbeddings { + return g.onEmbeddingsResponseBody(ctx, body, log) + } + return types.ActionContinue, errUnsupportedApiName +} + +func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { + geminiResponse := &geminiChatResponse{} + if err := json.Unmarshal(body, geminiResponse); err != nil { + return types.ActionContinue, fmt.Errorf("unable to unmarshal gemini chat response: %v", err) + } + if geminiResponse.Error != nil { + return types.ActionContinue, fmt.Errorf("gemini chat completion response error, error_code: %d, error_status:%s, error_message: %s", + geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message) + } + response := g.buildChatCompletionResponse(ctx, geminiResponse) + return types.ActionContinue, replaceJsonResponseBody(response, log) +} + +func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { + geminiResponse := &geminiEmbeddingResponse{} + if err := json.Unmarshal(body, geminiResponse); err != nil { + return types.ActionContinue, fmt.Errorf("unable to unmarshal gemini embeddings response: %v", err) + } + if geminiResponse.Error != nil { + return types.ActionContinue, fmt.Errorf("gemini embeddings response error, error_code: %d, error_status:%s, error_message: %s", + geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message) + } + response := g.buildEmbeddingsResponse(ctx, geminiResponse) + return types.ActionContinue, replaceJsonResponseBody(response, log) +} + +func (g *geminiProvider) getRequestPath(apiName ApiName, geminiModel string, stream bool) string { + action := "" + if apiName == ApiNameEmbeddings { + action = "batchEmbedContents" + } else if stream { + action = "streamGenerateContent?alt=sse" + } else { + action = "generateContent" + } + return fmt.Sprintf("/v1/models/%s:%s", geminiModel, action) +} + +type geminiChatRequest struct { + // Model and Stream are only used when using the gemini original protocol + Model string `json:"model,omitempty"` + Stream bool `json:"stream,omitempty"` + Contents []geminiChatContent `json:"contents"` + SafetySettings []geminiChatSafetySetting `json:"safety_settings,omitempty"` + GenerationConfig geminiChatGenerationConfig `json:"generation_config,omitempty"` + Tools []geminiChatTools `json:"tools,omitempty"` +} + +type geminiChatContent struct { + Role string `json:"role,omitempty"` + Parts []geminiPart `json:"parts"` +} + +type geminiChatSafetySetting struct { + Category string `json:"category"` + Threshold string `json:"threshold"` +} + +type geminiChatGenerationConfig struct { + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` +} + +type geminiChatTools struct { + FunctionDeclarations any `json:"function_declarations,omitempty"` +} + +type geminiPart struct { + Text string `json:"text,omitempty"` + InlineData *geminiInlineData `json:"inlineData,omitempty"` + FunctionCall *geminiFunctionCall `json:"functionCall,omitempty"` +} + +type geminiInlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` +} + +type geminiFunctionCall struct { + FunctionName string `json:"name"` + Arguments any `json:"args"` +} + +func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest) *geminiChatRequest { + var safetySettings []geminiChatSafetySetting + { + } + for category, threshold := range g.config.geminiSafetySetting { + safetySettings = append(safetySettings, geminiChatSafetySetting{ + Category: category, + Threshold: threshold, + }) + } + geminiRequest := geminiChatRequest{ + Contents: make([]geminiChatContent, 0, len(request.Messages)), + SafetySettings: safetySettings, + GenerationConfig: geminiChatGenerationConfig{ + Temperature: request.Temperature, + TopP: request.TopP, + MaxOutputTokens: request.MaxTokens, + }, + } + if request.Tools != nil { + functions := make([]function, 0, len(request.Tools)) + for _, tool := range request.Tools { + functions = append(functions, tool.Function) + } + geminiRequest.Tools = []geminiChatTools{ + { + FunctionDeclarations: functions, + }, + } + } + shouldAddDummyModelMessage := false + for _, message := range request.Messages { + content := geminiChatContent{ + Role: message.Role, + Parts: []geminiPart{ + { + Text: message.Content, + }, + }, + } + + // there's no assistant role in gemini and API shall vomit if role is not user or model + if content.Role == roleAssistant { + content.Role = "model" + } else if content.Role == roleSystem { // converting system prompt to prompt from user for the same reason + content.Role = roleUser + shouldAddDummyModelMessage = true + } + geminiRequest.Contents = append(geminiRequest.Contents, content) + + // if a system message is the last message, we need to add a dummy model message to make gemini happy + if shouldAddDummyModelMessage { + geminiRequest.Contents = append(geminiRequest.Contents, geminiChatContent{ + Role: "model", + Parts: []geminiPart{ + { + Text: "Okay", + }, + }, + }) + shouldAddDummyModelMessage = false + } + } + + return &geminiRequest +} + +func (g *geminiProvider) setSystemContent(request *geminiChatRequest, content string) { + systemContents := []geminiChatContent{{ + Role: roleUser, + Parts: []geminiPart{ + { + Text: content, + }, + }, + }} + request.Contents = append(systemContents, request.Contents...) +} + +type geminiBatchEmbeddingRequest struct { + // Model are only used when using the gemini original protocol + Model string `json:"model,omitempty"` + Requests []geminiEmbeddingRequest `json:"requests"` +} + +type geminiEmbeddingRequest struct { + Model string `json:"model"` + Content geminiChatContent `json:"content"` + TaskType string `json:"taskType,omitempty"` + Title string `json:"title,omitempty"` + OutputDimensionality int `json:"outputDimensionality,omitempty"` +} + +func (g *geminiProvider) buildBatchEmbeddingRequest(request *embeddingsRequest) *geminiBatchEmbeddingRequest { + inputs := request.ParseInput() + requests := make([]geminiEmbeddingRequest, len(inputs)) + model := fmt.Sprintf("models/%s", request.Model) + + for i, input := range inputs { + requests[i] = geminiEmbeddingRequest{ + Model: model, + Content: geminiChatContent{ + Parts: []geminiPart{ + { + Text: input, + }, + }, + }, + } + } + + return &geminiBatchEmbeddingRequest{ + Requests: requests, + } +} + +type geminiChatResponse struct { + Candidates []geminiChatCandidate `json:"candidates"` + PromptFeedback geminiChatPromptFeedback `json:"promptFeedback"` + UsageMetadata geminiUsageMetadata `json:"usageMetadata"` + Error *geminiResponseError `json:"error,omitempty"` +} + +type geminiChatCandidate struct { + Content geminiChatContent `json:"content"` + FinishReason string `json:"finishReason"` + Index int64 `json:"index"` + SafetyRatings []geminiChatSafetyRating `json:"safetyRatings"` +} + +type geminiChatPromptFeedback struct { + SafetyRatings []geminiChatSafetyRating `json:"safetyRatings"` +} + +type geminiUsageMetadata struct { + PromptTokenCount int `json:"promptTokenCount,omitempty"` + CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"` + TotalTokenCount int `json:"totalTokenCount,omitempty"` +} + +type geminiResponseError struct { + Code int `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Status string `json:"status,omitempty"` +} + +type geminiChatSafetyRating struct { + Category string `json:"category"` + Probability string `json:"probability"` +} + +func (g *geminiProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, response *geminiChatResponse) *chatCompletionResponse { + fullTextResponse := chatCompletionResponse{ + Id: fmt.Sprintf("chatcmpl-%s", uuid.New().String()), + Object: objectChatCompletion, + Created: time.Now().UnixMilli() / 1000, + Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), + Choices: make([]chatCompletionChoice, 0, len(response.Candidates)), + Usage: usage{ + PromptTokens: response.UsageMetadata.PromptTokenCount, + CompletionTokens: response.UsageMetadata.CandidatesTokenCount, + TotalTokens: response.UsageMetadata.TotalTokenCount, + }, + } + for i, candidate := range response.Candidates { + choice := chatCompletionChoice{ + Index: i, + Message: &chatMessage{ + Role: roleAssistant, + }, + FinishReason: finishReasonStop, + } + if len(candidate.Content.Parts) > 0 { + if candidate.Content.Parts[0].FunctionCall != nil { + choice.Message.ToolCalls = g.buildToolCalls(&candidate) + } else { + choice.Message.Content = candidate.Content.Parts[0].Text + } + } else { + choice.Message.Content = "" + choice.FinishReason = candidate.FinishReason + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + } + return &fullTextResponse +} + +func (g *geminiProvider) buildToolCalls(candidate *geminiChatCandidate) []toolCall { + var toolCalls []toolCall + + item := candidate.Content.Parts[0] + if item.FunctionCall != nil { + return toolCalls + } + argsBytes, err := json.Marshal(item.FunctionCall.Arguments) + if err != nil { + proxywasm.LogErrorf("get toolCalls from gemini response failed: " + err.Error()) + return toolCalls + } + toolCall := toolCall{ + Id: fmt.Sprintf("call_%s", uuid.New().String()), + Type: "function", + Function: functionCall{ + Arguments: string(argsBytes), + Name: item.FunctionCall.FunctionName, + }, + } + toolCalls = append(toolCalls, toolCall) + return toolCalls +} + +func (g *geminiProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, geminiResp *geminiChatResponse) *chatCompletionResponse { + var choice chatCompletionChoice + if len(geminiResp.Candidates) > 0 && len(geminiResp.Candidates[0].Content.Parts) > 0 { + choice.Delta = &chatMessage{Content: geminiResp.Candidates[0].Content.Parts[0].Text} + } + streamResponse := chatCompletionResponse{ + Id: fmt.Sprintf("chatcmpl-%s", uuid.New().String()), + Object: objectChatCompletionChunk, + Created: time.Now().UnixMilli() / 1000, + Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), + Choices: []chatCompletionChoice{choice}, + Usage: usage{ + PromptTokens: geminiResp.UsageMetadata.PromptTokenCount, + CompletionTokens: geminiResp.UsageMetadata.CandidatesTokenCount, + TotalTokens: geminiResp.UsageMetadata.TotalTokenCount, + }, + } + return &streamResponse +} + +type geminiEmbeddingResponse struct { + Embeddings []geminiEmbeddingData `json:"embeddings"` + Error *geminiResponseError `json:"error,omitempty"` +} + +type geminiEmbeddingData struct { + Values []float64 `json:"values"` +} + +func (g *geminiProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, geminiResp *geminiEmbeddingResponse) *embeddingsResponse { + response := embeddingsResponse{ + Object: "list", + Data: make([]embedding, 0, len(geminiResp.Embeddings)), + Model: ctx.GetContext(ctxKeyFinalRequestModel).(string), + Usage: usage{ + TotalTokens: 0, + }, + } + for _, item := range geminiResp.Embeddings { + response.Data = append(response.Data, embedding{ + Object: `embedding`, + Index: 0, + Embedding: item.Values, + }) + } + return &response +} + +func (g *geminiProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) { + responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/model.go b/plugins/wasm-go/extensions/ai-proxy/provider/model.go index efe9fdf948..1366716599 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/model.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/model.go @@ -161,3 +161,22 @@ type embedding struct { Index int `json:"index"` Embedding []float64 `json:"embedding"` } + +func (r embeddingsRequest) ParseInput() []string { + if r.Input == nil { + return nil + } + var input []string + switch r.Input.(type) { + case string: + input = []string{r.Input.(string)} + case []any: + input = make([]string, 0, len(r.Input.([]any))) + for _, item := range r.Input.([]any) { + if str, ok := item.(string); ok { + input = append(input, str) + } + } + } + return input +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index ec3d2caf55..419b18d6f4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -34,6 +34,7 @@ const ( providerTypeMinimax = "minimax" providerTypeCloudflare = "cloudflare" providerTypeSpark = "spark" + providerTypeGemini = "gemini" protocolOpenAI = "openai" protocolOriginal = "original" @@ -86,6 +87,7 @@ var ( providerTypeMinimax: &minimaxProviderInitializer{}, providerTypeCloudflare: &cloudflareProviderInitializer{}, providerTypeSpark: &sparkProviderInitializer{}, + providerTypeGemini: &geminiProviderInitializer{}, } ) @@ -168,6 +170,9 @@ type ProviderConfig struct { // @Title zh-CN Cloudflare Account ID // @Description zh-CN 仅适用于 Cloudflare Workers AI 服务。参考:https://developers.cloudflare.com/workers-ai/get-started/rest-api/#2-run-a-model-via-api cloudflareAccountId string `required:"false" yaml:"cloudflareAccountId" json:"cloudflareAccountId"` + // @Title zh-CN Gemini AI内容过滤和安全级别设定 + // @Description zh-CN 仅适用于 Gemini AI 服务。参考:https://ai.google.dev/gemini-api/docs/safety-settings + geminiSafetySetting map[string]string `required:"false" yaml:"geminiSafetySetting" json:"geminiSafetySetting"` } func (c *ProviderConfig) FromJson(json gjson.Result) { @@ -208,6 +213,12 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.hunyuanAuthKey = json.Get("hunyuanAuthKey").String() c.minimaxGroupId = json.Get("minimaxGroupId").String() c.cloudflareAccountId = json.Get("cloudflareAccountId").String() + if c.typ == providerTypeGemini { + c.geminiSafetySetting = make(map[string]string) + for k, v := range json.Get("geminiSafetySetting").Map() { + c.geminiSafetySetting[k] = v.String() + } + } } func (c *ProviderConfig) Validate() error {