Skip to content

Commit

Permalink
feat: add support for credential auth for non-tool-references
Browse files Browse the repository at this point in the history
Support authentication for agent and workflow tools that use
non-tool-references in their definition. That is, a GitHub hosted tool
or a local tool.

During the implementation and testing for non-tool-references, I also
spotted a gap: nested tools. That is, if an agent uses a tool that uses
a tool that has a credential, then the authentication for that tool
would not be processed. After this change, it will.

Signed-off-by: Donnie Adams <[email protected]>
  • Loading branch information
thedadams committed Jan 7, 2025
1 parent 51275aa commit 0e66d6b
Show file tree
Hide file tree
Showing 5 changed files with 279 additions and 179 deletions.
98 changes: 69 additions & 29 deletions pkg/api/handlers/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/obot-platform/obot/apiclient/types"
"github.com/obot-platform/obot/pkg/alias"
"github.com/obot-platform/obot/pkg/api"
"github.com/obot-platform/obot/pkg/controller/creds"
"github.com/obot-platform/obot/pkg/invoke"
"github.com/obot-platform/obot/pkg/render"
v1 "github.com/obot-platform/obot/pkg/storage/apis/otto.otto8.ai/v1"
Expand Down Expand Up @@ -60,7 +61,7 @@ func (a *AgentHandler) Authenticate(req api.Context) (err error) {
return err
}

resp, err := runAuthForAgent(req.Context(), req.Storage, a.invoker, agent.DeepCopy(), tools)
resp, err := runAuthForAgent(req.Context(), req.Storage, a.invoker, a.gptscript, agent.DeepCopy(), tools)
if err != nil {
return err
}
Expand Down Expand Up @@ -94,24 +95,7 @@ func (a *AgentHandler) DeAuthenticate(req api.Context) error {
return err
}

var (
errs []error
toolRef v1.ToolReference
)
for _, tool := range tools {
if err := req.Get(&toolRef, tool); err != nil {
errs = append(errs, err)
continue
}

if toolRef.Status.Tool != nil {
for _, cred := range toolRef.Status.Tool.CredentialNames {
if err := a.gptscript.DeleteCredential(req.Context(), id, cred); err != nil && !strings.HasSuffix(err.Error(), "credential not found") {
errs = append(errs, err)
}
}
}
}
errs := removeToolCredentials(req.Context(), req.Storage, a.gptscript, id, agent.Namespace, tools)

if err := kickAgent(req.Context(), req.Storage, &agent); err != nil {
errs = append(errs, fmt.Errorf("failed to update agent status: %w", err))
Expand Down Expand Up @@ -929,23 +913,35 @@ func MetadataFrom(obj kclient.Object, linkKV ...string) types.Metadata {
return m
}

func runAuthForAgent(ctx context.Context, c kclient.WithWatch, invoker *invoke.Invoker, agent *v1.Agent, tools []string) (*invoke.Response, error) {
func runAuthForAgent(ctx context.Context, c kclient.WithWatch, invoker *invoke.Invoker, gClient *gptscript.GPTScript, agent *v1.Agent, tools []string) (*invoke.Response, error) {
credentials := make([]string, 0, len(tools))

var toolRef v1.ToolReference
for _, tool := range tools {
if err := c.Get(ctx, kclient.ObjectKey{Namespace: agent.Namespace, Name: tool}, &toolRef); err != nil {
return nil, err
}
if strings.ContainsAny(tool, "./") {
prg, err := gClient.LoadFile(ctx, tool)
if err != nil {
return nil, err
}

if toolRef.Status.Tool == nil {
return nil, types.NewErrHttp(http.StatusTooEarly, fmt.Sprintf("tool %q is not ready", tool))
}
credentails, _, err := creds.DetermineCredsAndCredNames(prg, prg.ToolSet[prg.EntryToolID], tool)
if err != nil {
return nil, err
}

credentials = append(credentials, toolRef.Status.Tool.Credentials...)
credentials = append(credentials, credentails...)
} else if err := c.Get(ctx, kclient.ObjectKey{Namespace: agent.Namespace, Name: tool}, &toolRef); err == nil {
if toolRef.Status.Tool == nil {
return nil, types.NewErrHttp(http.StatusTooEarly, fmt.Sprintf("tool %q is not ready", tool))
}

// Reset the fields we care about so that we can use the same variable for the whole loop.
toolRef.Status.Tool = nil
credentials = append(credentials, toolRef.Status.Tool.Credentials...)

// Reset the fields we care about so that we can use the same variable for the whole loop.
toolRef.Status.Tool = nil
} else {
return nil, err
}
}

agent.Spec.Manifest.Prompt = "#!sys.echo\nDONE"
Expand All @@ -962,6 +958,50 @@ func runAuthForAgent(ctx context.Context, c kclient.WithWatch, invoker *invoke.I
})
}

func removeToolCredentials(ctx context.Context, client kclient.Client, gClient *gptscript.GPTScript, credCtx, namespace string, tools []string) []error {
var (
errs []error
toolRef v1.ToolReference
credentialNames []string
)
for _, tool := range tools {
if strings.ContainsAny(tool, "./") {
prg, err := gClient.LoadFile(ctx, tool)
if err != nil {
errs = append(errs, err)
continue
}

_, names, err := creds.DetermineCredsAndCredNames(prg, prg.ToolSet[prg.EntryToolID], tool)
if err != nil {
errs = append(errs, err)
continue
}

credentialNames = append(credentialNames, names...)
} else if err := client.Get(ctx, kclient.ObjectKey{Namespace: namespace, Name: tool}, &toolRef); err == nil {
if toolRef.Status.Tool != nil {
credentialNames = append(credentialNames, toolRef.Status.Tool.CredentialNames...)
}
} else {
errs = append(errs, err)
continue
}

// Reset the value we care about so the same variable can be used.
// This ensures that the value we read on the next iteration is pulled from the database.
toolRef.Status.Tool = nil

for _, cred := range credentialNames {
if err := gClient.DeleteCredential(ctx, credCtx, cred); err != nil && !strings.HasSuffix(err.Error(), "credential not found") {
errs = append(errs, err)
}
}
}

return errs
}

func kickAgent(ctx context.Context, c kclient.Client, agent *v1.Agent) error {
if agent.Annotations[v1.AgentSyncAnnotation] != "" {
delete(agent.Annotations, v1.AgentSyncAnnotation)
Expand Down
25 changes: 2 additions & 23 deletions pkg/api/handlers/workflows.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (a *WorkflowHandler) Authenticate(req api.Context) error {
return err
}

resp, err := runAuthForAgent(req.Context(), req.Storage, a.invoker, agent, tools)
resp, err := runAuthForAgent(req.Context(), req.Storage, a.invoker, a.gptscript, agent, tools)
if err != nil {
return err
}
Expand Down Expand Up @@ -92,28 +92,7 @@ func (a *WorkflowHandler) DeAuthenticate(req api.Context) error {
return err
}

var (
errs []error
toolRef v1.ToolReference
)
for _, tool := range tools {
if err := req.Get(&toolRef, tool); err != nil {
errs = append(errs, err)
continue
}

if toolRef.Status.Tool != nil {
for _, cred := range toolRef.Status.Tool.CredentialNames {
if err := a.gptscript.DeleteCredential(req.Context(), id, cred); err != nil && !strings.HasSuffix(err.Error(), "credential not found") {
errs = append(errs, err)
}
}

// Reset the value we care about so the same variable can be used.
// This ensures that the value we read on the next iteration is pulled from the database.
toolRef.Status.Tool = nil
}
}
errs := removeToolCredentials(req.Context(), req.Storage, a.gptscript, id, wf.Namespace, tools)

if err := kickWorkflow(req.Context(), req.Storage, &wf); err != nil {
errs = append(errs, fmt.Errorf("failed to update workflow status: %w", err))
Expand Down
173 changes: 173 additions & 0 deletions pkg/controller/creds/creds.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
package creds

import (
"fmt"
"net/url"
"path"
"slices"
"strings"

"github.com/gptscript-ai/go-gptscript"
gtypes "github.com/gptscript-ai/gptscript/pkg/types"
"github.com/obot-platform/obot/pkg/system"
)

func DetermineCredsAndCredNames(prg *gptscript.Program, tool gptscript.Tool, name string) ([]string, []string, error) {
// The available tool references from this tool are the tool itself and any tool this tool exports.
toolRefs := make([]toolRef, 0, len(tool.Export)+len(tool.Tools)+1)
toolRefs = append(toolRefs, toolRef{
ToolReference: gptscript.ToolReference{
Reference: name,
ToolID: prg.EntryToolID,
},
name: name,
})
for _, t := range tool.Tools {
for _, ref := range tool.ToolMapping[t] {
toolRefs = append(toolRefs, toolRef{
ToolReference: ref,
name: t,
})
}
}

credentials := make([]string, 0, len(tool.Credentials)+len(tool.Export)+len(tool.Tools))
credentialNames := make([]string, 0, len(tool.Credentials)+len(tool.Export)+len(tool.Tools))
seen := make(map[string]struct{})
for len(toolRefs) > 0 {
ref := toolRefs[0]
toolRefs = toolRefs[1:]

if _, ok := seen[ref.ToolID]; ok {
continue
}
seen[ref.ToolID] = struct{}{}

t := prg.ToolSet[ref.ToolID]

// Add the tools that this tool exports if we haven't already seen them.
for _, e := range t.Export {
refs := t.ToolMapping[e]
for _, r := range refs {
if _, ok := seen[r.ToolID]; !ok {
toolRefs = append(toolRefs, toolRef{
ToolReference: r,
name: ref.name,
})
}
}
}

for _, cred := range t.Credentials {
parsedCred := cred
credToolName, credSubTool := gtypes.SplitToolRef(cred)
if strings.HasPrefix(credToolName, ".") {
toolName, _ := gtypes.SplitToolRef(ref.Reference)
if !path.IsAbs(toolName) {
if !strings.HasPrefix(toolName, ".") {
toolName, _ = gtypes.SplitToolRef(ref.name)
} else {
toolName = path.Join(ref.name, toolName)
}
}

refURL, err := url.Parse(toolName)
if err != nil {
continue
}

if strings.HasSuffix(refURL.Path, ".gpt") {
refURL.Path = path.Dir(refURL.Path)
}

refURL.Path = path.Join(refURL.Path, credToolName)
parsedCred = refURL.String()
if refURL.Host == "" {
// This is only a path, so url unescape it.
// No need to check the error here, we would have errored when parsing.
parsedCred, _ = url.PathUnescape(parsedCred)
}

if credSubTool != "" {
parsedCred = fmt.Sprintf("%s from %s", credSubTool, parsedCred)
}
}

if parsedCred != "" && !slices.Contains(credentials, parsedCred) {
credentials = append(credentials, parsedCred)
}

credNames, err := determineCredentialNames(prg, prg.ToolSet[ref.ToolID], cred)
if err != nil {
return credentials, credentialNames, err
}

for _, n := range credNames {
if !slices.Contains(credentialNames, n) {
credentialNames = append(credentialNames, n)
}
}
}
}

return credentials, credentialNames, nil
}

func determineCredentialNames(prg *gptscript.Program, tool gptscript.Tool, toolName string) ([]string, error) {
if toolName == system.ModelProviderCredential {
return []string{system.ModelProviderCredential}, nil
}

var subTool string
parsedToolName, alias, args, err := gtypes.ParseCredentialArgs(toolName, "")
if err != nil {
parsedToolName, subTool = gtypes.SplitToolRef(toolName)
parsedToolName, alias, args, err = gtypes.ParseCredentialArgs(parsedToolName, "")
if err != nil {
return nil, err
}
}

if alias != "" {
return []string{alias}, nil
}

if args == nil {
// This is a tool and not the credential format. Parse the tool from the program to determine the alias
toolNames := make([]string, 0, len(tool.Credentials))
if subTool == "" {
toolName = parsedToolName
}
for _, cred := range tool.Credentials {
if cred == toolName {
if len(tool.ToolMapping[cred]) == 0 {
return nil, fmt.Errorf("cannot find credential name for tool %q", toolName)
}

for _, ref := range tool.ToolMapping[cred] {
for _, c := range prg.ToolSet[ref.ToolID].ExportCredentials {
names, err := determineCredentialNames(prg, prg.ToolSet[ref.ToolID], c)
if err != nil {
return nil, err
}

toolNames = append(toolNames, names...)
}
}
}
}

if len(toolNames) > 0 {
return toolNames, nil
}

return nil, fmt.Errorf("tool %q not found in program", toolName)
}

return []string{toolName}, nil
}

type toolRef struct {
gptscript.ToolReference
name string
}
Loading

0 comments on commit 0e66d6b

Please sign in to comment.