Skip to content

Commit

Permalink
updated code generation
Browse files Browse the repository at this point in the history
  • Loading branch information
adranwit committed Nov 25, 2023
1 parent 9d9a0e7 commit 139e7c4
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 41 deletions.
14 changes: 13 additions & 1 deletion options.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ type (
snippetAfter string
packageTypes []*Type
importModule map[string]string
buildTypes map[string]bool
rewriteDoc bool
withEmbed bool
embedURI string
embedFormatter func(string) string
content map[string]string
withVelty *bool
Expand Down Expand Up @@ -99,6 +101,7 @@ func (o *options) initGen() {
if o.Package == "" {
o.Package = "generated"
}
o.generateOption.buildTypes = map[string]bool{}
}

func (o *options) formatEmbed(name string) string {
Expand Down Expand Up @@ -140,6 +143,14 @@ func WithPackage(pkg string) Option {
}
}

// WithModulePath creates with module path option
func WithModulePath(aPath string) Option {
return func(o *options) {
o.ModulePath = aPath

}
}

// WithImports creates import option
func WithImports(imports []string) Option {
return func(o *options) {
Expand Down Expand Up @@ -253,9 +264,10 @@ func WithSQLTag() Option {
}

// WithSQLRewrite return withEmbed rewrite option, it rewrites SQL to sql:"uri:xxxx" tag
func WithSQLRewrite(content map[string]string) Option {
func WithSQLRewrite(embedURI string, content map[string]string) Option {
return func(o *options) {
o.withEmbed = true
o.embedURI = embedURI
o.content = content
}
}
Expand Down
30 changes: 30 additions & 0 deletions parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"go/ast"
"go/parser"
"go/token"
"golang.org/x/mod/modfile"
"os"
"path"
"reflect"
"strconv"
Expand All @@ -23,9 +25,30 @@ func ParseTypes(path string, options ...Option) (*DirTypes, error) {
if err = dirTypes.indexPackages(packageFiles); err != nil {
return nil, err
}

dirTypes.ModulePath = detectModulePath(path)
return dirTypes, nil
}

func detectModulePath(aPath string) string {
parts := strings.Split(aPath, "/")
var index int
var aFile *modfile.File
for i := len(parts) - 1; i >= 0; i-- {
aPath = strings.Join(parts[:i], "/")
if isFileExists(path.Join(aPath, "go.mod")) {
index = i
data, _ := os.ReadFile(path.Join(aPath, "go.mod"))
aFile, _ = modfile.Parse("", data, nil)
break
}
}
if aFile == nil || aFile.Module == nil {
return ""
}
return path.Join(aFile.Module.Mod.Path, strings.Join(parts[index:], "/"))
}

func (t *DirTypes) indexPackages(packages map[string]*ast.Package) error {
for _, aPackage := range packages {
if err := t.indexPackage(aPackage); err != nil {
Expand Down Expand Up @@ -341,3 +364,10 @@ func asIdent(x ast.Expr) (*ast.Ident, bool) {
ident, ok := x.(*ast.Ident)
return ident, ok
}

func isFileExists(filename string) bool {
if _, err := os.Stat(filename); err != nil {
return false
}
return true
}
5 changes: 4 additions & 1 deletion stringify.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,10 @@ func hasInterface(aType reflect.Type) bool {

func removeTag(tag string, tagName string) (string, string) {
tag = strings.TrimSpace(tag)
tag = trim(tag, '`')
tag = " " + tag
fragment := ""
tagName = ` ` + tagName
tagName += ":"
if index := strings.Index(tag, tagName); index != -1 {
matched := tag[index:]
Expand All @@ -223,7 +226,7 @@ func removeTag(tag string, tagName string) (string, string) {
tag = strings.Replace(tag, matched, "", 1)
}
}
tag = trim(tag, '`')
tag = strings.TrimSpace(tag)
if tag == "" {
return "", ""
}
Expand Down
73 changes: 66 additions & 7 deletions struct.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package xreflect

import (
"fmt"
"go/format"
"reflect"
"strconv"
Expand Down Expand Up @@ -30,9 +31,13 @@ func GenerateStruct(name string, structType reflect.Type, opts ...Option) string
}

dependencyTypes := buildGoType(typeBuilder, importsBuilder, structType, map[string]bool{}, true, genOptions)

additionalTypeBuilder := strings.Builder{}
for _, aType := range genOptions.withTypes {
if genOptions.buildTypes[aType.TypeName()] {
continue
}
genOptions.buildTypes[aType.TypeName()] = true

additionalTypeBuilder.WriteString("\n\n")
aTypeBuilder := newTypeBuilder(aType.Name)
dep := buildGoType(aTypeBuilder, importsBuilder, aType.Type, map[string]bool{}, true, genOptions)
Expand Down Expand Up @@ -107,16 +112,21 @@ func buildGoType(mainBuilder *strings.Builder, importsBuilder *strings.Builder,
aField := structType.Field(i)
fieldTag, typeName := removeTag(string(aField.Tag), TagTypeName)

if aField.Type.Name() == "" && typeName == "" {
aType := resolveType(aField.Type, opts.Registry)
updateType(aType, &aField, opts, importsBuilder, imports, isMain)

}
if opts.withEmbed {
SQL := ""
if fieldTag, SQL = removeTag(string(aField.Tag), "sql"); SQL != "" {
if fieldTag, SQL = removeTag(string(fieldTag), "sql"); SQL != "" {
name := typeName
if name == "" {
name = aField.Name
}
key := opts.formatEmbed(name) + ".sql"
opts.content[key] = SQL
fieldTag += ` sql:"uri=sql/` + key + `" `
fieldTag += fmt.Sprintf(` sql:"uri=%v/`+key+`" `, opts.embedURI)
}
} else if opts.removeSQLTag() {
fieldTag, _ = removeTag(fieldTag, "sql")
Expand Down Expand Up @@ -156,10 +166,16 @@ func buildGoType(mainBuilder *strings.Builder, importsBuilder *strings.Builder,
mainBuilder.WriteString(typeName)
nestedStruct := &strings.Builder{}
structBuilders = append(structBuilders, nestedStruct)
nestedStruct.WriteString("type ")
nestedStruct.WriteString(typeName)
nestedStruct.WriteByte(' ')
structBuilders = append(structBuilders, buildGoType(nestedStruct, importsBuilder, actualType, imports, false, opts)...)
if !strings.Contains(typeName, ".") {
if opts.generateOption.buildTypes[typeName] {
continue
}
opts.generateOption.buildTypes[typeName] = true
nestedStruct.WriteString("type ")
nestedStruct.WriteString(typeName)
nestedStruct.WriteByte(' ')
structBuilders = append(structBuilders, buildGoType(nestedStruct, importsBuilder, actualType, imports, false, opts)...)
}
}
} else {
mainBuilder.WriteString(actualType.String())
Expand Down Expand Up @@ -190,6 +206,49 @@ func buildGoType(mainBuilder *strings.Builder, importsBuilder *strings.Builder,
return structBuilders
}

func updateType(aType *Type, aField *reflect.StructField, opts *options, importsBuilder *strings.Builder, imports map[string]bool, isMain bool) {
if aType == nil {
return
}
typeName := aType.TypeName()
if opts.Package == aType.Package {
typeName = aType.SimpleTypeName()
}
if typeName != "" {
aField.Tag += reflect.StructTag(" " + TagTypeName + `:"` + typeName + `"`)
}
typePkg, _ := splitPackage(typeName)
if typePkg != "" && typePkg != opts.Package && aType.ModulePath != "" {
appendImportIfNeeded(importsBuilder, aType.ModulePath, imports, isMain, opts)
}
}

func splitPackage(name string) (string, string) {
index := strings.LastIndex(name, ".")
if index != -1 {
return name[:index], name[index+1:]
}
return "", name
}

func resolveType(aType reflect.Type, types *Types) *Type {
if types == nil || len(types.info) == 0 {
return nil
}
rawType := aType
if rawType.Kind() == reflect.Slice {
rawType = rawType.Elem()
}
if rawType.Kind() == reflect.Ptr {
rawType = rawType.Elem()
}
info, ok := types.info[rawType]
if !ok {
return nil
}
return info
}

func appendImportIfNeeded(importsBuilder *strings.Builder, pkgPath string, imports map[string]bool, isMain bool, opts *options) {
if isMain {
return
Expand Down
61 changes: 31 additions & 30 deletions struct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,36 +110,7 @@ type Bar struct {
}
`,
},
{
description: "tags",
rType: reflect.StructOf([]reflect.StructField{
{
Name: "Id",
Type: IntType,
Tag: "json:\",omitempty\"",
},
{
Name: "Name",
Type: StringType,
Tag: "json:\",omitempty\"",
},
{
Name: "Bar",
Type: reflect.StructOf([]reflect.StructField{
{
Name: "BarId",
Type: IntType,
},
{
Name: "Price",
Type: Float64Type,
},
}),
},
}),
name: "Foo",
expected: "package generated\n\ntype Foo struct {\n\tId int `json:\",omitempty\"`\n\tName string `json:\",omitempty\"`\n\tBar Bar\n}\n\ntype Bar struct {\n\tBarId int\n\tPrice float64\n}\n",
},

{
description: "golang types",
rType: reflect.StructOf([]reflect.StructField{
Expand Down Expand Up @@ -209,6 +180,36 @@ type Foo struct {
}
`,
},
{
description: "tags",
rType: reflect.StructOf([]reflect.StructField{
{
Name: "Id",
Type: IntType,
Tag: "json:\",omitempty\"",
},
{
Name: "Name",
Type: StringType,
Tag: "json:\",omitempty\"",
},
{
Name: "Bar",
Type: reflect.StructOf([]reflect.StructField{
{
Name: "BarId",
Type: IntType,
},
{
Name: "Price",
Type: Float64Type,
},
}),
},
}),
name: "Foo",
expected: "package generated\n\ntype Foo struct {\n\tId int `json:\",omitempty\"`\n\tName string `json:\",omitempty\"`\n\tBar Bar\n}\n\ntype Bar struct {\n\tBarId int\n\tPrice float64\n}\n",
},
}

//for _, testCase := range testcases[len(testcases)-1:] {
Expand Down
9 changes: 7 additions & 2 deletions type.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const customPackageName = "PackageName"

type Type struct {
PackagePath string
ModulePath string
Package string
Name string
Definition string
Expand All @@ -20,6 +21,11 @@ type Type struct {
IsPtr bool
}

// TypeName package qualified type name
func (t *Type) SimpleTypeName() string {
return t.Name
}

// TypeName package qualified type name
func (t *Type) TypeName() string {
if t.Package == "" {
Expand Down Expand Up @@ -54,6 +60,7 @@ func (t *Type) LoadType(registry *Types) (reflect.Type, error) {
if err != nil {
return nil, err
}
t.ModulePath = dirType.ModulePath
packageName := dirType.PackagePath(t.PackagePath) //ensure location package matches actual package
if value, err := dirType.Value(customPackageName); err == nil {
if literal, ok := value.(*ast.BasicLit); ok {
Expand All @@ -62,7 +69,6 @@ func (t *Type) LoadType(registry *Types) (reflect.Type, error) {
}
}
}

if packageName != "" && packageName != pkg.Name { //otherwise correct it
pkg.packagePaths[t.PackagePath] = packageName
pkg.Path = ""
Expand All @@ -83,7 +89,6 @@ func (t *Type) LoadType(registry *Types) (reflect.Type, error) {
t.Methods = append(t.Methods, method)
}
}

return rType, nil
}
if t.Definition != "" {
Expand Down
9 changes: 9 additions & 0 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,15 @@ func (t *Types) registerType(aType *Type) error {
t.mux.RLock()
prev, ok := t.info[aType.Type]
t.mux.RUnlock()

if t.parent != nil {
if candidate, _ := t.parent.Lookup(aType.Name, WithPackage(aType.Package)); candidate != nil {
if candidate.Name() != "" { //use named registed type instead
aType.Type = candidate
}
}
}

//if previous type is a named type, it should not be overridden by inlined type i.e struct{X ...}
if ok && prev.Type.Name() != "" && aType.Type.Name() == "" {
return nil
Expand Down

0 comments on commit 139e7c4

Please sign in to comment.