Skip to content

Commit

Permalink
Implement outgoing requests from Rust SDK
Browse files Browse the repository at this point in the history
  • Loading branch information
richvdh committed Jan 4, 2023
1 parent 0ebea8b commit 11304c7
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 11 deletions.
171 changes: 164 additions & 7 deletions spec/unit/rust-crypto.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,21 @@ limitations under the License.

import "fake-indexeddb/auto";
import { IDBFactory } from "fake-indexeddb";
import * as RustSdkCryptoJs from "@matrix-org/matrix-sdk-crypto-js";
import {
KeysBackupRequest,
KeysClaimRequest,
KeysQueryRequest,
KeysUploadRequest,
SignatureUploadRequest,
} from "@matrix-org/matrix-sdk-crypto-js";
import { Mocked } from "jest-mock";
import MockHttpBackend from "matrix-mock-request";

import { RustCrypto } from "../../src/rust-crypto/rust-crypto";
import { initRustCrypto } from "../../src/rust-crypto";
import { IHttpOpts, MatrixHttpApi } from "../../src";
import { HttpApiEvent, HttpApiEventHandlerMap, IHttpOpts, MatrixHttpApi } from "../../src";
import { TypedEventEmitter } from "../../src/models/typed-event-emitter";

afterEach(() => {
// reset fake-indexeddb after each test, to make sure we don't leak connections
Expand All @@ -32,17 +43,163 @@ describe("RustCrypto", () => {
const TEST_USER = "@alice:example.com";
const TEST_DEVICE_ID = "TEST_DEVICE";

let rustCrypto: RustCrypto;
describe(".exportRoomKeys", () => {
let rustCrypto: RustCrypto;

beforeEach(async () => {
const mockHttpApi = {} as MatrixHttpApi<IHttpOpts>;
rustCrypto = (await initRustCrypto(mockHttpApi, TEST_USER, TEST_DEVICE_ID)) as RustCrypto;
});
beforeEach(async () => {
const mockHttpApi = {} as MatrixHttpApi<IHttpOpts>;
rustCrypto = (await initRustCrypto(mockHttpApi, TEST_USER, TEST_DEVICE_ID)) as RustCrypto;
});

describe(".exportRoomKeys", () => {
it("should return a list", async () => {
const keys = await rustCrypto.exportRoomKeys();
expect(Array.isArray(keys)).toBeTruthy();
});
});

describe("outgoing requests", () => {
/** the RustCrypto implementation under test */
let rustCrypto: RustCrypto;

/** A mock http backend which rustCrypto is connected to */
let httpBackend: MockHttpBackend;

/** a mocked-up OlmMachine which rustCrypto is connected to */
let olmMachine: Mocked<RustSdkCryptoJs.OlmMachine>;

/** A list of results to be returned from olmMachine.outgoingRequest. Each call will shift a result off
* the front of the queue, until it is empty. */
let outgoingRequestQueue: Array<Array<any>>;

/** wait for a call to olmMachine.markRequestAsSent */
function awaitCallToMarkAsSent(): Promise<void> {
return new Promise((resolve, _reject) => {
olmMachine.markRequestAsSent.mockImplementationOnce(async () => {
resolve(undefined);
});
});
}

beforeEach(async () => {
httpBackend = new MockHttpBackend();

await RustSdkCryptoJs.initAsync();

const dummyEventEmitter = new TypedEventEmitter<HttpApiEvent, HttpApiEventHandlerMap>();
const httpApi = new MatrixHttpApi(dummyEventEmitter, {
baseUrl: "https://example.com",
prefix: "/_matrix",
fetchFn: httpBackend.fetchFn as typeof global.fetch,
});

// for these tests we use a mock OlmMachine, with an implementation of outgoingRequests that
// returns objects from outgoingRequestQueue
outgoingRequestQueue = [];
olmMachine = {
outgoingRequests: jest.fn().mockImplementation(() => {
return Promise.resolve(outgoingRequestQueue.shift() ?? []);
}),
markRequestAsSent: jest.fn(),
close: jest.fn(),
} as unknown as Mocked<RustSdkCryptoJs.OlmMachine>;

rustCrypto = new RustCrypto(olmMachine, httpApi, TEST_USER, TEST_DEVICE_ID);
});

it("should poll for outgoing messages", () => {
rustCrypto.onSyncCompleted({});
expect(olmMachine.outgoingRequests).toHaveBeenCalled();
});

/* simple requests that map directly to the request body */
const tests: Array<[any, "POST" | "PUT", string]> = [
[KeysUploadRequest, "POST", "https://example.com/_matrix/client/v3/keys/upload"],
[KeysQueryRequest, "POST", "https://example.com/_matrix/client/v3/keys/query"],
[KeysClaimRequest, "POST", "https://example.com/_matrix/client/v3/keys/claim"],
[SignatureUploadRequest, "POST", "https://example.com/_matrix/client/v3/keys/signatures/upload"],
[KeysBackupRequest, "PUT", "https://example.com/_matrix/client/v3/room_keys/keys"],
];

for (const [RequestClass, expectedMethod, expectedPath] of tests) {
it(`should handle ${RequestClass.name}s`, async () => {
const testBody = '{ "foo": "bar" }';
const outgoingRequest = new RequestClass("1234", testBody);
outgoingRequestQueue.push([outgoingRequest]);

const testResponse = '{ "result": 1 }';
httpBackend
.when(expectedMethod, "/_matrix")
.check((req) => {
expect(req.path).toEqual(expectedPath);
expect(req.rawData).toEqual(testBody);
expect(req.headers["Accept"]).toEqual("application/json");
expect(req.headers["Content-Type"]).toEqual("application/json");
})
.respond(200, testResponse, true);

rustCrypto.onSyncCompleted({});

expect(olmMachine.outgoingRequests).toHaveBeenCalledTimes(1);

const markSentCallPromise = awaitCallToMarkAsSent();
await httpBackend.flushAllExpected();

await markSentCallPromise;
expect(olmMachine.markRequestAsSent).toHaveBeenCalledWith("1234", outgoingRequest.type, testResponse);
httpBackend.verifyNoOutstandingRequests();
});
}

it("does not explode with unknown requests", async () => {
const outgoingRequest = { id: "5678", type: 987 };
outgoingRequestQueue.push([outgoingRequest]);

rustCrypto.onSyncCompleted({});

await awaitCallToMarkAsSent();
expect(olmMachine.markRequestAsSent).toHaveBeenCalledWith("5678", 987, "");
});

it("stops looping when stop() is called", async () => {
const testResponse = '{ "result": 1 }';

for (let i = 0; i < 5; i++) {
outgoingRequestQueue.push([new KeysQueryRequest("1234", "{}")]);
httpBackend.when("POST", "/_matrix").respond(200, testResponse, true);
}

rustCrypto.onSyncCompleted({});

expect(rustCrypto["outgoingRequestLoopRunning"]).toBeTruthy();

// go a couple of times round the loop
await httpBackend.flush("/_matrix", 1);
await awaitCallToMarkAsSent();

await httpBackend.flush("/_matrix", 1);
await awaitCallToMarkAsSent();

// a second sync while this is going on shouldn't make any difference
rustCrypto.onSyncCompleted({});

await httpBackend.flush("/_matrix", 1);
await awaitCallToMarkAsSent();

// now stop...
rustCrypto.stop();

// which should (eventually) cause the loop to stop with no further calls to outgoingRequests
olmMachine.outgoingRequests.mockReset();

await new Promise((resolve) => {
setTimeout(resolve, 100);
});
expect(rustCrypto["outgoingRequestLoopRunning"]).toBeFalsy();
httpBackend.verifyNoOutstandingRequests();
expect(olmMachine.outgoingRequests).not.toHaveBeenCalled();

// we sent three, so there should be 2 left
expect(outgoingRequestQueue.length).toEqual(2);
});
});
});
110 changes: 106 additions & 4 deletions src/rust-crypto/rust-crypto.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,28 @@ limitations under the License.
*/

import * as RustSdkCryptoJs from "@matrix-org/matrix-sdk-crypto-js";
import {
KeysBackupRequest,
KeysClaimRequest,
KeysQueryRequest,
KeysUploadRequest,
SignatureUploadRequest,
} from "@matrix-org/matrix-sdk-crypto-js";

import type { IEventDecryptionResult, IMegolmSessionData } from "../@types/crypto";
import { MatrixEvent } from "../models/event";
import { CryptoBackend, OnSyncCompletedData } from "../common-crypto/CryptoBackend";
import { IHttpOpts, MatrixHttpApi } from "../http-api";
import { logger } from "../logger";
import { IHttpOpts, IRequestOpts, MatrixHttpApi, Method } from "../http-api";
import { QueryDict } from "../utils";

// import { logger } from "../logger";
/**
* Common interface for all the request types returned by `OlmMachine.outgoingRequests`.
*/
interface OutgoingRequest {
readonly id: string | undefined;
readonly type: number;
}

/**
* An implementation of {@link CryptoBackend} using the Rust matrix-sdk-crypto.
Expand All @@ -30,9 +45,12 @@ export class RustCrypto implements CryptoBackend {
public globalBlacklistUnverifiedDevices = false;
public globalErrorOnUnknownDevices = false;

/** whether stop() has been called */
/** whether {@link stop} has been called */
private stopped = false;

/** whether {@link outgoingRequestLoop} is currently running */
private outgoingRequestLoopRunning = false;

public constructor(
private readonly olmMachine: RustSdkCryptoJs.OlmMachine,
private readonly http: MatrixHttpApi<IHttpOpts>,
Expand Down Expand Up @@ -69,11 +87,95 @@ export class RustCrypto implements CryptoBackend {
return [];
}

///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//
// SyncCryptoCallbacks implementation
//
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////

/** called by the sync loop after processing each sync.
*
* TODO: figure out something equivalent for sliding sync.
*
* @param syncState - information on the completed sync.
*/
public onSyncCompleted(syncState: OnSyncCompletedData): void {}
public onSyncCompleted(syncState: OnSyncCompletedData): void {
// Processing the /sync may have produced new outgoing requests which need sending, so kick off the outgoing
// request loop, if it's not already running.
this.outgoingRequestLoop();
}

///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Outgoing requests
//
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////

private async outgoingRequestLoop(): Promise<void> {
if (this.outgoingRequestLoopRunning) {
return;
}
this.outgoingRequestLoopRunning = true;
try {
while (!this.stopped) {
const outgoingRequests: Object[] = await this.olmMachine.outgoingRequests();
if (outgoingRequests.length == 0 || this.stopped) {
// no more messages to send (or we have been told to stop): exit the loop
return;
}
for (const msg of outgoingRequests) {
await this.doOutgoingRequest(msg as OutgoingRequest);
}
}
} catch (e) {
logger.error("Error processing outgoing-message requests from rust crypto-sdk", e);
} finally {
this.outgoingRequestLoopRunning = false;
}
}

private async doOutgoingRequest(msg: OutgoingRequest): Promise<void> {
let resp: string;

/* refer https://docs.rs/matrix-sdk-crypto/0.6.0/matrix_sdk_crypto/requests/enum.OutgoingRequests.html
* for the complete list of request types
*/
if (msg instanceof KeysUploadRequest) {
resp = await this.rawJsonRequest(Method.Post, "/_matrix/client/v3/keys/upload", {}, msg.body);
} else if (msg instanceof KeysQueryRequest) {
resp = await this.rawJsonRequest(Method.Post, "/_matrix/client/v3/keys/query", {}, msg.body);
} else if (msg instanceof KeysClaimRequest) {
resp = await this.rawJsonRequest(Method.Post, "/_matrix/client/v3/keys/claim", {}, msg.body);
} else if (msg instanceof SignatureUploadRequest) {
resp = await this.rawJsonRequest(Method.Post, "/_matrix/client/v3/keys/signatures/upload", {}, msg.body);
} else if (msg instanceof KeysBackupRequest) {
resp = await this.rawJsonRequest(Method.Put, "/_matrix/client/v3/room_keys/keys", {}, msg.body);
} else {
// TODO: ToDeviceRequest, RoomMessageRequest
logger.warn("Unsupported outgoing message", Object.getPrototypeOf(msg));
resp = "";
}

if (msg.id) {
await this.olmMachine.markRequestAsSent(msg.id, msg.type, resp);
}
}

private async rawJsonRequest(
method: Method,
path: string,
queryParams: QueryDict,
body: string,
opts: IRequestOpts = {},
): Promise<string> {
// unbeknownst to HttpApi, we are sending JSON
opts.headers ??= {};
opts.headers["Content-Type"] = "application/json";

// we use the full prefix
opts.prefix ??= "";

const resp = await this.http.authedRequest(method, path, queryParams, body, opts);
return await resp.text();
}
}

0 comments on commit 11304c7

Please sign in to comment.