Skip to content

Commit

Permalink
fix: 豆包支持embeddings
Browse files Browse the repository at this point in the history
Fixes #1594
  • Loading branch information
igophper committed Jul 17, 2024
1 parent 6209ff9 commit 0d81c37
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 26 deletions.
6 changes: 5 additions & 1 deletion relay/adaptor/doubao/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ import (
)

func GetRequestURL(meta *meta.Meta) (string, error) {
if meta.Mode == relaymode.ChatCompletions {
switch meta.Mode {
case relaymode.ChatCompletions:

Check warning on line 11 in relay/adaptor/doubao/main.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/doubao/main.go#L10-L11

Added lines #L10 - L11 were not covered by tests
return fmt.Sprintf("%s/api/v3/chat/completions", meta.BaseURL), nil
case relaymode.Embeddings:
return fmt.Sprintf("%s/api/v3/embeddings", meta.BaseURL), nil
default:

Check warning on line 15 in relay/adaptor/doubao/main.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/doubao/main.go#L13-L15

Added lines #L13 - L15 were not covered by tests
}
return "", fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode)
}
51 changes: 26 additions & 25 deletions relay/controller/text.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/billing"
Expand All @@ -31,9 +32,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
meta.IsStream = textRequest.Stream

// map model name
var isModelMapped bool
meta.OriginModelName = textRequest.Model
textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping)
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)

Check warning on line 36 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L36

Added line #L36 was not covered by tests
meta.ActualModelName = textRequest.Model
// get model ratio & group ratio
modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
Expand All @@ -56,29 +56,9 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {

// get request body
var requestBody io.Reader
if meta.APIType == apitype.OpenAI {
// no need to convert request for openai
shouldResetRequestBody := isModelMapped || meta.ChannelType == channeltype.Baichuan // frequency_penalty 0 is not acceptable for baichuan
if shouldResetRequestBody {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
} else {
convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
if err != nil {
return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
logger.Debugf(ctx, "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData)
requestBody, err = getRequestBody(c, meta, textRequest, adaptor)
if err != nil {
return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)

Check warning on line 61 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L59-L61

Added lines #L59 - L61 were not covered by tests
}

// do request
Expand All @@ -103,3 +83,24 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
return nil
}

func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
if meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan {

Check warning on line 88 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L87-L88

Added lines #L87 - L88 were not covered by tests
// no need to convert request for openai
return c.Request.Body, nil

Check warning on line 90 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L90

Added line #L90 was not covered by tests
}

// get request body
var requestBody io.Reader
convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
if err != nil {
return nil, err

Check warning on line 97 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L94-L97

Added lines #L94 - L97 were not covered by tests
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return nil, err

Check warning on line 101 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L99-L101

Added lines #L99 - L101 were not covered by tests
}
logger.Debugf(c.Request.Context(), "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData)
return requestBody, nil

Check warning on line 105 in relay/controller/text.go

View check run for this annotation

Codecov / codecov/patch

relay/controller/text.go#L103-L105

Added lines #L103 - L105 were not covered by tests
}

0 comments on commit 0d81c37

Please sign in to comment.