From 24bb4b1823b13e8e283185b3eb7df05bf9a3b50c Mon Sep 17 00:00:00 2001 From: Achille Date: Wed, 10 May 2023 23:17:08 -0700 Subject: [PATCH] integrate latest wazero experimental updates (#66) Pulls the changes from https://github.com/tetratelabs/wazero/pull/1453 --------- Signed-off-by: Achille Roussel --- cpu.go | 9 ++++--- go.mod | 2 +- go.sum | 4 ++-- mem.go | 62 ++++++++++++++++++++++++++----------------------- sampler.go | 15 ++++++++---- sampler_test.go | 5 ++-- 6 files changed, 55 insertions(+), 42 deletions(-) diff --git a/cpu.go b/cpu.go index 23e3de9..a1912ad 100644 --- a/cpu.go +++ b/cpu.go @@ -190,7 +190,7 @@ func (p *CPUProfiler) NewFunctionListener(def api.FunctionDefinition) experiment type cpuProfiler struct{ *CPUProfiler } -func (p cpuProfiler) Before(ctx context.Context, mod api.Module, def api.FunctionDefinition, params []uint64, si experimental.StackIterator) context.Context { +func (p cpuProfiler) Before(ctx context.Context, mod api.Module, def api.FunctionDefinition, _ []uint64, si experimental.StackIterator) { var frame cpuTimeFrame p.mutex.Lock() @@ -212,10 +212,9 @@ func (p cpuProfiler) Before(ctx context.Context, mod api.Module, def api.Functio p.mutex.Unlock() p.frames = append(p.frames, frame) - return ctx } -func (p cpuProfiler) After(ctx context.Context, mod api.Module, def api.FunctionDefinition, err error, results []uint64) { +func (p cpuProfiler) After(ctx context.Context, mod api.Module, def api.FunctionDefinition, _ []uint64) { i := len(p.frames) - 1 f := p.frames[i] p.frames = p.frames[:i] @@ -229,3 +228,7 @@ func (p cpuProfiler) After(ctx context.Context, mod api.Module, def api.Function p.traces = append(p.traces, f.trace) } } + +func (p cpuProfiler) Abort(ctx context.Context, mod api.Module, def api.FunctionDefinition, _ error) { + p.After(ctx, mod, def, nil) +} diff --git a/go.mod b/go.mod index e8237aa..d765f68 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,6 @@ go 1.20 require ( github.com/google/pprof v0.0.0-20230406165453-00490a63f317 - github.com/tetratelabs/wazero v1.1.1-0.20230509203401-ef3d67119550 + github.com/tetratelabs/wazero v1.1.1-0.20230511035210-78c35acd6e1c golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 ) diff --git a/go.sum b/go.sum index b2f392f..e90953c 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,6 @@ github.com/google/pprof v0.0.0-20230406165453-00490a63f317 h1:hFhpt7CTmR3DX+b4R19ydQFtofxT0Sv3QsKNMVQYTMQ= github.com/google/pprof v0.0.0-20230406165453-00490a63f317/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk= -github.com/tetratelabs/wazero v1.1.1-0.20230509203401-ef3d67119550 h1:t2+P0h4kAojvOntox6816oFO76+N+lLYq2b3xKMzNnI= -github.com/tetratelabs/wazero v1.1.1-0.20230509203401-ef3d67119550/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ= +github.com/tetratelabs/wazero v1.1.1-0.20230511035210-78c35acd6e1c h1:3vKw5AyUv+16qm9tIVfvm205aDIpDmrjxPydJ87GBKY= +github.com/tetratelabs/wazero v1.1.1-0.20230511035210-78c35acd6e1c/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ= golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 h1:5llv2sWeaMSnA3w2kS57ouQQ4pudlXrR0dCgw51QK9o= golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= diff --git a/mem.go b/mem.go index 7aa18b1..135a57b 100644 --- a/mem.go +++ b/mem.go @@ -221,16 +221,16 @@ type mallocProfiler struct { stack stackTrace } -func (p *mallocProfiler) Before(ctx context.Context, mod api.Module, def api.FunctionDefinition, params []uint64, si experimental.StackIterator) context.Context { +func (p *mallocProfiler) Before(ctx context.Context, mod api.Module, def api.FunctionDefinition, params []uint64, si experimental.StackIterator) { p.size = api.DecodeU32(params[0]) p.stack = makeStackTrace(p.stack, si) - return ctx } -func (p *mallocProfiler) After(ctx context.Context, mod api.Module, def api.FunctionDefinition, err error, results []uint64) { - if err == nil { - p.memory.observeAlloc(api.DecodeU32(results[0]), p.size, p.stack) - } +func (p *mallocProfiler) After(ctx context.Context, mod api.Module, def api.FunctionDefinition, results []uint64) { + p.memory.observeAlloc(api.DecodeU32(results[0]), p.size, p.stack) +} + +func (p *mallocProfiler) Abort(ctx context.Context, mod api.Module, def api.FunctionDefinition, _ error) { } type callocProfiler struct { @@ -240,17 +240,17 @@ type callocProfiler struct { stack stackTrace } -func (p *callocProfiler) Before(ctx context.Context, mod api.Module, def api.FunctionDefinition, params []uint64, si experimental.StackIterator) context.Context { +func (p *callocProfiler) Before(ctx context.Context, mod api.Module, def api.FunctionDefinition, params []uint64, si experimental.StackIterator) { p.count = api.DecodeU32(params[0]) p.size = api.DecodeU32(params[1]) p.stack = makeStackTrace(p.stack, si) - return ctx } -func (p *callocProfiler) After(ctx context.Context, mod api.Module, def api.FunctionDefinition, err error, results []uint64) { - if err == nil { - p.memory.observeAlloc(api.DecodeU32(results[0]), p.count*p.size, p.stack) - } +func (p *callocProfiler) After(ctx context.Context, mod api.Module, def api.FunctionDefinition, results []uint64) { + p.memory.observeAlloc(api.DecodeU32(results[0]), p.count*p.size, p.stack) +} + +func (p *callocProfiler) Abort(ctx context.Context, mod api.Module, def api.FunctionDefinition, _ error) { } type reallocProfiler struct { @@ -260,18 +260,18 @@ type reallocProfiler struct { stack stackTrace } -func (p *reallocProfiler) Before(ctx context.Context, mod api.Module, def api.FunctionDefinition, params []uint64, si experimental.StackIterator) context.Context { +func (p *reallocProfiler) Before(ctx context.Context, mod api.Module, def api.FunctionDefinition, params []uint64, si experimental.StackIterator) { p.addr = api.DecodeU32(params[0]) p.size = api.DecodeU32(params[1]) p.stack = makeStackTrace(p.stack, si) - return ctx } -func (p *reallocProfiler) After(ctx context.Context, mod api.Module, def api.FunctionDefinition, err error, results []uint64) { - if err == nil { - p.memory.observeFree(p.addr) - p.memory.observeAlloc(api.DecodeU32(results[0]), p.size, p.stack) - } +func (p *reallocProfiler) After(ctx context.Context, mod api.Module, def api.FunctionDefinition, results []uint64) { + p.memory.observeFree(p.addr) + p.memory.observeAlloc(api.DecodeU32(results[0]), p.size, p.stack) +} + +func (p *reallocProfiler) Abort(ctx context.Context, mod api.Module, def api.FunctionDefinition, _ error) { } type freeProfiler struct { @@ -279,15 +279,16 @@ type freeProfiler struct { addr uint32 } -func (p *freeProfiler) Before(ctx context.Context, mod api.Module, def api.FunctionDefinition, params []uint64, si experimental.StackIterator) context.Context { +func (p *freeProfiler) Before(ctx context.Context, mod api.Module, def api.FunctionDefinition, params []uint64, si experimental.StackIterator) { p.addr = api.DecodeU32(params[0]) - return ctx } -func (p *freeProfiler) After(ctx context.Context, mod api.Module, def api.FunctionDefinition, err error, results []uint64) { - if err == nil { - p.memory.observeFree(p.addr) - } +func (p *freeProfiler) After(ctx context.Context, mod api.Module, def api.FunctionDefinition, _ []uint64) { + p.memory.observeFree(p.addr) +} + +func (p *freeProfiler) Abort(ctx context.Context, mod api.Module, def api.FunctionDefinition, _ error) { + p.After(ctx, mod, def, nil) } type goRuntimeMallocgcProfiler struct { @@ -296,7 +297,7 @@ type goRuntimeMallocgcProfiler struct { stack stackTrace } -func (p *goRuntimeMallocgcProfiler) Before(ctx context.Context, mod api.Module, def api.FunctionDefinition, params []uint64, si experimental.StackIterator) context.Context { +func (p *goRuntimeMallocgcProfiler) Before(ctx context.Context, mod api.Module, def api.FunctionDefinition, params []uint64, si experimental.StackIterator) { imod := mod.(experimental.InternalModule) mem := imod.Memory() @@ -309,13 +310,16 @@ func (p *goRuntimeMallocgcProfiler) Before(ctx context.Context, mod api.Module, } else { p.size = 0 } - return ctx } -func (p *goRuntimeMallocgcProfiler) After(ctx context.Context, mod api.Module, def api.FunctionDefinition, err error, results []uint64) { - if err == nil && p.size != 0 { +func (p *goRuntimeMallocgcProfiler) After(ctx context.Context, mod api.Module, def api.FunctionDefinition, _ []uint64) { + if p.size != 0 { // TODO: get the returned pointer addr := uint32(0) p.memory.observeAlloc(addr, p.size, p.stack) } } + +func (p *goRuntimeMallocgcProfiler) Abort(ctx context.Context, mod api.Module, def api.FunctionDefinition, _ error) { + p.After(ctx, mod, def, nil) +} diff --git a/sampler.go b/sampler.go index cc8c206..24af81d 100644 --- a/sampler.go +++ b/sampler.go @@ -55,22 +55,27 @@ type sampledFunctionListener struct { lstn experimental.FunctionListener } -func (s *sampledFunctionListener) Before(ctx context.Context, mod api.Module, def api.FunctionDefinition, params []uint64, stack experimental.StackIterator) context.Context { +func (s *sampledFunctionListener) Before(ctx context.Context, mod api.Module, def api.FunctionDefinition, params []uint64, stack experimental.StackIterator) { bit := uint(0) if s.count--; s.count == 0 { s.count = s.cycle + s.lstn.Before(ctx, mod, def, params, stack) bit = 1 - ctx = s.lstn.Before(ctx, mod, def, params, stack) } s.stack.push(bit) - return ctx } -func (s *sampledFunctionListener) After(ctx context.Context, mod api.Module, def api.FunctionDefinition, err error, results []uint64) { +func (s *sampledFunctionListener) After(ctx context.Context, mod api.Module, def api.FunctionDefinition, results []uint64) { if s.stack.pop() != 0 { - s.lstn.After(ctx, mod, def, err, results) + s.lstn.After(ctx, mod, def, results) + } +} + +func (s *sampledFunctionListener) Abort(ctx context.Context, mod api.Module, def api.FunctionDefinition, err error) { + if s.stack.pop() != 0 { + s.lstn.Abort(ctx, mod, def, err) } } diff --git a/sampler_test.go b/sampler_test.go index fd1cbfc..342420e 100644 --- a/sampler_test.go +++ b/sampler_test.go @@ -25,10 +25,11 @@ func TestSampledFunctionListener(t *testing.T) { function := module.Function(0).Definition() listener := factory.NewFunctionListener(function) + ctx := context.Background() for i := 0; i < 20; i++ { - ctx := listener.Before(context.Background(), module, function, nil, nil) - listener.After(ctx, module, function, nil, nil) + listener.Before(ctx, module, function, nil, nil) + listener.After(ctx, module, function, nil) } if n != 2 {