Skip to content

Commit

Permalink
Add JWT verification
Browse files Browse the repository at this point in the history
  • Loading branch information
neekolas committed Dec 10, 2024
1 parent b9e0ffc commit 0c6d7b1
Show file tree
Hide file tree
Showing 7 changed files with 408 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .mockery.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ outpkg: "{{.PackageName}}"
dir: "pkg/mocks/{{.PackageName}}"
filename: "mock_{{.InterfaceName}}.go"
packages:
github.com/xmtp/xmtpd/pkg/authn:
interfaces:
JWTVerifier:
github.com/xmtp/xmtpd/pkg/mlsvalidate:
interfaces:
MLSValidationService:
Expand Down
1 change: 1 addition & 0 deletions contracts/lib/forge-std
Submodule forge-std added at 035de3
1 change: 1 addition & 0 deletions contracts/lib/openzeppelin-contracts
Submodule openzeppelin-contracts added at 8b591b
2 changes: 2 additions & 0 deletions pkg/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ const (
NODE_AUTHORIZATION_HEADER_NAME = "node-authorization"
MAX_BLOCKCHAIN_ORIGINATOR_ID = 100
)

type VERIFIED_NODE_REQUEST_CTX_KEY = struct{}
117 changes: 117 additions & 0 deletions pkg/interceptors/server/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package server

import (
"context"

"github.com/xmtp/xmtpd/pkg/authn"
"github.com/xmtp/xmtpd/pkg/constants"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

// wrappedServerStream allows us to modify the context of the stream
type wrappedServerStream struct {
grpc.ServerStream
ctx context.Context
}

func (w *wrappedServerStream) Context() context.Context {
return w.ctx
}

// AuthInterceptor validates JWT tokens from other nodes
type AuthInterceptor struct {
verifier authn.JWTVerifier
logger *zap.Logger
}

func NewAuthInterceptor(verifier authn.JWTVerifier, logger *zap.Logger) *AuthInterceptor {
return &AuthInterceptor{
verifier: verifier,
logger: logger,
}
}

// extractToken gets the JWT token from the request metadata
func extractToken(ctx context.Context) (string, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return "", status.Error(codes.Unauthenticated, "missing metadata")
}

values := md.Get(constants.NODE_AUTHORIZATION_HEADER_NAME)
if len(values) == 0 {
return "", status.Error(codes.Unauthenticated, "missing auth token")
}

if len(values) > 1 {
return "", status.Error(codes.Unauthenticated, "multiple auth tokens provided")
}

return values[0], nil
}

// Unary returns a grpc.UnaryServerInterceptor that validates JWT tokens
func (i *AuthInterceptor) Unary() grpc.UnaryServerInterceptor {
return func(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
token, err := extractToken(ctx)
if err != nil {
i.logger.Debug("failed to find auth token. Allowing request to proceed", zap.Error(err))
return handler(ctx, req)
}

if err := i.verifier.Verify(token); err != nil {
return nil, status.Errorf(
codes.Unauthenticated,
"invalid auth token: %v",
err,
)
}
ctx = context.WithValue(ctx, constants.VERIFIED_NODE_REQUEST_CTX_KEY{}, true)

Check failure on line 78 in pkg/interceptors/server/auth.go

View workflow job for this annotation

GitHub Actions / Lint

SA1029: should not use empty anonymous struct as key for value; define your own type to avoid collisions (staticcheck)

return handler(ctx, req)
}
}

// Stream returns a grpc.StreamServerInterceptor that validates JWT tokens
func (i *AuthInterceptor) Stream() grpc.StreamServerInterceptor {
return func(
srv interface{},
stream grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
token, err := extractToken(stream.Context())
if err != nil {
i.logger.Debug("failed to find auth token. Allowing request to proceed", zap.Error(err))
return handler(srv, stream)
}

if err := i.verifier.Verify(token); err != nil {
return status.Errorf(
codes.Unauthenticated,
"invalid auth token: %v",
err,
)
}

stream = &wrappedServerStream{
ServerStream: stream,
ctx: context.WithValue(
stream.Context(),
constants.VERIFIED_NODE_REQUEST_CTX_KEY{},

Check failure on line 110 in pkg/interceptors/server/auth.go

View workflow job for this annotation

GitHub Actions / Lint

SA1029: should not use empty anonymous struct as key for value; define your own type to avoid collisions (staticcheck)
true,
),
}

return handler(srv, stream)
}
}
206 changes: 206 additions & 0 deletions pkg/interceptors/server/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
package server

import (
"context"
"errors"
"testing"

"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/constants"
"github.com/xmtp/xmtpd/pkg/mocks/authn"
"go.uber.org/zap/zaptest"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

func TestUnaryInterceptor(t *testing.T) {
mockVerifier := authn.NewMockJWTVerifier(t)
logger := zaptest.NewLogger(t)
interceptor := NewAuthInterceptor(mockVerifier, logger)

tests := []struct {
name string
setupContext func() context.Context
setupVerifier func()
wantError error
wantVerifiedNode bool
}{
{
name: "valid token",
setupContext: func() context.Context {
md := metadata.New(map[string]string{
constants.NODE_AUTHORIZATION_HEADER_NAME: "valid_token",
})
return metadata.NewIncomingContext(context.Background(), md)
},
setupVerifier: func() {
mockVerifier.EXPECT().Verify("valid_token").Return(nil)
},
wantError: nil,
wantVerifiedNode: true,
},
{
name: "missing metadata",
setupContext: func() context.Context {
return context.Background()
},
setupVerifier: func() {},
wantError: nil,
wantVerifiedNode: false,
},
{
name: "missing token",
setupContext: func() context.Context {
md := metadata.New(map[string]string{})
return metadata.NewIncomingContext(context.Background(), md)
},
setupVerifier: func() {},
wantError: nil,
wantVerifiedNode: false,
},
{
name: "invalid token",
setupContext: func() context.Context {
md := metadata.New(map[string]string{
constants.NODE_AUTHORIZATION_HEADER_NAME: "invalid_token",
})
return metadata.NewIncomingContext(context.Background(), md)
},
setupVerifier: func() {
mockVerifier.EXPECT().
Verify("invalid_token").
Return(errors.New("invalid signature"))
},
wantError: status.Error(
codes.Unauthenticated,
"invalid auth token: invalid signature",
),
wantVerifiedNode: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupVerifier()

ctx := tt.setupContext()
var handlerCtx context.Context
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCtx = ctx
return "ok", nil
}

_, err := interceptor.Unary()(ctx, nil, &grpc.UnaryServerInfo{}, handler)

if tt.wantError != nil {
require.Error(t, err)
require.Equal(t, tt.wantError.Error(), err.Error())
} else {
require.NoError(t, err)
isVerified, hasContextValue := handlerCtx.Value(constants.VERIFIED_NODE_REQUEST_CTX_KEY{}).(bool)
if tt.wantVerifiedNode {
require.True(t, isVerified)
} else {
require.False(t, hasContextValue)
}
}
})
}
}

func TestStreamInterceptor(t *testing.T) {
mockVerifier := authn.NewMockJWTVerifier(t)
logger := zaptest.NewLogger(t)
interceptor := NewAuthInterceptor(mockVerifier, logger)

tests := []struct {
name string
setupContext func() context.Context
setupVerifier func()
wantError error
wantVerifiedNode bool
}{
{
name: "valid token",
setupContext: func() context.Context {
md := metadata.New(map[string]string{
constants.NODE_AUTHORIZATION_HEADER_NAME: "valid_token",
})
return metadata.NewIncomingContext(context.Background(), md)
},
setupVerifier: func() {
mockVerifier.EXPECT().Verify("valid_token").Return(nil)
},
wantError: nil,
wantVerifiedNode: true,
},
{
name: "missing metadata",
setupContext: func() context.Context {
return context.Background()
},
setupVerifier: func() {},
wantError: nil,
wantVerifiedNode: false,
},
{
name: "invalid token",
setupContext: func() context.Context {
md := metadata.New(map[string]string{
constants.NODE_AUTHORIZATION_HEADER_NAME: "invalid_token",
})
return metadata.NewIncomingContext(context.Background(), md)
},
setupVerifier: func() {
mockVerifier.EXPECT().
Verify("invalid_token").
Return(errors.New("invalid signature"))
},
wantError: status.Error(
codes.Unauthenticated,
"invalid auth token: invalid signature",
),
wantVerifiedNode: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupVerifier()

ctx := tt.setupContext()
var handlerStream grpc.ServerStream
stream := &mockServerStream{ctx: ctx}
handler := func(srv interface{}, stream grpc.ServerStream) error {
handlerStream = stream
return nil
}

err := interceptor.Stream()(nil, stream, &grpc.StreamServerInfo{}, handler)

if tt.wantError != nil {
require.Error(t, err)
require.Equal(t, tt.wantError.Error(), err.Error())
} else {
require.NoError(t, err)
isVerified, hasContextValue := handlerStream.Context().Value(constants.VERIFIED_NODE_REQUEST_CTX_KEY{}).(bool)
if tt.wantVerifiedNode {
require.True(t, isVerified)
} else {
require.False(t, hasContextValue)
}
}
})
}
}

type mockServerStream struct {
grpc.ServerStream
ctx context.Context
}

func (s *mockServerStream) Context() context.Context {
return s.ctx
}
Loading

0 comments on commit 0c6d7b1

Please sign in to comment.