Skip to content

Commit

Permalink
✨ feat: Support stream_options
Browse files Browse the repository at this point in the history
  • Loading branch information
MartialBE committed May 26, 2024
1 parent fa54ca7 commit eb26065
Show file tree
Hide file tree
Showing 11 changed files with 188 additions and 31 deletions.
14 changes: 14 additions & 0 deletions providers/baichuan/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,26 @@ func (p *BaichuanProvider) CreateChatCompletion(request *types.ChatCompletionReq
}

func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
streamOptions := request.StreamOptions
// 如果支持流式返回Usage 则需要更改配置:
if p.SupportStreamOptions {
request.StreamOptions = &types.StreamOptions{
IncludeUsage: true,
}
} else {
// 避免误传导致报错
request.StreamOptions = nil
}

req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()

// 恢复原来的配置
request.StreamOptions = streamOptions

// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
Expand Down
13 changes: 13 additions & 0 deletions providers/groq/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,26 @@ func (p *GroqProvider) CreateChatCompletion(request *types.ChatCompletionRequest
}

func (p *GroqProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
streamOptions := request.StreamOptions
// 如果支持流式返回Usage 则需要更改配置:
if p.SupportStreamOptions {
request.StreamOptions = &types.StreamOptions{
IncludeUsage: true,
}
} else {
// 避免误传导致报错
request.StreamOptions = nil
}
p.getChatRequestBody(request)
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()

// 恢复原来的配置
request.StreamOptions = streamOptions

// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
Expand Down
13 changes: 10 additions & 3 deletions providers/openai/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ type OpenAIProviderFactory struct{}

type OpenAIProvider struct {
base.BaseProvider
IsAzure bool
BalanceAction bool
IsAzure bool
BalanceAction bool
SupportStreamOptions bool
}

// 创建 OpenAIProvider
Expand All @@ -33,7 +34,7 @@ func (f OpenAIProviderFactory) Create(channel *model.Channel) base.ProviderInter
func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvider {
config := getOpenAIConfig(baseURL)

return &OpenAIProvider{
OpenAIProvider := &OpenAIProvider{
BaseProvider: base.BaseProvider{
Config: config,
Channel: channel,
Expand All @@ -42,6 +43,12 @@ func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvide
IsAzure: false,
BalanceAction: true,
}

if channel.Type == common.ChannelTypeOpenAI {
OpenAIProvider.SupportStreamOptions = true
}

return OpenAIProvider
}

func getOpenAIConfig(baseURL string) base.ProviderConfig {
Expand Down
33 changes: 25 additions & 8 deletions providers/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"one-api/common/requester"
"one-api/types"
"strings"
"time"
)

type OpenAIStreamHandler struct {
Expand Down Expand Up @@ -58,12 +57,25 @@ func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionReque
}

func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
streamOptions := request.StreamOptions
// 如果支持流式返回Usage 则需要更改配置:
if p.SupportStreamOptions {
request.StreamOptions = &types.StreamOptions{
IncludeUsage: true,
}
} else {
// 避免误传导致报错
request.StreamOptions = nil
}
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()

// 恢复原来的配置
request.StreamOptions = streamOptions

// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
Expand Down Expand Up @@ -110,18 +122,23 @@ func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, dataChan chan s
}

if len(openaiResponse.Choices) == 0 {
if openaiResponse.Usage != nil {
*h.Usage = *openaiResponse.Usage
}
*rawLine = nil
return
}

dataChan <- string(*rawLine)

if h.isAzure {
// 阻塞 20ms
time.Sleep(20 * time.Millisecond)
if len(openaiResponse.Choices) > 0 && openaiResponse.Choices[0].Usage != nil {
*h.Usage = *openaiResponse.Choices[0].Usage
} else {
if h.Usage.TotalTokens == 0 {
h.Usage.TotalTokens = h.Usage.PromptTokens
}
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
h.Usage.CompletionTokens += countTokenText
h.Usage.TotalTokens += countTokenText
}

countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
h.Usage.CompletionTokens += countTokenText
h.Usage.TotalTokens += countTokenText
}
24 changes: 24 additions & 0 deletions providers/openai/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,25 @@ func (p *OpenAIProvider) CreateCompletion(request *types.CompletionRequest) (ope
}

func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[string], errWithCode *types.OpenAIErrorWithStatusCode) {
streamOptions := request.StreamOptions
// 如果支持流式返回Usage 则需要更改配置:
if p.SupportStreamOptions {
request.StreamOptions = &types.StreamOptions{
IncludeUsage: true,
}
} else {
// 避免误传导致报错
request.StreamOptions = nil
}
req, errWithCode := p.GetRequestTextBody(common.RelayModeCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()

// 恢复原来的配置
request.StreamOptions = streamOptions

// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
Expand Down Expand Up @@ -90,8 +103,19 @@ func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, dataChan
return
}

if len(openaiResponse.Choices) == 0 {
if openaiResponse.Usage != nil {
*h.Usage = *openaiResponse.Usage
}
*rawLine = nil
return
}

dataChan <- string(*rawLine)

if h.Usage.TotalTokens == 0 {
h.Usage.TotalTokens = h.Usage.PromptTokens
}
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
h.Usage.CompletionTokens += countTokenText
h.Usage.TotalTokens += countTokenText
Expand Down
34 changes: 33 additions & 1 deletion relay/chat.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package relay

import (
"encoding/json"
"errors"
"fmt"
"math"
"net/http"
"one-api/common"
Expand Down Expand Up @@ -36,6 +38,10 @@ func (r *relayChat) setRequest() error {
r.c.Set("skip_only_chat", true)
}

if !r.chatRequest.Stream && r.chatRequest.StreamOptions != nil {
return errors.New("The 'stream_options' parameter is only allowed when 'stream' is enabled.")
}

r.originalModel = r.chatRequest.Model

return nil
Expand Down Expand Up @@ -66,7 +72,11 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
return
}

err = responseStreamClient(r.c, response, r.cache)
doneStr := func() string {
return r.getUsageResponse()
}

err = responseStreamClient(r.c, response, r.cache, doneStr)
} else {
var response *types.ChatCompletionResponse
response, err = chatProvider.CreateChatCompletion(&r.chatRequest)
Expand All @@ -86,3 +96,25 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {

return
}

func (r *relayChat) getUsageResponse() string {
if r.chatRequest.StreamOptions != nil && r.chatRequest.StreamOptions.IncludeUsage {
usageResponse := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: r.chatRequest.Model,
Choices: []types.ChatCompletionStreamChoice{},
Usage: r.provider.GetUsage(),
}

responseBody, err := json.Marshal(usageResponse)
if err != nil {
return ""
}

return string(responseBody)
}

return ""
}
14 changes: 12 additions & 2 deletions relay/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWith
return nil
}

func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *util.ChatCacheProps) (errWithOP *types.OpenAIErrorWithStatusCode) {
type StreamEndHandler func() string

func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *util.ChatCacheProps, endHandler StreamEndHandler) (errWithOP *types.OpenAIErrorWithStatusCode) {
requester.SetEventStreamHeaders(c)
dataChan, errChan := stream.Recv()

Expand All @@ -160,14 +162,22 @@ func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface
cache.NoCache()
}

if errWithOP == nil && endHandler != nil {
streamData := endHandler()
if streamData != "" {
fmt.Fprint(w, "data: "+streamData+"\n\n")
cache.SetResponse(streamData)
}
}

streamData := "data: [DONE]\n"
fmt.Fprint(w, streamData)
cache.SetResponse(streamData)
return false
}
})

return errWithOP
return nil
}

func responseMultipart(c *gin.Context, resp *http.Response) *types.OpenAIErrorWithStatusCode {
Expand Down
34 changes: 33 additions & 1 deletion relay/completions.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package relay

import (
"encoding/json"
"errors"
"fmt"
"math"
"net/http"
"one-api/common"
Expand Down Expand Up @@ -32,6 +34,10 @@ func (r *relayCompletions) setRequest() error {
return errors.New("max_tokens is invalid")
}

if !r.request.Stream && r.request.StreamOptions != nil {
return errors.New("The 'stream_options' parameter is only allowed when 'stream' is enabled.")
}

r.originalModel = r.request.Model

return nil
Expand Down Expand Up @@ -62,7 +68,11 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo
return
}

err = responseStreamClient(r.c, response, r.cache)
doneStr := func() string {
return r.getUsageResponse()
}

err = responseStreamClient(r.c, response, r.cache, doneStr)
} else {
var response *types.CompletionResponse
response, err = provider.CreateCompletion(&r.request)
Expand All @@ -79,3 +89,25 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo

return
}

func (r *relayCompletions) getUsageResponse() string {
if r.request.StreamOptions != nil && r.request.StreamOptions.IncludeUsage {
usageResponse := types.CompletionResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: r.request.Model,
Choices: []types.CompletionChoice{},
Usage: r.provider.GetUsage(),
}

responseBody, err := json.Marshal(usageResponse)
if err != nil {
return ""
}

return string(responseBody)
}

return ""
}
3 changes: 3 additions & 0 deletions types/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ type ChatCompletionRequest struct {
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
Stop []string `json:"stop,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"`
Expand Down Expand Up @@ -356,6 +357,7 @@ type ChatCompletionStreamChoice struct {
Delta ChatCompletionStreamChoiceDelta `json:"delta"`
FinishReason any `json:"finish_reason"`
ContentFilterResults any `json:"content_filter_results,omitempty"`
Usage *Usage `json:"usage,omitempty"`
}

func (c *ChatCompletionStreamChoice) CheckChoice(request *ChatCompletionRequest) {
Expand All @@ -372,4 +374,5 @@ type ChatCompletionStreamResponse struct {
Model string `json:"model"`
Choices []ChatCompletionStreamChoice `json:"choices"`
PromptAnnotations any `json:"prompt_annotations,omitempty"`
Usage *Usage `json:"usage,omitempty"`
}
4 changes: 4 additions & 0 deletions types/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@ type OpenAIErrorWithStatusCode struct {
type OpenAIErrorResponse struct {
Error OpenAIError `json:"error,omitempty"`
}

type StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
Loading

0 comments on commit eb26065

Please sign in to comment.