diff --git a/gateway/grpc/server.go b/gateway/grpc/server.go index 038e0ee9..9bbfc5d7 100644 --- a/gateway/grpc/server.go +++ b/gateway/grpc/server.go @@ -17,22 +17,32 @@ import ( ctypes "github.com/akash-network/akash-api/go/node/cert/v1beta3" leasev1 "github.com/akash-network/akash-api/go/provider/lease/v1" providerv1 "github.com/akash-network/akash-api/go/provider/v1" + cmblog "github.com/tendermint/tendermint/libs/log" "github.com/akash-network/provider" "github.com/akash-network/provider/gateway/utils" "github.com/akash-network/provider/tools/fromctx" ) +var ( + _ providerv1.ProviderRPCServer = (*server)(nil) + _ leasev1.LeaseRPCServer = (*server)(nil) +) + +type server struct { + *providerV1 + *leaseV1 +} + func Serve(ctx context.Context, endpoint string, certs []tls.Certificate, c provider.Client) error { group, err := fromctx.ErrGroupFromCtx(ctx) if err != nil { return err } - var ( - grpcSrv = newServer(ctx, certs, c) - log = fromctx.LogcFromCtx(ctx) - ) + grpcSrv := newServer(ctx, certs, c) + + log := fromctx.LogcFromCtx(ctx) group.Go(func() error { grpcLis, err := net.Listen("tcp", endpoint) @@ -56,16 +66,6 @@ func Serve(ctx context.Context, endpoint string, certs []tls.Certificate, c prov return nil } -var ( - _ providerv1.ProviderRPCServer = (*server)(nil) - _ leasev1.LeaseRPCServer = (*server)(nil) -) - -type server struct { - *providerV1 - *leaseV1 -} - func newServer(ctx context.Context, certs []tls.Certificate, c provider.Client) *grpc.Server { // InsecureSkipVerify is set to true due to inability to use normal TLS verification // certificate validation and authentication performed later in mtlsHandler @@ -88,7 +88,7 @@ func newServer(ctx context.Context, certs []tls.Certificate, c provider.Client) }), grpc.ChainUnaryInterceptor( mtlsInterceptor(cquery), - errorLogInterceptor(), + errorLogInterceptor(fromctx.LogcFromCtx(ctx)), ), ) @@ -111,10 +111,10 @@ func newServer(ctx context.Context, certs []tls.Certificate, c provider.Client) } func mtlsInterceptor(cquery ctypes.QueryClient) grpc.UnaryServerInterceptor { - return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, h grpc.UnaryHandler) (any, error) { + return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, next grpc.UnaryHandler) (any, error) { if p, ok := peer.FromContext(ctx); ok { if mtls, ok := p.AuthInfo.(credentials.TLSInfo); ok { - owner, err := utils.VerifyCertChain(ctx, mtls.State.PeerCertificates, "", x509.ExtKeyUsageServerAuth, cquery) + owner, err := utils.VerifyOwnerCert(ctx, mtls.State.PeerCertificates, "", x509.ExtKeyUsageServerAuth, cquery) if err != nil { return nil, fmt.Errorf("verify cert chain: %w", err) } @@ -125,18 +125,18 @@ func mtlsInterceptor(cquery ctypes.QueryClient) grpc.UnaryServerInterceptor { } } - return h(ctx, req) + return next(ctx, req) } } // TODO(andrewhare): Possibly replace this with // https://github.com/grpc-ecosystem/go-grpc-middleware/tree/main/interceptors/logging // to get full request/response logging? -func errorLogInterceptor() grpc.UnaryServerInterceptor { - return func(ctx context.Context, req any, i *grpc.UnaryServerInfo, h grpc.UnaryHandler) (any, error) { - resp, err := h(ctx, req) +func errorLogInterceptor(l cmblog.Logger) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req any, i *grpc.UnaryServerInfo, next grpc.UnaryHandler) (any, error) { + resp, err := next(ctx, req) if err != nil { - fromctx.LogcFromCtx(ctx).Error(i.FullMethod, "err", err) + l.Error(i.FullMethod, "err", err) } return resp, err diff --git a/gateway/rest/client.go b/gateway/rest/client.go index c3ec388f..25c46e74 100644 --- a/gateway/rest/client.go +++ b/gateway/rest/client.go @@ -296,7 +296,7 @@ func (c *client) verifyPeerCertificate(certificates [][]byte, _ [][]*x509.Certif return errors.Errorf("tls: invalid certificate chain") } - prov, err := utils.VerifyCertChain( + prov, err := utils.VerifyOwnerCert( context.Background(), certificates, c.host.Hostname(), diff --git a/gateway/utils/utils.go b/gateway/utils/utils.go index bd852ee3..479c9bf8 100644 --- a/gateway/utils/utils.go +++ b/gateway/utils/utils.go @@ -23,7 +23,7 @@ func NewServerTLSConfig(ctx context.Context, certs []tls.Certificate, cquery cty InsecureSkipVerify: true, // nolint: gosec MinVersion: tls.VersionTLS13, VerifyPeerCertificate: func(certificates [][]byte, _ [][]*x509.Certificate) error { - if _, err := VerifyCertChain(ctx, certificates, "", x509.ExtKeyUsageClientAuth, cquery); err != nil { + if _, err := VerifyOwnerCert(ctx, certificates, "", x509.ExtKeyUsageClientAuth, cquery); err != nil { return err } return nil @@ -33,11 +33,11 @@ func NewServerTLSConfig(ctx context.Context, certs []tls.Certificate, cquery cty return cfg, nil } -type certChain interface { +type cert interface { *x509.Certificate | []byte } -func VerifyCertChain[T certChain]( +func VerifyOwnerCert[T cert]( ctx context.Context, chain []T, dnsName string, @@ -52,31 +52,31 @@ func VerifyCertChain[T certChain]( return nil, errors.Errorf("tls: invalid certificate chain") } - var cert *x509.Certificate + var c *x509.Certificate switch t := any(chain).(type) { case []*x509.Certificate: - cert = t[0] + c = t[0] case [][]byte: var err error - if cert, err = x509.ParseCertificate(t[0]); err != nil { + if c, err = x509.ParseCertificate(t[0]); err != nil { return nil, fmt.Errorf("tls: failed to parse certificate: %w", err) } } // validation - owner, err := sdk.AccAddressFromBech32(cert.Subject.CommonName) + owner, err := sdk.AccAddressFromBech32(c.Subject.CommonName) if err != nil { return nil, fmt.Errorf("tls: invalid certificate's subject common name: %w", err) } // 1. CommonName in issuer and Subject must match and be as Bech32 format - if cert.Subject.CommonName != cert.Issuer.CommonName { + if c.Subject.CommonName != c.Issuer.CommonName { return nil, fmt.Errorf("tls: invalid certificate's issuer common name: %w", err) } // 2. serial number must be in - if cert.SerialNumber == nil { + if c.SerialNumber == nil { return nil, fmt.Errorf("tls: invalid certificate serial number: %w", err) } @@ -87,7 +87,7 @@ func VerifyCertChain[T certChain]( &ctypes.QueryCertificatesRequest{ Filter: ctypes.CertificateFilter{ Owner: owner.String(), - Serial: cert.SerialNumber.String(), + Serial: c.SerialNumber.String(), State: "valid", }, }, @@ -100,7 +100,7 @@ func VerifyCertChain[T certChain]( } clientCertPool := x509.NewCertPool() - clientCertPool.AddCert(cert) + clientCertPool.AddCert(c) opts := x509.VerifyOptions{ DNSName: dnsName, @@ -110,7 +110,7 @@ func VerifyCertChain[T certChain]( MaxConstraintComparisions: 0, } - if _, err = cert.Verify(opts); err != nil { + if _, err = c.Verify(opts); err != nil { return nil, fmt.Errorf("tls: unable to verify certificate: %w", err) }