From 2b38e8ed8d16cf414452ca8bd35120cccdcadbaf Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Sun, 29 Dec 2024 00:00:24 +0800 Subject: [PATCH] feat: add multi-file type support for Gemini and Claude - Add file data DTO for structured file handling - Implement file decoder service - Update Claude and Gemini relay channels to handle various file types - Reorganize worker service to cf_worker for clarity - Update token counter and image service for new file types --- README.en.md | 1 + README.md | 1 + constant/env.go | 2 ++ dto/file_data.go | 8 ++++++ relay/channel/claude/relay-claude.go | 9 ++++--- relay/channel/gemini/relay-gemini.go | 9 ++++--- relay/relay-text.go | 2 +- service/{worker.go => cf_worker.go} | 9 ++++--- service/file_decoder.go | 39 ++++++++++++++++++++++++++++ service/image.go | 14 +++++++--- service/token_counter.go | 15 ++++++----- 11 files changed, 89 insertions(+), 20 deletions(-) create mode 100644 dto/file_data.go rename service/{worker.go => cf_worker.go} (59%) create mode 100644 service/file_decoder.go diff --git a/README.en.md b/README.en.md index c45fff94a..dac89af07 100644 --- a/README.en.md +++ b/README.en.md @@ -82,6 +82,7 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model - `GEMINI_MODEL_MAP`: Specify Gemini model versions (v1/v1beta), format: "model:version", comma-separated - `COHERE_SAFETY_SETTING`: Cohere model [safety settings](https://docs.cohere.com/docs/safety-modes#overview), options: `NONE`, `CONTEXTUAL`, `STRICT`, default `NONE` - `GEMINI_VISION_MAX_IMAGE_NUM`: Gemini model maximum image number, default `16`, set to `-1` to disable +- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20` ## Deployment > [!TIP] diff --git a/README.md b/README.md index ddf3cd9e3..7c6417c9f 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,7 @@ - `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置(v1beta) - `COHERE_SAFETY_SETTING`:Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`,`STRICT`,默认为 `NONE`。 - `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认为 `16`,设置为 `-1` 则不限制。 +- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB,默认为 `20`。 ## 部署 > [!TIP] > 最新版Docker镜像:`calciumion/new-api:latest` diff --git a/constant/env.go b/constant/env.go index b9a6801dc..cd2d71b2a 100644 --- a/constant/env.go +++ b/constant/env.go @@ -10,6 +10,8 @@ import ( var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60) var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true) +var MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20) + // ForceStreamOption 覆盖请求参数,强制返回usage信息 var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true) diff --git a/dto/file_data.go b/dto/file_data.go new file mode 100644 index 000000000..d5cf0f684 --- /dev/null +++ b/dto/file_data.go @@ -0,0 +1,8 @@ +package dto + +type LocalFileData struct { + MimeType string + Base64Data string + Url string + Size int64 +} diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 0cddf8a66..317bf6047 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -225,9 +225,12 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR // 判断是否是url if strings.HasPrefix(imageUrl.Url, "http") { // 是url,获取图片的类型和base64编码的数据 - mimeType, data, _ := service.GetImageFromUrl(imageUrl.Url) - claudeMediaMessage.Source.MediaType = mimeType - claudeMediaMessage.Source.Data = data + fileData, err := service.GetFileBase64FromUrl(imageUrl.Url) + if err != nil { + return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error()) + } + claudeMediaMessage.Source.MediaType = fileData.MimeType + claudeMediaMessage.Source.Data = fileData.Base64Data } else { _, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url) if err != nil { diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 01f21b312..ebdd1dd3f 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -192,11 +192,14 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque // 判断是否是url if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") { // 是url,获取图片的类型和base64编码的数据 - mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) + fileData, err := service.GetFileBase64FromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) + if err != nil { + return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error()) + } parts = append(parts, GeminiPart{ InlineData: &GeminiInlineData{ - MimeType: mimeType, - Data: data, + MimeType: fileData.MimeType, + Data: fileData.Base64Data, }, }) } else { diff --git a/relay/relay-text.go b/relay/relay-text.go index 86fed5f27..c3e449be6 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -230,7 +230,7 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re var err error switch info.RelayMode { case relayconstant.RelayModeChatCompletions: - promptTokens, err = service.CountTokenChatRequest(*textRequest, textRequest.Model) + promptTokens, err = service.CountTokenChatRequest(info, *textRequest) case relayconstant.RelayModeCompletions: promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model) case relayconstant.RelayModeModerations: diff --git a/service/worker.go b/service/cf_worker.go similarity index 59% rename from service/worker.go rename to service/cf_worker.go index 254681826..afe65411b 100644 --- a/service/worker.go +++ b/service/cf_worker.go @@ -9,9 +9,12 @@ import ( "strings" ) -func DoImageRequest(originUrl string) (resp *http.Response, err error) { +func DoDownloadRequest(originUrl string) (resp *http.Response, err error) { if setting.EnableWorker() { - common.SysLog(fmt.Sprintf("downloading image from worker: %s", originUrl)) + common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl)) + if !strings.HasPrefix(originUrl, "https") { + return nil, fmt.Errorf("only support https url") + } workerUrl := setting.WorkerUrl if !strings.HasSuffix(workerUrl, "/") { workerUrl += "/" @@ -20,7 +23,7 @@ func DoImageRequest(originUrl string) (resp *http.Response, err error) { data := []byte(`{"url":"` + originUrl + `","key":"` + setting.WorkerValidKey + `"}`) return http.Post(setting.WorkerUrl, "application/json", bytes.NewBuffer(data)) } else { - common.SysLog(fmt.Sprintf("downloading image from origin: %s", originUrl)) + common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl)) return http.Get(originUrl) } } diff --git a/service/file_decoder.go b/service/file_decoder.go new file mode 100644 index 000000000..ac9f00f34 --- /dev/null +++ b/service/file_decoder.go @@ -0,0 +1,39 @@ +package service + +import ( + "encoding/base64" + "fmt" + "io" + "one-api/constant" + "one-api/dto" +) + +var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024 + +func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) { + resp, err := DoDownloadRequest(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + // Always use LimitReader to prevent oversized downloads + fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1))) + if err != nil { + return nil, err + } + + // Check actual size after reading + if len(fileBytes) > maxFileSize { + return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB) + } + + // Convert to base64 + base64Data := base64.StdEncoding.EncodeToString(fileBytes) + + return &dto.LocalFileData{ + Base64Data: base64Data, + MimeType: resp.Header.Get("Content-Type"), + Size: int64(len(fileBytes)), + }, nil +} diff --git a/service/image.go b/service/image.go index f3eddff43..61f5364f5 100644 --- a/service/image.go +++ b/service/image.go @@ -33,12 +33,12 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e // GetImageFromUrl 获取图片的类型和base64编码的数据 func GetImageFromUrl(url string) (mimeType string, data string, err error) { - resp, err := DoImageRequest(url) + resp, err := DoDownloadRequest(url) if err != nil { - return + return "", "", err } if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") { - return + return "", "", fmt.Errorf("invalid content type: %s, required image/*", resp.Header.Get("Content-Type")) } defer resp.Body.Close() buffer := bytes.NewBuffer(nil) @@ -52,7 +52,7 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) { } func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { - response, err := DoImageRequest(imageUrl) + response, err := DoDownloadRequest(imageUrl) if err != nil { common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error())) return image.Config{}, "", err @@ -64,6 +64,12 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { return image.Config{}, "", err } + mimeType := response.Header.Get("Content-Type") + + if !strings.HasPrefix(mimeType, "image/") { + return image.Config{}, "", fmt.Errorf("invalid content type: %s, required image/*", mimeType) + } + var readData []byte for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} { common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit)) diff --git a/service/token_counter.go b/service/token_counter.go index e82da5cc2..cde88de30 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -80,7 +80,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { return len(tokenEncoder.Encode(text, nil, nil)) } -func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) { +func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) { baseTokens := 85 if model == "glm-4v" { return 1047, nil @@ -96,6 +96,9 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in if !constant.GetMediaToken { return 256, nil } + if info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelTypeVertexAi || info.ChannelType == common.ChannelTypeAnthropic { + return 256, nil + } // 同步One API的图片计费逻辑 if imageUrl.Detail == "auto" || imageUrl.Detail == "" { imageUrl.Detail = "high" @@ -155,9 +158,9 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in return tiles*tileTokens + baseTokens, nil } -func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, error) { +func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) { tkm := 0 - msgTokens, err := CountTokenMessages(request.Messages, model, request.Stream) + msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream) if err != nil { return 0, err } @@ -179,7 +182,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, countStr += fmt.Sprintf("%v", tool.Function.Parameters) } } - toolTokens, err := CountTokenInput(countStr, model) + toolTokens, err := CountTokenInput(countStr, request.Model) if err != nil { return 0, err } @@ -256,7 +259,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, return textToken, audioToken, nil } -func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) { +func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) { //recover when panic tokenEncoder := getTokenEncoder(model) // Reference: @@ -290,7 +293,7 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, for _, m := range arrayContent { if m.Type == dto.ContentTypeImageURL { imageUrl := m.ImageUrl.(dto.MessageImageUrl) - imageTokenNum, err := getImageToken(&imageUrl, model, stream) + imageTokenNum, err := getImageToken(info, &imageUrl, model, stream) if err != nil { return 0, err }