diff --git a/fixtures/proto/encodedAnotherPersonNestedV2.ts b/fixtures/proto/encodedAnotherPersonNestedV2.ts new file mode 100644 index 0000000..99480d1 --- /dev/null +++ b/fixtures/proto/encodedAnotherPersonNestedV2.ts @@ -0,0 +1 @@ +export default [0, 0, 0, 0, 3, 2, 2, 10, 8, 74, 111, 104, 110, 32, 68, 111, 101] diff --git a/fixtures/proto/encodedAnotherPersonV2.ts b/fixtures/proto/encodedAnotherPersonV2.ts index b826938..4415d96 100644 --- a/fixtures/proto/encodedAnotherPersonV2.ts +++ b/fixtures/proto/encodedAnotherPersonV2.ts @@ -1 +1 @@ -export default [0, 0, 0, 0, 3, 10, 8, 74, 111, 104, 110, 32, 68, 111, 101] +export default [0, 0, 0, 0, 3, 0, 10, 8, 74, 111, 104, 110, 32, 68, 111, 101] diff --git a/fixtures/proto/encodedNestedV2.ts b/fixtures/proto/encodedNestedV2.ts index 327e482..7c71068 100644 --- a/fixtures/proto/encodedNestedV2.ts +++ b/fixtures/proto/encodedNestedV2.ts @@ -1 +1 @@ -export default [0, 0, 0, 0, 32, 10, 10, 100, 97, 116, 97, 45, 118, 97, 108, 117, 101, 18, 17, 10, 15, 115, 111, 109, 101, 70, 105, 101, 108, 100, 45, 118, 97, 108, 117, 101, 26, 22, 10, 20, 115, 111, 109, 101, 79, 116, 104, 101, 114, 70, 105, 101, 108, 100, 45, 118, 97, 108, 117, 101 ] \ No newline at end of file +export default [0, 0, 0, 0, 32, 0, 10, 10, 100, 97, 116, 97, 45, 118, 97, 108, 117, 101, 18, 17, 10, 15, 115, 111, 109, 101, 70, 105, 101, 108, 100, 45, 118, 97, 108, 117, 101, 26, 22, 10, 20, 115, 111, 109, 101, 79, 116, 104, 101, 114, 70, 105, 101, 108, 100, 45, 118, 97, 108, 117, 101 ] diff --git a/package.json b/package.json index a786a1f..f2a2203 100644 --- a/package.json +++ b/package.json @@ -25,8 +25,9 @@ "dependencies": { "ajv": "^7.1.0", "avsc": ">= 5.4.13 < 6", + "long": "^5.2.3", "mappersmith": ">= 2.30.1 < 3", - "protobufjs": "^6.11.4" + "protobufjs": "github:davidgrisham/protobuf.js#ordered-nested-objects" }, "devDependencies": { "@types/execa": "^2.0.0", diff --git a/src/ProtoSchema.ts b/src/ProtoSchema.ts index bd358fe..19eca05 100644 --- a/src/ProtoSchema.ts +++ b/src/ProtoSchema.ts @@ -1,74 +1,234 @@ +import Long from 'long' import { Schema, ProtoOptions, ProtoConfluentSchema } from './@types' import protobuf from 'protobufjs' -import { IParserResult, ReflectionObject, Namespace, Type } from 'protobufjs/light' -import { - ConfluentSchemaRegistryArgumentError, - ConfluentSchemaRegistryValidationError, -} from './errors' +import { Root, Namespace, Type } from 'protobufjs/light' +import { ConfluentSchemaRegistryValidationError } from './errors' + +const MAX_VARINT_LEN_64 = 10 export default class ProtoSchema implements Schema { - private message: Type + private namespace: Namespace constructor(schema: ProtoConfluentSchema, opts?: ProtoOptions) { const parsedMessage = protobuf.parse(schema.schema) const root = parsedMessage.root - const referencedSchemas = opts?.referencedSchemas + this.namespace = this.getNestedNamespace(parsedMessage.root, parsedMessage.package || '') + const referencedSchemas = opts?.referencedSchemas // handle all schema references independent on nested references if (referencedSchemas) { referencedSchemas.forEach(rawSchema => protobuf.parse(rawSchema.schema as string, root)) } + } - this.message = root.lookupType(this.getTypeName(parsedMessage, opts)) + // getNestedNamespace traverses from the root down into the innermost namespace specified by the package name. + // this should return the Namespace that encapsulates all of the message types for this schema. + private getNestedNamespace(root: Root, pkg: string): Namespace { + let ns: Namespace = root + for (const name of pkg.split('.')) { + if (!ns.nested) { + throw new Error(`Unable to retrieve nested namespace '${pkg}' from root object`) + } + ns = ns.nested[name] as Namespace + if (!(ns instanceof Namespace)) { + throw new Error( + `Failed to retrieve namespace '${pkg}' from root object, because nested object '${name}' is not a Namespace instance`, + ) + } + } + return ns } - private getNestedTypeName(parent: { [k: string]: ReflectionObject } | undefined): string { - if (!parent) throw new ConfluentSchemaRegistryArgumentError('no nested fields') - const keys = Object.keys(parent) - const reflection = parent[keys[0]] + // this encodes a payload against the specified schema with the proper message index bytes. if typeName is empty, + // we default to the first schema in the namespace. if typeName is provided, we split it on '.' and access the + // nested schema accordingly. for example, if typeName is 'Task', then the payload will be encoded with the + // top level Task type in the namespace; if typeName is 'Task.TaskId', then payload will be encoded with the + // TaskId message type nested inside of the top level Task message. + // + // for more information on this see https://docs.confluent.io/platform/current/schema-registry/fundamentals/serdes-develop/index.html#wire-format + public toBufferFromNestedType(payload: object, typeName = ''): Buffer { + if (!typeName) { + return this.toBuffer(payload) + } - // Traverse down the nested Namespaces until we find a message Type instance (which extends Namespace) - if (reflection instanceof Namespace && !(reflection instanceof Type) && reflection.nested) - return this.getNestedTypeName(reflection.nested) - return keys[0] - } + const typeArray = typeName.split('.') + const msgIndexes = new Array(typeArray.length) + let currentNamespace: Namespace = this.namespace + // find the nested type (and store the message indexes along the way) + for (const [i, name] of typeName.split('.').entries()) { + const nestedMessageIndex = currentNamespace.orderedNestedMessages.findIndex( + msg => msg.name === name, + ) + if (nestedMessageIndex === -1) { + throw new Error( + `Unable to retrieve nested type '${typeName}' from namespace (failed at '${name}')`, + ) + } + const nestedMessage = currentNamespace.orderedNestedMessages[nestedMessageIndex] + if (!(nestedMessage && nestedMessage instanceof Type)) { + throw new Error( + `Unable to retrieve nested type '${typeName}' from namespace (failed at '${name}')`, + ) + } + msgIndexes[i] = nestedMessageIndex + currentNamespace = nestedMessage as Namespace + } + const schema = currentNamespace as Type - private getTypeName(parsedMessage: IParserResult, opts?: ProtoOptions): string { - const root = parsedMessage.root - const pkg = parsedMessage.package - const name = opts && opts.messageName ? opts.messageName : this.getNestedTypeName(root.nested) - return `${pkg ? pkg + '.' : ''}.${name}` - } + const encodedMessageIndexes = this.encodeMessageIndexes(msgIndexes) + const msgPayload = schema.create(payload) - private trimStart(buffer: Buffer): Buffer { - const index = buffer.findIndex((value: number) => value != 0) - return buffer.slice(index) + return Buffer.concat([encodedMessageIndexes, Buffer.from(schema.encode(msgPayload).finish())]) } + // this handles the common case where we default to the first message in the payload. in this case we encode the + // message index bytes as just a single byte of 0. this function is partly here to conform to the Schema interface + // -- the more general cases are handled by toBufferFromNestedType above. public toBuffer(payload: object): Buffer { + if (!(this.namespace.orderedNestedMessages[0] instanceof Type)) { + throw new Error( + 'Failed to retrieve schema to serialize protobuf message: nested message is not an instance of Type', + ) + } + const schema = this.namespace.orderedNestedMessages[0] as Type + const paths: string[][] = [] if ( - !this.isValid(payload, { + !this.validatePayloadAgainstSchema(schema, payload, { errorHook: (path: Array) => paths.push(path), }) ) { throw new ConfluentSchemaRegistryValidationError('invalid payload', paths) } - const protoPayload = this.message.create(payload) - return Buffer.from(this.message.encode(protoPayload).finish()) + return Buffer.from([0, ...schema.encode(schema.create(payload)).finish()]) + } + + // adapted from https://github.com/confluentinc/confluent-kafka-go/blob/af4a5f8b2018db6503f7e8097a25a24a6d6feb06/schemaregistry/serde/protobuf/protobuf.go#L295 + private encodeMessageIndexes(msgIndexes: Array): Buffer { + const encodedIndexes = Buffer.alloc((1 + msgIndexes.length) * MAX_VARINT_LEN_64) + + let totalLength = this.putVarint(encodedIndexes, msgIndexes.length, 0) + + for (const msgIndex of msgIndexes) { + const length = this.putVarint(encodedIndexes, msgIndex, totalLength) + totalLength += length + } + return encodedIndexes.slice(0, totalLength) + } + + // adapted from https://go.dev/src/encoding/binary/varint.go + private putVarint(buffer: Buffer, value: number, offset: number): number { + let x = Long.fromNumber(value, true).shiftLeft(1) // unsigned 64 bit integer + if (value < 0) { + x = x.not() + } + let i = 0 + while (x.gte(0x80)) { + buffer.writeUInt8((x.getLowBits() & 0x000000ff) | 0x80, offset + i) + x = x.shiftRightUnsigned(7) + i += 1 + } + buffer.writeUInt8(x.getLowBits() & 0x000000ff, offset + i) + return i + 1 } public fromBuffer(buffer: Buffer): any { - const newBuffer = this.trimStart(buffer) - return this.message.decode(newBuffer) + const [bytesRead, msgIndexes] = this.readMessageIndexes(buffer) + const message = this.lookupMessage(msgIndexes) + + return message.decode(buffer.slice(bytesRead)) } + private lookupMessage(msgIndexes: Array): Type { + let currentNamespace: Namespace = this.namespace + for (const idx of msgIndexes) { + if (!(currentNamespace.orderedNestedMessages[idx] instanceof Type)) { + throw new Error( + 'Failed to retrieve nested message from namespace: nested message is not an instance of Type', + ) + } + currentNamespace = currentNamespace.orderedNestedMessages[idx] as Namespace + } + return currentNamespace as Type + } + + private readMessageIndexes(payload: Buffer): [number, Array] { + const result = this.parseVarint(payload) + const arrayLen = result[0] + let bytesRead = result[1] + if (bytesRead <= 0) { + throw new Error('unable to read message indexes') + } + if (arrayLen.lt(0)) { + throw new Error('parsed invalid message index count') + } + if (arrayLen.eq(0)) { + return [bytesRead, [0]] + } + + const msgIndexes = new Array(arrayLen.toInt()) + for (let i = 0; arrayLen.gt(i); i++) { + const [idx, read] = this.parseVarint(payload.slice(bytesRead)) + if (read <= 0) { + throw new Error('unable to read message indexes') + } + bytesRead += read + msgIndexes[i] = idx.toInt() + } + + return [bytesRead, msgIndexes] + } + + // adapted from https://go.dev/src/encoding/binary/varint.go + private parseVarint(buffer: Buffer): [Long, number] { + const [ux, n] = this.parseUvarint(buffer) + let x = ux.shiftRight(1).toSigned() + if (ux.and(1).neq(0)) { + x = x.not() + } + return [x, n] + } + + private parseUvarint(buffer: Buffer): [Long, number] { + let x = Long.UZERO // new unsigned 64 bit integer + let s = 0 + + for (let i = 0; i < buffer.length; i++) { + if (i == MAX_VARINT_LEN_64) { + // overflow + return [Long.UZERO, -(i + 1)] + } + const b = buffer.readUInt8(i) + if (b < 0x80) { + if (i == MAX_VARINT_LEN_64 - 1 && b > 1) { + // overflow + return [Long.UZERO, -(i + 1)] + } + return [x.or(Long.fromBits(b, 0, true).shiftLeft(s)), i + 1] + } + x = x.or(new Long(b & 0x7f, 0, true)) + s += 7 + } + return [Long.UZERO, 0] + } + + // unimplemented -- this is part of the Schema interface, but because the protobuf schema namespace can + // store multiple schemas we need something that can specify which schema we're validating against. + // the validatePayloadAgainstSchema function below achieves this public isValid( + _payload: object, + _opts?: { errorHook: (path: Array, value: any, type?: any) => void }, + ): boolean { + return false + } + + public validatePayloadAgainstSchema( + schema: Type, payload: object, opts?: { errorHook: (path: Array, value: any, type?: any) => void }, ): boolean { - const errMsg: null | string = this.message.verify(payload) + const errMsg: null | string = schema.verify(payload) if (errMsg) { if (opts?.errorHook) { opts.errorHook([errMsg], payload) diff --git a/src/SchemaRegistry.newApi.spec.ts b/src/SchemaRegistry.newApi.spec.ts index 0138844..2b2dfdc 100644 --- a/src/SchemaRegistry.newApi.spec.ts +++ b/src/SchemaRegistry.newApi.spec.ts @@ -8,6 +8,7 @@ import { COMPATIBILITY, DEFAULT_API_CLIENT_ID } from './constants' import encodedAnotherPersonV2Avro from '../fixtures/avro/encodedAnotherPersonV2' import encodedAnotherPersonV2Json from '../fixtures/json/encodedAnotherPersonV2' import encodedAnotherPersonV2Proto from '../fixtures/proto/encodedAnotherPersonV2' +import encodedAnotherPersonNestedV2Proto from '../fixtures/proto/encodedAnotherPersonNestedV2' import encodedNestedV2Proto from '../fixtures/proto/encodedNestedV2' import wrongMagicByte from '../fixtures/wrongMagicByte' import Ajv2020 from 'ajv8/dist/2020' @@ -160,6 +161,7 @@ describe('SchemaRegistry - new Api', () => { } `, encodedAnotherPersonV2: encodedAnotherPersonV2Proto, + encodedAnotherPersonNestedV2: encodedAnotherPersonNestedV2Proto, }, } const types = Object.keys(schemaStringsByType).map(str => SchemaType[str]) as KnownSchemaTypes[] @@ -468,11 +470,11 @@ describe('SchemaRegistry - new Api', () => { subject: `${type}_test3`, }) - const data = await schemaRegistry.encode(schema3.id, payload) + const data = await schemaRegistry.encode(schema3.id, payload, 'AnotherPerson') expect(data).toMatchConfluentEncodedPayload({ registryId: schema3.id, - payload: Buffer.from(schemaStringsByType[type].encodedAnotherPersonV2), + payload: Buffer.from(schemaStringsByType[type].encodedAnotherPersonNestedV2), }) }) @@ -486,7 +488,7 @@ describe('SchemaRegistry - new Api', () => { subject: `${type}_test3`, }) - const buffer = Buffer.from(await schemaRegistry.encode(schema3.id, payload)) + const buffer = Buffer.from(await schemaRegistry.encode(schema3.id, payload, 'AnotherPerson')) const data = await schemaRegistry.decode(buffer) expect(data).toEqual(payload) diff --git a/src/SchemaRegistry.protobuf.spec.ts b/src/SchemaRegistry.protobuf.spec.ts index 7477a9f..a5144ee 100644 --- a/src/SchemaRegistry.protobuf.spec.ts +++ b/src/SchemaRegistry.protobuf.spec.ts @@ -230,7 +230,7 @@ describe('SchemaRegistry', () => { }) it('should return schema that match message', async () => { - expect(schema.message.name).toEqual('ThirdLevel') + expect(schema.namespace.orderedNestedMessages[0].name).toEqual('ThirdLevel') }) it('should be able to encode/decode', async () => { @@ -256,7 +256,7 @@ describe('SchemaRegistry', () => { }) it('should return schema that match message', async () => { - expect(schema.message.name).toEqual('SecondLevelA') + expect(schema.namespace.orderedNestedMessages[0].name).toEqual('SecondLevelA') }) it('should be able to encode/decode', async () => { @@ -300,7 +300,7 @@ describe('SchemaRegistry', () => { }) it('should return schema that match message', async () => { - expect(schema.message.name).toEqual('FirstLevel') + expect(schema.namespace.orderedNestedMessages[0].name).toEqual('FirstLevel') }) it('should be able to encode/decode', async () => { diff --git a/src/SchemaRegistry.ts b/src/SchemaRegistry.ts index cb8ce32..6bd475e 100644 --- a/src/SchemaRegistry.ts +++ b/src/SchemaRegistry.ts @@ -32,6 +32,7 @@ import { schemaTypeFromString, schemaFromConfluentSchema, } from './schemaTypeResolver' +import ProtoSchema from './ProtoSchema' export interface RegisteredSchema { id: number @@ -262,16 +263,30 @@ export default class SchemaRegistry { return await (await this._getSchema(registryId)).schema } - public async encode(registryId: number, payload: any): Promise { + public async encode( + registryId: number, + payload: any, + typeName?: string /* for protobuf */, + ): Promise { if (!registryId) { throw new ConfluentSchemaRegistryArgumentError( `Invalid registryId: ${JSON.stringify(registryId)}`, ) } - const { schema } = await this._getSchema(registryId) + const { type, schema } = await this._getSchema(registryId) try { - const serializedPayload = schema.toBuffer(payload) + let serializedPayload + switch (type) { + // in the case of protobuf schemas we need a bit more information to specify which schema to encode with + // (see the implementation for details) + case SchemaType.PROTOBUF: + const protoSchema = schema as ProtoSchema + serializedPayload = protoSchema.toBufferFromNestedType(payload, typeName) + break + default: + serializedPayload = schema.toBuffer(payload) + } return encode(registryId, serializedPayload) } catch (error) { if (error instanceof ConfluentSchemaRegistryValidationError) throw error