Skip to content

Commit

Permalink
Use massive allocators for subtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
outofforest committed Sep 4, 2024
1 parent 563491b commit 5c7b483
Show file tree
Hide file tree
Showing 42 changed files with 413 additions and 111 deletions.
2 changes: 1 addition & 1 deletion build/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ require (
github.com/outofforest/ioc/v2 v2.5.2 // indirect
github.com/outofforest/libexec v0.3.9 // indirect
github.com/outofforest/logger v0.5.4 // indirect
github.com/outofforest/mass v0.1.2 // indirect
github.com/outofforest/mass v0.2.1 // indirect
github.com/outofforest/parallel v0.2.3 // indirect
github.com/outofforest/run v0.6.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
Expand Down
4 changes: 2 additions & 2 deletions build/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ github.com/outofforest/logger v0.3.3/go.mod h1:+M5sO17Va9V33t28Qs9VqRQ8bFV501Uhq
github.com/outofforest/logger v0.3.4/go.mod h1:wOsyVEu2nnueGK+IZuD1tOWYx6tXGV48earpJsDPT3Y=
github.com/outofforest/logger v0.5.4 h1:mRxOxRrm1ppUueQiv+ektljGAGhUM9wOLmkmVpQuMqo=
github.com/outofforest/logger v0.5.4/go.mod h1:czsrxU2w6KlZ31gbJt164H2k0ckoygXbYqCx7V2DAHM=
github.com/outofforest/mass v0.1.2 h1:dS1qRqE+LHbHbc5JFTjrGCCbpHeQLp37Z1tpchDpK2A=
github.com/outofforest/mass v0.1.2/go.mod h1:rqr19KwYSKncmsmZCmMatTsg8pI+ElxerH9v1SGU1CQ=
github.com/outofforest/mass v0.2.1 h1:oIzOnoTJqN8eVXo5jxk1htOhW7bL7hy2JHrvnTsfvtU=
github.com/outofforest/mass v0.2.1/go.mod h1:rqr19KwYSKncmsmZCmMatTsg8pI+ElxerH9v1SGU1CQ=
github.com/outofforest/parallel v0.2.3 h1:DRIgHr7XTL4LLgsTqrj041kulv4ajtbCkRbkOG5psWY=
github.com/outofforest/parallel v0.2.3/go.mod h1:cu210xIjJtOMXR2ERzEcNA2kr0Z0xfZjSKw2jTxAQ2E=
github.com/outofforest/run v0.6.0 h1:t/3vAodvU5L5vJ3BB0qRgfviX+T3JJmLgPN6G2WQs3U=
Expand Down
116 changes: 80 additions & 36 deletions generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func Generate(filePath string, msgs ...any) error {

msgTypes := make([]reflect.Type, 0, len(msgs))
processed := map[reflect.Type]bool{}
roots := map[reflect.Type]bool{}
stack := make([]reflect.Type, 0, len(msgs))
for _, msg := range msgs {
msgType := reflect.TypeOf(msg)
Expand All @@ -40,29 +41,38 @@ func Generate(filePath string, msgs ...any) error {
msgTypes = append(msgTypes, msgType)
}
processed[msgType] = true
roots[msgType] = true
}

tm := types.NewTypeMap()
tm.Import("github.com/pkg/errors")
tm.Import("github.com/outofforest/proton")
tm.Import("github.com/outofforest/mass")

msgInfos := make([]msgInfo, 0, len(msgs))
for len(stack) > 0 {
msgType := stack[len(stack)-1]
stack = stack[:len(stack)-1]

code, dependencies, err := generateMsg(msgType, tm)

code, dependencies, allocators, err := generateMsg(msgType, tm)
if err != nil {
return err
}

switch {
case pkg == "":
pkg = msgType.PkgPath()
case pkg != msgType.PkgPath():
return errors.New("all the msgTypes must belong to the same package")
}

if roots[msgType] {
msgInfos = append(msgInfos, msgInfo{
Type: msgType,
Allocators: allocators,
})
}

for dep := range dependencies {
if !processed[dep] {
stack = append(stack, dep)
Expand Down Expand Up @@ -119,11 +129,11 @@ func Generate(filePath string, msgs ...any) error {
}
}

if err := writeMsgConsts(out, msgTypes); err != nil {
if err := writeMsgConsts(out, msgTypes, tm); err != nil {
return errors.WithStack(err)
}

if err := writeMarshaller(out, msgTypes); err != nil {
if err := writeMarshaller(out, msgInfos, tm); err != nil {
return errors.WithStack(err)
}

Expand All @@ -134,17 +144,17 @@ func Generate(filePath string, msgs ...any) error {
return nil
}

func generateMsg(msgType reflect.Type, tm types.TypeMap) ([]byte, map[reflect.Type]bool, error) {
func generateMsg(msgType reflect.Type, tm types.TypeMap) ([]byte, map[reflect.Type]struct{}, []reflect.Type, error) {
pkg := msgType.PkgPath()
if msgType.Kind() != reflect.Struct {
return nil, nil, errors.Errorf("type %s is not a struct", msgType)
return nil, nil, nil, errors.Errorf("type %s is not a struct", msgType)
}

cfg := methods.Config{
Type: msgType,
}

dependencies := map[reflect.Type]bool{}
dependencies := map[reflect.Type]struct{}{}

err := helpers.ForEachField(msgType, func(field reflect.StructField) error {
builder, err := factory.Get(msgType, field.Type, tm)
Expand All @@ -156,7 +166,7 @@ func generateMsg(msgType reflect.Type, tm types.TypeMap) ([]byte, map[reflect.Ty
if d.PkgPath() != pkg {
continue
}
dependencies[d] = true
dependencies[d] = struct{}{}
}

if field.Type.Kind() == reflect.Bool {
Expand All @@ -165,7 +175,7 @@ func generateMsg(msgType reflect.Type, tm types.TypeMap) ([]byte, map[reflect.Ty
return nil
})
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}

b := &bytes.Buffer{}
Expand All @@ -174,9 +184,10 @@ func generateMsg(msgType reflect.Type, tm types.TypeMap) ([]byte, map[reflect.Ty
b.WriteString("\n\n")
b.Write(marshal.Build(cfg, tm))
b.WriteString("\n\n")
b.Write(unmarshal.Build(cfg, tm))
unmarshalBytes, allocators := unmarshal.Build(cfg, tm)
b.Write(unmarshalBytes)

return b.Bytes(), dependencies, nil
return b.Bytes(), dependencies, allocators, nil
}

var sdkPkgs = map[string]bool{
Expand Down Expand Up @@ -216,7 +227,7 @@ func writeImports(out io.StringWriter, pkgs []string, aliases map[string]string)
return nil
}

func writeMsgConsts(out io.StringWriter, msgTypes []reflect.Type) error {
func writeMsgConsts(out io.StringWriter, msgTypes []reflect.Type, tm types.TypeMap) error {
const header = `
const (
`
Expand All @@ -226,15 +237,17 @@ const (
}

for i, msgType := range msgTypes {
if _, err := out.WriteString(fmt.Sprintf("\t// ID%[1]s is the ID of %[1]s message.\n", msgType.Name())); err != nil {
varName := tm.VarName(msgType, msgType, "ID")
if _, err := out.WriteString(fmt.Sprintf("\t// %[1]s is the ID of %[2]s message.\n",
varName, msgType.Name())); err != nil {
return errors.WithStack(err)
}
if i == 0 {
if _, err := out.WriteString(fmt.Sprintf("\tID%s uint64 = iota + 1\n", msgType.Name())); err != nil {
if _, err := out.WriteString(fmt.Sprintf("\t%s uint64 = iota + 1\n", varName)); err != nil {
return errors.WithStack(err)
}
} else {
if _, err := out.WriteString(fmt.Sprintf("\tID%s\n", msgType.Name())); err != nil {
if _, err := out.WriteString(fmt.Sprintf("\t%s\n", varName)); err != nil {
return errors.WithStack(err)
}
}
Expand All @@ -247,12 +260,22 @@ const (
return nil
}

func writeMarshaller(out io.StringWriter, msgTypes []reflect.Type) error {
func writeMarshaller(out io.StringWriter, msgInfos []msgInfo, tm types.TypeMap) error {
anyType := msgInfos[0].Type

allocators := []reflect.Type{}
for _, msgInfo := range msgInfos {
allocators = append(allocators, msgInfo.Type)
}
for _, msgInfo := range msgInfos {
allocators = types.MergeTypes(allocators, msgInfo.Allocators)
}

const constructorHeader = `
var _ proton.Marshaller = Marshaller{}
// NewMarshaller creates marshaller.
func NewMarshaller(capacity int) Marshaller {
func NewMarshaller(capacity uint64) Marshaller {
return Marshaller{
`
const typeHeader = `
Expand All @@ -278,7 +301,7 @@ func (m Marshaller) Marshal(msg proton.Marshallable, buf []byte) (retID, retSize
`

const marshalTemplate = ` case *%[1]s:
return ID%[1]s, msg2.Marshal(buf), nil
return %[2]s, msg2.Marshal(buf), nil
`

const unmarshalHeader = `
Expand All @@ -298,25 +321,23 @@ func (m Marshaller) Unmarshal(id uint64, buf []byte) (retMsg any, retSize uint64
}
`

const unmarshalTemplate = ` case ID%[1]s:
msg := m.mass%[1]s.New()
return msg, msg.Unmarshal(buf), nil
`

var longestName int
for _, msgType := range msgTypes {
if len(msgType.Name()) > longestName {
longestName = len(msgType.Name())
for _, t := range allocators {
varName := tm.VarName(anyType, t, "mass")
if len(varName) > longestName {
longestName = len(varName)
}
}

if _, err := out.WriteString(constructorHeader); err != nil {
return errors.WithStack(err)
}

for _, msgType := range msgTypes {
if _, err := out.WriteString(fmt.Sprintf(" mass%[1]s:%[2]s mass.New[%[1]s](capacity),\n",
msgType.Name(), types.Align(msgType.Name(), longestName))); err != nil {
for _, t := range allocators {
varName := tm.VarName(anyType, t, "mass")
if _, err := out.WriteString(fmt.Sprintf(" %[1]s:%[2]s mass.New[%[3]s](capacity),\n",
varName, types.Align(varName, longestName),
tm.TypeName(anyType, t))); err != nil {
return errors.WithStack(err)
}
}
Expand All @@ -329,9 +350,11 @@ func (m Marshaller) Unmarshal(id uint64, buf []byte) (retMsg any, retSize uint64
return errors.WithStack(err)
}

for _, msgType := range msgTypes {
if _, err := out.WriteString(fmt.Sprintf(" mass%[1]s%[2]s *mass.Mass[%[1]s]\n",
msgType.Name(), types.Align(msgType.Name(), longestName))); err != nil {
for _, t := range allocators {
varName := tm.VarName(anyType, t, "mass")
if _, err := out.WriteString(fmt.Sprintf(" %[1]s%[2]s *mass.Mass[%[3]s]\n",
varName, types.Align(varName, longestName),
tm.TypeName(anyType, t))); err != nil {
return errors.WithStack(err)
}
}
Expand All @@ -344,8 +367,9 @@ func (m Marshaller) Unmarshal(id uint64, buf []byte) (retMsg any, retSize uint64
return errors.WithStack(err)
}

for _, msgType := range msgTypes {
if _, err := out.WriteString(fmt.Sprintf(marshalTemplate, msgType.Name())); err != nil {
for _, msgInfo := range msgInfos {
if _, err := out.WriteString(fmt.Sprintf(marshalTemplate, msgInfo.Type.Name(),
tm.VarName(anyType, msgInfo.Type, "ID"))); err != nil {
return errors.WithStack(err)
}
}
Expand All @@ -358,8 +382,23 @@ func (m Marshaller) Unmarshal(id uint64, buf []byte) (retMsg any, retSize uint64
return errors.WithStack(err)
}

for _, msgType := range msgTypes {
if _, err := out.WriteString(fmt.Sprintf(unmarshalTemplate, msgType.Name())); err != nil {
for _, msgInfo := range msgInfos {
if _, err := out.WriteString(fmt.Sprintf(` case %[1]s:
msg := m.%[2]s.New()
return msg, msg.Unmarshal(
buf,
`, tm.VarName(anyType, msgInfo.Type, "ID"), tm.VarName(anyType, msgInfo.Type, "mass"))); err != nil {
return errors.WithStack(err)
}

for _, t := range msgInfo.Allocators {
if _, err := out.WriteString(fmt.Sprintf(" m.%[1]s,\n",
tm.VarName(anyType, t, "mass"))); err != nil {
return errors.WithStack(err)
}
}

if _, err := out.WriteString(" ), nil\n"); err != nil {
return errors.WithStack(err)
}
}
Expand All @@ -370,3 +409,8 @@ func (m Marshaller) Unmarshal(id uint64, buf []byte) (retMsg any, retSize uint64

return nil
}

type msgInfo struct {
Type reflect.Type
Allocators []reflect.Type
}
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ module github.com/outofforest/proton
go 1.22

require (
github.com/outofforest/mass v0.1.2
github.com/outofforest/mass v0.2.1
github.com/pkg/errors v0.9.1
github.com/samber/lo v1.47.0
github.com/stretchr/testify v1.9.0
golang.org/x/text v0.17.0
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/text v0.17.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/outofforest/mass v0.1.2 h1:dS1qRqE+LHbHbc5JFTjrGCCbpHeQLp37Z1tpchDpK2A=
github.com/outofforest/mass v0.1.2/go.mod h1:rqr19KwYSKncmsmZCmMatTsg8pI+ElxerH9v1SGU1CQ=
github.com/outofforest/mass v0.2.1 h1:oIzOnoTJqN8eVXo5jxk1htOhW7bL7hy2JHrvnTsfvtU=
github.com/outofforest/mass v0.2.1/go.mod h1:rqr19KwYSKncmsmZCmMatTsg8pI+ElxerH9v1SGU1CQ=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
Expand Down
33 changes: 20 additions & 13 deletions methods/unmarshal/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,8 @@ import (
"github.com/outofforest/proton/types/factory"
)

const (
header = `// Unmarshal unmarshals the structure.
func (m *{{ .TypeName }}) Unmarshal(b []byte) uint64 {
`
)

// Build generates code of Unmarshal method.
func Build(cfg methods.Config, tm types.TypeMap) []byte {
func Build(cfg methods.Config, tm types.TypeMap) ([]byte, []reflect.Type) {
code := &bytes.Buffer{}

offset := methods.BitMapLength(cfg.NumOfBooleanFields)
Expand All @@ -31,6 +25,7 @@ func Build(cfg methods.Config, tm types.TypeMap) []byte {
}

var boolIndex uint64
allocators := []reflect.Type{}
lo.Must0(helpers.ForEachField(cfg.Type, func(field reflect.StructField) error {
if field.Type.Kind() == reflect.Bool {
byteIndex, bitIndex := methods.BitMapPosition(boolIndex)
Expand All @@ -50,6 +45,8 @@ func Build(cfg methods.Config, tm types.TypeMap) []byte {
return err
}

allocators = types.MergeTypes(allocators, builder.Allocators())

marshalCode := builder.UnmarshalCodeTemplate(new(uint64))

code.WriteString(" {\n // " + field.Name + "\n\n")
Expand All @@ -61,16 +58,26 @@ func Build(cfg methods.Config, tm types.TypeMap) []byte {
}))

b := &bytes.Buffer{}
helpers.Execute(b, header, struct {
TypeName string
}{
TypeName: cfg.Type.Name(),
})

b.WriteString(fmt.Sprintf(`// Unmarshal unmarshals the structure.
func (m *%[1]s) Unmarshal(
b []byte,
`, cfg.Type.Name()))

for _, allocator := range allocators {
b.WriteString(fmt.Sprintf(" %[1]s *mass.Mass[%[2]s],\n",
tm.VarName(cfg.Type, allocator, "mass"),
tm.TypeName(cfg.Type, allocator),
))
}

b.WriteString(`) uint64 {
`)

if code.Len() > 0 {
lo.Must(code.WriteTo(b))
}

b.WriteString("\n return o\n}")
return b.Bytes()
return b.Bytes(), allocators
}
Loading

0 comments on commit 5c7b483

Please sign in to comment.