diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go index d12f838e..ff5d1374 100644 --- a/pkg/parser/parser.go +++ b/pkg/parser/parser.go @@ -150,6 +150,8 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) { tool.Parameters.Credentials = append(tool.Parameters.Credentials, value) case "sharecredentials", "sharecreds", "sharecredential", "sharecred", "sharedcredentials", "sharedcreds", "sharedcredential", "sharedcred": tool.Parameters.ExportCredentials = append(tool.Parameters.ExportCredentials, value) + case "type": + tool.Type = types.ToolType(strings.ToLower(value)) default: return false, nil } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 9e8695a7..3a33c720 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -332,7 +332,7 @@ func getToolRefInput(prg *types.Program, ref types.ToolReference, input string) } func (r *Runner) getContext(callCtx engine.Context, state *State, monitor Monitor, env []string, input string) (result []engine.InputContext, _ *State, _ error) { - toolRefs, err := callCtx.Program.GetContextToolRefs(callCtx.Tool.ID) + toolRefs, err := callCtx.Tool.GetContextTools(*callCtx.Program) if err != nil { return nil, nil, err } diff --git a/pkg/tests/runner_test.go b/pkg/tests/runner_test.go index a9a01510..12eff23a 100644 --- a/pkg/tests/runner_test.go +++ b/pkg/tests/runner_test.go @@ -995,3 +995,8 @@ func TestMissingTool(t *testing.T) { r.AssertResponded(t) autogold.Expect("TEST RESULT CALL: 2").Equal(t, resp) } + +func TestToolRefAll(t *testing.T) { + r := tester.NewRunner(t) + r.RunDefault() +} diff --git a/pkg/tests/testdata/TestToolRefAll/call1-resp.golden b/pkg/tests/testdata/TestToolRefAll/call1-resp.golden new file mode 100644 index 00000000..2861a036 --- /dev/null +++ b/pkg/tests/testdata/TestToolRefAll/call1-resp.golden @@ -0,0 +1,9 @@ +`{ + "role": "assistant", + "content": [ + { + "text": "TEST RESULT CALL: 1" + } + ], + "usage": {} +}` diff --git a/pkg/tests/testdata/TestToolRefAll/call1.golden b/pkg/tests/testdata/TestToolRefAll/call1.golden new file mode 100644 index 00000000..4957014d --- /dev/null +++ b/pkg/tests/testdata/TestToolRefAll/call1.golden @@ -0,0 +1,61 @@ +`{ + "model": "gpt-4o", + "tools": [ + { + "function": { + "toolID": "testdata/TestToolRefAll/test.gpt:tool", + "name": "tool", + "parameters": { + "properties": { + "toolArg": { + "description": "stuff", + "type": "string" + } + }, + "type": "object" + } + } + }, + { + "function": { + "toolID": "testdata/TestToolRefAll/test.gpt:none", + "name": "none", + "parameters": { + "properties": { + "noneArg": { + "description": "stuff", + "type": "string" + } + }, + "type": "object" + } + } + }, + { + "function": { + "toolID": "testdata/TestToolRefAll/test.gpt:agentAssistant", + "name": "agent", + "parameters": { + "properties": { + "defaultPromptParameter": { + "description": "Prompt to send to the tool. This may be an instruction or question.", + "type": "string" + } + }, + "type": "object" + } + } + } + ], + "messages": [ + { + "role": "system", + "content": [ + { + "text": "\nContext Body\nMain tool" + } + ], + "usage": {} + } + ] +}` diff --git a/pkg/tests/testdata/TestToolRefAll/test.gpt b/pkg/tests/testdata/TestToolRefAll/test.gpt new file mode 100644 index 00000000..93c4ea05 --- /dev/null +++ b/pkg/tests/testdata/TestToolRefAll/test.gpt @@ -0,0 +1,30 @@ +tools: tool, agentAssistant, context, none + +Main tool + +--- +name: agentAssistant +type: agent + +Agent body + +--- +name: context +type: context + +#!sys.echo + +Context Body + +--- +name: none +param: noneArg: stuff + +Default type + +--- +name: tool +type: Tool +param: toolArg: stuff + +Typed tool \ No newline at end of file diff --git a/pkg/types/tool.go b/pkg/types/tool.go index 82effad4..54d5d817 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -26,6 +26,20 @@ var ( DefaultFiles = []string{"agent.gpt", "tool.gpt"} ) +type ToolType string + +const ( + ToolTypeContext = ToolType("context") + ToolTypeAgent = ToolType("agent") + ToolTypeOutput = ToolType("output") + ToolTypeInput = ToolType("input") + ToolTypeAssistant = ToolType("assistant") + ToolTypeTool = ToolType("tool") + ToolTypeCredential = ToolType("credential") + ToolTypeProvider = ToolType("provider") + ToolTypeDefault = ToolType("") +) + type ErrToolNotFound struct { ToolName string } @@ -77,28 +91,6 @@ type ToolReference struct { ToolID string `json:"toolID,omitempty"` } -func (p Program) GetContextToolRefs(toolID string) ([]ToolReference, error) { - return p.ToolSet[toolID].GetContextTools(p) -} - -func (p Program) GetCompletionTools() (result []CompletionTool, err error) { - return Tool{ - ToolDef: ToolDef{ - Parameters: Parameters{ - Tools: []string{"main"}, - }, - }, - ToolMapping: map[string][]ToolReference{ - "main": { - { - Reference: "main", - ToolID: p.EntryToolID, - }, - }, - }, - }.GetCompletionTools(p) -} - func (p Program) TopLevelTools() (result []Tool) { for _, tool := range p.ToolSet[p.EntryToolID].LocalTools { if target, ok := p.ToolSet[tool]; ok { @@ -145,6 +137,7 @@ type Parameters struct { OutputFilters []string `json:"outputFilters,omitempty"` ExportOutputFilters []string `json:"exportOutputFilters,omitempty"` Blocking bool `json:"-"` + Type ToolType `json:"type,omitempty"` } func (p Parameters) ToolRefNames() []string { @@ -347,6 +340,13 @@ func (t Tool) GetAgents(prg Program) (result []ToolReference, _ error) { return nil, err } + genericToolRefs, err := t.getCompletionToolRefs(prg, nil, ToolTypeAgent) + if err != nil { + return nil, err + } + + toolRefs = append(toolRefs, genericToolRefs...) + // Agent Tool refs must be named for i, toolRef := range toolRefs { if toolRef.Named != "" { @@ -358,7 +358,9 @@ func (t Tool) GetAgents(prg Program) (result []ToolReference, _ error) { name = toolRef.Reference } normed := ToolNormalizer(name) - normed = strings.TrimSuffix(strings.TrimSuffix(normed, "Agent"), "Assistant") + if trimmed := strings.TrimSuffix(strings.TrimSuffix(normed, "Agent"), "Assistant"); trimmed != "" { + normed = trimmed + } toolRefs[i].Named = normed } @@ -404,6 +406,9 @@ func (t ToolDef) String() string { if t.Parameters.Description != "" { _, _ = fmt.Fprintf(buf, "Description: %s\n", t.Parameters.Description) } + if t.Parameters.Type != ToolTypeDefault { + _, _ = fmt.Fprintf(buf, "Type: %s\n", strings.ToUpper(string(t.Type[0]))+string(t.Type[1:])) + } if len(t.Parameters.Agents) != 0 { _, _ = fmt.Fprintf(buf, "Agents: %s\n", strings.Join(t.Parameters.Agents, ", ")) } @@ -486,7 +491,7 @@ func (t ToolDef) String() string { return buf.String() } -func (t Tool) GetExportedContext(prg Program) ([]ToolReference, error) { +func (t Tool) getExportedContext(prg Program) ([]ToolReference, error) { result := &toolRefSet{} exportRefs, err := t.GetToolRefsFromNames(t.ExportContext) @@ -498,13 +503,13 @@ func (t Tool) GetExportedContext(prg Program) ([]ToolReference, error) { result.Add(exportRef) tool := prg.ToolSet[exportRef.ToolID] - result.AddAll(tool.GetExportedContext(prg)) + result.AddAll(tool.getExportedContext(prg)) } return result.List() } -func (t Tool) GetExportedTools(prg Program) ([]ToolReference, error) { +func (t Tool) getExportedTools(prg Program) ([]ToolReference, error) { result := &toolRefSet{} exportRefs, err := t.GetToolRefsFromNames(t.Export) @@ -514,7 +519,7 @@ func (t Tool) GetExportedTools(prg Program) ([]ToolReference, error) { for _, exportRef := range exportRefs { result.Add(exportRef) - result.AddAll(prg.ToolSet[exportRef.ToolID].GetExportedTools(prg)) + result.AddAll(prg.ToolSet[exportRef.ToolID].getExportedTools(prg)) } return result.List() @@ -524,6 +529,15 @@ func (t Tool) GetExportedTools(prg Program) ([]ToolReference, error) { // contexts that are exported by the context tools. This will recurse all exports. func (t Tool) GetContextTools(prg Program) ([]ToolReference, error) { result := &toolRefSet{} + result.AddAll(t.getDirectContextToolRefs(prg)) + result.AddAll(t.getCompletionToolRefs(prg, nil, ToolTypeContext)) + return result.List() +} + +// GetContextTools returns all tools that are in the context of the tool including all the +// contexts that are exported by the context tools. This will recurse all exports. +func (t Tool) getDirectContextToolRefs(prg Program) ([]ToolReference, error) { + result := &toolRefSet{} contextRefs, err := t.GetToolRefsFromNames(t.Context) if err != nil { @@ -531,7 +545,7 @@ func (t Tool) GetContextTools(prg Program) ([]ToolReference, error) { } for _, contextRef := range contextRefs { - result.AddAll(prg.ToolSet[contextRef.ToolID].GetExportedContext(prg)) + result.AddAll(prg.ToolSet[contextRef.ToolID].getExportedContext(prg)) result.Add(contextRef) } @@ -550,7 +564,9 @@ func (t Tool) GetOutputFilterTools(program Program) ([]ToolReference, error) { result.Add(outputFilterRef) } - contextRefs, err := t.GetContextTools(program) + result.AddAll(t.getCompletionToolRefs(program, nil, ToolTypeOutput)) + + contextRefs, err := t.getDirectContextToolRefs(program) if err != nil { return nil, err } @@ -575,7 +591,9 @@ func (t Tool) GetInputFilterTools(program Program) ([]ToolReference, error) { result.Add(inputFilterRef) } - contextRefs, err := t.GetContextTools(program) + result.AddAll(t.getCompletionToolRefs(program, nil, ToolTypeInput)) + + contextRefs, err := t.getDirectContextToolRefs(program) if err != nil { return nil, err } @@ -602,11 +620,28 @@ func (t Tool) GetNextAgentGroup(prg Program, agentGroup []ToolReference, toolID return agentGroup, nil } +func filterRefs(prg Program, refs []ToolReference, types ...ToolType) (result []ToolReference) { + for _, ref := range refs { + if slices.Contains(types, prg.ToolSet[ref.ToolID].Type) { + result = append(result, ref) + } + } + return +} + func (t Tool) GetCompletionTools(prg Program, agentGroup ...ToolReference) (result []CompletionTool, err error) { - refs, err := t.getCompletionToolRefs(prg, agentGroup) + toolSet := &toolRefSet{} + toolSet.AddAll(t.getCompletionToolRefs(prg, agentGroup, ToolTypeDefault, ToolTypeTool)) + + if err := t.addAgents(prg, toolSet); err != nil { + return nil, err + } + + refs, err := toolSet.List() if err != nil { return nil, err } + return toolRefsToCompletionTools(refs, prg), nil } @@ -638,26 +673,30 @@ func (t Tool) addReferencedTools(prg Program, result *toolRefSet) error { result.Add(subToolRef) // Get all tools exports - result.AddAll(prg.ToolSet[subToolRef.ToolID].GetExportedTools(prg)) + result.AddAll(prg.ToolSet[subToolRef.ToolID].getExportedTools(prg)) } return nil } func (t Tool) addContextExportedTools(prg Program, result *toolRefSet) error { - contextTools, err := t.GetContextTools(prg) + contextTools, err := t.getDirectContextToolRefs(prg) if err != nil { return err } for _, contextTool := range contextTools { - result.AddAll(prg.ToolSet[contextTool.ToolID].GetExportedTools(prg)) + result.AddAll(prg.ToolSet[contextTool.ToolID].getExportedTools(prg)) } return nil } -func (t Tool) getCompletionToolRefs(prg Program, agentGroup []ToolReference) ([]ToolReference, error) { +func (t Tool) getCompletionToolRefs(prg Program, agentGroup []ToolReference, types ...ToolType) ([]ToolReference, error) { + if len(types) == 0 { + types = []ToolType{ToolTypeDefault, ToolTypeTool} + } + result := toolRefSet{} if t.Chat { @@ -677,11 +716,8 @@ func (t Tool) getCompletionToolRefs(prg Program, agentGroup []ToolReference) ([] return nil, err } - if err := t.addAgents(prg, &result); err != nil { - return nil, err - } - - return result.List() + refs, err := result.List() + return filterRefs(prg, refs, types...), err } func (t Tool) GetCredentialTools(prg Program, agentGroup []ToolReference) ([]ToolReference, error) { @@ -689,6 +725,8 @@ func (t Tool) GetCredentialTools(prg Program, agentGroup []ToolReference) ([]Too result.AddAll(t.GetToolRefsFromNames(t.Credentials)) + result.AddAll(t.getCompletionToolRefs(prg, nil, ToolTypeCredential)) + toolRefs, err := t.getCompletionToolRefs(prg, agentGroup) if err != nil { return nil, err diff --git a/pkg/types/tool_test.go b/pkg/types/tool_test.go index 43af6cee..a47014a1 100644 --- a/pkg/types/tool_test.go +++ b/pkg/types/tool_test.go @@ -33,6 +33,8 @@ func TestToolDef_String(t *testing.T) { ExportInputFilters: []string{"SharedFilter1", "SharedFilter2"}, OutputFilters: []string{"Filter1", "Filter2"}, ExportOutputFilters: []string{"SharedFilter1", "SharedFilter2"}, + ExportCredentials: []string{"ExportCredential1", "ExportCredential2"}, + Type: ToolTypeContext, }, Instructions: "This is a sample instruction", } @@ -41,6 +43,7 @@ func TestToolDef_String(t *testing.T) { Global Tools: GlobalTool1, GlobalTool2 Name: Tool Sample Description: This is a sample tool +Type: Context Agents: Agent1, Agent2 Tools: Tool1, Tool2 Share Tools: Export1, Export2 @@ -60,6 +63,8 @@ Parameter: arg2: desc2 Internal Prompt: true Credential: Credential1 Credential: Credential2 +Share Credential: ExportCredential1 +Share Credential: ExportCredential2 Chat: true This is a sample instruction