diff --git a/pkg/builtin/builtin.go b/pkg/builtin/builtin.go index 90f36919..42ff373b 100644 --- a/pkg/builtin/builtin.go +++ b/pkg/builtin/builtin.go @@ -277,10 +277,15 @@ func ListTools() (result []types.Tool) { } func Builtin(name string) (types.Tool, bool) { + return BuiltinWithDefaultModel(name, "") +} + +func BuiltinWithDefaultModel(name, defaultModel string) (types.Tool, bool) { // Legacy syntax not used anymore name = strings.TrimSuffix(name, "?") t, ok := tools[name] t.Parameters.Name = name + t.Parameters.ModelName = defaultModel t.ID = name t.Instructions = "#!" + name return SetDefaults(t), ok diff --git a/pkg/loader/loader.go b/pkg/loader/loader.go index f2679c6f..5a907f5b 100644 --- a/pkg/loader/loader.go +++ b/pkg/loader/loader.go @@ -132,7 +132,7 @@ func loadLocal(base *source, name string) (*source, bool, error) { }, true, nil } -func loadProgram(data []byte, into *types.Program, targetToolName string) (types.Tool, error) { +func loadProgram(data []byte, into *types.Program, targetToolName, defaultModel string) (types.Tool, error) { var ext types.Program if err := json.Unmarshal(data[len(assemble.Header):], &ext); err != nil { @@ -141,7 +141,7 @@ func loadProgram(data []byte, into *types.Program, targetToolName string) (types into.ToolSet = make(map[string]types.Tool, len(ext.ToolSet)) for k, v := range ext.ToolSet { - if builtinTool, ok := builtin.Builtin(k); ok { + if builtinTool, ok := builtin.BuiltinWithDefaultModel(k, defaultModel); ok { v = builtinTool } into.ToolSet[k] = v @@ -186,11 +186,11 @@ func loadOpenAPI(prg *types.Program, data []byte) *openapi3.T { return openAPIDocument } -func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, targetToolName string) ([]types.Tool, error) { +func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, targetToolName, defaultModel string) ([]types.Tool, error) { data := base.Content if bytes.HasPrefix(data, assemble.Header) { - tool, err := loadProgram(data, prg, targetToolName) + tool, err := loadProgram(data, prg, targetToolName, defaultModel) if err != nil { return nil, err } @@ -310,17 +310,17 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base localTools[strings.ToLower(tool.Parameters.Name)] = tool } - return linkAll(ctx, cache, prg, base, targetTools, localTools) + return linkAll(ctx, cache, prg, base, targetTools, localTools, defaultModel) } -func linkAll(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tools []types.Tool, localTools types.ToolSet) (result []types.Tool, _ error) { +func linkAll(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tools []types.Tool, localTools types.ToolSet, defaultModel string) (result []types.Tool, _ error) { localToolsMapping := make(map[string]string, len(tools)) for _, localTool := range localTools { localToolsMapping[strings.ToLower(localTool.Parameters.Name)] = localTool.ID } for _, tool := range tools { - tool, err := link(ctx, cache, prg, base, tool, localTools, localToolsMapping) + tool, err := link(ctx, cache, prg, base, tool, localTools, localToolsMapping, defaultModel) if err != nil { return nil, err } @@ -329,7 +329,7 @@ func linkAll(ctx context.Context, cache *cache.Client, prg *types.Program, base return } -func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tool types.Tool, localTools types.ToolSet, localToolsMapping map[string]string) (types.Tool, error) { +func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tool types.Tool, localTools types.ToolSet, localToolsMapping map[string]string, defaultModel string) (types.Tool, error) { if existing, ok := prg.ToolSet[tool.ID]; ok { return existing, nil } @@ -354,7 +354,7 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so linkedTool = existing } else { var err error - linkedTool, err = link(ctx, cache, prg, base, localTool, localTools, localToolsMapping) + linkedTool, err = link(ctx, cache, prg, base, localTool, localTools, localToolsMapping, defaultModel) if err != nil { return types.Tool{}, fmt.Errorf("failed linking %s at %s: %w", targetToolName, base, err) } @@ -364,7 +364,7 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so toolNames[targetToolName] = struct{}{} } else { toolName, subTool := types.SplitToolRef(targetToolName) - resolvedTools, err := resolve(ctx, cache, prg, base, toolName, subTool) + resolvedTools, err := resolve(ctx, cache, prg, base, toolName, subTool, defaultModel) if err != nil { return types.Tool{}, fmt.Errorf("failed resolving %s from %s: %w", targetToolName, base, err) } @@ -376,6 +376,10 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so tool.LocalTools = localToolsMapping + if tool.ModelName == "" { + tool.ModelName = defaultModel + } + tool = builtin.SetDefaults(tool) prg.ToolSet[tool.ID] = tool @@ -405,7 +409,7 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts .. Path: locationPath, Name: locationName, Location: opt.Location, - }, subToolName) + }, subToolName, opt.DefaultModel) if err != nil { return types.Program{}, err } @@ -414,20 +418,26 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts .. } type Options struct { - Cache *cache.Client - Location string + Cache *cache.Client + Location string + DefaultModel string } func complete(opts ...Options) (result Options) { for _, opt := range opts { result.Cache = types.FirstSet(opt.Cache, result.Cache) result.Location = types.FirstSet(opt.Location, result.Location) + result.DefaultModel = types.FirstSet(opt.DefaultModel, result.DefaultModel) } if result.Location == "" { result.Location = "inline" } + if result.DefaultModel == "" { + result.DefaultModel = builtin.GetDefaultModel() + } + return } @@ -451,7 +461,7 @@ func Program(ctx context.Context, name, subToolName string, opts ...Options) (ty Name: name, ToolSet: types.ToolSet{}, } - tools, err := resolve(ctx, opt.Cache, &prg, &source{}, name, subToolName) + tools, err := resolve(ctx, opt.Cache, &prg, &source{}, name, subToolName, opt.DefaultModel) if err != nil { return types.Program{}, err } @@ -459,9 +469,9 @@ func Program(ctx context.Context, name, subToolName string, opts ...Options) (ty return prg, nil } -func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, name, subTool string) ([]types.Tool, error) { +func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, name, subTool, defaultModel string) ([]types.Tool, error) { if subTool == "" { - t, ok := builtin.Builtin(name) + t, ok := builtin.BuiltinWithDefaultModel(name, defaultModel) if ok { prg.ToolSet[t.ID] = t return []types.Tool{t}, nil @@ -473,7 +483,7 @@ func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base return nil, err } - result, err := readTool(ctx, cache, prg, s, subTool) + result, err := readTool(ctx, cache, prg, s, subTool, defaultModel) if err != nil { return nil, err } diff --git a/pkg/loader/openapi_test.go b/pkg/loader/openapi_test.go index 1a7eaa76..423246d1 100644 --- a/pkg/loader/openapi_test.go +++ b/pkg/loader/openapi_test.go @@ -26,7 +26,7 @@ func TestLoadOpenAPI(t *testing.T) { } datav3, err := os.ReadFile("testdata/openapi_v3.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "") + _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "") require.NoError(t, err, "failed to read openapi v3") require.Equal(t, 3, numOpenAPITools(prgv3.ToolSet), "expected 3 openapi tools") @@ -35,7 +35,7 @@ func TestLoadOpenAPI(t *testing.T) { } datav2, err := os.ReadFile("testdata/openapi_v2.json") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv2json, &source{Content: datav2}, "") + _, err = readTool(context.Background(), nil, &prgv2json, &source{Content: datav2}, "", "") require.NoError(t, err, "failed to read openapi v2") require.Equal(t, 3, numOpenAPITools(prgv2json.ToolSet), "expected 3 openapi tools") @@ -44,7 +44,7 @@ func TestLoadOpenAPI(t *testing.T) { } datav2, err = os.ReadFile("testdata/openapi_v2.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv2yaml, &source{Content: datav2}, "") + _, err = readTool(context.Background(), nil, &prgv2yaml, &source{Content: datav2}, "", "") require.NoError(t, err, "failed to read openapi v2 (yaml)") require.Equal(t, 3, numOpenAPITools(prgv2yaml.ToolSet), "expected 3 openapi tools") @@ -57,7 +57,7 @@ func TestOpenAPIv3(t *testing.T) { } datav3, err := os.ReadFile("testdata/openapi_v3.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "") + _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi")) @@ -69,7 +69,7 @@ func TestOpenAPIv3NoOperationIDs(t *testing.T) { } datav3, err := os.ReadFile("testdata/openapi_v3_no_operation_ids.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "") + _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi")) @@ -81,7 +81,7 @@ func TestOpenAPIv2(t *testing.T) { } datav2, err := os.ReadFile("testdata/openapi_v2.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv2, &source{Content: datav2}, "") + _, err = readTool(context.Background(), nil, &prgv2, &source{Content: datav2}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv2.ToolSet, autogold.Dir("testdata/openapi")) @@ -94,7 +94,7 @@ func TestOpenAPIv3Revamp(t *testing.T) { } datav3, err := os.ReadFile("testdata/openapi_v3.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "") + _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi")) @@ -107,7 +107,7 @@ func TestOpenAPIv3NoOperationIDsRevamp(t *testing.T) { } datav3, err := os.ReadFile("testdata/openapi_v3_no_operation_ids.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "") + _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi")) @@ -120,7 +120,7 @@ func TestOpenAPIv2Revamp(t *testing.T) { } datav2, err := os.ReadFile("testdata/openapi_v2.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv2, &source{Content: datav2}, "") + _, err = readTool(context.Background(), nil, &prgv2, &source{Content: datav2}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv2.ToolSet, autogold.Dir("testdata/openapi")) diff --git a/pkg/sdkserver/run.go b/pkg/sdkserver/run.go index b6b5a049..1c0f7c4b 100644 --- a/pkg/sdkserver/run.go +++ b/pkg/sdkserver/run.go @@ -32,7 +32,11 @@ func (s *server) execAndStream(ctx context.Context, programLoader loaderFunc, lo } defer g.Close(false) - prg, err := programLoader(ctx, toolDef.String(), subTool, loader.Options{Cache: g.Cache}) + defaultModel := opts.OpenAI.DefaultModel + if defaultModel == "" { + defaultModel = s.gptscriptOpts.OpenAI.DefaultModel + } + prg, err := programLoader(ctx, toolDef.String(), subTool, loader.Options{Cache: g.Cache, DefaultModel: defaultModel}) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err)) return