Skip to content

Commit

Permalink
fix(cmd/web/restdoc): 修正各类已知的 bug
Browse files Browse the repository at this point in the history
  • Loading branch information
caixw committed Feb 26, 2024
1 parent 68e7784 commit 7cdfc04
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 48 deletions.
3 changes: 2 additions & 1 deletion cmd/web/restdoc/pkg/testdir/testdir.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type S struct { // 此行不会作为 S 的 comment
T time.Time
}
F2 []Int // F2 Doc
// F3 *S
F3 []*S // 引用自身
}

// S2 Alias
Expand All @@ -33,6 +33,7 @@ type S2 = S
type G[T any] struct {
F1 T // F1 Doc
F2 int // F2 Doc
F3 *G[T]
}

type GInt G[int]
Expand Down
66 changes: 31 additions & 35 deletions cmd/web/restdoc/pkg/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
type (
// Struct 这是对 [types.Struct] 的包装
Struct struct {
st *ast.StructType
name string
fields []*types.Var
docs []*ast.CommentGroup
tags []string
Expand Down Expand Up @@ -46,7 +46,7 @@ type (
defaultTypeList []types.Type
)

func (s *Struct) String() string { return "struct with " + strconv.Itoa(s.NumFields()) + " fields" }
func (s *Struct) String() string { return s.name }

func (s *Struct) Underlying() types.Type { return s }

Expand Down Expand Up @@ -81,57 +81,48 @@ func (tl defaultTypeList) Len() int { return len(tl) }
// newTypeList 声明 [typeList] 接口对象
func newTypeList(t ...types.Type) typeList { return defaultTypeList(t) }

func (pkgs *Packages) newStruct(ctx context.Context, pkg *types.Package, st *ast.StructType, file *ast.File, tl typeList, tps *types.TypeParamList) (*Struct, error) {
// name 为结构体名称,可以为空;
func (pkgs *Packages) newStruct(ctx context.Context, pkg *types.Package, name string, st *ast.StructType, file *ast.File, tl typeList, tps *types.TypeParamList) (*Struct, error) {
size := st.Fields.NumFields()
s := &Struct{
st: st,
name: name,
fields: make([]*types.Var, 0, size),
docs: make([]*ast.CommentGroup, 0, size),
tags: make([]string, 0, size),
}

// BUG 结构体中的字段如果引用自身,会造成死循环

if err := pkgs.addField(ctx, pkg, s, st, file, tl, tps); err != nil {
return nil, err
}
return s, nil
}

// 将 ts 的所有有字段加入 s 之中
func (pkgs *Packages) addField(ctx context.Context, pkg *types.Package, s *Struct, st *ast.StructType, file *ast.File, tl typeList, tps *types.TypeParamList) error {
for _, f := range st.Fields.List {
doc := getDoc(f.Doc, f.Comment)

t, err := pkgs.typeOfExpr(ctx, pkg, file, f.Type, nil, tl, tps)
if err != nil {
return err
}

var tag string
if f.Tag != nil {
tag = f.Tag.Value
}

typ, err := pkgs.typeOfExpr(ctx, pkg, file, f.Type, s, nil, tl, tps)
if err != nil {
return nil, err
}

switch len(f.Names) {
case 0:
s.fields = append(s.fields, types.NewField(f.Pos(), pkg, "", t, true))
s.fields = append(s.fields, types.NewField(f.Pos(), pkg, "", typ, true))
s.docs = append(s.docs, doc)
s.tags = append(s.tags, tag)
case 1:
s.fields = append(s.fields, types.NewField(f.Pos(), pkg, f.Names[0].Name, t, false))
s.fields = append(s.fields, types.NewField(f.Pos(), pkg, f.Names[0].Name, typ, false))
s.docs = append(s.docs, doc)
s.tags = append(s.tags, tag)
default:
for _, n := range f.Names {
s.fields = append(s.fields, types.NewField(f.Pos(), pkg, n.Name, t, false))
s.fields = append(s.fields, types.NewField(f.Pos(), pkg, n.Name, typ, false))
s.docs = append(s.docs, doc)
s.tags = append(s.tags, tag)
}
}
}

return nil
return s, nil
}

// tl 表示范型参数列表,可以为空
Expand Down Expand Up @@ -201,12 +192,17 @@ func (pkgs *Packages) TypeOf(ctx context.Context, path string) (types.Type, erro
return wrap(typ), nil
}

// doc 可以为空,参考 typeOfPath
func (pkgs *Packages) typeOfExpr(ctx context.Context, pkg *types.Package, f *ast.File, expr ast.Expr, doc *ast.CommentGroup, tl typeList, tps *types.TypeParamList) (types.Type, error) {
// doc 可以为空,参考 typeOfPath;
// self 防止对象引用自身引起的死循环,该值用于判断 expr 是否为与 parent 相同。可以为空;
func (pkgs *Packages) typeOfExpr(ctx context.Context, pkg *types.Package, f *ast.File, expr ast.Expr, self *Struct, doc *ast.CommentGroup, tl typeList, tps *types.TypeParamList) (types.Type, error) {
switch e := expr.(type) {
case *ast.SelectorExpr: // type x path.struct
return pkgs.typeOfPath(ctx, pkgs.getPathFromSelectorExpr(e, f), "", doc, tl, tps)
case *ast.Ident: // type x y,或是 struct{ f1 T } 中的 T
if self != nil && e.Name == self.String() {
return self, nil
}

basic := e.Name
name := pkg.Path() + "." + basic

Expand All @@ -232,9 +228,9 @@ func (pkgs *Packages) typeOfExpr(ctx context.Context, pkg *types.Package, f *ast

return pkgs.typeOfPath(ctx, name, basic, doc, tl, tps)
case *ast.StructType:
return pkgs.newStruct(ctx, pkg, e, f, tl, tps)
return pkgs.newStruct(ctx, pkg, "", e, f, tl, tps)
case *ast.ArrayType: // type x []y
typ, err := pkgs.typeOfExpr(ctx, pkg, f, e.Elt, doc, tl, tps)
typ, err := pkgs.typeOfExpr(ctx, pkg, f, e.Elt, self, doc, tl, tps)
if err != nil {
return nil, err
}
Expand All @@ -249,29 +245,29 @@ func (pkgs *Packages) typeOfExpr(ctx context.Context, pkg *types.Package, f *ast
return types.NewArray(typ, l), nil
}
case *ast.StarExpr: // type x *y
typ, err := pkgs.typeOfExpr(ctx, pkg, f, e.X, doc, tl, tps)
typ, err := pkgs.typeOfExpr(ctx, pkg, f, e.X, self, doc, tl, tps)
if err != nil {
return nil, err
}
return types.NewPointer(typ), nil
case *ast.IndexExpr: // type x y[int] 等实例化的范型
idxType, err := pkgs.typeOfExpr(ctx, pkg, f, e.Index, nil, tl, tps)
idxType, err := pkgs.typeOfExpr(ctx, pkg, f, e.Index, nil, nil, tl, tps)
if err != nil {
return nil, err
}

return pkgs.typeOfExpr(ctx, pkg, f, e.X, doc, newTypeList(idxType), tps)
return pkgs.typeOfExpr(ctx, pkg, f, e.X, self, doc, newTypeList(idxType), tps)
case *ast.IndexListExpr:
idxTypes := make([]types.Type, 0, len(e.Indices))
for _, idx := range e.Indices {
idxType, err := pkgs.typeOfExpr(ctx, pkg, f, idx, nil, tl, tps)
idxType, err := pkgs.typeOfExpr(ctx, pkg, f, idx, nil, nil, tl, tps)
if err != nil {
return nil, err
}
idxTypes = append(idxTypes, idxType)
}

return pkgs.typeOfExpr(ctx, pkg, f, e.X, doc, newTypeList(idxTypes...), tps)
return pkgs.typeOfExpr(ctx, pkg, f, e.X, self, doc, newTypeList(idxTypes...), tps)
case *ast.InterfaceType:
return nil, web.NewLocaleError("ast.InterfaceType can not covert to openapi schema", expr)
default:
Expand All @@ -284,7 +280,7 @@ func (pkgs *Packages) typeOfExpr(ctx context.Context, pkg *types.Package, f *ast
// 如果其类型为 [types.Struct],会被包装为 [Struct]。
// 如果存在类型为 [types.Named],会被包装为 [Named]。
// 可能存在 type uint string 之类的定义,basicType 表示 path 找不到时是否需要按 basicType 查找基本的内置类型。
// doc 自定义的文档信息,可以为空,表示根据指定的类型信息确定文档。如果是类型字段可以自己指定此值
// doc 自定义的文档信息,可以为空,表示根据指定的类型信息确定文档。如果是字段类型可以自己指定此值
func (pkgs *Packages) typeOfPath(ctx context.Context, path, basicType string, doc *ast.CommentGroup, tl typeList, tps *types.TypeParamList) (typ types.Type, err error) {
obj, spec, f, found := pkgs.lookup(ctx, path)
if !found {
Expand All @@ -306,12 +302,12 @@ func (pkgs *Packages) typeOfPath(ctx context.Context, path, basicType string, do
}

if st, ok := spec.Type.(*ast.StructType); ok {
typ, err = pkgs.newStruct(ctx, tn.Pkg(), st, f, tl, obj.Type().(*types.Named).TypeParams())
typ, err = pkgs.newStruct(ctx, tn.Pkg(), spec.Name.Name, st, f, tl, obj.Type().(*types.Named).TypeParams())
} else {
if doc == nil {
doc = getDoc(spec.Doc, spec.Comment)
}
typ, err = pkgs.typeOfExpr(ctx, tn.Pkg(), f, spec.Type, doc, tl, tps)
typ, err = pkgs.typeOfExpr(ctx, tn.Pkg(), f, spec.Type, nil, doc, tl, tps)
}
if err != nil {
return nil, err
Expand Down
22 changes: 10 additions & 12 deletions cmd/web/restdoc/schema/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,26 +140,24 @@ func (s *Schema) fromType(t *openapi.OpenAPI, xmlName string, typ types.Type, ta
// xmlName 结构体名称,同时也会被当作 XML 根元素名称(会被 XMLName 字段改写);
// 将 *pkg.Struct 解析为 schema 对象
func (s *Schema) fromStruct(schema *openapi3.Schema, t *openapi.OpenAPI, xmlName string, st *pkg.Struct, tag string) error {
// BUG 结构体与嵌入字段重名的处理

for i := range st.NumFields() {
field := st.Field(i)
ft := field.Type()

if field.Embedded() {
named, ok := ft.(*pkg.Named)
for ok {
ft = named.Next()
// BUG ft 可能是 NotFound 之前就可用了
named, ok = ft.(*pkg.Named)
fieldRef, _, err := s.fromType(t, "", ft, tag)
if err != nil {
return err
}

if ps, ok := ft.(*pkg.Struct); ok {
if err := s.fromStruct(schema, t, xmlName, ps, tag); err != nil {
return err
if fieldRef.Value.Type == openapi3.TypeObject {
for k, v := range fieldRef.Value.Properties {
if _, found := schema.Properties[k]; found { // 防止与现有的重名
continue
}
schema.WithPropertyRef(k, v)
}
continue
} // 非 struct 类型的嵌入字段,当作普通字段处理
}
}

if !field.Exported() { // 嵌入对象名小写是可以的,所以要在 filed.Embedded 判断之后。
Expand Down

0 comments on commit 7cdfc04

Please sign in to comment.