diff --git a/go.mod b/go.mod index 969b6be..dff13dc 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.23 require ( github.com/iancoleman/strcase v0.3.0 github.com/spf13/cobra v1.8.1 + golang.org/x/tools v0.24.0 ) require ( diff --git a/go.sum b/go.sum index b2c1df5..9b49a7d 100644 --- a/go.sum +++ b/go.sum @@ -8,5 +8,7 @@ github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= +golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/generate.go b/internal/generate.go index 36c4426..a1416f3 100644 --- a/internal/generate.go +++ b/internal/generate.go @@ -11,6 +11,7 @@ import ( "go/ast" "go/parser" "go/token" + "golang.org/x/tools/go/ast/astutil" "os" "reflect" "slices" @@ -90,66 +91,64 @@ func Generate(ctx context.Context, sourceFileName string) error { Structs: make([]Struct, 0), } - for _, node := range nodes.Decls { - switch node.(type) { - case *ast.GenDecl: - genDecl := node.(*ast.GenDecl) - for _, spec := range genDecl.Specs { - switch spec.(type) { - case *ast.TypeSpec: - typeSpec := spec.(*ast.TypeSpec) - switch concreteType := typeSpec.Type.(type) { - case *ast.StructType: - strct := Struct{ - Name: typeSpec.Name.Name, - Fields: make([]Field, 0, len(concreteType.Fields.List)), - } - for _, field := range concreteType.Fields.List { - switch ident := field.Type.(type) { - case *ast.Ident: - fieldType := ident.Name - switch fieldType { - case "string": - output.Imports = addImport(output.Imports, "strings", "") - case "error": - output.Imports = addImport(output.Imports, "errors", "") - } - for _, name := range field.Names { - tag := field.Tag - var tagStr string - if tag != nil { - tagStr = tag.Value - } - strct.Fields = append(strct.Fields, newField(ctx, name.Name, "", fieldType, tagStr)) - } - case *ast.SelectorExpr: - pkg := ident.X.(*ast.Ident).Name - fieldType := ident.Sel.Name + astutil.Apply(nodes, nil, func(c *astutil.Cursor) bool { + typeSpec, ok := c.Node().(*ast.TypeSpec) + if !ok { + return true + } + structType, ok := typeSpec.Type.(*ast.StructType) + if !ok { + return true + } + + strct := Struct{ + Name: typeSpec.Name.Name, + Fields: make([]Field, 0, len(structType.Fields.List)), + } + for _, field := range structType.Fields.List { + switch ident := field.Type.(type) { + case *ast.Ident: + fieldType := ident.Name + switch fieldType { + case "string": + output.Imports = addImport(output.Imports, "strings", "") + case "error": + output.Imports = addImport(output.Imports, "errors", "") + } + for _, name := range field.Names { + tag := field.Tag + var tagStr string + if tag != nil { + tagStr = tag.Value + } + strct.Fields = append(strct.Fields, newField(ctx, name.Name, "", fieldType, tagStr)) + } + case *ast.SelectorExpr: + pkg := ident.X.(*ast.Ident).Name + fieldType := ident.Sel.Name - if impPath := getImportPath(ctx, imports, pkg); impPath != "" { - var alias string - if !strings.HasSuffix(impPath, pkg) { - alias = pkg - } - output.Imports = addImport(output.Imports, impPath, alias) - } + if impPath := getImportPath(ctx, imports, pkg); impPath != "" { + var alias string + if !strings.HasSuffix(impPath, pkg) { + alias = pkg + } + output.Imports = addImport(output.Imports, impPath, alias) + } - for _, name := range field.Names { - tag := field.Tag - var tagStr string - if tag != nil { - tagStr = tag.Value - } - strct.Fields = append(strct.Fields, newField(ctx, name.Name, pkg, fieldType, tagStr)) - } - } - } - output.Structs = append(output.Structs, strct) + for _, name := range field.Names { + tag := field.Tag + var tagStr string + if tag != nil { + tagStr = tag.Value } + strct.Fields = append(strct.Fields, newField(ctx, name.Name, pkg, fieldType, tagStr)) } } } - } + output.Structs = append(output.Structs, strct) + return true + }) + f, err := os.Create(fileName) err = tpl.Execute(f, output) if err != nil {