Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix protobuf wire encoding/decoding #258

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fixtures/proto/encodedAnotherPersonNestedV2.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export default [0, 0, 0, 0, 3, 2, 2, 10, 8, 74, 111, 104, 110, 32, 68, 111, 101]
2 changes: 1 addition & 1 deletion fixtures/proto/encodedAnotherPersonV2.ts
Original file line number Diff line number Diff line change
@@ -1 +1 @@
export default [0, 0, 0, 0, 3, 10, 8, 74, 111, 104, 110, 32, 68, 111, 101]
export default [0, 0, 0, 0, 3, 0, 10, 8, 74, 111, 104, 110, 32, 68, 111, 101]
2 changes: 1 addition & 1 deletion fixtures/proto/encodedNestedV2.ts
Original file line number Diff line number Diff line change
@@ -1 +1 @@
export default [0, 0, 0, 0, 32, 10, 10, 100, 97, 116, 97, 45, 118, 97, 108, 117, 101, 18, 17, 10, 15, 115, 111, 109, 101, 70, 105, 101, 108, 100, 45, 118, 97, 108, 117, 101, 26, 22, 10, 20, 115, 111, 109, 101, 79, 116, 104, 101, 114, 70, 105, 101, 108, 100, 45, 118, 97, 108, 117, 101 ]
export default [0, 0, 0, 0, 32, 0, 10, 10, 100, 97, 116, 97, 45, 118, 97, 108, 117, 101, 18, 17, 10, 15, 115, 111, 109, 101, 70, 105, 101, 108, 100, 45, 118, 97, 108, 117, 101, 26, 22, 10, 20, 115, 111, 109, 101, 79, 116, 104, 101, 114, 70, 105, 101, 108, 100, 45, 118, 97, 108, 117, 101 ]
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
"dependencies": {
"ajv": "^7.1.0",
"avsc": ">= 5.4.13 < 6",
"long": "^5.2.3",
"mappersmith": ">= 2.30.1 < 3",
"protobufjs": "^6.11.4"
"protobufjs": "github:davidgrisham/protobuf.js#ordered-nested-objects"
},
"devDependencies": {
"@types/execa": "^2.0.0",
Expand Down
224 changes: 192 additions & 32 deletions src/ProtoSchema.ts
Original file line number Diff line number Diff line change
@@ -1,74 +1,234 @@
import Long from 'long'
import { Schema, ProtoOptions, ProtoConfluentSchema } from './@types'
import protobuf from 'protobufjs'
import { IParserResult, ReflectionObject, Namespace, Type } from 'protobufjs/light'
import {
ConfluentSchemaRegistryArgumentError,
ConfluentSchemaRegistryValidationError,
} from './errors'
import { Root, Namespace, Type } from 'protobufjs/light'
import { ConfluentSchemaRegistryValidationError } from './errors'

const MAX_VARINT_LEN_64 = 10

export default class ProtoSchema implements Schema {
private message: Type
private namespace: Namespace

constructor(schema: ProtoConfluentSchema, opts?: ProtoOptions) {
const parsedMessage = protobuf.parse(schema.schema)
const root = parsedMessage.root
const referencedSchemas = opts?.referencedSchemas
this.namespace = this.getNestedNamespace(parsedMessage.root, parsedMessage.package || '')

const referencedSchemas = opts?.referencedSchemas
// handle all schema references independent on nested references
if (referencedSchemas) {
referencedSchemas.forEach(rawSchema => protobuf.parse(rawSchema.schema as string, root))
}
}

this.message = root.lookupType(this.getTypeName(parsedMessage, opts))
// getNestedNamespace traverses from the root down into the innermost namespace specified by the package name.
// this should return the Namespace that encapsulates all of the message types for this schema.
private getNestedNamespace(root: Root, pkg: string): Namespace {
let ns: Namespace = root
for (const name of pkg.split('.')) {
if (!ns.nested) {
throw new Error(`Unable to retrieve nested namespace '${pkg}' from root object`)
}
ns = ns.nested[name] as Namespace
if (!(ns instanceof Namespace)) {
throw new Error(
`Failed to retrieve namespace '${pkg}' from root object, because nested object '${name}' is not a Namespace instance`,
)
}
}
return ns
}

private getNestedTypeName(parent: { [k: string]: ReflectionObject } | undefined): string {
if (!parent) throw new ConfluentSchemaRegistryArgumentError('no nested fields')
const keys = Object.keys(parent)
const reflection = parent[keys[0]]
// this encodes a payload against the specified schema with the proper message index bytes. if typeName is empty,
// we default to the first schema in the namespace. if typeName is provided, we split it on '.' and access the
// nested schema accordingly. for example, if typeName is 'Task', then the payload will be encoded with the
// top level Task type in the namespace; if typeName is 'Task.TaskId', then payload will be encoded with the
// TaskId message type nested inside of the top level Task message.
//
// for more information on this see https://docs.confluent.io/platform/current/schema-registry/fundamentals/serdes-develop/index.html#wire-format
public toBufferFromNestedType(payload: object, typeName = ''): Buffer {
if (!typeName) {
return this.toBuffer(payload)
}

// Traverse down the nested Namespaces until we find a message Type instance (which extends Namespace)
if (reflection instanceof Namespace && !(reflection instanceof Type) && reflection.nested)
return this.getNestedTypeName(reflection.nested)
return keys[0]
}
const typeArray = typeName.split('.')
const msgIndexes = new Array<number>(typeArray.length)
let currentNamespace: Namespace = this.namespace
// find the nested type (and store the message indexes along the way)
for (const [i, name] of typeName.split('.').entries()) {
const nestedMessageIndex = currentNamespace.orderedNestedMessages.findIndex(
msg => msg.name === name,
)
if (nestedMessageIndex === -1) {
throw new Error(
`Unable to retrieve nested type '${typeName}' from namespace (failed at '${name}')`,
)
}
const nestedMessage = currentNamespace.orderedNestedMessages[nestedMessageIndex]
if (!(nestedMessage && nestedMessage instanceof Type)) {
throw new Error(
`Unable to retrieve nested type '${typeName}' from namespace (failed at '${name}')`,
)
}
msgIndexes[i] = nestedMessageIndex
currentNamespace = nestedMessage as Namespace
}
const schema = currentNamespace as Type

private getTypeName(parsedMessage: IParserResult, opts?: ProtoOptions): string {
const root = parsedMessage.root
const pkg = parsedMessage.package
const name = opts && opts.messageName ? opts.messageName : this.getNestedTypeName(root.nested)
return `${pkg ? pkg + '.' : ''}.${name}`
}
const encodedMessageIndexes = this.encodeMessageIndexes(msgIndexes)
const msgPayload = schema.create(payload)

private trimStart(buffer: Buffer): Buffer {
const index = buffer.findIndex((value: number) => value != 0)
return buffer.slice(index)
return Buffer.concat([encodedMessageIndexes, Buffer.from(schema.encode(msgPayload).finish())])
}

// this handles the common case where we default to the first message in the payload. in this case we encode the
// message index bytes as just a single byte of 0. this function is partly here to conform to the Schema interface
// -- the more general cases are handled by toBufferFromNestedType above.
public toBuffer(payload: object): Buffer {
if (!(this.namespace.orderedNestedMessages[0] instanceof Type)) {
throw new Error(
'Failed to retrieve schema to serialize protobuf message: nested message is not an instance of Type',
)
}
const schema = this.namespace.orderedNestedMessages[0] as Type

const paths: string[][] = []
if (
!this.isValid(payload, {
!this.validatePayloadAgainstSchema(schema, payload, {
errorHook: (path: Array<string>) => paths.push(path),
})
) {
throw new ConfluentSchemaRegistryValidationError('invalid payload', paths)
}

const protoPayload = this.message.create(payload)
return Buffer.from(this.message.encode(protoPayload).finish())
return Buffer.from([0, ...schema.encode(schema.create(payload)).finish()])
}

// adapted from https://github.com/confluentinc/confluent-kafka-go/blob/af4a5f8b2018db6503f7e8097a25a24a6d6feb06/schemaregistry/serde/protobuf/protobuf.go#L295
private encodeMessageIndexes(msgIndexes: Array<number>): Buffer {
const encodedIndexes = Buffer.alloc((1 + msgIndexes.length) * MAX_VARINT_LEN_64)

let totalLength = this.putVarint(encodedIndexes, msgIndexes.length, 0)

for (const msgIndex of msgIndexes) {
const length = this.putVarint(encodedIndexes, msgIndex, totalLength)
totalLength += length
}
return encodedIndexes.slice(0, totalLength)
}

// adapted from https://go.dev/src/encoding/binary/varint.go
private putVarint(buffer: Buffer, value: number, offset: number): number {
let x = Long.fromNumber(value, true).shiftLeft(1) // unsigned 64 bit integer
if (value < 0) {
x = x.not()
}
let i = 0
while (x.gte(0x80)) {
buffer.writeUInt8((x.getLowBits() & 0x000000ff) | 0x80, offset + i)
x = x.shiftRightUnsigned(7)
i += 1
}
buffer.writeUInt8(x.getLowBits() & 0x000000ff, offset + i)
return i + 1
}

public fromBuffer(buffer: Buffer): any {
const newBuffer = this.trimStart(buffer)
return this.message.decode(newBuffer)
const [bytesRead, msgIndexes] = this.readMessageIndexes(buffer)
const message = this.lookupMessage(msgIndexes)

return message.decode(buffer.slice(bytesRead))
}

private lookupMessage(msgIndexes: Array<number>): Type {
let currentNamespace: Namespace = this.namespace
for (const idx of msgIndexes) {
if (!(currentNamespace.orderedNestedMessages[idx] instanceof Type)) {
throw new Error(
'Failed to retrieve nested message from namespace: nested message is not an instance of Type',
)
}
currentNamespace = currentNamespace.orderedNestedMessages[idx] as Namespace
}
return currentNamespace as Type
}

private readMessageIndexes(payload: Buffer): [number, Array<number>] {
const result = this.parseVarint(payload)
const arrayLen = result[0]
let bytesRead = result[1]
if (bytesRead <= 0) {
throw new Error('unable to read message indexes')
}
if (arrayLen.lt(0)) {
throw new Error('parsed invalid message index count')
}
if (arrayLen.eq(0)) {
return [bytesRead, [0]]
}

const msgIndexes = new Array<number>(arrayLen.toInt())
for (let i = 0; arrayLen.gt(i); i++) {
const [idx, read] = this.parseVarint(payload.slice(bytesRead))
if (read <= 0) {
throw new Error('unable to read message indexes')
}
bytesRead += read
msgIndexes[i] = idx.toInt()
}

return [bytesRead, msgIndexes]
}

// adapted from https://go.dev/src/encoding/binary/varint.go
private parseVarint(buffer: Buffer): [Long, number] {
const [ux, n] = this.parseUvarint(buffer)
let x = ux.shiftRight(1).toSigned()
if (ux.and(1).neq(0)) {
x = x.not()
}
return [x, n]
}

private parseUvarint(buffer: Buffer): [Long, number] {
let x = Long.UZERO // new unsigned 64 bit integer
let s = 0

for (let i = 0; i < buffer.length; i++) {
if (i == MAX_VARINT_LEN_64) {
// overflow
return [Long.UZERO, -(i + 1)]
}
const b = buffer.readUInt8(i)
if (b < 0x80) {
if (i == MAX_VARINT_LEN_64 - 1 && b > 1) {
// overflow
return [Long.UZERO, -(i + 1)]
}
return [x.or(Long.fromBits(b, 0, true).shiftLeft(s)), i + 1]
}
x = x.or(new Long(b & 0x7f, 0, true))
s += 7
}
return [Long.UZERO, 0]
}

// unimplemented -- this is part of the Schema interface, but because the protobuf schema namespace can
// store multiple schemas we need something that can specify which schema we're validating against.
// the validatePayloadAgainstSchema function below achieves this
public isValid(
_payload: object,
_opts?: { errorHook: (path: Array<string>, value: any, type?: any) => void },
): boolean {
return false
}

public validatePayloadAgainstSchema(
schema: Type,
payload: object,
opts?: { errorHook: (path: Array<string>, value: any, type?: any) => void },
): boolean {
const errMsg: null | string = this.message.verify(payload)
const errMsg: null | string = schema.verify(payload)
if (errMsg) {
if (opts?.errorHook) {
opts.errorHook([errMsg], payload)
Expand Down
8 changes: 5 additions & 3 deletions src/SchemaRegistry.newApi.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { COMPATIBILITY, DEFAULT_API_CLIENT_ID } from './constants'
import encodedAnotherPersonV2Avro from '../fixtures/avro/encodedAnotherPersonV2'
import encodedAnotherPersonV2Json from '../fixtures/json/encodedAnotherPersonV2'
import encodedAnotherPersonV2Proto from '../fixtures/proto/encodedAnotherPersonV2'
import encodedAnotherPersonNestedV2Proto from '../fixtures/proto/encodedAnotherPersonNestedV2'
import encodedNestedV2Proto from '../fixtures/proto/encodedNestedV2'
import wrongMagicByte from '../fixtures/wrongMagicByte'
import Ajv2020 from 'ajv8/dist/2020'
Expand Down Expand Up @@ -160,6 +161,7 @@ describe('SchemaRegistry - new Api', () => {
}
`,
encodedAnotherPersonV2: encodedAnotherPersonV2Proto,
encodedAnotherPersonNestedV2: encodedAnotherPersonNestedV2Proto,
},
}
const types = Object.keys(schemaStringsByType).map(str => SchemaType[str]) as KnownSchemaTypes[]
Expand Down Expand Up @@ -468,11 +470,11 @@ describe('SchemaRegistry - new Api', () => {
subject: `${type}_test3`,
})

const data = await schemaRegistry.encode(schema3.id, payload)
const data = await schemaRegistry.encode(schema3.id, payload, 'AnotherPerson')

expect(data).toMatchConfluentEncodedPayload({
registryId: schema3.id,
payload: Buffer.from(schemaStringsByType[type].encodedAnotherPersonV2),
payload: Buffer.from(schemaStringsByType[type].encodedAnotherPersonNestedV2),
})
})

Expand All @@ -486,7 +488,7 @@ describe('SchemaRegistry - new Api', () => {
subject: `${type}_test3`,
})

const buffer = Buffer.from(await schemaRegistry.encode(schema3.id, payload))
const buffer = Buffer.from(await schemaRegistry.encode(schema3.id, payload, 'AnotherPerson'))
const data = await schemaRegistry.decode(buffer)

expect(data).toEqual(payload)
Expand Down
6 changes: 3 additions & 3 deletions src/SchemaRegistry.protobuf.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ describe('SchemaRegistry', () => {
})

it('should return schema that match message', async () => {
expect(schema.message.name).toEqual('ThirdLevel')
expect(schema.namespace.orderedNestedMessages[0].name).toEqual('ThirdLevel')
})

it('should be able to encode/decode', async () => {
Expand All @@ -256,7 +256,7 @@ describe('SchemaRegistry', () => {
})

it('should return schema that match message', async () => {
expect(schema.message.name).toEqual('SecondLevelA')
expect(schema.namespace.orderedNestedMessages[0].name).toEqual('SecondLevelA')
})

it('should be able to encode/decode', async () => {
Expand Down Expand Up @@ -300,7 +300,7 @@ describe('SchemaRegistry', () => {
})

it('should return schema that match message', async () => {
expect(schema.message.name).toEqual('FirstLevel')
expect(schema.namespace.orderedNestedMessages[0].name).toEqual('FirstLevel')
})

it('should be able to encode/decode', async () => {
Expand Down
21 changes: 18 additions & 3 deletions src/SchemaRegistry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import {
schemaTypeFromString,
schemaFromConfluentSchema,
} from './schemaTypeResolver'
import ProtoSchema from './ProtoSchema'

export interface RegisteredSchema {
id: number
Expand Down Expand Up @@ -262,16 +263,30 @@ export default class SchemaRegistry {
return await (await this._getSchema(registryId)).schema
}

public async encode(registryId: number, payload: any): Promise<Buffer> {
public async encode(
registryId: number,
payload: any,
typeName?: string /* for protobuf */,
): Promise<Buffer> {
if (!registryId) {
throw new ConfluentSchemaRegistryArgumentError(
`Invalid registryId: ${JSON.stringify(registryId)}`,
)
}

const { schema } = await this._getSchema(registryId)
const { type, schema } = await this._getSchema(registryId)
try {
const serializedPayload = schema.toBuffer(payload)
let serializedPayload
switch (type) {
// in the case of protobuf schemas we need a bit more information to specify which schema to encode with
// (see the implementation for details)
case SchemaType.PROTOBUF:
const protoSchema = schema as ProtoSchema
serializedPayload = protoSchema.toBufferFromNestedType(payload, typeName)
break
default:
serializedPayload = schema.toBuffer(payload)
}
return encode(registryId, serializedPayload)
} catch (error) {
if (error instanceof ConfluentSchemaRegistryValidationError) throw error
Expand Down