Skip to content

Commit

Permalink
fix: PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Hare committed Mar 21, 2024
1 parent db28180 commit 75e6c33
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 35 deletions.
44 changes: 22 additions & 22 deletions gateway/grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,32 @@ import (
ctypes "github.com/akash-network/akash-api/go/node/cert/v1beta3"
leasev1 "github.com/akash-network/akash-api/go/provider/lease/v1"
providerv1 "github.com/akash-network/akash-api/go/provider/v1"
cmblog "github.com/tendermint/tendermint/libs/log"

"github.com/akash-network/provider"
"github.com/akash-network/provider/gateway/utils"
"github.com/akash-network/provider/tools/fromctx"
)

var (
_ providerv1.ProviderRPCServer = (*server)(nil)
_ leasev1.LeaseRPCServer = (*server)(nil)
)

type server struct {
*providerV1
*leaseV1
}

func Serve(ctx context.Context, endpoint string, certs []tls.Certificate, c provider.Client) error {
group, err := fromctx.ErrGroupFromCtx(ctx)
if err != nil {
return err
}

var (
grpcSrv = newServer(ctx, certs, c)
log = fromctx.LogcFromCtx(ctx)
)
grpcSrv := newServer(ctx, certs, c)

log := fromctx.LogcFromCtx(ctx)

group.Go(func() error {
grpcLis, err := net.Listen("tcp", endpoint)
Expand All @@ -56,16 +66,6 @@ func Serve(ctx context.Context, endpoint string, certs []tls.Certificate, c prov
return nil
}

var (
_ providerv1.ProviderRPCServer = (*server)(nil)
_ leasev1.LeaseRPCServer = (*server)(nil)
)

type server struct {
*providerV1
*leaseV1
}

func newServer(ctx context.Context, certs []tls.Certificate, c provider.Client) *grpc.Server {
// InsecureSkipVerify is set to true due to inability to use normal TLS verification
// certificate validation and authentication performed later in mtlsHandler
Expand All @@ -88,7 +88,7 @@ func newServer(ctx context.Context, certs []tls.Certificate, c provider.Client)
}),
grpc.ChainUnaryInterceptor(
mtlsInterceptor(cquery),
errorLogInterceptor(),
errorLogInterceptor(fromctx.LogcFromCtx(ctx)),
),
)

Expand All @@ -111,10 +111,10 @@ func newServer(ctx context.Context, certs []tls.Certificate, c provider.Client)
}

func mtlsInterceptor(cquery ctypes.QueryClient) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, h grpc.UnaryHandler) (any, error) {
return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, next grpc.UnaryHandler) (any, error) {
if p, ok := peer.FromContext(ctx); ok {
if mtls, ok := p.AuthInfo.(credentials.TLSInfo); ok {
owner, err := utils.VerifyCertChain(ctx, mtls.State.PeerCertificates, "", x509.ExtKeyUsageServerAuth, cquery)
owner, err := utils.VerifyOwnerCert(ctx, mtls.State.PeerCertificates, "", x509.ExtKeyUsageServerAuth, cquery)
if err != nil {
return nil, fmt.Errorf("verify cert chain: %w", err)
}
Expand All @@ -125,18 +125,18 @@ func mtlsInterceptor(cquery ctypes.QueryClient) grpc.UnaryServerInterceptor {
}
}

return h(ctx, req)
return next(ctx, req)
}
}

// TODO(andrewhare): Possibly replace this with
// https://github.com/grpc-ecosystem/go-grpc-middleware/tree/main/interceptors/logging
// to get full request/response logging?
func errorLogInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, i *grpc.UnaryServerInfo, h grpc.UnaryHandler) (any, error) {
resp, err := h(ctx, req)
func errorLogInterceptor(l cmblog.Logger) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, i *grpc.UnaryServerInfo, next grpc.UnaryHandler) (any, error) {
resp, err := next(ctx, req)
if err != nil {
fromctx.LogcFromCtx(ctx).Error(i.FullMethod, "err", err)
l.Error(i.FullMethod, "err", err)
}

return resp, err
Expand Down
2 changes: 1 addition & 1 deletion gateway/rest/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ func (c *client) verifyPeerCertificate(certificates [][]byte, _ [][]*x509.Certif
return errors.Errorf("tls: invalid certificate chain")
}

prov, err := utils.VerifyCertChain(
prov, err := utils.VerifyOwnerCert(
context.Background(),
certificates,
c.host.Hostname(),
Expand Down
24 changes: 12 additions & 12 deletions gateway/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func NewServerTLSConfig(ctx context.Context, certs []tls.Certificate, cquery cty
InsecureSkipVerify: true, // nolint: gosec
MinVersion: tls.VersionTLS13,
VerifyPeerCertificate: func(certificates [][]byte, _ [][]*x509.Certificate) error {
if _, err := VerifyCertChain(ctx, certificates, "", x509.ExtKeyUsageClientAuth, cquery); err != nil {
if _, err := VerifyOwnerCert(ctx, certificates, "", x509.ExtKeyUsageClientAuth, cquery); err != nil {
return err
}
return nil
Expand All @@ -33,11 +33,11 @@ func NewServerTLSConfig(ctx context.Context, certs []tls.Certificate, cquery cty
return cfg, nil
}

type certChain interface {
type cert interface {
*x509.Certificate | []byte
}

func VerifyCertChain[T certChain](
func VerifyOwnerCert[T cert](
ctx context.Context,
chain []T,
dnsName string,
Expand All @@ -52,31 +52,31 @@ func VerifyCertChain[T certChain](
return nil, errors.Errorf("tls: invalid certificate chain")
}

var cert *x509.Certificate
var c *x509.Certificate

switch t := any(chain).(type) {
case []*x509.Certificate:
cert = t[0]
c = t[0]
case [][]byte:
var err error
if cert, err = x509.ParseCertificate(t[0]); err != nil {
if c, err = x509.ParseCertificate(t[0]); err != nil {
return nil, fmt.Errorf("tls: failed to parse certificate: %w", err)
}
}

// validation
owner, err := sdk.AccAddressFromBech32(cert.Subject.CommonName)
owner, err := sdk.AccAddressFromBech32(c.Subject.CommonName)
if err != nil {
return nil, fmt.Errorf("tls: invalid certificate's subject common name: %w", err)
}

// 1. CommonName in issuer and Subject must match and be as Bech32 format
if cert.Subject.CommonName != cert.Issuer.CommonName {
if c.Subject.CommonName != c.Issuer.CommonName {
return nil, fmt.Errorf("tls: invalid certificate's issuer common name: %w", err)
}

// 2. serial number must be in
if cert.SerialNumber == nil {
if c.SerialNumber == nil {
return nil, fmt.Errorf("tls: invalid certificate serial number: %w", err)
}

Expand All @@ -87,7 +87,7 @@ func VerifyCertChain[T certChain](
&ctypes.QueryCertificatesRequest{
Filter: ctypes.CertificateFilter{
Owner: owner.String(),
Serial: cert.SerialNumber.String(),
Serial: c.SerialNumber.String(),
State: "valid",
},
},
Expand All @@ -100,7 +100,7 @@ func VerifyCertChain[T certChain](
}

clientCertPool := x509.NewCertPool()
clientCertPool.AddCert(cert)
clientCertPool.AddCert(c)

opts := x509.VerifyOptions{
DNSName: dnsName,
Expand All @@ -110,7 +110,7 @@ func VerifyCertChain[T certChain](
MaxConstraintComparisions: 0,
}

if _, err = cert.Verify(opts); err != nil {
if _, err = c.Verify(opts); err != nil {
return nil, fmt.Errorf("tls: unable to verify certificate: %w", err)
}

Expand Down

0 comments on commit 75e6c33

Please sign in to comment.