diff --git a/gateway/grpc/server.go b/gateway/grpc/server.go index 941f7cc8..acf9499e 100644 --- a/gateway/grpc/server.go +++ b/gateway/grpc/server.go @@ -3,6 +3,7 @@ package grpc import ( "crypto/tls" "crypto/x509" + "encoding/pem" "errors" "fmt" "net" @@ -164,8 +165,18 @@ func mtlsInterceptor() grpc.UnaryServerInterceptor { return nil, errors.New("tls: attempt to use non-existing or revoked certificate") // nolint: goerr113 } + block, rest := pem.Decode(resp.Certificates[0].Certificate.Cert) + if len(rest) > 0 { + return nil, fmt.Errorf("%w: tls: failed to decode onchain certificate", err) + } + + onchainCert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, fmt.Errorf("%w: tls: failed to parse onchain certificate", err) + } + clientCertPool := x509.NewCertPool() - clientCertPool.AddCert(cert) + clientCertPool.AddCert(onchainCert) opts := x509.VerifyOptions{ Roots: clientCertPool, diff --git a/gateway/rest/client.go b/gateway/rest/client.go index 44b9f80d..d9063bd9 100644 --- a/gateway/rest/client.go +++ b/gateway/rest/client.go @@ -7,6 +7,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/json" + "encoding/pem" "fmt" "io" "net/http" @@ -337,8 +338,18 @@ func (c *client) verifyPeerCertificate(certificates [][]byte, _ [][]*x509.Certif return errors.New("tls: attempt to use non-existing or revoked certificate") } + block, rest := pem.Decode(resp.Certificates[0].Certificate.Cert) + if len(rest) > 0 { + return fmt.Errorf("%w: tls: failed to decode onchain certificate", err) + } + + onchainCert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return fmt.Errorf("%w: tls: failed to parse onchain certificate", err) + } + certPool := x509.NewCertPool() - certPool.AddCert(cert) + certPool.AddCert(onchainCert) opts := x509.VerifyOptions{ DNSName: c.host.Hostname(), diff --git a/gateway/utils/utils.go b/gateway/utils/utils.go index e030e4c6..34c7b73f 100644 --- a/gateway/utils/utils.go +++ b/gateway/utils/utils.go @@ -4,10 +4,10 @@ import ( "context" "crypto/tls" "crypto/x509" + "encoding/pem" + "fmt" "time" - "github.com/pkg/errors" - sdk "github.com/cosmos/cosmos-sdk/types" ctypes "github.com/akash-network/akash-api/go/node/cert/v1beta3" @@ -24,28 +24,28 @@ func NewServerTLSConfig(ctx context.Context, certs []tls.Certificate, cquery cty VerifyPeerCertificate: func(certificates [][]byte, _ [][]*x509.Certificate) error { if len(certificates) > 0 { if len(certificates) != 1 { - return errors.Errorf("tls: invalid certificate chain") + return fmt.Errorf("tls: invalid certificate chain") } cert, err := x509.ParseCertificate(certificates[0]) if err != nil { - return errors.Wrap(err, "tls: failed to parse certificate") + return fmt.Errorf("%w: tls: failed to parse certificate", err) } // validation var owner sdk.Address if owner, err = sdk.AccAddressFromBech32(cert.Subject.CommonName); err != nil { - return errors.Wrap(err, "tls: invalid certificate's subject common name") + return fmt.Errorf("%w: tls: invalid certificate's subject common name", err) } // 1. CommonName in issuer and Subject must match and be as Bech32 format if cert.Subject.CommonName != cert.Issuer.CommonName { - return errors.Wrap(err, "tls: invalid certificate's issuer common name") + return fmt.Errorf("%w: tls: invalid certificate's issuer common name", err) } // 2. serial number must be in if cert.SerialNumber == nil { - return errors.Wrap(err, "tls: invalid certificate serial number") + return fmt.Errorf("%w: tls: invalid certificate serial number", err) } // 3. look up certificate on chain @@ -61,14 +61,24 @@ func NewServerTLSConfig(ctx context.Context, certs []tls.Certificate, cquery cty }, ) if err != nil { - return errors.Wrap(err, "tls: unable to fetch certificate from chain") + return fmt.Errorf("%w: tls: unable to fetch certificate from chain", err) } if (len(resp.Certificates) != 1) || !resp.Certificates[0].Certificate.IsState(ctypes.CertificateValid) { - return errors.New("tls: attempt to use non-existing or revoked certificate") + return fmt.Errorf("%w tls: attempt to use non-existing or revoked certificate", err) + } + + block, rest := pem.Decode(resp.Certificates[0].Certificate.Cert) + if len(rest) > 0 { + return fmt.Errorf("%w: tls: failed to decode onchain certificate", err) + } + + onchainCert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return fmt.Errorf("%w: tls: failed to parse onchain certificate", err) } clientCertPool := x509.NewCertPool() - clientCertPool.AddCert(cert) + clientCertPool.AddCert(onchainCert) opts := x509.VerifyOptions{ Roots: clientCertPool, @@ -78,7 +88,7 @@ func NewServerTLSConfig(ctx context.Context, certs []tls.Certificate, cquery cty } if _, err = cert.Verify(opts); err != nil { - return errors.Wrap(err, "tls: unable to verify certificate") + return fmt.Errorf("%w: tls: unable to verify certificate", err) } } return nil