diff --git a/spec/unit/rust-crypto.spec.ts b/spec/unit/rust-crypto.spec.ts index fe98dedbb8a..822c08bf6c9 100644 --- a/spec/unit/rust-crypto.spec.ts +++ b/spec/unit/rust-crypto.spec.ts @@ -16,10 +16,21 @@ limitations under the License. import "fake-indexeddb/auto"; import { IDBFactory } from "fake-indexeddb"; +import * as RustSdkCryptoJs from "@matrix-org/matrix-sdk-crypto-js"; +import { + KeysBackupRequest, + KeysClaimRequest, + KeysQueryRequest, + KeysUploadRequest, + SignatureUploadRequest, +} from "@matrix-org/matrix-sdk-crypto-js"; +import { Mocked } from "jest-mock"; +import MockHttpBackend from "matrix-mock-request"; import { RustCrypto } from "../../src/rust-crypto/rust-crypto"; import { initRustCrypto } from "../../src/rust-crypto"; -import { IHttpOpts, MatrixHttpApi } from "../../src"; +import { HttpApiEvent, HttpApiEventHandlerMap, IHttpOpts, MatrixHttpApi } from "../../src"; +import { TypedEventEmitter } from "../../src/models/typed-event-emitter"; afterEach(() => { // reset fake-indexeddb after each test, to make sure we don't leak connections @@ -32,17 +43,163 @@ describe("RustCrypto", () => { const TEST_USER = "@alice:example.com"; const TEST_DEVICE_ID = "TEST_DEVICE"; - let rustCrypto: RustCrypto; + describe(".exportRoomKeys", () => { + let rustCrypto: RustCrypto; - beforeEach(async () => { - const mockHttpApi = {} as MatrixHttpApi; - rustCrypto = (await initRustCrypto(mockHttpApi, TEST_USER, TEST_DEVICE_ID)) as RustCrypto; - }); + beforeEach(async () => { + const mockHttpApi = {} as MatrixHttpApi; + rustCrypto = (await initRustCrypto(mockHttpApi, TEST_USER, TEST_DEVICE_ID)) as RustCrypto; + }); - describe(".exportRoomKeys", () => { it("should return a list", async () => { const keys = await rustCrypto.exportRoomKeys(); expect(Array.isArray(keys)).toBeTruthy(); }); }); + + describe("outgoing requests", () => { + /** the RustCrypto implementation under test */ + let rustCrypto: RustCrypto; + + /** A mock http backend which rustCrypto is connected to */ + let httpBackend: MockHttpBackend; + + /** a mocked-up OlmMachine which rustCrypto is connected to */ + let olmMachine: Mocked; + + /** A list of results to be returned from olmMachine.outgoingRequest. Each call will shift a result off + * the front of the queue, until it is empty. */ + let outgoingRequestQueue: Array>; + + /** wait for a call to olmMachine.markRequestAsSent */ + function awaitCallToMarkAsSent(): Promise { + return new Promise((resolve, _reject) => { + olmMachine.markRequestAsSent.mockImplementationOnce(async () => { + resolve(undefined); + }); + }); + } + + beforeEach(async () => { + httpBackend = new MockHttpBackend(); + + await RustSdkCryptoJs.initAsync(); + + const dummyEventEmitter = new TypedEventEmitter(); + const httpApi = new MatrixHttpApi(dummyEventEmitter, { + baseUrl: "https://example.com", + prefix: "/_matrix", + fetchFn: httpBackend.fetchFn as typeof global.fetch, + }); + + // for these tests we use a mock OlmMachine, with an implementation of outgoingRequests that + // returns objects from outgoingRequestQueue + outgoingRequestQueue = []; + olmMachine = { + outgoingRequests: jest.fn().mockImplementation(() => { + return Promise.resolve(outgoingRequestQueue.shift() ?? []); + }), + markRequestAsSent: jest.fn(), + close: jest.fn(), + } as unknown as Mocked; + + rustCrypto = new RustCrypto(olmMachine, httpApi, TEST_USER, TEST_DEVICE_ID); + }); + + it("should poll for outgoing messages", () => { + rustCrypto.onSyncCompleted({}); + expect(olmMachine.outgoingRequests).toHaveBeenCalled(); + }); + + /* simple requests that map directly to the request body */ + const tests: Array<[any, "POST" | "PUT", string]> = [ + [KeysUploadRequest, "POST", "https://example.com/_matrix/client/v3/keys/upload"], + [KeysQueryRequest, "POST", "https://example.com/_matrix/client/v3/keys/query"], + [KeysClaimRequest, "POST", "https://example.com/_matrix/client/v3/keys/claim"], + [SignatureUploadRequest, "POST", "https://example.com/_matrix/client/v3/keys/signatures/upload"], + [KeysBackupRequest, "PUT", "https://example.com/_matrix/client/v3/room_keys/keys"], + ]; + + for (const [RequestClass, expectedMethod, expectedPath] of tests) { + it(`should handle ${RequestClass.name}s`, async () => { + const testBody = '{ "foo": "bar" }'; + const outgoingRequest = new RequestClass("1234", testBody); + outgoingRequestQueue.push([outgoingRequest]); + + const testResponse = '{ "result": 1 }'; + httpBackend + .when(expectedMethod, "/_matrix") + .check((req) => { + expect(req.path).toEqual(expectedPath); + expect(req.rawData).toEqual(testBody); + expect(req.headers["Accept"]).toEqual("application/json"); + expect(req.headers["Content-Type"]).toEqual("application/json"); + }) + .respond(200, testResponse, true); + + rustCrypto.onSyncCompleted({}); + + expect(olmMachine.outgoingRequests).toHaveBeenCalledTimes(1); + + const markSentCallPromise = awaitCallToMarkAsSent(); + await httpBackend.flushAllExpected(); + + await markSentCallPromise; + expect(olmMachine.markRequestAsSent).toHaveBeenCalledWith("1234", outgoingRequest.type, testResponse); + httpBackend.verifyNoOutstandingRequests(); + }); + } + + it("does not explode with unknown requests", async () => { + const outgoingRequest = { id: "5678", type: 987 }; + outgoingRequestQueue.push([outgoingRequest]); + + rustCrypto.onSyncCompleted({}); + + await awaitCallToMarkAsSent(); + expect(olmMachine.markRequestAsSent).toHaveBeenCalledWith("5678", 987, ""); + }); + + it("stops looping when stop() is called", async () => { + const testResponse = '{ "result": 1 }'; + + for (let i = 0; i < 5; i++) { + outgoingRequestQueue.push([new KeysQueryRequest("1234", "{}")]); + httpBackend.when("POST", "/_matrix").respond(200, testResponse, true); + } + + rustCrypto.onSyncCompleted({}); + + expect(rustCrypto["outgoingRequestLoopRunning"]).toBeTruthy(); + + // go a couple of times round the loop + await httpBackend.flush("/_matrix", 1); + await awaitCallToMarkAsSent(); + + await httpBackend.flush("/_matrix", 1); + await awaitCallToMarkAsSent(); + + // a second sync while this is going on shouldn't make any difference + rustCrypto.onSyncCompleted({}); + + await httpBackend.flush("/_matrix", 1); + await awaitCallToMarkAsSent(); + + // now stop... + rustCrypto.stop(); + + // which should (eventually) cause the loop to stop with no further calls to outgoingRequests + olmMachine.outgoingRequests.mockReset(); + + await new Promise((resolve) => { + setTimeout(resolve, 100); + }); + expect(rustCrypto["outgoingRequestLoopRunning"]).toBeFalsy(); + httpBackend.verifyNoOutstandingRequests(); + expect(olmMachine.outgoingRequests).not.toHaveBeenCalled(); + + // we sent three, so there should be 2 left + expect(outgoingRequestQueue.length).toEqual(2); + }); + }); }); diff --git a/src/rust-crypto/rust-crypto.ts b/src/rust-crypto/rust-crypto.ts index 48b5e795108..187f6d6cd31 100644 --- a/src/rust-crypto/rust-crypto.ts +++ b/src/rust-crypto/rust-crypto.ts @@ -15,13 +15,29 @@ limitations under the License. */ import * as RustSdkCryptoJs from "@matrix-org/matrix-sdk-crypto-js"; +import { + DecryptedRoomEvent, + KeysBackupRequest, + KeysClaimRequest, + KeysQueryRequest, + KeysUploadRequest, + SignatureUploadRequest, +} from "@matrix-org/matrix-sdk-crypto-js"; import type { IEventDecryptionResult, IMegolmSessionData } from "../@types/crypto"; import { MatrixEvent } from "../models/event"; import { CryptoBackend, OnSyncCompletedData } from "../common-crypto/CryptoBackend"; -import { IHttpOpts, MatrixHttpApi } from "../http-api"; +import { logger } from "../logger"; +import { IHttpOpts, IRequestOpts, MatrixHttpApi, Method } from "../http-api"; +import { QueryDict } from "../utils"; -// import { logger } from "../logger"; +/** + * Common interface for all the request types returned by `OlmMachine.outgoingRequests`. + */ +interface OutgoingRequest { + readonly id: string | undefined; + readonly type: number; +} /** * An implementation of {@link CryptoBackend} using the Rust matrix-sdk-crypto. @@ -30,9 +46,12 @@ export class RustCrypto implements CryptoBackend { public globalBlacklistUnverifiedDevices = false; public globalErrorOnUnknownDevices = false; - /** whether stop() has been called */ + /** whether {@link stop} has been called */ private stopped = false; + /** whether {@link outgoingRequestLoop} is currently running */ + private outgoingRequestLoopRunning = false; + public constructor( private readonly olmMachine: RustSdkCryptoJs.OlmMachine, private readonly http: MatrixHttpApi, @@ -69,11 +88,95 @@ export class RustCrypto implements CryptoBackend { return []; } + /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // + // SyncCryptoCallbacks implementation + // + /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + /** called by the sync loop after processing each sync. * * TODO: figure out something equivalent for sliding sync. * * @param syncState - information on the completed sync. */ - public onSyncCompleted(syncState: OnSyncCompletedData): void {} + public onSyncCompleted(syncState: OnSyncCompletedData): void { + // Processing the /sync may have produced new outgoing requests which need sending, so kick off the outgoing + // request loop, if it's not already running. + this.outgoingRequestLoop(); + } + + /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // + // Outgoing requests + // + /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + private async outgoingRequestLoop(): Promise { + if (this.outgoingRequestLoopRunning) { + return; + } + this.outgoingRequestLoopRunning = true; + try { + while (!this.stopped) { + const outgoingRequests: Object[] = await this.olmMachine.outgoingRequests(); + if (outgoingRequests.length == 0 || this.stopped) { + // no more messages to send (or we have been told to stop): exit the loop + return; + } + for (const msg of outgoingRequests) { + await this.doOutgoingRequest(msg as OutgoingRequest); + } + } + } catch (e) { + logger.error("Error processing outgoing-message requests from rust crypto-sdk", e); + } finally { + this.outgoingRequestLoopRunning = false; + } + } + + private async doOutgoingRequest(msg: OutgoingRequest): Promise { + let resp: string; + + /* refer https://docs.rs/matrix-sdk-crypto/0.6.0/matrix_sdk_crypto/requests/enum.OutgoingRequests.html + * for the complete list of request types + */ + if (msg instanceof KeysUploadRequest) { + resp = await this.rawJsonRequest(Method.Post, "/_matrix/client/v3/keys/upload", {}, msg.body); + } else if (msg instanceof KeysQueryRequest) { + resp = await this.rawJsonRequest(Method.Post, "/_matrix/client/v3/keys/query", {}, msg.body); + } else if (msg instanceof KeysClaimRequest) { + resp = await this.rawJsonRequest(Method.Post, "/_matrix/client/v3/keys/claim", {}, msg.body); + } else if (msg instanceof SignatureUploadRequest) { + resp = await this.rawJsonRequest(Method.Post, "/_matrix/client/v3/keys/signatures/upload", {}, msg.body); + } else if (msg instanceof KeysBackupRequest) { + resp = await this.rawJsonRequest(Method.Put, "/_matrix/client/v3/room_keys/keys", {}, msg.body); + } else { + // TODO: ToDeviceRequest, RoomMessageRequest + logger.warn("Unsupported outgoing message", Object.getPrototypeOf(msg)); + resp = ""; + } + + if (msg.id) { + await this.olmMachine.markRequestAsSent(msg.id, msg.type, resp); + } + } + + private async rawJsonRequest( + method: Method, + path: string, + queryParams: QueryDict, + body: string, + opts: IRequestOpts = {}, + ): Promise { + // unbeknownst to HttpApi, we are sending JSON + opts.headers ??= {}; + opts.headers["Content-Type"] = "application/json"; + + // we use the full prefix + opts.prefix ??= ""; + + const resp = await this.http.authedRequest(method, path, queryParams, body, opts); + return await resp.text(); + } }