diff --git a/pkg/tests/testdata/TestToolRefAll/call1.golden b/pkg/tests/testdata/TestToolRefAll/call1.golden index 4957014d..ef36e3fb 100644 --- a/pkg/tests/testdata/TestToolRefAll/call1.golden +++ b/pkg/tests/testdata/TestToolRefAll/call1.golden @@ -52,7 +52,7 @@ "role": "system", "content": [ { - "text": "\nContext Body\nMain tool" + "text": "\nShared context\n\nContext Body\nMain tool" } ], "usage": {} diff --git a/pkg/tests/testdata/TestToolRefAll/test.gpt b/pkg/tests/testdata/TestToolRefAll/test.gpt index 93c4ea05..423cf766 100644 --- a/pkg/tests/testdata/TestToolRefAll/test.gpt +++ b/pkg/tests/testdata/TestToolRefAll/test.gpt @@ -11,11 +11,19 @@ Agent body --- name: context type: context +share context: sharedcontext #!sys.echo Context Body +--- +name: sharedcontext + +#!sys.echo + +Shared context + --- name: none param: noneArg: stuff diff --git a/pkg/types/tool.go b/pkg/types/tool.go index 789215b6..b59a1953 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -546,7 +546,17 @@ func (t Tool) getExportedTools(prg Program) ([]ToolReference, error) { func (t Tool) GetContextTools(prg Program) ([]ToolReference, error) { result := &toolRefSet{} result.AddAll(t.getDirectContextToolRefs(prg)) - result.AddAll(t.getCompletionToolRefs(prg, nil, ToolTypeContext)) + + contextRefs, err := t.getCompletionToolRefs(prg, nil, ToolTypeContext) + if err != nil { + return nil, err + } + + for _, contextRef := range contextRefs { + result.AddAll(prg.ToolSet[contextRef.ToolID].getExportedContext(prg)) + result.Add(contextRef) + } + return result.List() }