Skip to content

Commit

Permalink
feat: add ability to pass request-specific env vars to chat completion
Browse files Browse the repository at this point in the history
This will allow authentication per-request in model providers.

Signed-off-by: Donnie Adams <[email protected]>
  • Loading branch information
thedadams committed Nov 4, 2024
1 parent 50489f2 commit 953f19a
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 41 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ require (
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/google/uuid v1.6.0
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3
github.com/gptscript-ai/chat-completion-client v0.0.0-20241104122544-5fe75f07c131
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb
github.com/gptscript-ai/go-gptscript v0.9.5-rc5.0.20240927213153-2af51434b93e
github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86 h1:m9yLtIEd0z1ia8qFjq3u0Ozb6QKwidyL856JLJp6nbA=
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86/go.mod h1:lK3K5EZx4dyT24UG3yCt0wmspkYqrj4D/8kxdN3relk=
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3 h1:EQiFTZv+BnOWJX2B9XdF09fL2Zj7h19n1l23TpWCafc=
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
github.com/gptscript-ai/chat-completion-client v0.0.0-20241104122544-5fe75f07c131 h1:y2FcmT4X8U606gUS0teX5+JWX9K/NclsLEhHiyrd+EU=
github.com/gptscript-ai/chat-completion-client v0.0.0-20241104122544-5fe75f07c131/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb h1:ky2J2CzBOskC7Jgm2VJAQi2x3p7FVGa+2/PcywkFJuc=
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw=
github.com/gptscript-ai/go-gptscript v0.9.5-rc5.0.20240927213153-2af51434b93e h1:WpNae0NBx+Ri8RB3SxF8DhadDKU7h+jfWPQterDpbJA=
Expand Down
11 changes: 0 additions & 11 deletions pkg/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,3 @@ func GetLogger(ctx context.Context) mvl.Logger {

return l
}

type envKey struct{}

func WithEnv(ctx context.Context, env []string) context.Context {
return context.WithValue(ctx, envKey{}, env)
}

func GetEnv(ctx context.Context) []string {
l, _ := ctx.Value(envKey{}).([]string)
return l
}
4 changes: 2 additions & 2 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"sync"

"github.com/gptscript-ai/gptscript/pkg/config"
gcontext "github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/counter"
"github.com/gptscript-ai/gptscript/pkg/types"
"github.com/gptscript-ai/gptscript/pkg/version"
Expand Down Expand Up @@ -389,7 +388,8 @@ func (e *Engine) complete(ctx context.Context, state *State) (*Return, error) {
}
}()

resp, err := e.Model.Call(gcontext.WithEnv(ctx, e.Env), state.Completion, progress)
state.Completion.Env = e.Env
resp, err := e.Model.Call(ctx, state.Completion, progress)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/llm/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (r *Registry) ServeHTTP(w http.ResponseWriter, req *http.Request) {

var (
model string
data = map[string]any{}
data map[string]any
)

if json.Unmarshal(inBytes, &data) == nil {
Expand All @@ -65,7 +65,7 @@ func (r *Registry) ServeHTTP(w http.ResponseWriter, req *http.Request) {
model = builtin.GetDefaultModel()
}

c, err := r.getClient(req.Context(), model)
c, err := r.getClient(req.Context(), model, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
Expand Down
6 changes: 3 additions & 3 deletions pkg/llm/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (r *Registry) fastPath(modelName string) Client {
return r.clients[0]
}

func (r *Registry) getClient(ctx context.Context, modelName string) (Client, error) {
func (r *Registry) getClient(ctx context.Context, modelName string, env []string) (Client, error) {
if c := r.fastPath(modelName); c != nil {
return c, nil
}
Expand All @@ -101,7 +101,7 @@ func (r *Registry) getClient(ctx context.Context, modelName string) (Client, err

if len(errs) > 0 && oaiClient != nil {
// Prompt the user to enter their OpenAI API key and try again.
if err := oaiClient.RetrieveAPIKey(ctx); err != nil {
if err := oaiClient.RetrieveAPIKey(ctx, env); err != nil {
return nil, err
}
ok, err := oaiClient.Supports(ctx, modelName)
Expand Down Expand Up @@ -146,7 +146,7 @@ func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequ

if len(errs) > 0 && oaiClient != nil {
// Prompt the user to enter their OpenAI API key and try again.
if err := oaiClient.RetrieveAPIKey(ctx); err != nil {
if err := oaiClient.RetrieveAPIKey(ctx, messageRequest.Env); err != nil {
return nil, err
}
ok, err := oaiClient.Supports(ctx, messageRequest.Model)
Expand Down
37 changes: 26 additions & 11 deletions pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (

openai "github.com/gptscript-ai/chat-completion-client"
"github.com/gptscript-ai/gptscript/pkg/cache"
gcontext "github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/counter"
"github.com/gptscript-ai/gptscript/pkg/credentials"
"github.com/gptscript-ai/gptscript/pkg/hash"
Expand Down Expand Up @@ -305,7 +304,7 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C

func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
if err := c.ValidAuth(); err != nil {
if err := c.RetrieveAPIKey(ctx); err != nil {
if err := c.RetrieveAPIKey(ctx, messageRequest.Env); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -401,15 +400,15 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
if err != nil {
return nil, err
} else if !ok {
result, err = c.call(ctx, request, id, status)
result, err = c.call(ctx, request, id, messageRequest.Env, status)

// If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass.
var apiError *openai.APIError
if errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" && messageRequest.Chat {
// Decrease maxTokens by 10% to make garbage collection more aggressive.
// The retry loop will further decrease maxTokens if needed.
maxTokens := decreaseTenPercent(messageRequest.MaxTokens)
result, err = c.contextLimitRetryLoop(ctx, request, id, maxTokens, status)
result, err = c.contextLimitRetryLoop(ctx, request, id, messageRequest.Env, maxTokens, status)
}
if err != nil {
return nil, err
Expand Down Expand Up @@ -443,7 +442,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
return &result, nil
}

func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, maxTokens int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) {
func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, env []string, maxTokens int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) {
var (
response types.CompletionMessage
err error
Expand All @@ -452,7 +451,7 @@ func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatC
for range 10 { // maximum 10 tries
// Try to drop older messages again, with a decreased max tokens.
request.Messages = dropMessagesOverCount(maxTokens, request.Messages)
response, err = c.call(ctx, request, id, status)
response, err = c.call(ctx, request, id, env, status)
if err == nil {
return response, nil
}
Expand Down Expand Up @@ -542,7 +541,7 @@ func override(left, right string) string {
return left
}

func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, transactionID string, partial chan<- types.CompletionStatus) (types.CompletionMessage, error) {
func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, transactionID string, env []string, partial chan<- types.CompletionStatus) (types.CompletionMessage, error) {
streamResponse := os.Getenv("GPTSCRIPT_INTERNAL_OPENAI_STREAMING") != "false"

partial <- types.CompletionStatus{
Expand All @@ -553,11 +552,27 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
},
}

var (
headers map[string]string
modelProviderEnv []string
)
for _, e := range env {
if strings.HasPrefix(e, "GPTSCRIPT_MODEL_PROVIDER_") {
modelProviderEnv = append(modelProviderEnv, e)
}
}

if len(modelProviderEnv) > 0 {
headers = map[string]string{
"X-GPTScript-Env": strings.Join(modelProviderEnv, ","),
}
}

slog.Debug("calling openai", "message", request.Messages)

if !streamResponse {
request.StreamOptions = nil
resp, err := c.c.CreateChatCompletion(ctx, request)
resp, err := c.c.CreateChatCompletion(ctx, request, headers)
if err != nil {
return types.CompletionMessage{}, err
}
Expand All @@ -582,7 +597,7 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
}), nil
}

stream, err := c.c.CreateChatCompletionStream(ctx, request)
stream, err := c.c.CreateChatCompletionStream(ctx, request, headers)
if err != nil {
return types.CompletionMessage{}, err
}
Expand Down Expand Up @@ -614,8 +629,8 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
}
}

func (c *Client) RetrieveAPIKey(ctx context.Context) error {
k, err := prompt.GetModelProviderCredential(ctx, c.credStore, BuiltinCredName, "OPENAI_API_KEY", "Please provide your OpenAI API key:", gcontext.GetEnv(ctx))
func (c *Client) RetrieveAPIKey(ctx context.Context, env []string) error {
k, err := prompt.GetModelProviderCredential(ctx, c.credStore, BuiltinCredName, "OPENAI_API_KEY", "Please provide your OpenAI API key:", env)
if err != nil {
return err
}
Expand Down
15 changes: 7 additions & 8 deletions pkg/remote/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"sync"

"github.com/gptscript-ai/gptscript/pkg/cache"
gcontext "github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/credentials"
"github.com/gptscript-ai/gptscript/pkg/engine"
env2 "github.com/gptscript-ai/gptscript/pkg/env"
Expand Down Expand Up @@ -48,7 +47,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
return nil, fmt.Errorf("failed to find remote model %s", messageRequest.Model)
}

client, err := c.load(ctx, provider)
client, err := c.load(ctx, provider, messageRequest.Env...)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -111,7 +110,7 @@ func isHTTPURL(toolName string) bool {
strings.HasPrefix(toolName, "https://")
}

func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Client, error) {
func (c *Client) clientFromURL(ctx context.Context, apiURL string, envs []string) (*openai.Client, error) {
parsed, err := url.Parse(apiURL)
if err != nil {
return nil, err
Expand All @@ -121,7 +120,7 @@ func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Clie

if key == "" && !isLocalhost(apiURL) {
var err error
key, err = c.retrieveAPIKey(ctx, env, apiURL)
key, err = c.retrieveAPIKey(ctx, env, apiURL, envs)
if err != nil {
return nil, err
}
Expand All @@ -134,7 +133,7 @@ func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Clie
})
}

func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, error) {
func (c *Client) load(ctx context.Context, toolName string, env ...string) (*openai.Client, error) {
c.clientsLock.Lock()
defer c.clientsLock.Unlock()

Expand All @@ -144,7 +143,7 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
}

if isHTTPURL(toolName) {
remoteClient, err := c.clientFromURL(ctx, toolName)
remoteClient, err := c.clientFromURL(ctx, toolName, env)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -183,8 +182,8 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
return oClient, nil
}

func (c *Client) retrieveAPIKey(ctx context.Context, env, url string) (string, error) {
return prompt.GetModelProviderCredential(ctx, c.credStore, url, env, fmt.Sprintf("Please provide your API key for %s", url), append(gcontext.GetEnv(ctx), c.envs...))
func (c *Client) retrieveAPIKey(ctx context.Context, env, url string, envs []string) (string, error) {
return prompt.GetModelProviderCredential(ctx, c.credStore, url, env, fmt.Sprintf("Please provide your API key for %s", url), append(envs, c.envs...))
}

func isLocalhost(url string) bool {
Expand Down
2 changes: 1 addition & 1 deletion pkg/runner/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []str
if err != nil {
return nil, fmt.Errorf("marshaling input for output filter: %w", err)
}
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, outputToolRef.ToolID, string(inputData), "", engine.OutputToolCategory)
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, outputToolRef.ToolID, inputData, "", engine.OutputToolCategory)
if err != nil {
return nil, err
}
Expand Down
1 change: 1 addition & 0 deletions pkg/types/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type CompletionRequest struct {
Temperature *float32 `json:"temperature,omitempty"`
JSONResponse bool `json:"jsonResponse,omitempty"`
Cache *bool `json:"cache,omitempty"`
Env []string `json:"env,omitempty"`
}

func (r *CompletionRequest) GetCache() bool {
Expand Down

0 comments on commit 953f19a

Please sign in to comment.