Skip to content

Commit

Permalink
LatencyMeasurer test refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
tilk committed Mar 31, 2024
1 parent 2971e30 commit 463f29f
Showing 1 changed file with 24 additions and 29 deletions.
53 changes: 24 additions & 29 deletions test/transactron/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from parameterized import parameterized_class

from amaranth import *
from amaranth.sim import Passive, Settle
from amaranth.sim import Settle

from transactron.lib.metrics import *
from transactron import *
Expand Down Expand Up @@ -227,6 +227,21 @@ def test_process():
sim.add_sync_process(test_process)


class TestLatencyMeasurerBase(TestCaseWithSimulator):
def check_latencies(self, m: SimpleTestCircuit, latencies: list[int]):
self.assertEqual(min(latencies), (yield m._dut.histogram.min.value))
self.assertEqual(max(latencies), (yield m._dut.histogram.max.value))
self.assertEqual(sum(latencies), (yield m._dut.histogram.sum.value))
self.assertEqual(len(latencies), (yield m._dut.histogram.count.value))

for i in range(m._dut.histogram.bucket_count):
bucket_start = 0 if i == 0 else 2 ** (i - 1)
bucket_end = 1e10 if i == m._dut.histogram.bucket_count - 1 else 2**i

count = sum(1 for x in latencies if bucket_start <= x < bucket_end)
self.assertEqual(count, (yield m._dut.histogram.buckets[i].value))


@parameterized_class(
("slots_number", "expected_consumer_wait"),
[
Expand All @@ -238,7 +253,7 @@ def test_process():
(5, 5),
],
)
class TestFIFOLatencyMeasurer(TestCaseWithSimulator):
class TestFIFOLatencyMeasurer(TestLatencyMeasurerBase):
slots_number: int
expected_consumer_wait: float

Expand All @@ -262,7 +277,7 @@ def producer():

# Make sure that the time is updated first.
yield Settle()
time = (yield Now())
time = yield Now()
event_queue.put(time)
yield from self.random_wait_geom(0.8)

Expand All @@ -274,22 +289,12 @@ def consumer():

# Make sure that the time is updated first.
yield Settle()
time = (yield Now())
time = yield Now()
latencies.append(time - event_queue.get())

yield from self.random_wait_geom(1.0 / self.expected_consumer_wait)

self.assertEqual(min(latencies), (yield m._dut.histogram.min.value))
self.assertEqual(max(latencies), (yield m._dut.histogram.max.value))
self.assertEqual(sum(latencies), (yield m._dut.histogram.sum.value))
self.assertEqual(len(latencies), (yield m._dut.histogram.count.value))

for i in range(m._dut.histogram.bucket_count):
bucket_start = 0 if i == 0 else 2 ** (i - 1)
bucket_end = 1e10 if i == m._dut.histogram.bucket_count - 1 else 2**i

count = sum(1 for x in latencies if bucket_start <= x < bucket_end)
self.assertEqual(count, (yield m._dut.histogram.buckets[i].value))
self.check_latencies(m, latencies)

with self.run_simulation(m) as sim:
sim.add_sync_process(producer)
Expand All @@ -307,7 +312,7 @@ def consumer():
(5, 5),
],
)
class TestIndexedLatencyMeasurer(TestCaseWithSimulator):
class TestIndexedLatencyMeasurer(TestLatencyMeasurerBase):
slots_number: int
expected_consumer_wait: float

Expand Down Expand Up @@ -337,7 +342,7 @@ def producer():
slot_id = random.choice(free_slots)
yield from m._start.call(slot=slot_id)

time = (yield Now())
time = yield Now()

events[slot_id] = time
free_slots.remove(slot_id)
Expand All @@ -357,7 +362,7 @@ def consumer():

yield from m._stop.call(slot=slot_id)

time = (yield Now())
time = yield Now()

yield Settle()
yield Settle()
Expand All @@ -368,17 +373,7 @@ def consumer():

yield from self.random_wait_geom(1.0 / self.expected_consumer_wait)

self.assertEqual(min(latencies), (yield m._dut.histogram.min.value))
self.assertEqual(max(latencies), (yield m._dut.histogram.max.value))
self.assertEqual(sum(latencies), (yield m._dut.histogram.sum.value))
self.assertEqual(len(latencies), (yield m._dut.histogram.count.value))

for i in range(m._dut.histogram.bucket_count):
bucket_start = 0 if i == 0 else 2 ** (i - 1)
bucket_end = 1e10 if i == m._dut.histogram.bucket_count - 1 else 2**i

count = sum(1 for x in latencies if bucket_start <= x < bucket_end)
self.assertEqual(count, (yield m._dut.histogram.buckets[i].value))
self.check_latencies(m, latencies)

with self.run_simulation(m) as sim:
sim.add_sync_process(producer)
Expand Down

0 comments on commit 463f29f

Please sign in to comment.