Skip to content

Commit

Permalink
Add methods to instantiate a StatusError without formatting (#13987)
Browse files Browse the repository at this point in the history
Adds two new methods to instantiate an `api.StatusError`:
- `NewStatusError(status int, msg string)`.
- `NewGenericStatusError(status int)`.

These can be used in the following contexts:
1. When there is an error to wrap but we want to add a status code to
it, use `api.StatusErrorf`.
2. When there is no error to wrap but we want to return an error with a
status code, use `api.NewStatusError`.
3. When we want to mask an error message to prevent any information
leakage (e.g. for auth), use `api.NewGenericStatusError`.
  • Loading branch information
tomponline authored Aug 27, 2024
2 parents a894583 + ba484a8 commit 9a5dda1
Show file tree
Hide file tree
Showing 48 changed files with 3,340 additions and 3,355 deletions.
35 changes: 18 additions & 17 deletions lxc/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"net"
"net/http"
Expand Down Expand Up @@ -156,7 +157,7 @@ func (c *cmdRemoteAdd) runToken(server string, token string, rawToken *api.Certi
conf := c.global.conf

if !conf.HasClientCertificate() {
fmt.Fprintf(os.Stderr, i18n.G("Generating a client certificate. This may take a minute...")+"\n")
fmt.Fprint(os.Stderr, i18n.G("Generating a client certificate. This may take a minute...")+"\n")
err := conf.GenerateClientCertificate()
if err != nil {
return err
Expand All @@ -179,15 +180,15 @@ func (c *cmdRemoteAdd) runToken(server string, token string, rawToken *api.Certi
}

fmt.Println(i18n.G("All server addresses are unavailable"))
fmt.Printf(i18n.G("Please provide an alternate server address (empty to abort):") + " ")
fmt.Print(i18n.G("Please provide an alternate server address (empty to abort):") + " ")

line, err := shared.ReadStdin()
if err != nil {
return err
}

if len(line) == 0 {
return fmt.Errorf(i18n.G("Failed to add remote"))
return errors.New(i18n.G("Failed to add remote"))
}

err = c.addRemoteFromToken(string(line), server, token, rawToken.Fingerprint)
Expand All @@ -210,7 +211,7 @@ func (c *cmdRemoteAdd) addRemoteFromToken(addr string, server string, token stri
if err != nil {
certificate, err = shared.GetRemoteCertificate(addr, c.global.conf.UserAgent)
if err != nil {
return api.StatusErrorf(http.StatusServiceUnavailable, i18n.G("Unavailable remote server")+": %w", err)
return api.StatusErrorf(http.StatusServiceUnavailable, "%s: %w", i18n.G("Unavailable remote server"), err)
}

certDigest := shared.CertFingerprint(certificate)
Expand All @@ -221,7 +222,7 @@ func (c *cmdRemoteAdd) addRemoteFromToken(addr string, server string, token stri
dnam := conf.ConfigPath("servercerts")
err := os.MkdirAll(dnam, 0750)
if err != nil {
return fmt.Errorf(i18n.G("Could not create server cert dir"))
return errors.New(i18n.G("Could not create server cert dir"))
}

certf := conf.ServerCertPath(server)
Expand All @@ -244,7 +245,7 @@ func (c *cmdRemoteAdd) addRemoteFromToken(addr string, server string, token stri

d, err := conf.GetInstanceServer(server)
if err != nil {
return api.StatusErrorf(http.StatusServiceUnavailable, i18n.G("Unavailable remote server")+": %w", err)
return api.StatusErrorf(http.StatusServiceUnavailable, "%s: %w", i18n.G("Unavailable remote server"), err)
}

req := api.CertificatesPost{}
Expand Down Expand Up @@ -290,12 +291,12 @@ func (c *cmdRemoteAdd) run(cmd *cobra.Command, args []string) error {
}

if len(addr) == 0 {
return fmt.Errorf(i18n.G("Remote address must not be empty"))
return errors.New(i18n.G("Remote address must not be empty"))
}

// Validate the server name.
if strings.Contains(server, ":") {
return fmt.Errorf(i18n.G("Remote names may not contain colons"))
return errors.New(i18n.G("Remote names may not contain colons"))
}

// Check for existing remote
Expand Down Expand Up @@ -332,7 +333,7 @@ func (c *cmdRemoteAdd) run(cmd *cobra.Command, args []string) error {
// Fast track simplestreams
if c.flagProtocol == "simplestreams" {
if remoteURL.Scheme != "https" {
return fmt.Errorf(i18n.G("Only https URLs are supported for simplestreams"))
return errors.New(i18n.G("Only https URLs are supported for simplestreams"))
}

conf.Remotes[server] = config.Remote{Addr: addr, Public: true, Protocol: c.flagProtocol}
Expand Down Expand Up @@ -397,7 +398,7 @@ func (c *cmdRemoteAdd) run(cmd *cobra.Command, args []string) error {
// adding the remote server.
if rScheme != "unix" && !c.flagPublic && (c.flagAuthType == api.AuthenticationMethodTLS || c.flagAuthType == "") {
if !conf.HasClientCertificate() {
fmt.Fprintf(os.Stderr, i18n.G("Generating a client certificate. This may take a minute...")+"\n")
fmt.Fprint(os.Stderr, i18n.G("Generating a client certificate. This may take a minute...")+"\n")
err = conf.GenerateClientCertificate()
if err != nil {
return err
Expand Down Expand Up @@ -451,26 +452,26 @@ func (c *cmdRemoteAdd) run(cmd *cobra.Command, args []string) error {
if !c.flagAcceptCert {
digest := shared.CertFingerprint(certificate)

fmt.Printf(i18n.G("Certificate fingerprint: %s")+"\n", digest)
fmt.Printf(i18n.G("ok (y/n/[fingerprint])?") + " ")
fmt.Printf("%s: %s\n", i18n.G("Certificate fingerprint"), digest)
fmt.Print(i18n.G("ok (y/n/[fingerprint])?") + " ")
line, err := shared.ReadStdin()
if err != nil {
return err
}

if string(line) != digest {
if len(line) < 1 || strings.ToLower(string(line[0])) == i18n.G("n") {
return fmt.Errorf(i18n.G("Server certificate NACKed by user"))
return errors.New(i18n.G("Server certificate NACKed by user"))
} else if strings.ToLower(string(line[0])) != i18n.G("y") {
return fmt.Errorf(i18n.G("Please type 'y', 'n' or the fingerprint:"))
return errors.New(i18n.G("Please type 'y', 'n' or the fingerprint:"))
}
}
}

dnam := conf.ConfigPath("servercerts")
err := os.MkdirAll(dnam, 0750)
if err != nil {
return fmt.Errorf(i18n.G("Could not create server cert dir"))
return errors.New(i18n.G("Could not create server cert dir"))
}

certf := conf.ServerCertPath(server)
Expand Down Expand Up @@ -612,7 +613,7 @@ func (c *cmdRemoteAdd) run(cmd *cobra.Command, args []string) error {
}

if srv.Auth != "trusted" {
return fmt.Errorf(i18n.G("Server doesn't trust us after authentication"))
return errors.New(i18n.G("Server doesn't trust us after authentication"))
}

if c.flagAuthType == api.AuthenticationMethodTLS {
Expand Down Expand Up @@ -874,7 +875,7 @@ func (c *cmdRemoteRemove) run(cmd *cobra.Command, args []string) error {
}

if conf.DefaultRemote == args[0] {
return fmt.Errorf(i18n.G("Can't remove the default remote"))
return errors.New(i18n.G("Can't remove the default remote"))
}

delete(conf.Remotes, args[0])
Expand Down
4 changes: 2 additions & 2 deletions lxd/auth/drivers/openfga.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (e *embeddedOpenFGA) CheckPermission(ctx context.Context, entityURL *api.UR

// Untrusted requests are denied.
if !auth.IsTrusted(ctx) {
return api.StatusErrorf(http.StatusForbidden, "%s", http.StatusText(http.StatusForbidden))
return api.NewGenericStatusError(http.StatusForbidden)
}

isRoot, err := auth.IsServerAdmin(ctx, e.identityCache)
Expand Down Expand Up @@ -284,7 +284,7 @@ func (e *embeddedOpenFGA) CheckPermission(ctx context.Context, entityURL *api.UR
l.Info("Access denied", logger.Ctx{"http_code": responseCode})
}

return api.StatusErrorf(responseCode, "%s", http.StatusText(responseCode))
return api.NewGenericStatusError(responseCode)
}

return nil
Expand Down
2 changes: 1 addition & 1 deletion lxd/auth/drivers/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (t *tls) load(ctx context.Context, identityCache *identity.Cache, opts Opts
func (t *tls) CheckPermission(ctx context.Context, entityURL *api.URL, entitlement auth.Entitlement) error {
// Untrusted requests are denied.
if !auth.IsTrusted(ctx) {
return api.StatusErrorf(http.StatusForbidden, "%s", http.StatusText(http.StatusForbidden))
return api.NewGenericStatusError(http.StatusForbidden)
}

isRoot, err := auth.IsServerAdmin(ctx, t.identities)
Expand Down
6 changes: 3 additions & 3 deletions lxd/devlxd.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func devlxdAPIHandlerFunc(d *Daemon, c instance.Instance, w http.ResponseWriter,

err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusBadRequest, err.Error()), c.Type() == instancetype.VM)
return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusBadRequest, "Invalid request body: %w", err), c.Type() == instancetype.VM)
}

state := api.StatusCodeFromString(req.State)
Expand All @@ -260,7 +260,7 @@ func devlxdAPIHandlerFunc(d *Daemon, c instance.Instance, w http.ResponseWriter,

err = c.VolatileSet(map[string]string{"volatile.last_state.ready": strconv.FormatBool(state == api.Ready)})
if err != nil {
return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusInternalServerError, err.Error()), c.Type() == instancetype.VM)
return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusInternalServerError, "Failed to set instance state: %w", err), c.Type() == instancetype.VM)
}

if state == api.Ready {
Expand All @@ -270,7 +270,7 @@ func devlxdAPIHandlerFunc(d *Daemon, c instance.Instance, w http.ResponseWriter,
return response.DevLxdResponse(http.StatusOK, "", "raw", c.Type() == instancetype.VM)
}

return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusMethodNotAllowed, fmt.Sprintf("method %q not allowed", r.Method)), c.Type() == instancetype.VM)
return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusMethodNotAllowed, "method %q not allowed", r.Method), c.Type() == instancetype.VM)
}

var devlxdDevicesGet = devLxdHandler{
Expand Down
2 changes: 1 addition & 1 deletion lxd/instance/drivers/qmp/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ func (m *Monitor) RemoveBlockDevice(blockDevName string) error {
err := m.run("blockdev-del", blockDevName, nil)
if err != nil {
if strings.Contains(err.Error(), "is in use") {
return api.StatusErrorf(http.StatusLocked, err.Error())
return api.StatusErrorf(http.StatusLocked, "%w", err)
}

if strings.Contains(err.Error(), "Failed to find") {
Expand Down
2 changes: 1 addition & 1 deletion lxd/instance_sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func (r *sftpServeResponse) Render(w http.ResponseWriter) error {

err = response.Upgrade(remoteConn, "sftp")
if err != nil {
return api.StatusErrorf(http.StatusInternalServerError, err.Error())
return api.StatusErrorf(http.StatusInternalServerError, "Failed to upgrade SFTP connection: %w", err)
}

ctx, cancel := context.WithCancel(r.req.Context())
Expand Down
4 changes: 2 additions & 2 deletions lxd/project/permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -1632,7 +1632,7 @@ func CheckTargetMember(p *api.Project, targetMemberName string, allMembers []db.
// If restricted groups are specified then check member is in at least one of them.
err := AllowClusterMember(p, &potentialMember)
if err != nil {
return nil, api.StatusErrorf(http.StatusForbidden, err.Error())
return nil, api.StatusErrorf(http.StatusForbidden, "%w", err)
}

return &potentialMember, nil
Expand All @@ -1647,7 +1647,7 @@ func CheckTargetGroup(ctx context.Context, tx *db.ClusterTx, p *api.Project, gro
// If restricted groups are specified then check the requested group is in the list.
err := AllowClusterGroup(p, groupName)
if err != nil {
return api.StatusErrorf(http.StatusForbidden, err.Error())
return api.StatusErrorf(http.StatusForbidden, "%w", err)
}

// Check if the target group exists.
Expand Down
Loading

0 comments on commit 9a5dda1

Please sign in to comment.