Skip to content

Commit

Permalink
pblite: copy from mautrix-gmessages
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Aug 25, 2024
1 parent aa3f73c commit 33d03e7
Show file tree
Hide file tree
Showing 7 changed files with 453 additions and 0 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ require (
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa
golang.org/x/sys v0.24.0
golang.org/x/text v0.17.0
google.golang.org/protobuf v1.34.2
gopkg.in/yaml.v3 v3.0.1
)

Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
Expand Down Expand Up @@ -32,6 +34,8 @@ golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc=
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down
197 changes: 197 additions & 0 deletions pblite/deserialize.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
package pblite

import (
"encoding/base64"
"encoding/json"
"fmt"
"strconv"

"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
)

func Unmarshal(data []byte, m proto.Message) error {
var anyData any
if err := json.Unmarshal(data, &anyData); err != nil {
return err
}
anyDataArr, ok := anyData.([]any)
if !ok {
return fmt.Errorf("expected array in JSON, got %T", anyData)
}
return deserializeFromSlice(anyDataArr, m.ProtoReflect())
}

func isPbliteBinary(descriptor protoreflect.FieldDescriptor) bool {
opts := descriptor.Options().(*descriptorpb.FieldOptions)
pbliteBinary, ok := proto.GetExtension(opts, E_PbliteBinary).(bool)
return ok && pbliteBinary
}

func deserializeOne(val any, index int, ref protoreflect.Message, insideList protoreflect.List, fieldDescriptor protoreflect.FieldDescriptor) (protoreflect.Value, error) {
var num float64
var expectedKind, str string
var boolean, ok bool
var outputVal protoreflect.Value
if fieldDescriptor.IsList() && insideList == nil {
nestedData, ok := val.([]any)
if !ok {
return outputVal, fmt.Errorf("expected untyped array at index %d for repeated field %s, got %T", index, fieldDescriptor.FullName(), val)
}
list := ref.NewField(fieldDescriptor).List()
list.NewElement()
for i, nestedVal := range nestedData {
nestedParsed, err := deserializeOne(nestedVal, i, ref, list, fieldDescriptor)
if err != nil {
return outputVal, err
}
list.Append(nestedParsed)
}
return protoreflect.ValueOfList(list), nil
}
switch fieldDescriptor.Kind() {
case protoreflect.MessageKind:
ok = true
var nestedMessage protoreflect.Message
if insideList != nil {
nestedMessage = insideList.NewElement().Message()
} else {
nestedMessage = ref.NewField(fieldDescriptor).Message()
}
if isPbliteBinary(fieldDescriptor) {
bytesBase64, ok := val.(string)
if !ok {
return outputVal, fmt.Errorf("expected string at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val)
}
bytes, err := base64.StdEncoding.DecodeString(bytesBase64)
if err != nil {
return outputVal, fmt.Errorf("failed to decode base64 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err)
}
err = proto.Unmarshal(bytes, nestedMessage.Interface())
if err != nil {
return outputVal, fmt.Errorf("failed to unmarshal binary protobuf at index %d for field %s: %w", index, fieldDescriptor.FullName(), err)
}
} else {
nestedData, ok := val.([]any)
if !ok {
return outputVal, fmt.Errorf("expected untyped array at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val)
}
if err := deserializeFromSlice(nestedData, nestedMessage); err != nil {
return outputVal, err
}
}
outputVal = protoreflect.ValueOfMessage(nestedMessage)
case protoreflect.BytesKind:
ok = true
bytesBase64, ok := val.(string)
if !ok {
return outputVal, fmt.Errorf("expected string at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val)
}
bytes, err := base64.StdEncoding.DecodeString(bytesBase64)
if err != nil {
return outputVal, fmt.Errorf("failed to decode base64 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err)
}

outputVal = protoreflect.ValueOfBytes(bytes)
case protoreflect.EnumKind:
num, ok = val.(float64)
expectedKind = "float64"
outputVal = protoreflect.ValueOfEnum(protoreflect.EnumNumber(int32(num)))
case protoreflect.Int32Kind:
if str, ok = val.(string); ok {
parsedVal, err := strconv.ParseInt(str, 10, 32)
if err != nil {
return outputVal, fmt.Errorf("failed to parse int32 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err)
}
outputVal = protoreflect.ValueOfInt32(int32(parsedVal))
} else {
num, ok = val.(float64)
expectedKind = "float64"
outputVal = protoreflect.ValueOfInt32(int32(num))
}
case protoreflect.Int64Kind:
if str, ok = val.(string); ok {
parsedVal, err := strconv.ParseInt(str, 10, 64)
if err != nil {
return outputVal, fmt.Errorf("failed to parse int64 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err)
}
outputVal = protoreflect.ValueOfInt64(parsedVal)
} else {
num, ok = val.(float64)
expectedKind = "float64"
outputVal = protoreflect.ValueOfInt64(int64(num))
}
case protoreflect.Uint32Kind:
if str, ok = val.(string); ok {
parsedVal, err := strconv.ParseUint(str, 10, 32)
if err != nil {
return outputVal, fmt.Errorf("failed to parse uint32 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err)
}
outputVal = protoreflect.ValueOfUint32(uint32(parsedVal))
} else {
num, ok = val.(float64)
expectedKind = "float64"
outputVal = protoreflect.ValueOfUint32(uint32(num))
}
case protoreflect.Uint64Kind:
if str, ok = val.(string); ok {
parsedVal, err := strconv.ParseUint(str, 10, 64)
if err != nil {
return outputVal, fmt.Errorf("failed to parse uint64 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err)
}
outputVal = protoreflect.ValueOfUint64(parsedVal)
} else {
num, ok = val.(float64)
expectedKind = "float64"
outputVal = protoreflect.ValueOfUint64(uint64(num))
}
case protoreflect.FloatKind:
num, ok = val.(float64)
expectedKind = "float64"
outputVal = protoreflect.ValueOfFloat32(float32(num))
case protoreflect.DoubleKind:
num, ok = val.(float64)
expectedKind = "float64"
outputVal = protoreflect.ValueOfFloat64(num)
case protoreflect.StringKind:
str, ok = val.(string)
if ok && isPbliteBinary(fieldDescriptor) {
bytes, err := base64.StdEncoding.DecodeString(str)
if err != nil {
return outputVal, fmt.Errorf("failed to decode base64 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err)
}
str = string(bytes)
}
expectedKind = "string"
outputVal = protoreflect.ValueOfString(str)
case protoreflect.BoolKind:
boolean, ok = val.(bool)
expectedKind = "bool"
outputVal = protoreflect.ValueOfBool(boolean)
default:
return outputVal, fmt.Errorf("unsupported field type %s in %s", fieldDescriptor.Kind(), fieldDescriptor.FullName())
}
if !ok {
return outputVal, fmt.Errorf("expected %s at index %d for field %s, got %T", expectedKind, index, fieldDescriptor.FullName(), val)
}
return outputVal, nil
}

func deserializeFromSlice(data []any, ref protoreflect.Message) error {
for i := 0; i < ref.Descriptor().Fields().Len(); i++ {
fieldDescriptor := ref.Descriptor().Fields().Get(i)
index := int(fieldDescriptor.Number()) - 1
if index < 0 || index >= len(data) || data[index] == nil {
continue
}

val := data[index]
outputVal, err := deserializeOne(val, index, ref, nil, fieldDescriptor)
if err != nil {
return err
}
ref.Set(fieldDescriptor, outputVal)
}
return nil
}
83 changes: 83 additions & 0 deletions pblite/pblite.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pblite/pblite.pb.raw
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

pblite.protopblite google/protobuf/descriptor.proto:G
pblite_binary.google.protobuf.FieldOptionsІ (R pbliteBinary�B Z ../pblitebproto3
Expand Down
10 changes: 10 additions & 0 deletions pblite/pblite.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
syntax = "proto3";
package pblite;

option go_package = "../pblite";

import "google/protobuf/descriptor.proto";

extend google.protobuf.FieldOptions {
optional bool pblite_binary = 50000;
}
Loading

0 comments on commit 33d03e7

Please sign in to comment.