Skip to content

Commit

Permalink
Assertions in cocotb (kuznia-rdzeni/coreblocks#590)
Browse files Browse the repository at this point in the history
  • Loading branch information
tilk authored Mar 5, 2024
1 parent a398d34 commit 8de3624
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 26 deletions.
2 changes: 1 addition & 1 deletion test/testing/test_assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def elaborate(self, platform):

m.d.comb += self.output.eq(self.input & ~self.input)

assertion(self.input == self.output)
assertion(m, self.input == self.output)

return m

Expand Down
7 changes: 2 additions & 5 deletions transactron/testing/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,18 @@
from typing import Any
from amaranth.sim import Passive, Tick
from transactron.utils import assert_bit, assert_bits
from transactron.utils.dependencies import DependencyContext


__all__ = ["make_assert_handler"]


def make_assert_handler(my_assert: Callable[[int, str], Any]):
dependency_manager = DependencyContext.get()

def assert_handler():
yield Passive()
while True:
yield Tick("sync_neg")
if not (yield assert_bit(dependency_manager)):
for v, (n, i) in assert_bits(dependency_manager):
if not (yield assert_bit()):
for v, (n, i) in assert_bits():
my_assert((yield v), f"Assertion at {n}:{i}")
yield

Expand Down
32 changes: 13 additions & 19 deletions transactron/utils/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
import operator
from dataclasses import dataclass
from transactron.utils import SrcLoc
from transactron.utils.dependencies import DependencyContext, DependencyManager, ListKey
from transactron.utils._typing import ModuleLike, ValueLike
from transactron.utils.dependencies import DependencyContext, ListKey

__all__ = ["AssertKey", "assertion", "assert_bit", "assert_bits"]


@dataclass(frozen=True)
class AssertKey(ListKey[tuple[Value, SrcLoc]]):
class AssertKey(ListKey[tuple[Signal, SrcLoc]]):
pass


def assertion(value: Value, *, src_loc_at: int = 0):
def assertion(m: ModuleLike, value: ValueLike, *, src_loc_at: int = 0):
"""Define an assertion.
This function might help find some hardware bugs which might otherwise be
Expand All @@ -24,6 +25,8 @@ def assertion(value: Value, *, src_loc_at: int = 0):
Parameters
----------
m: Module
Module in which the assertion is defined.
value : Value
If the value of this Amaranth expression is false, the assertion will
fail.
Expand All @@ -32,35 +35,26 @@ def assertion(value: Value, *, src_loc_at: int = 0):
identify the failing assertion.
"""
src_loc = get_src_loc(src_loc_at)
sig = Signal()
m.d.comb += sig.eq(value)
dependencies = DependencyContext.get()
dependencies.add_dependency(AssertKey(), (value, src_loc))
dependencies.add_dependency(AssertKey(), (sig, src_loc))


def assert_bits(dependencies: DependencyManager) -> list[tuple[Value, SrcLoc]]:
def assert_bits() -> list[tuple[Signal, SrcLoc]]:
"""Gets assertion bits.
This function returns all the assertion signals created by `assertion`,
together with their source locations.
Parameters
----------
dependencies : DependencyManager
The assertion feature uses the `DependencyManager` to store
assertions.
"""
dependencies = DependencyContext.get()
return dependencies.get_dependency(AssertKey())


def assert_bit(dependencies: DependencyManager) -> Value:
def assert_bit() -> Signal:
"""Gets assertion bit.
The signal returned by this function is false if and only if there exists
a false signal among assertion bits created by `assertion`.
Parameters
----------
dependencies : DependencyManager
The assertion feature uses the `DependencyManager` to store
assertions.
"""
return reduce(operator.and_, [a[0] for a in assert_bits(dependencies)], C(1))
return reduce(operator.and_, [a[0] for a in assert_bits()], C(1))
37 changes: 36 additions & 1 deletion transactron/utils/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from amaranth.hdl import Fragment

from transactron.lib.metrics import HardwareMetricsManager
from transactron.utils._typing import SrcLoc
from transactron.utils.assertion import assert_bits

from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand All @@ -14,6 +17,7 @@

__all__ = [
"MetricLocation",
"AssertLocation",
"GenerationInfo",
"generate_verilog",
]
Expand All @@ -35,6 +39,25 @@ class MetricLocation:
regs: dict[str, list[str]] = field(default_factory=dict)


@dataclass_json
@dataclass
class AssertLocation:
"""Information about an assert signal in the generated Verilog code.
Attributes
----------
location : list[str]
The location of the assert signal. The location is a list of Verilog
identifiers that denote a path consisting of module names (and the
signal name at the end) leading to the signal wire.
src_loc : SrcLoc
Source location of the assertion.
"""

location: list[str]
src_loc: SrcLoc


@dataclass_json
@dataclass
class GenerationInfo:
Expand All @@ -45,9 +68,12 @@ class GenerationInfo:
metrics_location : dict[str, MetricInfo]
Mapping from a metric name to an object storing Verilog locations
of its registers.
asserts : list[AssertLocation]
Locations and metadata for assertion signals.
"""

metrics_location: dict[str, MetricLocation] = field(default_factory=dict)
asserts: list[AssertLocation] = field(default_factory=list)

def encode(self, file_name: str):
"""
Expand Down Expand Up @@ -116,12 +142,21 @@ def collect_metric_locations(name_map: "SignalDict") -> dict[str, MetricLocation
return metrics_location


def collect_asserts(name_map: "SignalDict") -> list[AssertLocation]:
asserts: list[AssertLocation] = []

for v, src_loc in assert_bits():
asserts.append(AssertLocation(get_signal_location(v, name_map), src_loc))

return asserts


def generate_verilog(
top_module: Elaboratable, ports: list[Signal], top_name: str = "top"
) -> tuple[str, GenerationInfo]:
fragment = Fragment.get(top_module, platform=None).prepare(ports=ports)
verilog_text, name_map = verilog.convert_fragment(fragment, name=top_name, emit_src=True, strip_internal_attrs=True)

gen_info = GenerationInfo(metrics_location=collect_metric_locations(name_map)) # type: ignore
gen_info = GenerationInfo(metrics_location=collect_metric_locations(name_map), asserts=collect_asserts(name_map))

return verilog_text, gen_info

0 comments on commit 8de3624

Please sign in to comment.