From 263720cfe7c71896d94c574366a452380d9a4b6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Urba=C5=84czyk?= Date: Mon, 12 Feb 2024 12:08:21 +0000 Subject: [PATCH] Hardware Metrics (#580) --- test/transactron/test_metrics.py | 399 +++++++++++++++++++++ transactron/lib/__init__.py | 1 + transactron/lib/metrics.py | 555 ++++++++++++++++++++++++++++++ transactron/utils/dependencies.py | 10 + 4 files changed, 965 insertions(+) create mode 100644 test/transactron/test_metrics.py create mode 100644 transactron/lib/metrics.py diff --git a/test/transactron/test_metrics.py b/test/transactron/test_metrics.py new file mode 100644 index 000000000..12acdfd27 --- /dev/null +++ b/test/transactron/test_metrics.py @@ -0,0 +1,399 @@ +import json +import random +import queue +from parameterized import parameterized_class + +from amaranth import * +from amaranth.sim import Passive, Settle + +from transactron.lib.metrics import * +from transactron import * +from transactron.testing import TestCaseWithSimulator, data_layout, SimpleTestCircuit +from transactron.utils.dependencies import DependencyContext + + +class CounterInMethodCircuit(Elaboratable): + def __init__(self): + self.method = Method() + self.counter = HwCounter("in_method") + + def elaborate(self, platform): + m = TModule() + + m.submodules.counter = self.counter + + @def_method(m, self.method) + def _(): + self.counter.incr(m) + + return m + + +class CounterWithConditionInMethodCircuit(Elaboratable): + def __init__(self): + self.method = Method(i=[("cond", 1)]) + self.counter = HwCounter("with_condition_in_method") + + def elaborate(self, platform): + m = TModule() + + m.submodules.counter = self.counter + + @def_method(m, self.method) + def _(cond): + self.counter.incr(m, cond=cond) + + return m + + +class CounterWithoutMethodCircuit(Elaboratable): + def __init__(self): + self.cond = Signal() + self.counter = HwCounter("with_condition_without_method") + + def elaborate(self, platform): + m = TModule() + + m.submodules.counter = self.counter + + with Transaction().body(m): + self.counter.incr(m, cond=self.cond) + + return m + + +class TestHwCounter(TestCaseWithSimulator): + def setUp(self) -> None: + random.seed(42) + + def test_counter_in_method(self): + m = SimpleTestCircuit(CounterInMethodCircuit()) + DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) + + def test_process(): + called_cnt = 0 + for _ in range(200): + call_now = random.randint(0, 1) == 0 + + if call_now: + yield from m.method.call() + else: + yield + + # Note that it takes one cycle to update the register value, so here + # we are comparing the "previous" values. + self.assertEqual(called_cnt, (yield m._dut.counter.count.value)) + + if call_now: + called_cnt += 1 + + with self.run_simulation(m) as sim: + sim.add_sync_process(test_process) + + def test_counter_with_condition_in_method(self): + m = SimpleTestCircuit(CounterWithConditionInMethodCircuit()) + DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) + + def test_process(): + called_cnt = 0 + for _ in range(200): + call_now = random.randint(0, 1) == 0 + condition = random.randint(0, 1) + + if call_now: + yield from m.method.call(cond=condition) + else: + yield + + # Note that it takes one cycle to update the register value, so here + # we are comparing the "previous" values. + self.assertEqual(called_cnt, (yield m._dut.counter.count.value)) + + if call_now and condition == 1: + called_cnt += 1 + + with self.run_simulation(m) as sim: + sim.add_sync_process(test_process) + + def test_counter_with_condition_without_method(self): + m = CounterWithoutMethodCircuit() + DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) + + def test_process(): + called_cnt = 0 + for _ in range(200): + condition = random.randint(0, 1) + + yield m.cond.eq(condition) + yield + + # Note that it takes one cycle to update the register value, so here + # we are comparing the "previous" values. + self.assertEqual(called_cnt, (yield m.counter.count.value)) + + if condition == 1: + called_cnt += 1 + + with self.run_simulation(m) as sim: + sim.add_sync_process(test_process) + + +class ExpHistogramCircuit(Elaboratable): + def __init__(self, bucket_cnt: int, sample_width: int): + self.sample_width = sample_width + + self.method = Method(i=data_layout(32)) + self.histogram = HwExpHistogram("histogram", bucket_count=bucket_cnt, sample_width=sample_width) + + def elaborate(self, platform): + m = TModule() + + m.submodules.histogram = self.histogram + + @def_method(m, self.method) + def _(data): + self.histogram.add(m, data[0 : self.sample_width]) + + return m + + +@parameterized_class( + ("bucket_count", "sample_width"), + [ + (5, 5), # last bucket is [8, inf), max sample=31 + (8, 5), # last bucket is [64, inf), max sample=31 + (8, 6), # last bucket is [64, inf), max sample=63 + (8, 20), # last bucket is [64, inf), max sample=big + ], +) +class TestHwHistogram(TestCaseWithSimulator): + bucket_count: int + sample_width: int + + def test_histogram(self): + random.seed(42) + + m = SimpleTestCircuit(ExpHistogramCircuit(bucket_cnt=self.bucket_count, sample_width=self.sample_width)) + DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) + + max_sample_value = 2**self.sample_width - 1 + + def test_process(): + min = max_sample_value + 1 + max = 0 + sum = 0 + count = 0 + + buckets = [0] * self.bucket_count + + for _ in range(500): + if random.randrange(3) == 0: + value = random.randint(0, max_sample_value) + if value < min: + min = value + if value > max: + max = value + sum += value + count += 1 + for i in range(self.bucket_count): + if value < 2**i or i == self.bucket_count - 1: + buckets[i] += 1 + break + yield from m.method.call(data=value) + yield + else: + yield + + histogram = m._dut.histogram + # Skip the assertion if the min is still uninitialized + if min != max_sample_value + 1: + self.assertEqual(min, (yield histogram.min.value)) + + self.assertEqual(max, (yield histogram.max.value)) + self.assertEqual(sum, (yield histogram.sum.value)) + self.assertEqual(count, (yield histogram.count.value)) + + total_count = 0 + for i in range(self.bucket_count): + bucket_value = yield histogram.buckets[i].value + total_count += bucket_value + self.assertEqual(buckets[i], bucket_value) + + # Sanity check if all buckets sum up to the total count value + self.assertEqual(total_count, (yield histogram.count.value)) + + with self.run_simulation(m) as sim: + sim.add_sync_process(test_process) + + +@parameterized_class( + ("slots_number", "expected_consumer_wait"), + [ + (2, 5), + (2, 10), + (5, 10), + (10, 1), + (10, 10), + (5, 5), + ], +) +class TestLatencyMeasurer(TestCaseWithSimulator): + slots_number: int + expected_consumer_wait: float + + def test_latency_measurer(self): + random.seed(42) + + m = SimpleTestCircuit(LatencyMeasurer("latency", slots_number=self.slots_number, max_latency=300)) + DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) + + latencies: list[int] = [] + + event_queue = queue.Queue() + + time = 0 + + def ticker(): + nonlocal time + + yield Passive() + + while True: + yield + time += 1 + + finish = False + + def producer(): + nonlocal finish + + for _ in range(200): + yield from m._start.call() + + # Make sure that the time is updated first. + yield Settle() + event_queue.put(time) + yield from self.random_wait_geom(0.8) + + finish = True + + def consumer(): + while not finish: + yield from m._stop.call() + + # Make sure that the time is updated first. + yield Settle() + latencies.append(time - event_queue.get()) + + yield from self.random_wait_geom(1.0 / self.expected_consumer_wait) + + self.assertEqual(min(latencies), (yield m._dut.histogram.min.value)) + self.assertEqual(max(latencies), (yield m._dut.histogram.max.value)) + self.assertEqual(sum(latencies), (yield m._dut.histogram.sum.value)) + self.assertEqual(len(latencies), (yield m._dut.histogram.count.value)) + + for i in range(m._dut.histogram.bucket_count): + bucket_start = 0 if i == 0 else 2 ** (i - 1) + bucket_end = 1e10 if i == m._dut.histogram.bucket_count - 1 else 2**i + + count = sum(1 for x in latencies if bucket_start <= x < bucket_end) + self.assertEqual(count, (yield m._dut.histogram.buckets[i].value)) + + with self.run_simulation(m) as sim: + sim.add_sync_process(producer) + sim.add_sync_process(consumer) + sim.add_sync_process(ticker) + + +class MetricManagerTestCircuit(Elaboratable): + def __init__(self): + self.incr_counters = Method(i=[("counter1", 1), ("counter2", 1), ("counter3", 1)]) + + self.counter1 = HwCounter("foo.counter1", "this is the description") + self.counter2 = HwCounter("bar.baz.counter2") + self.counter3 = HwCounter("bar.baz.counter3", "yet another description") + + def elaborate(self, platform): + m = TModule() + + m.submodules += [self.counter1, self.counter2, self.counter3] + + @def_method(m, self.incr_counters) + def _(counter1, counter2, counter3): + self.counter1.incr(m, cond=counter1) + self.counter2.incr(m, cond=counter2) + self.counter3.incr(m, cond=counter3) + + return m + + +class TestMetricsManager(TestCaseWithSimulator): + def test_metrics_metadata(self): + # We need to initialize the circuit to make sure that metrics are registered + # in the dependency manager. + m = MetricManagerTestCircuit() + metrics_manager = HardwareMetricsManager() + + # Run the simulation so Amaranth doesn't scream that we have unused elaboratables. + with self.run_simulation(m): + pass + + self.assertEqual( + metrics_manager.get_metrics()["foo.counter1"].to_json(), # type: ignore + json.dumps( + { + "fully_qualified_name": "foo.counter1", + "description": "this is the description", + "regs": {"count": {"name": "count", "description": "the value of the counter", "width": 32}}, + } + ), + ) + + self.assertEqual( + metrics_manager.get_metrics()["bar.baz.counter2"].to_json(), # type: ignore + json.dumps( + { + "fully_qualified_name": "bar.baz.counter2", + "description": "", + "regs": {"count": {"name": "count", "description": "the value of the counter", "width": 32}}, + } + ), + ) + + self.assertEqual( + metrics_manager.get_metrics()["bar.baz.counter3"].to_json(), # type: ignore + json.dumps( + { + "fully_qualified_name": "bar.baz.counter3", + "description": "yet another description", + "regs": {"count": {"name": "count", "description": "the value of the counter", "width": 32}}, + } + ), + ) + + def test_returned_reg_values(self): + random.seed(42) + + m = SimpleTestCircuit(MetricManagerTestCircuit()) + metrics_manager = HardwareMetricsManager() + + DependencyContext.get().add_dependency(HwMetricsEnabledKey(), True) + + def test_process(): + counters = [0] * 3 + for _ in range(200): + rand = [random.randint(0, 1) for _ in range(3)] + + yield from m.incr_counters.call(counter1=rand[0], counter2=rand[1], counter3=rand[2]) + yield + + for i in range(3): + if rand[i] == 1: + counters[i] += 1 + + self.assertEqual(counters[0], (yield metrics_manager.get_register_value("foo.counter1", "count"))) + self.assertEqual(counters[1], (yield metrics_manager.get_register_value("bar.baz.counter2", "count"))) + self.assertEqual(counters[2], (yield metrics_manager.get_register_value("bar.baz.counter3", "count"))) + + with self.run_simulation(m) as sim: + sim.add_sync_process(test_process) diff --git a/transactron/lib/__init__.py b/transactron/lib/__init__.py index c814b5e93..f6dd3ef0a 100644 --- a/transactron/lib/__init__.py +++ b/transactron/lib/__init__.py @@ -6,3 +6,4 @@ from .reqres import * # noqa: F401 from .storage import * # noqa: F401 from .simultaneous import * # noqa: F401 +from .metrics import * # noqa: F401 diff --git a/transactron/lib/metrics.py b/transactron/lib/metrics.py new file mode 100644 index 000000000..0a9e79095 --- /dev/null +++ b/transactron/lib/metrics.py @@ -0,0 +1,555 @@ +from dataclasses import dataclass, field +from dataclasses_json import dataclass_json +from typing import Optional +from abc import ABC + +from amaranth import * +from amaranth.utils import bits_for + +from transactron.utils import ValueLike +from transactron import Method, def_method, TModule +from transactron.utils import SignalBundle +from transactron.lib import FIFO +from transactron.utils.dependencies import ListKey, DependencyContext, SimpleKey + +__all__ = [ + "MetricRegisterModel", + "MetricModel", + "HwMetric", + "HwCounter", + "HwExpHistogram", + "LatencyMeasurer", + "HardwareMetricsManager", + "HwMetricsEnabledKey", +] + + +@dataclass_json +@dataclass(frozen=True) +class MetricRegisterModel: + """ + Represents a single register of a metric, serving as a fundamental + building block that holds a singular value. + + Attributes + ---------- + name: str + The unique identifier for the register (among remaning + registers of a specific metric). + description: str + A brief description of the metric's purpose. + width: int + The bit-width of the register. + """ + + name: str + description: str + width: int + + +@dataclass_json +@dataclass +class MetricModel: + """ + Provides information about a metric exposed by the circuit. Each metric + comprises multiple registers, each dedicated to storing specific values. + + The configuration of registers is internally determined by a + specific metric type and is not user-configurable. + + Attributes + ---------- + fully_qualified_name: str + The fully qualified name of the metric, with name components joined by dots ('.'), + e.g., 'foo.bar.requests'. + description: str + A human-readable description of the metric's functionality. + regs: list[MetricRegisterModel] + A list of registers associated with the metric. + """ + + fully_qualified_name: str + description: str + regs: dict[str, MetricRegisterModel] = field(default_factory=dict) + + +class HwMetricRegister(MetricRegisterModel): + """ + A concrete implementation of a metric register that holds its value as Amaranth signal. + + Attributes + ---------- + value: Signal + Amaranth signal representing the value of the register. + """ + + def __init__(self, name: str, width_bits: int, description: str = "", reset: int = 0): + """ + Parameters + ---------- + name: str + The unique identifier for the register (among remaning + registers of a specific metric). + width: int + The bit-width of the register. + description: str + A brief description of the metric's purpose. + reset: int + The reset value of the register. + """ + super().__init__(name, description, width_bits) + + self.value = Signal(width_bits, reset=reset, name=name) + + +@dataclass(frozen=True) +class HwMetricsListKey(ListKey["HwMetric"]): + """DependencyManager key collecting hardware metrics globally as a list.""" + + pass + + +@dataclass(frozen=True) +class HwMetricsEnabledKey(SimpleKey[bool]): + """ + DependencyManager key for enabling hardware metrics. If metrics are disabled, + none of theirs signals will be synthesized. + """ + + lock_on_get = False + empty_valid = True + default_value = False + + +class HwMetric(ABC, MetricModel): + """ + A base for all metric implementations. It should be only used for declaring + new types of metrics. + + It takes care of registering the metric in the dependency manager. + + Attributes + ---------- + signals: dict[str, Signal] + A mapping from a register name to a Signal containing the value of that register. + """ + + def __init__(self, fully_qualified_name: str, description: str): + """ + Parameters + ---------- + fully_qualified_name: str + The fully qualified name of the metric. + description: str + A human-readable description of the metric's functionality. + """ + super().__init__(fully_qualified_name, description) + + self.signals: dict[str, Signal] = {} + + # add the metric to the global list of all metrics + DependencyContext.get().add_dependency(HwMetricsListKey(), self) + + def add_registers(self, regs: list[HwMetricRegister]): + """ + Adds registers to a metric. Should be only called by inheriting classes + during initialization. + + Parameters + ---------- + regs: list[HwMetricRegister] + A list of registers to be registered. + """ + for reg in regs: + if reg.name in self.regs: + raise RuntimeError(f"Register {reg.name}' is already added to the metric {self.fully_qualified_name}") + + self.regs[reg.name] = reg + self.signals[reg.name] = reg.value + + def metrics_enabled(self) -> bool: + return DependencyContext.get().get_dependency(HwMetricsEnabledKey()) + + +class HwCounter(Elaboratable, HwMetric): + """Hardware Counter + + The most basic hardware metric that can just increase its value. + """ + + def __init__(self, fully_qualified_name: str, description: str = "", *, width_bits: int = 32): + """ + Parameters + ---------- + fully_qualified_name: str + The fully qualified name of the metric. + description: str + A human-readable description of the metric's functionality. + width_bits: int + The bit-width of the register. Defaults to 32 bits. + """ + + super().__init__(fully_qualified_name, description) + + self.count = HwMetricRegister("count", width_bits, "the value of the counter") + + self.add_registers([self.count]) + + self._incr = Method() + + def elaborate(self, platform): + if not self.metrics_enabled(): + return TModule() + + m = TModule() + + @def_method(m, self._incr) + def _(): + m.d.sync += self.count.value.eq(self.count.value + 1) + + return m + + def incr(self, m: TModule, *, cond: ValueLike = C(1)): + """ + Increases the value of the counter by 1. + + Should be called in the body of either a transaction or a method. + + Parameters + ---------- + m: TModule + Transactron module + """ + if not self.metrics_enabled(): + return + + with m.If(cond): + self._incr(m) + + +class HwExpHistogram(Elaboratable, HwMetric): + """Hardware Exponential Histogram + + Represents the distribution of sampled data through a histogram. A histogram + samples observations (usually things like request durations or queue sizes) and counts + them in a configurable number of buckets. The buckets are of exponential size. For example, + a histogram with 5 buckets would have the following value ranges: + [0, 1); [1, 2); [2, 4); [4, 8); [8, +inf). + + Additionally, the histogram tracks the number of observations, the sum + of observed values, and the minimum and maximum values. + """ + + def __init__( + self, + fully_qualified_name: str, + description: str = "", + *, + bucket_count: int, + sample_width: int = 32, + registers_width: int = 32, + ): + """ + Parameters + ---------- + fully_qualified_name: str + The fully qualified name of the metric. + description: str + A human-readable description of the metric's functionality. + max_value: int + The maximum value that the histogram would be able to count. This + value is used to calculate the number of buckets. + """ + + super().__init__(fully_qualified_name, description) + self.bucket_count = bucket_count + self.sample_width = sample_width + + self._add = Method(i=[("sample", self.sample_width)]) + + self.count = HwMetricRegister("count", registers_width, "the count of events that have been observed") + self.sum = HwMetricRegister("sum", registers_width, "the total sum of all observed values") + self.min = HwMetricRegister( + "min", + self.sample_width, + "the minimum of all observed values", + reset=(1 << self.sample_width) - 1, + ) + self.max = HwMetricRegister("max", self.sample_width, "the maximum of all observed values") + + self.buckets = [] + for i in range(self.bucket_count): + bucket_start = 0 if i == 0 else 2 ** (i - 1) + bucket_end = "inf" if i == self.bucket_count - 1 else 2**i + + self.buckets.append( + HwMetricRegister( + f"bucket-{bucket_end}", + registers_width, + f"the cumulative counter for the observation bucket [{bucket_start}, {bucket_end})", + ) + ) + + self.add_registers([self.count, self.sum, self.max, self.min] + self.buckets) + + def elaborate(self, platform): + if not self.metrics_enabled(): + return TModule() + + m = TModule() + + @def_method(m, self._add) + def _(sample): + m.d.sync += self.count.value.eq(self.count.value + 1) + m.d.sync += self.sum.value.eq(self.sum.value + sample) + + with m.If(sample > self.max.value): + m.d.sync += self.max.value.eq(sample) + + with m.If(sample < self.min.value): + m.d.sync += self.min.value.eq(sample) + + # todo: perhaps replace with a recursive implementation of the priority encoder + bucket_idx = Signal(range(self.sample_width)) + for i in range(self.sample_width): + with m.If(sample[i]): + m.d.av_comb += bucket_idx.eq(i) + + for i, bucket in enumerate(self.buckets): + should_incr = C(0) + if i == 0: + # The first bucket has a range [0, 1). + should_incr = sample == 0 + elif i == self.bucket_count - 1: + # The last bucket should count values bigger or equal to 2**(self.bucket_count-1) + should_incr = (bucket_idx >= i - 1) & (sample != 0) + else: + should_incr = (bucket_idx == i - 1) & (sample != 0) + + with m.If(should_incr): + m.d.sync += bucket.value.eq(bucket.value + 1) + + return m + + def add(self, m: TModule, sample: Value): + """ + Adds a new sample to the histogram. + + Should be called in the body of either a transaction or a method. + + Parameters + ---------- + m: TModule + Transactron module + sample: ValueLike + The value that will be added to the histogram + """ + + if not self.metrics_enabled(): + return + + self._add(m, sample) + + +class LatencyMeasurer(Elaboratable): + """ + Measures duration between two events, e.g. request processing latency. + It can track multiple events at the same time, i.e. the second event can + be registered as started, before the first finishes. However, they must be + processed in the FIFO order. + + The module exposes an exponential histogram of the measured latencies. + """ + + def __init__( + self, + fully_qualified_name: str, + description: str = "", + *, + slots_number: int, + max_latency: int, + ): + """ + Parameters + ---------- + fully_qualified_name: str + The fully qualified name of the metric. + description: str + A human-readable description of the metric's functionality. + slots_number: str + A number of events that the module can track simultaneously. + max_latency: int + The maximum latency of an event. Used to set signal widths and + number of buckets in the histogram. If a latency turns to be + bigger than the maximum, it will overflow and result in a false + measurement. + """ + self.fully_qualified_name = fully_qualified_name + self.description = description + self.slots_number = slots_number + self.max_latency = max_latency + + self._start = Method() + self._stop = Method() + + # This bucket count gives us the best possible granularity. + bucket_count = bits_for(self.max_latency) + 1 + self.histogram = HwExpHistogram( + self.fully_qualified_name, + self.description, + bucket_count=bucket_count, + sample_width=bits_for(self.max_latency), + ) + + def elaborate(self, platform): + if not self.metrics_enabled(): + return TModule() + + m = TModule() + + epoch_width = bits_for(self.max_latency) + + m.submodules.fifo = self.fifo = FIFO([("epoch", epoch_width)], self.slots_number) + m.submodules.histogram = self.histogram + + epoch = Signal(epoch_width) + + m.d.sync += epoch.eq(epoch + 1) + + @def_method(m, self._start) + def _(): + self.fifo.write(m, epoch) + + @def_method(m, self._stop) + def _(): + ret = self.fifo.read(m) + # The result of substracting two unsigned n-bit is a signed (n+1)-bit value, + # so we need to cast the result and discard the most significant bit. + duration = (epoch - ret.epoch).as_unsigned()[:-1] + self.histogram.add(m, duration) + + return m + + def start(self, m: TModule): + """ + Registers the start of an event. Can be called before the previous events + finish. If there are no slots available, the method will be blocked. + + Should be called in the body of either a transaction or a method. + + Parameters + ---------- + m: TModule + Transactron module + """ + + if not self.metrics_enabled(): + return + + self._start(m) + + def stop(self, m: TModule): + """ + Registers the end of the oldest event (the FIFO order). If there are no + started events in the queue, the method will block. + + Should be called in the body of either a transaction or a method. + + Parameters + ---------- + m: TModule + Transactron module + """ + + if not self.metrics_enabled(): + return + + self._stop(m) + + def metrics_enabled(self) -> bool: + return DependencyContext.get().get_dependency(HwMetricsEnabledKey()) + + +class HardwareMetricsManager: + """ + Collects all metrics registered in the circuit and provides an easy + access to them. + """ + + def __init__(self): + self._metrics: Optional[dict[str, HwMetric]] = None + + def _collect_metrics(self) -> dict[str, HwMetric]: + # We lazily collect all metrics so that the metrics manager can be + # constructed at any time. Otherwise, if a metric object was created + # after the manager object had been created, that metric wouldn't end up + # being registered. + metrics: dict[str, HwMetric] = {} + for metric in DependencyContext.get().get_dependency(HwMetricsListKey()): + if metric.fully_qualified_name in metrics: + raise RuntimeError(f"Metric '{metric.fully_qualified_name}' is already registered") + + metrics[metric.fully_qualified_name] = metric + + return metrics + + def get_metrics(self) -> dict[str, HwMetric]: + """ + Returns all metrics registered in the circuit. + """ + if self._metrics is None: + self._metrics = self._collect_metrics() + return self._metrics + + def get_register_value(self, metric_name: str, reg_name: str) -> Signal: + """ + Returns the signal holding the register value of the given metric. + + Parameters + ---------- + metric_name: str + The fully qualified name of the metric, for example 'frontend.icache.loads'. + reg_name: str + The name of the register from that metric, for example if + the metric is a histogram, the 'reg_name' could be 'min' + or 'bucket-32'. + """ + + metrics = self.get_metrics() + if metric_name not in metrics: + raise RuntimeError(f"Couldn't find metric '{metric_name}'") + return metrics[metric_name].signals[reg_name] + + def debug_signals(self) -> SignalBundle: + """ + Returns tree-like SignalBundle composed of all metric registers. + """ + metrics = self.get_metrics() + + def rec(metric_names: list[str], depth: int = 1): + bundle: list[SignalBundle] = [] + components: dict[str, list[str]] = {} + + for metric in metric_names: + parts = metric.split(".") + + if len(parts) == depth: + signals = metrics[metric].signals + reg_values = [signals[reg_name] for reg_name in signals] + + bundle.append({metric: reg_values}) + + continue + + component_prefix = ".".join(parts[:depth]) + + if component_prefix not in components: + components[component_prefix] = [] + components[component_prefix].append(metric) + + for component_name, elements in components.items(): + bundle.append({component_name: rec(elements, depth + 1)}) + + return bundle + + return {"metrics": rec(list(self.get_metrics().keys()))} diff --git a/transactron/utils/dependencies.py b/transactron/utils/dependencies.py index f365e17b9..e66a7a23b 100644 --- a/transactron/utils/dependencies.py +++ b/transactron/utils/dependencies.py @@ -62,9 +62,19 @@ class SimpleKey(Generic[T], DependencyKey[T, T]): Simple dependency keys are used when there is an one-to-one relation between keys and dependencies. If more than one dependency is added to a simple key, an error is raised. + + Parameters + ---------- + default_value: T + Specifies the default value returned when no dependencies are added. To + enable it `empty_valid` must be True. """ + default_value: T + def combine(self, data: list[T]) -> T: + if len(data) == 0: + return self.default_value if len(data) != 1: raise RuntimeError(f"Key {self} assigned {len(data)} values, expected 1") return data[0]