From 14e9780ec906a8d2c0b0c9e66c29faea0744a22b Mon Sep 17 00:00:00 2001 From: caixw Date: Fri, 31 May 2024 09:07:06 +0800 Subject: [PATCH] =?UTF-8?q?fix(cmd/web/restdoc):=20=E9=99=90=E5=88=B6=20pk?= =?UTF-8?q?g.Packages.TypeOf=20=E7=9A=84=E8=B0=83=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/web/restdoc/pkg/bench_test.go | 8 +++---- cmd/web/restdoc/pkg/path.go | 4 ++-- cmd/web/restdoc/pkg/path_test.go | 1 + cmd/web/restdoc/pkg/pkg.go | 17 ++++++++------- .../restdoc/pkg/testdir/testdir2/testdir2.go | 1 + cmd/web/restdoc/pkg/types.go | 6 ++++++ cmd/web/restdoc/pkg/types_test.go | 21 ++++++++++++++----- cmd/web/restdoc/schema/schema.go | 10 ++++----- web.go | 2 +- 9 files changed, 45 insertions(+), 25 deletions(-) diff --git a/cmd/web/restdoc/pkg/bench_test.go b/cmd/web/restdoc/pkg/bench_test.go index dc89f504..b2c66e11 100644 --- a/cmd/web/restdoc/pkg/bench_test.go +++ b/cmd/web/restdoc/pkg/bench_test.go @@ -13,7 +13,7 @@ import ( "github.com/issue9/web/cmd/web/restdoc/logger/loggertest" ) -func BenchmarkPackages_TypeOf_Int(b *testing.B) { +func BenchmarkPackages_typeOf_Int(b *testing.B) { a := assert.New(b, false) l := loggertest.New(a) p := New(l.Logger) @@ -21,11 +21,11 @@ func BenchmarkPackages_TypeOf_Int(b *testing.B) { p.ScanDir(context.Background(), "./testdir", true) ctx := context.Background() for range b.N { - p.TypeOf(ctx, "github.com/issue9/web/restdoc/pkg.Int") + p.typeOf(ctx, "github.com/issue9/web/restdoc/pkg.Int") } } -func BenchmarkPackages_TypeOf_S(b *testing.B) { +func BenchmarkPackages_typeOf_S(b *testing.B) { a := assert.New(b, false) l := loggertest.New(a) p := New(l.Logger) @@ -33,6 +33,6 @@ func BenchmarkPackages_TypeOf_S(b *testing.B) { p.ScanDir(context.Background(), "./testdir", true) ctx := context.Background() for range b.N { - p.TypeOf(ctx, "github.com/issue9/web/restdoc/pkg.S") + p.typeOf(ctx, "github.com/issue9/web/restdoc/pkg.S") } } diff --git a/cmd/web/restdoc/pkg/path.go b/cmd/web/restdoc/pkg/path.go index 0419fcff..32eff1bf 100644 --- a/cmd/web/restdoc/pkg/path.go +++ b/cmd/web/restdoc/pkg/path.go @@ -45,7 +45,7 @@ func (pkgs *Packages) splitFieldTypes(ctx context.Context, path string) (p strin panic(fmt.Sprintf("无效的语法 %s", path)) } - t, err := pkgs.TypeOf(ctx, ps[1]) + t, err := pkgs.typeOf(ctx, ps[1]) if err != nil { return err } @@ -148,7 +148,7 @@ func (pkgs *Packages) splitTypeParams(ctx context.Context, path string) (p strin if len(tps) > 0 { ts := make([]types.Type, 0, len(tps)) for _, p := range tps { - t, err := pkgs.TypeOf(ctx, p) + t, err := pkgs.typeOf(ctx, p) if err != nil { return "", nil, err } diff --git a/cmd/web/restdoc/pkg/path_test.go b/cmd/web/restdoc/pkg/path_test.go index 01174ec0..229bb1f4 100644 --- a/cmd/web/restdoc/pkg/path_test.go +++ b/cmd/web/restdoc/pkg/path_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/issue9/assert/v4" + "github.com/issue9/web/cmd/web/restdoc/logger/loggertest" ) diff --git a/cmd/web/restdoc/pkg/pkg.go b/cmd/web/restdoc/pkg/pkg.go index b3cdfacf..9e7bffde 100644 --- a/cmd/web/restdoc/pkg/pkg.go +++ b/cmd/web/restdoc/pkg/pkg.go @@ -35,7 +35,11 @@ type Packages struct { // 结构体可能存在相互引用的情况,保存每个结构体的数据,键名为 [Struct.String]。 structs map[string]*Struct - structsM sync.RWMutex + structsM sync.Mutex + + // Packages.TypeOf 会中途加载包文件,比较耗时, + // 防止 Packages.TypeOf 在执行到一半时又调用此方法加载相同的类型。 + typeOfM sync.Mutex } // 根据 st 生成一个空的 [Struct] 或是在已经存在的情况下返回该实例 @@ -46,10 +50,10 @@ func (pkgs *Packages) getStruct(st *ast.StructType, tps *types.TypeParamList, tl id = id + "[" + ss + "]" } - pkgs.structsM.RLock() - s = pkgs.structs[id] - pkgs.structsM.RUnlock() - if s != nil { + pkgs.structsM.Lock() + defer pkgs.structsM.Unlock() + + if s, f := pkgs.structs[id]; f { return s, false } @@ -60,10 +64,7 @@ func (pkgs *Packages) getStruct(st *ast.StructType, tps *types.TypeParamList, tl docs: make([]*ast.CommentGroup, 0, size), tags: make([]string, 0, size), } - - pkgs.structsM.Lock() pkgs.structs[id] = s - pkgs.structsM.Unlock() return s, true } diff --git a/cmd/web/restdoc/pkg/testdir/testdir2/testdir2.go b/cmd/web/restdoc/pkg/testdir/testdir2/testdir2.go index d74533eb..d92ed640 100644 --- a/cmd/web/restdoc/pkg/testdir/testdir2/testdir2.go +++ b/cmd/web/restdoc/pkg/testdir/testdir2/testdir2.go @@ -36,6 +36,7 @@ type GS[T0 any, T1 any, T2 any] struct { F3 T0 F4 T2 F5 S2 // 引用类型的字段 + F6 pkg.Int } type GSNumber = GS[int, int, pkg.S] diff --git a/cmd/web/restdoc/pkg/types.go b/cmd/web/restdoc/pkg/types.go index 69018735..4c806028 100644 --- a/cmd/web/restdoc/pkg/types.go +++ b/cmd/web/restdoc/pkg/types.go @@ -200,6 +200,12 @@ func getTypeParamsList(tpl *types.TypeParamList, tl typeList) string { // - {} 表示空值,将返回 nil, true // - map 或是 any 将返回 [types.InterfaceType] func (pkgs *Packages) TypeOf(ctx context.Context, path string) (types.Type, error) { + pkgs.typeOfM.Lock() + defer pkgs.typeOfM.Unlock() + return pkgs.typeOf(ctx, path) +} + +func (pkgs *Packages) typeOf(ctx context.Context, path string) (types.Type, error) { path, fieldTypes, err := pkgs.splitFieldTypes(ctx, path) if err != nil { return nil, err diff --git a/cmd/web/restdoc/pkg/types_test.go b/cmd/web/restdoc/pkg/types_test.go index c3fb6c01..3921c111 100644 --- a/cmd/web/restdoc/pkg/types_test.go +++ b/cmd/web/restdoc/pkg/types_test.go @@ -72,7 +72,7 @@ func TestPackages_TypeOf(t *testing.T) { p := New(l.Logger) p.ScanDir(context.Background(), "./testdir", true) - eq := func(a *assert.Assertion, path string, docs ...string) { + eq := func(a *assert.Assertion, path string, docs ...string) *Struct { a.TB().Helper() typ, err := p.TypeOf(context.Background(), path) a.NotError(err).NotNil(typ) @@ -93,10 +93,13 @@ func TestPackages_TypeOf(t *testing.T) { Equal(st.Field(1).Name(), "F1"). Equal(st.FieldDoc(3).Text(), "F2 Doc\n"). Equal(st.Field(3).Name(), "F2") + + return st } t.Run("pkg.S", func(_ *testing.T) { - eq(a, "github.com/issue9/web/restdoc/pkg.S", "S Doc\n") + s := eq(a, "github.com/issue9/web/restdoc/pkg.S", "S Doc\n") + a.Equal(s.NumFields(), 5) }) t.Run("pkg.S2", func(_ *testing.T) { @@ -281,7 +284,11 @@ func TestPackages_TypeOf_generic(t *testing.T) { eqG(a, "github.com/issue9/web/restdoc/pkg.GInt", "", "G Doc\n") }) - eqGS := func(a *assert.Assertion, path string, docs ...string) { + l = loggertest.New(a) + p = New(l.Logger) + p.ScanDir(context.Background(), "./testdir", true) + + eqGS := func(a *assert.Assertion, path string, docs ...string) *Struct { a.TB().Helper() typ, err := p.TypeOf(context.Background(), path) a.NotError(err).NotNil(typ) @@ -301,14 +308,18 @@ func TestPackages_TypeOf_generic(t *testing.T) { Equal(st.Field(2).Name(), "F4"). Equal(st.FieldDoc(3).Text(), "引用类型的字段\n"). Equal(st.Field(3).Name(), "F5") + + return st } t.Run("pkg.GSNumber", func(_ *testing.T) { - eqGS(a, "github.com/issue9/web/restdoc/pkg.GSNumber", "GSNumber Doc\n", "GS Doc\n") + s := eqGS(a, "github.com/issue9/web/restdoc/pkg.GSNumber", "GSNumber Doc\n", "GS Doc\n") + a.Equal(s.NumFields(), 6) }) t.Run("pkg/testdir2.GSNumber", func(_ *testing.T) { - eqGS(a, "github.com/issue9/web/restdoc/pkg/testdir2.GSNumber", "", "GS Doc\n") + s := eqGS(a, "github.com/issue9/web/restdoc/pkg/testdir2.GSNumber", "", "GS Doc\n") + a.Equal(s.NumFields(), 5) }) } diff --git a/cmd/web/restdoc/schema/schema.go b/cmd/web/restdoc/schema/schema.go index c0852170..2fa3f34f 100644 --- a/cmd/web/restdoc/schema/schema.go +++ b/cmd/web/restdoc/schema/schema.go @@ -38,14 +38,14 @@ var refReplacer = strings.NewReplacer( type Schema struct { pkg *pkg.Packages - structs map[string]string + structs map[string]*openapi3.SchemaRef structsM sync.Mutex } func New(l *logger.Logger) *Schema { return &Schema{ pkg: pkg.New(l), - structs: make(map[string]string, 10), + structs: make(map[string]*openapi3.SchemaRef, 10), } } @@ -54,12 +54,12 @@ func (s *Schema) getStruct(ref string, t types.Type) *openapi3.SchemaRef { s.structsM.Lock() defer s.structsM.Unlock() - if id, found := s.structs[t.String()]; found { - return openapi.NewSchemaRef(id, nil) + if r, found := s.structs[t.String()]; found { + return r } if ref != "" { - s.structs[t.String()] = ref + s.structs[t.String()] = openapi.NewSchemaRef(ref, nil) } return nil diff --git a/web.go b/web.go index c9b92c2f..839e7c42 100644 --- a/web.go +++ b/web.go @@ -24,7 +24,7 @@ import ( ) // Version 当前框架的版本 -const Version = "0.96.2" +const Version = "0.96.3" type ( Logger = logs.Logger