-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added JWT validation for gRPC server streams, enhancing security for authenticated requests. - Introduced a new mock implementation for the JWTVerifier interface to facilitate testing. - **Bug Fixes** - Improved error handling for missing or invalid JWT tokens in both unary and stream interceptors. - **Tests** - Added comprehensive unit tests for the AuthInterceptor functionality, covering various scenarios and edge cases. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
- Loading branch information
Showing
5 changed files
with
406 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
package server | ||
|
||
import ( | ||
"context" | ||
|
||
"github.com/xmtp/xmtpd/pkg/authn" | ||
"github.com/xmtp/xmtpd/pkg/constants" | ||
"go.uber.org/zap" | ||
"google.golang.org/grpc" | ||
"google.golang.org/grpc/codes" | ||
"google.golang.org/grpc/metadata" | ||
"google.golang.org/grpc/status" | ||
) | ||
|
||
// wrappedServerStream allows us to modify the context of the stream | ||
type wrappedServerStream struct { | ||
grpc.ServerStream | ||
ctx context.Context | ||
} | ||
|
||
func (w *wrappedServerStream) Context() context.Context { | ||
return w.ctx | ||
} | ||
|
||
// AuthInterceptor validates JWT tokens from other nodes | ||
type AuthInterceptor struct { | ||
verifier authn.JWTVerifier | ||
logger *zap.Logger | ||
} | ||
|
||
func NewAuthInterceptor(verifier authn.JWTVerifier, logger *zap.Logger) *AuthInterceptor { | ||
return &AuthInterceptor{ | ||
verifier: verifier, | ||
logger: logger, | ||
} | ||
} | ||
|
||
// extractToken gets the JWT token from the request metadata | ||
func extractToken(ctx context.Context) (string, error) { | ||
md, ok := metadata.FromIncomingContext(ctx) | ||
if !ok { | ||
return "", status.Error(codes.Unauthenticated, "missing metadata") | ||
} | ||
|
||
values := md.Get(constants.NODE_AUTHORIZATION_HEADER_NAME) | ||
if len(values) == 0 { | ||
return "", status.Error(codes.Unauthenticated, "missing auth token") | ||
} | ||
|
||
if len(values) > 1 { | ||
return "", status.Error(codes.Unauthenticated, "multiple auth tokens provided") | ||
} | ||
|
||
return values[0], nil | ||
} | ||
|
||
// Unary returns a grpc.UnaryServerInterceptor that validates JWT tokens | ||
func (i *AuthInterceptor) Unary() grpc.UnaryServerInterceptor { | ||
return func( | ||
ctx context.Context, | ||
req interface{}, | ||
info *grpc.UnaryServerInfo, | ||
handler grpc.UnaryHandler, | ||
) (interface{}, error) { | ||
token, err := extractToken(ctx) | ||
if err != nil { | ||
i.logger.Debug("failed to find auth token. Allowing request to proceed", zap.Error(err)) | ||
return handler(ctx, req) | ||
} | ||
|
||
if err := i.verifier.Verify(token); err != nil { | ||
return nil, status.Errorf( | ||
codes.Unauthenticated, | ||
"invalid auth token: %v", | ||
err, | ||
) | ||
} | ||
ctx = context.WithValue(ctx, constants.VerifiedNodeRequestCtxKey{}, true) | ||
|
||
return handler(ctx, req) | ||
} | ||
} | ||
|
||
// Stream returns a grpc.StreamServerInterceptor that validates JWT tokens | ||
func (i *AuthInterceptor) Stream() grpc.StreamServerInterceptor { | ||
return func( | ||
srv interface{}, | ||
stream grpc.ServerStream, | ||
info *grpc.StreamServerInfo, | ||
handler grpc.StreamHandler, | ||
) error { | ||
token, err := extractToken(stream.Context()) | ||
if err != nil { | ||
i.logger.Debug("failed to find auth token. Allowing request to proceed", zap.Error(err)) | ||
return handler(srv, stream) | ||
} | ||
|
||
if err := i.verifier.Verify(token); err != nil { | ||
return status.Errorf( | ||
codes.Unauthenticated, | ||
"invalid auth token: %v", | ||
err, | ||
) | ||
} | ||
|
||
stream = &wrappedServerStream{ | ||
ServerStream: stream, | ||
ctx: context.WithValue( | ||
stream.Context(), | ||
constants.VerifiedNodeRequestCtxKey{}, | ||
true, | ||
), | ||
} | ||
|
||
return handler(srv, stream) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
package server | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
"github.com/xmtp/xmtpd/pkg/constants" | ||
"github.com/xmtp/xmtpd/pkg/mocks/authn" | ||
"go.uber.org/zap/zaptest" | ||
"google.golang.org/grpc" | ||
"google.golang.org/grpc/codes" | ||
"google.golang.org/grpc/metadata" | ||
"google.golang.org/grpc/status" | ||
) | ||
|
||
func TestUnaryInterceptor(t *testing.T) { | ||
mockVerifier := authn.NewMockJWTVerifier(t) | ||
logger := zaptest.NewLogger(t) | ||
interceptor := NewAuthInterceptor(mockVerifier, logger) | ||
|
||
tests := []struct { | ||
name string | ||
setupContext func() context.Context | ||
setupVerifier func() | ||
wantError error | ||
wantVerifiedNode bool | ||
}{ | ||
{ | ||
name: "valid token", | ||
setupContext: func() context.Context { | ||
md := metadata.New(map[string]string{ | ||
constants.NODE_AUTHORIZATION_HEADER_NAME: "valid_token", | ||
}) | ||
return metadata.NewIncomingContext(context.Background(), md) | ||
}, | ||
setupVerifier: func() { | ||
mockVerifier.EXPECT().Verify("valid_token").Return(nil) | ||
}, | ||
wantError: nil, | ||
wantVerifiedNode: true, | ||
}, | ||
{ | ||
name: "missing metadata", | ||
setupContext: func() context.Context { | ||
return context.Background() | ||
}, | ||
setupVerifier: func() {}, | ||
wantError: nil, | ||
wantVerifiedNode: false, | ||
}, | ||
{ | ||
name: "missing token", | ||
setupContext: func() context.Context { | ||
md := metadata.New(map[string]string{}) | ||
return metadata.NewIncomingContext(context.Background(), md) | ||
}, | ||
setupVerifier: func() {}, | ||
wantError: nil, | ||
wantVerifiedNode: false, | ||
}, | ||
{ | ||
name: "invalid token", | ||
setupContext: func() context.Context { | ||
md := metadata.New(map[string]string{ | ||
constants.NODE_AUTHORIZATION_HEADER_NAME: "invalid_token", | ||
}) | ||
return metadata.NewIncomingContext(context.Background(), md) | ||
}, | ||
setupVerifier: func() { | ||
mockVerifier.EXPECT(). | ||
Verify("invalid_token"). | ||
Return(errors.New("invalid signature")) | ||
}, | ||
wantError: status.Error( | ||
codes.Unauthenticated, | ||
"invalid auth token: invalid signature", | ||
), | ||
wantVerifiedNode: false, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
tt.setupVerifier() | ||
|
||
ctx := tt.setupContext() | ||
var handlerCtx context.Context | ||
handler := func(ctx context.Context, req interface{}) (interface{}, error) { | ||
handlerCtx = ctx | ||
return "ok", nil | ||
} | ||
|
||
_, err := interceptor.Unary()(ctx, nil, &grpc.UnaryServerInfo{}, handler) | ||
|
||
if tt.wantError != nil { | ||
require.Error(t, err) | ||
require.Equal(t, tt.wantError.Error(), err.Error()) | ||
} else { | ||
require.NoError(t, err) | ||
isVerified, hasContextValue := handlerCtx.Value(constants.VerifiedNodeRequestCtxKey{}).(bool) | ||
if tt.wantVerifiedNode { | ||
require.True(t, isVerified) | ||
} else { | ||
require.False(t, hasContextValue) | ||
} | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestStreamInterceptor(t *testing.T) { | ||
mockVerifier := authn.NewMockJWTVerifier(t) | ||
logger := zaptest.NewLogger(t) | ||
interceptor := NewAuthInterceptor(mockVerifier, logger) | ||
|
||
tests := []struct { | ||
name string | ||
setupContext func() context.Context | ||
setupVerifier func() | ||
wantError error | ||
wantVerifiedNode bool | ||
}{ | ||
{ | ||
name: "valid token", | ||
setupContext: func() context.Context { | ||
md := metadata.New(map[string]string{ | ||
constants.NODE_AUTHORIZATION_HEADER_NAME: "valid_token", | ||
}) | ||
return metadata.NewIncomingContext(context.Background(), md) | ||
}, | ||
setupVerifier: func() { | ||
mockVerifier.EXPECT().Verify("valid_token").Return(nil) | ||
}, | ||
wantError: nil, | ||
wantVerifiedNode: true, | ||
}, | ||
{ | ||
name: "missing metadata", | ||
setupContext: func() context.Context { | ||
return context.Background() | ||
}, | ||
setupVerifier: func() {}, | ||
wantError: nil, | ||
wantVerifiedNode: false, | ||
}, | ||
{ | ||
name: "invalid token", | ||
setupContext: func() context.Context { | ||
md := metadata.New(map[string]string{ | ||
constants.NODE_AUTHORIZATION_HEADER_NAME: "invalid_token", | ||
}) | ||
return metadata.NewIncomingContext(context.Background(), md) | ||
}, | ||
setupVerifier: func() { | ||
mockVerifier.EXPECT(). | ||
Verify("invalid_token"). | ||
Return(errors.New("invalid signature")) | ||
}, | ||
wantError: status.Error( | ||
codes.Unauthenticated, | ||
"invalid auth token: invalid signature", | ||
), | ||
wantVerifiedNode: false, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
tt.setupVerifier() | ||
|
||
ctx := tt.setupContext() | ||
var handlerStream grpc.ServerStream | ||
stream := &mockServerStreamWithCtx{ctx: ctx} | ||
handler := func(srv interface{}, stream grpc.ServerStream) error { | ||
handlerStream = stream | ||
return nil | ||
} | ||
|
||
err := interceptor.Stream()(nil, stream, &grpc.StreamServerInfo{}, handler) | ||
|
||
if tt.wantError != nil { | ||
require.Error(t, err) | ||
require.Equal(t, tt.wantError.Error(), err.Error()) | ||
} else { | ||
require.NoError(t, err) | ||
isVerified, hasContextValue := handlerStream.Context().Value(constants.VerifiedNodeRequestCtxKey{}).(bool) | ||
if tt.wantVerifiedNode { | ||
require.True(t, isVerified) | ||
} else { | ||
require.False(t, hasContextValue) | ||
} | ||
} | ||
}) | ||
} | ||
} | ||
|
||
type mockServerStreamWithCtx struct { | ||
grpc.ServerStream | ||
ctx context.Context | ||
} | ||
|
||
func (s *mockServerStreamWithCtx) Context() context.Context { | ||
return s.ctx | ||
} |
Oops, something went wrong.