Skip to content

Commit

Permalink
Add TaggedCounter (kuznia-rdzeni/coreblocks#637)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jakub Urbańczyk authored Apr 1, 2024
1 parent 6896604 commit 81c1f50
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 4 deletions.
82 changes: 82 additions & 0 deletions test/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import json
import random
import queue
from typing import Type
from enum import IntFlag, IntEnum, auto, Enum

from parameterized import parameterized_class

from amaranth import *
Expand Down Expand Up @@ -138,6 +141,85 @@ def test_process():
sim.add_sync_process(test_process)


class OneHotEnum(IntFlag):
ADD = auto()
XOR = auto()
OR = auto()


class PlainIntEnum(IntEnum):
TEST_1 = auto()
TEST_2 = auto()
TEST_3 = auto()


class TaggedCounterCircuit(Elaboratable):
def __init__(self, tags: range | Type[Enum] | list[int]):
self.counter = TaggedCounter("counter", "", tags=tags)

self.cond = Signal()
self.tag = Signal(self.counter.tag_width)

def elaborate(self, platform):
m = TModule()

m.submodules.counter = self.counter

with Transaction().body(m):
self.counter.incr(m, self.tag, cond=self.cond)

return m


class TestTaggedCounter(TestCaseWithSimulator):
def setUp(self) -> None:
random.seed(42)

def do_test_enum(self, tags: range | Type[Enum] | list[int], tag_values: list[int]):
m = TaggedCounterCircuit(tags)
DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True)

counts: dict[int, int] = {}
for i in tag_values:
counts[i] = 0

def test_process():
for _ in range(200):
for i in tag_values:
self.assertEqual(counts[i], (yield m.counter.counters[i].value))

tag = random.choice(list(tag_values))

yield m.cond.eq(1)
yield m.tag.eq(tag)
yield
yield m.cond.eq(0)
yield

counts[tag] += 1

with self.run_simulation(m) as sim:
sim.add_sync_process(test_process)

def test_one_hot_enum(self):
self.do_test_enum(OneHotEnum, [e.value for e in OneHotEnum])

def test_plain_int_enum(self):
self.do_test_enum(PlainIntEnum, [e.value for e in PlainIntEnum])

def test_negative_range(self):
r = range(-10, 15, 3)
self.do_test_enum(r, list(r))

def test_positive_range(self):
r = range(0, 30, 2)
self.do_test_enum(r, list(r))

def test_value_list(self):
values = [-2137, 2, 4, 8, 42]
self.do_test_enum(values, values)


class ExpHistogramCircuit(Elaboratable):
def __init__(self, bucket_cnt: int, sample_width: int):
self.sample_width = sample_width
Expand Down
130 changes: 126 additions & 4 deletions transactron/lib/metrics.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from dataclasses import dataclass, field
from dataclasses_json import dataclass_json
from typing import Optional
from typing import Optional, Type
from abc import ABC
from enum import Enum

from amaranth import *
from amaranth.utils import bits_for
from amaranth.utils import bits_for, ceil_log2, exact_log2

from transactron.utils import ValueLike
from transactron.utils import ValueLike, OneHotSwitchDynamic, SignalBundle
from transactron import Method, def_method, TModule
from transactron.utils import SignalBundle
from transactron.lib import FIFO
from transactron.utils.dependencies import ListKey, DependencyContext, SimpleKey

Expand All @@ -17,6 +17,7 @@
"MetricModel",
"HwMetric",
"HwCounter",
"TaggedCounter",
"HwExpHistogram",
"LatencyMeasurer",
"HardwareMetricsManager",
Expand Down Expand Up @@ -230,6 +231,127 @@ def incr(self, m: TModule, *, cond: ValueLike = C(1)):
self._incr(m)


class TaggedCounter(Elaboratable, HwMetric):
"""Hardware Tagged Counter
Like HwCounter, but contains multiple counters, each with its own tag.
At a time a single counter can be increased and the value of the tag
can be provided dynamically. The type of the tag can be either an int
enum, a range or a list of integers (negative numbers are ok).
Internally, it detects if tag values can be one-hot encoded and if so,
it generates more optimized circuit.
Attributes
----------
tag_width: int
The length of the signal holding a tag value.
one_hot: bool
Whether tag values can be one-hot encoded.
counters: dict[int, HwMetricRegisters]
Mapping from a tag value to a register holding a counter for that tag.
"""

def __init__(
self,
fully_qualified_name: str,
description: str = "",
*,
tags: range | Type[Enum] | list[int],
registers_width: int = 32,
):
"""
Parameters
----------
fully_qualified_name: str
The fully qualified name of the metric.
description: str
A human-readable description of the metric's functionality.
tags: range | Type[Enum] | list[int]
Tag values.
registers_width: int
Width of the underlying registers. Defaults to 32 bits.
"""

super().__init__(fully_qualified_name, description)

if isinstance(tags, range) or isinstance(tags, list):
counters_meta = [(i, f"{i}") for i in tags]
else:
counters_meta = [(i.value, i.name) for i in tags]

values = [value for value, _ in counters_meta]
self.tag_width = max(bits_for(max(values)), bits_for(min(values)))

self.one_hot = True
negative_values = False
for value in values:
if value < 0:
self.one_hot = False
negative_values = True
break

log = ceil_log2(value)
if 2**log != value:
self.one_hot = False

self._incr = Method(i=[("tag", Shape(self.tag_width, signed=negative_values))])

self.counters: dict[int, HwMetricRegister] = {}
for tag_value, name in counters_meta:
value_str = ("1<<" + str(exact_log2(tag_value))) if self.one_hot else str(tag_value)
description = f"the counter for tag {name} (value={value_str})"

self.counters[tag_value] = HwMetricRegister(
name,
registers_width,
description,
)

self.add_registers(list(self.counters.values()))

def elaborate(self, platform):
if not self.metrics_enabled():
return TModule()

m = TModule()

@def_method(m, self._incr)
def _(tag):
if self.one_hot:
sorted_tags = sorted(list(self.counters.keys()))
for i in OneHotSwitchDynamic(m, tag):
counter = self.counters[sorted_tags[i]]
m.d.sync += counter.value.eq(counter.value + 1)
else:
for tag_value, counter in self.counters.items():
with m.If(tag == tag_value):
m.d.sync += counter.value.eq(counter.value + 1)

return m

def incr(self, m: TModule, tag: ValueLike, *, cond: ValueLike = C(1)):
"""
Increases the counter of a given tag by 1.
Should be called in the body of either a transaction or a method.
Parameters
----------
m: TModule
Transactron module
tag: ValueLike
The tag of the counter.
cond: ValueLike
When set to high, the counter will be increased. By default set to high.
"""
if not self.metrics_enabled():
return

with m.If(cond):
self._incr(m, tag)


class HwExpHistogram(Elaboratable, HwMetric):
"""Hardware Exponential Histogram
Expand Down

0 comments on commit 81c1f50

Please sign in to comment.