Skip to content

Commit

Permalink
fix: restart model providers when its credential changes (#761)
Browse files Browse the repository at this point in the history
Signed-off-by: Donnie Adams <[email protected]>
  • Loading branch information
thedadams authored Dec 4, 2024
1 parent 7145fd9 commit a8a93e2
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 65 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ require (
github.com/gptscript-ai/chat-completion-client v0.0.0-20241127005108-02b41e1cd02e
github.com/gptscript-ai/cmd v0.0.0-20240907001148-ffd49061124a
github.com/gptscript-ai/go-gptscript v0.9.6-0.20241115201052-7efb3409cfcc
github.com/gptscript-ai/gptscript v0.9.6-0.20241121180135-e5fe428c6858
github.com/gptscript-ai/gptscript v0.9.6-0.20241204172147-c39a0693ee94
github.com/liggitt/tabwriter v0.0.0-20181228230101-89fcab3d43de
github.com/mhale/smtpd v0.8.3
github.com/oauth2-proxy/oauth2-proxy/v7 v7.0.0-00010101000000-000000000000
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,8 @@ github.com/gptscript-ai/cmd v0.0.0-20240907001148-ffd49061124a h1:LX7AOcbBoTnUk/
github.com/gptscript-ai/cmd v0.0.0-20240907001148-ffd49061124a/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw=
github.com/gptscript-ai/go-gptscript v0.9.6-0.20241115201052-7efb3409cfcc h1:D0N65peenVgPor9Ph0LiMLfgI4qCE9VGxgw8KxlvLgE=
github.com/gptscript-ai/go-gptscript v0.9.6-0.20241115201052-7efb3409cfcc/go.mod h1:/FVuLwhz+sIfsWUgUHWKi32qT0i6+IXlUlzs70KKt/Q=
github.com/gptscript-ai/gptscript v0.9.6-0.20241121180135-e5fe428c6858 h1:FCHDABLrqOgxppU0NqLW4PlyWeF/K8WD969PQso3mVY=
github.com/gptscript-ai/gptscript v0.9.6-0.20241121180135-e5fe428c6858/go.mod h1:1ECuES7S+IjL4oua0nJzytsWw45tCCv750T44+/JczQ=
github.com/gptscript-ai/gptscript v0.9.6-0.20241204172147-c39a0693ee94 h1:Vkaujp51uMhix6Mag2LMHBCR8XEBXbnvxR9klUiG5iE=
github.com/gptscript-ai/gptscript v0.9.6-0.20241204172147-c39a0693ee94/go.mod h1:1ECuES7S+IjL4oua0nJzytsWw45tCCv750T44+/JczQ=
github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6 h1:vkgNZVWQgbE33VD3z9WKDwuu7B/eJVVMMPM62ixfCR8=
github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6/go.mod h1:frrl/B+ZH3VSs3Tqk2qxEIIWTONExX3tuUa4JsVnqx4=
github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7 h1:pdN6V1QBWetyv/0+wjACpqVH+eVULgEjkurDLq3goeM=
Expand Down
11 changes: 8 additions & 3 deletions pkg/api/handlers/modelprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,21 @@ import (
"github.com/gptscript-ai/go-gptscript"
"github.com/otto8-ai/otto8/apiclient/types"
"github.com/otto8-ai/otto8/pkg/api"
"github.com/otto8-ai/otto8/pkg/gateway/server/dispatcher"
v1 "github.com/otto8-ai/otto8/pkg/storage/apis/otto.otto8.ai/v1"
"k8s.io/apimachinery/pkg/fields"
kclient "sigs.k8s.io/controller-runtime/pkg/client"
)

type ModelProviderHandler struct {
gptscript *gptscript.GPTScript
gptscript *gptscript.GPTScript
dispatcher *dispatcher.Dispatcher
}

func NewModelProviderHandler(gClient *gptscript.GPTScript) *ModelProviderHandler {
func NewModelProviderHandler(gClient *gptscript.GPTScript, dispatcher *dispatcher.Dispatcher) *ModelProviderHandler {
return &ModelProviderHandler{
gptscript: gClient,
gptscript: gClient,
dispatcher: dispatcher,
}
}

Expand Down Expand Up @@ -93,6 +96,8 @@ func (mp *ModelProviderHandler) Configure(req api.Context) error {
return fmt.Errorf("failed to create credential: %w", err)
}

mp.dispatcher.StopModelProvider(ref.Namespace, ref.Name)

return nil
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/api/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func Router(services *services.Services) (http.Handler, error) {
cronJobs := handlers.NewCronJobHandler()
models := handlers.NewModelHandler(services.GPTClient)
availableModels := handlers.NewAvailableModelsHandler(services.GPTClient, services.ModelProviderDispatcher)
modelProviders := handlers.NewModelProviderHandler(services.GPTClient)
modelProviders := handlers.NewModelProviderHandler(services.GPTClient, services.ModelProviderDispatcher)
prompt := handlers.NewPromptHandler(services.GPTClient)
emailreceiver := handlers.NewEmailReceiverHandler(services.EmailServerName)
defaultModelAliases := handlers.NewDefaultModelAliasHandler()
Expand Down
2 changes: 1 addition & 1 deletion pkg/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (c *Controller) PostStart(ctx context.Context) error {
return fmt.Errorf("failed to apply data: %w", err)
}
go c.toolRefHandler.PollRegistry(ctx, c.services.Router.Backend())
return nil
return c.toolRefHandler.EnsureOpenAIEnvCredential(ctx, c.services.Router.Backend())
}

func (c *Controller) Start(ctx context.Context) error {
Expand Down
95 changes: 55 additions & 40 deletions pkg/controller/handlers/toolreference/toolreference.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/otto8-ai/otto8/logger"
v1 "github.com/otto8-ai/otto8/pkg/storage/apis/otto.otto8.ai/v1"
"github.com/otto8-ai/otto8/pkg/system"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/yaml"
Expand Down Expand Up @@ -193,53 +194,13 @@ func (h *Handler) PollRegistry(ctx context.Context, c client.Client) {
break
}

var openAICredentialSet bool
t := time.NewTicker(time.Hour)
defer t.Stop()
for {
if err := h.readFromRegistry(ctx, c); err != nil {
log.Errorf("Failed to read from registry: %v", err)
}

// If the openai-model-provider exists and the OPENAI_API_KEY environment variable is set, then ensure the credential exists.
if !openAICredentialSet && os.Getenv("OPENAI_API_KEY") != "" {
var openAIModelProvider v1.ToolReference

// Not reporting errors here, best-effort only.
if err := c.Get(ctx, client.ObjectKey{Namespace: system.DefaultNamespace, Name: "openai-model-provider"}, &openAIModelProvider); err == nil {
if cred, err := h.gptClient.RevealCredential(ctx, []string{string(openAIModelProvider.UID)}, "openai-model-provider"); err != nil && strings.HasSuffix(err.Error(), "credential not found") {
// The credential doesn't exist, so create it.
err = h.gptClient.CreateCredential(ctx, gptscript.Credential{
Context: string(openAIModelProvider.UID),
ToolName: "openai-model-provider",
Type: gptscript.CredentialTypeModelProvider,
Env: map[string]string{
"OTTO8_OPENAI_MODEL_PROVIDER_API_KEY": os.Getenv("OPENAI_API_KEY"),
},
})

openAICredentialSet = err == nil
} else if err == nil && cred.Env["OTTO8_OPENAI_MODEL_PROVIDER_API_KEY"] != os.Getenv("OPENAI_API_KEY") {
// If the credential exists, but has a different value, then update it.
// The only way to update it is to delete the existing credential and recreate it.
if err = h.gptClient.DeleteCredential(ctx, string(openAIModelProvider.UID), "openai-model-provider"); err == nil {
err = h.gptClient.CreateCredential(ctx, gptscript.Credential{
Context: string(openAIModelProvider.UID),
ToolName: "openai-model-provider",
Type: gptscript.CredentialTypeModelProvider,
Env: map[string]string{
"OTTO8_OPENAI_MODEL_PROVIDER_API_KEY": os.Getenv("OPENAI_API_KEY"),
},
})

openAICredentialSet = err == nil
}
} else {
openAICredentialSet = true
}
}
}

select {
case <-t.C:
case <-ctx.Done():
Expand Down Expand Up @@ -297,6 +258,60 @@ func (h *Handler) Populate(req router.Request, resp router.Response) error {
return nil
}

func (h *Handler) EnsureOpenAIEnvCredential(ctx context.Context, c client.Client) error {
if os.Getenv("OPENAI_API_KEY") == "" {
return nil
}

for {
select {
case <-time.After(2 * time.Second):
case <-ctx.Done():
return ctx.Err()
}

// If the openai-model-provider exists and the OPENAI_API_KEY environment variable is set, then ensure the credential exists.
var openAIModelProvider v1.ToolReference
if err := c.Get(ctx, client.ObjectKey{Namespace: system.DefaultNamespace, Name: "openai-model-provider"}, &openAIModelProvider); apierrors.IsNotFound(err) {
continue
} else if err != nil {
return err
}

if cred, err := h.gptClient.RevealCredential(ctx, []string{string(openAIModelProvider.UID)}, "openai-model-provider"); err != nil {
if strings.HasSuffix(err.Error(), "credential not found") {
// The credential doesn't exist, so create it.
return h.gptClient.CreateCredential(ctx, gptscript.Credential{
Context: string(openAIModelProvider.UID),
ToolName: "openai-model-provider",
Type: gptscript.CredentialTypeModelProvider,
Env: map[string]string{
"OTTO8_OPENAI_MODEL_PROVIDER_API_KEY": os.Getenv("OPENAI_API_KEY"),
},
})
}

return fmt.Errorf("failed to check OpenAI credential: %w", err)
} else if cred.Env["OTTO8_OPENAI_MODEL_PROVIDER_API_KEY"] != os.Getenv("OPENAI_API_KEY") {
// If the credential exists, but has a different value, then update it.
// The only way to update it is to delete the existing credential and recreate it.
if err = h.gptClient.DeleteCredential(ctx, string(openAIModelProvider.UID), "openai-model-provider"); err != nil {
return fmt.Errorf("failed to delete credential: %w", err)
}
return h.gptClient.CreateCredential(ctx, gptscript.Credential{
Context: string(openAIModelProvider.UID),
ToolName: "openai-model-provider",
Type: gptscript.CredentialTypeModelProvider,
Env: map[string]string{
"OTTO8_OPENAI_MODEL_PROVIDER_API_KEY": os.Getenv("OPENAI_API_KEY"),
},
})
}

return nil
}
}

func (h *Handler) RemoveModelProviderCredential(req router.Request, _ router.Response) error {
toolRef := req.Object.(*v1.ToolReference)
if toolRef.Spec.Type != types.ToolReferenceTypeModelProvider || toolRef.Status.Tool == nil || toolRef.Status.Tool.Metadata["envVars"] == "" {
Expand Down
24 changes: 19 additions & 5 deletions pkg/gateway/server/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ func New(invoker *invoke.Invoker, c kclient.Client, gClient *gptscript.GPTScript
}

func (d *Dispatcher) URLForModelProvider(ctx context.Context, namespace, modelProviderName string) (*url.URL, error) {
key := namespace + "/" + modelProviderName
// Check the map with the read lock.
d.lock.RLock()
u, ok := d.urls[modelProviderName]
u, ok := d.urls[key]
d.lock.RUnlock()
if ok && (u.Scheme == "https" || engine.IsDaemonRunning(u.String())) {
if ok && (u.Hostname() == "127.0.0.1" || engine.IsDaemonRunning(u.String())) {
return u, nil
}

Expand All @@ -56,8 +57,8 @@ func (d *Dispatcher) URLForModelProvider(ctx context.Context, namespace, modelPr

// If we didn't find anything with the read lock, check with the write lock.
// It could be that another thread beat us to the write lock and added the model provider we desire.
u, ok = d.urls[modelProviderName]
if ok && (u.Scheme == "https" || engine.IsDaemonRunning(u.String())) {
u, ok = d.urls[key]
if ok && (u.Hostname() != "127.0.0.1" || engine.IsDaemonRunning(u.String())) {
return u, nil
}

Expand All @@ -67,10 +68,23 @@ func (d *Dispatcher) URLForModelProvider(ctx context.Context, namespace, modelPr
return nil, err
}

d.urls[modelProviderName] = u
d.urls[key] = u
return u, nil
}

func (d *Dispatcher) StopModelProvider(namespace, modelProviderName string) {
key := namespace + "/" + modelProviderName
d.lock.Lock()
defer d.lock.Unlock()

u := d.urls[key]
if u != nil && u.Hostname() == "127.0.0.1" && engine.IsDaemonRunning(u.String()) {
engine.StopDaemon(u.String())
}

delete(d.urls, key)
}

func (d *Dispatcher) TransformRequest(req *http.Request, namespace string) error {
body, err := readBody(req)
if err != nil {
Expand Down
22 changes: 10 additions & 12 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,16 @@ func Run(ctx context.Context, c services.Config) error {
return err
}

go func() {
c, err := controller.New(svcs)
if err != nil {
log.Fatalf("Failed to start controller: %v", err)
}
if err := c.Start(ctx); err != nil {
log.Fatalf("Failed to start controller: %v", err)
}
if err := c.PostStart(ctx); err != nil {
log.Fatalf("Failed to post start controller: %v", err)
}
}()
ctrl, err := controller.New(svcs)
if err != nil {
log.Fatalf("Failed to start controller: %v", err)
}
if err = ctrl.Start(ctx); err != nil {
log.Fatalf("Failed to start controller: %v", err)
}
if err = ctrl.PostStart(ctx); err != nil {
log.Fatalf("Failed to post start controller: %v", err)
}

handler, err := router.Router(svcs)
if err != nil {
Expand Down

0 comments on commit a8a93e2

Please sign in to comment.