Skip to content

Commit

Permalink
credentials/test: remove ALPN table driven tests
Browse files Browse the repository at this point in the history
now that ALPN is always enforced,
the table driven tests that support
this can be removed.

most of the changes are whitespace,
however an assertion has been modified
so that the tests will fail if an expected
error is not returned
  • Loading branch information
pnikonowicz committed Dec 10, 2024
1 parent 8bb8c54 commit 19bbbc1
Showing 1 changed file with 85 additions and 117 deletions.
202 changes: 85 additions & 117 deletions credentials/tls_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,136 +445,104 @@ func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) {
// TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when
// connecting to a server that doesn't support ALPN.
func (s) TestTLS_DisabledALPNClient(t *testing.T) {
tests := []struct {
name string
alpnEnforced bool
wantErr bool
}{
{
name: "enforced",
alpnEnforced: true,
wantErr: true,
},
listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{
Certificates: []tls.Certificate{serverCert},
NextProtos: []string{}, // Empty list indicates ALPN is disabled.
})
if err != nil {
t.Fatalf("Error starting TLS server: %v", err)
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{
Certificates: []tls.Certificate{serverCert},
NextProtos: []string{}, // Empty list indicates ALPN is disabled.
})
if err != nil {
t.Fatalf("Error starting TLS server: %v", err)
}

errCh := make(chan error, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
} else {
// The first write to the TLS listener initiates the TLS handshake.
conn.Write([]byte("Hello, World!"))
conn.Close()
}
close(errCh)
}()

serverAddr := listener.Addr().String()
conn, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err)
}
defer conn.Close()
errCh := make(chan error, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
} else {
// The first write to the TLS listener initiates the TLS handshake.
conn.Write([]byte("Hello, World!"))
conn.Close()
}
close(errCh)
}()

serverAddr := listener.Addr().String()
conn, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err)
}
defer conn.Close()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

clientCfg := tls.Config{
ServerName: serverName,
RootCAs: certPool,
NextProtos: []string{"h2"},
}
_, _, err = credentials.NewTLS(&clientCfg).ClientHandshake(ctx, serverName, conn)
clientCfg := tls.Config{
ServerName: serverName,
RootCAs: certPool,
NextProtos: []string{"h2"},
}
_, _, err = credentials.NewTLS(&clientCfg).ClientHandshake(ctx, serverName, conn)

if gotErr := (err != nil); gotErr != tc.wantErr {
t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
}
if err == nil {
t.Errorf("expected ClientHandshake to return an error but did not")
}

select {
case err := <-errCh:
if err != nil {
t.Fatalf("Unexpected error received from server: %v", err)
}
case <-ctx.Done():
t.Fatalf("Timeout waiting for error from server")
}
})
select {
case err := <-errCh:
if err != nil {
t.Fatalf("Unexpected error received from server: %v", err)
}
case <-ctx.Done():
t.Fatalf("Timeout waiting for error from server")
}
}

// TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when
// accepting a request from a client that doesn't support ALPN.
func (s) TestTLS_DisabledALPNServer(t *testing.T) {
tests := []struct {
name string
alpnEnforced bool
wantErr bool
}{
{
name: "enforced",
alpnEnforced: true,
wantErr: true,
},
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error starting server: %v", err)
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error starting server: %v", err)
}

errCh := make(chan error, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
return
}
defer conn.Close()
serverCfg := tls.Config{
Certificates: []tls.Certificate{serverCert},
NextProtos: []string{"h2"},
}
_, _, err = credentials.NewTLS(&serverCfg).ServerHandshake(conn)
if gotErr := (err != nil); gotErr != tc.wantErr {
t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
}
close(errCh)
}()

serverAddr := listener.Addr().String()
clientCfg := &tls.Config{
Certificates: []tls.Certificate{serverCert},
NextProtos: []string{}, // Empty list indicates ALPN is disabled.
RootCAs: certPool,
ServerName: serverName,
}
conn, err := tls.Dial("tcp", serverAddr, clientCfg)
if err != nil {
t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err)
}
defer conn.Close()

select {
case <-time.After(defaultTestTimeout):
t.Fatal("Timed out waiting for completion")
case err := <-errCh:
if err != nil {
t.Fatalf("Unexpected server error: %v", err)
}
}
})
errCh := make(chan error, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
return
}
defer conn.Close()
serverCfg := tls.Config{
Certificates: []tls.Certificate{serverCert},
NextProtos: []string{"h2"},
}
_, _, err = credentials.NewTLS(&serverCfg).ServerHandshake(conn)
if err == nil {
t.Errorf("expected ServerHandshake to return an error but")
}
close(errCh)
}()

serverAddr := listener.Addr().String()
clientCfg := &tls.Config{
Certificates: []tls.Certificate{serverCert},
NextProtos: []string{}, // Empty list indicates ALPN is disabled.
RootCAs: certPool,
ServerName: serverName,
}
conn, err := tls.Dial("tcp", serverAddr, clientCfg)
if err != nil {
t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err)
}
defer conn.Close()

select {
case <-time.After(defaultTestTimeout):
t.Fatal("Timed out waiting for completion")
case err := <-errCh:
if err != nil {
t.Fatalf("Unexpected server error: %v", err)
}
}
}

0 comments on commit 19bbbc1

Please sign in to comment.