Skip to content

Commit

Permalink
Merge pull request #706 from ibuildthecloud/main
Browse files Browse the repository at this point in the history
chore: add type field to tools
  • Loading branch information
ibuildthecloud authored Aug 5, 2024
2 parents a7509b0 + 8931247 commit 039a685
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 41 deletions.
2 changes: 2 additions & 0 deletions pkg/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
tool.Parameters.Credentials = append(tool.Parameters.Credentials, value)
case "sharecredentials", "sharecreds", "sharecredential", "sharecred", "sharedcredentials", "sharedcreds", "sharedcredential", "sharedcred":
tool.Parameters.ExportCredentials = append(tool.Parameters.ExportCredentials, value)
case "type":
tool.Type = types.ToolType(strings.ToLower(value))
default:
return false, nil
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ func getToolRefInput(prg *types.Program, ref types.ToolReference, input string)
}

func (r *Runner) getContext(callCtx engine.Context, state *State, monitor Monitor, env []string, input string) (result []engine.InputContext, _ *State, _ error) {
toolRefs, err := callCtx.Program.GetContextToolRefs(callCtx.Tool.ID)
toolRefs, err := callCtx.Tool.GetContextTools(*callCtx.Program)
if err != nil {
return nil, nil, err
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/tests/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -995,3 +995,8 @@ func TestMissingTool(t *testing.T) {
r.AssertResponded(t)
autogold.Expect("TEST RESULT CALL: 2").Equal(t, resp)
}

func TestToolRefAll(t *testing.T) {
r := tester.NewRunner(t)
r.RunDefault()
}
9 changes: 9 additions & 0 deletions pkg/tests/testdata/TestToolRefAll/call1-resp.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
`{
"role": "assistant",
"content": [
{
"text": "TEST RESULT CALL: 1"
}
],
"usage": {}
}`
61 changes: 61 additions & 0 deletions pkg/tests/testdata/TestToolRefAll/call1.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
`{
"model": "gpt-4o",
"tools": [
{
"function": {
"toolID": "testdata/TestToolRefAll/test.gpt:tool",
"name": "tool",
"parameters": {
"properties": {
"toolArg": {
"description": "stuff",
"type": "string"
}
},
"type": "object"
}
}
},
{
"function": {
"toolID": "testdata/TestToolRefAll/test.gpt:none",
"name": "none",
"parameters": {
"properties": {
"noneArg": {
"description": "stuff",
"type": "string"
}
},
"type": "object"
}
}
},
{
"function": {
"toolID": "testdata/TestToolRefAll/test.gpt:agentAssistant",
"name": "agent",
"parameters": {
"properties": {
"defaultPromptParameter": {
"description": "Prompt to send to the tool. This may be an instruction or question.",
"type": "string"
}
},
"type": "object"
}
}
}
],
"messages": [
{
"role": "system",
"content": [
{
"text": "\nContext Body\nMain tool"
}
],
"usage": {}
}
]
}`
30 changes: 30 additions & 0 deletions pkg/tests/testdata/TestToolRefAll/test.gpt
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
tools: tool, agentAssistant, context, none

Main tool

---
name: agentAssistant
type: agent

Agent body

---
name: context
type: context

#!sys.echo

Context Body

---
name: none
param: noneArg: stuff

Default type

---
name: tool
type: Tool
param: toolArg: stuff

Typed tool
118 changes: 78 additions & 40 deletions pkg/types/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@ var (
DefaultFiles = []string{"agent.gpt", "tool.gpt"}
)

type ToolType string

const (
ToolTypeContext = ToolType("context")
ToolTypeAgent = ToolType("agent")
ToolTypeOutput = ToolType("output")
ToolTypeInput = ToolType("input")
ToolTypeAssistant = ToolType("assistant")
ToolTypeTool = ToolType("tool")
ToolTypeCredential = ToolType("credential")
ToolTypeProvider = ToolType("provider")
ToolTypeDefault = ToolType("")
)

type ErrToolNotFound struct {
ToolName string
}
Expand Down Expand Up @@ -77,28 +91,6 @@ type ToolReference struct {
ToolID string `json:"toolID,omitempty"`
}

func (p Program) GetContextToolRefs(toolID string) ([]ToolReference, error) {
return p.ToolSet[toolID].GetContextTools(p)
}

func (p Program) GetCompletionTools() (result []CompletionTool, err error) {
return Tool{
ToolDef: ToolDef{
Parameters: Parameters{
Tools: []string{"main"},
},
},
ToolMapping: map[string][]ToolReference{
"main": {
{
Reference: "main",
ToolID: p.EntryToolID,
},
},
},
}.GetCompletionTools(p)
}

func (p Program) TopLevelTools() (result []Tool) {
for _, tool := range p.ToolSet[p.EntryToolID].LocalTools {
if target, ok := p.ToolSet[tool]; ok {
Expand Down Expand Up @@ -145,6 +137,7 @@ type Parameters struct {
OutputFilters []string `json:"outputFilters,omitempty"`
ExportOutputFilters []string `json:"exportOutputFilters,omitempty"`
Blocking bool `json:"-"`
Type ToolType `json:"type,omitempty"`
}

func (p Parameters) ToolRefNames() []string {
Expand Down Expand Up @@ -347,6 +340,13 @@ func (t Tool) GetAgents(prg Program) (result []ToolReference, _ error) {
return nil, err
}

genericToolRefs, err := t.getCompletionToolRefs(prg, nil, ToolTypeAgent)
if err != nil {
return nil, err
}

toolRefs = append(toolRefs, genericToolRefs...)

// Agent Tool refs must be named
for i, toolRef := range toolRefs {
if toolRef.Named != "" {
Expand All @@ -358,7 +358,9 @@ func (t Tool) GetAgents(prg Program) (result []ToolReference, _ error) {
name = toolRef.Reference
}
normed := ToolNormalizer(name)
normed = strings.TrimSuffix(strings.TrimSuffix(normed, "Agent"), "Assistant")
if trimmed := strings.TrimSuffix(strings.TrimSuffix(normed, "Agent"), "Assistant"); trimmed != "" {
normed = trimmed
}
toolRefs[i].Named = normed
}

Expand Down Expand Up @@ -404,6 +406,9 @@ func (t ToolDef) String() string {
if t.Parameters.Description != "" {
_, _ = fmt.Fprintf(buf, "Description: %s\n", t.Parameters.Description)
}
if t.Parameters.Type != ToolTypeDefault {
_, _ = fmt.Fprintf(buf, "Type: %s\n", strings.ToUpper(string(t.Type[0]))+string(t.Type[1:]))
}
if len(t.Parameters.Agents) != 0 {
_, _ = fmt.Fprintf(buf, "Agents: %s\n", strings.Join(t.Parameters.Agents, ", "))
}
Expand Down Expand Up @@ -486,7 +491,7 @@ func (t ToolDef) String() string {
return buf.String()
}

func (t Tool) GetExportedContext(prg Program) ([]ToolReference, error) {
func (t Tool) getExportedContext(prg Program) ([]ToolReference, error) {
result := &toolRefSet{}

exportRefs, err := t.GetToolRefsFromNames(t.ExportContext)
Expand All @@ -498,13 +503,13 @@ func (t Tool) GetExportedContext(prg Program) ([]ToolReference, error) {
result.Add(exportRef)

tool := prg.ToolSet[exportRef.ToolID]
result.AddAll(tool.GetExportedContext(prg))
result.AddAll(tool.getExportedContext(prg))
}

return result.List()
}

func (t Tool) GetExportedTools(prg Program) ([]ToolReference, error) {
func (t Tool) getExportedTools(prg Program) ([]ToolReference, error) {
result := &toolRefSet{}

exportRefs, err := t.GetToolRefsFromNames(t.Export)
Expand All @@ -514,7 +519,7 @@ func (t Tool) GetExportedTools(prg Program) ([]ToolReference, error) {

for _, exportRef := range exportRefs {
result.Add(exportRef)
result.AddAll(prg.ToolSet[exportRef.ToolID].GetExportedTools(prg))
result.AddAll(prg.ToolSet[exportRef.ToolID].getExportedTools(prg))
}

return result.List()
Expand All @@ -524,14 +529,23 @@ func (t Tool) GetExportedTools(prg Program) ([]ToolReference, error) {
// contexts that are exported by the context tools. This will recurse all exports.
func (t Tool) GetContextTools(prg Program) ([]ToolReference, error) {
result := &toolRefSet{}
result.AddAll(t.getDirectContextToolRefs(prg))
result.AddAll(t.getCompletionToolRefs(prg, nil, ToolTypeContext))
return result.List()
}

// GetContextTools returns all tools that are in the context of the tool including all the
// contexts that are exported by the context tools. This will recurse all exports.
func (t Tool) getDirectContextToolRefs(prg Program) ([]ToolReference, error) {
result := &toolRefSet{}

contextRefs, err := t.GetToolRefsFromNames(t.Context)
if err != nil {
return nil, err
}

for _, contextRef := range contextRefs {
result.AddAll(prg.ToolSet[contextRef.ToolID].GetExportedContext(prg))
result.AddAll(prg.ToolSet[contextRef.ToolID].getExportedContext(prg))
result.Add(contextRef)
}

Expand All @@ -550,7 +564,9 @@ func (t Tool) GetOutputFilterTools(program Program) ([]ToolReference, error) {
result.Add(outputFilterRef)
}

contextRefs, err := t.GetContextTools(program)
result.AddAll(t.getCompletionToolRefs(program, nil, ToolTypeOutput))

contextRefs, err := t.getDirectContextToolRefs(program)
if err != nil {
return nil, err
}
Expand All @@ -575,7 +591,9 @@ func (t Tool) GetInputFilterTools(program Program) ([]ToolReference, error) {
result.Add(inputFilterRef)
}

contextRefs, err := t.GetContextTools(program)
result.AddAll(t.getCompletionToolRefs(program, nil, ToolTypeInput))

contextRefs, err := t.getDirectContextToolRefs(program)
if err != nil {
return nil, err
}
Expand All @@ -602,11 +620,28 @@ func (t Tool) GetNextAgentGroup(prg Program, agentGroup []ToolReference, toolID
return agentGroup, nil
}

func filterRefs(prg Program, refs []ToolReference, types ...ToolType) (result []ToolReference) {
for _, ref := range refs {
if slices.Contains(types, prg.ToolSet[ref.ToolID].Type) {
result = append(result, ref)
}
}
return
}

func (t Tool) GetCompletionTools(prg Program, agentGroup ...ToolReference) (result []CompletionTool, err error) {
refs, err := t.getCompletionToolRefs(prg, agentGroup)
toolSet := &toolRefSet{}
toolSet.AddAll(t.getCompletionToolRefs(prg, agentGroup, ToolTypeDefault, ToolTypeTool))

if err := t.addAgents(prg, toolSet); err != nil {
return nil, err
}

refs, err := toolSet.List()
if err != nil {
return nil, err
}

return toolRefsToCompletionTools(refs, prg), nil
}

Expand Down Expand Up @@ -638,26 +673,30 @@ func (t Tool) addReferencedTools(prg Program, result *toolRefSet) error {
result.Add(subToolRef)

// Get all tools exports
result.AddAll(prg.ToolSet[subToolRef.ToolID].GetExportedTools(prg))
result.AddAll(prg.ToolSet[subToolRef.ToolID].getExportedTools(prg))
}

return nil
}

func (t Tool) addContextExportedTools(prg Program, result *toolRefSet) error {
contextTools, err := t.GetContextTools(prg)
contextTools, err := t.getDirectContextToolRefs(prg)
if err != nil {
return err
}

for _, contextTool := range contextTools {
result.AddAll(prg.ToolSet[contextTool.ToolID].GetExportedTools(prg))
result.AddAll(prg.ToolSet[contextTool.ToolID].getExportedTools(prg))
}

return nil
}

func (t Tool) getCompletionToolRefs(prg Program, agentGroup []ToolReference) ([]ToolReference, error) {
func (t Tool) getCompletionToolRefs(prg Program, agentGroup []ToolReference, types ...ToolType) ([]ToolReference, error) {
if len(types) == 0 {
types = []ToolType{ToolTypeDefault, ToolTypeTool}
}

result := toolRefSet{}

if t.Chat {
Expand All @@ -677,18 +716,17 @@ func (t Tool) getCompletionToolRefs(prg Program, agentGroup []ToolReference) ([]
return nil, err
}

if err := t.addAgents(prg, &result); err != nil {
return nil, err
}

return result.List()
refs, err := result.List()
return filterRefs(prg, refs, types...), err
}

func (t Tool) GetCredentialTools(prg Program, agentGroup []ToolReference) ([]ToolReference, error) {
result := toolRefSet{}

result.AddAll(t.GetToolRefsFromNames(t.Credentials))

result.AddAll(t.getCompletionToolRefs(prg, nil, ToolTypeCredential))

toolRefs, err := t.getCompletionToolRefs(prg, agentGroup)
if err != nil {
return nil, err
Expand Down
Loading

0 comments on commit 039a685

Please sign in to comment.