Skip to content

Commit

Permalink
Add secret source support for git and huggingface
Browse files Browse the repository at this point in the history
  • Loading branch information
discordianfish committed Feb 10, 2024
1 parent 3891683 commit 9f6eb9d
Show file tree
Hide file tree
Showing 10 changed files with 323 additions and 25 deletions.
5 changes: 3 additions & 2 deletions pkg/cmd/agent/submit.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ import (

func NewSubmitCmd(logger *log.Logger) *cobra.Command {
dump := false
submissionConfig := diambra.NewSubmissionConfig(logger)
submissionConfig := diambra.SubmissionConfig{}
submissionConfig.RegisterCredentialsProviders()
c, err := diambra.NewConfig(logger)
if err != nil {
level.Error(logger).Log("msg", err.Error())
Expand All @@ -46,7 +47,7 @@ func NewSubmitCmd(logger *log.Logger) *cobra.Command {
level.Error(logger).Log("msg", err.Error())
os.Exit(1)
}
submission, err := submissionConfig.Submission(c.CredPath, args)
submission, err := submissionConfig.Submission(c, args)
if err != nil {
level.Error(logger).Log("msg", "failed to configure manifest", "err", err.Error())
os.Exit(1)
Expand Down
5 changes: 3 additions & 2 deletions pkg/cmd/agent/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ const (
)

func NewTestCmd(logger *log.Logger) *cobra.Command {
submissionConfig := diambra.NewSubmissionConfig(logger)
submissionConfig := diambra.SubmissionConfig{}
submissionConfig.RegisterCredentialsProviders()
c, err := diambra.NewConfig(logger)
if err != nil {
level.Error(logger).Log("msg", err.Error())
Expand All @@ -37,7 +38,7 @@ func NewTestCmd(logger *log.Logger) *cobra.Command {
Long: `This takes a docker image or submission manifest and runs it in the same way as it would be run when submitted
to DIAMBRA. This is useful for testing your agent before submitting it. Optionally, you can pass in commands to run instead of the configured entrypoint.`,
Run: func(cmd *cobra.Command, args []string) {
submission, err := submissionConfig.Submission(c.CredPath, args)
submission, err := submissionConfig.Submission(c, args)
if err != nil {
level.Error(logger).Log("msg", "failed to configure manifest", "err", err.Error())
os.Exit(1)
Expand Down
2 changes: 1 addition & 1 deletion pkg/container/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func NewDockerRunner(logger log.Logger, client *client.Client, autoRemove bool)
func (r *DockerRunner) Pull(c *Container, output *os.File) error {
reader, err := r.Client.ImagePull(context.TODO(), c.Image, types.ImagePullOptions{})
if err != nil {
return fmt.Errorf("couldn't pull image %s: %w:\nTo disable pulling the image on start, retry with --images.pull=false", c.Image, err)
return fmt.Errorf("couldn't pull image %s: %w:\nTo disable pulling the image on start, retry with --images.no-pull", c.Image, err)
}
defer reader.Close()

Expand Down
76 changes: 62 additions & 14 deletions pkg/diambra/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (

"github.com/diambra/cli/pkg/container"
"github.com/diambra/cli/pkg/diambra/client"
"github.com/diambra/cli/pkg/secretsources"
"github.com/diambra/init/initializer"
"github.com/go-kit/log"
"github.com/go-kit/log/level"
Expand Down Expand Up @@ -234,22 +235,28 @@ const (
var ErrInvalidArgs = errors.New("either image, manifest path or submission id must be provided")

type SubmissionConfig struct {
logger log.Logger

Mode string
Difficulty string
EnvVars map[string]string
Sources map[string]string
Secrets map[string]string
SecretsFrom string
ArgsIsCommand bool
ManifestPath string
SubmissionID int

credentialsProvider map[string]secretsources.CredentialProvider
}

func NewSubmissionConfig(logger log.Logger) *SubmissionConfig {
return &SubmissionConfig{
logger: logger,
func (c *SubmissionConfig) RegisterCredentialsProvider(name string, provider secretsources.CredentialProvider) {
if c.credentialsProvider == nil {
c.credentialsProvider = make(map[string]secretsources.CredentialProvider)
}
c.credentialsProvider[name] = provider
}
func (c *SubmissionConfig) RegisterCredentialsProviders() {
c.RegisterCredentialsProvider("git", &secretsources.GitCredentials{})
c.RegisterCredentialsProvider("huggingface", &secretsources.HuggingfaceCredentials{})
}

func (c *SubmissionConfig) AddFlags(flags *pflag.FlagSet) {
Expand All @@ -258,20 +265,21 @@ func (c *SubmissionConfig) AddFlags(flags *pflag.FlagSet) {
flags.StringToStringVarP(&c.EnvVars, "submission.env", "e", nil, "Environment variables to pass to the agent")
flags.StringToStringVarP(&c.Sources, "submission.source", "u", nil, "Source urls to pass to the agent")
flags.StringToStringVar(&c.Secrets, "submission.secret", nil, "Secrets to pass to the agent")
flags.StringVar(&c.SecretsFrom, "submission.secrets-from", "", "Automatically add secrets. Supported values: git, huggingface")
flags.StringVar(&c.ManifestPath, "submission.manifest", "", "Path to manifest file.")
flags.IntVar(&c.SubmissionID, "submission.id", 0, "Submission ID to retrieve manifest from")
flags.BoolVar(&c.ArgsIsCommand, "submission.set-command", false, "Treat positional arguments are command instead of entrypoint")
}

func (c *SubmissionConfig) Submission(credPath string, args []string) (*client.Submission, error) {
func (c *SubmissionConfig) Submission(config *EnvConfig, args []string) (*client.Submission, error) {
var (
nargs = len(args)
manifest *client.Manifest
)

switch {
case c.SubmissionID != 0:
cl, err := client.NewClient(c.logger, credPath)
cl, err := client.NewClient(config.logger, config.CredPath)
if err != nil {
return nil, fmt.Errorf("failed to create client: %w", err)
}
Expand Down Expand Up @@ -320,22 +328,62 @@ func (c *SubmissionConfig) Submission(credPath string, args []string) (*client.S
}

if c.Sources != nil {
level.Debug(c.logger).Log("msg", "Using sources", "sources", c.Sources)
level.Debug(config.logger).Log("msg", "Using sources", "sources", c.Sources)
manifest.Sources = make(map[string]string)
for k, v := range c.Sources {
manifest.Sources[k] = v
}
}

if manifest.Sources != nil {
init, err := initializer.NewInitializer(c.logger, manifest.Sources, c.Secrets, map[string]string{}, "")
if err != nil {
return nil, err
if c.SecretsFrom != "" {
if c.Secrets == nil {
c.Secrets = make(map[string]string)
}
}

if err := init.Validate(); err != nil {
return nil, err
if c.SecretsFrom != "" {
ss, ok := c.credentialsProvider[c.SecretsFrom]
if !ok {
return nil, fmt.Errorf("invalid value for --submission.secrets-from: %s", c.SecretsFrom)
}
switch c.SecretsFrom {
case "git":
secrets, err := secretsources.CredentialsFill(ss, manifest.Sources)
if err != nil {
return nil, err
}
if manifest.Sources == nil {
return nil, fmt.Errorf("sources are required to use --submission.secrets-from=git")
}
level.Debug(config.logger).Log("msg", "Adding git secrets")
for k, v := range secrets {
level.Info(config.logger).Log("msg", "Adding git secret", "key", k)
c.Secrets[k] = v
}
case "huggingface":
level.Debug(config.logger).Log("msg", "Adding huggingface secrets")
secrets, err := ss.Credentials("")
if err != nil {
return nil, err
}
c.Secrets["HF_TOKEN"] = secrets["HF_TOKEN"]
if manifest.Env == nil {
manifest.Env = make(map[string]string)
}
manifest.Env["HF_TOKEN"] = "{{ .Secrets.HF_TOKEN }}"
case "":
default:
return nil, fmt.Errorf("invalid value for --submission.secrets-from: %s", c.SecretsFrom)
}
}

init, err := initializer.NewInitializer(config.logger, manifest.Sources, c.Secrets, map[string]string{}, "")
if err != nil {
return nil, err
}

if err := init.Validate(); err != nil {
return nil, err
}

return &client.Submission{
Expand Down
63 changes: 57 additions & 6 deletions pkg/diambra/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
package diambra

import (
"os"
"path/filepath"
"testing"

"github.com/diambra/cli/pkg/diambra/client"
"github.com/diambra/cli/pkg/secretsources"
"github.com/go-kit/log"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -56,6 +60,13 @@ func TestAppArgs(t *testing.T) {
}

func TestSubmissionConfig(t *testing.T) {
envConfig := &EnvConfig{
logger: log.NewNopLogger(),
CredPath: "",
}
cwd, err := os.Getwd()
assert.NoError(t, err)

for _, tc := range []struct {
name string
config SubmissionConfig
Expand Down Expand Up @@ -113,20 +124,60 @@ func TestSubmissionConfig(t *testing.T) {
nil,
},
{
"from args, with secrets",
SubmissionConfig{},
[]string{"diambra/agent-random-1:main", "--gameId", "doapp"},
"from args with sources and secrets",
SubmissionConfig{
ManifestPath: "testdata/manifest.yaml",
ArgsIsCommand: true,
Sources: map[string]string{"model.zip": "https://user:{{ .Secrets.foo }}@example.com/model.zip"},
Secrets: map[string]string{
"foo": "bar",
},
},
[]string{"python", "agent.py"},
&client.Submission{
Manifest: client.Manifest{
Image: "diambra/agent-random-1:main",
Args: []string{"--gameId", "doapp"},
Image: "diambra/agent-random-1:main",
Command: []string{"python", "agent.py"},
Args: []string{"--gameId", "doapp"},
Sources: map[string]string{
"model.zip": "https://user:{{ .Secrets.foo }}@example.com/model.zip",
},
},
Secrets: map[string]string{
"foo": "bar",
},
},
nil,
},
{
"from args with sources and secrets from git",
SubmissionConfig{
ManifestPath: "testdata/manifest.yaml",
ArgsIsCommand: true,
Sources: map[string]string{"model.zip": "https://example.com/mode.zip"},
SecretsFrom: "git",
},
[]string{"python", "agent.py"},
&client.Submission{
Manifest: client.Manifest{
Image: "diambra/agent-random-1:main",
Command: []string{"python", "agent.py"},
Args: []string{"--gameId", "doapp"},
Sources: map[string]string{
"model.zip": "https://{{ .Secrets.git_username_1 }}:{{ .Secrets.git_password_1 }}@example.com/mode.zip",
},
},
Secrets: map[string]string{
"git_username_1": "user1",
"git_password_1": "pass1",
},
},
nil,
},
} {
t.Run(tc.name, func(t *testing.T) {
submission, err := tc.config.Submission("", tc.args)
tc.config.RegisterCredentialsProvider("git", &secretsources.GitCredentials{Helper: filepath.Join(cwd, "../../test/mock-credential-helper.sh")})
submission, err := tc.config.Submission(envConfig, tc.args)
assert.Equal(t, tc.expectedErr, err)
assert.Equal(t, tc.expected, submission)
})
Expand Down
80 changes: 80 additions & 0 deletions pkg/secretsources/credentials.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package secretsources

import (
"bytes"
"fmt"
"net/url"
"os/exec"
"strings"
)

type CredentialProvider interface {
Credentials(url string) (map[string]string, error)
}

type GitCredentials struct {
Helper string
}

func (c *GitCredentials) Credentials(url string) (map[string]string, error) {
args := []string{}
if c.Helper != "" {
args = append(args, "-c", fmt.Sprintf("credential.helper=%s", c.Helper))
}
args = append(args, "credential", "fill")
cmd := exec.Command("git", args...)
cmd.Stdin = strings.NewReader("url=" + url + "\n")

var stdout bytes.Buffer
cmd.Stdout = &stdout
if err := cmd.Run(); err != nil {
return nil, fmt.Errorf("failed to run %v: %w", cmd, err)
}

credentials := make(map[string]string)
lines := strings.Split(stdout.String(), "\n")
for _, line := range lines {
parts := strings.SplitN(line, "=", 2)
if len(parts) == 2 {
credentials[parts[0]] = parts[1]
}
}

return credentials, nil
}

// CredentialsFill calls the CredentialsProvider for each source and returns
// a new source map with templating as well as a map of credentials for the templated values.
func CredentialsFill(provider CredentialProvider, sources map[string]string) (map[string]string, error) {
secrets := make(map[string]string)
i := 0
for k, v := range sources {
i++
u, err := url.Parse(v)
if err != nil {
return nil, fmt.Errorf("failed to parse url %s: %w", v, err)
}
credentials, err := provider.Credentials(v)
if err != nil {
return nil, err
}
if credentials["password"] == "" {
continue
}

if credentials["host"] != u.Host {
return nil, fmt.Errorf("host %s does not match %s (this should never happend)", credentials["host"], u.Host)
}

var (
uservar = fmt.Sprintf("git_username_%d", i)
passvar = fmt.Sprintf("git_password_%d", i)
)

u.User = url.UserPassword(fmt.Sprintf("{{ %s }}", uservar), fmt.Sprintf("{{ %s }}", passvar))
secrets[uservar] = credentials["username"]
secrets[passvar] = credentials["password"]
sources[k] = fmt.Sprintf("%s://{{ .Secrets.%s }}:{{ .Secrets.%s }}@%s%s", u.Scheme, uservar, passvar, u.Host, u.Path)
}
return secrets, nil
}
Loading

0 comments on commit 9f6eb9d

Please sign in to comment.