Skip to content

Commit

Permalink
feat: Add support for tools from github enterprise.
Browse files Browse the repository at this point in the history
  • Loading branch information
johnrengelman committed Jul 1, 2024
1 parent a42c9e3 commit 374632f
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 42 deletions.
42 changes: 24 additions & 18 deletions pkg/cli/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/gptscript-ai/gptscript/pkg/gptscript"
"github.com/gptscript-ai/gptscript/pkg/input"
"github.com/gptscript-ai/gptscript/pkg/loader"
"github.com/gptscript-ai/gptscript/pkg/loader/github"
"github.com/gptscript-ai/gptscript/pkg/monitor"
"github.com/gptscript-ai/gptscript/pkg/mvl"
"github.com/gptscript-ai/gptscript/pkg/openai"
Expand Down Expand Up @@ -53,24 +54,25 @@ type GPTScript struct {
Output string `usage:"Save output to a file, or - for stdout" short:"o"`
EventsStreamTo string `usage:"Stream events to this location, could be a file descriptor/handle (e.g. fd://2), filename, or named pipe (e.g. \\\\.\\pipe\\my-pipe)" name:"events-stream-to"`
// Input should not be using GPTSCRIPT_INPUT env var because that is the same value that is set in tool executions
Input string `usage:"Read input from a file (\"-\" for stdin)" short:"f" env:"GPTSCRIPT_INPUT_FILE"`
SubTool string `usage:"Use tool of this name, not the first tool in file" local:"true"`
Assemble bool `usage:"Assemble tool to a single artifact, saved to --output" hidden:"true" local:"true"`
ListModels bool `usage:"List the models available and exit" local:"true"`
ListTools bool `usage:"List built-in tools and exit" local:"true"`
ListenAddress string `usage:"Server listen address" default:"127.0.0.1:0" hidden:"true"`
Chdir string `usage:"Change current working directory" short:"C"`
Daemon bool `usage:"Run tool as a daemon" local:"true" hidden:"true"`
Ports string `usage:"The port range to use for ephemeral daemon ports (ex: 11000-12000)" hidden:"true"`
CredentialContext string `usage:"Context name in which to store credentials" default:"default"`
CredentialOverride string `usage:"Credentials to override (ex: --credential-override github.com/example/cred-tool:API_TOKEN=1234)"`
ChatState string `usage:"The chat state to continue, or null to start a new chat and return the state" local:"true"`
ForceChat bool `usage:"Force an interactive chat session if even the top level tool is not a chat tool" local:"true"`
ForceSequential bool `usage:"Force parallel calls to run sequentially" local:"true"`
Workspace string `usage:"Directory to use for the workspace, if specified it will not be deleted on exit"`
UI bool `usage:"Launch the UI" local:"true" name:"ui"`
DisableTUI bool `usage:"Don't use chat TUI but instead verbose output" local:"true" name:"disable-tui"`
SaveChatStateFile string `usage:"A file to save the chat state to so that a conversation can be resumed with --chat-state" local:"true"`
Input string `usage:"Read input from a file (\"-\" for stdin)" short:"f" env:"GPTSCRIPT_INPUT_FILE"`
SubTool string `usage:"Use tool of this name, not the first tool in file" local:"true"`
Assemble bool `usage:"Assemble tool to a single artifact, saved to --output" hidden:"true" local:"true"`
ListModels bool `usage:"List the models available and exit" local:"true"`
ListTools bool `usage:"List built-in tools and exit" local:"true"`
ListenAddress string `usage:"Server listen address" default:"127.0.0.1:0" hidden:"true"`
Chdir string `usage:"Change current working directory" short:"C"`
Daemon bool `usage:"Run tool as a daemon" local:"true" hidden:"true"`
Ports string `usage:"The port range to use for ephemeral daemon ports (ex: 11000-12000)" hidden:"true"`
CredentialContext string `usage:"Context name in which to store credentials" default:"default"`
CredentialOverride string `usage:"Credentials to override (ex: --credential-override github.com/example/cred-tool:API_TOKEN=1234)"`
ChatState string `usage:"The chat state to continue, or null to start a new chat and return the state" local:"true"`
ForceChat bool `usage:"Force an interactive chat session if even the top level tool is not a chat tool" local:"true"`
ForceSequential bool `usage:"Force parallel calls to run sequentially" local:"true"`
Workspace string `usage:"Directory to use for the workspace, if specified it will not be deleted on exit"`
UI bool `usage:"Launch the UI" local:"true" name:"ui"`
DisableTUI bool `usage:"Don't use chat TUI but instead verbose output" local:"true" name:"disable-tui"`
SaveChatStateFile string `usage:"A file to save the chat state to so that a conversation can be resumed with --chat-state" local:"true"`
EnableGithubEnterprise string `usage:"The host name for a Github Enterprise instance to enable for remote loading" local:"true"`

readData []byte
}
Expand Down Expand Up @@ -328,6 +330,10 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) {
return err
}

if r.EnableGithubEnterprise != "" {
loader.AddVSC(github.LoaderForPrefix(r.EnableGithubEnterprise))
}

// If the user is trying to launch the chat-builder UI, then set up the tool and options here.
if r.UI {
args = append([]string{uiTool()}, args...)
Expand Down
87 changes: 63 additions & 24 deletions pkg/loader/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package github

import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
Expand All @@ -18,52 +19,63 @@ import (
"github.com/gptscript-ai/gptscript/pkg/types"
)

const (
GithubPrefix = "github.com/"
githubRepoURL = "https://github.com/%s/%s.git"
githubDownloadURL = "https://raw.githubusercontent.com/%s/%s/%s/%s"
githubCommitURL = "https://api.github.com/repos/%s/%s/commits/%s"
)
type GithubConfig struct {
Prefix string
RepoURL string
DownloadURL string
CommitURL string
AuthToken string
}

var (
githubAuthToken = os.Getenv("GITHUB_AUTH_TOKEN")
log = mvl.Package()
log = mvl.Package()
DEFAULT_GITHUB_CONFIG = &GithubConfig{
Prefix: "github.com/",
RepoURL: "https://github.com/%s/%s.git",
DownloadURL: "https://raw.githubusercontent.com/%s/%s/%s/%s",
CommitURL: "https://api.github.com/repos/%s/%s/commits/%s",
AuthToken: os.Getenv("GITHUB_AUTH_TOKEN"),
}
)

func init() {
loader.AddVSC(Load)
}

func getCommitLsRemote(ctx context.Context, account, repo, ref string) (string, error) {
url := fmt.Sprintf(githubRepoURL, account, repo)
func getCommitLsRemote(ctx context.Context, account, repo, ref string, config *GithubConfig) (string, error) {
url := fmt.Sprintf(config.RepoURL, account, repo)
return git.LsRemote(ctx, url, ref)
}

// regexp to match a git commit id
var commitRegexp = regexp.MustCompile("^[a-f0-9]{40}$")

func getCommit(ctx context.Context, account, repo, ref string) (string, error) {
func getCommit(ctx context.Context, account, repo, ref string, config *GithubConfig) (string, error) {
if commitRegexp.MatchString(ref) {
return ref, nil
}

url := fmt.Sprintf(githubCommitURL, account, repo, ref)
url := fmt.Sprintf(config.CommitURL, account, repo, ref)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", fmt.Errorf("failed to create request of %s/%s at %s: %w", account, repo, url, err)
}

if githubAuthToken != "" {
req.Header.Add("Authorization", "Bearer "+githubAuthToken)
if config.AuthToken != "" {
req.Header.Add("Authorization", "Bearer "+config.AuthToken)
}

resp, err := http.DefaultClient.Do(req)
client := http.DefaultClient
if req.Host == config.Prefix && strings.ToLower(os.Getenv("GH_ENTERPRISE_SKIP_VERIFY")) == "true" {
client = &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}}
}
resp, err := client.Do(req)
if err != nil {
return "", err
} else if resp.StatusCode != http.StatusOK {
c, _ := io.ReadAll(resp.Body)
resp.Body.Close()
commit, fallBackErr := getCommitLsRemote(ctx, account, repo, ref)
commit, fallBackErr := getCommitLsRemote(ctx, account, repo, ref, config)
if fallBackErr == nil {
return commit, nil
}
Expand All @@ -88,8 +100,28 @@ func getCommit(ctx context.Context, account, repo, ref string) (string, error) {
return commit.SHA, nil
}

func Load(ctx context.Context, _ *cache.Client, urlName string) (string, string, *types.Repo, bool, error) {
if !strings.HasPrefix(urlName, GithubPrefix) {
func LoaderForPrefix(prefix string) func(context.Context, *cache.Client, string) (string, string, *types.Repo, bool, error) {
return func(ctx context.Context, c *cache.Client, urlName string) (string, string, *types.Repo, bool, error) {
return LoadWithConfig(ctx, c, urlName, NewGithubEnterpriseConfig(prefix))
}
}

func Load(ctx context.Context, c *cache.Client, urlName string) (string, string, *types.Repo, bool, error) {
return LoadWithConfig(ctx, c, urlName, DEFAULT_GITHUB_CONFIG)
}

func NewGithubEnterpriseConfig(prefix string) *GithubConfig {
return &GithubConfig{
Prefix: prefix,
RepoURL: fmt.Sprintf("https://%s/%%s/%%s.git", prefix),
DownloadURL: fmt.Sprintf("https://raw.%s/%%s/%%s/%%s/%%s", prefix),
CommitURL: fmt.Sprintf("https://%s/api/v3/repos/%%s/%%s/commits/%%s", prefix),
AuthToken: os.Getenv("GH_ENTERPRISE_TOKEN"),
}
}

func LoadWithConfig(ctx context.Context, _ *cache.Client, urlName string, config *GithubConfig) (string, string, *types.Repo, bool, error) {
if !strings.HasPrefix(urlName, config.Prefix) {
return "", "", nil, false, nil
}

Expand All @@ -107,12 +139,12 @@ func Load(ctx context.Context, _ *cache.Client, urlName string) (string, string,
account, repo := parts[1], parts[2]
path := strings.Join(parts[3:], "/")

ref, err := getCommit(ctx, account, repo, ref)
ref, err := getCommit(ctx, account, repo, ref, config)
if err != nil {
return "", "", nil, false, err
}

downloadURL := fmt.Sprintf(githubDownloadURL, account, repo, ref, path)
downloadURL := fmt.Sprintf(config.DownloadURL, account, repo, ref, path)
if path == "" || path == "/" || !strings.Contains(parts[len(parts)-1], ".") {
var (
testPath string
Expand All @@ -124,13 +156,20 @@ func Load(ctx context.Context, _ *cache.Client, urlName string) (string, string,
} else {
testPath = path + "/" + ext
}
testURL = fmt.Sprintf(githubDownloadURL, account, repo, ref, testPath)
testURL = fmt.Sprintf(config.DownloadURL, account, repo, ref, testPath)
if i == len(types.DefaultFiles)-1 {
// no reason to test the last one, we are just going to use it. Being that the default list is only
// two elements this loop could have been one check, but hey over-engineered code ftw.
break
}
if resp, err := http.Head(testURL); err == nil {
headReq, err := http.NewRequest("HEAD", testURL, nil)
if err != nil {
break
}
if config.AuthToken != "" {
headReq.Header.Add("Authorization", "Bearer "+config.AuthToken)
}
if resp, err := http.DefaultClient.Do(headReq); err == nil {
_ = resp.Body.Close()
if resp.StatusCode == 200 {
break
Expand All @@ -141,9 +180,9 @@ func Load(ctx context.Context, _ *cache.Client, urlName string) (string, string,
path = testPath
}

return downloadURL, githubAuthToken, &types.Repo{
return downloadURL, config.AuthToken, &types.Repo{
VCS: "git",
Root: fmt.Sprintf(githubRepoURL, account, repo),
Root: fmt.Sprintf(config.RepoURL, account, repo),
Path: gpath.Dir(path),
Name: gpath.Base(path),
Revision: ref,
Expand Down
58 changes: 58 additions & 0 deletions pkg/loader/github/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ package github

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"os"
"testing"

"github.com/gptscript-ai/gptscript/pkg/types"
Expand Down Expand Up @@ -44,3 +48,57 @@ func TestLoad(t *testing.T) {
Revision: "172dfb00b48c6adbbaa7e99270933f95887d1b91",
}).Equal(t, repo)
}

func TestLoad_GithubEnterprise(t *testing.T) {
gheToken := "mytoken"
os.Setenv("GH_ENTERPRISE_SKIP_VERIFY", "true")
os.Setenv("GH_ENTERPRISE_TOKEN", gheToken)
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Printf("Request for %s\n", r.URL.Path)
switch r.URL.Path {
case "/api/v3/repos/gptscript-ai/gptscript/commits/172dfb0":
w.Write([]byte(`{"sha": "172dfb00b48c6adbbaa7e99270933f95887d1b91"}`))
default:
w.WriteHeader(404)
}
}))
defer s.Close()

serverAddr := s.Listener.Addr().String()

url, token, repo, ok, err := LoadWithConfig(context.Background(), nil, fmt.Sprintf("%s/gptscript-ai/gptscript/pkg/loader/testdata/tool@172dfb0", serverAddr), NewGithubEnterpriseConfig(serverAddr))
require.NoError(t, err)
assert.True(t, ok)
autogold.Expect(fmt.Sprintf("https://raw.%s/gptscript-ai/gptscript/172dfb00b48c6adbbaa7e99270933f95887d1b91/pkg/loader/testdata/tool/tool.gpt", serverAddr)).Equal(t, url)
autogold.Expect(&types.Repo{
VCS: "git", Root: fmt.Sprintf("https://%s/gptscript-ai/gptscript.git", serverAddr),
Path: "pkg/loader/testdata/tool",
Name: "tool.gpt",
Revision: "172dfb00b48c6adbbaa7e99270933f95887d1b91",
}).Equal(t, repo)
autogold.Expect(gheToken).Equal(t, token)

url, token, repo, ok, err = Load(context.Background(), nil, "github.com/gptscript-ai/gptscript/pkg/loader/testdata/agent@172dfb0")
require.NoError(t, err)
assert.True(t, ok)
autogold.Expect("https://raw.githubusercontent.com/gptscript-ai/gptscript/172dfb00b48c6adbbaa7e99270933f95887d1b91/pkg/loader/testdata/agent/agent.gpt").Equal(t, url)
autogold.Expect(&types.Repo{
VCS: "git", Root: "https://github.com/gptscript-ai/gptscript.git",
Path: "pkg/loader/testdata/agent",
Name: "agent.gpt",
Revision: "172dfb00b48c6adbbaa7e99270933f95887d1b91",
}).Equal(t, repo)
autogold.Expect("").Equal(t, token)

url, token, repo, ok, err = Load(context.Background(), nil, "github.com/gptscript-ai/gptscript/pkg/loader/testdata/bothtoolagent@172dfb0")
require.NoError(t, err)
assert.True(t, ok)
autogold.Expect("https://raw.githubusercontent.com/gptscript-ai/gptscript/172dfb00b48c6adbbaa7e99270933f95887d1b91/pkg/loader/testdata/bothtoolagent/agent.gpt").Equal(t, url)
autogold.Expect(&types.Repo{
VCS: "git", Root: "https://github.com/gptscript-ai/gptscript.git",
Path: "pkg/loader/testdata/bothtoolagent",
Name: "agent.gpt",
Revision: "172dfb00b48c6adbbaa7e99270933f95887d1b91",
}).Equal(t, repo)
autogold.Expect("").Equal(t, token)
}
3 changes: 3 additions & 0 deletions pkg/types/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,9 @@ type Repo struct {
Name string
// The revision of this source
Revision string

// Additional headers to pass when making requests for this repo
Headers map[string]string
}

type ToolSource struct {
Expand Down

0 comments on commit 374632f

Please sign in to comment.