diff --git a/cpu.go b/cpu.go index a1912ad..195cecd 100644 --- a/cpu.go +++ b/cpu.go @@ -50,6 +50,7 @@ func TimeFunc(time func() int64) CPUProfilerOption { type cpuTimeFrame struct { start int64 + sub int64 trace stackTrace } @@ -99,11 +100,6 @@ func (p *CPUProfiler) StopProfile(sampleRate float64, symbols Symbolizer) *profi for k, sample := range samples { if sample.stack.host() { delete(samples, k) - for _, other := range samples { - if sample.stack.contains(other.stack) { - other.subtract(sample.total()) - } - } } } } @@ -220,9 +216,14 @@ func (p cpuProfiler) After(ctx context.Context, mod api.Module, def api.Function p.frames = p.frames[:i] if f.start != 0 { + duration := p.time() - f.start + if i := len(p.frames); i > 0 { + p.frames[i-1].sub += duration + } + duration -= f.sub p.mutex.Lock() if p.counts != nil { - p.counts.observe(f.trace, p.time()-f.start) + p.counts.observe(f.trace, duration) } p.mutex.Unlock() p.traces = append(p.traces, f.trace) diff --git a/cpu_test.go b/cpu_test.go index 684a9cb..5c2c35d 100644 --- a/cpu_test.go +++ b/cpu_test.go @@ -1,7 +1,12 @@ package wzprof import ( + "context" "testing" + + "github.com/tetratelabs/wazero/api" + "github.com/tetratelabs/wazero/experimental" + "github.com/tetratelabs/wazero/experimental/wazerotest" ) func BenchmarkCPUProfilerOn(b *testing.B) { @@ -14,3 +19,100 @@ func BenchmarkCPUProfilerOff(b *testing.B) { p := NewCPUProfiler() benchmarkFunctionListener(b, p) } + +func TestCPUProfilerTime(t *testing.T) { + currentTime := int64(0) + + p := NewCPUProfiler( + TimeFunc(func() int64 { return currentTime }), + ) + + module := wazerotest.NewModule(nil, + wazerotest.NewFunction(func(context.Context, api.Module) {}), + wazerotest.NewFunction(func(context.Context, api.Module) {}), + wazerotest.NewFunction(func(context.Context, api.Module) {}), + ) + + f0 := p.NewFunctionListener(module.Function(0).Definition()) + f1 := p.NewFunctionListener(module.Function(1).Definition()) + f2 := p.NewFunctionListener(module.Function(2).Definition()) + + stack0 := []experimental.StackFrame{ + {Function: module.Function(0)}, + } + + stack1 := []experimental.StackFrame{ + {Function: module.Function(0)}, + {Function: module.Function(1)}, + } + + stack2 := []experimental.StackFrame{ + {Function: module.Function(0)}, + {Function: module.Function(1)}, + {Function: module.Function(2)}, + } + + def0 := stack0[0].Function.Definition() + def1 := stack1[1].Function.Definition() + def2 := stack2[2].Function.Definition() + + ctx := context.Background() + + const ( + t0 int64 = 1 + t1 int64 = 10 + t2 int64 = 42 + t3 int64 = 100 + t4 int64 = 101 + t5 int64 = 102 + ) + + p.StartProfile() + + currentTime = t0 + f0.Before(ctx, module, def0, nil, experimental.NewStackIterator(stack0...)) + + currentTime = t1 + f1.Before(ctx, module, def1, nil, experimental.NewStackIterator(stack1...)) + + currentTime = t2 + f2.Before(ctx, module, def2, nil, experimental.NewStackIterator(stack2...)) + + currentTime = t3 + f2.After(ctx, module, def2, nil) + + currentTime = t4 + f1.After(ctx, module, def1, nil) + + currentTime = t5 + f0.After(ctx, module, def0, nil) + + trace0 := makeStackTraceFromFrames(stack0) + trace1 := makeStackTraceFromFrames(stack1) + trace2 := makeStackTraceFromFrames(stack2) + + d2 := t3 - t2 + d1 := t4 - (t1 + d2) + d0 := t5 - (t0 + d1 + d2) + + assertStackCount(t, p.counts, trace0, 1, d0) + assertStackCount(t, p.counts, trace1, 1, d1) + assertStackCount(t, p.counts, trace2, 1, d2) +} + +func assertStackCount(t *testing.T, counts stackCounterMap, trace stackTrace, count, total int64) { + t.Helper() + c := counts.lookup(trace) + + if c.count() != count { + t.Errorf("%sstack count mismatch: want=%d got=%d", trace, count, c.count()) + } + + if c.total() != total { + t.Errorf("%sstack total mismatch: want=%d got=%d", trace, total, c.total()) + } +} + +func makeStackTraceFromFrames(stackFrames []experimental.StackFrame) stackTrace { + return makeStackTrace(stackTrace{}, experimental.NewStackIterator(stackFrames...)) +} diff --git a/wzprof.go b/wzprof.go index 71c4cc9..49a0a9a 100644 --- a/wzprof.go +++ b/wzprof.go @@ -1,7 +1,6 @@ package wzprof import ( - "bytes" "fmt" "hash/maphash" "net/http" @@ -186,14 +185,6 @@ func (sc *stackCounter) total() int64 { return sc.value[1] } -func (sc *stackCounter) subtract(value int64) { - if total := sc.total(); total < value { - sc.value[1] = 0 - } else { - sc.value[1] -= value - } -} - func (sc *stackCounter) sampleLocation() stackTrace { return sc.stack } @@ -202,6 +193,10 @@ func (sc *stackCounter) sampleValue() []int64 { return sc.value[:] } +func (sc *stackCounter) String() string { + return fmt.Sprintf("{count:%d,total:%d}", sc.count(), sc.total()) +} + // Compile-time check that program counters are uint64 values. var _ = assertTypeIsUint64[experimental.ProgramCounter]() @@ -235,21 +230,6 @@ func (st stackTrace) host() bool { return len(st.fns) > 0 && st.fns[0].Definition().GoFunction() != nil } -func (st stackTrace) contains(sx stackTrace) bool { - if len(st.fns) < len(sx.fns) { - return false - } - n := len(st.fns) - len(sx.fns) - st.fns = st.fns[n:] - st.pcs = st.pcs[n:] - if st.fns[0] != sx.fns[0] { - return false - } - st.pcs = st.pcs[1:] - sx.pcs = sx.pcs[1:] - return bytes.Equal(st.bytes(), sx.bytes()) -} - func (st stackTrace) len() int { return len(st.pcs) }