From 0c6d7b1b66b6e3b8b4fab5adf320698f98e17b68 Mon Sep 17 00:00:00 2001
From: Nicholas Molnar <65710+neekolas@users.noreply.github.com>
Date: Mon, 9 Dec 2024 16:18:09 -0800
Subject: [PATCH] Add JWT verification

---
 .mockery.yaml                        |   3 +
 contracts/lib/forge-std              |   1 +
 contracts/lib/openzeppelin-contracts |   1 +
 pkg/constants/constants.go           |   2 +
 pkg/interceptors/server/auth.go      | 117 +++++++++++++++
 pkg/interceptors/server/auth_test.go | 206 +++++++++++++++++++++++++++
 pkg/mocks/authn/mock_JWTVerifier.go  |  78 ++++++++++
 7 files changed, 408 insertions(+)
 create mode 160000 contracts/lib/forge-std
 create mode 160000 contracts/lib/openzeppelin-contracts
 create mode 100644 pkg/interceptors/server/auth.go
 create mode 100644 pkg/interceptors/server/auth_test.go
 create mode 100644 pkg/mocks/authn/mock_JWTVerifier.go

diff --git a/.mockery.yaml b/.mockery.yaml
index 1312ccd6..ef28dc75 100644
--- a/.mockery.yaml
+++ b/.mockery.yaml
@@ -4,6 +4,9 @@ outpkg: "{{.PackageName}}"
 dir: "pkg/mocks/{{.PackageName}}"
 filename: "mock_{{.InterfaceName}}.go"
 packages:
+  github.com/xmtp/xmtpd/pkg/authn:
+    interfaces:
+      JWTVerifier:
   github.com/xmtp/xmtpd/pkg/mlsvalidate:
     interfaces:
       MLSValidationService:
diff --git a/contracts/lib/forge-std b/contracts/lib/forge-std
new file mode 160000
index 00000000..035de35f
--- /dev/null
+++ b/contracts/lib/forge-std
@@ -0,0 +1 @@
+Subproject commit 035de35f5e366c8d6ed142aec4ccb57fe2dd87d4
diff --git a/contracts/lib/openzeppelin-contracts b/contracts/lib/openzeppelin-contracts
new file mode 160000
index 00000000..8b591bae
--- /dev/null
+++ b/contracts/lib/openzeppelin-contracts
@@ -0,0 +1 @@
+Subproject commit 8b591baef460523e5ca1c53712c464bcc1a1c467
diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go
index 922d6bde..a94d6338 100644
--- a/pkg/constants/constants.go
+++ b/pkg/constants/constants.go
@@ -7,3 +7,5 @@ const (
 	NODE_AUTHORIZATION_HEADER_NAME     = "node-authorization"
 	MAX_BLOCKCHAIN_ORIGINATOR_ID       = 100
 )
+
+type VERIFIED_NODE_REQUEST_CTX_KEY = struct{}
diff --git a/pkg/interceptors/server/auth.go b/pkg/interceptors/server/auth.go
new file mode 100644
index 00000000..274023dc
--- /dev/null
+++ b/pkg/interceptors/server/auth.go
@@ -0,0 +1,117 @@
+package server
+
+import (
+	"context"
+
+	"github.com/xmtp/xmtpd/pkg/authn"
+	"github.com/xmtp/xmtpd/pkg/constants"
+	"go.uber.org/zap"
+	"google.golang.org/grpc"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/metadata"
+	"google.golang.org/grpc/status"
+)
+
+// wrappedServerStream allows us to modify the context of the stream
+type wrappedServerStream struct {
+	grpc.ServerStream
+	ctx context.Context
+}
+
+func (w *wrappedServerStream) Context() context.Context {
+	return w.ctx
+}
+
+// AuthInterceptor validates JWT tokens from other nodes
+type AuthInterceptor struct {
+	verifier authn.JWTVerifier
+	logger   *zap.Logger
+}
+
+func NewAuthInterceptor(verifier authn.JWTVerifier, logger *zap.Logger) *AuthInterceptor {
+	return &AuthInterceptor{
+		verifier: verifier,
+		logger:   logger,
+	}
+}
+
+// extractToken gets the JWT token from the request metadata
+func extractToken(ctx context.Context) (string, error) {
+	md, ok := metadata.FromIncomingContext(ctx)
+	if !ok {
+		return "", status.Error(codes.Unauthenticated, "missing metadata")
+	}
+
+	values := md.Get(constants.NODE_AUTHORIZATION_HEADER_NAME)
+	if len(values) == 0 {
+		return "", status.Error(codes.Unauthenticated, "missing auth token")
+	}
+
+	if len(values) > 1 {
+		return "", status.Error(codes.Unauthenticated, "multiple auth tokens provided")
+	}
+
+	return values[0], nil
+}
+
+// Unary returns a grpc.UnaryServerInterceptor that validates JWT tokens
+func (i *AuthInterceptor) Unary() grpc.UnaryServerInterceptor {
+	return func(
+		ctx context.Context,
+		req interface{},
+		info *grpc.UnaryServerInfo,
+		handler grpc.UnaryHandler,
+	) (interface{}, error) {
+		token, err := extractToken(ctx)
+		if err != nil {
+			i.logger.Debug("failed to find auth token. Allowing request to proceed", zap.Error(err))
+			return handler(ctx, req)
+		}
+
+		if err := i.verifier.Verify(token); err != nil {
+			return nil, status.Errorf(
+				codes.Unauthenticated,
+				"invalid auth token: %v",
+				err,
+			)
+		}
+		ctx = context.WithValue(ctx, constants.VERIFIED_NODE_REQUEST_CTX_KEY{}, true)
+
+		return handler(ctx, req)
+	}
+}
+
+// Stream returns a grpc.StreamServerInterceptor that validates JWT tokens
+func (i *AuthInterceptor) Stream() grpc.StreamServerInterceptor {
+	return func(
+		srv interface{},
+		stream grpc.ServerStream,
+		info *grpc.StreamServerInfo,
+		handler grpc.StreamHandler,
+	) error {
+		token, err := extractToken(stream.Context())
+		if err != nil {
+			i.logger.Debug("failed to find auth token. Allowing request to proceed", zap.Error(err))
+			return handler(srv, stream)
+		}
+
+		if err := i.verifier.Verify(token); err != nil {
+			return status.Errorf(
+				codes.Unauthenticated,
+				"invalid auth token: %v",
+				err,
+			)
+		}
+
+		stream = &wrappedServerStream{
+			ServerStream: stream,
+			ctx: context.WithValue(
+				stream.Context(),
+				constants.VERIFIED_NODE_REQUEST_CTX_KEY{},
+				true,
+			),
+		}
+
+		return handler(srv, stream)
+	}
+}
diff --git a/pkg/interceptors/server/auth_test.go b/pkg/interceptors/server/auth_test.go
new file mode 100644
index 00000000..5c43e68c
--- /dev/null
+++ b/pkg/interceptors/server/auth_test.go
@@ -0,0 +1,206 @@
+package server
+
+import (
+	"context"
+	"errors"
+	"testing"
+
+	"github.com/stretchr/testify/require"
+	"github.com/xmtp/xmtpd/pkg/constants"
+	"github.com/xmtp/xmtpd/pkg/mocks/authn"
+	"go.uber.org/zap/zaptest"
+	"google.golang.org/grpc"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/metadata"
+	"google.golang.org/grpc/status"
+)
+
+func TestUnaryInterceptor(t *testing.T) {
+	mockVerifier := authn.NewMockJWTVerifier(t)
+	logger := zaptest.NewLogger(t)
+	interceptor := NewAuthInterceptor(mockVerifier, logger)
+
+	tests := []struct {
+		name             string
+		setupContext     func() context.Context
+		setupVerifier    func()
+		wantError        error
+		wantVerifiedNode bool
+	}{
+		{
+			name: "valid token",
+			setupContext: func() context.Context {
+				md := metadata.New(map[string]string{
+					constants.NODE_AUTHORIZATION_HEADER_NAME: "valid_token",
+				})
+				return metadata.NewIncomingContext(context.Background(), md)
+			},
+			setupVerifier: func() {
+				mockVerifier.EXPECT().Verify("valid_token").Return(nil)
+			},
+			wantError:        nil,
+			wantVerifiedNode: true,
+		},
+		{
+			name: "missing metadata",
+			setupContext: func() context.Context {
+				return context.Background()
+			},
+			setupVerifier:    func() {},
+			wantError:        nil,
+			wantVerifiedNode: false,
+		},
+		{
+			name: "missing token",
+			setupContext: func() context.Context {
+				md := metadata.New(map[string]string{})
+				return metadata.NewIncomingContext(context.Background(), md)
+			},
+			setupVerifier:    func() {},
+			wantError:        nil,
+			wantVerifiedNode: false,
+		},
+		{
+			name: "invalid token",
+			setupContext: func() context.Context {
+				md := metadata.New(map[string]string{
+					constants.NODE_AUTHORIZATION_HEADER_NAME: "invalid_token",
+				})
+				return metadata.NewIncomingContext(context.Background(), md)
+			},
+			setupVerifier: func() {
+				mockVerifier.EXPECT().
+					Verify("invalid_token").
+					Return(errors.New("invalid signature"))
+			},
+			wantError: status.Error(
+				codes.Unauthenticated,
+				"invalid auth token: invalid signature",
+			),
+			wantVerifiedNode: false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			tt.setupVerifier()
+
+			ctx := tt.setupContext()
+			var handlerCtx context.Context
+			handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+				handlerCtx = ctx
+				return "ok", nil
+			}
+
+			_, err := interceptor.Unary()(ctx, nil, &grpc.UnaryServerInfo{}, handler)
+
+			if tt.wantError != nil {
+				require.Error(t, err)
+				require.Equal(t, tt.wantError.Error(), err.Error())
+			} else {
+				require.NoError(t, err)
+				isVerified, hasContextValue := handlerCtx.Value(constants.VERIFIED_NODE_REQUEST_CTX_KEY{}).(bool)
+				if tt.wantVerifiedNode {
+					require.True(t, isVerified)
+				} else {
+					require.False(t, hasContextValue)
+				}
+			}
+		})
+	}
+}
+
+func TestStreamInterceptor(t *testing.T) {
+	mockVerifier := authn.NewMockJWTVerifier(t)
+	logger := zaptest.NewLogger(t)
+	interceptor := NewAuthInterceptor(mockVerifier, logger)
+
+	tests := []struct {
+		name             string
+		setupContext     func() context.Context
+		setupVerifier    func()
+		wantError        error
+		wantVerifiedNode bool
+	}{
+		{
+			name: "valid token",
+			setupContext: func() context.Context {
+				md := metadata.New(map[string]string{
+					constants.NODE_AUTHORIZATION_HEADER_NAME: "valid_token",
+				})
+				return metadata.NewIncomingContext(context.Background(), md)
+			},
+			setupVerifier: func() {
+				mockVerifier.EXPECT().Verify("valid_token").Return(nil)
+			},
+			wantError:        nil,
+			wantVerifiedNode: true,
+		},
+		{
+			name: "missing metadata",
+			setupContext: func() context.Context {
+				return context.Background()
+			},
+			setupVerifier:    func() {},
+			wantError:        nil,
+			wantVerifiedNode: false,
+		},
+		{
+			name: "invalid token",
+			setupContext: func() context.Context {
+				md := metadata.New(map[string]string{
+					constants.NODE_AUTHORIZATION_HEADER_NAME: "invalid_token",
+				})
+				return metadata.NewIncomingContext(context.Background(), md)
+			},
+			setupVerifier: func() {
+				mockVerifier.EXPECT().
+					Verify("invalid_token").
+					Return(errors.New("invalid signature"))
+			},
+			wantError: status.Error(
+				codes.Unauthenticated,
+				"invalid auth token: invalid signature",
+			),
+			wantVerifiedNode: false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			tt.setupVerifier()
+
+			ctx := tt.setupContext()
+			var handlerStream grpc.ServerStream
+			stream := &mockServerStream{ctx: ctx}
+			handler := func(srv interface{}, stream grpc.ServerStream) error {
+				handlerStream = stream
+				return nil
+			}
+
+			err := interceptor.Stream()(nil, stream, &grpc.StreamServerInfo{}, handler)
+
+			if tt.wantError != nil {
+				require.Error(t, err)
+				require.Equal(t, tt.wantError.Error(), err.Error())
+			} else {
+				require.NoError(t, err)
+				isVerified, hasContextValue := handlerStream.Context().Value(constants.VERIFIED_NODE_REQUEST_CTX_KEY{}).(bool)
+				if tt.wantVerifiedNode {
+					require.True(t, isVerified)
+				} else {
+					require.False(t, hasContextValue)
+				}
+			}
+		})
+	}
+}
+
+type mockServerStream struct {
+	grpc.ServerStream
+	ctx context.Context
+}
+
+func (s *mockServerStream) Context() context.Context {
+	return s.ctx
+}
diff --git a/pkg/mocks/authn/mock_JWTVerifier.go b/pkg/mocks/authn/mock_JWTVerifier.go
new file mode 100644
index 00000000..009928a2
--- /dev/null
+++ b/pkg/mocks/authn/mock_JWTVerifier.go
@@ -0,0 +1,78 @@
+// Code generated by mockery v2.44.1. DO NOT EDIT.
+
+package authn
+
+import mock "github.com/stretchr/testify/mock"
+
+// MockJWTVerifier is an autogenerated mock type for the JWTVerifier type
+type MockJWTVerifier struct {
+	mock.Mock
+}
+
+type MockJWTVerifier_Expecter struct {
+	mock *mock.Mock
+}
+
+func (_m *MockJWTVerifier) EXPECT() *MockJWTVerifier_Expecter {
+	return &MockJWTVerifier_Expecter{mock: &_m.Mock}
+}
+
+// Verify provides a mock function with given fields: tokenString
+func (_m *MockJWTVerifier) Verify(tokenString string) error {
+	ret := _m.Called(tokenString)
+
+	if len(ret) == 0 {
+		panic("no return value specified for Verify")
+	}
+
+	var r0 error
+	if rf, ok := ret.Get(0).(func(string) error); ok {
+		r0 = rf(tokenString)
+	} else {
+		r0 = ret.Error(0)
+	}
+
+	return r0
+}
+
+// MockJWTVerifier_Verify_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Verify'
+type MockJWTVerifier_Verify_Call struct {
+	*mock.Call
+}
+
+// Verify is a helper method to define mock.On call
+//   - tokenString string
+func (_e *MockJWTVerifier_Expecter) Verify(tokenString interface{}) *MockJWTVerifier_Verify_Call {
+	return &MockJWTVerifier_Verify_Call{Call: _e.mock.On("Verify", tokenString)}
+}
+
+func (_c *MockJWTVerifier_Verify_Call) Run(run func(tokenString string)) *MockJWTVerifier_Verify_Call {
+	_c.Call.Run(func(args mock.Arguments) {
+		run(args[0].(string))
+	})
+	return _c
+}
+
+func (_c *MockJWTVerifier_Verify_Call) Return(_a0 error) *MockJWTVerifier_Verify_Call {
+	_c.Call.Return(_a0)
+	return _c
+}
+
+func (_c *MockJWTVerifier_Verify_Call) RunAndReturn(run func(string) error) *MockJWTVerifier_Verify_Call {
+	_c.Call.Return(run)
+	return _c
+}
+
+// NewMockJWTVerifier creates a new instance of MockJWTVerifier. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
+// The first argument is typically a *testing.T value.
+func NewMockJWTVerifier(t interface {
+	mock.TestingT
+	Cleanup(func())
+}) *MockJWTVerifier {
+	mock := &MockJWTVerifier{}
+	mock.Mock.Test(t)
+
+	t.Cleanup(func() { mock.AssertExpectations(t) })
+
+	return mock
+}