Skip to content

Commit

Permalink
chore: allow setting of dataset tool in SDK server config
Browse files Browse the repository at this point in the history
Signed-off-by: Donnie Adams <[email protected]>
  • Loading branch information
thedadams committed Nov 6, 2024
1 parent 2a9f664 commit d21c001
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 22 deletions.
2 changes: 2 additions & 0 deletions pkg/cli/sdk_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

type SDKServer struct {
*GPTScript
DatasetTool string `usage:"Tool to use for datasets"`
WorkspaceTool string `usage:"Tool to use for workspace"`
}

Expand Down Expand Up @@ -38,6 +39,7 @@ func (c *SDKServer) Run(cmd *cobra.Command, _ []string) error {
Options: opts,
ListenAddress: c.ListenAddress,
Debug: c.Debug,
DatasetTool: c.DatasetTool,
WorkspaceTool: c.WorkspaceTool,
})
}
27 changes: 14 additions & 13 deletions pkg/sdkserver/datasets.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ import (
"github.com/gptscript-ai/gptscript/pkg/loader"
)

func (s *server) getDatasetTool(req datasetRequest) string {
if req.DatasetToolRepo != "" {
return req.DatasetToolRepo
}

return s.datasetTool
}

type datasetRequest struct {
Input string `json:"input"`
WorkspaceID string `json:"workspaceID"`
Expand Down Expand Up @@ -38,13 +46,6 @@ func (r datasetRequest) opts(o gptscript.Options) gptscript.Options {
return opts
}

func (r datasetRequest) getToolRepo() string {
if r.DatasetToolRepo != "" {
return r.DatasetToolRepo
}
return "github.com/otto8-ai/datasets"
}

func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
logger := gcontext.GetLogger(r.Context())

Expand All @@ -65,7 +66,7 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
return
}

prg, err := loader.Program(r.Context(), req.getToolRepo(), "List Datasets", loader.Options{
prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "List Datasets", loader.Options{
Cache: g.Cache,
})

Expand Down Expand Up @@ -126,7 +127,7 @@ func (s *server) createDataset(w http.ResponseWriter, r *http.Request) {
return
}

prg, err := loader.Program(r.Context(), req.getToolRepo(), "Create Dataset", loader.Options{
prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Create Dataset", loader.Options{
Cache: g.Cache,
})

Expand Down Expand Up @@ -195,7 +196,7 @@ func (s *server) addDatasetElement(w http.ResponseWriter, r *http.Request) {
return
}

prg, err := loader.Program(r.Context(), req.getToolRepo(), "Add Element", loader.Options{
prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Add Element", loader.Options{
Cache: g.Cache,
})
if err != nil {
Expand Down Expand Up @@ -262,7 +263,7 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) {
return
}

prg, err := loader.Program(r.Context(), req.getToolRepo(), "Add Elements", loader.Options{
prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Add Elements", loader.Options{
Cache: g.Cache,
})
if err != nil {
Expand Down Expand Up @@ -327,7 +328,7 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) {
return
}

prg, err := loader.Program(r.Context(), req.getToolRepo(), "List Elements", loader.Options{
prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "List Elements", loader.Options{
Cache: g.Cache,
})
if err != nil {
Expand Down Expand Up @@ -390,7 +391,7 @@ func (s *server) getDatasetElement(w http.ResponseWriter, r *http.Request) {
return
}

prg, err := loader.Program(r.Context(), req.getToolRepo(), "Get Element SDK", loader.Options{
prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Get Element SDK", loader.Options{
Cache: g.Cache,
})
if err != nil {
Expand Down
10 changes: 5 additions & 5 deletions pkg/sdkserver/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ import (
)

type server struct {
gptscriptOpts gptscript.Options
address, token string
workspaceTool string
client *gptscript.GPTScript
events *broadcaster.Broadcaster[event]
gptscriptOpts gptscript.Options
address, token string
datasetTool, workspaceTool string
client *gptscript.GPTScript
events *broadcaster.Broadcaster[event]

runtimeManager engine.RuntimeManager

Expand Down
13 changes: 9 additions & 4 deletions pkg/sdkserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ import (
type Options struct {
gptscript.Options

ListenAddress string
WorkspaceTool string
Debug bool
DisableServerErrorLogging bool
ListenAddress string
DatasetTool, WorkspaceTool string
Debug bool
DisableServerErrorLogging bool
}

// Run will start the server and block until the server is shut down.
Expand Down Expand Up @@ -108,6 +108,7 @@ func run(ctx context.Context, listener net.Listener, opts Options) error {
gptscriptOpts: opts.Options,
address: listener.Addr().String(),
token: token,
datasetTool: opts.DatasetTool,
workspaceTool: opts.WorkspaceTool,
client: g,
events: events,
Expand Down Expand Up @@ -159,6 +160,7 @@ func complete(opts ...Options) Options {
for _, opt := range opts {
result.Options = gptscript.Complete(result.Options, opt.Options)
result.ListenAddress = types.FirstSet(opt.ListenAddress, result.ListenAddress)
result.DatasetTool = types.FirstSet(opt.DatasetTool, result.DatasetTool)
result.WorkspaceTool = types.FirstSet(opt.WorkspaceTool, result.WorkspaceTool)
result.Debug = types.FirstSet(opt.Debug, result.Debug)
result.DisableServerErrorLogging = types.FirstSet(opt.DisableServerErrorLogging, result.DisableServerErrorLogging)
Expand All @@ -171,6 +173,9 @@ func complete(opts ...Options) Options {
if result.WorkspaceTool == "" {
result.WorkspaceTool = "github.com/gptscript-ai/workspace-provider"
}
if result.DatasetTool == "" {
result.DatasetTool = "github.com/otto8-ai/datasets"
}

return result
}

0 comments on commit d21c001

Please sign in to comment.