Skip to content

Commit

Permalink
feat: add multi-file type support for Gemini and Claude
Browse files Browse the repository at this point in the history
- Add file data DTO for structured file handling
- Implement file decoder service
- Update Claude and Gemini relay channels to handle various file types
- Reorganize worker service to cf_worker for clarity
- Update token counter and image service for new file types
  • Loading branch information
Calcium-Ion committed Dec 28, 2024
1 parent d75ecfc commit 2b38e8e
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 20 deletions.
1 change: 1 addition & 0 deletions README.en.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
- `GEMINI_MODEL_MAP`: Specify Gemini model versions (v1/v1beta), format: "model:version", comma-separated
- `COHERE_SAFETY_SETTING`: Cohere model [safety settings](https://docs.cohere.com/docs/safety-modes#overview), options: `NONE`, `CONTEXTUAL`, `STRICT`, default `NONE`
- `GEMINI_VISION_MAX_IMAGE_NUM`: Gemini model maximum image number, default `16`, set to `-1` to disable
- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`

## Deployment
> [!TIP]
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
- `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置(v1beta)
- `COHERE_SAFETY_SETTING`:Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL``STRICT`,默认为 `NONE`
- `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认为 `16`,设置为 `-1` 则不限制。
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB,默认为 `20`
## 部署
> [!TIP]
> 最新版Docker镜像:`calciumion/new-api:latest`
Expand Down
2 changes: 2 additions & 0 deletions constant/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)

var MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)

// ForceStreamOption 覆盖请求参数,强制返回usage信息
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)

Expand Down
8 changes: 8 additions & 0 deletions dto/file_data.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package dto

type LocalFileData struct {
MimeType string
Base64Data string
Url string
Size int64
}
9 changes: 6 additions & 3 deletions relay/channel/claude/relay-claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,12 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
// 判断是否是url
if strings.HasPrefix(imageUrl.Url, "http") {
// 是url,获取图片的类型和base64编码的数据
mimeType, data, _ := service.GetImageFromUrl(imageUrl.Url)
claudeMediaMessage.Source.MediaType = mimeType
claudeMediaMessage.Source.Data = data
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
if err != nil {
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
}
claudeMediaMessage.Source.MediaType = fileData.MimeType
claudeMediaMessage.Source.Data = fileData.Base64Data
} else {
_, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
if err != nil {
Expand Down
9 changes: 6 additions & 3 deletions relay/channel/gemini/relay-gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,14 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
// 判断是否是url
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
// 是url,获取图片的类型和base64编码的数据
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
fileData, err := service.GetFileBase64FromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
if err != nil {
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
MimeType: fileData.MimeType,
Data: fileData.Base64Data,
},
})
} else {
Expand Down
2 changes: 1 addition & 1 deletion relay/relay-text.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
var err error
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
promptTokens, err = service.CountTokenChatRequest(*textRequest, textRequest.Model)
promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
case relayconstant.RelayModeCompletions:
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
case relayconstant.RelayModeModerations:
Expand Down
9 changes: 6 additions & 3 deletions service/worker.go → service/cf_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ import (
"strings"
)

func DoImageRequest(originUrl string) (resp *http.Response, err error) {
func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
if setting.EnableWorker() {
common.SysLog(fmt.Sprintf("downloading image from worker: %s", originUrl))
common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
if !strings.HasPrefix(originUrl, "https") {
return nil, fmt.Errorf("only support https url")
}
workerUrl := setting.WorkerUrl
if !strings.HasSuffix(workerUrl, "/") {
workerUrl += "/"
Expand All @@ -20,7 +23,7 @@ func DoImageRequest(originUrl string) (resp *http.Response, err error) {
data := []byte(`{"url":"` + originUrl + `","key":"` + setting.WorkerValidKey + `"}`)
return http.Post(setting.WorkerUrl, "application/json", bytes.NewBuffer(data))
} else {
common.SysLog(fmt.Sprintf("downloading image from origin: %s", originUrl))
common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
return http.Get(originUrl)
}
}
39 changes: 39 additions & 0 deletions service/file_decoder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package service

import (
"encoding/base64"
"fmt"
"io"
"one-api/constant"
"one-api/dto"
)

var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024

func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
resp, err := DoDownloadRequest(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()

// Always use LimitReader to prevent oversized downloads
fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
if err != nil {
return nil, err
}

// Check actual size after reading
if len(fileBytes) > maxFileSize {
return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
}

// Convert to base64
base64Data := base64.StdEncoding.EncodeToString(fileBytes)

return &dto.LocalFileData{
Base64Data: base64Data,
MimeType: resp.Header.Get("Content-Type"),
Size: int64(len(fileBytes)),
}, nil
}
14 changes: 10 additions & 4 deletions service/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e

// GetImageFromUrl 获取图片的类型和base64编码的数据
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
resp, err := DoImageRequest(url)
resp, err := DoDownloadRequest(url)
if err != nil {
return
return "", "", err
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return
return "", "", fmt.Errorf("invalid content type: %s, required image/*", resp.Header.Get("Content-Type"))
}
defer resp.Body.Close()
buffer := bytes.NewBuffer(nil)
Expand All @@ -52,7 +52,7 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
}

func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
response, err := DoImageRequest(imageUrl)
response, err := DoDownloadRequest(imageUrl)
if err != nil {
common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
return image.Config{}, "", err
Expand All @@ -64,6 +64,12 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
return image.Config{}, "", err
}

mimeType := response.Header.Get("Content-Type")

if !strings.HasPrefix(mimeType, "image/") {
return image.Config{}, "", fmt.Errorf("invalid content type: %s, required image/*", mimeType)
}

var readData []byte
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
Expand Down
15 changes: 9 additions & 6 deletions service/token_counter.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
return len(tokenEncoder.Encode(text, nil, nil))
}

func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
baseTokens := 85
if model == "glm-4v" {
return 1047, nil
Expand All @@ -96,6 +96,9 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
if !constant.GetMediaToken {
return 256, nil
}
if info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelTypeVertexAi || info.ChannelType == common.ChannelTypeAnthropic {
return 256, nil
}
// 同步One API的图片计费逻辑
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
imageUrl.Detail = "high"
Expand Down Expand Up @@ -155,9 +158,9 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
return tiles*tileTokens + baseTokens, nil
}

func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, error) {
func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
tkm := 0
msgTokens, err := CountTokenMessages(request.Messages, model, request.Stream)
msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
if err != nil {
return 0, err
}
Expand All @@ -179,7 +182,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
}
}
toolTokens, err := CountTokenInput(countStr, model)
toolTokens, err := CountTokenInput(countStr, request.Model)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -256,7 +259,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
return textToken, audioToken, nil
}

func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
//recover when panic
tokenEncoder := getTokenEncoder(model)
// Reference:
Expand Down Expand Up @@ -290,7 +293,7 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
for _, m := range arrayContent {
if m.Type == dto.ContentTypeImageURL {
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err := getImageToken(&imageUrl, model, stream)
imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
if err != nil {
return 0, err
}
Expand Down

0 comments on commit 2b38e8e

Please sign in to comment.