Skip to content

Commit

Permalink
Optimise MatrixRTC sender key distribution for case of single member
Browse files Browse the repository at this point in the history
The "doesn't re-send key immediately" test is superseded by the other test cases
  • Loading branch information
hughns committed Oct 16, 2024
1 parent 80b6424 commit 52a5c96
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 105 deletions.
222 changes: 123 additions & 99 deletions spec/unit/matrixrtc/MatrixRTCSession.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -763,50 +763,99 @@ describe("MatrixRTCSession", () => {
expect(client.cancelPendingEvent).toHaveBeenCalledWith(eventSentinel);
});

it("rotates key if a new member joins", async () => {
jest.useFakeTimers();
try {
const mockRoom = makeMockRoom([membershipTemplate]);
sess = MatrixRTCSession.roomSessionForRoom(client, mockRoom);
async function setupParticipantChangeTest(initial: CallMembershipData[]) {
const mockRoom = makeMockRoom(initial);
sess = MatrixRTCSession.roomSessionForRoom(client, mockRoom);
const onMyEncryptionKeyChanged = jest.fn();
sess.on(
MatrixRTCSessionEvent.EncryptionKeyChanged,
(_key: Uint8Array, _idx: number, participantId: string) => {
if (participantId === `${client.getUserId()}:${client.getDeviceId()}`) {
onMyEncryptionKeyChanged();
}
},
);

const keysSentPromise = new Promise<EncryptionKeysEventContent>((resolve) => {
sendEventMock.mockImplementation((_roomId, _evType, payload) => resolve(payload));
});
sess.joinRoomSession([mockFocus], mockFocus, { manageMediaKeys: true });
const initialKeysPayload = await keysSentPromise;

const keysSentPromise1 = new Promise<EncryptionKeysEventContent>((resolve) => {
const changeMembers = (changed: CallMembershipData[]) => {
const mock = new Promise<EncryptionKeysEventContent>((resolve) => {
sendEventMock.mockImplementation((_roomId, _evType, payload) => resolve(payload));
});

sess.joinRoomSession([mockFocus], mockFocus, { manageMediaKeys: true });
const firstKeysPayload = await keysSentPromise1;
expect(firstKeysPayload.keys).toHaveLength(1);
expect(firstKeysPayload.keys[0].index).toEqual(0);
expect(sess!.statistics.counters.roomEventEncryptionKeysSent).toEqual(1);
mockRoom.getLiveTimeline().getState = jest
.fn()
.mockReturnValue(makeMockRoomState(changed, mockRoom.roomId));
sess!.onMembershipUpdate();

sendEventMock.mockClear();
jest.advanceTimersByTime(10000);
return mock;
};

const keysSentPromise2 = new Promise<EncryptionKeysEventContent>((resolve) => {
sendEventMock.mockImplementation((_roomId, _evType, payload) => resolve(payload));
return { onMyEncryptionKeyChanged, session: sess, initialKeysPayload, changeMembers };
}

it("rotates key and emits immediately when second member joins", async () => {
jest.useFakeTimers();
try {
const member2 = Object.assign({}, membershipTemplate, {
device_id: "BBBBBBB",
});

const onMembershipsChanged = jest.fn();
sess.on(MatrixRTCSessionEvent.MembershipsChanged, onMembershipsChanged);
const { session, onMyEncryptionKeyChanged, changeMembers } = await setupParticipantChangeTest([
membershipTemplate,
]);

jest.advanceTimersByTime(10000);

jest.clearAllMocks();
await changeMembers([membershipTemplate, member2]);

expect(session.statistics.counters.roomEventEncryptionKeysSent).toEqual(2);
expect(onMyEncryptionKeyChanged).toHaveBeenCalledTimes(1);
} finally {
jest.useRealTimers();
}
});

it("rotates key after delay when additional members join", async () => {
jest.useFakeTimers();
try {
const member2 = Object.assign({}, membershipTemplate, {
device_id: "BBBBBBB",
});

mockRoom.getLiveTimeline().getState = jest
.fn()
.mockReturnValue(makeMockRoomState([membershipTemplate, member2], mockRoom.roomId));
sess.onMembershipUpdate();
const member3 = Object.assign({}, membershipTemplate, {
device_id: "CCCCCCC",
});

const { session, onMyEncryptionKeyChanged, changeMembers } = await setupParticipantChangeTest([
membershipTemplate,
member2,
]);

jest.advanceTimersByTime(10000);

const secondKeysPayload = await keysSentPromise2;
const keysSentPromise = changeMembers([membershipTemplate, member2, member3]);

expect(sendEventMock).toHaveBeenCalled();
expect(secondKeysPayload.keys).toHaveLength(1);
expect(secondKeysPayload.keys[0].index).toEqual(1);
expect(secondKeysPayload.keys[0].key).not.toEqual(firstKeysPayload.keys[0].key);
expect(sess!.statistics.counters.roomEventEncryptionKeysSent).toEqual(2);
// key is only generated after a delay
jest.clearAllMocks();
jest.advanceTimersByTime(2500); // the key should not yet have been generated
expect(sendEventMock).not.toHaveBeenCalled();
expect(session.statistics.counters.roomEventEncryptionKeysSent).toEqual(1);
jest.advanceTimersByTime(500); // the key should now have been generated

await keysSentPromise;
expect(session.statistics.counters.roomEventEncryptionKeysSent).toEqual(2);
expect(onMyEncryptionKeyChanged).not.toHaveBeenCalled();

// the emit comes after the key usage delay
jest.clearAllMocks();
jest.advanceTimersByTime(5000);
expect(onMyEncryptionKeyChanged).toHaveBeenCalledTimes(1);
} finally {
jest.useRealTimers();
}
Expand Down Expand Up @@ -987,6 +1036,7 @@ describe("MatrixRTCSession", () => {
sent_ts: Date.now(),
},
);
const firstKey = sendEventMock.mock.calls[0][2].keys[0].key;
expect(sess!.statistics.counters.roomEventEncryptionKeysSent).toEqual(1);

sendEventMock.mockClear();
Expand Down Expand Up @@ -1019,7 +1069,7 @@ describe("MatrixRTCSession", () => {
keys: [
{
index: 0,
key: expect.stringMatching(".*"),
key: firstKey,
},
],
sent_ts: Date.now(),
Expand All @@ -1031,54 +1081,68 @@ describe("MatrixRTCSession", () => {
}
});

it("rotates key if a member leaves", async () => {
it("rotates key and emits immediately when second member leaves", async () => {
jest.useFakeTimers();
try {
const member2 = Object.assign({}, membershipTemplate, {
device_id: "BBBBBBB",
});
const mockRoom = makeMockRoom([membershipTemplate, member2]);
sess = MatrixRTCSession.roomSessionForRoom(client, mockRoom);

const onMyEncryptionKeyChanged = jest.fn();
sess.on(
MatrixRTCSessionEvent.EncryptionKeyChanged,
(_key: Uint8Array, _idx: number, participantId: string) => {
if (participantId === `${client.getUserId()}:${client.getDeviceId()}`) {
onMyEncryptionKeyChanged();
}
},
);
const { onMyEncryptionKeyChanged, changeMembers } = await setupParticipantChangeTest([
membershipTemplate,
member2,
]);

const keysSentPromise1 = new Promise<EncryptionKeysEventContent>((resolve) => {
sendEventMock.mockImplementation((_roomId, _evType, payload) => resolve(payload));
});
jest.advanceTimersByTime(10000);

sess.joinRoomSession([mockFocus], mockFocus, { manageMediaKeys: true });
const firstKeysPayload = await keysSentPromise1;
expect(firstKeysPayload.keys).toHaveLength(1);
expect(firstKeysPayload.keys[0].index).toEqual(0);
expect(sess!.statistics.counters.roomEventEncryptionKeysSent).toEqual(1);
jest.clearAllMocks();
await changeMembers([membershipTemplate]);

sendEventMock.mockClear();
expect(sess!.statistics.counters.roomEventEncryptionKeysSent).toEqual(2);
expect(onMyEncryptionKeyChanged).toHaveBeenCalledTimes(1);
} finally {
jest.useRealTimers();
}
});

const keysSentPromise2 = new Promise<EncryptionKeysEventContent>((resolve) => {
sendEventMock.mockImplementation((_roomId, _evType, payload) => resolve(payload));
it("rotates key after delay when additional member leaves", async () => {
jest.useFakeTimers();
try {
const member2 = Object.assign({}, membershipTemplate, {
device_id: "BBBBBBB",
});

mockRoom.getLiveTimeline().getState = jest
.fn()
.mockReturnValue(makeMockRoomState([membershipTemplate], mockRoom.roomId));
sess.onMembershipUpdate();
const member3 = Object.assign({}, membershipTemplate, {
device_id: "CCCCCCC",
});

const { onMyEncryptionKeyChanged, changeMembers } = await setupParticipantChangeTest([
membershipTemplate,
member2,
member3,
]);

jest.advanceTimersByTime(10000);

const secondKeysPayload = await keysSentPromise2;
jest.clearAllMocks();
const keysSentPromise = changeMembers([membershipTemplate, member2]);

// key is only generated after a delay
jest.clearAllMocks();
jest.advanceTimersByTime(2500); // the key should not yet have been generated
expect(sendEventMock).not.toHaveBeenCalled();
expect(sess!.statistics.counters.roomEventEncryptionKeysSent).toEqual(1);
jest.advanceTimersByTime(1000); // the key should now have been generated

await keysSentPromise;

expect(secondKeysPayload.keys).toHaveLength(1);
expect(secondKeysPayload.keys[0].index).toEqual(1);
expect(onMyEncryptionKeyChanged).toHaveBeenCalledTimes(2);
expect(sess!.statistics.counters.roomEventEncryptionKeysSent).toEqual(2);
expect(onMyEncryptionKeyChanged).not.toHaveBeenCalled();

// the emit comes after the key usage delay
jest.clearAllMocks();
jest.advanceTimersByTime(5000);
expect(onMyEncryptionKeyChanged).toHaveBeenCalledTimes(1);
} finally {
jest.useRealTimers();
}
Expand Down Expand Up @@ -1125,46 +1189,6 @@ describe("MatrixRTCSession", () => {
jest.useRealTimers();
}
});

it("doesn't re-send key immediately", async () => {
const realSetTimeout = setTimeout;
jest.useFakeTimers();
try {
const mockRoom = makeMockRoom([membershipTemplate]);
sess = MatrixRTCSession.roomSessionForRoom(client, mockRoom);

const keysSentPromise1 = new Promise((resolve) => {
sendEventMock.mockImplementation(resolve);
});

sess.joinRoomSession([mockFocus], mockFocus, { manageMediaKeys: true });
await keysSentPromise1;

sendEventMock.mockClear();
expect(sess!.statistics.counters.roomEventEncryptionKeysSent).toEqual(1);

const onMembershipsChanged = jest.fn();
sess.on(MatrixRTCSessionEvent.MembershipsChanged, onMembershipsChanged);

const member2 = Object.assign({}, membershipTemplate, {
device_id: "BBBBBBB",
});

mockRoom.getLiveTimeline().getState = jest
.fn()
.mockReturnValue(makeMockRoomState([membershipTemplate, member2], mockRoom.roomId));
sess.onMembershipUpdate();

await new Promise((resolve) => {
realSetTimeout(resolve);
});

expect(sendEventMock).not.toHaveBeenCalled();
expect(sess!.statistics.counters.roomEventEncryptionKeysSent).toEqual(1);
} finally {
jest.useRealTimers();
}
});
});

describe("receiving", () => {
Expand Down
22 changes: 16 additions & 6 deletions src/matrixrtc/MatrixRTCSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -792,11 +792,21 @@ export class MatrixRTCSession extends TypedEventEmitter<MatrixRTCSessionEvent, M
this.storeLastMembershipFingerprints();

if (anyLeft) {
logger.debug(`Member(s) have left: queueing sender key rotation`);
this.makeNewKeyTimeout = setTimeout(this.onRotateKeyTimeout, MAKE_KEY_DELAY);
if (newMembershipIds.size === 1) {
logger.debug(`New member(s) have left: doing immediate sender key rotation`);
this.doRotateKey(false);
} else {
logger.debug(`New member(s) have left: queueing sender key rotation`);
this.makeNewKeyTimeout = setTimeout(() => this.doRotateKey(true), MAKE_KEY_DELAY);
}
} else if (anyJoined) {
logger.debug(`New member(s) have joined: queueing sender key rotation`);
this.makeNewKeyTimeout = setTimeout(this.onRotateKeyTimeout, MAKE_KEY_DELAY);
if (newMembershipIds.size === 2) {
logger.debug(`New member(s) have joined: doing immediate sender key rotation`);
this.doRotateKey(false);
} else {
logger.debug(`New member(s) have joined: queueing sender key rotation`);
this.makeNewKeyTimeout = setTimeout(() => this.doRotateKey(true), MAKE_KEY_DELAY);
}
} else if (oldFingerprints) {
// does it look like any of the members have updated their memberships?
const newFingerprints = this.lastMembershipFingerprints!;
Expand Down Expand Up @@ -1094,12 +1104,12 @@ export class MatrixRTCSession extends TypedEventEmitter<MatrixRTCSessionEvent, M
}
}

private onRotateKeyTimeout = (): void => {
private doRotateKey = (delayOnUse: boolean): void => {
if (!this.manageMediaKeys) return;

this.makeNewKeyTimeout = undefined;
logger.info("Making new sender key for key rotation");
const newKeyIndex = this.makeNewSenderKey(true);
const newKeyIndex = this.makeNewSenderKey(delayOnUse);
// send immediately: if we're about to start sending with a new key, it's
// important we get it out to others as soon as we can.
this.sendEncryptionKeysEvent(newKeyIndex);
Expand Down

0 comments on commit 52a5c96

Please sign in to comment.