From 1f8e1bc0c013bb33e337723a1553846a11f65ee3 Mon Sep 17 00:00:00 2001 From: caixw Date: Mon, 26 Feb 2024 21:48:08 +0800 Subject: [PATCH] =?UTF-8?q?perf(cmd/web/restdoc):=20=E6=8C=89=E7=9B=AE?= =?UTF-8?q?=E5=BD=95=E7=BC=93=E5=AD=98=E5=8A=A0=E8=BD=BD=E7=9A=84=E7=BB=93?= =?UTF-8?q?=E6=9E=84=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/web/restdoc/pkg/bench_test.go | 36 ++++++++++++++++++++++++++++++ cmd/web/restdoc/pkg/path.go | 7 +++--- cmd/web/restdoc/pkg/pkg.go | 37 ++++++++++++++++++------------- cmd/web/restdoc/pkg/pkg_test.go | 3 ++- 4 files changed, 64 insertions(+), 19 deletions(-) create mode 100644 cmd/web/restdoc/pkg/bench_test.go diff --git a/cmd/web/restdoc/pkg/bench_test.go b/cmd/web/restdoc/pkg/bench_test.go new file mode 100644 index 00000000..07c1f61e --- /dev/null +++ b/cmd/web/restdoc/pkg/bench_test.go @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT + +package pkg + +import ( + "context" + "testing" + + "github.com/issue9/assert/v3" + + "github.com/issue9/web/cmd/web/restdoc/logger/loggertest" +) + +func BenchmarkPackages_TypeOf_Int(b *testing.B) { + a := assert.New(b, false) + l := loggertest.New(a) + p := New(l.Logger) + + p.ScanDir(context.Background(), "./testdir", true) + ctx := context.Background() + for range b.N { + p.TypeOf(ctx, "github.com/issue9/web/restdoc/pkg.Int") + } +} + +func BenchmarkPackages_TypeOf_S(b *testing.B) { + a := assert.New(b, false) + l := loggertest.New(a) + p := New(l.Logger) + + p.ScanDir(context.Background(), "./testdir", true) + ctx := context.Background() + for range b.N { + p.TypeOf(ctx, "github.com/issue9/web/restdoc/pkg.S") + } +} diff --git a/cmd/web/restdoc/pkg/path.go b/cmd/web/restdoc/pkg/path.go index 965e06c1..13a7ee8b 100644 --- a/cmd/web/restdoc/pkg/path.go +++ b/cmd/web/restdoc/pkg/path.go @@ -140,19 +140,20 @@ func (pkgs *Packages) lookup(ctx context.Context, typePath string) (types.Object // 出于性能考虑并未加载依赖项,但是可能会依赖部分标准库的类型, // 此处对标准库作了特殊处理:未找到标准库中的对象时会加载相应的包。 if pkgPath != "" && strings.IndexByte(pkgPath, '.') < 0 { - ps, err := pkgs.load(ctx, path.Join(build.Default.GOROOT, "src", pkgPath)) + dir := path.Join(build.Default.GOROOT, "src", pkgPath) + p, err := pkgs.load(ctx, dir) if err != nil { pkgs.l.Error(err, "", 0) return nil, nil, nil, false } - return findInPkgs(ps, pkgPath, typeName) + return findInPkgs(map[string]*packages.Package{dir: p}, pkgPath, typeName) } return nil, nil, nil, false } -func findInPkgs(ps []*packages.Package, pkgPath, typeName string) (types.Object, *ast.TypeSpec, *ast.File, bool) { +func findInPkgs(ps map[string]*packages.Package, pkgPath, typeName string) (types.Object, *ast.TypeSpec, *ast.File, bool) { for _, p := range ps { if p.PkgPath != pkgPath { continue diff --git a/cmd/web/restdoc/pkg/pkg.go b/cmd/web/restdoc/pkg/pkg.go index 03446efd..34f40db7 100644 --- a/cmd/web/restdoc/pkg/pkg.go +++ b/cmd/web/restdoc/pkg/pkg.go @@ -5,12 +5,11 @@ package pkg import ( "context" + "fmt" "go/token" "path/filepath" - "slices" "sync" - "github.com/issue9/sliceutil" "github.com/issue9/web" "golang.org/x/tools/go/packages" @@ -25,14 +24,14 @@ const Cancelled = web.StringPhrase("cancelled") // Packages 管理加载的包 type Packages struct { pkgsM sync.Mutex - pkgs []*packages.Package + pkgs map[string]*packages.Package // 键名为对应的目录名 fset *token.FileSet l *logger.Logger } func New(l *logger.Logger) *Packages { return &Packages{ - pkgs: make([]*packages.Package, 0, 10), + pkgs: make(map[string]*packages.Package, 30), fset: token.NewFileSet(), l: l, } @@ -73,7 +72,15 @@ func (pkgs *Packages) ScanDir(ctx context.Context, root string, recursive bool) wg.Wait() } -func (pkgs *Packages) load(ctx context.Context, dir string) ([]*packages.Package, error) { +func (pkgs *Packages) load(ctx context.Context, dir string) (*packages.Package, error) { + dir = filepath.Clean(dir) + + pkgs.pkgsM.Lock() + defer pkgs.pkgsM.Unlock() + if p, found := pkgs.pkgs[dir]; found { + return p, nil + } + ps, err := packages.Load(&packages.Config{ Mode: mode, Context: ctx, @@ -84,23 +91,23 @@ func (pkgs *Packages) load(ctx context.Context, dir string) ([]*packages.Package return nil, err } - pkgs.pkgsM.Lock() - defer pkgs.pkgsM.Unlock() - for _, p := range ps { - if slices.IndexFunc(pkgs.pkgs, func(e *packages.Package) bool { return p.PkgPath == e.PkgPath }) < 0 { - pkgs.pkgs = append(pkgs.pkgs, p) - } + if len(ps) > 1 { + panic(fmt.Sprintf("目录 %s 中包的数量大于 1:%d", dir, len(ps))) } - - return ps, nil + pkgs.pkgs[dir] = ps[0] + return ps[0], nil } func (pkgs *Packages) FileSet() *token.FileSet { return pkgs.fset } // Package 返回指定路径的包对象 func (pkgs *Packages) Package(path string) *packages.Package { - pkg, _ := sliceutil.At(pkgs.pkgs, func(p *packages.Package, _ int) bool { return p.PkgPath == path }) - return pkg + for _, p := range pkgs.pkgs { + if p.PkgPath == path { + return p + } + } + return nil } // Range 依次访问已经加载的包 diff --git a/cmd/web/restdoc/pkg/pkg_test.go b/cmd/web/restdoc/pkg/pkg_test.go index def718ed..fdb5b845 100644 --- a/cmd/web/restdoc/pkg/pkg_test.go +++ b/cmd/web/restdoc/pkg/pkg_test.go @@ -4,6 +4,7 @@ package pkg import ( "context" + "path/filepath" "testing" "github.com/issue9/assert/v3" @@ -23,5 +24,5 @@ func TestPackages_ScanDir(t *testing.T) { p = New(l.Logger) p.ScanDir(context.Background(), "./testdir", false) a.Length(p.pkgs, 1). - Equal(p.pkgs[0].PkgPath, "github.com/issue9/web/restdoc/pkg") + Equal(p.pkgs[filepath.Clean("./testdir")].PkgPath, "github.com/issue9/web/restdoc/pkg") }