Skip to content

Commit

Permalink
Add JWT verification (#317)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Added JWT validation for gRPC server streams, enhancing security for
authenticated requests.
- Introduced a new mock implementation for the JWTVerifier interface to
facilitate testing.

- **Bug Fixes**
- Improved error handling for missing or invalid JWT tokens in both
unary and stream interceptors.

- **Tests**
- Added comprehensive unit tests for the AuthInterceptor functionality,
covering various scenarios and edge cases.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
neekolas authored Dec 20, 2024
1 parent f223c36 commit c48f822
Show file tree
Hide file tree
Showing 5 changed files with 406 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .mockery.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ outpkg: "{{.PackageName}}"
dir: "pkg/mocks/{{.PackageName}}"
filename: "mock_{{.InterfaceName}}.go"
packages:
github.com/xmtp/xmtpd/pkg/authn:
interfaces:
JWTVerifier:
github.com/xmtp/xmtpd/pkg/mlsvalidate:
interfaces:
MLSValidationService:
Expand Down
2 changes: 2 additions & 0 deletions pkg/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ const (
NODE_AUTHORIZATION_HEADER_NAME = "node-authorization"
MAX_BLOCKCHAIN_ORIGINATOR_ID = 100
)

type VerifiedNodeRequestCtxKey struct{}
117 changes: 117 additions & 0 deletions pkg/interceptors/server/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package server

import (
"context"

"github.com/xmtp/xmtpd/pkg/authn"
"github.com/xmtp/xmtpd/pkg/constants"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

// wrappedServerStream allows us to modify the context of the stream
type wrappedServerStream struct {
grpc.ServerStream
ctx context.Context
}

func (w *wrappedServerStream) Context() context.Context {
return w.ctx
}

// AuthInterceptor validates JWT tokens from other nodes
type AuthInterceptor struct {
verifier authn.JWTVerifier
logger *zap.Logger
}

func NewAuthInterceptor(verifier authn.JWTVerifier, logger *zap.Logger) *AuthInterceptor {
return &AuthInterceptor{
verifier: verifier,
logger: logger,
}
}

// extractToken gets the JWT token from the request metadata
func extractToken(ctx context.Context) (string, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return "", status.Error(codes.Unauthenticated, "missing metadata")
}

values := md.Get(constants.NODE_AUTHORIZATION_HEADER_NAME)
if len(values) == 0 {
return "", status.Error(codes.Unauthenticated, "missing auth token")
}

if len(values) > 1 {
return "", status.Error(codes.Unauthenticated, "multiple auth tokens provided")
}

return values[0], nil
}

// Unary returns a grpc.UnaryServerInterceptor that validates JWT tokens
func (i *AuthInterceptor) Unary() grpc.UnaryServerInterceptor {
return func(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
token, err := extractToken(ctx)
if err != nil {
i.logger.Debug("failed to find auth token. Allowing request to proceed", zap.Error(err))
return handler(ctx, req)
}

if err := i.verifier.Verify(token); err != nil {
return nil, status.Errorf(
codes.Unauthenticated,
"invalid auth token: %v",
err,
)
}
ctx = context.WithValue(ctx, constants.VerifiedNodeRequestCtxKey{}, true)

return handler(ctx, req)
}
}

// Stream returns a grpc.StreamServerInterceptor that validates JWT tokens
func (i *AuthInterceptor) Stream() grpc.StreamServerInterceptor {
return func(
srv interface{},
stream grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
token, err := extractToken(stream.Context())
if err != nil {
i.logger.Debug("failed to find auth token. Allowing request to proceed", zap.Error(err))
return handler(srv, stream)
}

if err := i.verifier.Verify(token); err != nil {
return status.Errorf(
codes.Unauthenticated,
"invalid auth token: %v",
err,
)
}

stream = &wrappedServerStream{
ServerStream: stream,
ctx: context.WithValue(
stream.Context(),
constants.VerifiedNodeRequestCtxKey{},
true,
),
}

return handler(srv, stream)
}
}
206 changes: 206 additions & 0 deletions pkg/interceptors/server/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
package server

import (
"context"
"errors"
"testing"

"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/constants"
"github.com/xmtp/xmtpd/pkg/mocks/authn"
"go.uber.org/zap/zaptest"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

func TestUnaryInterceptor(t *testing.T) {
mockVerifier := authn.NewMockJWTVerifier(t)
logger := zaptest.NewLogger(t)
interceptor := NewAuthInterceptor(mockVerifier, logger)

tests := []struct {
name string
setupContext func() context.Context
setupVerifier func()
wantError error
wantVerifiedNode bool
}{
{
name: "valid token",
setupContext: func() context.Context {
md := metadata.New(map[string]string{
constants.NODE_AUTHORIZATION_HEADER_NAME: "valid_token",
})
return metadata.NewIncomingContext(context.Background(), md)
},
setupVerifier: func() {
mockVerifier.EXPECT().Verify("valid_token").Return(nil)
},
wantError: nil,
wantVerifiedNode: true,
},
{
name: "missing metadata",
setupContext: func() context.Context {
return context.Background()
},
setupVerifier: func() {},
wantError: nil,
wantVerifiedNode: false,
},
{
name: "missing token",
setupContext: func() context.Context {
md := metadata.New(map[string]string{})
return metadata.NewIncomingContext(context.Background(), md)
},
setupVerifier: func() {},
wantError: nil,
wantVerifiedNode: false,
},
{
name: "invalid token",
setupContext: func() context.Context {
md := metadata.New(map[string]string{
constants.NODE_AUTHORIZATION_HEADER_NAME: "invalid_token",
})
return metadata.NewIncomingContext(context.Background(), md)
},
setupVerifier: func() {
mockVerifier.EXPECT().
Verify("invalid_token").
Return(errors.New("invalid signature"))
},
wantError: status.Error(
codes.Unauthenticated,
"invalid auth token: invalid signature",
),
wantVerifiedNode: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupVerifier()

ctx := tt.setupContext()
var handlerCtx context.Context
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCtx = ctx
return "ok", nil
}

_, err := interceptor.Unary()(ctx, nil, &grpc.UnaryServerInfo{}, handler)

if tt.wantError != nil {
require.Error(t, err)
require.Equal(t, tt.wantError.Error(), err.Error())
} else {
require.NoError(t, err)
isVerified, hasContextValue := handlerCtx.Value(constants.VerifiedNodeRequestCtxKey{}).(bool)
if tt.wantVerifiedNode {
require.True(t, isVerified)
} else {
require.False(t, hasContextValue)
}
}
})
}
}

func TestStreamInterceptor(t *testing.T) {
mockVerifier := authn.NewMockJWTVerifier(t)
logger := zaptest.NewLogger(t)
interceptor := NewAuthInterceptor(mockVerifier, logger)

tests := []struct {
name string
setupContext func() context.Context
setupVerifier func()
wantError error
wantVerifiedNode bool
}{
{
name: "valid token",
setupContext: func() context.Context {
md := metadata.New(map[string]string{
constants.NODE_AUTHORIZATION_HEADER_NAME: "valid_token",
})
return metadata.NewIncomingContext(context.Background(), md)
},
setupVerifier: func() {
mockVerifier.EXPECT().Verify("valid_token").Return(nil)
},
wantError: nil,
wantVerifiedNode: true,
},
{
name: "missing metadata",
setupContext: func() context.Context {
return context.Background()
},
setupVerifier: func() {},
wantError: nil,
wantVerifiedNode: false,
},
{
name: "invalid token",
setupContext: func() context.Context {
md := metadata.New(map[string]string{
constants.NODE_AUTHORIZATION_HEADER_NAME: "invalid_token",
})
return metadata.NewIncomingContext(context.Background(), md)
},
setupVerifier: func() {
mockVerifier.EXPECT().
Verify("invalid_token").
Return(errors.New("invalid signature"))
},
wantError: status.Error(
codes.Unauthenticated,
"invalid auth token: invalid signature",
),
wantVerifiedNode: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupVerifier()

ctx := tt.setupContext()
var handlerStream grpc.ServerStream
stream := &mockServerStreamWithCtx{ctx: ctx}
handler := func(srv interface{}, stream grpc.ServerStream) error {
handlerStream = stream
return nil
}

err := interceptor.Stream()(nil, stream, &grpc.StreamServerInfo{}, handler)

if tt.wantError != nil {
require.Error(t, err)
require.Equal(t, tt.wantError.Error(), err.Error())
} else {
require.NoError(t, err)
isVerified, hasContextValue := handlerStream.Context().Value(constants.VerifiedNodeRequestCtxKey{}).(bool)
if tt.wantVerifiedNode {
require.True(t, isVerified)
} else {
require.False(t, hasContextValue)
}
}
})
}
}

type mockServerStreamWithCtx struct {
grpc.ServerStream
ctx context.Context
}

func (s *mockServerStreamWithCtx) Context() context.Context {
return s.ctx
}
Loading

0 comments on commit c48f822

Please sign in to comment.