diff --git a/core/node/rpc/archiver_test.go b/core/node/rpc/archiver_test.go index c33d5ff8f..bdf42ae80 100644 --- a/core/node/rpc/archiver_test.go +++ b/core/node/rpc/archiver_test.go @@ -334,7 +334,7 @@ func TestArchiveOneStream(t *testing.T) { } func makeTestServerOpts(tester *serviceTester) *ServerStartOpts { - listener, _ := makeTestListener(tester.t) + listener, _ := tester.makeTestListener() return &ServerStartOpts{ RiverChain: tester.btc.NewWalletAndBlockchain(tester.ctx), Listener: listener, diff --git a/core/node/rpc/notification_test.go b/core/node/rpc/notification_test.go index 6324d80f7..47716d4e5 100644 --- a/core/node/rpc/notification_test.go +++ b/core/node/rpc/notification_test.go @@ -18,6 +18,9 @@ import ( "github.com/ethereum/go-ethereum/common" eth_crypto "github.com/ethereum/go-ethereum/crypto" "github.com/google/go-cmp/cmp" + payload2 "github.com/sideshow/apns2/payload" + "github.com/stretchr/testify/require" + "github.com/river-build/river/core/node/crypto" "github.com/river-build/river/core/node/events" "github.com/river-build/river/core/node/notifications/push" @@ -27,16 +30,13 @@ import ( . "github.com/river-build/river/core/node/shared" "github.com/river-build/river/core/node/testutils" "github.com/river-build/river/core/node/testutils/testcert" - payload2 "github.com/sideshow/apns2/payload" - "github.com/stretchr/testify/require" ) // TestSubscriptionExpired ensures that web/apn subscriptions for which the notification API // returns 410 - Gone /expired are automatically purged. func TestSubscriptionExpired(t *testing.T) { tester := newServiceTester(t, serviceTesterOpts{numNodes: 1, start: true}) - ctx, cancel := context.WithCancel(tester.ctx) - defer cancel() + ctx := tester.ctx var notifications notificationExpired @@ -49,7 +49,8 @@ func TestSubscriptionExpired(t *testing.T) { authClient := protocolconnect.NewAuthenticationServiceClient( httpClient, "https://"+notificationService.listener.Addr().String()) - t.Run("webpush", func(t *testing.T) { + tester.parallelSubtest("webpush", func(tester *serviceTester) { + ctx := tester.ctx test := setupDMNotificationTest(ctx, tester, notificationClient, authClient) test.subscribeWebPush(ctx, test.initiator) test.subscribeWebPush(ctx, test.member) @@ -69,7 +70,9 @@ func TestSubscriptionExpired(t *testing.T) { }, 15*time.Second, 100*time.Millisecond, "webpush subscription not deleted") }) - t.Run("APN", func(t *testing.T) { + tester.parallelSubtest("APN", func(tester *serviceTester) { + ctx := tester.ctx + test := setupDMNotificationTest(ctx, tester, notificationClient, authClient) test.subscribeApnPush(ctx, test.initiator) test.subscribeApnPush(ctx, test.member) @@ -94,9 +97,7 @@ func TestSubscriptionExpired(t *testing.T) { // and share the same set of nodes, notification service and client. func TestNotifications(t *testing.T) { tester := newServiceTester(t, serviceTesterOpts{numNodes: 1, start: true}) - ctx, cancel := context.WithCancel(tester.ctx) - defer cancel() - + ctx := tester.ctx notifications := ¬ificationCapture{ WebPushNotifications: make(map[common.Hash]map[common.Address]int), ApnPushNotifications: make(map[common.Hash]map[common.Address]int), @@ -112,28 +113,27 @@ func TestNotifications(t *testing.T) { authClient := protocolconnect.NewAuthenticationServiceClient( httpClient, "https://"+notificationService.listener.Addr().String()) - t.Run("DMNotifications", func(t *testing.T) { - testDMNotifications(t, ctx, tester, notificationClient, authClient, notifications) + tester.parallelSubtest("DMNotifications", func(tester *serviceTester) { + testDMNotifications(tester, notificationClient, authClient, notifications) }) - t.Run("GDMNotifications", func(t *testing.T) { - testGDMNotifications(t, ctx, tester, notificationClient, authClient, notifications) + tester.parallelSubtest("GDMNotifications", func(tester *serviceTester) { + testGDMNotifications(tester, notificationClient, authClient, notifications) }) - t.Run("SpaceChannelNotification", func(t *testing.T) { - SpaceChannelNotification(t, ctx, tester, notificationClient, authClient, notifications) + tester.parallelSubtest("SpaceChannelNotification", func(tester *serviceTester) { + testSpaceChannelNotifications(tester, notificationClient, authClient, notifications) }) } func testGDMNotifications( - t *testing.T, - ctx context.Context, tester *serviceTester, notificationClient protocolconnect.NotificationServiceClient, authClient protocolconnect.AuthenticationServiceClient, notifications *notificationCapture, ) { - t.Run("MessageWithNoMentionsRepliesAndReaction", func(t *testing.T) { + tester.sequentialSubtest("MessageWithNoMentionsRepliesAndReaction", func(tester *serviceTester) { + ctx := tester.ctx test := setupGDMNotificationTest(ctx, tester, notificationClient, authClient) testGDMMessageWithNoMentionsRepliesAndReaction(ctx, test, notifications) }) @@ -216,33 +216,35 @@ func testGDMMessageWithNoMentionsRepliesAndReaction( return !cmp.Equal(nc.WebPushNotifications[eventHash], expectedUsersToReceiveNotification) || !cmp.Equal(nc.ApnPushNotifications[eventHash], expectedUsersToReceiveNotification) - }, 5*time.Second, 100*time.Millisecond, "Received unexpected notifications") + }, time.Second, 100*time.Millisecond, "Received unexpected notifications") } func testDMNotifications( - t *testing.T, - ctx context.Context, tester *serviceTester, notificationClient protocolconnect.NotificationServiceClient, authClient protocolconnect.AuthenticationServiceClient, notifications *notificationCapture, ) { - t.Run("MessageWithDefaultUserNotificationsPreferences", func(t *testing.T) { + tester.sequentialSubtest("MessageWithDefaultUserNotificationsPreferences", func(tester *serviceTester) { + ctx := tester.ctx test := setupDMNotificationTest(ctx, tester, notificationClient, authClient) testDMMessageWithDefaultUserNotificationsPreferences(ctx, test, notifications) }) - t.Run("DMMessageWithNotificationsMutedOnDmChannel", func(t *testing.T) { + tester.sequentialSubtest("DMMessageWithNotificationsMutedOnDmChannel", func(tester *serviceTester) { + ctx := tester.ctx test := setupDMNotificationTest(ctx, tester, notificationClient, authClient) testDMMessageWithNotificationsMutedOnDmChannel(ctx, test, notifications) }) - t.Run("DMMessageWithNotificationsMutedGlobal", func(t *testing.T) { + tester.sequentialSubtest("DMMessageWithNotificationsMutedGlobal", func(tester *serviceTester) { + ctx := tester.ctx test := setupDMNotificationTest(ctx, tester, notificationClient, authClient) testDMMessageWithNotificationsMutedGlobal(ctx, test, notifications) }) - t.Run("MessageWithBlockedUser", func(t *testing.T) { + tester.sequentialSubtest("MessageWithBlockedUser", func(tester *serviceTester) { + ctx := tester.ctx test := setupDMNotificationTest(ctx, tester, notificationClient, authClient) testDMMessageWithBlockedUser(ctx, test, notifications) }) @@ -280,7 +282,7 @@ func testDMMessageWithNotificationsMutedOnDmChannel( nc.ApnPushNotificationsMu.Unlock() return webCount != expectedNotifications || apnCount != expectedNotifications - }, 5*time.Second, 100*time.Millisecond, "Received unexpected notifications") + }, time.Second, 100*time.Millisecond, "Received unexpected notifications") } func testDMMessageWithNotificationsMutedGlobal( @@ -314,7 +316,7 @@ func testDMMessageWithNotificationsMutedGlobal( nc.ApnPushNotificationsMu.Unlock() return webCount != expectedUsersToReceiveNotification || apnCount != expectedUsersToReceiveNotification - }, 5*time.Second, 100*time.Millisecond, "Received unexpected notifications") + }, time.Second, 100*time.Millisecond, "Received unexpected notifications") } func testDMMessageWithDefaultUserNotificationsPreferences( @@ -363,7 +365,7 @@ func testDMMessageWithDefaultUserNotificationsPreferences( return webCount != len(expectedUsersToReceiveNotification) || apnCount != len(expectedUsersToReceiveNotification) - }, 5*time.Second, 100*time.Millisecond, "Received unexpected notifications") + }, time.Second, 100*time.Millisecond, "Received unexpected notifications") } func testDMMessageWithBlockedUser( @@ -403,33 +405,35 @@ func testDMMessageWithBlockedUser( nc.ApnPushNotificationsMu.Unlock() return webCount != expectedNotifications || apnCount != expectedNotifications - }, 10*time.Second, 100*time.Millisecond, "Received unexpected notifications") + }, time.Second, 100*time.Millisecond, "Received unexpected notifications") } -func SpaceChannelNotification( - t *testing.T, - ctx context.Context, +func testSpaceChannelNotifications( tester *serviceTester, notificationClient protocolconnect.NotificationServiceClient, authClient protocolconnect.AuthenticationServiceClient, notifications *notificationCapture, ) { - t.Run("TestPlainMessage", func(t *testing.T) { + tester.sequentialSubtest("TestPlainMessage", func(tester *serviceTester) { + ctx := tester.ctx test := setupSpaceChannelNotificationTest(ctx, tester, notificationClient, authClient) testSpaceChannelPlainMessage(ctx, test, notifications) }) - t.Run("TestAtChannelTag", func(t *testing.T) { + tester.sequentialSubtest("TestAtChannelTag", func(tester *serviceTester) { + ctx := tester.ctx test := setupSpaceChannelNotificationTest(ctx, tester, notificationClient, authClient) testSpaceChannelAtChannelTag(ctx, test, notifications) }) - t.Run("TestMentionsTag", func(t *testing.T) { + tester.sequentialSubtest("TestMentionsTag", func(tester *serviceTester) { + ctx := tester.ctx test := setupSpaceChannelNotificationTest(ctx, tester, notificationClient, authClient) testSpaceChannelMentionTag(ctx, test, notifications) }) - t.Run("Settings", func(t *testing.T) { + tester.sequentialSubtest("Settings", func(tester *serviceTester) { + ctx := tester.ctx test := setupSpaceChannelNotificationTest(ctx, tester, notificationClient, authClient) spaceChannelSettings(ctx, test) }) @@ -495,7 +499,7 @@ func testSpaceChannelPlainMessage( return webCount != len(expectedUsersToReceiveNotification) || apnCount != len(expectedUsersToReceiveNotification) - }, 5*time.Second, 100*time.Millisecond, "Received unexpected notifications") + }, time.Second, 100*time.Millisecond, "Received unexpected notifications") } func testSpaceChannelAtChannelTag( @@ -574,7 +578,7 @@ func testSpaceChannelAtChannelTag( return webCount != len(expectedUsersToReceiveNotification) || apnCount != len(expectedUsersToReceiveNotification) - }, 5*time.Second, 100*time.Millisecond, "Received unexpected notifications") + }, time.Second, 100*time.Millisecond, "Received unexpected notifications") } func testSpaceChannelMentionTag( @@ -655,7 +659,7 @@ func testSpaceChannelMentionTag( return webCount != len(expectedUsersToReceiveNotification) || apnCount != len(expectedUsersToReceiveNotification) - }, 5*time.Second, 100*time.Millisecond, "Received too unexpected notifications") + }, time.Second, 100*time.Millisecond, "Received too unexpected notifications") } func initNotificationService( diff --git a/core/node/rpc/shutdown_test.go b/core/node/rpc/shutdown_test.go index 25690d5af..b4c17752c 100644 --- a/core/node/rpc/shutdown_test.go +++ b/core/node/rpc/shutdown_test.go @@ -22,7 +22,7 @@ func TestShutdown(t *testing.T) { exitStatus <- firstExit }() - listener, _ := makeTestListener(tester.t) + listener, _ := tester.makeTestListener() // Start the second node with same address require.NoError(tester.startSingle(0, startOpts{listeners: []net.Listener{listener}})) diff --git a/core/node/rpc/tester_test.go b/core/node/rpc/tester_test.go index dc932d3a4..962c971b6 100644 --- a/core/node/rpc/tester_test.go +++ b/core/node/rpc/tester_test.go @@ -3,6 +3,7 @@ package rpc import ( "context" "crypto/tls" + "fmt" "hash/fnv" "io" "log" @@ -69,14 +70,19 @@ type serviceTesterOpts struct { start bool } -func makeTestListener(t *testing.T) (net.Listener, string) { +func makeTestListenerNoCleanup(t *testing.T) (net.Listener, string) { listener, err := net.Listen("tcp", "localhost:0") require.NoError(t, err) listener = tls.NewListener(listener, testcert.GetHttp2LocalhostTLSConfig()) - t.Cleanup(func() { _ = listener.Close() }) return listener, "https://" + listener.Addr().String() } +func makeTestListener(t *testing.T) (net.Listener, string) { + l, url := makeTestListenerNoCleanup(t) + t.Cleanup(func() { _ = l.Close() }) + return l, url +} + func newServiceTester(t *testing.T, opts serviceTesterOpts) *serviceTester { t.Parallel() @@ -89,8 +95,6 @@ func newServiceTester(t *testing.T, opts serviceTesterOpts) *serviceTester { } ctx, ctxCancel := test.NewTestContext() - t.Cleanup(ctxCancel) - require := require.New(t) st := &serviceTester{ @@ -103,17 +107,20 @@ func newServiceTester(t *testing.T, opts serviceTesterOpts) *serviceTester { opts: opts, } + // Cleanup context on test completion even if no other cleanups are registered. + st.cleanup(func() {}) + btc, err := crypto.NewBlockchainTestContext( st.ctx, crypto.TestParams{NumKeys: opts.numNodes, MineOnTx: true, AutoMine: true}, ) require.NoError(err) st.btc = btc - t.Cleanup(st.btc.Close) + st.cleanup(st.btc.Close) for i := 0; i < opts.numNodes; i++ { st.nodes[i] = &testNodeRecord{} - st.nodes[i].listener, st.nodes[i].url = makeTestListener(t) + st.nodes[i].listener, st.nodes[i].url = st.makeTestListener() } st.startAutoMining() @@ -133,7 +140,61 @@ func newServiceTester(t *testing.T, opts serviceTesterOpts) *serviceTester { return st } -func (st serviceTester) CloseNode(i int) { +// Returns a new serviceTester instance for a makeSubtest. +// +// The new instance shares nodes with the parent instance, +// if parallel tests are run, node restarts or other changes should not be performed. +func (st *serviceTester) makeSubtest(t *testing.T) *serviceTester { + var sub serviceTester = *st + sub.t = t + sub.ctx, sub.ctxCancel = context.WithCancel(st.ctx) + sub.require = require.New(t) + + // Cleanup context on subtest completion even if no other cleanups are registered. + sub.cleanup(func() {}) + + return &sub +} + +func (st *serviceTester) parallelSubtest(name string, test func(*serviceTester)) { + st.t.Run(name, func(t *testing.T) { + t.Parallel() + test(st.makeSubtest(t)) + }) +} + +func (st *serviceTester) sequentialSubtest(name string, test func(*serviceTester)) { + st.t.Run(name, func(t *testing.T) { + test(st.makeSubtest(t)) + }) +} + +func (st *serviceTester) cleanup(f any) { + st.t.Cleanup(func() { + st.t.Helper() + // On first cleanup call cancel context for the current test, so relevant shutdowns are started. + if st.ctxCancel != nil { + st.ctxCancel() + st.ctxCancel = nil + } + switch f := f.(type) { + case func(): + f() + case func() error: + _ = f() + default: + panic(fmt.Sprintf("unsupported cleanup type: %T", f)) + } + }) +} + +func (st *serviceTester) makeTestListener() (net.Listener, string) { + l, url := makeTestListenerNoCleanup(st.t) + st.cleanup(l.Close) + return l, url +} + +func (st *serviceTester) CloseNode(i int) { if st.nodes[i] != nil { st.nodes[i].Close(st.ctx, st.dbUrl) } @@ -278,13 +339,7 @@ func (st *serviceTester) startSingle(i int, opts ...startOpts) error { var nodeRecord testNodeRecord = *st.nodes[i] - st.t.Cleanup(func() { - // Cancel context here: t.Cleanup calls functions in reverse order, - // but it's better to cancel context first. - // Since it's ok to cancel context multiple times, it's safe to cancel it here. - st.ctxCancel() - nodeRecord.Close(st.ctx, st.dbUrl) - }) + st.cleanup(func() { nodeRecord.Close(st.ctx, st.dbUrl) }) return nil }