From a5dc003d91ab7930a0933968be67c60b348cfa74 Mon Sep 17 00:00:00 2001 From: lekcyjna123 <34948061+lekcyjna123@users.noreply.github.com> Date: Mon, 15 Apr 2024 14:30:15 +0200 Subject: [PATCH] Move tests after https://github.com/kuznia-rdzeni/coreblocks/pull/620 (https://github.com/kuznia-rdzeni/coreblocks/pull/644) --- test/core/__init__.py | 0 test/core/test_transactions.py | 447 +++++++++++++++++ test/lib/__init__.py | 0 test/lib/test_fifo.py | 79 +++ test/lib/test_transaction_lib.py | 805 +++++++++++++++++++++++++++++++ test/test_adapter.py | 62 +++ test/test_assign.py | 125 +++++ test/test_branches.py | 99 ++++ test/test_methods.py | 644 +++++++++++++++++++++++++ test/test_simultaneous.py | 172 +++++++ test/utils/__init__.py | 0 test/utils/test_onehotswitch.py | 62 +++ test/utils/test_utils.py | 196 ++++++++ 13 files changed, 2691 insertions(+) create mode 100644 test/core/__init__.py create mode 100644 test/core/test_transactions.py create mode 100644 test/lib/__init__.py create mode 100644 test/lib/test_fifo.py create mode 100644 test/lib/test_transaction_lib.py create mode 100644 test/test_adapter.py create mode 100644 test/test_assign.py create mode 100644 test/test_branches.py create mode 100644 test/test_methods.py create mode 100644 test/test_simultaneous.py create mode 100644 test/utils/__init__.py create mode 100644 test/utils/test_onehotswitch.py create mode 100644 test/utils/test_utils.py diff --git a/test/core/__init__.py b/test/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/core/test_transactions.py b/test/core/test_transactions.py new file mode 100644 index 0000000..5d68b77 --- /dev/null +++ b/test/core/test_transactions.py @@ -0,0 +1,447 @@ +from unittest.case import TestCase +import pytest +from amaranth import * +from amaranth.sim import * + +import random +import contextlib + +from collections import deque +from typing import Iterable, Callable +from parameterized import parameterized, parameterized_class + +from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout + +from transactron import * +from transactron.lib import Adapter, AdapterTrans +from transactron.utils import Scheduler + +from transactron.core import Priority +from transactron.core.schedulers import trivial_roundrobin_cc_scheduler, eager_deterministic_cc_scheduler +from transactron.core.manager import TransactionScheduler +from transactron.utils.dependencies import DependencyContext + + +class TestNames(TestCase): + def test_names(self): + mgr = TransactionManager() + mgr._MustUse__silence = True # type: ignore + + class T(Elaboratable): + def __init__(self): + self._MustUse__silence = True # type: ignore + Transaction(manager=mgr) + + T() + assert mgr.transactions[0].name == "T" + + t = Transaction(name="x", manager=mgr) + assert t.name == "x" + + t = Transaction(manager=mgr) + assert t.name == "t" + + m = Method(name="x") + assert m.name == "x" + + m = Method() + assert m.name == "m" + + +class TestScheduler(TestCaseWithSimulator): + def count_test(self, sched, cnt): + assert sched.count == cnt + assert len(sched.requests) == cnt + assert len(sched.grant) == cnt + assert len(sched.valid) == 1 + + def sim_step(self, sched, request, expected_grant): + yield sched.requests.eq(request) + yield + + if request == 0: + assert not (yield sched.valid) + else: + assert (yield sched.grant) == expected_grant + assert (yield sched.valid) + + def test_single(self): + sched = Scheduler(1) + self.count_test(sched, 1) + + def process(): + yield from self.sim_step(sched, 0, 0) + yield from self.sim_step(sched, 1, 1) + yield from self.sim_step(sched, 1, 1) + yield from self.sim_step(sched, 0, 0) + + with self.run_simulation(sched) as sim: + sim.add_sync_process(process) + + def test_multi(self): + sched = Scheduler(4) + self.count_test(sched, 4) + + def process(): + yield from self.sim_step(sched, 0b0000, 0b0000) + yield from self.sim_step(sched, 0b1010, 0b0010) + yield from self.sim_step(sched, 0b1010, 0b1000) + yield from self.sim_step(sched, 0b1010, 0b0010) + yield from self.sim_step(sched, 0b1001, 0b1000) + yield from self.sim_step(sched, 0b1001, 0b0001) + + yield from self.sim_step(sched, 0b1111, 0b0010) + yield from self.sim_step(sched, 0b1111, 0b0100) + yield from self.sim_step(sched, 0b1111, 0b1000) + yield from self.sim_step(sched, 0b1111, 0b0001) + + yield from self.sim_step(sched, 0b0000, 0b0000) + yield from self.sim_step(sched, 0b0010, 0b0010) + yield from self.sim_step(sched, 0b0010, 0b0010) + + with self.run_simulation(sched) as sim: + sim.add_sync_process(process) + + +class TransactionConflictTestCircuit(Elaboratable): + def __init__(self, scheduler): + self.scheduler = scheduler + + def elaborate(self, platform): + m = TModule() + tm = TransactionModule(m, DependencyContext.get(), TransactionManager(self.scheduler)) + adapter = Adapter(i=data_layout(32), o=data_layout(32)) + m.submodules.out = self.out = TestbenchIO(adapter) + m.submodules.in1 = self.in1 = TestbenchIO(AdapterTrans(adapter.iface)) + m.submodules.in2 = self.in2 = TestbenchIO(AdapterTrans(adapter.iface)) + return tm + + +@parameterized_class( + ("name", "scheduler"), + [ + ("trivial_roundrobin", trivial_roundrobin_cc_scheduler), + ("eager_deterministic", eager_deterministic_cc_scheduler), + ], +) +class TestTransactionConflict(TestCaseWithSimulator): + scheduler: TransactionScheduler + + def setup_method(self): + random.seed(42) + + def make_process( + self, io: TestbenchIO, prob: float, src: Iterable[int], tgt: Callable[[int], None], chk: Callable[[int], None] + ): + def process(): + for i in src: + while random.random() >= prob: + yield + tgt(i) + r = yield from io.call(data=i) + chk(r["data"]) + + return process + + def make_in1_process(self, prob: float): + def tgt(x: int): + self.out1_expected.append(x) + + def chk(x: int): + assert x == self.in_expected.popleft() + + return self.make_process(self.m.in1, prob, self.in1_stream, tgt, chk) + + def make_in2_process(self, prob: float): + def tgt(x: int): + self.out2_expected.append(x) + + def chk(x: int): + assert x == self.in_expected.popleft() + + return self.make_process(self.m.in2, prob, self.in2_stream, tgt, chk) + + def make_out_process(self, prob: float): + def tgt(x: int): + self.in_expected.append(x) + + def chk(x: int): + if self.out1_expected and x == self.out1_expected[0]: + self.out1_expected.popleft() + elif self.out2_expected and x == self.out2_expected[0]: + self.out2_expected.popleft() + else: + assert False, "%d not found in any of the queues" % x + + return self.make_process(self.m.out, prob, self.out_stream, tgt, chk) + + @parameterized.expand( + [ + ("fullcontention", 1, 1, 1), + ("highcontention", 0.5, 0.5, 0.75), + ("lowcontention", 0.1, 0.1, 0.5), + ] + ) + def test_calls(self, name, prob1, prob2, probout): + self.in1_stream = range(0, 100) + self.in2_stream = range(100, 200) + self.out_stream = range(200, 400) + self.in_expected = deque() + self.out1_expected = deque() + self.out2_expected = deque() + self.m = TransactionConflictTestCircuit(self.__class__.scheduler) + + with self.run_simulation(self.m, add_transaction_module=False) as sim: + sim.add_sync_process(self.make_in1_process(prob1)) + sim.add_sync_process(self.make_in2_process(prob2)) + sim.add_sync_process(self.make_out_process(probout)) + + assert not self.in_expected + assert not self.out1_expected + assert not self.out2_expected + + +class SchedulingTestCircuit(Elaboratable): + def __init__(self): + self.r1 = Signal() + self.r2 = Signal() + self.t1 = Signal() + self.t2 = Signal() + + +class PriorityTestCircuit(SchedulingTestCircuit): + def __init__(self, priority: Priority, unsatisfiable=False): + super().__init__() + self.priority = priority + self.unsatisfiable = unsatisfiable + + def make_relations(self, t1: Transaction | Method, t2: Transaction | Method): + t1.add_conflict(t2, self.priority) + if self.unsatisfiable: + t2.add_conflict(t1, self.priority) + + +class TransactionPriorityTestCircuit(PriorityTestCircuit): + def elaborate(self, platform): + m = TModule() + + transaction1 = Transaction() + transaction2 = Transaction() + + with transaction1.body(m, request=self.r1): + m.d.comb += self.t1.eq(1) + + with transaction2.body(m, request=self.r2): + m.d.comb += self.t2.eq(1) + + self.make_relations(transaction1, transaction2) + + return m + + +class MethodPriorityTestCircuit(PriorityTestCircuit): + def elaborate(self, platform): + m = TModule() + + method1 = Method() + method2 = Method() + + @def_method(m, method1, ready=self.r1) + def _(): + m.d.comb += self.t1.eq(1) + + @def_method(m, method2, ready=self.r2) + def _(): + m.d.comb += self.t2.eq(1) + + with Transaction().body(m): + method1(m) + + with Transaction().body(m): + method2(m) + + self.make_relations(method1, method2) + + return m + + +@parameterized_class( + ("name", "circuit"), [("transaction", TransactionPriorityTestCircuit), ("method", MethodPriorityTestCircuit)] +) +class TestTransactionPriorities(TestCaseWithSimulator): + circuit: type[PriorityTestCircuit] + + def setup_method(self): + random.seed(42) + + @parameterized.expand([(Priority.UNDEFINED,), (Priority.LEFT,), (Priority.RIGHT,)]) + def test_priorities(self, priority: Priority): + m = self.circuit(priority) + + def process(): + to_do = 5 * [(0, 1), (1, 0), (1, 1)] + random.shuffle(to_do) + for r1, r2 in to_do: + yield m.r1.eq(r1) + yield m.r2.eq(r2) + yield Settle() + assert (yield m.t1) != (yield m.t2) + if r1 == 1 and r2 == 1: + if priority == Priority.LEFT: + assert (yield m.t1) + if priority == Priority.RIGHT: + assert (yield m.t2) + + with self.run_simulation(m) as sim: + sim.add_process(process) + + @parameterized.expand([(Priority.UNDEFINED,), (Priority.LEFT,), (Priority.RIGHT,)]) + def test_unsatisfiable(self, priority: Priority): + m = self.circuit(priority, True) + + import graphlib + + if priority != Priority.UNDEFINED: + cm = pytest.raises(graphlib.CycleError) + else: + cm = contextlib.nullcontext() + + with cm: + with self.run_simulation(m): + pass + + +class NestedTransactionsTestCircuit(SchedulingTestCircuit): + def elaborate(self, platform): + m = TModule() + tm = TransactionModule(m) + + with tm.context(): + with Transaction().body(m, request=self.r1): + m.d.comb += self.t1.eq(1) + with Transaction().body(m, request=self.r2): + m.d.comb += self.t2.eq(1) + + return tm + + +class NestedMethodsTestCircuit(SchedulingTestCircuit): + def elaborate(self, platform): + m = TModule() + tm = TransactionModule(m) + + method1 = Method() + method2 = Method() + + @def_method(m, method1, ready=self.r1) + def _(): + m.d.comb += self.t1.eq(1) + + @def_method(m, method2, ready=self.r2) + def _(): + m.d.comb += self.t2.eq(1) + + with tm.context(): + with Transaction().body(m): + method1(m) + + with Transaction().body(m): + method2(m) + + return tm + + +@parameterized_class( + ("name", "circuit"), [("transaction", NestedTransactionsTestCircuit), ("method", NestedMethodsTestCircuit)] +) +class TestNested(TestCaseWithSimulator): + circuit: type[SchedulingTestCircuit] + + def setup_method(self): + random.seed(42) + + def test_scheduling(self): + m = self.circuit() + + def process(): + to_do = 5 * [(0, 1), (1, 0), (1, 1)] + random.shuffle(to_do) + for r1, r2 in to_do: + yield m.r1.eq(r1) + yield m.r2.eq(r2) + yield + assert (yield m.t1) == r1 + assert (yield m.t2) == r1 * r2 + + with self.run_simulation(m) as sim: + sim.add_sync_process(process) + + +class ScheduleBeforeTestCircuit(SchedulingTestCircuit): + def elaborate(self, platform): + m = TModule() + tm = TransactionModule(m) + + method = Method() + + @def_method(m, method) + def _(): + pass + + with tm.context(): + with (t1 := Transaction()).body(m, request=self.r1): + method(m) + m.d.comb += self.t1.eq(1) + + with (t2 := Transaction()).body(m, request=self.r2 & t1.grant): + method(m) + m.d.comb += self.t2.eq(1) + + t1.schedule_before(t2) + + return tm + + +class TestScheduleBefore(TestCaseWithSimulator): + def setup_method(self): + random.seed(42) + + def test_schedule_before(self): + m = ScheduleBeforeTestCircuit() + + def process(): + to_do = 5 * [(0, 1), (1, 0), (1, 1)] + random.shuffle(to_do) + for r1, r2 in to_do: + yield m.r1.eq(r1) + yield m.r2.eq(r2) + yield + assert (yield m.t1) == r1 + assert not (yield m.t2) + + with self.run_simulation(m) as sim: + sim.add_sync_process(process) + + +class SingleCallerTestCircuit(Elaboratable): + def elaborate(self, platform): + m = TModule() + + method = Method(single_caller=True) + + with Transaction().body(m): + method(m) + + with Transaction().body(m): + method(m) + + return m + + +class TestSingleCaller(TestCaseWithSimulator): + def test_single_caller(self): + m = SingleCallerTestCircuit() + + with pytest.raises(RuntimeError): + with self.run_simulation(m): + pass diff --git a/test/lib/__init__.py b/test/lib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/lib/test_fifo.py b/test/lib/test_fifo.py new file mode 100644 index 0000000..1db7d44 --- /dev/null +++ b/test/lib/test_fifo.py @@ -0,0 +1,79 @@ +from amaranth import * +from amaranth.sim import Settle + +from transactron.lib import AdapterTrans, BasicFifo + +from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout +from collections import deque +from parameterized import parameterized_class +import random + + +class BasicFifoTestCircuit(Elaboratable): + def __init__(self, depth): + self.depth = depth + + def elaborate(self, platform): + m = Module() + + m.submodules.fifo = self.fifo = BasicFifo(layout=data_layout(8), depth=self.depth) + + m.submodules.fifo_read = self.fifo_read = TestbenchIO(AdapterTrans(self.fifo.read)) + m.submodules.fifo_write = self.fifo_write = TestbenchIO(AdapterTrans(self.fifo.write)) + m.submodules.fifo_clear = self.fifo_clear = TestbenchIO(AdapterTrans(self.fifo.clear)) + + return m + + +@parameterized_class( + ("name", "depth"), + [ + ("notpower", 5), + ("power", 4), + ], +) +class TestBasicFifo(TestCaseWithSimulator): + depth: int + + def test_randomized(self): + fifoc = BasicFifoTestCircuit(depth=self.depth) + expq = deque() + + cycles = 256 + random.seed(42) + + self.done = False + + def source(): + for _ in range(cycles): + if random.randint(0, 1): + yield # random delay + + v = random.randint(0, (2**fifoc.fifo.width) - 1) + yield from fifoc.fifo_write.call(data=v) + expq.appendleft(v) + + if random.random() < 0.005: + yield from fifoc.fifo_clear.call() + yield Settle() + expq.clear() + + self.done = True + + def target(): + while not self.done or expq: + if random.randint(0, 1): + yield + + yield from fifoc.fifo_read.call_init() + yield + + v = yield from fifoc.fifo_read.call_result() + if v is not None: + assert v["data"] == expq.pop() + + yield from fifoc.fifo_read.disable() + + with self.run_simulation(fifoc) as sim: + sim.add_sync_process(source) + sim.add_sync_process(target) diff --git a/test/lib/test_transaction_lib.py b/test/lib/test_transaction_lib.py new file mode 100644 index 0000000..912c546 --- /dev/null +++ b/test/lib/test_transaction_lib.py @@ -0,0 +1,805 @@ +import pytest +from itertools import product +import random +import itertools +from operator import and_ +from functools import reduce +from amaranth.sim import Settle, Passive +from typing import Optional, TypeAlias +from parameterized import parameterized +from collections import deque + +from amaranth import * +from transactron import * +from transactron.lib import * +from transactron.utils._typing import ModuleLike, MethodStruct, RecordDict +from transactron.utils import ModuleConnector +from transactron.testing import ( + SimpleTestCircuit, + TestCaseWithSimulator, + TestbenchIO, + data_layout, + def_method_mock, +) + + +class RevConnect(Elaboratable): + def __init__(self, layout: MethodLayout): + self.connect = Connect(rev_layout=layout) + self.read = self.connect.write + self.write = self.connect.read + + def elaborate(self, platform): + return self.connect + + +FIFO_Like: TypeAlias = FIFO | Forwarder | Connect | RevConnect | Pipe + + +class TestFifoBase(TestCaseWithSimulator): + def do_test_fifo( + self, fifo_class: type[FIFO_Like], writer_rand: int = 0, reader_rand: int = 0, fifo_kwargs: dict = {} + ): + iosize = 8 + + m = SimpleTestCircuit(fifo_class(data_layout(iosize), **fifo_kwargs)) + + random.seed(1337) + + def writer(): + for i in range(2**iosize): + yield from m.write.call(data=i) + yield from self.random_wait(writer_rand) + + def reader(): + for i in range(2**iosize): + assert (yield from m.read.call()) == {"data": i} + yield from self.random_wait(reader_rand) + + with self.run_simulation(m) as sim: + sim.add_sync_process(reader) + sim.add_sync_process(writer) + + +class TestFIFO(TestFifoBase): + @parameterized.expand([(0, 0), (2, 0), (0, 2), (1, 1)]) + def test_fifo(self, writer_rand, reader_rand): + self.do_test_fifo(FIFO, writer_rand=writer_rand, reader_rand=reader_rand, fifo_kwargs=dict(depth=4)) + + +class TestConnect(TestFifoBase): + @parameterized.expand([(0, 0), (2, 0), (0, 2), (1, 1)]) + def test_fifo(self, writer_rand, reader_rand): + self.do_test_fifo(Connect, writer_rand=writer_rand, reader_rand=reader_rand) + + @parameterized.expand([(0, 0), (2, 0), (0, 2), (1, 1)]) + def test_rev_fifo(self, writer_rand, reader_rand): + self.do_test_fifo(RevConnect, writer_rand=writer_rand, reader_rand=reader_rand) + + +class TestForwarder(TestFifoBase): + @parameterized.expand([(0, 0), (2, 0), (0, 2), (1, 1)]) + def test_fifo(self, writer_rand, reader_rand): + self.do_test_fifo(Forwarder, writer_rand=writer_rand, reader_rand=reader_rand) + + def test_forwarding(self): + iosize = 8 + + m = SimpleTestCircuit(Forwarder(data_layout(iosize))) + + def forward_check(x): + yield from m.read.call_init() + yield from m.write.call_init(data=x) + yield Settle() + assert (yield from m.read.call_result()) == {"data": x} + assert (yield from m.write.call_result()) is not None + yield + + def process(): + # test forwarding behavior + for x in range(4): + yield from forward_check(x) + + # load the overflow buffer + yield from m.read.disable() + yield from m.write.call_init(data=42) + yield Settle() + assert (yield from m.write.call_result()) is not None + yield + + # writes are not possible now + yield from m.write.call_init(data=84) + yield Settle() + assert (yield from m.write.call_result()) is None + yield + + # read from the overflow buffer, writes still blocked + yield from m.read.enable() + yield from m.write.call_init(data=111) + yield Settle() + assert (yield from m.read.call_result()) == {"data": 42} + assert (yield from m.write.call_result()) is None + yield + + # forwarding now works again + for x in range(4): + yield from forward_check(x) + + with self.run_simulation(m) as sim: + sim.add_sync_process(process) + + +class TestPipe(TestFifoBase): + @parameterized.expand([(0, 0), (2, 0), (0, 2), (1, 1)]) + def test_fifo(self, writer_rand, reader_rand): + self.do_test_fifo(Pipe, writer_rand=writer_rand, reader_rand=reader_rand) + + def test_pipelining(self): + self.do_test_fifo(Pipe, writer_rand=0, reader_rand=0) + + +class TestMemoryBank(TestCaseWithSimulator): + test_conf = [(9, 3, 3, 3, 14), (16, 1, 1, 3, 15), (16, 1, 1, 1, 16), (12, 3, 1, 1, 17)] + + parametrized_input = [tc + sf for tc, sf in itertools.product(test_conf, [(True,), (False,)])] + + @parameterized.expand(parametrized_input) + def test_mem(self, max_addr, writer_rand, reader_req_rand, reader_resp_rand, seed, safe_writes): + test_count = 200 + + data_width = 6 + m = SimpleTestCircuit( + MemoryBank(data_layout=[("data", data_width)], elem_count=max_addr, safe_writes=safe_writes) + ) + + data: list[int] = list(0 for _ in range(max_addr)) + read_req_queue = deque() + addr_queue = deque() + + random.seed(seed) + + def writer(): + for cycle in range(test_count): + d = random.randrange(2**data_width) + a = random.randrange(max_addr) + yield from m.write.call(data=d, addr=a) + for _ in range(2): + yield Settle() + data[a] = d + yield from self.random_wait(writer_rand, min_cycle_cnt=1) + + def reader_req(): + for cycle in range(test_count): + a = random.randrange(max_addr) + yield from m.read_req.call(addr=a) + for _ in range(1): + yield Settle() + if safe_writes: + d = data[a] + read_req_queue.append(d) + else: + addr_queue.append((cycle, a)) + yield from self.random_wait(reader_req_rand, min_cycle_cnt=1) + + def reader_resp(): + for cycle in range(test_count): + while not read_req_queue: + yield from self.random_wait(reader_resp_rand, min_cycle_cnt=1) + d = read_req_queue.popleft() + assert (yield from m.read_resp.call()) == {"data": d} + yield from self.random_wait(reader_resp_rand, min_cycle_cnt=1) + + def internal_reader_resp(): + assert m._dut._internal_read_resp_trans is not None + yield Passive() + while True: + if addr_queue: + instr, a = addr_queue[0] + else: + yield + continue + d = data[a] + # check when internal method has been run to capture + # memory state for tests purposes + if (yield m._dut._internal_read_resp_trans.grant): + addr_queue.popleft() + read_req_queue.append(d) + yield + + with self.run_simulation(m) as sim: + sim.add_sync_process(reader_req) + sim.add_sync_process(reader_resp) + sim.add_sync_process(writer) + if not safe_writes: + sim.add_sync_process(internal_reader_resp) + + def test_pipelined(self): + data_width = 6 + max_addr = 9 + m = SimpleTestCircuit(MemoryBank(data_layout=[("data", data_width)], elem_count=max_addr, safe_writes=False)) + + random.seed(14) + + def process(): + a = 3 + d1 = random.randrange(2**data_width) + yield from m.write.call_init(data=d1, addr=a) + yield from m.read_req.call_init(addr=a) + yield + d2 = random.randrange(2**data_width) + yield from m.write.call_init(data=d2, addr=a) + yield from m.read_resp.call_init() + yield + yield from m.write.disable() + yield from m.read_req.disable() + ret_d1 = (yield from m.read_resp.call_result())["data"] + assert d1 == ret_d1 + yield + ret_d2 = (yield from m.read_resp.call_result())["data"] + assert d2 == ret_d2 + + with self.run_simulation(m) as sim: + sim.add_sync_process(process) + + +class TestAsyncMemoryBank(TestCaseWithSimulator): + @parameterized.expand([(9, 3, 3, 14), (16, 1, 1, 15), (16, 1, 1, 16), (12, 3, 1, 17)]) + def test_mem(self, max_addr, writer_rand, reader_rand, seed): + test_count = 200 + + data_width = 6 + m = SimpleTestCircuit(AsyncMemoryBank(data_layout=[("data", data_width)], elem_count=max_addr)) + + data: list[int] = list(0 for i in range(max_addr)) + + random.seed(seed) + + def writer(): + for cycle in range(test_count): + d = random.randrange(2**data_width) + a = random.randrange(max_addr) + yield from m.write.call(data=d, addr=a) + for _ in range(2): + yield Settle() + data[a] = d + yield from self.random_wait(writer_rand, min_cycle_cnt=1) + + def reader(): + for cycle in range(test_count): + a = random.randrange(max_addr) + d = yield from m.read.call(addr=a) + for _ in range(1): + yield Settle() + expected_d = data[a] + assert d["data"] == expected_d + yield from self.random_wait(reader_rand, min_cycle_cnt=1) + + with self.run_simulation(m) as sim: + sim.add_sync_process(reader) + sim.add_sync_process(writer) + + +class ManyToOneConnectTransTestCircuit(Elaboratable): + def __init__(self, count: int, lay: MethodLayout): + self.count = count + self.lay = lay + self.inputs = [] + + def elaborate(self, platform): + m = TModule() + + get_results = [] + for i in range(self.count): + input = TestbenchIO(Adapter(o=self.lay)) + get_results.append(input.adapter.iface) + m.submodules[f"input_{i}"] = input + self.inputs.append(input) + + # Create ManyToOneConnectTrans, which will serialize results from different inputs + output = TestbenchIO(Adapter(i=self.lay)) + m.submodules.output = output + self.output = output + m.submodules.fu_arbitration = ManyToOneConnectTrans(get_results=get_results, put_result=output.adapter.iface) + + return m + + +class TestManyToOneConnectTrans(TestCaseWithSimulator): + def initialize(self): + f1_size = 14 + f2_size = 3 + self.lay = [("field1", f1_size), ("field2", f2_size)] + + self.m = ManyToOneConnectTransTestCircuit(self.count, self.lay) + random.seed(14) + + self.inputs = [] + # Create list with info if we processed all data from inputs + self.producer_end = [False for i in range(self.count)] + self.expected_output = {} + self.max_wait = 4 + + # Prepare random results for inputs + for i in range(self.count): + data = [] + input_size = random.randint(20, 30) + for j in range(input_size): + t = ( + random.randint(0, 2**f1_size), + random.randint(0, 2**f2_size), + ) + data.append(t) + if t in self.expected_output: + self.expected_output[t] += 1 + else: + self.expected_output[t] = 1 + self.inputs.append(data) + + def generate_producer(self, i: int): + """ + This is an helper function, which generates a producer process, + which will simulate an FU. Producer will insert in random intervals new + results to its output FIFO. This records will be next serialized by FUArbiter. + """ + + def producer(): + inputs = self.inputs[i] + for field1, field2 in inputs: + io: TestbenchIO = self.m.inputs[i] + yield from io.call_init(field1=field1, field2=field2) + yield from self.random_wait(self.max_wait) + self.producer_end[i] = True + + return producer + + def consumer(self): + while reduce(and_, self.producer_end, True): + result = yield from self.m.output.call_do() + + assert result is not None + + t = (result["field1"], result["field2"]) + assert t in self.expected_output + if self.expected_output[t] == 1: + del self.expected_output[t] + else: + self.expected_output[t] -= 1 + yield from self.random_wait(self.max_wait) + + def test_one_out(self): + self.count = 1 + self.initialize() + with self.run_simulation(self.m) as sim: + sim.add_sync_process(self.consumer) + for i in range(self.count): + sim.add_sync_process(self.generate_producer(i)) + + def test_many_out(self): + self.count = 4 + self.initialize() + with self.run_simulation(self.m) as sim: + sim.add_sync_process(self.consumer) + for i in range(self.count): + sim.add_sync_process(self.generate_producer(i)) + + +class MethodMapTestCircuit(Elaboratable): + def __init__(self, iosize: int, use_methods: bool, use_dicts: bool): + self.iosize = iosize + self.use_methods = use_methods + self.use_dicts = use_dicts + + def elaborate(self, platform): + m = TModule() + + layout = data_layout(self.iosize) + + def itransform_rec(m: ModuleLike, v: MethodStruct) -> MethodStruct: + s = Signal.like(v) + m.d.comb += s.data.eq(v.data + 1) + return s + + def otransform_rec(m: ModuleLike, v: MethodStruct) -> MethodStruct: + s = Signal.like(v) + m.d.comb += s.data.eq(v.data - 1) + return s + + def itransform_dict(_, v: MethodStruct) -> RecordDict: + return {"data": v.data + 1} + + def otransform_dict(_, v: MethodStruct) -> RecordDict: + return {"data": v.data - 1} + + if self.use_dicts: + itransform = itransform_dict + otransform = otransform_dict + else: + itransform = itransform_rec + otransform = otransform_rec + + m.submodules.target = self.target = TestbenchIO(Adapter(i=layout, o=layout)) + + if self.use_methods: + imeth = Method(i=layout, o=layout) + ometh = Method(i=layout, o=layout) + + @def_method(m, imeth) + def _(arg: MethodStruct): + return itransform(m, arg) + + @def_method(m, ometh) + def _(arg: MethodStruct): + return otransform(m, arg) + + trans = MethodMap(self.target.adapter.iface, i_transform=(layout, imeth), o_transform=(layout, ometh)) + else: + trans = MethodMap( + self.target.adapter.iface, + i_transform=(layout, itransform), + o_transform=(layout, otransform), + ) + + m.submodules.source = self.source = TestbenchIO(AdapterTrans(trans.use(m))) + + return m + + +class TestMethodTransformer(TestCaseWithSimulator): + m: MethodMapTestCircuit + + def source(self): + for i in range(2**self.m.iosize): + v = yield from self.m.source.call(data=i) + i1 = (i + 1) & ((1 << self.m.iosize) - 1) + assert v["data"] == (((i1 << 1) | (i1 >> (self.m.iosize - 1))) - 1) & ((1 << self.m.iosize) - 1) + + @def_method_mock(lambda self: self.m.target) + def target(self, data): + return {"data": (data << 1) | (data >> (self.m.iosize - 1))} + + def test_method_transformer(self): + self.m = MethodMapTestCircuit(4, False, False) + with self.run_simulation(self.m) as sim: + sim.add_sync_process(self.source) + sim.add_sync_process(self.target) + + def test_method_transformer_dicts(self): + self.m = MethodMapTestCircuit(4, False, True) + with self.run_simulation(self.m) as sim: + sim.add_sync_process(self.source) + + def test_method_transformer_with_methods(self): + self.m = MethodMapTestCircuit(4, True, True) + with self.run_simulation(self.m) as sim: + sim.add_sync_process(self.source) + + +class TestMethodFilter(TestCaseWithSimulator): + def initialize(self): + self.iosize = 4 + self.layout = data_layout(self.iosize) + self.target = TestbenchIO(Adapter(i=self.layout, o=self.layout)) + self.cmeth = TestbenchIO(Adapter(i=self.layout, o=data_layout(1))) + + def source(self): + for i in range(2**self.iosize): + v = yield from self.tc.method.call(data=i) + if i & 1: + assert v["data"] == (i + 1) & ((1 << self.iosize) - 1) + else: + assert v["data"] == 0 + + @def_method_mock(lambda self: self.target, sched_prio=2) + def target_mock(self, data): + return {"data": data + 1} + + @def_method_mock(lambda self: self.cmeth, sched_prio=1) + def cmeth_mock(self, data): + return {"data": data % 2} + + @parameterized.expand([(True,), (False,)]) + def test_method_filter_with_methods(self, use_condition): + self.initialize() + self.tc = SimpleTestCircuit( + MethodFilter(self.target.adapter.iface, self.cmeth.adapter.iface, use_condition=use_condition) + ) + m = ModuleConnector(test_circuit=self.tc, target=self.target, cmeth=self.cmeth) + with self.run_simulation(m) as sim: + sim.add_sync_process(self.source) + + @parameterized.expand([(True,), (False,)]) + def test_method_filter(self, use_condition): + self.initialize() + + def condition(_, v): + return v.data[0] + + self.tc = SimpleTestCircuit(MethodFilter(self.target.adapter.iface, condition, use_condition=use_condition)) + m = ModuleConnector(test_circuit=self.tc, target=self.target, cmeth=self.cmeth) + with self.run_simulation(m) as sim: + sim.add_sync_process(self.source) + + +class MethodProductTestCircuit(Elaboratable): + def __init__(self, iosize: int, targets: int, add_combiner: bool): + self.iosize = iosize + self.targets = targets + self.add_combiner = add_combiner + self.target: list[TestbenchIO] = [] + + def elaborate(self, platform): + m = TModule() + + layout = data_layout(self.iosize) + + methods = [] + + for k in range(self.targets): + tgt = TestbenchIO(Adapter(i=layout, o=layout)) + methods.append(tgt.adapter.iface) + self.target.append(tgt) + m.submodules += tgt + + combiner = None + if self.add_combiner: + combiner = (layout, lambda _, vs: {"data": sum(x.data for x in vs)}) + + product = MethodProduct(methods, combiner) + + m.submodules.method = self.method = TestbenchIO(AdapterTrans(product.use(m))) + + return m + + +class TestMethodProduct(TestCaseWithSimulator): + @parameterized.expand([(1, False), (2, False), (5, True)]) + def test_method_product(self, targets: int, add_combiner: bool): + random.seed(14) + + iosize = 8 + m = MethodProductTestCircuit(iosize, targets, add_combiner) + + method_en = [False] * targets + + def target_process(k: int): + @def_method_mock(lambda: m.target[k], enable=lambda: method_en[k]) + def process(data): + return {"data": data + k} + + return process + + def method_process(): + # if any of the target methods is not enabled, call does not succeed + for i in range(2**targets - 1): + for k in range(targets): + method_en[k] = bool(i & (1 << k)) + + yield + assert (yield from m.method.call_try(data=0)) is None + + # otherwise, the call succeeds + for k in range(targets): + method_en[k] = True + yield + + data = random.randint(0, (1 << iosize) - 1) + val = (yield from m.method.call(data=data))["data"] + if add_combiner: + assert val == (targets * data + (targets - 1) * targets // 2) & ((1 << iosize) - 1) + else: + assert val == data + + with self.run_simulation(m) as sim: + sim.add_sync_process(method_process) + for k in range(targets): + sim.add_sync_process(target_process(k)) + + +class TestSerializer(TestCaseWithSimulator): + def setup_method(self): + self.test_count = 100 + + self.port_count = 2 + self.data_width = 5 + + self.requestor_rand = 4 + + layout = [("field", self.data_width)] + + self.req_method = TestbenchIO(Adapter(i=layout)) + self.resp_method = TestbenchIO(Adapter(o=layout)) + + self.test_circuit = SimpleTestCircuit( + Serializer( + port_count=self.port_count, + serialized_req_method=self.req_method.adapter.iface, + serialized_resp_method=self.resp_method.adapter.iface, + ) + ) + self.m = ModuleConnector( + test_circuit=self.test_circuit, req_method=self.req_method, resp_method=self.resp_method + ) + + random.seed(14) + + self.serialized_data = deque() + self.port_data = [deque() for _ in range(self.port_count)] + + self.got_request = False + + @def_method_mock(lambda self: self.req_method, enable=lambda self: not self.got_request) + def serial_req_mock(self, field): + self.serialized_data.append(field) + self.got_request = True + + @def_method_mock(lambda self: self.resp_method, enable=lambda self: self.got_request) + def serial_resp_mock(self): + self.got_request = False + return {"field": self.serialized_data[-1]} + + def requestor(self, i: int): + def f(): + for _ in range(self.test_count): + d = random.randrange(2**self.data_width) + yield from self.test_circuit.serialize_in[i].call(field=d) + self.port_data[i].append(d) + yield from self.random_wait(self.requestor_rand, min_cycle_cnt=1) + + return f + + def responder(self, i: int): + def f(): + for _ in range(self.test_count): + data_out = yield from self.test_circuit.serialize_out[i].call() + assert self.port_data[i].popleft() == data_out["field"] + yield from self.random_wait(self.requestor_rand, min_cycle_cnt=1) + + return f + + def test_serial(self): + with self.run_simulation(self.m) as sim: + for i in range(self.port_count): + sim.add_sync_process(self.requestor(i)) + sim.add_sync_process(self.responder(i)) + + +class TestMethodTryProduct(TestCaseWithSimulator): + @parameterized.expand([(1, False), (2, False), (5, True)]) + def test_method_try_product(self, targets: int, add_combiner: bool): + random.seed(14) + + iosize = 8 + m = MethodTryProductTestCircuit(iosize, targets, add_combiner) + + method_en = [False] * targets + + def target_process(k: int): + @def_method_mock(lambda: m.target[k], enable=lambda: method_en[k]) + def process(data): + return {"data": data + k} + + return process + + def method_process(): + for i in range(2**targets): + for k in range(targets): + method_en[k] = bool(i & (1 << k)) + + active_targets = sum(method_en) + + yield + + data = random.randint(0, (1 << iosize) - 1) + val = yield from m.method.call(data=data) + if add_combiner: + adds = sum(k * method_en[k] for k in range(targets)) + assert val == {"data": (active_targets * data + adds) & ((1 << iosize) - 1)} + else: + assert val == {} + + with self.run_simulation(m) as sim: + sim.add_sync_process(method_process) + for k in range(targets): + sim.add_sync_process(target_process(k)) + + +class MethodTryProductTestCircuit(Elaboratable): + def __init__(self, iosize: int, targets: int, add_combiner: bool): + self.iosize = iosize + self.targets = targets + self.add_combiner = add_combiner + self.target: list[TestbenchIO] = [] + + def elaborate(self, platform): + m = TModule() + + layout = data_layout(self.iosize) + + methods = [] + + for k in range(self.targets): + tgt = TestbenchIO(Adapter(i=layout, o=layout)) + methods.append(tgt.adapter.iface) + self.target.append(tgt) + m.submodules += tgt + + combiner = None + if self.add_combiner: + combiner = (layout, lambda _, vs: {"data": sum(Mux(s, r, 0) for (s, r) in vs)}) + + product = MethodTryProduct(methods, combiner) + + m.submodules.method = self.method = TestbenchIO(AdapterTrans(product.use(m))) + + return m + + +class ConditionTestCircuit(Elaboratable): + def __init__(self, target: Method, *, nonblocking: bool, priority: bool, catchall: bool): + self.target = target + self.source = Method(i=[("cond1", 1), ("cond2", 1), ("cond3", 1)], single_caller=True) + self.nonblocking = nonblocking + self.priority = priority + self.catchall = catchall + + def elaborate(self, platform): + m = TModule() + + @def_method(m, self.source) + def _(cond1, cond2, cond3): + with condition(m, nonblocking=self.nonblocking, priority=self.priority) as branch: + with branch(cond1): + self.target(m, cond=1) + with branch(cond2): + self.target(m, cond=2) + with branch(cond3): + self.target(m, cond=3) + if self.catchall: + with branch(): + self.target(m, cond=0) + + return m + + +class TestCondition(TestCaseWithSimulator): + @pytest.mark.parametrize("nonblocking", [False, True]) + @pytest.mark.parametrize("priority", [False, True]) + @pytest.mark.parametrize("catchall", [False, True]) + def test_condition(self, nonblocking: bool, priority: bool, catchall: bool): + target = TestbenchIO(Adapter(i=[("cond", 2)])) + + circ = SimpleTestCircuit( + ConditionTestCircuit(target.adapter.iface, nonblocking=nonblocking, priority=priority, catchall=catchall) + ) + m = ModuleConnector(test_circuit=circ, target=target) + + selection: Optional[int] + + @def_method_mock(lambda: target) + def target_process(cond): + nonlocal selection + selection = cond + + def process(): + nonlocal selection + for c1, c2, c3 in product([0, 1], [0, 1], [0, 1]): + selection = None + res = yield from circ.source.call_try(cond1=c1, cond2=c2, cond3=c3) + + if catchall or nonblocking: + assert res is not None + + if res is None: + assert selection is None + assert not catchall or nonblocking + assert (c1, c2, c3) == (0, 0, 0) + elif selection is None: + assert nonblocking + assert (c1, c2, c3) == (0, 0, 0) + elif priority: + assert selection == c1 + 2 * c2 * (1 - c1) + 3 * c3 * (1 - c2) * (1 - c1) + else: + assert selection in [c1, 2 * c2, 3 * c3] + + with self.run_simulation(m) as sim: + sim.add_sync_process(process) diff --git a/test/test_adapter.py b/test/test_adapter.py new file mode 100644 index 0000000..901203e --- /dev/null +++ b/test/test_adapter.py @@ -0,0 +1,62 @@ +from amaranth import * + +from transactron import Method, def_method, TModule + + +from transactron.testing import TestCaseWithSimulator, data_layout, SimpleTestCircuit, ModuleConnector + + +class Echo(Elaboratable): + def __init__(self): + self.data_bits = 8 + + self.layout_in = data_layout(self.data_bits) + self.layout_out = data_layout(self.data_bits) + + self.action = Method(i=self.layout_in, o=self.layout_out) + + def elaborate(self, platform): + m = TModule() + + @def_method(m, self.action, ready=C(1)) + def _(arg): + return arg + + return m + + +class Consumer(Elaboratable): + def __init__(self): + self.data_bits = 8 + + self.layout_in = data_layout(self.data_bits) + self.layout_out = [] + + self.action = Method(i=self.layout_in, o=self.layout_out) + + def elaborate(self, platform): + m = TModule() + + @def_method(m, self.action, ready=C(1)) + def _(arg): + return None + + return m + + +class TestAdapterTrans(TestCaseWithSimulator): + def proc(self): + for _ in range(3): + # this would previously timeout if the output layout was empty (as is in this case) + yield from self.consumer.action.call() + for expected in [4, 1, 0]: + obtained = (yield from self.echo.action.call(data=expected))["data"] + assert expected == obtained + + def test_single(self): + self.echo = SimpleTestCircuit(Echo()) + self.consumer = SimpleTestCircuit(Consumer()) + self.m = ModuleConnector(echo=self.echo, consumer=self.consumer) + + with self.run_simulation(self.m, max_cycles=100) as sim: + sim.add_sync_process(self.proc) diff --git a/test/test_assign.py b/test/test_assign.py new file mode 100644 index 0000000..658a52c --- /dev/null +++ b/test/test_assign.py @@ -0,0 +1,125 @@ +import pytest +from typing import Callable +from amaranth import * +from amaranth.lib import data +from amaranth.hdl._ast import ArrayProxy, Slice + +from transactron.utils._typing import MethodLayout +from transactron.utils import AssignType, assign +from transactron.utils.assign import AssignArg, AssignFields + +from unittest import TestCase +from parameterized import parameterized_class, parameterized + + +layout_a = [("a", 1)] +layout_ab = [("a", 1), ("b", 2)] +layout_ac = [("a", 1), ("c", 3)] +layout_a_alt = [("a", 2)] + +params_build_wrap_extr = [ + ("normal", lambda mk, lay: mk(lay), lambda x: x, lambda r: r), + ("rec", lambda mk, lay: mk([("x", lay)]), lambda x: {"x": x}, lambda r: r.x), + ("dict", lambda mk, lay: {"x": mk(lay)}, lambda x: {"x": x}, lambda r: r["x"]), + ("list", lambda mk, lay: [mk(lay)], lambda x: {0: x}, lambda r: r[0]), + ("array", lambda mk, lay: Signal(data.ArrayLayout(reclayout2datalayout(lay), 1)), lambda x: {0: x}, lambda r: r[0]), +] + + +def mkproxy(layout): + arr = Array([Signal(reclayout2datalayout(layout)) for _ in range(4)]) + sig = Signal(2) + return arr[sig] + + +def reclayout2datalayout(layout): + if not isinstance(layout, list): + return layout + return data.StructLayout({k: reclayout2datalayout(lay) for k, lay in layout}) + + +def mkstruct(layout): + return Signal(reclayout2datalayout(layout)) + + +params_mk = [ + ("proxy", mkproxy), + ("struct", mkstruct), +] + + +@parameterized_class( + ["name", "build", "wrap", "extr", "constr", "mk"], + [ + (n, *map(staticmethod, (b, w, e)), c, staticmethod(m)) + for n, b, w, e in params_build_wrap_extr + for c, m in params_mk + ], +) +class TestAssign(TestCase): + # constructs `assign` arguments (views, proxies, dicts) which have an "inner" and "outer" part + # parameterized with a constructor and a layout of the inner part + build: Callable[[Callable[[MethodLayout], AssignArg], MethodLayout], AssignArg] + # constructs field specifications for `assign`, takes field specifications for the inner part + wrap: Callable[[AssignFields], AssignFields] + # extracts the inner part of the structure + extr: Callable[[AssignArg], ArrayProxy] + # constructor, takes a layout + mk: Callable[[MethodLayout], AssignArg] + + def test_rhs_exception(self): + with pytest.raises(KeyError): + list(assign(self.build(self.mk, layout_a), self.build(self.mk, layout_ab), fields=AssignType.RHS)) + with pytest.raises(KeyError): + list(assign(self.build(self.mk, layout_ab), self.build(self.mk, layout_ac), fields=AssignType.RHS)) + + def test_all_exception(self): + with pytest.raises(KeyError): + list(assign(self.build(self.mk, layout_a), self.build(self.mk, layout_ab), fields=AssignType.ALL)) + with pytest.raises(KeyError): + list(assign(self.build(self.mk, layout_ab), self.build(self.mk, layout_a), fields=AssignType.ALL)) + with pytest.raises(KeyError): + list(assign(self.build(self.mk, layout_ab), self.build(self.mk, layout_ac), fields=AssignType.ALL)) + + def test_missing_exception(self): + with pytest.raises(KeyError): + list(assign(self.build(self.mk, layout_a), self.build(self.mk, layout_ab), fields=self.wrap({"b"}))) + with pytest.raises(KeyError): + list(assign(self.build(self.mk, layout_ab), self.build(self.mk, layout_a), fields=self.wrap({"b"}))) + with pytest.raises(KeyError): + list(assign(self.build(self.mk, layout_a), self.build(self.mk, layout_a), fields=self.wrap({"b"}))) + + def test_wrong_bits(self): + with pytest.raises(ValueError): + list(assign(self.build(self.mk, layout_a), self.build(self.mk, layout_a_alt))) + + @parameterized.expand( + [ + ("rhs", layout_ab, layout_a, AssignType.RHS), + ("all", layout_a, layout_a, AssignType.ALL), + ("common", layout_ab, layout_ac, AssignType.COMMON), + ("set", layout_ab, layout_ab, {"a"}), + ("list", layout_ab, layout_ab, ["a", "a"]), + ] + ) + def test_assign_a(self, name, layout1: MethodLayout, layout2: MethodLayout, atype: AssignType): + lhs = self.build(self.mk, layout1) + rhs = self.build(self.mk, layout2) + alist = list(assign(lhs, rhs, fields=self.wrap(atype))) + assert len(alist) == 1 + self.assertIs_AP(alist[0].lhs, self.extr(lhs).a) + self.assertIs_AP(alist[0].rhs, self.extr(rhs).a) + + def assertIs_AP(self, expr1, expr2): # noqa: N802 + if isinstance(expr1, ArrayProxy) and isinstance(expr2, ArrayProxy): + # new proxies are created on each index, structural equality is needed + self.assertIs(expr1.index, expr2.index) + assert len(expr1.elems) == len(expr2.elems) + for x, y in zip(expr1.elems, expr2.elems): + self.assertIs_AP(x, y) + elif isinstance(expr1, Slice) and isinstance(expr2, Slice): + self.assertIs_AP(expr1.value, expr2.value) + assert expr1.start == expr2.start + assert expr1.stop == expr2.stop + else: + self.assertIs(expr1, expr2) diff --git a/test/test_branches.py b/test/test_branches.py new file mode 100644 index 0000000..9af6b26 --- /dev/null +++ b/test/test_branches.py @@ -0,0 +1,99 @@ +from amaranth import * +from itertools import product +from transactron.core import ( + TModule, + Method, + Transaction, + TransactionManager, + TransactionModule, + def_method, +) +from transactron.core.tmodule import CtrlPath +from transactron.core.manager import MethodMap +from unittest import TestCase +from transactron.testing import TestCaseWithSimulator +from transactron.utils.dependencies import DependencyContext + + +class TestExclusivePath(TestCase): + def test_exclusive_path(self): + m = TModule() + m._MustUse__silence = True # type: ignore + + with m.If(0): + cp0 = m.ctrl_path + with m.Switch(3): + with m.Case(0): + cp0a0 = m.ctrl_path + with m.Case(1): + cp0a1 = m.ctrl_path + with m.Default(): + cp0a2 = m.ctrl_path + with m.If(1): + cp0b0 = m.ctrl_path + with m.Else(): + cp0b1 = m.ctrl_path + with m.Elif(1): + cp1 = m.ctrl_path + with m.FSM(): + with m.State("start"): + cp10 = m.ctrl_path + with m.State("next"): + cp11 = m.ctrl_path + with m.Else(): + cp2 = m.ctrl_path + + def mutually_exclusive(*cps: CtrlPath): + return all(cpa.exclusive_with(cpb) for i, cpa in enumerate(cps) for cpb in cps[i + 1 :]) + + def pairwise_exclusive(cps1: list[CtrlPath], cps2: list[CtrlPath]): + return all(cpa.exclusive_with(cpb) for cpa, cpb in product(cps1, cps2)) + + def pairwise_not_exclusive(cps1: list[CtrlPath], cps2: list[CtrlPath]): + return all(not cpa.exclusive_with(cpb) for cpa, cpb in product(cps1, cps2)) + + assert mutually_exclusive(cp0, cp1, cp2) + assert mutually_exclusive(cp0a0, cp0a1, cp0a2) + assert mutually_exclusive(cp0b0, cp0b1) + assert mutually_exclusive(cp10, cp11) + assert pairwise_exclusive([cp0, cp0a0, cp0a1, cp0a2, cp0b0, cp0b1], [cp1, cp10, cp11]) + assert pairwise_not_exclusive([cp0, cp0a0, cp0a1, cp0a2], [cp0, cp0b0, cp0b1]) + + +class ExclusiveConflictRemovalCircuit(Elaboratable): + def __init__(self): + self.sel = Signal() + + def elaborate(self, platform): + m = TModule() + + called_method = Method(i=[], o=[]) + + @def_method(m, called_method) + def _(): + pass + + with m.If(self.sel): + with Transaction().body(m): + called_method(m) + with m.Else(): + with Transaction().body(m): + called_method(m) + + return m + + +class TestExclusiveConflictRemoval(TestCaseWithSimulator): + def test_conflict_removal(self): + circ = ExclusiveConflictRemovalCircuit() + + tm = TransactionManager() + dut = TransactionModule(circ, DependencyContext.get(), tm) + + with self.run_simulation(dut, add_transaction_module=False): + pass + + cgr, _, _ = tm._conflict_graph(MethodMap(tm.transactions)) + + for s in cgr.values(): + assert not s diff --git a/test/test_methods.py b/test/test_methods.py new file mode 100644 index 0000000..03d9a87 --- /dev/null +++ b/test/test_methods.py @@ -0,0 +1,644 @@ +import pytest +import random +from amaranth import * +from amaranth.sim import * + +from transactron.testing import TestCaseWithSimulator, TestbenchIO, data_layout + +from transactron import * +from transactron.utils import MethodStruct +from transactron.lib import * + +from parameterized import parameterized + +from unittest import TestCase + + +class TestDefMethod(TestCaseWithSimulator): + class CircuitTestModule(Elaboratable): + def __init__(self, method_definition): + self.method = Method( + i=[("foo1", 3), ("foo2", [("bar1", 4), ("bar2", 6)])], + o=[("foo1", 3), ("foo2", [("bar1", 4), ("bar2", 6)])], + ) + + self.method_definition = method_definition + + def elaborate(self, platform): + m = TModule() + m._MustUse__silence = True # type: ignore + + def_method(m, self.method)(self.method_definition) + + return m + + def do_test_definition(self, definer): + with self.run_simulation(TestDefMethod.CircuitTestModule(definer)): + pass + + def test_fields_valid1(self): + def definition(arg): + return {"foo1": Signal(3), "foo2": {"bar1": Signal(4), "bar2": Signal(6)}} + + self.do_test_definition(definition) + + def test_fields_valid2(self): + rec = Signal(from_method_layout([("bar1", 4), ("bar2", 6)])) + + def definition(arg): + return {"foo1": Signal(3), "foo2": rec} + + self.do_test_definition(definition) + + def test_fields_valid3(self): + def definition(arg): + return arg + + self.do_test_definition(definition) + + def test_fields_valid4(self): + def definition(arg: MethodStruct): + return arg + + self.do_test_definition(definition) + + def test_fields_valid5(self): + def definition(**arg): + return arg + + self.do_test_definition(definition) + + def test_fields_valid6(self): + def definition(foo1, foo2): + return {"foo1": foo1, "foo2": foo2} + + self.do_test_definition(definition) + + def test_fields_valid7(self): + def definition(foo1, **arg): + return {"foo1": foo1, "foo2": arg["foo2"]} + + self.do_test_definition(definition) + + def test_fields_invalid1(self): + def definition(arg): + return {"foo1": Signal(3), "baz": Signal(4)} + + with pytest.raises(KeyError): + self.do_test_definition(definition) + + def test_fields_invalid2(self): + def definition(arg): + return {"foo1": Signal(3)} + + with pytest.raises(KeyError): + self.do_test_definition(definition) + + def test_fields_invalid3(self): + def definition(arg): + return {"foo1": {"baz1": Signal(), "baz2": Signal()}, "foo2": {"bar1": Signal(4), "bar2": Signal(6)}} + + with pytest.raises(TypeError): + self.do_test_definition(definition) + + def test_fields_invalid4(self): + def definition(arg: Value): + return arg + + with pytest.raises(TypeError): + self.do_test_definition(definition) + + def test_fields_invalid5(self): + def definition(foo): + return foo + + with pytest.raises(TypeError): + self.do_test_definition(definition) + + def test_fields_invalid6(self): + def definition(foo1): + return {"foo1": foo1, "foo2": {"bar1": Signal(4), "bar2": Signal(6)}} + + with pytest.raises(TypeError): + self.do_test_definition(definition) + + +class AdapterCircuit(Elaboratable): + def __init__(self, module, methods): + self.module = module + self.methods = methods + + def elaborate(self, platform): + m = TModule() + + m.submodules += self.module + for method in self.methods: + m.submodules += AdapterTrans(method) + + return m + + +class TestInvalidMethods(TestCase): + def assert_re(self, msg, m): + with pytest.raises(RuntimeError, match=msg): + Fragment.get(TransactionModule(m), platform=None) + + def test_twice(self): + class Twice(Elaboratable): + def __init__(self): + self.meth1 = Method() + self.meth2 = Method() + + def elaborate(self, platform): + m = TModule() + m._MustUse__silence = True # type: ignore + + with self.meth1.body(m): + pass + + with self.meth2.body(m): + self.meth1(m) + self.meth1(m) + + return m + + self.assert_re("called twice", Twice()) + + def test_twice_cond(self): + class Twice(Elaboratable): + def __init__(self): + self.meth1 = Method() + self.meth2 = Method() + + def elaborate(self, platform): + m = TModule() + m._MustUse__silence = True # type: ignore + + with self.meth1.body(m): + pass + + with self.meth2.body(m): + with m.If(1): + self.meth1(m) + with m.Else(): + self.meth1(m) + + return m + + Fragment.get(TransactionModule(Twice()), platform=None) + + def test_diamond(self): + class Diamond(Elaboratable): + def __init__(self): + self.meth1 = Method() + self.meth2 = Method() + self.meth3 = Method() + self.meth4 = Method() + + def elaborate(self, platform): + m = TModule() + + with self.meth1.body(m): + pass + + with self.meth2.body(m): + self.meth1(m) + + with self.meth3.body(m): + self.meth1(m) + + with self.meth4.body(m): + self.meth2(m) + self.meth3(m) + + return m + + m = Diamond() + self.assert_re("called twice", AdapterCircuit(m, [m.meth4])) + + def test_loop(self): + class Loop(Elaboratable): + def __init__(self): + self.meth1 = Method() + + def elaborate(self, platform): + m = TModule() + + with self.meth1.body(m): + self.meth1(m) + + return m + + m = Loop() + self.assert_re("called twice", AdapterCircuit(m, [m.meth1])) + + def test_cycle(self): + class Cycle(Elaboratable): + def __init__(self): + self.meth1 = Method() + self.meth2 = Method() + + def elaborate(self, platform): + m = TModule() + + with self.meth1.body(m): + self.meth2(m) + + with self.meth2.body(m): + self.meth1(m) + + return m + + m = Cycle() + self.assert_re("called twice", AdapterCircuit(m, [m.meth1])) + + def test_redefine(self): + class Redefine(Elaboratable): + def elaborate(self, platform): + m = TModule() + m._MustUse__silence = True # type: ignore + + meth = Method() + + with meth.body(m): + pass + + with meth.body(m): + pass + + self.assert_re("already defined", Redefine()) + + def test_undefined_in_trans(self): + class Undefined(Elaboratable): + def __init__(self): + self.meth = Method(i=data_layout(1)) + + def elaborate(self, platform): + return TModule() + + class Circuit(Elaboratable): + def elaborate(self, platform): + m = TModule() + + m.submodules.undefined = undefined = Undefined() + m.submodules.adapter = AdapterTrans(undefined.meth) + + return m + + self.assert_re("not defined", Circuit()) + + +WIDTH = 8 + + +class Quadruple(Elaboratable): + def __init__(self): + layout = data_layout(WIDTH) + self.id = Method(i=layout, o=layout) + self.double = Method(i=layout, o=layout) + self.quadruple = Method(i=layout, o=layout) + + def elaborate(self, platform): + m = TModule() + + @def_method(m, self.id) + def _(arg): + return arg + + @def_method(m, self.double) + def _(arg): + return {"data": self.id(m, arg).data * 2} + + @def_method(m, self.quadruple) + def _(arg): + return {"data": self.double(m, arg).data * 2} + + return m + + +class QuadrupleCircuit(Elaboratable): + def __init__(self, quadruple): + self.quadruple = quadruple + + def elaborate(self, platform): + m = TModule() + + m.submodules.quadruple = self.quadruple + m.submodules.tb = self.tb = TestbenchIO(AdapterTrans(self.quadruple.quadruple)) + + return m + + +class Quadruple2(Elaboratable): + def __init__(self): + layout = data_layout(WIDTH) + self.quadruple = Method(i=layout, o=layout) + + def elaborate(self, platform): + m = TModule() + + m.submodules.sub = Quadruple() + + @def_method(m, self.quadruple) + def _(arg): + return {"data": 2 * m.submodules.sub.double(m, arg).data} + + return m + + +class TestQuadrupleCircuits(TestCaseWithSimulator): + @parameterized.expand([(Quadruple,), (Quadruple2,)]) + def test(self, quadruple): + circ = QuadrupleCircuit(quadruple()) + + def process(): + for n in range(1 << (WIDTH - 2)): + out = yield from circ.tb.call(data=n) + assert out["data"] == n * 4 + + with self.run_simulation(circ) as sim: + sim.add_sync_process(process) + + +class ConditionalCallCircuit(Elaboratable): + def elaborate(self, platform): + m = TModule() + + meth = Method(i=data_layout(1)) + + m.submodules.tb = self.tb = TestbenchIO(AdapterTrans(meth)) + m.submodules.out = self.out = TestbenchIO(Adapter()) + + @def_method(m, meth) + def _(arg): + with m.If(arg): + self.out.adapter.iface(m) + + return m + + +class ConditionalMethodCircuit1(Elaboratable): + def elaborate(self, platform): + m = TModule() + + meth = Method() + + self.ready = Signal() + m.submodules.tb = self.tb = TestbenchIO(AdapterTrans(meth)) + + @def_method(m, meth, ready=self.ready) + def _(arg): + pass + + return m + + +class ConditionalMethodCircuit2(Elaboratable): + def elaborate(self, platform): + m = TModule() + + meth = Method() + + self.ready = Signal() + m.submodules.tb = self.tb = TestbenchIO(AdapterTrans(meth)) + + with m.If(self.ready): + + @def_method(m, meth) + def _(arg): + pass + + return m + + +class ConditionalTransactionCircuit1(Elaboratable): + def elaborate(self, platform): + m = TModule() + + self.ready = Signal() + m.submodules.tb = self.tb = TestbenchIO(Adapter()) + + with Transaction().body(m, request=self.ready): + self.tb.adapter.iface(m) + + return m + + +class ConditionalTransactionCircuit2(Elaboratable): + def elaborate(self, platform): + m = TModule() + + self.ready = Signal() + m.submodules.tb = self.tb = TestbenchIO(Adapter()) + + with m.If(self.ready): + with Transaction().body(m): + self.tb.adapter.iface(m) + + return m + + +class TestConditionals(TestCaseWithSimulator): + def test_conditional_call(self): + circ = ConditionalCallCircuit() + + def process(): + yield from circ.out.disable() + yield from circ.tb.call_init(data=0) + yield Settle() + assert not (yield from circ.out.done()) + assert not (yield from circ.tb.done()) + + yield from circ.out.enable() + yield Settle() + assert not (yield from circ.out.done()) + assert (yield from circ.tb.done()) + + yield from circ.tb.call_init(data=1) + yield Settle() + assert (yield from circ.out.done()) + assert (yield from circ.tb.done()) + + # the argument is still 1 but the method is not called + yield from circ.tb.disable() + yield Settle() + assert not (yield from circ.out.done()) + assert not (yield from circ.tb.done()) + + with self.run_simulation(circ) as sim: + sim.add_sync_process(process) + + @parameterized.expand( + [ + (ConditionalMethodCircuit1,), + (ConditionalMethodCircuit2,), + (ConditionalTransactionCircuit1,), + (ConditionalTransactionCircuit2,), + ] + ) + def test_conditional(self, elaboratable): + circ = elaboratable() + + def process(): + yield from circ.tb.enable() + yield circ.ready.eq(0) + yield Settle() + assert not (yield from circ.tb.done()) + + yield circ.ready.eq(1) + yield Settle() + assert (yield from circ.tb.done()) + + with self.run_simulation(circ) as sim: + sim.add_sync_process(process) + + +class NonexclusiveMethodCircuit(Elaboratable): + def elaborate(self, platform): + m = TModule() + + self.ready = Signal() + self.running = Signal() + self.data = Signal(WIDTH) + + method = Method(o=data_layout(WIDTH), nonexclusive=True) + + @def_method(m, method, self.ready) + def _(): + m.d.comb += self.running.eq(1) + return {"data": self.data} + + m.submodules.t1 = self.t1 = TestbenchIO(AdapterTrans(method)) + m.submodules.t2 = self.t2 = TestbenchIO(AdapterTrans(method)) + + return m + + +class TestNonexclusiveMethod(TestCaseWithSimulator): + def test_nonexclusive_method(self): + circ = NonexclusiveMethodCircuit() + + def process(): + for x in range(8): + t1en = bool(x & 1) + t2en = bool(x & 2) + mrdy = bool(x & 4) + + if t1en: + yield from circ.t1.enable() + else: + yield from circ.t1.disable() + + if t2en: + yield from circ.t2.enable() + else: + yield from circ.t2.disable() + + if mrdy: + yield circ.ready.eq(1) + else: + yield circ.ready.eq(0) + + yield circ.data.eq(x) + yield Settle() + + assert bool((yield circ.running)) == ((t1en or t2en) and mrdy) + assert bool((yield from circ.t1.done())) == (t1en and mrdy) + assert bool((yield from circ.t2.done())) == (t2en and mrdy) + + if t1en and mrdy: + assert (yield from circ.t1.get_outputs()) == {"data": x} + + if t2en and mrdy: + assert (yield from circ.t2.get_outputs()) == {"data": x} + + with self.run_simulation(circ) as sim: + sim.add_sync_process(process) + + +class DataDependentConditionalCircuit(Elaboratable): + def __init__(self, n=2, ready_function=lambda arg: arg.data != 3): + self.method = Method(i=data_layout(n)) + self.ready_function = ready_function + + self.in_t1 = Signal(from_method_layout(data_layout(n))) + self.in_t2 = Signal(from_method_layout(data_layout(n))) + self.ready = Signal() + self.req_t1 = Signal() + self.req_t2 = Signal() + + self.out_m = Signal() + self.out_t1 = Signal() + self.out_t2 = Signal() + + def elaborate(self, platform): + m = TModule() + + @def_method(m, self.method, self.ready, validate_arguments=self.ready_function) + def _(data): + m.d.comb += self.out_m.eq(1) + + with Transaction().body(m, request=self.req_t1): + m.d.comb += self.out_t1.eq(1) + self.method(m, self.in_t1) + + with Transaction().body(m, request=self.req_t2): + m.d.comb += self.out_t2.eq(1) + self.method(m, self.in_t2) + + return m + + +class TestDataDependentConditionalMethod(TestCaseWithSimulator): + def setup_method(self): + self.test_number = 200 + self.bad_number = 3 + self.n = 2 + + def base_random(self, f): + random.seed(14) + self.circ = DataDependentConditionalCircuit(n=self.n, ready_function=f) + + def process(): + for _ in range(self.test_number): + in1 = random.randrange(0, 2**self.n) + in2 = random.randrange(0, 2**self.n) + m_ready = random.randrange(2) + req_t1 = random.randrange(2) + req_t2 = random.randrange(2) + + yield self.circ.in_t1.eq(in1) + yield self.circ.in_t2.eq(in2) + yield self.circ.req_t1.eq(req_t1) + yield self.circ.req_t2.eq(req_t2) + yield self.circ.ready.eq(m_ready) + yield Settle() + + out_m = yield self.circ.out_m + out_t1 = yield self.circ.out_t1 + out_t2 = yield self.circ.out_t2 + + if not m_ready or (not req_t1 or in1 == self.bad_number) and (not req_t2 or in2 == self.bad_number): + assert out_m == 0 + assert out_t1 == 0 + assert out_t2 == 0 + continue + # Here method global ready signal is high and we requested one of the transactions + # we also know that one of the transactions request correct input data + + assert out_m == 1 + assert out_t1 ^ out_t2 == 1 + # inX == self.bad_number implies out_tX==0 + assert in1 != self.bad_number or not out_t1 + assert in2 != self.bad_number or not out_t2 + + yield + + with self.run_simulation(self.circ, 100) as sim: + sim.add_sync_process(process) + + def test_random_arg(self): + self.base_random(lambda arg: arg.data != self.bad_number) + + def test_random_kwarg(self): + self.base_random(lambda data: data != self.bad_number) diff --git a/test/test_simultaneous.py b/test/test_simultaneous.py new file mode 100644 index 0000000..eac74b8 --- /dev/null +++ b/test/test_simultaneous.py @@ -0,0 +1,172 @@ +import pytest +from itertools import product +from typing import Optional +from amaranth import * +from amaranth.sim import * + +from transactron.utils import ModuleConnector + +from transactron.testing import SimpleTestCircuit, TestCaseWithSimulator, TestbenchIO, def_method_mock + +from transactron import * +from transactron.lib import Adapter, Connect, ConnectTrans + + +def empty_method(m: TModule, method: Method): + @def_method(m, method) + def _(): + pass + + +class SimultaneousDiamondTestCircuit(Elaboratable): + def __init__(self): + self.method_l = Method() + self.method_r = Method() + self.method_u = Method() + self.method_d = Method() + + def elaborate(self, platform): + m = TModule() + + empty_method(m, self.method_l) + empty_method(m, self.method_r) + empty_method(m, self.method_u) + empty_method(m, self.method_d) + + # the only possibilities for the following are: (l, u, r) or (l, d, r) + self.method_l.simultaneous_alternatives(self.method_u, self.method_d) + self.method_r.simultaneous_alternatives(self.method_u, self.method_d) + + return m + + +class TestSimultaneousDiamond(TestCaseWithSimulator): + def test_diamond(self): + circ = SimpleTestCircuit(SimultaneousDiamondTestCircuit()) + + def process(): + methods = {"l": circ.method_l, "r": circ.method_r, "u": circ.method_u, "d": circ.method_d} + for i in range(1 << len(methods)): + enables: dict[str, bool] = {} + for k, n in enumerate(methods): + enables[n] = bool(i & (1 << k)) + yield from methods[n].set_enable(enables[n]) + yield + dones: dict[str, bool] = {} + for n in methods: + dones[n] = bool((yield from methods[n].done())) + for n in methods: + if not enables[n]: + assert not dones[n] + if enables["l"] and enables["r"] and (enables["u"] or enables["d"]): + assert dones["l"] + assert dones["r"] + assert dones["u"] or dones["d"] + else: + assert not any(dones.values()) + + with self.run_simulation(circ) as sim: + sim.add_sync_process(process) + + +class UnsatisfiableTriangleTestCircuit(Elaboratable): + def __init__(self): + self.method_l = Method() + self.method_u = Method() + self.method_d = Method() + + def elaborate(self, platform): + m = TModule() + + empty_method(m, self.method_l) + empty_method(m, self.method_u) + empty_method(m, self.method_d) + + # the following is unsatisfiable + self.method_l.simultaneous_alternatives(self.method_u, self.method_d) + self.method_u.simultaneous(self.method_d) + + return m + + +class TestUnsatisfiableTriangle(TestCaseWithSimulator): + def test_unsatisfiable(self): + circ = SimpleTestCircuit(UnsatisfiableTriangleTestCircuit()) + + with pytest.raises(RuntimeError): + with self.run_simulation(circ) as _: + pass + + +class HelperConnect(Elaboratable): + def __init__(self, source: Method, target: Method, request: Signal, data: int): + self.source = source + self.target = target + self.request = request + self.data = data + + def elaborate(self, platform): + m = TModule() + + with Transaction().body(m, request=self.request): + self.target(m, self.data ^ self.source(m).data) + + return m + + +class TransitivityTestCircuit(Elaboratable): + def __init__(self, target: Method, req1: Signal, req2: Signal): + self.source1 = Method(i=[("data", 2)]) + self.source2 = Method(i=[("data", 2)]) + self.target = target + self.req1 = req1 + self.req2 = req2 + + def elaborate(self, platform): + m = TModule() + + m.submodules.c1 = c1 = Connect([("data", 2)]) + m.submodules.c2 = c2 = Connect([("data", 2)]) + self.source1.proxy(m, c1.write) + self.source2.proxy(m, c1.write) + m.submodules.ct = ConnectTrans(c2.read, self.target) + m.submodules.hc1 = HelperConnect(c1.read, c2.write, self.req1, 1) + m.submodules.hc2 = HelperConnect(c1.read, c2.write, self.req2, 2) + + return m + + +class TestTransitivity(TestCaseWithSimulator): + def test_transitivity(self): + target = TestbenchIO(Adapter(i=[("data", 2)])) + req1 = Signal() + req2 = Signal() + + circ = SimpleTestCircuit(TransitivityTestCircuit(target.adapter.iface, req1, req2)) + m = ModuleConnector(test_circuit=circ, target=target) + + result: Optional[int] + + @def_method_mock(lambda: target) + def target_process(data): + nonlocal result + result = data + + def process(): + nonlocal result + for source, data, reqv1, reqv2 in product([circ.source1, circ.source2], [0, 1, 2, 3], [0, 1], [0, 1]): + result = None + yield req1.eq(reqv1) + yield req2.eq(reqv2) + call_result = yield from source.call_try(data=data) + + if not reqv1 and not reqv2: + assert call_result is None + assert result is None + else: + assert call_result is not None + possibles = reqv1 * [data ^ 1] + reqv2 * [data ^ 2] + assert result in possibles + + with self.run_simulation(m) as sim: + sim.add_sync_process(process) diff --git a/test/utils/__init__.py b/test/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/utils/test_onehotswitch.py b/test/utils/test_onehotswitch.py new file mode 100644 index 0000000..9d7dc84 --- /dev/null +++ b/test/utils/test_onehotswitch.py @@ -0,0 +1,62 @@ +from amaranth import * +from amaranth.sim import * + +from transactron.utils import OneHotSwitch + +from transactron.testing import TestCaseWithSimulator + +from parameterized import parameterized + + +class OneHotSwitchCircuit(Elaboratable): + def __init__(self, width: int, test_zero: bool): + self.input = Signal(1 << width) + self.output = Signal(width) + self.zero = Signal() + self.test_zero = test_zero + + def elaborate(self, platform): + m = Module() + + with OneHotSwitch(m, self.input) as OneHotCase: + for i in range(len(self.input)): + with OneHotCase(1 << i): + m.d.comb += self.output.eq(i) + + if self.test_zero: + with OneHotCase(): + m.d.comb += self.zero.eq(1) + + return m + + +class TestAssign(TestCaseWithSimulator): + @parameterized.expand([(False,), (True,)]) + def test_onehotswitch(self, test_zero): + circuit = OneHotSwitchCircuit(4, test_zero) + + def switch_test_proc(): + for i in range(len(circuit.input)): + yield circuit.input.eq(1 << i) + yield Settle() + assert (yield circuit.output) == i + + with self.run_simulation(circuit) as sim: + sim.add_process(switch_test_proc) + + def test_onehotswitch_zero(self): + circuit = OneHotSwitchCircuit(4, True) + + def switch_test_proc_zero(): + for i in range(len(circuit.input)): + yield circuit.input.eq(1 << i) + yield Settle() + assert (yield circuit.output) == i + assert not (yield circuit.zero) + + yield circuit.input.eq(0) + yield Settle() + assert (yield circuit.zero) + + with self.run_simulation(circuit) as sim: + sim.add_process(switch_test_proc_zero) diff --git a/test/utils/test_utils.py b/test/utils/test_utils.py new file mode 100644 index 0000000..63c1761 --- /dev/null +++ b/test/utils/test_utils.py @@ -0,0 +1,196 @@ +import unittest +import random + +from amaranth import * +from transactron.testing import * +from transactron.utils import ( + align_to_power_of_two, + align_down_to_power_of_two, + popcount, + count_leading_zeros, + count_trailing_zeros, +) +from parameterized import parameterized_class + + +class TestAlignToPowerOfTwo(unittest.TestCase): + def test_align_to_power_of_two(self): + test_cases = [ + (2, 2, 4), + (2, 1, 2), + (3, 1, 4), + (7, 3, 8), + (8, 3, 8), + (14, 3, 16), + (17, 3, 24), + (33, 3, 40), + (33, 1, 34), + (33, 0, 33), + (33, 4, 48), + (33, 5, 64), + (33, 6, 64), + ] + + for num, power, expected in test_cases: + out = align_to_power_of_two(num, power) + assert expected == out + + def test_align_down_to_power_of_two(self): + test_cases = [ + (3, 1, 2), + (3, 0, 3), + (3, 3, 0), + (8, 3, 8), + (8, 2, 8), + (33, 5, 32), + (29, 5, 0), + (29, 1, 28), + (29, 3, 24), + ] + + for num, power, expected in test_cases: + out = align_down_to_power_of_two(num, power) + assert expected == out + + +class PopcountTestCircuit(Elaboratable): + def __init__(self, size: int): + self.sig_in = Signal(size) + self.sig_out = Signal(size) + + def elaborate(self, platform): + m = Module() + + m.d.comb += self.sig_out.eq(popcount(self.sig_in)) + + return m + + +@parameterized_class( + ("name", "size"), + [("size" + str(s), s) for s in [2, 3, 4, 5, 6, 8, 10, 16, 21, 32, 33, 64, 1025]], +) +class TestPopcount(TestCaseWithSimulator): + size: int + + def setup_method(self): + random.seed(14) + self.test_number = 40 + self.m = PopcountTestCircuit(self.size) + + def check(self, n): + yield self.m.sig_in.eq(n) + yield Settle() + out_popcount = yield self.m.sig_out + assert out_popcount == n.bit_count(), f"{n:x}" + + def process(self): + for i in range(self.test_number): + n = random.randrange(2**self.size) + yield from self.check(n) + yield from self.check(2**self.size - 1) + + def test_popcount(self): + with self.run_simulation(self.m) as sim: + sim.add_process(self.process) + + +class CLZTestCircuit(Elaboratable): + def __init__(self, xlen_log: int): + self.sig_in = Signal(1 << xlen_log) + self.sig_out = Signal(xlen_log + 1) + self.xlen_log = xlen_log + + def elaborate(self, platform): + m = Module() + + m.d.comb += self.sig_out.eq(count_leading_zeros(self.sig_in)) + # dummy signal + s = Signal() + m.d.sync += s.eq(1) + + return m + + +@parameterized_class( + ("name", "size"), + [("size" + str(s), s) for s in range(1, 7)], +) +class TestCountLeadingZeros(TestCaseWithSimulator): + size: int + + def setup_method(self): + random.seed(14) + self.test_number = 40 + self.m = CLZTestCircuit(self.size) + + def check(self, n): + yield self.m.sig_in.eq(n) + yield Settle() + out_clz = yield self.m.sig_out + assert out_clz == (2**self.size) - n.bit_length(), f"{n:x}" + + def process(self): + for i in range(self.test_number): + n = random.randrange(2**self.size) + yield from self.check(n) + yield from self.check(2**self.size - 1) + + def test_count_leading_zeros(self): + with self.run_simulation(self.m) as sim: + sim.add_process(self.process) + + +class CTZTestCircuit(Elaboratable): + def __init__(self, xlen_log: int): + self.sig_in = Signal(1 << xlen_log) + self.sig_out = Signal(xlen_log + 1) + self.xlen_log = xlen_log + + def elaborate(self, platform): + m = Module() + + m.d.comb += self.sig_out.eq(count_trailing_zeros(self.sig_in)) + # dummy signal + s = Signal() + m.d.sync += s.eq(1) + + return m + + +@parameterized_class( + ("name", "size"), + [("size" + str(s), s) for s in range(1, 7)], +) +class TestCountTrailingZeros(TestCaseWithSimulator): + size: int + + def setup_method(self): + random.seed(14) + self.test_number = 40 + self.m = CTZTestCircuit(self.size) + + def check(self, n): + yield self.m.sig_in.eq(n) + yield Settle() + out_ctz = yield self.m.sig_out + + expected = 0 + if n == 0: + expected = 2**self.size + else: + while (n & 1) == 0: + expected += 1 + n >>= 1 + + assert out_ctz == expected, f"{n:x}" + + def process(self): + for i in range(self.test_number): + n = random.randrange(2**self.size) + yield from self.check(n) + yield from self.check(2**self.size - 1) + + def test_count_trailing_zeros(self): + with self.run_simulation(self.m) as sim: + sim.add_process(self.process)