diff --git a/cmd/web/restdoc/pkg/testdir/testdir.go b/cmd/web/restdoc/pkg/testdir/testdir.go index dae1ebdd..ed910010 100644 --- a/cmd/web/restdoc/pkg/testdir/testdir.go +++ b/cmd/web/restdoc/pkg/testdir/testdir.go @@ -23,7 +23,7 @@ type S struct { // 此行不会作为 S 的 comment T time.Time } F2 []Int // F2 Doc - // F3 *S + F3 []*S // 引用自身 } // S2 Alias @@ -33,6 +33,7 @@ type S2 = S type G[T any] struct { F1 T // F1 Doc F2 int // F2 Doc + F3 *G[T] } type GInt G[int] diff --git a/cmd/web/restdoc/pkg/types.go b/cmd/web/restdoc/pkg/types.go index a16847ca..95e4cc7d 100644 --- a/cmd/web/restdoc/pkg/types.go +++ b/cmd/web/restdoc/pkg/types.go @@ -16,7 +16,7 @@ import ( type ( // Struct 这是对 [types.Struct] 的包装 Struct struct { - st *ast.StructType + name string fields []*types.Var docs []*ast.CommentGroup tags []string @@ -46,7 +46,7 @@ type ( defaultTypeList []types.Type ) -func (s *Struct) String() string { return "struct with " + strconv.Itoa(s.NumFields()) + " fields" } +func (s *Struct) String() string { return s.name } func (s *Struct) Underlying() types.Type { return s } @@ -81,57 +81,48 @@ func (tl defaultTypeList) Len() int { return len(tl) } // newTypeList 声明 [typeList] 接口对象 func newTypeList(t ...types.Type) typeList { return defaultTypeList(t) } -func (pkgs *Packages) newStruct(ctx context.Context, pkg *types.Package, st *ast.StructType, file *ast.File, tl typeList, tps *types.TypeParamList) (*Struct, error) { +// name 为结构体名称,可以为空; +func (pkgs *Packages) newStruct(ctx context.Context, pkg *types.Package, name string, st *ast.StructType, file *ast.File, tl typeList, tps *types.TypeParamList) (*Struct, error) { size := st.Fields.NumFields() s := &Struct{ - st: st, + name: name, fields: make([]*types.Var, 0, size), docs: make([]*ast.CommentGroup, 0, size), tags: make([]string, 0, size), } - // BUG 结构体中的字段如果引用自身,会造成死循环 - - if err := pkgs.addField(ctx, pkg, s, st, file, tl, tps); err != nil { - return nil, err - } - return s, nil -} - -// 将 ts 的所有有字段加入 s 之中 -func (pkgs *Packages) addField(ctx context.Context, pkg *types.Package, s *Struct, st *ast.StructType, file *ast.File, tl typeList, tps *types.TypeParamList) error { for _, f := range st.Fields.List { doc := getDoc(f.Doc, f.Comment) - t, err := pkgs.typeOfExpr(ctx, pkg, file, f.Type, nil, tl, tps) - if err != nil { - return err - } - var tag string if f.Tag != nil { tag = f.Tag.Value } + typ, err := pkgs.typeOfExpr(ctx, pkg, file, f.Type, s, nil, tl, tps) + if err != nil { + return nil, err + } + switch len(f.Names) { case 0: - s.fields = append(s.fields, types.NewField(f.Pos(), pkg, "", t, true)) + s.fields = append(s.fields, types.NewField(f.Pos(), pkg, "", typ, true)) s.docs = append(s.docs, doc) s.tags = append(s.tags, tag) case 1: - s.fields = append(s.fields, types.NewField(f.Pos(), pkg, f.Names[0].Name, t, false)) + s.fields = append(s.fields, types.NewField(f.Pos(), pkg, f.Names[0].Name, typ, false)) s.docs = append(s.docs, doc) s.tags = append(s.tags, tag) default: for _, n := range f.Names { - s.fields = append(s.fields, types.NewField(f.Pos(), pkg, n.Name, t, false)) + s.fields = append(s.fields, types.NewField(f.Pos(), pkg, n.Name, typ, false)) s.docs = append(s.docs, doc) s.tags = append(s.tags, tag) } } } - return nil + return s, nil } // tl 表示范型参数列表,可以为空 @@ -201,12 +192,17 @@ func (pkgs *Packages) TypeOf(ctx context.Context, path string) (types.Type, erro return wrap(typ), nil } -// doc 可以为空,参考 typeOfPath -func (pkgs *Packages) typeOfExpr(ctx context.Context, pkg *types.Package, f *ast.File, expr ast.Expr, doc *ast.CommentGroup, tl typeList, tps *types.TypeParamList) (types.Type, error) { +// doc 可以为空,参考 typeOfPath; +// self 防止对象引用自身引起的死循环,该值用于判断 expr 是否为与 parent 相同。可以为空; +func (pkgs *Packages) typeOfExpr(ctx context.Context, pkg *types.Package, f *ast.File, expr ast.Expr, self *Struct, doc *ast.CommentGroup, tl typeList, tps *types.TypeParamList) (types.Type, error) { switch e := expr.(type) { case *ast.SelectorExpr: // type x path.struct return pkgs.typeOfPath(ctx, pkgs.getPathFromSelectorExpr(e, f), "", doc, tl, tps) case *ast.Ident: // type x y,或是 struct{ f1 T } 中的 T + if self != nil && e.Name == self.String() { + return self, nil + } + basic := e.Name name := pkg.Path() + "." + basic @@ -232,9 +228,9 @@ func (pkgs *Packages) typeOfExpr(ctx context.Context, pkg *types.Package, f *ast return pkgs.typeOfPath(ctx, name, basic, doc, tl, tps) case *ast.StructType: - return pkgs.newStruct(ctx, pkg, e, f, tl, tps) + return pkgs.newStruct(ctx, pkg, "", e, f, tl, tps) case *ast.ArrayType: // type x []y - typ, err := pkgs.typeOfExpr(ctx, pkg, f, e.Elt, doc, tl, tps) + typ, err := pkgs.typeOfExpr(ctx, pkg, f, e.Elt, self, doc, tl, tps) if err != nil { return nil, err } @@ -249,29 +245,29 @@ func (pkgs *Packages) typeOfExpr(ctx context.Context, pkg *types.Package, f *ast return types.NewArray(typ, l), nil } case *ast.StarExpr: // type x *y - typ, err := pkgs.typeOfExpr(ctx, pkg, f, e.X, doc, tl, tps) + typ, err := pkgs.typeOfExpr(ctx, pkg, f, e.X, self, doc, tl, tps) if err != nil { return nil, err } return types.NewPointer(typ), nil case *ast.IndexExpr: // type x y[int] 等实例化的范型 - idxType, err := pkgs.typeOfExpr(ctx, pkg, f, e.Index, nil, tl, tps) + idxType, err := pkgs.typeOfExpr(ctx, pkg, f, e.Index, nil, nil, tl, tps) if err != nil { return nil, err } - return pkgs.typeOfExpr(ctx, pkg, f, e.X, doc, newTypeList(idxType), tps) + return pkgs.typeOfExpr(ctx, pkg, f, e.X, self, doc, newTypeList(idxType), tps) case *ast.IndexListExpr: idxTypes := make([]types.Type, 0, len(e.Indices)) for _, idx := range e.Indices { - idxType, err := pkgs.typeOfExpr(ctx, pkg, f, idx, nil, tl, tps) + idxType, err := pkgs.typeOfExpr(ctx, pkg, f, idx, nil, nil, tl, tps) if err != nil { return nil, err } idxTypes = append(idxTypes, idxType) } - return pkgs.typeOfExpr(ctx, pkg, f, e.X, doc, newTypeList(idxTypes...), tps) + return pkgs.typeOfExpr(ctx, pkg, f, e.X, self, doc, newTypeList(idxTypes...), tps) case *ast.InterfaceType: return nil, web.NewLocaleError("ast.InterfaceType can not covert to openapi schema", expr) default: @@ -284,7 +280,7 @@ func (pkgs *Packages) typeOfExpr(ctx context.Context, pkg *types.Package, f *ast // 如果其类型为 [types.Struct],会被包装为 [Struct]。 // 如果存在类型为 [types.Named],会被包装为 [Named]。 // 可能存在 type uint string 之类的定义,basicType 表示 path 找不到时是否需要按 basicType 查找基本的内置类型。 -// doc 自定义的文档信息,可以为空,表示根据指定的类型信息确定文档。如果是类型字段可以自己指定此值; +// doc 自定义的文档信息,可以为空,表示根据指定的类型信息确定文档。如果是字段类型可以自己指定此值; func (pkgs *Packages) typeOfPath(ctx context.Context, path, basicType string, doc *ast.CommentGroup, tl typeList, tps *types.TypeParamList) (typ types.Type, err error) { obj, spec, f, found := pkgs.lookup(ctx, path) if !found { @@ -306,12 +302,12 @@ func (pkgs *Packages) typeOfPath(ctx context.Context, path, basicType string, do } if st, ok := spec.Type.(*ast.StructType); ok { - typ, err = pkgs.newStruct(ctx, tn.Pkg(), st, f, tl, obj.Type().(*types.Named).TypeParams()) + typ, err = pkgs.newStruct(ctx, tn.Pkg(), spec.Name.Name, st, f, tl, obj.Type().(*types.Named).TypeParams()) } else { if doc == nil { doc = getDoc(spec.Doc, spec.Comment) } - typ, err = pkgs.typeOfExpr(ctx, tn.Pkg(), f, spec.Type, doc, tl, tps) + typ, err = pkgs.typeOfExpr(ctx, tn.Pkg(), f, spec.Type, nil, doc, tl, tps) } if err != nil { return nil, err diff --git a/cmd/web/restdoc/schema/search.go b/cmd/web/restdoc/schema/search.go index d89d40a7..1c09a293 100644 --- a/cmd/web/restdoc/schema/search.go +++ b/cmd/web/restdoc/schema/search.go @@ -140,26 +140,24 @@ func (s *Schema) fromType(t *openapi.OpenAPI, xmlName string, typ types.Type, ta // xmlName 结构体名称,同时也会被当作 XML 根元素名称(会被 XMLName 字段改写); // 将 *pkg.Struct 解析为 schema 对象 func (s *Schema) fromStruct(schema *openapi3.Schema, t *openapi.OpenAPI, xmlName string, st *pkg.Struct, tag string) error { - // BUG 结构体与嵌入字段重名的处理 - for i := range st.NumFields() { field := st.Field(i) ft := field.Type() if field.Embedded() { - named, ok := ft.(*pkg.Named) - for ok { - ft = named.Next() - // BUG ft 可能是 NotFound 之前就可用了 - named, ok = ft.(*pkg.Named) + fieldRef, _, err := s.fromType(t, "", ft, tag) + if err != nil { + return err } - if ps, ok := ft.(*pkg.Struct); ok { - if err := s.fromStruct(schema, t, xmlName, ps, tag); err != nil { - return err + if fieldRef.Value.Type == openapi3.TypeObject { + for k, v := range fieldRef.Value.Properties { + if _, found := schema.Properties[k]; found { // 防止与现有的重名 + continue + } + schema.WithPropertyRef(k, v) } - continue - } // 非 struct 类型的嵌入字段,当作普通字段处理 + } } if !field.Exported() { // 嵌入对象名小写是可以的,所以要在 filed.Embedded 判断之后。