Skip to content

Commit

Permalink
Merge pull request #388 from zong-zhe/refactor-oci-auth
Browse files Browse the repository at this point in the history
feat: add cache for credential to reduce the probability that kpm would be considered a threat
  • Loading branch information
Peefy authored Jul 18, 2024
2 parents b65d5ed + 74b0e81 commit 635551d
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 40 deletions.
142 changes: 135 additions & 7 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"encoding/json"
"fmt"
"io"
Expand All @@ -19,7 +20,9 @@ import (
"github.com/otiai10/copy"
"golang.org/x/mod/module"
"kcl-lang.io/kcl-go/pkg/kcl"
"oras.land/oras-go/pkg/auth"
"oras.land/oras-go/v2"
remoteauth "oras.land/oras-go/v2/registry/remote/auth"

"kcl-lang.io/kpm/pkg/constants"
"kcl-lang.io/kpm/pkg/downloader"
Expand All @@ -41,6 +44,8 @@ type KpmClient struct {
logWriter io.Writer
// The downloader of the dependencies.
DepDownloader *downloader.DepDownloader
// credential store
credsClient *downloader.CredClient
// The home path of kpm for global configuration file and kcl package storage path.
homePath string
// The settings of kpm loaded from the global configuration file.
Expand Down Expand Up @@ -75,6 +80,33 @@ func (c *KpmClient) SetNoSumCheck(noSumCheck bool) {
c.noSumCheck = noSumCheck
}

// GetCredsClient will return the credential client.
func (c *KpmClient) GetCredsClient() (*downloader.CredClient, error) {
if c.credsClient == nil {
credCli, err := downloader.LoadCredentialFile(c.settings.CredentialsFile)
if err != nil {
return nil, err
}
c.credsClient = credCli
}
return c.credsClient, nil
}

// GetCredentials will return the credentials of the host.
func (c *KpmClient) GetCredentials(hostName string) (*remoteauth.Credential, error) {
credCli, err := c.GetCredsClient()
if err != nil {
return nil, err
}

creds, err := credCli.Credential(hostName)
if err != nil {
return nil, err
}

return creds, nil
}

// GetNoSumCheck will return the 'noSumCheck' flag.
func (c *KpmClient) GetNoSumCheck() bool {
return c.noSumCheck
Expand Down Expand Up @@ -953,7 +985,18 @@ func (c *KpmClient) FillDependenciesInfo(modFile *pkg.ModFile) error {

// AcquireTheLatestOciVersion will acquire the latest version of the OCI reference.
func (c *KpmClient) AcquireTheLatestOciVersion(ociSource downloader.Oci) (string, error) {
ociClient, err := oci.NewOciClient(ociSource.Reg, ociSource.Repo, &c.settings)
repoPath := utils.JoinPath(ociSource.Reg, ociSource.Repo)
cred, err := c.GetCredentials(ociSource.Reg)
if err != nil {
return "", err
}

ociClient, err := oci.NewOciClientWithOpts(
oci.WithCredential(cred),
oci.WithRepoPath(repoPath),
oci.WithPlainHttp(c.GetSettings().DefaultOciPlainHttp()),
)

if err != nil {
return "", err
}
Expand Down Expand Up @@ -1098,11 +1141,16 @@ func (c *KpmClient) Download(dep *pkg.Dependency, homePath, localPath string) (*
// clean the temp dir.
defer os.RemoveAll(tmpDir)

credCli, err := c.GetCredsClient()
if err != nil {
return nil, err
}
err = c.DepDownloader.Download(*downloader.NewDownloadOptions(
downloader.WithLocalPath(tmpDir),
downloader.WithSource(dep.Source),
downloader.WithLogWriter(c.logWriter),
downloader.WithSettings(c.settings),
downloader.WithCredsClient(credCli),
))
if err != nil {
return nil, err
Expand Down Expand Up @@ -1276,10 +1324,22 @@ func (c *KpmClient) ParseKclModFile(kclPkg *pkg.KclPkg) (map[string]map[string]s

// LoadPkgFromOci will download the kcl package from the oci repository and return an `KclPkg`.
func (c *KpmClient) DownloadPkgFromOci(dep *downloader.Oci, localPath string) (*pkg.KclPkg, error) {
ociClient, err := oci.NewOciClient(dep.Reg, dep.Repo, &c.settings)
repoPath := utils.JoinPath(dep.Reg, dep.Repo)
cred, err := c.GetCredentials(dep.Reg)
if err != nil {
return nil, err
}

ociClient, err := oci.NewOciClientWithOpts(
oci.WithCredential(cred),
oci.WithRepoPath(repoPath),
oci.WithPlainHttp(c.GetSettings().DefaultOciPlainHttp()),
)

if err != nil {
return nil, err
}

ociClient.SetLogWriter(c.logWriter)
// Select the latest tag, if the tag, the user inputed, is empty.
var tagSelected string
Expand Down Expand Up @@ -1478,7 +1538,18 @@ func (c *KpmClient) PullFromOci(localPath, source, tag string) error {

// PushToOci will push a kcl package to oci registry.
func (c *KpmClient) PushToOci(localPath string, ociOpts *opt.OciOptions) error {
ociCli, err := oci.NewOciClient(ociOpts.Reg, ociOpts.Repo, &c.settings)
repoPath := utils.JoinPath(ociOpts.Reg, ociOpts.Repo)
cred, err := c.GetCredentials(ociOpts.Reg)
if err != nil {
return err
}

ociCli, err := oci.NewOciClientWithOpts(
oci.WithCredential(cred),
oci.WithRepoPath(repoPath),
oci.WithPlainHttp(c.GetSettings().DefaultOciPlainHttp()),
)

if err != nil {
return err
}
Expand All @@ -1504,12 +1575,46 @@ func (c *KpmClient) PushToOci(localPath string, ociOpts *opt.OciOptions) error {

// LoginOci will login to the oci registry.
func (c *KpmClient) LoginOci(hostname, username, password string) error {
return oci.Login(hostname, username, password, &c.settings)

credCli, err := c.GetCredsClient()
if err != nil {
return err
}

err = credCli.GetAuthClient().LoginWithOpts(
[]auth.LoginOption{
auth.WithLoginHostname(hostname),
auth.WithLoginUsername(username),
auth.WithLoginSecret(password),
}...,
)

if err != nil {
return reporter.NewErrorEvent(
reporter.FailedLogin,
err,
fmt.Sprintf("failed to login '%s', please check registry, username and password is valid", hostname),
)
}

return nil
}

// LogoutOci will logout from the oci registry.
func (c *KpmClient) LogoutOci(hostname string) error {
return oci.Logout(hostname, &c.settings)

credCli, err := c.GetCredsClient()
if err != nil {
return err
}

err = credCli.GetAuthClient().Logout(context.Background(), hostname)

if err != nil {
return reporter.NewErrorEvent(reporter.FailedLogout, err, fmt.Sprintf("failed to logout '%s'", hostname))
}

return nil
}

// ParseOciRef will parser '<repo_name>:<repo_tag>' into an 'OciOptions'.
Expand Down Expand Up @@ -1753,7 +1858,18 @@ func (c *KpmClient) pullTarFromOci(localPath string, ociOpts *opt.OciOptions) er
return reporter.NewErrorEvent(reporter.Bug, err)
}

ociCli, err := oci.NewOciClient(ociOpts.Reg, ociOpts.Repo, &c.settings)
repoPath := utils.JoinPath(ociOpts.Reg, ociOpts.Repo)
cred, err := c.GetCredentials(ociOpts.Reg)
if err != nil {
return err
}

ociCli, err := oci.NewOciClientWithOpts(
oci.WithCredential(cred),
oci.WithRepoPath(repoPath),
oci.WithPlainHttp(c.GetSettings().DefaultOciPlainHttp()),
)

if err != nil {
return err
}
Expand Down Expand Up @@ -1790,7 +1906,19 @@ func (c *KpmClient) pullTarFromOci(localPath string, ociOpts *opt.OciOptions) er

// FetchOciManifestConfIntoJsonStr will fetch the oci manifest config of the kcl package from the oci registry and return it into json string.
func (c *KpmClient) FetchOciManifestIntoJsonStr(opts opt.OciFetchOptions) (string, error) {
ociCli, err := oci.NewOciClient(opts.Reg, opts.Repo, &c.settings)

repoPath := utils.JoinPath(opts.Reg, opts.Repo)
cred, err := c.GetCredentials(opts.Reg)
if err != nil {
return "", err
}

ociCli, err := oci.NewOciClientWithOpts(
oci.WithCredential(cred),
oci.WithRepoPath(repoPath),
oci.WithPlainHttp(c.GetSettings().DefaultOciPlainHttp()),
)

if err != nil {
return "", err
}
Expand Down
6 changes: 6 additions & 0 deletions pkg/client/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,18 @@ func (rv *RemoteVisitor) Visit(s *downloader.Source, v visitFunc) error {
tmpDir = filepath.Join(tmpDir, constants.GitScheme)
}

credCli, err := rv.kpmcli.GetCredsClient()
if err != nil {
return err
}

defer os.RemoveAll(tmpDir)
err = rv.kpmcli.DepDownloader.Download(*downloader.NewDownloadOptions(
downloader.WithLocalPath(tmpDir),
downloader.WithSource(*s),
downloader.WithLogWriter(rv.kpmcli.GetLogWriter()),
downloader.WithSettings(*rv.kpmcli.GetSettings()),
downloader.WithCredsClient(credCli),
))

if err != nil {
Expand Down
50 changes: 50 additions & 0 deletions pkg/downloader/credential.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package downloader

import (
"fmt"

dockerauth "oras.land/oras-go/pkg/auth/docker"
remoteauth "oras.land/oras-go/v2/registry/remote/auth"
)

// CredClient is the client to get the credentials.
type CredClient struct {
credsClient *dockerauth.Client
}

// LoadCredentialFile loads the credential file and return the CredClient.
func LoadCredentialFile(filepath string) (*CredClient, error) {
authClient, err := dockerauth.NewClientWithDockerFallback(filepath)
if err != nil {
return nil, err
}
dockerAuthClient, ok := authClient.(*dockerauth.Client)
if !ok {
return nil, fmt.Errorf("authClient is not *docker.Client type")
}

return &CredClient{
credsClient: dockerAuthClient,
}, nil
}

// GetAuthClient returns the auth client.
func (cred *CredClient) GetAuthClient() *dockerauth.Client {
return cred.credsClient
}

// Credential will reture the credential info cache in CredClient
func (cred *CredClient) Credential(hostName string) (*remoteauth.Credential, error) {
if len(hostName) == 0 {
return nil, fmt.Errorf("hostName is empty")
}
username, password, err := cred.credsClient.Credential(hostName)
if err != nil {
return nil, err
}

return &remoteauth.Credential{
Username: username,
Password: password,
}, nil
}
29 changes: 28 additions & 1 deletion pkg/downloader/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"kcl-lang.io/kpm/pkg/reporter"
"kcl-lang.io/kpm/pkg/settings"
"kcl-lang.io/kpm/pkg/utils"
remoteauth "oras.land/oras-go/v2/registry/remote/auth"
)

// DownloadOptions is the options for downloading a package.
Expand All @@ -25,10 +26,18 @@ type DownloadOptions struct {
Settings settings.Settings
// LogWriter is the writer to write the log.
LogWriter io.Writer
// credsClient is the client to get the credentials.
credsClient *CredClient
}

type Option func(*DownloadOptions)

func WithCredsClient(credsClient *CredClient) Option {
return func(do *DownloadOptions) {
do.credsClient = credsClient
}
}

func WithLogWriter(logWriter io.Writer) Option {
return func(do *DownloadOptions) {
do.LogWriter = logWriter
Expand Down Expand Up @@ -125,7 +134,25 @@ func (d *OciDownloader) Download(opts DownloadOptions) error {

localPath := opts.LocalPath

ociCli, err := oci.NewOciClient(ociSource.Reg, ociSource.Repo, &opts.Settings)
repoPath := utils.JoinPath(ociSource.Reg, ociSource.Repo)

var cred *remoteauth.Credential
var err error
if opts.credsClient != nil {
cred, err = opts.credsClient.Credential(ociSource.Reg)
if err != nil {
return err
}
} else {
cred = &remoteauth.Credential{}
}

ociCli, err := oci.NewOciClientWithOpts(
oci.WithCredential(cred),
oci.WithRepoPath(repoPath),
oci.WithPlainHttp(opts.Settings.DefaultOciPlainHttp()),
)

if err != nil {
return err
}
Expand Down
Loading

0 comments on commit 635551d

Please sign in to comment.