From 0c6d7b1b66b6e3b8b4fab5adf320698f98e17b68 Mon Sep 17 00:00:00 2001 From: Nicholas Molnar <65710+neekolas@users.noreply.github.com> Date: Mon, 9 Dec 2024 16:18:09 -0800 Subject: [PATCH] Add JWT verification --- .mockery.yaml | 3 + contracts/lib/forge-std | 1 + contracts/lib/openzeppelin-contracts | 1 + pkg/constants/constants.go | 2 + pkg/interceptors/server/auth.go | 117 +++++++++++++++ pkg/interceptors/server/auth_test.go | 206 +++++++++++++++++++++++++++ pkg/mocks/authn/mock_JWTVerifier.go | 78 ++++++++++ 7 files changed, 408 insertions(+) create mode 160000 contracts/lib/forge-std create mode 160000 contracts/lib/openzeppelin-contracts create mode 100644 pkg/interceptors/server/auth.go create mode 100644 pkg/interceptors/server/auth_test.go create mode 100644 pkg/mocks/authn/mock_JWTVerifier.go diff --git a/.mockery.yaml b/.mockery.yaml index 1312ccd6..ef28dc75 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -4,6 +4,9 @@ outpkg: "{{.PackageName}}" dir: "pkg/mocks/{{.PackageName}}" filename: "mock_{{.InterfaceName}}.go" packages: + github.com/xmtp/xmtpd/pkg/authn: + interfaces: + JWTVerifier: github.com/xmtp/xmtpd/pkg/mlsvalidate: interfaces: MLSValidationService: diff --git a/contracts/lib/forge-std b/contracts/lib/forge-std new file mode 160000 index 00000000..035de35f --- /dev/null +++ b/contracts/lib/forge-std @@ -0,0 +1 @@ +Subproject commit 035de35f5e366c8d6ed142aec4ccb57fe2dd87d4 diff --git a/contracts/lib/openzeppelin-contracts b/contracts/lib/openzeppelin-contracts new file mode 160000 index 00000000..8b591bae --- /dev/null +++ b/contracts/lib/openzeppelin-contracts @@ -0,0 +1 @@ +Subproject commit 8b591baef460523e5ca1c53712c464bcc1a1c467 diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go index 922d6bde..a94d6338 100644 --- a/pkg/constants/constants.go +++ b/pkg/constants/constants.go @@ -7,3 +7,5 @@ const ( NODE_AUTHORIZATION_HEADER_NAME = "node-authorization" MAX_BLOCKCHAIN_ORIGINATOR_ID = 100 ) + +type VERIFIED_NODE_REQUEST_CTX_KEY = struct{} diff --git a/pkg/interceptors/server/auth.go b/pkg/interceptors/server/auth.go new file mode 100644 index 00000000..274023dc --- /dev/null +++ b/pkg/interceptors/server/auth.go @@ -0,0 +1,117 @@ +package server + +import ( + "context" + + "github.com/xmtp/xmtpd/pkg/authn" + "github.com/xmtp/xmtpd/pkg/constants" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +// wrappedServerStream allows us to modify the context of the stream +type wrappedServerStream struct { + grpc.ServerStream + ctx context.Context +} + +func (w *wrappedServerStream) Context() context.Context { + return w.ctx +} + +// AuthInterceptor validates JWT tokens from other nodes +type AuthInterceptor struct { + verifier authn.JWTVerifier + logger *zap.Logger +} + +func NewAuthInterceptor(verifier authn.JWTVerifier, logger *zap.Logger) *AuthInterceptor { + return &AuthInterceptor{ + verifier: verifier, + logger: logger, + } +} + +// extractToken gets the JWT token from the request metadata +func extractToken(ctx context.Context) (string, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "", status.Error(codes.Unauthenticated, "missing metadata") + } + + values := md.Get(constants.NODE_AUTHORIZATION_HEADER_NAME) + if len(values) == 0 { + return "", status.Error(codes.Unauthenticated, "missing auth token") + } + + if len(values) > 1 { + return "", status.Error(codes.Unauthenticated, "multiple auth tokens provided") + } + + return values[0], nil +} + +// Unary returns a grpc.UnaryServerInterceptor that validates JWT tokens +func (i *AuthInterceptor) Unary() grpc.UnaryServerInterceptor { + return func( + ctx context.Context, + req interface{}, + info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler, + ) (interface{}, error) { + token, err := extractToken(ctx) + if err != nil { + i.logger.Debug("failed to find auth token. Allowing request to proceed", zap.Error(err)) + return handler(ctx, req) + } + + if err := i.verifier.Verify(token); err != nil { + return nil, status.Errorf( + codes.Unauthenticated, + "invalid auth token: %v", + err, + ) + } + ctx = context.WithValue(ctx, constants.VERIFIED_NODE_REQUEST_CTX_KEY{}, true) + + return handler(ctx, req) + } +} + +// Stream returns a grpc.StreamServerInterceptor that validates JWT tokens +func (i *AuthInterceptor) Stream() grpc.StreamServerInterceptor { + return func( + srv interface{}, + stream grpc.ServerStream, + info *grpc.StreamServerInfo, + handler grpc.StreamHandler, + ) error { + token, err := extractToken(stream.Context()) + if err != nil { + i.logger.Debug("failed to find auth token. Allowing request to proceed", zap.Error(err)) + return handler(srv, stream) + } + + if err := i.verifier.Verify(token); err != nil { + return status.Errorf( + codes.Unauthenticated, + "invalid auth token: %v", + err, + ) + } + + stream = &wrappedServerStream{ + ServerStream: stream, + ctx: context.WithValue( + stream.Context(), + constants.VERIFIED_NODE_REQUEST_CTX_KEY{}, + true, + ), + } + + return handler(srv, stream) + } +} diff --git a/pkg/interceptors/server/auth_test.go b/pkg/interceptors/server/auth_test.go new file mode 100644 index 00000000..5c43e68c --- /dev/null +++ b/pkg/interceptors/server/auth_test.go @@ -0,0 +1,206 @@ +package server + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" + "github.com/xmtp/xmtpd/pkg/constants" + "github.com/xmtp/xmtpd/pkg/mocks/authn" + "go.uber.org/zap/zaptest" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +func TestUnaryInterceptor(t *testing.T) { + mockVerifier := authn.NewMockJWTVerifier(t) + logger := zaptest.NewLogger(t) + interceptor := NewAuthInterceptor(mockVerifier, logger) + + tests := []struct { + name string + setupContext func() context.Context + setupVerifier func() + wantError error + wantVerifiedNode bool + }{ + { + name: "valid token", + setupContext: func() context.Context { + md := metadata.New(map[string]string{ + constants.NODE_AUTHORIZATION_HEADER_NAME: "valid_token", + }) + return metadata.NewIncomingContext(context.Background(), md) + }, + setupVerifier: func() { + mockVerifier.EXPECT().Verify("valid_token").Return(nil) + }, + wantError: nil, + wantVerifiedNode: true, + }, + { + name: "missing metadata", + setupContext: func() context.Context { + return context.Background() + }, + setupVerifier: func() {}, + wantError: nil, + wantVerifiedNode: false, + }, + { + name: "missing token", + setupContext: func() context.Context { + md := metadata.New(map[string]string{}) + return metadata.NewIncomingContext(context.Background(), md) + }, + setupVerifier: func() {}, + wantError: nil, + wantVerifiedNode: false, + }, + { + name: "invalid token", + setupContext: func() context.Context { + md := metadata.New(map[string]string{ + constants.NODE_AUTHORIZATION_HEADER_NAME: "invalid_token", + }) + return metadata.NewIncomingContext(context.Background(), md) + }, + setupVerifier: func() { + mockVerifier.EXPECT(). + Verify("invalid_token"). + Return(errors.New("invalid signature")) + }, + wantError: status.Error( + codes.Unauthenticated, + "invalid auth token: invalid signature", + ), + wantVerifiedNode: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupVerifier() + + ctx := tt.setupContext() + var handlerCtx context.Context + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + handlerCtx = ctx + return "ok", nil + } + + _, err := interceptor.Unary()(ctx, nil, &grpc.UnaryServerInfo{}, handler) + + if tt.wantError != nil { + require.Error(t, err) + require.Equal(t, tt.wantError.Error(), err.Error()) + } else { + require.NoError(t, err) + isVerified, hasContextValue := handlerCtx.Value(constants.VERIFIED_NODE_REQUEST_CTX_KEY{}).(bool) + if tt.wantVerifiedNode { + require.True(t, isVerified) + } else { + require.False(t, hasContextValue) + } + } + }) + } +} + +func TestStreamInterceptor(t *testing.T) { + mockVerifier := authn.NewMockJWTVerifier(t) + logger := zaptest.NewLogger(t) + interceptor := NewAuthInterceptor(mockVerifier, logger) + + tests := []struct { + name string + setupContext func() context.Context + setupVerifier func() + wantError error + wantVerifiedNode bool + }{ + { + name: "valid token", + setupContext: func() context.Context { + md := metadata.New(map[string]string{ + constants.NODE_AUTHORIZATION_HEADER_NAME: "valid_token", + }) + return metadata.NewIncomingContext(context.Background(), md) + }, + setupVerifier: func() { + mockVerifier.EXPECT().Verify("valid_token").Return(nil) + }, + wantError: nil, + wantVerifiedNode: true, + }, + { + name: "missing metadata", + setupContext: func() context.Context { + return context.Background() + }, + setupVerifier: func() {}, + wantError: nil, + wantVerifiedNode: false, + }, + { + name: "invalid token", + setupContext: func() context.Context { + md := metadata.New(map[string]string{ + constants.NODE_AUTHORIZATION_HEADER_NAME: "invalid_token", + }) + return metadata.NewIncomingContext(context.Background(), md) + }, + setupVerifier: func() { + mockVerifier.EXPECT(). + Verify("invalid_token"). + Return(errors.New("invalid signature")) + }, + wantError: status.Error( + codes.Unauthenticated, + "invalid auth token: invalid signature", + ), + wantVerifiedNode: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupVerifier() + + ctx := tt.setupContext() + var handlerStream grpc.ServerStream + stream := &mockServerStream{ctx: ctx} + handler := func(srv interface{}, stream grpc.ServerStream) error { + handlerStream = stream + return nil + } + + err := interceptor.Stream()(nil, stream, &grpc.StreamServerInfo{}, handler) + + if tt.wantError != nil { + require.Error(t, err) + require.Equal(t, tt.wantError.Error(), err.Error()) + } else { + require.NoError(t, err) + isVerified, hasContextValue := handlerStream.Context().Value(constants.VERIFIED_NODE_REQUEST_CTX_KEY{}).(bool) + if tt.wantVerifiedNode { + require.True(t, isVerified) + } else { + require.False(t, hasContextValue) + } + } + }) + } +} + +type mockServerStream struct { + grpc.ServerStream + ctx context.Context +} + +func (s *mockServerStream) Context() context.Context { + return s.ctx +} diff --git a/pkg/mocks/authn/mock_JWTVerifier.go b/pkg/mocks/authn/mock_JWTVerifier.go new file mode 100644 index 00000000..009928a2 --- /dev/null +++ b/pkg/mocks/authn/mock_JWTVerifier.go @@ -0,0 +1,78 @@ +// Code generated by mockery v2.44.1. DO NOT EDIT. + +package authn + +import mock "github.com/stretchr/testify/mock" + +// MockJWTVerifier is an autogenerated mock type for the JWTVerifier type +type MockJWTVerifier struct { + mock.Mock +} + +type MockJWTVerifier_Expecter struct { + mock *mock.Mock +} + +func (_m *MockJWTVerifier) EXPECT() *MockJWTVerifier_Expecter { + return &MockJWTVerifier_Expecter{mock: &_m.Mock} +} + +// Verify provides a mock function with given fields: tokenString +func (_m *MockJWTVerifier) Verify(tokenString string) error { + ret := _m.Called(tokenString) + + if len(ret) == 0 { + panic("no return value specified for Verify") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(tokenString) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockJWTVerifier_Verify_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Verify' +type MockJWTVerifier_Verify_Call struct { + *mock.Call +} + +// Verify is a helper method to define mock.On call +// - tokenString string +func (_e *MockJWTVerifier_Expecter) Verify(tokenString interface{}) *MockJWTVerifier_Verify_Call { + return &MockJWTVerifier_Verify_Call{Call: _e.mock.On("Verify", tokenString)} +} + +func (_c *MockJWTVerifier_Verify_Call) Run(run func(tokenString string)) *MockJWTVerifier_Verify_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockJWTVerifier_Verify_Call) Return(_a0 error) *MockJWTVerifier_Verify_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockJWTVerifier_Verify_Call) RunAndReturn(run func(string) error) *MockJWTVerifier_Verify_Call { + _c.Call.Return(run) + return _c +} + +// NewMockJWTVerifier creates a new instance of MockJWTVerifier. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockJWTVerifier(t interface { + mock.TestingT + Cleanup(func()) +}) *MockJWTVerifier { + mock := &MockJWTVerifier{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +}