Skip to content

Commit

Permalink
extended handler with file imports info
Browse files Browse the repository at this point in the history
  • Loading branch information
adranwit committed Sep 11, 2024
1 parent 1353d1a commit a6bbb2f
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 33 deletions.
56 changes: 43 additions & 13 deletions dir.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type (
methods map[string]*Methods
scopes map[string]*ast.Scope
packages map[string]string
imports map[string]goImports
imports map[string]GoImports
typesOccurrences map[string][]string
}

Expand All @@ -37,18 +37,44 @@ type (
*DirTypes
}

goImport struct {
GoImport struct {
Name string
Module string
}
goImports []*goImport
GoImports []*GoImport
)

func (i *goImport) depPath(basePath string, module *modfile.Module) string {
func (i GoImports) Lookup(pkgName string) string {
for _, candidate := range i {
if candidate.Name == pkgName {
return candidate.Module
}
}
for _, candidate := range i {
if strings.HasSuffix(candidate.Module, "/"+pkgName) {
return candidate.Module
}
}
return ""
}

func (i GoImports) OwnertPkgPath(pkg string) string {
pkgPath := i.Lookup(pkg)
if pkgPath == "" {
return pkgPath
}
parts := strings.Split(pkgPath, "/")
if len(parts) > 1 {
return path.Join(parts[len(parts)-2:]...)
}
return ""
}

func (i *GoImport) depPath(basePath string, module *modfile.Module) string {
return path.Join(basePath, i.folder(module))
}

func (i *goImport) folder(module *modfile.Module) string {
func (i *GoImport) folder(module *modfile.Module) string {
if mod := module; mod != nil && strings.Contains(i.Module, module.Mod.Path) {
if index := strings.Index(i.Module, mod.Mod.Path); index != -1 {
return strings.Trim(i.Module[index+len(mod.Mod.Path):], "/")
Expand All @@ -57,7 +83,7 @@ func (i *goImport) folder(module *modfile.Module) string {
return ""
}

func (i goImports) lookup(packageAlias string) *goImport {
func (i GoImports) lookup(packageAlias string) *GoImport {
for _, cadndidate := range i {
if cadndidate.Name == packageAlias {
return cadndidate
Expand All @@ -72,11 +98,11 @@ func (i goImports) lookup(packageAlias string) *goImport {
return nil
}

func newGoImports(file *ast.File) goImports {
var imports []*goImport
func newGoImports(file *ast.File) GoImports {
var imports []*GoImport
for _, spec := range file.Imports {
value, _ := strconv.Unquote(spec.Path.Value)
imp := &goImport{Module: value}
imp := &GoImport{Module: value}
if spec.Name != nil {
imp.Name = spec.Name.Name
}
Expand Down Expand Up @@ -108,7 +134,7 @@ func NewDirTypes(path string) *DirTypes {
specs: map[string]*TypeSpec{},
values: map[string]interface{}{},
methods: map[string]*Methods{},
imports: map[string]goImports{},
imports: map[string]GoImports{},
scopes: map[string]*ast.Scope{},
packages: map[string]string{},
typesOccurrences: map[string][]string{},
Expand All @@ -133,7 +159,7 @@ func (t *TypeSpec) lookupType(packagePath string, packageIdentifier string, type
lookup = t.options.Registry.Lookup
}
if lookup != nil {
rType, err := lookup(typeName, WithPackagePath(packagePath), WithPackage(packageIdentifier))
rType, err := lookup(typeName, WithPackagePath(packagePath), WithPackage(packageIdentifier), WithGoImports(t.GoImports))
if err == nil {
return rType, nil
}
Expand Down Expand Up @@ -195,14 +221,18 @@ func (t *DirTypes) Type(name string) (reflect.Type, error) {
if rType, ok := t.types[name]; ok {
return rType, nil
}
goImports := t.GoImports
spec, ok := t.specs[name]
if !ok {
return nil, fmt.Errorf("not found type %v", name)
}
spec.DirTypes = t
matched, err := spec.matchType(spec.pkg, spec.spec, spec.spec.Type)
if len(goImports) > 0 {
spec.GoImports = goImports
}
pkgPath := ""
matched, err := spec.matchType(spec.pkg, &pkgPath, spec.spec, spec.spec.Type, t.GoImports)
if err != nil {

return nil, err
}

Expand Down
18 changes: 14 additions & 4 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ type (
parseMode parser.Mode
module *modfile.Module
moduleLocation string
onField func(typeName string, field *ast.Field) error
onStruct func(spec *ast.TypeSpec, aStruct *ast.StructType) error
onField func(typeName string, field *ast.Field, imports GoImports) error
onStruct func(spec *ast.TypeSpec, aStruct *ast.StructType, imports GoImports) error
onLookup func(packagePath, pkg, typeName string, rType reflect.Type)
GoImports GoImports
}

generateOption struct {
Expand Down Expand Up @@ -97,7 +98,7 @@ func WithParserMode(mode parser.Mode) Option {
}

// WithOnField returns on field function
func WithOnField(fn func(typeName string, field *ast.Field) error) Option {
func WithOnField(fn func(typeName string, field *ast.Field, imports GoImports) error) Option {
return func(o *options) {
o.onField = fn
}
Expand Down Expand Up @@ -217,7 +218,7 @@ func WithOnLookup(fn func(packagePath, pkg, typeName string, rType reflect.Type)
}

// WithOnStruct return on lookup notifier option
func WithOnStruct(fn func(spec *ast.TypeSpec, aStruct *ast.StructType) error) Option {
func WithOnStruct(fn func(spec *ast.TypeSpec, aStruct *ast.StructType, imports GoImports) error) Option {
return func(o *options) {
o.onStruct = fn
}
Expand All @@ -235,3 +236,12 @@ func withOptions(opt *options) Option {
*o = *opt
}
}

func WithGoImports(imports GoImports) Option {
return func(o *options) {
if len(imports) == 0 {
return
}
o.GoImports = imports
}
}
31 changes: 19 additions & 12 deletions parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,30 +144,32 @@ func Parse(dataType string, opts ...Option) (reflect.Type, error) {
types := NewDirTypes("")
types.Apply(WithTypeLookup(lookup), WithPackage(o.Package), WithRegistry(o.Registry), WithModule(o.module, o.moduleLocation))
typeSpec := &TypeSpec{DirTypes: types}
rType, err := typeSpec.matchType(types.Package, nil, expr)
pkgPath := ""
rType, err := typeSpec.matchType(types.Package, &pkgPath, nil, expr, o.GoImports)
if err != nil {
return nil, err
}
return rType, nil
}

func (t *TypeSpec) matchType(pkg string, spec *ast.TypeSpec, expr ast.Node) (reflect.Type, error) {
func (t *TypeSpec) matchType(pkg string, pkgPath *string, spec *ast.TypeSpec, expr ast.Node, imps GoImports) (reflect.Type, error) {
switch actual := expr.(type) {
case *ast.StarExpr:
rType, err := t.matchType(pkg, spec, actual.X)
rType, err := t.matchType(pkg, pkgPath, spec, actual.X, imps)
if err != nil {
return nil, err
}
return reflect.PtrTo(rType), nil
case *ast.StructType:
if t.options.onStruct != nil {
t.options.onStruct(spec, actual)
t.options.onStruct(spec, actual, nil)
}
imps = t.DirTypes.imports[t.path]
rFields := make([]reflect.StructField, 0, len(actual.Fields.List))
for _, field := range actual.Fields.List {

if t.onField != nil {
if err := t.onField(spec.Name.Name, field); err != nil {
if err := t.onField(spec.Name.Name, field, imps); err != nil {
return nil, err
}
}
Expand All @@ -182,7 +184,7 @@ func (t *TypeSpec) matchType(pkg string, spec *ast.TypeSpec, expr ast.Node) (ref
tag, prevTypeName = RemoveTag(tag, TagTypeName)
}

fieldType, err := t.matchType(pkg, spec, field.Type)
fieldType, err := t.matchType(pkg, pkgPath, spec, field.Type, imps)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -234,7 +236,12 @@ func (t *TypeSpec) matchType(pkg string, spec *ast.TypeSpec, expr ast.Node) (ref
}
rType, err := t.lookup("", packageIdent.Name, actual.Sel.Name)
if err != nil {
return nil, err
if pkgPath := imps.OwnertPkgPath(packageIdent.Name); pkgPath != "" {
rType, err = t.lookup("", pkgPath, actual.Sel.Name)
}
if err != nil {
return nil, err
}
}
return rType, nil
} else {
Expand All @@ -246,25 +253,25 @@ func (t *TypeSpec) matchType(pkg string, spec *ast.TypeSpec, expr ast.Node) (ref
}

case *ast.ArrayType:
rType, err := t.matchType(pkg, spec, actual.Elt)
rType, err := t.matchType(pkg, pkgPath, spec, actual.Elt, imps)
if err != nil {
return nil, err
}
return reflect.SliceOf(rType), nil
case *ast.MapType:
keyType, err := t.matchType(pkg, spec, actual.Key)
keyType, err := t.matchType(pkg, pkgPath, spec, actual.Key, imps)
if err != nil {
return nil, err
}
valueType, err := t.matchType(pkg, spec, actual.Value)
valueType, err := t.matchType(pkg, pkgPath, spec, actual.Value, imps)
if err != nil {
return nil, err
}
return reflect.MapOf(keyType, valueType), nil
case *ast.InterfaceType:
return InterfaceType, nil
case *ast.TypeSpec:
return t.matchType(pkg, actual, actual.Type)
return t.matchType(pkg, pkgPath, actual, actual.Type, imps)
case *ast.Ident:
switch actual.Name {
case "int":
Expand Down Expand Up @@ -335,7 +342,7 @@ func (t *TypeSpec) tryResolveStandardTypes(packageIdent *ast.Ident, actual *ast.
return nil, false
}

func sourceLocation(t *TypeSpec, imp *goImport) (string, string) {
func sourceLocation(t *TypeSpec, imp *GoImport) (string, string) {
module := t.options.module
if module == nil {
return "", ""
Expand Down
2 changes: 1 addition & 1 deletion parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ func TestParseTypes(t *testing.T) {
},
{
location: "./internal/testdata",
options: []Option{WithParserMode(parser.ParseComments), WithOnField(func(typeName string, field *ast.Field) error {
options: []Option{WithParserMode(parser.ParseComments), WithOnField(func(typeName string, field *ast.Field, imps GoImports) error {
if field.Doc != nil {
comments := CommentGroup(*field.Doc).Stringify()
comments = strings.Trim(comments, "\"/**/")
Expand Down
2 changes: 2 additions & 0 deletions type.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type Type struct {
Methods []reflect.Method
Registry *Types
IsPtr bool
Imports GoImports
}

// TypeName package qualified type name
Expand Down Expand Up @@ -141,6 +142,7 @@ func NewType(name string, opts ...Option) *Type {
strings.Contains(o.Type.Name, "*")) {
o.Definition = name
}
o.Type.Imports = o.parseOption.GoImports
return &o.Type
}

Expand Down
20 changes: 17 additions & 3 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (t *Types) Lookup(name string, opts ...Option) (reflect.Type, error) {
}

func (t *Types) LookupType(aType *Type) (reflect.Type, error) {
ret, err := t.lookupType(aType)
ret, err := t.lookupType(aType, aType.Imports)
if err != nil && t.parent != nil {
if ret, _ = t.parent.LookupType(aType); ret != nil {
return ret, nil
Expand All @@ -140,13 +140,21 @@ func (t *Types) lookupMethods(aType *Type) ([]reflect.Method, error) {
return pkg.Methods(aType.Name)
}

func (t *Types) lookupType(aType *Type) (reflect.Type, error) {
func (t *Types) lookupType(aType *Type, imps GoImports) (reflect.Type, error) {
t.mux.RLock()
pkg := t.packages[aType.Package]
t.mux.RUnlock()
if pkg == nil {
if !aType.IsLoadable() {
return nil, fmt.Errorf("unable locate: %s unknown package: '%s'", aType.Name, aType.Package)
if imps != nil {
if pkgPath := imps.OwnertPkgPath(aType.Package); pkgPath != "" {
pkg = t.packages[pkgPath]
aType.PackagePath = imps.Lookup(aType.Package)
}
}
if pkg == nil {
return nil, fmt.Errorf("unable locate: %s unknown package: '%s'", aType.Name, aType.Package)
}
}
pkg = t.ensurePackage(aType.Package, aType.PackagePath)
}
Expand Down Expand Up @@ -203,6 +211,12 @@ func (t *Types) registerType(aType *Type) error {
return fmt.Errorf("failed to register %v reflect.Type was nil", aType.TypeName())
}
if aType.Type, err = aType.LoadType(t); err != nil {
if pkg, ok := t.packages[aType.Package]; ok {
if aType.Type, err = pkg.Lookup(aType.Name); err == nil {
return nil
}
}

return err
}
}
Expand Down

0 comments on commit a6bbb2f

Please sign in to comment.