Skip to content

Commit

Permalink
fix(cmd/web/restdoc): 限制 pkg.Packages.TypeOf 的调用
Browse files Browse the repository at this point in the history
  • Loading branch information
caixw committed May 31, 2024
1 parent c65ca0a commit 14e9780
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 25 deletions.
8 changes: 4 additions & 4 deletions cmd/web/restdoc/pkg/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,26 @@ import (
"github.com/issue9/web/cmd/web/restdoc/logger/loggertest"
)

func BenchmarkPackages_TypeOf_Int(b *testing.B) {
func BenchmarkPackages_typeOf_Int(b *testing.B) {
a := assert.New(b, false)
l := loggertest.New(a)
p := New(l.Logger)

p.ScanDir(context.Background(), "./testdir", true)
ctx := context.Background()
for range b.N {
p.TypeOf(ctx, "github.com/issue9/web/restdoc/pkg.Int")
p.typeOf(ctx, "github.com/issue9/web/restdoc/pkg.Int")
}
}

func BenchmarkPackages_TypeOf_S(b *testing.B) {
func BenchmarkPackages_typeOf_S(b *testing.B) {
a := assert.New(b, false)
l := loggertest.New(a)
p := New(l.Logger)

p.ScanDir(context.Background(), "./testdir", true)
ctx := context.Background()
for range b.N {
p.TypeOf(ctx, "github.com/issue9/web/restdoc/pkg.S")
p.typeOf(ctx, "github.com/issue9/web/restdoc/pkg.S")
}
}
4 changes: 2 additions & 2 deletions cmd/web/restdoc/pkg/path.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (pkgs *Packages) splitFieldTypes(ctx context.Context, path string) (p strin
panic(fmt.Sprintf("无效的语法 %s", path))
}

t, err := pkgs.TypeOf(ctx, ps[1])
t, err := pkgs.typeOf(ctx, ps[1])
if err != nil {
return err
}
Expand Down Expand Up @@ -148,7 +148,7 @@ func (pkgs *Packages) splitTypeParams(ctx context.Context, path string) (p strin
if len(tps) > 0 {
ts := make([]types.Type, 0, len(tps))
for _, p := range tps {
t, err := pkgs.TypeOf(ctx, p)
t, err := pkgs.typeOf(ctx, p)
if err != nil {
return "", nil, err
}
Expand Down
1 change: 1 addition & 0 deletions cmd/web/restdoc/pkg/path_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"testing"

"github.com/issue9/assert/v4"

"github.com/issue9/web/cmd/web/restdoc/logger/loggertest"
)

Expand Down
17 changes: 9 additions & 8 deletions cmd/web/restdoc/pkg/pkg.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ type Packages struct {

// 结构体可能存在相互引用的情况,保存每个结构体的数据,键名为 [Struct.String]。
structs map[string]*Struct
structsM sync.RWMutex
structsM sync.Mutex

// Packages.TypeOf 会中途加载包文件,比较耗时,
// 防止 Packages.TypeOf 在执行到一半时又调用此方法加载相同的类型。
typeOfM sync.Mutex
}

// 根据 st 生成一个空的 [Struct] 或是在已经存在的情况下返回该实例
Expand All @@ -46,10 +50,10 @@ func (pkgs *Packages) getStruct(st *ast.StructType, tps *types.TypeParamList, tl
id = id + "[" + ss + "]"
}

pkgs.structsM.RLock()
s = pkgs.structs[id]
pkgs.structsM.RUnlock()
if s != nil {
pkgs.structsM.Lock()
defer pkgs.structsM.Unlock()

if s, f := pkgs.structs[id]; f {
return s, false
}

Expand All @@ -60,10 +64,7 @@ func (pkgs *Packages) getStruct(st *ast.StructType, tps *types.TypeParamList, tl
docs: make([]*ast.CommentGroup, 0, size),
tags: make([]string, 0, size),
}

pkgs.structsM.Lock()
pkgs.structs[id] = s
pkgs.structsM.Unlock()

return s, true
}
Expand Down
1 change: 1 addition & 0 deletions cmd/web/restdoc/pkg/testdir/testdir2/testdir2.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ type GS[T0 any, T1 any, T2 any] struct {
F3 T0
F4 T2
F5 S2 // 引用类型的字段
F6 pkg.Int
}

type GSNumber = GS[int, int, pkg.S]
6 changes: 6 additions & 0 deletions cmd/web/restdoc/pkg/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,12 @@ func getTypeParamsList(tpl *types.TypeParamList, tl typeList) string {
// - {} 表示空值,将返回 nil, true
// - map 或是 any 将返回 [types.InterfaceType]
func (pkgs *Packages) TypeOf(ctx context.Context, path string) (types.Type, error) {
pkgs.typeOfM.Lock()
defer pkgs.typeOfM.Unlock()
return pkgs.typeOf(ctx, path)
}

func (pkgs *Packages) typeOf(ctx context.Context, path string) (types.Type, error) {
path, fieldTypes, err := pkgs.splitFieldTypes(ctx, path)
if err != nil {
return nil, err
Expand Down
21 changes: 16 additions & 5 deletions cmd/web/restdoc/pkg/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func TestPackages_TypeOf(t *testing.T) {
p := New(l.Logger)
p.ScanDir(context.Background(), "./testdir", true)

eq := func(a *assert.Assertion, path string, docs ...string) {
eq := func(a *assert.Assertion, path string, docs ...string) *Struct {
a.TB().Helper()
typ, err := p.TypeOf(context.Background(), path)
a.NotError(err).NotNil(typ)
Expand All @@ -93,10 +93,13 @@ func TestPackages_TypeOf(t *testing.T) {
Equal(st.Field(1).Name(), "F1").
Equal(st.FieldDoc(3).Text(), "F2 Doc\n").
Equal(st.Field(3).Name(), "F2")

return st
}

t.Run("pkg.S", func(_ *testing.T) {
eq(a, "github.com/issue9/web/restdoc/pkg.S", "S Doc\n")
s := eq(a, "github.com/issue9/web/restdoc/pkg.S", "S Doc\n")
a.Equal(s.NumFields(), 5)
})

t.Run("pkg.S2", func(_ *testing.T) {
Expand Down Expand Up @@ -281,7 +284,11 @@ func TestPackages_TypeOf_generic(t *testing.T) {
eqG(a, "github.com/issue9/web/restdoc/pkg.GInt", "", "G Doc\n")
})

eqGS := func(a *assert.Assertion, path string, docs ...string) {
l = loggertest.New(a)
p = New(l.Logger)
p.ScanDir(context.Background(), "./testdir", true)

eqGS := func(a *assert.Assertion, path string, docs ...string) *Struct {
a.TB().Helper()
typ, err := p.TypeOf(context.Background(), path)
a.NotError(err).NotNil(typ)
Expand All @@ -301,14 +308,18 @@ func TestPackages_TypeOf_generic(t *testing.T) {
Equal(st.Field(2).Name(), "F4").
Equal(st.FieldDoc(3).Text(), "引用类型的字段\n").
Equal(st.Field(3).Name(), "F5")

return st
}

t.Run("pkg.GSNumber", func(_ *testing.T) {
eqGS(a, "github.com/issue9/web/restdoc/pkg.GSNumber", "GSNumber Doc\n", "GS Doc\n")
s := eqGS(a, "github.com/issue9/web/restdoc/pkg.GSNumber", "GSNumber Doc\n", "GS Doc\n")
a.Equal(s.NumFields(), 6)
})

t.Run("pkg/testdir2.GSNumber", func(_ *testing.T) {
eqGS(a, "github.com/issue9/web/restdoc/pkg/testdir2.GSNumber", "", "GS Doc\n")
s := eqGS(a, "github.com/issue9/web/restdoc/pkg/testdir2.GSNumber", "", "GS Doc\n")
a.Equal(s.NumFields(), 5)
})
}

Expand Down
10 changes: 5 additions & 5 deletions cmd/web/restdoc/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ var refReplacer = strings.NewReplacer(
type Schema struct {
pkg *pkg.Packages

structs map[string]string
structs map[string]*openapi3.SchemaRef
structsM sync.Mutex
}

func New(l *logger.Logger) *Schema {
return &Schema{
pkg: pkg.New(l),
structs: make(map[string]string, 10),
structs: make(map[string]*openapi3.SchemaRef, 10),
}
}

Expand All @@ -54,12 +54,12 @@ func (s *Schema) getStruct(ref string, t types.Type) *openapi3.SchemaRef {
s.structsM.Lock()
defer s.structsM.Unlock()

if id, found := s.structs[t.String()]; found {
return openapi.NewSchemaRef(id, nil)
if r, found := s.structs[t.String()]; found {
return r
}

if ref != "" {
s.structs[t.String()] = ref
s.structs[t.String()] = openapi.NewSchemaRef(ref, nil)
}

return nil
Expand Down
2 changes: 1 addition & 1 deletion web.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
)

// Version 当前框架的版本
const Version = "0.96.2"
const Version = "0.96.3"

type (
Logger = logs.Logger
Expand Down

0 comments on commit 14e9780

Please sign in to comment.