From 8de36247d9e1b289b7e178b2367e0b1a1152e670 Mon Sep 17 00:00:00 2001 From: Marek Materzok Date: Tue, 5 Mar 2024 16:09:08 +0100 Subject: [PATCH] Assertions in cocotb (https://github.com/kuznia-rdzeni/coreblocks/pull/590) --- test/testing/test_assertion.py | 2 +- transactron/testing/assertion.py | 7 ++---- transactron/utils/assertion.py | 32 +++++++++++---------------- transactron/utils/gen.py | 37 +++++++++++++++++++++++++++++++- 4 files changed, 52 insertions(+), 26 deletions(-) diff --git a/test/testing/test_assertion.py b/test/testing/test_assertion.py index c5bc128..4becf30 100644 --- a/test/testing/test_assertion.py +++ b/test/testing/test_assertion.py @@ -14,7 +14,7 @@ def elaborate(self, platform): m.d.comb += self.output.eq(self.input & ~self.input) - assertion(self.input == self.output) + assertion(m, self.input == self.output) return m diff --git a/transactron/testing/assertion.py b/transactron/testing/assertion.py index 8ae9bdf..19c5a41 100644 --- a/transactron/testing/assertion.py +++ b/transactron/testing/assertion.py @@ -2,21 +2,18 @@ from typing import Any from amaranth.sim import Passive, Tick from transactron.utils import assert_bit, assert_bits -from transactron.utils.dependencies import DependencyContext __all__ = ["make_assert_handler"] def make_assert_handler(my_assert: Callable[[int, str], Any]): - dependency_manager = DependencyContext.get() - def assert_handler(): yield Passive() while True: yield Tick("sync_neg") - if not (yield assert_bit(dependency_manager)): - for v, (n, i) in assert_bits(dependency_manager): + if not (yield assert_bit()): + for v, (n, i) in assert_bits(): my_assert((yield v), f"Assertion at {n}:{i}") yield diff --git a/transactron/utils/assertion.py b/transactron/utils/assertion.py index 537fbd0..b79a74f 100644 --- a/transactron/utils/assertion.py +++ b/transactron/utils/assertion.py @@ -4,17 +4,18 @@ import operator from dataclasses import dataclass from transactron.utils import SrcLoc -from transactron.utils.dependencies import DependencyContext, DependencyManager, ListKey +from transactron.utils._typing import ModuleLike, ValueLike +from transactron.utils.dependencies import DependencyContext, ListKey __all__ = ["AssertKey", "assertion", "assert_bit", "assert_bits"] @dataclass(frozen=True) -class AssertKey(ListKey[tuple[Value, SrcLoc]]): +class AssertKey(ListKey[tuple[Signal, SrcLoc]]): pass -def assertion(value: Value, *, src_loc_at: int = 0): +def assertion(m: ModuleLike, value: ValueLike, *, src_loc_at: int = 0): """Define an assertion. This function might help find some hardware bugs which might otherwise be @@ -24,6 +25,8 @@ def assertion(value: Value, *, src_loc_at: int = 0): Parameters ---------- + m: Module + Module in which the assertion is defined. value : Value If the value of this Amaranth expression is false, the assertion will fail. @@ -32,35 +35,26 @@ def assertion(value: Value, *, src_loc_at: int = 0): identify the failing assertion. """ src_loc = get_src_loc(src_loc_at) + sig = Signal() + m.d.comb += sig.eq(value) dependencies = DependencyContext.get() - dependencies.add_dependency(AssertKey(), (value, src_loc)) + dependencies.add_dependency(AssertKey(), (sig, src_loc)) -def assert_bits(dependencies: DependencyManager) -> list[tuple[Value, SrcLoc]]: +def assert_bits() -> list[tuple[Signal, SrcLoc]]: """Gets assertion bits. This function returns all the assertion signals created by `assertion`, together with their source locations. - - Parameters - ---------- - dependencies : DependencyManager - The assertion feature uses the `DependencyManager` to store - assertions. """ + dependencies = DependencyContext.get() return dependencies.get_dependency(AssertKey()) -def assert_bit(dependencies: DependencyManager) -> Value: +def assert_bit() -> Signal: """Gets assertion bit. The signal returned by this function is false if and only if there exists a false signal among assertion bits created by `assertion`. - - Parameters - ---------- - dependencies : DependencyManager - The assertion feature uses the `DependencyManager` to store - assertions. """ - return reduce(operator.and_, [a[0] for a in assert_bits(dependencies)], C(1)) + return reduce(operator.and_, [a[0] for a in assert_bits()], C(1)) diff --git a/transactron/utils/gen.py b/transactron/utils/gen.py index d5aa53d..b2cbf6d 100644 --- a/transactron/utils/gen.py +++ b/transactron/utils/gen.py @@ -6,6 +6,9 @@ from amaranth.hdl import Fragment from transactron.lib.metrics import HardwareMetricsManager +from transactron.utils._typing import SrcLoc +from transactron.utils.assertion import assert_bits + from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -14,6 +17,7 @@ __all__ = [ "MetricLocation", + "AssertLocation", "GenerationInfo", "generate_verilog", ] @@ -35,6 +39,25 @@ class MetricLocation: regs: dict[str, list[str]] = field(default_factory=dict) +@dataclass_json +@dataclass +class AssertLocation: + """Information about an assert signal in the generated Verilog code. + + Attributes + ---------- + location : list[str] + The location of the assert signal. The location is a list of Verilog + identifiers that denote a path consisting of module names (and the + signal name at the end) leading to the signal wire. + src_loc : SrcLoc + Source location of the assertion. + """ + + location: list[str] + src_loc: SrcLoc + + @dataclass_json @dataclass class GenerationInfo: @@ -45,9 +68,12 @@ class GenerationInfo: metrics_location : dict[str, MetricInfo] Mapping from a metric name to an object storing Verilog locations of its registers. + asserts : list[AssertLocation] + Locations and metadata for assertion signals. """ metrics_location: dict[str, MetricLocation] = field(default_factory=dict) + asserts: list[AssertLocation] = field(default_factory=list) def encode(self, file_name: str): """ @@ -116,12 +142,21 @@ def collect_metric_locations(name_map: "SignalDict") -> dict[str, MetricLocation return metrics_location +def collect_asserts(name_map: "SignalDict") -> list[AssertLocation]: + asserts: list[AssertLocation] = [] + + for v, src_loc in assert_bits(): + asserts.append(AssertLocation(get_signal_location(v, name_map), src_loc)) + + return asserts + + def generate_verilog( top_module: Elaboratable, ports: list[Signal], top_name: str = "top" ) -> tuple[str, GenerationInfo]: fragment = Fragment.get(top_module, platform=None).prepare(ports=ports) verilog_text, name_map = verilog.convert_fragment(fragment, name=top_name, emit_src=True, strip_internal_attrs=True) - gen_info = GenerationInfo(metrics_location=collect_metric_locations(name_map)) # type: ignore + gen_info = GenerationInfo(metrics_location=collect_metric_locations(name_map), asserts=collect_asserts(name_map)) return verilog_text, gen_info