diff --git a/src/Message.ts b/src/Message.ts index c500ab79f..c778a02fa 100644 --- a/src/Message.ts +++ b/src/Message.ts @@ -11,10 +11,17 @@ import { NoMatchingPreKeyError } from './crypto/errors' import { bytesToHex } from './crypto/utils' import { sha256 } from './crypto/encryption' +const extractV1Message = (msg: proto.Message): proto.V1Message => { + if (!msg.v1) { + throw new Error('Message is not of type v1') + } + return msg.v1 +} + // Message is basic unit of communication on the network. // Message header carries the sender and recipient keys used to protect message. // Message timestamp is set by the sender. -export default class Message implements proto.Message { +export default class Message implements proto.V1Message { header: proto.MessageHeader | undefined // eslint-disable-line camelcase headerBytes: Uint8Array // encoded header bytes ciphertext: Ciphertext | undefined @@ -34,12 +41,13 @@ export default class Message implements proto.Message { obj: proto.Message, header: proto.MessageHeader ) { + const msg = extractV1Message(obj) this.id = id this.bytes = bytes - this.headerBytes = obj.headerBytes + this.headerBytes = msg.headerBytes this.header = header - if (obj.ciphertext) { - this.ciphertext = new Ciphertext(obj.ciphertext) + if (msg.ciphertext) { + this.ciphertext = new Ciphertext(msg.ciphertext) } } @@ -58,7 +66,8 @@ export default class Message implements proto.Message { static async fromBytes(bytes: Uint8Array): Promise { const msg = proto.Message.decode(bytes) - const header = proto.MessageHeader.decode(msg.headerBytes) + const innerMessage = extractV1Message(msg) + const header = proto.MessageHeader.decode(innerMessage.headerBytes) return Message.create(msg, header, bytes) } @@ -112,7 +121,7 @@ export default class Message implements proto.Message { } const headerBytes = proto.MessageHeader.encode(header).finish() const ciphertext = await encrypt(msgBytes, secret, headerBytes) - const protoMsg = { headerBytes: headerBytes, ciphertext } + const protoMsg = { v1: { headerBytes: headerBytes, ciphertext } } const bytes = proto.Message.encode(protoMsg).finish() const msg = await Message.create(protoMsg, header, bytes) msg.decrypted = message @@ -127,7 +136,8 @@ export default class Message implements proto.Message { bytes: Uint8Array ): Promise { const message = proto.Message.decode(bytes) - const header = proto.MessageHeader.decode(message.headerBytes) + const v1Message = extractV1Message(message) + const header = proto.MessageHeader.decode(v1Message.headerBytes) if (!header) { throw new Error('missing message header') } @@ -157,10 +167,10 @@ export default class Message implements proto.Message { new PublicKey(header.sender.identityKey), new PublicKey(header.sender.preKey) ) - if (!message.ciphertext?.aes256GcmHkdfSha256) { + if (!v1Message.ciphertext?.aes256GcmHkdfSha256) { throw new Error('missing message ciphertext') } - const ciphertext = new Ciphertext(message.ciphertext) + const ciphertext = new Ciphertext(v1Message.ciphertext) const msg = await Message.create(message, header, bytes) let secret: Uint8Array try { @@ -178,7 +188,7 @@ export default class Message implements proto.Message { msg.error = e return msg } - bytes = await decrypt(ciphertext, secret, message.headerBytes) + bytes = await decrypt(ciphertext, secret, v1Message.headerBytes) msg.decrypted = new TextDecoder().decode(bytes) return msg } diff --git a/src/proto/messaging.proto b/src/proto/messaging.proto index 9e2cae3c9..69cba5e18 100644 --- a/src/proto/messaging.proto +++ b/src/proto/messaging.proto @@ -54,11 +54,17 @@ message MessageHeader { uint64 timestamp = 3; } -message Message { +message V1Message { bytes headerBytes = 1; // encapsulates the encoded MessageHeader Ciphertext ciphertext = 2; } +message Message { + oneof version { + V1Message v1 = 1; + } +} + // Private Key Storage message PrivateKeyBundle { diff --git a/src/proto/messaging.ts b/src/proto/messaging.ts index adc7da288..56da05ab7 100644 --- a/src/proto/messaging.ts +++ b/src/proto/messaging.ts @@ -58,12 +58,16 @@ export interface MessageHeader { timestamp: number } -export interface Message { +export interface V1Message { /** encapsulates the encoded MessageHeader */ headerBytes: Uint8Array ciphertext: Ciphertext | undefined } +export interface Message { + v1: V1Message | undefined +} + export interface PrivateKeyBundle { identityKey: PrivateKey | undefined preKeys: PrivateKey[] @@ -869,13 +873,13 @@ export const MessageHeader = { }, } -function createBaseMessage(): Message { +function createBaseV1Message(): V1Message { return { headerBytes: new Uint8Array(), ciphertext: undefined } } -export const Message = { +export const V1Message = { encode( - message: Message, + message: V1Message, writer: _m0.Writer = _m0.Writer.create() ): _m0.Writer { if (message.headerBytes.length !== 0) { @@ -887,10 +891,10 @@ export const Message = { return writer }, - decode(input: _m0.Reader | Uint8Array, length?: number): Message { + decode(input: _m0.Reader | Uint8Array, length?: number): V1Message { const reader = input instanceof _m0.Reader ? input : new _m0.Reader(input) let end = length === undefined ? reader.len : reader.pos + length - const message = createBaseMessage() + const message = createBaseV1Message() while (reader.pos < end) { const tag = reader.uint32() switch (tag >>> 3) { @@ -908,7 +912,7 @@ export const Message = { return message }, - fromJSON(object: any): Message { + fromJSON(object: any): V1Message { return { headerBytes: isSet(object.headerBytes) ? bytesFromBase64(object.headerBytes) @@ -919,7 +923,7 @@ export const Message = { } }, - toJSON(message: Message): unknown { + toJSON(message: V1Message): unknown { const obj: any = {} message.headerBytes !== undefined && (obj.headerBytes = base64FromBytes( @@ -934,8 +938,10 @@ export const Message = { return obj }, - fromPartial, I>>(object: I): Message { - const message = createBaseMessage() + fromPartial, I>>( + object: I + ): V1Message { + const message = createBaseV1Message() message.headerBytes = object.headerBytes ?? new Uint8Array() message.ciphertext = object.ciphertext !== undefined && object.ciphertext !== null @@ -945,6 +951,62 @@ export const Message = { }, } +function createBaseMessage(): Message { + return { v1: undefined } +} + +export const Message = { + encode( + message: Message, + writer: _m0.Writer = _m0.Writer.create() + ): _m0.Writer { + if (message.v1 !== undefined) { + V1Message.encode(message.v1, writer.uint32(10).fork()).ldelim() + } + return writer + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): Message { + const reader = input instanceof _m0.Reader ? input : new _m0.Reader(input) + let end = length === undefined ? reader.len : reader.pos + length + const message = createBaseMessage() + while (reader.pos < end) { + const tag = reader.uint32() + switch (tag >>> 3) { + case 1: + message.v1 = V1Message.decode(reader, reader.uint32()) + break + default: + reader.skipType(tag & 7) + break + } + } + return message + }, + + fromJSON(object: any): Message { + return { + v1: isSet(object.v1) ? V1Message.fromJSON(object.v1) : undefined, + } + }, + + toJSON(message: Message): unknown { + const obj: any = {} + message.v1 !== undefined && + (obj.v1 = message.v1 ? V1Message.toJSON(message.v1) : undefined) + return obj + }, + + fromPartial, I>>(object: I): Message { + const message = createBaseMessage() + message.v1 = + object.v1 !== undefined && object.v1 !== null + ? V1Message.fromPartial(object.v1) + : undefined + return message + }, +} + function createBasePrivateKeyBundle(): PrivateKeyBundle { return { identityKey: undefined, preKeys: [] } }