From 6973ba509ad258fec2e2909fe0574069a2628e10 Mon Sep 17 00:00:00 2001 From: Steven Normore Date: Tue, 26 Sep 2023 08:15:09 -0400 Subject: [PATCH] Enable race tests (#301) * Enable and fix -race tests * Fix rebase issues --- .github/workflows/test.yml | 5 +++ dev/test | 6 ++++ pkg/api/message/v1/service.go | 2 +- pkg/api/server.go | 8 ++++- pkg/api/telemetry.go | 2 +- pkg/api/token_test.go | 31 +++++++++--------- pkg/authn/authn.pb.go | 2 +- pkg/authn/transport_authentication.go | 38 +++++++++------------- pkg/authn/transport_authentication_test.go | 33 +++++++++---------- pkg/authz/wallet_allow_lister.go | 10 ++++++ pkg/logging/context.go | 23 ------------- pkg/metrics/api-limits.go | 9 +++-- pkg/metrics/api.go | 16 ++++----- pkg/metrics/metrics.go | 11 ++++--- pkg/metrics/peers.go | 5 ++- pkg/ratelimiter/rate_limiter.go | 4 +-- pkg/server/server.go | 4 +-- pkg/store/store.go | 2 +- pkg/testing/node.go | 2 ++ pkg/tracing/tracing.go | 8 ++--- pkg/tracing/tracing_test.go | 14 +++++++- 21 files changed, 121 insertions(+), 114 deletions(-) delete mode 100644 pkg/logging/context.go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index db348445..c57b7111 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,6 +21,11 @@ jobs: export PATH="${PATH}:${GOPATH}/bin" go install github.com/jstemmer/go-junit-report/v2@latest go test -v ./... | go-junit-report -set-exit-code -iocopy -out report.xml + - name: Run Race Tests + run: | + export GOPATH="${HOME}/go/" + export PATH="${PATH}:${GOPATH}/bin" + go test -v ./... -race - uses: datadog/junit-upload-github-action@v1 with: api-key: ${{ secrets.DD_API_KEY }} diff --git a/dev/test b/dev/test index 644551a1..e7a13767 100755 --- a/dev/test +++ b/dev/test @@ -4,3 +4,9 @@ set -e ulimit -n 2048 go test ./... "$@" + +if [ -n "${RACE:-}" ]; then + echo + echo "Running race tests" + go test ./... "$@" -race +fi diff --git a/pkg/api/message/v1/service.go b/pkg/api/message/v1/service.go index 2f92d80b..16797ba5 100644 --- a/pkg/api/message/v1/service.go +++ b/pkg/api/message/v1/service.go @@ -164,7 +164,7 @@ func (s *Service) Publish(ctx context.Context, req *proto.PublishRequest) (*prot if err != nil { return nil, status.Errorf(codes.Internal, err.Error()) } - metrics.EmitPublishedEnvelope(ctx, env) + metrics.EmitPublishedEnvelope(ctx, log, env) } return &proto.PublishResponse{}, nil } diff --git a/pkg/api/server.go b/pkg/api/server.go index dc6a4837..7c6e5990 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -33,6 +33,10 @@ const ( authorizationMetadataKey = "authorization" ) +var ( + prometheusOnce sync.Once +) + type Server struct { *Config @@ -79,7 +83,9 @@ func (s *Server) startGRPC() error { return errors.Wrap(err, "creating grpc listener") } - prometheus.EnableHandlingTimeHistogram() + prometheusOnce.Do(func() { + prometheus.EnableHandlingTimeHistogram() + }) unary := []grpc.UnaryServerInterceptor{prometheus.UnaryServerInterceptor} stream := []grpc.StreamServerInterceptor{prometheus.StreamServerInterceptor} diff --git a/pkg/api/telemetry.go b/pkg/api/telemetry.go index 59aeab1d..bb344222 100644 --- a/pkg/api/telemetry.go +++ b/pkg/api/telemetry.go @@ -85,7 +85,7 @@ func (ti *TelemetryInterceptor) record(ctx context.Context, fullMethod string, e } logFn("api request", fields...) - metrics.EmitAPIRequest(ctx, fields) + metrics.EmitAPIRequest(ctx, ti.log, fields) } func splitMethodName(fullMethodName string) (serviceName string, methodName string) { diff --git a/pkg/api/token_test.go b/pkg/api/token_test.go index 23abd123..cdcf46bc 100644 --- a/pkg/api/token_test.go +++ b/pkg/api/token_test.go @@ -7,7 +7,6 @@ import ( "time" "github.com/stretchr/testify/require" - "github.com/xmtp/xmtp-node-go/pkg/logging" "go.uber.org/zap" ) @@ -18,12 +17,12 @@ func randomBytes(n int) []byte { } func Test_NominalV2(t *testing.T) { - logger, _ := zap.NewDevelopment() - ctx := logging.With(context.Background(), logger) + log, _ := zap.NewDevelopment() + ctx := context.Background() now := time.Now() token, data, err := generateV2AuthToken(now.Add(-time.Minute)) require.NoError(t, err) - walletAddr, err := validateToken(ctx, logger, token, now) + walletAddr, err := validateToken(ctx, log, token, now) require.NoError(t, err) require.Equal(t, data.WalletAddr, string(walletAddr), "wallet address mismatch") } @@ -34,31 +33,31 @@ func Test_XmtpjsToken(t *testing.T) { tokenAddress := "0x2D0e614e8Dc8Adf82c70090F905e6cFC1DE84900" tokenCreatedNs := int64(1660761120550000000) - logger, _ := zap.NewDevelopment() - ctx := logging.With(context.Background(), logger) + log, _ := zap.NewDevelopment() + ctx := context.Background() now := time.Unix(0, tokenCreatedNs).Add(10 * time.Minute) token, err := decodeAuthToken(tokenBytes) require.NoError(t, err) - walletAddr, err := validateToken(ctx, logger, token, now) + walletAddr, err := validateToken(ctx, log, token, now) require.NoError(t, err) require.Equal(t, tokenAddress, string(walletAddr)) } func Test_BadAuthSig(t *testing.T) { - logger, _ := zap.NewDevelopment() - ctx := logging.With(context.Background(), logger) + log, _ := zap.NewDevelopment() + ctx := context.Background() now := time.Now() token, _, err := generateV2AuthToken(now.Add(-time.Minute)) require.NoError(t, err) token.GetAuthDataSignature().GetEcdsaCompact().Bytes = randomBytes(64) - _, err = validateToken(ctx, logger, token, now) + _, err = validateToken(ctx, log, token, now) require.Error(t, err) require.Equal(t, err, ErrInvalidSignature) } func Test_SignatureMismatch(t *testing.T) { - logger, _ := zap.NewDevelopment() - ctx := logging.With(context.Background(), logger) + log, _ := zap.NewDevelopment() + ctx := context.Background() now := time.Now() token1, _, err := generateV2AuthToken(now.Add(-time.Minute)) require.NoError(t, err) @@ -66,18 +65,18 @@ func Test_SignatureMismatch(t *testing.T) { require.NoError(t, err) // Nominal Checks - _, err = validateToken(ctx, logger, token1, now) + _, err = validateToken(ctx, log, token1, now) require.NoError(t, err) - _, err = validateToken(ctx, logger, token2, now) + _, err = validateToken(ctx, log, token2, now) require.NoError(t, err) // Swap Signatures to check for valid but mismatched signatures token1.IdentityKey.Signature, token2.AuthDataSignature = token2.AuthDataSignature, token1.IdentityKey.Signature // Expect Errors as the derived walletAddr will not match the one supplied in AuthData - _, err = validateToken(ctx, logger, token1, now) + _, err = validateToken(ctx, log, token1, now) require.Error(t, err) - _, err = validateToken(ctx, logger, token1, now) + _, err = validateToken(ctx, log, token1, now) require.Error(t, err) } diff --git a/pkg/authn/authn.pb.go b/pkg/authn/authn.pb.go index d8df5a16..5a3d5733 100644 --- a/pkg/authn/authn.pb.go +++ b/pkg/authn/authn.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.28.1 -// protoc v4.24.2 +// protoc v4.24.3 // source: authn.proto package authn diff --git a/pkg/authn/transport_authentication.go b/pkg/authn/transport_authentication.go index f9ade4c6..b712113c 100644 --- a/pkg/authn/transport_authentication.go +++ b/pkg/authn/transport_authentication.go @@ -50,9 +50,7 @@ func (xmtpAuth *XmtpAuthentication) Start() { } func (xmtpAuth *XmtpAuthentication) onRequest(stream network.Stream) { - log := xmtpAuth.log.With(logging.HostID("peer", stream.Conn().RemotePeer())) - ctx := logging.With(xmtpAuth.ctx, log) log.Info("stream established") defer func() { @@ -61,7 +59,7 @@ func (xmtpAuth *XmtpAuthentication) onRequest(stream network.Stream) { } }() - authenticatedPeerId, authenticatedWalletAddr, err := handleRequest(ctx, stream) + authenticatedPeerId, authenticatedWalletAddr, err := handleRequest(xmtpAuth.ctx, log, stream) isAuthenticated := err == nil if isAuthenticated { @@ -85,12 +83,10 @@ func (xmtpAuth *XmtpAuthentication) onRequest(stream network.Stream) { } -func handleRequest(ctx context.Context, stream network.Stream) (peer types.PeerId, wallet types.WalletAddr, err error) { - logger := logging.From(ctx) - +func handleRequest(ctx context.Context, log *zap.Logger, stream network.Stream) (peer types.PeerId, wallet types.WalletAddr, err error) { authRequest, err := readAuthRequest(stream) if err != nil { - logger.Error("reading request", zap.Error(err)) + log.Error("reading request", zap.Error(err)) return peer, wallet, err } @@ -101,14 +97,14 @@ func handleRequest(ctx context.Context, stream network.Stream) (peer types.PeerI switch version := authRequest.Version.(type) { case *ClientAuthRequest_V1: - suppliedPeerId, walletAddr, err = validateRequest(ctx, authRequest.GetV1(), connectingPeerId) + suppliedPeerId, walletAddr, err = validateRequest(ctx, log, authRequest.GetV1(), connectingPeerId) default: - logger.Error("No handler for request", logging.ValueType("version", version)) + log.Error("No handler for request", logging.ValueType("version", version)) return peer, wallet, ErrNoHandler } if err != nil { - logger.Error("validating request", zap.Error(err)) + log.Error("validating request", zap.Error(err)) return peer, wallet, err } @@ -117,32 +113,30 @@ func handleRequest(ctx context.Context, stream network.Stream) (peer types.PeerI return suppliedPeerId, walletAddr, err } -func validateRequest(ctx context.Context, request *V1ClientAuthRequest, connectingPeerId types.PeerId) (peer types.PeerId, wallet types.WalletAddr, err error) { - logger := logging.From(ctx) - +func validateRequest(ctx context.Context, log *zap.Logger, request *V1ClientAuthRequest, connectingPeerId types.PeerId) (peer types.PeerId, wallet types.WalletAddr, err error) { // Validate WalletSignature recoveredWalletAddress, err := recoverWalletAddress(request.IdentityKeyBytes, request.WalletSignature.GetEcdsaCompact()) if err != nil { - logger.Error("verifying wallet signature", zap.Error(err)) + log.Error("verifying wallet signature", zap.Error(err)) return peer, wallet, err } // Validate AuthSignature - suppliedPeerId, suppliedWalletAddress, err := verifyAuthSignature(ctx, request.IdentityKeyBytes, request.AuthDataBytes, request.AuthSignature.GetEcdsaCompact()) + suppliedPeerId, suppliedWalletAddress, err := verifyAuthSignature(ctx, log, request.IdentityKeyBytes, request.AuthDataBytes, request.AuthSignature.GetEcdsaCompact()) if err != nil { - logger.Error("verifying authn signature", zap.Error(err)) + log.Error("verifying authn signature", zap.Error(err)) return peer, wallet, err } // To protect against spoofing, ensure the walletAddresses match in both signatures if recoveredWalletAddress != suppliedWalletAddress { - logger.Error("wallet address mismatch", zap.Error(err), logging.WalletAddressLabelled("recovered", recoveredWalletAddress), logging.WalletAddressLabelled("supplied", suppliedWalletAddress)) + log.Error("wallet address mismatch", zap.Error(err), logging.WalletAddressLabelled("recovered", recoveredWalletAddress), logging.WalletAddressLabelled("supplied", suppliedWalletAddress)) return peer, wallet, ErrWalletMismatch } // To protect against spoofing, ensure the AuthRequest originated from the same peerID that was authenticated. if connectingPeerId != suppliedPeerId { - logger.Error("peerId Mismatch", zap.Error(err), logging.HostID("supplied", suppliedPeerId.Raw())) + log.Error("peerId Mismatch", zap.Error(err), logging.HostID("supplied", suppliedPeerId.Raw())) return peer, wallet, ErrWrongPeerId } @@ -171,9 +165,7 @@ func recoverWalletAddress(identityKeyBytes []byte, signature *Signature_ECDSACom return crypto.RecoverWalletAddress(isrBytes, sig, uint8(signature.GetRecovery())) } -func verifyAuthSignature(ctx context.Context, identityKeyBytes []byte, authDataBytes []byte, authSig *Signature_ECDSACompact) (peer types.PeerId, wallet types.WalletAddr, err error) { - logger := logging.From(ctx) - +func verifyAuthSignature(ctx context.Context, log *zap.Logger, identityKeyBytes []byte, authDataBytes []byte, authSig *Signature_ECDSACompact) (peer types.PeerId, wallet types.WalletAddr, err error) { pubKey := &PublicKey{} err = proto.Unmarshal(identityKeyBytes, pubKey) if err != nil { @@ -187,7 +179,7 @@ func verifyAuthSignature(ctx context.Context, identityKeyBytes []byte, authDataB signature, err := crypto.SignatureFromBytes(authSig.GetBytes()) if err != nil { - logger.Error("signature decoding", zap.Error(err)) + log.Error("signature decoding", zap.Error(err)) return peer, wallet, err } @@ -198,7 +190,7 @@ func verifyAuthSignature(ctx context.Context, identityKeyBytes []byte, authDataB authData, err := unpackAuthData(authDataBytes) if err != nil { - logger.Error("unpacking authn data", zap.Error(err)) + log.Error("unpacking authn data", zap.Error(err)) return peer, wallet, err } diff --git a/pkg/authn/transport_authentication_test.go b/pkg/authn/transport_authentication_test.go index 6599b3b5..7f6788d1 100644 --- a/pkg/authn/transport_authentication_test.go +++ b/pkg/authn/transport_authentication_test.go @@ -16,7 +16,7 @@ import ( "github.com/stretchr/testify/require" "github.com/waku-org/go-waku/tests" "github.com/xmtp/go-msgio/protoio" - "github.com/xmtp/xmtp-node-go/pkg/logging" + test "github.com/xmtp/xmtp-node-go/pkg/testing" "github.com/xmtp/xmtp-node-go/pkg/types" "go.uber.org/zap" "google.golang.org/protobuf/proto" @@ -152,7 +152,6 @@ func ClientAuth(ctx context.Context, log *zap.Logger, h host.Host, peerId types. // generated from an oracle the peerIDs between the saved request and the connecting stream will not match, resulting in // a failed authentication. func TestRoundTrip(t *testing.T) { - log, _ := zap.NewDevelopment() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -185,8 +184,8 @@ func TestRoundTrip(t *testing.T) { } func TestV1_Nominal(t *testing.T) { - logger, _ := zap.NewDevelopment() - ctx := logging.With(context.Background(), logger) + ctx := context.Background() + log := test.NewLog(t) testCase := sampleAuthReq001 @@ -202,7 +201,7 @@ func TestV1_Nominal(t *testing.T) { require.Equal(t, expectedWalletAddr, types.WalletAddr(authData.WalletAddr), "bad deserialized wallet address") require.Equal(t, expectedPeerId, types.PeerId(authData.PeerId), "bad deserialized peerID") - peerId, walletAddr, err := validateRequest(ctx, req, expectedPeerId) + peerId, walletAddr, err := validateRequest(ctx, log, req, expectedPeerId) require.NoError(t, err) require.Equal(t, expectedWalletAddr, walletAddr, "wallet address mismatch") @@ -211,8 +210,8 @@ func TestV1_Nominal(t *testing.T) { } func TestV1_BadAuthSig(t *testing.T) { - logger, _ := zap.NewDevelopment() - ctx := logging.With(context.Background(), logger) + ctx := context.Background() + log := test.NewLog(t) connectingPeerId := types.PeerId("TestPeerID") req, err := LoadSerializedAuthReq(sampleAuthReq001.reqBytes) @@ -224,25 +223,25 @@ func TestV1_BadAuthSig(t *testing.T) { authData.WalletAddr = "0000000" req.AuthDataBytes, _ = proto.Marshal(authData) - _, _, err = validateRequest(ctx, req, connectingPeerId) + _, _, err = validateRequest(ctx, log, req, connectingPeerId) require.Error(t, err) } func TestV1_PeerIdSpoof(t *testing.T) { - logger, _ := zap.NewDevelopment() - ctx := logging.With(context.Background(), logger) + ctx := context.Background() + log := test.NewLog(t) req, err := LoadSerializedAuthReq(sampleAuthReq001.reqBytes) require.NoError(t, err) connectingPeerId := types.PeerId("InvalidPeerID") - _, _, err = validateRequest(ctx, req, connectingPeerId) + _, _, err = validateRequest(ctx, log, req, connectingPeerId) require.Error(t, err) } func TestV1_SignatureMismatch(t *testing.T) { - logger, _ := zap.NewDevelopment() - ctx := logging.With(context.Background(), logger) + ctx := context.Background() + log := test.NewLog(t) req1, err := LoadSerializedAuthReq(sampleAuthReq001.reqBytes) require.NoError(t, err) @@ -255,9 +254,9 @@ func TestV1_SignatureMismatch(t *testing.T) { require.NoError(t, err) // Nominal Checks - _, _, err = validateRequest(ctx, req1, types.PeerId(authData1.PeerId)) + _, _, err = validateRequest(ctx, log, req1, types.PeerId(authData1.PeerId)) require.NoError(t, err) - _, _, err = validateRequest(ctx, req2, types.PeerId(authData2.PeerId)) + _, _, err = validateRequest(ctx, log, req2, types.PeerId(authData2.PeerId)) require.NoError(t, err) // Swap Signatures to check for valid but mismatched signatures @@ -265,8 +264,8 @@ func TestV1_SignatureMismatch(t *testing.T) { req2.AuthSignature = req1.AuthSignature // Expect Errors as the derived walletAddr will not match the one supplied in AuthData - _, _, err = validateRequest(ctx, req1, types.PeerId(authData1.PeerId)) + _, _, err = validateRequest(ctx, log, req1, types.PeerId(authData1.PeerId)) require.Error(t, err) - _, _, err = validateRequest(ctx, req2, types.PeerId(authData2.PeerId)) + _, _, err = validateRequest(ctx, log, req2, types.PeerId(authData2.PeerId)) require.Error(t, err) } diff --git a/pkg/authz/wallet_allow_lister.go b/pkg/authz/wallet_allow_lister.go index 4476ff18..ea73d64e 100644 --- a/pkg/authz/wallet_allow_lister.go +++ b/pkg/authz/wallet_allow_lister.go @@ -51,6 +51,7 @@ type DatabaseWalletAllowLister struct { db *bun.DB log *zap.Logger permissions map[string]Permission + permissionsLock sync.RWMutex refreshInterval time.Duration ctx context.Context cancelFunc context.CancelFunc @@ -69,6 +70,9 @@ func NewDatabaseWalletAllowLister(db *bun.DB, log *zap.Logger) *DatabaseWalletAl // Get the permissions for a wallet address func (d *DatabaseWalletAllowLister) GetPermissions(walletAddress string) Permission { + d.permissionsLock.RLock() + defer d.permissionsLock.RUnlock() + permission, hasPermission := d.permissions[walletAddress] if !hasPermission { return Unspecified @@ -97,6 +101,9 @@ func (d *DatabaseWalletAllowLister) Allow(ctx context.Context, walletAddress str } func (d *DatabaseWalletAllowLister) Apply(ctx context.Context, walletAddress string, permission Permission) error { + d.permissionsLock.Lock() + defer d.permissionsLock.Unlock() + wallet := WalletAddress{ WalletAddress: walletAddress, Permission: unmapPermission(permission), @@ -132,6 +139,9 @@ func (d *DatabaseWalletAllowLister) Start(ctx context.Context) error { // Currently just loads absolutely everything, since n is going to be small enough for now. // Should be possible to do incrementally if we need to using the created_at field func (d *DatabaseWalletAllowLister) loadPermissions() error { + d.permissionsLock.Lock() + defer d.permissionsLock.Unlock() + var wallets []WalletAddress query := d.db.NewSelect().Model(&wallets).Where("deleted_at IS NULL") if err := query.Scan(d.ctx); err != nil { diff --git a/pkg/logging/context.go b/pkg/logging/context.go deleted file mode 100644 index 9ff02709..00000000 --- a/pkg/logging/context.go +++ /dev/null @@ -1,23 +0,0 @@ -package logging - -import ( - "context" - - "github.com/waku-org/go-waku/logging" - "github.com/waku-org/go-waku/waku/v2/utils" - "go.uber.org/zap" -) - -var ( - // Re-export the go-waku helpers. - With = logging.With -) - -// From returns a logger from the context or the default logger. -func From(ctx context.Context) *zap.Logger { - logger := logging.From(ctx) - if logger == nil { - logger = utils.Logger() - } - return logger -} diff --git a/pkg/metrics/api-limits.go b/pkg/metrics/api-limits.go index 37866a16..ca4db731 100644 --- a/pkg/metrics/api-limits.go +++ b/pkg/metrics/api-limits.go @@ -3,7 +3,6 @@ package metrics import ( "context" - "github.com/xmtp/xmtp-node-go/pkg/logging" "go.opencensus.io/stats" "go.opencensus.io/stats/view" "go.opencensus.io/tag" @@ -21,10 +20,10 @@ var ratelimiterBucketsGaugeView = &view.View{ TagKeys: []tag.Key{bucketsNameKey}, } -func EmitRatelimiterBucketsSize(ctx context.Context, name string, size int) { +func EmitRatelimiterBucketsSize(ctx context.Context, log *zap.Logger, name string, size int) { err := recordWithTags(ctx, []tag.Mutator{tag.Insert(bucketsNameKey, name)}, ratelimiterBucketsGaugeMeasure.M(int64(size))) if err != nil { - logging.From(ctx).Warn("recording metric", + log.Warn("recording metric", zap.String("metric", ratelimiterBucketsGaugeMeasure.Name()), zap.Error(err)) } @@ -39,10 +38,10 @@ var ratelimiterBucketsDeletedCounterView = &view.View{ TagKeys: []tag.Key{bucketsNameKey}, } -func EmitRatelimiterDeletedEntries(ctx context.Context, name string, count int) { +func EmitRatelimiterDeletedEntries(ctx context.Context, log *zap.Logger, name string, count int) { err := recordWithTags(ctx, []tag.Mutator{tag.Insert(bucketsNameKey, name)}, ratelimiterBucketsDeletedCounterMeasure.M(int64(count))) if err != nil { - logging.From(ctx).Warn("recording metric", + log.Warn("recording metric", zap.String("metric", ratelimiterBucketsDeletedCounterMeasure.Name()), zap.Error(err)) } diff --git a/pkg/metrics/api.go b/pkg/metrics/api.go index c04c6e23..78f680d3 100644 --- a/pkg/metrics/api.go +++ b/pkg/metrics/api.go @@ -41,7 +41,7 @@ var apiRequestsView = &view.View{ TagKeys: apiRequestTagKeys, } -func EmitAPIRequest(ctx context.Context, fields []zapcore.Field) { +func EmitAPIRequest(ctx context.Context, log *zap.Logger, fields []zapcore.Field) { mutators := make([]tag.Mutator, 0, len(fields)) for _, field := range fields { key, ok := apiRequestTagKeysByName[field.Key] @@ -52,7 +52,7 @@ func EmitAPIRequest(ctx context.Context, fields []zapcore.Field) { } err := recordWithTags(ctx, mutators, apiRequestsMeasure.M(1)) if err != nil { - logging.From(ctx).Error("recording metric", fields...) + log.Error("recording metric", fields...) } } @@ -75,14 +75,14 @@ var publishedEnvelopeCounterView = &view.View{ TagKeys: append([]tag.Key{topicCategoryTag}, appClientVersionTagKeys...), } -func EmitPublishedEnvelope(ctx context.Context, env *proto.Envelope) { +func EmitPublishedEnvelope(ctx context.Context, log *zap.Logger, env *proto.Envelope) { mutators := contextMutators(ctx) topicCategory := topic.Category(env.ContentTopic) mutators = append(mutators, tag.Insert(topicCategoryTag, topicCategory)) size := int64(len(env.Message)) err := recordWithTags(ctx, mutators, publishedEnvelopeMeasure.M(size)) if err != nil { - logging.From(ctx).Error("recording metric", + log.Error("recording metric", zap.Error(err), zap.String("metric", publishedEnvelopeView.Name), zap.Int64("size", size), @@ -91,7 +91,7 @@ func EmitPublishedEnvelope(ctx context.Context, env *proto.Envelope) { } err = recordWithTags(ctx, mutators, publishedEnvelopeCounterMeasure.M(1)) if err != nil { - logging.From(ctx).Error("recording metric", + log.Error("recording metric", zap.Error(err), zap.String("metric", publishedEnvelopeCounterView.Name), zap.Int64("size", size), @@ -147,7 +147,7 @@ var queryResultView = &view.View{ TagKeys: append([]tag.Key{topicCategoryTag, queryErrorTag, queryParametersTag}, appClientVersionTagKeys...), } -func EmitQuery(ctx context.Context, req *proto.QueryRequest, results int, err error, duration time.Duration) { +func EmitQuery(ctx context.Context, log *zap.Logger, req *proto.QueryRequest, results int, err error, duration time.Duration) { mutators := []tag.Mutator{} if len(req.ContentTopics) > 0 { topicCategory := topic.Category(req.ContentTopics[0]) @@ -160,7 +160,7 @@ func EmitQuery(ctx context.Context, req *proto.QueryRequest, results int, err er mutators = append(mutators, tag.Insert(queryParametersTag, parameters)) err = recordWithTags(ctx, mutators, queryDurationMeasure.M(duration.Milliseconds())) if err != nil { - logging.From(ctx).Error("recording metric", + log.Error("recording metric", zap.Error(err), zap.Duration("duration", duration), zap.String("parameters", parameters), @@ -169,7 +169,7 @@ func EmitQuery(ctx context.Context, req *proto.QueryRequest, results int, err er } err = recordWithTags(ctx, mutators, queryResultMeasure.M(int64(results))) if err != nil { - logging.From(ctx).Error("recording metric", + log.Error("recording metric", zap.Error(err), zap.Int("results", results), zap.String("parameters", parameters), diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 4e6a5892..fb5c9fd1 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -6,7 +6,6 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/waku-org/go-waku/waku/metrics" - "github.com/xmtp/xmtp-node-go/pkg/logging" "github.com/xmtp/xmtp-node-go/pkg/tracing" "go.opencensus.io/stats" "go.opencensus.io/stats/view" @@ -16,16 +15,20 @@ import ( // Server wraps go-waku metrics server, so that we don't need to reference the go-waku package anywhere type Server struct { + log *zap.Logger waku *metrics.Server http *http.Server } -func NewMetricsServer(address string, port int, logger *zap.Logger) *Server { - return &Server{waku: metrics.NewMetricsServer(address, port, logger)} +func NewMetricsServer(address string, port int, log *zap.Logger) *Server { + return &Server{ + log: log, + waku: metrics.NewMetricsServer(address, port, log), + } } func (s *Server) Start(ctx context.Context) { - log := logging.From(ctx).Named("metrics") + log := s.log.Named("metrics") go tracing.PanicWrap(ctx, "waku metrics server", func(_ context.Context) { s.waku.Start() }) s.http = &http.Server{Addr: ":8009", Handler: promhttp.Handler()} go tracing.PanicWrap(ctx, "metrics server", func(_ context.Context) { diff --git a/pkg/metrics/peers.go b/pkg/metrics/peers.go index 8729869b..6bb431d2 100644 --- a/pkg/metrics/peers.go +++ b/pkg/metrics/peers.go @@ -6,7 +6,6 @@ import ( "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" - "github.com/xmtp/xmtp-node-go/pkg/logging" "go.opencensus.io/stats" "go.opencensus.io/stats/view" "go.opencensus.io/tag" @@ -32,7 +31,7 @@ var BootstrapPeersView = &view.View{ Aggregation: view.LastValue(), } -func EmitPeersByProtocol(ctx context.Context, host host.Host) { +func EmitPeersByProtocol(ctx context.Context, log *zap.Logger, host host.Host) { byProtocol := map[string]int64{} ps := host.Peerstore() for _, peer := range ps.Peers() { @@ -48,7 +47,7 @@ func EmitPeersByProtocol(ctx context.Context, host host.Host) { mutators := []tag.Mutator{tag.Insert(TagProto, proto)} err := recordWithTags(ctx, mutators, PeersByProto.M(count)) if err != nil { - logging.From(ctx).Warn("recording metric", zap.String("metric", PeersByProto.Name()), zap.String("proto", proto), zap.Error(err)) + log.Warn("recording metric", zap.String("metric", PeersByProto.Name()), zap.String("proto", proto), zap.Error(err)) } } } diff --git a/pkg/ratelimiter/rate_limiter.go b/pkg/ratelimiter/rate_limiter.go index 00c865d0..f5c9e306 100644 --- a/pkg/ratelimiter/rate_limiter.go +++ b/pkg/ratelimiter/rate_limiter.go @@ -156,8 +156,8 @@ func (rl *TokenBucketRateLimiter) Janitor(sweepInterval, expiresAfter time.Durat func (rl *TokenBucketRateLimiter) sweepAndSwap(expiresAfter time.Duration) (deletedEntries int) { // Only the janitor writes to oldBuckets (the swap below), so we shouldn't need to rlock it here. deletedEntries = rl.oldBuckets.deleteExpired(expiresAfter) - metrics.EmitRatelimiterDeletedEntries(rl.ctx, rl.oldBuckets.name, deletedEntries) - metrics.EmitRatelimiterBucketsSize(rl.ctx, rl.oldBuckets.name, len(rl.oldBuckets.buckets)) + metrics.EmitRatelimiterDeletedEntries(rl.ctx, rl.log, rl.oldBuckets.name, deletedEntries) + metrics.EmitRatelimiterBucketsSize(rl.ctx, rl.log, rl.oldBuckets.name, len(rl.oldBuckets.buckets)) rl.mutex.Lock() rl.newBuckets, rl.oldBuckets = rl.oldBuckets, rl.newBuckets rl.mutex.Unlock() diff --git a/pkg/server/server.go b/pkg/server/server.go index 1d1b1558..76240c04 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -83,7 +83,7 @@ func New(ctx context.Context, log *zap.Logger, options Options) (*Server, error) } s.log = s.log.With(logging.HostID("node", id)) - s.ctx, s.cancel = context.WithCancel(logging.With(ctx, s.log)) + s.ctx, s.cancel = context.WithCancel(ctx) if options.Metrics.Enable { s.metricsServer = metrics.NewMetricsServer(options.Metrics.Address, options.Metrics.Port, s.log) @@ -351,7 +351,7 @@ func (s *Server) statusMetricsLoop(options Options) { case <-s.ctx.Done(): return case <-ticker.C: - metrics.EmitPeersByProtocol(s.ctx, s.wakuNode.Host()) + metrics.EmitPeersByProtocol(s.ctx, s.log, s.wakuNode.Host()) if len(bootstrapPeers) > 0 { metrics.EmitBootstrapPeersConnected(s.ctx, s.wakuNode.Host(), bootstrapPeers) } diff --git a/pkg/store/store.go b/pkg/store/store.go index ae27abac..71e17d9a 100644 --- a/pkg/store/store.go +++ b/pkg/store/store.go @@ -104,7 +104,7 @@ func (s *Store) metricsLoop(ctx context.Context) { func (s *Store) InsertMessage(env *messagev1.Envelope) (bool, error) { var stored bool - err := tracing.Wrap(s.ctx, "storing message", func(ctx context.Context, span tracing.Span) error { + err := tracing.Wrap(s.ctx, s.log, "storing message", func(ctx context.Context, log *zap.Logger, span tracing.Span) error { tracing.SpanResource(span, "store") tracing.SpanType(span, "db") err := s.insertMessage(env, s.now().UnixNano()) diff --git a/pkg/testing/node.go b/pkg/testing/node.go index 9687d5f6..1d4e3c2e 100644 --- a/pkg/testing/node.go +++ b/pkg/testing/node.go @@ -69,7 +69,9 @@ func NewNode(t *testing.T, opts ...wakunode.WakuNodeOption) (*wakunode.WakuNode, hostAddr, _ := net.ResolveTCPAddr("tcp", "0.0.0.0:0") prvKey := NewPrivateKey(t) ctx := context.Background() + log := NewLog(t) opts = append([]wakunode.WakuNodeOption{ + wakunode.WithLogger(log), wakunode.WithPrivateKey(prvKey), wakunode.WithHostAddress(hostAddr), wakunode.WithWakuRelay(), diff --git a/pkg/tracing/tracing.go b/pkg/tracing/tracing.go index 88189e33..78d5ca3a 100644 --- a/pkg/tracing/tracing.go +++ b/pkg/tracing/tracing.go @@ -7,7 +7,6 @@ import ( "os" "sync" - "github.com/xmtp/xmtp-node-go/pkg/logging" "go.uber.org/zap" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" @@ -52,12 +51,11 @@ func Stop() { // Wrap executes action in the context of a span. // Tags the span with the error if action returns one. -func Wrap(ctx context.Context, operation string, action func(context.Context, Span) error) error { +func Wrap(ctx context.Context, log *zap.Logger, operation string, action func(context.Context, *zap.Logger, Span) error) error { span, ctx := tracer.StartSpanFromContext(ctx, operation) defer span.Finish() - log := logging.From(ctx).With(zap.String("span", operation)) - ctx = logging.With(ctx, Link(span, log)) - err := action(ctx, span) + log = Link(span, log.With(zap.String("span", operation))) + err := action(ctx, log, span) if err != nil { span.Finish(WithError(err)) } diff --git a/pkg/tracing/tracing_test.go b/pkg/tracing/tracing_test.go index 09ae1a21..60aa8939 100644 --- a/pkg/tracing/tracing_test.go +++ b/pkg/tracing/tracing_test.go @@ -13,16 +13,28 @@ func Test_GoPanicWrap_WaitGroup(t *testing.T) { var wg sync.WaitGroup ctx, cancel := context.WithCancel(context.Background()) finished := false + var finishedLock sync.RWMutex GoPanicWrap(ctx, &wg, "test", func(ctx context.Context) { <-ctx.Done() + finishedLock.Lock() + defer finishedLock.Unlock() finished = true }) done := false + var doneLock sync.RWMutex go func() { wg.Wait() + doneLock.Lock() + defer doneLock.Unlock() done = true }() go func() { time.Sleep(time.Millisecond); cancel() }() - assert.Eventually(t, func() bool { return finished && done }, time.Second, 10*time.Millisecond) + assert.Eventually(t, func() bool { + finishedLock.RLock() + defer finishedLock.RUnlock() + doneLock.RLock() + defer doneLock.RUnlock() + return finished && done + }, time.Second, 10*time.Millisecond) }