Skip to content

Commit

Permalink
Log incoming metadata of JWT verified nodes (#402)
Browse files Browse the repository at this point in the history
We want to see what nodes are connected to us, so we can reason about
the distributed state of the system.

We only collect metadata for nodes, not clients, to preserve their
anonymity. IPs in general are considered personally identifiable
information, so we do not want to collect those.

Fixes #391 

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
	- Enhanced JWT verification to return node ID alongside error status
- Added logging of incoming client address and DNS name during
authentication

- **Bug Fixes**
	- Improved token verification process with more detailed error handling

- **Refactor**
- Updated authentication-related method signatures to provide more
context during verification
- Modified mock and test implementations to support new verification
approach

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
mkysel authored Jan 10, 2025
1 parent 235c2be commit ee518bc
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 22 deletions.
4 changes: 2 additions & 2 deletions pkg/authn/claims_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func TestClaimsVerifierNoVersion(t *testing.T) {

token, err := tokenFactory.CreateToken(uint32(VERIFIER_NODE_ID))
require.NoError(t, err)
verificationError := verifier.Verify(token.SignedString)
_, verificationError := verifier.Verify(token.SignedString)
if tt.wantErr {
require.Error(t, verificationError)
} else {
Expand Down Expand Up @@ -119,7 +119,7 @@ func TestClaimsVerifier(t *testing.T) {

token, err := tokenFactory.CreateToken(uint32(VERIFIER_NODE_ID))
require.NoError(t, err)
verificationError := verifier.Verify(token.SignedString)
_, verificationError := verifier.Verify(token.SignedString)
if tt.wantErr {
require.Error(t, verificationError)
} else {
Expand Down
2 changes: 1 addition & 1 deletion pkg/authn/interface.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package authn

type JWTVerifier interface {
Verify(tokenString string) error
Verify(tokenString string) (uint32, error)
}
12 changes: 6 additions & 6 deletions pkg/authn/verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func NewRegistryVerifier(registry registry.NodeRegistry, myNodeID uint32) *Regis
return &RegistryVerifier{registry: registry, myNodeID: myNodeID}
}

func (v *RegistryVerifier) Verify(tokenString string) error {
func (v *RegistryVerifier) Verify(tokenString string) (uint32, error) {
var token *jwt.Token
var err error

Expand All @@ -36,21 +36,21 @@ func (v *RegistryVerifier) Verify(tokenString string) error {
&XmtpdClaims{},
v.getMatchingPublicKey,
); err != nil {
return err
return 0, err
}
if err = v.validateAudience(token); err != nil {
return err
return 0, err
}

if err = validateExpiry(token); err != nil {
return err
return 0, err
}

if err = v.validateClaims(token); err != nil {
return err
return 0, err
}

return nil
return getSubjectNodeId(token)
}

func (v *RegistryVerifier) getMatchingPublicKey(token *jwt.Token) (interface{}, error) {
Expand Down
18 changes: 9 additions & 9 deletions pkg/authn/verifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ func TestVerifier(t *testing.T) {
token, err := tokenFactory.CreateToken(uint32(VERIFIER_NODE_ID))
require.NoError(t, err)
// This should verify correctly
verificationError := verifier.Verify(token.SignedString)
_, verificationError := verifier.Verify(token.SignedString)
require.NoError(t, verificationError)

// Create a token targeting a different node as the audience
tokenForWrongNode, err := tokenFactory.CreateToken(uint32(300))
require.NoError(t, err)
// This should not verify correctly
verificationError = verifier.Verify(tokenForWrongNode.SignedString)
_, verificationError = verifier.Verify(tokenForWrongNode.SignedString)
require.Error(t, verificationError)
}

Expand All @@ -90,7 +90,7 @@ func TestWrongAudience(t *testing.T) {
tokenForWrongNode, err := tokenFactory.CreateToken(uint32(300))
require.NoError(t, err)
// This should not verify correctly
verificationError := verifier.Verify(tokenForWrongNode.SignedString)
_, verificationError := verifier.Verify(tokenForWrongNode.SignedString)
require.Error(t, verificationError)
}

Expand All @@ -105,7 +105,7 @@ func TestUnknownNode(t *testing.T) {
token, err := tokenFactory.CreateToken(uint32(VERIFIER_NODE_ID))
require.NoError(t, err)

verificationError := verifier.Verify(token.SignedString)
_, verificationError := verifier.Verify(token.SignedString)
require.Error(t, verificationError)
}

Expand All @@ -125,7 +125,7 @@ func TestWrongPublicKey(t *testing.T) {
token, err := tokenFactory.CreateToken(uint32(VERIFIER_NODE_ID))
require.NoError(t, err)

verificationError := verifier.Verify(token.SignedString)
_, verificationError := verifier.Verify(token.SignedString)
require.Error(t, verificationError)
}

Expand All @@ -147,7 +147,7 @@ func TestExpiredToken(t *testing.T) {
time.Now().Add(-time.Hour),
)

verificationError := verifier.Verify(signedString)
_, verificationError := verifier.Verify(signedString)
require.Error(t, verificationError)
}

Expand All @@ -169,7 +169,7 @@ func TestTokenDurationTooLong(t *testing.T) {
time.Now().Add(5*time.Hour),
)

verificationError := verifier.Verify(signedString)
_, verificationError := verifier.Verify(signedString)
require.Error(t, verificationError)
}

Expand All @@ -192,7 +192,7 @@ func TestTokenClockSkew(t *testing.T) {
time.Now().Add(1*time.Hour),
)

verificationError := verifier.Verify(validToken)
_, verificationError := verifier.Verify(validToken)
require.NoError(t, verificationError)

invalidToken := buildJwt(
Expand All @@ -204,6 +204,6 @@ func TestTokenClockSkew(t *testing.T) {
time.Now().Add(1*time.Hour),
)

verificationError = verifier.Verify(invalidToken)
_, verificationError = verifier.Verify(invalidToken)
require.Error(t, verificationError)
}
39 changes: 37 additions & 2 deletions pkg/interceptors/server/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package server

import (
"context"
"google.golang.org/grpc/peer"
"net"

"github.com/xmtp/xmtpd/pkg/authn"
"github.com/xmtp/xmtpd/pkg/constants"
Expand Down Expand Up @@ -54,6 +56,31 @@ func extractToken(ctx context.Context) (string, error) {
return values[0], nil
}

func (i *AuthInterceptor) logIncomingAddressIfAvailable(ctx context.Context, nodeId uint32) {
if i.logger.Core().Enabled(zap.DebugLevel) {
if p, ok := peer.FromContext(ctx); ok {
clientAddr := p.Addr.String()
var dnsName []string
// Attempt to resolve the DNS name
host, _, err := net.SplitHostPort(clientAddr)
if err == nil {
dnsName, err = net.LookupAddr(host)
if err != nil || len(dnsName) == 0 {
dnsName = []string{"Unknown"}
}
} else {
dnsName = []string{"Unknown"}
}
i.logger.Debug(
"Incoming connection",
zap.String("client_addr", clientAddr),
zap.String("dns_name", dnsName[0]),
zap.Uint32("node_id", nodeId),
)
}
}
}

// Unary returns a grpc.UnaryServerInterceptor that validates JWT tokens
func (i *AuthInterceptor) Unary() grpc.UnaryServerInterceptor {
return func(
Expand All @@ -68,13 +95,17 @@ func (i *AuthInterceptor) Unary() grpc.UnaryServerInterceptor {
return handler(ctx, req)
}

if err := i.verifier.Verify(token); err != nil {
nodeId, err := i.verifier.Verify(token)
if err != nil {
return nil, status.Errorf(
codes.Unauthenticated,
"invalid auth token: %v",
err,
)
}

i.logIncomingAddressIfAvailable(ctx, nodeId)

ctx = context.WithValue(ctx, constants.VerifiedNodeRequestCtxKey{}, true)

return handler(ctx, req)
Expand All @@ -95,14 +126,18 @@ func (i *AuthInterceptor) Stream() grpc.StreamServerInterceptor {
return handler(srv, stream)
}

if err := i.verifier.Verify(token); err != nil {
nodeId, err := i.verifier.Verify(token)

if err != nil {
return status.Errorf(
codes.Unauthenticated,
"invalid auth token: %v",
err,
)
}

i.logIncomingAddressIfAvailable(stream.Context(), nodeId)

stream = &wrappedServerStream{
ServerStream: stream,
ctx: context.WithValue(
Expand Down
4 changes: 2 additions & 2 deletions pkg/mocks/authn/mock_JWTVerifier.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit ee518bc

Please sign in to comment.