diff --git a/test/test_metrics.py b/test/test_metrics.py index 12acdfd..7a91616 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -1,6 +1,9 @@ import json import random import queue +from typing import Type +from enum import IntFlag, IntEnum, auto, Enum + from parameterized import parameterized_class from amaranth import * @@ -138,6 +141,85 @@ def test_process(): sim.add_sync_process(test_process) +class OneHotEnum(IntFlag): + ADD = auto() + XOR = auto() + OR = auto() + + +class PlainIntEnum(IntEnum): + TEST_1 = auto() + TEST_2 = auto() + TEST_3 = auto() + + +class TaggedCounterCircuit(Elaboratable): + def __init__(self, tags: range | Type[Enum] | list[int]): + self.counter = TaggedCounter("counter", "", tags=tags) + + self.cond = Signal() + self.tag = Signal(self.counter.tag_width) + + def elaborate(self, platform): + m = TModule() + + m.submodules.counter = self.counter + + with Transaction().body(m): + self.counter.incr(m, self.tag, cond=self.cond) + + return m + + +class TestTaggedCounter(TestCaseWithSimulator): + def setUp(self) -> None: + random.seed(42) + + def do_test_enum(self, tags: range | Type[Enum] | list[int], tag_values: list[int]): + m = TaggedCounterCircuit(tags) + DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) + + counts: dict[int, int] = {} + for i in tag_values: + counts[i] = 0 + + def test_process(): + for _ in range(200): + for i in tag_values: + self.assertEqual(counts[i], (yield m.counter.counters[i].value)) + + tag = random.choice(list(tag_values)) + + yield m.cond.eq(1) + yield m.tag.eq(tag) + yield + yield m.cond.eq(0) + yield + + counts[tag] += 1 + + with self.run_simulation(m) as sim: + sim.add_sync_process(test_process) + + def test_one_hot_enum(self): + self.do_test_enum(OneHotEnum, [e.value for e in OneHotEnum]) + + def test_plain_int_enum(self): + self.do_test_enum(PlainIntEnum, [e.value for e in PlainIntEnum]) + + def test_negative_range(self): + r = range(-10, 15, 3) + self.do_test_enum(r, list(r)) + + def test_positive_range(self): + r = range(0, 30, 2) + self.do_test_enum(r, list(r)) + + def test_value_list(self): + values = [-2137, 2, 4, 8, 42] + self.do_test_enum(values, values) + + class ExpHistogramCircuit(Elaboratable): def __init__(self, bucket_cnt: int, sample_width: int): self.sample_width = sample_width diff --git a/transactron/lib/metrics.py b/transactron/lib/metrics.py index 2e706e0..f3d5b9e 100644 --- a/transactron/lib/metrics.py +++ b/transactron/lib/metrics.py @@ -1,14 +1,14 @@ from dataclasses import dataclass, field from dataclasses_json import dataclass_json -from typing import Optional +from typing import Optional, Type from abc import ABC +from enum import Enum from amaranth import * -from amaranth.utils import bits_for +from amaranth.utils import bits_for, ceil_log2, exact_log2 -from transactron.utils import ValueLike +from transactron.utils import ValueLike, OneHotSwitchDynamic, SignalBundle from transactron import Method, def_method, TModule -from transactron.utils import SignalBundle from transactron.lib import FIFO from transactron.utils.dependencies import ListKey, DependencyContext, SimpleKey @@ -17,6 +17,7 @@ "MetricModel", "HwMetric", "HwCounter", + "TaggedCounter", "HwExpHistogram", "LatencyMeasurer", "HardwareMetricsManager", @@ -230,6 +231,127 @@ def incr(self, m: TModule, *, cond: ValueLike = C(1)): self._incr(m) +class TaggedCounter(Elaboratable, HwMetric): + """Hardware Tagged Counter + + Like HwCounter, but contains multiple counters, each with its own tag. + At a time a single counter can be increased and the value of the tag + can be provided dynamically. The type of the tag can be either an int + enum, a range or a list of integers (negative numbers are ok). + + Internally, it detects if tag values can be one-hot encoded and if so, + it generates more optimized circuit. + + Attributes + ---------- + tag_width: int + The length of the signal holding a tag value. + one_hot: bool + Whether tag values can be one-hot encoded. + counters: dict[int, HwMetricRegisters] + Mapping from a tag value to a register holding a counter for that tag. + """ + + def __init__( + self, + fully_qualified_name: str, + description: str = "", + *, + tags: range | Type[Enum] | list[int], + registers_width: int = 32, + ): + """ + Parameters + ---------- + fully_qualified_name: str + The fully qualified name of the metric. + description: str + A human-readable description of the metric's functionality. + tags: range | Type[Enum] | list[int] + Tag values. + registers_width: int + Width of the underlying registers. Defaults to 32 bits. + """ + + super().__init__(fully_qualified_name, description) + + if isinstance(tags, range) or isinstance(tags, list): + counters_meta = [(i, f"{i}") for i in tags] + else: + counters_meta = [(i.value, i.name) for i in tags] + + values = [value for value, _ in counters_meta] + self.tag_width = max(bits_for(max(values)), bits_for(min(values))) + + self.one_hot = True + negative_values = False + for value in values: + if value < 0: + self.one_hot = False + negative_values = True + break + + log = ceil_log2(value) + if 2**log != value: + self.one_hot = False + + self._incr = Method(i=[("tag", Shape(self.tag_width, signed=negative_values))]) + + self.counters: dict[int, HwMetricRegister] = {} + for tag_value, name in counters_meta: + value_str = ("1<<" + str(exact_log2(tag_value))) if self.one_hot else str(tag_value) + description = f"the counter for tag {name} (value={value_str})" + + self.counters[tag_value] = HwMetricRegister( + name, + registers_width, + description, + ) + + self.add_registers(list(self.counters.values())) + + def elaborate(self, platform): + if not self.metrics_enabled(): + return TModule() + + m = TModule() + + @def_method(m, self._incr) + def _(tag): + if self.one_hot: + sorted_tags = sorted(list(self.counters.keys())) + for i in OneHotSwitchDynamic(m, tag): + counter = self.counters[sorted_tags[i]] + m.d.sync += counter.value.eq(counter.value + 1) + else: + for tag_value, counter in self.counters.items(): + with m.If(tag == tag_value): + m.d.sync += counter.value.eq(counter.value + 1) + + return m + + def incr(self, m: TModule, tag: ValueLike, *, cond: ValueLike = C(1)): + """ + Increases the counter of a given tag by 1. + + Should be called in the body of either a transaction or a method. + + Parameters + ---------- + m: TModule + Transactron module + tag: ValueLike + The tag of the counter. + cond: ValueLike + When set to high, the counter will be increased. By default set to high. + """ + if not self.metrics_enabled(): + return + + with m.If(cond): + self._incr(m, tag) + + class HwExpHistogram(Elaboratable, HwMetric): """Hardware Exponential Histogram