From 80937a99d53cf90edd42db604068cd6030bf0398 Mon Sep 17 00:00:00 2001
From: Arjan Singh Bal <46515553+arjan-bal@users.noreply.github.com>
Date: Tue, 22 Oct 2024 22:58:16 +0530
Subject: [PATCH] credentials: Apply defaults to TLS configs provided through
 GetConfigForClient (#7754)

---
 credentials/tls.go          |  29 ++-
 credentials/tls_ext_test.go | 394 +++++++++++++++++++++++++++---------
 2 files changed, 324 insertions(+), 99 deletions(-)

diff --git a/credentials/tls.go b/credentials/tls.go
index 4114358545ef..e163a473df93 100644
--- a/credentials/tls.go
+++ b/credentials/tls.go
@@ -200,25 +200,40 @@ var tls12ForbiddenCipherSuites = map[uint16]struct{}{
 
 // NewTLS uses c to construct a TransportCredentials based on TLS.
 func NewTLS(c *tls.Config) TransportCredentials {
-	tc := &tlsCreds{credinternal.CloneTLSConfig(c)}
-	tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos)
+	config := applyDefaults(c)
+	if config.GetConfigForClient != nil {
+		oldFn := config.GetConfigForClient
+		config.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
+			cfgForClient, err := oldFn(hello)
+			if err != nil || cfgForClient == nil {
+				return cfgForClient, err
+			}
+			return applyDefaults(cfgForClient), nil
+		}
+	}
+	return &tlsCreds{config: config}
+}
+
+func applyDefaults(c *tls.Config) *tls.Config {
+	config := credinternal.CloneTLSConfig(c)
+	config.NextProtos = credinternal.AppendH2ToNextProtos(config.NextProtos)
 	// If the user did not configure a MinVersion and did not configure a
 	// MaxVersion < 1.2, use MinVersion=1.2, which is required by
 	// https://datatracker.ietf.org/doc/html/rfc7540#section-9.2
-	if tc.config.MinVersion == 0 && (tc.config.MaxVersion == 0 || tc.config.MaxVersion >= tls.VersionTLS12) {
-		tc.config.MinVersion = tls.VersionTLS12
+	if config.MinVersion == 0 && (config.MaxVersion == 0 || config.MaxVersion >= tls.VersionTLS12) {
+		config.MinVersion = tls.VersionTLS12
 	}
 	// If the user did not configure CipherSuites, use all "secure" cipher
 	// suites reported by the TLS package, but remove some explicitly forbidden
 	// by https://datatracker.ietf.org/doc/html/rfc7540#appendix-A
-	if tc.config.CipherSuites == nil {
+	if config.CipherSuites == nil {
 		for _, cs := range tls.CipherSuites() {
 			if _, ok := tls12ForbiddenCipherSuites[cs.ID]; !ok {
-				tc.config.CipherSuites = append(tc.config.CipherSuites, cs.ID)
+				config.CipherSuites = append(config.CipherSuites, cs.ID)
 			}
 		}
 	}
-	return tc
+	return config
 }
 
 // NewClientTLSFromCert constructs TLS credentials from the provided root
diff --git a/credentials/tls_ext_test.go b/credentials/tls_ext_test.go
index c817777b2f89..22881a6f497a 100644
--- a/credentials/tls_ext_test.go
+++ b/credentials/tls_ext_test.go
@@ -79,43 +79,86 @@ func (s) TestTLS_MinVersion12(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
 	defer cancel()
 
-	// Create server creds without a minimum version.
-	serverCreds := credentials.NewTLS(&tls.Config{
-		// MinVersion should be set to 1.2 by gRPC by default.
-		Certificates: []tls.Certificate{serverCert},
-	})
-	ss := stubserver.StubServer{
-		EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
-			return &testpb.Empty{}, nil
+	testCases := []struct {
+		name      string
+		serverTLS func() *tls.Config
+	}{
+		{
+			name: "base_case",
+			serverTLS: func() *tls.Config {
+				return &tls.Config{
+					// MinVersion should be set to 1.2 by gRPC by default.
+					Certificates: []tls.Certificate{serverCert},
+				}
+			},
+		},
+		{
+			name: "fallback_to_base",
+			serverTLS: func() *tls.Config {
+				config := &tls.Config{
+					// MinVersion should be set to 1.2 by gRPC by default.
+					Certificates: []tls.Certificate{serverCert},
+				}
+				config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
+					return nil, nil
+				}
+				return config
+			},
+		},
+		{
+			name: "dynamic_using_get_config_for_client",
+			serverTLS: func() *tls.Config {
+				return &tls.Config{
+					GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
+						return &tls.Config{
+							// MinVersion should be set to 1.2 by gRPC by default.
+							Certificates: []tls.Certificate{serverCert},
+						}, nil
+					},
+				}
+			},
 		},
 	}
 
-	// Create client creds that supports V1.0-V1.1.
-	clientCreds := credentials.NewTLS(&tls.Config{
-		ServerName: serverName,
-		RootCAs:    certPool,
-		MinVersion: tls.VersionTLS10,
-		MaxVersion: tls.VersionTLS11,
-	})
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			// Create server creds without a minimum version.
+			serverCreds := credentials.NewTLS(tc.serverTLS())
+			ss := stubserver.StubServer{
+				EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
+					return &testpb.Empty{}, nil
+				},
+			}
 
-	// Start server and client separately, because Start() blocks on a
-	// successful connection, which we will not get.
-	if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil {
-		t.Fatalf("Error starting server: %v", err)
-	}
-	defer ss.Stop()
+			// Create client creds that supports V1.0-V1.1.
+			clientCreds := credentials.NewTLS(&tls.Config{
+				ServerName: serverName,
+				RootCAs:    certPool,
+				MinVersion: tls.VersionTLS10,
+				MaxVersion: tls.VersionTLS11,
+			})
 
-	cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(clientCreds))
-	if err != nil {
-		t.Fatalf("grpc.NewClient error: %v", err)
-	}
-	defer cc.Close()
+			// Start server and client separately, because Start() blocks on a
+			// successful connection, which we will not get.
+			if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil {
+				t.Fatalf("Error starting server: %v", err)
+			}
+			defer ss.Stop()
 
-	client := testgrpc.NewTestServiceClient(cc)
+			cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(clientCreds))
+			if err != nil {
+				t.Fatalf("grpc.NewClient error: %v", err)
+			}
+			defer cc.Close()
 
-	const wantStr = "authentication handshake failed"
-	if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) {
-		t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr)
+			client := testgrpc.NewTestServiceClient(cc)
+
+			const wantStr = "authentication handshake failed"
+			if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) {
+				t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr)
+			}
+
+		})
 	}
 }
 
@@ -129,35 +172,78 @@ func (s) TestTLS_MinVersionOverridable(t *testing.T) {
 	for _, cs := range tls.CipherSuites() {
 		allCipherSuites = append(allCipherSuites, cs.ID)
 	}
-
-	// Create server creds that allow v1.0.
-	serverCreds := credentials.NewTLS(&tls.Config{
-		MinVersion:   tls.VersionTLS10,
-		Certificates: []tls.Certificate{serverCert},
-		CipherSuites: allCipherSuites,
-	})
-	ss := stubserver.StubServer{
-		EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
-			return &testpb.Empty{}, nil
+	testCases := []struct {
+		name      string
+		serverTLS func() *tls.Config
+	}{
+		{
+			name: "base_case",
+			serverTLS: func() *tls.Config {
+				return &tls.Config{
+					MinVersion:   tls.VersionTLS10,
+					Certificates: []tls.Certificate{serverCert},
+					CipherSuites: allCipherSuites,
+				}
+			},
+		},
+		{
+			name: "fallback_to_base",
+			serverTLS: func() *tls.Config {
+				config := &tls.Config{
+					MinVersion:   tls.VersionTLS10,
+					Certificates: []tls.Certificate{serverCert},
+					CipherSuites: allCipherSuites,
+				}
+				config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
+					return nil, nil
+				}
+				return config
+			},
+		},
+		{
+			name: "dynamic_using_get_config_for_client",
+			serverTLS: func() *tls.Config {
+				return &tls.Config{
+					GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
+						return &tls.Config{
+							MinVersion:   tls.VersionTLS10,
+							Certificates: []tls.Certificate{serverCert},
+							CipherSuites: allCipherSuites,
+						}, nil
+					},
+				}
+			},
 		},
 	}
 
-	// Create client creds that supports V1.0-V1.1.
-	clientCreds := credentials.NewTLS(&tls.Config{
-		ServerName:   serverName,
-		RootCAs:      certPool,
-		CipherSuites: allCipherSuites,
-		MinVersion:   tls.VersionTLS10,
-		MaxVersion:   tls.VersionTLS11,
-	})
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			// Create server creds that allow v1.0.
+			serverCreds := credentials.NewTLS(tc.serverTLS())
+			ss := stubserver.StubServer{
+				EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
+					return &testpb.Empty{}, nil
+				},
+			}
 
-	if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil {
-		t.Fatalf("Error starting stub server: %v", err)
-	}
-	defer ss.Stop()
+			// Create client creds that supports V1.0-V1.1.
+			clientCreds := credentials.NewTLS(&tls.Config{
+				ServerName:   serverName,
+				RootCAs:      certPool,
+				CipherSuites: allCipherSuites,
+				MinVersion:   tls.VersionTLS10,
+				MaxVersion:   tls.VersionTLS11,
+			})
 
-	if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
-		t.Fatalf("EmptyCall err = %v; want <nil>", err)
+			if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil {
+				t.Fatalf("Error starting stub server: %v", err)
+			}
+			defer ss.Stop()
+
+			if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
+				t.Fatalf("EmptyCall err = %v; want <nil>", err)
+			}
+		})
 	}
 }
 
@@ -165,43 +251,82 @@ func (s) TestTLS_MinVersionOverridable(t *testing.T) {
 func (s) TestTLS_CipherSuites(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
 	defer cancel()
-
-	// Create server creds without cipher suites.
-	serverCreds := credentials.NewTLS(&tls.Config{
-		Certificates: []tls.Certificate{serverCert},
-	})
-	ss := stubserver.StubServer{
-		EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
-			return &testpb.Empty{}, nil
+	testCases := []struct {
+		name      string
+		serverTLS func() *tls.Config
+	}{
+		{
+			name: "base_case",
+			serverTLS: func() *tls.Config {
+				return &tls.Config{
+					Certificates: []tls.Certificate{serverCert},
+				}
+			},
+		},
+		{
+			name: "fallback_to_base",
+			serverTLS: func() *tls.Config {
+				config := &tls.Config{
+					Certificates: []tls.Certificate{serverCert},
+				}
+				config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
+					return nil, nil
+				}
+				return config
+			},
+		},
+		{
+			name: "dynamic_using_get_config_for_client",
+			serverTLS: func() *tls.Config {
+				return &tls.Config{
+					GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
+						return &tls.Config{
+							Certificates: []tls.Certificate{serverCert},
+						}, nil
+					},
+				}
+			},
 		},
 	}
 
-	// Create client creds that use a forbidden suite only.
-	clientCreds := credentials.NewTLS(&tls.Config{
-		ServerName:   serverName,
-		RootCAs:      certPool,
-		CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
-		MaxVersion:   tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2.
-	})
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			// Create server creds without cipher suites.
+			serverCreds := credentials.NewTLS(tc.serverTLS())
+			ss := stubserver.StubServer{
+				EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
+					return &testpb.Empty{}, nil
+				},
+			}
 
-	// Start server and client separately, because Start() blocks on a
-	// successful connection, which we will not get.
-	if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil {
-		t.Fatalf("Error starting server: %v", err)
-	}
-	defer ss.Stop()
+			// Create client creds that use a forbidden suite only.
+			clientCreds := credentials.NewTLS(&tls.Config{
+				ServerName:   serverName,
+				RootCAs:      certPool,
+				CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
+				MaxVersion:   tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2.
+			})
 
-	cc, err := grpc.NewClient("dns:"+ss.Address, grpc.WithTransportCredentials(clientCreds))
-	if err != nil {
-		t.Fatalf("grpc.NewClient error: %v", err)
-	}
-	defer cc.Close()
+			// Start server and client separately, because Start() blocks on a
+			// successful connection, which we will not get.
+			if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil {
+				t.Fatalf("Error starting server: %v", err)
+			}
+			defer ss.Stop()
 
-	client := testgrpc.NewTestServiceClient(cc)
+			cc, err := grpc.NewClient("dns:"+ss.Address, grpc.WithTransportCredentials(clientCreds))
+			if err != nil {
+				t.Fatalf("grpc.NewClient error: %v", err)
+			}
+			defer cc.Close()
+
+			client := testgrpc.NewTestServiceClient(cc)
 
-	const wantStr = "authentication handshake failed"
-	if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) {
-		t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr)
+			const wantStr = "authentication handshake failed"
+			if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) {
+				t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr)
+			}
+		})
 	}
 }
 
@@ -210,23 +335,108 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
 	defer cancel()
 
-	// Create server that allows only a forbidden cipher suite.
+	testCases := []struct {
+		name      string
+		serverTLS func() *tls.Config
+	}{
+		{
+			name: "base_case",
+			serverTLS: func() *tls.Config {
+				return &tls.Config{
+					Certificates: []tls.Certificate{serverCert},
+					CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
+				}
+			},
+		},
+		{
+			name: "fallback_to_base",
+			serverTLS: func() *tls.Config {
+				config := &tls.Config{
+					Certificates: []tls.Certificate{serverCert},
+					CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
+				}
+				config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
+					return nil, nil
+				}
+				return config
+			},
+		},
+		{
+			name: "dynamic_using_get_config_for_client",
+			serverTLS: func() *tls.Config {
+				return &tls.Config{
+					GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
+						return &tls.Config{
+							Certificates: []tls.Certificate{serverCert},
+							CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
+						}, nil
+					},
+				}
+			},
+		},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			// Create server that allows only a forbidden cipher suite.
+			serverCreds := credentials.NewTLS(tc.serverTLS())
+			ss := stubserver.StubServer{
+				EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
+					return &testpb.Empty{}, nil
+				},
+			}
+
+			// Create server that allows only a forbidden cipher suite.
+			clientCreds := credentials.NewTLS(&tls.Config{
+				ServerName:   serverName,
+				RootCAs:      certPool,
+				CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
+				MaxVersion:   tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2.
+			})
+
+			if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil {
+				t.Fatalf("Error starting stub server: %v", err)
+			}
+			defer ss.Stop()
+
+			if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
+				t.Fatalf("EmptyCall err = %v; want <nil>", err)
+			}
+		})
+	}
+}
+
+// TestTLS_ServerConfiguresALPNByDefault verifies that ALPN is configured
+// correctly for a server that doesn't specify the NextProtos field and uses
+// GetConfigForClient to provide the TLS config during the handshake.
+func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) {
+	initialVal := envconfig.EnforceALPNEnabled
+	defer func() {
+		envconfig.EnforceALPNEnabled = initialVal
+	}()
+	envconfig.EnforceALPNEnabled = true
+
+	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+	defer cancel()
+
+	// Create a server that doesn't set the NextProtos field.
 	serverCreds := credentials.NewTLS(&tls.Config{
-		Certificates: []tls.Certificate{serverCert},
-		CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
+		GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
+			return &tls.Config{
+				Certificates: []tls.Certificate{serverCert},
+			}, nil
+		},
 	})
+
 	ss := stubserver.StubServer{
 		EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
 			return &testpb.Empty{}, nil
 		},
 	}
 
-	// Create server that allows only a forbidden cipher suite.
 	clientCreds := credentials.NewTLS(&tls.Config{
-		ServerName:   serverName,
-		RootCAs:      certPool,
-		CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
-		MaxVersion:   tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2.
+		ServerName: serverName,
+		RootCAs:    certPool,
 	})
 
 	if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil {