-
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
453 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
package pblite | ||
|
||
import ( | ||
"encoding/base64" | ||
"encoding/json" | ||
"fmt" | ||
"strconv" | ||
|
||
"google.golang.org/protobuf/proto" | ||
"google.golang.org/protobuf/reflect/protoreflect" | ||
"google.golang.org/protobuf/types/descriptorpb" | ||
) | ||
|
||
func Unmarshal(data []byte, m proto.Message) error { | ||
var anyData any | ||
if err := json.Unmarshal(data, &anyData); err != nil { | ||
return err | ||
} | ||
anyDataArr, ok := anyData.([]any) | ||
if !ok { | ||
return fmt.Errorf("expected array in JSON, got %T", anyData) | ||
} | ||
return deserializeFromSlice(anyDataArr, m.ProtoReflect()) | ||
} | ||
|
||
func isPbliteBinary(descriptor protoreflect.FieldDescriptor) bool { | ||
opts := descriptor.Options().(*descriptorpb.FieldOptions) | ||
pbliteBinary, ok := proto.GetExtension(opts, E_PbliteBinary).(bool) | ||
return ok && pbliteBinary | ||
} | ||
|
||
func deserializeOne(val any, index int, ref protoreflect.Message, insideList protoreflect.List, fieldDescriptor protoreflect.FieldDescriptor) (protoreflect.Value, error) { | ||
var num float64 | ||
var expectedKind, str string | ||
var boolean, ok bool | ||
var outputVal protoreflect.Value | ||
if fieldDescriptor.IsList() && insideList == nil { | ||
nestedData, ok := val.([]any) | ||
if !ok { | ||
return outputVal, fmt.Errorf("expected untyped array at index %d for repeated field %s, got %T", index, fieldDescriptor.FullName(), val) | ||
} | ||
list := ref.NewField(fieldDescriptor).List() | ||
list.NewElement() | ||
for i, nestedVal := range nestedData { | ||
nestedParsed, err := deserializeOne(nestedVal, i, ref, list, fieldDescriptor) | ||
if err != nil { | ||
return outputVal, err | ||
} | ||
list.Append(nestedParsed) | ||
} | ||
return protoreflect.ValueOfList(list), nil | ||
} | ||
switch fieldDescriptor.Kind() { | ||
case protoreflect.MessageKind: | ||
ok = true | ||
var nestedMessage protoreflect.Message | ||
if insideList != nil { | ||
nestedMessage = insideList.NewElement().Message() | ||
} else { | ||
nestedMessage = ref.NewField(fieldDescriptor).Message() | ||
} | ||
if isPbliteBinary(fieldDescriptor) { | ||
bytesBase64, ok := val.(string) | ||
if !ok { | ||
return outputVal, fmt.Errorf("expected string at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val) | ||
} | ||
bytes, err := base64.StdEncoding.DecodeString(bytesBase64) | ||
if err != nil { | ||
return outputVal, fmt.Errorf("failed to decode base64 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err) | ||
} | ||
err = proto.Unmarshal(bytes, nestedMessage.Interface()) | ||
if err != nil { | ||
return outputVal, fmt.Errorf("failed to unmarshal binary protobuf at index %d for field %s: %w", index, fieldDescriptor.FullName(), err) | ||
} | ||
} else { | ||
nestedData, ok := val.([]any) | ||
if !ok { | ||
return outputVal, fmt.Errorf("expected untyped array at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val) | ||
} | ||
if err := deserializeFromSlice(nestedData, nestedMessage); err != nil { | ||
return outputVal, err | ||
} | ||
} | ||
outputVal = protoreflect.ValueOfMessage(nestedMessage) | ||
case protoreflect.BytesKind: | ||
ok = true | ||
bytesBase64, ok := val.(string) | ||
if !ok { | ||
return outputVal, fmt.Errorf("expected string at index %d for field %s, got %T", index, fieldDescriptor.FullName(), val) | ||
} | ||
bytes, err := base64.StdEncoding.DecodeString(bytesBase64) | ||
if err != nil { | ||
return outputVal, fmt.Errorf("failed to decode base64 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err) | ||
} | ||
|
||
outputVal = protoreflect.ValueOfBytes(bytes) | ||
case protoreflect.EnumKind: | ||
num, ok = val.(float64) | ||
expectedKind = "float64" | ||
outputVal = protoreflect.ValueOfEnum(protoreflect.EnumNumber(int32(num))) | ||
case protoreflect.Int32Kind: | ||
if str, ok = val.(string); ok { | ||
parsedVal, err := strconv.ParseInt(str, 10, 32) | ||
if err != nil { | ||
return outputVal, fmt.Errorf("failed to parse int32 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err) | ||
} | ||
outputVal = protoreflect.ValueOfInt32(int32(parsedVal)) | ||
} else { | ||
num, ok = val.(float64) | ||
expectedKind = "float64" | ||
outputVal = protoreflect.ValueOfInt32(int32(num)) | ||
} | ||
case protoreflect.Int64Kind: | ||
if str, ok = val.(string); ok { | ||
parsedVal, err := strconv.ParseInt(str, 10, 64) | ||
if err != nil { | ||
return outputVal, fmt.Errorf("failed to parse int64 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err) | ||
} | ||
outputVal = protoreflect.ValueOfInt64(parsedVal) | ||
} else { | ||
num, ok = val.(float64) | ||
expectedKind = "float64" | ||
outputVal = protoreflect.ValueOfInt64(int64(num)) | ||
} | ||
case protoreflect.Uint32Kind: | ||
if str, ok = val.(string); ok { | ||
parsedVal, err := strconv.ParseUint(str, 10, 32) | ||
if err != nil { | ||
return outputVal, fmt.Errorf("failed to parse uint32 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err) | ||
} | ||
outputVal = protoreflect.ValueOfUint32(uint32(parsedVal)) | ||
} else { | ||
num, ok = val.(float64) | ||
expectedKind = "float64" | ||
outputVal = protoreflect.ValueOfUint32(uint32(num)) | ||
} | ||
case protoreflect.Uint64Kind: | ||
if str, ok = val.(string); ok { | ||
parsedVal, err := strconv.ParseUint(str, 10, 64) | ||
if err != nil { | ||
return outputVal, fmt.Errorf("failed to parse uint64 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err) | ||
} | ||
outputVal = protoreflect.ValueOfUint64(parsedVal) | ||
} else { | ||
num, ok = val.(float64) | ||
expectedKind = "float64" | ||
outputVal = protoreflect.ValueOfUint64(uint64(num)) | ||
} | ||
case protoreflect.FloatKind: | ||
num, ok = val.(float64) | ||
expectedKind = "float64" | ||
outputVal = protoreflect.ValueOfFloat32(float32(num)) | ||
case protoreflect.DoubleKind: | ||
num, ok = val.(float64) | ||
expectedKind = "float64" | ||
outputVal = protoreflect.ValueOfFloat64(num) | ||
case protoreflect.StringKind: | ||
str, ok = val.(string) | ||
if ok && isPbliteBinary(fieldDescriptor) { | ||
bytes, err := base64.StdEncoding.DecodeString(str) | ||
if err != nil { | ||
return outputVal, fmt.Errorf("failed to decode base64 at index %d for field %s: %w", index, fieldDescriptor.FullName(), err) | ||
} | ||
str = string(bytes) | ||
} | ||
expectedKind = "string" | ||
outputVal = protoreflect.ValueOfString(str) | ||
case protoreflect.BoolKind: | ||
boolean, ok = val.(bool) | ||
expectedKind = "bool" | ||
outputVal = protoreflect.ValueOfBool(boolean) | ||
default: | ||
return outputVal, fmt.Errorf("unsupported field type %s in %s", fieldDescriptor.Kind(), fieldDescriptor.FullName()) | ||
} | ||
if !ok { | ||
return outputVal, fmt.Errorf("expected %s at index %d for field %s, got %T", expectedKind, index, fieldDescriptor.FullName(), val) | ||
} | ||
return outputVal, nil | ||
} | ||
|
||
func deserializeFromSlice(data []any, ref protoreflect.Message) error { | ||
for i := 0; i < ref.Descriptor().Fields().Len(); i++ { | ||
fieldDescriptor := ref.Descriptor().Fields().Get(i) | ||
index := int(fieldDescriptor.Number()) - 1 | ||
if index < 0 || index >= len(data) || data[index] == nil { | ||
continue | ||
} | ||
|
||
val := data[index] | ||
outputVal, err := deserializeOne(val, index, ref, nil, fieldDescriptor) | ||
if err != nil { | ||
return err | ||
} | ||
ref.Set(fieldDescriptor, outputVal) | ||
} | ||
return nil | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
syntax = "proto3"; | ||
package pblite; | ||
|
||
option go_package = "../pblite"; | ||
|
||
import "google/protobuf/descriptor.proto"; | ||
|
||
extend google.protobuf.FieldOptions { | ||
optional bool pblite_binary = 50000; | ||
} |
Oops, something went wrong.