diff --git a/credentials/tls_ext_test.go b/credentials/tls_ext_test.go index cd3b9348bab0..bdcc754166b7 100644 --- a/credentials/tls_ext_test.go +++ b/credentials/tls_ext_test.go @@ -445,136 +445,104 @@ func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) { // TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when // connecting to a server that doesn't support ALPN. func (s) TestTLS_DisabledALPNClient(t *testing.T) { - tests := []struct { - name string - alpnEnforced bool - wantErr bool - }{ - { - name: "enforced", - alpnEnforced: true, - wantErr: true, - }, + listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + NextProtos: []string{}, // Empty list indicates ALPN is disabled. + }) + if err != nil { + t.Fatalf("Error starting TLS server: %v", err) } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{ - Certificates: []tls.Certificate{serverCert}, - NextProtos: []string{}, // Empty list indicates ALPN is disabled. - }) - if err != nil { - t.Fatalf("Error starting TLS server: %v", err) - } - - errCh := make(chan error, 1) - go func() { - conn, err := listener.Accept() - if err != nil { - errCh <- fmt.Errorf("listener.Accept returned error: %v", err) - } else { - // The first write to the TLS listener initiates the TLS handshake. - conn.Write([]byte("Hello, World!")) - conn.Close() - } - close(errCh) - }() - - serverAddr := listener.Addr().String() - conn, err := net.Dial("tcp", serverAddr) - if err != nil { - t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err) - } - defer conn.Close() + errCh := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + errCh <- fmt.Errorf("listener.Accept returned error: %v", err) + } else { + // The first write to the TLS listener initiates the TLS handshake. + conn.Write([]byte("Hello, World!")) + conn.Close() + } + close(errCh) + }() + + serverAddr := listener.Addr().String() + conn, err := net.Dial("tcp", serverAddr) + if err != nil { + t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err) + } + defer conn.Close() - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() - clientCfg := tls.Config{ - ServerName: serverName, - RootCAs: certPool, - NextProtos: []string{"h2"}, - } - _, _, err = credentials.NewTLS(&clientCfg).ClientHandshake(ctx, serverName, conn) + clientCfg := tls.Config{ + ServerName: serverName, + RootCAs: certPool, + NextProtos: []string{"h2"}, + } + _, _, err = credentials.NewTLS(&clientCfg).ClientHandshake(ctx, serverName, conn) - if gotErr := (err != nil); gotErr != tc.wantErr { - t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr) - } + if err == nil { + t.Errorf("expected ClientHandshake to return an error but did not") + } - select { - case err := <-errCh: - if err != nil { - t.Fatalf("Unexpected error received from server: %v", err) - } - case <-ctx.Done(): - t.Fatalf("Timeout waiting for error from server") - } - }) + select { + case err := <-errCh: + if err != nil { + t.Fatalf("Unexpected error received from server: %v", err) + } + case <-ctx.Done(): + t.Fatalf("Timeout waiting for error from server") } } // TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when // accepting a request from a client that doesn't support ALPN. func (s) TestTLS_DisabledALPNServer(t *testing.T) { - tests := []struct { - name string - alpnEnforced bool - wantErr bool - }{ - { - name: "enforced", - alpnEnforced: true, - wantErr: true, - }, + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Error starting server: %v", err) } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - listener, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("Error starting server: %v", err) - } - - errCh := make(chan error, 1) - go func() { - conn, err := listener.Accept() - if err != nil { - errCh <- fmt.Errorf("listener.Accept returned error: %v", err) - return - } - defer conn.Close() - serverCfg := tls.Config{ - Certificates: []tls.Certificate{serverCert}, - NextProtos: []string{"h2"}, - } - _, _, err = credentials.NewTLS(&serverCfg).ServerHandshake(conn) - if gotErr := (err != nil); gotErr != tc.wantErr { - t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr) - } - close(errCh) - }() - - serverAddr := listener.Addr().String() - clientCfg := &tls.Config{ - Certificates: []tls.Certificate{serverCert}, - NextProtos: []string{}, // Empty list indicates ALPN is disabled. - RootCAs: certPool, - ServerName: serverName, - } - conn, err := tls.Dial("tcp", serverAddr, clientCfg) - if err != nil { - t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err) - } - defer conn.Close() - - select { - case <-time.After(defaultTestTimeout): - t.Fatal("Timed out waiting for completion") - case err := <-errCh: - if err != nil { - t.Fatalf("Unexpected server error: %v", err) - } - } - }) + errCh := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + errCh <- fmt.Errorf("listener.Accept returned error: %v", err) + return + } + defer conn.Close() + serverCfg := tls.Config{ + Certificates: []tls.Certificate{serverCert}, + NextProtos: []string{"h2"}, + } + _, _, err = credentials.NewTLS(&serverCfg).ServerHandshake(conn) + if err == nil { + t.Errorf("expected ServerHandshake to return an error but") + } + close(errCh) + }() + + serverAddr := listener.Addr().String() + clientCfg := &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + NextProtos: []string{}, // Empty list indicates ALPN is disabled. + RootCAs: certPool, + ServerName: serverName, + } + conn, err := tls.Dial("tcp", serverAddr, clientCfg) + if err != nil { + t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err) + } + defer conn.Close() + + select { + case <-time.After(defaultTestTimeout): + t.Fatal("Timed out waiting for completion") + case err := <-errCh: + if err != nil { + t.Fatalf("Unexpected server error: %v", err) + } } }