Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle outgoing requests from rust crypto SDK #3019

Merged
merged 2 commits into from
Jan 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 164 additions & 5 deletions spec/unit/rust-crypto.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,21 @@ limitations under the License.

import "fake-indexeddb/auto";
import { IDBFactory } from "fake-indexeddb";
import * as RustSdkCryptoJs from "@matrix-org/matrix-sdk-crypto-js";
import {
KeysBackupRequest,
KeysClaimRequest,
KeysQueryRequest,
KeysUploadRequest,
SignatureUploadRequest,
} from "@matrix-org/matrix-sdk-crypto-js";
import { Mocked } from "jest-mock";
import MockHttpBackend from "matrix-mock-request";

import { RustCrypto } from "../../src/rust-crypto/rust-crypto";
import { initRustCrypto } from "../../src/rust-crypto";
import { HttpApiEvent, HttpApiEventHandlerMap, IHttpOpts, MatrixHttpApi } from "../../src";
import { TypedEventEmitter } from "../../src/models/typed-event-emitter";

afterEach(() => {
// reset fake-indexeddb after each test, to make sure we don't leak connections
Expand All @@ -31,16 +43,163 @@ describe("RustCrypto", () => {
const TEST_USER = "@alice:example.com";
const TEST_DEVICE_ID = "TEST_DEVICE";

let rustCrypto: RustCrypto;
describe(".exportRoomKeys", () => {
let rustCrypto: RustCrypto;

beforeEach(async () => {
rustCrypto = (await initRustCrypto(TEST_USER, TEST_DEVICE_ID)) as RustCrypto;
});
beforeEach(async () => {
const mockHttpApi = {} as MatrixHttpApi<IHttpOpts>;
rustCrypto = (await initRustCrypto(mockHttpApi, TEST_USER, TEST_DEVICE_ID)) as RustCrypto;
});

describe(".exportRoomKeys", () => {
it("should return a list", async () => {
const keys = await rustCrypto.exportRoomKeys();
expect(Array.isArray(keys)).toBeTruthy();
});
});

describe("outgoing requests", () => {
/** the RustCrypto implementation under test */
let rustCrypto: RustCrypto;

/** A mock http backend which rustCrypto is connected to */
let httpBackend: MockHttpBackend;

/** a mocked-up OlmMachine which rustCrypto is connected to */
let olmMachine: Mocked<RustSdkCryptoJs.OlmMachine>;

/** A list of results to be returned from olmMachine.outgoingRequest. Each call will shift a result off
* the front of the queue, until it is empty. */
let outgoingRequestQueue: Array<Array<any>>;

/** wait for a call to olmMachine.markRequestAsSent */
function awaitCallToMarkAsSent(): Promise<void> {
return new Promise((resolve, _reject) => {
olmMachine.markRequestAsSent.mockImplementationOnce(async () => {
resolve(undefined);
});
});
}

beforeEach(async () => {
httpBackend = new MockHttpBackend();

await RustSdkCryptoJs.initAsync();

const dummyEventEmitter = new TypedEventEmitter<HttpApiEvent, HttpApiEventHandlerMap>();
const httpApi = new MatrixHttpApi(dummyEventEmitter, {
baseUrl: "https://example.com",
prefix: "/_matrix",
fetchFn: httpBackend.fetchFn as typeof global.fetch,
});

// for these tests we use a mock OlmMachine, with an implementation of outgoingRequests that
// returns objects from outgoingRequestQueue
outgoingRequestQueue = [];
olmMachine = {
outgoingRequests: jest.fn().mockImplementation(() => {
return Promise.resolve(outgoingRequestQueue.shift() ?? []);
}),
markRequestAsSent: jest.fn(),
close: jest.fn(),
} as unknown as Mocked<RustSdkCryptoJs.OlmMachine>;

rustCrypto = new RustCrypto(olmMachine, httpApi, TEST_USER, TEST_DEVICE_ID);
});

it("should poll for outgoing messages", () => {
rustCrypto.onSyncCompleted({});
expect(olmMachine.outgoingRequests).toHaveBeenCalled();
});

/* simple requests that map directly to the request body */
const tests: Array<[any, "POST" | "PUT", string]> = [
[KeysUploadRequest, "POST", "https://example.com/_matrix/client/v3/keys/upload"],
[KeysQueryRequest, "POST", "https://example.com/_matrix/client/v3/keys/query"],
[KeysClaimRequest, "POST", "https://example.com/_matrix/client/v3/keys/claim"],
[SignatureUploadRequest, "POST", "https://example.com/_matrix/client/v3/keys/signatures/upload"],
[KeysBackupRequest, "PUT", "https://example.com/_matrix/client/v3/room_keys/keys"],
];

for (const [RequestClass, expectedMethod, expectedPath] of tests) {
it(`should handle ${RequestClass.name}s`, async () => {
const testBody = '{ "foo": "bar" }';
const outgoingRequest = new RequestClass("1234", testBody);
outgoingRequestQueue.push([outgoingRequest]);

const testResponse = '{ "result": 1 }';
httpBackend
.when(expectedMethod, "/_matrix")
.check((req) => {
expect(req.path).toEqual(expectedPath);
expect(req.rawData).toEqual(testBody);
expect(req.headers["Accept"]).toEqual("application/json");
expect(req.headers["Content-Type"]).toEqual("application/json");
})
.respond(200, testResponse, true);

rustCrypto.onSyncCompleted({});

expect(olmMachine.outgoingRequests).toHaveBeenCalledTimes(1);

const markSentCallPromise = awaitCallToMarkAsSent();
await httpBackend.flushAllExpected();

await markSentCallPromise;
expect(olmMachine.markRequestAsSent).toHaveBeenCalledWith("1234", outgoingRequest.type, testResponse);
httpBackend.verifyNoOutstandingRequests();
});
}

it("does not explode with unknown requests", async () => {
const outgoingRequest = { id: "5678", type: 987 };
outgoingRequestQueue.push([outgoingRequest]);

rustCrypto.onSyncCompleted({});

await awaitCallToMarkAsSent();
expect(olmMachine.markRequestAsSent).toHaveBeenCalledWith("5678", 987, "");
});

it("stops looping when stop() is called", async () => {
const testResponse = '{ "result": 1 }';

for (let i = 0; i < 5; i++) {
outgoingRequestQueue.push([new KeysQueryRequest("1234", "{}")]);
httpBackend.when("POST", "/_matrix").respond(200, testResponse, true);
}

rustCrypto.onSyncCompleted({});

expect(rustCrypto["outgoingRequestLoopRunning"]).toBeTruthy();

// go a couple of times round the loop
await httpBackend.flush("/_matrix", 1);
await awaitCallToMarkAsSent();

await httpBackend.flush("/_matrix", 1);
await awaitCallToMarkAsSent();

// a second sync while this is going on shouldn't make any difference
rustCrypto.onSyncCompleted({});

await httpBackend.flush("/_matrix", 1);
await awaitCallToMarkAsSent();

// now stop...
rustCrypto.stop();

// which should (eventually) cause the loop to stop with no further calls to outgoingRequests
olmMachine.outgoingRequests.mockReset();

await new Promise((resolve) => {
setTimeout(resolve, 100);
});
expect(rustCrypto["outgoingRequestLoopRunning"]).toBeFalsy();
httpBackend.verifyNoOutstandingRequests();
expect(olmMachine.outgoingRequests).not.toHaveBeenCalled();

// we sent three, so there should be 2 left
expect(outgoingRequestQueue.length).toEqual(2);
});
});
});
2 changes: 1 addition & 1 deletion src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2148,7 +2148,7 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
// importing rust-crypto will download the webassembly, so we delay it until we know it will be
// needed.
const RustCrypto = await import("./rust-crypto");
this.cryptoBackend = await RustCrypto.initRustCrypto(userId, deviceId);
this.cryptoBackend = await RustCrypto.initRustCrypto(this.http, userId, deviceId);
}

/**
Expand Down
9 changes: 7 additions & 2 deletions src/rust-crypto/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@ import { RustCrypto } from "./rust-crypto";
import { logger } from "../logger";
import { CryptoBackend } from "../common-crypto/CryptoBackend";
import { RUST_SDK_STORE_PREFIX } from "./constants";
import { IHttpOpts, MatrixHttpApi } from "../http-api";

export async function initRustCrypto(userId: string, deviceId: string): Promise<CryptoBackend> {
export async function initRustCrypto(
http: MatrixHttpApi<IHttpOpts>,
userId: string,
deviceId: string,
): Promise<CryptoBackend> {
// initialise the rust matrix-sdk-crypto-js, if it hasn't already been done
await RustSdkCryptoJs.initAsync();

Expand All @@ -34,7 +39,7 @@ export async function initRustCrypto(userId: string, deviceId: string): Promise<

// TODO: use the pickle key for the passphrase
const olmMachine = await RustSdkCryptoJs.OlmMachine.initialize(u, d, RUST_SDK_STORE_PREFIX, "test pass");
const rustCrypto = new RustCrypto(olmMachine, userId, deviceId);
const rustCrypto = new RustCrypto(olmMachine, http, userId, deviceId);

logger.info("Completed rust crypto-sdk setup");
return rustCrypto;
Expand Down
116 changes: 112 additions & 4 deletions src/rust-crypto/rust-crypto.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,28 @@ limitations under the License.
*/

import * as RustSdkCryptoJs from "@matrix-org/matrix-sdk-crypto-js";
import {
KeysBackupRequest,
KeysClaimRequest,
KeysQueryRequest,
KeysUploadRequest,
SignatureUploadRequest,
} from "@matrix-org/matrix-sdk-crypto-js";

import type { IEventDecryptionResult, IMegolmSessionData } from "../@types/crypto";
import { MatrixEvent } from "../models/event";
import { CryptoBackend, OnSyncCompletedData } from "../common-crypto/CryptoBackend";
import { logger } from "../logger";
import { IHttpOpts, IRequestOpts, MatrixHttpApi, Method } from "../http-api";
import { QueryDict } from "../utils";

// import { logger } from "../logger";
/**
* Common interface for all the request types returned by `OlmMachine.outgoingRequests`.
*/
interface OutgoingRequest {
readonly id: string | undefined;
readonly type: number;
}

/**
* An implementation of {@link CryptoBackend} using the Rust matrix-sdk-crypto.
Expand All @@ -29,10 +45,18 @@ export class RustCrypto implements CryptoBackend {
public globalBlacklistUnverifiedDevices = false;
public globalErrorOnUnknownDevices = false;

/** whether stop() has been called */
/** whether {@link stop} has been called */
private stopped = false;

public constructor(private readonly olmMachine: RustSdkCryptoJs.OlmMachine, _userId: string, _deviceId: string) {}
/** whether {@link outgoingRequestLoop} is currently running */
private outgoingRequestLoopRunning = false;

public constructor(
private readonly olmMachine: RustSdkCryptoJs.OlmMachine,
private readonly http: MatrixHttpApi<IHttpOpts>,
_userId: string,
_deviceId: string,
) {}

public stop(): void {
// stop() may be called multiple times, but attempting to close() the OlmMachine twice
Expand Down Expand Up @@ -63,11 +87,95 @@ export class RustCrypto implements CryptoBackend {
return [];
}

///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//
// SyncCryptoCallbacks implementation
//
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////

/** called by the sync loop after processing each sync.
*
* TODO: figure out something equivalent for sliding sync.
*
* @param syncState - information on the completed sync.
*/
public onSyncCompleted(syncState: OnSyncCompletedData): void {}
public onSyncCompleted(syncState: OnSyncCompletedData): void {
// Processing the /sync may have produced new outgoing requests which need sending, so kick off the outgoing
// request loop, if it's not already running.
this.outgoingRequestLoop();
}

///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Outgoing requests
//
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////

private async outgoingRequestLoop(): Promise<void> {
if (this.outgoingRequestLoopRunning) {
return;
}
this.outgoingRequestLoopRunning = true;
try {
while (!this.stopped) {
const outgoingRequests: Object[] = await this.olmMachine.outgoingRequests();
if (outgoingRequests.length == 0 || this.stopped) {
// no more messages to send (or we have been told to stop): exit the loop
return;
}
for (const msg of outgoingRequests) {
await this.doOutgoingRequest(msg as OutgoingRequest);
}
}
} catch (e) {
logger.error("Error processing outgoing-message requests from rust crypto-sdk", e);
} finally {
this.outgoingRequestLoopRunning = false;
}
}

private async doOutgoingRequest(msg: OutgoingRequest): Promise<void> {
let resp: string;

/* refer https://docs.rs/matrix-sdk-crypto/0.6.0/matrix_sdk_crypto/requests/enum.OutgoingRequests.html
* for the complete list of request types
*/
if (msg instanceof KeysUploadRequest) {
resp = await this.rawJsonRequest(Method.Post, "/_matrix/client/v3/keys/upload", {}, msg.body);
} else if (msg instanceof KeysQueryRequest) {
resp = await this.rawJsonRequest(Method.Post, "/_matrix/client/v3/keys/query", {}, msg.body);
} else if (msg instanceof KeysClaimRequest) {
resp = await this.rawJsonRequest(Method.Post, "/_matrix/client/v3/keys/claim", {}, msg.body);
} else if (msg instanceof SignatureUploadRequest) {
resp = await this.rawJsonRequest(Method.Post, "/_matrix/client/v3/keys/signatures/upload", {}, msg.body);
} else if (msg instanceof KeysBackupRequest) {
resp = await this.rawJsonRequest(Method.Put, "/_matrix/client/v3/room_keys/keys", {}, msg.body);
} else {
// TODO: ToDeviceRequest, RoomMessageRequest
logger.warn("Unsupported outgoing message", Object.getPrototypeOf(msg));
resp = "";
}

if (msg.id) {
await this.olmMachine.markRequestAsSent(msg.id, msg.type, resp);
}
}

private async rawJsonRequest(
method: Method,
path: string,
queryParams: QueryDict,
body: string,
opts: IRequestOpts = {},
): Promise<string> {
// unbeknownst to HttpApi, we are sending JSON
opts.headers ??= {};
opts.headers["Content-Type"] = "application/json";

// we use the full prefix
opts.prefix ??= "";

const resp = await this.http.authedRequest(method, path, queryParams, body, opts);
return await resp.text();
}
}