From a85733a5fe0c7d6ee1b25f5d627513fbb2c268b3 Mon Sep 17 00:00:00 2001 From: Andrew Hare Date: Fri, 5 Apr 2024 18:41:27 -0600 Subject: [PATCH] fix: Misc TLS issues --- cmd/provider-services/cmd/grpc.go | 10 +- cmd/provider-services/cmd/manifest.go | 4 +- gateway/grpc/client.go | 1 + gateway/grpc/server.go | 2 +- gateway/grpc/server_test.go | 136 ++++++++++++++++++++++++++ gateway/utils/utils.go | 2 +- 6 files changed, 149 insertions(+), 6 deletions(-) diff --git a/cmd/provider-services/cmd/grpc.go b/cmd/provider-services/cmd/grpc.go index c3a62e8c..f88cc8e2 100644 --- a/cmd/provider-services/cmd/grpc.go +++ b/cmd/provider-services/cmd/grpc.go @@ -3,13 +3,19 @@ package cmd import ( "fmt" "net" + "net/url" ) func grpcURI(hostURI string) (string, error) { - host, _, err := net.SplitHostPort(hostURI) + u, err := url.Parse(hostURI) + if err != nil { + return "", fmt.Errorf("url parse: %w", err) + } + + h, _, err := net.SplitHostPort(u.Host) if err != nil { return "", fmt.Errorf("split host port: %w", err) } - return net.JoinHostPort(host, "8442"), nil + return net.JoinHostPort(h, "8444"), nil } diff --git a/cmd/provider-services/cmd/manifest.go b/cmd/provider-services/cmd/manifest.go index 495a4575..a0b41d73 100644 --- a/cmd/provider-services/cmd/manifest.go +++ b/cmd/provider-services/cmd/manifest.go @@ -103,12 +103,12 @@ func doSendManifest(cmd *cobra.Command, sdlpath string) error { ) for i, lid := range leases { - err := func() error { + err = func() error { ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() provAddr, _ := sdk.AccAddressFromBech32(lid.Provider) - prov, err := cl.Provider(context.Background(), &ptypes.QueryProviderRequest{Owner: provAddr.String()}) + prov, err := cl.Provider(ctx, &ptypes.QueryProviderRequest{Owner: provAddr.String()}) if err != nil { return fmt.Errorf("query client provider: %w", err) } diff --git a/gateway/grpc/client.go b/gateway/grpc/client.go index 8da2ba21..af657f3d 100644 --- a/gateway/grpc/client.go +++ b/gateway/grpc/client.go @@ -31,6 +31,7 @@ func NewClient(ctx context.Context, addr string, cert tls.Certificate, cquery ct tlsConfig := tls.Config{ InsecureSkipVerify: true, Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS13, VerifyPeerCertificate: func(certificates [][]byte, _ [][]*x509.Certificate) error { if _, err := utils.VerifyOwnerCertBytes(ctx, certificates, "", x509.ExtKeyUsageClientAuth, cquery); err != nil { return err diff --git a/gateway/grpc/server.go b/gateway/grpc/server.go index 9bbfc5d7..e11f8181 100644 --- a/gateway/grpc/server.go +++ b/gateway/grpc/server.go @@ -114,7 +114,7 @@ func mtlsInterceptor(cquery ctypes.QueryClient) grpc.UnaryServerInterceptor { return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, next grpc.UnaryHandler) (any, error) { if p, ok := peer.FromContext(ctx); ok { if mtls, ok := p.AuthInfo.(credentials.TLSInfo); ok { - owner, err := utils.VerifyOwnerCert(ctx, mtls.State.PeerCertificates, "", x509.ExtKeyUsageServerAuth, cquery) + owner, err := utils.VerifyOwnerCert(ctx, mtls.State.PeerCertificates, "", x509.ExtKeyUsageClientAuth, cquery) if err != nil { return nil, fmt.Errorf("verify cert chain: %w", err) } diff --git a/gateway/grpc/server_test.go b/gateway/grpc/server_test.go index 5bdbac5e..e516e2bb 100644 --- a/gateway/grpc/server_test.go +++ b/gateway/grpc/server_test.go @@ -2,10 +2,18 @@ package grpc import ( "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "errors" + "math/big" "net" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -209,3 +217,131 @@ func TestRPCs(t *testing.T) { }) } } + +func TestMTLS(t *testing.T) { + var ( + qclient = &qmock.QueryClient{} + com = testutil.CertificateOptionMocks(qclient) + cod = testutil.CertificateOptionDomains([]string{"localhost", "127.0.0.1"}) + ) + + crt := testutil.Certificate(t, testutil.AccAddress(t), com, cod) + + qclient.EXPECT().Certificates(mock.Anything, mock.Anything).Return(&types.QueryCertificatesResponse{ + Certificates: types.CertificatesResponse{ + types.CertificateResponse{ + Certificate: types.Certificate{ + State: types.CertificateValid, + Cert: crt.PEM.Cert, + Pubkey: crt.PEM.Pub, + }, + Serial: crt.Serial.String(), + }, + }, + }, nil) + + cases := []struct { + desc string + cert func(*testing.T) tls.Certificate + errContains string + }{ + { + desc: "good cert", + cert: func(*testing.T) tls.Certificate { + return testutil.Certificate(t, testutil.AccAddress(t), com, cod).Cert[0] + }, + }, + { + desc: "empty chain", + cert: func(*testing.T) tls.Certificate { + return tls.Certificate{} + }, + errContains: "empty chain", + }, + { + desc: "invalid subject", + cert: func(t *testing.T) tls.Certificate { + t.Helper() + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + template := x509.Certificate{ + SerialNumber: new(big.Int).SetInt64(time.Now().UTC().UnixNano()), + Subject: pkix.Name{ + CommonName: "badcert", + }, + BasicConstraintsValid: true, + } + + certDer, err := x509.CreateCertificate(rand.Reader, &template, &template, priv.Public(), priv) + require.NoError(t, err) + + keyDer, err := x509.MarshalPKCS8PrivateKey(priv) + require.NoError(t, err) + + certBytes := pem.EncodeToMemory(&pem.Block{ + Type: types.PemBlkTypeCertificate, + Bytes: certDer, + }) + privBytes := pem.EncodeToMemory(&pem.Block{ + Type: types.PemBlkTypeECPrivateKey, + Bytes: keyDer, + }) + + cert, err := tls.X509KeyPair(certBytes, privBytes) + require.NoError(t, err) + + return cert + }, + errContains: "invalid certificate's subject", + }, + } + + for _, c := range cases { + c := c + + t.Run(c.desc, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ctx = ContextWithQueryClient(ctx, qclient) + + var ( + m mocks.Client + mc mmocks.Client + ) + + mc.EXPECT().Submit(mock.Anything, mock.Anything, mock.Anything).Return(nil) + m.EXPECT().Manifest().Return(&mc) + + s := newServer(ctx, crt.Cert, &m) + defer s.Stop() + + l, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + go func() { + require.NoError(t, s.Serve(l)) + }() + + tlsConfig := tls.Config{ + InsecureSkipVerify: true, + Certificates: []tls.Certificate{c.cert(t)}, + } + + conn, err := grpc.DialContext(ctx, l.Addr().String(), + grpc.WithTransportCredentials(credentials.NewTLS(&tlsConfig))) + require.NoError(t, err) + + defer conn.Close() + + _, err = leasev1.NewLeaseRPCClient(conn).SendManifest(ctx, &leasev1.SendManifestRequest{}) + if c.errContains != "" { + assert.ErrorContains(t, err, c.errContains) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/gateway/utils/utils.go b/gateway/utils/utils.go index 7abe3831..596e548e 100644 --- a/gateway/utils/utils.go +++ b/gateway/utils/utils.go @@ -52,7 +52,7 @@ func VerifyOwnerCertBytes(ctx context.Context, chain [][]byte, dnsName string, u func VerifyOwnerCert(ctx context.Context, chain []*x509.Certificate, dnsName string, usage x509.ExtKeyUsage, cquery ctypes.QueryClient) (sdk.Address, error) { if len(chain) == 0 { - return nil, nil + return nil, errors.Errorf("tls: empty chain") } if len(chain) > 1 {