Skip to content

Commit

Permalink
Enable race tests (#301)
Browse files Browse the repository at this point in the history
* Enable and fix -race tests

* Fix rebase issues
  • Loading branch information
snormore authored Sep 26, 2023
1 parent 6ca903a commit 6973ba5
Show file tree
Hide file tree
Showing 21 changed files with 121 additions and 114 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ jobs:
export PATH="${PATH}:${GOPATH}/bin"
go install github.com/jstemmer/go-junit-report/v2@latest
go test -v ./... | go-junit-report -set-exit-code -iocopy -out report.xml
- name: Run Race Tests
run: |
export GOPATH="${HOME}/go/"
export PATH="${PATH}:${GOPATH}/bin"
go test -v ./... -race
- uses: datadog/junit-upload-github-action@v1
with:
api-key: ${{ secrets.DD_API_KEY }}
Expand Down
6 changes: 6 additions & 0 deletions dev/test
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@ set -e
ulimit -n 2048

go test ./... "$@"

if [ -n "${RACE:-}" ]; then
echo
echo "Running race tests"
go test ./... "$@" -race
fi
2 changes: 1 addition & 1 deletion pkg/api/message/v1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func (s *Service) Publish(ctx context.Context, req *proto.PublishRequest) (*prot
if err != nil {
return nil, status.Errorf(codes.Internal, err.Error())
}
metrics.EmitPublishedEnvelope(ctx, env)
metrics.EmitPublishedEnvelope(ctx, log, env)
}
return &proto.PublishResponse{}, nil
}
Expand Down
8 changes: 7 additions & 1 deletion pkg/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ const (
authorizationMetadataKey = "authorization"
)

var (
prometheusOnce sync.Once
)

type Server struct {
*Config

Expand Down Expand Up @@ -79,7 +83,9 @@ func (s *Server) startGRPC() error {
return errors.Wrap(err, "creating grpc listener")
}

prometheus.EnableHandlingTimeHistogram()
prometheusOnce.Do(func() {
prometheus.EnableHandlingTimeHistogram()
})
unary := []grpc.UnaryServerInterceptor{prometheus.UnaryServerInterceptor}
stream := []grpc.StreamServerInterceptor{prometheus.StreamServerInterceptor}

Expand Down
2 changes: 1 addition & 1 deletion pkg/api/telemetry.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (ti *TelemetryInterceptor) record(ctx context.Context, fullMethod string, e
}

logFn("api request", fields...)
metrics.EmitAPIRequest(ctx, fields)
metrics.EmitAPIRequest(ctx, ti.log, fields)
}

func splitMethodName(fullMethodName string) (serviceName string, methodName string) {
Expand Down
31 changes: 15 additions & 16 deletions pkg/api/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"time"

"github.com/stretchr/testify/require"
"github.com/xmtp/xmtp-node-go/pkg/logging"
"go.uber.org/zap"
)

Expand All @@ -18,12 +17,12 @@ func randomBytes(n int) []byte {
}

func Test_NominalV2(t *testing.T) {
logger, _ := zap.NewDevelopment()
ctx := logging.With(context.Background(), logger)
log, _ := zap.NewDevelopment()
ctx := context.Background()
now := time.Now()
token, data, err := generateV2AuthToken(now.Add(-time.Minute))
require.NoError(t, err)
walletAddr, err := validateToken(ctx, logger, token, now)
walletAddr, err := validateToken(ctx, log, token, now)
require.NoError(t, err)
require.Equal(t, data.WalletAddr, string(walletAddr), "wallet address mismatch")
}
Expand All @@ -34,50 +33,50 @@ func Test_XmtpjsToken(t *testing.T) {
tokenAddress := "0x2D0e614e8Dc8Adf82c70090F905e6cFC1DE84900"
tokenCreatedNs := int64(1660761120550000000)

logger, _ := zap.NewDevelopment()
ctx := logging.With(context.Background(), logger)
log, _ := zap.NewDevelopment()
ctx := context.Background()
now := time.Unix(0, tokenCreatedNs).Add(10 * time.Minute)
token, err := decodeAuthToken(tokenBytes)
require.NoError(t, err)
walletAddr, err := validateToken(ctx, logger, token, now)
walletAddr, err := validateToken(ctx, log, token, now)
require.NoError(t, err)
require.Equal(t, tokenAddress, string(walletAddr))
}

func Test_BadAuthSig(t *testing.T) {
logger, _ := zap.NewDevelopment()
ctx := logging.With(context.Background(), logger)
log, _ := zap.NewDevelopment()
ctx := context.Background()
now := time.Now()
token, _, err := generateV2AuthToken(now.Add(-time.Minute))
require.NoError(t, err)
token.GetAuthDataSignature().GetEcdsaCompact().Bytes = randomBytes(64)
_, err = validateToken(ctx, logger, token, now)
_, err = validateToken(ctx, log, token, now)
require.Error(t, err)
require.Equal(t, err, ErrInvalidSignature)
}

func Test_SignatureMismatch(t *testing.T) {
logger, _ := zap.NewDevelopment()
ctx := logging.With(context.Background(), logger)
log, _ := zap.NewDevelopment()
ctx := context.Background()
now := time.Now()
token1, _, err := generateV2AuthToken(now.Add(-time.Minute))
require.NoError(t, err)
token2, _, err := generateV2AuthToken(now.Add(-time.Minute))
require.NoError(t, err)

// Nominal Checks
_, err = validateToken(ctx, logger, token1, now)
_, err = validateToken(ctx, log, token1, now)
require.NoError(t, err)
_, err = validateToken(ctx, logger, token2, now)
_, err = validateToken(ctx, log, token2, now)
require.NoError(t, err)

// Swap Signatures to check for valid but mismatched signatures
token1.IdentityKey.Signature, token2.AuthDataSignature = token2.AuthDataSignature, token1.IdentityKey.Signature

// Expect Errors as the derived walletAddr will not match the one supplied in AuthData
_, err = validateToken(ctx, logger, token1, now)
_, err = validateToken(ctx, log, token1, now)
require.Error(t, err)
_, err = validateToken(ctx, logger, token1, now)
_, err = validateToken(ctx, log, token1, now)
require.Error(t, err)
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/authn/authn.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 15 additions & 23 deletions pkg/authn/transport_authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ func (xmtpAuth *XmtpAuthentication) Start() {
}

func (xmtpAuth *XmtpAuthentication) onRequest(stream network.Stream) {

log := xmtpAuth.log.With(logging.HostID("peer", stream.Conn().RemotePeer()))
ctx := logging.With(xmtpAuth.ctx, log)

log.Info("stream established")
defer func() {
Expand All @@ -61,7 +59,7 @@ func (xmtpAuth *XmtpAuthentication) onRequest(stream network.Stream) {
}
}()

authenticatedPeerId, authenticatedWalletAddr, err := handleRequest(ctx, stream)
authenticatedPeerId, authenticatedWalletAddr, err := handleRequest(xmtpAuth.ctx, log, stream)
isAuthenticated := err == nil

if isAuthenticated {
Expand All @@ -85,12 +83,10 @@ func (xmtpAuth *XmtpAuthentication) onRequest(stream network.Stream) {

}

func handleRequest(ctx context.Context, stream network.Stream) (peer types.PeerId, wallet types.WalletAddr, err error) {
logger := logging.From(ctx)

func handleRequest(ctx context.Context, log *zap.Logger, stream network.Stream) (peer types.PeerId, wallet types.WalletAddr, err error) {
authRequest, err := readAuthRequest(stream)
if err != nil {
logger.Error("reading request", zap.Error(err))
log.Error("reading request", zap.Error(err))
return peer, wallet, err
}

Expand All @@ -101,14 +97,14 @@ func handleRequest(ctx context.Context, stream network.Stream) (peer types.PeerI

switch version := authRequest.Version.(type) {
case *ClientAuthRequest_V1:
suppliedPeerId, walletAddr, err = validateRequest(ctx, authRequest.GetV1(), connectingPeerId)
suppliedPeerId, walletAddr, err = validateRequest(ctx, log, authRequest.GetV1(), connectingPeerId)
default:
logger.Error("No handler for request", logging.ValueType("version", version))
log.Error("No handler for request", logging.ValueType("version", version))
return peer, wallet, ErrNoHandler
}

if err != nil {
logger.Error("validating request", zap.Error(err))
log.Error("validating request", zap.Error(err))
return peer, wallet, err
}

Expand All @@ -117,32 +113,30 @@ func handleRequest(ctx context.Context, stream network.Stream) (peer types.PeerI
return suppliedPeerId, walletAddr, err
}

func validateRequest(ctx context.Context, request *V1ClientAuthRequest, connectingPeerId types.PeerId) (peer types.PeerId, wallet types.WalletAddr, err error) {
logger := logging.From(ctx)

func validateRequest(ctx context.Context, log *zap.Logger, request *V1ClientAuthRequest, connectingPeerId types.PeerId) (peer types.PeerId, wallet types.WalletAddr, err error) {
// Validate WalletSignature
recoveredWalletAddress, err := recoverWalletAddress(request.IdentityKeyBytes, request.WalletSignature.GetEcdsaCompact())
if err != nil {
logger.Error("verifying wallet signature", zap.Error(err))
log.Error("verifying wallet signature", zap.Error(err))
return peer, wallet, err
}

// Validate AuthSignature
suppliedPeerId, suppliedWalletAddress, err := verifyAuthSignature(ctx, request.IdentityKeyBytes, request.AuthDataBytes, request.AuthSignature.GetEcdsaCompact())
suppliedPeerId, suppliedWalletAddress, err := verifyAuthSignature(ctx, log, request.IdentityKeyBytes, request.AuthDataBytes, request.AuthSignature.GetEcdsaCompact())
if err != nil {
logger.Error("verifying authn signature", zap.Error(err))
log.Error("verifying authn signature", zap.Error(err))
return peer, wallet, err
}

// To protect against spoofing, ensure the walletAddresses match in both signatures
if recoveredWalletAddress != suppliedWalletAddress {
logger.Error("wallet address mismatch", zap.Error(err), logging.WalletAddressLabelled("recovered", recoveredWalletAddress), logging.WalletAddressLabelled("supplied", suppliedWalletAddress))
log.Error("wallet address mismatch", zap.Error(err), logging.WalletAddressLabelled("recovered", recoveredWalletAddress), logging.WalletAddressLabelled("supplied", suppliedWalletAddress))
return peer, wallet, ErrWalletMismatch
}

// To protect against spoofing, ensure the AuthRequest originated from the same peerID that was authenticated.
if connectingPeerId != suppliedPeerId {
logger.Error("peerId Mismatch", zap.Error(err), logging.HostID("supplied", suppliedPeerId.Raw()))
log.Error("peerId Mismatch", zap.Error(err), logging.HostID("supplied", suppliedPeerId.Raw()))
return peer, wallet, ErrWrongPeerId
}

Expand Down Expand Up @@ -171,9 +165,7 @@ func recoverWalletAddress(identityKeyBytes []byte, signature *Signature_ECDSACom
return crypto.RecoverWalletAddress(isrBytes, sig, uint8(signature.GetRecovery()))
}

func verifyAuthSignature(ctx context.Context, identityKeyBytes []byte, authDataBytes []byte, authSig *Signature_ECDSACompact) (peer types.PeerId, wallet types.WalletAddr, err error) {
logger := logging.From(ctx)

func verifyAuthSignature(ctx context.Context, log *zap.Logger, identityKeyBytes []byte, authDataBytes []byte, authSig *Signature_ECDSACompact) (peer types.PeerId, wallet types.WalletAddr, err error) {
pubKey := &PublicKey{}
err = proto.Unmarshal(identityKeyBytes, pubKey)
if err != nil {
Expand All @@ -187,7 +179,7 @@ func verifyAuthSignature(ctx context.Context, identityKeyBytes []byte, authDataB

signature, err := crypto.SignatureFromBytes(authSig.GetBytes())
if err != nil {
logger.Error("signature decoding", zap.Error(err))
log.Error("signature decoding", zap.Error(err))
return peer, wallet, err
}

Expand All @@ -198,7 +190,7 @@ func verifyAuthSignature(ctx context.Context, identityKeyBytes []byte, authDataB

authData, err := unpackAuthData(authDataBytes)
if err != nil {
logger.Error("unpacking authn data", zap.Error(err))
log.Error("unpacking authn data", zap.Error(err))
return peer, wallet, err
}

Expand Down
33 changes: 16 additions & 17 deletions pkg/authn/transport_authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/waku-org/go-waku/tests"
"github.com/xmtp/go-msgio/protoio"
"github.com/xmtp/xmtp-node-go/pkg/logging"
test "github.com/xmtp/xmtp-node-go/pkg/testing"
"github.com/xmtp/xmtp-node-go/pkg/types"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -152,7 +152,6 @@ func ClientAuth(ctx context.Context, log *zap.Logger, h host.Host, peerId types.
// generated from an oracle the peerIDs between the saved request and the connecting stream will not match, resulting in
// a failed authentication.
func TestRoundTrip(t *testing.T) {

log, _ := zap.NewDevelopment()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
Expand Down Expand Up @@ -185,8 +184,8 @@ func TestRoundTrip(t *testing.T) {
}

func TestV1_Nominal(t *testing.T) {
logger, _ := zap.NewDevelopment()
ctx := logging.With(context.Background(), logger)
ctx := context.Background()
log := test.NewLog(t)

testCase := sampleAuthReq001

Expand All @@ -202,7 +201,7 @@ func TestV1_Nominal(t *testing.T) {
require.Equal(t, expectedWalletAddr, types.WalletAddr(authData.WalletAddr), "bad deserialized wallet address")
require.Equal(t, expectedPeerId, types.PeerId(authData.PeerId), "bad deserialized peerID")

peerId, walletAddr, err := validateRequest(ctx, req, expectedPeerId)
peerId, walletAddr, err := validateRequest(ctx, log, req, expectedPeerId)
require.NoError(t, err)

require.Equal(t, expectedWalletAddr, walletAddr, "wallet address mismatch")
Expand All @@ -211,8 +210,8 @@ func TestV1_Nominal(t *testing.T) {
}

func TestV1_BadAuthSig(t *testing.T) {
logger, _ := zap.NewDevelopment()
ctx := logging.With(context.Background(), logger)
ctx := context.Background()
log := test.NewLog(t)

connectingPeerId := types.PeerId("TestPeerID")
req, err := LoadSerializedAuthReq(sampleAuthReq001.reqBytes)
Expand All @@ -224,25 +223,25 @@ func TestV1_BadAuthSig(t *testing.T) {
authData.WalletAddr = "0000000"
req.AuthDataBytes, _ = proto.Marshal(authData)

_, _, err = validateRequest(ctx, req, connectingPeerId)
_, _, err = validateRequest(ctx, log, req, connectingPeerId)
require.Error(t, err)
}

func TestV1_PeerIdSpoof(t *testing.T) {
logger, _ := zap.NewDevelopment()
ctx := logging.With(context.Background(), logger)
ctx := context.Background()
log := test.NewLog(t)

req, err := LoadSerializedAuthReq(sampleAuthReq001.reqBytes)
require.NoError(t, err)
connectingPeerId := types.PeerId("InvalidPeerID")

_, _, err = validateRequest(ctx, req, connectingPeerId)
_, _, err = validateRequest(ctx, log, req, connectingPeerId)
require.Error(t, err)
}

func TestV1_SignatureMismatch(t *testing.T) {
logger, _ := zap.NewDevelopment()
ctx := logging.With(context.Background(), logger)
ctx := context.Background()
log := test.NewLog(t)

req1, err := LoadSerializedAuthReq(sampleAuthReq001.reqBytes)
require.NoError(t, err)
Expand All @@ -255,18 +254,18 @@ func TestV1_SignatureMismatch(t *testing.T) {
require.NoError(t, err)

// Nominal Checks
_, _, err = validateRequest(ctx, req1, types.PeerId(authData1.PeerId))
_, _, err = validateRequest(ctx, log, req1, types.PeerId(authData1.PeerId))
require.NoError(t, err)
_, _, err = validateRequest(ctx, req2, types.PeerId(authData2.PeerId))
_, _, err = validateRequest(ctx, log, req2, types.PeerId(authData2.PeerId))
require.NoError(t, err)

// Swap Signatures to check for valid but mismatched signatures
req1.WalletSignature = req2.WalletSignature
req2.AuthSignature = req1.AuthSignature

// Expect Errors as the derived walletAddr will not match the one supplied in AuthData
_, _, err = validateRequest(ctx, req1, types.PeerId(authData1.PeerId))
_, _, err = validateRequest(ctx, log, req1, types.PeerId(authData1.PeerId))
require.Error(t, err)
_, _, err = validateRequest(ctx, req2, types.PeerId(authData2.PeerId))
_, _, err = validateRequest(ctx, log, req2, types.PeerId(authData2.PeerId))
require.Error(t, err)
}
Loading

0 comments on commit 6973ba5

Please sign in to comment.