-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
408 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Submodule openzeppelin-contracts
added at
8b591b
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
package server | ||
|
||
import ( | ||
"context" | ||
|
||
"github.com/xmtp/xmtpd/pkg/authn" | ||
"github.com/xmtp/xmtpd/pkg/constants" | ||
"go.uber.org/zap" | ||
"google.golang.org/grpc" | ||
"google.golang.org/grpc/codes" | ||
"google.golang.org/grpc/metadata" | ||
"google.golang.org/grpc/status" | ||
) | ||
|
||
// wrappedServerStream allows us to modify the context of the stream | ||
type wrappedServerStream struct { | ||
grpc.ServerStream | ||
ctx context.Context | ||
} | ||
|
||
func (w *wrappedServerStream) Context() context.Context { | ||
return w.ctx | ||
} | ||
|
||
// AuthInterceptor validates JWT tokens from other nodes | ||
type AuthInterceptor struct { | ||
verifier authn.JWTVerifier | ||
logger *zap.Logger | ||
} | ||
|
||
func NewAuthInterceptor(verifier authn.JWTVerifier, logger *zap.Logger) *AuthInterceptor { | ||
return &AuthInterceptor{ | ||
verifier: verifier, | ||
logger: logger, | ||
} | ||
} | ||
|
||
// extractToken gets the JWT token from the request metadata | ||
func extractToken(ctx context.Context) (string, error) { | ||
md, ok := metadata.FromIncomingContext(ctx) | ||
if !ok { | ||
return "", status.Error(codes.Unauthenticated, "missing metadata") | ||
} | ||
|
||
values := md.Get(constants.NODE_AUTHORIZATION_HEADER_NAME) | ||
if len(values) == 0 { | ||
return "", status.Error(codes.Unauthenticated, "missing auth token") | ||
} | ||
|
||
if len(values) > 1 { | ||
return "", status.Error(codes.Unauthenticated, "multiple auth tokens provided") | ||
} | ||
|
||
return values[0], nil | ||
} | ||
|
||
// Unary returns a grpc.UnaryServerInterceptor that validates JWT tokens | ||
func (i *AuthInterceptor) Unary() grpc.UnaryServerInterceptor { | ||
return func( | ||
ctx context.Context, | ||
req interface{}, | ||
info *grpc.UnaryServerInfo, | ||
handler grpc.UnaryHandler, | ||
) (interface{}, error) { | ||
token, err := extractToken(ctx) | ||
if err != nil { | ||
i.logger.Debug("failed to find auth token. Allowing request to proceed", zap.Error(err)) | ||
return handler(ctx, req) | ||
} | ||
|
||
if err := i.verifier.Verify(token); err != nil { | ||
return nil, status.Errorf( | ||
codes.Unauthenticated, | ||
"invalid auth token: %v", | ||
err, | ||
) | ||
} | ||
ctx = context.WithValue(ctx, constants.VERIFIED_NODE_REQUEST_CTX_KEY{}, true) | ||
|
||
return handler(ctx, req) | ||
} | ||
} | ||
|
||
// Stream returns a grpc.StreamServerInterceptor that validates JWT tokens | ||
func (i *AuthInterceptor) Stream() grpc.StreamServerInterceptor { | ||
return func( | ||
srv interface{}, | ||
stream grpc.ServerStream, | ||
info *grpc.StreamServerInfo, | ||
handler grpc.StreamHandler, | ||
) error { | ||
token, err := extractToken(stream.Context()) | ||
if err != nil { | ||
i.logger.Debug("failed to find auth token. Allowing request to proceed", zap.Error(err)) | ||
return handler(srv, stream) | ||
} | ||
|
||
if err := i.verifier.Verify(token); err != nil { | ||
return status.Errorf( | ||
codes.Unauthenticated, | ||
"invalid auth token: %v", | ||
err, | ||
) | ||
} | ||
|
||
stream = &wrappedServerStream{ | ||
ServerStream: stream, | ||
ctx: context.WithValue( | ||
stream.Context(), | ||
constants.VERIFIED_NODE_REQUEST_CTX_KEY{}, | ||
true, | ||
), | ||
} | ||
|
||
return handler(srv, stream) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
package server | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
"github.com/xmtp/xmtpd/pkg/constants" | ||
"github.com/xmtp/xmtpd/pkg/mocks/authn" | ||
"go.uber.org/zap/zaptest" | ||
"google.golang.org/grpc" | ||
"google.golang.org/grpc/codes" | ||
"google.golang.org/grpc/metadata" | ||
"google.golang.org/grpc/status" | ||
) | ||
|
||
func TestUnaryInterceptor(t *testing.T) { | ||
mockVerifier := authn.NewMockJWTVerifier(t) | ||
logger := zaptest.NewLogger(t) | ||
interceptor := NewAuthInterceptor(mockVerifier, logger) | ||
|
||
tests := []struct { | ||
name string | ||
setupContext func() context.Context | ||
setupVerifier func() | ||
wantError error | ||
wantVerifiedNode bool | ||
}{ | ||
{ | ||
name: "valid token", | ||
setupContext: func() context.Context { | ||
md := metadata.New(map[string]string{ | ||
constants.NODE_AUTHORIZATION_HEADER_NAME: "valid_token", | ||
}) | ||
return metadata.NewIncomingContext(context.Background(), md) | ||
}, | ||
setupVerifier: func() { | ||
mockVerifier.EXPECT().Verify("valid_token").Return(nil) | ||
}, | ||
wantError: nil, | ||
wantVerifiedNode: true, | ||
}, | ||
{ | ||
name: "missing metadata", | ||
setupContext: func() context.Context { | ||
return context.Background() | ||
}, | ||
setupVerifier: func() {}, | ||
wantError: nil, | ||
wantVerifiedNode: false, | ||
}, | ||
{ | ||
name: "missing token", | ||
setupContext: func() context.Context { | ||
md := metadata.New(map[string]string{}) | ||
return metadata.NewIncomingContext(context.Background(), md) | ||
}, | ||
setupVerifier: func() {}, | ||
wantError: nil, | ||
wantVerifiedNode: false, | ||
}, | ||
{ | ||
name: "invalid token", | ||
setupContext: func() context.Context { | ||
md := metadata.New(map[string]string{ | ||
constants.NODE_AUTHORIZATION_HEADER_NAME: "invalid_token", | ||
}) | ||
return metadata.NewIncomingContext(context.Background(), md) | ||
}, | ||
setupVerifier: func() { | ||
mockVerifier.EXPECT(). | ||
Verify("invalid_token"). | ||
Return(errors.New("invalid signature")) | ||
}, | ||
wantError: status.Error( | ||
codes.Unauthenticated, | ||
"invalid auth token: invalid signature", | ||
), | ||
wantVerifiedNode: false, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
tt.setupVerifier() | ||
|
||
ctx := tt.setupContext() | ||
var handlerCtx context.Context | ||
handler := func(ctx context.Context, req interface{}) (interface{}, error) { | ||
handlerCtx = ctx | ||
return "ok", nil | ||
} | ||
|
||
_, err := interceptor.Unary()(ctx, nil, &grpc.UnaryServerInfo{}, handler) | ||
|
||
if tt.wantError != nil { | ||
require.Error(t, err) | ||
require.Equal(t, tt.wantError.Error(), err.Error()) | ||
} else { | ||
require.NoError(t, err) | ||
isVerified, hasContextValue := handlerCtx.Value(constants.VERIFIED_NODE_REQUEST_CTX_KEY{}).(bool) | ||
if tt.wantVerifiedNode { | ||
require.True(t, isVerified) | ||
} else { | ||
require.False(t, hasContextValue) | ||
} | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestStreamInterceptor(t *testing.T) { | ||
mockVerifier := authn.NewMockJWTVerifier(t) | ||
logger := zaptest.NewLogger(t) | ||
interceptor := NewAuthInterceptor(mockVerifier, logger) | ||
|
||
tests := []struct { | ||
name string | ||
setupContext func() context.Context | ||
setupVerifier func() | ||
wantError error | ||
wantVerifiedNode bool | ||
}{ | ||
{ | ||
name: "valid token", | ||
setupContext: func() context.Context { | ||
md := metadata.New(map[string]string{ | ||
constants.NODE_AUTHORIZATION_HEADER_NAME: "valid_token", | ||
}) | ||
return metadata.NewIncomingContext(context.Background(), md) | ||
}, | ||
setupVerifier: func() { | ||
mockVerifier.EXPECT().Verify("valid_token").Return(nil) | ||
}, | ||
wantError: nil, | ||
wantVerifiedNode: true, | ||
}, | ||
{ | ||
name: "missing metadata", | ||
setupContext: func() context.Context { | ||
return context.Background() | ||
}, | ||
setupVerifier: func() {}, | ||
wantError: nil, | ||
wantVerifiedNode: false, | ||
}, | ||
{ | ||
name: "invalid token", | ||
setupContext: func() context.Context { | ||
md := metadata.New(map[string]string{ | ||
constants.NODE_AUTHORIZATION_HEADER_NAME: "invalid_token", | ||
}) | ||
return metadata.NewIncomingContext(context.Background(), md) | ||
}, | ||
setupVerifier: func() { | ||
mockVerifier.EXPECT(). | ||
Verify("invalid_token"). | ||
Return(errors.New("invalid signature")) | ||
}, | ||
wantError: status.Error( | ||
codes.Unauthenticated, | ||
"invalid auth token: invalid signature", | ||
), | ||
wantVerifiedNode: false, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
tt.setupVerifier() | ||
|
||
ctx := tt.setupContext() | ||
var handlerStream grpc.ServerStream | ||
stream := &mockServerStream{ctx: ctx} | ||
handler := func(srv interface{}, stream grpc.ServerStream) error { | ||
handlerStream = stream | ||
return nil | ||
} | ||
|
||
err := interceptor.Stream()(nil, stream, &grpc.StreamServerInfo{}, handler) | ||
|
||
if tt.wantError != nil { | ||
require.Error(t, err) | ||
require.Equal(t, tt.wantError.Error(), err.Error()) | ||
} else { | ||
require.NoError(t, err) | ||
isVerified, hasContextValue := handlerStream.Context().Value(constants.VERIFIED_NODE_REQUEST_CTX_KEY{}).(bool) | ||
if tt.wantVerifiedNode { | ||
require.True(t, isVerified) | ||
} else { | ||
require.False(t, hasContextValue) | ||
} | ||
} | ||
}) | ||
} | ||
} | ||
|
||
type mockServerStream struct { | ||
grpc.ServerStream | ||
ctx context.Context | ||
} | ||
|
||
func (s *mockServerStream) Context() context.Context { | ||
return s.ctx | ||
} |
Oops, something went wrong.