diff --git a/internal/api/rust/dynamic_slice_test.go b/internal/api/rust/dynamic_slice_test.go new file mode 100644 index 0000000..1fc5b76 --- /dev/null +++ b/internal/api/rust/dynamic_slice_test.go @@ -0,0 +1,45 @@ +package rust + +import ( + "slices" + "testing" +) + +func mustEqual(t *testing.T, got, want []int, msg string) { + t.Helper() + if !slices.Equal(got, want) { + t.Errorf("%s, got %v want %v", msg, got, want) + } +} + +func TestDynamicSlice(t *testing.T) { + var slice DynamicSlice[int] + slice.Append(5, 6, 7) + mustEqual(t, slice.Slice, []int{5, 6, 7}, "Append") + slice.Insert(0, 42) + mustEqual(t, slice.Slice, []int{42, 5, 6, 7}, "Insert") + slice.Insert(2, 43) + mustEqual(t, slice.Slice, []int{42, 5, 43, 6, 7}, "Insert") + slice.Insert(5, 44) + mustEqual(t, slice.Slice, []int{42, 5, 43, 6, 7, 44}, "Insert") + slice.PopBack() + mustEqual(t, slice.Slice, []int{42, 5, 43, 6, 7}, "PopBack") + slice.PopFront() + mustEqual(t, slice.Slice, []int{5, 43, 6, 7}, "PopFront") + slice.Remove(1) + mustEqual(t, slice.Slice, []int{5, 6, 7}, "Remove") + slice.Set(1, 77) + mustEqual(t, slice.Slice, []int{5, 77, 7}, "Set") + slice.Truncate(2) + mustEqual(t, slice.Slice, []int{5, 77}, "Truncate") + slice.Reset([]int{2, 3, 5, 7, 11}) + mustEqual(t, slice.Slice, []int{2, 3, 5, 7, 11}, "Reset") + slice.Append(13) + mustEqual(t, slice.Slice, []int{2, 3, 5, 7, 11, 13}, "Append") + slice.PushBack(17) + mustEqual(t, slice.Slice, []int{2, 3, 5, 7, 11, 13, 17}, "PushBack") + slice.PushFront(1) + mustEqual(t, slice.Slice, []int{1, 2, 3, 5, 7, 11, 13, 17}, "PushFront") + slice.Clear() + mustEqual(t, slice.Slice, []int{}, "Clear") +} diff --git a/internal/api/rust/generic_state_listener.go b/internal/api/rust/generic_state_listener.go index 79d9c87..49b7cf6 100644 --- a/internal/api/rust/generic_state_listener.go +++ b/internal/api/rust/generic_state_listener.go @@ -1,12 +1,14 @@ package rust +import "sync/atomic" + // This is a recurring pattern in the ffi bindings, where a param is an interface // which has a single OnUpdate(T) method. Rather than having many specific impls, // just have a generic one which can be listened to via a channel. Using a channel // means we can time out more easily using select {}. type genericStateListener[T any] struct { ch chan T - isClosed bool + isClosed atomic.Bool } func newGenericStateListener[T any]() *genericStateListener[T] { @@ -16,12 +18,13 @@ func newGenericStateListener[T any]() *genericStateListener[T] { } func (l *genericStateListener[T]) Close() { - l.isClosed = true - close(l.ch) + if l.isClosed.CompareAndSwap(false, true) { + close(l.ch) + } } func (l *genericStateListener[T]) OnUpdate(state T) { - if l.isClosed { + if l.isClosed.Load() { return } l.ch <- state diff --git a/internal/api/rust/generic_state_listener_test.go b/internal/api/rust/generic_state_listener_test.go new file mode 100644 index 0000000..8ffc8c0 --- /dev/null +++ b/internal/api/rust/generic_state_listener_test.go @@ -0,0 +1,36 @@ +package rust + +import ( + "testing" + "time" + + "github.com/matrix-org/complement/must" +) + +func receiveFromChannel(t *testing.T, ch <-chan string) string { + t.Helper() + select { + case val := <-ch: + return val + case <-time.After(time.Second): + t.Fatalf("failed to receive from channel") + } + return "" +} + +func TestGenericStateListener(t *testing.T) { + l := newGenericStateListener[string]() + go l.OnUpdate("foo") + must.Equal(t, receiveFromChannel(t, l.ch), "foo", "OnUpdate") + go l.OnUpdate("bar") + must.Equal(t, receiveFromChannel(t, l.ch), "bar", "OnUpdate") + + // can close and then no more updates get sent + l.Close() + l.OnUpdate("baz") // this should not block due to not sending on the channel + must.Equal(t, receiveFromChannel(t, l.ch), "", "Closed") // recv on a closed channel is the zero value + + // can close repeatedly without panicking + l.Close() + l.Close() +} diff --git a/internal/api/rust/room_listener_test.go b/internal/api/rust/room_listener_test.go new file mode 100644 index 0000000..726da8f --- /dev/null +++ b/internal/api/rust/room_listener_test.go @@ -0,0 +1,64 @@ +package rust + +import ( + "testing" + + "github.com/matrix-org/complement/must" +) + +func TestRoomListener(t *testing.T) { + rl := NewRoomsListener() + + // basic functionality + recv := make(chan string, 2) + cancel := rl.AddListener(func(broadcastRoomID string) (cancel bool) { + recv <- broadcastRoomID + return false + }) + rl.BroadcastUpdateForRoom("foo") + must.Equal(t, <-recv, "foo", "basic usage") + + // multiple broadcasts + rl.BroadcastUpdateForRoom("bar") + rl.BroadcastUpdateForRoom("baz") + must.Equal(t, <-recv, "bar", "multiple broadcasts") + must.Equal(t, <-recv, "baz", "multiple broadcasts") + + // multiple listeners + recv2 := make(chan string, 2) + shouldCancel := false + cancel2 := rl.AddListener(func(broadcastRoomID string) (cancel bool) { + recv2 <- broadcastRoomID + return shouldCancel + }) + rl.BroadcastUpdateForRoom("ping") + must.Equal(t, <-recv, "ping", "multiple listeners") + must.Equal(t, <-recv2, "ping", "multiple listeners") + + // once cancelled, no more data + cancel() + rl.BroadcastUpdateForRoom("quuz") + select { + case <-recv: + t.Fatalf("received room id after cancel()") + default: + // we expect to hit this + } + // but the 2nd listener gets it + must.Equal(t, <-recv2, "quuz", "uncancelled listener") + + // returning true from the listener automatically cancels + shouldCancel = true + rl.BroadcastUpdateForRoom("final message") + must.Equal(t, <-recv2, "final message", "cancel bool") + rl.BroadcastUpdateForRoom("no one is listening") + select { + case <-recv2: + t.Fatalf("received room id after returning true") + default: + // we expect to hit this + } + + // calling the cancel() function in addition to returning true no-ops + cancel2() +} diff --git a/internal/api/rust/rust.go b/internal/api/rust/rust.go index 2008cc1..f4e0f0d 100644 --- a/internal/api/rust/rust.go +++ b/internal/api/rust/rust.go @@ -170,7 +170,7 @@ func (c *RustClient) StartSyncing(t ct.TestLike) (stopSyncing func(), err error) allRoomsListener := newGenericStateListener[[]matrix_sdk_ffi.RoomListEntriesUpdate]() go func() { var allRoomIds DynamicSlice[matrix_sdk_ffi.RoomListEntry] - for !allRoomsListener.isClosed { + for !allRoomsListener.isClosed.Load() { updates := <-allRoomsListener.ch var newEntries []matrix_sdk_ffi.RoomListEntry for _, update := range updates { @@ -268,7 +268,7 @@ func (c *RustClient) MustBackupKeys(t ct.TestLike) (recoveryKey string) { var listener matrix_sdk_ffi.EnableRecoveryProgressListener = genericListener recoveryKey, err := c.FFIClient.Encryption().EnableRecovery(true, listener) must.NotError(t, "Encryption.EnableRecovery", err) - for !genericListener.isClosed { + for !genericListener.isClosed.Load() { select { case s := <-genericListener.ch: switch x := s.(type) {