diff --git a/pkg/cli/gptscript.go b/pkg/cli/gptscript.go index 3ba5e8c7..8bce488f 100644 --- a/pkg/cli/gptscript.go +++ b/pkg/cli/gptscript.go @@ -22,6 +22,7 @@ import ( "github.com/gptscript-ai/gptscript/pkg/gptscript" "github.com/gptscript-ai/gptscript/pkg/input" "github.com/gptscript-ai/gptscript/pkg/loader" + "github.com/gptscript-ai/gptscript/pkg/loader/github" "github.com/gptscript-ai/gptscript/pkg/monitor" "github.com/gptscript-ai/gptscript/pkg/mvl" "github.com/gptscript-ai/gptscript/pkg/openai" @@ -53,24 +54,25 @@ type GPTScript struct { Output string `usage:"Save output to a file, or - for stdout" short:"o"` EventsStreamTo string `usage:"Stream events to this location, could be a file descriptor/handle (e.g. fd://2), filename, or named pipe (e.g. \\\\.\\pipe\\my-pipe)" name:"events-stream-to"` // Input should not be using GPTSCRIPT_INPUT env var because that is the same value that is set in tool executions - Input string `usage:"Read input from a file (\"-\" for stdin)" short:"f" env:"GPTSCRIPT_INPUT_FILE"` - SubTool string `usage:"Use tool of this name, not the first tool in file" local:"true"` - Assemble bool `usage:"Assemble tool to a single artifact, saved to --output" hidden:"true" local:"true"` - ListModels bool `usage:"List the models available and exit" local:"true"` - ListTools bool `usage:"List built-in tools and exit" local:"true"` - ListenAddress string `usage:"Server listen address" default:"127.0.0.1:0" hidden:"true"` - Chdir string `usage:"Change current working directory" short:"C"` - Daemon bool `usage:"Run tool as a daemon" local:"true" hidden:"true"` - Ports string `usage:"The port range to use for ephemeral daemon ports (ex: 11000-12000)" hidden:"true"` - CredentialContext string `usage:"Context name in which to store credentials" default:"default"` - CredentialOverride string `usage:"Credentials to override (ex: --credential-override github.com/example/cred-tool:API_TOKEN=1234)"` - ChatState string `usage:"The chat state to continue, or null to start a new chat and return the state" local:"true"` - ForceChat bool `usage:"Force an interactive chat session if even the top level tool is not a chat tool" local:"true"` - ForceSequential bool `usage:"Force parallel calls to run sequentially" local:"true"` - Workspace string `usage:"Directory to use for the workspace, if specified it will not be deleted on exit"` - UI bool `usage:"Launch the UI" local:"true" name:"ui"` - DisableTUI bool `usage:"Don't use chat TUI but instead verbose output" local:"true" name:"disable-tui"` - SaveChatStateFile string `usage:"A file to save the chat state to so that a conversation can be resumed with --chat-state" local:"true"` + Input string `usage:"Read input from a file (\"-\" for stdin)" short:"f" env:"GPTSCRIPT_INPUT_FILE"` + SubTool string `usage:"Use tool of this name, not the first tool in file" local:"true"` + Assemble bool `usage:"Assemble tool to a single artifact, saved to --output" hidden:"true" local:"true"` + ListModels bool `usage:"List the models available and exit" local:"true"` + ListTools bool `usage:"List built-in tools and exit" local:"true"` + ListenAddress string `usage:"Server listen address" default:"127.0.0.1:0" hidden:"true"` + Chdir string `usage:"Change current working directory" short:"C"` + Daemon bool `usage:"Run tool as a daemon" local:"true" hidden:"true"` + Ports string `usage:"The port range to use for ephemeral daemon ports (ex: 11000-12000)" hidden:"true"` + CredentialContext string `usage:"Context name in which to store credentials" default:"default"` + CredentialOverride string `usage:"Credentials to override (ex: --credential-override github.com/example/cred-tool:API_TOKEN=1234)"` + ChatState string `usage:"The chat state to continue, or null to start a new chat and return the state" local:"true"` + ForceChat bool `usage:"Force an interactive chat session if even the top level tool is not a chat tool" local:"true"` + ForceSequential bool `usage:"Force parallel calls to run sequentially" local:"true"` + Workspace string `usage:"Directory to use for the workspace, if specified it will not be deleted on exit"` + UI bool `usage:"Launch the UI" local:"true" name:"ui"` + DisableTUI bool `usage:"Don't use chat TUI but instead verbose output" local:"true" name:"disable-tui"` + SaveChatStateFile string `usage:"A file to save the chat state to so that a conversation can be resumed with --chat-state" local:"true"` + EnableGithubEnterprise string `usage:"The host name for a Github Enterprise instance to enable for remote loading" local:"true"` readData []byte } @@ -328,6 +330,10 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) { return err } + if r.EnableGithubEnterprise != "" { + loader.AddVSC(github.LoaderForPrefix(r.EnableGithubEnterprise)) + } + // If the user is trying to launch the chat-builder UI, then set up the tool and options here. if r.UI { args = append([]string{uiTool()}, args...) diff --git a/pkg/loader/github/github.go b/pkg/loader/github/github.go index 2fb01c3d..8f71ca19 100644 --- a/pkg/loader/github/github.go +++ b/pkg/loader/github/github.go @@ -2,6 +2,7 @@ package github import ( "context" + "crypto/tls" "encoding/json" "fmt" "io" @@ -18,52 +19,63 @@ import ( "github.com/gptscript-ai/gptscript/pkg/types" ) -const ( - GithubPrefix = "github.com/" - githubRepoURL = "https://github.com/%s/%s.git" - githubDownloadURL = "https://raw.githubusercontent.com/%s/%s/%s/%s" - githubCommitURL = "https://api.github.com/repos/%s/%s/commits/%s" -) +type GithubConfig struct { + Prefix string + RepoURL string + DownloadURL string + CommitURL string + AuthToken string +} var ( - githubAuthToken = os.Getenv("GITHUB_AUTH_TOKEN") - log = mvl.Package() + log = mvl.Package() + DEFAULT_GITHUB_CONFIG = &GithubConfig{ + Prefix: "github.com/", + RepoURL: "https://github.com/%s/%s.git", + DownloadURL: "https://raw.githubusercontent.com/%s/%s/%s/%s", + CommitURL: "https://api.github.com/repos/%s/%s/commits/%s", + AuthToken: os.Getenv("GITHUB_AUTH_TOKEN"), + } ) func init() { loader.AddVSC(Load) } -func getCommitLsRemote(ctx context.Context, account, repo, ref string) (string, error) { - url := fmt.Sprintf(githubRepoURL, account, repo) +func getCommitLsRemote(ctx context.Context, account, repo, ref string, config *GithubConfig) (string, error) { + url := fmt.Sprintf(config.RepoURL, account, repo) return git.LsRemote(ctx, url, ref) } // regexp to match a git commit id var commitRegexp = regexp.MustCompile("^[a-f0-9]{40}$") -func getCommit(ctx context.Context, account, repo, ref string) (string, error) { +func getCommit(ctx context.Context, account, repo, ref string, config *GithubConfig) (string, error) { if commitRegexp.MatchString(ref) { return ref, nil } - url := fmt.Sprintf(githubCommitURL, account, repo, ref) + url := fmt.Sprintf(config.CommitURL, account, repo, ref) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return "", fmt.Errorf("failed to create request of %s/%s at %s: %w", account, repo, url, err) } - if githubAuthToken != "" { - req.Header.Add("Authorization", "Bearer "+githubAuthToken) + if config.AuthToken != "" { + req.Header.Add("Authorization", "Bearer "+config.AuthToken) } - resp, err := http.DefaultClient.Do(req) + client := http.DefaultClient + if req.Host == config.Prefix && strings.ToLower(os.Getenv("GH_ENTERPRISE_SKIP_VERIFY")) == "true" { + client = &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}} + } + resp, err := client.Do(req) if err != nil { return "", err } else if resp.StatusCode != http.StatusOK { c, _ := io.ReadAll(resp.Body) resp.Body.Close() - commit, fallBackErr := getCommitLsRemote(ctx, account, repo, ref) + commit, fallBackErr := getCommitLsRemote(ctx, account, repo, ref, config) if fallBackErr == nil { return commit, nil } @@ -88,8 +100,28 @@ func getCommit(ctx context.Context, account, repo, ref string) (string, error) { return commit.SHA, nil } -func Load(ctx context.Context, _ *cache.Client, urlName string) (string, string, *types.Repo, bool, error) { - if !strings.HasPrefix(urlName, GithubPrefix) { +func LoaderForPrefix(prefix string) func(context.Context, *cache.Client, string) (string, string, *types.Repo, bool, error) { + return func(ctx context.Context, c *cache.Client, urlName string) (string, string, *types.Repo, bool, error) { + return LoadWithConfig(ctx, c, urlName, NewGithubEnterpriseConfig(prefix)) + } +} + +func Load(ctx context.Context, c *cache.Client, urlName string) (string, string, *types.Repo, bool, error) { + return LoadWithConfig(ctx, c, urlName, DEFAULT_GITHUB_CONFIG) +} + +func NewGithubEnterpriseConfig(prefix string) *GithubConfig { + return &GithubConfig{ + Prefix: prefix, + RepoURL: fmt.Sprintf("https://%s/%%s/%%s.git", prefix), + DownloadURL: fmt.Sprintf("https://raw.%s/%%s/%%s/%%s/%%s", prefix), + CommitURL: fmt.Sprintf("https://%s/api/v3/repos/%%s/%%s/commits/%%s", prefix), + AuthToken: os.Getenv("GH_ENTERPRISE_TOKEN"), + } +} + +func LoadWithConfig(ctx context.Context, _ *cache.Client, urlName string, config *GithubConfig) (string, string, *types.Repo, bool, error) { + if !strings.HasPrefix(urlName, config.Prefix) { return "", "", nil, false, nil } @@ -107,12 +139,12 @@ func Load(ctx context.Context, _ *cache.Client, urlName string) (string, string, account, repo := parts[1], parts[2] path := strings.Join(parts[3:], "/") - ref, err := getCommit(ctx, account, repo, ref) + ref, err := getCommit(ctx, account, repo, ref, config) if err != nil { return "", "", nil, false, err } - downloadURL := fmt.Sprintf(githubDownloadURL, account, repo, ref, path) + downloadURL := fmt.Sprintf(config.DownloadURL, account, repo, ref, path) if path == "" || path == "/" || !strings.Contains(parts[len(parts)-1], ".") { var ( testPath string @@ -124,13 +156,20 @@ func Load(ctx context.Context, _ *cache.Client, urlName string) (string, string, } else { testPath = path + "/" + ext } - testURL = fmt.Sprintf(githubDownloadURL, account, repo, ref, testPath) + testURL = fmt.Sprintf(config.DownloadURL, account, repo, ref, testPath) if i == len(types.DefaultFiles)-1 { // no reason to test the last one, we are just going to use it. Being that the default list is only // two elements this loop could have been one check, but hey over-engineered code ftw. break } - if resp, err := http.Head(testURL); err == nil { + headReq, err := http.NewRequest("HEAD", testURL, nil) + if err != nil { + break + } + if config.AuthToken != "" { + headReq.Header.Add("Authorization", "Bearer "+config.AuthToken) + } + if resp, err := http.DefaultClient.Do(headReq); err == nil { _ = resp.Body.Close() if resp.StatusCode == 200 { break @@ -141,9 +180,9 @@ func Load(ctx context.Context, _ *cache.Client, urlName string) (string, string, path = testPath } - return downloadURL, githubAuthToken, &types.Repo{ + return downloadURL, config.AuthToken, &types.Repo{ VCS: "git", - Root: fmt.Sprintf(githubRepoURL, account, repo), + Root: fmt.Sprintf(config.RepoURL, account, repo), Path: gpath.Dir(path), Name: gpath.Base(path), Revision: ref, diff --git a/pkg/loader/github/github_test.go b/pkg/loader/github/github_test.go index d627ee5e..e735ed8c 100644 --- a/pkg/loader/github/github_test.go +++ b/pkg/loader/github/github_test.go @@ -2,6 +2,10 @@ package github import ( "context" + "fmt" + "net/http" + "net/http/httptest" + "os" "testing" "github.com/gptscript-ai/gptscript/pkg/types" @@ -44,3 +48,57 @@ func TestLoad(t *testing.T) { Revision: "172dfb00b48c6adbbaa7e99270933f95887d1b91", }).Equal(t, repo) } + +func TestLoad_GithubEnterprise(t *testing.T) { + gheToken := "mytoken" + os.Setenv("GH_ENTERPRISE_SKIP_VERIFY", "true") + os.Setenv("GH_ENTERPRISE_TOKEN", gheToken) + s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Printf("Request for %s\n", r.URL.Path) + switch r.URL.Path { + case "/api/v3/repos/gptscript-ai/gptscript/commits/172dfb0": + w.Write([]byte(`{"sha": "172dfb00b48c6adbbaa7e99270933f95887d1b91"}`)) + default: + w.WriteHeader(404) + } + })) + defer s.Close() + + serverAddr := s.Listener.Addr().String() + + url, token, repo, ok, err := LoadWithConfig(context.Background(), nil, fmt.Sprintf("%s/gptscript-ai/gptscript/pkg/loader/testdata/tool@172dfb0", serverAddr), NewGithubEnterpriseConfig(serverAddr)) + require.NoError(t, err) + assert.True(t, ok) + autogold.Expect(fmt.Sprintf("https://raw.%s/gptscript-ai/gptscript/172dfb00b48c6adbbaa7e99270933f95887d1b91/pkg/loader/testdata/tool/tool.gpt", serverAddr)).Equal(t, url) + autogold.Expect(&types.Repo{ + VCS: "git", Root: fmt.Sprintf("https://%s/gptscript-ai/gptscript.git", serverAddr), + Path: "pkg/loader/testdata/tool", + Name: "tool.gpt", + Revision: "172dfb00b48c6adbbaa7e99270933f95887d1b91", + }).Equal(t, repo) + autogold.Expect(gheToken).Equal(t, token) + + url, token, repo, ok, err = Load(context.Background(), nil, "github.com/gptscript-ai/gptscript/pkg/loader/testdata/agent@172dfb0") + require.NoError(t, err) + assert.True(t, ok) + autogold.Expect("https://raw.githubusercontent.com/gptscript-ai/gptscript/172dfb00b48c6adbbaa7e99270933f95887d1b91/pkg/loader/testdata/agent/agent.gpt").Equal(t, url) + autogold.Expect(&types.Repo{ + VCS: "git", Root: "https://github.com/gptscript-ai/gptscript.git", + Path: "pkg/loader/testdata/agent", + Name: "agent.gpt", + Revision: "172dfb00b48c6adbbaa7e99270933f95887d1b91", + }).Equal(t, repo) + autogold.Expect("").Equal(t, token) + + url, token, repo, ok, err = Load(context.Background(), nil, "github.com/gptscript-ai/gptscript/pkg/loader/testdata/bothtoolagent@172dfb0") + require.NoError(t, err) + assert.True(t, ok) + autogold.Expect("https://raw.githubusercontent.com/gptscript-ai/gptscript/172dfb00b48c6adbbaa7e99270933f95887d1b91/pkg/loader/testdata/bothtoolagent/agent.gpt").Equal(t, url) + autogold.Expect(&types.Repo{ + VCS: "git", Root: "https://github.com/gptscript-ai/gptscript.git", + Path: "pkg/loader/testdata/bothtoolagent", + Name: "agent.gpt", + Revision: "172dfb00b48c6adbbaa7e99270933f95887d1b91", + }).Equal(t, repo) + autogold.Expect("").Equal(t, token) +} diff --git a/pkg/types/tool.go b/pkg/types/tool.go index b0af5183..e9d3e484 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -721,6 +721,9 @@ type Repo struct { Name string // The revision of this source Revision string + + // Additional headers to pass when making requests for this repo + Headers map[string]string } type ToolSource struct {