diff --git a/internal/api/client.go b/internal/api/client.go index 0f90013..bdb3ced 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -2,7 +2,6 @@ package api import ( "fmt" - "testing" "time" "github.com/matrix-org/complement/client" @@ -27,34 +26,35 @@ type Client interface { // Specifically, we need to shut off existing browsers and any FFI bindings. // If we get callbacks/events after this point, tests may panic if the callbacks // log messages. - Close(t *testing.T) - - Login(t *testing.T, opts ClientCreationOpts) error + Close(t Test) + // Remove any persistent storage, if it was enabled. + DeletePersistentStorage(t Test) + Login(t Test, opts ClientCreationOpts) error // StartSyncing to begin syncing from sync v2 / sliding sync. // Tests should call stopSyncing() at the end of the test. // MUST BLOCK until the initial sync is complete. - StartSyncing(t *testing.T) (stopSyncing func()) + StartSyncing(t Test) (stopSyncing func()) // IsRoomEncrypted returns true if the room is encrypted. May return an error e.g if you // provide a bogus room ID. - IsRoomEncrypted(t *testing.T, roomID string) (bool, error) + IsRoomEncrypted(t Test, roomID string) (bool, error) // SendMessage sends the given text as an m.room.message with msgtype:m.text into the given // room. Returns the event ID of the sent event, so MUST BLOCK until the event has been sent. - SendMessage(t *testing.T, roomID, text string) (eventID string) + SendMessage(t Test, roomID, text string) (eventID string) // TrySendMessage tries to send the message, but can fail. - TrySendMessage(t *testing.T, roomID, text string) (eventID string, err error) + TrySendMessage(t Test, roomID, text string) (eventID string, err error) // Wait until an event with the given body is seen. Not all impls expose event IDs // hence needing to use body as a proxy. - WaitUntilEventInRoom(t *testing.T, roomID string, checker func(e Event) bool) Waiter + WaitUntilEventInRoom(t Test, roomID string, checker func(e Event) bool) Waiter // Backpaginate in this room by `count` events. - MustBackpaginate(t *testing.T, roomID string, count int) + MustBackpaginate(t Test, roomID string, count int) // MustGetEvent will return the client's view of this event, or fail the test if the event cannot be found. - MustGetEvent(t *testing.T, roomID, eventID string) Event + MustGetEvent(t Test, roomID, eventID string) Event // MustBackupKeys will backup E2EE keys, else fail the test. - MustBackupKeys(t *testing.T) (recoveryKey string) + MustBackupKeys(t Test) (recoveryKey string) // MustLoadBackup will recover E2EE keys from the latest backup, else fail the test. - MustLoadBackup(t *testing.T, recoveryKey string) + MustLoadBackup(t Test, recoveryKey string) // Log something to stdout and the underlying client log file - Logf(t *testing.T, format string, args ...interface{}) + Logf(t Test, format string, args ...interface{}) // The user for this client UserID() string Type() ClientTypeLang @@ -64,13 +64,19 @@ type LoggedClient struct { Client } -func (c *LoggedClient) Close(t *testing.T) { +func (c *LoggedClient) Login(t Test, opts ClientCreationOpts) error { + t.Helper() + c.Logf(t, "%s Login %+v", c.logPrefix(), opts) + return c.Client.Login(t, opts) +} + +func (c *LoggedClient) Close(t Test) { t.Helper() c.Logf(t, "%s Close", c.logPrefix()) c.Client.Close(t) } -func (c *LoggedClient) StartSyncing(t *testing.T) (stopSyncing func()) { +func (c *LoggedClient) StartSyncing(t Test) (stopSyncing func()) { t.Helper() c.Logf(t, "%s StartSyncing starting to sync", c.logPrefix()) stopSyncing = c.Client.StartSyncing(t) @@ -78,13 +84,13 @@ func (c *LoggedClient) StartSyncing(t *testing.T) (stopSyncing func()) { return } -func (c *LoggedClient) IsRoomEncrypted(t *testing.T, roomID string) (bool, error) { +func (c *LoggedClient) IsRoomEncrypted(t Test, roomID string) (bool, error) { t.Helper() c.Logf(t, "%s IsRoomEncrypted %s", c.logPrefix(), roomID) return c.Client.IsRoomEncrypted(t, roomID) } -func (c *LoggedClient) TrySendMessage(t *testing.T, roomID, text string) (eventID string, err error) { +func (c *LoggedClient) TrySendMessage(t Test, roomID, text string) (eventID string, err error) { t.Helper() c.Logf(t, "%s TrySendMessage %s => %s", c.logPrefix(), roomID, text) eventID, err = c.Client.TrySendMessage(t, roomID, text) @@ -92,7 +98,7 @@ func (c *LoggedClient) TrySendMessage(t *testing.T, roomID, text string) (eventI return } -func (c *LoggedClient) SendMessage(t *testing.T, roomID, text string) (eventID string) { +func (c *LoggedClient) SendMessage(t Test, roomID, text string) (eventID string) { t.Helper() c.Logf(t, "%s SendMessage %s => %s", c.logPrefix(), roomID, text) eventID = c.Client.SendMessage(t, roomID, text) @@ -100,19 +106,19 @@ func (c *LoggedClient) SendMessage(t *testing.T, roomID, text string) (eventID s return } -func (c *LoggedClient) WaitUntilEventInRoom(t *testing.T, roomID string, checker func(e Event) bool) Waiter { +func (c *LoggedClient) WaitUntilEventInRoom(t Test, roomID string, checker func(e Event) bool) Waiter { t.Helper() c.Logf(t, "%s WaitUntilEventInRoom %s", c.logPrefix(), roomID) return c.Client.WaitUntilEventInRoom(t, roomID, checker) } -func (c *LoggedClient) MustBackpaginate(t *testing.T, roomID string, count int) { +func (c *LoggedClient) MustBackpaginate(t Test, roomID string, count int) { t.Helper() c.Logf(t, "%s MustBackpaginate %d %s", c.logPrefix(), count, roomID) c.Client.MustBackpaginate(t, roomID, count) } -func (c *LoggedClient) MustBackupKeys(t *testing.T) (recoveryKey string) { +func (c *LoggedClient) MustBackupKeys(t Test) (recoveryKey string) { t.Helper() c.Logf(t, "%s MustBackupKeys", c.logPrefix()) recoveryKey = c.Client.MustBackupKeys(t) @@ -120,12 +126,18 @@ func (c *LoggedClient) MustBackupKeys(t *testing.T) (recoveryKey string) { return recoveryKey } -func (c *LoggedClient) MustLoadBackup(t *testing.T, recoveryKey string) { +func (c *LoggedClient) MustLoadBackup(t Test, recoveryKey string) { t.Helper() c.Logf(t, "%s MustLoadBackup key=%s", c.logPrefix(), recoveryKey) c.Client.MustLoadBackup(t, recoveryKey) } +func (c *LoggedClient) DeletePersistentStorage(t Test) { + t.Helper() + c.Logf(t, "%s DeletePersistentStorage", c.logPrefix()) + c.Client.DeletePersistentStorage(t) +} + func (c *LoggedClient) logPrefix() string { return fmt.Sprintf("[%s](%s)", c.UserID(), c.Type()) } @@ -142,6 +154,9 @@ type ClientCreationOpts struct { // Required. The password for this account. Password string + // Optional. If true, persistent storage will be used for the same user|device ID. + PersistentStorage bool + // Optional. Set this to login with this device ID. DeviceID string } @@ -167,7 +182,7 @@ type Event struct { } type Waiter interface { - Wait(t *testing.T, s time.Duration) + Wait(t Test, s time.Duration) } func CheckEventHasBody(body string) func(e Event) bool { @@ -186,15 +201,42 @@ const ansiRedForeground = "\x1b[31m" const ansiResetForeground = "\x1b[39m" // Errorf is a wrapper around t.Errorf which prints the failing error message in red. -func Errorf(t *testing.T, format string, args ...any) { +func Errorf(t Test, format string, args ...any) { t.Helper() format = ansiRedForeground + format + ansiResetForeground t.Errorf(format, args...) } // Fatalf is a wrapper around t.Fatalf which prints the failing error message in red. -func Fatalf(t *testing.T, format string, args ...any) { +func Fatalf(t Test, format string, args ...any) { t.Helper() format = ansiRedForeground + format + ansiResetForeground t.Fatalf(format, args...) } + +type Test interface { + Logf(f string, args ...any) + Errorf(f string, args ...any) + Fatalf(f string, args ...any) + Helper() + Name() string +} + +// TODO move to must package when it accepts an interface + +// NotError will ensure `err` is nil else terminate the test with `msg`. +func MustNotError(t Test, msg string, err error) { + t.Helper() + if err != nil { + Fatalf(t, "must.NotError: %s -> %s", msg, err) + } +} + +// NotEqual ensures that got!=want else logs an error. +// The 'msg' is displayed with the error to provide extra context. +func MustNotEqual[V comparable](t Test, got, want V, msg string) { + t.Helper() + if got == want { + Errorf(t, "NotEqual %s: got '%v', want '%v'", msg, got, want) + } +} diff --git a/internal/api/js/chrome/chrome.go b/internal/api/js/chrome/chrome.go index d318ac0..cd284dd 100644 --- a/internal/api/js/chrome/chrome.go +++ b/internal/api/js/chrome/chrome.go @@ -13,7 +13,6 @@ import ( "net/http" "strconv" "sync" - "testing" "time" "github.com/chromedp/cdproto/runtime" @@ -35,7 +34,7 @@ type Void *runtime.RemoteObject // // result, err := RunAsyncFn[string](t, ctx, "return await getSomeString()") // void, err := RunAsyncFn[chrome.Void](t, ctx, "doSomething(); await doSomethingElse();") -func RunAsyncFn[T any](t *testing.T, ctx context.Context, js string) (*T, error) { +func RunAsyncFn[T any](t api.Test, ctx context.Context, js string) (*T, error) { t.Helper() out := new(T) err := chromedp.Run(ctx, @@ -54,7 +53,7 @@ func RunAsyncFn[T any](t *testing.T, ctx context.Context, js string) (*T, error) // Run an anonymous async iffe in the browser. Set the type parameter to a basic data type // which can be returned as JSON e.g string, map[string]any, []string. If you do not want // to return anything, use chrome.Void -func MustRunAsyncFn[T any](t *testing.T, ctx context.Context, js string) *T { +func MustRunAsyncFn[T any](t api.Test, ctx context.Context, js string) *T { t.Helper() result, err := RunAsyncFn[T](t, ctx, js) if err != nil { diff --git a/internal/api/js/js.go b/internal/api/js/js.go index 227717d..8e1bb45 100644 --- a/internal/api/js/js.go +++ b/internal/api/js/js.go @@ -6,12 +6,10 @@ import ( "os" "strings" "sync/atomic" - "testing" "time" "github.com/matrix-org/complement-crypto/internal/api" "github.com/matrix-org/complement-crypto/internal/api/js/chrome" - "github.com/matrix-org/complement/must" "github.com/tidwall/gjson" ) @@ -47,7 +45,7 @@ type JSClient struct { userID string } -func NewJSClient(t *testing.T, opts api.ClientCreationOpts) (api.Client, error) { +func NewJSClient(t api.Test, opts api.ClientCreationOpts) (api.Client, error) { jsc := &JSClient{ listeners: make(map[int32]func(roomID string, ev api.Event)), userID: opts.UserID, @@ -116,7 +114,7 @@ func NewJSClient(t *testing.T, opts api.ClientCreationOpts) (api.Client, error) return &api.LoggedClient{Client: jsc}, nil } -func (c *JSClient) Login(t *testing.T, opts api.ClientCreationOpts) error { +func (c *JSClient) Login(t api.Test, opts api.ClientCreationOpts) error { deviceID := "undefined" if opts.DeviceID != "" { deviceID = `"` + opts.DeviceID + `"` @@ -142,11 +140,16 @@ func (c *JSClient) Login(t *testing.T, opts api.ClientCreationOpts) error { return nil } +func (c *JSClient) DeletePersistentStorage(t api.Test) { + t.Helper() + // TODO +} + // Close is called to clean up resources. // Specifically, we need to shut off existing browsers and any FFI bindings. // If we get callbacks/events after this point, tests may panic if the callbacks // log messages. -func (c *JSClient) Close(t *testing.T) { +func (c *JSClient) Close(t api.Test) { c.browser.Cancel() c.listeners = make(map[int32]func(roomID string, ev api.Event)) } @@ -155,7 +158,7 @@ func (c *JSClient) UserID() string { return c.userID } -func (c *JSClient) MustGetEvent(t *testing.T, roomID, eventID string) api.Event { +func (c *JSClient) MustGetEvent(t api.Test, roomID, eventID string) api.Event { t.Helper() // serialised output (if encrypted): // { @@ -197,7 +200,7 @@ func (c *JSClient) MustGetEvent(t *testing.T, roomID, eventID string) api.Event // StartSyncing to begin syncing from sync v2 / sliding sync. // Tests should call stopSyncing() at the end of the test. -func (c *JSClient) StartSyncing(t *testing.T) (stopSyncing func()) { +func (c *JSClient) StartSyncing(t api.Test) (stopSyncing func()) { t.Helper() chrome.MustRunAsyncFn[chrome.Void](t, c.browser.Ctx, fmt.Sprintf(` var fn; @@ -234,7 +237,7 @@ func (c *JSClient) StartSyncing(t *testing.T) (stopSyncing func()) { // IsRoomEncrypted returns true if the room is encrypted. May return an error e.g if you // provide a bogus room ID. -func (c *JSClient) IsRoomEncrypted(t *testing.T, roomID string) (bool, error) { +func (c *JSClient) IsRoomEncrypted(t api.Test, roomID string) (bool, error) { t.Helper() isEncrypted, err := chrome.RunAsyncFn[bool]( t, c.browser.Ctx, fmt.Sprintf(`return window.__client.isRoomEncrypted("%s")`, roomID), @@ -247,14 +250,14 @@ func (c *JSClient) IsRoomEncrypted(t *testing.T, roomID string) (bool, error) { // SendMessage sends the given text as an m.room.message with msgtype:m.text into the given // room. -func (c *JSClient) SendMessage(t *testing.T, roomID, text string) (eventID string) { +func (c *JSClient) SendMessage(t api.Test, roomID, text string) (eventID string) { t.Helper() eventID, err := c.TrySendMessage(t, roomID, text) - must.NotError(t, "failed to sendMessage", err) + api.MustNotError(t, "failed to sendMessage", err) return eventID } -func (c *JSClient) TrySendMessage(t *testing.T, roomID, text string) (eventID string, err error) { +func (c *JSClient) TrySendMessage(t api.Test, roomID, text string) (eventID string, err error) { t.Helper() res, err := chrome.RunAsyncFn[map[string]interface{}](t, c.browser.Ctx, fmt.Sprintf(` return await window.__client.sendMessage("%s", { @@ -267,14 +270,14 @@ func (c *JSClient) TrySendMessage(t *testing.T, roomID, text string) (eventID st return (*res)["event_id"].(string), nil } -func (c *JSClient) MustBackpaginate(t *testing.T, roomID string, count int) { +func (c *JSClient) MustBackpaginate(t api.Test, roomID string, count int) { t.Helper() chrome.MustRunAsyncFn[chrome.Void](t, c.browser.Ctx, fmt.Sprintf( `await window.__client.scrollback(window.__client.getRoom("%s"), %d);`, roomID, count, )) } -func (c *JSClient) MustBackupKeys(t *testing.T) (recoveryKey string) { +func (c *JSClient) MustBackupKeys(t api.Test) (recoveryKey string) { t.Helper() key := chrome.MustRunAsyncFn[string](t, c.browser.Ctx, ` // we need to ensure that we have a recovery key first, though we don't actually care about it..? @@ -295,7 +298,7 @@ func (c *JSClient) MustBackupKeys(t *testing.T) (recoveryKey string) { return *key } -func (c *JSClient) MustLoadBackup(t *testing.T, recoveryKey string) { +func (c *JSClient) MustLoadBackup(t api.Test, recoveryKey string) { chrome.MustRunAsyncFn[chrome.Void](t, c.browser.Ctx, fmt.Sprintf(` // we assume the recovery key is the private key for the default key id so // figure out what that key id is. @@ -312,7 +315,7 @@ func (c *JSClient) MustLoadBackup(t *testing.T, recoveryKey string) { recoveryKey)) } -func (c *JSClient) WaitUntilEventInRoom(t *testing.T, roomID string, checker func(e api.Event) bool) api.Waiter { +func (c *JSClient) WaitUntilEventInRoom(t api.Test, roomID string, checker func(e api.Event) bool) api.Waiter { t.Helper() return &jsTimelineWaiter{ roomID: roomID, @@ -321,7 +324,7 @@ func (c *JSClient) WaitUntilEventInRoom(t *testing.T, roomID string, checker fun } } -func (c *JSClient) Logf(t *testing.T, format string, args ...interface{}) { +func (c *JSClient) Logf(t api.Test, format string, args ...interface{}) { t.Helper() formatted := fmt.Sprintf(t.Name()+": "+format, args...) chrome.MustRunAsyncFn[chrome.Void](t, c.browser.Ctx, fmt.Sprintf(`console.log("%s");`, formatted)) @@ -346,7 +349,7 @@ type jsTimelineWaiter struct { client *JSClient } -func (w *jsTimelineWaiter) Wait(t *testing.T, s time.Duration) { +func (w *jsTimelineWaiter) Wait(t api.Test, s time.Duration) { t.Helper() updates := make(chan bool, 3) cancel := w.client.listenForUpdates(func(roomID string, ev api.Event) { diff --git a/internal/api/rust/rust.go b/internal/api/rust/rust.go index 8e180c1..b1b2e63 100644 --- a/internal/api/rust/rust.go +++ b/internal/api/rust/rust.go @@ -2,14 +2,14 @@ package rust import ( "fmt" + "os" + "strings" "sync" "sync/atomic" - "testing" "time" "github.com/matrix-org/complement-crypto/internal/api" "github.com/matrix-org/complement-crypto/rust/matrix_sdk_ffi" - "github.com/matrix-org/complement/must" "golang.org/x/exp/slices" ) @@ -33,18 +33,25 @@ type RustRoomInfo struct { } type RustClient struct { - FFIClient *matrix_sdk_ffi.Client - listeners map[int32]func(roomID string) - listenerID atomic.Int32 - allRooms *matrix_sdk_ffi.RoomList - rooms map[string]*RustRoomInfo - roomsMu *sync.RWMutex - userID string + FFIClient *matrix_sdk_ffi.Client + listeners map[int32]func(roomID string) + listenerID atomic.Int32 + allRooms *matrix_sdk_ffi.RoomList + rooms map[string]*RustRoomInfo + roomsMu *sync.RWMutex + userID string + persistentStoragePath string } -func NewRustClient(t *testing.T, opts api.ClientCreationOpts, ssURL string) (api.Client, error) { +func NewRustClient(t api.Test, opts api.ClientCreationOpts, ssURL string) (api.Client, error) { t.Logf("NewRustClient[%s] creating...", opts.UserID) ab := matrix_sdk_ffi.NewClientBuilder().HomeserverUrl(opts.BaseURL).SlidingSyncProxy(&ssURL) + var username string + if opts.PersistentStorage { + // @alice:hs1, FOOBAR => alice_hs1_FOOBAR + username = strings.Replace(opts.UserID[1:], ":", "_", -1) + "_" + opts.DeviceID + ab.BasePath("rust_storage").Username(username) + } client, err := ab.Build() if err != nil { return nil, fmt.Errorf("ClientBuilder.Build failed: %s", err) @@ -56,11 +63,14 @@ func NewRustClient(t *testing.T, opts api.ClientCreationOpts, ssURL string) (api listeners: make(map[int32]func(roomID string)), roomsMu: &sync.RWMutex{}, } - c.Logf(t, "NewRustClient[%s] created client", opts.UserID) + if opts.PersistentStorage { + c.persistentStoragePath = "./rust_storage/" + username + } + c.Logf(t, "NewRustClient[%s] created client storage=%v", opts.UserID, c.persistentStoragePath) return &api.LoggedClient{Client: c}, nil } -func (c *RustClient) Login(t *testing.T, opts api.ClientCreationOpts) error { +func (c *RustClient) Login(t api.Test, opts api.ClientCreationOpts) error { var deviceID *string if opts.DeviceID != "" { deviceID = &opts.DeviceID @@ -69,10 +79,21 @@ func (c *RustClient) Login(t *testing.T, opts api.ClientCreationOpts) error { if err != nil { return fmt.Errorf("Client.Login failed: %s", err) } + c.FFIClient.Destroy() return nil } -func (c *RustClient) Close(t *testing.T) { +func (c *RustClient) DeletePersistentStorage(t api.Test) { + t.Helper() + if c.persistentStoragePath != "" { + err := os.RemoveAll(c.persistentStoragePath) + if err != nil { + api.Fatalf(t, "DeletePersistentStorage: %s", err) + } + } +} + +func (c *RustClient) Close(t api.Test) { t.Helper() c.roomsMu.Lock() for _, rri := range c.rooms { @@ -85,7 +106,7 @@ func (c *RustClient) Close(t *testing.T) { c.FFIClient.Destroy() } -func (c *RustClient) MustGetEvent(t *testing.T, roomID, eventID string) api.Event { +func (c *RustClient) MustGetEvent(t api.Test, roomID, eventID string) api.Event { t.Helper() room := c.findRoom(t, roomID) timelineItem, err := room.Timeline().GetEventTimelineItemByEventId(eventID) @@ -101,15 +122,15 @@ func (c *RustClient) MustGetEvent(t *testing.T, roomID, eventID string) api.Even // StartSyncing to begin syncing from sync v2 / sliding sync. // Tests should call stopSyncing() at the end of the test. -func (c *RustClient) StartSyncing(t *testing.T) (stopSyncing func()) { +func (c *RustClient) StartSyncing(t api.Test) (stopSyncing func()) { t.Helper() syncService, err := c.FFIClient.SyncService().Finish() - must.NotError(t, fmt.Sprintf("[%s]failed to make sync service", c.userID), err) + api.MustNotError(t, fmt.Sprintf("[%s]failed to make sync service", c.userID), err) roomList, err := syncService.RoomListService().AllRooms() - must.NotError(t, "failed to call SyncService.RoomListService.AllRooms", err) + api.MustNotError(t, "failed to call SyncService.RoomListService.AllRooms", err) genericListener := newGenericStateListener[matrix_sdk_ffi.RoomListLoadingState]() result, err := roomList.LoadingState(genericListener) - must.NotError(t, "failed to call RoomList.LoadingState", err) + api.MustNotError(t, "failed to call RoomList.LoadingState", err) go syncService.Start() c.allRooms = roomList @@ -140,7 +161,7 @@ func (c *RustClient) StartSyncing(t *testing.T) (stopSyncing func()) { // IsRoomEncrypted returns true if the room is encrypted. May return an error e.g if you // provide a bogus room ID. -func (c *RustClient) IsRoomEncrypted(t *testing.T, roomID string) (bool, error) { +func (c *RustClient) IsRoomEncrypted(t api.Test, roomID string) (bool, error) { t.Helper() r := c.findRoom(t, roomID) if r == nil { @@ -150,7 +171,7 @@ func (c *RustClient) IsRoomEncrypted(t *testing.T, roomID string) (bool, error) return r.IsEncrypted() } -func (c *RustClient) MustBackupKeys(t *testing.T) (recoveryKey string) { +func (c *RustClient) MustBackupKeys(t api.Test) (recoveryKey string) { t.Helper() genericListener := newGenericStateListener[matrix_sdk_ffi.EnableRecoveryProgress]() var listener matrix_sdk_ffi.EnableRecoveryProgressListener = genericListener @@ -168,16 +189,16 @@ func (c *RustClient) MustBackupKeys(t *testing.T) (recoveryKey string) { genericListener.Close() // break the loop } } - must.NotError(t, "Encryption.EnableRecovery", err) + api.MustNotError(t, "Encryption.EnableRecovery", err) return recoveryKey } -func (c *RustClient) MustLoadBackup(t *testing.T, recoveryKey string) { +func (c *RustClient) MustLoadBackup(t api.Test, recoveryKey string) { t.Helper() - must.NotError(t, "Recover", c.FFIClient.Encryption().Recover(recoveryKey)) + api.MustNotError(t, "Recover", c.FFIClient.Encryption().Recover(recoveryKey)) } -func (c *RustClient) WaitUntilEventInRoom(t *testing.T, roomID string, checker func(api.Event) bool) api.Waiter { +func (c *RustClient) WaitUntilEventInRoom(t api.Test, roomID string, checker func(api.Event) bool) api.Waiter { t.Helper() c.ensureListening(t, roomID) return &timelineWaiter{ @@ -193,7 +214,7 @@ func (c *RustClient) Type() api.ClientTypeLang { // SendMessage sends the given text as an m.room.message with msgtype:m.text into the given // room. Returns the event ID of the sent event. -func (c *RustClient) SendMessage(t *testing.T, roomID, text string) (eventID string) { +func (c *RustClient) SendMessage(t api.Test, roomID, text string) (eventID string) { t.Helper() eventID, err := c.TrySendMessage(t, roomID, text) if err != nil { @@ -202,7 +223,7 @@ func (c *RustClient) SendMessage(t *testing.T, roomID, text string) (eventID str return eventID } -func (c *RustClient) TrySendMessage(t *testing.T, roomID, text string) (eventID string, err error) { +func (c *RustClient) TrySendMessage(t api.Test, roomID, text string) (eventID string, err error) { t.Helper() ch := make(chan bool) // we need a timeline listener before we can send messages, AND that listener must be attached to the @@ -234,11 +255,11 @@ func (c *RustClient) TrySendMessage(t *testing.T, roomID, text string) (eventID } } -func (c *RustClient) MustBackpaginate(t *testing.T, roomID string, count int) { +func (c *RustClient) MustBackpaginate(t api.Test, roomID string, count int) { t.Helper() r := c.findRoom(t, roomID) - must.NotEqual(t, r, nil, "unknown room") - must.NotError(t, "failed to backpaginate", r.Timeline().PaginateBackwards(matrix_sdk_ffi.PaginationOptionsSimpleRequest{ + api.MustNotEqual(t, r, nil, "unknown room") + api.MustNotError(t, "failed to backpaginate", r.Timeline().PaginateBackwards(matrix_sdk_ffi.PaginationOptionsSimpleRequest{ EventLimit: uint16(count), })) } @@ -259,7 +280,7 @@ func (c *RustClient) findRoomInMap(roomID string) *matrix_sdk_ffi.Room { } // findRoom returns the room, waiting up to 5s for it to appear -func (c *RustClient) findRoom(t *testing.T, roomID string) *matrix_sdk_ffi.Room { +func (c *RustClient) findRoom(t api.Test, roomID string) *matrix_sdk_ffi.Room { t.Helper() room := c.findRoomInMap(roomID) if room != nil { @@ -300,20 +321,20 @@ func (c *RustClient) findRoom(t *testing.T, roomID string) *matrix_sdk_ffi.Room return nil } -func (c *RustClient) Logf(t *testing.T, format string, args ...interface{}) { +func (c *RustClient) Logf(t api.Test, format string, args ...interface{}) { t.Helper() c.logToFile(t, format, args...) t.Logf(format, args...) } -func (c *RustClient) logToFile(t *testing.T, format string, args ...interface{}) { +func (c *RustClient) logToFile(t api.Test, format string, args ...interface{}) { matrix_sdk_ffi.LogEvent("rust.go", &zero, matrix_sdk_ffi.LogLevelInfo, t.Name(), fmt.Sprintf(format, args...)) } -func (c *RustClient) ensureListening(t *testing.T, roomID string) *matrix_sdk_ffi.Room { +func (c *RustClient) ensureListening(t api.Test, roomID string) *matrix_sdk_ffi.Room { t.Helper() r := c.findRoom(t, roomID) - must.NotEqual(t, r, nil, fmt.Sprintf("room %s does not exist", roomID)) + api.MustNotEqual(t, r, nil, fmt.Sprintf("room %s does not exist", roomID)) info := c.rooms[roomID] if info.stream != nil { @@ -415,7 +436,7 @@ type timelineWaiter struct { client *RustClient } -func (w *timelineWaiter) Wait(t *testing.T, s time.Duration) { +func (w *timelineWaiter) Wait(t api.Test, s time.Duration) { t.Helper() checkForEvent := func() bool { diff --git a/tests/client_connectivity_test.go b/tests/client_connectivity_test.go index 1378b56..e2ce9fe 100644 --- a/tests/client_connectivity_test.go +++ b/tests/client_connectivity_test.go @@ -1,13 +1,171 @@ package tests import ( + "encoding/json" "net/http" + "os" + "os/exec" + "sync" + "sync/atomic" "testing" + "text/template" "time" "github.com/matrix-org/complement-crypto/internal/api" + "github.com/matrix-org/complement-crypto/internal/api/js" + "github.com/matrix-org/complement-crypto/internal/api/rust" + "github.com/matrix-org/complement/must" ) +// TODO: move to internal? or addons?! +type CallbackData struct { + Method string `json:"method"` + URL string `json:"url"` + AccessToken string `json:"access_token"` + ResponseCode int `json:"response_code"` +} + +// TODO: move internally +func RunGoProcess(t *testing.T, templateFilename string, templateData any) (*exec.Cmd, func()) { + tmpl, err := template.New(templateFilename).ParseFiles("./templates/" + templateFilename) + if err != nil { + api.Fatalf(t, "failed to parse template %s : %s", templateFilename, err) + } + scriptFile, err := os.CreateTemp("./templates", "script_*.go") // os.CreateTemp(".", "script_*.go") + if err != nil { + api.Fatalf(t, "failed to open temporary file: %s", err) + } + defer scriptFile.Close() + if err = tmpl.ExecuteTemplate(scriptFile, templateFilename, templateData); err != nil { + api.Fatalf(t, "failed to execute template to file: %s", err) + } + // TODO: should we build output to the random number? + // e.g go build -o ./templates/script ./templates/script_3523965439.go + cmd := exec.Command("go", "build", "-o", "./templates/script", scriptFile.Name()) + t.Logf(cmd.String()) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + t.Fatalf("Failed to build script %s: %s", scriptFile.Name(), err) + } + return exec.Command("./templates/script"), func() { + os.Remove(scriptFile.Name()) + os.Remove("./templates/script") + } +} + +// Test that if the client is restarted BEFORE getting the /keys/upload response but +// AFTER the server has processed the request, the keys are not regenerated (which would +// cause duplicate key IDs with different keys). Requires persistent storage. +func TestSigkillBeforeKeysUploadResponse(t *testing.T) { + for _, clientType := range []api.ClientType{{Lang: api.ClientTypeRust, HS: "hs1"}} { // {Lang: api.ClientTypeJS} + t.Run(string(clientType.Lang), func(t *testing.T) { + var mu sync.Mutex + var terminated atomic.Bool + var terminateClient func() + // TODO: factor out to helper + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + var data CallbackData + if err := json.NewDecoder(r.Body).Decode(&data); err != nil { + t.Logf("error decoding json: %s", err) + w.WriteHeader(500) + return + } + t.Logf("%v %+v", time.Now(), data) + if terminated.Load() { + // make sure the 2nd upload 200 OKs + if data.ResponseCode != 200 { + // TODO: Errorf + t.Logf("2nd /keys/upload did not 200 OK => got %v", data.ResponseCode) + } + w.WriteHeader(200) + return // 2nd /keys/upload should go through + } + // destroy the client + mu.Lock() + terminateClient() + mu.Unlock() + w.WriteHeader(200) + }) + srv := http.Server{ + Addr: "127.0.0.1:6879", + Handler: mux, + } + defer srv.Close() + go srv.ListenAndServe() + + tc := CreateTestContext(t, clientType, clientType) + tc.Deployment.WithMITMOptions(t, map[string]interface{}{ + "callback": map[string]interface{}{ + "callback_url": "http://host.docker.internal:6879", + "filter": "~u .*\\/keys\\/upload.*", + }, + }, func() { + cfg := api.FromComplementClient(tc.Alice, "complement-crypto-password") + // run some code in a separate process so we can kill it later + cmd, close := RunGoProcess(t, "sigkill_before_keys_upload_response.go", + struct { + UserID string + DeviceID string + Password string + BaseURL string + SSURL string + }{ + UserID: cfg.UserID, + Password: cfg.Password, + DeviceID: cfg.DeviceID, + BaseURL: tc.Deployment.ReverseProxyURLForHS(clientType.HS), + SSURL: tc.Deployment.SlidingSyncURL(t), + }) + + defer close() + var wg sync.WaitGroup + wg.Add(1) + + terminateClient = func() { + terminated.Store(true) + t.Logf("got keys/upload: terminating process") + if err := cmd.Process.Kill(); err != nil { + t.Fatalf("failed to kill process: %s", err) + } + cmd.Wait() + t.Logf("terminated process") + wg.Done() + } + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Start() + wg.Wait() + t.Logf("terminated process, making new client") + // now make the same client + cfg.BaseURL = tc.Deployment.ReverseProxyURLForHS(clientType.HS) + cfg.PersistentStorage = true + alice := mustCreateClient(t, clientType, tc, cfg) + alice.Login(t, cfg) // login should work + alice.Close(t) + alice.DeletePersistentStorage(t) + }) + }) + } +} + +func mustCreateClient(t *testing.T, clientType api.ClientType, tc *TestContext, cfg api.ClientCreationOpts) api.Client { + switch clientType.Lang { + case api.ClientTypeRust: + client, err := rust.NewRustClient(t, cfg, tc.Deployment.SlidingSyncURL(t)) + must.NotError(t, "NewRustClient: %s", err) + return client + case api.ClientTypeJS: + client, err := js.NewJSClient(t, cfg) + must.NotError(t, "NewJSClient: %s", err) + return client + default: + t.Fatalf("unknown client type %v", clientType) + } + panic("unreachable") +} + // Test that if a client is unable to call /sendToDevice, it retries. func TestClientRetriesSendToDevice(t *testing.T) { ClientTypeMatrix(t, func(t *testing.T, clientTypeA, clientTypeB api.ClientType) { diff --git a/tests/templates/sigkill_before_keys_upload_response.go b/tests/templates/sigkill_before_keys_upload_response.go new file mode 100644 index 0000000..9f5bc0e --- /dev/null +++ b/tests/templates/sigkill_before_keys_upload_response.go @@ -0,0 +1,45 @@ +package main + +import ( + "fmt" + "os" + "time" + + "github.com/matrix-org/complement-crypto/internal/api" + "github.com/matrix-org/complement-crypto/internal/api/rust" +) + +type MockT struct{} + +func (t *MockT) Helper() {} +func (t *MockT) Logf(f string, args ...any) { + fmt.Printf(f, args...) +} +func (t *MockT) Errorf(f string, args ...any) { + fmt.Printf(f, args...) +} +func (t *MockT) Fatalf(f string, args ...any) { + fmt.Printf(f, args...) + os.Exit(1) +} +func (t *MockT) Name() string { return "inline_script" } + +func main() { + time.Sleep(time.Second) + t := &MockT{} + cfg := api.ClientCreationOpts{ + BaseURL: "{{.BaseURL}}", + UserID: "{{.UserID}}", + DeviceID: "{{.DeviceID}}", + Password: "{{.Password}}", + PersistentStorage: true, + } + client, err := rust.NewRustClient(t, cfg, "{{.SSURL}}") + if err != nil { + panic(err) + } + fmt.Println(time.Now(), "script about to login, expect /keys/upload") + client.Login(t, cfg) + fmt.Println("exiting.. you should not see this as it should have been sigkilled by now!") + +}