Skip to content

Commit

Permalink
Assertions in hardware (#538)
Browse files Browse the repository at this point in the history
  • Loading branch information
tilk authored Feb 7, 2024
1 parent 6db5cf0 commit c692d7d
Show file tree
Hide file tree
Showing 14 changed files with 236 additions and 71 deletions.
97 changes: 45 additions & 52 deletions test/frontend/test_instr_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,92 +179,85 @@ def setUp(self):
self.decoder = InstrDecoder(self.gen_params)
self.cnt = 1

def do_test(self, test):
def do_test(self, tests: list[InstrTest]):
def process():
yield self.decoder.instr.eq(test.encoding)
yield Settle()
for test in tests:
yield self.decoder.instr.eq(test.encoding)
yield Settle()

self.assertEqual((yield self.decoder.illegal), test.illegal)
if test.illegal:
return
self.assertEqual((yield self.decoder.illegal), test.illegal)
if test.illegal:
return

self.assertEqual((yield self.decoder.opcode), test.opcode)
self.assertEqual((yield self.decoder.opcode), test.opcode)

if test.funct3 is not None:
self.assertEqual((yield self.decoder.funct3), test.funct3)
self.assertEqual((yield self.decoder.funct3_v), test.funct3 is not None)
if test.funct3 is not None:
self.assertEqual((yield self.decoder.funct3), test.funct3)
self.assertEqual((yield self.decoder.funct3_v), test.funct3 is not None)

if test.funct7 is not None:
self.assertEqual((yield self.decoder.funct7), test.funct7)
self.assertEqual((yield self.decoder.funct7_v), test.funct7 is not None)
if test.funct7 is not None:
self.assertEqual((yield self.decoder.funct7), test.funct7)
self.assertEqual((yield self.decoder.funct7_v), test.funct7 is not None)

if test.funct12 is not None:
self.assertEqual((yield self.decoder.funct12), test.funct12)
self.assertEqual((yield self.decoder.funct12_v), test.funct12 is not None)
if test.funct12 is not None:
self.assertEqual((yield self.decoder.funct12), test.funct12)
self.assertEqual((yield self.decoder.funct12_v), test.funct12 is not None)

if test.rd is not None:
self.assertEqual((yield self.decoder.rd), test.rd)
self.assertEqual((yield self.decoder.rd_v), test.rd is not None)
if test.rd is not None:
self.assertEqual((yield self.decoder.rd), test.rd)
self.assertEqual((yield self.decoder.rd_v), test.rd is not None)

if test.rs1 is not None:
self.assertEqual((yield self.decoder.rs1), test.rs1)
self.assertEqual((yield self.decoder.rs1_v), test.rs1 is not None)
if test.rs1 is not None:
self.assertEqual((yield self.decoder.rs1), test.rs1)
self.assertEqual((yield self.decoder.rs1_v), test.rs1 is not None)

if test.rs2 is not None:
self.assertEqual((yield self.decoder.rs2), test.rs2)
self.assertEqual((yield self.decoder.rs2_v), test.rs2 is not None)
if test.rs2 is not None:
self.assertEqual((yield self.decoder.rs2), test.rs2)
self.assertEqual((yield self.decoder.rs2_v), test.rs2 is not None)

if test.imm is not None:
self.assertEqual((yield self.decoder.imm.as_signed()), test.imm)
if test.imm is not None:
self.assertEqual((yield self.decoder.imm.as_signed()), test.imm)

if test.succ is not None:
self.assertEqual((yield self.decoder.succ), test.succ)
if test.succ is not None:
self.assertEqual((yield self.decoder.succ), test.succ)

if test.pred is not None:
self.assertEqual((yield self.decoder.pred), test.pred)
if test.pred is not None:
self.assertEqual((yield self.decoder.pred), test.pred)

if test.fm is not None:
self.assertEqual((yield self.decoder.fm), test.fm)
if test.fm is not None:
self.assertEqual((yield self.decoder.fm), test.fm)

if test.csr is not None:
self.assertEqual((yield self.decoder.csr), test.csr)
if test.csr is not None:
self.assertEqual((yield self.decoder.csr), test.csr)

self.assertEqual((yield self.decoder.optype), test.op)
self.assertEqual((yield self.decoder.optype), test.op)

with self.run_simulation(self.decoder) as sim:
sim.add_process(process)

def test_i(self):
for test in self.DECODER_TESTS_I:
self.do_test(test)
self.do_test(self.DECODER_TESTS_I)

def test_zifencei(self):
for test in self.DECODER_TESTS_ZIFENCEI:
self.do_test(test)
self.do_test(self.DECODER_TESTS_ZIFENCEI)

def test_zicsr(self):
for test in self.DECODER_TESTS_ZICSR:
self.do_test(test)
self.do_test(self.DECODER_TESTS_ZICSR)

def test_m(self):
for test in self.DECODER_TESTS_M:
self.do_test(test)
self.do_test(self.DECODER_TESTS_M)

def test_illegal(self):
for test in self.DECODER_TESTS_ILLEGAL:
self.do_test(test)
self.do_test(self.DECODER_TESTS_ILLEGAL)

def test_xintmachinemode(self):
for test in self.DECODER_TESTS_XINTMACHINEMODE:
self.do_test(test)
self.do_test(self.DECODER_TESTS_XINTMACHINEMODE)

def test_xintsupervisor(self):
for test in self.DECODER_TESTS_XINTSUPERVISOR:
self.do_test(test)
self.do_test(self.DECODER_TESTS_XINTSUPERVISOR)

def test_zbb(self):
for test in self.DECODER_TESTS_ZBB:
self.do_test(test)
self.do_test(self.DECODER_TESTS_ZBB)


class TestDecoderEExtLegal(TestCaseWithSimulator):
Expand Down
5 changes: 3 additions & 2 deletions test/transactions/test_branches.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from unittest import TestCase
from transactron.testing import TestCaseWithSimulator
from transactron.utils.dependencies import DependencyContext


class TestExclusivePath(TestCase):
Expand Down Expand Up @@ -87,9 +88,9 @@ def test_conflict_removal(self):
circ = ExclusiveConflictRemovalCircuit()

tm = TransactionManager()
dut = TransactionModule(circ, tm)
dut = TransactionModule(circ, DependencyContext.get(), tm)

with self.run_simulation(dut):
with self.run_simulation(dut, add_transaction_module=False):
pass

cgr, _, _ = tm._conflict_graph(MethodMap(tm.transactions))
Expand Down
7 changes: 3 additions & 4 deletions test/transactions/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,10 @@ def _(arg):


class TestQuadrupleCircuits(TestCaseWithSimulator):
def test(self):
self.work(QuadrupleCircuit(Quadruple()))
self.work(QuadrupleCircuit(Quadruple2()))
@parameterized.expand([(Quadruple,), (Quadruple2,)])
def test(self, quadruple):
circ = QuadrupleCircuit(quadruple())

def work(self, circ):
def process():
for n in range(1 << (WIDTH - 2)):
out = yield from circ.tb.call(data=n)
Expand Down
3 changes: 2 additions & 1 deletion test/transactions/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
trivial_roundrobin_cc_scheduler,
eager_deterministic_cc_scheduler,
)
from transactron.utils.dependencies import DependencyContext


class TestNames(TestCase):
Expand Down Expand Up @@ -110,7 +111,7 @@ def __init__(self, scheduler):

def elaborate(self, platform):
m = TModule()
tm = TransactionModule(m, TransactionManager(self.scheduler))
tm = TransactionModule(m, DependencyContext.get(), TransactionManager(self.scheduler))
adapter = Adapter(i=data_layout(32), o=data_layout(32))
m.submodules.out = self.out = TestbenchIO(adapter)
m.submodules.in1 = self.in1 = TestbenchIO(AdapterTrans(adapter.iface))
Expand Down
Empty file added test/transactron/__init__.py
Empty file.
32 changes: 32 additions & 0 deletions test/transactron/testing/test_assertion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from amaranth import *

from transactron.utils import assertion
from transactron.testing import TestCaseWithSimulator


class AssertionTest(Elaboratable):
def __init__(self):
self.input = Signal()
self.output = Signal()

def elaborate(self, platform):
m = Module()

m.d.comb += self.output.eq(self.input & ~self.input)

assertion(self.input == self.output)

return m


class TestAssertion(TestCaseWithSimulator):
def test_assertion(self):
m = AssertionTest()

def proc():
yield
yield m.input.eq(1)

with self.assertRaises(AssertionError):
with self.run_simulation(m) as sim:
sim.add_sync_process(proc)
27 changes: 20 additions & 7 deletions transactron/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,18 +567,31 @@ class TransactionModule(Elaboratable):
The `TransactionManager` is stored in a `DependencyManager`.
"""

def __init__(self, elaboratable: HasElaborate, manager: Optional[TransactionManager] = None):
def __init__(
self,
elaboratable: HasElaborate,
dependency_manager: Optional[DependencyManager] = None,
transaction_manager: Optional[TransactionManager] = None,
):
"""
Parameters
----------
elaboratable: HasElaborate
The `Elaboratable` which should be wrapped to add support for
transactions and methods.
The `Elaboratable` which should be wrapped to add support for
transactions and methods.
dependency_manager: DependencyManager, optional
The `DependencyManager` to use inside the transaction module.
If omitted, a new one is created.
transaction_manager: TransactionManager, optional
The `TransactionManager` to use inside the transaction module.
If omitted, a new one is created.
"""
if manager is None:
manager = TransactionManager()
self.manager = DependencyManager()
self.manager.add_dependency(TransactionManagerKey(), manager)
if transaction_manager is None:
transaction_manager = TransactionManager()
if dependency_manager is None:
dependency_manager = DependencyManager()
self.manager = dependency_manager
self.manager.add_dependency(TransactionManagerKey(), transaction_manager)
self.elaboratable = elaboratable

def context(self) -> DependencyContext:
Expand Down
1 change: 1 addition & 0 deletions transactron/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
from .sugar import * # noqa: F401
from .testbenchio import * # noqa: F401
from .profiler import * # noqa: F401
from .assertion import * # noqa: F401
from transactron.utils import data_layout # noqa: F401
23 changes: 23 additions & 0 deletions transactron/testing/assertion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from collections.abc import Callable
from typing import Any
from amaranth.sim import Passive, Tick
from transactron.utils import assert_bit, assert_bits
from transactron.utils.dependencies import DependencyContext


__all__ = ["make_assert_handler"]


def make_assert_handler(my_assert: Callable[[int, str], Any]):
dependency_manager = DependencyContext.get()

def assert_handler():
yield Passive()
while True:
yield Tick("sync_neg")
if not (yield assert_bit(dependency_manager)):
for v, (n, i) in assert_bits(dependency_manager):
my_assert((yield v), f"Assertion at {n}:{i}")
yield

return assert_handler
40 changes: 37 additions & 3 deletions transactron/testing/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from abc import ABC
from amaranth import *
from amaranth.sim import *

from transactron.utils.dependencies import DependencyContext, DependencyManager
from .testbenchio import TestbenchIO
from .profiler import profiler_process, Profile
from .functions import TestGen
from .assertion import make_assert_handler
from .gtkw_extension import write_vcd_ext
from transactron import Method
from transactron.lib import AdapterTrans
Expand Down Expand Up @@ -90,8 +93,12 @@ def debug_signals(self):


class _TestModule(Elaboratable):
def __init__(self, tested_module: HasElaborate, add_transaction_module):
self.tested_module = TransactionModule(tested_module) if add_transaction_module else tested_module
def __init__(self, tested_module: HasElaborate, add_transaction_module: bool):
self.tested_module = (
TransactionModule(tested_module, dependency_manager=DependencyContext.get())
if add_transaction_module
else tested_module
)
self.add_transaction_module = add_transaction_module

def elaborate(self, platform) -> HasElaborate:
Expand All @@ -103,6 +110,8 @@ def elaborate(self, platform) -> HasElaborate:

m.submodules.tested_module = self.tested_module

m.domains.sync_neg = ClockDomain(clk_edge="neg", local=True)

return m


Expand Down Expand Up @@ -154,6 +163,7 @@ def __init__(
super().__init__(test_module)

self.add_clock(clk_period)
self.add_clock(clk_period, domain="sync_neg")

if isinstance(tested_module, HasDebugSignals):
extra_signals = tested_module.debug_signals
Expand Down Expand Up @@ -192,6 +202,27 @@ def run(self) -> bool:


class TestCaseWithSimulator(unittest.TestCase):
dependency_manager: DependencyManager

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.dependency_manager = DependencyManager()

def wrap(f: Callable[[], None]):
@functools.wraps(f)
def wrapper():
with DependencyContext(self.dependency_manager):
f()

return wrapper

for k in dir(self):
if k.startswith("test"):
f = getattr(self, k)
if isinstance(f, Callable):
setattr(self, k, wrap(getattr(self, k)))

def add_class_mocks(self, sim: PysimSimulator) -> None:
for key in dir(self):
val = getattr(self, key)
Expand Down Expand Up @@ -222,15 +253,18 @@ def run_simulation(self, module: HasElaborate, max_cycles: float = 10e4, add_tra
clk_period=clk_period,
)
self.add_all_mocks(sim, sys._getframe(2).f_locals)

yield sim

profile = None
if "__TRANSACTRON_PROFILE" in os.environ and isinstance(sim.tested_module, TransactionModule):
profile = Profile()
sim.add_sync_process(
profiler_process(sim.tested_module.manager.get_dependency(TransactionManagerKey()), profile, clk_period)
profiler_process(sim.tested_module.manager.get_dependency(TransactionManagerKey()), profile)
)

sim.add_sync_process(make_assert_handler(self.assertTrue))

res = sim.run()

if profile is not None:
Expand Down
Loading

0 comments on commit c692d7d

Please sign in to comment.