diff --git a/spec/unit/rust-crypto.spec.ts b/spec/unit/rust-crypto.spec.ts index 822c08bf6c9..0bcc88d5d26 100644 --- a/spec/unit/rust-crypto.spec.ts +++ b/spec/unit/rust-crypto.spec.ts @@ -22,6 +22,7 @@ import { KeysClaimRequest, KeysQueryRequest, KeysUploadRequest, + OlmMachine, SignatureUploadRequest, } from "@matrix-org/matrix-sdk-crypto-js"; import { Mocked } from "jest-mock"; @@ -29,7 +30,7 @@ import MockHttpBackend from "matrix-mock-request"; import { RustCrypto } from "../../src/rust-crypto/rust-crypto"; import { initRustCrypto } from "../../src/rust-crypto"; -import { HttpApiEvent, HttpApiEventHandlerMap, IHttpOpts, MatrixHttpApi } from "../../src"; +import { HttpApiEvent, HttpApiEventHandlerMap, IHttpOpts, IToDeviceEvent, MatrixHttpApi } from "../../src"; import { TypedEventEmitter } from "../../src/models/typed-event-emitter"; afterEach(() => { @@ -57,6 +58,47 @@ describe("RustCrypto", () => { }); }); + describe("to-device messages", () => { + let rustCrypto: RustCrypto; + + beforeEach(async () => { + const mockHttpApi = {} as MatrixHttpApi; + rustCrypto = (await initRustCrypto(mockHttpApi, TEST_USER, TEST_DEVICE_ID)) as RustCrypto; + }); + + it("should pass through unencrypted to-device messages", async () => { + const inputs: IToDeviceEvent[] = [ + { content: { key: "value" }, type: "org.matrix.test", sender: "@alice:example.com" }, + ]; + const res = await rustCrypto.preprocessToDeviceMessages(inputs); + expect(res).toEqual(inputs); + }); + + it("should pass through bad encrypted messages", async () => { + const olmMachine: OlmMachine = rustCrypto["olmMachine"]; + const keys = olmMachine.identityKeys; + const inputs: IToDeviceEvent[] = [ + { + type: "m.room.encrypted", + content: { + algorithm: "m.olm.v1.curve25519-aes-sha2", + sender_key: "IlRMeOPX2e0MurIyfWEucYBRVOEEUMrOHqn/8mLqMjA", + ciphertext: { + [keys.curve25519.toBase64()]: { + type: 0, + body: "ajyjlghi", + }, + }, + }, + sender: "@alice:example.com", + }, + ]; + + const res = await rustCrypto.preprocessToDeviceMessages(inputs); + expect(res).toEqual(inputs); + }); + }); + describe("outgoing requests", () => { /** the RustCrypto implementation under test */ let rustCrypto: RustCrypto; diff --git a/src/common-crypto/CryptoBackend.ts b/src/common-crypto/CryptoBackend.ts index 82db3f28a9c..a90afdf3a82 100644 --- a/src/common-crypto/CryptoBackend.ts +++ b/src/common-crypto/CryptoBackend.ts @@ -15,6 +15,7 @@ limitations under the License. */ import type { IEventDecryptionResult, IMegolmSessionData } from "../@types/crypto"; +import type { IToDeviceEvent } from "../sync-accumulator"; import { MatrixEvent } from "../models/event"; /** @@ -74,6 +75,20 @@ export interface CryptoBackend extends SyncCryptoCallbacks { /** The methods which crypto implementations should expose to the Sync api */ export interface SyncCryptoCallbacks { + /** + * Called by the /sync loop whenever there are incoming to-device messages. + * + * The implementation may preprocess the received messages (eg, decrypt them) and return an + * updated list of messages for dispatch to the rest of the system. + * + * Note that, unlike {@link ClientEvent.ToDeviceEvent} events, this is called on the raw to-device + * messages, rather than the results of any decryption attempts. + * + * @param events - the received to-device messages + * @returns A list of preprocessed to-device messages. + */ + preprocessToDeviceMessages(events: IToDeviceEvent[]): Promise; + /** * Called by the /sync loop after each /sync response is processed. * diff --git a/src/crypto/index.ts b/src/crypto/index.ts index 24b36b08143..458132e75bb 100644 --- a/src/crypto/index.ts +++ b/src/crypto/index.ts @@ -85,7 +85,7 @@ import { CryptoStore } from "./store/base"; import { IVerificationChannel } from "./verification/request/Channel"; import { TypedEventEmitter } from "../models/typed-event-emitter"; import { IContent } from "../models/event"; -import { ISyncResponse } from "../sync-accumulator"; +import { ISyncResponse, IToDeviceEvent } from "../sync-accumulator"; import { ISignatures } from "../@types/signed"; import { IMessage } from "./algorithms/olm"; import { CryptoBackend, OnSyncCompletedData } from "../common-crypto/CryptoBackend"; @@ -3198,6 +3198,21 @@ export class Crypto extends TypedEventEmitter { + // all we do here is filter out encrypted to-device messages with the wrong algorithm. Decryption + // happens later in decryptEvent, via the EventMapper + return events.filter((toDevice) => { + if ( + toDevice.type === EventType.RoomMessageEncrypted && + !["m.olm.v1.curve25519-aes-sha2"].includes(toDevice.content?.algorithm) + ) { + logger.log("Ignoring invalid encrypted to-device event from " + toDevice.sender); + return false; + } + return true; + }); + } + private onToDeviceEvent = (event: MatrixEvent): void => { try { logger.log( diff --git a/src/event-mapper.ts b/src/event-mapper.ts index 81d3d772a59..87db88d6407 100644 --- a/src/event-mapper.ts +++ b/src/event-mapper.ts @@ -60,6 +60,9 @@ export function eventMapperFor(client: MatrixClient, options: MapperOpts): Event event.setThread(thread); } + // TODO: once we get rid of the old libolm-backed crypto, we can restrict this to room events (rather than + // to-device events), because the rust implementation decrypts to-device messages at a higher level. + // Generally we probably want to use a different eventMapper implementation for to-device events because if (event.isEncrypted()) { if (!preventReEmit) { client.reEmitter.reEmit(event, [MatrixEventEvent.Decrypted]); diff --git a/src/rust-crypto/rust-crypto.ts b/src/rust-crypto/rust-crypto.ts index c6ff569a816..baa4bda4b16 100644 --- a/src/rust-crypto/rust-crypto.ts +++ b/src/rust-crypto/rust-crypto.ts @@ -24,6 +24,7 @@ import { } from "@matrix-org/matrix-sdk-crypto-js"; import type { IEventDecryptionResult, IMegolmSessionData } from "../@types/crypto"; +import type { IToDeviceEvent } from "../sync-accumulator"; import { MatrixEvent } from "../models/event"; import { CryptoBackend, OnSyncCompletedData } from "../common-crypto/CryptoBackend"; import { logger } from "../logger"; @@ -93,6 +94,25 @@ export class RustCrypto implements CryptoBackend { // /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + /** called by the sync loop to preprocess incoming to-device messages + * + * @param events - the received to-device messages + * @returns A list of preprocessed to-device messages. + */ + public async preprocessToDeviceMessages(events: IToDeviceEvent[]): Promise { + // send the received to-device messages into receiveSyncChanges. We have no info on device-list changes, + // one-time-keys, or fallback keys, so just pass empty data. + const result = await this.olmMachine.receiveSyncChanges( + JSON.stringify(events), + new RustSdkCryptoJs.DeviceLists(), + new Map(), + new Set(), + ); + + // receiveSyncChanges returns a JSON-encoded list of decrypted to-device messages. + return JSON.parse(result); + } + /** called by the sync loop after processing each sync. * * TODO: figure out something equivalent for sliding sync. diff --git a/src/sliding-sync-sdk.ts b/src/sliding-sync-sdk.ts index 18c94c16836..91ff9d7a75e 100644 --- a/src/sliding-sync-sdk.ts +++ b/src/sliding-sync-sdk.ts @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ +import type { SyncCryptoCallbacks } from "./common-crypto/CryptoBackend"; import { NotificationCountType, Room, RoomEvent } from "./models/room"; import { logger } from "./logger"; import * as utils from "./utils"; @@ -127,7 +128,7 @@ type ExtensionToDeviceResponse = { class ExtensionToDevice implements Extension { private nextBatch: string | null = null; - public constructor(private readonly client: MatrixClient) {} + public constructor(private readonly client: MatrixClient, private readonly cryptoCallbacks?: SyncCryptoCallbacks) {} public name(): string { return "to_device"; @@ -150,8 +151,12 @@ class ExtensionToDevice implements Extension { const cancelledKeyVerificationTxns: string[] = []; - data.events - ?.map(this.client.getEventMapper()) + let events = data["events"] || []; + if (events.length > 0 && this.cryptoCallbacks) { + events = await this.cryptoCallbacks.preprocessToDeviceMessages(events); + } + events + .map(this.client.getEventMapper()) .map((toDeviceEvent) => { // map is a cheap inline forEach // We want to flag m.key.verification.start events as cancelled @@ -373,7 +378,7 @@ export class SlidingSyncSdk { this.slidingSync.on(SlidingSyncEvent.Lifecycle, this.onLifecycle.bind(this)); this.slidingSync.on(SlidingSyncEvent.RoomData, this.onRoomData.bind(this)); const extensions: Extension[] = [ - new ExtensionToDevice(this.client), + new ExtensionToDevice(this.client, this.syncOpts.cryptoCallbacks), new ExtensionAccountData(this.client), new ExtensionTyping(this.client), new ExtensionReceipts(this.client), diff --git a/src/sync.ts b/src/sync.ts index 0ba52ba8c0c..11fe3cc102b 100644 --- a/src/sync.ts +++ b/src/sync.ts @@ -48,6 +48,7 @@ import { IStrippedState, ISyncResponse, ITimeline, + IToDeviceEvent, } from "./sync-accumulator"; import { MatrixEvent } from "./models/event"; import { MatrixError, Method } from "./http-api"; @@ -1170,19 +1171,15 @@ export class SyncApi { } // handle to-device events - if (Array.isArray(data.to_device?.events) && data.to_device!.events.length > 0) { - const cancelledKeyVerificationTxns: string[] = []; - data.to_device!.events.filter((eventJSON) => { - if ( - eventJSON.type === EventType.RoomMessageEncrypted && - !["m.olm.v1.curve25519-aes-sha2"].includes(eventJSON.content?.algorithm) - ) { - logger.log("Ignoring invalid encrypted to-device event from " + eventJSON.sender); - return false; - } + if (data.to_device && Array.isArray(data.to_device.events) && data.to_device.events.length > 0) { + let toDeviceMessages: IToDeviceEvent[] = data.to_device.events; - return true; - }) + if (this.syncOpts.cryptoCallbacks) { + toDeviceMessages = await this.syncOpts.cryptoCallbacks.preprocessToDeviceMessages(toDeviceMessages); + } + + const cancelledKeyVerificationTxns: string[] = []; + toDeviceMessages .map(client.getEventMapper({ toDevice: true })) .map((toDeviceEvent) => { // map is a cheap inline forEach